Path: blob/main/py-polars/tests/unit/operations/test_cross_join.py
6939 views
from datetime import datetime1from zoneinfo import ZoneInfo23import pytest45import polars as pl6from polars.testing import assert_frame_equal789def test_cross_join_predicate_pushdown_block_16956() -> None:10lf = pl.LazyFrame(11[12[1718085600000, 1718172000000, 1718776800000],13[1718114400000, 1718200800000, 1718805600000],14],15schema=["start_datetime", "end_datetime"],16).cast(pl.Datetime("ms", "Europe/Amsterdam"))1718assert (19lf.join(lf, how="cross")20.filter(21pl.col.end_datetime_right.is_between(22pl.col.start_datetime, pl.col.start_datetime.dt.offset_by("132h")23)24)25.select("start_datetime", "end_datetime_right")26).collect(optimizations=pl.QueryOptFlags(predicate_pushdown=True)).to_dict(27as_series=False28) == {29"start_datetime": [30datetime(2024, 6, 11, 8, 0, tzinfo=ZoneInfo(key="Europe/Amsterdam")),31datetime(2024, 6, 11, 8, 0, tzinfo=ZoneInfo(key="Europe/Amsterdam")),32datetime(2024, 6, 12, 8, 0, tzinfo=ZoneInfo(key="Europe/Amsterdam")),33datetime(2024, 6, 19, 8, 0, tzinfo=ZoneInfo(key="Europe/Amsterdam")),34],35"end_datetime_right": [36datetime(2024, 6, 11, 16, 0, tzinfo=ZoneInfo(key="Europe/Amsterdam")),37datetime(2024, 6, 12, 16, 0, tzinfo=ZoneInfo(key="Europe/Amsterdam")),38datetime(2024, 6, 12, 16, 0, tzinfo=ZoneInfo(key="Europe/Amsterdam")),39datetime(2024, 6, 19, 16, 0, tzinfo=ZoneInfo(key="Europe/Amsterdam")),40],41}424344def test_cross_join_raise_on_keys() -> None:45df = pl.DataFrame({"a": [0, 1], "b": ["x", "y"]})4647with pytest.raises(ValueError):48df.join(df, how="cross", left_on="a", right_on="b")495051def test_nested_loop_join() -> None:52left = pl.LazyFrame(53{54"a": [1, 2, 1, 3],55"b": [1, 2, 3, 4],56}57)58right = pl.LazyFrame(59{60"c": [4, 1, 2],61"d": [1, 2, 3],62}63)6465actual = left.join_where(right, pl.col("a") != pl.col("c"))66plan = actual.explain()67assert "NESTED LOOP JOIN" in plan68expected = pl.DataFrame(69{70"a": [1, 1, 2, 2, 1, 1, 3, 3, 3],71"b": [1, 1, 2, 2, 3, 3, 4, 4, 4],72"c": [4, 2, 4, 1, 4, 2, 4, 1, 2],73"d": [1, 3, 1, 2, 1, 3, 1, 2, 3],74}75)76assert_frame_equal(77actual.collect(), expected, check_row_order=False, check_exact=True78)798081def test_cross_join_chunking_panic_22793() -> None:82N = int(pl.thread_pool_size() ** 0.5) * 283df = pl.DataFrame(84[pl.concat([pl.Series("a", [0]) for _ in range(N)]), pl.Series("b", [0] * N)],85)86assert_frame_equal(87df.lazy()88.join(pl.DataFrame().lazy(), how="cross")89.filter(pl.col("a") == pl.col("a"))90.collect(),91df.schema.to_frame(),92)939495