Path: blob/main/py-polars/tests/unit/lazyframe/cuda/test_node_visitor.py
8424 views
from __future__ import annotations12import json3import time4from functools import lru_cache, partial5from typing import TYPE_CHECKING, Any67import polars as pl8from polars._plr import _ir_nodes9from polars._utils.wrap import wrap_df10from tests.unit.io.conftest import format_file_uri1112if TYPE_CHECKING:13from collections.abc import Callable14from pathlib import Path1516import pandas as pd171819class Timer:20"""Simple-minded timing of nodes."""2122def __init__(self, start: int | None) -> None:23self.start = start24self.timings: list[tuple[int, int, str]] = []2526def record(self, fn: Callable[[], pd.DataFrame], name: str) -> pd.DataFrame:27start = time.monotonic_ns()28result = fn()29end = time.monotonic_ns()30if self.start is not None:31self.timings.append((start - self.start, end - self.start, name))32return result333435def test_run_on_pandas() -> None:36# Simple join example, missing multiple columns, slices, etc.37def join(38inputs: list[Callable[[], pd.DataFrame]],39obj: Any,40_node_traverser: Any,41timer: Timer,42) -> Callable[[], pd.DataFrame]:43assert len(obj.left_on) == 144assert len(obj.right_on) == 145left_on = obj.left_on[0].output_name46right_on = obj.right_on[0].output_name4748assert len(inputs) == 24950def run(inputs: list[Callable[[], pd.DataFrame]]) -> pd.DataFrame:51# materialize inputs52dataframes = [call() for call in inputs]53return timer.record(54lambda: dataframes[0].merge(55dataframes[1], left_on=left_on, right_on=right_on56),57"pandas-join",58)5960return partial(run, inputs)6162# Simple scan example, missing predicates, columns pruning, slices, etc.63def df_scan(64_inputs: None, obj: Any, _: Any, timer: Timer65) -> Callable[[], pd.DataFrame]:66assert obj.selection is None67return lambda: timer.record(lambda: wrap_df(obj.df).to_pandas(), "pandas-scan")6869@lru_cache(1)70def get_node_converters() -> dict[71type, Callable[[Any, Any, Any, Timer], Callable[[], pd.DataFrame]]72]:73return {74_ir_nodes.Join: join,75_ir_nodes.DataFrameScan: df_scan,76}7778def get_input(node_traverser: Any, *, timer: Timer) -> Callable[[], pd.DataFrame]:79current_node = node_traverser.get_node()8081inputs_callable = []82for inp in node_traverser.get_inputs():83node_traverser.set_node(inp)84inputs_callable.append(get_input(node_traverser, timer=timer))8586node_traverser.set_node(current_node)87ir_node = node_traverser.view_current_node()88return get_node_converters()[ir_node.__class__](89inputs_callable, ir_node, node_traverser, timer90)9192def run_on_pandas(node_traverser: Any, query_start: int | None) -> None:93timer = Timer(94time.monotonic_ns() - query_start if query_start is not None else None95)96current_node = node_traverser.get_node()9798callback = get_input(node_traverser, timer=timer)99100def run_callback(101columns: list[str] | None,102_: Any,103n_rows: int | None,104should_time: bool,105) -> pl.DataFrame | tuple[pl.DataFrame, list[tuple[int, int, str]]]:106assert n_rows is None107assert columns is None108109# produce a wrong result to ensure the callback has run.110result = pl.from_pandas(callback() * 2)111if should_time:112return result, timer.timings113else:114return result115116node_traverser.set_node(current_node)117node_traverser.set_udf(run_callback)118119# Polars query that will run on pandas120q1 = pl.LazyFrame({"foo": [1, 2, 3]})121q2 = pl.LazyFrame({"foo": [1], "bar": [2]})122q = q1.join(q2, on="foo")123assert q.collect(124post_opt_callback=run_on_pandas # type: ignore[call-overload]125).to_dict(as_series=False) == {126"foo": [2],127"bar": [4],128}129130result, timings = q.profile(post_opt_callback=run_on_pandas)131assert result.to_dict(as_series=False) == {132"foo": [2],133"bar": [4],134}135assert timings["node"].to_list() == [136"optimization",137"pandas-scan",138"pandas-scan",139"pandas-join",140]141142143def test_path_uri_to_python_conversion_22766(tmp_path: Path) -> None:144path = format_file_uri(f"{tmp_path / 'data.parquet'}")145146df = pl.DataFrame({"x": 1})147df.write_parquet(path)148149q = pl.scan_parquet(path)150151out: list[str] = q._ldf.visit().view_current_node().paths152assert len(out) == 1153154assert out[0].startswith("file://")155assert out == [path]156157158def test_node_traverse_sink(tmp_path: Path) -> None:159def callback(node_traverser: Any, query_start: int | None) -> None:160assert list(json.loads(node_traverser.view_current_node().payload)["File"]) == [161"target",162"file_format",163"unified_sink_args",164]165166q = pl.LazyFrame({"x": [0, 1, 2]}).sink_parquet(tmp_path / "a", lazy=True)167q.collect(168post_opt_callback=callback # type: ignore[call-overload]169)170171172