Path: blob/main/py-polars/tests/unit/streaming/test_streaming_group_by.py
8420 views
from __future__ import annotations12from datetime import date3from typing import TYPE_CHECKING, Any45import numpy as np6import pytest78import polars as pl9from polars.exceptions import DuplicateError10from polars.testing import assert_frame_equal11from tests.unit.conftest import INTEGER_DTYPES1213if TYPE_CHECKING:14from pathlib import Path1516from tests.conftest import PlMonkeyPatch1718pytestmark = pytest.mark.xdist_group("streaming")192021@pytest.mark.slow22def test_streaming_group_by_sorted_fast_path_nulls_10273() -> None:23df = pl.Series(24name="x",25values=(26*(i for i in range(4) for _ in range(100)),27*(None for _ in range(100)),28),29).to_frame()3031assert (32df.set_sorted("x")33.lazy()34.group_by("x")35.agg(pl.len())36.collect(engine="streaming")37.sort("x")38).to_dict(as_series=False) == {39"x": [None, 0, 1, 2, 3],40"len": [100, 100, 100, 100, 100],41}424344def test_streaming_group_by_types() -> None:45df = pl.DataFrame(46{47"person_id": [1, 1],48"year": [1995, 1995],49"person_name": ["bob", "foo"],50"bool": [True, False],51"date": [date(2022, 1, 1), date(2022, 1, 1)],52}53)5455for by in ["person_id", "year", "date", ["person_id", "year"]]:56out = (57(58df.lazy()59.group_by(by)60.agg(61[62pl.col("person_name").first().alias("str_first"),63pl.col("person_name").last().alias("str_last"),64pl.col("bool").first().alias("bool_first"),65pl.col("bool").last().alias("bool_last"),66pl.col("bool").mean().alias("bool_mean"),67pl.col("bool").sum().alias("bool_sum"),68# pl.col("date").sum().alias("date_sum"),69# Date streaming mean/median has been temporarily disabled70# pl.col("date").mean().alias("date_mean"),71pl.col("date").first().alias("date_first"),72pl.col("date").last().alias("date_last"),73pl.col("date").min().alias("date_min"),74pl.col("date").max().alias("date_max"),75]76)77)78.select(pl.all().exclude(by))79.collect(engine="streaming")80)81assert out.schema == {82"str_first": pl.String,83"str_last": pl.String,84"bool_first": pl.Boolean,85"bool_last": pl.Boolean,86"bool_mean": pl.Float64,87"bool_sum": pl.get_index_type(),88# "date_sum": pl.Date,89# "date_mean": pl.Date,90"date_first": pl.Date,91"date_last": pl.Date,92"date_min": pl.Date,93"date_max": pl.Date,94}9596assert out.to_dict(as_series=False) == {97"str_first": ["bob"],98"str_last": ["foo"],99"bool_first": [True],100"bool_last": [False],101"bool_mean": [0.5],102"bool_sum": [1],103# "date_sum": [None],104# Date streaming mean/median has been temporarily disabled105# "date_mean": [date(2022, 1, 1)],106"date_first": [date(2022, 1, 1)],107"date_last": [date(2022, 1, 1)],108"date_min": [date(2022, 1, 1)],109"date_max": [date(2022, 1, 1)],110}111112with pytest.raises(DuplicateError):113(114df.lazy()115.group_by("person_id")116.agg(117[118pl.col("person_name").first().alias("str_first"),119pl.col("person_name").last().alias("str_last"),120pl.col("person_name").mean().alias("str_mean"),121pl.col("bool").first().alias("bool_first"),122pl.col("bool").last().alias("bool_first"),123]124)125.select(pl.all().exclude("person_id"))126.collect(engine="streaming")127)128129130def test_streaming_group_by_min_max() -> None:131df = pl.DataFrame(132{133"person_id": [1, 2, 3, 4, 5, 6],134"year": [1995, 1995, 1995, 2, 2, 2],135}136)137out = (138df.lazy()139.group_by("year")140.agg([pl.min("person_id").alias("min"), pl.max("person_id").alias("max")])141.collect()142.sort("year")143)144assert out["min"].to_list() == [4, 1]145assert out["max"].to_list() == [6, 3]146147148def test_streaming_non_streaming_gb() -> None:149n = 100150df = pl.DataFrame({"a": np.random.randint(0, 20, n)})151q = df.lazy().group_by("a").agg(pl.len()).sort("a")152assert_frame_equal(q.collect(engine="streaming"), q.collect())153154q = df.lazy().with_columns(pl.col("a").cast(pl.String))155q = q.group_by("a").agg(pl.len()).sort("a")156assert_frame_equal(q.collect(engine="streaming"), q.collect())157q = df.lazy().with_columns(pl.col("a").alias("b"))158q = q.group_by(["a", "b"]).agg(pl.len(), pl.col("a").sum().alias("sum_a")).sort("a")159assert_frame_equal(q.collect(engine="streaming"), q.collect())160161162def test_streaming_group_by_sorted_fast_path() -> None:163a = np.random.randint(0, 20, 80)164df = pl.DataFrame(165{166# test on int8 as that also tests proper conversions167"a": pl.Series(np.sort(a), dtype=pl.Int8)168}169).with_row_index()170171df_sorted = df.with_columns(pl.col("a").set_sorted())172173for streaming in [True, False]:174results = []175for df_ in [df, df_sorted]:176out = (177df_.lazy()178.group_by("a")179.agg(180[181pl.first("a").alias("first"),182pl.last("a").alias("last"),183pl.sum("a").alias("sum"),184pl.mean("a").alias("mean"),185pl.count("a").alias("count"),186pl.min("a").alias("min"),187pl.max("a").alias("max"),188]189)190.sort("a")191.collect(engine="streaming" if streaming else "in-memory")192)193results.append(out)194195assert_frame_equal(results[0], results[1])196197198@pytest.fixture(scope="module")199def random_integers() -> pl.Series:200np.random.seed(1)201return pl.Series("a", np.random.randint(0, 10, 100), dtype=pl.Int64)202203204@pytest.mark.write_disk205def test_streaming_group_by_ooc_q1(206random_integers: pl.Series,207tmp_path: Path,208plmonkeypatch: PlMonkeyPatch,209) -> None:210tmp_path.mkdir(exist_ok=True)211plmonkeypatch.setenv("POLARS_TEMP_DIR", str(tmp_path))212plmonkeypatch.setenv("POLARS_FORCE_OOC", "1")213214lf = random_integers.to_frame().lazy()215result = (216lf.group_by("a")217.agg(pl.first("a").alias("a_first"), pl.last("a").alias("a_last"))218.sort("a")219.collect(engine="streaming")220)221222expected = pl.DataFrame(223{224"a": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],225"a_first": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],226"a_last": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],227}228)229assert_frame_equal(result, expected)230231232@pytest.mark.write_disk233def test_streaming_group_by_ooc_q2(234random_integers: pl.Series,235tmp_path: Path,236plmonkeypatch: PlMonkeyPatch,237) -> None:238tmp_path.mkdir(exist_ok=True)239plmonkeypatch.setenv("POLARS_TEMP_DIR", str(tmp_path))240plmonkeypatch.setenv("POLARS_FORCE_OOC", "1")241242lf = random_integers.cast(str).to_frame().lazy()243result = (244lf.group_by("a")245.agg(pl.first("a").alias("a_first"), pl.last("a").alias("a_last"))246.sort("a")247.collect(engine="streaming")248)249250expected = pl.DataFrame(251{252"a": ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"],253"a_first": ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"],254"a_last": ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"],255}256)257assert_frame_equal(result, expected)258259260@pytest.mark.write_disk261def test_streaming_group_by_ooc_q3(262random_integers: pl.Series,263tmp_path: Path,264plmonkeypatch: PlMonkeyPatch,265) -> None:266tmp_path.mkdir(exist_ok=True)267plmonkeypatch.setenv("POLARS_TEMP_DIR", str(tmp_path))268plmonkeypatch.setenv("POLARS_FORCE_OOC", "1")269270lf = pl.LazyFrame({"a": random_integers, "b": random_integers})271result = (272lf.group_by("a", "b")273.agg(pl.first("a").alias("a_first"), pl.last("a").alias("a_last"))274.sort("a")275.collect(engine="streaming")276)277278expected = pl.DataFrame(279{280"a": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],281"b": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],282"a_first": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],283"a_last": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],284}285)286assert_frame_equal(result, expected)287288289def test_streaming_group_by_struct_key() -> None:290df = pl.DataFrame(291{"A": [1, 2, 3, 2], "B": ["google", "ms", "apple", "ms"], "C": [2, 3, 4, 3]}292)293df1 = df.lazy().with_columns(pl.struct(["A", "C"]).alias("tuples"))294assert df1.group_by("tuples").agg(pl.len(), pl.col("B").first()).sort("B").collect(295engine="streaming"296).to_dict(as_series=False) == {297"tuples": [{"A": 3, "C": 4}, {"A": 1, "C": 2}, {"A": 2, "C": 3}],298"len": [1, 1, 2],299"B": ["apple", "google", "ms"],300}301302303@pytest.mark.slow304def test_streaming_group_by_all_numeric_types_stability_8570() -> None:305m = 1000306n = 1000307308rng = np.random.default_rng(seed=0)309dfa = pl.DataFrame({"x": pl.arange(start=0, end=n, eager=True)})310dfb = pl.DataFrame(311{312"y": rng.integers(low=0, high=10, size=m),313"z": rng.integers(low=0, high=2, size=m),314}315)316dfc = dfa.join(dfb, how="cross")317318for keys in [["x", "y"], "z"]:319for dtype in [*INTEGER_DTYPES, pl.Boolean]:320# the alias checks if the schema is correctly handled321dfd = (322dfc.lazy()323.with_columns(pl.col("z").cast(dtype))324.group_by(keys)325.agg(pl.col("z").sum().alias("z_sum"))326.collect(engine="streaming")327)328assert dfd["z_sum"].sum() == dfc["z"].sum()329330331def test_streaming_group_by_categorical_aggregate() -> None:332out = (333pl.LazyFrame(334{335"a": pl.Series(336["a", "a", "b", "b", "c", "c", None, None], dtype=pl.Categorical337),338"b": pl.Series(339pl.date_range(340date(2023, 4, 28),341date(2023, 5, 5),342eager=True,343).to_list(),344dtype=pl.Date,345),346}347)348.group_by(["a", "b"])349.agg([pl.col("a").first().alias("sum")])350.collect(engine="streaming")351)352353assert out.sort("b").to_dict(as_series=False) == {354"a": ["a", "a", "b", "b", "c", "c", None, None],355"b": [356date(2023, 4, 28),357date(2023, 4, 29),358date(2023, 4, 30),359date(2023, 5, 1),360date(2023, 5, 2),361date(2023, 5, 3),362date(2023, 5, 4),363date(2023, 5, 5),364],365"sum": ["a", "a", "b", "b", "c", "c", None, None],366}367368369def test_streaming_group_by_list_9758() -> None:370payload = {"a": [[1, 2]]}371assert (372pl.LazyFrame(payload)373.group_by("a")374.first()375.collect(engine="streaming")376.to_dict(as_series=False)377== payload378)379380381def test_group_by_min_max_string_type() -> None:382table = pl.from_dict({"a": [1, 1, 2, 2, 2], "b": ["a", "b", "c", "d", None]})383384expected = {"a": [1, 2], "min": ["a", "c"], "max": ["b", "d"]}385386for streaming in [True, False]:387assert (388table.lazy()389.group_by("a")390.agg([pl.min("b").alias("min"), pl.max("b").alias("max")])391.collect(engine="streaming" if streaming else "in-memory")392.sort("a")393.to_dict(as_series=False)394== expected395)396397398@pytest.mark.parametrize("literal", [True, "foo", 1])399def test_streaming_group_by_literal(literal: Any) -> None:400df = pl.LazyFrame({"a": range(20)})401402assert df.group_by(pl.lit(literal)).agg(403[404pl.col("a").count().alias("a_count"),405pl.col("a").sum().alias("a_sum"),406]407).collect(engine="streaming").to_dict(as_series=False) == {408"literal": [literal],409"a_count": [20],410"a_sum": [190],411}412413414@pytest.mark.parametrize("streaming", [True, False])415def test_group_by_multiple_keys_one_literal(streaming: bool) -> None:416df = pl.DataFrame({"a": [1, 1, 2], "b": [4, 5, 6]})417418expected = {"a": [1, 2], "literal": [1, 1], "b": [5, 6]}419assert (420df.lazy()421.group_by("a", pl.lit(1))422.agg(pl.col("b").max())423.sort(["a", "b"])424.collect(engine="streaming" if streaming else "in-memory")425.to_dict(as_series=False)426== expected427)428429430def test_streaming_group_null_count() -> None:431df = pl.DataFrame({"g": [1] * 6, "a": ["yes", None] * 3}).lazy()432assert df.group_by("g").agg(pl.col("a").count()).collect(433engine="streaming"434).to_dict(as_series=False) == {"g": [1], "a": [3]}435436437def test_streaming_group_by_binary_15116() -> None:438assert (439pl.LazyFrame(440{441"str": [442"A",443"A",444"BB",445"BB",446"CCCC",447"CCCC",448"DDDDDDDD",449"DDDDDDDD",450"EEEEEEEEEEEEEEEE",451"A",452]453}454)455.select([pl.col("str").cast(pl.Binary)])456.group_by(["str"])457.agg([pl.len().alias("count")])458).sort("str").collect(engine="streaming").to_dict(as_series=False) == {459"str": [b"A", b"BB", b"CCCC", b"DDDDDDDD", b"EEEEEEEEEEEEEEEE"],460"count": [3, 2, 2, 2, 1],461}462463464def test_streaming_group_by_convert_15380(partition_limit: int) -> None:465assert (466pl.DataFrame({"a": [1] * partition_limit}).group_by(b="a").len()["len"].item()467== partition_limit468)469470471@pytest.mark.parametrize("streaming", [True, False])472@pytest.mark.parametrize("n_rows_limit_offset", [-1, +3])473def test_streaming_group_by_boolean_mean_15610(474n_rows_limit_offset: int, streaming: bool, partition_limit: int475) -> None:476n_rows = partition_limit + n_rows_limit_offset477478# Also test non-streaming because it sometimes dispatched to streaming agg.479expect = pl.DataFrame({"a": [False, True], "c": [0.0, 0.5]})480481n_repeats = n_rows // 3482assert n_repeats > 0483484out = (485pl.select(486a=pl.repeat([True, False, True], n_repeats).explode(),487b=pl.repeat([True, False, False], n_repeats).explode(),488)489.lazy()490.group_by("a")491.agg(c=pl.mean("b"))492.sort("a")493.collect(engine="streaming" if streaming else "in-memory")494)495496assert_frame_equal(out, expect)497498499def test_streaming_group_by_all_null_21593() -> None:500df = pl.DataFrame(501{502"col_1": ["A", "B", "C", "D"],503"col_2": ["test", None, None, None],504}505)506507out = df.lazy().group_by(pl.all()).min().collect(engine="streaming")508assert_frame_equal(df, out, check_row_order=False)509510511