Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/lazyframe/cuda/test_node_visitor.py
6940 views
1
from __future__ import annotations
2
3
import sys
4
import time
5
from functools import lru_cache, partial
6
from typing import TYPE_CHECKING, Any, Callable
7
8
import polars as pl
9
from polars._plr import _ir_nodes
10
from polars._utils.wrap import wrap_df
11
12
if TYPE_CHECKING:
13
from pathlib import Path
14
15
import pandas as pd
16
17
18
class Timer:
19
"""Simple-minded timing of nodes."""
20
21
def __init__(self, start: int | None) -> None:
22
self.start = start
23
self.timings: list[tuple[int, int, str]] = []
24
25
def record(self, fn: Callable[[], pd.DataFrame], name: str) -> pd.DataFrame:
26
start = time.monotonic_ns()
27
result = fn()
28
end = time.monotonic_ns()
29
if self.start is not None:
30
self.timings.append((start - self.start, end - self.start, name))
31
return result
32
33
34
def test_run_on_pandas() -> None:
35
# Simple join example, missing multiple columns, slices, etc.
36
def join(
37
inputs: list[Callable[[], pd.DataFrame]],
38
obj: Any,
39
_node_traverser: Any,
40
timer: Timer,
41
) -> Callable[[], pd.DataFrame]:
42
assert len(obj.left_on) == 1
43
assert len(obj.right_on) == 1
44
left_on = obj.left_on[0].output_name
45
right_on = obj.right_on[0].output_name
46
47
assert len(inputs) == 2
48
49
def run(inputs: list[Callable[[], pd.DataFrame]]) -> pd.DataFrame:
50
# materialize inputs
51
dataframes = [call() for call in inputs]
52
return timer.record(
53
lambda: dataframes[0].merge(
54
dataframes[1], left_on=left_on, right_on=right_on
55
),
56
"pandas-join",
57
)
58
59
return partial(run, inputs)
60
61
# Simple scan example, missing predicates, columns pruning, slices, etc.
62
def df_scan(
63
_inputs: None, obj: Any, _: Any, timer: Timer
64
) -> Callable[[], pd.DataFrame]:
65
assert obj.selection is None
66
return lambda: timer.record(lambda: wrap_df(obj.df).to_pandas(), "pandas-scan")
67
68
@lru_cache(1)
69
def get_node_converters() -> dict[
70
type, Callable[[Any, Any, Any, Timer], Callable[[], pd.DataFrame]]
71
]:
72
return {
73
_ir_nodes.Join: join,
74
_ir_nodes.DataFrameScan: df_scan,
75
}
76
77
def get_input(node_traverser: Any, *, timer: Timer) -> Callable[[], pd.DataFrame]:
78
current_node = node_traverser.get_node()
79
80
inputs_callable = []
81
for inp in node_traverser.get_inputs():
82
node_traverser.set_node(inp)
83
inputs_callable.append(get_input(node_traverser, timer=timer))
84
85
node_traverser.set_node(current_node)
86
ir_node = node_traverser.view_current_node()
87
return get_node_converters()[ir_node.__class__](
88
inputs_callable, ir_node, node_traverser, timer
89
)
90
91
def run_on_pandas(node_traverser: Any, query_start: int | None) -> None:
92
timer = Timer(
93
time.monotonic_ns() - query_start if query_start is not None else None
94
)
95
current_node = node_traverser.get_node()
96
97
callback = get_input(node_traverser, timer=timer)
98
99
def run_callback(
100
columns: list[str] | None,
101
_: Any,
102
n_rows: int | None,
103
should_time: bool,
104
) -> pl.DataFrame | tuple[pl.DataFrame, list[tuple[int, int, str]]]:
105
assert n_rows is None
106
assert columns is None
107
108
# produce a wrong result to ensure the callback has run.
109
result = pl.from_pandas(callback() * 2)
110
if should_time:
111
return result, timer.timings
112
else:
113
return result
114
115
node_traverser.set_node(current_node)
116
node_traverser.set_udf(run_callback)
117
118
# Polars query that will run on pandas
119
q1 = pl.LazyFrame({"foo": [1, 2, 3]})
120
q2 = pl.LazyFrame({"foo": [1], "bar": [2]})
121
q = q1.join(q2, on="foo")
122
assert q.collect(
123
post_opt_callback=run_on_pandas # type: ignore[call-overload]
124
).to_dict(as_series=False) == {
125
"foo": [2],
126
"bar": [4],
127
}
128
129
result, timings = q.profile(post_opt_callback=run_on_pandas)
130
assert result.to_dict(as_series=False) == {
131
"foo": [2],
132
"bar": [4],
133
}
134
assert timings["node"].to_list() == [
135
"optimization",
136
"pandas-scan",
137
"pandas-scan",
138
"pandas-join",
139
]
140
141
142
def test_path_uri_to_python_conversion_22766(tmp_path: Path) -> None:
143
path = f"file://{tmp_path / 'data.parquet'}"
144
145
df = pl.DataFrame({"x": 1})
146
df.write_parquet(path)
147
148
q = pl.scan_parquet(path)
149
150
out: list[str] = q._ldf.visit().view_current_node().paths
151
assert len(out) == 1
152
153
assert out[0].startswith("file://")
154
155
# Windows fails because it turns everything into `\\`
156
if sys.platform != "win32":
157
assert out == [path]
158
159