Path: blob/main/py-polars/tests/unit/operations/test_over.py
8422 views
from typing import Any12import pytest34import polars as pl5from polars.testing import assert_frame_equal, assert_series_equal678def test_implode_explode_over_22188() -> None:9df = pl.DataFrame(10{11"x": [1, 2, 3, 1, 2, 3, 1, 2, 3],12"y": [2, 2, 2, 3, 3, 3, 4, 4, 4],13}14)15result = df.select(16(pl.col.x * (pl.lit(pl.Series([1, 1, 1])).implode().explode())).over(pl.col.y),17)1819assert_series_equal(result.to_series(), df.get_column("x"))202122def test_implode_in_over_22188() -> None:23df = pl.DataFrame(24{25"x": [[1], [2], [3]],26"y": [2, 3, 4],27}28).select(pl.col.x.list.set_union(pl.lit(pl.Series([1])).implode()).over(pl.col.y))29assert_series_equal(df.to_series(), pl.Series("x", [[1], [2, 1], [3, 1]]))303132def test_over_no_partition_by() -> None:33df = pl.DataFrame({"a": [1, 1, 2], "i": [2, 1, 3]})34result = df.with_columns(b=pl.col("a").cum_sum().over(order_by="i"))35expected = pl.DataFrame({"a": [1, 1, 2], "i": [2, 1, 3], "b": [2, 1, 4]})36assert_frame_equal(result, expected)373839def test_over_no_partition_by_no_over() -> None:40df = pl.DataFrame({"a": [1, 1, 2], "i": [2, 1, 3]})41with pytest.raises(pl.exceptions.InvalidOperationError):42df.with_columns(b=pl.col("a").cum_sum().over())434445def test_over_explode_22770() -> None:46df = pl.DataFrame({"x": [[1.0], [2.0]], "idx": [1, 2]})47e = pl.col("x").list.explode().over("idx", mapping_strategy="join")4849assert_frame_equal(50df.select(pl.col("x").list.diff()),51df.select(e.list.diff()),52)535455def test_over_replace_strict_22870() -> None:56lookup = pl.DataFrame(57{58"cat": ["a", "b", "c"],59"val": [102, 100, 101],60}61)6263df = pl.DataFrame(64{65"cat": ["a", "b", "a", "a", "b"],66"data": [2, 3, 4, 5, 6],67"a": ["a", "b", "c", "d", "e"],68"b": [102, 100, 101, 109, 110],69}70)7172out = (73df.lazy()74.select(75pl.col("cat")76.replace_strict(lookup["cat"], lookup["val"], default=-1)77.alias("val"),78pl.col("cat")79.replace_strict(lookup["cat"], lookup["val"], default=-1)80.over("cat")81.alias("val_over"),82)83.collect()84)85assert_series_equal(86out.get_column("val"), out.get_column("val_over"), check_names=False87)8889out = (90df.lazy()91.select(92pl.col("cat").replace_strict(pl.col.a, pl.col.b, default=-1).alias("val"),93pl.col("cat")94.replace_strict(pl.col.a, pl.col.b, default=-1)95.over("cat")96.alias("val_over"),97)98.collect()99)100assert_series_equal(101out.get_column("val"), out.get_column("val_over"), check_names=False102)103104105@pytest.mark.parametrize(106"col",107[108[1, 2, 3],109[[11, 12], [21], [31]],110],111)112def test_implode_explode_list_over_24616(col: list[Any]) -> None:113df = pl.DataFrame({"x": col})114q = df.lazy().select(pl.col.x.implode().explode().over(1))115q_base = df.lazy().select(pl.col.x.over(1))116expected = df117assert_frame_equal(q.collect(), expected)118assert_frame_equal(q_base.collect(), expected)119120df = pl.DataFrame({"g": [10, 10, 20], "x": col})121q = df.lazy().with_columns(pl.col.x.implode().explode().over("g"))122q_base = df.lazy().with_columns(pl.col.x.over("g"))123expected = df124assert_frame_equal(q.collect(), expected)125assert_frame_equal(q_base.collect(), expected)126127128def test_first_last_over() -> None:129df = pl.DataFrame(130{131"a": [1, 1, 1, 1, 2, 2, 2, 2],132"b": pl.Series([1, 2, 3, None, None, 4, 5, 6], dtype=pl.Int32),133}134)135136result = df.select(pl.col("b").first().over("a"))137expected = pl.DataFrame(138{"b": pl.Series([1, 1, 1, 1, None, None, None, None], dtype=pl.Int32)}139)140assert_frame_equal(result, expected)141142result = df.select(pl.col("b").first(ignore_nulls=True).over("a"))143expected = pl.DataFrame({"b": pl.Series([1, 1, 1, 1, 4, 4, 4, 4], dtype=pl.Int32)})144assert_frame_equal(result, expected)145146result = df.select(pl.col("b").last().over("a"))147expected = pl.DataFrame(148{"b": pl.Series([None, None, None, None, 6, 6, 6, 6], dtype=pl.Int32)}149)150assert_frame_equal(result, expected)151152result = df.select(pl.col("b").last(ignore_nulls=True).over("a"))153expected = pl.DataFrame({"b": pl.Series([3, 3, 3, 3, 6, 6, 6, 6], dtype=pl.Int32)})154assert_frame_equal(result, expected)155156157