Path: blob/main/py-polars/tests/unit/operations/test_join.py
8424 views
from __future__ import annotations12import typing3import warnings4from datetime import date, datetime5from typing import TYPE_CHECKING, Any, Literal67import numpy as np8import pandas as pd9import pytest1011import polars as pl12from polars.exceptions import (13ColumnNotFoundError,14ComputeError,15DuplicateError,16InvalidOperationError,17SchemaError,18)19from polars.testing import assert_frame_equal, assert_series_equal20from tests.unit.conftest import time_func2122if TYPE_CHECKING:23from collections.abc import Callable2425from polars._typing import JoinStrategy, PolarsDataType262728def test_semi_anti_join() -> None:29df_a = pl.DataFrame({"key": [1, 2, 3], "payload": ["f", "i", None]})3031df_b = pl.DataFrame({"key": [3, 4, 5, None]})3233assert df_a.join(df_b, on="key", how="anti").to_dict(as_series=False) == {34"key": [1, 2],35"payload": ["f", "i"],36}37assert df_a.join(df_b, on="key", how="semi").to_dict(as_series=False) == {38"key": [3],39"payload": [None],40}4142# lazy43result = df_a.lazy().join(df_b.lazy(), on="key", how="anti").collect()44expected_values = {"key": [1, 2], "payload": ["f", "i"]}45assert result.to_dict(as_series=False) == expected_values4647result = df_a.lazy().join(df_b.lazy(), on="key", how="semi").collect()48expected_values = {"key": [3], "payload": [None]}49assert result.to_dict(as_series=False) == expected_values5051df_a = pl.DataFrame(52{"a": [1, 2, 3, 1], "b": ["a", "b", "c", "a"], "payload": [10, 20, 30, 40]}53)5455df_b = pl.DataFrame({"a": [3, 3, 4, 5], "b": ["c", "c", "d", "e"]})5657assert df_a.join(df_b, on=["a", "b"], how="anti").to_dict(as_series=False) == {58"a": [1, 2, 1],59"b": ["a", "b", "a"],60"payload": [10, 20, 40],61}62assert df_a.join(df_b, on=["a", "b"], how="semi").to_dict(as_series=False) == {63"a": [3],64"b": ["c"],65"payload": [30],66}676869def test_join_same_cat_src() -> None:70df = pl.DataFrame(71data={"column": ["a", "a", "b"], "more": [1, 2, 3]},72schema=[("column", pl.Categorical), ("more", pl.Int32)],73)74df_agg = df.group_by("column").agg(pl.col("more").mean())75assert_frame_equal(76df.join(df_agg, on="column"),77pl.DataFrame(78{79"column": ["a", "a", "b"],80"more": [1, 2, 3],81"more_right": [1.5, 1.5, 3.0],82},83schema=[84("column", pl.Categorical),85("more", pl.Int32),86("more_right", pl.Float64),87],88),89check_row_order=False,90)919293@pytest.mark.parametrize("reverse", [False, True])94def test_sorted_merge_joins(reverse: bool) -> None:95n = 3096df_a = pl.DataFrame({"a": np.sort(np.random.randint(0, n // 2, n))}).with_row_index(97"row_a"98)99df_b = pl.DataFrame(100{"a": np.sort(np.random.randint(0, n // 2, n // 2))}101).with_row_index("row_b")102103if reverse:104df_a = df_a.select(pl.all().reverse())105df_b = df_b.select(pl.all().reverse())106107join_strategies: list[JoinStrategy] = ["left", "inner"]108for cast_to in [int, str, float]:109for how in join_strategies:110df_a_ = df_a.with_columns(pl.col("a").cast(cast_to))111df_b_ = df_b.with_columns(pl.col("a").cast(cast_to))112113# hash join114out_hash_join = df_a_.join(df_b_, on="a", how=how)115116# sorted merge join117out_sorted_merge_join = df_a_.with_columns(118pl.col("a").set_sorted(descending=reverse)119).join(120df_b_.with_columns(pl.col("a").set_sorted(descending=reverse)),121on="a",122how=how,123)124125assert_frame_equal(126out_hash_join, out_sorted_merge_join, check_row_order=False127)128129130def test_join_negative_integers() -> None:131expected = pl.DataFrame({"a": [-6, -1, 0], "b": [-6, -1, 0]})132df1 = pl.DataFrame(133{134"a": [-1, -6, -3, 0],135}136)137138df2 = pl.DataFrame(139{140"a": [-6, -1, -4, -2, 0],141"b": [-6, -1, -4, -2, 0],142}143)144145for dt in [pl.Int8, pl.Int16, pl.Int32, pl.Int64]:146assert_frame_equal(147df1.with_columns(pl.all().cast(dt)).join(148df2.with_columns(pl.all().cast(dt)), on="a", how="inner"149),150expected.select(pl.all().cast(dt)),151check_row_order=False,152)153154155def test_deprecated() -> None:156df = pl.DataFrame({"a": [1, 2], "b": [3, 4]})157other = pl.DataFrame({"a": [1, 2], "c": [3, 4]})158result = pl.DataFrame({"a": [1, 2], "b": [3, 4], "c": [3, 4]})159160np.testing.assert_equal(161df.join(other=other, on="a", maintain_order="left").to_numpy(),162result.to_numpy(),163)164np.testing.assert_equal(165df.lazy()166.join(other=other.lazy(), on="a", maintain_order="left")167.collect()168.to_numpy(),169result.to_numpy(),170)171172173def test_deprecated_parameter_join_nulls() -> None:174df = pl.DataFrame({"a": [1, None]})175with pytest.deprecated_call(176match=r"the argument `join_nulls` for `DataFrame.join` is deprecated. It was renamed to `nulls_equal`"177):178result = df.join(df, on="a", join_nulls=True) # type: ignore[call-arg]179assert_frame_equal(result, df, check_row_order=False)180181182def test_join_on_expressions() -> None:183df_a = pl.DataFrame({"a": [1, 2, 3]})184185df_b = pl.DataFrame({"b": [1, 4, 9, 9, 0]})186187assert_frame_equal(188df_a.join(df_b, left_on=(pl.col("a") ** 2).cast(int), right_on=pl.col("b")),189pl.DataFrame({"a": [1, 2, 3, 3], "b": [1, 4, 9, 9]}),190check_row_order=False,191)192193194def test_join_lazy_frame_on_expression() -> None:195# Tests a lazy frame projection pushdown bug196# https://github.com/pola-rs/polars/issues/19822197198df = pl.DataFrame(data={"a": [0, 1], "b": [2, 3]})199200lazy_join = (201df.lazy()202.join(df.lazy(), left_on=pl.coalesce("b", "a"), right_on="a")203.select("a")204.collect()205)206207eager_join = df.join(df, left_on=pl.coalesce("b", "a"), right_on="a").select("a")208209assert lazy_join.shape == eager_join.shape210211212def test_right_join_schema_maintained_22516() -> None:213df_left = pl.DataFrame({"number": [1]})214df_right = pl.DataFrame({"invoice_number": [1]})215eager_join = df_left.join(216df_right, left_on="number", right_on="invoice_number", how="right"217).select(pl.len())218219lazy_join = (220df_left.lazy()221.join(df_right.lazy(), left_on="number", right_on="invoice_number", how="right")222.select(pl.len())223.collect()224)225226assert lazy_join.item() == eager_join.item()227228229def test_join() -> None:230df_left = pl.DataFrame(231{232"a": ["a", "b", "a", "z"],233"b": [1, 2, 3, 4],234"c": [6, 5, 4, 3],235}236)237df_right = pl.DataFrame(238{239"a": ["b", "c", "b", "a"],240"k": [0, 3, 9, 6],241"c": [1, 0, 2, 1],242}243)244245joined = df_left.join(246df_right, left_on="a", right_on="a", maintain_order="left_right"247).sort("a")248assert_series_equal(joined["b"], pl.Series("b", [1, 3, 2, 2]))249250joined = df_left.join(251df_right, left_on="a", right_on="a", how="left", maintain_order="left_right"252).sort("a")253assert joined["c_right"].is_null().sum() == 1254assert_series_equal(joined["b"], pl.Series("b", [1, 3, 2, 2, 4]))255256joined = df_left.join(df_right, left_on="a", right_on="a", how="full").sort("a")257assert joined["c_right"].null_count() == 1258assert joined["c"].null_count() == 1259assert joined["b"].null_count() == 1260assert joined["k"].null_count() == 1261assert joined["a"].null_count() == 1262263# we need to pass in a column to join on, either by supplying `on`, or both264# `left_on` and `right_on`265with pytest.raises(ValueError):266df_left.join(df_right)267with pytest.raises(ValueError):268df_left.join(df_right, right_on="a")269with pytest.raises(ValueError):270df_left.join(df_right, left_on="a")271272df_a = pl.DataFrame({"a": [1, 2, 1, 1], "b": ["a", "b", "c", "c"]})273df_b = pl.DataFrame(274{"foo": [1, 1, 1], "bar": ["a", "c", "c"], "ham": ["let", "var", "const"]}275)276277# just check if join on multiple columns runs278df_a.join(df_b, left_on=["a", "b"], right_on=["foo", "bar"])279eager_join = df_a.join(df_b, left_on="a", right_on="foo")280lazy_join = df_a.lazy().join(df_b.lazy(), left_on="a", right_on="foo").collect()281282cols = ["a", "b", "bar", "ham"]283assert lazy_join.shape == eager_join.shape284assert_frame_equal(lazy_join.sort(by=cols), eager_join.sort(by=cols))285286287def test_joins_dispatch() -> None:288# this just flexes the dispatch a bit289290# don't change the data of this dataframe, this triggered:291# https://github.com/pola-rs/polars/issues/1688292dfa = pl.DataFrame(293{294"a": ["a", "b", "c", "a"],295"b": [1, 2, 3, 1],296"date": ["2021-01-01", "2021-01-02", "2021-01-03", "2021-01-01"],297"datetime": [13241324, 12341256, 12341234, 13241324],298}299).with_columns(300pl.col("date").str.strptime(pl.Date), pl.col("datetime").cast(pl.Datetime)301)302303join_strategies: list[JoinStrategy] = ["left", "inner", "full"]304for how in join_strategies:305dfa.join(dfa, on=["a", "b", "date", "datetime"], how=how)306dfa.join(dfa, on=["date", "datetime"], how=how)307dfa.join(dfa, on=["date", "datetime", "a"], how=how)308dfa.join(dfa, on=["date", "a"], how=how)309dfa.join(dfa, on=["a", "datetime"], how=how)310dfa.join(dfa, on=["date"], how=how)311312313def test_join_on_cast() -> None:314df_a = (315pl.DataFrame({"a": [-5, -2, 3, 3, 9, 10]})316.with_row_index()317.with_columns(pl.col("a").cast(pl.Int32))318)319320df_b = pl.DataFrame({"a": [-2, -3, 3, 10]})321322assert_frame_equal(323df_a.join(df_b, on=pl.col("a").cast(pl.Int64)),324pl.DataFrame(325{326"index": [1, 2, 3, 5],327"a": [-2, 3, 3, 10],328"a_right": [-2, 3, 3, 10],329}330),331check_row_order=False,332check_dtypes=False,333)334assert df_a.lazy().join(335df_b.lazy(),336on=pl.col("a").cast(pl.Int64),337maintain_order="left",338).collect().to_dict(as_series=False) == {339"index": [1, 2, 3, 5],340"a": [-2, 3, 3, 10],341"a_right": [-2, 3, 3, 10],342}343344345def test_join_chunks_alignment_4720() -> None:346# https://github.com/pola-rs/polars/issues/4720347348df1 = pl.DataFrame(349{350"index1": pl.arange(0, 2, eager=True),351"index2": pl.arange(10, 12, eager=True),352}353)354355df2 = pl.DataFrame(356{357"index3": pl.arange(100, 102, eager=True),358}359)360361df3 = pl.DataFrame(362{363"index1": pl.arange(0, 2, eager=True),364"index2": pl.arange(10, 12, eager=True),365"index3": pl.arange(100, 102, eager=True),366}367)368assert_frame_equal(369df1.join(df2, how="cross").join(370df3,371on=["index1", "index2", "index3"],372how="left",373),374pl.DataFrame(375{376"index1": [0, 0, 1, 1],377"index2": [10, 10, 11, 11],378"index3": [100, 101, 100, 101],379}380),381check_row_order=False,382)383384assert_frame_equal(385df1.join(df2, how="cross").join(386df3,387on=["index3", "index1", "index2"],388how="left",389),390pl.DataFrame(391{392"index1": [0, 0, 1, 1],393"index2": [10, 10, 11, 11],394"index3": [100, 101, 100, 101],395}396),397check_row_order=False,398)399400401def test_jit_sort_joins() -> None:402n = 200403# Explicitly specify numpy dtype because of different defaults on Windows404dfa = pd.DataFrame(405{406"a": np.random.randint(0, 100, n, dtype=np.int64),407"b": np.arange(0, n, dtype=np.int64),408}409)410411n = 40412dfb = pd.DataFrame(413{414"a": np.random.randint(0, 100, n, dtype=np.int64),415"b": np.arange(0, n, dtype=np.int64),416}417)418dfa_pl = pl.from_pandas(dfa).sort("a")419dfb_pl = pl.from_pandas(dfb)420421join_strategies: list[Literal["left", "inner"]] = ["left", "inner"]422for how in join_strategies:423pd_result = dfa.merge(dfb, on="a", how=how)424pd_result.columns = pd.Index(["a", "b", "b_right"])425426# left key sorted right is not427pl_result = dfa_pl.join(dfb_pl, on="a", how=how).sort(["a", "b", "b_right"])428429a = (430pl.from_pandas(pd_result)431.with_columns(pl.all().cast(int))432.sort(["a", "b", "b_right"])433)434assert_frame_equal(a, pl_result)435assert pl_result["a"].flags["SORTED_ASC"]436437# left key sorted right is not438pd_result = dfb.merge(dfa, on="a", how=how)439pd_result.columns = pd.Index(["a", "b", "b_right"])440pl_result = dfb_pl.join(dfa_pl, on="a", how=how).sort(["a", "b", "b_right"])441442a = (443pl.from_pandas(pd_result)444.with_columns(pl.all().cast(int))445.sort(["a", "b", "b_right"])446)447assert_frame_equal(a, pl_result)448assert pl_result["a"].flags["SORTED_ASC"]449450451def test_join_panic_on_binary_expr_5915() -> None:452df_a = pl.DataFrame({"a": [1, 2, 3]}).lazy()453df_b = pl.DataFrame({"b": [1, 4, 9, 9, 0]}).lazy()454455z = df_a.join(df_b, left_on=[(pl.col("a") + 1).cast(int)], right_on=[pl.col("b")])456assert z.collect().to_dict(as_series=False) == {"a": [3], "b": [4]}457458459def test_semi_join_projection_pushdown_6423() -> None:460df1 = pl.DataFrame({"x": [1]}).lazy()461df2 = pl.DataFrame({"y": [1], "x": [1]}).lazy()462463assert (464df1.join(df2, left_on="x", right_on="y", how="semi")465.join(df2, left_on="x", right_on="y", how="semi")466.select(["x"])467).collect().to_dict(as_series=False) == {"x": [1]}468469470def test_semi_join_projection_pushdown_6455() -> None:471df = pl.DataFrame(472{473"id": [1, 1, 2],474"timestamp": [475datetime(2022, 12, 11),476datetime(2022, 12, 12),477datetime(2022, 1, 1),478],479"value": [1, 2, 4],480}481).lazy()482483latest = df.group_by("id").agg(pl.col("timestamp").max())484df = df.join(latest, on=["id", "timestamp"], how="semi")485assert df.select(["id", "value"]).collect().to_dict(as_series=False) == {486"id": [1, 2],487"value": [2, 4],488}489490491def test_update() -> None:492df1 = pl.DataFrame(493{494"key1": [1, 2, 3, 4],495"key2": [1, 2, 3, 4],496"a": [1, 2, 3, 4],497"b": [1, 2, 3, 4],498"c": ["1", "2", "3", "4"],499"d": [500date(2023, 1, 1),501date(2023, 1, 2),502date(2023, 1, 3),503date(2023, 1, 4),504],505}506)507508df2 = pl.DataFrame(509{510"key1": [1, 2, 3, 4],511"key2": [1, 2, 3, 5],512"a": [1, 1, 1, 1],513"b": [2, 2, 2, 2],514"c": ["3", "3", "3", "3"],515"d": [516date(2023, 5, 5),517date(2023, 5, 5),518date(2023, 5, 5),519date(2023, 5, 5),520],521}522)523524# update only on key1525expected = pl.DataFrame(526{527"key1": [1, 2, 3, 4],528"key2": [1, 2, 3, 5],529"a": [1, 1, 1, 1],530"b": [2, 2, 2, 2],531"c": ["3", "3", "3", "3"],532"d": [533date(2023, 5, 5),534date(2023, 5, 5),535date(2023, 5, 5),536date(2023, 5, 5),537],538}539)540assert_frame_equal(df1.update(df2, on="key1"), expected)541542# update on key1 using different left/right names543assert_frame_equal(544df1.update(545df2.rename({"key1": "key1b"}),546left_on="key1",547right_on="key1b",548),549expected,550)551552# update on key1 and key2. This should fail to match the last item.553expected = pl.DataFrame(554{555"key1": [1, 2, 3, 4],556"key2": [1, 2, 3, 4],557"a": [1, 1, 1, 4],558"b": [2, 2, 2, 4],559"c": ["3", "3", "3", "4"],560"d": [561date(2023, 5, 5),562date(2023, 5, 5),563date(2023, 5, 5),564date(2023, 1, 4),565],566}567)568assert_frame_equal(df1.update(df2, on=["key1", "key2"]), expected)569570# update on key1 and key2 using different left/right names571assert_frame_equal(572df1.update(573df2.rename({"key1": "key1b", "key2": "key2b"}),574left_on=["key1", "key2"],575right_on=["key1b", "key2b"],576),577expected,578)579580df = pl.DataFrame({"A": [1, 2, 3, 4], "B": [400, 500, 600, 700]})581582new_df = pl.DataFrame({"B": [4, None, 6], "C": [7, 8, 9]})583584assert df.update(new_df).to_dict(as_series=False) == {585"A": [1, 2, 3, 4],586"B": [4, 500, 6, 700],587}588df1 = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})589df2 = pl.DataFrame({"a": [2, 3], "b": [8, 9]})590591assert df1.update(df2, on="a").to_dict(as_series=False) == {592"a": [1, 2, 3],593"b": [4, 8, 9],594}595596a = pl.LazyFrame({"a": [1, 2, 3]})597b = pl.LazyFrame({"b": [4, 5], "c": [3, 1]})598c = a.update(b)599600assert_frame_equal(a, c)601602# check behaviour of 'how' param603result = a.update(b, left_on="a", right_on="c")604assert result.collect().to_series().to_list() == [1, 2, 3]605606result = a.update(b, how="inner", left_on="a", right_on="c")607assert sorted(result.collect().to_series().to_list()) == [1, 3]608609result = a.update(b.rename({"b": "a"}), how="full", on="a")610assert sorted(result.collect().to_series().sort().to_list()) == [1, 2, 3, 4, 5]611612# check behavior of include_nulls=True613df = pl.DataFrame(614{615"A": [1, 2, 3, 4],616"B": [400, 500, 600, 700],617}618)619new_df = pl.DataFrame(620{621"B": [-66, None, -99],622"C": [5, 3, 1],623}624)625out = df.update(new_df, left_on="A", right_on="C", how="full", include_nulls=True)626expected = pl.DataFrame(627{628"A": [1, 2, 3, 4, 5],629"B": [-99, 500, None, 700, -66],630}631)632assert_frame_equal(out, expected, check_row_order=False)633634# edge-case #11684635x = pl.DataFrame({"a": [0, 1]})636y = pl.DataFrame({"a": [2, 3]})637assert sorted(x.update(y, on="a", how="full")["a"].to_list()) == [0, 1, 2, 3]638639# disallowed join strategies640for join_strategy in ("cross", "anti", "semi"):641with pytest.raises(642ValueError,643match=f"`how` must be one of {{'left', 'inner', 'full'}}; found '{join_strategy}'",644):645a.update(b, how=join_strategy) # type: ignore[arg-type]646647648def test_join_frame_consistency() -> None:649df = pl.DataFrame({"A": [1, 2, 3]})650ldf = pl.DataFrame({"A": [1, 2, 5]}).lazy()651652with pytest.raises(TypeError, match=r"expected `other`.*LazyFrame"):653_ = ldf.join(df, on="A") # type: ignore[arg-type]654with pytest.raises(TypeError, match=r"expected `other`.*DataFrame"):655_ = df.join(ldf, on="A") # type: ignore[arg-type]656with pytest.raises(TypeError, match=r"expected `other`.*LazyFrame"):657_ = ldf.join_asof(df, on="A") # type: ignore[arg-type]658with pytest.raises(TypeError, match=r"expected `other`.*DataFrame"):659_ = df.join_asof(ldf, on="A") # type: ignore[arg-type]660661662def test_join_concat_projection_pd_case_7071() -> None:663ldf = pl.DataFrame({"id": [1, 2], "value": [100, 200]}).lazy()664ldf2 = pl.DataFrame({"id": [1, 3], "value": [100, 300]}).lazy()665666ldf = ldf.join(ldf2, on=["id", "value"])667ldf = pl.concat([ldf, ldf2])668result = ldf.select("id")669670expected = pl.DataFrame({"id": [1, 1, 3]}).lazy()671assert_frame_equal(result, expected)672673674@pytest.mark.may_fail_auto_streaming # legacy full join is not order-preserving whereas new-streaming is675def test_join_sorted_fast_paths_null() -> None:676df1 = pl.DataFrame({"x": [0, 1, 0]}).sort("x")677df2 = pl.DataFrame({"x": [0, None], "y": [0, 1]})678assert df1.join(df2, on="x", how="inner").to_dict(as_series=False) == {679"x": [0, 0],680"y": [0, 0],681}682assert df1.join(df2, on="x", how="left").to_dict(as_series=False) == {683"x": [0, 0, 1],684"y": [0, 0, None],685}686assert df1.join(df2, on="x", how="anti").to_dict(as_series=False) == {"x": [1]}687assert df1.join(df2, on="x", how="semi").to_dict(as_series=False) == {"x": [0, 0]}688assert df1.join(df2, on="x", how="full").to_dict(as_series=False) == {689"x": [0, 0, 1, None],690"x_right": [0, 0, None, None],691"y": [0, 0, None, 1],692}693694695def test_full_outer_join_list_() -> None:696schema = {"id": pl.Int64, "vals": pl.List(pl.Float64)}697join_schema = {**schema, **{k + "_right": t for (k, t) in schema.items()}}698df1 = pl.DataFrame({"id": [1], "vals": [[]]}, schema=schema) # type: ignore[arg-type]699df2 = pl.DataFrame({"id": [2, 3], "vals": [[], [4]]}, schema=schema) # type: ignore[arg-type]700expected = pl.DataFrame(701{702"id": [None, None, 1],703"vals": [None, None, []],704"id_right": [2, 3, None],705"vals_right": [[], [4.0], None],706},707schema=join_schema, # type: ignore[arg-type]708)709out = df1.join(df2, on="id", how="full", maintain_order="right_left")710assert_frame_equal(out, expected)711712713@pytest.mark.slow714def test_join_validation() -> None:715def test_each_join_validation(716unique: pl.DataFrame, duplicate: pl.DataFrame, on: str, how: JoinStrategy717) -> None:718# one_to_many719_one_to_many_success_inner = unique.join(720duplicate, on=on, how=how, validate="1:m"721)722723with pytest.raises(ComputeError):724_one_to_many_fail_inner = duplicate.join(725unique, on=on, how=how, validate="1:m"726)727728# one to one729with pytest.raises(ComputeError):730_one_to_one_fail_1_inner = unique.join(731duplicate, on=on, how=how, validate="1:1"732)733734with pytest.raises(ComputeError):735_one_to_one_fail_2_inner = duplicate.join(736unique, on=on, how=how, validate="1:1"737)738739# many to one740with pytest.raises(ComputeError):741_many_to_one_fail_inner = unique.join(742duplicate, on=on, how=how, validate="m:1"743)744745_many_to_one_success_inner = duplicate.join(746unique, on=on, how=how, validate="m:1"747)748749# many to many750_many_to_many_success_1_inner = duplicate.join(751unique, on=on, how=how, validate="m:m"752)753754_many_to_many_success_2_inner = unique.join(755duplicate, on=on, how=how, validate="m:m"756)757758# test data759short_unique = pl.DataFrame(760{761"id": [1, 2, 3, 4],762"id_str": ["1", "2", "3", "4"],763"name": ["hello", "world", "rust", "polars"],764}765)766short_duplicate = pl.DataFrame(767{"id": [1, 2, 3, 1], "id_str": ["1", "2", "3", "1"], "cnt": [2, 4, 6, 1]}768)769long_unique = pl.DataFrame(770{771"id": [1, 2, 3, 4, 5],772"id_str": ["1", "2", "3", "4", "5"],773"name": ["hello", "world", "rust", "polars", "meow"],774}775)776long_duplicate = pl.DataFrame(777{778"id": [1, 2, 3, 1, 5],779"id_str": ["1", "2", "3", "1", "5"],780"cnt": [2, 4, 6, 1, 8],781}782)783784join_strategies: list[JoinStrategy] = ["inner", "full", "left"]785786for join_col in ["id", "id_str"]:787for how in join_strategies:788# same size789test_each_join_validation(long_unique, long_duplicate, join_col, how)790791# left longer792test_each_join_validation(long_unique, short_duplicate, join_col, how)793794# right longer795test_each_join_validation(short_unique, long_duplicate, join_col, how)796797798@typing.no_type_check799def test_join_validation_many_keys() -> None:800# unique in both801df1 = pl.DataFrame(802{803"val1": [11, 12, 13, 14],804"val2": [1, 2, 3, 4],805}806)807df2 = pl.DataFrame(808{809"val1": [11, 12, 13, 14],810"val2": [1, 2, 3, 4],811}812)813for join_type in ["inner", "left", "full"]:814for val in ["m:m", "m:1", "1:1", "1:m"]:815df1.join(df2, on=["val1", "val2"], how=join_type, validate=val)816817# many in lhs818df1 = pl.DataFrame(819{820"val1": [11, 11, 12, 13, 14],821"val2": [1, 1, 2, 3, 4],822}823)824825for join_type in ["inner", "left", "full"]:826for val in ["1:1", "1:m"]:827with pytest.raises(ComputeError):828df1.join(df2, on=["val1", "val2"], how=join_type, validate=val)829830# many in rhs831df1 = pl.DataFrame(832{833"val1": [11, 12, 13, 14],834"val2": [1, 2, 3, 4],835}836)837df2 = pl.DataFrame(838{839"val1": [11, 11, 12, 13, 14],840"val2": [1, 1, 2, 3, 4],841}842)843844for join_type in ["inner", "left", "full"]:845for val in ["m:1", "1:1"]:846with pytest.raises(ComputeError):847df1.join(df2, on=["val1", "val2"], how=join_type, validate=val)848849850def test_full_outer_join_bool() -> None:851df1 = pl.DataFrame({"id": [True, False], "val": [1, 2]})852df2 = pl.DataFrame({"id": [True, False], "val": [0, -1]})853assert df1.join(df2, on="id", how="full", maintain_order="right").to_dict(854as_series=False855) == {856"id": [True, False],857"val": [1, 2],858"id_right": [True, False],859"val_right": [0, -1],860}861862863def test_full_outer_join_coalesce_different_names_13450() -> None:864df1 = pl.DataFrame({"L1": ["a", "b", "c"], "L3": ["b", "c", "d"], "L2": [1, 2, 3]})865df2 = pl.DataFrame({"L3": ["a", "c", "d"], "R2": [7, 8, 9]})866867expected = pl.DataFrame(868{869"L1": ["a", "c", "d", "b"],870"L3": ["b", "d", None, "c"],871"L2": [1, 3, None, 2],872"R2": [7, 8, 9, None],873}874)875876out = df1.join(df2, left_on="L1", right_on="L3", how="full", coalesce=True)877assert_frame_equal(out, expected, check_row_order=False)878879880# https://github.com/pola-rs/polars/issues/10663881def test_join_on_wildcard_error() -> None:882df = pl.DataFrame({"x": [1]})883df2 = pl.DataFrame({"x": [1], "y": [2]})884with pytest.raises(885InvalidOperationError,886):887df.join(df2, on=pl.all())888889890def test_join_on_nth_error() -> None:891df = pl.DataFrame({"x": [1]})892df2 = pl.DataFrame({"x": [1], "y": [2]})893with pytest.raises(894InvalidOperationError,895):896df.join(df2, on=pl.first())897898899def test_join_results_in_duplicate_names() -> None:900df = pl.DataFrame(901{902"a": [1, 2, 3],903"b": [4, 5, 6],904"c": [1, 2, 3],905"c_right": [1, 2, 3],906}907)908909def f(x: Any) -> Any:910return x.join(x, on=["a", "b"], how="left")911912# Ensure it also contains the hint913match_str = "(?s)column with name 'c_right' already exists.*You may want to try"914915# Ensure it fails immediately when resolving schema.916with pytest.raises(DuplicateError, match=match_str):917f(df.lazy()).collect_schema()918919with pytest.raises(DuplicateError, match=match_str):920f(df.lazy()).collect()921922with pytest.raises(DuplicateError, match=match_str):923f(df).collect()924925926def test_join_duplicate_suffixed_columns_from_join_key_column_21048() -> None:927df = pl.DataFrame({"a": 1, "b": 1, "b_right": 1})928929def f(x: Any) -> Any:930return x.join(x, on="a")931932# Ensure it also contains the hint933match_str = "(?s)column with name 'b_right' already exists.*You may want to try"934935# Ensure it fails immediately when resolving schema.936with pytest.raises(DuplicateError, match=match_str):937f(df.lazy()).collect_schema()938939with pytest.raises(DuplicateError, match=match_str):940f(df.lazy()).collect()941942with pytest.raises(DuplicateError, match=match_str):943f(df)944945946def test_join_projection_invalid_name_contains_suffix_15243() -> None:947df1 = pl.DataFrame({"a": [1, 2, 3]}).lazy()948df2 = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}).lazy()949950with pytest.raises(ColumnNotFoundError):951(952df1.join(df2, on="a")953.select(pl.col("b").filter(pl.col("b") == pl.col("foo_right")))954.collect()955)956957958def test_join_list_non_numeric() -> None:959assert (960pl.DataFrame(961{962"lists": [963["a", "b", "c"],964["a", "c", "b"],965["a", "c", "b"],966["a", "c", "d"],967]968}969)970).group_by("lists", maintain_order=True).agg(pl.len().alias("count")).to_dict(971as_series=False972) == {973"lists": [["a", "b", "c"], ["a", "c", "b"], ["a", "c", "d"]],974"count": [1, 2, 1],975}976977978@pytest.mark.slow979def test_join_4_columns_with_validity() -> None:980# join on 4 columns so we trigger combine validities981# use 138 as that is 2 u64 and a remainder982a = pl.DataFrame(983{"a": [None if a % 6 == 0 else a for a in range(138)]}984).with_columns(985b=pl.col("a"),986c=pl.col("a"),987d=pl.col("a"),988)989990assert a.join(a, on=["a", "b", "c", "d"], how="inner", nulls_equal=True).shape == (991644,9924,993)994assert a.join(a, on=["a", "b", "c", "d"], how="inner", nulls_equal=False).shape == (995115,9964,997)9989991000@pytest.mark.release1001def test_cross_join() -> None:1002# triggers > 100 rows implementation1003# https://github.com/pola-rs/polars/blob/5f5acb2a523ce01bc710768b396762b8e69a9e07/polars/polars-core/src/frame/cross_join.rs#L341004df1 = pl.DataFrame({"col1": ["a"], "col2": ["d"]})1005df2 = pl.DataFrame({"frame2": pl.arange(0, 100, eager=True)})1006out = df2.join(df1, how="cross")1007df2 = pl.DataFrame({"frame2": pl.arange(0, 101, eager=True)})1008assert_frame_equal(1009df2.join(df1, how="cross", maintain_order="left_right").slice(0, 100), out1010)101110121013@pytest.mark.release1014def test_cross_join_slice_pushdown() -> None:1015# this will likely go out of memory if we did not pushdown the slice1016df = (1017pl.Series("x", pl.arange(0, 2**16 - 1, eager=True, dtype=pl.UInt16) % 2**15)1018).to_frame()10191020result = (1021df.lazy()1022.join(df.lazy(), how="cross", maintain_order="left_right", suffix="_")1023.slice(-5, 10)1024.collect()1025)1026expected = pl.DataFrame(1027{1028"x": [32766, 32766, 32766, 32766, 32766],1029"x_": [32762, 32763, 32764, 32765, 32766],1030},1031schema={"x": pl.UInt16, "x_": pl.UInt16},1032)1033assert_frame_equal(result, expected)10341035result = (1036df.lazy()1037.join(df.lazy(), how="cross", maintain_order="left_right", suffix="_")1038.slice(2, 10)1039.collect()1040)1041expected = pl.DataFrame(1042{1043"x": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],1044"x_": [2, 3, 4, 5, 6, 7, 8, 9, 10, 11],1045},1046schema={"x": pl.UInt16, "x_": pl.UInt16},1047)1048assert_frame_equal(result, expected)104910501051@pytest.mark.parametrize("how", ["left", "inner"])1052def test_join_coalesce(how: JoinStrategy) -> None:1053a = pl.LazyFrame({"a": [1, 2], "b": [1, 2]})1054b = pl.LazyFrame(1055{1056"a": [1, 2, 1, 2],1057"b": [5, 7, 8, 9],1058"c": [1, 2, 1, 2],1059}1060)10611062how = "inner"1063q = a.join(b, on="a", coalesce=False, how=how)1064out = q.collect()1065assert q.collect_schema() == out.schema1066assert out.columns == ["a", "b", "a_right", "b_right", "c"]10671068q = a.join(b, on=["a", "b"], coalesce=False, how=how)1069out = q.collect()1070assert q.collect_schema() == out.schema1071assert out.columns == ["a", "b", "a_right", "b_right", "c"]10721073q = a.join(b, on=["a", "b"], coalesce=True, how=how)1074out = q.collect()1075assert q.collect_schema() == out.schema1076assert out.columns == ["a", "b", "c"]107710781079@pytest.mark.parametrize("how", ["left", "inner", "full"])1080def test_join_empties(how: JoinStrategy) -> None:1081df1 = pl.DataFrame({"col1": [], "col2": [], "col3": []})1082df2 = pl.DataFrame({"col2": [], "col4": [], "col5": []})10831084df = df1.join(df2, on="col2", how=how)1085assert df.height == 0108610871088def test_join_raise_on_redundant_keys() -> None:1089left = pl.DataFrame({"a": [1, 2, 3], "b": [3, 4, 5], "c": [5, 6, 7]})1090right = pl.DataFrame({"a": [2, 3, 4], "c": [4, 5, 6]})1091with pytest.raises(InvalidOperationError, match="already joined on"):1092left.join(right, on=["a", "a"], how="full", coalesce=True)109310941095@pytest.mark.parametrize("coalesce", [False, True])1096def test_join_raise_on_repeated_expression_key_names(coalesce: bool) -> None:1097left = pl.DataFrame({"a": [1, 2, 3], "b": [3, 4, 5], "c": [5, 6, 7]})1098right = pl.DataFrame({"a": [2, 3, 4], "c": [4, 5, 6]})1099with ( # noqa: PT0121100pytest.raises(InvalidOperationError, match="already joined on"),1101warnings.catch_warnings(),1102):1103warnings.simplefilter(action="ignore", category=UserWarning)1104left.join(1105right, on=[pl.col("a"), pl.col("a") % 2], how="full", coalesce=coalesce1106)110711081109def test_join_lit_panic_11410() -> None:1110df = pl.LazyFrame({"date": [1, 2, 3], "symbol": [4, 5, 6]})1111dates = df.select("date").unique(maintain_order=True)1112symbols = df.select("symbol").unique(maintain_order=True)11131114assert symbols.join(1115dates, left_on=pl.lit(1), right_on=pl.lit(1), maintain_order="left_right"1116).collect().to_dict(as_series=False) == {1117"symbol": [4, 4, 4, 5, 5, 5, 6, 6, 6],1118"date": [1, 2, 3, 1, 2, 3, 1, 2, 3],1119}112011211122def test_join_empty_literal_17027() -> None:1123df1 = pl.DataFrame({"a": [1]})1124df2 = pl.DataFrame(schema={"a": pl.Int64})11251126assert df1.join(df2, on=pl.lit(0), how="left").height == 11127assert df1.join(df2, on=pl.lit(0), how="inner").height == 01128assert (1129df1.lazy()1130.join(df2.lazy(), on=pl.lit(0), how="inner")1131.collect(engine="streaming")1132.height1133== 01134)1135assert (1136df1.lazy()1137.join(df2.lazy(), on=pl.lit(0), how="left")1138.collect(engine="streaming")1139.height1140== 11141)114211431144@pytest.mark.parametrize(1145("left_on", "right_on"),1146zip(1147[pl.col("a"), pl.col("a").sort(), [pl.col("a"), pl.col("b")]],1148[pl.col("a").slice(0, 2) * 2, pl.col("b"), [pl.col("a"), pl.col("b").head()]],1149strict=False,1150),1151)1152def test_join_non_elementwise_keys_raises(left_on: pl.Expr, right_on: pl.Expr) -> None:1153# https://github.com/pola-rs/polars/issues/171841154left = pl.LazyFrame({"a": [1, 2, 3], "b": [3, 4, 5]})1155right = pl.LazyFrame({"a": [1, 2, 3], "b": [3, 4, 5]})11561157q = left.join(1158right,1159left_on=left_on,1160right_on=right_on,1161how="inner",1162)11631164with pytest.raises(pl.exceptions.InvalidOperationError):1165q.collect()116611671168def test_join_coalesce_not_supported_warning() -> None:1169# https://github.com/pola-rs/polars/issues/171841170left = pl.LazyFrame({"a": [1, 2, 3], "b": [3, 4, 5]})1171right = pl.LazyFrame({"a": [1, 2, 3], "b": [3, 4, 5]})11721173q = left.join(1174right,1175left_on=[pl.col("a") * 2],1176right_on=[pl.col("a") * 2],1177how="inner",1178coalesce=True,1179)1180with pytest.warns(UserWarning, match="turning off key coalescing"):1181got = q.collect()1182expect = pl.DataFrame(1183{"a": [1, 2, 3], "b": [3, 4, 5], "a_right": [1, 2, 3], "b_right": [3, 4, 5]}1184)11851186assert_frame_equal(expect, got, check_row_order=False)118711881189@pytest.mark.parametrize(1190("on_args"),1191[1192{"on": "a", "left_on": "a"},1193{"on": "a", "right_on": "a"},1194{"on": "a", "left_on": "a", "right_on": "a"},1195],1196)1197def test_join_on_and_left_right_on(on_args: dict[str, str]) -> None:1198df1 = pl.DataFrame({"a": [1], "b": [2]})1199df2 = pl.DataFrame({"a": [1], "c": [3]})1200msg = "cannot use 'on' in conjunction with 'left_on' or 'right_on'"1201with pytest.raises(ValueError, match=msg):1202df1.join(df2, **on_args) # type: ignore[arg-type]120312041205@pytest.mark.parametrize(1206("on_args"),1207[1208{"left_on": "a"},1209{"right_on": "a"},1210],1211)1212def test_join_only_left_or_right_on(on_args: dict[str, str]) -> None:1213df1 = pl.DataFrame({"a": [1]})1214df2 = pl.DataFrame({"a": [1]})1215msg = "'left_on' requires corresponding 'right_on'"1216with pytest.raises(ValueError, match=msg):1217df1.join(df2, **on_args) # type: ignore[arg-type]121812191220@pytest.mark.parametrize(1221("on_args"),1222[1223{"on": "a"},1224{"left_on": "a", "right_on": "a"},1225],1226)1227def test_cross_join_no_on_keys(on_args: dict[str, str]) -> None:1228df1 = pl.DataFrame({"a": [1, 2]})1229df2 = pl.DataFrame({"b": [3, 4]})1230msg = "cross join should not pass join keys"1231with pytest.raises(ValueError, match=msg):1232df1.join(df2, how="cross", **on_args) # type: ignore[arg-type]123312341235@pytest.mark.parametrize("set_sorted", [True, False])1236def test_left_join_slice_pushdown_19405(set_sorted: bool) -> None:1237left = pl.LazyFrame({"k": [1, 2, 3, 4, 0]})1238right = pl.LazyFrame({"k": [1, 1, 1, 1, 0]})12391240if set_sorted:1241# The data isn't actually sorted on purpose to ensure we default to a1242# hash join unless we set the sorted flag here, in case there is new1243# code in the future that automatically identifies sortedness during1244# Series construction from Python.1245left = left.set_sorted("k")1246right = right.set_sorted("k")12471248q = left.join(right, on="k", how="left", maintain_order="left_right").head(5)1249assert_frame_equal(q.collect(), pl.DataFrame({"k": [1, 1, 1, 1, 2]}))125012511252def test_join_key_type_coercion_19597() -> None:1253left = pl.LazyFrame({"a": pl.Series([1, 2, 3], dtype=pl.Float64)})1254right = pl.LazyFrame({"a": pl.Series([1, 2, 3], dtype=pl.Int64)})12551256with pytest.raises(SchemaError, match="datatypes of join keys don't match"):1257left.join(right, left_on=pl.col("a"), right_on=pl.col("a")).collect_schema()12581259with pytest.raises(SchemaError, match="datatypes of join keys don't match"):1260left.join(1261right, left_on=pl.col("a") * 2, right_on=pl.col("a") * 21262).collect_schema()126312641265def test_array_explode_join_19763() -> None:1266q = pl.LazyFrame().select(1267pl.lit(pl.Series([[1], [2]], dtype=pl.Array(pl.Int64, 1))).explode().alias("k")1268)12691270q = q.join(pl.LazyFrame({"k": [1, 2]}), on="k")12711272assert_frame_equal(q.collect().sort("k"), pl.DataFrame({"k": [1, 2]}))127312741275def test_join_full_19814() -> None:1276schema = {"a": pl.Int64, "c": pl.Categorical}1277a = pl.LazyFrame({"a": [1], "c": [None]}, schema=schema)1278b = pl.LazyFrame({"a": [1, 3, 4]})1279assert_frame_equal(1280a.join(b, on="a", how="full", coalesce=True).collect(),1281pl.DataFrame({"a": [1, 3, 4], "c": [None, None, None]}, schema=schema),1282check_row_order=False,1283)128412851286def test_join_preserve_order_inner() -> None:1287left = pl.LazyFrame({"a": [None, 2, 1, 1, 5]})1288right = pl.LazyFrame({"a": [1, 1, None, 2], "b": [6, 7, 8, 9]})12891290# Inner joins12911292inner_left = left.join(right, on="a", how="inner", maintain_order="left").collect()1293assert inner_left.get_column("a").cast(pl.UInt32).to_list() == [2, 1, 1, 1, 1]1294inner_left_right = left.join(1295right, on="a", how="inner", maintain_order="left"1296).collect()1297assert inner_left.get_column("a").equals(inner_left_right.get_column("a"))12981299inner_right = left.join(1300right, on="a", how="inner", maintain_order="right"1301).collect()1302assert inner_right.get_column("a").cast(pl.UInt32).to_list() == [1, 1, 1, 1, 2]1303inner_right_left = left.join(1304right, on="a", how="inner", maintain_order="right"1305).collect()1306assert inner_right.get_column("a").equals(inner_right_left.get_column("a"))130713081309# The new streaming engine does not provide the same maintain_order="none"1310# ordering guarantee that is currently kept for compatibility on the in-memory1311# engine.1312@pytest.mark.may_fail_auto_streaming1313def test_join_preserve_order_left() -> None:1314left = pl.LazyFrame({"a": [None, 2, 1, 1, 5]})1315right = pl.LazyFrame({"a": [1, None, 2, 6], "b": [6, 7, 8, 9]})13161317# Right now the left join algorithm is ordered without explicitly setting any order1318# This behaviour is deprecated but can only be removed in 2.01319left_none = left.join(right, on="a", how="left", maintain_order="none").collect()1320assert left_none.get_column("a").cast(pl.UInt32).to_list() == [1321None,13222,13231,13241,13255,1326]13271328left_left = left.join(right, on="a", how="left", maintain_order="left").collect()1329assert left_left.get_column("a").cast(pl.UInt32).to_list() == [1330None,13312,13321,13331,13345,1335]13361337left_left_right = left.join(1338right, on="a", how="left", maintain_order="left_right"1339).collect()1340# If the left order is preserved then there are no unsorted right rows1341assert left_left.get_column("a").equals(left_left_right.get_column("a"))13421343left_right = left.join(right, on="a", how="left", maintain_order="right").collect()1344assert left_right.get_column("a").cast(pl.UInt32).to_list()[:5] == [13451,13461,13472,1348None,13495,1350]13511352left_right_left = left.join(1353right, on="a", how="left", maintain_order="right_left"1354).collect()1355assert left_right_left.get_column("a").cast(pl.UInt32).to_list() == [13561,13571,13582,1359None,13605,1361]13621363right_left = left.join(right, on="a", how="right", maintain_order="left").collect()1364assert right_left.get_column("a").cast(pl.UInt32).to_list() == [2, 1, 1, None, 6]13651366right_right = left.join(1367right, on="a", how="right", maintain_order="right"1368).collect()1369assert right_right.get_column("a").cast(pl.UInt32).to_list() == [13701,13711,1372None,13732,13746,1375]137613771378def test_join_preserve_order_full() -> None:1379left = pl.LazyFrame({"a": [None, 2, 1, 1, 5]})1380right = pl.LazyFrame({"a": [1, None, 2, 6], "b": [6, 7, 8, 9]})13811382full_left = left.join(right, on="a", how="full", maintain_order="left").collect()1383assert full_left.get_column("a").cast(pl.UInt32).to_list()[:5] == [1384None,13852,13861,13871,13885,1389]1390full_right = left.join(right, on="a", how="full", maintain_order="right").collect()1391assert full_right.get_column("a").cast(pl.UInt32).to_list()[:5] == [13921,13931,1394None,13952,1396None,1397]13981399full_left_right = left.join(1400right, on="a", how="full", maintain_order="left_right"1401).collect()1402assert full_left_right.get_column("a_right").cast(pl.UInt32).to_list() == [1403None,14042,14051,14061,1407None,1408None,14096,1410]14111412full_right_left = left.join(1413right, on="a", how="full", maintain_order="right_left"1414).collect()1415assert full_right_left.get_column("a").cast(pl.UInt32).to_list() == [14161,14171,1418None,14192,1420None,1421None,14225,1423]142414251426@pytest.mark.parametrize(1427"dtypes",1428[1429["Int128", "Int128", "Int64"],1430["Int128", "Int128", "Int32"],1431["Int128", "Int128", "Int16"],1432["Int128", "Int128", "Int8"],1433["Int128", "UInt64", "Int128"],1434["Int128", "UInt64", "Int64"],1435["Int128", "UInt64", "Int32"],1436["Int128", "UInt64", "Int16"],1437["Int128", "UInt64", "Int8"],1438["Int128", "UInt32", "Int128"],1439["Int128", "UInt16", "Int128"],1440["Int128", "UInt8", "Int128"],14411442["Int64", "Int64", "Int32"],1443["Int64", "Int64", "Int16"],1444["Int64", "Int64", "Int8"],1445["Int64", "UInt32", "Int64"],1446["Int64", "UInt32", "Int32"],1447["Int64", "UInt32", "Int16"],1448["Int64", "UInt32", "Int8"],1449["Int64", "UInt16", "Int64"],1450["Int64", "UInt8", "Int64"],14511452["Int32", "Int32", "Int16"],1453["Int32", "Int32", "Int8"],1454["Int32", "UInt16", "Int32"],1455["Int32", "UInt16", "Int16"],1456["Int32", "UInt16", "Int8"],1457["Int32", "UInt8", "Int32"],14581459["Int16", "Int16", "Int8"],1460["Int16", "UInt8", "Int16"],1461["Int16", "UInt8", "Int8"],14621463["UInt128", "UInt128", "UInt64"],1464["UInt128", "UInt128", "UInt32"],1465["UInt128", "UInt128", "UInt16"],1466["UInt128", "UInt128", "UInt8"],1467["UInt128", "UInt64", "UInt128"],1468["UInt128", "UInt32", "UInt128"],1469["UInt128", "UInt16", "UInt128"],1470["UInt128", "UInt8", "UInt128"],14711472["UInt64", "UInt64", "UInt32"],1473["UInt64", "UInt64", "UInt16"],1474["UInt64", "UInt64", "UInt8"],14751476["UInt32", "UInt32", "UInt16"],1477["UInt32", "UInt32", "UInt8"],14781479["UInt16", "UInt16", "UInt8"],14801481["Float64", "Float64", "Float32"],1482["Float32", "Float32", "Float16"],1483],1484) # fmt: skip1485@pytest.mark.parametrize("swap", [True, False])1486def test_join_numeric_key_upcast_15338(1487dtypes: tuple[str, str, str], swap: bool1488) -> None:1489supertype, ltype, rtype = (getattr(pl, x) for x in dtypes)1490ltype, rtype = (rtype, ltype) if swap else (ltype, rtype)14911492left = pl.select(pl.Series("a", [1, 1, 3]).cast(ltype)).lazy()1493right = pl.select(pl.Series("a", [1]).cast(rtype), b=pl.lit("A")).lazy()14941495assert_frame_equal(1496left.join(right, on="a", how="left").collect(),1497pl.select(a=pl.Series([1, 1, 3]).cast(ltype), b=pl.Series(["A", "A", None])),1498check_row_order=False,1499)15001501assert_frame_equal(1502left.join(right, on="a", how="left", coalesce=False).drop("a_right").collect(),1503pl.select(a=pl.Series([1, 1, 3]).cast(ltype), b=pl.Series(["A", "A", None])),1504check_row_order=False,1505)15061507assert_frame_equal(1508left.join(right, on="a", how="full").collect(),1509pl.select(1510a=pl.Series([1, 1, 3]).cast(ltype),1511a_right=pl.Series([1, 1, None]).cast(rtype),1512b=pl.Series(["A", "A", None]),1513),1514check_row_order=False,1515)15161517assert_frame_equal(1518left.join(right, on="a", how="full", coalesce=True).collect(),1519pl.select(1520a=pl.Series([1, 1, 3]).cast(supertype),1521b=pl.Series(["A", "A", None]),1522),1523check_row_order=False,1524)15251526assert_frame_equal(1527left.join(right, on="a", how="semi").collect(),1528pl.select(a=pl.Series([1, 1]).cast(ltype)),1529)15301531# join_where1532for optimizations in [pl.QueryOptFlags(), pl.QueryOptFlags.none()]:1533assert_frame_equal(1534left.join_where(right, pl.col("a") == pl.col("a_right")).collect(1535optimizations=optimizations,1536),1537pl.select(1538a=pl.Series([1, 1]).cast(ltype),1539a_right=pl.lit(1, dtype=rtype),1540b=pl.Series(["A", "A"]),1541),1542)154315441545def test_join_numeric_key_upcast_forbid_float_int() -> None:1546ltype = pl.Float641547rtype = pl.Int12815481549left = pl.LazyFrame({"a": [1.0, 0.0]}, schema={"a": ltype})1550right = pl.LazyFrame({"a": [1, 2]}, schema={"a": rtype})15511552# Establish baseline: In a non-join context, comparisons between ltype and1553# rtype succeed even if the upcast is lossy.1554assert_frame_equal(1555left.with_columns(right.collect()["a"].alias("a_right"))1556.select(pl.col("a") == pl.col("a_right"))1557.collect(),1558pl.DataFrame({"a": [True, False]}),1559)15601561with pytest.raises(SchemaError, match="datatypes of join keys don't match"):1562left.join(right, on="a", how="left").collect()15631564for optimizations in [pl.QueryOptFlags(), pl.QueryOptFlags.none()]:1565with pytest.raises(1566SchemaError, match="'join_where' cannot compare Float64 with Int128"1567):1568left.join_where(right, pl.col("a") == pl.col("a_right")).collect(1569optimizations=optimizations,1570)15711572with pytest.raises(1573SchemaError, match="'join_where' cannot compare Float64 with Int128"1574):1575left.join_where(1576right, pl.col("a") == (pl.col("a") == pl.col("a_right"))1577).collect(optimizations=optimizations)157815791580def test_join_numeric_key_upcast_order() -> None:1581# E.g. when we are joining on this expression:1582# * col('a') + 1271583#1584# and we want to upcast, ensure that we upcast like this:1585# * ( col('a') + 127 ) .cast(<type>)1586#1587# and *not* like this:1588# * ( col('a').cast(<type>) + lit(127).cast(<type>) )1589#1590# as otherwise the results would be different.15911592left = pl.select(pl.Series("a", [1], dtype=pl.Int8)).lazy()1593right = pl.select(1594pl.Series("a", [1, 128, -128], dtype=pl.Int64), b=pl.lit("A")1595).lazy()15961597# col('a') in `left` is Int8, the result will overflow to become -1281598left_expr = pl.col("a") + 12715991600assert_frame_equal(1601left.join(right, left_on=left_expr, right_on="a", how="inner").collect(),1602pl.DataFrame(1603{1604"a": pl.Series([1], dtype=pl.Int8),1605"a_right": pl.Series([-128], dtype=pl.Int64),1606"b": "A",1607}1608),1609)16101611assert_frame_equal(1612left.join_where(right, left_expr == pl.col("a_right")).collect(),1613pl.DataFrame(1614{1615"a": pl.Series([1], dtype=pl.Int8),1616"a_right": pl.Series([-128], dtype=pl.Int64),1617"b": "A",1618}1619),1620)16211622assert_frame_equal(1623(1624left.join(right, left_on=left_expr, right_on="a", how="full")1625.collect()1626.sort(pl.all())1627),1628pl.DataFrame(1629{1630"a": pl.Series([1, None, None], dtype=pl.Int8),1631"a_right": pl.Series([-128, 1, 128], dtype=pl.Int64),1632"b": ["A", "A", "A"],1633}1634).sort(pl.all()),1635)163616371638def test_no_collapse_join_when_maintain_order_20725() -> None:1639df1 = pl.LazyFrame({"Fraction_1": [0, 25, 50, 75, 100]})1640df2 = pl.LazyFrame({"Fraction_2": [0, 1]})1641df3 = pl.LazyFrame({"Fraction_3": [0, 1]})16421643ldf = df1.join(df2, how="cross", maintain_order="left_right").join(1644df3, how="cross", maintain_order="left_right"1645)16461647df_pl_lazy = ldf.filter(pl.col("Fraction_1") == 100).collect()1648df_pl_eager = ldf.collect().filter(pl.col("Fraction_1") == 100)16491650assert_frame_equal(df_pl_lazy, df_pl_eager)165116521653def test_join_where_predicate_type_coercion_21009() -> None:1654left_frame = pl.LazyFrame(1655{1656"left_match": ["A", "B", "C", "D", "E", "F"],1657"left_date_start": range(6),1658}1659)16601661right_frame = pl.LazyFrame(1662{1663"right_match": ["D", "E", "F", "G", "H", "I"],1664"right_date": range(6),1665}1666)16671668# Note: Cannot eq the plans as the operand sides are non-deterministic16691670q1 = left_frame.join_where(1671right_frame,1672pl.col("left_match") == pl.col("right_match"),1673pl.col("right_date") >= pl.col("left_date_start"),1674)16751676plan = q1.explain().splitlines()1677assert plan[0].strip().startswith("FILTER")1678assert plan[1] == "FROM"1679assert plan[2].strip().startswith("INNER JOIN")16801681q2 = left_frame.join_where(1682right_frame,1683pl.all_horizontal(pl.col("left_match") == pl.col("right_match")),1684pl.col("right_date") >= pl.col("left_date_start"),1685)16861687plan = q2.explain().splitlines()1688assert plan[0].strip().startswith("FILTER")1689assert plan[1] == "FROM"1690assert plan[2].strip().startswith("INNER JOIN")16911692assert_frame_equal(q1.collect(), q2.collect())169316941695def test_join_right_predicate_pushdown_21142() -> None:1696left = pl.LazyFrame({"key": [1, 2, 4], "values": ["a", "b", "c"]})1697right = pl.LazyFrame({"key": [1, 2, 3], "values": ["d", "e", "f"]})16981699rjoin = left.join(right, on="key", how="right")17001701q = rjoin.filter(pl.col("values").is_null())17021703expect = pl.select(1704pl.Series("values", [None], pl.String),1705pl.Series("key", [3], pl.Int64),1706pl.Series("values_right", ["f"], pl.String),1707)17081709assert_frame_equal(q.collect(), expect)17101711# Ensure for right join, filter on RHS key-columns are pushed down.1712q = rjoin.filter(pl.col("values_right").is_null())17131714plan = q.explain()1715assert plan.index("FILTER") > plan.index("RIGHT PLAN ON")17161717assert_frame_equal(q.collect(), expect.clear())171817191720def test_join_where_nested_expr_21066() -> None:1721left = pl.LazyFrame({"a": [1, 2]})1722right = pl.LazyFrame({"a": [1]})17231724q = left.join_where(right, pl.col("a") == (pl.col("a_right") + 1))17251726assert_frame_equal(q.collect(), pl.DataFrame({"a": 2, "a_right": 1}))172717281729def test_select_after_join_where_20831() -> None:1730left = pl.LazyFrame(1731{1732"a": [1, 2, 3, 1, None],1733"b": [1, 2, 3, 4, 5],1734"c": [2, 3, 4, 5, 6],1735}1736)17371738right = pl.LazyFrame(1739{1740"a": [1, 4, 3, 7, None, None, 1],1741"c": [2, 3, 4, 5, 6, 7, 8],1742"d": [6, None, 7, 8, -1, 2, 4],1743}1744)17451746q = left.join_where(1747right, pl.col("b") * 2 <= pl.col("a_right"), pl.col("a") < pl.col("c_right")1748)17491750assert_frame_equal(1751q.select("d").collect().sort("d"),1752pl.Series("d", [None, None, 7, 8, 8, 8]).to_frame(),1753)17541755assert q.select(pl.len()).collect().item() == 617561757q = (1758left.join(right, how="cross")1759.filter(pl.col("b") * 2 <= pl.col("a_right"))1760.filter(pl.col("a") < pl.col("c_right"))1761)17621763assert_frame_equal(1764q.select("d").collect().sort("d"),1765pl.Series("d", [None, None, 7, 8, 8, 8]).to_frame(),1766)17671768assert q.select(pl.len()).collect().item() == 6176917701771@pytest.mark.parametrize(1772("dtype", "data"),1773[1774(pl.Struct, [{"x": 1}, {"x": 2}, {"x": 3}, {"x": 4}]),1775(pl.List, [[1], [2, 2], [3, 3, 3], [4, 4, 4, 4]]),1776(pl.Array(pl.Int64, 2), [[1, 1], [2, 2], [3, 3], [4, 4]]),1777],1778)1779def test_join_on_nested(dtype: PolarsDataType, data: list[Any]) -> None:1780lhs = pl.DataFrame(1781{1782"a": data[:3],1783"b": [1, 2, 3],1784}1785)1786rhs = pl.DataFrame(1787{1788"a": [data[3], data[1]],1789"c": [4, 2],1790}1791)17921793assert_frame_equal(1794lhs.join(rhs, on="a", how="left", maintain_order="left"),1795pl.select(1796a=pl.Series(data[:3]),1797b=pl.Series([1, 2, 3]),1798c=pl.Series([None, 2, None]),1799),1800)1801assert_frame_equal(1802lhs.join(rhs, on="a", how="right", maintain_order="right"),1803pl.select(1804b=pl.Series([None, 2]),1805a=pl.Series([data[3], data[1]]),1806c=pl.Series([4, 2]),1807),1808)1809assert_frame_equal(1810lhs.join(rhs, on="a", how="inner"),1811pl.select(1812a=pl.Series([data[1]]),1813b=pl.Series([2]),1814c=pl.Series([2]),1815),1816)1817assert_frame_equal(1818lhs.join(rhs, on="a", how="full", maintain_order="left_right"),1819pl.select(1820a=pl.Series(data[:3] + [None]),1821b=pl.Series([1, 2, 3, None]),1822a_right=pl.Series([None, data[1], None, data[3]]),1823c=pl.Series([None, 2, None, 4]),1824),1825)1826assert_frame_equal(1827lhs.join(rhs, on="a", how="semi"),1828pl.select(1829a=pl.Series([data[1]]),1830b=pl.Series([2]),1831),1832)1833assert_frame_equal(1834lhs.join(rhs, on="a", how="anti", maintain_order="left"),1835pl.select(1836a=pl.Series([data[0], data[2]]),1837b=pl.Series([1, 3]),1838),1839)1840assert_frame_equal(1841lhs.join(rhs, how="cross", maintain_order="left_right"),1842pl.select(1843a=pl.Series([data[0], data[0], data[1], data[1], data[2], data[2]]),1844b=pl.Series([1, 1, 2, 2, 3, 3]),1845a_right=pl.Series([data[3], data[1], data[3], data[1], data[3], data[1]]),1846c=pl.Series([4, 2, 4, 2, 4, 2]),1847),1848)184918501851def test_empty_join_result_with_array_15474() -> None:1852lhs = pl.DataFrame(1853{1854"x": [1, 2],1855"y": pl.Series([[1, 2, 3], [4, 5, 6]], dtype=pl.Array(pl.Int64, 3)),1856}1857)1858rhs = pl.DataFrame({"x": [0]})1859result = lhs.join(rhs, on="x")1860expected = pl.DataFrame(schema={"x": pl.Int64, "y": pl.Array(pl.Int64, 3)})1861assert_frame_equal(result, expected)186218631864@pytest.mark.slow1865def test_join_where_eager_perf_21145() -> None:1866left = pl.Series("left", range(3_000)).to_frame()1867right = pl.Series("right", range(1_000)).to_frame()18681869p = pl.col("left").is_between(pl.lit(0, dtype=pl.Int64), pl.col("right"))1870runtime_eager = time_func(lambda: left.join_where(right, p))1871runtime_lazy = time_func(lambda: left.lazy().join_where(right.lazy(), p).collect())1872runtime_ratio = runtime_eager / runtime_lazy18731874# Pick as high as reasonably possible for CI stability1875# * Was observed to be >=5 seconds on the bugged version, so 3 is a safe bet.1876threshold = 318771878if runtime_ratio > threshold:1879msg = f"runtime_ratio ({runtime_ratio}) > {threshold}x ({runtime_eager = }, {runtime_lazy = })"1880raise ValueError(msg)188118821883def test_select_len_after_semi_anti_join_21343() -> None:1884lhs = pl.LazyFrame({"a": [1, 2, 3], "b": [4, 5, 6]})1885rhs = pl.LazyFrame({"a": [1, 2, 3]})18861887q = lhs.join(rhs, on="a", how="anti").select(pl.len())18881889assert q.collect().item() == 0189018911892def test_multi_leftjoin_empty_right_21701() -> None:1893parent_data = {1894"id": [1, 30, 80],1895"parent_field1": [3, 20, 17],1896}1897parent_df = pl.LazyFrame(parent_data)1898child_df = pl.LazyFrame(1899[],1900schema={"id": pl.Int32(), "parent_id": pl.Int32(), "child_field1": pl.Int32()},1901)1902subchild_df = pl.LazyFrame(1903[], schema={"child_id": pl.Int32(), "subchild_field1": pl.Int32()}1904)19051906joined_df = parent_df.join(1907child_df.join(1908subchild_df, left_on=pl.col("id"), right_on=pl.col("child_id"), how="left"1909),1910left_on=pl.col("id"),1911right_on=pl.col("parent_id"),1912how="left",1913)1914joined_df = joined_df.select("id", "parent_field1")1915assert_frame_equal(joined_df.collect(), parent_df.collect(), check_row_order=False)191619171918@pytest.mark.parametrize("order", ["none", "left_right", "right_left"])1919def test_join_null_equal(order: Literal["none", "left_right", "right_left"]) -> None:1920lhs = pl.DataFrame({"x": [1, None, None], "y": [1, 2, 3]})1921with_null = pl.DataFrame({"x": [1, None], "z": [1, 2]})1922without_null = pl.DataFrame({"x": [1, 3], "z": [1, 3]})1923check_row_order = order != "none"19241925# Inner join.1926assert_frame_equal(1927lhs.join(with_null, on="x", nulls_equal=True, maintain_order=order),1928pl.DataFrame({"x": [1, None, None], "y": [1, 2, 3], "z": [1, 2, 2]}),1929check_row_order=check_row_order,1930)1931assert_frame_equal(1932lhs.join(without_null, on="x", nulls_equal=True),1933pl.DataFrame({"x": [1], "y": [1], "z": [1]}),1934)19351936# Left join.1937assert_frame_equal(1938lhs.join(with_null, on="x", how="left", nulls_equal=True, maintain_order=order),1939pl.DataFrame({"x": [1, None, None], "y": [1, 2, 3], "z": [1, 2, 2]}),1940check_row_order=check_row_order,1941)1942assert_frame_equal(1943lhs.join(1944without_null, on="x", how="left", nulls_equal=True, maintain_order=order1945),1946pl.DataFrame({"x": [1, None, None], "y": [1, 2, 3], "z": [1, None, None]}),1947check_row_order=check_row_order,1948)19491950# Full join.1951assert_frame_equal(1952lhs.join(1953with_null,1954on="x",1955how="full",1956nulls_equal=True,1957coalesce=True,1958maintain_order=order,1959),1960pl.DataFrame({"x": [1, None, None], "y": [1, 2, 3], "z": [1, 2, 2]}),1961check_row_order=check_row_order,1962)1963if order == "left_right":1964expected = pl.DataFrame(1965{1966"x": [1, None, None, None],1967"x_right": [1, None, None, 3],1968"y": [1, 2, 3, None],1969"z": [1, None, None, 3],1970}1971)1972else:1973expected = pl.DataFrame(1974{1975"x": [1, None, None, None],1976"x_right": [1, 3, None, None],1977"y": [1, None, 2, 3],1978"z": [1, 3, None, None],1979}1980)1981assert_frame_equal(1982lhs.join(1983without_null, on="x", how="full", nulls_equal=True, maintain_order=order1984),1985expected,1986check_row_order=check_row_order,1987check_column_order=False,1988)198919901991def test_join_categorical_21815() -> None:1992left = pl.DataFrame({"x": ["a", "b", "c", "d"]}).with_columns(1993xc=pl.col.x.cast(pl.Categorical)1994)1995right = pl.DataFrame({"x": ["c", "d", "e", "f"]}).with_columns(1996xc=pl.col.x.cast(pl.Categorical)1997)19981999# As key.2000cat_key = left.join(right, on="xc", how="full")20012002# As payload.2003cat_payload = left.join(right, on="x", how="full")20042005expected = pl.DataFrame(2006{2007"x": ["a", "b", "c", "d", None, None],2008"x_right": [None, None, "c", "d", "e", "f"],2009}2010).with_columns(2011xc=pl.col.x.cast(pl.Categorical),2012xc_right=pl.col.x_right.cast(pl.Categorical),2013)20142015assert_frame_equal(2016cat_key, expected, check_row_order=False, check_column_order=False2017)2018assert_frame_equal(2019cat_payload, expected, check_row_order=False, check_column_order=False2020)202120222023def test_join_where_nested_boolean() -> None:2024df1 = pl.DataFrame({"a": [1, 9, 22], "b": [6, 4, 50]})2025df2 = pl.DataFrame({"c": [1]})20262027predicate = (pl.col("a") < pl.col("b")).cast(pl.Int32) < pl.col("c")2028result = df1.join_where(df2, predicate)2029expected = pl.DataFrame(2030{2031"a": [9],2032"b": [4],2033"c": [1],2034}2035)2036assert_frame_equal(result, expected)203720382039def test_join_where_dtype_upcast() -> None:2040df1 = pl.DataFrame(2041{2042"a": pl.Series([1, 9, 22], dtype=pl.Int8),2043"b": [6, 4, 50],2044}2045)2046df2 = pl.DataFrame({"c": [10]})20472048predicate = (pl.col("a") + (pl.col("b") > 0)) < pl.col("c")2049result = df1.join_where(df2, predicate)2050expected = pl.DataFrame(2051{2052"a": pl.Series([1], dtype=pl.Int8),2053"b": [6],2054"c": [10],2055}2056)2057assert_frame_equal(result, expected)205820592060def test_join_where_valid_dtype_upcast_same_side() -> None:2061# Unsafe comparisons are all contained entirely within one table (LHS)2062# Safe comparisons across both tables.2063df1 = pl.DataFrame(2064{2065"a": pl.Series([1, 9, 22], dtype=pl.Float32),2066"b": [6, 4, 50],2067}2068)2069df2 = pl.DataFrame({"c": [10, 1, 5]})20702071predicate = ((pl.col("a") < pl.col("b")).cast(pl.Int32) + 3) < pl.col("c")2072result = df1.join_where(df2, predicate).sort("a", "b", "c")2073expected = pl.DataFrame(2074{2075"a": pl.Series([1, 1, 9, 9, 22, 22], dtype=pl.Float32),2076"b": [6, 6, 4, 4, 50, 50],2077"c": [5, 10, 5, 10, 5, 10],2078}2079)2080assert_frame_equal(result, expected)208120822083def test_join_where_invalid_dtype_upcast_different_side() -> None:2084# Unsafe comparisons exist across tables.2085df1 = pl.DataFrame(2086{2087"a": pl.Series([1, 9, 22], dtype=pl.Float32),2088"b": pl.Series([6, 4, 50], dtype=pl.Float64),2089}2090)2091df2 = pl.DataFrame({"c": [10, 1, 5]})20922093predicate = ((pl.col("a") >= pl.col("c")) + 3) < 42094with pytest.raises(2095SchemaError, match="'join_where' cannot compare Float32 with Int64"2096):2097df1.join_where(df2, predicate)20982099# add in a cast to predicate to fix2100predicate = ((pl.col("a").cast(pl.UInt8) >= pl.col("c")) + 3) < 42101result = df1.join_where(df2, predicate).sort("a", "b", "c")2102expected = pl.DataFrame(2103{2104"a": pl.Series([1, 1, 9], dtype=pl.Float32),2105"b": pl.Series([6, 6, 4], dtype=pl.Float64),2106"c": [5, 10, 10],2107}2108)2109assert_frame_equal(result, expected)211021112112@pytest.mark.parametrize("dtype", [pl.Int32, pl.Float32])2113def test_join_where_literals(dtype: PolarsDataType) -> None:2114df1 = pl.DataFrame({"a": pl.Series([0, 1], dtype=dtype)})2115df2 = pl.DataFrame({"b": pl.Series([1, 2], dtype=dtype)})2116result = df1.join_where(df2, (pl.col("a") + pl.col("b")) < 2)2117expected = pl.DataFrame(2118{2119"a": pl.Series([0], dtype=dtype),2120"b": pl.Series([1], dtype=dtype),2121}2122)2123assert_frame_equal(result, expected)212421252126def test_join_where_categorical_string_compare() -> None:2127dt = pl.Enum(["a", "b", "c"])2128df1 = pl.DataFrame({"a": pl.Series(["a", "a", "b", "c"], dtype=dt)})2129df2 = pl.DataFrame({"b": [1, 6, 4]})2130predicate = pl.col("a").is_in(["a", "b"]) & (pl.col("b") < 5)2131result = df1.join_where(df2, predicate).sort("a", "b")2132expected = pl.DataFrame(2133{2134"a": pl.Series(["a", "a", "a", "a", "b", "b"], dtype=dt),2135"b": [1, 1, 4, 4, 1, 4],2136}2137)2138assert_frame_equal(result, expected)213921402141def test_join_where_nonboolean_predicate() -> None:2142df1 = pl.DataFrame({"a": [1, 2, 3]})2143df2 = pl.DataFrame({"b": [1, 2, 3]})2144with pytest.raises(2145ComputeError, match="'join_where' predicates must resolve to boolean"2146):2147df1.join_where(df2, pl.col("a") * 2)214821492150def test_empty_outer_join_22206() -> None:2151df = pl.LazyFrame({"a": [5, 6], "b": [1, 2]})2152empty = pl.LazyFrame(schema=df.collect_schema())2153assert_frame_equal(2154df.join(empty, on=["a", "b"], how="full", coalesce=True),2155df,2156check_row_order=False,2157)2158assert_frame_equal(2159empty.join(df, on=["a", "b"], how="full", coalesce=True),2160df,2161check_row_order=False,2162)216321642165def test_join_coalesce_22498() -> None:2166df_a = pl.DataFrame({"y": [2]})2167df_b = pl.DataFrame({"x": [1], "y": [2]})2168df_j = df_a.lazy().join(df_b.lazy(), how="full", on="y", coalesce=True)2169assert_frame_equal(df_j.collect(), pl.DataFrame({"y": [2], "x": [1]}))217021712172def _extract_plan_joins_and_filters(plan: str) -> list[str]:2173return [2174x2175for x in (x.strip() for x in plan.splitlines())2176if x.startswith("LEFT PLAN") # noqa: PIE8102177or x.startswith("RIGHT PLAN")2178or x.startswith("FILTER")2179]218021812182def test_join_filter_pushdown_inner_join() -> None:2183lhs = pl.LazyFrame(2184{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, 4, None], "c": ["a", "b", "c", "d", "e"]}2185)2186rhs = pl.LazyFrame(2187{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, None, 5], "c": ["A", "B", "C", "D", "E"]}2188)21892190# Filter on key output column is pushed to both sides.2191q = lhs.join(rhs, on=["a", "b"], how="inner", maintain_order="left_right").filter(2192pl.col("b") <= 22193)21942195expect = pl.DataFrame(2196{"a": [1, 2], "b": [1, 2], "c": ["a", "b"], "c_right": ["A", "B"]}2197)21982199plan = q.explain()22002201assert _extract_plan_joins_and_filters(plan) == [2202'LEFT PLAN ON: [col("a"), col("b")]',2203'FILTER [(col("b")) <= (2)]',2204'RIGHT PLAN ON: [col("a"), col("b")]',2205'FILTER [(col("b")) <= (2)]',2206]22072208assert_frame_equal(q.collect(), expect)2209assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)22102211# Side-specific filters are all pushed for inner join.2212q = (2213lhs.join(rhs, on=["a", "b"], how="inner", maintain_order="left_right")2214.filter(pl.col("b") <= 2)2215.filter(pl.col("c") == "a", pl.col("c_right") == "A")2216)22172218expect = pl.DataFrame({"a": [1], "b": [1], "c": ["a"], "c_right": ["A"]})22192220plan = q.explain()22212222extract = _extract_plan_joins_and_filters(plan)22232224assert extract[0] == 'LEFT PLAN ON: [col("a"), col("b")]'2225assert 'col("c")) == ("a")' in extract[1]2226assert 'col("b")) <= (2)' in extract[1]22272228assert extract[2] == 'RIGHT PLAN ON: [col("a"), col("b")]'2229assert 'col("b")) <= (2)' in extract[3]2230assert 'col("c")) == ("A")' in extract[3]22312232assert_frame_equal(q.collect(), expect)2233assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)22342235# Filter applied to the non-coalesced `_right` column of an inner-join is2236# also pushed to the left2237# input table.2238q = lhs.join(2239rhs, on=["a", "b"], how="inner", coalesce=False, maintain_order="left_right"2240).filter(pl.col("a_right") <= 2)22412242expect = pl.DataFrame(2243{2244"a": [1, 2],2245"b": [1, 2],2246"c": ["a", "b"],2247"a_right": [1, 2],2248"b_right": [1, 2],2249"c_right": ["A", "B"],2250}2251)22522253plan = q.explain()22542255extract = _extract_plan_joins_and_filters(plan)2256assert extract == [2257'LEFT PLAN ON: [col("a"), col("b")]',2258'FILTER [(col("a")) <= (2)]',2259'RIGHT PLAN ON: [col("a"), col("b")]',2260'FILTER [(col("a")) <= (2)]',2261]22622263assert_frame_equal(q.collect(), expect)2264assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)22652266# Different names in left_on and right_on2267q = lhs.join(2268rhs, left_on="a", right_on="b", how="inner", maintain_order="left_right"2269).filter(pl.col("a") <= 2)22702271expect = pl.DataFrame(2272{2273"a": [1, 2],2274"b": [1, 2],2275"c": ["a", "b"],2276"a_right": [1, 2],2277"c_right": ["A", "B"],2278}2279)22802281plan = q.explain()22822283extract = _extract_plan_joins_and_filters(plan)2284assert extract == [2285'LEFT PLAN ON: [col("a")]',2286'FILTER [(col("a")) <= (2)]',2287'RIGHT PLAN ON: [col("b")]',2288'FILTER [(col("b")) <= (2)]',2289]22902291assert_frame_equal(q.collect(), expect)2292assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)22932294# Different names in left_on and right_on, coalesce=False2295q = lhs.join(2296rhs,2297left_on="a",2298right_on="b",2299how="inner",2300coalesce=False,2301maintain_order="left_right",2302).filter(pl.col("a") <= 2)23032304expect = pl.DataFrame(2305{2306"a": [1, 2],2307"b": [1, 2],2308"c": ["a", "b"],2309"a_right": [1, 2],2310"b_right": [1, 2],2311"c_right": ["A", "B"],2312}2313)23142315plan = q.explain()23162317extract = _extract_plan_joins_and_filters(plan)2318assert extract == [2319'LEFT PLAN ON: [col("a")]',2320'FILTER [(col("a")) <= (2)]',2321'RIGHT PLAN ON: [col("b")]',2322'FILTER [(col("b")) <= (2)]',2323]23242325assert_frame_equal(q.collect(), expect)2326assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)23272328# left_on=col(A), right_on=lit(1). Filters referencing col(A) can only push2329# to the left side.2330q = lhs.join(2331rhs,2332left_on=["a", pl.lit(1)],2333right_on=[pl.lit(1), "b"],2334how="inner",2335coalesce=False,2336maintain_order="left_right",2337).filter(2338pl.col("a") == 1,2339pl.col("b") >= 1,2340pl.col("a_right") <= 1,2341pl.col("b_right") >= 0,2342)23432344expect = pl.DataFrame(2345{2346"a": [1],2347"b": [1],2348"c": ["a"],2349"a_right": [1],2350"b_right": [1],2351"c_right": ["A"],2352}2353)23542355plan = q.explain()23562357extract = _extract_plan_joins_and_filters(plan)23582359assert (2360extract[0]2361== 'LEFT PLAN ON: [col("a").cast(Int64), col("_POLARS_0").cast(Int64)]'2362)2363assert '(col("a")) == (1)' in extract[1]2364assert '(col("b")) >= (1)' in extract[1]2365assert (2366extract[2]2367== 'RIGHT PLAN ON: [col("_POLARS_1").cast(Int64), col("b").cast(Int64)]'2368)2369assert '(col("b")) >= (0)' in extract[3]2370assert 'col("a")) <= (1)' in extract[3]23712372assert_frame_equal(q.collect(), expect)2373assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)23742375# Filters don't pass if they refer to columns from both tables2376# TODO: In the optimizer we can add additional equalities into the join2377# condition itself for some cases.2378q = lhs.join(rhs, on=["a"], how="inner", maintain_order="left_right").filter(2379pl.col("b") == pl.col("b_right")2380)23812382expect = pl.DataFrame(2383{2384"a": [1, 2, 3],2385"b": [1, 2, 3],2386"c": ["a", "b", "c"],2387"b_right": [1, 2, 3],2388"c_right": ["A", "B", "C"],2389}2390)23912392plan = q.explain()23932394extract = _extract_plan_joins_and_filters(plan)2395assert extract == [2396'FILTER [(col("b")) == (col("b_right"))]',2397'LEFT PLAN ON: [col("a")]',2398'RIGHT PLAN ON: [col("a")]',2399]24002401assert_frame_equal(q.collect(), expect)2402assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)24032404# Duplicate filter removal - https://github.com/pola-rs/polars/issues/232432405q = (2406pl.LazyFrame({"x": [1, 2, 3]})2407.join(pl.LazyFrame({"x": [1, 2, 3]}), on="x", how="inner", coalesce=False)2408.filter(2409pl.col("x") == 2,2410pl.col("x_right") == 2,2411)2412)24132414expect = pl.DataFrame(2415[2416pl.Series("x", [2], dtype=pl.Int64),2417pl.Series("x_right", [2], dtype=pl.Int64),2418]2419)24202421plan = q.explain()24222423extract = _extract_plan_joins_and_filters(plan)24242425assert extract == [2426'LEFT PLAN ON: [col("x")]',2427'FILTER [(col("x")) == (2)]',2428'RIGHT PLAN ON: [col("x")]',2429'FILTER [(col("x")) == (2)]',2430]24312432assert_frame_equal(q.collect(), expect)2433assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)243424352436def test_join_filter_pushdown_left_join() -> None:2437lhs = pl.LazyFrame(2438{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, 4, None], "c": ["a", "b", "c", "d", "e"]}2439)2440rhs = pl.LazyFrame(2441{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, None, 5], "c": ["A", "B", "C", "D", "E"]}2442)24432444# Filter on key output column is pushed to both sides.2445q = lhs.join(rhs, on=["a", "b"], how="left", maintain_order="left_right").filter(2446pl.col("b") <= 22447)24482449expect = pl.DataFrame(2450{"a": [1, 2], "b": [1, 2], "c": ["a", "b"], "c_right": ["A", "B"]}2451)24522453plan = q.explain()24542455extract = _extract_plan_joins_and_filters(plan)2456assert extract == [2457'LEFT PLAN ON: [col("a"), col("b")]',2458'FILTER [(col("b")) <= (2)]',2459'RIGHT PLAN ON: [col("a"), col("b")]',2460'FILTER [(col("b")) <= (2)]',2461]24622463assert_frame_equal(q.collect(), expect)2464assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)24652466# Filter on key output column is pushed to both sides.2467# This tests joins on differing left/right names.2468q = lhs.join(2469rhs, left_on="a", right_on="b", how="left", maintain_order="left_right"2470).filter(pl.col("a") <= 2)24712472expect = pl.DataFrame(2473{2474"a": [1, 2],2475"b": [1, 2],2476"c": ["a", "b"],2477"a_right": [1, 2],2478"c_right": ["A", "B"],2479}2480)24812482plan = q.explain()24832484extract = _extract_plan_joins_and_filters(plan)2485assert extract == [2486'LEFT PLAN ON: [col("a")]',2487'FILTER [(col("a")) <= (2)]',2488'RIGHT PLAN ON: [col("b")]',2489'FILTER [(col("b")) <= (2)]',2490]24912492assert_frame_equal(q.collect(), expect)2493assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)24942495# Filters referring to columns that exist only in the left table can be pushed.2496q = lhs.join(rhs, on=["a", "b"], how="left", maintain_order="left_right").filter(2497pl.col("c") == "b"2498)24992500expect = pl.DataFrame({"a": [2], "b": [2], "c": ["b"], "c_right": ["B"]})25012502plan = q.explain()25032504extract = _extract_plan_joins_and_filters(plan)2505assert extract == [2506'LEFT PLAN ON: [col("a"), col("b")]',2507'FILTER [(col("c")) == ("b")]',2508'RIGHT PLAN ON: [col("a"), col("b")]',2509]25102511assert_frame_equal(q.collect(), expect)2512assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)25132514# Filters referring to columns that exist only in the right table cannot be2515# pushed for left-join2516q = lhs.join(rhs, on=["a", "b"], how="left", maintain_order="left_right").filter(2517# Note: `eq_missing` to block join downgrade.2518pl.col("c_right").eq_missing("B")2519)25202521expect = pl.DataFrame({"a": [2], "b": [2], "c": ["b"], "c_right": ["B"]})25222523plan = q.explain()25242525extract = _extract_plan_joins_and_filters(plan)2526assert extract == [2527'FILTER [(col("c_right")) ==v ("B")]',2528'LEFT PLAN ON: [col("a"), col("b")]',2529'RIGHT PLAN ON: [col("a"), col("b")]',2530]25312532assert_frame_equal(q.collect(), expect)2533assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)25342535# Filters referring to a non-coalesced key column originating from the right2536# table cannot be pushed.2537#2538# Note, technically it's possible to push these filters if we can guarantee that2539# they do not remove NULLs (or otherwise if we also apply the filter on the2540# result table). But this is not something we do at the moment.2541q = lhs.join(2542rhs, on=["a", "b"], how="left", coalesce=False, maintain_order="left_right"2543).filter(pl.col("b_right").eq_missing(2))25442545expect = pl.DataFrame(2546{2547"a": [2],2548"b": [2],2549"c": ["b"],2550"a_right": [2],2551"b_right": [2],2552"c_right": ["B"],2553}2554)25552556plan = q.explain()25572558extract = _extract_plan_joins_and_filters(plan)2559assert extract == [2560'FILTER [(col("b_right")) ==v (2)]',2561'LEFT PLAN ON: [col("a"), col("b")]',2562'RIGHT PLAN ON: [col("a"), col("b")]',2563]25642565assert_frame_equal(q.collect(), expect)2566assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)256725682569def test_join_filter_pushdown_right_join() -> None:2570lhs = pl.LazyFrame(2571{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, 4, None], "c": ["a", "b", "c", "d", "e"]}2572)2573rhs = pl.LazyFrame(2574{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, None, 5], "c": ["A", "B", "C", "D", "E"]}2575)25762577# Filter on key output column is pushed to both sides.2578q = lhs.join(rhs, on=["a", "b"], how="right", maintain_order="left_right").filter(2579pl.col("b") <= 22580)25812582expect = pl.DataFrame(2583{"c": ["a", "b"], "a": [1, 2], "b": [1, 2], "c_right": ["A", "B"]}2584)25852586plan = q.explain()25872588extract = _extract_plan_joins_and_filters(plan)2589assert extract == [2590'LEFT PLAN ON: [col("a"), col("b")]',2591'FILTER [(col("b")) <= (2)]',2592'RIGHT PLAN ON: [col("a"), col("b")]',2593'FILTER [(col("b")) <= (2)]',2594]25952596assert_frame_equal(q.collect(), expect)2597assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)25982599# Filter on key output column is pushed to both sides.2600# This tests joins on differing left/right names.2601# col(A) is coalesced into col(B) (from right), but col(B) is named as2602# col(B_right) in the output because the LHS table also has a col(B).2603q = lhs.join(2604rhs, left_on="a", right_on="b", how="right", maintain_order="left_right"2605).filter(pl.col("b_right") <= 2)26062607expect = pl.DataFrame(2608{2609"b": [1, 2],2610"c": ["a", "b"],2611"a": [1, 2],2612"b_right": [1, 2],2613"c_right": ["A", "B"],2614}2615)26162617plan = q.explain()26182619extract = _extract_plan_joins_and_filters(plan)2620assert extract == [2621'LEFT PLAN ON: [col("a")]',2622'FILTER [(col("a")) <= (2)]',2623'RIGHT PLAN ON: [col("b")]',2624'FILTER [(col("b")) <= (2)]',2625]26262627assert_frame_equal(q.collect(), expect)2628assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)26292630# Filters referring to columns that exist only in the right table can be pushed.2631q = lhs.join(rhs, on=["a", "b"], how="right", maintain_order="left_right").filter(2632pl.col("c_right") == "B"2633)26342635expect = pl.DataFrame({"c": ["b"], "a": [2], "b": [2], "c_right": ["B"]})26362637plan = q.explain()26382639extract = _extract_plan_joins_and_filters(plan)2640assert extract == [2641'LEFT PLAN ON: [col("a"), col("b")]',2642'RIGHT PLAN ON: [col("a"), col("b")]',2643'FILTER [(col("c")) == ("B")]',2644]26452646assert_frame_equal(q.collect(), expect)2647assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)26482649# Filters referring to columns that exist only in the left table cannot be2650# pushed for right-join2651q = lhs.join(rhs, on=["a", "b"], how="right", maintain_order="left_right").filter(2652# Note: eq_missing to block join downgrade2653pl.col("c").eq_missing("b")2654)26552656expect = pl.DataFrame({"c": ["b"], "a": [2], "b": [2], "c_right": ["B"]})26572658plan = q.explain()26592660extract = _extract_plan_joins_and_filters(plan)2661assert extract == [2662'FILTER [(col("c")) ==v ("b")]',2663'LEFT PLAN ON: [col("a"), col("b")]',2664'RIGHT PLAN ON: [col("a"), col("b")]',2665]26662667assert_frame_equal(q.collect(), expect)2668assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)26692670# Filters referring to a non-coalesced key column originating from the left2671# table cannot be pushed for right-join.2672q = lhs.join(2673rhs, on=["a", "b"], how="right", coalesce=False, maintain_order="left_right"2674).filter(pl.col("b").eq_missing(2))26752676expect = pl.DataFrame(2677{2678"a": [2],2679"b": [2],2680"c": ["b"],2681"a_right": [2],2682"b_right": [2],2683"c_right": ["B"],2684}2685)26862687plan = q.explain()26882689extract = _extract_plan_joins_and_filters(plan)2690assert extract == [2691'FILTER [(col("b")) ==v (2)]',2692'LEFT PLAN ON: [col("a"), col("b")]',2693'RIGHT PLAN ON: [col("a"), col("b")]',2694]26952696assert_frame_equal(q.collect(), expect)2697assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)269826992700def test_join_filter_pushdown_full_join() -> None:2701lhs = pl.LazyFrame(2702{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, 4, None], "c": ["a", "b", "c", "d", "e"]}2703)2704rhs = pl.LazyFrame(2705{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, None, 5], "c": ["A", "B", "C", "D", "E"]}2706)27072708# Full join can only push filters that refer to coalesced key columns.2709q = lhs.join(2710rhs,2711left_on="a",2712right_on="b",2713how="full",2714coalesce=True,2715maintain_order="left_right",2716).filter(pl.col("a") == 2)27172718expect = pl.DataFrame(2719{2720"a": [2],2721"b": [2],2722"c": ["b"],2723"a_right": [2],2724"c_right": ["B"],2725}2726)27272728plan = q.explain()2729extract = _extract_plan_joins_and_filters(plan)27302731assert extract == [2732'LEFT PLAN ON: [col("a")]',2733'FILTER [(col("a")) == (2)]',2734'RIGHT PLAN ON: [col("b")]',2735'FILTER [(col("b")) == (2)]',2736]27372738assert_frame_equal(q.collect(), expect)2739assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)27402741# Non-coalescing full-join cannot push any filters2742# Note: We add fill_null to bypass non-NULL filter mask detection.2743q = lhs.join(2744rhs,2745left_on="a",2746right_on="b",2747how="full",2748coalesce=False,2749maintain_order="left_right",2750).filter(2751pl.col("a").fill_null(0) >= 2,2752pl.col("a").fill_null(0) <= 2,2753)27542755expect = pl.DataFrame(2756{2757"a": [2],2758"b": [2],2759"c": ["b"],2760"a_right": [2],2761"b_right": [2],2762"c_right": ["B"],2763}2764)27652766plan = q.explain()2767extract = _extract_plan_joins_and_filters(plan)27682769assert extract[0].startswith("FILTER ")2770assert extract[1:] == [2771'LEFT PLAN ON: [col("a")]',2772'RIGHT PLAN ON: [col("b")]',2773]27742775assert_frame_equal(q.collect(), expect)2776assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)277727782779def test_join_filter_pushdown_semi_join() -> None:2780lhs = pl.LazyFrame(2781{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, 4, None], "c": ["a", "b", "c", "d", "e"]}2782)2783rhs = pl.LazyFrame(2784{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, None, 5], "c": ["A", "B", "C", "D", "E"]}2785)27862787q = lhs.join(2788rhs,2789left_on=["a", "b"],2790right_on=["b", pl.lit(2)],2791how="semi",2792maintain_order="left_right",2793).filter(pl.col("a") == 2, pl.col("b") == 2, pl.col("c") == "b")27942795expect = pl.DataFrame(2796{2797"a": [2],2798"b": [2],2799"c": ["b"],2800}2801)28022803plan = q.explain()2804extract = _extract_plan_joins_and_filters(plan)28052806# * filter on col(a) is pushed to both sides (renamed to col(b) in the right side)2807# * filter on col(b) is pushed only to left, as the right join key is a literal2808# * filter on col(c) is pushed only to left, as the column does not exist in2809# the right.28102811assert extract[0] == 'LEFT PLAN ON: [col("a"), col("b").cast(Int64)]'2812assert 'col("a")) == (2)' in extract[1]2813assert 'col("b")) == (2)' in extract[1]2814assert 'col("c")) == ("b")' in extract[1]28152816assert extract[2:] == [2817'RIGHT PLAN ON: [col("b"), col("_POLARS_0").cast(Int64)]',2818'FILTER [(col("b")) == (2)]',2819]28202821assert_frame_equal(q.collect(), expect)2822assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)282328242825def test_join_filter_pushdown_anti_join() -> None:2826lhs = pl.LazyFrame(2827{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, 4, None], "c": ["a", "b", "c", "d", "e"]}2828)2829rhs = pl.LazyFrame(2830{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, None, 5], "c": ["A", "B", "C", "D", "E"]}2831)28322833q = lhs.join(2834rhs,2835left_on=["a", "b"],2836right_on=["b", pl.lit(1)],2837how="anti",2838maintain_order="left_right",2839).filter(pl.col("a") == 2, pl.col("b") == 2, pl.col("c") == "b")28402841expect = pl.DataFrame(2842{2843"a": [2],2844"b": [2],2845"c": ["b"],2846}2847)28482849plan = q.explain()2850extract = _extract_plan_joins_and_filters(plan)28512852assert extract[0] == 'LEFT PLAN ON: [col("a"), col("b").cast(Int64)]'2853assert 'col("a")) == (2)' in extract[1]2854assert 'col("b")) == (2)' in extract[1]2855assert 'col("c")) == ("b")' in extract[1]28562857assert extract[2:] == [2858'RIGHT PLAN ON: [col("b"), col("_POLARS_0").cast(Int64)]',2859'FILTER [(col("b")) == (2)]',2860]28612862assert_frame_equal(q.collect(), expect)2863assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)286428652866def test_join_filter_pushdown_cross_join() -> None:2867lhs = pl.LazyFrame(2868{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, 4, None], "c": ["a", "b", "c", "d", "e"]}2869)2870rhs = pl.LazyFrame(2871{"a": [0, 0, 0, 0, 0], "b": [1, 2, 3, None, 5], "c": ["A", "B", "C", "D", "E"]}2872)28732874# Nested loop join for `!=`2875q = (2876lhs.with_row_index()2877.join(rhs, how="cross")2878.filter(2879pl.col("a") <= 4, pl.col("c_right") <= "B", pl.col("a") != pl.col("a_right")2880)2881.sort("index")2882)28832884expect = pl.DataFrame(2885[2886pl.Series("index", [0, 0, 1, 1, 2, 2, 3, 3], dtype=pl.get_index_type()),2887pl.Series("a", [1, 1, 2, 2, 3, 3, 4, 4], dtype=pl.Int64),2888pl.Series("b", [1, 1, 2, 2, 3, 3, 4, 4], dtype=pl.Int64),2889pl.Series("c", ["a", "a", "b", "b", "c", "c", "d", "d"], dtype=pl.String),2890pl.Series("a_right", [0, 0, 0, 0, 0, 0, 0, 0], dtype=pl.Int64),2891pl.Series("b_right", [1, 2, 1, 2, 1, 2, 1, 2], dtype=pl.Int64),2892pl.Series(2893"c_right", ["A", "B", "A", "B", "A", "B", "A", "B"], dtype=pl.String2894),2895]2896)28972898plan = q.explain()28992900assert 'NESTED LOOP JOIN ON [(col("a")) != (col("a_right"))]' in plan29012902extract = _extract_plan_joins_and_filters(plan)29032904assert extract == [2905"LEFT PLAN:",2906'FILTER [(col("a")) <= (4)]',2907"RIGHT PLAN:",2908'FILTER [(col("c")) <= ("B")]',2909]29102911assert_frame_equal(q.collect(), expect)2912assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)29132914# Conversion to inner-join for `==`2915q = lhs.join(rhs, how="cross", maintain_order="left_right").filter(2916pl.col("a") <= 4,2917pl.col("c_right") <= "B",2918pl.col("a") == (pl.col("a_right") + 1),2919)29202921expect = pl.DataFrame(2922{2923"a": [1, 1],2924"b": [1, 1],2925"c": ["a", "a"],2926"a_right": [0, 0],2927"b_right": [1, 2],2928"c_right": ["A", "B"],2929}2930)29312932plan = q.explain()29332934extract = _extract_plan_joins_and_filters(plan)29352936assert extract == [2937'LEFT PLAN ON: [col("a")]',2938'FILTER [(col("a")) <= (4)]',2939'RIGHT PLAN ON: [[(col("a")) + (1)]]',2940'FILTER [(col("c")) <= ("B")]',2941]29422943assert_frame_equal(q.collect(), expect)2944assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)29452946# Avoid conversion for order maintaining cross-join2947q = (2948pl.LazyFrame(2949[2950pl.Series("a", [2, 4, 8, 9, 11], dtype=pl.Int64),2951pl.Series("b", [1, 2, 3, 4, 5], dtype=pl.Int64),2952]2953)2954.join(2955pl.LazyFrame(2956{2957"c": [0, 1, 2, 3, 4],2958}2959),2960how="cross",2961maintain_order="left_right",2962)2963.filter(pl.col("c") <= pl.col("b"))2964)29652966expect = pl.DataFrame(2967[2968pl.Series(2969"a",2970[2, 2, 4, 4, 4, 8, 8, 8, 8, 9, 9, 9, 9, 9, 11, 11, 11, 11, 11],2971dtype=pl.Int64,2972),2973pl.Series(2974"b",2975[1, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5],2976dtype=pl.Int64,2977),2978pl.Series(2979"c",2980[0, 1, 0, 1, 2, 0, 1, 2, 3, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4],2981dtype=pl.Int64,2982),2983]2984)29852986plan = q.explain()29872988assert plan.startswith('NESTED LOOP JOIN ON [(col("c")) <= (col("b"))]:')29892990assert_frame_equal(q.collect(), expect)2991assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)299229932994def test_join_filter_pushdown_iejoin() -> None:2995lhs = pl.LazyFrame(2996{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, 4, None], "c": ["a", "b", "c", "d", "e"]}2997)2998rhs = pl.LazyFrame(2999{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, None, 5], "c": ["A", "B", "C", "D", "E"]}3000)30013002q = (3003lhs.with_row_index()3004.join_where(3005rhs,3006pl.col("a") >= 1,3007pl.col("a") == pl.col("a_right"),3008pl.col("c_right") <= "B",3009)3010.sort("index")3011)30123013expect = pl.DataFrame(3014{3015"a": [1, 2],3016"b": [1, 2],3017"c": ["a", "b"],3018"a_right": [1, 2],3019"b_right": [1, 2],3020"c_right": ["A", "B"],3021}3022).with_row_index()30233024plan = q.explain()30253026assert "INNER JOIN" in plan30273028extract = _extract_plan_joins_and_filters(plan)30293030assert extract[:3] == [3031'LEFT PLAN ON: [col("a")]',3032'FILTER [(col("a")) >= (1)]',3033'RIGHT PLAN ON: [col("a")]',3034]30353036assert extract[3].startswith("FILTER")3037assert 'col("c")) <= ("B")' in extract[3]3038assert '(col("a")) >= (1)' in extract[3]3039assert len(extract) == 430403041assert_frame_equal(q.collect(), expect)3042assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)30433044q = (3045lhs.with_row_index()3046.join_where(3047rhs,3048pl.col("a") >= 1,3049pl.col("a") >= pl.col("a_right"),3050pl.col("c_right") <= "B",3051)3052.sort("index")3053)30543055expect = pl.DataFrame(3056[3057pl.Series("index", [0, 1, 1, 2, 2, 3, 3, 4, 4], dtype=pl.get_index_type()),3058pl.Series("a", [1, 2, 2, 3, 3, 4, 4, 5, 5], dtype=pl.Int64),3059pl.Series("b", [1, 2, 2, 3, 3, 4, 4, None, None], dtype=pl.Int64),3060pl.Series(3061"c", ["a", "b", "b", "c", "c", "d", "d", "e", "e"], dtype=pl.String3062),3063pl.Series("a_right", [1, 1, 2, 1, 2, 1, 2, 1, 2], dtype=pl.Int64),3064pl.Series("b_right", [1, 1, 2, 1, 2, 1, 2, 1, 2], dtype=pl.Int64),3065pl.Series(3066"c_right",3067["A", "A", "B", "A", "B", "A", "B", "A", "B"],3068dtype=pl.String,3069),3070]3071)30723073plan = q.explain()30743075assert "IEJOIN" in plan30763077extract = _extract_plan_joins_and_filters(plan)30783079assert extract == [3080'LEFT PLAN ON: [col("a")]',3081'FILTER [(col("a")) >= (1)]',3082'RIGHT PLAN ON: [col("a")]',3083'FILTER [(col("c")) <= ("B")]',3084]30853086assert_frame_equal(q.collect().sort(pl.all()), expect)3087assert_frame_equal(3088q.collect(optimizations=pl.QueryOptFlags.none()).sort(pl.all()),3089expect,3090)30913092q = pl.LazyFrame({"x": [1, 2, 3]}).join_where(3093pl.LazyFrame({"x": [1, 2, 3]}),3094pl.col("x") > pl.col("x_right"),3095pl.col("x") > 1,3096)30973098expect = pl.DataFrame(3099[3100pl.Series("x", [2, 3, 3], dtype=pl.Int64),3101pl.Series("x_right", [1, 1, 2], dtype=pl.Int64),3102]3103)31043105plan = q.explain()31063107assert "IEJOIN" in plan31083109extract = _extract_plan_joins_and_filters(plan)31103111assert extract == [3112'LEFT PLAN ON: [col("x")]',3113'FILTER [(col("x")) > (1)]',3114'RIGHT PLAN ON: [col("x")]',3115]31163117assert_frame_equal(q.collect().sort(pl.all()), expect)3118assert_frame_equal(3119q.collect(optimizations=pl.QueryOptFlags.none()).sort(pl.all()),3120expect,3121)31223123# Join filter pushdown inside CSE - https://github.com/pola-rs/polars/issues/2348931243125lf_x = pl.LazyFrame({"a": [1, 2, 3], "b": [1, 2, 3]})3126lf_y = pl.LazyFrame({"a": [1, 2, 3], "b": [3, 3, 3]})31273128lf_xy = lf_x.join_where(lf_y, pl.col("a") > pl.col("a_right"))31293130q = pl.concat([lf_xy, lf_xy]).filter(3131pl.col("b") < pl.col("b_right"), pl.col("a") > 03132)31333134expect = pl.DataFrame(3135[3136pl.Series("a", [2, 2], dtype=pl.Int64),3137pl.Series("b", [2, 2], dtype=pl.Int64),3138pl.Series("a_right", [1, 1], dtype=pl.Int64),3139pl.Series("b_right", [3, 3], dtype=pl.Int64),3140]3141)31423143plan = q.explain()31443145assert "IEJOIN" in plan31463147extract = _extract_plan_joins_and_filters(plan)31483149assert extract[0] in {3150'LEFT PLAN ON: [col("a"), col("b")]',3151'LEFT PLAN ON: [col("b"), col("a")]',3152}3153assert extract[1] == 'FILTER [(col("a")) > (0)]'3154assert extract[2] in {3155'RIGHT PLAN ON: [col("a"), col("b")]',3156'RIGHT PLAN ON: [col("b"), col("a")]',3157}3158assert extract[3] in {3159'LEFT PLAN ON: [col("a"), col("b")]',3160'LEFT PLAN ON: [col("b"), col("a")]',3161}3162assert extract[4] == 'FILTER [(col("a")) > (0)]'3163assert extract[5] in {3164'RIGHT PLAN ON: [col("a"), col("b")]',3165'RIGHT PLAN ON: [col("b"), col("a")]',3166}3167assert len(extract) == 631683169assert_frame_equal(q.collect().sort(pl.all()), expect)3170assert_frame_equal(3171q.collect(optimizations=pl.QueryOptFlags.none()).sort(pl.all()),3172expect,3173)317431753176def test_join_filter_pushdown_asof_join() -> None:3177lhs = pl.LazyFrame(3178{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, 4, None], "c": ["a", "b", "c", "d", "e"]}3179)3180rhs = pl.LazyFrame(3181{3182"a": [1, 2, 3, 4, 5],3183"b": [1, 2, 3, None, None],3184"c": ["A", "B", "C", "D", "E"],3185}3186)31873188q = lhs.join_asof(3189rhs,3190left_on=pl.col("a").set_sorted(),3191right_on=pl.col("b").set_sorted(),3192tolerance=0,3193).filter(3194pl.col("a") >= 2,3195pl.col("b") >= 3,3196pl.col("c") >= "A",3197pl.col("c_right") >= "B",3198)31993200expect = pl.DataFrame(3201{3202"a": [3],3203"b": [3],3204"c": ["c"],3205"a_right": [3],3206"b_right": [3],3207"c_right": ["C"],3208}3209)32103211plan = q.explain()3212extract = _extract_plan_joins_and_filters(plan)32133214assert extract[:2] == [3215'FILTER [(col("c_right")) >= ("B")]',3216'LEFT PLAN ON: [col("a").set_sorted()]',3217]32183219assert 'col("b")) >= (3)' in extract[2]3220assert 'col("c")) >= ("A")' in extract[2]3221assert 'col("a")) >= (2)' in extract[2]32223223assert extract[3:] == ['RIGHT PLAN ON: [col("b").set_sorted()]']32243225assert_frame_equal(q.collect(), expect)3226assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)32273228# With "by" columns3229q = lhs.join_asof(3230rhs,3231left_on="a",3232right_on="b",3233tolerance=99,3234by_left="b",3235by_right="a",3236).filter(3237pl.col("a") >= 2,3238pl.col("b") >= 3,3239pl.col("c") >= "A",3240pl.col("c_right") >= "B",3241)32423243expect = pl.DataFrame(3244{3245"a": [3],3246"b": [3],3247"c": ["c"],3248"b_right": [3],3249"c_right": ["C"],3250}3251)32523253plan = q.explain()3254extract = _extract_plan_joins_and_filters(plan)32553256assert extract[:2] == [3257'FILTER [(col("c_right")) >= ("B")]',3258'LEFT PLAN ON: [col("a")]',3259]3260assert 'col("a")) >= (2)' in extract[2]3261assert 'col("b")) >= (3)' in extract[2]32623263assert extract[3:] == [3264'RIGHT PLAN ON: [col("b")]',3265'FILTER [(col("a")) >= (3)]',3266]32673268assert_frame_equal(q.collect(), expect)3269assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)327032713272def test_join_filter_pushdown_full_join_rewrite() -> None:3273lhs = pl.LazyFrame(3274{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, 4, None], "c": ["a", "b", "c", "d", "e"]}3275)3276rhs = pl.LazyFrame(3277{3278"a": [1, 2, 3, 4, None],3279"b": [1, 2, 3, None, 5],3280"c": ["A", "B", "C", "D", "E"],3281}3282)32833284# Downgrades to left-join3285q = lhs.join(rhs, on=["a", "b"], how="full", maintain_order="left_right").filter(3286pl.col("b") >= 33287)32883289expect = pl.DataFrame(3290[3291pl.Series("a", [3, 4], dtype=pl.Int64),3292pl.Series("b", [3, 4], dtype=pl.Int64),3293pl.Series("c", ["c", "d"], dtype=pl.String),3294pl.Series("a_right", [3, None], dtype=pl.Int64),3295pl.Series("b_right", [3, None], dtype=pl.Int64),3296pl.Series("c_right", ["C", None], dtype=pl.String),3297]3298)32993300plan = q.explain()33013302assert "FULL JOIN" not in plan3303assert plan.startswith("LEFT JOIN")33043305extract = _extract_plan_joins_and_filters(plan)33063307assert extract == [3308'LEFT PLAN ON: [col("a"), col("b")]',3309'FILTER [(col("b")) >= (3)]',3310'RIGHT PLAN ON: [col("a"), col("b")]',3311'FILTER [(col("b")) >= (3)]',3312]33133314assert_frame_equal(q.collect(), expect)3315assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)33163317# Downgrades to right-join3318q = lhs.join(3319rhs, left_on="a", right_on="b", how="full", maintain_order="left_right"3320).filter(pl.col("b_right") >= 3)33213322expect = pl.DataFrame(3323[3324pl.Series("a", [3, 5], dtype=pl.Int64),3325pl.Series("b", [3, None], dtype=pl.Int64),3326pl.Series("c", ["c", "e"], dtype=pl.String),3327pl.Series("a_right", [3, None], dtype=pl.Int64),3328pl.Series("b_right", [3, 5], dtype=pl.Int64),3329pl.Series("c_right", ["C", "E"], dtype=pl.String),3330]3331)33323333plan = q.explain()33343335assert "FULL JOIN" not in plan3336assert "RIGHT JOIN" in plan33373338extract = _extract_plan_joins_and_filters(plan)33393340assert extract == [3341'LEFT PLAN ON: [col("a")]',3342'FILTER [(col("a")) >= (3)]',3343'RIGHT PLAN ON: [col("b")]',3344'FILTER [(col("b")) >= (3)]',3345]33463347assert_frame_equal(q.collect(), expect)3348assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)33493350# Downgrades to right-join3351q = lhs.join(3352rhs,3353left_on="a",3354right_on="b",3355how="full",3356coalesce=True,3357maintain_order="left_right",3358).filter(3359(pl.col("a") >= 1) | pl.col("a").is_null(), # col(a) from LHS3360pl.col("a_right") >= 3, # col(a) from RHS3361(pl.col("b") >= 2) | pl.col("b").is_null(), # col(b) from LHS3362pl.col("c_right") >= "C", # col(c) from RHS3363)33643365expect = pl.DataFrame(3366[3367pl.Series("a", [3, None], dtype=pl.Int64),3368pl.Series("b", [3, None], dtype=pl.Int64),3369pl.Series("c", ["c", None], dtype=pl.String),3370pl.Series("a_right", [3, 4], dtype=pl.Int64),3371pl.Series("c_right", ["C", "D"], dtype=pl.String),3372]3373)33743375plan = q.explain()33763377assert "FULL JOIN" not in plan3378assert "RIGHT JOIN" in plan33793380extract = _extract_plan_joins_and_filters(plan)33813382assert [3383'FILTER [([(col("b")) >= (2)]) | (col("b").is_null())]',3384'LEFT PLAN ON: [col("a")]',3385'FILTER [([(col("a")) >= (1)]) | (col("a").is_null())]',3386'RIGHT PLAN ON: [col("b")]',3387]33883389assert 'col("a")) >= (3)' in extract[4]3390assert '(col("b")) >= (1)]) | (col("b").alias("a").is_null())' in extract[4]3391assert 'col("c")) >= ("C")' in extract[4]33923393assert len(extract) == 533943395assert_frame_equal(q.collect(), expect)3396assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)33973398# Downgrades to inner-join3399q = lhs.join(rhs, on=["a", "b"], how="full", maintain_order="left_right").filter(3400pl.col("b").is_not_null(), pl.col("b_right").is_not_null()3401)34023403expect = pl.DataFrame(3404[3405pl.Series("a", [1, 2, 3], dtype=pl.Int64),3406pl.Series("b", [1, 2, 3], dtype=pl.Int64),3407pl.Series("c", ["a", "b", "c"], dtype=pl.String),3408pl.Series("a_right", [1, 2, 3], dtype=pl.Int64),3409pl.Series("b_right", [1, 2, 3], dtype=pl.Int64),3410pl.Series("c_right", ["A", "B", "C"], dtype=pl.String),3411]3412)34133414plan = q.explain()34153416assert "FULL JOIN" not in plan3417assert plan.startswith("INNER JOIN")34183419extract = _extract_plan_joins_and_filters(plan)34203421assert extract[0] == 'LEFT PLAN ON: [col("a"), col("b")]'3422assert 'col("b").is_not_null()' in extract[1]3423assert 'col("b").alias("b_right").is_not_null()' in extract[1]34243425assert extract[2] == 'RIGHT PLAN ON: [col("a"), col("b")]'3426assert 'col("b").is_not_null()' in extract[3]3427assert 'col("b").alias("b_right").is_not_null()' in extract[3]34283429assert len(extract) == 434303431assert_frame_equal(q.collect(), expect)3432assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)34333434# Does not downgrade because col(b) is a coalesced key-column, but the filter3435# is still pushed to both sides.3436q = lhs.join(3437rhs, on=["a", "b"], how="full", coalesce=True, maintain_order="left_right"3438).filter(pl.col("b") >= 3)34393440expect = pl.DataFrame(3441[3442pl.Series("a", [3, 4, None], dtype=pl.Int64),3443pl.Series("b", [3, 4, 5], dtype=pl.Int64),3444pl.Series("c", ["c", "d", None], dtype=pl.String),3445pl.Series("c_right", ["C", None, "E"], dtype=pl.String),3446]3447)34483449plan = q.explain()3450assert plan.startswith("FULL JOIN")34513452extract = _extract_plan_joins_and_filters(plan)34533454assert extract == [3455'LEFT PLAN ON: [col("a"), col("b")]',3456'FILTER [(col("b")) >= (3)]',3457'RIGHT PLAN ON: [col("a"), col("b")]',3458'FILTER [(col("b")) >= (3)]',3459]34603461assert_frame_equal(q.collect(), expect)3462assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)346334643465def test_join_filter_pushdown_right_join_rewrite() -> None:3466lhs = pl.LazyFrame(3467{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, 4, None], "c": ["a", "b", "c", "d", "e"]}3468)3469rhs = pl.LazyFrame(3470{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, None, 5], "c": ["A", "B", "C", "D", "E"]}3471)34723473# Downgrades to inner-join3474q = lhs.join(3475rhs,3476left_on="a",3477right_on="b",3478how="right",3479coalesce=True,3480maintain_order="left_right",3481).filter(3482pl.col("a") <= 7, # col(a) from RHS (LHS col(a) is coalesced into col(b_right))3483pl.col("b_right") <= 10, # Key-column filter3484pl.col("c") <= "b", # col(c) from LHS3485)34863487expect = pl.DataFrame(3488[3489pl.Series("b", [1, 2], dtype=pl.Int64),3490pl.Series("c", ["a", "b"], dtype=pl.String),3491pl.Series("a", [1, 2], dtype=pl.Int64),3492pl.Series("b_right", [1, 2], dtype=pl.Int64),3493pl.Series("c_right", ["A", "B"], dtype=pl.String),3494]3495)34963497plan = q.explain()34983499assert "RIGHT JOIN" not in plan3500assert "INNER JOIN" in plan35013502extract = _extract_plan_joins_and_filters(plan)35033504assert extract[0] == 'LEFT PLAN ON: [col("a")]'3505assert 'col("a")) <= (10)' in extract[1]3506assert 'col("c")) <= ("b")' in extract[1]35073508assert extract[2] == 'RIGHT PLAN ON: [col("b")]'3509assert 'col("a")) <= (7)' in extract[3]3510assert 'col("b")) <= (10)' in extract[3]35113512assert len(extract) == 435133514assert_frame_equal(q.collect(), expect)3515assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)351635173518def test_join_filter_pushdown_join_rewrite_equality_above_and() -> None:3519lhs = pl.LazyFrame(3520{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, 4, None], "c": ["a", "b", "c", "d", "e"]}3521)3522rhs = pl.LazyFrame(3523{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, None, 5], "c": ["A", "B", "C", "D", "E"]}3524)35253526q = lhs.join(3527rhs,3528left_on="a",3529right_on="b",3530how="full",3531coalesce=False,3532maintain_order="left_right",3533).filter(((pl.col("b") >= 3) & False) >= False)35343535expect = pl.DataFrame(3536[3537pl.Series("a", [1, 2, 3, 4, 5, None], dtype=pl.Int64),3538pl.Series("b", [1, 2, 3, 4, None, None], dtype=pl.Int64),3539pl.Series("c", ["a", "b", "c", "d", "e", None], dtype=pl.String),3540pl.Series("a_right", [1, 2, 3, None, 5, 4], dtype=pl.Int64),3541pl.Series("b_right", [1, 2, 3, None, 5, None], dtype=pl.Int64),3542pl.Series("c_right", ["A", "B", "C", None, "E", "D"], dtype=pl.String),3543]3544)35453546assert_frame_equal(q.collect(), expect)3547assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)354835493550def test_join_filter_pushdown_left_join_rewrite() -> None:3551lhs = pl.LazyFrame(3552{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, 4, None], "c": ["a", "b", "c", "d", "e"]}3553)3554rhs = pl.LazyFrame(3555{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, None, 5], "c": ["A", "B", None, "D", "E"]}3556)35573558# Downgrades to inner-join3559q = lhs.join(3560rhs,3561left_on="a",3562right_on="b",3563how="left",3564coalesce=True,3565maintain_order="left_right",3566).filter(pl.col("c_right") <= "B")35673568expect = pl.DataFrame(3569[3570pl.Series("a", [1, 2], dtype=pl.Int64),3571pl.Series("b", [1, 2], dtype=pl.Int64),3572pl.Series("c", ["a", "b"], dtype=pl.String),3573pl.Series("a_right", [1, 2], dtype=pl.Int64),3574pl.Series("c_right", ["A", "B"], dtype=pl.String),3575]3576)35773578plan = q.explain()35793580assert "LEFT JOIN" not in plan3581assert plan.startswith("INNER JOIN")35823583extract = _extract_plan_joins_and_filters(plan)35843585assert extract == [3586'LEFT PLAN ON: [col("a")]',3587'RIGHT PLAN ON: [col("b")]',3588'FILTER [(col("c")) <= ("B")]',3589]35903591assert_frame_equal(q.collect(), expect)3592assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)359335943595def test_join_filter_pushdown_left_join_rewrite_23133() -> None:3596lhs = pl.LazyFrame(3597{3598"foo": [1, 2, 3],3599"bar": [6.0, 7.0, 8.0],3600"ham": ["a", "b", "c"],3601}3602)36033604rhs = pl.LazyFrame(3605{3606"apple": ["x", "y", "z"],3607"ham": ["a", "b", "d"],3608"bar": ["a", "b", "c"],3609"foo2": [1, 2, 3],3610}3611)36123613q = lhs.join(rhs, how="left", on="ham", maintain_order="left_right").filter(3614pl.col("ham") == "a", pl.col("apple") == "x", pl.col("foo") <= 23615)36163617expect = pl.DataFrame(3618[3619pl.Series("foo", [1], dtype=pl.Int64),3620pl.Series("bar", [6.0], dtype=pl.Float64),3621pl.Series("ham", ["a"], dtype=pl.String),3622pl.Series("apple", ["x"], dtype=pl.String),3623pl.Series("bar_right", ["a"], dtype=pl.String),3624pl.Series("foo2", [1], dtype=pl.Int64),3625]3626)36273628plan = q.explain()3629assert "FULL JOIN" not in plan3630assert plan.startswith("INNER JOIN")36313632extract = _extract_plan_joins_and_filters(plan)36333634assert extract[0] == 'LEFT PLAN ON: [col("ham")]'3635assert '(col("foo")) <= (2)' in extract[1]3636assert 'col("ham")) == ("a")' in extract[1]36373638assert extract[2] == 'RIGHT PLAN ON: [col("ham")]'3639assert 'col("ham")) == ("a")' in extract[3]3640assert 'col("apple")) == ("x")' in extract[3]36413642assert len(extract) == 436433644assert_frame_equal(q.collect(), expect)3645assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)364636473648def test_join_rewrite_panic_23307() -> None:3649lhs = pl.select(a=pl.lit(1, dtype=pl.Int8)).lazy()3650rhs = pl.select(a=pl.lit(1, dtype=pl.Int16), x=pl.lit(1, dtype=pl.Int32)).lazy()36513652q = lhs.join(rhs, on="a", how="left", coalesce=True).filter(pl.col("x") >= 1)36533654assert_frame_equal(3655q.collect(),3656pl.select(3657a=pl.lit(1, dtype=pl.Int8),3658x=pl.lit(1, dtype=pl.Int32),3659),3660)36613662lhs = pl.select(a=pl.lit(999, dtype=pl.Int16)).lazy()36633664# Note: -25 matches to (999).overflowing_cast(Int8).3665# This is specially chosen to test that we don't accidentally push the filter3666# to the RHS.3667rhs = pl.LazyFrame(3668{"a": [1, -25], "x": [1, 2]}, schema={"a": pl.Int8, "x": pl.Int32}3669)36703671q = lhs.join(3672rhs,3673on=pl.col("a").cast(pl.Int8, strict=False, wrap_numerical=True),3674how="left",3675coalesce=False,3676).filter(pl.col("a") >= 0)36773678expect = pl.DataFrame(3679{"a": 999, "a_right": -25, "x": 2},3680schema={"a": pl.Int16, "a_right": pl.Int8, "x": pl.Int32},3681)36823683plan = q.explain()36843685assert not plan.startswith("FILTER")36863687assert_frame_equal(q.collect(), expect)3688assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)368936903691@pytest.mark.parametrize(3692("expr_first_input", "expr_func"),3693[3694(pl.lit(None, dtype=pl.Int64), lambda col: col >= 1),3695(pl.lit(None, dtype=pl.Int64), lambda col: (col >= 1).is_not_null()),3696(pl.lit(None, dtype=pl.Int64), lambda col: (~(col >= 1)).is_not_null()),3697(pl.lit(None, dtype=pl.Int64), lambda col: ~(col >= 1).is_null()),3698#3699(pl.lit(None, dtype=pl.Int64), lambda col: col.is_in([1])),3700(pl.lit(None, dtype=pl.Int64), lambda col: ~col.is_in([1])),3701#3702(pl.lit(None, dtype=pl.Int64), lambda col: col.is_between(1, 1)),3703(1, lambda col: col.is_between(None, 1)),3704(1, lambda col: col.is_between(1, None)),3705#3706(pl.lit(None, dtype=pl.Int64), lambda col: col.is_close(1)),3707(1, lambda col: col.is_close(pl.lit(None, dtype=pl.Int64))),3708#3709(pl.lit(None, dtype=pl.Int64), lambda col: col.is_nan()),3710(pl.lit(None, dtype=pl.Int64), lambda col: col.is_not_nan()),3711(pl.lit(None, dtype=pl.Int64), lambda col: col.is_finite()),3712(pl.lit(None, dtype=pl.Int64), lambda col: col.is_infinite()),3713#3714(pl.lit(None, dtype=pl.Float64), lambda col: col.is_nan()),3715(pl.lit(None, dtype=pl.Float64), lambda col: col.is_not_nan()),3716(pl.lit(None, dtype=pl.Float64), lambda col: col.is_finite()),3717(pl.lit(None, dtype=pl.Float64), lambda col: col.is_infinite()),3718],3719)3720def test_join_rewrite_null_preserving_exprs(3721expr_first_input: Any, expr_func: Callable[[pl.Expr], pl.Expr]3722) -> None:3723lhs = pl.LazyFrame({"a": 1})3724rhs = pl.select(a=1, x=expr_first_input).lazy()37253726assert (3727pl.select(expr_first_input)3728.select(expr_func(pl.first()))3729.select(pl.first().is_null() | ~pl.first())3730.to_series()3731.item()3732)37333734q = lhs.join(rhs, on="a", how="left", maintain_order="left_right").filter(3735expr_func(pl.col("x"))3736)37373738plan = q.explain()3739assert plan.startswith("INNER JOIN")37403741out = q.collect()37423743assert out.height == 03744assert_frame_equal(out, q.collect(optimizations=pl.QueryOptFlags.none()))374537463747@pytest.mark.parametrize(3748("expr_first_input", "expr_func"),3749[3750(3751pl.lit(None, dtype=pl.Int64),3752lambda x: ~(x.is_in([1, None], nulls_equal=True)),3753),3754(3755pl.lit(None, dtype=pl.Int64),3756lambda x: x.is_in([1, None], nulls_equal=True) > True,3757),3758(3759pl.lit(None, dtype=pl.Int64),3760lambda x: x.is_in([1], nulls_equal=True),3761),3762],3763)3764def test_join_rewrite_forbid_exprs(3765expr_first_input: Any, expr_func: Callable[[pl.Expr], pl.Expr]3766) -> None:3767lhs = pl.LazyFrame({"a": 1})3768rhs = pl.select(a=1, x=expr_first_input).lazy()37693770q = lhs.join(rhs, on="a", how="left", maintain_order="left_right").filter(3771expr_func(pl.col("x"))3772)37733774plan = q.explain()3775assert plan.startswith("FILTER")37763777assert_frame_equal(q.collect(), q.collect(optimizations=pl.QueryOptFlags.none()))377837793780def test_join_coalesce_column_order_23177() -> None:3781df1 = pl.DataFrame({"time": ["09:00:21"], "symbol": [5253]})3782df2 = pl.DataFrame({"symbol": [5253], "time": ["09:00:21"]})37833784q = df1.lazy().join(df2.lazy(), on=["time", "symbol"], how="full", coalesce=True)37853786expect = pl.DataFrame({"time": ["09:00:21"], "symbol": [5253]})37873788assert_frame_equal(q.collect(), expect)378937903791def test_join_filter_pushdown_iejoin_cse_23469() -> None:3792lf_x = pl.LazyFrame({"x": [1, 2, 3]})3793lf_y = pl.LazyFrame({"y": [1, 2, 3]})37943795lf_xy = lf_x.join(lf_y, how="cross").filter(pl.col("x") > pl.col("y"))37963797q = pl.concat([lf_xy, lf_xy])37983799assert_frame_equal(3800q.collect().sort(pl.all()),3801pl.DataFrame(3802{3803"x": [2, 2, 3, 3, 3, 3],3804"y": [1, 1, 1, 1, 2, 2],3805},3806),3807)38083809q = pl.concat([lf_xy, lf_xy]).filter(pl.col("x") > pl.col("y"))38103811assert_frame_equal(3812q.collect().sort(pl.all()),3813pl.DataFrame(3814{3815"x": [2, 2, 3, 3, 3, 3],3816"y": [1, 1, 1, 1, 2, 2],3817},3818),3819)38203821q = (3822lf_x.join_where(lf_y, pl.col("x") == pl.col("y"))3823.cache()3824.filter(pl.col("x") >= 0)3825)38263827assert_frame_equal(3828q.collect().sort(pl.all()), pl.DataFrame({"x": [1, 2, 3], "y": [1, 2, 3]})3829)383038313832def test_join_cast_type_coercion_23236() -> None:3833lhs = pl.LazyFrame([{"name": "a"}]).rename({"name": "newname"})3834rhs = pl.LazyFrame([{"name": "a"}])38353836q = lhs.join(rhs, left_on=pl.col("newname").cast(pl.String), right_on="name")38373838assert_frame_equal(q.collect(), pl.DataFrame({"newname": "a", "name": "a"}))383938403841@pytest.mark.parametrize(3842("how", "expected"),3843[3844(3845"inner",3846pl.DataFrame(schema={"a": pl.Int128, "a_right": pl.Int128}),3847),3848(3849"left",3850pl.DataFrame(3851{"a": [1, 1, 2], "a_right": None},3852schema={"a": pl.Int128, "a_right": pl.Int128},3853),3854),3855(3856"right",3857pl.DataFrame(3858{3859"a": None,3860"a_right": [3861-9223372036854775808,3862-9223372036854775807,3863-9223372036854775806,3864],3865},3866schema={"a": pl.Int128, "a_right": pl.Int128},3867),3868),3869(3870"full",3871pl.DataFrame(3872[3873pl.Series("a", [None, None, None, 1, 1, 2], dtype=pl.Int128),3874pl.Series(3875"a_right",3876[3877-9223372036854775808,3878-9223372036854775807,3879-9223372036854775806,3880None,3881None,3882None,3883],3884dtype=pl.Int128,3885),3886]3887),3888),3889(3890"semi",3891pl.DataFrame([pl.Series("a", [], dtype=pl.Int128)]),3892),3893(3894"anti",3895pl.DataFrame([pl.Series("a", [1, 1, 2], dtype=pl.Int128)]),3896),3897],3898)3899@pytest.mark.parametrize(3900("sort_left", "sort_right"),3901[(True, True), (True, False), (False, True), (False, False)],3902)3903def test_join_i128_23688(3904how: str, expected: pl.DataFrame, sort_left: bool, sort_right: bool3905) -> None:3906lhs = pl.LazyFrame({"a": pl.Series([1, 1, 2], dtype=pl.Int128)})39073908rhs = pl.LazyFrame(3909{3910"a": pl.Series(3911[3912-9223372036854775808,3913-9223372036854775807,3914-9223372036854775806,3915],3916dtype=pl.Int128,3917)3918}3919)39203921lhs = lhs.collect().sort("a").lazy() if sort_left else lhs3922rhs = rhs.collect().sort("a").lazy() if sort_right else rhs39233924q = lhs.join(rhs, on="a", how=how, coalesce=False) # type: ignore[arg-type]39253926assert_frame_equal(3927q.collect().sort(pl.all()),3928expected,3929)39303931q = (3932lhs.with_columns(b=pl.col("a"))3933.join(3934rhs.with_columns(b=pl.col("a")),3935on=["a", "b"],3936how=how, # type: ignore[arg-type]3937coalesce=False,3938)3939.select(expected.columns)3940)39413942assert_frame_equal(3943q.collect().sort(pl.all()),3944expected,3945)394639473948def test_join_asof_by_i128() -> None:3949lhs = pl.LazyFrame({"a": pl.Series([1, 1, 2], dtype=pl.Int128), "i": 1})39503951rhs = pl.LazyFrame(3952{3953"a": pl.Series(3954[3955-9223372036854775808,3956-9223372036854775807,3957-9223372036854775806,3958],3959dtype=pl.Int128,3960),3961"i": 1,3962}3963).with_columns(b=pl.col("a"))39643965q = lhs.join_asof(rhs, on="i", by="a")39663967assert_frame_equal(3968q.collect().sort(pl.all()),3969pl.DataFrame(3970{"a": [1, 1, 2], "i": 1, "b": None},3971schema={"a": pl.Int128, "i": pl.Int32, "b": pl.Int128},3972),3973)397439753976def test_join_lazyframe_with_itself_after_sort_25395() -> None:3977lf = pl.LazyFrame({"a": [1]})3978result = lf.sort("a").join(lf, on="a").collect()39793980assert_frame_equal(result, pl.DataFrame({"a": [1]}))398139823983def test_join_right_with_cast_predicate_pushdown() -> None:3984lhs = pl.LazyFrame({"x": [0, 1], "z": [4, 5]})3985rhs = pl.LazyFrame({"y": [2, 3]}).cast(pl.Int32)39863987out = (3988lhs.join(rhs, left_on="x", right_on="y", how="right")3989.filter(pl.col("z") >= 6)3990.collect()3991)39923993ret = pl.DataFrame(3994{3995"z": [],3996"y": [],3997},3998schema={"z": pl.Int64, "y": pl.Int64},3999)4000assert_frame_equal(out, ret, check_column_order=True, check_row_order=False)400140024003def test_full_join_rewrite_to_right_with_cast() -> None:4004lhs = pl.LazyFrame({"x": [0, 1], "a": [10, 20]})4005rhs = pl.LazyFrame({"y": [2, 3], "b": [30, 40]}).cast(pl.Int32)40064007out = (4008lhs.join(rhs, left_on="x", right_on="y", how="full")4009.filter(pl.col("b") >= 0)4010.collect()4011)40124013ret = pl.DataFrame(4014{4015"x": [None, None],4016"a": [None, None],4017"y": [2, 3],4018"b": [30, 40],4019},4020schema={4021"x": pl.Int64,4022"a": pl.Int64,4023"y": pl.Int32,4024"b": pl.Int32,4025},4026)4027assert_frame_equal(out, ret, check_column_order=True, check_row_order=False)402840294030