Path: blob/main/py-polars/tests/unit/lazyframe/test_cwc.py
8424 views
# Tests for the optimization pass cluster WITH_COLUMNS12import pytest34import polars as pl5from polars.exceptions import ColumnNotFoundError6from polars.testing import assert_frame_equal789def test_basic_cwc() -> None:10df = (11pl.LazyFrame({"a": [1, 2]})12.with_columns(pl.col("a").alias("b") * 2)13.with_columns(pl.col("a").alias("c") * 3)14.with_columns(pl.col("a").alias("d") * 4)15)1617assert (18"""[[(col("a")) * (2)].alias("b"), [(col("a")) * (3)].alias("c"), [(col("a")) * (4)].alias("d")]"""19in df.explain()20)212223def test_disable_cwc() -> None:24df = (25pl.LazyFrame({"a": [1, 2]})26.with_columns(pl.col("a").alias("b") * 2)27.with_columns(pl.col("a").alias("c") * 3)28.with_columns(pl.col("a").alias("d") * 4)29)3031explain = df.explain(optimizations=pl.QueryOptFlags(cluster_with_columns=False))3233assert """[[(col("a")) * (2)].alias("b")]""" in explain34assert """[[(col("a")) * (3)].alias("c")]""" in explain35assert """[[(col("a")) * (4)].alias("d")]""" in explain363738def test_refuse_with_deps() -> None:39df = (40pl.LazyFrame({"a": [1, 2]})41.with_columns(pl.col("a").alias("b") * 2)42.with_columns(pl.col("b").alias("c") * 3)43.with_columns(pl.col("c").alias("d") * 4)44)4546explain = df.explain()4748assert """[[(col("a")) * (2)].alias("b")]""" in explain49assert """[[(col("b")) * (3)].alias("c")]""" in explain50assert """[[(col("c")) * (4)].alias("d")]""" in explain515253def test_partial_deps() -> None:54df = (55pl.LazyFrame({"a": [1, 2]})56.with_columns(pl.col("a").alias("b") * 2)57.with_columns(58pl.col("a").alias("c") * 3,59pl.col("b").alias("d") * 4,60pl.col("a").alias("e") * 5,61)62.with_columns(pl.col("b").alias("f") * 6)63)6465explain = df.explain()6667assert (68"""[[(col("b")) * (4)].alias("d"), [(col("b")) * (6)].alias("f")]""" in explain69)70assert (71"""[[(col("a")) * (2)].alias("b"), [(col("a")) * (3)].alias("c"), [(col("a")) * (5)].alias("e")]"""72in explain73)747576def test_swap_remove() -> None:77df = (78pl.LazyFrame({"a": [1, 2]})79.with_columns(pl.col("a").alias("b") * 2)80.with_columns(81pl.col("b").alias("f") * 6,82pl.col("a").alias("c") * 3,83pl.col("b").alias("d") * 4,84pl.col("b").alias("e") * 5,85)86)8788explain = df.explain()89assert df.collect().equals(90pl.DataFrame(91{92"a": [1, 2],93"b": [2, 4],94"f": [12, 24],95"c": [3, 6],96"d": [8, 16],97"e": [10, 20],98}99)100)101102assert (103"""[[(col("b")) * (6)].alias("f"), [(col("b")) * (4)].alias("d"), [(col("b")) * (5)].alias("e")]"""104in explain105)106assert (107"""[[(col("a")) * (2)].alias("b"), [(col("a")) * (3)].alias("c")]""" in explain108)109assert """simple π""" in explain110111112def test_try_remove_simple_project() -> None:113q = (114pl.LazyFrame({"a": [1, 2]})115.with_columns(pl.col("a").alias("b") * 2)116.with_columns(pl.col("a").alias("d") * 4, pl.col("b").alias("c") * 3)117)118119assert_frame_equal(120q.collect(),121pl.DataFrame(122[123pl.Series("a", [1, 2], dtype=pl.Int64),124pl.Series("b", [2, 4], dtype=pl.Int64),125pl.Series("d", [4, 8], dtype=pl.Int64),126pl.Series("c", [6, 12], dtype=pl.Int64),127]128),129)130131plan = q.explain()132133assert """[[(col("a")) * (2)].alias("b"), [(col("a")) * (4)].alias("d")]""" in plan134assert """[[(col("b")) * (3)].alias("c")]""" in plan135assert """simple π""" not in plan136137q = (138pl.LazyFrame({"a": [1, 2]})139.with_columns(pl.col("a").alias("b") * 2)140.with_columns(pl.col("b").alias("c") * 3, pl.col("a").alias("d") * 4)141)142143assert_frame_equal(144q.collect(),145pl.DataFrame(146[147pl.Series("a", [1, 2], dtype=pl.Int64),148pl.Series("b", [2, 4], dtype=pl.Int64),149pl.Series("c", [6, 12], dtype=pl.Int64),150pl.Series("d", [4, 8], dtype=pl.Int64),151]152),153)154155plan = q.explain()156157assert """[[(col("a")) * (2)].alias("b"), [(col("a")) * (4)].alias("d")]""" in plan158assert """[[(col("b")) * (3)].alias("c")]""" in plan159assert """simple π""" in plan160161162def test_cwc_with_internal_aliases() -> None:163df = (164pl.LazyFrame({"a": [1, 2], "b": [3, 4]})165.with_columns(pl.any_horizontal((pl.col("a") == 2).alias("b")).alias("c"))166.with_columns(pl.col("b").alias("d") * 3)167)168169explain = df.explain()170171assert (172"""[[(col("a")) == (2)].alias("c"), [(col("b")) * (3)].alias("d")]""" in explain173)174175176def test_read_of_pushed_column_16436() -> None:177df = pl.DataFrame(178{179"x": [1.12, 2.21, 4.2, 3.21],180"y": [2.11, 3.32, 2.1, 6.12],181}182)183184df = (185df.lazy()186.with_columns((pl.col("y") / pl.col("x")).alias("z"))187.with_columns(188pl.when(pl.col("z").is_infinite()).then(0).otherwise(pl.col("z")).alias("z")189)190.fill_nan(0)191.collect()192)193194195def test_multiple_simple_projections_16435() -> None:196df = pl.DataFrame({"a": [1]}).lazy()197198df = (199df.with_columns(b=pl.col("a"))200.with_columns(c=pl.col("b"))201.with_columns(l2a=pl.lit(2))202.with_columns(l2b=pl.col("l2a"))203.with_columns(m=pl.lit(3))204)205206df.collect()207208209def test_reverse_order() -> None:210df = pl.LazyFrame({"a": [1], "b": [2]})211212df = (213df.with_columns(a=pl.col("a"), b=pl.col("b"), c=pl.col("a") * pl.col("b"))214.with_columns(x=pl.col("a"), y=pl.col("b"))215.with_columns(b=pl.col("a"), a=pl.col("b"))216)217218df.collect()219220221def test_realias_of_unread_column_16530() -> None:222df = (223pl.LazyFrame({"x": [True]})224.with_columns(x=pl.lit(False))225.with_columns(y=~pl.col("x"))226.with_columns(y=pl.lit(False))227)228229plan = df.explain()230231assert plan.count("WITH_COLUMNS") == 1232assert df.collect().equals(pl.DataFrame({"x": [False], "y": [False]}))233234235def test_realias_with_dependencies() -> None:236df = (237pl.LazyFrame({"x": [True]})238.with_columns(x=pl.lit(False))239.with_columns(y=~pl.col("x"))240.with_columns(y=pl.lit(False), z=pl.col("y") | True)241)242243explain = df.explain()244245assert explain.count("WITH_COLUMNS") == 3246assert df.collect().equals(pl.DataFrame({"x": [False], "y": [False], "z": [True]}))247248249def test_refuse_pushdown_with_aliases() -> None:250df = (251pl.LazyFrame({"x": [True]})252.with_columns(x=pl.lit(False))253.with_columns(y=pl.lit(True))254.with_columns(y=pl.lit(False), z=pl.col("y") | True)255)256257explain = df.explain()258259assert explain.count("WITH_COLUMNS") == 2260assert df.collect().equals(pl.DataFrame({"x": [False], "y": [False], "z": [True]}))261262263def test_neighbour_live_expr() -> None:264df = (265pl.LazyFrame({"x": [True]})266.with_columns(y=pl.lit(False))267.with_columns(x=pl.lit(False), z=pl.col("x") | False)268)269270explain = df.explain()271272assert explain.count("WITH_COLUMNS") == 1273assert df.collect().equals(pl.DataFrame({"x": [False], "y": [False], "z": [True]}))274275276def test_cluster_with_columns_collect_all_panic_26092() -> None:277lf = pl.LazyFrame()278lf = lf.with_columns(pl.lit(1.0).cast(pl.Float64()).alias("numbers1"))279lf = lf.with_columns(pl.lit(2.0).cast(pl.Float64()).alias("numbers2"))280281a, b = pl.collect_all([lf, lf])282283assert_frame_equal(a, pl.DataFrame({"numbers1": 1.0, "numbers2": 2.0}))284assert_frame_equal(b, pl.DataFrame({"numbers1": 1.0, "numbers2": 2.0}))285286287def test_cluster_with_columns_schema_update_26417() -> None:288lf = pl.LazyFrame({"x": [[0.0, 1.0]], "y": [[2.0]]})289290q = (291lf.with_columns(pl.col("x").cast(pl.Array(pl.Float64, shape=2)))292.with_columns(pl.col("y").cast(pl.Array(pl.Float64, shape=1)))293.with_columns(pl.col("y").arr.get(0))294)295296assert_frame_equal(297q.collect(),298pl.DataFrame(299[300pl.Series("x", [[0.0, 1.0]], dtype=pl.Array(pl.Float64, shape=(2,))),301pl.Series("y", [2.0], dtype=pl.Float64),302]303),304)305306307def test_cluster_with_columns_use_existing_names_26456() -> None:308q = (309pl.LazyFrame({"a": [1, 2, 3]})310.with_columns(pl.lit(1).alias("b"))311.with_columns(pl.col("a") + 1, pl.col("b") + pl.col("a"))312)313314assert_frame_equal(315q.collect(),316pl.DataFrame(317[318pl.Series("a", [2, 3, 4], dtype=pl.Int64),319pl.Series("b", [2, 3, 4], dtype=pl.Int64),320]321),322)323324325def test_cluster_with_columns_prune_col() -> None:326q = (327pl.LazyFrame({"foo": [0.5, 1.7, 3.2], "bar": [4.1, 1.5, 9.2]})328.with_columns(pl.col("foo").alias("buzz"))329.with_columns(pl.col("buzz"), pl.col("foo") * 2.0)330)331332plan = q.explain()333334assert plan.count("WITH_COLUMNS") == 1335336assert_frame_equal(337q.collect(),338pl.DataFrame(339[340pl.Series("foo", [1.0, 3.4, 6.4], dtype=pl.Float64),341pl.Series("bar", [4.1, 1.5, 9.2], dtype=pl.Float64),342pl.Series("buzz", [0.5, 1.7, 3.2], dtype=pl.Float64),343]344),345)346347q = pl.LazyFrame({"a": 1}).with_columns(pl.col("a")).with_columns(pl.col("b"))348349with pytest.raises(ColumnNotFoundError, match='unable to find column "b"'):350q.collect()351352353