Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/test_queries.py
6939 views
1
from __future__ import annotations
2
3
from datetime import date, datetime, time, timedelta
4
5
import numpy as np
6
import pandas as pd
7
import pytest
8
9
import polars as pl
10
from polars.testing import assert_frame_equal
11
from tests.unit.conftest import NUMERIC_DTYPES
12
13
14
def test_sort_by_bools() -> None:
15
# tests dispatch
16
df = pl.DataFrame(
17
{
18
"foo": [1, 2, 3],
19
"bar": [6.0, 7.0, 8.0],
20
"ham": ["a", "b", "c"],
21
}
22
)
23
out = df.with_columns((pl.col("foo") % 2 == 1).alias("foo_odd")).sort(
24
by=["foo_odd", "foo"]
25
)
26
assert out.rows() == [
27
(2, 7.0, "b", False),
28
(1, 6.0, "a", True),
29
(3, 8.0, "c", True),
30
]
31
assert out.shape == (3, 4)
32
33
34
def test_repeat_expansion_in_group_by() -> None:
35
out = (
36
pl.DataFrame({"g": [1, 2, 2, 3, 3, 3]})
37
.group_by("g", maintain_order=True)
38
.agg(pl.repeat(1, pl.len()).cum_sum())
39
.to_dict(as_series=False)
40
)
41
assert out == {"g": [1, 2, 3], "repeat": [[1], [1, 2], [1, 2, 3]]}
42
43
44
def test_agg_after_head() -> None:
45
a = [1, 1, 1, 2, 2, 3, 3, 3, 3]
46
47
df = pl.DataFrame({"a": a, "b": pl.arange(1, len(a) + 1, eager=True)})
48
49
expected = pl.DataFrame({"a": [1, 2, 3], "b": [6, 9, 21]})
50
51
for maintain_order in [True, False]:
52
out = df.group_by("a", maintain_order=maintain_order).agg(
53
[pl.col("b").head(3).sum()]
54
)
55
56
if not maintain_order:
57
out = out.sort("a")
58
59
assert_frame_equal(out, expected)
60
61
62
def test_overflow_uint16_agg_mean() -> None:
63
assert (
64
pl.DataFrame(
65
{
66
"col1": ["A" for _ in range(1025)],
67
"col3": [64 for _ in range(1025)],
68
}
69
)
70
.with_columns(pl.col("col3").cast(pl.UInt16))
71
.group_by(["col1"])
72
.agg(pl.col("col3").mean())
73
.to_dict(as_series=False)
74
) == {"col1": ["A"], "col3": [64.0]}
75
76
77
def test_binary_on_list_agg_3345() -> None:
78
df = pl.DataFrame(
79
{
80
"group": ["A", "A", "A", "B", "B", "B", "B"],
81
"id": [1, 2, 1, 4, 5, 4, 6],
82
}
83
)
84
85
assert (
86
df.group_by(["group"], maintain_order=True)
87
.agg(
88
[
89
(
90
(pl.col("id").unique_counts() / pl.col("id").len()).log()
91
* -1
92
* (pl.col("id").unique_counts() / pl.col("id").len())
93
).sum()
94
]
95
)
96
.to_dict(as_series=False)
97
) == {"group": ["A", "B"], "id": [0.6365141682948128, 1.0397207708399179]}
98
99
100
def test_maintain_order_after_sampling() -> None:
101
# internally samples cardinality
102
# check if the maintain_order kwarg is dispatched
103
df = pl.DataFrame(
104
{
105
"type": ["A", "B", "C", "D", "A", "B", "C", "D"],
106
"value": [1, 3, 2, 3, 4, 5, 3, 4],
107
}
108
)
109
110
result = df.group_by("type", maintain_order=True).agg(pl.col("value").sum())
111
expected = {"type": ["A", "B", "C", "D"], "value": [5, 8, 5, 7]}
112
assert result.to_dict(as_series=False) == expected
113
114
115
@pytest.mark.may_fail_auto_streaming
116
@pytest.mark.parametrize("descending", [False, True])
117
@pytest.mark.parametrize("nulls_last", [False, True])
118
@pytest.mark.parametrize("maintain_order", [False, True])
119
def test_sorted_group_by_optimization(
120
descending: bool, nulls_last: bool, maintain_order: bool
121
) -> None:
122
df = pl.DataFrame({"a": np.random.randint(0, 5, 20)})
123
124
# the sorted optimization should not randomize the
125
# groups, so this is tests that we hit the sorted optimization
126
sorted_implicit = (
127
df.with_columns(pl.col("a").sort(descending=descending, nulls_last=nulls_last))
128
.group_by("a", maintain_order=maintain_order)
129
.agg(pl.len())
130
)
131
sorted_explicit = (
132
df.group_by("a", maintain_order=maintain_order)
133
.agg(pl.len())
134
.sort("a", descending=descending, nulls_last=nulls_last)
135
)
136
assert_frame_equal(
137
sorted_explicit,
138
sorted_implicit,
139
check_row_order=maintain_order,
140
)
141
142
143
def test_median_on_shifted_col_3522() -> None:
144
df = pl.DataFrame(
145
{
146
"foo": [
147
datetime(2022, 5, 5, 12, 31, 34),
148
datetime(2022, 5, 5, 12, 47, 1),
149
datetime(2022, 5, 6, 8, 59, 11),
150
]
151
}
152
)
153
diffs = df.select(pl.col("foo").diff().dt.total_seconds())
154
assert diffs.select(pl.col("foo").median()).to_series()[0] == 36828.5
155
156
157
def test_group_by_agg_equals_zero_3535() -> None:
158
# setup test frame
159
df = pl.DataFrame(
160
data=[
161
# note: the 'bb'-keyed values should clearly sum to 0
162
("aa", 10, None),
163
("bb", -10, 0.5),
164
("bb", 10, -0.5),
165
("cc", -99, 10.5),
166
("cc", None, 0.0),
167
],
168
schema=[
169
("key", pl.String),
170
("val1", pl.Int16),
171
("val2", pl.Float32),
172
],
173
orient="row",
174
)
175
# group by the key, aggregating the two numeric cols
176
assert df.group_by(pl.col("key"), maintain_order=True).agg(
177
[pl.col("val1").sum(), pl.col("val2").sum()]
178
).to_dict(as_series=False) == {
179
"key": ["aa", "bb", "cc"],
180
"val1": [10, 0, -99],
181
"val2": [0.0, 0.0, 10.5],
182
}
183
184
185
def test_group_by_followed_by_limit() -> None:
186
lf = pl.LazyFrame(
187
{
188
"key": ["xx", "yy", "zz", "xx", "zz", "zz"],
189
"val1": [15, 25, 10, 20, 20, 20],
190
"val2": [-33, 20, 44, -2, 16, 71],
191
}
192
)
193
grp = lf.group_by("key", maintain_order=True).agg(pl.col("val1", "val2").sum())
194
assert sorted(grp.collect().rows()) == [
195
("xx", 35, -35),
196
("yy", 25, 20),
197
("zz", 50, 131),
198
]
199
assert sorted(grp.head(2).collect().rows()) == [
200
("xx", 35, -35),
201
("yy", 25, 20),
202
]
203
assert sorted(grp.head(10).collect().rows()) == [
204
("xx", 35, -35),
205
("yy", 25, 20),
206
("zz", 50, 131),
207
]
208
209
210
def test_dtype_concat_3735() -> None:
211
for dt in NUMERIC_DTYPES:
212
d1 = pl.DataFrame([pl.Series("val", [1, 2], dtype=dt)])
213
214
d2 = pl.DataFrame([pl.Series("val", [3, 4], dtype=dt)])
215
df = pl.concat([d1, d2])
216
217
assert df.shape == (4, 1)
218
assert df.columns == ["val"]
219
assert df.to_series().to_list() == [1, 2, 3, 4]
220
221
222
def test_opaque_filter_on_lists_3784() -> None:
223
df = pl.DataFrame(
224
{"str": ["A", "B", "A", "B", "C"], "group": [1, 1, 2, 1, 2]}
225
).lazy()
226
df = df.with_columns(pl.col("str").cast(pl.Categorical))
227
228
df_groups = df.group_by("group").agg([pl.col("str").alias("str_list")])
229
230
pre = "A"
231
succ = "B"
232
233
assert (
234
df_groups.filter(
235
pl.col("str_list").map_elements(
236
lambda variant: pre in variant
237
and succ in variant
238
and variant.to_list().index(pre) < variant.to_list().index(succ),
239
return_dtype=pl.Boolean,
240
)
241
)
242
).collect().to_dict(as_series=False) == {
243
"group": [1],
244
"str_list": [["A", "B", "B"]],
245
}
246
247
248
def test_ternary_none_struct() -> None:
249
ignore_nulls = False
250
251
def map_expr(name: str) -> pl.Expr:
252
return (
253
pl.when(ignore_nulls or pl.col(name).null_count() == 0)
254
.then(
255
pl.struct(
256
[
257
pl.sum(name).alias("sum"),
258
(pl.len() - pl.col(name).null_count()).alias("count"),
259
]
260
),
261
)
262
.otherwise(None)
263
).alias("out")
264
265
assert (
266
pl.DataFrame({"groups": [1, 2, 3, 4], "values": [None, None, 1, 2]})
267
.group_by("groups", maintain_order=True)
268
.agg([map_expr("values")])
269
).to_dict(as_series=False) == {
270
"groups": [1, 2, 3, 4],
271
"out": [
272
None,
273
None,
274
{"sum": 1, "count": 1},
275
{"sum": 2, "count": 1},
276
],
277
}
278
279
280
def test_edge_cast_string_duplicates_4259() -> None:
281
# carefully constructed data.
282
# note that row 2, 3 concatenated are the same string ('5461214484')
283
df = pl.DataFrame(
284
{
285
"a": [99, 54612, 546121],
286
"b": [1, 14484, 4484],
287
}
288
).with_columns(pl.all().cast(pl.String))
289
290
mask = df.select(["a", "b"]).is_duplicated()
291
df_filtered = df.filter(pl.lit(mask))
292
293
assert df_filtered.shape == (0, 2)
294
assert df_filtered.rows() == []
295
296
297
def test_query_4438() -> None:
298
df = pl.DataFrame({"x": [1, 2, 3, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 1, 1]})
299
300
q = (
301
df.lazy()
302
.with_columns(pl.col("x").rolling_max(window_size=3).alias("rolling_max"))
303
.fill_null(strategy="backward")
304
.with_columns(
305
pl.col("rolling_max").rolling_max(window_size=3).alias("rolling_max_2")
306
)
307
)
308
assert q.collect()["rolling_max_2"].to_list() == [
309
None,
310
None,
311
3,
312
10,
313
10,
314
10,
315
10,
316
10,
317
9,
318
8,
319
7,
320
6,
321
5,
322
4,
323
3,
324
]
325
326
327
def test_query_4538() -> None:
328
df = pl.DataFrame(
329
[
330
pl.Series("value", ["aaa", "bbb"]),
331
]
332
)
333
assert df.select([pl.col("value").str.to_uppercase().is_in(["AAA"])])[
334
"value"
335
].to_list() == [True, False]
336
337
338
def test_none_comparison_4773() -> None:
339
df = pl.DataFrame(
340
{
341
"x": [0, 1, None, 2],
342
"y": [1, 2, None, 3],
343
}
344
).filter(pl.col("x") != pl.col("y"))
345
assert df.shape == (3, 2)
346
assert df.rows() == [(0, 1), (1, 2), (2, 3)]
347
348
349
def test_datetime_supertype_5236() -> None:
350
df = pd.DataFrame(
351
{
352
"StartDateTime": [pd.Timestamp.now(tz="UTC"), pd.Timestamp.now(tz="UTC")],
353
"EndDateTime": [pd.Timestamp.now(tz="UTC"), pd.Timestamp.now(tz="UTC")],
354
}
355
)
356
out = pl.from_pandas(df).filter(
357
pl.col("StartDateTime")
358
< (pl.col("EndDateTime").dt.truncate("1d").max() - timedelta(days=1))
359
)
360
assert out.shape == (0, 2)
361
assert out.dtypes == [pl.Datetime("ns", "UTC")] * 2
362
363
364
def test_shift_drop_nulls_10875() -> None:
365
assert pl.LazyFrame({"a": [1, 2, 3]}).shift(1).drop_nulls().collect()[
366
"a"
367
].to_list() == [1, 2]
368
369
370
def test_temporal_downcasts() -> None:
371
s = pl.Series([-1, 0, 1]).cast(pl.Datetime("us"))
372
373
assert s.to_list() == [
374
datetime(1969, 12, 31, 23, 59, 59, 999999),
375
datetime(1970, 1, 1),
376
datetime(1970, 1, 1, 0, 0, 0, 1),
377
]
378
379
# downcast (from us to ms, or from datetime to date) should NOT change the date
380
for s_dt in (s.dt.date(), s.cast(pl.Date)):
381
assert s_dt.to_list() == [
382
date(1969, 12, 31),
383
date(1970, 1, 1),
384
date(1970, 1, 1),
385
]
386
assert s.cast(pl.Datetime("ms")).to_list() == [
387
datetime(1969, 12, 31, 23, 59, 59, 999000),
388
datetime(1970, 1, 1),
389
datetime(1970, 1, 1),
390
]
391
392
393
def test_slice_pushdown_queries() -> None:
394
lf = pl.LazyFrame({"a": range(100)}).cache()
395
396
q1 = lf.slice(50).select(pl.col.a + 200)
397
q2 = lf
398
399
q = pl.concat([q1, q2])
400
401
expected = pl.Series("a", list(range(250, 300)) + list(range(100))).to_frame()
402
403
assert_frame_equal(q.collect(), expected)
404
405
nq = q.unique()
406
assert_frame_equal(nq.collect(), expected, check_row_order=False)
407
408
nq = q.group_by("a").agg([])
409
assert_frame_equal(nq.collect(), expected, check_row_order=False)
410
411
nq = q.group_by("a", maintain_order=True).agg([])
412
assert_frame_equal(nq.collect(), expected)
413
414
nq = q.group_by("a").agg(b=pl.col.a.first()).select(a=pl.col.b)
415
assert_frame_equal(nq.collect(), expected, check_row_order=False)
416
417
nq = q.group_by("a", maintain_order=True).agg(b=pl.col.a.first()).select(a=pl.col.b)
418
assert_frame_equal(nq.collect(), expected)
419
420
421
def test_temporal_time_casts() -> None:
422
s = pl.Series([-1, 0, 1]).cast(pl.Datetime("us"))
423
424
for s_dt in (s.dt.time(), s.cast(pl.Time)):
425
assert s_dt.to_list() == [
426
time(23, 59, 59, 999999),
427
time(0, 0, 0, 0),
428
time(0, 0, 0, 1),
429
]
430
431
432
def assert_unopt_frame_equal(
433
lf: pl.LazyFrame, *, check_row_order: bool = False
434
) -> None:
435
assert_frame_equal(
436
lf.collect(),
437
lf.collect(optimizations=pl.QueryOptFlags.none()),
438
check_row_order=check_row_order,
439
)
440
441
442
def test_order_queries() -> None:
443
lf = pl.LazyFrame({"a": range(100), "b": range(100)})
444
445
q = lf.group_by(["a", "b"], maintain_order=True).agg([]).cache()
446
q1 = q.with_columns(pl.col.a.cum_sum())
447
q2 = q.with_columns(pl.col.b + 1).unique("a")
448
449
assert_unopt_frame_equal(pl.concat([q1, q]), check_row_order=True)
450
assert_unopt_frame_equal(pl.concat([q1, q]).unique(), check_row_order=False)
451
assert_unopt_frame_equal(pl.concat([q1, q2]).unique(), check_row_order=False)
452
453
q = lf.cache()
454
q1 = q.with_columns(pl.col.a.cum_sum())
455
q2 = q.with_columns(pl.col.b + 1).unique("a")
456
457
assert_unopt_frame_equal(pl.concat([q1, q]), check_row_order=True)
458
assert_unopt_frame_equal(pl.concat([q1, q]).unique(), check_row_order=False)
459
assert_unopt_frame_equal(pl.concat([q1, q2]).unique(), check_row_order=False)
460
461