Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/lazyframe/cuda/test_node_visitor.py
8424 views
1
from __future__ import annotations
2
3
import json
4
import time
5
from functools import lru_cache, partial
6
from typing import TYPE_CHECKING, Any
7
8
import polars as pl
9
from polars._plr import _ir_nodes
10
from polars._utils.wrap import wrap_df
11
from tests.unit.io.conftest import format_file_uri
12
13
if TYPE_CHECKING:
14
from collections.abc import Callable
15
from pathlib import Path
16
17
import pandas as pd
18
19
20
class Timer:
21
"""Simple-minded timing of nodes."""
22
23
def __init__(self, start: int | None) -> None:
24
self.start = start
25
self.timings: list[tuple[int, int, str]] = []
26
27
def record(self, fn: Callable[[], pd.DataFrame], name: str) -> pd.DataFrame:
28
start = time.monotonic_ns()
29
result = fn()
30
end = time.monotonic_ns()
31
if self.start is not None:
32
self.timings.append((start - self.start, end - self.start, name))
33
return result
34
35
36
def test_run_on_pandas() -> None:
37
# Simple join example, missing multiple columns, slices, etc.
38
def join(
39
inputs: list[Callable[[], pd.DataFrame]],
40
obj: Any,
41
_node_traverser: Any,
42
timer: Timer,
43
) -> Callable[[], pd.DataFrame]:
44
assert len(obj.left_on) == 1
45
assert len(obj.right_on) == 1
46
left_on = obj.left_on[0].output_name
47
right_on = obj.right_on[0].output_name
48
49
assert len(inputs) == 2
50
51
def run(inputs: list[Callable[[], pd.DataFrame]]) -> pd.DataFrame:
52
# materialize inputs
53
dataframes = [call() for call in inputs]
54
return timer.record(
55
lambda: dataframes[0].merge(
56
dataframes[1], left_on=left_on, right_on=right_on
57
),
58
"pandas-join",
59
)
60
61
return partial(run, inputs)
62
63
# Simple scan example, missing predicates, columns pruning, slices, etc.
64
def df_scan(
65
_inputs: None, obj: Any, _: Any, timer: Timer
66
) -> Callable[[], pd.DataFrame]:
67
assert obj.selection is None
68
return lambda: timer.record(lambda: wrap_df(obj.df).to_pandas(), "pandas-scan")
69
70
@lru_cache(1)
71
def get_node_converters() -> dict[
72
type, Callable[[Any, Any, Any, Timer], Callable[[], pd.DataFrame]]
73
]:
74
return {
75
_ir_nodes.Join: join,
76
_ir_nodes.DataFrameScan: df_scan,
77
}
78
79
def get_input(node_traverser: Any, *, timer: Timer) -> Callable[[], pd.DataFrame]:
80
current_node = node_traverser.get_node()
81
82
inputs_callable = []
83
for inp in node_traverser.get_inputs():
84
node_traverser.set_node(inp)
85
inputs_callable.append(get_input(node_traverser, timer=timer))
86
87
node_traverser.set_node(current_node)
88
ir_node = node_traverser.view_current_node()
89
return get_node_converters()[ir_node.__class__](
90
inputs_callable, ir_node, node_traverser, timer
91
)
92
93
def run_on_pandas(node_traverser: Any, query_start: int | None) -> None:
94
timer = Timer(
95
time.monotonic_ns() - query_start if query_start is not None else None
96
)
97
current_node = node_traverser.get_node()
98
99
callback = get_input(node_traverser, timer=timer)
100
101
def run_callback(
102
columns: list[str] | None,
103
_: Any,
104
n_rows: int | None,
105
should_time: bool,
106
) -> pl.DataFrame | tuple[pl.DataFrame, list[tuple[int, int, str]]]:
107
assert n_rows is None
108
assert columns is None
109
110
# produce a wrong result to ensure the callback has run.
111
result = pl.from_pandas(callback() * 2)
112
if should_time:
113
return result, timer.timings
114
else:
115
return result
116
117
node_traverser.set_node(current_node)
118
node_traverser.set_udf(run_callback)
119
120
# Polars query that will run on pandas
121
q1 = pl.LazyFrame({"foo": [1, 2, 3]})
122
q2 = pl.LazyFrame({"foo": [1], "bar": [2]})
123
q = q1.join(q2, on="foo")
124
assert q.collect(
125
post_opt_callback=run_on_pandas # type: ignore[call-overload]
126
).to_dict(as_series=False) == {
127
"foo": [2],
128
"bar": [4],
129
}
130
131
result, timings = q.profile(post_opt_callback=run_on_pandas)
132
assert result.to_dict(as_series=False) == {
133
"foo": [2],
134
"bar": [4],
135
}
136
assert timings["node"].to_list() == [
137
"optimization",
138
"pandas-scan",
139
"pandas-scan",
140
"pandas-join",
141
]
142
143
144
def test_path_uri_to_python_conversion_22766(tmp_path: Path) -> None:
145
path = format_file_uri(f"{tmp_path / 'data.parquet'}")
146
147
df = pl.DataFrame({"x": 1})
148
df.write_parquet(path)
149
150
q = pl.scan_parquet(path)
151
152
out: list[str] = q._ldf.visit().view_current_node().paths
153
assert len(out) == 1
154
155
assert out[0].startswith("file://")
156
assert out == [path]
157
158
159
def test_node_traverse_sink(tmp_path: Path) -> None:
160
def callback(node_traverser: Any, query_start: int | None) -> None:
161
assert list(json.loads(node_traverser.view_current_node().payload)["File"]) == [
162
"target",
163
"file_format",
164
"unified_sink_args",
165
]
166
167
q = pl.LazyFrame({"x": [0, 1, 2]}).sink_parquet(tmp_path / "a", lazy=True)
168
q.collect(
169
post_opt_callback=callback # type: ignore[call-overload]
170
)
171
172