Path: blob/main/py-polars/tests/unit/streaming/test_streaming_join.py
8412 views
from __future__ import annotations12import typing3from datetime import datetime, timedelta4from typing import TYPE_CHECKING, Any, Literal56import hypothesis.strategies as st7import numpy as np8import pandas as pd9import pytest10from hypothesis import given, settings1112import polars as pl13from polars._typing import AsofJoinStrategy14from polars.datatypes.group import (15FLOAT_DTYPES,16INTEGER_DTYPES,17)18from polars.testing import assert_frame_equal, assert_series_equal19from polars.testing.parametric.strategies.core import dataframes2021if TYPE_CHECKING:22from pathlib import Path2324from polars._typing import AsofJoinStrategy, JoinStrategy, MaintainOrderJoin2526pytestmark = pytest.mark.xdist_group("streaming")272829def test_streaming_full_outer_joins() -> None:30n = 10031dfa = pl.DataFrame(32{33"a": np.random.randint(0, 40, n),34"idx": np.arange(0, n),35}36)3738n = 10039dfb = pl.DataFrame(40{41"a": np.random.randint(0, 40, n),42"idx": np.arange(0, n),43}44)4546join_strategies: list[tuple[JoinStrategy, bool]] = [47("full", False),48("full", True),49]50for how, coalesce in join_strategies:51q = (52dfa.lazy()53.join(dfb.lazy(), on="a", how=how, coalesce=coalesce)54.sort(["idx"])55)56a = q.collect(engine="streaming")57b = q.collect(engine="in-memory")58assert_frame_equal(a, b, check_row_order=False)596061def test_streaming_joins() -> None:62n = 10063dfa = pd.DataFrame(64{65"a": np.random.randint(0, 40, n),66"b": np.arange(0, n),67}68)6970n = 10071dfb = pd.DataFrame(72{73"a": np.random.randint(0, 40, n),74"b": np.arange(0, n),75}76)77dfa_pl = pl.from_pandas(dfa).sort("a")78dfb_pl = pl.from_pandas(dfb)7980join_strategies: list[Literal["inner", "left"]] = ["inner", "left"]81for how in join_strategies:82pd_result = dfa.merge(dfb, on="a", how=how)83pd_result.columns = pd.Index(["a", "b", "b_right"])8485pl_result = (86dfa_pl.lazy()87.join(dfb_pl.lazy(), on="a", how=how)88.sort(["a", "b", "b_right"])89.collect(engine="streaming")90)9192a = (93pl.from_pandas(pd_result)94.with_columns(pl.all().cast(int))95.sort(["a", "b", "b_right"])96)97assert_frame_equal(a, pl_result, check_dtypes=False)9899pd_result = dfa.merge(dfb, on=["a", "b"], how=how)100101pl_result = (102dfa_pl.lazy()103.join(dfb_pl.lazy(), on=["a", "b"], how=how)104.sort(["a", "b"])105.collect(engine="streaming")106)107108# we cast to integer because pandas joins creates floats109a = pl.from_pandas(pd_result).with_columns(pl.all().cast(int)).sort(["a", "b"])110assert_frame_equal(a, pl_result, check_dtypes=False)111112113def test_streaming_cross_join_empty() -> None:114df1 = pl.LazyFrame(data={"col1": ["a"]})115116df2 = pl.LazyFrame(117data={"col1": []},118schema={"col1": str},119)120121out = df1.join(df2, how="cross").collect(engine="streaming")122assert out.shape == (0, 2)123assert out.columns == ["col1", "col1_right"]124125126def test_streaming_join_rechunk_12498() -> None:127rows = pl.int_range(0, 2)128129a = pl.select(A=rows).lazy()130b = pl.select(B=rows).lazy()131132q = a.join(b, how="cross")133assert q.collect(engine="streaming").sort(["B", "A"]).to_dict(as_series=False) == {134"A": [0, 1, 0, 1],135"B": [0, 0, 1, 1],136}137138139@pytest.mark.parametrize("maintain_order", [False, True])140def test_join_null_matches(maintain_order: bool) -> None:141# null values in joins should never find a match.142df_a = pl.LazyFrame(143{144"idx_a": [0, 1, 2],145"a": [None, 1, 2],146}147)148149df_b = pl.LazyFrame(150{151"idx_b": [0, 1, 2, 3],152"a": [None, 2, 1, None],153}154)155# Semi156assert_series_equal(157df_a.join(158df_b,159on="a",160how="semi",161nulls_equal=True,162maintain_order="left" if maintain_order else "none",163).collect()["idx_a"],164pl.Series("idx_a", [0, 1, 2]),165check_order=maintain_order,166)167assert_series_equal(168df_a.join(169df_b,170on="a",171how="semi",172nulls_equal=False,173maintain_order="left" if maintain_order else "none",174).collect()["idx_a"],175pl.Series("idx_a", [1, 2]),176check_order=maintain_order,177)178179# Inner180expected = pl.DataFrame({"idx_a": [2, 1], "a": [2, 1], "idx_b": [1, 2]})181assert_frame_equal(182df_a.join(183df_b,184on="a",185how="inner",186maintain_order="right" if maintain_order else "none",187).collect(),188expected,189check_row_order=maintain_order,190)191192# Left outer193expected = pl.DataFrame(194{"idx_a": [0, 1, 2], "a": [None, 1, 2], "idx_b": [None, 2, 1]}195)196assert_frame_equal(197df_a.join(198df_b,199on="a",200how="left",201maintain_order="left" if maintain_order else "none",202).collect(),203expected,204check_row_order=maintain_order,205)206# Full outer207expected = pl.DataFrame(208{209"idx_a": [None, 2, 1, None, 0],210"a": [None, 2, 1, None, None],211"idx_b": [0, 1, 2, 3, None],212"a_right": [None, 2, 1, None, None],213}214)215assert_frame_equal(216df_a.join(217df_b,218on="a",219how="full",220maintain_order="right" if maintain_order else "none",221).collect(),222expected,223check_row_order=maintain_order,224)225226227@pytest.mark.parametrize("streaming", [False, True])228def test_join_null_matches_multiple_keys(streaming: bool) -> None:229df_a = pl.LazyFrame(230{231"a": [None, 1, 2],232"idx": [0, 1, 2],233}234)235236df_b = pl.LazyFrame(237{238"a": [None, 2, 1, None, 1],239"idx": [0, 1, 2, 3, 1],240"c": [10, 20, 30, 40, 50],241}242)243244expected = pl.DataFrame({"a": [1], "idx": [1], "c": [50]})245assert_frame_equal(246df_a.join(df_b, on=["a", "idx"], how="inner").collect(247engine="streaming" if streaming else "in-memory"248),249expected,250check_row_order=False,251)252expected = pl.DataFrame(253{"a": [None, 1, 2], "idx": [0, 1, 2], "c": [None, 50, None]}254)255assert_frame_equal(256df_a.join(df_b, on=["a", "idx"], how="left").collect(257engine="streaming" if streaming else "in-memory"258),259expected,260check_row_order=False,261)262263expected = pl.DataFrame(264{265"a": [None, None, None, None, None, 1, 2],266"idx": [None, None, None, None, 0, 1, 2],267"a_right": [None, 2, 1, None, None, 1, None],268"idx_right": [0, 1, 2, 3, None, 1, None],269"c": [10, 20, 30, 40, None, 50, None],270}271)272assert_frame_equal(273df_a.join(df_b, on=["a", "idx"], how="full").sort("a").collect(),274expected,275check_row_order=False,276)277278279def test_streaming_join_and_union() -> None:280a = pl.LazyFrame({"a": [1, 2]})281282b = pl.LazyFrame({"a": [1, 2, 4, 8]})283284c = a.join(b, on="a", maintain_order="left_right")285# The join node latest ensures that the dispatcher286# needs to replace placeholders in unions.287q = pl.concat([a, b, c])288289out = q.collect(engine="streaming")290assert_frame_equal(out, q.collect(engine="in-memory"))291assert out.to_series().to_list() == [1, 2, 1, 2, 4, 8, 1, 2]292293294def test_non_coalescing_streaming_left_join() -> None:295df1 = pl.LazyFrame({"a": [1, 2, 3], "b": ["a", "b", "c"]})296297df2 = pl.LazyFrame({"a": [1, 2], "c": ["j", "i"]})298299q = df1.join(df2, on="a", how="left", coalesce=False)300assert_frame_equal(301q.collect(engine="streaming"),302pl.DataFrame(303{304"a": [1, 2, 3],305"b": ["a", "b", "c"],306"a_right": [1, 2, None],307"c": ["j", "i", None],308}309),310check_row_order=False,311)312313314@pytest.mark.write_disk315def test_streaming_outer_join_partial_flush(tmp_path: Path) -> None:316data = {317"value_at": [datetime(2024, i + 1, 1) for i in range(6)],318"value": list(range(6)),319}320321parquet_path = tmp_path / "data.parquet"322pl.DataFrame(data=data).write_parquet(parquet_path)323324other_parquet_path = tmp_path / "data2.parquet"325pl.DataFrame(data=data).write_parquet(other_parquet_path)326327lf1 = pl.scan_parquet(other_parquet_path)328lf2 = pl.scan_parquet(parquet_path)329330join_cols = set(lf1.collect_schema()).intersection(set(lf2.collect_schema()))331final_lf = lf1.join(lf2, on=list(join_cols), how="full", coalesce=True)332333assert_frame_equal(334final_lf.collect(engine="streaming"),335pl.DataFrame(336{337"value_at": [338datetime(2024, 1, 1, 0, 0),339datetime(2024, 2, 1, 0, 0),340datetime(2024, 3, 1, 0, 0),341datetime(2024, 4, 1, 0, 0),342datetime(2024, 5, 1, 0, 0),343datetime(2024, 6, 1, 0, 0),344],345"value": [0, 1, 2, 3, 4, 5],346}347),348check_row_order=False,349)350351352def test_flush_join_and_operation_19040() -> None:353df_A = pl.LazyFrame({"K": [True, False], "A": [1, 1]})354355df_B = pl.LazyFrame({"K": [True], "B": [1]})356357df_C = pl.LazyFrame({"K": [True], "C": [1]})358359q = (360df_A.join(df_B, how="full", on=["K"], coalesce=True)361.join(df_C, how="full", on=["K"], coalesce=True)362.with_columns(B=pl.col("B"))363.sort("K")364)365assert q.collect(engine="streaming").to_dict(as_series=False) == {366"K": [False, True],367"A": [1, 1],368"B": [None, 1],369"C": [None, 1],370}371372373def test_full_coalesce_join_and_rename_15583() -> None:374df1 = pl.LazyFrame({"a": [1, 2, 3]})375df2 = pl.LazyFrame({"a": [3, 4, 5]})376377result = (378df1.join(df2, on="a", how="full", coalesce=True)379.select(pl.all().name.map(lambda c: c.upper()))380.sort("A")381.collect(engine="streaming")382)383assert result["A"].to_list() == [1, 2, 3, 4, 5]384385386def test_invert_order_full_join_22295() -> None:387lf = pl.LazyFrame(388{389"value_at": [datetime(2024, i + 1, 1) for i in range(6)],390"value": list(range(6)),391}392)393lf.join(lf, on=["value", "value_at"], how="full", coalesce=True).collect(394engine="streaming"395)396397398def test_cross_join_with_literal_column_25544() -> None:399df0 = pl.LazyFrame({"c0": [0]})400df1 = pl.LazyFrame({"c0": [1]})401402result = df0.join(403df1.select(pl.col("c0")).with_columns(pl.lit(1)),404on=True, # type: ignore[arg-type]405).select("c0")406407in_memory_result = result.collect(engine="in-memory")408streaming_result = result.collect(engine="streaming")409410assert_frame_equal(streaming_result, in_memory_result)411assert streaming_result.item() == 0412413414@pytest.mark.parametrize("on", [["key"], ["key", "key_ext"]])415@pytest.mark.parametrize("how", ["inner", "left", "right", "full"])416@pytest.mark.parametrize("descending", [False, True])417@pytest.mark.parametrize("nulls_last", [False, True])418@pytest.mark.parametrize("nulls_equal", [False, True])419@pytest.mark.parametrize("coalesce", [None, True, False])420@pytest.mark.parametrize("maintain_order", ["none", "left_right", "right_left"])421@given(data=st.data())422@settings(max_examples=10)423def test_merge_join(424on: list[str],425how: JoinStrategy,426descending: bool,427nulls_last: bool,428nulls_equal: bool,429coalesce: bool | None,430maintain_order: MaintainOrderJoin,431data: st.DataObject,432) -> None:433check_row_order = maintain_order in {"left_right", "right_left"}434435df_st = dataframes(min_cols=len(on), max_cols=len(on), allowed_dtypes=[pl.Int16])436left_df = data.draw(df_st)437right_df = data.draw(df_st)438439left = left_df.rename(dict(zip(left_df.columns, ["key", "key_ext"], strict=False)))440right = right_df.rename(441dict(zip(right_df.columns, ["key", "key_ext"], strict=False))442)443444def df_sorted(df: pl.DataFrame) -> pl.LazyFrame:445return (446df.lazy()447.sort(448*on,449descending=descending,450nulls_last=nulls_last,451maintain_order=True,452multithreaded=False,453)454.set_sorted(on, descending=descending, nulls_last=nulls_last)455)456457q = df_sorted(left).join(458df_sorted(right),459on=on,460how=how,461nulls_equal=nulls_equal,462coalesce=coalesce,463maintain_order=maintain_order,464)465dot = q.show_graph(engine="streaming", plan_stage="physical", raw_output=True)466expected = q.collect(engine="in-memory")467actual = q.collect(engine="streaming")468469assert "merge-join" in typing.cast("str", dot), "merge-join not used in plan"470assert_frame_equal(actual, expected, check_row_order=check_row_order)471472473@pytest.mark.parametrize(474("keys", "dtype"),475[476([False, True, False], pl.Boolean),477([1, 3, 2], pl.Int8),478([1, 3, 2], pl.Int16),479([1, 3, 2], pl.Int32),480([1, 3, 2], pl.Int64),481([1, 3, 2], pl.Int128),482([1, 3, 2], pl.UInt8),483([1, 3, 2], pl.UInt16),484([1, 3, 2], pl.UInt32),485([1, 3, 2], pl.UInt64),486([1, 3, 2], pl.UInt128),487([1.0, 3.0, 2.0], pl.Float16),488([1.0, 3.0, 2.0], pl.Float32),489([1.0, 3.0, 2.0], pl.Float64),490(["a", "b", "c"], pl.String),491([b"a", b"b", b"c"], pl.Binary),492([datetime(2024, 1, x) for x in [1, 3, 2]], pl.Date),493([datetime(2024, 1, x, 12, 0) for x in [1, 3, 2]], pl.Time),494([datetime(2024, 1, x, 12, 0) for x in [1, 3, 2]], pl.Datetime),495([timedelta(days=x) for x in [1, 3, 2]], pl.Duration),496([1, 3, 2], pl.Decimal),497([pl.Null, pl.Null, pl.Null], pl.Null),498(["a", "c", "b"], pl.Enum(["a", "b", "c"])),499(["a", "c", "b"], pl.Categorical),500],501)502@pytest.mark.parametrize("how", ["inner", "left", "right", "full"])503@pytest.mark.parametrize("nulls_equal", [False, True])504def test_join_dtypes(505keys: list[Any], dtype: pl.DataType, how: JoinStrategy, nulls_equal: bool506) -> None:507df_left = pl.DataFrame({"key": pl.Series("key", keys[:2], dtype=dtype)})508df_right = pl.DataFrame({"key": pl.Series("key", keys[2:], dtype=dtype)})509510def df_sorted(df: pl.DataFrame) -> pl.LazyFrame:511return (512df.lazy()513.sort(514"key",515maintain_order=True,516multithreaded=False,517)518.set_sorted("key")519)520521q_hashjoin = df_left.lazy().join(522df_right.lazy(),523on="key",524how=how,525nulls_equal=nulls_equal,526maintain_order="none",527)528dot = q_hashjoin.show_graph(529engine="streaming", plan_stage="physical", raw_output=True530)531expected = q_hashjoin.collect(engine="in-memory")532actual = q_hashjoin.collect(engine="streaming")533assert "equi-join" in typing.cast("str", dot), "hash-join not used in plan"534assert_frame_equal(actual, expected, check_row_order=False)535536q_mergejoin = df_sorted(df_left).join(537df_sorted(df_right),538on="key",539how=how,540nulls_equal=nulls_equal,541maintain_order="none",542)543dot = q_mergejoin.show_graph(544engine="streaming", plan_stage="physical", raw_output=True545)546expected = q_mergejoin.collect(engine="in-memory")547actual = q_mergejoin.collect(engine="streaming")548assert "merge-join" in typing.cast("str", dot), "merge-join not used in plan"549assert_frame_equal(actual, expected, check_row_order=False)550551552def test_merge_join_exprs() -> None:553left = pl.LazyFrame(554{555"key": ["", "a", "c"],556"key_ext": [1, 2, 3],557"value": [1, 2, 3],558}559).set_sorted("key", "key_ext")560right = pl.LazyFrame(561{562"key": ["", "a", "b"],563"key_ext": [3, 2, 3],564"value": [4, 5, 6],565}566).set_sorted("key", "key_ext")567568q = left.join(569right,570left_on="key",571right_on=pl.concat_str(pl.col("key"), ignore_nulls=False),572how="full",573maintain_order="none",574)575dot = q.show_graph(engine="streaming", plan_stage="physical", raw_output=True)576assert "merge-join" in typing.cast("str", dot), "merge-join not used in plan"577assert_frame_equal(q.collect(engine="streaming"), q.collect(engine="in-memory"))578579580@pytest.mark.parametrize("left_descending", [False, True])581@pytest.mark.parametrize("right_descending", [False, True])582@pytest.mark.parametrize("left_nulls_last", [False, True])583@pytest.mark.parametrize("right_nulls_last", [False, True])584def test_merge_join_applicable(585left_descending: bool,586right_descending: bool,587left_nulls_last: bool,588right_nulls_last: bool,589) -> None:590left = pl.LazyFrame({"key": [1]}).set_sorted(591"key", descending=left_descending, nulls_last=left_nulls_last592)593right = pl.LazyFrame({"key": [2]}).set_sorted(594"key", descending=right_descending, nulls_last=right_nulls_last595)596q = left.join(right, on="key", how="full", maintain_order="left_right")597dot = q.show_graph(engine="streaming", plan_stage="physical", raw_output=True)598if (left_descending, left_nulls_last) == (right_descending, right_nulls_last):599assert "merge-join" in typing.cast("str", dot)600else:601assert "merge-join" not in typing.cast("str", dot)602assert_frame_equal(q.collect(engine="streaming"), q.collect(engine="in-memory"))603604605@pytest.mark.parametrize("strategy", ["backward", "forward", "nearest"])606@pytest.mark.parametrize("allow_exact_matches", [False, True])607@pytest.mark.parametrize("coalesce", [False, True])608@pytest.mark.parametrize(609"dtypes",610[611FLOAT_DTYPES,612INTEGER_DTYPES,613{pl.String, pl.Binary},614{pl.Date},615{616pl.Datetime("ms"),617pl.Datetime("us"),618pl.Datetime("ns"),619},620{621pl.Datetime("ms", time_zone="Europe/Amsterdam"),622pl.Datetime("us", time_zone="Europe/Amsterdam"),623pl.Datetime("ns", time_zone="Europe/Amsterdam"),624},625{pl.Time},626{pl.Duration("ms"), pl.Duration("us"), pl.Duration("ns")},627],628)629@given(data=st.data())630def test_streaming_asof_join(631data: st.DataObject,632strategy: AsofJoinStrategy,633allow_exact_matches: bool,634coalesce: bool,635dtypes: set[pl.DataType],636) -> None:637if dtypes & {pl.String, pl.Binary} and strategy == "nearest":638pytest.skip("asof join with string/binary does not support 'nearest' strategy")639640dtype = data.draw(st.sampled_from(list(dtypes)))641df_st = dataframes(642min_cols=1, max_cols=1, allowed_dtypes=[dtype], allow_time_zones=False643)644left_df = data.draw(df_st)645right_df = data.draw(df_st)646647left = left_df.rename(lambda _: "key").sort("key").with_row_index().lazy()648right = right_df.rename(lambda _: "key").sort("key").with_row_index().lazy()649650q = left.join_asof(651right,652on="key",653strategy=strategy,654allow_exact_matches=allow_exact_matches,655coalesce=coalesce,656)657expected = q.collect(engine="in-memory")658actual = q.collect(engine="streaming")659assert_frame_equal(actual, expected)660661662