Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/operations/aggregation/test_aggregations.py
6940 views
1
from __future__ import annotations
2
3
from datetime import date, datetime, timedelta
4
from typing import TYPE_CHECKING, Any, cast
5
6
import numpy as np
7
import pytest
8
9
import polars as pl
10
from polars.exceptions import InvalidOperationError
11
from polars.testing import assert_frame_equal
12
13
if TYPE_CHECKING:
14
import numpy.typing as npt
15
16
from polars._typing import PolarsDataType
17
18
19
def test_quantile_expr_input() -> None:
20
df = pl.DataFrame({"a": [1, 2, 3, 4, 5], "b": [0.0, 0.0, 0.3, 0.2, 0.0]})
21
22
assert_frame_equal(
23
df.select([pl.col("a").quantile(pl.col("b").sum() + 0.1)]),
24
df.select(pl.col("a").quantile(0.6)),
25
)
26
27
28
def test_boolean_aggs() -> None:
29
df = pl.DataFrame({"bool": [True, False, None, True]})
30
31
aggs = [
32
pl.mean("bool").alias("mean"),
33
pl.std("bool").alias("std"),
34
pl.var("bool").alias("var"),
35
]
36
assert df.select(aggs).to_dict(as_series=False) == {
37
"mean": [0.6666666666666666],
38
"std": [0.5773502691896258],
39
"var": [0.33333333333333337],
40
}
41
42
assert df.group_by(pl.lit(1)).agg(aggs).to_dict(as_series=False) == {
43
"literal": [1],
44
"mean": [0.6666666666666666],
45
"std": [0.5773502691896258],
46
"var": [0.33333333333333337],
47
}
48
49
50
def test_duration_aggs() -> None:
51
df = pl.DataFrame(
52
{
53
"time1": pl.datetime_range(
54
start=datetime(2022, 12, 12),
55
end=datetime(2022, 12, 18),
56
interval="1d",
57
eager=True,
58
),
59
"time2": pl.datetime_range(
60
start=datetime(2023, 1, 12),
61
end=datetime(2023, 1, 18),
62
interval="1d",
63
eager=True,
64
),
65
}
66
)
67
68
df = df.with_columns((pl.col("time2") - pl.col("time1")).alias("time_difference"))
69
70
assert df.select("time_difference").mean().to_dict(as_series=False) == {
71
"time_difference": [timedelta(days=31)]
72
}
73
assert df.group_by(pl.lit(1)).agg(pl.mean("time_difference")).to_dict(
74
as_series=False
75
) == {
76
"literal": [1],
77
"time_difference": [timedelta(days=31)],
78
}
79
80
81
def test_list_aggregation_that_filters_all_data_6017() -> None:
82
out = (
83
pl.DataFrame({"col_to_group_by": [2], "flt": [1672740910.967138], "col3": [1]})
84
.group_by("col_to_group_by")
85
.agg((pl.col("flt").filter(col3=0).diff() * 1000).diff().alias("calc"))
86
)
87
88
assert out.schema == {"col_to_group_by": pl.Int64, "calc": pl.List(pl.Float64)}
89
assert out.to_dict(as_series=False) == {"col_to_group_by": [2], "calc": [[]]}
90
91
92
def test_median() -> None:
93
s = pl.Series([1, 2, 3])
94
assert s.median() == 2
95
96
97
def test_single_element_std() -> None:
98
s = pl.Series([1])
99
assert s.std(ddof=1) is None
100
assert s.std(ddof=0) == 0.0
101
102
103
def test_quantile() -> None:
104
s = pl.Series([1, 2, 3])
105
assert s.quantile(0.5, "nearest") == 2
106
assert s.quantile(0.5, "lower") == 2
107
assert s.quantile(0.5, "higher") == 2
108
109
110
@pytest.mark.slow
111
@pytest.mark.parametrize("tp", [int, float])
112
@pytest.mark.parametrize("n", [1, 2, 10, 100])
113
def test_quantile_vs_numpy(tp: type, n: int) -> None:
114
a: np.ndarray[Any, Any] = np.random.randint(0, 50, n).astype(tp)
115
np_result: npt.ArrayLike | None = np.median(a)
116
# nan check
117
if np_result != np_result:
118
np_result = None
119
median = pl.Series(a).median()
120
if median is not None:
121
assert np.isclose(median, np_result) # type: ignore[arg-type]
122
else:
123
assert np_result is None
124
125
q = np.random.sample()
126
try:
127
np_result = np.quantile(a, q)
128
except IndexError:
129
np_result = None
130
if np_result:
131
# nan check
132
if np_result != np_result:
133
np_result = None
134
assert np.isclose(
135
pl.Series(a).quantile(q, interpolation="linear"), # type: ignore[arg-type]
136
np_result, # type: ignore[arg-type]
137
)
138
139
140
def test_mean_overflow() -> None:
141
assert np.isclose(
142
pl.Series([9_223_372_036_854_775_800, 100]).mean(), # type: ignore[arg-type]
143
4.611686018427388e18,
144
)
145
146
147
def test_mean_null_simd() -> None:
148
for dtype in [int, float]:
149
df = (
150
pl.Series(np.random.randint(0, 100, 1000))
151
.cast(dtype)
152
.to_frame("a")
153
.select(pl.when(pl.col("a") > 40).then(pl.col("a")))
154
)
155
156
s = df["a"]
157
assert s.mean() == s.to_pandas().mean()
158
159
160
def test_literal_group_agg_chunked_7968() -> None:
161
df = pl.DataFrame({"A": [1, 1], "B": [1, 3]})
162
ser = pl.concat([pl.Series([3]), pl.Series([4, 5])], rechunk=False)
163
164
assert_frame_equal(
165
df.group_by("A").agg(pl.col("B").search_sorted(ser)),
166
pl.DataFrame(
167
[
168
pl.Series("A", [1], dtype=pl.Int64),
169
pl.Series("B", [[1, 2, 2]], dtype=pl.List(pl.UInt32)),
170
]
171
),
172
)
173
174
175
def test_duration_function_literal() -> None:
176
df = pl.DataFrame(
177
{
178
"A": ["x", "x", "y", "y", "y"],
179
"T": pl.datetime_range(
180
date(2022, 1, 1), date(2022, 5, 1), interval="1mo", eager=True
181
),
182
"S": [1, 2, 4, 8, 16],
183
}
184
)
185
186
result = df.group_by("A", maintain_order=True).agg(
187
(pl.col("T").max() + pl.duration(seconds=1)) - pl.col("T")
188
)
189
190
# this checks if the `pl.duration` is flagged as AggState::Literal
191
expected = pl.DataFrame(
192
{
193
"A": ["x", "y"],
194
"T": [
195
[timedelta(days=31, seconds=1), timedelta(seconds=1)],
196
[
197
timedelta(days=61, seconds=1),
198
timedelta(days=30, seconds=1),
199
timedelta(seconds=1),
200
],
201
],
202
}
203
)
204
assert_frame_equal(result, expected)
205
206
207
def test_string_par_materialize_8207() -> None:
208
df = pl.LazyFrame(
209
{
210
"a": ["a", "b", "d", "c", "e"],
211
"b": ["P", "L", "R", "T", "a long string"],
212
}
213
)
214
215
assert df.group_by(["a"]).agg(pl.min("b")).sort("a").collect().to_dict(
216
as_series=False
217
) == {
218
"a": ["a", "b", "c", "d", "e"],
219
"b": ["P", "L", "T", "R", "a long string"],
220
}
221
222
223
def test_online_variance() -> None:
224
df = pl.DataFrame(
225
{
226
"id": [1] * 5,
227
"no_nulls": [1, 2, 3, 4, 5],
228
"nulls": [1, None, 3, None, 5],
229
}
230
)
231
232
assert_frame_equal(
233
df.group_by("id")
234
.agg(pl.all().exclude("id").std())
235
.select(["no_nulls", "nulls"]),
236
df.select(pl.all().exclude("id").std()),
237
)
238
239
240
def test_implode_and_agg() -> None:
241
df = pl.DataFrame({"type": ["water", "fire", "water", "earth"]})
242
243
# this would OOB
244
with pytest.raises(
245
InvalidOperationError,
246
match=r"'implode' followed by an aggregation is not allowed",
247
):
248
df.group_by("type").agg(pl.col("type").implode().first().alias("foo"))
249
250
# implode + function should be allowed in group_by
251
assert df.group_by("type", maintain_order=True).agg(
252
pl.col("type").implode().list.head().alias("foo")
253
).to_dict(as_series=False) == {
254
"type": ["water", "fire", "earth"],
255
"foo": [["water", "water"], ["fire"], ["earth"]],
256
}
257
assert df.select(pl.col("type").implode().list.head(1).over("type")).to_dict(
258
as_series=False
259
) == {"type": [["water"], ["fire"], ["water"], ["earth"]]}
260
261
262
def test_mapped_literal_to_literal_9217() -> None:
263
df = pl.DataFrame({"unique_id": ["a", "b"]})
264
assert df.group_by(True).agg(
265
pl.struct(pl.lit("unique_id").alias("unique_id"))
266
).to_dict(as_series=False) == {
267
"literal": [True],
268
"unique_id": [{"unique_id": "unique_id"}],
269
}
270
271
272
def test_sum_empty_and_null_set() -> None:
273
series = pl.Series("a", [], dtype=pl.Float32)
274
assert series.sum() == 0
275
276
series = pl.Series("a", [None], dtype=pl.Float32)
277
assert series.sum() == 0
278
279
df = pl.DataFrame(
280
{"a": [None, None, None], "b": [1, 1, 1]},
281
schema={"a": pl.Float32, "b": pl.Int64},
282
)
283
assert df.select(pl.sum("a")).item() == 0.0
284
assert df.group_by("b").agg(pl.sum("a"))["a"].item() == 0.0
285
286
287
def test_horizontal_sum_null_to_identity() -> None:
288
assert pl.DataFrame({"a": [1, 5], "b": [10, None]}).select(
289
pl.sum_horizontal(["a", "b"])
290
).to_series().to_list() == [11, 5]
291
292
293
def test_horizontal_sum_bool_dtype() -> None:
294
out = pl.DataFrame({"a": [True, False]}).select(pl.sum_horizontal("a"))
295
assert_frame_equal(out, pl.DataFrame({"a": pl.Series([1, 0], dtype=pl.UInt32)}))
296
297
298
def test_horizontal_sum_in_group_by_15102() -> None:
299
nbr_records = 1000
300
out = (
301
pl.LazyFrame(
302
{
303
"x": [None, "two", None] * nbr_records,
304
"y": ["one", "two", None] * nbr_records,
305
"z": [None, "two", None] * nbr_records,
306
}
307
)
308
.select(pl.sum_horizontal(pl.all().is_null()).alias("num_null"))
309
.group_by("num_null")
310
.len()
311
.sort(by="num_null")
312
.collect()
313
)
314
assert_frame_equal(
315
out,
316
pl.DataFrame(
317
{
318
"num_null": pl.Series([0, 2, 3], dtype=pl.UInt32),
319
"len": pl.Series([nbr_records] * 3, dtype=pl.UInt32),
320
}
321
),
322
)
323
324
325
def test_first_last_unit_length_12363() -> None:
326
df = pl.DataFrame(
327
{
328
"a": [1, 2],
329
"b": [None, None],
330
}
331
)
332
333
assert df.select(
334
pl.all().drop_nulls().first().name.suffix("_first"),
335
pl.all().drop_nulls().last().name.suffix("_last"),
336
).to_dict(as_series=False) == {
337
"a_first": [1],
338
"b_first": [None],
339
"a_last": [2],
340
"b_last": [None],
341
}
342
343
344
def test_binary_op_agg_context_no_simplify_expr_12423() -> None:
345
expect = pl.DataFrame({"x": [1], "y": [1]}, schema={"x": pl.Int64, "y": pl.Int32})
346
347
for simplify_expression in (True, False):
348
assert_frame_equal(
349
expect,
350
pl.LazyFrame({"x": [1]})
351
.group_by("x")
352
.agg(y=pl.lit(1) * pl.lit(1))
353
.collect(
354
optimizations=pl.QueryOptFlags(simplify_expression=simplify_expression)
355
),
356
)
357
358
359
def test_nan_inf_aggregation() -> None:
360
df = pl.DataFrame(
361
[
362
("both nan", np.nan),
363
("both nan", np.nan),
364
("nan and 5", np.nan),
365
("nan and 5", 5),
366
("nan and null", np.nan),
367
("nan and null", None),
368
("both none", None),
369
("both none", None),
370
("both inf", np.inf),
371
("both inf", np.inf),
372
("inf and null", np.inf),
373
("inf and null", None),
374
],
375
schema=["group", "value"],
376
orient="row",
377
)
378
379
assert_frame_equal(
380
df.group_by("group", maintain_order=True).agg(
381
min=pl.col("value").min(),
382
max=pl.col("value").max(),
383
mean=pl.col("value").mean(),
384
),
385
pl.DataFrame(
386
[
387
("both nan", np.nan, np.nan, np.nan),
388
("nan and 5", 5, 5, np.nan),
389
("nan and null", np.nan, np.nan, np.nan),
390
("both none", None, None, None),
391
("both inf", np.inf, np.inf, np.inf),
392
("inf and null", np.inf, np.inf, np.inf),
393
],
394
schema=["group", "min", "max", "mean"],
395
orient="row",
396
),
397
)
398
399
400
@pytest.mark.parametrize("dtype", [pl.Int16, pl.UInt16])
401
def test_int16_max_12904(dtype: PolarsDataType) -> None:
402
s = pl.Series([None, 1], dtype=dtype)
403
404
assert s.min() == 1
405
assert s.max() == 1
406
407
408
def test_agg_filter_over_empty_df_13610() -> None:
409
ldf = pl.LazyFrame(
410
{
411
"a": [1, 1, 1, 2, 3],
412
"b": [True, True, True, True, True],
413
"c": [None, None, None, None, None],
414
}
415
)
416
417
out = (
418
ldf.drop_nulls()
419
.group_by(["a"], maintain_order=True)
420
.agg(pl.col("b").filter(pl.col("b").shift(1)))
421
.collect()
422
)
423
expected = pl.DataFrame(schema={"a": pl.Int64, "b": pl.List(pl.Boolean)})
424
assert_frame_equal(out, expected)
425
426
df = pl.DataFrame(schema={"a": pl.Int64, "b": pl.Boolean})
427
out = df.group_by("a").agg(pl.col("b").filter(pl.col("b").shift()))
428
expected = pl.DataFrame(schema={"a": pl.Int64, "b": pl.List(pl.Boolean)})
429
assert_frame_equal(out, expected)
430
431
432
@pytest.mark.may_fail_cloud # reason: output order is defined for this in cloud
433
@pytest.mark.slow
434
def test_agg_empty_sum_after_filter_14734() -> None:
435
f = (
436
pl.DataFrame({"a": [1, 2], "b": [1, 2]})
437
.lazy()
438
.group_by("a")
439
.agg(pl.col("b").filter(pl.lit(False)).sum())
440
.collect
441
)
442
443
last = f()
444
445
# We need both possible output orders, which should happen within
446
# 1000 iterations (during testing it usually happens within 10).
447
limit = 1000
448
i = 0
449
while (curr := f()).equals(last):
450
i += 1
451
assert i != limit
452
453
expect = pl.Series("b", [0, 0]).to_frame()
454
assert_frame_equal(expect, last.select("b"))
455
assert_frame_equal(expect, curr.select("b"))
456
457
458
@pytest.mark.slow
459
def test_grouping_hash_14749() -> None:
460
n_groups = 251
461
rows_per_group = 4
462
assert (
463
pl.DataFrame(
464
{
465
"grp": np.repeat(np.arange(n_groups), rows_per_group),
466
"x": np.tile(np.arange(rows_per_group), n_groups),
467
}
468
)
469
.select(pl.col("x").max().over("grp"))["x"]
470
.value_counts()
471
).to_dict(as_series=False) == {"x": [3], "count": [1004]}
472
473
474
@pytest.mark.parametrize(
475
("in_dtype", "out_dtype"),
476
[
477
(pl.Boolean, pl.Float64),
478
(pl.UInt8, pl.Float64),
479
(pl.UInt16, pl.Float64),
480
(pl.UInt32, pl.Float64),
481
(pl.UInt64, pl.Float64),
482
(pl.Int8, pl.Float64),
483
(pl.Int16, pl.Float64),
484
(pl.Int32, pl.Float64),
485
(pl.Int64, pl.Float64),
486
(pl.Float32, pl.Float32),
487
(pl.Float64, pl.Float64),
488
],
489
)
490
def test_horizontal_mean_single_column(
491
in_dtype: PolarsDataType,
492
out_dtype: PolarsDataType,
493
) -> None:
494
out = (
495
pl.LazyFrame({"a": pl.Series([1, 0]).cast(in_dtype)})
496
.select(pl.mean_horizontal(pl.all()))
497
.collect()
498
)
499
500
assert_frame_equal(out, pl.DataFrame({"a": pl.Series([1.0, 0.0]).cast(out_dtype)}))
501
502
503
def test_horizontal_mean_in_group_by_15115() -> None:
504
nbr_records = 1000
505
out = (
506
pl.LazyFrame(
507
{
508
"w": [None, "one", "two", "three"] * nbr_records,
509
"x": [None, None, "two", "three"] * nbr_records,
510
"y": [None, None, None, "three"] * nbr_records,
511
"z": [None, None, None, None] * nbr_records,
512
}
513
)
514
.select(pl.mean_horizontal(pl.all().is_null()).alias("mean_null"))
515
.group_by("mean_null")
516
.len()
517
.sort(by="mean_null")
518
.collect()
519
)
520
assert_frame_equal(
521
out,
522
pl.DataFrame(
523
{
524
"mean_null": pl.Series([0.25, 0.5, 0.75, 1.0], dtype=pl.Float64),
525
"len": pl.Series([nbr_records] * 4, dtype=pl.UInt32),
526
}
527
),
528
)
529
530
531
def test_group_count_over_null_column_15705() -> None:
532
df = pl.DataFrame(
533
{"a": [1, 1, 2, 2, 3, 3], "c": [None, None, None, None, None, None]}
534
)
535
out = df.group_by("a", maintain_order=True).agg(pl.col("c").count())
536
assert out["c"].to_list() == [0, 0, 0]
537
538
539
@pytest.mark.release
540
def test_min_max_2850() -> None:
541
# https://github.com/pola-rs/polars/issues/2850
542
df = pl.DataFrame(
543
{
544
"id": [
545
130352432,
546
130352277,
547
130352611,
548
130352833,
549
130352305,
550
130352258,
551
130352764,
552
130352475,
553
130352368,
554
130352346,
555
]
556
}
557
)
558
559
minimum = 130352258
560
maximum = 130352833.0
561
562
for _ in range(10):
563
permuted = df.sample(fraction=1.0, seed=0)
564
computed = permuted.select(
565
pl.col("id").min().alias("min"), pl.col("id").max().alias("max")
566
)
567
assert cast(int, computed[0, "min"]) == minimum
568
assert cast(float, computed[0, "max"]) == maximum
569
570
571
def test_multi_arg_structify_15834() -> None:
572
df = pl.DataFrame(
573
{
574
"group": [1, 2, 1, 2],
575
"value": [
576
0.1973209146402105,
577
0.13380719982405365,
578
0.6152394463707009,
579
0.4558767896005155,
580
],
581
}
582
)
583
584
assert df.lazy().group_by("group").agg(
585
pl.struct(a=1, value=pl.col("value").sum())
586
).collect().sort("group").to_dict(as_series=False) == {
587
"group": [1, 2],
588
"a": [
589
{"a": 1, "value": 0.8125603610109114},
590
{"a": 1, "value": 0.5896839894245691},
591
],
592
}
593
594
595
def test_filter_aggregation_16642() -> None:
596
df = pl.DataFrame(
597
{
598
"datetime": [
599
datetime(2022, 1, 1, 11, 0),
600
datetime(2022, 1, 1, 11, 1),
601
datetime(2022, 1, 1, 11, 2),
602
datetime(2022, 1, 1, 11, 3),
603
datetime(2022, 1, 1, 11, 4),
604
datetime(2022, 1, 1, 11, 5),
605
datetime(2022, 1, 1, 11, 6),
606
datetime(2022, 1, 1, 11, 7),
607
datetime(2022, 1, 1, 11, 8),
608
datetime(2022, 1, 1, 11, 9, 1),
609
datetime(2022, 1, 2, 11, 0),
610
datetime(2022, 1, 2, 11, 1),
611
datetime(2022, 1, 2, 11, 2),
612
datetime(2022, 1, 2, 11, 3),
613
datetime(2022, 1, 2, 11, 4),
614
datetime(2022, 1, 2, 11, 5),
615
datetime(2022, 1, 2, 11, 6),
616
datetime(2022, 1, 2, 11, 7),
617
datetime(2022, 1, 2, 11, 8),
618
datetime(2022, 1, 2, 11, 9, 1),
619
],
620
"alpha": [
621
"A",
622
"B",
623
"C",
624
"D",
625
"E",
626
"F",
627
"G",
628
"H",
629
"I",
630
"J",
631
"A",
632
"B",
633
"C",
634
"D",
635
"E",
636
"F",
637
"G",
638
"H",
639
"I",
640
"J",
641
],
642
"num": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
643
}
644
)
645
grouped = df.group_by(pl.col("datetime").dt.date())
646
647
ts_filter = pl.col("datetime").dt.time() <= pl.time(11, 3)
648
649
report = grouped.agg(pl.col("num").filter(ts_filter).max()).sort("datetime")
650
assert report.to_dict(as_series=False) == {
651
"datetime": [date(2022, 1, 1), date(2022, 1, 2)],
652
"num": [3, 3],
653
}
654
655
656
def test_sort_by_over_single_nulls_first() -> None:
657
key = [0, 0, 0, 0, 1, 1, 1, 1]
658
df = pl.DataFrame(
659
{
660
"key": key,
661
"value": [2, None, 1, 0, 2, None, 1, 0],
662
}
663
)
664
out = df.select(
665
pl.all().sort_by("value", nulls_last=False, maintain_order=True).over("key")
666
)
667
expected = pl.DataFrame(
668
{
669
"key": key,
670
"value": [None, 0, 1, 2, None, 0, 1, 2],
671
}
672
)
673
assert_frame_equal(out, expected)
674
675
676
def test_sort_by_over_single_nulls_last() -> None:
677
key = [0, 0, 0, 0, 1, 1, 1, 1]
678
df = pl.DataFrame(
679
{
680
"key": key,
681
"value": [2, None, 1, 0, 2, None, 1, 0],
682
}
683
)
684
out = df.select(
685
pl.all().sort_by("value", nulls_last=True, maintain_order=True).over("key")
686
)
687
expected = pl.DataFrame(
688
{
689
"key": key,
690
"value": [0, 1, 2, None, 0, 1, 2, None],
691
}
692
)
693
assert_frame_equal(out, expected)
694
695
696
def test_sort_by_over_multiple_nulls_first() -> None:
697
key1 = [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]
698
key2 = [0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1]
699
df = pl.DataFrame(
700
{
701
"key1": key1,
702
"key2": key2,
703
"value": [1, None, 0, 1, None, 0, 1, None, 0, None, 1, 0],
704
}
705
)
706
out = df.select(
707
pl.all()
708
.sort_by("value", nulls_last=False, maintain_order=True)
709
.over("key1", "key2")
710
)
711
expected = pl.DataFrame(
712
{
713
"key1": key1,
714
"key2": key2,
715
"value": [None, 0, 1, None, 0, 1, None, 0, 1, None, 0, 1],
716
}
717
)
718
assert_frame_equal(out, expected)
719
720
721
def test_sort_by_over_multiple_nulls_last() -> None:
722
key1 = [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]
723
key2 = [0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1]
724
df = pl.DataFrame(
725
{
726
"key1": key1,
727
"key2": key2,
728
"value": [1, None, 0, 1, None, 0, 1, None, 0, None, 1, 0],
729
}
730
)
731
out = df.select(
732
pl.all()
733
.sort_by("value", nulls_last=True, maintain_order=True)
734
.over("key1", "key2")
735
)
736
expected = pl.DataFrame(
737
{
738
"key1": key1,
739
"key2": key2,
740
"value": [0, 1, None, 0, 1, None, 0, 1, None, 0, 1, None],
741
}
742
)
743
assert_frame_equal(out, expected)
744
745
746
def test_slice_after_agg_raises() -> None:
747
with pytest.raises(
748
InvalidOperationError, match=r"cannot slice\(\) an aggregated scalar value"
749
):
750
pl.select(a=1, b=1).group_by("a").agg(pl.col("b").first().slice(99, 0))
751
752
753
def test_agg_scalar_empty_groups_20115() -> None:
754
assert_frame_equal(
755
(
756
pl.DataFrame({"key": [123], "value": [456]})
757
.group_by("key")
758
.agg(pl.col("value").slice(1, 1).first())
759
),
760
pl.select(key=pl.lit(123, pl.Int64), value=pl.lit(None, pl.Int64)),
761
)
762
763
764
def test_agg_expr_returns_list_type_15574() -> None:
765
assert (
766
pl.LazyFrame({"a": [1, None], "b": [1, 2]})
767
.group_by("b")
768
.agg(pl.col("a").drop_nulls())
769
.collect_schema()
770
) == {"b": pl.Int64, "a": pl.List(pl.Int64)}
771
772
773
def test_empty_agg_22005() -> None:
774
out = (
775
pl.concat([pl.LazyFrame({"a": [1, 2]}), pl.LazyFrame({"a": [1, 2]})])
776
.limit(0)
777
.select(pl.col("a").sum())
778
)
779
assert_frame_equal(out.collect(), pl.DataFrame({"a": 0}))
780
781
782
@pytest.mark.parametrize("wrap_numerical", [True, False])
783
@pytest.mark.parametrize("strict_cast", [True, False])
784
def test_agg_with_filter_then_cast_23682(
785
strict_cast: bool, wrap_numerical: bool
786
) -> None:
787
assert_frame_equal(
788
pl.DataFrame([{"a": 123, "b": 12}, {"a": 123, "b": 257}])
789
.group_by("a")
790
.agg(
791
pl.col("b")
792
.filter(pl.col("b") < 256)
793
.cast(pl.UInt8, strict=strict_cast, wrap_numerical=wrap_numerical)
794
),
795
pl.DataFrame(
796
[{"a": 123, "b": [12]}], schema={"a": pl.Int64, "b": pl.List(pl.UInt8)}
797
),
798
)
799
800
801
@pytest.mark.parametrize("wrap_numerical", [True, False])
802
@pytest.mark.parametrize("strict_cast", [True, False])
803
def test_agg_with_slice_then_cast_23682(
804
strict_cast: bool, wrap_numerical: bool
805
) -> None:
806
assert_frame_equal(
807
pl.DataFrame([{"a": 123, "b": 12}, {"a": 123, "b": 257}])
808
.group_by("a")
809
.agg(
810
pl.col("b")
811
.slice(0, 1)
812
.cast(pl.UInt8, strict=strict_cast, wrap_numerical=wrap_numerical)
813
),
814
pl.DataFrame(
815
[{"a": 123, "b": [12]}], schema={"a": pl.Int64, "b": pl.List(pl.UInt8)}
816
),
817
)
818
819