Path: blob/main/py-polars/tests/unit/operations/test_join.py
6939 views
from __future__ import annotations12import typing3import warnings4from datetime import date, datetime5from typing import TYPE_CHECKING, Any, Callable, Literal67import numpy as np8import pandas as pd9import pytest1011import polars as pl12from polars.exceptions import (13ColumnNotFoundError,14ComputeError,15DuplicateError,16InvalidOperationError,17SchemaError,18)19from polars.testing import assert_frame_equal, assert_series_equal20from tests.unit.conftest import time_func2122if TYPE_CHECKING:23from polars._typing import JoinStrategy, PolarsDataType242526def test_semi_anti_join() -> None:27df_a = pl.DataFrame({"key": [1, 2, 3], "payload": ["f", "i", None]})2829df_b = pl.DataFrame({"key": [3, 4, 5, None]})3031assert df_a.join(df_b, on="key", how="anti").to_dict(as_series=False) == {32"key": [1, 2],33"payload": ["f", "i"],34}35assert df_a.join(df_b, on="key", how="semi").to_dict(as_series=False) == {36"key": [3],37"payload": [None],38}3940# lazy41result = df_a.lazy().join(df_b.lazy(), on="key", how="anti").collect()42expected_values = {"key": [1, 2], "payload": ["f", "i"]}43assert result.to_dict(as_series=False) == expected_values4445result = df_a.lazy().join(df_b.lazy(), on="key", how="semi").collect()46expected_values = {"key": [3], "payload": [None]}47assert result.to_dict(as_series=False) == expected_values4849df_a = pl.DataFrame(50{"a": [1, 2, 3, 1], "b": ["a", "b", "c", "a"], "payload": [10, 20, 30, 40]}51)5253df_b = pl.DataFrame({"a": [3, 3, 4, 5], "b": ["c", "c", "d", "e"]})5455assert df_a.join(df_b, on=["a", "b"], how="anti").to_dict(as_series=False) == {56"a": [1, 2, 1],57"b": ["a", "b", "a"],58"payload": [10, 20, 40],59}60assert df_a.join(df_b, on=["a", "b"], how="semi").to_dict(as_series=False) == {61"a": [3],62"b": ["c"],63"payload": [30],64}656667def test_join_same_cat_src() -> None:68df = pl.DataFrame(69data={"column": ["a", "a", "b"], "more": [1, 2, 3]},70schema=[("column", pl.Categorical), ("more", pl.Int32)],71)72df_agg = df.group_by("column").agg(pl.col("more").mean())73assert_frame_equal(74df.join(df_agg, on="column"),75pl.DataFrame(76{77"column": ["a", "a", "b"],78"more": [1, 2, 3],79"more_right": [1.5, 1.5, 3.0],80},81schema=[82("column", pl.Categorical),83("more", pl.Int32),84("more_right", pl.Float64),85],86),87check_row_order=False,88)899091@pytest.mark.parametrize("reverse", [False, True])92def test_sorted_merge_joins(reverse: bool) -> None:93n = 3094df_a = pl.DataFrame({"a": np.sort(np.random.randint(0, n // 2, n))}).with_row_index(95"row_a"96)97df_b = pl.DataFrame(98{"a": np.sort(np.random.randint(0, n // 2, n // 2))}99).with_row_index("row_b")100101if reverse:102df_a = df_a.select(pl.all().reverse())103df_b = df_b.select(pl.all().reverse())104105join_strategies: list[JoinStrategy] = ["left", "inner"]106for cast_to in [int, str, float]:107for how in join_strategies:108df_a_ = df_a.with_columns(pl.col("a").cast(cast_to))109df_b_ = df_b.with_columns(pl.col("a").cast(cast_to))110111# hash join112out_hash_join = df_a_.join(df_b_, on="a", how=how)113114# sorted merge join115out_sorted_merge_join = df_a_.with_columns(116pl.col("a").set_sorted(descending=reverse)117).join(118df_b_.with_columns(pl.col("a").set_sorted(descending=reverse)),119on="a",120how=how,121)122123assert_frame_equal(124out_hash_join, out_sorted_merge_join, check_row_order=False125)126127128def test_join_negative_integers() -> None:129expected = pl.DataFrame({"a": [-6, -1, 0], "b": [-6, -1, 0]})130df1 = pl.DataFrame(131{132"a": [-1, -6, -3, 0],133}134)135136df2 = pl.DataFrame(137{138"a": [-6, -1, -4, -2, 0],139"b": [-6, -1, -4, -2, 0],140}141)142143for dt in [pl.Int8, pl.Int16, pl.Int32, pl.Int64]:144assert_frame_equal(145df1.with_columns(pl.all().cast(dt)).join(146df2.with_columns(pl.all().cast(dt)), on="a", how="inner"147),148expected.select(pl.all().cast(dt)),149check_row_order=False,150)151152153def test_deprecated() -> None:154df = pl.DataFrame({"a": [1, 2], "b": [3, 4]})155other = pl.DataFrame({"a": [1, 2], "c": [3, 4]})156result = pl.DataFrame({"a": [1, 2], "b": [3, 4], "c": [3, 4]})157158np.testing.assert_equal(159df.join(other=other, on="a", maintain_order="left").to_numpy(),160result.to_numpy(),161)162np.testing.assert_equal(163df.lazy()164.join(other=other.lazy(), on="a", maintain_order="left")165.collect()166.to_numpy(),167result.to_numpy(),168)169170171def test_deprecated_parameter_join_nulls() -> None:172df = pl.DataFrame({"a": [1, None]})173with pytest.deprecated_call(174match=r"the argument `join_nulls` for `DataFrame.join` is deprecated. It was renamed to `nulls_equal`"175):176result = df.join(df, on="a", join_nulls=True) # type: ignore[call-arg]177assert_frame_equal(result, df, check_row_order=False)178179180def test_join_on_expressions() -> None:181df_a = pl.DataFrame({"a": [1, 2, 3]})182183df_b = pl.DataFrame({"b": [1, 4, 9, 9, 0]})184185assert_frame_equal(186df_a.join(df_b, left_on=(pl.col("a") ** 2).cast(int), right_on=pl.col("b")),187pl.DataFrame({"a": [1, 2, 3, 3], "b": [1, 4, 9, 9]}),188check_row_order=False,189)190191192def test_join_lazy_frame_on_expression() -> None:193# Tests a lazy frame projection pushdown bug194# https://github.com/pola-rs/polars/issues/19822195196df = pl.DataFrame(data={"a": [0, 1], "b": [2, 3]})197198lazy_join = (199df.lazy()200.join(df.lazy(), left_on=pl.coalesce("b", "a"), right_on="a")201.select("a")202.collect()203)204205eager_join = df.join(df, left_on=pl.coalesce("b", "a"), right_on="a").select("a")206207assert lazy_join.shape == eager_join.shape208209210def test_right_join_schema_maintained_22516() -> None:211df_left = pl.DataFrame({"number": [1]})212df_right = pl.DataFrame({"invoice_number": [1]})213eager_join = df_left.join(214df_right, left_on="number", right_on="invoice_number", how="right"215).select(pl.len())216217lazy_join = (218df_left.lazy()219.join(df_right.lazy(), left_on="number", right_on="invoice_number", how="right")220.select(pl.len())221.collect()222)223224assert lazy_join.item() == eager_join.item()225226227def test_join() -> None:228df_left = pl.DataFrame(229{230"a": ["a", "b", "a", "z"],231"b": [1, 2, 3, 4],232"c": [6, 5, 4, 3],233}234)235df_right = pl.DataFrame(236{237"a": ["b", "c", "b", "a"],238"k": [0, 3, 9, 6],239"c": [1, 0, 2, 1],240}241)242243joined = df_left.join(244df_right, left_on="a", right_on="a", maintain_order="left_right"245).sort("a")246assert_series_equal(joined["b"], pl.Series("b", [1, 3, 2, 2]))247248joined = df_left.join(249df_right, left_on="a", right_on="a", how="left", maintain_order="left_right"250).sort("a")251assert joined["c_right"].is_null().sum() == 1252assert_series_equal(joined["b"], pl.Series("b", [1, 3, 2, 2, 4]))253254joined = df_left.join(df_right, left_on="a", right_on="a", how="full").sort("a")255assert joined["c_right"].null_count() == 1256assert joined["c"].null_count() == 1257assert joined["b"].null_count() == 1258assert joined["k"].null_count() == 1259assert joined["a"].null_count() == 1260261# we need to pass in a column to join on, either by supplying `on`, or both262# `left_on` and `right_on`263with pytest.raises(ValueError):264df_left.join(df_right)265with pytest.raises(ValueError):266df_left.join(df_right, right_on="a")267with pytest.raises(ValueError):268df_left.join(df_right, left_on="a")269270df_a = pl.DataFrame({"a": [1, 2, 1, 1], "b": ["a", "b", "c", "c"]})271df_b = pl.DataFrame(272{"foo": [1, 1, 1], "bar": ["a", "c", "c"], "ham": ["let", "var", "const"]}273)274275# just check if join on multiple columns runs276df_a.join(df_b, left_on=["a", "b"], right_on=["foo", "bar"])277eager_join = df_a.join(df_b, left_on="a", right_on="foo")278lazy_join = df_a.lazy().join(df_b.lazy(), left_on="a", right_on="foo").collect()279280cols = ["a", "b", "bar", "ham"]281assert lazy_join.shape == eager_join.shape282assert_frame_equal(lazy_join.sort(by=cols), eager_join.sort(by=cols))283284285def test_joins_dispatch() -> None:286# this just flexes the dispatch a bit287288# don't change the data of this dataframe, this triggered:289# https://github.com/pola-rs/polars/issues/1688290dfa = pl.DataFrame(291{292"a": ["a", "b", "c", "a"],293"b": [1, 2, 3, 1],294"date": ["2021-01-01", "2021-01-02", "2021-01-03", "2021-01-01"],295"datetime": [13241324, 12341256, 12341234, 13241324],296}297).with_columns(298pl.col("date").str.strptime(pl.Date), pl.col("datetime").cast(pl.Datetime)299)300301join_strategies: list[JoinStrategy] = ["left", "inner", "full"]302for how in join_strategies:303dfa.join(dfa, on=["a", "b", "date", "datetime"], how=how)304dfa.join(dfa, on=["date", "datetime"], how=how)305dfa.join(dfa, on=["date", "datetime", "a"], how=how)306dfa.join(dfa, on=["date", "a"], how=how)307dfa.join(dfa, on=["a", "datetime"], how=how)308dfa.join(dfa, on=["date"], how=how)309310311def test_join_on_cast() -> None:312df_a = (313pl.DataFrame({"a": [-5, -2, 3, 3, 9, 10]})314.with_row_index()315.with_columns(pl.col("a").cast(pl.Int32))316)317318df_b = pl.DataFrame({"a": [-2, -3, 3, 10]})319320assert_frame_equal(321df_a.join(df_b, on=pl.col("a").cast(pl.Int64)),322pl.DataFrame(323{324"index": [1, 2, 3, 5],325"a": [-2, 3, 3, 10],326"a_right": [-2, 3, 3, 10],327}328),329check_row_order=False,330check_dtypes=False,331)332assert df_a.lazy().join(333df_b.lazy(),334on=pl.col("a").cast(pl.Int64),335maintain_order="left",336).collect().to_dict(as_series=False) == {337"index": [1, 2, 3, 5],338"a": [-2, 3, 3, 10],339"a_right": [-2, 3, 3, 10],340}341342343def test_join_chunks_alignment_4720() -> None:344# https://github.com/pola-rs/polars/issues/4720345346df1 = pl.DataFrame(347{348"index1": pl.arange(0, 2, eager=True),349"index2": pl.arange(10, 12, eager=True),350}351)352353df2 = pl.DataFrame(354{355"index3": pl.arange(100, 102, eager=True),356}357)358359df3 = pl.DataFrame(360{361"index1": pl.arange(0, 2, eager=True),362"index2": pl.arange(10, 12, eager=True),363"index3": pl.arange(100, 102, eager=True),364}365)366assert_frame_equal(367df1.join(df2, how="cross").join(368df3,369on=["index1", "index2", "index3"],370how="left",371),372pl.DataFrame(373{374"index1": [0, 0, 1, 1],375"index2": [10, 10, 11, 11],376"index3": [100, 101, 100, 101],377}378),379check_row_order=False,380)381382assert_frame_equal(383df1.join(df2, how="cross").join(384df3,385on=["index3", "index1", "index2"],386how="left",387),388pl.DataFrame(389{390"index1": [0, 0, 1, 1],391"index2": [10, 10, 11, 11],392"index3": [100, 101, 100, 101],393}394),395check_row_order=False,396)397398399def test_jit_sort_joins() -> None:400n = 200401# Explicitly specify numpy dtype because of different defaults on Windows402dfa = pd.DataFrame(403{404"a": np.random.randint(0, 100, n, dtype=np.int64),405"b": np.arange(0, n, dtype=np.int64),406}407)408409n = 40410dfb = pd.DataFrame(411{412"a": np.random.randint(0, 100, n, dtype=np.int64),413"b": np.arange(0, n, dtype=np.int64),414}415)416dfa_pl = pl.from_pandas(dfa).sort("a")417dfb_pl = pl.from_pandas(dfb)418419join_strategies: list[Literal["left", "inner"]] = ["left", "inner"]420for how in join_strategies:421pd_result = dfa.merge(dfb, on="a", how=how)422pd_result.columns = pd.Index(["a", "b", "b_right"])423424# left key sorted right is not425pl_result = dfa_pl.join(dfb_pl, on="a", how=how).sort(["a", "b", "b_right"])426427a = (428pl.from_pandas(pd_result)429.with_columns(pl.all().cast(int))430.sort(["a", "b", "b_right"])431)432assert_frame_equal(a, pl_result)433assert pl_result["a"].flags["SORTED_ASC"]434435# left key sorted right is not436pd_result = dfb.merge(dfa, on="a", how=how)437pd_result.columns = pd.Index(["a", "b", "b_right"])438pl_result = dfb_pl.join(dfa_pl, on="a", how=how).sort(["a", "b", "b_right"])439440a = (441pl.from_pandas(pd_result)442.with_columns(pl.all().cast(int))443.sort(["a", "b", "b_right"])444)445assert_frame_equal(a, pl_result)446assert pl_result["a"].flags["SORTED_ASC"]447448449def test_join_panic_on_binary_expr_5915() -> None:450df_a = pl.DataFrame({"a": [1, 2, 3]}).lazy()451df_b = pl.DataFrame({"b": [1, 4, 9, 9, 0]}).lazy()452453z = df_a.join(df_b, left_on=[(pl.col("a") + 1).cast(int)], right_on=[pl.col("b")])454assert z.collect().to_dict(as_series=False) == {"a": [3], "b": [4]}455456457def test_semi_join_projection_pushdown_6423() -> None:458df1 = pl.DataFrame({"x": [1]}).lazy()459df2 = pl.DataFrame({"y": [1], "x": [1]}).lazy()460461assert (462df1.join(df2, left_on="x", right_on="y", how="semi")463.join(df2, left_on="x", right_on="y", how="semi")464.select(["x"])465).collect().to_dict(as_series=False) == {"x": [1]}466467468def test_semi_join_projection_pushdown_6455() -> None:469df = pl.DataFrame(470{471"id": [1, 1, 2],472"timestamp": [473datetime(2022, 12, 11),474datetime(2022, 12, 12),475datetime(2022, 1, 1),476],477"value": [1, 2, 4],478}479).lazy()480481latest = df.group_by("id").agg(pl.col("timestamp").max())482df = df.join(latest, on=["id", "timestamp"], how="semi")483assert df.select(["id", "value"]).collect().to_dict(as_series=False) == {484"id": [1, 2],485"value": [2, 4],486}487488489def test_update() -> None:490df1 = pl.DataFrame(491{492"key1": [1, 2, 3, 4],493"key2": [1, 2, 3, 4],494"a": [1, 2, 3, 4],495"b": [1, 2, 3, 4],496"c": ["1", "2", "3", "4"],497"d": [498date(2023, 1, 1),499date(2023, 1, 2),500date(2023, 1, 3),501date(2023, 1, 4),502],503}504)505506df2 = pl.DataFrame(507{508"key1": [1, 2, 3, 4],509"key2": [1, 2, 3, 5],510"a": [1, 1, 1, 1],511"b": [2, 2, 2, 2],512"c": ["3", "3", "3", "3"],513"d": [514date(2023, 5, 5),515date(2023, 5, 5),516date(2023, 5, 5),517date(2023, 5, 5),518],519}520)521522# update only on key1523expected = pl.DataFrame(524{525"key1": [1, 2, 3, 4],526"key2": [1, 2, 3, 5],527"a": [1, 1, 1, 1],528"b": [2, 2, 2, 2],529"c": ["3", "3", "3", "3"],530"d": [531date(2023, 5, 5),532date(2023, 5, 5),533date(2023, 5, 5),534date(2023, 5, 5),535],536}537)538assert_frame_equal(df1.update(df2, on="key1"), expected)539540# update on key1 using different left/right names541assert_frame_equal(542df1.update(543df2.rename({"key1": "key1b"}),544left_on="key1",545right_on="key1b",546),547expected,548)549550# update on key1 and key2. This should fail to match the last item.551expected = pl.DataFrame(552{553"key1": [1, 2, 3, 4],554"key2": [1, 2, 3, 4],555"a": [1, 1, 1, 4],556"b": [2, 2, 2, 4],557"c": ["3", "3", "3", "4"],558"d": [559date(2023, 5, 5),560date(2023, 5, 5),561date(2023, 5, 5),562date(2023, 1, 4),563],564}565)566assert_frame_equal(df1.update(df2, on=["key1", "key2"]), expected)567568# update on key1 and key2 using different left/right names569assert_frame_equal(570df1.update(571df2.rename({"key1": "key1b", "key2": "key2b"}),572left_on=["key1", "key2"],573right_on=["key1b", "key2b"],574),575expected,576)577578df = pl.DataFrame({"A": [1, 2, 3, 4], "B": [400, 500, 600, 700]})579580new_df = pl.DataFrame({"B": [4, None, 6], "C": [7, 8, 9]})581582assert df.update(new_df).to_dict(as_series=False) == {583"A": [1, 2, 3, 4],584"B": [4, 500, 6, 700],585}586df1 = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})587df2 = pl.DataFrame({"a": [2, 3], "b": [8, 9]})588589assert df1.update(df2, on="a").to_dict(as_series=False) == {590"a": [1, 2, 3],591"b": [4, 8, 9],592}593594a = pl.LazyFrame({"a": [1, 2, 3]})595b = pl.LazyFrame({"b": [4, 5], "c": [3, 1]})596c = a.update(b)597598assert_frame_equal(a, c)599600# check behaviour of 'how' param601result = a.update(b, left_on="a", right_on="c")602assert result.collect().to_series().to_list() == [1, 2, 3]603604result = a.update(b, how="inner", left_on="a", right_on="c")605assert sorted(result.collect().to_series().to_list()) == [1, 3]606607result = a.update(b.rename({"b": "a"}), how="full", on="a")608assert sorted(result.collect().to_series().sort().to_list()) == [1, 2, 3, 4, 5]609610# check behavior of include_nulls=True611df = pl.DataFrame(612{613"A": [1, 2, 3, 4],614"B": [400, 500, 600, 700],615}616)617new_df = pl.DataFrame(618{619"B": [-66, None, -99],620"C": [5, 3, 1],621}622)623out = df.update(new_df, left_on="A", right_on="C", how="full", include_nulls=True)624expected = pl.DataFrame(625{626"A": [1, 2, 3, 4, 5],627"B": [-99, 500, None, 700, -66],628}629)630assert_frame_equal(out, expected, check_row_order=False)631632# edge-case #11684633x = pl.DataFrame({"a": [0, 1]})634y = pl.DataFrame({"a": [2, 3]})635assert sorted(x.update(y, on="a", how="full")["a"].to_list()) == [0, 1, 2, 3]636637# disallowed join strategies638for join_strategy in ("cross", "anti", "semi"):639with pytest.raises(640ValueError,641match=f"`how` must be one of {{'left', 'inner', 'full'}}; found '{join_strategy}'",642):643a.update(b, how=join_strategy) # type: ignore[arg-type]644645646def test_join_frame_consistency() -> None:647df = pl.DataFrame({"A": [1, 2, 3]})648ldf = pl.DataFrame({"A": [1, 2, 5]}).lazy()649650with pytest.raises(TypeError, match="expected `other`.*LazyFrame"):651_ = ldf.join(df, on="A") # type: ignore[arg-type]652with pytest.raises(TypeError, match="expected `other`.*DataFrame"):653_ = df.join(ldf, on="A") # type: ignore[arg-type]654with pytest.raises(TypeError, match="expected `other`.*LazyFrame"):655_ = ldf.join_asof(df, on="A") # type: ignore[arg-type]656with pytest.raises(TypeError, match="expected `other`.*DataFrame"):657_ = df.join_asof(ldf, on="A") # type: ignore[arg-type]658659660def test_join_concat_projection_pd_case_7071() -> None:661ldf = pl.DataFrame({"id": [1, 2], "value": [100, 200]}).lazy()662ldf2 = pl.DataFrame({"id": [1, 3], "value": [100, 300]}).lazy()663664ldf = ldf.join(ldf2, on=["id", "value"])665ldf = pl.concat([ldf, ldf2])666result = ldf.select("id")667668expected = pl.DataFrame({"id": [1, 1, 3]}).lazy()669assert_frame_equal(result, expected)670671672@pytest.mark.may_fail_auto_streaming # legacy full join is not order-preserving whereas new-streaming is673def test_join_sorted_fast_paths_null() -> None:674df1 = pl.DataFrame({"x": [0, 1, 0]}).sort("x")675df2 = pl.DataFrame({"x": [0, None], "y": [0, 1]})676assert df1.join(df2, on="x", how="inner").to_dict(as_series=False) == {677"x": [0, 0],678"y": [0, 0],679}680assert df1.join(df2, on="x", how="left").to_dict(as_series=False) == {681"x": [0, 0, 1],682"y": [0, 0, None],683}684assert df1.join(df2, on="x", how="anti").to_dict(as_series=False) == {"x": [1]}685assert df1.join(df2, on="x", how="semi").to_dict(as_series=False) == {"x": [0, 0]}686assert df1.join(df2, on="x", how="full").to_dict(as_series=False) == {687"x": [0, 0, 1, None],688"x_right": [0, 0, None, None],689"y": [0, 0, None, 1],690}691692693def test_full_outer_join_list_() -> None:694schema = {"id": pl.Int64, "vals": pl.List(pl.Float64)}695join_schema = {**schema, **{k + "_right": t for (k, t) in schema.items()}}696df1 = pl.DataFrame({"id": [1], "vals": [[]]}, schema=schema) # type: ignore[arg-type]697df2 = pl.DataFrame({"id": [2, 3], "vals": [[], [4]]}, schema=schema) # type: ignore[arg-type]698expected = pl.DataFrame(699{700"id": [None, None, 1],701"vals": [None, None, []],702"id_right": [2, 3, None],703"vals_right": [[], [4.0], None],704},705schema=join_schema, # type: ignore[arg-type]706)707out = df1.join(df2, on="id", how="full", maintain_order="right_left")708assert_frame_equal(out, expected)709710711@pytest.mark.slow712def test_join_validation() -> None:713def test_each_join_validation(714unique: pl.DataFrame, duplicate: pl.DataFrame, on: str, how: JoinStrategy715) -> None:716# one_to_many717_one_to_many_success_inner = unique.join(718duplicate, on=on, how=how, validate="1:m"719)720721with pytest.raises(ComputeError):722_one_to_many_fail_inner = duplicate.join(723unique, on=on, how=how, validate="1:m"724)725726# one to one727with pytest.raises(ComputeError):728_one_to_one_fail_1_inner = unique.join(729duplicate, on=on, how=how, validate="1:1"730)731732with pytest.raises(ComputeError):733_one_to_one_fail_2_inner = duplicate.join(734unique, on=on, how=how, validate="1:1"735)736737# many to one738with pytest.raises(ComputeError):739_many_to_one_fail_inner = unique.join(740duplicate, on=on, how=how, validate="m:1"741)742743_many_to_one_success_inner = duplicate.join(744unique, on=on, how=how, validate="m:1"745)746747# many to many748_many_to_many_success_1_inner = duplicate.join(749unique, on=on, how=how, validate="m:m"750)751752_many_to_many_success_2_inner = unique.join(753duplicate, on=on, how=how, validate="m:m"754)755756# test data757short_unique = pl.DataFrame(758{759"id": [1, 2, 3, 4],760"id_str": ["1", "2", "3", "4"],761"name": ["hello", "world", "rust", "polars"],762}763)764short_duplicate = pl.DataFrame(765{"id": [1, 2, 3, 1], "id_str": ["1", "2", "3", "1"], "cnt": [2, 4, 6, 1]}766)767long_unique = pl.DataFrame(768{769"id": [1, 2, 3, 4, 5],770"id_str": ["1", "2", "3", "4", "5"],771"name": ["hello", "world", "rust", "polars", "meow"],772}773)774long_duplicate = pl.DataFrame(775{776"id": [1, 2, 3, 1, 5],777"id_str": ["1", "2", "3", "1", "5"],778"cnt": [2, 4, 6, 1, 8],779}780)781782join_strategies: list[JoinStrategy] = ["inner", "full", "left"]783784for join_col in ["id", "id_str"]:785for how in join_strategies:786# same size787test_each_join_validation(long_unique, long_duplicate, join_col, how)788789# left longer790test_each_join_validation(long_unique, short_duplicate, join_col, how)791792# right longer793test_each_join_validation(short_unique, long_duplicate, join_col, how)794795796@typing.no_type_check797def test_join_validation_many_keys() -> None:798# unique in both799df1 = pl.DataFrame(800{801"val1": [11, 12, 13, 14],802"val2": [1, 2, 3, 4],803}804)805df2 = pl.DataFrame(806{807"val1": [11, 12, 13, 14],808"val2": [1, 2, 3, 4],809}810)811for join_type in ["inner", "left", "full"]:812for val in ["m:m", "m:1", "1:1", "1:m"]:813df1.join(df2, on=["val1", "val2"], how=join_type, validate=val)814815# many in lhs816df1 = pl.DataFrame(817{818"val1": [11, 11, 12, 13, 14],819"val2": [1, 1, 2, 3, 4],820}821)822823for join_type in ["inner", "left", "full"]:824for val in ["1:1", "1:m"]:825with pytest.raises(ComputeError):826df1.join(df2, on=["val1", "val2"], how=join_type, validate=val)827828# many in rhs829df1 = pl.DataFrame(830{831"val1": [11, 12, 13, 14],832"val2": [1, 2, 3, 4],833}834)835df2 = pl.DataFrame(836{837"val1": [11, 11, 12, 13, 14],838"val2": [1, 1, 2, 3, 4],839}840)841842for join_type in ["inner", "left", "full"]:843for val in ["m:1", "1:1"]:844with pytest.raises(ComputeError):845df1.join(df2, on=["val1", "val2"], how=join_type, validate=val)846847848def test_full_outer_join_bool() -> None:849df1 = pl.DataFrame({"id": [True, False], "val": [1, 2]})850df2 = pl.DataFrame({"id": [True, False], "val": [0, -1]})851assert df1.join(df2, on="id", how="full", maintain_order="right").to_dict(852as_series=False853) == {854"id": [True, False],855"val": [1, 2],856"id_right": [True, False],857"val_right": [0, -1],858}859860861def test_full_outer_join_coalesce_different_names_13450() -> None:862df1 = pl.DataFrame({"L1": ["a", "b", "c"], "L3": ["b", "c", "d"], "L2": [1, 2, 3]})863df2 = pl.DataFrame({"L3": ["a", "c", "d"], "R2": [7, 8, 9]})864865expected = pl.DataFrame(866{867"L1": ["a", "c", "d", "b"],868"L3": ["b", "d", None, "c"],869"L2": [1, 3, None, 2],870"R2": [7, 8, 9, None],871}872)873874out = df1.join(df2, left_on="L1", right_on="L3", how="full", coalesce=True)875assert_frame_equal(out, expected, check_row_order=False)876877878# https://github.com/pola-rs/polars/issues/10663879def test_join_on_wildcard_error() -> None:880df = pl.DataFrame({"x": [1]})881df2 = pl.DataFrame({"x": [1], "y": [2]})882with pytest.raises(883InvalidOperationError,884):885df.join(df2, on=pl.all())886887888def test_join_on_nth_error() -> None:889df = pl.DataFrame({"x": [1]})890df2 = pl.DataFrame({"x": [1], "y": [2]})891with pytest.raises(892InvalidOperationError,893):894df.join(df2, on=pl.first())895896897def test_join_results_in_duplicate_names() -> None:898df = pl.DataFrame(899{900"a": [1, 2, 3],901"b": [4, 5, 6],902"c": [1, 2, 3],903"c_right": [1, 2, 3],904}905)906907def f(x: Any) -> Any:908return x.join(x, on=["a", "b"], how="left")909910# Ensure it also contains the hint911match_str = "(?s)column with name 'c_right' already exists.*You may want to try"912913# Ensure it fails immediately when resolving schema.914with pytest.raises(DuplicateError, match=match_str):915f(df.lazy()).collect_schema()916917with pytest.raises(DuplicateError, match=match_str):918f(df.lazy()).collect()919920with pytest.raises(DuplicateError, match=match_str):921f(df).collect()922923924def test_join_duplicate_suffixed_columns_from_join_key_column_21048() -> None:925df = pl.DataFrame({"a": 1, "b": 1, "b_right": 1})926927def f(x: Any) -> Any:928return x.join(x, on="a")929930# Ensure it also contains the hint931match_str = "(?s)column with name 'b_right' already exists.*You may want to try"932933# Ensure it fails immediately when resolving schema.934with pytest.raises(DuplicateError, match=match_str):935f(df.lazy()).collect_schema()936937with pytest.raises(DuplicateError, match=match_str):938f(df.lazy()).collect()939940with pytest.raises(DuplicateError, match=match_str):941f(df)942943944def test_join_projection_invalid_name_contains_suffix_15243() -> None:945df1 = pl.DataFrame({"a": [1, 2, 3]}).lazy()946df2 = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}).lazy()947948with pytest.raises(ColumnNotFoundError):949(950df1.join(df2, on="a")951.select(pl.col("b").filter(pl.col("b") == pl.col("foo_right")))952.collect()953)954955956def test_join_list_non_numeric() -> None:957assert (958pl.DataFrame(959{960"lists": [961["a", "b", "c"],962["a", "c", "b"],963["a", "c", "b"],964["a", "c", "d"],965]966}967)968).group_by("lists", maintain_order=True).agg(pl.len().alias("count")).to_dict(969as_series=False970) == {971"lists": [["a", "b", "c"], ["a", "c", "b"], ["a", "c", "d"]],972"count": [1, 2, 1],973}974975976@pytest.mark.slow977def test_join_4_columns_with_validity() -> None:978# join on 4 columns so we trigger combine validities979# use 138 as that is 2 u64 and a remainder980a = pl.DataFrame(981{"a": [None if a % 6 == 0 else a for a in range(138)]}982).with_columns(983b=pl.col("a"),984c=pl.col("a"),985d=pl.col("a"),986)987988assert a.join(a, on=["a", "b", "c", "d"], how="inner", nulls_equal=True).shape == (989644,9904,991)992assert a.join(a, on=["a", "b", "c", "d"], how="inner", nulls_equal=False).shape == (993115,9944,995)996997998@pytest.mark.release999def test_cross_join() -> None:1000# triggers > 100 rows implementation1001# https://github.com/pola-rs/polars/blob/5f5acb2a523ce01bc710768b396762b8e69a9e07/polars/polars-core/src/frame/cross_join.rs#L341002df1 = pl.DataFrame({"col1": ["a"], "col2": ["d"]})1003df2 = pl.DataFrame({"frame2": pl.arange(0, 100, eager=True)})1004out = df2.join(df1, how="cross")1005df2 = pl.DataFrame({"frame2": pl.arange(0, 101, eager=True)})1006assert_frame_equal(df2.join(df1, how="cross").slice(0, 100), out)100710081009@pytest.mark.release1010def test_cross_join_slice_pushdown() -> None:1011# this will likely go out of memory if we did not pushdown the slice1012df = (1013pl.Series("x", pl.arange(0, 2**16 - 1, eager=True, dtype=pl.UInt16) % 2**15)1014).to_frame()10151016result = df.lazy().join(df.lazy(), how="cross", suffix="_").slice(-5, 10).collect()1017expected = pl.DataFrame(1018{1019"x": [32766, 32766, 32766, 32766, 32766],1020"x_": [32762, 32763, 32764, 32765, 32766],1021},1022schema={"x": pl.UInt16, "x_": pl.UInt16},1023)1024assert_frame_equal(result, expected)10251026result = df.lazy().join(df.lazy(), how="cross", suffix="_").slice(2, 10).collect()1027expected = pl.DataFrame(1028{1029"x": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],1030"x_": [2, 3, 4, 5, 6, 7, 8, 9, 10, 11],1031},1032schema={"x": pl.UInt16, "x_": pl.UInt16},1033)1034assert_frame_equal(result, expected)103510361037@pytest.mark.parametrize("how", ["left", "inner"])1038def test_join_coalesce(how: JoinStrategy) -> None:1039a = pl.LazyFrame({"a": [1, 2], "b": [1, 2]})1040b = pl.LazyFrame(1041{1042"a": [1, 2, 1, 2],1043"b": [5, 7, 8, 9],1044"c": [1, 2, 1, 2],1045}1046)10471048how = "inner"1049q = a.join(b, on="a", coalesce=False, how=how)1050out = q.collect()1051assert q.collect_schema() == out.schema1052assert out.columns == ["a", "b", "a_right", "b_right", "c"]10531054q = a.join(b, on=["a", "b"], coalesce=False, how=how)1055out = q.collect()1056assert q.collect_schema() == out.schema1057assert out.columns == ["a", "b", "a_right", "b_right", "c"]10581059q = a.join(b, on=["a", "b"], coalesce=True, how=how)1060out = q.collect()1061assert q.collect_schema() == out.schema1062assert out.columns == ["a", "b", "c"]106310641065@pytest.mark.parametrize("how", ["left", "inner", "full"])1066def test_join_empties(how: JoinStrategy) -> None:1067df1 = pl.DataFrame({"col1": [], "col2": [], "col3": []})1068df2 = pl.DataFrame({"col2": [], "col4": [], "col5": []})10691070df = df1.join(df2, on="col2", how=how)1071assert df.height == 0107210731074def test_join_raise_on_redundant_keys() -> None:1075left = pl.DataFrame({"a": [1, 2, 3], "b": [3, 4, 5], "c": [5, 6, 7]})1076right = pl.DataFrame({"a": [2, 3, 4], "c": [4, 5, 6]})1077with pytest.raises(InvalidOperationError, match="already joined on"):1078left.join(right, on=["a", "a"], how="full", coalesce=True)107910801081@pytest.mark.parametrize("coalesce", [False, True])1082def test_join_raise_on_repeated_expression_key_names(coalesce: bool) -> None:1083left = pl.DataFrame({"a": [1, 2, 3], "b": [3, 4, 5], "c": [5, 6, 7]})1084right = pl.DataFrame({"a": [2, 3, 4], "c": [4, 5, 6]})1085with ( # noqa: PT0121086pytest.raises(InvalidOperationError, match="already joined on"),1087warnings.catch_warnings(),1088):1089warnings.simplefilter(action="ignore", category=UserWarning)1090left.join(1091right, on=[pl.col("a"), pl.col("a") % 2], how="full", coalesce=coalesce1092)109310941095def test_join_lit_panic_11410() -> None:1096df = pl.LazyFrame({"date": [1, 2, 3], "symbol": [4, 5, 6]})1097dates = df.select("date").unique(maintain_order=True)1098symbols = df.select("symbol").unique(maintain_order=True)10991100assert symbols.join(1101dates, left_on=pl.lit(1), right_on=pl.lit(1), maintain_order="left_right"1102).collect().to_dict(as_series=False) == {1103"symbol": [4, 4, 4, 5, 5, 5, 6, 6, 6],1104"date": [1, 2, 3, 1, 2, 3, 1, 2, 3],1105}110611071108def test_join_empty_literal_17027() -> None:1109df1 = pl.DataFrame({"a": [1]})1110df2 = pl.DataFrame(schema={"a": pl.Int64})11111112assert df1.join(df2, on=pl.lit(0), how="left").height == 11113assert df1.join(df2, on=pl.lit(0), how="inner").height == 01114assert (1115df1.lazy()1116.join(df2.lazy(), on=pl.lit(0), how="inner")1117.collect(engine="streaming")1118.height1119== 01120)1121assert (1122df1.lazy()1123.join(df2.lazy(), on=pl.lit(0), how="left")1124.collect(engine="streaming")1125.height1126== 11127)112811291130@pytest.mark.parametrize(1131("left_on", "right_on"),1132zip(1133[pl.col("a"), pl.col("a").sort(), [pl.col("a"), pl.col("b")]],1134[pl.col("a").slice(0, 2) * 2, pl.col("b"), [pl.col("a"), pl.col("b").head()]],1135),1136)1137def test_join_non_elementwise_keys_raises(left_on: pl.Expr, right_on: pl.Expr) -> None:1138# https://github.com/pola-rs/polars/issues/171841139left = pl.LazyFrame({"a": [1, 2, 3], "b": [3, 4, 5]})1140right = pl.LazyFrame({"a": [1, 2, 3], "b": [3, 4, 5]})11411142q = left.join(1143right,1144left_on=left_on,1145right_on=right_on,1146how="inner",1147)11481149with pytest.raises(pl.exceptions.InvalidOperationError):1150q.collect()115111521153def test_join_coalesce_not_supported_warning() -> None:1154# https://github.com/pola-rs/polars/issues/171841155left = pl.LazyFrame({"a": [1, 2, 3], "b": [3, 4, 5]})1156right = pl.LazyFrame({"a": [1, 2, 3], "b": [3, 4, 5]})11571158q = left.join(1159right,1160left_on=[pl.col("a") * 2],1161right_on=[pl.col("a") * 2],1162how="inner",1163coalesce=True,1164)1165with pytest.warns(UserWarning, match="turning off key coalescing"):1166got = q.collect()1167expect = pl.DataFrame(1168{"a": [1, 2, 3], "b": [3, 4, 5], "a_right": [1, 2, 3], "b_right": [3, 4, 5]}1169)11701171assert_frame_equal(expect, got, check_row_order=False)117211731174@pytest.mark.parametrize(1175("on_args"),1176[1177{"on": "a", "left_on": "a"},1178{"on": "a", "right_on": "a"},1179{"on": "a", "left_on": "a", "right_on": "a"},1180],1181)1182def test_join_on_and_left_right_on(on_args: dict[str, str]) -> None:1183df1 = pl.DataFrame({"a": [1], "b": [2]})1184df2 = pl.DataFrame({"a": [1], "c": [3]})1185msg = "cannot use 'on' in conjunction with 'left_on' or 'right_on'"1186with pytest.raises(ValueError, match=msg):1187df1.join(df2, **on_args) # type: ignore[arg-type]118811891190@pytest.mark.parametrize(1191("on_args"),1192[1193{"left_on": "a"},1194{"right_on": "a"},1195],1196)1197def test_join_only_left_or_right_on(on_args: dict[str, str]) -> None:1198df1 = pl.DataFrame({"a": [1]})1199df2 = pl.DataFrame({"a": [1]})1200msg = "'left_on' requires corresponding 'right_on'"1201with pytest.raises(ValueError, match=msg):1202df1.join(df2, **on_args) # type: ignore[arg-type]120312041205@pytest.mark.parametrize(1206("on_args"),1207[1208{"on": "a"},1209{"left_on": "a", "right_on": "a"},1210],1211)1212def test_cross_join_no_on_keys(on_args: dict[str, str]) -> None:1213df1 = pl.DataFrame({"a": [1, 2]})1214df2 = pl.DataFrame({"b": [3, 4]})1215msg = "cross join should not pass join keys"1216with pytest.raises(ValueError, match=msg):1217df1.join(df2, how="cross", **on_args) # type: ignore[arg-type]121812191220@pytest.mark.parametrize("set_sorted", [True, False])1221def test_left_join_slice_pushdown_19405(set_sorted: bool) -> None:1222left = pl.LazyFrame({"k": [1, 2, 3, 4, 0]})1223right = pl.LazyFrame({"k": [1, 1, 1, 1, 0]})12241225if set_sorted:1226# The data isn't actually sorted on purpose to ensure we default to a1227# hash join unless we set the sorted flag here, in case there is new1228# code in the future that automatically identifies sortedness during1229# Series construction from Python.1230left = left.set_sorted("k")1231right = right.set_sorted("k")12321233q = left.join(right, on="k", how="left", maintain_order="left_right").head(5)1234assert_frame_equal(q.collect(), pl.DataFrame({"k": [1, 1, 1, 1, 2]}))123512361237def test_join_key_type_coercion_19597() -> None:1238left = pl.LazyFrame({"a": pl.Series([1, 2, 3], dtype=pl.Float64)})1239right = pl.LazyFrame({"a": pl.Series([1, 2, 3], dtype=pl.Int64)})12401241with pytest.raises(SchemaError, match="datatypes of join keys don't match"):1242left.join(right, left_on=pl.col("a"), right_on=pl.col("a")).collect_schema()12431244with pytest.raises(SchemaError, match="datatypes of join keys don't match"):1245left.join(1246right, left_on=pl.col("a") * 2, right_on=pl.col("a") * 21247).collect_schema()124812491250def test_array_explode_join_19763() -> None:1251q = pl.LazyFrame().select(1252pl.lit(pl.Series([[1], [2]], dtype=pl.Array(pl.Int64, 1))).explode().alias("k")1253)12541255q = q.join(pl.LazyFrame({"k": [1, 2]}), on="k")12561257assert_frame_equal(q.collect().sort("k"), pl.DataFrame({"k": [1, 2]}))125812591260def test_join_full_19814() -> None:1261schema = {"a": pl.Int64, "c": pl.Categorical}1262a = pl.LazyFrame({"a": [1], "c": [None]}, schema=schema)1263b = pl.LazyFrame({"a": [1, 3, 4]})1264assert_frame_equal(1265a.join(b, on="a", how="full", coalesce=True).collect(),1266pl.DataFrame({"a": [1, 3, 4], "c": [None, None, None]}, schema=schema),1267check_row_order=False,1268)126912701271def test_join_preserve_order_inner() -> None:1272left = pl.LazyFrame({"a": [None, 2, 1, 1, 5]})1273right = pl.LazyFrame({"a": [1, 1, None, 2], "b": [6, 7, 8, 9]})12741275# Inner joins12761277inner_left = left.join(right, on="a", how="inner", maintain_order="left").collect()1278assert inner_left.get_column("a").cast(pl.UInt32).to_list() == [2, 1, 1, 1, 1]1279inner_left_right = left.join(1280right, on="a", how="inner", maintain_order="left"1281).collect()1282assert inner_left.get_column("a").equals(inner_left_right.get_column("a"))12831284inner_right = left.join(1285right, on="a", how="inner", maintain_order="right"1286).collect()1287assert inner_right.get_column("a").cast(pl.UInt32).to_list() == [1, 1, 1, 1, 2]1288inner_right_left = left.join(1289right, on="a", how="inner", maintain_order="right"1290).collect()1291assert inner_right.get_column("a").equals(inner_right_left.get_column("a"))129212931294# The new streaming engine does not provide the same maintain_order="none"1295# ordering guarantee that is currently kept for compatibility on the in-memory1296# engine.1297@pytest.mark.may_fail_auto_streaming1298def test_join_preserve_order_left() -> None:1299left = pl.LazyFrame({"a": [None, 2, 1, 1, 5]})1300right = pl.LazyFrame({"a": [1, None, 2, 6], "b": [6, 7, 8, 9]})13011302# Right now the left join algorithm is ordered without explicitly setting any order1303# This behaviour is deprecated but can only be removed in 2.01304left_none = left.join(right, on="a", how="left", maintain_order="none").collect()1305assert left_none.get_column("a").cast(pl.UInt32).to_list() == [1306None,13072,13081,13091,13105,1311]13121313left_left = left.join(right, on="a", how="left", maintain_order="left").collect()1314assert left_left.get_column("a").cast(pl.UInt32).to_list() == [1315None,13162,13171,13181,13195,1320]13211322left_left_right = left.join(1323right, on="a", how="left", maintain_order="left_right"1324).collect()1325# If the left order is preserved then there are no unsorted right rows1326assert left_left.get_column("a").equals(left_left_right.get_column("a"))13271328left_right = left.join(right, on="a", how="left", maintain_order="right").collect()1329assert left_right.get_column("a").cast(pl.UInt32).to_list()[:5] == [13301,13311,13322,1333None,13345,1335]13361337left_right_left = left.join(1338right, on="a", how="left", maintain_order="right_left"1339).collect()1340assert left_right_left.get_column("a").cast(pl.UInt32).to_list() == [13411,13421,13432,1344None,13455,1346]13471348right_left = left.join(right, on="a", how="right", maintain_order="left").collect()1349assert right_left.get_column("a").cast(pl.UInt32).to_list() == [2, 1, 1, None, 6]13501351right_right = left.join(1352right, on="a", how="right", maintain_order="right"1353).collect()1354assert right_right.get_column("a").cast(pl.UInt32).to_list() == [13551,13561,1357None,13582,13596,1360]136113621363def test_join_preserve_order_full() -> None:1364left = pl.LazyFrame({"a": [None, 2, 1, 1, 5]})1365right = pl.LazyFrame({"a": [1, None, 2, 6], "b": [6, 7, 8, 9]})13661367full_left = left.join(right, on="a", how="full", maintain_order="left").collect()1368assert full_left.get_column("a").cast(pl.UInt32).to_list()[:5] == [1369None,13702,13711,13721,13735,1374]1375full_right = left.join(right, on="a", how="full", maintain_order="right").collect()1376assert full_right.get_column("a").cast(pl.UInt32).to_list()[:5] == [13771,13781,1379None,13802,1381None,1382]13831384full_left_right = left.join(1385right, on="a", how="full", maintain_order="left_right"1386).collect()1387assert full_left_right.get_column("a_right").cast(pl.UInt32).to_list() == [1388None,13892,13901,13911,1392None,1393None,13946,1395]13961397full_right_left = left.join(1398right, on="a", how="full", maintain_order="right_left"1399).collect()1400assert full_right_left.get_column("a").cast(pl.UInt32).to_list() == [14011,14021,1403None,14042,1405None,1406None,14075,1408]140914101411@pytest.mark.parametrize(1412"dtypes",1413[1414["Int128", "Int128", "Int64"],1415["Int128", "Int128", "Int32"],1416["Int128", "Int128", "Int16"],1417["Int128", "Int128", "Int8"],1418["Int128", "UInt64", "Int128"],1419["Int128", "UInt64", "Int64"],1420["Int128", "UInt64", "Int32"],1421["Int128", "UInt64", "Int16"],1422["Int128", "UInt64", "Int8"],1423["Int128", "UInt32", "Int128"],1424["Int128", "UInt16", "Int128"],1425["Int128", "UInt8", "Int128"],14261427["Int64", "Int64", "Int32"],1428["Int64", "Int64", "Int16"],1429["Int64", "Int64", "Int8"],1430["Int64", "UInt32", "Int64"],1431["Int64", "UInt32", "Int32"],1432["Int64", "UInt32", "Int16"],1433["Int64", "UInt32", "Int8"],1434["Int64", "UInt16", "Int64"],1435["Int64", "UInt8", "Int64"],14361437["Int32", "Int32", "Int16"],1438["Int32", "Int32", "Int8"],1439["Int32", "UInt16", "Int32"],1440["Int32", "UInt16", "Int16"],1441["Int32", "UInt16", "Int8"],1442["Int32", "UInt8", "Int32"],14431444["Int16", "Int16", "Int8"],1445["Int16", "UInt8", "Int16"],1446["Int16", "UInt8", "Int8"],14471448["UInt64", "UInt64", "UInt32"],1449["UInt64", "UInt64", "UInt16"],1450["UInt64", "UInt64", "UInt8"],14511452["UInt32", "UInt32", "UInt16"],1453["UInt32", "UInt32", "UInt8"],14541455["UInt16", "UInt16", "UInt8"],14561457["Float64", "Float64", "Float32"],1458],1459) # fmt: skip1460@pytest.mark.parametrize("swap", [True, False])1461def test_join_numeric_key_upcast_15338(1462dtypes: tuple[str, str, str], swap: bool1463) -> None:1464supertype, ltype, rtype = (getattr(pl, x) for x in dtypes)1465ltype, rtype = (rtype, ltype) if swap else (ltype, rtype)14661467left = pl.select(pl.Series("a", [1, 1, 3]).cast(ltype)).lazy()1468right = pl.select(pl.Series("a", [1]).cast(rtype), b=pl.lit("A")).lazy()14691470assert_frame_equal(1471left.join(right, on="a", how="left").collect(),1472pl.select(a=pl.Series([1, 1, 3]).cast(ltype), b=pl.Series(["A", "A", None])),1473check_row_order=False,1474)14751476assert_frame_equal(1477left.join(right, on="a", how="left", coalesce=False).drop("a_right").collect(),1478pl.select(a=pl.Series([1, 1, 3]).cast(ltype), b=pl.Series(["A", "A", None])),1479check_row_order=False,1480)14811482assert_frame_equal(1483left.join(right, on="a", how="full").collect(),1484pl.select(1485a=pl.Series([1, 1, 3]).cast(ltype),1486a_right=pl.Series([1, 1, None]).cast(rtype),1487b=pl.Series(["A", "A", None]),1488),1489check_row_order=False,1490)14911492assert_frame_equal(1493left.join(right, on="a", how="full", coalesce=True).collect(),1494pl.select(1495a=pl.Series([1, 1, 3]).cast(supertype),1496b=pl.Series(["A", "A", None]),1497),1498check_row_order=False,1499)15001501assert_frame_equal(1502left.join(right, on="a", how="semi").collect(),1503pl.select(a=pl.Series([1, 1]).cast(ltype)),1504)15051506# join_where1507for optimizations in [pl.QueryOptFlags(), pl.QueryOptFlags.none()]:1508assert_frame_equal(1509left.join_where(right, pl.col("a") == pl.col("a_right")).collect(1510optimizations=optimizations,1511),1512pl.select(1513a=pl.Series([1, 1]).cast(ltype),1514a_right=pl.lit(1, dtype=rtype),1515b=pl.Series(["A", "A"]),1516),1517)151815191520def test_join_numeric_key_upcast_forbid_float_int() -> None:1521ltype = pl.Float641522rtype = pl.Int12815231524left = pl.LazyFrame({"a": [1.0, 0.0]}, schema={"a": ltype})1525right = pl.LazyFrame({"a": [1, 2]}, schema={"a": rtype})15261527# Establish baseline: In a non-join context, comparisons between ltype and1528# rtype succeed even if the upcast is lossy.1529assert_frame_equal(1530left.with_columns(right.collect()["a"].alias("a_right"))1531.select(pl.col("a") == pl.col("a_right"))1532.collect(),1533pl.DataFrame({"a": [True, False]}),1534)15351536with pytest.raises(SchemaError, match="datatypes of join keys don't match"):1537left.join(right, on="a", how="left").collect()15381539for optimizations in [pl.QueryOptFlags(), pl.QueryOptFlags.none()]:1540with pytest.raises(1541SchemaError, match="'join_where' cannot compare Float64 with Int128"1542):1543left.join_where(right, pl.col("a") == pl.col("a_right")).collect(1544optimizations=optimizations,1545)15461547with pytest.raises(1548SchemaError, match="'join_where' cannot compare Float64 with Int128"1549):1550left.join_where(1551right, pl.col("a") == (pl.col("a") == pl.col("a_right"))1552).collect(optimizations=optimizations)155315541555def test_join_numeric_key_upcast_order() -> None:1556# E.g. when we are joining on this expression:1557# * col('a') + 1271558#1559# and we want to upcast, ensure that we upcast like this:1560# * ( col('a') + 127 ) .cast(<type>)1561#1562# and *not* like this:1563# * ( col('a').cast(<type>) + lit(127).cast(<type>) )1564#1565# as otherwise the results would be different.15661567left = pl.select(pl.Series("a", [1], dtype=pl.Int8)).lazy()1568right = pl.select(1569pl.Series("a", [1, 128, -128], dtype=pl.Int64), b=pl.lit("A")1570).lazy()15711572# col('a') in `left` is Int8, the result will overflow to become -1281573left_expr = pl.col("a") + 12715741575assert_frame_equal(1576left.join(right, left_on=left_expr, right_on="a", how="inner").collect(),1577pl.DataFrame(1578{1579"a": pl.Series([1], dtype=pl.Int8),1580"a_right": pl.Series([-128], dtype=pl.Int64),1581"b": "A",1582}1583),1584)15851586assert_frame_equal(1587left.join_where(right, left_expr == pl.col("a_right")).collect(),1588pl.DataFrame(1589{1590"a": pl.Series([1], dtype=pl.Int8),1591"a_right": pl.Series([-128], dtype=pl.Int64),1592"b": "A",1593}1594),1595)15961597assert_frame_equal(1598(1599left.join(right, left_on=left_expr, right_on="a", how="full")1600.collect()1601.sort(pl.all())1602),1603pl.DataFrame(1604{1605"a": pl.Series([1, None, None], dtype=pl.Int8),1606"a_right": pl.Series([-128, 1, 128], dtype=pl.Int64),1607"b": ["A", "A", "A"],1608}1609).sort(pl.all()),1610)161116121613def test_no_collapse_join_when_maintain_order_20725() -> None:1614df1 = pl.LazyFrame({"Fraction_1": [0, 25, 50, 75, 100]})1615df2 = pl.LazyFrame({"Fraction_2": [0, 1]})1616df3 = pl.LazyFrame({"Fraction_3": [0, 1]})16171618ldf = df1.join(df2, how="cross", maintain_order="left_right").join(1619df3, how="cross", maintain_order="left_right"1620)16211622df_pl_lazy = ldf.filter(pl.col("Fraction_1") == 100).collect()1623df_pl_eager = ldf.collect().filter(pl.col("Fraction_1") == 100)16241625assert_frame_equal(df_pl_lazy, df_pl_eager)162616271628def test_join_where_predicate_type_coercion_21009() -> None:1629left_frame = pl.LazyFrame(1630{1631"left_match": ["A", "B", "C", "D", "E", "F"],1632"left_date_start": range(6),1633}1634)16351636right_frame = pl.LazyFrame(1637{1638"right_match": ["D", "E", "F", "G", "H", "I"],1639"right_date": range(6),1640}1641)16421643# Note: Cannot eq the plans as the operand sides are non-deterministic16441645q1 = left_frame.join_where(1646right_frame,1647pl.col("left_match") == pl.col("right_match"),1648pl.col("right_date") >= pl.col("left_date_start"),1649)16501651plan = q1.explain().splitlines()1652assert plan[0].strip().startswith("FILTER")1653assert plan[1] == "FROM"1654assert plan[2].strip().startswith("INNER JOIN")16551656q2 = left_frame.join_where(1657right_frame,1658pl.all_horizontal(pl.col("left_match") == pl.col("right_match")),1659pl.col("right_date") >= pl.col("left_date_start"),1660)16611662plan = q2.explain().splitlines()1663assert plan[0].strip().startswith("FILTER")1664assert plan[1] == "FROM"1665assert plan[2].strip().startswith("INNER JOIN")16661667assert_frame_equal(q1.collect(), q2.collect())166816691670def test_join_right_predicate_pushdown_21142() -> None:1671left = pl.LazyFrame({"key": [1, 2, 4], "values": ["a", "b", "c"]})1672right = pl.LazyFrame({"key": [1, 2, 3], "values": ["d", "e", "f"]})16731674rjoin = left.join(right, on="key", how="right")16751676q = rjoin.filter(pl.col("values").is_null())16771678expect = pl.select(1679pl.Series("values", [None], pl.String),1680pl.Series("key", [3], pl.Int64),1681pl.Series("values_right", ["f"], pl.String),1682)16831684assert_frame_equal(q.collect(), expect)16851686# Ensure for right join, filter on RHS key-columns are pushed down.1687q = rjoin.filter(pl.col("values_right").is_null())16881689plan = q.explain()1690assert plan.index("FILTER") > plan.index("RIGHT PLAN ON")16911692assert_frame_equal(q.collect(), expect.clear())169316941695def test_join_where_nested_expr_21066() -> None:1696left = pl.LazyFrame({"a": [1, 2]})1697right = pl.LazyFrame({"a": [1]})16981699q = left.join_where(right, pl.col("a") == (pl.col("a_right") + 1))17001701assert_frame_equal(q.collect(), pl.DataFrame({"a": 2, "a_right": 1}))170217031704def test_select_after_join_where_20831() -> None:1705left = pl.LazyFrame(1706{1707"a": [1, 2, 3, 1, None],1708"b": [1, 2, 3, 4, 5],1709"c": [2, 3, 4, 5, 6],1710}1711)17121713right = pl.LazyFrame(1714{1715"a": [1, 4, 3, 7, None, None, 1],1716"c": [2, 3, 4, 5, 6, 7, 8],1717"d": [6, None, 7, 8, -1, 2, 4],1718}1719)17201721q = left.join_where(1722right, pl.col("b") * 2 <= pl.col("a_right"), pl.col("a") < pl.col("c_right")1723)17241725assert_frame_equal(1726q.select("d").collect().sort("d"),1727pl.Series("d", [None, None, 7, 8, 8, 8]).to_frame(),1728)17291730assert q.select(pl.len()).collect().item() == 617311732q = (1733left.join(right, how="cross")1734.filter(pl.col("b") * 2 <= pl.col("a_right"))1735.filter(pl.col("a") < pl.col("c_right"))1736)17371738assert_frame_equal(1739q.select("d").collect().sort("d"),1740pl.Series("d", [None, None, 7, 8, 8, 8]).to_frame(),1741)17421743assert q.select(pl.len()).collect().item() == 6174417451746@pytest.mark.parametrize(1747("dtype", "data"),1748[1749(pl.Struct, [{"x": 1}, {"x": 2}, {"x": 3}, {"x": 4}]),1750(pl.List, [[1], [2, 2], [3, 3, 3], [4, 4, 4, 4]]),1751(pl.Array(pl.Int64, 2), [[1, 1], [2, 2], [3, 3], [4, 4]]),1752],1753)1754def test_join_on_nested(dtype: PolarsDataType, data: list[Any]) -> None:1755lhs = pl.DataFrame(1756{1757"a": data[:3],1758"b": [1, 2, 3],1759}1760)1761rhs = pl.DataFrame(1762{1763"a": [data[3], data[1]],1764"c": [4, 2],1765}1766)17671768assert_frame_equal(1769lhs.join(rhs, on="a", how="left", maintain_order="left"),1770pl.select(1771a=pl.Series(data[:3]),1772b=pl.Series([1, 2, 3]),1773c=pl.Series([None, 2, None]),1774),1775)1776assert_frame_equal(1777lhs.join(rhs, on="a", how="right", maintain_order="right"),1778pl.select(1779b=pl.Series([None, 2]),1780a=pl.Series([data[3], data[1]]),1781c=pl.Series([4, 2]),1782),1783)1784assert_frame_equal(1785lhs.join(rhs, on="a", how="inner"),1786pl.select(1787a=pl.Series([data[1]]),1788b=pl.Series([2]),1789c=pl.Series([2]),1790),1791)1792assert_frame_equal(1793lhs.join(rhs, on="a", how="full", maintain_order="left_right"),1794pl.select(1795a=pl.Series(data[:3] + [None]),1796b=pl.Series([1, 2, 3, None]),1797a_right=pl.Series([None, data[1], None, data[3]]),1798c=pl.Series([None, 2, None, 4]),1799),1800)1801assert_frame_equal(1802lhs.join(rhs, on="a", how="semi"),1803pl.select(1804a=pl.Series([data[1]]),1805b=pl.Series([2]),1806),1807)1808assert_frame_equal(1809lhs.join(rhs, on="a", how="anti", maintain_order="left"),1810pl.select(1811a=pl.Series([data[0], data[2]]),1812b=pl.Series([1, 3]),1813),1814)1815assert_frame_equal(1816lhs.join(rhs, how="cross", maintain_order="left_right"),1817pl.select(1818a=pl.Series([data[0], data[0], data[1], data[1], data[2], data[2]]),1819b=pl.Series([1, 1, 2, 2, 3, 3]),1820a_right=pl.Series([data[3], data[1], data[3], data[1], data[3], data[1]]),1821c=pl.Series([4, 2, 4, 2, 4, 2]),1822),1823)182418251826def test_empty_join_result_with_array_15474() -> None:1827lhs = pl.DataFrame(1828{1829"x": [1, 2],1830"y": pl.Series([[1, 2, 3], [4, 5, 6]], dtype=pl.Array(pl.Int64, 3)),1831}1832)1833rhs = pl.DataFrame({"x": [0]})1834result = lhs.join(rhs, on="x")1835expected = pl.DataFrame(schema={"x": pl.Int64, "y": pl.Array(pl.Int64, 3)})1836assert_frame_equal(result, expected)183718381839@pytest.mark.slow1840def test_join_where_eager_perf_21145() -> None:1841left = pl.Series("left", range(3_000)).to_frame()1842right = pl.Series("right", range(1_000)).to_frame()18431844p = pl.col("left").is_between(pl.lit(0, dtype=pl.Int64), pl.col("right"))1845runtime_eager = time_func(lambda: left.join_where(right, p))1846runtime_lazy = time_func(lambda: left.lazy().join_where(right.lazy(), p).collect())1847runtime_ratio = runtime_eager / runtime_lazy18481849# Pick as high as reasonably possible for CI stability1850# * Was observed to be >=5 seconds on the bugged version, so 3 is a safe bet.1851threshold = 318521853if runtime_ratio > threshold:1854msg = f"runtime_ratio ({runtime_ratio}) > {threshold}x ({runtime_eager = }, {runtime_lazy = })"1855raise ValueError(msg)185618571858def test_select_len_after_semi_anti_join_21343() -> None:1859lhs = pl.LazyFrame({"a": [1, 2, 3], "b": [4, 5, 6]})1860rhs = pl.LazyFrame({"a": [1, 2, 3]})18611862q = lhs.join(rhs, on="a", how="anti").select(pl.len())18631864assert q.collect().item() == 0186518661867def test_multi_leftjoin_empty_right_21701() -> None:1868parent_data = {1869"id": [1, 30, 80],1870"parent_field1": [3, 20, 17],1871}1872parent_df = pl.LazyFrame(parent_data)1873child_df = pl.LazyFrame(1874[],1875schema={"id": pl.Int32(), "parent_id": pl.Int32(), "child_field1": pl.Int32()},1876)1877subchild_df = pl.LazyFrame(1878[], schema={"child_id": pl.Int32(), "subchild_field1": pl.Int32()}1879)18801881joined_df = parent_df.join(1882child_df.join(1883subchild_df, left_on=pl.col("id"), right_on=pl.col("child_id"), how="left"1884),1885left_on=pl.col("id"),1886right_on=pl.col("parent_id"),1887how="left",1888)1889joined_df = joined_df.select("id", "parent_field1")1890assert_frame_equal(joined_df.collect(), parent_df.collect(), check_row_order=False)189118921893@pytest.mark.parametrize("order", ["none", "left_right", "right_left"])1894def test_join_null_equal(order: Literal["none", "left_right", "right_left"]) -> None:1895lhs = pl.DataFrame({"x": [1, None, None], "y": [1, 2, 3]})1896with_null = pl.DataFrame({"x": [1, None], "z": [1, 2]})1897without_null = pl.DataFrame({"x": [1, 3], "z": [1, 3]})1898check_row_order = order != "none"18991900# Inner join.1901assert_frame_equal(1902lhs.join(with_null, on="x", nulls_equal=True, maintain_order=order),1903pl.DataFrame({"x": [1, None, None], "y": [1, 2, 3], "z": [1, 2, 2]}),1904check_row_order=check_row_order,1905)1906assert_frame_equal(1907lhs.join(without_null, on="x", nulls_equal=True),1908pl.DataFrame({"x": [1], "y": [1], "z": [1]}),1909)19101911# Left join.1912assert_frame_equal(1913lhs.join(with_null, on="x", how="left", nulls_equal=True, maintain_order=order),1914pl.DataFrame({"x": [1, None, None], "y": [1, 2, 3], "z": [1, 2, 2]}),1915check_row_order=check_row_order,1916)1917assert_frame_equal(1918lhs.join(1919without_null, on="x", how="left", nulls_equal=True, maintain_order=order1920),1921pl.DataFrame({"x": [1, None, None], "y": [1, 2, 3], "z": [1, None, None]}),1922check_row_order=check_row_order,1923)19241925# Full join.1926assert_frame_equal(1927lhs.join(1928with_null,1929on="x",1930how="full",1931nulls_equal=True,1932coalesce=True,1933maintain_order=order,1934),1935pl.DataFrame({"x": [1, None, None], "y": [1, 2, 3], "z": [1, 2, 2]}),1936check_row_order=check_row_order,1937)1938if order == "left_right":1939expected = pl.DataFrame(1940{1941"x": [1, None, None, None],1942"x_right": [1, None, None, 3],1943"y": [1, 2, 3, None],1944"z": [1, None, None, 3],1945}1946)1947else:1948expected = pl.DataFrame(1949{1950"x": [1, None, None, None],1951"x_right": [1, 3, None, None],1952"y": [1, None, 2, 3],1953"z": [1, 3, None, None],1954}1955)1956assert_frame_equal(1957lhs.join(1958without_null, on="x", how="full", nulls_equal=True, maintain_order=order1959),1960expected,1961check_row_order=check_row_order,1962check_column_order=False,1963)196419651966def test_join_categorical_21815() -> None:1967left = pl.DataFrame({"x": ["a", "b", "c", "d"]}).with_columns(1968xc=pl.col.x.cast(pl.Categorical)1969)1970right = pl.DataFrame({"x": ["c", "d", "e", "f"]}).with_columns(1971xc=pl.col.x.cast(pl.Categorical)1972)19731974# As key.1975cat_key = left.join(right, on="xc", how="full")19761977# As payload.1978cat_payload = left.join(right, on="x", how="full")19791980expected = pl.DataFrame(1981{1982"x": ["a", "b", "c", "d", None, None],1983"x_right": [None, None, "c", "d", "e", "f"],1984}1985).with_columns(1986xc=pl.col.x.cast(pl.Categorical),1987xc_right=pl.col.x_right.cast(pl.Categorical),1988)19891990assert_frame_equal(1991cat_key, expected, check_row_order=False, check_column_order=False1992)1993assert_frame_equal(1994cat_payload, expected, check_row_order=False, check_column_order=False1995)199619971998def test_join_where_nested_boolean() -> None:1999df1 = pl.DataFrame({"a": [1, 9, 22], "b": [6, 4, 50]})2000df2 = pl.DataFrame({"c": [1]})20012002predicate = (pl.col("a") < pl.col("b")).cast(pl.Int32) < pl.col("c")2003result = df1.join_where(df2, predicate)2004expected = pl.DataFrame(2005{2006"a": [9],2007"b": [4],2008"c": [1],2009}2010)2011assert_frame_equal(result, expected)201220132014def test_join_where_dtype_upcast() -> None:2015df1 = pl.DataFrame(2016{2017"a": pl.Series([1, 9, 22], dtype=pl.Int8),2018"b": [6, 4, 50],2019}2020)2021df2 = pl.DataFrame({"c": [10]})20222023predicate = (pl.col("a") + (pl.col("b") > 0)) < pl.col("c")2024result = df1.join_where(df2, predicate)2025expected = pl.DataFrame(2026{2027"a": pl.Series([1], dtype=pl.Int8),2028"b": [6],2029"c": [10],2030}2031)2032assert_frame_equal(result, expected)203320342035def test_join_where_valid_dtype_upcast_same_side() -> None:2036# Unsafe comparisons are all contained entirely within one table (LHS)2037# Safe comparisons across both tables.2038df1 = pl.DataFrame(2039{2040"a": pl.Series([1, 9, 22], dtype=pl.Float32),2041"b": [6, 4, 50],2042}2043)2044df2 = pl.DataFrame({"c": [10, 1, 5]})20452046predicate = ((pl.col("a") < pl.col("b")).cast(pl.Int32) + 3) < pl.col("c")2047result = df1.join_where(df2, predicate).sort("a", "b", "c")2048expected = pl.DataFrame(2049{2050"a": pl.Series([1, 1, 9, 9, 22, 22], dtype=pl.Float32),2051"b": [6, 6, 4, 4, 50, 50],2052"c": [5, 10, 5, 10, 5, 10],2053}2054)2055assert_frame_equal(result, expected)205620572058def test_join_where_invalid_dtype_upcast_different_side() -> None:2059# Unsafe comparisons exist across tables.2060df1 = pl.DataFrame(2061{2062"a": pl.Series([1, 9, 22], dtype=pl.Float32),2063"b": pl.Series([6, 4, 50], dtype=pl.Float64),2064}2065)2066df2 = pl.DataFrame({"c": [10, 1, 5]})20672068predicate = ((pl.col("a") >= pl.col("c")) + 3) < 42069with pytest.raises(2070SchemaError, match="'join_where' cannot compare Float32 with Int64"2071):2072df1.join_where(df2, predicate)20732074# add in a cast to predicate to fix2075predicate = ((pl.col("a").cast(pl.UInt8) >= pl.col("c")) + 3) < 42076result = df1.join_where(df2, predicate).sort("a", "b", "c")2077expected = pl.DataFrame(2078{2079"a": pl.Series([1, 1, 9], dtype=pl.Float32),2080"b": pl.Series([6, 6, 4], dtype=pl.Float64),2081"c": [5, 10, 10],2082}2083)2084assert_frame_equal(result, expected)208520862087@pytest.mark.parametrize("dtype", [pl.Int32, pl.Float32])2088def test_join_where_literals(dtype: PolarsDataType) -> None:2089df1 = pl.DataFrame({"a": pl.Series([0, 1], dtype=dtype)})2090df2 = pl.DataFrame({"b": pl.Series([1, 2], dtype=dtype)})2091result = df1.join_where(df2, (pl.col("a") + pl.col("b")) < 2)2092expected = pl.DataFrame(2093{2094"a": pl.Series([0], dtype=dtype),2095"b": pl.Series([1], dtype=dtype),2096}2097)2098assert_frame_equal(result, expected)209921002101def test_join_where_categorical_string_compare() -> None:2102dt = pl.Enum(["a", "b", "c"])2103df1 = pl.DataFrame({"a": pl.Series(["a", "a", "b", "c"], dtype=dt)})2104df2 = pl.DataFrame({"b": [1, 6, 4]})2105predicate = pl.col("a").is_in(["a", "b"]) & (pl.col("b") < 5)2106result = df1.join_where(df2, predicate).sort("a", "b")2107expected = pl.DataFrame(2108{2109"a": pl.Series(["a", "a", "a", "a", "b", "b"], dtype=dt),2110"b": [1, 1, 4, 4, 1, 4],2111}2112)2113assert_frame_equal(result, expected)211421152116def test_join_where_nonboolean_predicate() -> None:2117df1 = pl.DataFrame({"a": [1, 2, 3]})2118df2 = pl.DataFrame({"b": [1, 2, 3]})2119with pytest.raises(2120ComputeError, match="'join_where' predicates must resolve to boolean"2121):2122df1.join_where(df2, pl.col("a") * 2)212321242125def test_empty_outer_join_22206() -> None:2126df = pl.LazyFrame({"a": [5, 6], "b": [1, 2]})2127empty = pl.LazyFrame(schema=df.collect_schema())2128assert_frame_equal(2129df.join(empty, on=["a", "b"], how="full", coalesce=True),2130df,2131check_row_order=False,2132)2133assert_frame_equal(2134empty.join(df, on=["a", "b"], how="full", coalesce=True),2135df,2136check_row_order=False,2137)213821392140def test_join_coalesce_22498() -> None:2141df_a = pl.DataFrame({"y": [2]})2142df_b = pl.DataFrame({"x": [1], "y": [2]})2143df_j = df_a.lazy().join(df_b.lazy(), how="full", on="y", coalesce=True)2144assert_frame_equal(df_j.collect(), pl.DataFrame({"y": [2], "x": [1]}))214521462147def _extract_plan_joins_and_filters(plan: str) -> list[str]:2148return [2149x2150for x in (x.strip() for x in plan.splitlines())2151if x.startswith("LEFT PLAN") # noqa: PIE8102152or x.startswith("RIGHT PLAN")2153or x.startswith("FILTER")2154]215521562157def test_join_filter_pushdown_inner_join() -> None:2158lhs = pl.LazyFrame(2159{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, 4, None], "c": ["a", "b", "c", "d", "e"]}2160)2161rhs = pl.LazyFrame(2162{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, None, 5], "c": ["A", "B", "C", "D", "E"]}2163)21642165# Filter on key output column is pushed to both sides.2166q = lhs.join(rhs, on=["a", "b"], how="inner", maintain_order="left_right").filter(2167pl.col("b") <= 22168)21692170expect = pl.DataFrame(2171{"a": [1, 2], "b": [1, 2], "c": ["a", "b"], "c_right": ["A", "B"]}2172)21732174plan = q.explain()21752176assert _extract_plan_joins_and_filters(plan) == [2177'LEFT PLAN ON: [col("a"), col("b")]',2178'FILTER [(col("b")) <= (2)]',2179'RIGHT PLAN ON: [col("a"), col("b")]',2180'FILTER [(col("b")) <= (2)]',2181]21822183assert_frame_equal(q.collect(), expect)2184assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)21852186# Side-specific filters are all pushed for inner join.2187q = (2188lhs.join(rhs, on=["a", "b"], how="inner", maintain_order="left_right")2189.filter(pl.col("b") <= 2)2190.filter(pl.col("c") == "a", pl.col("c_right") == "A")2191)21922193expect = pl.DataFrame({"a": [1], "b": [1], "c": ["a"], "c_right": ["A"]})21942195plan = q.explain()21962197extract = _extract_plan_joins_and_filters(plan)21982199assert extract[0] == 'LEFT PLAN ON: [col("a"), col("b")]'2200assert 'col("c")) == ("a")' in extract[1]2201assert 'col("b")) <= (2)' in extract[1]22022203assert extract[2] == 'RIGHT PLAN ON: [col("a"), col("b")]'2204assert 'col("b")) <= (2)' in extract[3]2205assert 'col("c")) == ("A")' in extract[3]22062207assert_frame_equal(q.collect(), expect)2208assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)22092210# Filter applied to the non-coalesced `_right` column of an inner-join is2211# also pushed to the left2212# input table.2213q = lhs.join(2214rhs, on=["a", "b"], how="inner", coalesce=False, maintain_order="left_right"2215).filter(pl.col("a_right") <= 2)22162217expect = pl.DataFrame(2218{2219"a": [1, 2],2220"b": [1, 2],2221"c": ["a", "b"],2222"a_right": [1, 2],2223"b_right": [1, 2],2224"c_right": ["A", "B"],2225}2226)22272228plan = q.explain()22292230extract = _extract_plan_joins_and_filters(plan)2231assert extract == [2232'LEFT PLAN ON: [col("a"), col("b")]',2233'FILTER [(col("a")) <= (2)]',2234'RIGHT PLAN ON: [col("a"), col("b")]',2235'FILTER [(col("a")) <= (2)]',2236]22372238assert_frame_equal(q.collect(), expect)2239assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)22402241# Different names in left_on and right_on2242q = lhs.join(2243rhs, left_on="a", right_on="b", how="inner", maintain_order="left_right"2244).filter(pl.col("a") <= 2)22452246expect = pl.DataFrame(2247{2248"a": [1, 2],2249"b": [1, 2],2250"c": ["a", "b"],2251"a_right": [1, 2],2252"c_right": ["A", "B"],2253}2254)22552256plan = q.explain()22572258extract = _extract_plan_joins_and_filters(plan)2259assert extract == [2260'LEFT PLAN ON: [col("a")]',2261'FILTER [(col("a")) <= (2)]',2262'RIGHT PLAN ON: [col("b")]',2263'FILTER [(col("b")) <= (2)]',2264]22652266assert_frame_equal(q.collect(), expect)2267assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)22682269# Different names in left_on and right_on, coalesce=False2270q = lhs.join(2271rhs,2272left_on="a",2273right_on="b",2274how="inner",2275coalesce=False,2276maintain_order="left_right",2277).filter(pl.col("a") <= 2)22782279expect = pl.DataFrame(2280{2281"a": [1, 2],2282"b": [1, 2],2283"c": ["a", "b"],2284"a_right": [1, 2],2285"b_right": [1, 2],2286"c_right": ["A", "B"],2287}2288)22892290plan = q.explain()22912292extract = _extract_plan_joins_and_filters(plan)2293assert extract == [2294'LEFT PLAN ON: [col("a")]',2295'FILTER [(col("a")) <= (2)]',2296'RIGHT PLAN ON: [col("b")]',2297'FILTER [(col("b")) <= (2)]',2298]22992300assert_frame_equal(q.collect(), expect)2301assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)23022303# left_on=col(A), right_on=lit(1). Filters referencing col(A) can only push2304# to the left side.2305q = lhs.join(2306rhs,2307left_on=["a", pl.lit(1)],2308right_on=[pl.lit(1), "b"],2309how="inner",2310coalesce=False,2311maintain_order="left_right",2312).filter(2313pl.col("a") == 1,2314pl.col("b") >= 1,2315pl.col("a_right") <= 1,2316pl.col("b_right") >= 0,2317)23182319expect = pl.DataFrame(2320{2321"a": [1],2322"b": [1],2323"c": ["a"],2324"a_right": [1],2325"b_right": [1],2326"c_right": ["A"],2327}2328)23292330plan = q.explain()23312332extract = _extract_plan_joins_and_filters(plan)23332334assert (2335extract[0]2336== 'LEFT PLAN ON: [col("a").cast(Int64), col("_POLARS_0").cast(Int64)]'2337)2338assert '(col("a")) == (1)' in extract[1]2339assert '(col("b")) >= (1)' in extract[1]2340assert (2341extract[2]2342== 'RIGHT PLAN ON: [col("_POLARS_1").cast(Int64), col("b").cast(Int64)]'2343)2344assert '(col("b")) >= (0)' in extract[3]2345assert 'col("a")) <= (1)' in extract[3]23462347assert_frame_equal(q.collect(), expect)2348assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)23492350# Filters don't pass if they refer to columns from both tables2351# TODO: In the optimizer we can add additional equalities into the join2352# condition itself for some cases.2353q = lhs.join(rhs, on=["a"], how="inner", maintain_order="left_right").filter(2354pl.col("b") == pl.col("b_right")2355)23562357expect = pl.DataFrame(2358{2359"a": [1, 2, 3],2360"b": [1, 2, 3],2361"c": ["a", "b", "c"],2362"b_right": [1, 2, 3],2363"c_right": ["A", "B", "C"],2364}2365)23662367plan = q.explain()23682369extract = _extract_plan_joins_and_filters(plan)2370assert extract == [2371'FILTER [(col("b")) == (col("b_right"))]',2372'LEFT PLAN ON: [col("a")]',2373'RIGHT PLAN ON: [col("a")]',2374]23752376assert_frame_equal(q.collect(), expect)2377assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)23782379# Duplicate filter removal - https://github.com/pola-rs/polars/issues/232432380q = (2381pl.LazyFrame({"x": [1, 2, 3]})2382.join(pl.LazyFrame({"x": [1, 2, 3]}), on="x", how="inner", coalesce=False)2383.filter(2384pl.col("x") == 2,2385pl.col("x_right") == 2,2386)2387)23882389expect = pl.DataFrame(2390[2391pl.Series("x", [2], dtype=pl.Int64),2392pl.Series("x_right", [2], dtype=pl.Int64),2393]2394)23952396plan = q.explain()23972398extract = _extract_plan_joins_and_filters(plan)23992400assert extract == [2401'LEFT PLAN ON: [col("x")]',2402'FILTER [(col("x")) == (2)]',2403'RIGHT PLAN ON: [col("x")]',2404'FILTER [(col("x")) == (2)]',2405]24062407assert_frame_equal(q.collect(), expect)2408assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)240924102411def test_join_filter_pushdown_left_join() -> None:2412lhs = pl.LazyFrame(2413{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, 4, None], "c": ["a", "b", "c", "d", "e"]}2414)2415rhs = pl.LazyFrame(2416{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, None, 5], "c": ["A", "B", "C", "D", "E"]}2417)24182419# Filter on key output column is pushed to both sides.2420q = lhs.join(rhs, on=["a", "b"], how="left", maintain_order="left_right").filter(2421pl.col("b") <= 22422)24232424expect = pl.DataFrame(2425{"a": [1, 2], "b": [1, 2], "c": ["a", "b"], "c_right": ["A", "B"]}2426)24272428plan = q.explain()24292430extract = _extract_plan_joins_and_filters(plan)2431assert extract == [2432'LEFT PLAN ON: [col("a"), col("b")]',2433'FILTER [(col("b")) <= (2)]',2434'RIGHT PLAN ON: [col("a"), col("b")]',2435'FILTER [(col("b")) <= (2)]',2436]24372438assert_frame_equal(q.collect(), expect)2439assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)24402441# Filter on key output column is pushed to both sides.2442# This tests joins on differing left/right names.2443q = lhs.join(2444rhs, left_on="a", right_on="b", how="left", maintain_order="left_right"2445).filter(pl.col("a") <= 2)24462447expect = pl.DataFrame(2448{2449"a": [1, 2],2450"b": [1, 2],2451"c": ["a", "b"],2452"a_right": [1, 2],2453"c_right": ["A", "B"],2454}2455)24562457plan = q.explain()24582459extract = _extract_plan_joins_and_filters(plan)2460assert extract == [2461'LEFT PLAN ON: [col("a")]',2462'FILTER [(col("a")) <= (2)]',2463'RIGHT PLAN ON: [col("b")]',2464'FILTER [(col("b")) <= (2)]',2465]24662467assert_frame_equal(q.collect(), expect)2468assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)24692470# Filters referring to columns that exist only in the left table can be pushed.2471q = lhs.join(rhs, on=["a", "b"], how="left", maintain_order="left_right").filter(2472pl.col("c") == "b"2473)24742475expect = pl.DataFrame({"a": [2], "b": [2], "c": ["b"], "c_right": ["B"]})24762477plan = q.explain()24782479extract = _extract_plan_joins_and_filters(plan)2480assert extract == [2481'LEFT PLAN ON: [col("a"), col("b")]',2482'FILTER [(col("c")) == ("b")]',2483'RIGHT PLAN ON: [col("a"), col("b")]',2484]24852486assert_frame_equal(q.collect(), expect)2487assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)24882489# Filters referring to columns that exist only in the right table cannot be2490# pushed for left-join2491q = lhs.join(rhs, on=["a", "b"], how="left", maintain_order="left_right").filter(2492# Note: `eq_missing` to block join downgrade.2493pl.col("c_right").eq_missing("B")2494)24952496expect = pl.DataFrame({"a": [2], "b": [2], "c": ["b"], "c_right": ["B"]})24972498plan = q.explain()24992500extract = _extract_plan_joins_and_filters(plan)2501assert extract == [2502'FILTER [(col("c_right")) ==v ("B")]',2503'LEFT PLAN ON: [col("a"), col("b")]',2504'RIGHT PLAN ON: [col("a"), col("b")]',2505]25062507assert_frame_equal(q.collect(), expect)2508assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)25092510# Filters referring to a non-coalesced key column originating from the right2511# table cannot be pushed.2512#2513# Note, technically it's possible to push these filters if we can guarantee that2514# they do not remove NULLs (or otherwise if we also apply the filter on the2515# result table). But this is not something we do at the moment.2516q = lhs.join(2517rhs, on=["a", "b"], how="left", coalesce=False, maintain_order="left_right"2518).filter(pl.col("b_right").eq_missing(2))25192520expect = pl.DataFrame(2521{2522"a": [2],2523"b": [2],2524"c": ["b"],2525"a_right": [2],2526"b_right": [2],2527"c_right": ["B"],2528}2529)25302531plan = q.explain()25322533extract = _extract_plan_joins_and_filters(plan)2534assert extract == [2535'FILTER [(col("b_right")) ==v (2)]',2536'LEFT PLAN ON: [col("a"), col("b")]',2537'RIGHT PLAN ON: [col("a"), col("b")]',2538]25392540assert_frame_equal(q.collect(), expect)2541assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)254225432544def test_join_filter_pushdown_right_join() -> None:2545lhs = pl.LazyFrame(2546{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, 4, None], "c": ["a", "b", "c", "d", "e"]}2547)2548rhs = pl.LazyFrame(2549{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, None, 5], "c": ["A", "B", "C", "D", "E"]}2550)25512552# Filter on key output column is pushed to both sides.2553q = lhs.join(rhs, on=["a", "b"], how="right", maintain_order="left_right").filter(2554pl.col("b") <= 22555)25562557expect = pl.DataFrame(2558{"c": ["a", "b"], "a": [1, 2], "b": [1, 2], "c_right": ["A", "B"]}2559)25602561plan = q.explain()25622563extract = _extract_plan_joins_and_filters(plan)2564assert extract == [2565'LEFT PLAN ON: [col("a"), col("b")]',2566'FILTER [(col("b")) <= (2)]',2567'RIGHT PLAN ON: [col("a"), col("b")]',2568'FILTER [(col("b")) <= (2)]',2569]25702571assert_frame_equal(q.collect(), expect)2572assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)25732574# Filter on key output column is pushed to both sides.2575# This tests joins on differing left/right names.2576# col(A) is coalesced into col(B) (from right), but col(B) is named as2577# col(B_right) in the output because the LHS table also has a col(B).2578q = lhs.join(2579rhs, left_on="a", right_on="b", how="right", maintain_order="left_right"2580).filter(pl.col("b_right") <= 2)25812582expect = pl.DataFrame(2583{2584"b": [1, 2],2585"c": ["a", "b"],2586"a": [1, 2],2587"b_right": [1, 2],2588"c_right": ["A", "B"],2589}2590)25912592plan = q.explain()25932594extract = _extract_plan_joins_and_filters(plan)2595assert extract == [2596'LEFT PLAN ON: [col("a")]',2597'FILTER [(col("a")) <= (2)]',2598'RIGHT PLAN ON: [col("b")]',2599'FILTER [(col("b")) <= (2)]',2600]26012602assert_frame_equal(q.collect(), expect)2603assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)26042605# Filters referring to columns that exist only in the right table can be pushed.2606q = lhs.join(rhs, on=["a", "b"], how="right", maintain_order="left_right").filter(2607pl.col("c_right") == "B"2608)26092610expect = pl.DataFrame({"c": ["b"], "a": [2], "b": [2], "c_right": ["B"]})26112612plan = q.explain()26132614extract = _extract_plan_joins_and_filters(plan)2615assert extract == [2616'LEFT PLAN ON: [col("a"), col("b")]',2617'RIGHT PLAN ON: [col("a"), col("b")]',2618'FILTER [(col("c")) == ("B")]',2619]26202621assert_frame_equal(q.collect(), expect)2622assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)26232624# Filters referring to columns that exist only in the left table cannot be2625# pushed for right-join2626q = lhs.join(rhs, on=["a", "b"], how="right", maintain_order="left_right").filter(2627# Note: eq_missing to block join downgrade2628pl.col("c").eq_missing("b")2629)26302631expect = pl.DataFrame({"c": ["b"], "a": [2], "b": [2], "c_right": ["B"]})26322633plan = q.explain()26342635extract = _extract_plan_joins_and_filters(plan)2636assert extract == [2637'FILTER [(col("c")) ==v ("b")]',2638'LEFT PLAN ON: [col("a"), col("b")]',2639'RIGHT PLAN ON: [col("a"), col("b")]',2640]26412642assert_frame_equal(q.collect(), expect)2643assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)26442645# Filters referring to a non-coalesced key column originating from the left2646# table cannot be pushed for right-join.2647q = lhs.join(2648rhs, on=["a", "b"], how="right", coalesce=False, maintain_order="left_right"2649).filter(pl.col("b").eq_missing(2))26502651expect = pl.DataFrame(2652{2653"a": [2],2654"b": [2],2655"c": ["b"],2656"a_right": [2],2657"b_right": [2],2658"c_right": ["B"],2659}2660)26612662plan = q.explain()26632664extract = _extract_plan_joins_and_filters(plan)2665assert extract == [2666'FILTER [(col("b")) ==v (2)]',2667'LEFT PLAN ON: [col("a"), col("b")]',2668'RIGHT PLAN ON: [col("a"), col("b")]',2669]26702671assert_frame_equal(q.collect(), expect)2672assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)267326742675def test_join_filter_pushdown_full_join() -> None:2676lhs = pl.LazyFrame(2677{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, 4, None], "c": ["a", "b", "c", "d", "e"]}2678)2679rhs = pl.LazyFrame(2680{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, None, 5], "c": ["A", "B", "C", "D", "E"]}2681)26822683# Full join can only push filters that refer to coalesced key columns.2684q = lhs.join(2685rhs,2686left_on="a",2687right_on="b",2688how="full",2689coalesce=True,2690maintain_order="left_right",2691).filter(pl.col("a") == 2)26922693expect = pl.DataFrame(2694{2695"a": [2],2696"b": [2],2697"c": ["b"],2698"a_right": [2],2699"c_right": ["B"],2700}2701)27022703plan = q.explain()2704extract = _extract_plan_joins_and_filters(plan)27052706assert extract == [2707'LEFT PLAN ON: [col("a")]',2708'FILTER [(col("a")) == (2)]',2709'RIGHT PLAN ON: [col("b")]',2710'FILTER [(col("b")) == (2)]',2711]27122713assert_frame_equal(q.collect(), expect)2714assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)27152716# Non-coalescing full-join cannot push any filters2717# Note: We add fill_null to bypass non-NULL filter mask detection.2718q = lhs.join(2719rhs,2720left_on="a",2721right_on="b",2722how="full",2723coalesce=False,2724maintain_order="left_right",2725).filter(2726pl.col("a").fill_null(0) >= 2,2727pl.col("a").fill_null(0) <= 2,2728)27292730expect = pl.DataFrame(2731{2732"a": [2],2733"b": [2],2734"c": ["b"],2735"a_right": [2],2736"b_right": [2],2737"c_right": ["B"],2738}2739)27402741plan = q.explain()2742extract = _extract_plan_joins_and_filters(plan)27432744assert extract[0].startswith("FILTER ")2745assert extract[1:] == [2746'LEFT PLAN ON: [col("a")]',2747'RIGHT PLAN ON: [col("b")]',2748]27492750assert_frame_equal(q.collect(), expect)2751assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)275227532754def test_join_filter_pushdown_semi_join() -> None:2755lhs = pl.LazyFrame(2756{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, 4, None], "c": ["a", "b", "c", "d", "e"]}2757)2758rhs = pl.LazyFrame(2759{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, None, 5], "c": ["A", "B", "C", "D", "E"]}2760)27612762q = lhs.join(2763rhs,2764left_on=["a", "b"],2765right_on=["b", pl.lit(2)],2766how="semi",2767maintain_order="left_right",2768).filter(pl.col("a") == 2, pl.col("b") == 2, pl.col("c") == "b")27692770expect = pl.DataFrame(2771{2772"a": [2],2773"b": [2],2774"c": ["b"],2775}2776)27772778plan = q.explain()2779extract = _extract_plan_joins_and_filters(plan)27802781# * filter on col(a) is pushed to both sides (renamed to col(b) in the right side)2782# * filter on col(b) is pushed only to left, as the right join key is a literal2783# * filter on col(c) is pushed only to left, as the column does not exist in2784# the right.27852786assert extract[0] == 'LEFT PLAN ON: [col("a"), col("b").cast(Int64)]'2787assert 'col("a")) == (2)' in extract[1]2788assert 'col("b")) == (2)' in extract[1]2789assert 'col("c")) == ("b")' in extract[1]27902791assert extract[2:] == [2792'RIGHT PLAN ON: [col("b"), col("_POLARS_0").cast(Int64)]',2793'FILTER [(col("b")) == (2)]',2794]27952796assert_frame_equal(q.collect(), expect)2797assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)279827992800def test_join_filter_pushdown_anti_join() -> None:2801lhs = pl.LazyFrame(2802{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, 4, None], "c": ["a", "b", "c", "d", "e"]}2803)2804rhs = pl.LazyFrame(2805{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, None, 5], "c": ["A", "B", "C", "D", "E"]}2806)28072808q = lhs.join(2809rhs,2810left_on=["a", "b"],2811right_on=["b", pl.lit(1)],2812how="anti",2813maintain_order="left_right",2814).filter(pl.col("a") == 2, pl.col("b") == 2, pl.col("c") == "b")28152816expect = pl.DataFrame(2817{2818"a": [2],2819"b": [2],2820"c": ["b"],2821}2822)28232824plan = q.explain()2825extract = _extract_plan_joins_and_filters(plan)28262827assert extract[0] == 'LEFT PLAN ON: [col("a"), col("b").cast(Int64)]'2828assert 'col("a")) == (2)' in extract[1]2829assert 'col("b")) == (2)' in extract[1]2830assert 'col("c")) == ("b")' in extract[1]28312832assert extract[2:] == [2833'RIGHT PLAN ON: [col("b"), col("_POLARS_0").cast(Int64)]',2834'FILTER [(col("b")) == (2)]',2835]28362837assert_frame_equal(q.collect(), expect)2838assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)283928402841def test_join_filter_pushdown_cross_join() -> None:2842lhs = pl.LazyFrame(2843{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, 4, None], "c": ["a", "b", "c", "d", "e"]}2844)2845rhs = pl.LazyFrame(2846{"a": [0, 0, 0, 0, 0], "b": [1, 2, 3, None, 5], "c": ["A", "B", "C", "D", "E"]}2847)28482849# Nested loop join for `!=`2850q = (2851lhs.with_row_index()2852.join(rhs, how="cross")2853.filter(2854pl.col("a") <= 4, pl.col("c_right") <= "B", pl.col("a") != pl.col("a_right")2855)2856.sort("index")2857)28582859expect = pl.DataFrame(2860[2861pl.Series("index", [0, 0, 1, 1, 2, 2, 3, 3], dtype=pl.get_index_type()),2862pl.Series("a", [1, 1, 2, 2, 3, 3, 4, 4], dtype=pl.Int64),2863pl.Series("b", [1, 1, 2, 2, 3, 3, 4, 4], dtype=pl.Int64),2864pl.Series("c", ["a", "a", "b", "b", "c", "c", "d", "d"], dtype=pl.String),2865pl.Series("a_right", [0, 0, 0, 0, 0, 0, 0, 0], dtype=pl.Int64),2866pl.Series("b_right", [1, 2, 1, 2, 1, 2, 1, 2], dtype=pl.Int64),2867pl.Series(2868"c_right", ["A", "B", "A", "B", "A", "B", "A", "B"], dtype=pl.String2869),2870]2871)28722873plan = q.explain()28742875assert 'NESTED LOOP JOIN ON [(col("a")) != (col("a_right"))]' in plan28762877extract = _extract_plan_joins_and_filters(plan)28782879assert extract == [2880"LEFT PLAN:",2881'FILTER [(col("a")) <= (4)]',2882"RIGHT PLAN:",2883'FILTER [(col("c")) <= ("B")]',2884]28852886assert_frame_equal(q.collect(), expect)2887assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)28882889# Conversion to inner-join for `==`2890q = lhs.join(rhs, how="cross", maintain_order="left_right").filter(2891pl.col("a") <= 4,2892pl.col("c_right") <= "B",2893pl.col("a") == (pl.col("a_right") + 1),2894)28952896expect = pl.DataFrame(2897{2898"a": [1, 1],2899"b": [1, 1],2900"c": ["a", "a"],2901"a_right": [0, 0],2902"b_right": [1, 2],2903"c_right": ["A", "B"],2904}2905)29062907plan = q.explain()29082909extract = _extract_plan_joins_and_filters(plan)29102911assert extract == [2912'LEFT PLAN ON: [col("a")]',2913'FILTER [(col("a")) <= (4)]',2914'RIGHT PLAN ON: [[(col("a")) + (1)]]',2915'FILTER [(col("c")) <= ("B")]',2916]29172918assert_frame_equal(q.collect(), expect)2919assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)292029212922def test_join_filter_pushdown_iejoin() -> None:2923lhs = pl.LazyFrame(2924{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, 4, None], "c": ["a", "b", "c", "d", "e"]}2925)2926rhs = pl.LazyFrame(2927{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, None, 5], "c": ["A", "B", "C", "D", "E"]}2928)29292930q = (2931lhs.with_row_index()2932.join_where(2933rhs,2934pl.col("a") >= 1,2935pl.col("a") == pl.col("a_right"),2936pl.col("c_right") <= "B",2937)2938.sort("index")2939)29402941expect = pl.DataFrame(2942{2943"a": [1, 2],2944"b": [1, 2],2945"c": ["a", "b"],2946"a_right": [1, 2],2947"b_right": [1, 2],2948"c_right": ["A", "B"],2949}2950).with_row_index()29512952plan = q.explain()29532954assert "INNER JOIN" in plan29552956extract = _extract_plan_joins_and_filters(plan)29572958assert extract == [2959'LEFT PLAN ON: [col("a")]',2960'FILTER [(col("a")) >= (1)]',2961'RIGHT PLAN ON: [col("a")]',2962'FILTER [(col("c")) <= ("B")]',2963]29642965assert_frame_equal(q.collect(), expect)2966assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)29672968q = (2969lhs.with_row_index()2970.join_where(2971rhs,2972pl.col("a") >= 1,2973pl.col("a") >= pl.col("a_right"),2974pl.col("c_right") <= "B",2975)2976.sort("index")2977)29782979expect = pl.DataFrame(2980[2981pl.Series("index", [0, 1, 1, 2, 2, 3, 3, 4, 4], dtype=pl.get_index_type()),2982pl.Series("a", [1, 2, 2, 3, 3, 4, 4, 5, 5], dtype=pl.Int64),2983pl.Series("b", [1, 2, 2, 3, 3, 4, 4, None, None], dtype=pl.Int64),2984pl.Series(2985"c", ["a", "b", "b", "c", "c", "d", "d", "e", "e"], dtype=pl.String2986),2987pl.Series("a_right", [1, 2, 1, 2, 1, 2, 1, 2, 1], dtype=pl.Int64),2988pl.Series("b_right", [1, 2, 1, 2, 1, 2, 1, 2, 1], dtype=pl.Int64),2989pl.Series(2990"c_right",2991["A", "B", "A", "B", "A", "B", "A", "B", "A"],2992dtype=pl.String,2993),2994]2995)29962997plan = q.explain()29982999assert "IEJOIN" in plan30003001extract = _extract_plan_joins_and_filters(plan)30023003assert extract == [3004'LEFT PLAN ON: [col("a")]',3005'FILTER [(col("a")) >= (1)]',3006'RIGHT PLAN ON: [col("a")]',3007'FILTER [(col("c")) <= ("B")]',3008]30093010assert_frame_equal(q.collect(), expect)3011assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)301230133014def test_join_filter_pushdown_asof_join() -> None:3015lhs = pl.LazyFrame(3016{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, 4, None], "c": ["a", "b", "c", "d", "e"]}3017)3018rhs = pl.LazyFrame(3019{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, None, 5], "c": ["A", "B", "C", "D", "E"]}3020)30213022q = lhs.join_asof(3023rhs,3024left_on=pl.col("a").set_sorted(),3025right_on=pl.col("b").set_sorted(),3026tolerance=0,3027).filter(3028pl.col("a") >= 2,3029pl.col("b") >= 3,3030pl.col("c") >= "A",3031pl.col("c_right") >= "B",3032)30333034expect = pl.DataFrame(3035{3036"a": [3],3037"b": [3],3038"c": ["c"],3039"a_right": [3],3040"b_right": [3],3041"c_right": ["C"],3042}3043)30443045plan = q.explain()3046extract = _extract_plan_joins_and_filters(plan)30473048assert extract[:2] == [3049'FILTER [(col("c_right")) >= ("B")]',3050'LEFT PLAN ON: [col("a").set_sorted()]',3051]30523053assert 'col("b")) >= (3)' in extract[2]3054assert 'col("c")) >= ("A")' in extract[2]3055assert 'col("a")) >= (2)' in extract[2]30563057assert extract[3:] == ['RIGHT PLAN ON: [col("b").set_sorted()]']30583059assert_frame_equal(q.collect(), expect)3060assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)30613062# With "by" columns3063q = lhs.join_asof(3064rhs,3065left_on="a",3066right_on="b",3067tolerance=99,3068by_left="b",3069by_right="a",3070).filter(3071pl.col("a") >= 2,3072pl.col("b") >= 3,3073pl.col("c") >= "A",3074pl.col("c_right") >= "B",3075)30763077expect = pl.DataFrame(3078{3079"a": [3],3080"b": [3],3081"c": ["c"],3082"b_right": [3],3083"c_right": ["C"],3084}3085)30863087plan = q.explain()3088extract = _extract_plan_joins_and_filters(plan)30893090assert extract[:2] == [3091'FILTER [(col("c_right")) >= ("B")]',3092'LEFT PLAN ON: [col("a")]',3093]3094assert 'col("a")) >= (2)' in extract[2]3095assert 'col("b")) >= (3)' in extract[2]30963097assert extract[3:] == [3098'RIGHT PLAN ON: [col("b")]',3099'FILTER [(col("a")) >= (3)]',3100]31013102assert_frame_equal(q.collect(), expect)3103assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)310431053106def test_join_filter_pushdown_full_join_rewrite() -> None:3107lhs = pl.LazyFrame(3108{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, 4, None], "c": ["a", "b", "c", "d", "e"]}3109)3110rhs = pl.LazyFrame(3111{3112"a": [1, 2, 3, 4, None],3113"b": [1, 2, 3, None, 5],3114"c": ["A", "B", "C", "D", "E"],3115}3116)31173118# Downgrades to left-join3119q = lhs.join(rhs, on=["a", "b"], how="full", maintain_order="left_right").filter(3120pl.col("b") >= 33121)31223123expect = pl.DataFrame(3124[3125pl.Series("a", [3, 4], dtype=pl.Int64),3126pl.Series("b", [3, 4], dtype=pl.Int64),3127pl.Series("c", ["c", "d"], dtype=pl.String),3128pl.Series("a_right", [3, None], dtype=pl.Int64),3129pl.Series("b_right", [3, None], dtype=pl.Int64),3130pl.Series("c_right", ["C", None], dtype=pl.String),3131]3132)31333134plan = q.explain()31353136assert "FULL JOIN" not in plan3137assert plan.startswith("LEFT JOIN")31383139extract = _extract_plan_joins_and_filters(plan)31403141assert extract == [3142'LEFT PLAN ON: [col("a"), col("b")]',3143'FILTER [(col("b")) >= (3)]',3144'RIGHT PLAN ON: [col("a"), col("b")]',3145'FILTER [(col("b")) >= (3)]',3146]31473148assert_frame_equal(q.collect(), expect)3149assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)31503151# Downgrades to right-join3152q = lhs.join(3153rhs, left_on="a", right_on="b", how="full", maintain_order="left_right"3154).filter(pl.col("b_right") >= 3)31553156expect = pl.DataFrame(3157[3158pl.Series("a", [3, 5], dtype=pl.Int64),3159pl.Series("b", [3, None], dtype=pl.Int64),3160pl.Series("c", ["c", "e"], dtype=pl.String),3161pl.Series("a_right", [3, None], dtype=pl.Int64),3162pl.Series("b_right", [3, 5], dtype=pl.Int64),3163pl.Series("c_right", ["C", "E"], dtype=pl.String),3164]3165)31663167plan = q.explain()31683169assert "FULL JOIN" not in plan3170assert "RIGHT JOIN" in plan31713172extract = _extract_plan_joins_and_filters(plan)31733174assert extract == [3175'LEFT PLAN ON: [col("a")]',3176'FILTER [(col("a")) >= (3)]',3177'RIGHT PLAN ON: [col("b")]',3178'FILTER [(col("b")) >= (3)]',3179]31803181assert_frame_equal(q.collect(), expect)3182assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)31833184# Downgrades to right-join3185q = lhs.join(3186rhs,3187left_on="a",3188right_on="b",3189how="full",3190coalesce=True,3191maintain_order="left_right",3192).filter(3193(pl.col("a") >= 1) | pl.col("a").is_null(), # col(a) from LHS3194pl.col("a_right") >= 3, # col(a) from RHS3195(pl.col("b") >= 2) | pl.col("b").is_null(), # col(b) from LHS3196pl.col("c_right") >= "C", # col(c) from RHS3197)31983199expect = pl.DataFrame(3200[3201pl.Series("a", [3, None], dtype=pl.Int64),3202pl.Series("b", [3, None], dtype=pl.Int64),3203pl.Series("c", ["c", None], dtype=pl.String),3204pl.Series("a_right", [3, 4], dtype=pl.Int64),3205pl.Series("c_right", ["C", "D"], dtype=pl.String),3206]3207)32083209plan = q.explain()32103211assert "FULL JOIN" not in plan3212assert "RIGHT JOIN" in plan32133214extract = _extract_plan_joins_and_filters(plan)32153216assert [3217'FILTER [([(col("b")) >= (2)]) | (col("b").is_null())]',3218'LEFT PLAN ON: [col("a")]',3219'FILTER [([(col("a")) >= (1)]) | (col("a").is_null())]',3220'RIGHT PLAN ON: [col("b")]',3221]32223223assert 'col("a")) >= (3)' in extract[4]3224assert '(col("b")) >= (1)]) | (col("b").alias("a").is_null())' in extract[4]3225assert 'col("c")) >= ("C")' in extract[4]32263227assert len(extract) == 532283229assert_frame_equal(q.collect(), expect)3230assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)32313232# Downgrades to inner-join3233q = lhs.join(rhs, on=["a", "b"], how="full", maintain_order="left_right").filter(3234pl.col("b").is_not_null(), pl.col("b_right").is_not_null()3235)32363237expect = pl.DataFrame(3238[3239pl.Series("a", [1, 2, 3], dtype=pl.Int64),3240pl.Series("b", [1, 2, 3], dtype=pl.Int64),3241pl.Series("c", ["a", "b", "c"], dtype=pl.String),3242pl.Series("a_right", [1, 2, 3], dtype=pl.Int64),3243pl.Series("b_right", [1, 2, 3], dtype=pl.Int64),3244pl.Series("c_right", ["A", "B", "C"], dtype=pl.String),3245]3246)32473248plan = q.explain()32493250assert "FULL JOIN" not in plan3251assert plan.startswith("INNER JOIN")32523253extract = _extract_plan_joins_and_filters(plan)32543255assert extract[0] == 'LEFT PLAN ON: [col("a"), col("b")]'3256assert 'col("b").is_not_null()' in extract[1]3257assert 'col("b").alias("b_right").is_not_null()' in extract[1]32583259assert extract[2] == 'RIGHT PLAN ON: [col("a"), col("b")]'3260assert 'col("b").is_not_null()' in extract[3]3261assert 'col("b").alias("b_right").is_not_null()' in extract[3]32623263assert len(extract) == 432643265assert_frame_equal(q.collect(), expect)3266assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)32673268# Does not downgrade because col(b) is a coalesced key-column, but the filter3269# is still pushed to both sides.3270q = lhs.join(3271rhs, on=["a", "b"], how="full", coalesce=True, maintain_order="left_right"3272).filter(pl.col("b") >= 3)32733274expect = pl.DataFrame(3275[3276pl.Series("a", [3, 4, None], dtype=pl.Int64),3277pl.Series("b", [3, 4, 5], dtype=pl.Int64),3278pl.Series("c", ["c", "d", None], dtype=pl.String),3279pl.Series("c_right", ["C", None, "E"], dtype=pl.String),3280]3281)32823283plan = q.explain()3284assert plan.startswith("FULL JOIN")32853286extract = _extract_plan_joins_and_filters(plan)32873288assert extract == [3289'LEFT PLAN ON: [col("a"), col("b")]',3290'FILTER [(col("b")) >= (3)]',3291'RIGHT PLAN ON: [col("a"), col("b")]',3292'FILTER [(col("b")) >= (3)]',3293]32943295assert_frame_equal(q.collect(), expect)3296assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)329732983299def test_join_filter_pushdown_right_join_rewrite() -> None:3300lhs = pl.LazyFrame(3301{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, 4, None], "c": ["a", "b", "c", "d", "e"]}3302)3303rhs = pl.LazyFrame(3304{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, None, 5], "c": ["A", "B", "C", "D", "E"]}3305)33063307# Downgrades to inner-join3308q = lhs.join(3309rhs,3310left_on="a",3311right_on="b",3312how="right",3313coalesce=True,3314maintain_order="left_right",3315).filter(3316pl.col("a") <= 7, # col(a) from RHS (LHS col(a) is coalesced into col(b_right))3317pl.col("b_right") <= 10, # Key-column filter3318pl.col("c") <= "b", # col(c) from LHS3319)33203321expect = pl.DataFrame(3322[3323pl.Series("b", [1, 2], dtype=pl.Int64),3324pl.Series("c", ["a", "b"], dtype=pl.String),3325pl.Series("a", [1, 2], dtype=pl.Int64),3326pl.Series("b_right", [1, 2], dtype=pl.Int64),3327pl.Series("c_right", ["A", "B"], dtype=pl.String),3328]3329)33303331plan = q.explain()33323333assert "RIGHT JOIN" not in plan3334assert "INNER JOIN" in plan33353336extract = _extract_plan_joins_and_filters(plan)33373338assert extract[0] == 'LEFT PLAN ON: [col("a")]'3339assert 'col("a")) <= (10)' in extract[1]3340assert 'col("c")) <= ("b")' in extract[1]33413342assert extract[2] == 'RIGHT PLAN ON: [col("b")]'3343assert 'col("a")) <= (7)' in extract[3]3344assert 'col("b")) <= (10)' in extract[3]33453346assert len(extract) == 433473348assert_frame_equal(q.collect(), expect)3349assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)335033513352def test_join_filter_pushdown_join_rewrite_equality_above_and() -> None:3353lhs = pl.LazyFrame(3354{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, 4, None], "c": ["a", "b", "c", "d", "e"]}3355)3356rhs = pl.LazyFrame(3357{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, None, 5], "c": ["A", "B", "C", "D", "E"]}3358)33593360q = lhs.join(3361rhs,3362left_on="a",3363right_on="b",3364how="full",3365coalesce=False,3366maintain_order="left_right",3367).filter(((pl.col("b") >= 3) & False) >= False)33683369expect = pl.DataFrame(3370[3371pl.Series("a", [1, 2, 3, 4, 5, None], dtype=pl.Int64),3372pl.Series("b", [1, 2, 3, 4, None, None], dtype=pl.Int64),3373pl.Series("c", ["a", "b", "c", "d", "e", None], dtype=pl.String),3374pl.Series("a_right", [1, 2, 3, None, 5, 4], dtype=pl.Int64),3375pl.Series("b_right", [1, 2, 3, None, 5, None], dtype=pl.Int64),3376pl.Series("c_right", ["A", "B", "C", None, "E", "D"], dtype=pl.String),3377]3378)33793380assert_frame_equal(q.collect(), expect)3381assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)338233833384def test_join_filter_pushdown_left_join_rewrite() -> None:3385lhs = pl.LazyFrame(3386{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, 4, None], "c": ["a", "b", "c", "d", "e"]}3387)3388rhs = pl.LazyFrame(3389{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, None, 5], "c": ["A", "B", None, "D", "E"]}3390)33913392# Downgrades to inner-join3393q = lhs.join(3394rhs,3395left_on="a",3396right_on="b",3397how="left",3398coalesce=True,3399maintain_order="left_right",3400).filter(pl.col("c_right") <= "B")34013402expect = pl.DataFrame(3403[3404pl.Series("a", [1, 2], dtype=pl.Int64),3405pl.Series("b", [1, 2], dtype=pl.Int64),3406pl.Series("c", ["a", "b"], dtype=pl.String),3407pl.Series("a_right", [1, 2], dtype=pl.Int64),3408pl.Series("c_right", ["A", "B"], dtype=pl.String),3409]3410)34113412plan = q.explain()34133414assert "LEFT JOIN" not in plan3415assert plan.startswith("INNER JOIN")34163417extract = _extract_plan_joins_and_filters(plan)34183419assert extract == [3420'LEFT PLAN ON: [col("a")]',3421'RIGHT PLAN ON: [col("b")]',3422'FILTER [(col("c")) <= ("B")]',3423]34243425assert_frame_equal(q.collect(), expect)3426assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)342734283429def test_join_filter_pushdown_left_join_rewrite_23133() -> None:3430lhs = pl.LazyFrame(3431{3432"foo": [1, 2, 3],3433"bar": [6.0, 7.0, 8.0],3434"ham": ["a", "b", "c"],3435}3436)34373438rhs = pl.LazyFrame(3439{3440"apple": ["x", "y", "z"],3441"ham": ["a", "b", "d"],3442"bar": ["a", "b", "c"],3443"foo2": [1, 2, 3],3444}3445)34463447q = lhs.join(rhs, how="left", on="ham", maintain_order="left_right").filter(3448pl.col("ham") == "a", pl.col("apple") == "x", pl.col("foo") <= 23449)34503451expect = pl.DataFrame(3452[3453pl.Series("foo", [1], dtype=pl.Int64),3454pl.Series("bar", [6.0], dtype=pl.Float64),3455pl.Series("ham", ["a"], dtype=pl.String),3456pl.Series("apple", ["x"], dtype=pl.String),3457pl.Series("bar_right", ["a"], dtype=pl.String),3458pl.Series("foo2", [1], dtype=pl.Int64),3459]3460)34613462plan = q.explain()3463assert "FULL JOIN" not in plan3464assert plan.startswith("INNER JOIN")34653466extract = _extract_plan_joins_and_filters(plan)34673468assert extract[0] == 'LEFT PLAN ON: [col("ham")]'3469assert '(col("foo")) <= (2)' in extract[1]3470assert 'col("ham")) == ("a")' in extract[1]34713472assert extract[2] == 'RIGHT PLAN ON: [col("ham")]'3473assert 'col("ham")) == ("a")' in extract[3]3474assert 'col("apple")) == ("x")' in extract[3]34753476assert len(extract) == 434773478assert_frame_equal(q.collect(), expect)3479assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)348034813482def test_join_rewrite_panic_23307() -> None:3483lhs = pl.select(a=pl.lit(1, dtype=pl.Int8)).lazy()3484rhs = pl.select(a=pl.lit(1, dtype=pl.Int16), x=pl.lit(1, dtype=pl.Int32)).lazy()34853486q = lhs.join(rhs, on="a", how="left", coalesce=True).filter(pl.col("x") >= 1)34873488assert_frame_equal(3489q.collect(),3490pl.select(3491a=pl.lit(1, dtype=pl.Int8),3492x=pl.lit(1, dtype=pl.Int32),3493),3494)34953496lhs = pl.select(a=pl.lit(999, dtype=pl.Int16)).lazy()34973498# Note: -25 matches to (999).overflowing_cast(Int8).3499# This is specially chosen to test that we don't accidentally push the filter3500# to the RHS.3501rhs = pl.LazyFrame(3502{"a": [1, -25], "x": [1, 2]}, schema={"a": pl.Int8, "x": pl.Int32}3503)35043505q = lhs.join(3506rhs,3507on=pl.col("a").cast(pl.Int8, strict=False, wrap_numerical=True),3508how="left",3509coalesce=False,3510).filter(pl.col("a") >= 0)35113512expect = pl.DataFrame(3513{"a": 999, "a_right": -25, "x": 2},3514schema={"a": pl.Int16, "a_right": pl.Int8, "x": pl.Int32},3515)35163517plan = q.explain()35183519assert not plan.startswith("FILTER")35203521assert_frame_equal(q.collect(), expect)3522assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)352335243525@pytest.mark.parametrize(3526("expr_first_input", "expr_func"),3527[3528(pl.lit(None, dtype=pl.Int64), lambda col: col >= 1),3529(pl.lit(None, dtype=pl.Int64), lambda col: (col >= 1).is_not_null()),3530(pl.lit(None, dtype=pl.Int64), lambda col: (~(col >= 1)).is_not_null()),3531(pl.lit(None, dtype=pl.Int64), lambda col: ~(col >= 1).is_null()),3532#3533(pl.lit(None, dtype=pl.Int64), lambda col: col.is_in([1])),3534(pl.lit(None, dtype=pl.Int64), lambda col: ~col.is_in([1])),3535#3536(pl.lit(None, dtype=pl.Int64), lambda col: col.is_between(1, 1)),3537(1, lambda col: col.is_between(None, 1)),3538(1, lambda col: col.is_between(1, None)),3539#3540(pl.lit(None, dtype=pl.Int64), lambda col: col.is_close(1)),3541(1, lambda col: col.is_close(pl.lit(None, dtype=pl.Int64))),3542#3543(pl.lit(None, dtype=pl.Int64), lambda col: col.is_nan()),3544(pl.lit(None, dtype=pl.Int64), lambda col: col.is_not_nan()),3545(pl.lit(None, dtype=pl.Int64), lambda col: col.is_finite()),3546(pl.lit(None, dtype=pl.Int64), lambda col: col.is_infinite()),3547#3548(pl.lit(None, dtype=pl.Float64), lambda col: col.is_nan()),3549(pl.lit(None, dtype=pl.Float64), lambda col: col.is_not_nan()),3550(pl.lit(None, dtype=pl.Float64), lambda col: col.is_finite()),3551(pl.lit(None, dtype=pl.Float64), lambda col: col.is_infinite()),3552],3553)3554def test_join_rewrite_null_preserving_exprs(3555expr_first_input: Any, expr_func: Callable[[pl.Expr], pl.Expr]3556) -> None:3557lhs = pl.LazyFrame({"a": 1})3558rhs = pl.select(a=1, x=expr_first_input).lazy()35593560assert (3561pl.select(expr_first_input)3562.select(expr_func(pl.first()))3563.select(pl.first().is_null() | ~pl.first())3564.to_series()3565.item()3566)35673568q = lhs.join(rhs, on="a", how="left", maintain_order="left_right").filter(3569expr_func(pl.col("x"))3570)35713572plan = q.explain()3573assert plan.startswith("INNER JOIN")35743575out = q.collect()35763577assert out.height == 03578assert_frame_equal(out, q.collect(optimizations=pl.QueryOptFlags.none()))357935803581@pytest.mark.parametrize(3582("expr_first_input", "expr_func"),3583[3584(3585pl.lit(None, dtype=pl.Int64),3586lambda x: ~(x.is_in([1, None], nulls_equal=True)),3587),3588(3589pl.lit(None, dtype=pl.Int64),3590lambda x: x.is_in([1, None], nulls_equal=True) > True,3591),3592(3593pl.lit(None, dtype=pl.Int64),3594lambda x: x.is_in([1], nulls_equal=True),3595),3596],3597)3598def test_join_rewrite_forbid_exprs(3599expr_first_input: Any, expr_func: Callable[[pl.Expr], pl.Expr]3600) -> None:3601lhs = pl.LazyFrame({"a": 1})3602rhs = pl.select(a=1, x=expr_first_input).lazy()36033604q = lhs.join(rhs, on="a", how="left", maintain_order="left_right").filter(3605expr_func(pl.col("x"))3606)36073608plan = q.explain()3609assert plan.startswith("FILTER")36103611assert_frame_equal(q.collect(), q.collect(optimizations=pl.QueryOptFlags.none()))361236133614def test_join_filter_pushdown_iejoin_cse_23469() -> None:3615lf_x = pl.LazyFrame({"x": [1, 2, 3]})3616lf_y = pl.LazyFrame({"y": [1, 2, 3]})36173618lf_xy = lf_x.join(lf_y, how="cross").filter(pl.col("x") > pl.col("y"))36193620q = pl.concat([lf_xy, lf_xy])36213622assert_frame_equal(3623q.collect().sort(pl.all()),3624pl.DataFrame(3625{3626"x": [2, 2, 3, 3, 3, 3],3627"y": [1, 1, 1, 1, 2, 2],3628},3629),3630)36313632q = pl.concat([lf_xy, lf_xy]).filter(pl.col("x") > pl.col("y"))36333634assert_frame_equal(3635q.collect().sort(pl.all()),3636pl.DataFrame(3637{3638"x": [2, 2, 3, 3, 3, 3],3639"y": [1, 1, 1, 1, 2, 2],3640},3641),3642)36433644q = (3645lf_x.join_where(lf_y, pl.col("x") == pl.col("y"))3646.cache()3647.filter(pl.col("x") >= 0)3648)36493650assert_frame_equal(3651q.collect().sort(pl.all()), pl.DataFrame({"x": [1, 2, 3], "y": [1, 2, 3]})3652)365336543655def test_join_cast_type_coercion_23236() -> None:3656lhs = pl.LazyFrame([{"name": "a"}]).rename({"name": "newname"})3657rhs = pl.LazyFrame([{"name": "a"}])36583659q = lhs.join(rhs, left_on=pl.col("newname").cast(pl.String), right_on="name")36603661assert_frame_equal(q.collect(), pl.DataFrame({"newname": "a", "name": "a"}))366236633664@pytest.mark.parametrize(3665("how", "expected"),3666[3667(3668"inner",3669pl.DataFrame(schema={"a": pl.Int128, "a_right": pl.Int128}),3670),3671(3672"left",3673pl.DataFrame(3674{"a": [1, 1, 2], "a_right": None},3675schema={"a": pl.Int128, "a_right": pl.Int128},3676),3677),3678(3679"right",3680pl.DataFrame(3681{3682"a": None,3683"a_right": [3684-9223372036854775808,3685-9223372036854775807,3686-9223372036854775806,3687],3688},3689schema={"a": pl.Int128, "a_right": pl.Int128},3690),3691),3692(3693"full",3694pl.DataFrame(3695[3696pl.Series("a", [None, None, None, 1, 1, 2], dtype=pl.Int128),3697pl.Series(3698"a_right",3699[3700-9223372036854775808,3701-9223372036854775807,3702-9223372036854775806,3703None,3704None,3705None,3706],3707dtype=pl.Int128,3708),3709]3710),3711),3712(3713"semi",3714pl.DataFrame([pl.Series("a", [], dtype=pl.Int128)]),3715),3716(3717"anti",3718pl.DataFrame([pl.Series("a", [1, 1, 2], dtype=pl.Int128)]),3719),3720],3721)3722@pytest.mark.parametrize(3723("sort_left", "sort_right"),3724[(True, True), (True, False), (False, True), (False, False)],3725)3726def test_join_i128_23688(3727how: str, expected: pl.DataFrame, sort_left: bool, sort_right: bool3728) -> None:3729lhs = pl.LazyFrame({"a": pl.Series([1, 1, 2], dtype=pl.Int128)})37303731rhs = pl.LazyFrame(3732{3733"a": pl.Series(3734[3735-9223372036854775808,3736-9223372036854775807,3737-9223372036854775806,3738],3739dtype=pl.Int128,3740)3741}3742)37433744lhs = lhs.collect().sort("a").lazy() if sort_left else lhs3745rhs = rhs.collect().sort("a").lazy() if sort_right else rhs37463747q = lhs.join(rhs, on="a", how=how, coalesce=False) # type: ignore[arg-type]37483749assert_frame_equal(3750q.collect().sort(pl.all()),3751expected,3752)37533754q = (3755lhs.with_columns(b=pl.col("a"))3756.join(3757rhs.with_columns(b=pl.col("a")),3758on=["a", "b"],3759how=how, # type: ignore[arg-type]3760coalesce=False,3761)3762.select(expected.columns)3763)37643765assert_frame_equal(3766q.collect().sort(pl.all()),3767expected,3768)376937703771def test_join_asof_by_i128() -> None:3772lhs = pl.LazyFrame({"a": pl.Series([1, 1, 2], dtype=pl.Int128), "i": 1})37733774rhs = pl.LazyFrame(3775{3776"a": pl.Series(3777[3778-9223372036854775808,3779-9223372036854775807,3780-9223372036854775806,3781],3782dtype=pl.Int128,3783),3784"i": 1,3785}3786).with_columns(b=pl.col("a"))37873788q = lhs.join_asof(rhs, on="i", by="a")37893790assert_frame_equal(3791q.collect().sort(pl.all()),3792pl.DataFrame(3793{"a": [1, 1, 2], "i": 1, "b": None},3794schema={"a": pl.Int128, "i": pl.Int32, "b": pl.Int128},3795),3796)379737983799