Path: blob/main/py-polars/tests/unit/operations/test_over.py
6939 views
import pytest12import polars as pl3from polars.testing import assert_frame_equal, assert_series_equal456def test_implode_explode_over_22188() -> None:7df = pl.DataFrame(8{9"x": [1, 2, 3, 1, 2, 3, 1, 2, 3],10"y": [2, 2, 2, 3, 3, 3, 4, 4, 4],11}12)13result = df.select(14(pl.col.x * (pl.lit(pl.Series([1, 1, 1])).implode().explode())).over(pl.col.y),15)1617assert_series_equal(result.to_series(), df.get_column("x"))181920def test_implode_in_over_22188() -> None:21df = pl.DataFrame(22{23"x": [[1], [2], [3]],24"y": [2, 3, 4],25}26).select(pl.col.x.list.set_union(pl.lit(pl.Series([1])).implode()).over(pl.col.y))27assert_series_equal(df.to_series(), pl.Series("x", [[1], [2, 1], [3, 1]]))282930def test_over_no_partition_by() -> None:31df = pl.DataFrame({"a": [1, 1, 2], "i": [2, 1, 3]})32result = df.with_columns(b=pl.col("a").cum_sum().over(order_by="i"))33expected = pl.DataFrame({"a": [1, 1, 2], "i": [2, 1, 3], "b": [2, 1, 4]})34assert_frame_equal(result, expected)353637def test_over_no_partition_by_no_over() -> None:38df = pl.DataFrame({"a": [1, 1, 2], "i": [2, 1, 3]})39with pytest.raises(pl.exceptions.InvalidOperationError):40df.with_columns(b=pl.col("a").cum_sum().over())414243def test_over_explode_22770() -> None:44df = pl.DataFrame({"x": [[1.0], [2.0]], "idx": [1, 2]})45e = pl.col("x").list.explode().over("idx", mapping_strategy="join")4647assert_frame_equal(48df.select(pl.col("x").list.diff()),49df.select(e.list.diff()),50)515253def test_over_replace_strict_22870() -> None:54lookup = pl.DataFrame(55{56"cat": ["a", "b", "c"],57"val": [102, 100, 101],58}59)6061df = pl.DataFrame(62{63"cat": ["a", "b", "a", "a", "b"],64"data": [2, 3, 4, 5, 6],65"a": ["a", "b", "c", "d", "e"],66"b": [102, 100, 101, 109, 110],67}68)6970out = (71df.lazy()72.select(73pl.col("cat")74.replace_strict(lookup["cat"], lookup["val"], default=-1)75.alias("val"),76pl.col("cat")77.replace_strict(lookup["cat"], lookup["val"], default=-1)78.over("cat")79.alias("val_over"),80)81.collect()82)83assert_series_equal(84out.get_column("val"), out.get_column("val_over"), check_names=False85)8687out = (88df.lazy()89.select(90pl.col("cat").replace_strict(pl.col.a, pl.col.b, default=-1).alias("val"),91pl.col("cat")92.replace_strict(pl.col.a, pl.col.b, default=-1)93.over("cat")94.alias("val_over"),95)96.collect()97)98assert_series_equal(99out.get_column("val"), out.get_column("val_over"), check_names=False100)101102103