Path: blob/main/py-polars/tests/unit/lazyframe/cuda/test_node_visitor.py
6940 views
from __future__ import annotations12import sys3import time4from functools import lru_cache, partial5from typing import TYPE_CHECKING, Any, Callable67import polars as pl8from polars._plr import _ir_nodes9from polars._utils.wrap import wrap_df1011if TYPE_CHECKING:12from pathlib import Path1314import pandas as pd151617class Timer:18"""Simple-minded timing of nodes."""1920def __init__(self, start: int | None) -> None:21self.start = start22self.timings: list[tuple[int, int, str]] = []2324def record(self, fn: Callable[[], pd.DataFrame], name: str) -> pd.DataFrame:25start = time.monotonic_ns()26result = fn()27end = time.monotonic_ns()28if self.start is not None:29self.timings.append((start - self.start, end - self.start, name))30return result313233def test_run_on_pandas() -> None:34# Simple join example, missing multiple columns, slices, etc.35def join(36inputs: list[Callable[[], pd.DataFrame]],37obj: Any,38_node_traverser: Any,39timer: Timer,40) -> Callable[[], pd.DataFrame]:41assert len(obj.left_on) == 142assert len(obj.right_on) == 143left_on = obj.left_on[0].output_name44right_on = obj.right_on[0].output_name4546assert len(inputs) == 24748def run(inputs: list[Callable[[], pd.DataFrame]]) -> pd.DataFrame:49# materialize inputs50dataframes = [call() for call in inputs]51return timer.record(52lambda: dataframes[0].merge(53dataframes[1], left_on=left_on, right_on=right_on54),55"pandas-join",56)5758return partial(run, inputs)5960# Simple scan example, missing predicates, columns pruning, slices, etc.61def df_scan(62_inputs: None, obj: Any, _: Any, timer: Timer63) -> Callable[[], pd.DataFrame]:64assert obj.selection is None65return lambda: timer.record(lambda: wrap_df(obj.df).to_pandas(), "pandas-scan")6667@lru_cache(1)68def get_node_converters() -> dict[69type, Callable[[Any, Any, Any, Timer], Callable[[], pd.DataFrame]]70]:71return {72_ir_nodes.Join: join,73_ir_nodes.DataFrameScan: df_scan,74}7576def get_input(node_traverser: Any, *, timer: Timer) -> Callable[[], pd.DataFrame]:77current_node = node_traverser.get_node()7879inputs_callable = []80for inp in node_traverser.get_inputs():81node_traverser.set_node(inp)82inputs_callable.append(get_input(node_traverser, timer=timer))8384node_traverser.set_node(current_node)85ir_node = node_traverser.view_current_node()86return get_node_converters()[ir_node.__class__](87inputs_callable, ir_node, node_traverser, timer88)8990def run_on_pandas(node_traverser: Any, query_start: int | None) -> None:91timer = Timer(92time.monotonic_ns() - query_start if query_start is not None else None93)94current_node = node_traverser.get_node()9596callback = get_input(node_traverser, timer=timer)9798def run_callback(99columns: list[str] | None,100_: Any,101n_rows: int | None,102should_time: bool,103) -> pl.DataFrame | tuple[pl.DataFrame, list[tuple[int, int, str]]]:104assert n_rows is None105assert columns is None106107# produce a wrong result to ensure the callback has run.108result = pl.from_pandas(callback() * 2)109if should_time:110return result, timer.timings111else:112return result113114node_traverser.set_node(current_node)115node_traverser.set_udf(run_callback)116117# Polars query that will run on pandas118q1 = pl.LazyFrame({"foo": [1, 2, 3]})119q2 = pl.LazyFrame({"foo": [1], "bar": [2]})120q = q1.join(q2, on="foo")121assert q.collect(122post_opt_callback=run_on_pandas # type: ignore[call-overload]123).to_dict(as_series=False) == {124"foo": [2],125"bar": [4],126}127128result, timings = q.profile(post_opt_callback=run_on_pandas)129assert result.to_dict(as_series=False) == {130"foo": [2],131"bar": [4],132}133assert timings["node"].to_list() == [134"optimization",135"pandas-scan",136"pandas-scan",137"pandas-join",138]139140141def test_path_uri_to_python_conversion_22766(tmp_path: Path) -> None:142path = f"file://{tmp_path / 'data.parquet'}"143144df = pl.DataFrame({"x": 1})145df.write_parquet(path)146147q = pl.scan_parquet(path)148149out: list[str] = q._ldf.visit().view_current_node().paths150assert len(out) == 1151152assert out[0].startswith("file://")153154# Windows fails because it turns everything into `\\`155if sys.platform != "win32":156assert out == [path]157158159