Path: blob/main/py-polars/tests/unit/streaming/test_streaming_cse.py
6939 views
from __future__ import annotations12from typing import Any34import pytest56import polars as pl7from polars.testing import assert_frame_equal89pytestmark = pytest.mark.xdist_group("streaming")101112def test_cse_expr_selection_streaming(monkeypatch: Any) -> None:13monkeypatch.setenv("POLARS_VERBOSE", "1")14q = pl.LazyFrame(15{16"a": [1, 2, 3, 4],17"b": [1, 2, 3, 4],18"c": [1, 2, 3, 4],19}20)2122derived = pl.col("a") * pl.col("b")23derived2 = derived * derived2425exprs = [26derived.alias("d1"),27derived2.alias("d2"),28(derived2 * 10).alias("d3"),29]3031result = q.select(exprs).collect(32optimizations=pl.QueryOptFlags(comm_subexpr_elim=True), engine="streaming"33)34expected = pl.DataFrame(35{"d1": [1, 4, 9, 16], "d2": [1, 16, 81, 256], "d3": [10, 160, 810, 2560]}36)37assert_frame_equal(result, expected)3839result = q.with_columns(exprs).collect(40optimizations=pl.QueryOptFlags(comm_subexpr_elim=True), engine="streaming"41)42expected = pl.DataFrame(43{44"a": [1, 2, 3, 4],45"b": [1, 2, 3, 4],46"c": [1, 2, 3, 4],47"d1": [1, 4, 9, 16],48"d2": [1, 16, 81, 256],49"d3": [10, 160, 810, 2560],50}51)52assert_frame_equal(result, expected)535455def test_cse_expr_group_by() -> None:56q = pl.LazyFrame(57{58"a": [1, 2, 3, 4],59"b": [1, 2, 3, 4],60"c": [1, 2, 3, 4],61}62)6364derived = pl.col("a") * pl.col("b")6566q = (67q.group_by("a")68.agg(derived.sum().alias("sum"), derived.min().alias("min"))69.sort("min")70)7172assert "__POLARS_CSER" in q.explain(73optimizations=pl.QueryOptFlags(comm_subexpr_elim=True)74)7576expected = pl.DataFrame(77{"a": [1, 2, 3, 4], "sum": [1, 4, 9, 16], "min": [1, 4, 9, 16]}78)79for streaming in [True, False]:80out = q.collect(81optimizations=pl.QueryOptFlags(comm_subexpr_elim=True),82engine="streaming" if streaming else "in-memory",83)84assert_frame_equal(out, expected)858687