Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/sql/test_conditional.py
6939 views
1
from __future__ import annotations
2
3
from datetime import date
4
from pathlib import Path
5
6
import pytest
7
8
import polars as pl
9
from polars.exceptions import SQLSyntaxError
10
from polars.testing import assert_frame_equal
11
12
13
@pytest.fixture
14
def foods_ipc_path() -> Path:
15
return Path(__file__).parent.parent / "io" / "files" / "foods1.ipc"
16
17
18
def test_case_when() -> None:
19
lf = pl.LazyFrame(
20
{
21
"v1": [None, 2, None, 4],
22
"v2": [101, 202, 303, 404],
23
}
24
)
25
with pl.SQLContext(test_data=lf, eager=True) as ctx:
26
out = ctx.execute(
27
"""
28
SELECT *, CASE WHEN COALESCE(v1, v2) % 2 != 0 THEN 'odd' ELSE 'even' END as "v3"
29
FROM test_data
30
"""
31
)
32
assert out.to_dict(as_series=False) == {
33
"v1": [None, 2, None, 4],
34
"v2": [101, 202, 303, 404],
35
"v3": ["odd", "even", "odd", "even"],
36
}
37
38
39
@pytest.mark.parametrize("else_clause", ["ELSE NULL ", ""])
40
def test_case_when_optional_else(else_clause: str) -> None:
41
df = pl.DataFrame(
42
{
43
"a": [1, 2, 3, 4, 5, 6, 7],
44
"b": [7, 6, 5, 4, 3, 2, 1],
45
"c": [3, 4, 0, 3, 4, 1, 1],
46
}
47
)
48
query = f"""
49
SELECT
50
AVG(CASE WHEN a <= b THEN c {else_clause}END) AS conditional_mean
51
FROM self
52
"""
53
res = df.sql(query)
54
assert res.to_dict(as_series=False) == {"conditional_mean": [2.5]}
55
56
57
def test_control_flow(foods_ipc_path: Path) -> None:
58
nums = pl.LazyFrame(
59
{
60
"x": [1, None, 2, 3, None, 4],
61
"y": [5, 4, None, 3, None, 2],
62
"z": [3, 4, None, 3, 6, None],
63
}
64
)
65
res = pl.SQLContext(df=nums).execute(
66
"""
67
SELECT
68
COALESCE(x,y,z) as "coalsc",
69
NULLIF(x, y) as "nullif x_y",
70
NULLIF(y, z) as "nullif y_z",
71
IFNULL(x, y) as "ifnull x_y",
72
IFNULL(y,-1) as "inullf y_z",
73
COALESCE(x, NULLIF(y,z)) as "both",
74
IF(x = y, 'eq', 'ne') as "x_eq_y",
75
FROM df
76
""",
77
eager=True,
78
)
79
assert res.to_dict(as_series=False) == {
80
"coalsc": [1, 4, 2, 3, 6, 4],
81
"nullif x_y": [1, None, 2, None, None, 4],
82
"nullif y_z": [5, None, None, None, None, 2],
83
"ifnull x_y": [1, 4, 2, 3, None, 4],
84
"inullf y_z": [5, 4, -1, 3, -1, 2],
85
"both": [1, None, 2, 3, None, 4],
86
"x_eq_y": ["ne", "ne", "ne", "eq", "ne", "ne"],
87
}
88
89
for null_func in ("IFNULL", "NULLIF"):
90
with pytest.raises(
91
SQLSyntaxError,
92
match=r"(IFNULL|NULLIF) expects 2 arguments \(found 3\)",
93
):
94
pl.SQLContext(df=nums).execute(f"SELECT {null_func}(x,y,z) FROM df")
95
96
97
def test_greatest_least() -> None:
98
df = pl.DataFrame(
99
{
100
"a": [-100, None, 200, 99],
101
"b": [None, -0.1, 99.0, 100.0],
102
"c": ["bb", "aa", "dd", "cc"],
103
"d": ["cc", "bb", "aa", "dd"],
104
"e": [date(1969, 12, 31), date(2021, 1, 2), None, date(2021, 1, 4)],
105
"f": [date(1970, 1, 1), date(2000, 10, 20), date(2077, 7, 5), None],
106
}
107
)
108
with pl.SQLContext(df=df) as ctx:
109
df_max_horizontal = ctx.execute(
110
"""
111
SELECT
112
GREATEST("a", 0, "b") AS max_ab_zero,
113
GREATEST("a", "b") AS max_ab,
114
GREATEST("c", "d", ) AS max_cd,
115
GREATEST("e", "f") AS max_ef,
116
GREATEST('1999-12-31'::date, "e", "f") AS max_efx
117
FROM df
118
"""
119
).collect()
120
121
assert_frame_equal(
122
df_max_horizontal,
123
pl.DataFrame(
124
{
125
"max_ab_zero": [0.0, 0.0, 200.0, 100.0],
126
"max_ab": [-100.0, -0.1, 200.0, 100.0],
127
"max_cd": ["cc", "bb", "dd", "dd"],
128
"max_ef": [
129
date(1970, 1, 1),
130
date(2021, 1, 2),
131
date(2077, 7, 5),
132
date(2021, 1, 4),
133
],
134
"max_efx": [
135
date(1999, 12, 31),
136
date(2021, 1, 2),
137
date(2077, 7, 5),
138
date(2021, 1, 4),
139
],
140
}
141
),
142
)
143
144
df_min_horizontal = ctx.execute(
145
"""
146
SELECT
147
LEAST("b", "a", 0) AS min_ab_zero,
148
LEAST("a", "b") AS min_ab,
149
LEAST("c", "d") AS min_cd,
150
LEAST("e", "f") AS min_ef,
151
LEAST("f", "e", '1999-12-31'::date) AS min_efx
152
FROM df
153
"""
154
).collect()
155
156
assert_frame_equal(
157
df_min_horizontal,
158
pl.DataFrame(
159
{
160
"min_ab_zero": [-100.0, -0.1, 0.0, 0.0],
161
"min_ab": [-100.0, -0.1, 99.0, 99.0],
162
"min_cd": ["bb", "aa", "aa", "cc"],
163
"min_ef": [
164
date(1969, 12, 31),
165
date(2000, 10, 20),
166
date(2077, 7, 5),
167
date(2021, 1, 4),
168
],
169
"min_efx": [
170
date(1969, 12, 31),
171
date(1999, 12, 31),
172
date(1999, 12, 31),
173
date(1999, 12, 31),
174
],
175
}
176
),
177
)
178
179