Path: blob/main/py-polars/tests/unit/operations/test_inequality_join.py
6939 views
from __future__ import annotations12from datetime import datetime3from typing import TYPE_CHECKING, Any45import hypothesis.strategies as st6import numpy as np7import pytest8from hypothesis import given910import polars as pl11from polars.testing import assert_frame_equal12from polars.testing.parametric.strategies import series1314if TYPE_CHECKING:15from hypothesis.strategies import DrawFn, SearchStrategy161718@pytest.mark.parametrize(19("pred_1", "pred_2"),20[21(pl.col("time") > pl.col("time_right"), pl.col("cost") < pl.col("cost_right")),22(pl.col("time_right") < pl.col("time"), pl.col("cost_right") > pl.col("cost")),23],24)25def test_self_join(pred_1: pl.Expr, pred_2: pl.Expr) -> None:26west = pl.DataFrame(27{28"t_id": [404, 498, 676, 742],29"time": [100, 140, 80, 90],30"cost": [6, 11, 10, 5],31"cores": [4, 2, 1, 4],32}33)3435actual = west.join_where(west, pred_1, pred_2)3637expected = pl.DataFrame(38{39"t_id": [742, 404],40"time": [90, 100],41"cost": [5, 6],42"cores": [4, 4],43"t_id_right": [676, 676],44"time_right": [80, 80],45"cost_right": [10, 10],46"cores_right": [1, 1],47}48)49assert_frame_equal(actual, expected, check_row_order=False, check_exact=True)505152def test_basic_ie_join() -> None:53east = pl.DataFrame(54{55"id": [100, 101, 102],56"dur": [140, 100, 90],57"rev": [12, 12, 5],58"cores": [2, 8, 4],59}60)61west = pl.DataFrame(62{63"t_id": [404, 498, 676, 742],64"time": [100, 140, 80, 90],65"cost": [6, 11, 10, 5],66"cores": [4, 2, 1, 4],67}68)6970actual = east.join_where(71west,72pl.col("dur") < pl.col("time"),73pl.col("rev") > pl.col("cost"),74)7576expected = pl.DataFrame(77{78"id": [101],79"dur": [100],80"rev": [12],81"cores": [8],82"t_id": [498],83"time": [140],84"cost": [11],85"cores_right": [2],86}87)88assert_frame_equal(actual, expected, check_row_order=False, check_exact=True)899091@given(92offset=st.integers(-6, 5),93length=st.integers(0, 6),94)95def test_ie_join_with_slice(offset: int, length: int) -> None:96east = pl.DataFrame(97{98"id": [100, 101, 102],99"dur": [120, 140, 160],100"rev": [12, 14, 16],101"cores": [2, 8, 4],102}103).lazy()104west = pl.DataFrame(105{106"t_id": [404, 498, 676, 742],107"time": [90, 130, 150, 170],108"cost": [9, 13, 15, 16],109"cores": [4, 2, 1, 4],110}111).lazy()112113actual = (114east.join_where(115west,116pl.col("dur") < pl.col("time"),117pl.col("rev") < pl.col("cost"),118)119.slice(offset, length)120.collect()121)122123expected_full = pl.DataFrame(124{125"id": [101, 101, 100, 100, 100],126"dur": [140, 140, 120, 120, 120],127"rev": [14, 14, 12, 12, 12],128"cores": [8, 8, 2, 2, 2],129"t_id": [676, 742, 498, 676, 742],130"time": [150, 170, 130, 150, 170],131"cost": [15, 16, 13, 15, 16],132"cores_right": [1, 4, 2, 1, 4],133}134)135# The ordering of the result is arbitrary, so we can136# only verify that each row of the slice is present in the full expected result.137assert len(actual) == len(expected_full.slice(offset, length))138139expected_rows = set(expected_full.iter_rows())140for row in actual.iter_rows():141assert row in expected_rows, f"{row} not in expected rows"142143144def test_ie_join_with_expressions() -> None:145east = pl.DataFrame(146{147"id": [100, 101, 102],148"dur": [70, 50, 45],149"rev": [12, 12, 5],150"cores": [2, 8, 4],151}152)153west = pl.DataFrame(154{155"t_id": [404, 498, 676, 742],156"time": [100, 140, 80, 90],157"cost": [12, 22, 20, 10],158"cores": [4, 2, 1, 4],159}160)161162actual = east.join_where(163west,164(pl.col("dur") * 2) < pl.col("time"),165pl.col("rev") > (pl.col("cost").cast(pl.Int32) // 2).cast(pl.Int64),166)167168expected = pl.DataFrame(169{170"id": [101],171"dur": [50],172"rev": [12],173"cores": [8],174"t_id": [498],175"time": [140],176"cost": [22],177"cores_right": [2],178}179)180assert_frame_equal(actual, expected, check_row_order=False, check_exact=True)181182183@pytest.mark.parametrize(184"range_constraint",185[186[187# can write individual components188pl.col("time") >= pl.col("start_time"),189pl.col("time") < pl.col("end_time"),190],191[192# or a single `is_between` expression193pl.col("time").is_between("start_time", "end_time", closed="left")194],195],196)197def test_join_where_predicates(range_constraint: list[pl.Expr]) -> None:198left = pl.DataFrame(199{200"id": [0, 1, 2, 3, 4, 5],201"group": [0, 0, 0, 1, 1, 1],202"time": [203datetime(2024, 8, 26, 15, 34, 30),204datetime(2024, 8, 26, 15, 35, 30),205datetime(2024, 8, 26, 15, 36, 30),206datetime(2024, 8, 26, 15, 37, 30),207datetime(2024, 8, 26, 15, 38, 0),208datetime(2024, 8, 26, 15, 39, 0),209],210}211)212right = pl.DataFrame(213{214"id": [0, 1, 2],215"group": [0, 1, 1],216"start_time": [217datetime(2024, 8, 26, 15, 34, 0),218datetime(2024, 8, 26, 15, 35, 0),219datetime(2024, 8, 26, 15, 38, 0),220],221"end_time": [222datetime(2024, 8, 26, 15, 36, 0),223datetime(2024, 8, 26, 15, 37, 0),224datetime(2024, 8, 26, 15, 39, 0),225],226}227)228229actual = left.join_where(right, *range_constraint).select("id", "id_right")230231expected = pl.DataFrame(232{233"id": [0, 1, 1, 2, 4],234"id_right": [0, 0, 1, 1, 2],235}236)237assert_frame_equal(actual, expected, check_row_order=False, check_exact=True)238239q = (240left.lazy()241.join_where(242right.lazy(),243pl.col("group_right") == pl.col("group"),244*range_constraint,245)246.select("id", "id_right", "group")247.sort("id")248)249250explained = q.explain()251assert "INNER JOIN" in explained252assert "FILTER" in explained253actual = q.collect()254255expected = (256left.join(right, how="cross")257.filter(pl.col("group") == pl.col("group_right"), *range_constraint)258.select("id", "id_right", "group")259.sort("id")260)261assert_frame_equal(actual, expected, check_exact=True)262263q = (264left.lazy()265.join_where(266right.lazy(),267pl.col("group") != pl.col("group_right"),268*range_constraint,269)270.select("id", "id_right", "group")271.sort("id")272)273274explained = q.explain()275assert "IEJOIN" in explained276assert "FILTER" in explained277actual = q.collect()278279expected = (280left.join(right, how="cross")281.filter(pl.col("group") != pl.col("group_right"), *range_constraint)282.select("id", "id_right", "group")283.sort("id")284)285assert_frame_equal(actual, expected, check_exact=True)286287q = (288left.lazy()289.join_where(290right.lazy(),291pl.col("group") != pl.col("group_right"),292)293.select("id", "group", "group_right")294.sort("id")295.select("group", "group_right")296)297298explained = q.explain()299assert "NESTED LOOP" in explained300actual = q.collect()301assert actual.to_dict(as_series=False) == {302"group": [0, 0, 0, 0, 0, 0, 1, 1, 1],303"group_right": [1, 1, 1, 1, 1, 1, 0, 0, 0],304}305306307def _inequality_expression(col1: str, op: str, col2: str) -> pl.Expr:308if op == "<":309return pl.col(col1) < pl.col(col2)310elif op == "<=":311return pl.col(col1) <= pl.col(col2)312elif op == ">":313return pl.col(col1) > pl.col(col2)314elif op == ">=":315return pl.col(col1) >= pl.col(col2)316else:317message = f"Invalid operator '{op}'"318raise ValueError(message)319320321def operators() -> SearchStrategy[str]:322valid_operators = ["<", "<=", ">", ">="]323return st.sampled_from(valid_operators)324325326@st.composite327def east_df(328draw: DrawFn, with_nulls: bool = False, use_floats: bool = False329) -> pl.DataFrame:330height = draw(st.integers(min_value=0, max_value=20))331332if use_floats:333dur_strategy: SearchStrategy[Any] = st.floats(allow_nan=True)334rev_strategy: SearchStrategy[Any] = st.floats(allow_nan=True)335dur_dtype: type[pl.DataType] = pl.Float32336rev_dtype: type[pl.DataType] = pl.Float32337else:338dur_strategy = st.integers(min_value=100, max_value=105)339rev_strategy = st.integers(min_value=9, max_value=13)340dur_dtype = pl.Int64341rev_dtype = pl.Int64342343if with_nulls:344dur_strategy = dur_strategy | st.none()345rev_strategy = rev_strategy | st.none()346347cores_strategy = st.integers(min_value=1, max_value=10)348349ids = np.arange(0, height)350dur = draw(st.lists(dur_strategy, min_size=height, max_size=height))351rev = draw(st.lists(rev_strategy, min_size=height, max_size=height))352cores = draw(st.lists(cores_strategy, min_size=height, max_size=height))353354return pl.DataFrame(355[356pl.Series("id", ids, dtype=pl.Int64),357pl.Series("dur", dur, dtype=dur_dtype),358pl.Series("rev", rev, dtype=rev_dtype),359pl.Series("cores", cores, dtype=pl.Int64),360]361)362363364@st.composite365def west_df(366draw: DrawFn, with_nulls: bool = False, use_floats: bool = False367) -> pl.DataFrame:368height = draw(st.integers(min_value=0, max_value=20))369370if use_floats:371time_strategy: SearchStrategy[Any] = st.floats(allow_nan=True)372cost_strategy: SearchStrategy[Any] = st.floats(allow_nan=True)373time_dtype: type[pl.DataType] = pl.Float32374cost_dtype: type[pl.DataType] = pl.Float32375else:376time_strategy = st.integers(min_value=100, max_value=105)377cost_strategy = st.integers(min_value=9, max_value=13)378time_dtype = pl.Int64379cost_dtype = pl.Int64380381if with_nulls:382time_strategy = time_strategy | st.none()383cost_strategy = cost_strategy | st.none()384385cores_strategy = st.integers(min_value=1, max_value=10)386387t_id = np.arange(100, 100 + height)388time = draw(st.lists(time_strategy, min_size=height, max_size=height))389cost = draw(st.lists(cost_strategy, min_size=height, max_size=height))390cores = draw(st.lists(cores_strategy, min_size=height, max_size=height))391392return pl.DataFrame(393[394pl.Series("t_id", t_id, dtype=pl.Int64),395pl.Series("time", time, dtype=time_dtype),396pl.Series("cost", cost, dtype=cost_dtype),397pl.Series("cores", cores, dtype=pl.Int64),398]399)400401402@given(403east=east_df(),404west=west_df(),405op1=operators(),406op2=operators(),407)408def test_ie_join(east: pl.DataFrame, west: pl.DataFrame, op1: str, op2: str) -> None:409expr0 = _inequality_expression("dur", op1, "time")410expr1 = _inequality_expression("rev", op2, "cost")411412actual = east.join_where(west, expr0 & expr1)413414expected = east.join(west, how="cross").filter(expr0 & expr1)415assert_frame_equal(actual, expected, check_row_order=False, check_exact=True)416417418@given(419east=east_df(with_nulls=True),420west=west_df(with_nulls=True),421op1=operators(),422op2=operators(),423)424def test_ie_join_with_nulls(425east: pl.DataFrame, west: pl.DataFrame, op1: str, op2: str426) -> None:427expr0 = _inequality_expression("dur", op1, "time")428expr1 = _inequality_expression("rev", op2, "cost")429430actual = east.join_where(west, expr0 & expr1)431432expected = east.join(west, how="cross").filter(expr0 & expr1)433assert_frame_equal(actual, expected, check_row_order=False, check_exact=True)434435436@given(437east=east_df(use_floats=True),438west=west_df(use_floats=True),439op1=operators(),440op2=operators(),441)442def test_ie_join_with_floats(443east: pl.DataFrame, west: pl.DataFrame, op1: str, op2: str444) -> None:445expr0 = _inequality_expression("dur", op1, "time")446expr1 = _inequality_expression("rev", op2, "cost")447448actual = east.join_where(west, expr0, expr1)449450expected = east.join(west, how="cross").filter(expr0 & expr1)451assert_frame_equal(actual, expected, check_row_order=False, check_exact=True)452453454def test_raise_invalid_input_join_where() -> None:455df = pl.DataFrame({"id": [1, 2]})456with pytest.raises(457pl.exceptions.InvalidOperationError,458match="expected join keys/predicates",459):460df.join_where(df)461462463def test_ie_join_use_keys_multiple() -> None:464a = pl.LazyFrame({"a": [1, 2, 3], "x": [7, 2, 1]})465b = pl.LazyFrame({"b": [2, 2, 2], "x": [7, 1, 3]})466467assert a.join_where(468b,469pl.col.a >= pl.col.b,470pl.col.a <= pl.col.b,471).collect().sort("x_right").to_dict(as_series=False) == {472"a": [2, 2, 2],473"x": [2, 2, 2],474"b": [2, 2, 2],475"x_right": [1, 3, 7],476}477478479@given(480left=series(481dtype=pl.Int64,482strategy=st.integers(min_value=0, max_value=10) | st.none(),483max_size=10,484),485right=series(486dtype=pl.Int64,487strategy=st.integers(min_value=-10, max_value=10) | st.none(),488max_size=10,489),490op=operators(),491)492def test_single_inequality(left: pl.Series, right: pl.Series, op: str) -> None:493expr = _inequality_expression("x", op, "y")494495left_df = pl.DataFrame(496{497"id": np.arange(len(left)),498"x": left,499}500)501right_df = pl.DataFrame(502{503"id": np.arange(len(right)),504"y": right,505}506)507508actual = left_df.join_where(right_df, expr)509510expected = left_df.join(right_df, how="cross").filter(expr)511assert_frame_equal(actual, expected, check_row_order=False, check_exact=True)512513514@given(515offset=st.integers(-6, 5),516length=st.integers(0, 6),517)518def test_single_inequality_with_slice(offset: int, length: int) -> None:519left = pl.DataFrame(520{521"id": list(range(8)),522"x": [0, 1, 1, 2, 3, 5, 5, 7],523}524)525right = pl.DataFrame(526{527"id": list(range(6)),528"y": [-1, 2, 4, 4, 6, 9],529}530)531532expr = pl.col("x") > pl.col("y")533actual = left.join_where(right, expr).slice(offset, length)534535expected_full = left.join(right, how="cross").filter(expr)536537assert len(actual) == len(expected_full.slice(offset, length))538539expected_rows = set(expected_full.iter_rows())540for row in actual.iter_rows():541assert row in expected_rows, f"{row} not in expected rows"542543544def test_ie_join_projection_pd_19005() -> None:545lf = pl.LazyFrame({"a": [1, 2], "b": [3, 4]}).with_row_index()546q = (547lf.join_where(548lf,549pl.col.index < pl.col.index_right,550pl.col.index.cast(pl.Int64) + pl.col.a > pl.col.a_right,551)552.group_by(pl.col.index)553.agg(pl.col.index_right)554)555556out = q.collect()557assert out.schema == pl.Schema(558[("index", pl.get_index_type()), ("index_right", pl.List(pl.get_index_type()))]559)560assert out.shape == (0, 2)561562563def test_single_sided_predicate() -> None:564left = pl.LazyFrame({"a": [1, -1, 2]}).with_row_index()565right = pl.LazyFrame({"b": [1, 2]})566567result = (568left.join_where(right, pl.col.index >= pl.col.a)569.collect()570.sort("index", "a", "b")571)572expected = pl.DataFrame(573{574"index": pl.Series([1, 1, 2, 2], dtype=pl.get_index_type()),575"a": [-1, -1, 2, 2],576"b": [1, 2, 1, 2],577}578)579assert_frame_equal(result, expected)580581582def test_join_on_strings() -> None:583df = pl.LazyFrame(584{585"a": ["a", "b", "c"],586"b": ["b", "b", "b"],587}588)589590q = df.join_where(df, pl.col("a").ge(pl.col("a_right")))591592assert "NESTED LOOP JOIN" in q.explain()593# Note: Output is flaky without sort when POLARS_MAX_THREADS=1594assert q.collect().sort(pl.all()).to_dict(as_series=False) == {595"a": ["a", "b", "b", "c", "c", "c"],596"b": ["b", "b", "b", "b", "b", "b"],597"a_right": ["a", "a", "b", "a", "b", "c"],598"b_right": ["b", "b", "b", "b", "b", "b"],599}600601602def test_join_partial_column_name_overlap_19119() -> None:603left = pl.LazyFrame({"a": [1], "b": [2]})604right = pl.LazyFrame({"a": [2], "d": [0]})605606q = left.join_where(right, pl.col("a") > pl.col("d"))607608assert q.collect().to_dict(as_series=False) == {609"a": [1],610"b": [2],611"a_right": [2],612"d": [0],613}614615616def test_join_predicate_pushdown_19580() -> None:617left = pl.LazyFrame(618{619"a": [1, 2, 3, 1],620"b": [1, 2, 3, 4],621"c": [2, 3, 4, 5],622}623)624625right = pl.LazyFrame({"a": [1, 3], "c": [2, 4], "d": [6, 3]})626627q = left.join_where(628right,629pl.col("b") < pl.col("c_right"),630pl.col("a") < pl.col("a_right"),631pl.col("a") < pl.col("d"),632)633634expect = (635left.join(right, how="cross")636.collect()637.filter(638(pl.col("a") < pl.col("d"))639& (pl.col("b") < pl.col("c_right"))640& (pl.col("a") < pl.col("a_right"))641)642)643644assert_frame_equal(expect, q.collect(), check_row_order=False)645646647def test_join_where_literal_20061() -> None:648df_left = pl.DataFrame(649{"id": [1, 2, 3], "value_left": [10, 20, 30], "flag": [1, 0, 1]}650)651652df_right = pl.DataFrame(653{654"id": [1, 2, 3],655"value_right": [5, 5, 25],656"flag": [1, 0, 1],657}658)659660assert df_left.join_where(661df_right,662pl.col("value_left") > pl.col("value_right"),663pl.col("flag_right") == pl.lit(1, dtype=pl.Int8),664).sort(pl.all()).to_dict(as_series=False) == {665"id": [1, 2, 3, 3],666"value_left": [10, 20, 30, 30],667"flag": [1, 0, 1, 1],668"id_right": [1, 1, 1, 3],669"value_right": [5, 5, 5, 25],670"flag_right": [1, 1, 1, 1],671}672673674def test_boolean_predicate_join_where() -> None:675urls = pl.LazyFrame({"url": "abcd.com/page"})676categories = pl.LazyFrame({"base_url": "abcd.com", "category": "landing page"})677assert (678"NESTED LOOP JOIN"679in urls.join_where(680categories, pl.col("url").str.starts_with(pl.col("base_url"))681).explain()682)683684685