Path: blob/main/py-polars/tests/unit/operations/test_cross_join.py
8424 views
from datetime import datetime1from zoneinfo import ZoneInfo23import pytest45import polars as pl6from polars._typing import MaintainOrderJoin7from polars.testing import assert_frame_equal8910def test_cross_join_predicate_pushdown_block_16956() -> None:11lf = pl.LazyFrame(12[13[1718085600000, 1718172000000, 1718776800000],14[1718114400000, 1718200800000, 1718805600000],15],16schema=["start_datetime", "end_datetime"],17).cast(pl.Datetime("ms", "Europe/Amsterdam"))1819assert (20lf.join(lf, how="cross")21.filter(22pl.col.end_datetime_right.is_between(23pl.col.start_datetime, pl.col.start_datetime.dt.offset_by("132h")24)25)26.select("start_datetime", "end_datetime_right")27).collect(optimizations=pl.QueryOptFlags(predicate_pushdown=True)).to_dict(28as_series=False29) == {30"start_datetime": [31datetime(2024, 6, 11, 8, 0, tzinfo=ZoneInfo(key="Europe/Amsterdam")),32datetime(2024, 6, 11, 8, 0, tzinfo=ZoneInfo(key="Europe/Amsterdam")),33datetime(2024, 6, 12, 8, 0, tzinfo=ZoneInfo(key="Europe/Amsterdam")),34datetime(2024, 6, 19, 8, 0, tzinfo=ZoneInfo(key="Europe/Amsterdam")),35],36"end_datetime_right": [37datetime(2024, 6, 11, 16, 0, tzinfo=ZoneInfo(key="Europe/Amsterdam")),38datetime(2024, 6, 12, 16, 0, tzinfo=ZoneInfo(key="Europe/Amsterdam")),39datetime(2024, 6, 12, 16, 0, tzinfo=ZoneInfo(key="Europe/Amsterdam")),40datetime(2024, 6, 19, 16, 0, tzinfo=ZoneInfo(key="Europe/Amsterdam")),41],42}434445def test_cross_join_raise_on_keys() -> None:46df = pl.DataFrame({"a": [0, 1], "b": ["x", "y"]})4748with pytest.raises(ValueError):49df.join(df, how="cross", left_on="a", right_on="b")505152def test_nested_loop_join() -> None:53left = pl.LazyFrame(54{55"a": [1, 2, 1, 3],56"b": [1, 2, 3, 4],57}58)59right = pl.LazyFrame(60{61"c": [4, 1, 2],62"d": [1, 2, 3],63}64)6566actual = left.join_where(right, pl.col("a") != pl.col("c"))67plan = actual.explain()68assert "NESTED LOOP JOIN" in plan69expected = pl.DataFrame(70{71"a": [1, 1, 2, 2, 1, 1, 3, 3, 3],72"b": [1, 1, 2, 2, 3, 3, 4, 4, 4],73"c": [4, 2, 4, 1, 4, 2, 4, 1, 2],74"d": [1, 3, 1, 2, 1, 3, 1, 2, 3],75}76)77assert_frame_equal(78actual.collect(), expected, check_row_order=False, check_exact=True79)808182def test_cross_join_chunking_panic_22793() -> None:83N = int(pl.thread_pool_size() ** 0.5) * 284df = pl.DataFrame(85[pl.concat([pl.Series("a", [0]) for _ in range(N)]), pl.Series("b", [0] * N)],86)87assert_frame_equal(88df.lazy()89.join(pl.DataFrame().lazy(), how="cross")90.filter(pl.col("a") == pl.col("a"))91.collect(),92df.schema.to_frame(),93)949596@pytest.mark.parametrize(97"maintain_order", ["left", "right", "left_right", "right_left"]98)99def test_cross_join_maintain_order_24663(maintain_order: MaintainOrderJoin) -> None:100df = pl.DataFrame({"x": [0, 1, 2, 3, 4]})101df2 = pl.DataFrame({"y": [0, 1, 2, 3, 4]})102primary = [x for x in range(5) for _ in range(5)]103secondary = [x for _ in range(5) for x in range(5)]104if maintain_order.startswith("left"):105expected = pl.DataFrame({"x": primary, "y": secondary})106else:107expected = pl.DataFrame({"x": secondary, "y": primary})108109assert_frame_equal(110df.join(df2, how="cross", maintain_order=maintain_order), expected111)112113# Test with fused filter as well.114assert_frame_equal(115df.lazy()116.join(df2.lazy(), how="cross", maintain_order=maintain_order)117.filter((pl.col.x + pl.col.y) % 2 == 0),118expected.lazy().filter((pl.col.x + pl.col.y) % 2 == 0),119)120121122