Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/operations/aggregation/test_aggregations.py
8424 views
1
from __future__ import annotations
2
3
from datetime import date, datetime, time, timedelta
4
from typing import TYPE_CHECKING, cast
5
from zoneinfo import ZoneInfo
6
7
import numpy as np
8
import pytest
9
from hypothesis import given
10
11
import polars as pl
12
from polars.testing import assert_frame_equal
13
from polars.testing.parametric import dataframes
14
15
if TYPE_CHECKING:
16
from collections.abc import Callable
17
from typing import Any
18
19
import numpy.typing as npt
20
21
from polars._typing import PolarsDataType, TimeUnit
22
23
24
def test_quantile_expr_input() -> None:
25
df = pl.DataFrame({"a": [1, 2, 3, 4, 5], "b": [0.0, 0.0, 0.3, 0.2, 0.0]})
26
27
assert_frame_equal(
28
df.select([pl.col("a").quantile(pl.col("b").sum() + 0.1)]),
29
df.select(pl.col("a").quantile(0.6)),
30
)
31
32
df = pl.DataFrame({"x": [1, 2, 3, 4], "y": [0.25, 0.3, 0.4, 0.75]})
33
34
assert_frame_equal(
35
df.select(
36
pl.col.x.quantile(pl.concat_list(pl.col.y.min(), pl.col.y.max().first()))
37
),
38
df.select(pl.col.x.quantile([0.25, 0.75])),
39
)
40
41
42
def test_boolean_aggs() -> None:
43
df = pl.DataFrame({"bool": [True, False, None, True]})
44
45
aggs = [
46
pl.mean("bool").alias("mean"),
47
pl.std("bool").alias("std"),
48
pl.var("bool").alias("var"),
49
]
50
assert df.select(aggs).to_dict(as_series=False) == {
51
"mean": [0.6666666666666666],
52
"std": [0.5773502691896258],
53
"var": [0.33333333333333337],
54
}
55
56
assert df.group_by(pl.lit(1)).agg(aggs).to_dict(as_series=False) == {
57
"literal": [1],
58
"mean": [0.6666666666666666],
59
"std": [0.5773502691896258],
60
"var": [0.33333333333333337],
61
}
62
63
64
def test_duration_aggs() -> None:
65
df = pl.DataFrame(
66
{
67
"time1": pl.datetime_range(
68
start=datetime(2022, 12, 12),
69
end=datetime(2022, 12, 18),
70
interval="1d",
71
eager=True,
72
),
73
"time2": pl.datetime_range(
74
start=datetime(2023, 1, 12),
75
end=datetime(2023, 1, 18),
76
interval="1d",
77
eager=True,
78
),
79
}
80
)
81
82
df = df.with_columns((pl.col("time2") - pl.col("time1")).alias("time_difference"))
83
84
assert df.select("time_difference").mean().to_dict(as_series=False) == {
85
"time_difference": [timedelta(days=31)]
86
}
87
assert df.group_by(pl.lit(1)).agg(pl.mean("time_difference")).to_dict(
88
as_series=False
89
) == {
90
"literal": [1],
91
"time_difference": [timedelta(days=31)],
92
}
93
94
95
def test_list_aggregation_that_filters_all_data_6017() -> None:
96
out = (
97
pl.DataFrame({"col_to_group_by": [2], "flt": [1672740910.967138], "col3": [1]})
98
.group_by("col_to_group_by")
99
.agg((pl.col("flt").filter(col3=0).diff() * 1000).diff().alias("calc"))
100
)
101
102
assert out.schema == {"col_to_group_by": pl.Int64, "calc": pl.List(pl.Float64)}
103
assert out.to_dict(as_series=False) == {"col_to_group_by": [2], "calc": [[]]}
104
105
106
def test_median() -> None:
107
s = pl.Series([1, 2, 3])
108
assert s.median() == 2
109
110
111
def test_single_element_std() -> None:
112
s = pl.Series([1])
113
assert s.std(ddof=1) is None
114
assert s.std(ddof=0) == 0.0
115
116
117
def test_quantile() -> None:
118
s = pl.Series([1, 2, 3])
119
assert s.quantile(0.5, "nearest") == 2
120
assert s.quantile(0.5, "lower") == 2
121
assert s.quantile(0.5, "higher") == 2
122
assert s.quantile([0.25, 0.75], "linear") == [1.5, 2.5]
123
124
df = pl.DataFrame({"a": [1.0, 2.0, 3.0]})
125
expected = pl.DataFrame({"a": [[2.0]]})
126
assert_frame_equal(
127
df.select(pl.col("a").quantile([0.5], interpolation="linear")), expected
128
)
129
130
131
def test_quantile_error_checking() -> None:
132
s = pl.Series([1, 2, 3])
133
with pytest.raises(pl.exceptions.ComputeError):
134
s.quantile(-0.1)
135
with pytest.raises(pl.exceptions.ComputeError):
136
s.quantile(1.1)
137
with pytest.raises(pl.exceptions.ComputeError):
138
s.quantile([0.0, 1.2])
139
140
141
def test_quantile_date() -> None:
142
s = pl.Series(
143
"a", [date(2025, 1, 1), date(2025, 1, 2), date(2025, 1, 3), date(2025, 1, 4)]
144
)
145
assert s.quantile(0.5, "nearest") == datetime(2025, 1, 3)
146
assert s.quantile(0.5, "lower") == datetime(2025, 1, 2)
147
assert s.quantile(0.5, "higher") == datetime(2025, 1, 3)
148
assert s.quantile(0.5, "linear") == datetime(2025, 1, 2, 12)
149
150
df = s.to_frame().lazy()
151
result = df.select(
152
nearest=pl.col("a").quantile(0.5, "nearest"),
153
lower=pl.col("a").quantile(0.5, "lower"),
154
higher=pl.col("a").quantile(0.5, "higher"),
155
linear=pl.col("a").quantile(0.5, "linear"),
156
)
157
dt = pl.Datetime("us")
158
assert result.collect_schema() == pl.Schema(
159
{
160
"nearest": dt,
161
"lower": dt,
162
"higher": dt,
163
"linear": dt,
164
}
165
)
166
expected = pl.DataFrame(
167
{
168
"nearest": pl.Series([datetime(2025, 1, 3)], dtype=dt),
169
"lower": pl.Series([datetime(2025, 1, 2)], dtype=dt),
170
"higher": pl.Series([datetime(2025, 1, 3)], dtype=dt),
171
"linear": pl.Series([datetime(2025, 1, 2, 12)], dtype=dt),
172
}
173
)
174
assert_frame_equal(result.collect(), expected)
175
176
177
@pytest.mark.parametrize("tu", ["ms", "us", "ns"])
178
@pytest.mark.parametrize("tz", [None, "Asia/Tokyo", "UTC"])
179
def test_quantile_datetime(tu: TimeUnit, tz: str) -> None:
180
time_zone = ZoneInfo(tz) if tz else None
181
dt = pl.Datetime(tu, time_zone)
182
183
s = pl.Series(
184
"a",
185
[
186
datetime(2025, 1, 1, tzinfo=time_zone),
187
datetime(2025, 1, 2, tzinfo=time_zone),
188
datetime(2025, 1, 3, tzinfo=time_zone),
189
datetime(2025, 1, 4, tzinfo=time_zone),
190
],
191
dtype=dt,
192
)
193
assert s.quantile(0.5, "nearest") == datetime(2025, 1, 3, tzinfo=time_zone)
194
assert s.quantile(0.5, "lower") == datetime(2025, 1, 2, tzinfo=time_zone)
195
assert s.quantile(0.5, "higher") == datetime(2025, 1, 3, tzinfo=time_zone)
196
assert s.quantile(0.5, "linear") == datetime(2025, 1, 2, 12, tzinfo=time_zone)
197
198
df = s.to_frame().lazy()
199
result = df.select(
200
nearest=pl.col("a").quantile(0.5, "nearest"),
201
lower=pl.col("a").quantile(0.5, "lower"),
202
higher=pl.col("a").quantile(0.5, "higher"),
203
linear=pl.col("a").quantile(0.5, "linear"),
204
)
205
assert result.collect_schema() == pl.Schema(
206
{
207
"nearest": dt,
208
"lower": dt,
209
"higher": dt,
210
"linear": dt,
211
}
212
)
213
expected = pl.DataFrame(
214
{
215
"nearest": pl.Series([datetime(2025, 1, 3, tzinfo=time_zone)], dtype=dt),
216
"lower": pl.Series([datetime(2025, 1, 2, tzinfo=time_zone)], dtype=dt),
217
"higher": pl.Series([datetime(2025, 1, 3, tzinfo=time_zone)], dtype=dt),
218
"linear": pl.Series([datetime(2025, 1, 2, 12, tzinfo=time_zone)], dtype=dt),
219
}
220
)
221
assert_frame_equal(result.collect(), expected)
222
223
224
@pytest.mark.parametrize("tu", ["ms", "us", "ns"])
225
def test_quantile_duration(tu: TimeUnit) -> None:
226
dt = pl.Duration(tu)
227
228
s = pl.Series(
229
"a",
230
[timedelta(days=1), timedelta(days=2), timedelta(days=3), timedelta(days=4)],
231
dtype=dt,
232
)
233
assert s.quantile(0.5, "nearest") == timedelta(days=3)
234
assert s.quantile(0.5, "lower") == timedelta(days=2)
235
assert s.quantile(0.5, "higher") == timedelta(days=3)
236
assert s.quantile(0.5, "linear") == timedelta(days=2, hours=12)
237
238
df = s.to_frame().lazy()
239
result = df.select(
240
nearest=pl.col("a").quantile(0.5, "nearest"),
241
lower=pl.col("a").quantile(0.5, "lower"),
242
higher=pl.col("a").quantile(0.5, "higher"),
243
linear=pl.col("a").quantile(0.5, "linear"),
244
)
245
assert result.collect_schema() == pl.Schema(
246
{
247
"nearest": dt,
248
"lower": dt,
249
"higher": dt,
250
"linear": dt,
251
}
252
)
253
expected = pl.DataFrame(
254
{
255
"nearest": pl.Series([timedelta(days=3)], dtype=dt),
256
"lower": pl.Series([timedelta(days=2)], dtype=dt),
257
"higher": pl.Series([timedelta(days=3)], dtype=dt),
258
"linear": pl.Series([timedelta(days=2, hours=12)], dtype=dt),
259
}
260
)
261
assert_frame_equal(result.collect(), expected)
262
263
264
def test_quantile_time() -> None:
265
s = pl.Series("a", [time(hour=1), time(hour=2), time(hour=3), time(hour=4)])
266
assert s.quantile(0.5, "nearest") == time(hour=3)
267
assert s.quantile(0.5, "lower") == time(hour=2)
268
assert s.quantile(0.5, "higher") == time(hour=3)
269
assert s.quantile(0.5, "linear") == time(hour=2, minute=30)
270
271
df = s.to_frame().lazy()
272
result = df.select(
273
nearest=pl.col("a").quantile(0.5, "nearest"),
274
lower=pl.col("a").quantile(0.5, "lower"),
275
higher=pl.col("a").quantile(0.5, "higher"),
276
linear=pl.col("a").quantile(0.5, "linear"),
277
)
278
assert result.collect_schema() == pl.Schema(
279
{
280
"nearest": pl.Time,
281
"lower": pl.Time,
282
"higher": pl.Time,
283
"linear": pl.Time,
284
}
285
)
286
expected = pl.DataFrame(
287
{
288
"nearest": pl.Series([time(hour=3)]),
289
"lower": pl.Series([time(hour=2)]),
290
"higher": pl.Series([time(hour=3)]),
291
"linear": pl.Series([time(hour=2, minute=30)]),
292
}
293
)
294
assert_frame_equal(result.collect(), expected)
295
296
297
@pytest.mark.slow
298
@pytest.mark.parametrize("tp", [int, float])
299
@pytest.mark.parametrize("n", [1, 2, 10, 100])
300
def test_quantile_vs_numpy(tp: type, n: int) -> None:
301
a: np.ndarray[Any, Any] = np.random.randint(0, 50, n).astype(tp)
302
np_result: npt.ArrayLike | None = np.median(a)
303
# nan check
304
if np_result != np_result:
305
np_result = None
306
median = pl.Series(a).median()
307
if median is not None:
308
assert np.isclose(median, np_result) # type: ignore[arg-type]
309
else:
310
assert np_result is None
311
312
q = np.random.sample()
313
try:
314
np_result = np.quantile(a, q)
315
except IndexError:
316
np_result = None
317
if np_result:
318
# nan check
319
if np_result != np_result:
320
np_result = None
321
assert np.isclose(
322
pl.Series(a).quantile(q, interpolation="linear"), # type: ignore[arg-type]
323
np_result, # type: ignore[arg-type]
324
)
325
326
df = pl.DataFrame({"a": a})
327
328
expected = df.select(
329
pl.col.a.quantile(0.25).alias("low"), pl.col.a.quantile(0.75).alias("high")
330
).select(pl.concat_list(["low", "high"]).alias("quantiles"))
331
332
result = df.select(pl.col.a.quantile([0.25, 0.75]).alias("quantiles"))
333
334
assert_frame_equal(expected, result)
335
336
337
def test_mean_overflow() -> None:
338
assert np.isclose(
339
pl.Series([9_223_372_036_854_775_800, 100]).mean(), # type: ignore[arg-type]
340
4.611686018427388e18,
341
)
342
343
344
def test_mean_null_simd() -> None:
345
for dtype in [int, float]:
346
df = (
347
pl.Series(np.random.randint(0, 100, 1000))
348
.cast(dtype)
349
.to_frame("a")
350
.select(pl.when(pl.col("a") > 40).then(pl.col("a")))
351
)
352
353
s = df["a"]
354
assert s.mean() == s.to_pandas().mean()
355
356
357
def test_literal_group_agg_chunked_7968() -> None:
358
df = pl.DataFrame({"A": [1, 1], "B": [1, 3]})
359
ser = pl.concat([pl.Series([3]), pl.Series([4, 5])], rechunk=False)
360
361
assert_frame_equal(
362
df.group_by("A").agg(pl.col("B").search_sorted(ser)),
363
pl.DataFrame(
364
[
365
pl.Series("A", [1], dtype=pl.Int64),
366
pl.Series("B", [[1, 2, 2]], dtype=pl.List(pl.get_index_type())),
367
]
368
),
369
)
370
371
372
def test_duration_function_literal() -> None:
373
df = pl.DataFrame(
374
{
375
"A": ["x", "x", "y", "y", "y"],
376
"T": pl.datetime_range(
377
date(2022, 1, 1), date(2022, 5, 1), interval="1mo", eager=True
378
),
379
"S": [1, 2, 4, 8, 16],
380
}
381
)
382
383
result = df.group_by("A", maintain_order=True).agg(
384
(pl.col("T").max() + pl.duration(seconds=1)) - pl.col("T")
385
)
386
387
# this checks if the `pl.duration` is flagged as AggState::Literal
388
expected = pl.DataFrame(
389
{
390
"A": ["x", "y"],
391
"T": [
392
[timedelta(days=31, seconds=1), timedelta(seconds=1)],
393
[
394
timedelta(days=61, seconds=1),
395
timedelta(days=30, seconds=1),
396
timedelta(seconds=1),
397
],
398
],
399
}
400
)
401
assert_frame_equal(result, expected)
402
403
404
def test_string_par_materialize_8207() -> None:
405
df = pl.LazyFrame(
406
{
407
"a": ["a", "b", "d", "c", "e"],
408
"b": ["P", "L", "R", "T", "a long string"],
409
}
410
)
411
412
assert df.group_by(["a"]).agg(pl.min("b")).sort("a").collect().to_dict(
413
as_series=False
414
) == {
415
"a": ["a", "b", "c", "d", "e"],
416
"b": ["P", "L", "T", "R", "a long string"],
417
}
418
419
420
def test_online_variance() -> None:
421
df = pl.DataFrame(
422
{
423
"id": [1] * 5,
424
"no_nulls": [1, 2, 3, 4, 5],
425
"nulls": [1, None, 3, None, 5],
426
}
427
)
428
429
assert_frame_equal(
430
df.group_by("id")
431
.agg(pl.all().exclude("id").std())
432
.select(["no_nulls", "nulls"]),
433
df.select(pl.all().exclude("id").std()),
434
)
435
436
437
def test_implode_and_agg() -> None:
438
df = pl.DataFrame({"type": ["water", "fire", "water", "earth"]})
439
440
assert_frame_equal(
441
df.group_by("type").agg(pl.col("type").implode().first().alias("foo")),
442
pl.DataFrame(
443
{
444
"type": ["water", "fire", "earth"],
445
"foo": [["water", "water"], ["fire"], ["earth"]],
446
}
447
),
448
check_row_order=False,
449
)
450
451
# implode + function should be allowed in group_by
452
assert df.group_by("type", maintain_order=True).agg(
453
pl.col("type").implode().list.head().alias("foo")
454
).to_dict(as_series=False) == {
455
"type": ["water", "fire", "earth"],
456
"foo": [["water", "water"], ["fire"], ["earth"]],
457
}
458
assert df.select(pl.col("type").implode().list.head(1).over("type")).to_dict(
459
as_series=False
460
) == {"type": [["water"], ["fire"], ["water"], ["earth"]]}
461
462
463
def test_mapped_literal_to_literal_9217() -> None:
464
df = pl.DataFrame({"unique_id": ["a", "b"]})
465
assert df.group_by(True).agg(
466
pl.struct(pl.lit("unique_id").alias("unique_id"))
467
).to_dict(as_series=False) == {
468
"literal": [True],
469
"unique_id": [{"unique_id": "unique_id"}],
470
}
471
472
473
def test_sum_empty_and_null_set() -> None:
474
series = pl.Series("a", [], dtype=pl.Float32)
475
assert series.sum() == 0
476
477
series = pl.Series("a", [None], dtype=pl.Float32)
478
assert series.sum() == 0
479
480
df = pl.DataFrame(
481
{"a": [None, None, None], "b": [1, 1, 1]},
482
schema={"a": pl.Float32, "b": pl.Int64},
483
)
484
assert df.select(pl.sum("a")).item() == 0.0
485
assert df.group_by("b").agg(pl.sum("a"))["a"].item() == 0.0
486
487
488
def test_horizontal_sum_null_to_identity() -> None:
489
assert pl.DataFrame({"a": [1, 5], "b": [10, None]}).select(
490
pl.sum_horizontal(["a", "b"])
491
).to_series().to_list() == [11, 5]
492
493
494
def test_horizontal_sum_bool_dtype() -> None:
495
out = pl.DataFrame({"a": [True, False]}).select(pl.sum_horizontal("a"))
496
assert_frame_equal(
497
out, pl.DataFrame({"a": pl.Series([1, 0], dtype=pl.get_index_type())})
498
)
499
500
501
def test_horizontal_sum_in_group_by_15102() -> None:
502
nbr_records = 1000
503
out = (
504
pl.LazyFrame(
505
{
506
"x": [None, "two", None] * nbr_records,
507
"y": ["one", "two", None] * nbr_records,
508
"z": [None, "two", None] * nbr_records,
509
}
510
)
511
.select(pl.sum_horizontal(pl.all().is_null()).alias("num_null"))
512
.group_by("num_null")
513
.len()
514
.sort(by="num_null")
515
.collect()
516
)
517
assert_frame_equal(
518
out,
519
pl.DataFrame(
520
{
521
"num_null": pl.Series([0, 2, 3], dtype=pl.get_index_type()),
522
"len": pl.Series([nbr_records] * 3, dtype=pl.get_index_type()),
523
}
524
),
525
)
526
527
528
def test_first_last_unit_length_12363() -> None:
529
df = pl.DataFrame(
530
{
531
"a": [1, 2],
532
"b": [None, None],
533
}
534
)
535
536
assert df.select(
537
pl.all().drop_nulls().first().name.suffix("_first"),
538
pl.all().drop_nulls().last().name.suffix("_last"),
539
).to_dict(as_series=False) == {
540
"a_first": [1],
541
"b_first": [None],
542
"a_last": [2],
543
"b_last": [None],
544
}
545
546
547
def test_binary_op_agg_context_no_simplify_expr_12423() -> None:
548
expect = pl.DataFrame({"x": [1], "y": [1]}, schema={"x": pl.Int64, "y": pl.Int32})
549
550
for simplify_expression in (True, False):
551
assert_frame_equal(
552
expect,
553
pl.LazyFrame({"x": [1]})
554
.group_by("x")
555
.agg(y=pl.lit(1) * pl.lit(1))
556
.collect(
557
optimizations=pl.QueryOptFlags(simplify_expression=simplify_expression)
558
),
559
)
560
561
562
def test_nan_inf_aggregation() -> None:
563
df = pl.DataFrame(
564
[
565
("both nan", np.nan),
566
("both nan", np.nan),
567
("nan and 5", np.nan),
568
("nan and 5", 5),
569
("nan and null", np.nan),
570
("nan and null", None),
571
("both none", None),
572
("both none", None),
573
("both inf", np.inf),
574
("both inf", np.inf),
575
("inf and null", np.inf),
576
("inf and null", None),
577
],
578
schema=["group", "value"],
579
orient="row",
580
)
581
582
assert_frame_equal(
583
df.group_by("group", maintain_order=True).agg(
584
min=pl.col("value").min(),
585
max=pl.col("value").max(),
586
mean=pl.col("value").mean(),
587
),
588
pl.DataFrame(
589
[
590
("both nan", np.nan, np.nan, np.nan),
591
("nan and 5", 5, 5, np.nan),
592
("nan and null", np.nan, np.nan, np.nan),
593
("both none", None, None, None),
594
("both inf", np.inf, np.inf, np.inf),
595
("inf and null", np.inf, np.inf, np.inf),
596
],
597
schema=["group", "min", "max", "mean"],
598
orient="row",
599
),
600
)
601
602
603
@pytest.mark.parametrize("dtype", [pl.Int16, pl.UInt16])
604
def test_int16_max_12904(dtype: PolarsDataType) -> None:
605
s = pl.Series([None, 1], dtype=dtype)
606
607
assert s.min() == 1
608
assert s.max() == 1
609
610
611
def test_agg_filter_over_empty_df_13610() -> None:
612
ldf = pl.LazyFrame(
613
{
614
"a": [1, 1, 1, 2, 3],
615
"b": [True, True, True, True, True],
616
"c": [None, None, None, None, None],
617
}
618
)
619
620
out = (
621
ldf.drop_nulls()
622
.group_by(["a"], maintain_order=True)
623
.agg(pl.col("b").filter(pl.col("b").shift(1)))
624
.collect()
625
)
626
expected = pl.DataFrame(schema={"a": pl.Int64, "b": pl.List(pl.Boolean)})
627
assert_frame_equal(out, expected)
628
629
df = pl.DataFrame(schema={"a": pl.Int64, "b": pl.Boolean})
630
out = df.group_by("a").agg(pl.col("b").filter(pl.col("b").shift()))
631
expected = pl.DataFrame(schema={"a": pl.Int64, "b": pl.List(pl.Boolean)})
632
assert_frame_equal(out, expected)
633
634
635
@pytest.mark.may_fail_cloud # reason: output order is defined for this in cloud
636
@pytest.mark.may_fail_auto_streaming
637
@pytest.mark.slow
638
def test_agg_empty_sum_after_filter_14734() -> None:
639
f = (
640
pl.DataFrame({"a": [1, 2], "b": [1, 2]})
641
.lazy()
642
.group_by("a")
643
.agg(pl.col("b").filter(pl.lit(False)).sum())
644
.collect
645
)
646
647
last = f()
648
649
# We need both possible output orders, which should happen within
650
# 1000 iterations (during testing it usually happens within 10).
651
limit = 1000
652
i = 0
653
while (curr := f()).equals(last):
654
i += 1
655
assert i != limit
656
657
expect = pl.Series("b", [0, 0]).to_frame()
658
assert_frame_equal(expect, last.select("b"))
659
assert_frame_equal(expect, curr.select("b"))
660
661
662
@pytest.mark.slow
663
def test_grouping_hash_14749() -> None:
664
n_groups = 251
665
rows_per_group = 4
666
assert (
667
pl.DataFrame(
668
{
669
"grp": np.repeat(np.arange(n_groups), rows_per_group),
670
"x": np.tile(np.arange(rows_per_group), n_groups),
671
}
672
)
673
.select(pl.col("x").max().over("grp"))["x"]
674
.value_counts()
675
).to_dict(as_series=False) == {"x": [3], "count": [1004]}
676
677
678
@pytest.mark.parametrize(
679
("in_dtype", "out_dtype"),
680
[
681
(pl.Boolean, pl.Float64),
682
(pl.UInt8, pl.Float64),
683
(pl.UInt16, pl.Float64),
684
(pl.UInt32, pl.Float64),
685
(pl.UInt64, pl.Float64),
686
(pl.Int8, pl.Float64),
687
(pl.Int16, pl.Float64),
688
(pl.Int32, pl.Float64),
689
(pl.Int64, pl.Float64),
690
(pl.Float32, pl.Float32),
691
(pl.Float64, pl.Float64),
692
],
693
)
694
def test_horizontal_mean_single_column(
695
in_dtype: PolarsDataType,
696
out_dtype: PolarsDataType,
697
) -> None:
698
out = (
699
pl.LazyFrame({"a": pl.Series([1, 0]).cast(in_dtype)})
700
.select(pl.mean_horizontal(pl.all()))
701
.collect()
702
)
703
704
assert_frame_equal(out, pl.DataFrame({"a": pl.Series([1.0, 0.0]).cast(out_dtype)}))
705
706
707
def test_horizontal_mean_in_group_by_15115() -> None:
708
nbr_records = 1000
709
out = (
710
pl.LazyFrame(
711
{
712
"w": [None, "one", "two", "three"] * nbr_records,
713
"x": [None, None, "two", "three"] * nbr_records,
714
"y": [None, None, None, "three"] * nbr_records,
715
"z": [None, None, None, None] * nbr_records,
716
}
717
)
718
.select(pl.mean_horizontal(pl.all().is_null()).alias("mean_null"))
719
.group_by("mean_null")
720
.len()
721
.sort(by="mean_null")
722
.collect()
723
)
724
assert_frame_equal(
725
out,
726
pl.DataFrame(
727
{
728
"mean_null": pl.Series([0.25, 0.5, 0.75, 1.0], dtype=pl.Float64),
729
"len": pl.Series([nbr_records] * 4, dtype=pl.get_index_type()),
730
}
731
),
732
)
733
734
735
def test_group_count_over_null_column_15705() -> None:
736
df = pl.DataFrame(
737
{"a": [1, 1, 2, 2, 3, 3], "c": [None, None, None, None, None, None]}
738
)
739
out = df.group_by("a", maintain_order=True).agg(pl.col("c").count())
740
assert out["c"].to_list() == [0, 0, 0]
741
742
743
@pytest.mark.release
744
def test_min_max_2850() -> None:
745
# https://github.com/pola-rs/polars/issues/2850
746
df = pl.DataFrame(
747
{
748
"id": [
749
130352432,
750
130352277,
751
130352611,
752
130352833,
753
130352305,
754
130352258,
755
130352764,
756
130352475,
757
130352368,
758
130352346,
759
]
760
}
761
)
762
763
minimum = 130352258
764
maximum = 130352833.0
765
766
for _ in range(10):
767
permuted = df.sample(fraction=1.0, seed=0)
768
computed = permuted.select(
769
pl.col("id").min().alias("min"), pl.col("id").max().alias("max")
770
)
771
assert cast("int", computed[0, "min"]) == minimum
772
assert cast("float", computed[0, "max"]) == maximum
773
774
775
def test_multi_arg_structify_15834() -> None:
776
df = pl.DataFrame(
777
{
778
"group": [1, 2, 1, 2],
779
"value": [
780
0.1973209146402105,
781
0.13380719982405365,
782
0.6152394463707009,
783
0.4558767896005155,
784
],
785
}
786
)
787
788
assert df.lazy().group_by("group").agg(
789
pl.struct(a=1, value=pl.col("value").sum())
790
).collect().sort("group").to_dict(as_series=False) == {
791
"group": [1, 2],
792
"a": [
793
{"a": 1, "value": 0.8125603610109114},
794
{"a": 1, "value": 0.5896839894245691},
795
],
796
}
797
798
799
def test_filter_aggregation_16642() -> None:
800
df = pl.DataFrame(
801
{
802
"datetime": [
803
datetime(2022, 1, 1, 11, 0),
804
datetime(2022, 1, 1, 11, 1),
805
datetime(2022, 1, 1, 11, 2),
806
datetime(2022, 1, 1, 11, 3),
807
datetime(2022, 1, 1, 11, 4),
808
datetime(2022, 1, 1, 11, 5),
809
datetime(2022, 1, 1, 11, 6),
810
datetime(2022, 1, 1, 11, 7),
811
datetime(2022, 1, 1, 11, 8),
812
datetime(2022, 1, 1, 11, 9, 1),
813
datetime(2022, 1, 2, 11, 0),
814
datetime(2022, 1, 2, 11, 1),
815
datetime(2022, 1, 2, 11, 2),
816
datetime(2022, 1, 2, 11, 3),
817
datetime(2022, 1, 2, 11, 4),
818
datetime(2022, 1, 2, 11, 5),
819
datetime(2022, 1, 2, 11, 6),
820
datetime(2022, 1, 2, 11, 7),
821
datetime(2022, 1, 2, 11, 8),
822
datetime(2022, 1, 2, 11, 9, 1),
823
],
824
"alpha": [
825
"A",
826
"B",
827
"C",
828
"D",
829
"E",
830
"F",
831
"G",
832
"H",
833
"I",
834
"J",
835
"A",
836
"B",
837
"C",
838
"D",
839
"E",
840
"F",
841
"G",
842
"H",
843
"I",
844
"J",
845
],
846
"num": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
847
}
848
)
849
grouped = df.group_by(pl.col("datetime").dt.date())
850
851
ts_filter = pl.col("datetime").dt.time() <= pl.time(11, 3)
852
853
report = grouped.agg(pl.col("num").filter(ts_filter).max()).sort("datetime")
854
assert report.to_dict(as_series=False) == {
855
"datetime": [date(2022, 1, 1), date(2022, 1, 2)],
856
"num": [3, 3],
857
}
858
859
860
def test_sort_by_over_single_nulls_first() -> None:
861
key = [0, 0, 0, 0, 1, 1, 1, 1]
862
df = pl.DataFrame(
863
{
864
"key": key,
865
"value": [2, None, 1, 0, 2, None, 1, 0],
866
}
867
)
868
out = df.select(
869
pl.all().sort_by("value", nulls_last=False, maintain_order=True).over("key")
870
)
871
expected = pl.DataFrame(
872
{
873
"key": key,
874
"value": [None, 0, 1, 2, None, 0, 1, 2],
875
}
876
)
877
assert_frame_equal(out, expected)
878
879
880
def test_sort_by_over_single_nulls_last() -> None:
881
key = [0, 0, 0, 0, 1, 1, 1, 1]
882
df = pl.DataFrame(
883
{
884
"key": key,
885
"value": [2, None, 1, 0, 2, None, 1, 0],
886
}
887
)
888
out = df.select(
889
pl.all().sort_by("value", nulls_last=True, maintain_order=True).over("key")
890
)
891
expected = pl.DataFrame(
892
{
893
"key": key,
894
"value": [0, 1, 2, None, 0, 1, 2, None],
895
}
896
)
897
assert_frame_equal(out, expected)
898
899
900
def test_sort_by_over_multiple_nulls_first() -> None:
901
key1 = [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]
902
key2 = [0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1]
903
df = pl.DataFrame(
904
{
905
"key1": key1,
906
"key2": key2,
907
"value": [1, None, 0, 1, None, 0, 1, None, 0, None, 1, 0],
908
}
909
)
910
out = df.select(
911
pl.all()
912
.sort_by("value", nulls_last=False, maintain_order=True)
913
.over("key1", "key2")
914
)
915
expected = pl.DataFrame(
916
{
917
"key1": key1,
918
"key2": key2,
919
"value": [None, 0, 1, None, 0, 1, None, 0, 1, None, 0, 1],
920
}
921
)
922
assert_frame_equal(out, expected)
923
924
925
def test_sort_by_over_multiple_nulls_last() -> None:
926
key1 = [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]
927
key2 = [0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1]
928
df = pl.DataFrame(
929
{
930
"key1": key1,
931
"key2": key2,
932
"value": [1, None, 0, 1, None, 0, 1, None, 0, None, 1, 0],
933
}
934
)
935
out = df.select(
936
pl.all()
937
.sort_by("value", nulls_last=True, maintain_order=True)
938
.over("key1", "key2")
939
)
940
expected = pl.DataFrame(
941
{
942
"key1": key1,
943
"key2": key2,
944
"value": [0, 1, None, 0, 1, None, 0, 1, None, 0, 1, None],
945
}
946
)
947
assert_frame_equal(out, expected)
948
949
950
def test_slice_after_agg() -> None:
951
assert_frame_equal(
952
pl.select(a=pl.lit(1, dtype=pl.Int64), b=pl.lit(1, dtype=pl.Int64))
953
.group_by("a")
954
.agg(pl.col("b").first().slice(99, 0)),
955
pl.DataFrame({"a": [1], "b": [[]]}, schema_overrides={"b": pl.List(pl.Int64)}),
956
)
957
958
959
def test_agg_scalar_empty_groups_20115() -> None:
960
assert_frame_equal(
961
(
962
pl.DataFrame({"key": [123], "value": [456]})
963
.group_by("key")
964
.agg(pl.col("value").slice(1, 1).first())
965
),
966
pl.select(key=pl.lit(123, pl.Int64), value=pl.lit(None, pl.Int64)),
967
)
968
969
970
def test_agg_expr_returns_list_type_15574() -> None:
971
assert (
972
pl.LazyFrame({"a": [1, None], "b": [1, 2]})
973
.group_by("b")
974
.agg(pl.col("a").drop_nulls())
975
.collect_schema()
976
) == {"b": pl.Int64, "a": pl.List(pl.Int64)}
977
978
979
def test_empty_agg_22005() -> None:
980
out = (
981
pl.concat([pl.LazyFrame({"a": [1, 2]}), pl.LazyFrame({"a": [1, 2]})])
982
.limit(0)
983
.select(pl.col("a").sum())
984
)
985
assert_frame_equal(out.collect(), pl.DataFrame({"a": 0}))
986
987
988
@pytest.mark.parametrize("wrap_numerical", [True, False])
989
@pytest.mark.parametrize("strict_cast", [True, False])
990
def test_agg_with_filter_then_cast_23682(
991
strict_cast: bool, wrap_numerical: bool
992
) -> None:
993
assert_frame_equal(
994
pl.DataFrame([{"a": 123, "b": 12}, {"a": 123, "b": 257}])
995
.group_by("a")
996
.agg(
997
pl.col("b")
998
.filter(pl.col("b") < 256)
999
.cast(pl.UInt8, strict=strict_cast, wrap_numerical=wrap_numerical)
1000
),
1001
pl.DataFrame(
1002
[{"a": 123, "b": [12]}], schema={"a": pl.Int64, "b": pl.List(pl.UInt8)}
1003
),
1004
)
1005
1006
1007
@pytest.mark.parametrize("wrap_numerical", [True, False])
1008
@pytest.mark.parametrize("strict_cast", [True, False])
1009
def test_agg_with_slice_then_cast_23682(
1010
strict_cast: bool, wrap_numerical: bool
1011
) -> None:
1012
assert_frame_equal(
1013
pl.DataFrame([{"a": 123, "b": 12}, {"a": 123, "b": 257}])
1014
.group_by("a")
1015
.agg(
1016
pl.col("b")
1017
.slice(0, 1)
1018
.cast(pl.UInt8, strict=strict_cast, wrap_numerical=wrap_numerical)
1019
),
1020
pl.DataFrame(
1021
[{"a": 123, "b": [12]}], schema={"a": pl.Int64, "b": pl.List(pl.UInt8)}
1022
),
1023
)
1024
1025
1026
@pytest.mark.parametrize(
1027
("op", "expr"),
1028
[
1029
("any", pl.all().cast(pl.Boolean).any()),
1030
("all", pl.all().cast(pl.Boolean).all()),
1031
("arg_max", pl.all().arg_max()),
1032
("arg_min", pl.all().arg_min()),
1033
("min", pl.all().min()),
1034
("max", pl.all().max()),
1035
("mean", pl.all().mean()),
1036
("median", pl.all().median()),
1037
("product", pl.all().product()),
1038
("quantile", pl.all().quantile(0.5)),
1039
("std", pl.all().std()),
1040
("var", pl.all().var()),
1041
("sum", pl.all().sum()),
1042
("first", pl.all().first()),
1043
("last", pl.all().last()),
1044
("approx_n_unique", pl.all().approx_n_unique()),
1045
("bitwise_and", pl.all().bitwise_and()),
1046
("bitwise_or", pl.all().bitwise_or()),
1047
("bitwise_xor", pl.all().bitwise_xor()),
1048
],
1049
)
1050
@pytest.mark.parametrize(
1051
"df",
1052
[
1053
pl.DataFrame({"a": [[10]]}, schema={"a": pl.Array(shape=(1,), inner=pl.Int32)}),
1054
pl.DataFrame({"a": [[1]]}, schema={"a": pl.Struct(fields={"a": pl.Int32})}),
1055
pl.DataFrame({"a": [True]}, schema={"a": pl.Boolean}),
1056
pl.DataFrame({"a": ["a"]}, schema={"a": pl.Categorical}),
1057
pl.DataFrame({"a": [b"a"]}, schema={"a": pl.Binary}),
1058
pl.DataFrame({"a": ["a"]}, schema={"a": pl.Utf8}),
1059
pl.DataFrame({"a": [10]}, schema={"a": pl.Int32}),
1060
pl.DataFrame({"a": [10]}, schema={"a": pl.Float16}),
1061
pl.DataFrame({"a": [10]}, schema={"a": pl.Float32}),
1062
pl.DataFrame({"a": [10]}, schema={"a": pl.Float64}),
1063
pl.DataFrame({"a": [10]}, schema={"a": pl.Int128}),
1064
pl.DataFrame({"a": [10]}, schema={"a": pl.UInt128}),
1065
pl.DataFrame({"a": ["a"]}, schema={"a": pl.String}),
1066
pl.DataFrame({"a": [None]}, schema={"a": pl.Null}),
1067
pl.DataFrame({"a": [10]}, schema={"a": pl.Decimal()}),
1068
pl.DataFrame({"a": [datetime.now()]}, schema={"a": pl.Datetime}),
1069
pl.DataFrame({"a": [date.today()]}, schema={"a": pl.Date}),
1070
pl.DataFrame({"a": [timedelta(seconds=10)]}, schema={"a": pl.Duration}),
1071
],
1072
)
1073
def test_agg_invalid_same_engines_behavior(
1074
op: str, expr: pl.Expr, df: pl.DataFrame
1075
) -> None:
1076
# If the in-memory engine produces a good result, then the streaming engine
1077
# should also produce a good result, and then it should match the in-memory result.
1078
1079
if isinstance(df.schema["a"], pl.Struct) and op in {"any", "all"}:
1080
# TODO: Remove this exception when #24509 is resolved
1081
pytest.skip("polars/#24509")
1082
1083
if isinstance(df.schema["a"], pl.Duration) and op in {"std", "var"}:
1084
# TODO: Remove this exception when std & var are implemented for Duration
1085
pytest.skip(f"'{op}' aggregation not yet implemented for Duration")
1086
1087
inmemory_result, inmemory_error = None, None
1088
streaming_result, streaming_error = None, None
1089
1090
try:
1091
inmemory_result = df.select(expr)
1092
except pl.exceptions.PolarsError as e:
1093
inmemory_error = e
1094
1095
try:
1096
streaming_result = df.lazy().select(expr).collect(engine="streaming")
1097
except pl.exceptions.PolarsError as e:
1098
streaming_error = e
1099
1100
assert (streaming_error is None) == (inmemory_error is None), (
1101
f"mismatch in errors for: {streaming_error} != {inmemory_error}"
1102
)
1103
if inmemory_error:
1104
assert streaming_error, (
1105
f"streaming engine did not error (expected in-memory error: {inmemory_error})"
1106
)
1107
assert streaming_error.__class__ == inmemory_error.__class__
1108
1109
if not inmemory_error:
1110
assert streaming_result is not None
1111
assert inmemory_result is not None
1112
assert_frame_equal(streaming_result, inmemory_result)
1113
1114
1115
@pytest.mark.parametrize(
1116
("op", "expr"),
1117
[
1118
("sum", pl.all().sum()),
1119
("mean", pl.all().mean()),
1120
("median", pl.all().median()),
1121
("std", pl.all().std()),
1122
("var", pl.all().var()),
1123
("quantile", pl.all().quantile(0.5)),
1124
("cum_sum", pl.all().cum_sum()),
1125
],
1126
)
1127
@pytest.mark.parametrize(
1128
"df",
1129
[
1130
pl.DataFrame({"a": [[10]]}, schema={"a": pl.Array(shape=(1), inner=pl.Int32)}),
1131
pl.DataFrame({"a": [[1]]}, schema={"a": pl.Struct(fields={"a": pl.Int32})}),
1132
pl.DataFrame({"a": ["a"]}, schema={"a": pl.Categorical}),
1133
pl.DataFrame({"a": [b"a"]}, schema={"a": pl.Binary}),
1134
pl.DataFrame({"a": ["a"]}, schema={"a": pl.Utf8}),
1135
pl.DataFrame({"a": ["a"]}, schema={"a": pl.String}),
1136
],
1137
)
1138
def test_invalid_agg_dtypes_should_raise(
1139
op: str, expr: pl.Expr, df: pl.DataFrame
1140
) -> None:
1141
with pytest.raises(
1142
pl.exceptions.PolarsError, match=rf"`{op}` operation not supported for dtype"
1143
):
1144
df.select(expr)
1145
with pytest.raises(
1146
pl.exceptions.PolarsError, match=rf"`{op}` operation not supported for dtype"
1147
):
1148
df.lazy().select(expr).collect(engine="streaming")
1149
1150
1151
@given(
1152
df=dataframes(
1153
min_size=1,
1154
max_size=1,
1155
excluded_dtypes=[
1156
# TODO: polars/#24936
1157
pl.Struct,
1158
],
1159
)
1160
)
1161
def test_single(df: pl.DataFrame) -> None:
1162
q = df.lazy().select(pl.all(ignore_nulls=False).item())
1163
assert_frame_equal(q.collect(), df)
1164
assert_frame_equal(q.collect(engine="streaming"), df)
1165
1166
1167
@given(df=dataframes(max_size=0))
1168
def test_single_empty(df: pl.DataFrame) -> None:
1169
q = df.lazy().select(pl.all().item())
1170
match = "aggregation 'item' expected a single value, got none"
1171
with pytest.raises(pl.exceptions.ComputeError, match=match):
1172
q.collect()
1173
with pytest.raises(pl.exceptions.ComputeError, match=match):
1174
q.collect(engine="streaming")
1175
1176
1177
@given(df=dataframes(min_size=2))
1178
def test_item_too_many(df: pl.DataFrame) -> None:
1179
q = df.lazy().select(pl.all(ignore_nulls=False).item())
1180
match = f"aggregation 'item' expected a single value, got {df.height} values"
1181
with pytest.raises(pl.exceptions.ComputeError, match=match):
1182
q.collect()
1183
with pytest.raises(pl.exceptions.ComputeError, match=match):
1184
q.collect(engine="streaming")
1185
1186
1187
@given(
1188
df=dataframes(
1189
min_size=1,
1190
max_size=1,
1191
allow_null=False,
1192
excluded_dtypes=[
1193
# TODO: polars/#24936
1194
pl.Struct,
1195
],
1196
)
1197
)
1198
def test_item_on_groups(df: pl.DataFrame) -> None:
1199
df = df.with_columns(pl.col("col0").alias("key"))
1200
q = df.lazy().group_by("col0").agg(pl.all(ignore_nulls=False).item())
1201
assert_frame_equal(q.collect(), df)
1202
assert_frame_equal(q.collect(engine="streaming"), df)
1203
1204
1205
def test_item_on_groups_empty() -> None:
1206
df = pl.DataFrame({"col0": [[]]})
1207
q = df.lazy().select(pl.all().list.item())
1208
match = "aggregation 'item' expected a single value, got none"
1209
with pytest.raises(pl.exceptions.ComputeError, match=match):
1210
q.collect()
1211
with pytest.raises(pl.exceptions.ComputeError, match=match):
1212
q.collect(engine="streaming")
1213
1214
1215
def test_item_on_groups_too_many() -> None:
1216
df = pl.DataFrame({"col0": [[1, 2, 3]]})
1217
q = df.lazy().select(pl.all().list.item())
1218
match = "aggregation 'item' expected a single value, got 3 values"
1219
with pytest.raises(pl.exceptions.ComputeError, match=match):
1220
q.collect()
1221
with pytest.raises(pl.exceptions.ComputeError, match=match):
1222
q.collect(engine="streaming")
1223
1224
1225
def test_all_any_on_list_raises_error() -> None:
1226
# Ensure boolean reductions on non-boolean columns raise an error.
1227
# (regression for #24942).
1228
lf = pl.LazyFrame({"x": [[True]]}, schema={"x": pl.List(pl.Boolean)})
1229
1230
# for in-memory engine
1231
for expr in (pl.col("x").all(), pl.col("x").any()):
1232
with pytest.raises(
1233
pl.exceptions.InvalidOperationError, match=r"expected boolean"
1234
):
1235
lf.select(expr).collect()
1236
1237
# for streaming engine
1238
for expr in (pl.col("x").all(), pl.col("x").any()):
1239
with pytest.raises(
1240
pl.exceptions.InvalidOperationError, match=r"expected boolean"
1241
):
1242
lf.select(expr).collect(engine="streaming")
1243
1244
1245
@pytest.mark.parametrize("null_endpoints", [True, False])
1246
@pytest.mark.parametrize("ignore_nulls", [True, False])
1247
@pytest.mark.parametrize(
1248
("dtype", "first_value", "last_value"),
1249
[
1250
# Struct
1251
(
1252
pl.Struct({"x": pl.Enum(["c0", "c1"]), "y": pl.Float32}),
1253
{"x": "c0", "y": 1.2},
1254
{"x": "c1", "y": 3.4},
1255
),
1256
# List
1257
(pl.List(pl.UInt8), [1], [2]),
1258
# Array
1259
(pl.Array(pl.Int16, 2), [1, 2], [3, 4]),
1260
# Date (logical test)
1261
(pl.Date, date(2025, 1, 1), date(2025, 1, 2)),
1262
# Float (primitive test)
1263
(pl.Float32, 1.0, 2.0),
1264
],
1265
)
1266
def test_first_last_nested(
1267
null_endpoints: bool,
1268
ignore_nulls: bool,
1269
dtype: PolarsDataType,
1270
first_value: Any,
1271
last_value: Any,
1272
) -> None:
1273
s = pl.Series([first_value, last_value], dtype=dtype)
1274
if null_endpoints:
1275
# Test the case where the first/last value is null
1276
null = pl.Series([None], dtype=dtype)
1277
s = pl.concat((null, s, null))
1278
1279
lf = pl.LazyFrame({"a": s})
1280
1281
# first
1282
result = lf.select(pl.col("a").first(ignore_nulls=ignore_nulls)).collect()
1283
expected = pl.DataFrame(
1284
{
1285
"a": pl.Series(
1286
[None if null_endpoints and not ignore_nulls else first_value],
1287
dtype=dtype,
1288
)
1289
}
1290
)
1291
assert_frame_equal(result, expected)
1292
1293
# last
1294
result = lf.select(pl.col("a").last(ignore_nulls=ignore_nulls)).collect()
1295
expected = pl.DataFrame(
1296
{
1297
"a": pl.Series(
1298
[None if null_endpoints and not ignore_nulls else last_value],
1299
dtype=dtype,
1300
),
1301
}
1302
)
1303
assert_frame_equal(result, expected)
1304
1305
1306
def test_struct_enum_agg_streaming_24936() -> None:
1307
s = (
1308
pl.Series(
1309
"a",
1310
[{"f0": "c0"}],
1311
dtype=pl.Struct({"f0": pl.Enum(categories=["c0"])}),
1312
),
1313
)
1314
df = pl.DataFrame(s)
1315
1316
q = df.lazy().select(pl.all(ignore_nulls=False).first())
1317
assert_frame_equal(q.collect(), df)
1318
1319
1320
def test_sum_inf_not_nan_25849() -> None:
1321
data = [10.0, None, 10.0, 10.0, 10.0, 10.0, float("inf"), 10.0, 10.0]
1322
df = pl.DataFrame({"x": data, "g": ["X"] * len(data)})
1323
assert df.group_by("g").agg(pl.col("x").sum())["x"].item() == float("inf")
1324
1325
1326
COLS = ["flt", "dec", "int", "str", "cat", "enum", "date", "dt"]
1327
1328
1329
@pytest.mark.parametrize(
1330
"agg_funcs", [(pl.Expr.min_by, pl.Expr.min), (pl.Expr.max_by, pl.Expr.max)]
1331
)
1332
@pytest.mark.parametrize("by_col", COLS)
1333
def test_min_max_by(agg_funcs: Any, by_col: str) -> None:
1334
agg_by, agg = agg_funcs
1335
df = pl.DataFrame(
1336
{
1337
"flt": [3.0, 2.0, float("nan"), 5.0, None, 4.0],
1338
"dec": [3, 2, None, 5, None, 4],
1339
"int": [3, 2, None, 5, None, 4],
1340
"str": ["c", "b", None, "e", None, "d"],
1341
"cat": ["c", "b", None, "e", None, "d"],
1342
"enum": ["c", "b", None, "e", None, "d"],
1343
"date": [
1344
date(2023, 3, 3),
1345
date(2023, 2, 2),
1346
None,
1347
date(2023, 5, 5),
1348
None,
1349
date(2023, 4, 4),
1350
],
1351
"dt": [
1352
datetime(2023, 3, 3),
1353
datetime(2023, 2, 2),
1354
None,
1355
datetime(2023, 5, 5),
1356
None,
1357
datetime(2023, 4, 4),
1358
],
1359
"g": [1, 1, 1, 2, 2, 2],
1360
},
1361
schema_overrides={
1362
"dec": pl.Decimal(scale=5),
1363
"cat": pl.Categorical,
1364
"enum": pl.Enum(["a", "b", "c", "d", "e", "f"]),
1365
},
1366
)
1367
1368
result = df.select([agg_by(pl.col(c), pl.col(by_col)) for c in COLS])
1369
expected = df.select([agg(pl.col(c)) for c in COLS])
1370
assert_frame_equal(result, expected)
1371
1372
# TODO: remove after https://github.com/pola-rs/polars/issues/25906.
1373
if by_col != "cat":
1374
df = df.drop("cat")
1375
cols = [c for c in COLS if c != "cat"]
1376
1377
result = df.group_by("g").agg([agg_by(pl.col(c), pl.col(by_col)) for c in cols])
1378
expected = df.group_by("g").agg([agg(pl.col(c)) for c in cols])
1379
assert_frame_equal(result, expected, check_row_order=False)
1380
1381
1382
@pytest.mark.parametrize(("agg", "expected"), [("max", 2), ("min", 0)])
1383
def test_grouped_minmax_after_reverse_on_sorted_column_26141(
1384
agg: str, expected: int
1385
) -> None:
1386
df = pl.DataFrame({"a": [0, 1, 2]}).sort("a")
1387
1388
expr = getattr(pl.col("a").reverse(), agg)()
1389
out = df.group_by(1).agg(expr)
1390
1391
expected_df = pl.DataFrame(
1392
{
1393
"literal": pl.Series([1], dtype=pl.Int32),
1394
"a": [expected],
1395
}
1396
)
1397
assert_frame_equal(out, expected_df)
1398
1399
1400
@pytest.mark.may_fail_auto_streaming
1401
@pytest.mark.parametrize("agg_by", [pl.Expr.min_by, pl.Expr.max_by])
1402
def test_min_max_by_series_length_mismatch_26049(
1403
agg_by: Callable[[pl.Expr, pl.Expr], pl.Expr],
1404
) -> None:
1405
lf = pl.LazyFrame(
1406
{
1407
"a": [0, 10, 20, 30, 40, 50, 60, 70, 80, 90],
1408
"b": [18, 5, 8, 8, 4, 5, 6, 8, 1, -10],
1409
"group": ["A", "A", "A", "A", "A", "B", "B", "C", "C", "C"],
1410
}
1411
)
1412
1413
q = lf.with_columns(
1414
agg_by(pl.col("group").filter(pl.col("b") % 2 == 0), pl.col("a"))
1415
)
1416
1417
with pytest.raises(
1418
pl.exceptions.ShapeError,
1419
match=r"^'by' column in (min|max)_by expression has incorrect length: expected \d+, got \d+$",
1420
):
1421
q.collect(engine="in-memory")
1422
with pytest.raises(
1423
pl.exceptions.ShapeError,
1424
match=r"^zip node received non-equal length inputs$",
1425
):
1426
q.collect(engine="streaming")
1427
1428
actual = (
1429
lf.group_by("group")
1430
.agg(
1431
pl.col("a")
1432
.max_by(pl.col("b").filter(pl.col("b") < 20).abs())
1433
.alias("max_by")
1434
)
1435
.sort("group")
1436
).collect()
1437
expected = pl.DataFrame(
1438
{
1439
"group": ["A", "B", "C"],
1440
"max_by": [0, 60, 90],
1441
}
1442
)
1443
assert_frame_equal(actual, expected)
1444
1445
q = (
1446
lf.group_by("group")
1447
.agg(
1448
pl.col("a")
1449
.max_by(pl.col("b").filter(pl.col("b") < 7).abs())
1450
.alias("group_length_mismatch")
1451
)
1452
.sort("group")
1453
)
1454
with pytest.raises(
1455
pl.exceptions.ShapeError,
1456
match=r"^expressions must have matching group lengths$",
1457
):
1458
q.collect(engine="in-memory")
1459
1460
1461
@pytest.mark.parametrize(
1462
"by_expr",
1463
[
1464
pl.struct("b", "c"),
1465
pl.concat_list("b", "c"),
1466
],
1467
)
1468
def test_min_by_max_by_nested_type_key_26268(by_expr: pl.Expr) -> None:
1469
df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 6, 5], "c": [7, 5, 2]})
1470
1471
with pytest.raises(
1472
pl.exceptions.InvalidOperationError,
1473
match="cannot use a nested type as `by` argument in `min_by`/`max_by`",
1474
):
1475
df.select(pl.col("a").min_by(by_expr))
1476
1477