Path: blob/main/py-polars/tests/unit/streaming/test_streaming_cse.py
8415 views
from __future__ import annotations12from typing import TYPE_CHECKING34import pytest56import polars as pl7from polars.testing import assert_frame_equal89if TYPE_CHECKING:10from tests.conftest import PlMonkeyPatch1112pytestmark = pytest.mark.xdist_group("streaming")131415def test_cse_expr_selection_streaming(plmonkeypatch: PlMonkeyPatch) -> None:16plmonkeypatch.setenv("POLARS_VERBOSE", "1")17q = pl.LazyFrame(18{19"a": [1, 2, 3, 4],20"b": [1, 2, 3, 4],21"c": [1, 2, 3, 4],22}23)2425derived = pl.col("a") * pl.col("b")26derived2 = derived * derived2728exprs = [29derived.alias("d1"),30derived2.alias("d2"),31(derived2 * 10).alias("d3"),32]3334result = q.select(exprs).collect(35optimizations=pl.QueryOptFlags(comm_subexpr_elim=True), engine="streaming"36)37expected = pl.DataFrame(38{"d1": [1, 4, 9, 16], "d2": [1, 16, 81, 256], "d3": [10, 160, 810, 2560]}39)40assert_frame_equal(result, expected)4142result = q.with_columns(exprs).collect(43optimizations=pl.QueryOptFlags(comm_subexpr_elim=True), engine="streaming"44)45expected = pl.DataFrame(46{47"a": [1, 2, 3, 4],48"b": [1, 2, 3, 4],49"c": [1, 2, 3, 4],50"d1": [1, 4, 9, 16],51"d2": [1, 16, 81, 256],52"d3": [10, 160, 810, 2560],53}54)55assert_frame_equal(result, expected)565758def test_cse_expr_group_by() -> None:59q = pl.LazyFrame(60{61"a": [1, 2, 3, 4],62"b": [1, 2, 3, 4],63"c": [1, 2, 3, 4],64}65)6667derived = pl.col("a") * pl.col("b")6869q = (70q.group_by("a")71.agg(derived.sum().alias("sum"), derived.min().alias("min"))72.sort("min")73)7475assert "__POLARS_CSER" in q.explain(76optimizations=pl.QueryOptFlags(comm_subexpr_elim=True)77)7879expected = pl.DataFrame(80{"a": [1, 2, 3, 4], "sum": [1, 4, 9, 16], "min": [1, 4, 9, 16]}81)82for streaming in [True, False]:83out = q.collect(84optimizations=pl.QueryOptFlags(comm_subexpr_elim=True),85engine="streaming" if streaming else "in-memory",86)87assert_frame_equal(out, expected)888990