Path: blob/main/py-polars/tests/benchmark/test_join_where.py
6939 views
"""Benchmark tests for join_where with inequality conditions."""12from __future__ import annotations34import numpy as np5import pytest67import polars as pl8from polars.exceptions import ColumnNotFoundError9from polars.testing import assert_frame_equal1011pytestmark = pytest.mark.benchmark()121314def test_strict_inequalities(east_west: tuple[pl.DataFrame, pl.DataFrame]) -> None:15east, west = east_west16result = (17east.lazy()18.join_where(19west.lazy(),20[pl.col("dur") < pl.col("time"), pl.col("rev") > pl.col("cost")],21)22.collect()23)2425assert len(result) > 0262728def test_non_strict_inequalities(east_west: tuple[pl.DataFrame, pl.DataFrame]) -> None:29east, west = east_west30result = (31east.lazy()32.join_where(33west.lazy(),34[pl.col("dur") <= pl.col("time"), pl.col("rev") >= pl.col("cost")],35)36.collect()37)3839assert len(result) > 0404142def test_single_inequality(east_west: tuple[pl.DataFrame, pl.DataFrame]) -> None:43east, west = east_west44result = (45east.lazy()46# Reduce the number of results by scaling LHS dur column up47.with_columns((pl.col("dur") * 30).alias("scaled_dur"))48.join_where(49west.lazy(),50pl.col("scaled_dur") < pl.col("time"),51)52.collect()53)5455assert len(result) > 0565758@pytest.fixture(scope="module")59def east_west() -> tuple[pl.DataFrame, pl.DataFrame]:60num_rows_left, num_rows_right = 50_000, 5_00061rng = np.random.default_rng(42)6263# Generate two separate datasets where revenue/cost are linearly related to64# duration/time, but add some noise to the west table so that there are some65# rows where the cost for the same or greater time will be less than the east table.66east_dur = rng.integers(1_000, 50_000, num_rows_left)67east_rev = (east_dur * 0.123).astype(np.int32)68west_time = rng.integers(1_000, 50_000, num_rows_right)69west_cost = west_time * 0.12370west_cost += rng.normal(0.0, 1.0, num_rows_right)71west_cost = west_cost.astype(np.int32)7273east = pl.DataFrame(74{75"id": np.arange(0, num_rows_left),76"dur": east_dur,77"rev": east_rev,78"cores": rng.integers(1, 10, num_rows_left),79}80)81west = pl.DataFrame(82{83"t_id": np.arange(0, num_rows_right),84"time": west_time,85"cost": west_cost,86"cores": rng.integers(1, 10, num_rows_right),87}88)8990return east, west919293def test_join_where_invalid_column() -> None:94df = pl.DataFrame({"x": 1})95with pytest.raises(ColumnNotFoundError, match="y"):96df.join_where(df, pl.col("x") < pl.col("y"))9798# Nested column99df1 = pl.DataFrame({"a": [1, 2, 3], "b": [True, False, True]})100df2 = pl.DataFrame(101{102"a": [2, 3, 4],103"c": ["a", "b", "c"],104}105)106with pytest.raises(ColumnNotFoundError, match="d"):107df = df1.join_where(108df2,109((pl.col("a") - pl.col("b")) > (pl.col("c") == "a").cast(pl.Int32))110> (pl.col("a") - pl.col("d")),111)112113114def test_join_where_not_elementwise_24134() -> None:115out = (116pl.LazyFrame({"a": [0, 1, 2, 16]})117.join_where(118pl.LazyFrame({"b": [0, 1, 2, 16]}),119pl.col.a == pl.len(),120)121.collect()122)123124expected = pl.DataFrame({"a": [16, 16, 16, 16], "b": [0, 1, 2, 16]})125assert_frame_equal(out, expected, check_row_order=False)126127128