Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/streaming/test_streaming_cse.py
8415 views
1
from __future__ import annotations
2
3
from typing import TYPE_CHECKING
4
5
import pytest
6
7
import polars as pl
8
from polars.testing import assert_frame_equal
9
10
if TYPE_CHECKING:
11
from tests.conftest import PlMonkeyPatch
12
13
pytestmark = pytest.mark.xdist_group("streaming")
14
15
16
def test_cse_expr_selection_streaming(plmonkeypatch: PlMonkeyPatch) -> None:
17
plmonkeypatch.setenv("POLARS_VERBOSE", "1")
18
q = pl.LazyFrame(
19
{
20
"a": [1, 2, 3, 4],
21
"b": [1, 2, 3, 4],
22
"c": [1, 2, 3, 4],
23
}
24
)
25
26
derived = pl.col("a") * pl.col("b")
27
derived2 = derived * derived
28
29
exprs = [
30
derived.alias("d1"),
31
derived2.alias("d2"),
32
(derived2 * 10).alias("d3"),
33
]
34
35
result = q.select(exprs).collect(
36
optimizations=pl.QueryOptFlags(comm_subexpr_elim=True), engine="streaming"
37
)
38
expected = pl.DataFrame(
39
{"d1": [1, 4, 9, 16], "d2": [1, 16, 81, 256], "d3": [10, 160, 810, 2560]}
40
)
41
assert_frame_equal(result, expected)
42
43
result = q.with_columns(exprs).collect(
44
optimizations=pl.QueryOptFlags(comm_subexpr_elim=True), engine="streaming"
45
)
46
expected = pl.DataFrame(
47
{
48
"a": [1, 2, 3, 4],
49
"b": [1, 2, 3, 4],
50
"c": [1, 2, 3, 4],
51
"d1": [1, 4, 9, 16],
52
"d2": [1, 16, 81, 256],
53
"d3": [10, 160, 810, 2560],
54
}
55
)
56
assert_frame_equal(result, expected)
57
58
59
def test_cse_expr_group_by() -> None:
60
q = pl.LazyFrame(
61
{
62
"a": [1, 2, 3, 4],
63
"b": [1, 2, 3, 4],
64
"c": [1, 2, 3, 4],
65
}
66
)
67
68
derived = pl.col("a") * pl.col("b")
69
70
q = (
71
q.group_by("a")
72
.agg(derived.sum().alias("sum"), derived.min().alias("min"))
73
.sort("min")
74
)
75
76
assert "__POLARS_CSER" in q.explain(
77
optimizations=pl.QueryOptFlags(comm_subexpr_elim=True)
78
)
79
80
expected = pl.DataFrame(
81
{"a": [1, 2, 3, 4], "sum": [1, 4, 9, 16], "min": [1, 4, 9, 16]}
82
)
83
for streaming in [True, False]:
84
out = q.collect(
85
optimizations=pl.QueryOptFlags(comm_subexpr_elim=True),
86
engine="streaming" if streaming else "in-memory",
87
)
88
assert_frame_equal(out, expected)
89
90