Path: blob/main/py-polars/tests/unit/sql/test_conditional.py
6939 views
from __future__ import annotations12from datetime import date3from pathlib import Path45import pytest67import polars as pl8from polars.exceptions import SQLSyntaxError9from polars.testing import assert_frame_equal101112@pytest.fixture13def foods_ipc_path() -> Path:14return Path(__file__).parent.parent / "io" / "files" / "foods1.ipc"151617def test_case_when() -> None:18lf = pl.LazyFrame(19{20"v1": [None, 2, None, 4],21"v2": [101, 202, 303, 404],22}23)24with pl.SQLContext(test_data=lf, eager=True) as ctx:25out = ctx.execute(26"""27SELECT *, CASE WHEN COALESCE(v1, v2) % 2 != 0 THEN 'odd' ELSE 'even' END as "v3"28FROM test_data29"""30)31assert out.to_dict(as_series=False) == {32"v1": [None, 2, None, 4],33"v2": [101, 202, 303, 404],34"v3": ["odd", "even", "odd", "even"],35}363738@pytest.mark.parametrize("else_clause", ["ELSE NULL ", ""])39def test_case_when_optional_else(else_clause: str) -> None:40df = pl.DataFrame(41{42"a": [1, 2, 3, 4, 5, 6, 7],43"b": [7, 6, 5, 4, 3, 2, 1],44"c": [3, 4, 0, 3, 4, 1, 1],45}46)47query = f"""48SELECT49AVG(CASE WHEN a <= b THEN c {else_clause}END) AS conditional_mean50FROM self51"""52res = df.sql(query)53assert res.to_dict(as_series=False) == {"conditional_mean": [2.5]}545556def test_control_flow(foods_ipc_path: Path) -> None:57nums = pl.LazyFrame(58{59"x": [1, None, 2, 3, None, 4],60"y": [5, 4, None, 3, None, 2],61"z": [3, 4, None, 3, 6, None],62}63)64res = pl.SQLContext(df=nums).execute(65"""66SELECT67COALESCE(x,y,z) as "coalsc",68NULLIF(x, y) as "nullif x_y",69NULLIF(y, z) as "nullif y_z",70IFNULL(x, y) as "ifnull x_y",71IFNULL(y,-1) as "inullf y_z",72COALESCE(x, NULLIF(y,z)) as "both",73IF(x = y, 'eq', 'ne') as "x_eq_y",74FROM df75""",76eager=True,77)78assert res.to_dict(as_series=False) == {79"coalsc": [1, 4, 2, 3, 6, 4],80"nullif x_y": [1, None, 2, None, None, 4],81"nullif y_z": [5, None, None, None, None, 2],82"ifnull x_y": [1, 4, 2, 3, None, 4],83"inullf y_z": [5, 4, -1, 3, -1, 2],84"both": [1, None, 2, 3, None, 4],85"x_eq_y": ["ne", "ne", "ne", "eq", "ne", "ne"],86}8788for null_func in ("IFNULL", "NULLIF"):89with pytest.raises(90SQLSyntaxError,91match=r"(IFNULL|NULLIF) expects 2 arguments \(found 3\)",92):93pl.SQLContext(df=nums).execute(f"SELECT {null_func}(x,y,z) FROM df")949596def test_greatest_least() -> None:97df = pl.DataFrame(98{99"a": [-100, None, 200, 99],100"b": [None, -0.1, 99.0, 100.0],101"c": ["bb", "aa", "dd", "cc"],102"d": ["cc", "bb", "aa", "dd"],103"e": [date(1969, 12, 31), date(2021, 1, 2), None, date(2021, 1, 4)],104"f": [date(1970, 1, 1), date(2000, 10, 20), date(2077, 7, 5), None],105}106)107with pl.SQLContext(df=df) as ctx:108df_max_horizontal = ctx.execute(109"""110SELECT111GREATEST("a", 0, "b") AS max_ab_zero,112GREATEST("a", "b") AS max_ab,113GREATEST("c", "d", ) AS max_cd,114GREATEST("e", "f") AS max_ef,115GREATEST('1999-12-31'::date, "e", "f") AS max_efx116FROM df117"""118).collect()119120assert_frame_equal(121df_max_horizontal,122pl.DataFrame(123{124"max_ab_zero": [0.0, 0.0, 200.0, 100.0],125"max_ab": [-100.0, -0.1, 200.0, 100.0],126"max_cd": ["cc", "bb", "dd", "dd"],127"max_ef": [128date(1970, 1, 1),129date(2021, 1, 2),130date(2077, 7, 5),131date(2021, 1, 4),132],133"max_efx": [134date(1999, 12, 31),135date(2021, 1, 2),136date(2077, 7, 5),137date(2021, 1, 4),138],139}140),141)142143df_min_horizontal = ctx.execute(144"""145SELECT146LEAST("b", "a", 0) AS min_ab_zero,147LEAST("a", "b") AS min_ab,148LEAST("c", "d") AS min_cd,149LEAST("e", "f") AS min_ef,150LEAST("f", "e", '1999-12-31'::date) AS min_efx151FROM df152"""153).collect()154155assert_frame_equal(156df_min_horizontal,157pl.DataFrame(158{159"min_ab_zero": [-100.0, -0.1, 0.0, 0.0],160"min_ab": [-100.0, -0.1, 99.0, 99.0],161"min_cd": ["bb", "aa", "aa", "cc"],162"min_ef": [163date(1969, 12, 31),164date(2000, 10, 20),165date(2077, 7, 5),166date(2021, 1, 4),167],168"min_efx": [169date(1969, 12, 31),170date(1999, 12, 31),171date(1999, 12, 31),172date(1999, 12, 31),173],174}175),176)177178179