Path: blob/main/py-polars/tests/unit/streaming/test_streaming_join.py
6939 views
from __future__ import annotations12from datetime import datetime3from typing import TYPE_CHECKING, Literal45import numpy as np6import pandas as pd7import pytest89import polars as pl10from polars.testing import assert_frame_equal, assert_series_equal1112if TYPE_CHECKING:13from pathlib import Path1415from polars._typing import JoinStrategy1617pytestmark = pytest.mark.xdist_group("streaming")181920def test_streaming_full_outer_joins() -> None:21n = 10022dfa = pl.DataFrame(23{24"a": np.random.randint(0, 40, n),25"idx": np.arange(0, n),26}27)2829n = 10030dfb = pl.DataFrame(31{32"a": np.random.randint(0, 40, n),33"idx": np.arange(0, n),34}35)3637join_strategies: list[tuple[JoinStrategy, bool]] = [38("full", False),39("full", True),40]41for how, coalesce in join_strategies:42q = (43dfa.lazy()44.join(dfb.lazy(), on="a", how=how, coalesce=coalesce)45.sort(["idx"])46)47a = q.collect(engine="streaming")48b = q.collect(engine="in-memory")49assert_frame_equal(a, b, check_row_order=False)505152def test_streaming_joins() -> None:53n = 10054dfa = pd.DataFrame(55{56"a": np.random.randint(0, 40, n),57"b": np.arange(0, n),58}59)6061n = 10062dfb = pd.DataFrame(63{64"a": np.random.randint(0, 40, n),65"b": np.arange(0, n),66}67)68dfa_pl = pl.from_pandas(dfa).sort("a")69dfb_pl = pl.from_pandas(dfb)7071join_strategies: list[Literal["inner", "left"]] = ["inner", "left"]72for how in join_strategies:73pd_result = dfa.merge(dfb, on="a", how=how)74pd_result.columns = pd.Index(["a", "b", "b_right"])7576pl_result = (77dfa_pl.lazy()78.join(dfb_pl.lazy(), on="a", how=how)79.sort(["a", "b", "b_right"])80.collect(engine="streaming")81)8283a = (84pl.from_pandas(pd_result)85.with_columns(pl.all().cast(int))86.sort(["a", "b", "b_right"])87)88assert_frame_equal(a, pl_result, check_dtypes=False)8990pd_result = dfa.merge(dfb, on=["a", "b"], how=how)9192pl_result = (93dfa_pl.lazy()94.join(dfb_pl.lazy(), on=["a", "b"], how=how)95.sort(["a", "b"])96.collect(engine="streaming")97)9899# we cast to integer because pandas joins creates floats100a = pl.from_pandas(pd_result).with_columns(pl.all().cast(int)).sort(["a", "b"])101assert_frame_equal(a, pl_result, check_dtypes=False)102103104def test_streaming_cross_join_empty() -> None:105df1 = pl.LazyFrame(data={"col1": ["a"]})106107df2 = pl.LazyFrame(108data={"col1": []},109schema={"col1": str},110)111112out = df1.join(df2, how="cross").collect(engine="streaming")113assert out.shape == (0, 2)114assert out.columns == ["col1", "col1_right"]115116117def test_streaming_join_rechunk_12498() -> None:118rows = pl.int_range(0, 2)119120a = pl.select(A=rows).lazy()121b = pl.select(B=rows).lazy()122123q = a.join(b, how="cross")124assert q.collect(engine="streaming").sort(["B", "A"]).to_dict(as_series=False) == {125"A": [0, 1, 0, 1],126"B": [0, 0, 1, 1],127}128129130@pytest.mark.parametrize("maintain_order", [False, True])131def test_join_null_matches(maintain_order: bool) -> None:132# null values in joins should never find a match.133df_a = pl.LazyFrame(134{135"idx_a": [0, 1, 2],136"a": [None, 1, 2],137}138)139140df_b = pl.LazyFrame(141{142"idx_b": [0, 1, 2, 3],143"a": [None, 2, 1, None],144}145)146# Semi147assert_series_equal(148df_a.join(149df_b,150on="a",151how="semi",152nulls_equal=True,153maintain_order="left" if maintain_order else "none",154).collect()["idx_a"],155pl.Series("idx_a", [0, 1, 2]),156check_order=maintain_order,157)158assert_series_equal(159df_a.join(160df_b,161on="a",162how="semi",163nulls_equal=False,164maintain_order="left" if maintain_order else "none",165).collect()["idx_a"],166pl.Series("idx_a", [1, 2]),167check_order=maintain_order,168)169170# Inner171expected = pl.DataFrame({"idx_a": [2, 1], "a": [2, 1], "idx_b": [1, 2]})172assert_frame_equal(173df_a.join(174df_b,175on="a",176how="inner",177maintain_order="right" if maintain_order else "none",178).collect(),179expected,180check_row_order=maintain_order,181)182183# Left outer184expected = pl.DataFrame(185{"idx_a": [0, 1, 2], "a": [None, 1, 2], "idx_b": [None, 2, 1]}186)187assert_frame_equal(188df_a.join(189df_b,190on="a",191how="left",192maintain_order="left" if maintain_order else "none",193).collect(),194expected,195check_row_order=maintain_order,196)197# Full outer198expected = pl.DataFrame(199{200"idx_a": [None, 2, 1, None, 0],201"a": [None, 2, 1, None, None],202"idx_b": [0, 1, 2, 3, None],203"a_right": [None, 2, 1, None, None],204}205)206assert_frame_equal(207df_a.join(208df_b,209on="a",210how="full",211maintain_order="right" if maintain_order else "none",212).collect(),213expected,214check_row_order=maintain_order,215)216217218@pytest.mark.parametrize("streaming", [False, True])219def test_join_null_matches_multiple_keys(streaming: bool) -> None:220df_a = pl.LazyFrame(221{222"a": [None, 1, 2],223"idx": [0, 1, 2],224}225)226227df_b = pl.LazyFrame(228{229"a": [None, 2, 1, None, 1],230"idx": [0, 1, 2, 3, 1],231"c": [10, 20, 30, 40, 50],232}233)234235expected = pl.DataFrame({"a": [1], "idx": [1], "c": [50]})236assert_frame_equal(237df_a.join(df_b, on=["a", "idx"], how="inner").collect(238engine="streaming" if streaming else "in-memory"239),240expected,241check_row_order=False,242)243expected = pl.DataFrame(244{"a": [None, 1, 2], "idx": [0, 1, 2], "c": [None, 50, None]}245)246assert_frame_equal(247df_a.join(df_b, on=["a", "idx"], how="left").collect(248engine="streaming" if streaming else "in-memory"249),250expected,251check_row_order=False,252)253254expected = pl.DataFrame(255{256"a": [None, None, None, None, None, 1, 2],257"idx": [None, None, None, None, 0, 1, 2],258"a_right": [None, 2, 1, None, None, 1, None],259"idx_right": [0, 1, 2, 3, None, 1, None],260"c": [10, 20, 30, 40, None, 50, None],261}262)263assert_frame_equal(264df_a.join(df_b, on=["a", "idx"], how="full").sort("a").collect(),265expected,266check_row_order=False,267)268269270def test_streaming_join_and_union() -> None:271a = pl.LazyFrame({"a": [1, 2]})272273b = pl.LazyFrame({"a": [1, 2, 4, 8]})274275c = a.join(b, on="a", maintain_order="left_right")276# The join node latest ensures that the dispatcher277# needs to replace placeholders in unions.278q = pl.concat([a, b, c])279280out = q.collect(engine="streaming")281assert_frame_equal(out, q.collect(engine="in-memory"))282assert out.to_series().to_list() == [1, 2, 1, 2, 4, 8, 1, 2]283284285def test_non_coalescing_streaming_left_join() -> None:286df1 = pl.LazyFrame({"a": [1, 2, 3], "b": ["a", "b", "c"]})287288df2 = pl.LazyFrame({"a": [1, 2], "c": ["j", "i"]})289290q = df1.join(df2, on="a", how="left", coalesce=False)291assert_frame_equal(292q.collect(engine="streaming"),293pl.DataFrame(294{295"a": [1, 2, 3],296"b": ["a", "b", "c"],297"a_right": [1, 2, None],298"c": ["j", "i", None],299}300),301check_row_order=False,302)303304305@pytest.mark.write_disk306def test_streaming_outer_join_partial_flush(tmp_path: Path) -> None:307data = {308"value_at": [datetime(2024, i + 1, 1) for i in range(6)],309"value": list(range(6)),310}311312parquet_path = tmp_path / "data.parquet"313pl.DataFrame(data=data).write_parquet(parquet_path)314315other_parquet_path = tmp_path / "data2.parquet"316pl.DataFrame(data=data).write_parquet(other_parquet_path)317318lf1 = pl.scan_parquet(other_parquet_path)319lf2 = pl.scan_parquet(parquet_path)320321join_cols = set(lf1.collect_schema()).intersection(set(lf2.collect_schema()))322final_lf = lf1.join(lf2, on=list(join_cols), how="full", coalesce=True)323324assert_frame_equal(325final_lf.collect(engine="streaming"),326pl.DataFrame(327{328"value_at": [329datetime(2024, 1, 1, 0, 0),330datetime(2024, 2, 1, 0, 0),331datetime(2024, 3, 1, 0, 0),332datetime(2024, 4, 1, 0, 0),333datetime(2024, 5, 1, 0, 0),334datetime(2024, 6, 1, 0, 0),335],336"value": [0, 1, 2, 3, 4, 5],337}338),339check_row_order=False,340)341342343def test_flush_join_and_operation_19040() -> None:344df_A = pl.LazyFrame({"K": [True, False], "A": [1, 1]})345346df_B = pl.LazyFrame({"K": [True], "B": [1]})347348df_C = pl.LazyFrame({"K": [True], "C": [1]})349350q = (351df_A.join(df_B, how="full", on=["K"], coalesce=True)352.join(df_C, how="full", on=["K"], coalesce=True)353.with_columns(B=pl.col("B"))354.sort("K")355)356assert q.collect(engine="streaming").to_dict(as_series=False) == {357"K": [False, True],358"A": [1, 1],359"B": [None, 1],360"C": [None, 1],361}362363364def test_full_coalesce_join_and_rename_15583() -> None:365df1 = pl.LazyFrame({"a": [1, 2, 3]})366df2 = pl.LazyFrame({"a": [3, 4, 5]})367368result = (369df1.join(df2, on="a", how="full", coalesce=True)370.select(pl.all().name.map(lambda c: c.upper()))371.sort("A")372.collect(engine="streaming")373)374assert result["A"].to_list() == [1, 2, 3, 4, 5]375376377def test_invert_order_full_join_22295() -> None:378lf = pl.LazyFrame(379{380"value_at": [datetime(2024, i + 1, 1) for i in range(6)],381"value": list(range(6)),382}383)384385lf.join(lf, on=["value", "value_at"], how="full", coalesce=True).collect(386engine="streaming"387)388389390