Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/operations/test_group_by.py
8429 views
1
from __future__ import annotations
2
3
import typing
4
from collections import OrderedDict
5
from datetime import date, datetime, time, timedelta
6
from typing import TYPE_CHECKING, Any
7
from zoneinfo import ZoneInfo
8
9
import numpy as np
10
import pytest
11
from hypothesis import given
12
13
import polars as pl
14
import polars.selectors as cs
15
from polars import Expr
16
from polars.exceptions import (
17
ColumnNotFoundError,
18
InvalidOperationError,
19
)
20
from polars.meta import get_index_type
21
from polars.testing import assert_frame_equal, assert_series_equal
22
from polars.testing.parametric import column, dataframes, series
23
24
if TYPE_CHECKING:
25
from collections.abc import Callable
26
27
from polars._typing import PolarsDataType, TimeUnit
28
from tests.conftest import PlMonkeyPatch
29
30
31
def test_group_by() -> None:
32
df = pl.DataFrame(
33
{
34
"a": ["a", "b", "a", "b", "b", "c"],
35
"b": [1, 2, 3, 4, 5, 6],
36
"c": [6, 5, 4, 3, 2, 1],
37
}
38
)
39
40
# Use lazy API in eager group_by
41
assert sorted(df.group_by("a").agg([pl.sum("b")]).rows()) == [
42
("a", 4),
43
("b", 11),
44
("c", 6),
45
]
46
# test if it accepts a single expression
47
assert df.group_by("a", maintain_order=True).agg(pl.sum("b")).rows() == [
48
("a", 4),
49
("b", 11),
50
("c", 6),
51
]
52
53
df = pl.DataFrame(
54
{
55
"a": [1, 2, 3, 4, 5],
56
"b": ["a", "a", "b", "b", "b"],
57
"c": [None, 1, None, 1, None],
58
}
59
)
60
61
# check if this query runs and thus column names propagate
62
df.group_by("b").agg(pl.col("c").fill_null(strategy="forward")).explode("c")
63
64
# get a specific column
65
result = df.group_by("b", maintain_order=True).agg(pl.count("a"))
66
assert result.rows() == [("a", 2), ("b", 3)]
67
assert result.columns == ["b", "a"]
68
69
70
@pytest.mark.parametrize(
71
("input", "expected", "input_dtype", "output_dtype"),
72
[
73
([1, 2, 3, 4], [2, 4], pl.UInt8, pl.Float64),
74
([1, 2, 3, 4], [2, 4], pl.Int8, pl.Float64),
75
([1, 2, 3, 4], [2, 4], pl.UInt16, pl.Float64),
76
([1, 2, 3, 4], [2, 4], pl.Int16, pl.Float64),
77
([1, 2, 3, 4], [2, 4], pl.UInt32, pl.Float64),
78
([1, 2, 3, 4], [2, 4], pl.Int32, pl.Float64),
79
([1, 2, 3, 4], [2, 4], pl.UInt64, pl.Float64),
80
([1, 2, 3, 4], [2, 4], pl.Float32, pl.Float32),
81
([1, 2, 3, 4], [2, 4], pl.Float64, pl.Float64),
82
([False, True, True, True], [2 / 3, 1], pl.Boolean, pl.Float64),
83
(
84
[date(2023, 1, 1), date(2023, 1, 2), date(2023, 1, 4), date(2023, 1, 5)],
85
[datetime(2023, 1, 2, 8, 0, 0), datetime(2023, 1, 5)],
86
pl.Date,
87
pl.Datetime("us"),
88
),
89
(
90
[
91
datetime(2023, 1, 1),
92
datetime(2023, 1, 2),
93
datetime(2023, 1, 3),
94
datetime(2023, 1, 4),
95
],
96
[datetime(2023, 1, 2), datetime(2023, 1, 4)],
97
pl.Datetime("ms"),
98
pl.Datetime("ms"),
99
),
100
(
101
[
102
datetime(2023, 1, 1),
103
datetime(2023, 1, 2),
104
datetime(2023, 1, 3),
105
datetime(2023, 1, 4),
106
],
107
[datetime(2023, 1, 2), datetime(2023, 1, 4)],
108
pl.Datetime("us"),
109
pl.Datetime("us"),
110
),
111
(
112
[
113
datetime(2023, 1, 1),
114
datetime(2023, 1, 2),
115
datetime(2023, 1, 3),
116
datetime(2023, 1, 4),
117
],
118
[datetime(2023, 1, 2), datetime(2023, 1, 4)],
119
pl.Datetime("ns"),
120
pl.Datetime("ns"),
121
),
122
(
123
[timedelta(1), timedelta(2), timedelta(3), timedelta(4)],
124
[timedelta(2), timedelta(4)],
125
pl.Duration("ms"),
126
pl.Duration("ms"),
127
),
128
(
129
[timedelta(1), timedelta(2), timedelta(3), timedelta(4)],
130
[timedelta(2), timedelta(4)],
131
pl.Duration("us"),
132
pl.Duration("us"),
133
),
134
(
135
[timedelta(1), timedelta(2), timedelta(3), timedelta(4)],
136
[timedelta(2), timedelta(4)],
137
pl.Duration("ns"),
138
pl.Duration("ns"),
139
),
140
],
141
)
142
def test_group_by_mean_by_dtype(
143
input: list[Any],
144
expected: list[Any],
145
input_dtype: PolarsDataType,
146
output_dtype: PolarsDataType,
147
) -> None:
148
# groups are defined by first 3 values, then last value
149
name = str(input_dtype)
150
key = ["a", "a", "a", "b"]
151
df = pl.LazyFrame(
152
{
153
"key": key,
154
name: pl.Series(input, dtype=input_dtype),
155
}
156
)
157
result = df.group_by("key", maintain_order=True).mean()
158
df_expected = pl.DataFrame(
159
{
160
"key": ["a", "b"],
161
name: pl.Series(expected, dtype=output_dtype),
162
}
163
)
164
assert result.collect_schema() == df_expected.schema
165
assert_frame_equal(result.collect(), df_expected)
166
167
168
@pytest.mark.parametrize(
169
("input", "expected", "input_dtype", "output_dtype"),
170
[
171
([1, 2, 4, 5], [2, 5], pl.UInt8, pl.Float64),
172
([1, 2, 4, 5], [2, 5], pl.Int8, pl.Float64),
173
([1, 2, 4, 5], [2, 5], pl.UInt16, pl.Float64),
174
([1, 2, 4, 5], [2, 5], pl.Int16, pl.Float64),
175
([1, 2, 4, 5], [2, 5], pl.UInt32, pl.Float64),
176
([1, 2, 4, 5], [2, 5], pl.Int32, pl.Float64),
177
([1, 2, 4, 5], [2, 5], pl.UInt64, pl.Float64),
178
([1, 2, 4, 5], [2, 5], pl.Float32, pl.Float32),
179
([1, 2, 4, 5], [2, 5], pl.Float64, pl.Float64),
180
([False, True, True, True], [1, 1], pl.Boolean, pl.Float64),
181
(
182
[date(2023, 1, 1), date(2023, 1, 2), date(2023, 1, 4), date(2023, 1, 5)],
183
[datetime(2023, 1, 2), datetime(2023, 1, 5)],
184
pl.Date,
185
pl.Datetime("us"),
186
),
187
(
188
[
189
datetime(2023, 1, 1),
190
datetime(2023, 1, 2),
191
datetime(2023, 1, 4),
192
datetime(2023, 1, 5),
193
],
194
[datetime(2023, 1, 2), datetime(2023, 1, 5)],
195
pl.Datetime("ms"),
196
pl.Datetime("ms"),
197
),
198
(
199
[
200
datetime(2023, 1, 1),
201
datetime(2023, 1, 2),
202
datetime(2023, 1, 4),
203
datetime(2023, 1, 5),
204
],
205
[datetime(2023, 1, 2), datetime(2023, 1, 5)],
206
pl.Datetime("us"),
207
pl.Datetime("us"),
208
),
209
(
210
[
211
datetime(2023, 1, 1),
212
datetime(2023, 1, 2),
213
datetime(2023, 1, 4),
214
datetime(2023, 1, 5),
215
],
216
[datetime(2023, 1, 2), datetime(2023, 1, 5)],
217
pl.Datetime("ns"),
218
pl.Datetime("ns"),
219
),
220
(
221
[timedelta(1), timedelta(2), timedelta(4), timedelta(5)],
222
[timedelta(2), timedelta(5)],
223
pl.Duration("ms"),
224
pl.Duration("ms"),
225
),
226
(
227
[timedelta(1), timedelta(2), timedelta(4), timedelta(5)],
228
[timedelta(2), timedelta(5)],
229
pl.Duration("us"),
230
pl.Duration("us"),
231
),
232
(
233
[timedelta(1), timedelta(2), timedelta(4), timedelta(5)],
234
[timedelta(2), timedelta(5)],
235
pl.Duration("ns"),
236
pl.Duration("ns"),
237
),
238
],
239
)
240
def test_group_by_median_by_dtype(
241
input: list[Any],
242
expected: list[Any],
243
input_dtype: PolarsDataType,
244
output_dtype: PolarsDataType,
245
) -> None:
246
# groups are defined by first 3 values, then last value
247
name = str(input_dtype)
248
key = ["a", "a", "a", "b"]
249
df = pl.LazyFrame(
250
{
251
"key": key,
252
name: pl.Series(input, dtype=input_dtype),
253
}
254
)
255
result = df.group_by("key", maintain_order=True).median()
256
df_expected = pl.DataFrame(
257
{
258
"key": ["a", "b"],
259
name: pl.Series(expected, dtype=output_dtype),
260
}
261
)
262
assert result.collect_schema() == df_expected.schema
263
assert_frame_equal(result.collect(), df_expected)
264
265
266
@pytest.fixture
267
def df() -> pl.DataFrame:
268
return pl.DataFrame(
269
{
270
"a": [1, 2, 3, 4, 5],
271
"b": ["a", "a", "b", "b", "b"],
272
"c": [None, 1, None, 1, None],
273
}
274
)
275
276
277
@pytest.mark.parametrize(
278
("method", "expected"),
279
[
280
("all", [("a", [1, 2], [None, 1]), ("b", [3, 4, 5], [None, 1, None])]),
281
("len", [("a", 2), ("b", 3)]),
282
("first", [("a", 1, None), ("b", 3, None)]),
283
("last", [("a", 2, 1), ("b", 5, None)]),
284
("max", [("a", 2, 1), ("b", 5, 1)]),
285
("mean", [("a", 1.5, 1.0), ("b", 4.0, 1.0)]),
286
("median", [("a", 1.5, 1.0), ("b", 4.0, 1.0)]),
287
("min", [("a", 1, 1), ("b", 3, 1)]),
288
("n_unique", [("a", 2, 2), ("b", 3, 2)]),
289
],
290
)
291
def test_group_by_shorthands(
292
df: pl.DataFrame, method: str, expected: list[tuple[Any]]
293
) -> None:
294
gb = df.group_by("b", maintain_order=True)
295
result = getattr(gb, method)()
296
assert result.rows() == expected
297
298
gb_lazy = df.lazy().group_by("b", maintain_order=True)
299
result = getattr(gb_lazy, method)().collect()
300
assert result.rows() == expected
301
302
303
def test_group_by_shorthand_quantile(df: pl.DataFrame) -> None:
304
result = df.group_by("b", maintain_order=True).quantile(0.5)
305
expected = [("a", 2.0, 1.0), ("b", 4.0, 1.0)]
306
assert result.rows() == expected
307
308
result = df.lazy().group_by("b", maintain_order=True).quantile(0.5).collect()
309
assert result.rows() == expected
310
311
312
def test_group_by_quantile_date() -> None:
313
df = pl.DataFrame(
314
{
315
"group": [1, 1, 1, 1, 2, 2, 2, 2],
316
"value": [date(2025, 1, x) for x in range(1, 9)],
317
}
318
)
319
result = (
320
df.lazy()
321
.group_by("group", maintain_order=True)
322
.agg(
323
nearest=pl.col("value").quantile(0.5, "nearest"),
324
higher=pl.col("value").quantile(0.5, "higher"),
325
lower=pl.col("value").quantile(0.5, "lower"),
326
linear=pl.col("value").quantile(0.5, "linear"),
327
)
328
)
329
dt = pl.Datetime("us")
330
expected = pl.DataFrame(
331
{
332
"group": [1, 2],
333
"nearest": pl.Series(
334
[datetime(2025, 1, 3), datetime(2025, 1, 7)], dtype=dt
335
),
336
"higher": pl.Series([datetime(2025, 1, 3), datetime(2025, 1, 7)], dtype=dt),
337
"lower": pl.Series([datetime(2025, 1, 2), datetime(2025, 1, 6)], dtype=dt),
338
"linear": pl.Series(
339
[datetime(2025, 1, 2, 12), datetime(2025, 1, 6, 12)], dtype=dt
340
),
341
}
342
)
343
assert result.collect_schema() == pl.Schema(
344
{ # type: ignore[arg-type]
345
"group": pl.Int64,
346
"nearest": dt,
347
"higher": dt,
348
"lower": dt,
349
"linear": dt,
350
}
351
)
352
assert_frame_equal(result.collect(), expected)
353
354
355
@pytest.mark.parametrize("tu", ["ms", "us", "ns"])
356
@pytest.mark.parametrize("time_zone", [None, "Asia/Tokyo"])
357
def test_group_by_quantile_datetime(tu: TimeUnit, time_zone: str) -> None:
358
dt = pl.Datetime(tu, time_zone)
359
tz = ZoneInfo(time_zone) if time_zone else None
360
df = pl.DataFrame(
361
{
362
"group": [1, 1, 1, 1, 2, 2, 2, 2],
363
"value": pl.Series(
364
[datetime(2025, 1, x, tzinfo=tz) for x in range(1, 9)],
365
dtype=dt,
366
),
367
}
368
)
369
result = (
370
df.lazy()
371
.group_by("group", maintain_order=True)
372
.agg(
373
nearest=pl.col("value").quantile(0.5, "nearest"),
374
higher=pl.col("value").quantile(0.5, "higher"),
375
lower=pl.col("value").quantile(0.5, "lower"),
376
linear=pl.col("value").quantile(0.5, "linear"),
377
)
378
)
379
expected = pl.DataFrame(
380
{
381
"group": [1, 2],
382
"nearest": pl.Series(
383
[datetime(2025, 1, 3, tzinfo=tz), datetime(2025, 1, 7, tzinfo=tz)],
384
dtype=dt,
385
),
386
"higher": pl.Series(
387
[datetime(2025, 1, 3, tzinfo=tz), datetime(2025, 1, 7, tzinfo=tz)],
388
dtype=dt,
389
),
390
"lower": pl.Series(
391
[datetime(2025, 1, 2, tzinfo=tz), datetime(2025, 1, 6, tzinfo=tz)],
392
dtype=dt,
393
),
394
"linear": pl.Series(
395
[
396
datetime(2025, 1, 2, 12, tzinfo=tz),
397
datetime(2025, 1, 6, 12, tzinfo=tz),
398
],
399
dtype=dt,
400
),
401
}
402
)
403
assert result.collect_schema() == pl.Schema(
404
{ # type: ignore[arg-type]
405
"group": pl.Int64,
406
"nearest": dt,
407
"higher": dt,
408
"lower": dt,
409
"linear": dt,
410
}
411
)
412
assert_frame_equal(result.collect(), expected)
413
414
415
@pytest.mark.parametrize("tu", ["ms", "us", "ns"])
416
def test_group_by_quantile_duration(tu: TimeUnit) -> None:
417
dt = pl.Duration(tu)
418
df = pl.DataFrame(
419
{
420
"group": [1, 1, 1, 1, 2, 2, 2, 2],
421
"value": pl.Series([timedelta(hours=x) for x in range(1, 9)], dtype=dt),
422
}
423
)
424
result = (
425
df.lazy()
426
.group_by("group", maintain_order=True)
427
.agg(
428
nearest=pl.col("value").quantile(0.5, "nearest"),
429
higher=pl.col("value").quantile(0.5, "higher"),
430
lower=pl.col("value").quantile(0.5, "lower"),
431
linear=pl.col("value").quantile(0.5, "linear"),
432
)
433
)
434
expected = pl.DataFrame(
435
{
436
"group": [1, 2],
437
"nearest": pl.Series([timedelta(hours=3), timedelta(hours=7)], dtype=dt),
438
"higher": pl.Series([timedelta(hours=3), timedelta(hours=7)], dtype=dt),
439
"lower": pl.Series([timedelta(hours=2), timedelta(hours=6)], dtype=dt),
440
"linear": pl.Series(
441
[timedelta(hours=2, minutes=30), timedelta(hours=6, minutes=30)],
442
dtype=dt,
443
),
444
}
445
)
446
assert result.collect_schema() == pl.Schema(
447
{ # type: ignore[arg-type]
448
"group": pl.Int64,
449
"nearest": dt,
450
"higher": dt,
451
"lower": dt,
452
"linear": dt,
453
}
454
)
455
assert_frame_equal(result.collect(), expected)
456
457
458
def test_group_by_quantile_time() -> None:
459
df = pl.DataFrame(
460
{
461
"group": [1, 1, 1, 1, 2, 2, 2, 2],
462
"value": pl.Series([time(hour=x) for x in range(1, 9)]),
463
}
464
)
465
result = (
466
df.lazy()
467
.group_by("group", maintain_order=True)
468
.agg(
469
nearest=pl.col("value").quantile(0.5, "nearest"),
470
higher=pl.col("value").quantile(0.5, "higher"),
471
lower=pl.col("value").quantile(0.5, "lower"),
472
linear=pl.col("value").quantile(0.5, "linear"),
473
)
474
)
475
expected = pl.DataFrame(
476
{
477
"group": [1, 2],
478
"nearest": pl.Series([time(hour=3), time(hour=7)]),
479
"higher": pl.Series([time(hour=3), time(hour=7)]),
480
"lower": pl.Series([time(hour=2), time(hour=6)]),
481
"linear": pl.Series([time(hour=2, minute=30), time(hour=6, minute=30)]),
482
}
483
)
484
assert result.collect_schema() == pl.Schema(
485
{
486
"group": pl.Int64,
487
"nearest": pl.Time,
488
"higher": pl.Time,
489
"lower": pl.Time,
490
"linear": pl.Time,
491
}
492
)
493
assert_frame_equal(result.collect(), expected)
494
495
496
def test_group_by_args() -> None:
497
df = pl.DataFrame(
498
{
499
"a": ["a", "b", "a", "b", "b", "c"],
500
"b": [1, 2, 3, 4, 5, 6],
501
"c": [6, 5, 4, 3, 2, 1],
502
}
503
)
504
505
# Single column name
506
assert df.group_by("a").agg("b").columns == ["a", "b"]
507
# Column names as list
508
expected = ["a", "b", "c"]
509
assert df.group_by(["a", "b"]).agg("c").columns == expected
510
# Column names as positional arguments
511
assert df.group_by("a", "b").agg("c").columns == expected
512
# With keyword argument
513
assert df.group_by("a", "b", maintain_order=True).agg("c").columns == expected
514
# Multiple aggregations as list
515
assert df.group_by("a").agg(["b", "c"]).columns == expected
516
# Multiple aggregations as positional arguments
517
assert df.group_by("a").agg("b", "c").columns == expected
518
# Multiple aggregations as keyword arguments
519
assert df.group_by("a").agg(q="b", r="c").columns == ["a", "q", "r"]
520
521
522
def test_group_by_empty() -> None:
523
df = pl.DataFrame({"a": [1, 1, 2]})
524
result = df.group_by("a").agg()
525
expected = pl.DataFrame({"a": [1, 2]})
526
assert_frame_equal(result, expected, check_row_order=False)
527
528
529
def test_group_by_iteration() -> None:
530
df = pl.DataFrame(
531
{
532
"foo": ["a", "b", "a", "b", "b", "c"],
533
"bar": [1, 2, 3, 4, 5, 6],
534
"baz": [6, 5, 4, 3, 2, 1],
535
}
536
)
537
expected_names = ["a", "b", "c"]
538
expected_rows = [
539
[("a", 1, 6), ("a", 3, 4)],
540
[("b", 2, 5), ("b", 4, 3), ("b", 5, 2)],
541
[("c", 6, 1)],
542
]
543
gb_iter = enumerate(df.group_by("foo", maintain_order=True))
544
for i, (group, data) in gb_iter:
545
assert group == (expected_names[i],)
546
assert data.rows() == expected_rows[i]
547
548
# Grouped by ALL columns should give groups of a single row
549
result = list(df.group_by(["foo", "bar", "baz"]))
550
assert len(result) == 6
551
552
# Iterating over groups should also work when grouping by expressions
553
result2 = list(df.group_by(["foo", pl.col("bar") * pl.col("baz")]))
554
assert len(result2) == 5
555
556
# Single expression, alias in group_by
557
df = pl.DataFrame({"foo": [1, 2, 3, 4, 5, 6]})
558
gb = df.group_by((pl.col("foo") // 2).alias("bar"), maintain_order=True)
559
result3 = [(group, df.rows()) for group, df in gb]
560
expected3 = [
561
((0,), [(1,)]),
562
((1,), [(2,), (3,)]),
563
((2,), [(4,), (5,)]),
564
((3,), [(6,)]),
565
]
566
assert result3 == expected3
567
568
569
def test_group_by_iteration_selector() -> None:
570
df = pl.DataFrame({"a": ["one", "two", "one", "two"], "b": [1, 2, 3, 4]})
571
result = dict(df.group_by(cs.string()))
572
result_first = result["one",]
573
assert result_first.to_dict(as_series=False) == {"a": ["one", "one"], "b": [1, 3]}
574
575
576
@pytest.mark.parametrize("input", [[pl.col("b").sum()], pl.col("b").sum()])
577
def test_group_by_agg_input_types(input: Any) -> None:
578
df = pl.LazyFrame({"a": [1, 1, 2, 2], "b": [1, 2, 3, 4]})
579
result = df.group_by("a", maintain_order=True).agg(input)
580
expected = pl.LazyFrame({"a": [1, 2], "b": [3, 7]})
581
assert_frame_equal(result, expected)
582
583
584
@pytest.mark.parametrize("input", [str, "b".join])
585
def test_group_by_agg_bad_input_types(input: Any) -> None:
586
df = pl.LazyFrame({"a": [1, 1, 2, 2], "b": [1, 2, 3, 4]})
587
with pytest.raises(TypeError):
588
df.group_by("a").agg(input)
589
590
591
def test_group_by_sorted_empty_dataframe_3680() -> None:
592
df = (
593
pl.DataFrame(
594
[
595
pl.Series("key", [], dtype=pl.Categorical),
596
pl.Series("val", [], dtype=pl.Float64),
597
]
598
)
599
.lazy()
600
.sort("key")
601
.group_by("key")
602
.tail(1)
603
.collect(optimizations=pl.QueryOptFlags(check_order_observe=False))
604
)
605
assert df.rows() == []
606
assert df.shape == (0, 2)
607
assert df.schema == {"key": pl.Categorical(), "val": pl.Float64}
608
609
610
def test_group_by_custom_agg_empty_list() -> None:
611
assert (
612
pl.DataFrame(
613
[
614
pl.Series("key", [], dtype=pl.Categorical),
615
pl.Series("val", [], dtype=pl.Float64),
616
]
617
)
618
.group_by("key")
619
.agg(
620
[
621
pl.col("val").mean().alias("mean"),
622
pl.col("val").std().alias("std"),
623
pl.col("val").skew().alias("skew"),
624
pl.col("val").kurtosis().alias("kurt"),
625
]
626
)
627
).dtypes == [pl.Categorical, pl.Float64, pl.Float64, pl.Float64, pl.Float64]
628
629
630
def test_apply_after_take_in_group_by_3869() -> None:
631
assert (
632
pl.DataFrame(
633
{
634
"k": list("aaabbb"),
635
"t": [1, 2, 3, 4, 5, 6],
636
"v": [3, 1, 2, 5, 6, 4],
637
}
638
)
639
.group_by("k", maintain_order=True)
640
.agg(
641
pl.col("v").get(pl.col("t").arg_max()).sqrt()
642
) # <- fails for sqrt, exp, log, pow, etc.
643
).to_dict(as_series=False) == {"k": ["a", "b"], "v": [1.4142135623730951, 2.0]}
644
645
646
def test_group_by_signed_transmutes() -> None:
647
df = pl.DataFrame({"foo": [-1, -2, -3, -4, -5], "bar": [500, 600, 700, 800, 900]})
648
649
for dt in [pl.Int8, pl.Int16, pl.Int32, pl.Int64]:
650
df = (
651
df.with_columns([pl.col("foo").cast(dt), pl.col("bar")])
652
.group_by("foo", maintain_order=True)
653
.agg(pl.col("bar").median())
654
)
655
656
assert df.to_dict(as_series=False) == {
657
"foo": [-1, -2, -3, -4, -5],
658
"bar": [500.0, 600.0, 700.0, 800.0, 900.0],
659
}
660
661
662
def test_arg_sort_sort_by_groups_update__4360() -> None:
663
df = pl.DataFrame(
664
{
665
"group": ["a"] * 3 + ["b"] * 3 + ["c"] * 3,
666
"col1": [1, 2, 3] * 3,
667
"col2": [1, 2, 3, 3, 2, 1, 2, 3, 1],
668
}
669
)
670
671
out = df.with_columns(
672
pl.col("col2").arg_sort().over("group").alias("col2_arg_sort")
673
).with_columns(
674
pl.col("col1").sort_by(pl.col("col2_arg_sort")).over("group").alias("result_a"),
675
pl.col("col1")
676
.sort_by(pl.col("col2").arg_sort())
677
.over("group")
678
.alias("result_b"),
679
)
680
681
assert_series_equal(out["result_a"], out["result_b"], check_names=False)
682
assert out["result_a"].to_list() == [1, 2, 3, 3, 2, 1, 2, 3, 1]
683
684
685
def test_unique_order() -> None:
686
df = pl.DataFrame({"a": [1, 2, 1]}).with_row_index()
687
assert df.unique(keep="last", subset="a", maintain_order=True).to_dict(
688
as_series=False
689
) == {
690
"index": [1, 2],
691
"a": [2, 1],
692
}
693
assert df.unique(keep="first", subset="a", maintain_order=True).to_dict(
694
as_series=False
695
) == {
696
"index": [0, 1],
697
"a": [1, 2],
698
}
699
700
701
def test_group_by_dynamic_flat_agg_4814() -> None:
702
df = pl.DataFrame({"a": [1, 2, 2], "b": [1, 8, 12]}).set_sorted("a")
703
704
assert df.group_by_dynamic("a", every="1i", period="2i").agg(
705
[
706
(pl.col("b").sum() / pl.col("a").sum()).alias("sum_ratio_1"),
707
(pl.col("b").last() / pl.col("a").last()).alias("last_ratio_1"),
708
(pl.col("b") / pl.col("a")).last().alias("last_ratio_2"),
709
]
710
).to_dict(as_series=False) == {
711
"a": [1, 2],
712
"sum_ratio_1": [4.2, 5.0],
713
"last_ratio_1": [6.0, 6.0],
714
"last_ratio_2": [6.0, 6.0],
715
}
716
717
718
@pytest.mark.parametrize(
719
("every", "period"),
720
[
721
("10s", timedelta(seconds=100)),
722
(timedelta(seconds=10), "100s"),
723
],
724
)
725
@pytest.mark.parametrize("time_zone", [None, "UTC", "Asia/Kathmandu"])
726
def test_group_by_dynamic_overlapping_groups_flat_apply_multiple_5038(
727
every: str | timedelta, period: str | timedelta, time_zone: str | None
728
) -> None:
729
res = (
730
(
731
pl.DataFrame(
732
{
733
"a": [
734
datetime(2021, 1, 1) + timedelta(seconds=2**i)
735
for i in range(10)
736
],
737
"b": [float(i) for i in range(10)],
738
}
739
)
740
.with_columns(pl.col("a").dt.replace_time_zone(time_zone))
741
.lazy()
742
.set_sorted("a")
743
.group_by_dynamic("a", every=every, period=period)
744
.agg([pl.col("b").var().sqrt().alias("corr")])
745
)
746
.collect()
747
.sum()
748
.to_dict(as_series=False)
749
)
750
751
assert res["corr"] == pytest.approx([6.988674024215477])
752
assert res["a"] == [None]
753
754
755
def test_take_in_group_by() -> None:
756
df = pl.DataFrame({"group": [1, 1, 1, 2, 2, 2], "values": [10, 200, 3, 40, 500, 6]})
757
assert df.group_by("group").agg(
758
pl.col("values").get(1) - pl.col("values").get(2)
759
).sort("group").to_dict(as_series=False) == {"group": [1, 2], "values": [197, 494]}
760
761
762
def test_group_by_wildcard() -> None:
763
df = pl.DataFrame(
764
{
765
"a": [1, 2],
766
"b": [1, 2],
767
}
768
)
769
assert df.group_by([pl.col("*")], maintain_order=True).agg(
770
[pl.col("a").first().name.suffix("_agg")]
771
).to_dict(as_series=False) == {"a": [1, 2], "b": [1, 2], "a_agg": [1, 2]}
772
773
774
def test_group_by_all_masked_out() -> None:
775
df = pl.DataFrame(
776
{
777
"val": pl.Series(
778
[None, None, None, None], dtype=pl.Categorical, nan_to_null=True
779
).set_sorted(),
780
"col": [4, 4, 4, 4],
781
}
782
)
783
parts = df.partition_by("val")
784
assert len(parts) == 1
785
assert_frame_equal(parts[0], df)
786
787
788
def test_group_by_null_propagation_6185() -> None:
789
df_1 = pl.DataFrame({"A": [0, 0], "B": [1, 2]})
790
791
expr = pl.col("A").filter(pl.col("A") > 0)
792
793
expected = {"B": [1, 2], "A": [None, None]}
794
assert (
795
df_1.group_by("B")
796
.agg((expr - expr.mean()).mean())
797
.sort("B")
798
.to_dict(as_series=False)
799
== expected
800
)
801
802
803
def test_group_by_when_then_with_binary_and_agg_in_pred_6202() -> None:
804
df = pl.DataFrame(
805
{"code": ["a", "b", "b", "b", "a"], "xx": [1.0, -1.5, -0.2, -3.9, 3.0]}
806
)
807
assert (
808
df.group_by("code", maintain_order=True).agg(
809
[pl.when(pl.col("xx") > pl.min("xx")).then(True).otherwise(False)]
810
)
811
).to_dict(as_series=False) == {
812
"code": ["a", "b"],
813
"literal": [[False, True], [True, True, False]],
814
}
815
816
817
def test_group_by_binary_agg_with_literal() -> None:
818
df = pl.DataFrame({"id": ["a", "a", "b", "b"], "value": [1, 2, 3, 4]})
819
820
out = df.group_by("id", maintain_order=True).agg(
821
pl.col("value") + pl.Series([1, 3])
822
)
823
assert out.to_dict(as_series=False) == {"id": ["a", "b"], "value": [[2, 5], [4, 7]]}
824
825
out = df.group_by("id", maintain_order=True).agg(pl.col("value") + pl.lit(1))
826
assert out.to_dict(as_series=False) == {"id": ["a", "b"], "value": [[2, 3], [4, 5]]}
827
828
out = df.group_by("id", maintain_order=True).agg(pl.lit(1) + pl.lit(2))
829
assert out.to_dict(as_series=False) == {"id": ["a", "b"], "literal": [3, 3]}
830
831
out = df.group_by("id", maintain_order=True).agg(pl.lit(1) + pl.Series([2, 3]))
832
assert out.to_dict(as_series=False) == {
833
"id": ["a", "b"],
834
"literal": [[3, 4], [3, 4]],
835
}
836
837
out = df.group_by("id", maintain_order=True).agg(
838
value=pl.lit(pl.Series([1, 2])) + pl.lit(pl.Series([3, 4]))
839
)
840
assert out.to_dict(as_series=False) == {"id": ["a", "b"], "value": [[4, 6], [4, 6]]}
841
842
843
@pytest.mark.slow
844
@pytest.mark.parametrize("dtype", [pl.Int32, pl.UInt32])
845
def test_overflow_mean_partitioned_group_by_5194(dtype: PolarsDataType) -> None:
846
df = pl.DataFrame(
847
[
848
pl.Series("data", [10_00_00_00] * 100_000, dtype=dtype),
849
pl.Series("group", [1, 2] * 50_000, dtype=dtype),
850
]
851
)
852
result = df.group_by("group").agg(pl.col("data").mean()).sort(by="group")
853
expected = {"group": [1, 2], "data": [10000000.0, 10000000.0]}
854
assert result.to_dict(as_series=False) == expected
855
856
857
# https://github.com/pola-rs/polars/issues/7181
858
def test_group_by_multiple_column_reference() -> None:
859
df = pl.DataFrame(
860
{
861
"gr": ["a", "b", "a", "b", "a", "b"],
862
"val": [1, 20, 100, 2000, 10000, 200000],
863
}
864
)
865
result = df.group_by("gr").agg(
866
pl.col("val") + pl.col("val").shift().fill_null(0),
867
)
868
869
assert result.sort("gr").to_dict(as_series=False) == {
870
"gr": ["a", "b"],
871
"val": [[1, 101, 10100], [20, 2020, 202000]],
872
}
873
874
875
@pytest.mark.parametrize(
876
("aggregation", "args", "expected_values", "expected_dtype"),
877
[
878
("first", [], [1, None], pl.Int64),
879
("last", [], [1, None], pl.Int64),
880
("max", [], [1, None], pl.Int64),
881
("mean", [], [1.0, None], pl.Float64),
882
("median", [], [1.0, None], pl.Float64),
883
("min", [], [1, None], pl.Int64),
884
("n_unique", [], [1, 0], pl.get_index_type()),
885
("quantile", [0.5], [1.0, None], pl.Float64),
886
],
887
)
888
def test_group_by_empty_groups(
889
aggregation: str,
890
args: list[object],
891
expected_values: list[object],
892
expected_dtype: pl.DataType,
893
) -> None:
894
df = pl.DataFrame({"a": [1, 2], "b": [1, 2]})
895
result = df.group_by("b", maintain_order=True).agg(
896
getattr(pl.col("a").filter(pl.col("b") != 2), aggregation)(*args)
897
)
898
expected = pl.DataFrame({"b": [1, 2], "a": expected_values}).with_columns(
899
pl.col("a").cast(expected_dtype)
900
)
901
assert_frame_equal(result, expected)
902
903
904
# https://github.com/pola-rs/polars/issues/8663
905
def test_perfect_hash_table_null_values() -> None:
906
# fmt: off
907
values = ["3", "41", "17", "5", "26", "27", "43", "45", "41", "13", "45", "48", "17", "22", "31", "25", "28", "13", "7", "26", "17", "4", "43", "47", "30", "28", "8", "27", "6", "7", "26", "11", "37", "29", "49", "20", "29", "28", "23", "9", None, "38", "19", "7", "38", "3", "30", "37", "41", "5", "16", "26", "31", "6", "25", "11", "17", "31", "31", "20", "26", None, "39", "10", "38", "4", "39", "15", "13", "35", "38", "11", "39", "11", "48", "36", "18", "11", "34", "16", "28", "9", "37", "8", "17", "48", "44", "28", "25", "30", "37", "30", "18", "12", None, "27", "10", "3", "16", "27", "6"]
908
groups = ["3", "41", "17", "5", "26", "27", "43", "45", "13", "48", "22", "31", "25", "28", "7", "4", "47", "30", "8", "6", "11", "37", "29", "49", "20", "23", "9", None, "38", "19", "16", "39", "10", "15", "35", "36", "18", "34", "44", "12"]
909
# fmt: on
910
911
s = pl.Series("a", values, dtype=pl.Categorical)
912
913
result = (
914
s.to_frame("a").group_by("a", maintain_order=True).agg(pl.col("a").alias("agg"))
915
)
916
917
agg_values = [
918
["3", "3", "3"],
919
["41", "41", "41"],
920
["17", "17", "17", "17", "17"],
921
["5", "5"],
922
["26", "26", "26", "26", "26"],
923
["27", "27", "27", "27"],
924
["43", "43"],
925
["45", "45"],
926
["13", "13", "13"],
927
["48", "48", "48"],
928
["22"],
929
["31", "31", "31", "31"],
930
["25", "25", "25"],
931
["28", "28", "28", "28", "28"],
932
["7", "7", "7"],
933
["4", "4"],
934
["47"],
935
["30", "30", "30", "30"],
936
["8", "8"],
937
["6", "6", "6"],
938
["11", "11", "11", "11", "11"],
939
["37", "37", "37", "37"],
940
["29", "29"],
941
["49"],
942
["20", "20"],
943
["23"],
944
["9", "9"],
945
[None, None, None],
946
["38", "38", "38", "38"],
947
["19"],
948
["16", "16", "16"],
949
["39", "39", "39"],
950
["10", "10"],
951
["15"],
952
["35"],
953
["36"],
954
["18", "18"],
955
["34"],
956
["44"],
957
["12"],
958
]
959
expected = pl.DataFrame(
960
{
961
"a": groups,
962
"agg": agg_values,
963
},
964
schema={"a": pl.Categorical, "agg": pl.List(pl.Categorical)},
965
)
966
assert_frame_equal(result, expected)
967
968
969
def test_group_by_partitioned_ending_cast(plmonkeypatch: PlMonkeyPatch) -> None:
970
plmonkeypatch.setenv("POLARS_FORCE_PARTITION", "1")
971
df = pl.DataFrame({"a": [1] * 5, "b": [1] * 5})
972
out = df.group_by(["a", "b"]).agg(pl.len().cast(pl.Int64).alias("num"))
973
expected = pl.DataFrame({"a": [1], "b": [1], "num": [5]})
974
assert_frame_equal(out, expected)
975
976
977
def test_group_by_series_partitioned(partition_limit: int) -> None:
978
# test 15354
979
df = pl.DataFrame([0, 0] * partition_limit)
980
groups = pl.Series([0, 1] * partition_limit)
981
df.group_by(groups).agg(pl.all().is_not_null().sum())
982
983
984
def test_group_by_list_scalar_11749() -> None:
985
df = pl.DataFrame(
986
{
987
"group_name": ["a;b", "a;b", "c;d", "c;d", "a;b", "a;b"],
988
"parent_name": ["a", "b", "c", "d", "a", "b"],
989
"measurement": [
990
["x1", "x2"],
991
["x1", "x2"],
992
["y1", "y2"],
993
["z1", "z2"],
994
["x1", "x2"],
995
["x1", "x2"],
996
],
997
}
998
)
999
assert (
1000
df.group_by("group_name").agg(
1001
(pl.col("measurement").first() == pl.col("measurement")).alias("eq"),
1002
)
1003
).sort("group_name").to_dict(as_series=False) == {
1004
"group_name": ["a;b", "c;d"],
1005
"eq": [[True, True, True, True], [True, False]],
1006
}
1007
1008
1009
def test_group_by_with_expr_as_key() -> None:
1010
gb = pl.select(x=1).group_by(pl.col("x").alias("key"))
1011
result = gb.agg(pl.all().first())
1012
expected = gb.agg(pl.first("x"))
1013
assert_frame_equal(result, expected)
1014
1015
# tests: 11766
1016
result = gb.head(0)
1017
expected = gb.agg(pl.col("x").head(0)).explode("x")
1018
assert_frame_equal(result, expected)
1019
1020
result = gb.tail(0)
1021
expected = gb.agg(pl.col("x").tail(0)).explode("x")
1022
assert_frame_equal(result, expected)
1023
1024
1025
def test_lazy_group_by_reuse_11767() -> None:
1026
lgb = pl.select(x=1).lazy().group_by("x")
1027
a = lgb.len()
1028
b = lgb.len()
1029
assert_frame_equal(a, b)
1030
1031
1032
def test_group_by_double_on_empty_12194() -> None:
1033
df = pl.DataFrame({"group": [1], "x": [1]}).clear()
1034
squared_deviation_sum = ((pl.col("x") - pl.col("x").mean()) ** 2).sum()
1035
assert df.group_by("group").agg(squared_deviation_sum).schema == OrderedDict(
1036
[("group", pl.Int64), ("x", pl.Float64)]
1037
)
1038
1039
1040
def test_group_by_when_then_no_aggregation_predicate() -> None:
1041
df = pl.DataFrame(
1042
{
1043
"key": ["aa", "aa", "bb", "bb", "aa", "aa"],
1044
"val": [-3, -2, 1, 4, -3, 5],
1045
}
1046
)
1047
assert df.group_by("key").agg(
1048
pos=pl.when(pl.col("val") >= 0).then(pl.col("val")).sum(),
1049
neg=pl.when(pl.col("val") < 0).then(pl.col("val")).sum(),
1050
).sort("key").to_dict(as_series=False) == {
1051
"key": ["aa", "bb"],
1052
"pos": [5, 5],
1053
"neg": [-8, 0],
1054
}
1055
1056
1057
def test_group_by_apply_first_input_is_literal() -> None:
1058
df = pl.DataFrame({"x": [1, 2, 3, 4, 5], "g": [1, 1, 2, 2, 2]})
1059
pow = df.group_by("g").agg(2 ** pl.col("x"))
1060
assert pow.sort("g").to_dict(as_series=False) == {
1061
"g": [1, 2],
1062
"literal": [[2.0, 4.0], [8.0, 16.0, 32.0]],
1063
}
1064
1065
1066
def test_group_by_all_12869() -> None:
1067
df = pl.DataFrame({"a": [1]})
1068
result = next(iter(df.group_by(pl.all())))[1]
1069
assert_frame_equal(df, result)
1070
1071
1072
def test_group_by_named() -> None:
1073
df = pl.DataFrame({"a": [1, 1, 2, 2, 3, 3], "b": range(6)})
1074
result = df.group_by(z=pl.col("a") * 2, maintain_order=True).agg(pl.col("b").min())
1075
expected = df.group_by((pl.col("a") * 2).alias("z"), maintain_order=True).agg(
1076
pl.col("b").min()
1077
)
1078
assert_frame_equal(result, expected)
1079
1080
1081
def test_group_by_with_null() -> None:
1082
df = pl.DataFrame(
1083
{"a": [None, None, None, None], "b": [1, 1, 2, 2], "c": ["x", "y", "z", "u"]}
1084
)
1085
expected = pl.DataFrame(
1086
{"a": [None, None], "b": [1, 2], "c": [["x", "y"], ["z", "u"]]}
1087
)
1088
output = df.group_by(["a", "b"], maintain_order=True).agg(pl.col("c"))
1089
assert_frame_equal(expected, output)
1090
1091
1092
def test_partitioned_group_by_14954(plmonkeypatch: PlMonkeyPatch) -> None:
1093
plmonkeypatch.setenv("POLARS_FORCE_PARTITION", "1")
1094
assert (
1095
pl.DataFrame({"a": range(20)})
1096
.select(pl.col("a") % 2)
1097
.group_by("a")
1098
.agg(
1099
(pl.col("a") > 1000).alias("a > 1000"),
1100
)
1101
).sort("a").to_dict(as_series=False) == {
1102
"a": [0, 1],
1103
"a > 1000": [
1104
[False, False, False, False, False, False, False, False, False, False],
1105
[False, False, False, False, False, False, False, False, False, False],
1106
],
1107
}
1108
1109
1110
def test_partitioned_group_by_nulls_mean_21838() -> None:
1111
size = 10
1112
a = [1 for i in range(size)] + [2 for i in range(size)] + [3 for i in range(size)]
1113
b = [1 for i in range(size)] + [None for i in range(size * 2)]
1114
df = pl.DataFrame({"a": a, "b": b})
1115
assert df.group_by("a").mean().sort("a").to_dict(as_series=False) == {
1116
"a": [1, 2, 3],
1117
"b": [1.0, None, None],
1118
}
1119
1120
1121
def test_aggregated_scalar_elementwise_15602() -> None:
1122
df = pl.DataFrame({"group": [1, 2, 1]})
1123
1124
out = df.group_by("group", maintain_order=True).agg(
1125
foo=pl.col("group").is_between(1, pl.max("group"))
1126
)
1127
expected = pl.DataFrame({"group": [1, 2], "foo": [[True, True], [True]]})
1128
assert_frame_equal(out, expected)
1129
1130
1131
def test_group_by_multiple_null_cols_15623() -> None:
1132
df = pl.DataFrame(schema={"a": pl.Null, "b": pl.Null}).group_by(pl.all()).len()
1133
assert df.is_empty()
1134
1135
1136
@pytest.mark.release
1137
def test_categorical_vs_str_group_by() -> None:
1138
# this triggers the perfect hash table
1139
s = pl.Series("a", np.random.randint(0, 50, 100))
1140
s_with_nulls = pl.select(
1141
pl.when(s < 3).then(None).otherwise(s).alias("a")
1142
).to_series()
1143
1144
for s_ in [s, s_with_nulls]:
1145
s_ = s_.cast(str)
1146
cat_out = (
1147
s_.cast(pl.Categorical)
1148
.to_frame("a")
1149
.group_by("a")
1150
.agg(pl.first().alias("first"))
1151
)
1152
1153
str_out = s_.to_frame("a").group_by("a").agg(pl.first().alias("first"))
1154
cat_out.with_columns(pl.col("a").cast(str))
1155
assert_frame_equal(
1156
cat_out.with_columns(
1157
pl.col("a").cast(str), pl.col("first").cast(pl.List(str))
1158
).sort("a"),
1159
str_out.sort("a"),
1160
)
1161
1162
1163
@pytest.mark.release
1164
def test_boolean_min_max_agg() -> None:
1165
np.random.seed(0)
1166
idx = np.random.randint(0, 500, 1000)
1167
c = np.random.randint(0, 500, 1000) > 250
1168
1169
df = pl.DataFrame({"idx": idx, "c": c})
1170
aggs = [pl.col("c").min().alias("c_min"), pl.col("c").max().alias("c_max")]
1171
1172
result = df.group_by("idx").agg(aggs).sum()
1173
1174
schema = {"idx": pl.Int64, "c_min": pl.UInt32, "c_max": pl.UInt32}
1175
expected = pl.DataFrame(
1176
{
1177
"idx": [107583],
1178
"c_min": [120],
1179
"c_max": [321],
1180
},
1181
schema=schema,
1182
)
1183
assert_frame_equal(result, expected)
1184
1185
nulls = np.random.randint(0, 500, 1000) < 100
1186
1187
result = (
1188
df.with_columns(c=pl.when(pl.lit(nulls)).then(None).otherwise(pl.col("c")))
1189
.group_by("idx")
1190
.agg(aggs)
1191
.sum()
1192
)
1193
1194
expected = pl.DataFrame(
1195
{
1196
"idx": [107583],
1197
"c_min": [133],
1198
"c_max": [276],
1199
},
1200
schema=schema,
1201
)
1202
assert_frame_equal(result, expected)
1203
1204
1205
def test_partitioned_group_by_chunked(partition_limit: int) -> None:
1206
n = partition_limit
1207
df1 = pl.DataFrame(np.random.randn(n, 2))
1208
df2 = pl.DataFrame(np.random.randn(n, 2))
1209
gps = pl.Series(name="oo", values=[0] * n + [1] * n)
1210
df = pl.concat([df1, df2], rechunk=False)
1211
assert_frame_equal(
1212
df.group_by(gps).sum().sort("oo"),
1213
df.rechunk().group_by(gps, maintain_order=True).sum(),
1214
)
1215
1216
1217
def test_schema_on_agg() -> None:
1218
lf = pl.LazyFrame({"a": ["x", "x", "y", "n"], "b": [1, 2, 3, 4]})
1219
1220
result = lf.group_by("a").agg(
1221
pl.col("b").min().alias("min"),
1222
pl.col("b").max().alias("max"),
1223
pl.col("b").sum().alias("sum"),
1224
pl.col("b").first().alias("first"),
1225
pl.col("b").last().alias("last"),
1226
pl.col("b").item().alias("item"),
1227
)
1228
expected_schema = {
1229
"a": pl.String,
1230
"min": pl.Int64,
1231
"max": pl.Int64,
1232
"sum": pl.Int64,
1233
"first": pl.Int64,
1234
"last": pl.Int64,
1235
"item": pl.Int64,
1236
}
1237
assert result.collect_schema() == expected_schema
1238
1239
1240
def test_group_by_schema_err() -> None:
1241
lf = pl.LazyFrame({"foo": [None, 1, 2], "bar": [1, 2, 3]})
1242
with pytest.raises(ColumnNotFoundError):
1243
lf.group_by("not-existent").agg(
1244
pl.col("bar").max().alias("max_bar")
1245
).collect_schema()
1246
1247
1248
@pytest.mark.parametrize(
1249
("data", "expr", "expected_select", "expected_gb"),
1250
[
1251
(
1252
{"x": ["x"], "y": ["y"]},
1253
pl.coalesce(pl.col("x"), pl.col("y")),
1254
{"x": pl.String},
1255
{"x": pl.List(pl.String)},
1256
),
1257
(
1258
{"x": [True]},
1259
pl.col("x").sum(),
1260
{"x": pl.get_index_type()},
1261
{"x": pl.get_index_type()},
1262
),
1263
(
1264
{"a": [[1, 2]]},
1265
pl.col("a").list.sum(),
1266
{"a": pl.Int64},
1267
{"a": pl.List(pl.Int64)},
1268
),
1269
],
1270
)
1271
def test_schemas(
1272
data: dict[str, list[Any]],
1273
expr: pl.Expr,
1274
expected_select: dict[str, PolarsDataType],
1275
expected_gb: dict[str, PolarsDataType],
1276
) -> None:
1277
df = pl.DataFrame(data)
1278
1279
# test selection schema
1280
schema = df.select(expr).schema
1281
for key, dtype in expected_select.items():
1282
assert schema[key] == dtype
1283
1284
# test group_by schema
1285
schema = df.group_by(pl.lit(1)).agg(expr).schema
1286
for key, dtype in expected_gb.items():
1287
assert schema[key] == dtype
1288
1289
1290
def test_lit_iter_schema() -> None:
1291
df = pl.DataFrame(
1292
{
1293
"key": ["A", "A", "A", "A"],
1294
"dates": [
1295
date(1970, 1, 1),
1296
date(1970, 1, 1),
1297
date(1970, 1, 2),
1298
date(1970, 1, 3),
1299
],
1300
}
1301
)
1302
1303
result = df.group_by("key").agg(pl.col("dates").unique() + timedelta(days=1))
1304
expected = {
1305
"key": ["A"],
1306
"dates": [[date(1970, 1, 2), date(1970, 1, 3), date(1970, 1, 4)]],
1307
}
1308
assert result.to_dict(as_series=False) == expected
1309
1310
1311
def test_absence_off_null_prop_8224() -> None:
1312
# a reminder to self to not do null propagation
1313
# it is inconsistent and makes output dtype
1314
# dependent of the data, big no!
1315
1316
def sub_col_min(column: str, min_column: str) -> pl.Expr:
1317
return pl.col(column).sub(pl.col(min_column).min())
1318
1319
df = pl.DataFrame(
1320
{
1321
"group": [1, 1, 2, 2],
1322
"vals_num": [10.0, 11.0, 12.0, 13.0],
1323
"vals_partial": [None, None, 12.0, 13.0],
1324
"vals_null": [None, None, None, None],
1325
}
1326
)
1327
1328
q = (
1329
df.lazy()
1330
.group_by("group")
1331
.agg(
1332
sub_col_min("vals_num", "vals_num").alias("sub_num"),
1333
sub_col_min("vals_num", "vals_partial").alias("sub_partial"),
1334
sub_col_min("vals_num", "vals_null").alias("sub_null"),
1335
)
1336
)
1337
1338
assert q.collect().dtypes == [
1339
pl.Int64,
1340
pl.List(pl.Float64),
1341
pl.List(pl.Float64),
1342
pl.List(pl.Float64),
1343
]
1344
1345
1346
@pytest.mark.parametrize("maintain_order", [False, True])
1347
def test_grouped_slice_literals(maintain_order: bool) -> None:
1348
df = pl.DataFrame({"idx": [1, 2, 3]})
1349
q = (
1350
df.lazy()
1351
.group_by(True, maintain_order=maintain_order)
1352
.agg(
1353
x=pl.lit([1, 2]).slice(
1354
-1, 1
1355
), # slices a list of 1 element, so remains the same element
1356
x2=pl.lit(pl.Series([1, 2])).slice(-1, 1),
1357
x3=pl.lit(pl.Series([[1, 2]])).slice(-1, 1),
1358
)
1359
)
1360
out = q.collect()
1361
expected = pl.DataFrame(
1362
{"literal": [True], "x": [[[1, 2]]], "x2": [[2]], "x3": [[[1, 2]]]}
1363
)
1364
assert_frame_equal(
1365
out,
1366
expected,
1367
check_row_order=maintain_order,
1368
)
1369
assert q.collect_schema() == q.collect().schema
1370
1371
1372
def test_positional_by_with_list_or_tuple_17540() -> None:
1373
with pytest.raises(TypeError, match="Hint: if you"):
1374
pl.DataFrame({"a": [1, 2, 3]}).group_by(by=["a"])
1375
with pytest.raises(TypeError, match="Hint: if you"):
1376
pl.LazyFrame({"a": [1, 2, 3]}).group_by(by=["a"])
1377
1378
1379
def test_group_by_agg_19173() -> None:
1380
df = pl.DataFrame({"x": [1.0], "g": [0]})
1381
out = df.head(0).group_by("g").agg((pl.col.x - pl.col.x.sum() * pl.col.x) ** 2)
1382
assert out.to_dict(as_series=False) == {"g": [], "x": []}
1383
assert out.schema == pl.Schema([("g", pl.Int64), ("x", pl.List(pl.Float64))])
1384
1385
1386
def test_group_by_map_groups_slice_pushdown_20002() -> None:
1387
schema = {
1388
"a": pl.Int8,
1389
"b": pl.UInt8,
1390
}
1391
1392
df = (
1393
pl.LazyFrame(
1394
data={"a": [1, 2, 3, 4, 5], "b": [90, 80, 70, 60, 50]},
1395
schema=schema,
1396
)
1397
.group_by("a", maintain_order=True)
1398
.map_groups(lambda df: df * 2.0, schema=schema)
1399
.head(3)
1400
.collect()
1401
)
1402
1403
assert_frame_equal(
1404
df,
1405
pl.DataFrame(
1406
{
1407
"a": [2.0, 4.0, 6.0],
1408
"b": [180.0, 160.0, 140.0],
1409
}
1410
),
1411
)
1412
1413
1414
@typing.no_type_check
1415
def test_group_by_lit_series(capfd: Any, plmonkeypatch: PlMonkeyPatch) -> None:
1416
plmonkeypatch.setenv("POLARS_VERBOSE", "1")
1417
n = 10
1418
df = pl.DataFrame({"x": np.ones(2 * n), "y": n * list(range(2))})
1419
a = np.ones(n, dtype=float)
1420
df.lazy().group_by("y").agg(pl.col("x").dot(a)).collect()
1421
captured = capfd.readouterr().err
1422
assert "are not partitionable" in captured
1423
1424
1425
def test_group_by_list_column() -> None:
1426
df = pl.DataFrame({"a": [1, 2, 3], "b": [[1, 2], [3], [1, 2]]})
1427
result = df.group_by("b").agg(pl.sum("a")).sort("b")
1428
expected = pl.DataFrame({"b": [[1, 2], [3]], "a": [4, 2]})
1429
assert_frame_equal(result, expected)
1430
1431
1432
def test_enum_perfect_group_by_21360() -> None:
1433
dtype = pl.Enum(categories=["a", "b"])
1434
1435
assert_frame_equal(
1436
pl.from_dicts([{"col": "a"}], schema={"col": dtype})
1437
.group_by("col")
1438
.agg(pl.len()),
1439
pl.DataFrame(
1440
[
1441
pl.Series("col", ["a"], dtype),
1442
pl.Series("len", [1], get_index_type()),
1443
]
1444
),
1445
)
1446
1447
1448
def test_partitioned_group_by_21634(partition_limit: int) -> None:
1449
n = partition_limit
1450
df = pl.DataFrame({"grp": [1] * n, "x": [1] * n})
1451
assert df.group_by("grp", True).agg().to_dict(as_series=False) == {
1452
"grp": [1],
1453
"literal": [True],
1454
}
1455
1456
1457
def test_group_by_cse_dup_key_alias_22238() -> None:
1458
df = pl.LazyFrame({"a": [1, 1, 2, 2, -1], "x": [0, 1, 2, 3, 10]})
1459
result = df.group_by(
1460
pl.col("a").abs(),
1461
pl.col("a").abs().alias("a_with_alias"),
1462
).agg(pl.col("x").sum())
1463
assert_frame_equal(
1464
result.collect(),
1465
pl.DataFrame({"a": [1, 2], "a_with_alias": [1, 2], "x": [11, 5]}),
1466
check_row_order=False,
1467
)
1468
1469
1470
def test_group_by_22328() -> None:
1471
N = 20
1472
1473
df1 = pl.select(
1474
x=pl.repeat(1, N // 2).append(pl.repeat(2, N // 2)).shuffle(),
1475
y=pl.lit(3.0, pl.Float32),
1476
).lazy()
1477
1478
df2 = pl.select(x=pl.repeat(4, N)).lazy()
1479
1480
assert (
1481
df2.join(df1.group_by("x").mean().with_columns(z="y"), how="left", on="x")
1482
.with_columns(pl.col("z").fill_null(0))
1483
.collect()
1484
).shape == (20, 3)
1485
1486
1487
@pytest.mark.parametrize("maintain_order", [False, True])
1488
def test_group_by_arrays_22574(maintain_order: bool) -> None:
1489
assert_frame_equal(
1490
pl.Series("a", [[1], [2], [2]], pl.Array(pl.Int64, 1))
1491
.to_frame()
1492
.group_by("a", maintain_order=maintain_order)
1493
.agg(pl.len()),
1494
pl.DataFrame(
1495
[
1496
pl.Series("a", [[1], [2]], pl.Array(pl.Int64, 1)),
1497
pl.Series("len", [1, 2], pl.get_index_type()),
1498
]
1499
),
1500
check_row_order=maintain_order,
1501
)
1502
1503
assert_frame_equal(
1504
pl.Series(
1505
"a", [[[1, 2]], [[2, 3]], [[2, 3]]], pl.Array(pl.Array(pl.Int64, 2), 1)
1506
)
1507
.to_frame()
1508
.group_by("a", maintain_order=maintain_order)
1509
.agg(pl.len()),
1510
pl.DataFrame(
1511
[
1512
pl.Series(
1513
"a", [[[1, 2]], [[2, 3]]], pl.Array(pl.Array(pl.Int64, 2), 1)
1514
),
1515
pl.Series("len", [1, 2], pl.get_index_type()),
1516
]
1517
),
1518
check_row_order=maintain_order,
1519
)
1520
1521
1522
def test_group_by_empty_rows_with_literal_21959() -> None:
1523
out = (
1524
pl.LazyFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [1, 1, 3]})
1525
.filter(pl.col("c") == 99)
1526
.group_by(pl.lit(1).alias("d"), pl.col("a"), pl.col("b"))
1527
.agg()
1528
.collect()
1529
)
1530
expected = pl.DataFrame(
1531
{"d": [], "a": [], "b": []},
1532
schema={"d": pl.Int32, "a": pl.Int64, "b": pl.Int64},
1533
)
1534
assert_frame_equal(out, expected)
1535
1536
1537
def test_group_by_empty_dtype_22716() -> None:
1538
df = pl.DataFrame(schema={"a": pl.String, "b": pl.Int64})
1539
out = df.group_by("a").agg(x=(pl.col("b") == pl.int_range(pl.len())).all())
1540
assert_frame_equal(out, pl.DataFrame(schema={"a": pl.String, "x": pl.Boolean}))
1541
1542
1543
def test_group_by_implode_22870() -> None:
1544
out = (
1545
pl.DataFrame({"x": ["a", "b"]})
1546
.group_by(pl.col.x)
1547
.agg(
1548
y=pl.col.x.replace_strict(
1549
pl.lit(pl.Series(["a", "b"])).implode(),
1550
pl.lit(pl.Series([1, 2])).implode(),
1551
default=-1,
1552
)
1553
)
1554
)
1555
assert_frame_equal(
1556
out,
1557
pl.DataFrame({"x": ["a", "b"], "y": [[1], [2]]}),
1558
check_row_order=False,
1559
)
1560
1561
1562
# Note: the underlying bug is not guaranteed to manifest itself as it depends
1563
# on the internal group order, i.e., for the bug to materialize, there must be
1564
# empty groups before the non-empty group
1565
def test_group_by_empty_groups_23338() -> None:
1566
# We need one non-empty and many groups
1567
df = pl.DataFrame(
1568
{
1569
"k": [10, 10, 20, 30, 40, 50, 60, 70, 80, 90],
1570
"a": [1, 1, 2, 3, 4, 5, 6, 7, 8, 9],
1571
}
1572
)
1573
out = df.group_by("k").agg(
1574
pl.col("a").filter(pl.col("a") == 1).fill_nan(None).sum()
1575
)
1576
expected = df.group_by("k").agg(pl.col("a").filter(pl.col("a") == 1).sum())
1577
assert_frame_equal(out.sort("k"), expected.sort("k"))
1578
1579
1580
def test_group_by_filter_all_22955() -> None:
1581
df = pl.DataFrame(
1582
{
1583
"grp": [1, 2, 3, 4, 5],
1584
"value": [10, 20, 30, 40, 50],
1585
}
1586
)
1587
1588
assert_frame_equal(
1589
df.group_by("grp").agg(
1590
pl.all().filter(pl.col("value") > 20),
1591
),
1592
pl.DataFrame(
1593
{
1594
"grp": [1, 2, 3, 4, 5],
1595
"value": [[], [], [30], [40], [50]],
1596
}
1597
),
1598
check_row_order=False,
1599
)
1600
1601
1602
@pytest.mark.parametrize("maintain_order", [False, True])
1603
def test_group_by_series_lit_22103(maintain_order: bool) -> None:
1604
df = pl.DataFrame(
1605
{
1606
"g": [0, 1],
1607
}
1608
)
1609
assert_frame_equal(
1610
df.group_by("g", maintain_order=maintain_order).agg(
1611
foo=pl.lit(pl.Series([42, 2, 3]))
1612
),
1613
pl.DataFrame(
1614
{
1615
"g": [0, 1],
1616
"foo": [[42, 2, 3], [42, 2, 3]],
1617
}
1618
),
1619
check_row_order=maintain_order,
1620
)
1621
1622
1623
@pytest.mark.parametrize("maintain_order", [False, True])
1624
def test_group_by_filter_sum_23897(maintain_order: bool) -> None:
1625
testdf = pl.DataFrame(
1626
{
1627
"id": [8113, 9110, 9110],
1628
"value": [None, None, 1.0],
1629
"weight": [1.0, 1.0, 1.0],
1630
}
1631
)
1632
1633
w = pl.col("weight").filter(pl.col("value").is_finite())
1634
1635
w = w / w.sum()
1636
1637
result = w.sum()
1638
1639
assert_frame_equal(
1640
testdf.group_by("id", maintain_order=maintain_order).agg(result),
1641
pl.DataFrame({"id": [8113, 9110], "weight": [0.0, 1.0]}),
1642
check_row_order=maintain_order,
1643
)
1644
1645
1646
@pytest.mark.parametrize("maintain_order", [False, True])
1647
def test_group_by_shift_filter_23910(maintain_order: bool) -> None:
1648
df = pl.DataFrame({"a": [3, 7, 5, 9, 2, 1], "b": [2, 2, 2, 3, 3, 1]})
1649
1650
out = df.group_by("b", maintain_order=maintain_order).agg(
1651
pl.col("a").filter(pl.col("a") > pl.col("a").shift(1)).sum().alias("tt")
1652
)
1653
1654
assert_frame_equal(
1655
out,
1656
pl.DataFrame(
1657
{
1658
"b": [2, 3, 1],
1659
"tt": [7, 0, 0],
1660
}
1661
),
1662
check_row_order=maintain_order,
1663
)
1664
1665
1666
@pytest.mark.parametrize("maintain_order", [False, True])
1667
def test_group_by_having(maintain_order: bool) -> None:
1668
df = pl.DataFrame(
1669
{
1670
"grp": ["A", "A", "B", "B", "C", "C"],
1671
"value": [10, 15, 5, 15, 5, 10],
1672
}
1673
)
1674
1675
result = (
1676
df.group_by("grp", maintain_order=maintain_order)
1677
.having(pl.col("value").mean() >= 10)
1678
.agg()
1679
)
1680
expected = pl.DataFrame({"grp": ["A", "B"]})
1681
assert_frame_equal(result, expected, check_row_order=maintain_order)
1682
1683
1684
def test_group_by_tuple_typing_24112() -> None:
1685
df = pl.DataFrame({"id": ["a", "b", "a"], "val": [1, 2, 3]})
1686
for (id_,), _ in df.group_by("id"):
1687
_should_work: str = id_
1688
1689
1690
def test_group_by_input_independent_with_len_23868() -> None:
1691
out = pl.DataFrame({"a": ["A", "B", "C"]}).group_by(pl.lit("G")).agg(pl.len())
1692
assert_frame_equal(
1693
out,
1694
pl.DataFrame(
1695
{"literal": "G", "len": 3},
1696
schema={"literal": pl.String, "len": pl.get_index_type()},
1697
),
1698
)
1699
1700
1701
@pytest.mark.parametrize("maintain_order", [False, True])
1702
def test_group_by_head_tail_24215(maintain_order: bool) -> None:
1703
df = pl.DataFrame(
1704
{
1705
"station": ["A", "A", "B"],
1706
"num_rides": [1, 2, 3],
1707
}
1708
)
1709
expected = pl.DataFrame(
1710
{"station": ["A", "B"], "num_rides": [1.5, 3], "rides_per_day": [[1, 2], [3]]}
1711
)
1712
1713
result = (
1714
df.group_by("station", maintain_order=maintain_order)
1715
.agg(
1716
cs.numeric().mean(),
1717
pl.col("num_rides").alias("rides_per_day"),
1718
)
1719
.group_by("station", maintain_order=maintain_order)
1720
.head(1)
1721
)
1722
assert_frame_equal(result, expected, check_row_order=maintain_order)
1723
1724
result = (
1725
df.group_by("station", maintain_order=maintain_order)
1726
.agg(
1727
cs.numeric().mean(),
1728
pl.col("num_rides").alias("rides_per_day"),
1729
)
1730
.group_by("station", maintain_order=maintain_order)
1731
.tail(1)
1732
)
1733
assert_frame_equal(result, expected, check_row_order=maintain_order)
1734
1735
1736
def test_slice_group_by_offset_24259() -> None:
1737
df = pl.DataFrame(
1738
{
1739
"letters": ["c", "c", "a", "c", "a", "b", "d"],
1740
"nrs": [1, 2, 3, 4, 5, 6, None],
1741
}
1742
)
1743
assert df.group_by("letters").agg(
1744
x=pl.col("nrs").drop_nulls(),
1745
tail=pl.col("nrs").drop_nulls().tail(1),
1746
).sort("letters").to_dict(as_series=False) == {
1747
"letters": ["a", "b", "c", "d"],
1748
"x": [[3, 5], [6], [1, 2, 4], []],
1749
"tail": [[5], [6], [4], []],
1750
}
1751
1752
1753
def test_group_by_first_nondet_24278() -> None:
1754
values = [
1755
96, 86, 0, 86, 43, 50, 9, 14, 98, 39, 93, 7, 71, 1, 93, 41, 56,
1756
56, 93, 41, 58, 91, 81, 29, 81, 68, 5, 9, 32, 93, 78, 34, 17, 40,
1757
14, 2, 52, 77, 81, 4, 56, 42, 64, 12, 29, 58, 71, 98, 32, 49, 34,
1758
86, 29, 94, 37, 21, 41, 36, 9, 72, 23, 28, 71, 9, 66, 72, 84, 81,
1759
23, 12, 64, 57, 99, 15, 77, 38, 95, 64, 13, 91, 43, 61, 70, 47,
1760
39, 75, 47, 93, 45, 1, 95, 55, 29, 5, 83, 8, 3, 6, 45, 84,
1761
] # fmt: skip
1762
q = (
1763
pl.LazyFrame({"a": values, "idx": range(100)})
1764
.group_by("a")
1765
.agg(pl.col.idx.first())
1766
.select(a=pl.col.idx)
1767
)
1768
1769
fst_value = q.collect().to_series().sum()
1770
for _ in range(10):
1771
assert q.collect().to_series().sum() == fst_value
1772
1773
1774
@pytest.mark.parametrize("maintain_order", [False, True])
1775
def test_group_by_agg_on_lit(maintain_order: bool) -> None:
1776
fs: list[Callable[[Expr], Expr]] = [
1777
Expr.min,
1778
Expr.max,
1779
Expr.mean,
1780
Expr.sum,
1781
Expr.len,
1782
Expr.count,
1783
Expr.first,
1784
Expr.last,
1785
Expr.n_unique,
1786
Expr.implode,
1787
Expr.std,
1788
Expr.var,
1789
lambda e: e.quantile(0.5),
1790
Expr.nan_min,
1791
Expr.nan_max,
1792
Expr.skew,
1793
Expr.null_count,
1794
Expr.product,
1795
lambda e: pl.corr(e, e),
1796
]
1797
1798
df = pl.DataFrame({"a": [1, 2], "b": [1, 1]})
1799
1800
assert_frame_equal(
1801
df.group_by("a", maintain_order=maintain_order).agg(
1802
f(pl.lit(1)).alias(f"c{i}") for i, f in enumerate(fs)
1803
),
1804
pl.select(
1805
[pl.lit(pl.Series("a", [1, 2]))]
1806
+ [f(pl.lit(1)).alias(f"c{i}") for i, f in enumerate(fs)]
1807
),
1808
check_row_order=maintain_order,
1809
)
1810
1811
df = pl.DataFrame({"a": [1, 2], "b": [None, 1]})
1812
1813
assert_frame_equal(
1814
df.group_by("a", maintain_order=maintain_order).agg(
1815
f(pl.lit(1)).alias(f"c{i}") for i, f in enumerate(fs)
1816
),
1817
pl.select(
1818
[pl.lit(pl.Series("a", [1, 2]))]
1819
+ [f(pl.lit(1)).alias(f"c{i}") for i, f in enumerate(fs)]
1820
),
1821
check_row_order=maintain_order,
1822
)
1823
1824
1825
def test_group_by_cum_sum_key_24489() -> None:
1826
df = pl.LazyFrame({"x": [1, 2]})
1827
out = df.group_by((pl.col.x > 1).cum_sum()).agg().collect()
1828
expected = pl.DataFrame({"x": [0, 1]}, schema={"x": pl.UInt32})
1829
assert_frame_equal(out, expected, check_row_order=False)
1830
1831
1832
@pytest.mark.parametrize("maintain_order", [False, True])
1833
def test_double_aggregations(maintain_order: bool) -> None:
1834
fs: list[Callable[[pl.Expr], pl.Expr]] = [
1835
Expr.min,
1836
Expr.max,
1837
Expr.mean,
1838
Expr.sum,
1839
Expr.len,
1840
Expr.count,
1841
Expr.first,
1842
Expr.last,
1843
Expr.n_unique,
1844
Expr.implode,
1845
Expr.std,
1846
Expr.var,
1847
lambda e: e.quantile(0.5),
1848
Expr.nan_min,
1849
Expr.nan_max,
1850
Expr.skew,
1851
Expr.null_count,
1852
Expr.product,
1853
lambda e: pl.corr(e, e),
1854
]
1855
1856
df = pl.DataFrame({"a": [1, 2], "b": [1, 1]})
1857
1858
assert_frame_equal(
1859
df.group_by("a", maintain_order=maintain_order).agg(
1860
f(pl.col.b).alias(f"c{i}") for i, f in enumerate(fs)
1861
),
1862
df.group_by("a", maintain_order=maintain_order).agg(
1863
f(pl.col.b.first()).alias(f"c{i}") for i, f in enumerate(fs)
1864
),
1865
check_row_order=maintain_order,
1866
)
1867
1868
df = pl.DataFrame({"a": [1, 2], "b": [None, 1]})
1869
1870
assert_frame_equal(
1871
df.group_by("a", maintain_order=maintain_order).agg(
1872
f(pl.col.b).alias(f"c{i}") for i, f in enumerate(fs)
1873
),
1874
df.group_by("a", maintain_order=maintain_order).agg(
1875
f(pl.col.b.first()).alias(f"c{i}") for i, f in enumerate(fs)
1876
),
1877
check_row_order=maintain_order,
1878
)
1879
1880
1881
def test_group_by_length_preserving_on_scalar() -> None:
1882
df = pl.DataFrame({"a": [[1], [2], [3]]})
1883
df = df.group_by(pl.lit(1, pl.Int64)).agg(
1884
a=pl.col.a.first().reverse(),
1885
b=pl.col.a.first(),
1886
c=pl.col.a.reverse(),
1887
d=pl.lit(1, pl.Int64).reverse(),
1888
e=pl.lit(1, pl.Int64).unique(),
1889
)
1890
1891
assert_frame_equal(
1892
df,
1893
pl.DataFrame(
1894
{
1895
"literal": [1],
1896
"a": [[1]],
1897
"b": [[1]],
1898
"c": [[[3], [2], [1]]],
1899
"d": [1],
1900
"e": [[1]],
1901
}
1902
),
1903
)
1904
1905
1906
def test_group_by_enum_min_max_18394() -> None:
1907
df = pl.DataFrame(
1908
{
1909
"id": ["a", "a", "b", "b", "c", "c"],
1910
"degree": ["low", "high", "high", "mid", "mid", "low"],
1911
}
1912
).with_columns(pl.col("degree").cast(pl.Enum(["low", "mid", "high"])))
1913
out = df.group_by("id").agg(
1914
min_degree=pl.col("degree").min(),
1915
max_degree=pl.col("degree").max(),
1916
)
1917
expected = pl.DataFrame(
1918
{
1919
"id": ["a", "b", "c"],
1920
"min_degree": ["low", "mid", "low"],
1921
"max_degree": ["high", "high", "mid"],
1922
},
1923
schema={
1924
"id": pl.String,
1925
"min_degree": pl.Enum(["low", "mid", "high"]),
1926
"max_degree": pl.Enum(["low", "mid", "high"]),
1927
},
1928
)
1929
assert_frame_equal(out, expected, check_row_order=False)
1930
1931
1932
@pytest.mark.parametrize("maintain_order", [False, True])
1933
def test_group_by_filter_24838(maintain_order: bool) -> None:
1934
df = pl.DataFrame({"a": [1, 1, 2, 2, 3], "b": [1, 2, 1, 2, 1]})
1935
1936
assert_frame_equal(
1937
df.group_by("a", maintain_order=maintain_order).agg(
1938
b=pl.lit(2, pl.Int64).filter(pl.col.b != 1)
1939
),
1940
pl.DataFrame(
1941
[
1942
pl.Series("a", [1, 2, 3], pl.Int64),
1943
pl.Series("b", [[2], [2], []], pl.List(pl.Int64)),
1944
]
1945
),
1946
check_row_order=maintain_order,
1947
)
1948
1949
1950
@pytest.mark.parametrize(
1951
"lhs",
1952
[
1953
pl.lit(2),
1954
pl.col.a,
1955
pl.col.a.first(),
1956
pl.col.a.reverse(),
1957
pl.col.a.fill_null(strategy="forward"),
1958
],
1959
)
1960
@pytest.mark.parametrize(
1961
"rhs",
1962
[
1963
pl.col.b == 3,
1964
pl.col.b != 3,
1965
pl.col.b.reverse() == 3,
1966
pl.col.b.reverse() != 3,
1967
pl.col.b.fill_null(1) != 3,
1968
pl.col.b.fill_null(1) == 3,
1969
pl.lit(True),
1970
pl.lit(False),
1971
pl.lit(pl.Series([True])),
1972
pl.lit(pl.Series([False])),
1973
pl.lit(pl.Series([True])).first(),
1974
pl.lit(pl.Series([False])).first(),
1975
],
1976
)
1977
@pytest.mark.parametrize(
1978
"agg",
1979
[
1980
Expr.implode,
1981
Expr.sum,
1982
Expr.first,
1983
],
1984
)
1985
def test_group_by_filter_parametric(
1986
lhs: pl.Expr, rhs: pl.Expr, agg: Callable[[pl.Expr], pl.Expr]
1987
) -> None:
1988
df = pl.DataFrame({"a": [1, 1, 2, 2, 3], "b": [1, 2, 1, 2, 1]})
1989
gb = df.group_by(pl.lit(1)).agg(a=agg(lhs.filter(rhs))).to_series(1)
1990
gb = gb.rename("a")
1991
sl = df.select(a=agg(lhs.filter(rhs))).to_series()
1992
assert_series_equal(gb, sl)
1993
1994
1995
@given(s=series(name="a", min_size=1))
1996
@pytest.mark.parametrize(
1997
("expr", "is_scalar", "maintain_order"),
1998
[
1999
(pl.Expr.n_unique, True, True),
2000
(pl.Expr.unique, False, False),
2001
(lambda e: e.unique(maintain_order=True), False, True),
2002
],
2003
)
2004
def test_group_by_unique_parametric(
2005
s: pl.Series,
2006
expr: Callable[[pl.Expr], pl.Expr],
2007
is_scalar: bool,
2008
maintain_order: bool,
2009
) -> None:
2010
df = s.to_frame()
2011
2012
sl = df.select(expr(pl.col.a))
2013
gb = df.group_by(pl.lit(1)).agg(expr(pl.col.a)).drop("literal")
2014
if not is_scalar:
2015
gb = gb.select(pl.col.a.explode())
2016
assert_frame_equal(sl, gb, check_row_order=maintain_order)
2017
2018
# check scalar case
2019
sl_first = df.select(expr(pl.col.a.first()))
2020
gb = df.group_by(pl.lit(1)).agg(expr(pl.col.a.first())).drop("literal")
2021
if not is_scalar:
2022
gb = gb.select(pl.col.a.explode())
2023
assert_frame_equal(sl_first, gb, check_row_order=maintain_order)
2024
2025
li = df.select(pl.col.a.implode().list.eval(expr(pl.element())))
2026
li = li.select(pl.col.a.explode())
2027
assert_frame_equal(sl, li, check_row_order=maintain_order)
2028
2029
2030
@pytest.mark.parametrize(
2031
"expr",
2032
[
2033
pl.Expr.any,
2034
pl.Expr.all,
2035
lambda e: e.any(ignore_nulls=False),
2036
lambda e: e.all(ignore_nulls=False),
2037
],
2038
)
2039
def test_group_by_any_all(expr: Callable[[pl.Expr], pl.Expr]) -> None:
2040
combinations = [
2041
[True, None],
2042
[None, None],
2043
[False, None],
2044
[True, True],
2045
[False, False],
2046
[True, False],
2047
]
2048
2049
cl = cs.starts_with("x")
2050
df = pl.DataFrame(
2051
[pl.Series("g", [1, 1])]
2052
+ [pl.Series(f"x{i}", c, pl.Boolean()) for i, c in enumerate(combinations)]
2053
)
2054
2055
# verify that we are actually calculating something
2056
assert len(df.lazy().select(expr(cl)).collect_schema()) == len(combinations)
2057
2058
assert_frame_equal(
2059
df.select(expr(cl)),
2060
df.group_by(lit=pl.lit(1)).agg(expr(cl)).drop("lit"),
2061
)
2062
2063
assert_frame_equal(
2064
df.select(expr(cl)),
2065
df.group_by("g").agg(expr(cl)).drop("g"),
2066
)
2067
2068
assert_frame_equal(
2069
df.select(expr(cl)),
2070
df.select(cl.implode().list.agg(expr(pl.element()))),
2071
)
2072
2073
df = pl.Schema({"x": pl.Boolean()}).to_frame()
2074
2075
assert_frame_equal(
2076
df.select(expr(cl)),
2077
pl.DataFrame({"x": [None]})
2078
.group_by(lit=pl.lit(1))
2079
.agg(expr(pl.lit(pl.Series("x", [], pl.Boolean()))))
2080
.drop("lit"),
2081
)
2082
2083
assert_frame_equal(
2084
df.select(expr(cl)),
2085
df.select(cl.implode().list.agg(expr(pl.element()))),
2086
)
2087
2088
2089
@given(
2090
s=series(
2091
name="f",
2092
dtype=pl.Float64(),
2093
allow_chunks=False, # bug: See #24960
2094
)
2095
)
2096
@pytest.mark.may_fail_auto_streaming # bug: See #24960
2097
def test_group_by_skew_kurtosis(s: pl.Series) -> None:
2098
df = s.to_frame()
2099
2100
exprs: dict[str, Callable[[pl.Expr], pl.Expr]] = {
2101
"skew": lambda e: e.skew(),
2102
"skew_b": lambda e: e.skew(bias=False),
2103
"kurt": lambda e: e.kurtosis(),
2104
"kurt_f": lambda e: e.kurtosis(fisher=False),
2105
"kurt_b": lambda e: e.kurtosis(bias=False),
2106
"kurt_fb": lambda e: e.kurtosis(fisher=False, bias=False),
2107
}
2108
2109
sl = df.select([e(pl.col.f).alias(n) for n, e in exprs.items()])
2110
if s.len() > 0:
2111
gb = (
2112
df.group_by(pl.lit(1))
2113
.agg([e(pl.col.f).alias(n) for n, e in exprs.items()])
2114
.drop("literal")
2115
)
2116
assert_frame_equal(sl, gb)
2117
2118
# check scalar case
2119
sl_first = df.select([e(pl.col.f.first()).alias(n) for n, e in exprs.items()])
2120
gb = (
2121
df.group_by(pl.lit(1))
2122
.agg([e(pl.col.f.first()).alias(n) for n, e in exprs.items()])
2123
.drop("literal")
2124
)
2125
assert_frame_equal(sl_first, gb)
2126
2127
li = df.select(pl.col.f.implode()).select(
2128
[pl.col.f.list.agg(e(pl.element())).alias(n) for n, e in exprs.items()]
2129
)
2130
assert_frame_equal(sl, li)
2131
2132
2133
def test_group_by_rolling_fill_null_25036() -> None:
2134
frame = pl.DataFrame(
2135
{
2136
"date": [date(2013, 1, 1), date(2013, 1, 2), date(2013, 1, 3)] * 2,
2137
"group": ["A"] * 3 + ["B"] * 3,
2138
"value": [None, None, 3, 4, None, None],
2139
}
2140
)
2141
result = frame.rolling(index_column="date", period="2d", group_by="group").agg(
2142
pl.col("value").forward_fill(limit=None).last()
2143
)
2144
2145
expected = pl.DataFrame(
2146
{
2147
"group": ["A"] * 3 + ["B"] * 3,
2148
"date": [date(2013, 1, 1), date(2013, 1, 2), date(2013, 1, 3)] * 2,
2149
"value": [None, None, 3, 4, 4, None],
2150
}
2151
)
2152
2153
assert_frame_equal(result, expected)
2154
2155
2156
exprs = [
2157
pl.col.a,
2158
pl.col.a.filter(pl.col.a <= 1),
2159
pl.col.a.first(),
2160
pl.lit(1).alias("one"),
2161
pl.lit(pl.Series([1])),
2162
]
2163
2164
2165
@pytest.mark.parametrize("lhs", exprs)
2166
@pytest.mark.parametrize("rhs", exprs)
2167
@pytest.mark.parametrize("op", [pl.Expr.add, pl.Expr.pow])
2168
def test_group_broadcast_binary_apply_expr_25046(
2169
lhs: pl.Expr, rhs: pl.Expr, op: Any
2170
) -> None:
2171
df = pl.DataFrame({"g": [10, 10, 20], "a": [1, 2, 3]})
2172
groups = pl.lit(1)
2173
out = df.group_by(groups).agg((op(lhs, rhs)).implode()).to_series(1)
2174
expected = df.select((op(lhs, rhs)).implode()).to_series()
2175
assert_series_equal(out, expected)
2176
2177
2178
def test_group_by_explode_none_dtype_25045() -> None:
2179
df = pl.DataFrame({"a": [None, None, None], "b": [1.0, 2.0, None]})
2180
out_a = df.group_by(pl.lit(1)).agg(pl.col.a.explode())
2181
expected_a = pl.DataFrame({"literal": 1, "a": [[None, None, None]]})
2182
assert_frame_equal(out_a, expected_a)
2183
2184
out_b = df.group_by(pl.lit(1)).agg(pl.col.b.explode())
2185
assert len(out_a["a"][0]) == len(out_b["b"][0])
2186
2187
out_c = df.select(
2188
pl.coalesce(pl.col.a.explode(), pl.col.b.explode())
2189
.implode()
2190
.over(pl.int_range(pl.len()))
2191
)
2192
expected_c = pl.DataFrame({"a": [[1.0], [2.0], [None]]})
2193
assert_frame_equal(out_c, expected_c)
2194
2195
2196
@pytest.mark.parametrize(
2197
("expr", "is_scalar"),
2198
[
2199
(pl.Expr.forward_fill, False),
2200
(pl.Expr.backward_fill, False),
2201
(lambda e: e.forward_fill(1), False),
2202
(lambda e: e.backward_fill(1), False),
2203
(lambda e: e.forward_fill(2), False),
2204
(lambda e: e.backward_fill(2), False),
2205
(lambda e: e.forward_fill().min(), True),
2206
(lambda e: e.backward_fill().min(), True),
2207
(lambda e: e.forward_fill().first(), True),
2208
(lambda e: e.backward_fill().first(), True),
2209
],
2210
)
2211
def test_group_by_forward_backward_fill(
2212
expr: Callable[[pl.Expr], pl.Expr], is_scalar: bool
2213
) -> None:
2214
combinations = [
2215
[1, None, 2, None, None],
2216
[None, 1, 2, 3, 4],
2217
[None, None, None, None, None],
2218
[1, 2, 3, 4, 5],
2219
[1, None, 2, 3, 4],
2220
[None, None, None, None, 1],
2221
[1, None, None, None, None],
2222
[None, None, None, 1, None],
2223
[None, 1, None, None, None],
2224
]
2225
2226
cl = cs.starts_with("x")
2227
df = pl.DataFrame(
2228
[pl.Series("g", [1] * 5)]
2229
+ [pl.Series(f"x{i}", c, pl.Int64()) for i, c in enumerate(combinations)]
2230
)
2231
2232
# verify that we are actually calculating something
2233
assert len(df.lazy().select(expr(cl)).collect_schema()) == len(combinations)
2234
2235
data = df.group_by(lit=pl.lit(1)).agg(expr(cl)).drop("lit")
2236
if not is_scalar:
2237
data = data.explode(cs.all())
2238
assert_frame_equal(df.select(expr(cl)), data)
2239
2240
data = df.group_by("g").agg(expr(cl)).drop("g")
2241
if not is_scalar:
2242
data = data.explode(cs.all())
2243
assert_frame_equal(df.select(expr(cl)), data)
2244
2245
assert_frame_equal(
2246
df.select(expr(cl)),
2247
df.select(cl.implode().list.eval(expr(pl.element())).explode()),
2248
)
2249
2250
df = pl.Schema({"x": pl.Int64()}).to_frame()
2251
2252
data = (
2253
pl.DataFrame({"x": [None]})
2254
.group_by(lit=pl.lit(1))
2255
.agg(expr(pl.lit(pl.Series("x", [], pl.Int64()))))
2256
.drop("lit")
2257
)
2258
if not is_scalar:
2259
data = data.select(cs.all().reshape((-1,)))
2260
assert_frame_equal(df.select(expr(cl)), data)
2261
2262
assert_frame_equal(
2263
df.select(expr(cl)),
2264
df.select(cl.implode().list.eval(expr(pl.element())).reshape((-1,))),
2265
)
2266
2267
2268
@given(s=series())
2269
def test_group_by_drop_nulls(s: pl.Series) -> None:
2270
df = s.rename("f").to_frame()
2271
2272
data = (
2273
df.group_by(lit=pl.lit(1))
2274
.agg(pl.col.f.drop_nulls())
2275
.drop("lit")
2276
.select(pl.col.f.reshape((-1,)))
2277
)
2278
assert_frame_equal(df.select(pl.col.f.drop_nulls()), data)
2279
2280
assert_frame_equal(
2281
df.select(pl.col.f.drop_nulls()),
2282
df.select(
2283
pl.col.f.implode().list.eval(pl.element().drop_nulls()).reshape((-1,))
2284
),
2285
)
2286
2287
df = pl.Schema({"f": pl.Int64()}).to_frame()
2288
2289
data = (
2290
pl.DataFrame({"x": [None]})
2291
.group_by(lit=pl.lit(1))
2292
.agg(pl.lit(pl.Series("f", [], pl.Int64())).drop_nulls())
2293
.drop("lit")
2294
)
2295
data = data.select(cs.all().reshape((-1,)))
2296
assert_frame_equal(df.select(pl.col.f.drop_nulls()), data)
2297
2298
assert_frame_equal(
2299
df.select(pl.col.f.drop_nulls()),
2300
df.select(
2301
pl.col.f.implode().list.eval(pl.element().drop_nulls()).reshape((-1,))
2302
),
2303
)
2304
2305
2306
@given(s=series())
2307
def test_group_by_drop_nans(s: pl.Series) -> None:
2308
df = s.rename("f").to_frame()
2309
2310
data = (
2311
df.group_by(lit=pl.lit(1))
2312
.agg(pl.col.f.drop_nans())
2313
.select(pl.col.f.reshape((-1,)))
2314
)
2315
assert_frame_equal(df.select(pl.col.f.drop_nans()), data)
2316
2317
assert_frame_equal(
2318
df.select(pl.col.f.drop_nans()),
2319
df.select(
2320
pl.col.f.implode().list.eval(pl.element().drop_nans()).reshape((-1,))
2321
),
2322
)
2323
2324
df = pl.Schema({"f": pl.Int64()}).to_frame()
2325
2326
data = (
2327
pl.DataFrame({"x": [None]})
2328
.group_by(lit=pl.lit(1))
2329
.agg(pl.lit(pl.Series("f", [], pl.Int64())).drop_nans())
2330
.drop("lit")
2331
)
2332
data = data.select(cs.all().reshape((-1,)))
2333
assert_frame_equal(df.select(pl.col.f.drop_nans()), data)
2334
2335
assert_frame_equal(
2336
df.select(pl.col.f.drop_nans()),
2337
df.select(
2338
pl.col.f.implode().list.eval(pl.element().drop_nans()).reshape((-1,))
2339
),
2340
)
2341
2342
2343
@given(
2344
df=dataframes(
2345
min_size=1,
2346
include_cols=[column(name="key", dtype=pl.UInt8, allow_null=False)],
2347
),
2348
)
2349
@pytest.mark.parametrize(
2350
("expr", "check_order", "returns_scalar", "length_preserving", "is_window"),
2351
[
2352
(pl.Expr.unique, False, False, False, False),
2353
(lambda e: e.unique(maintain_order=True), True, False, False, False),
2354
(pl.Expr.drop_nans, True, False, False, False),
2355
(pl.Expr.drop_nulls, True, False, False, False),
2356
(pl.Expr.null_count, True, False, False, False),
2357
(pl.Expr.n_unique, True, True, False, False),
2358
(
2359
lambda e: e.filter(pl.int_range(0, e.len()) % 3 == 0),
2360
True,
2361
False,
2362
False,
2363
False,
2364
),
2365
(pl.Expr.shift, True, False, True, False),
2366
(pl.Expr.forward_fill, True, False, True, False),
2367
(pl.Expr.backward_fill, True, False, True, False),
2368
(pl.Expr.reverse, True, False, True, False),
2369
(
2370
lambda e: (pl.int_range(e.len() - e.len(), e.len()) % 3 == 0).any(),
2371
True,
2372
True,
2373
False,
2374
False,
2375
),
2376
(
2377
lambda e: (pl.int_range(e.len() - e.len(), e.len()) % 3 == 0).all(),
2378
True,
2379
True,
2380
False,
2381
False,
2382
),
2383
(lambda e: e.head(2), True, False, False, False),
2384
(pl.Expr.first, True, True, False, False),
2385
(pl.Expr.mode, False, False, False, False),
2386
(lambda e: e.fill_null(e.first()).over(e), True, False, True, True),
2387
(lambda e: e.first().over(e), True, False, True, True),
2388
(
2389
lambda e: e.fill_null(e.first()).over(e, mapping_strategy="join"),
2390
True,
2391
False,
2392
True,
2393
True,
2394
),
2395
(
2396
lambda e: e.fill_null(e.first()).over(e, mapping_strategy="explode"),
2397
True,
2398
False,
2399
False,
2400
True,
2401
),
2402
(
2403
lambda e: e.fill_null(strategy="forward").over([e, e]),
2404
True,
2405
False,
2406
True,
2407
True,
2408
),
2409
(lambda e: e.fill_null(e.first()).over(e, order_by=e), True, False, True, True),
2410
(
2411
lambda e: e.fill_null(e.first()).over(e, order_by=e, descending=True),
2412
True,
2413
False,
2414
True,
2415
True,
2416
),
2417
(
2418
lambda e: e.gather(pl.int_range(0, e.len()).slice(1, 3)),
2419
True,
2420
False,
2421
False,
2422
False,
2423
),
2424
],
2425
)
2426
def test_grouped_agg_parametric(
2427
df: pl.DataFrame,
2428
expr: Callable[[pl.Expr], pl.Expr],
2429
check_order: bool,
2430
returns_scalar: bool,
2431
length_preserving: bool,
2432
is_window: bool,
2433
) -> None:
2434
types: dict[str, tuple[Callable[[pl.Expr], pl.Expr], bool, bool]] = {
2435
"basic": (lambda e: e, False, True),
2436
}
2437
2438
if not is_window:
2439
types["first"] = (pl.Expr.first, True, False)
2440
types["slice"] = (lambda e: e.slice(1, 3), False, False)
2441
types["impl_expl"] = (lambda e: e.implode().explode(), False, False)
2442
types["rolling"] = (
2443
lambda e: e.rolling(pl.row_index(), period="3i"),
2444
False,
2445
True,
2446
)
2447
types["over"] = (lambda e: e.forward_fill().over(e), False, True)
2448
2449
def slit(s: pl.Series) -> pl.Expr:
2450
import polars._plr as plr
2451
2452
return pl.Expr._from_pyexpr(plr.lit(s._s, False, is_scalar=True))
2453
2454
df = df.with_columns(pl.col.key % 4)
2455
gb = df.group_by("key").agg(
2456
*[
2457
expr(t(~cs.by_name("key"))).name.prefix(f"{k}_")
2458
for k, (t, _, _) in types.items()
2459
],
2460
*[
2461
expr(slit(df[c].head(1))).alias(f"literal_{c}")
2462
for c in filter(lambda c: c != "key", df.columns)
2463
],
2464
)
2465
ls = (
2466
df.group_by("key")
2467
.agg(pl.all())
2468
.select(
2469
pl.col.key,
2470
*[
2471
(~cs.by_name("key"))
2472
.list.agg(expr(t(pl.element())))
2473
.name.prefix(f"{k}_")
2474
for k, (t, _, _) in types.items()
2475
],
2476
*[
2477
pl.col(c).list.agg(expr(slit(df[c].head(1)))).alias(f"literal_{c}")
2478
for c in filter(lambda c: c != "key", df.columns)
2479
],
2480
)
2481
)
2482
2483
if not is_window:
2484
types["literal"] = (lambda e: e, True, False)
2485
2486
def verify_index(i: int) -> None:
2487
idx_df = df.filter(pl.col.key == pl.lit(i, pl.UInt8))
2488
idx_gb = gb.filter(pl.col.key == pl.lit(i, pl.UInt8))
2489
idx_ls = ls.filter(pl.col.key == pl.lit(i, pl.UInt8))
2490
2491
for col in df.columns:
2492
if col == "key":
2493
continue
2494
2495
for k, (t, t_is_scalar, t_is_length_preserving) in types.items():
2496
c = f"{k}_{col}"
2497
2498
if k == "literal":
2499
df_s = idx_df.select(
2500
expr(t(slit(df[col].head(1)))).alias(c)
2501
).to_series()
2502
else:
2503
df_s = idx_df.select(expr(t(pl.col(col))).alias(c)).to_series()
2504
2505
gb_s = idx_gb[c]
2506
ls_s = idx_ls[c]
2507
2508
result_is_scalar = False
2509
result_is_scalar |= returns_scalar and t_is_length_preserving
2510
result_is_scalar |= t_is_scalar and length_preserving
2511
result_is_scalar &= not is_window
2512
2513
if not result_is_scalar:
2514
gb_s = gb_s.explode(empty_as_null=False)
2515
ls_s = ls_s.explode(empty_as_null=False)
2516
2517
assert_series_equal(df_s, gb_s, check_order=check_order)
2518
assert_series_equal(df_s, ls_s, check_order=check_order)
2519
2520
if 0 in df["key"]:
2521
verify_index(0)
2522
if 1 in df["key"]:
2523
verify_index(1)
2524
if 2 in df["key"]:
2525
verify_index(2)
2526
if 3 in df["key"]:
2527
verify_index(3)
2528
2529
2530
@pytest.mark.parametrize("maintain_order", [False, True])
2531
@pytest.mark.parametrize(
2532
("df", "out"),
2533
[
2534
(
2535
pl.DataFrame(
2536
{
2537
"key": [0, 0, 0, 0, 1],
2538
"a": [True, False, False, False, False],
2539
}
2540
).with_columns(
2541
a=pl.when(pl.Series([False, False, False, False, True])).then(pl.col.a)
2542
),
2543
pl.DataFrame(
2544
{
2545
"key": [0, 1],
2546
"a": [1, 1],
2547
},
2548
schema_overrides={"a": pl.get_index_type()},
2549
),
2550
),
2551
(
2552
pl.DataFrame(
2553
{
2554
"key": [0, 0, 1, 1],
2555
"a": [False, False, False, False],
2556
}
2557
).with_columns(
2558
a=pl.when(pl.Series([False, False, True, True])).then(pl.col.a)
2559
),
2560
pl.DataFrame(
2561
{
2562
"key": [0, 1],
2563
"a": [1, 1],
2564
},
2565
schema_overrides={"a": pl.get_index_type()},
2566
),
2567
),
2568
],
2569
)
2570
def test_n_unique_masked_bools(
2571
maintain_order: bool, df: pl.DataFrame, out: pl.DataFrame
2572
) -> None:
2573
df = df
2574
2575
assert_frame_equal(
2576
df.group_by("key", maintain_order=maintain_order).agg(pl.col.a.n_unique()),
2577
out,
2578
check_row_order=maintain_order,
2579
)
2580
assert_frame_equal(
2581
df.group_by("key", maintain_order=maintain_order)
2582
.agg(pl.col.a)
2583
.with_columns(pl.col.a.list.agg(pl.element().n_unique())),
2584
out,
2585
check_row_order=maintain_order,
2586
)
2587
2588
2589
@pytest.mark.parametrize("maintain_order", [False, True])
2590
@pytest.mark.parametrize("stable", [False, True])
2591
def test_group_bool_unique_25267(maintain_order: bool, stable: bool) -> None:
2592
df = pl.DataFrame(
2593
{
2594
"id": ["A", "A", "B", "B", "C", "C"],
2595
"str_values": ["D", "E", "F", "F", "G", "G"],
2596
"bool_values": [True, False, True, True, False, False],
2597
}
2598
)
2599
2600
gb = df.group_by("id", maintain_order=maintain_order).agg(
2601
pl.col("str_values", "bool_values").unique(maintain_order=stable),
2602
)
2603
2604
ls = (
2605
df.group_by("id", maintain_order=maintain_order)
2606
.agg("str_values", "bool_values")
2607
.with_columns(
2608
pl.col("str_values", "bool_values").list.agg(
2609
pl.element().unique(maintain_order=stable)
2610
)
2611
)
2612
)
2613
2614
for i in ["A", "B", "C"]:
2615
for c in ["str_values", "bool_values"]:
2616
df_s = (
2617
df.select(pl.col(c).filter(pl.col.id == pl.lit(i)))
2618
.to_series()
2619
.unique(maintain_order=stable)
2620
)
2621
gb_s = gb.select(
2622
pl.col(c).filter(pl.col.id == pl.lit(i)).reshape((-1,))
2623
).to_series()
2624
ls_s = ls.select(
2625
pl.col(c).filter(pl.col.id == pl.lit(i)).reshape((-1,))
2626
).to_series()
2627
2628
assert_series_equal(df_s, gb_s, check_order=stable)
2629
assert_series_equal(df_s, ls_s, check_order=stable)
2630
2631
2632
@pytest.mark.parametrize("group_as_slice", [False, True])
2633
@pytest.mark.parametrize("n", [10, 100, 519])
2634
@pytest.mark.parametrize(
2635
"dtype", [pl.Int32, pl.Boolean, pl.String, pl.Categorical, pl.List(pl.Int32)]
2636
)
2637
def test_group_by_first_last(
2638
group_as_slice: bool, n: int, dtype: PolarsDataType
2639
) -> None:
2640
group_by_first_last_test_impl(group_as_slice, n, dtype)
2641
2642
2643
@pytest.mark.slow
2644
@pytest.mark.parametrize("group_as_slice", [False, True])
2645
@pytest.mark.parametrize("n", [1056, 10_432])
2646
@pytest.mark.parametrize(
2647
"dtype", [pl.Int32, pl.Boolean, pl.String, pl.Categorical, pl.List(pl.Int32)]
2648
)
2649
def test_group_by_first_last_big(
2650
group_as_slice: bool, n: int, dtype: PolarsDataType
2651
) -> None:
2652
group_by_first_last_test_impl(group_as_slice, n, dtype)
2653
2654
2655
def group_by_first_last_test_impl(
2656
group_as_slice: bool, n: int, dtype: PolarsDataType
2657
) -> None:
2658
idx = pl.Series([1, 2, 3, 4, 5], dtype=pl.Int32)
2659
2660
lf = pl.LazyFrame(
2661
{
2662
"idx": pl.Series(
2663
[1] * n + [2] * n + [3] * n + [4] * n + [5] * n, dtype=pl.Int32
2664
),
2665
# Each successive group has an additional None spanning the elements
2666
"a": pl.Series(
2667
[
2668
*[None] * 0, *list(range(1, n + 1)), *[None] * 0, # idx = 1
2669
*[None] * 1, *list(range(2, n - 0)), *[None] * 1, # idx = 2
2670
*[None] * 2, *list(range(3, n - 1)), *[None] * 2, # idx = 3
2671
*[None] * 3, *list(range(4, n - 2)), *[None] * 3, # idx = 4
2672
*[None] * 4, *list(range(5, n - 3)), *[None] * 4, # idx = 5
2673
],
2674
dtype=pl.Int32,
2675
),
2676
}
2677
) # fmt: skip
2678
if group_as_slice:
2679
lf = lf.set_sorted("idx") # Use GroupSlice path
2680
2681
if dtype == pl.Categorical:
2682
# for Categorical, we must first go through String
2683
lf = lf.with_columns(pl.col("a").cast(pl.String))
2684
lf = lf.with_columns(pl.col("a").cast(dtype))
2685
2686
# first()
2687
result = lf.group_by("idx", maintain_order=True).agg(pl.col("a").first()).collect()
2688
expected_vals = pl.Series([1, None, None, None, None])
2689
if dtype == pl.Categorical:
2690
# for Categorical, we must first go through String
2691
expected_vals = expected_vals.cast(pl.String)
2692
2693
expected_vals = expected_vals.cast(dtype)
2694
expected = pl.DataFrame({"idx": idx, "a": expected_vals})
2695
assert_frame_equal(result, expected)
2696
result = lf.group_by("idx", maintain_order=True).first().collect()
2697
assert_frame_equal(result, expected)
2698
2699
# first(ignore_nulls=True)
2700
result = (
2701
lf.group_by("idx", maintain_order=True)
2702
.agg(pl.col("a").first(ignore_nulls=True))
2703
.collect()
2704
)
2705
expected_vals = pl.Series([1, 2, 3, 4, 5])
2706
if dtype == pl.Categorical:
2707
# for Categorical, we must first go through String
2708
expected_vals = expected_vals.cast(pl.String)
2709
2710
expected_vals = expected_vals.cast(dtype)
2711
expected = pl.DataFrame({"idx": idx, "a": expected_vals})
2712
assert_frame_equal(result, expected)
2713
result = lf.group_by("idx", maintain_order=True).first(ignore_nulls=True).collect()
2714
assert_frame_equal(result, expected)
2715
2716
# last()
2717
result = lf.group_by("idx", maintain_order=True).agg(pl.col("a").last()).collect()
2718
expected_vals = pl.Series([n, None, None, None, None])
2719
if dtype == pl.Categorical:
2720
# for Categorical, we must first go through String
2721
expected_vals = expected_vals.cast(pl.String)
2722
2723
expected_vals = expected_vals.cast(dtype)
2724
expected = pl.DataFrame({"idx": idx, "a": expected_vals})
2725
assert_frame_equal(result, expected)
2726
result = lf.group_by("idx", maintain_order=True).last().collect()
2727
assert_frame_equal(result, expected)
2728
2729
# last_non_null
2730
result = (
2731
lf.group_by("idx", maintain_order=True)
2732
.agg(pl.col("a").last(ignore_nulls=True))
2733
.collect()
2734
)
2735
expected_vals = pl.Series([n, n - 1, n - 2, n - 3, n - 4])
2736
if dtype == pl.Categorical:
2737
# for Categorical, we must first go through String
2738
expected_vals = expected_vals.cast(pl.String)
2739
2740
expected_vals = expected_vals.cast(dtype)
2741
expected = pl.DataFrame({"idx": idx, "a": expected_vals})
2742
assert_frame_equal(result, expected)
2743
result = lf.group_by("idx", maintain_order=True).last(ignore_nulls=True).collect()
2744
assert_frame_equal(result, expected)
2745
2746
# Test with no nulls
2747
lf = pl.LazyFrame(
2748
{
2749
"idx": pl.Series(
2750
[1] * n + [2] * n + [3] * n + [4] * n + [5] * n, dtype=pl.Int32
2751
),
2752
# Each successive group has an additional None spanning the elements
2753
"a": pl.Series(
2754
[
2755
*list(range(1, n + 1)), # idx = 1
2756
*list(range(2, n + 2)), # idx = 2
2757
*list(range(3, n + 3)), # idx = 3
2758
*list(range(4, n + 4)), # idx = 4
2759
*list(range(5, n + 5)), # idx = 5
2760
],
2761
dtype=pl.Int32,
2762
),
2763
}
2764
)
2765
if group_as_slice:
2766
lf = lf.set_sorted("idx") # Use GroupSlice path
2767
2768
if dtype == pl.Categorical:
2769
# for Categorical, we must first go through String
2770
lf = lf.with_columns(pl.col("a").cast(pl.String))
2771
lf = lf.with_columns(pl.col("a").cast(dtype))
2772
2773
# first()
2774
expected_vals = pl.Series([1, 2, 3, 4, 5])
2775
if dtype == pl.Categorical:
2776
# for Categorical, we must first go through String
2777
expected_vals = expected_vals.cast(pl.String)
2778
2779
expected_vals = expected_vals.cast(dtype)
2780
expected = pl.DataFrame({"idx": idx, "a": expected_vals})
2781
result = lf.group_by("idx", maintain_order=True).agg(pl.col("a").first()).collect()
2782
assert_frame_equal(result, expected)
2783
result = lf.group_by("idx", maintain_order=True).first().collect()
2784
assert_frame_equal(result, expected)
2785
2786
# first_non_null
2787
result = (
2788
lf.group_by("idx", maintain_order=True)
2789
.agg(pl.col("a").first(ignore_nulls=True))
2790
.collect()
2791
)
2792
assert_frame_equal(result, expected)
2793
result = lf.group_by("idx", maintain_order=True).first(ignore_nulls=True).collect()
2794
assert_frame_equal(result, expected)
2795
2796
# last()
2797
expected_vals = pl.Series([n, n + 1, n + 2, n + 3, n + 4])
2798
if dtype == pl.Categorical:
2799
# for Categorical, we must first go through String
2800
expected_vals = expected_vals.cast(pl.String)
2801
2802
expected_vals = expected_vals.cast(dtype)
2803
expected = pl.DataFrame({"idx": idx, "a": expected_vals})
2804
result = lf.group_by("idx", maintain_order=True).agg(pl.col("a").last()).collect()
2805
assert_frame_equal(result, expected)
2806
result = lf.group_by("idx", maintain_order=True).last().collect()
2807
assert_frame_equal(result, expected)
2808
2809
# last_non_null
2810
result = (
2811
lf.group_by("idx", maintain_order=True)
2812
.agg(pl.col("a").last(ignore_nulls=True))
2813
.collect()
2814
)
2815
assert_frame_equal(result, expected)
2816
result = lf.group_by("idx", maintain_order=True).last(ignore_nulls=True).collect()
2817
assert_frame_equal(result, expected)
2818
2819
2820
def test_sorted_group_by() -> None:
2821
lf = pl.LazyFrame(
2822
{
2823
"a": [1, 1, 2, 2, 3, 3, 3],
2824
"b": [4, 5, 8, 1, 0, 1, 3],
2825
}
2826
)
2827
2828
lf1 = lf
2829
lf2 = lf.set_sorted("a")
2830
2831
assert_frame_equal(
2832
*[
2833
q.group_by("a")
2834
.agg(b_first=pl.col.b.first(), b_sum=pl.col.b.sum(), b=pl.col.b)
2835
.collect(engine="streaming")
2836
for q in (lf1, lf2)
2837
],
2838
check_row_order=False,
2839
)
2840
2841
lf = lf.with_columns(c=pl.col.a.rle_id())
2842
lf1 = lf
2843
lf2 = lf.set_sorted("a", "c")
2844
2845
assert_frame_equal(
2846
*[
2847
q.group_by("a", "c")
2848
.agg(b_first=pl.col.b.first(), b_sum=pl.col.b.sum(), b=pl.col.b)
2849
.collect(engine="streaming")
2850
for q in (lf1, lf2)
2851
],
2852
check_row_order=False,
2853
)
2854
2855
2856
def test_sorted_group_by_slice() -> None:
2857
lf = (
2858
pl.DataFrame({"a": [0, 5, 2, 1, 3] * 50})
2859
.with_row_index()
2860
.with_columns(pl.col.index // 5)
2861
.lazy()
2862
.set_sorted("index")
2863
.group_by("index", maintain_order=True)
2864
.agg(pl.col.a.sum() + pl.col.index.first())
2865
)
2866
2867
expected = pl.DataFrame(
2868
[
2869
pl.Series("index", range(50), pl.get_index_type()),
2870
pl.Series("a", range(11, 11 + 50), pl.Int64),
2871
]
2872
)
2873
2874
assert_frame_equal(lf.head(2).collect(), expected.head(2))
2875
assert_frame_equal(lf.slice(1, 3).collect(), expected.slice(1, 3))
2876
assert_frame_equal(lf.tail(2).collect(), expected.tail(2))
2877
assert_frame_equal(lf.slice(5, 1).collect(), expected.slice(5, 1))
2878
assert_frame_equal(lf.slice(5, 0).collect(), expected.slice(5, 0))
2879
assert_frame_equal(lf.slice(2, 1).collect(), expected.slice(2, 1))
2880
assert_frame_equal(lf.slice(50, 1).collect(), expected.slice(50, 1))
2881
assert_frame_equal(lf.slice(20, 30).collect(), expected.slice(20, 30))
2882
assert_frame_equal(lf.slice(20, 30).collect(), expected.slice(20, 30))
2883
2884
2885
def test_agg_first_last_non_null_25405() -> None:
2886
lf = pl.LazyFrame(
2887
{
2888
"a": [1, 1, 1, 1, 2, 2, 2, 2, 2],
2889
"b": pl.Series([1, 2, 3, None, None, 4, 5, 6, None]),
2890
}
2891
)
2892
2893
# first
2894
result = lf.group_by("a", maintain_order=True).agg(
2895
pl.col("b").first(ignore_nulls=True)
2896
)
2897
expected = pl.DataFrame(
2898
{
2899
"a": [1, 2],
2900
"b": [1, 4],
2901
}
2902
)
2903
assert_frame_equal(result.collect(), expected)
2904
2905
result = lf.with_columns(pl.col("b").first(ignore_nulls=True).over("a"))
2906
expected = pl.DataFrame(
2907
{
2908
"a": [1, 1, 1, 1, 2, 2, 2, 2, 2],
2909
"b": [1, 1, 1, 1, 4, 4, 4, 4, 4],
2910
}
2911
)
2912
assert_frame_equal(result.collect(), expected)
2913
2914
# last
2915
result = lf.group_by("a", maintain_order=True).agg(
2916
pl.col("b").last(ignore_nulls=True)
2917
)
2918
expected = pl.DataFrame(
2919
{
2920
"a": [1, 2],
2921
"b": [3, 6],
2922
}
2923
)
2924
assert_frame_equal(result.collect(), expected)
2925
2926
result = lf.with_columns(pl.col("b").last(ignore_nulls=True).over("a"))
2927
expected = pl.DataFrame(
2928
{
2929
"a": [1, 1, 1, 1, 2, 2, 2, 2, 2],
2930
"b": [3, 3, 3, 3, 6, 6, 6, 6, 6],
2931
}
2932
)
2933
assert_frame_equal(result.collect(), expected)
2934
2935
2936
def test_group_by_sum_on_strings_should_error_24659() -> None:
2937
with pytest.raises(
2938
InvalidOperationError,
2939
match=r"`sum`.*operation not supported for dtype.*str",
2940
):
2941
pl.DataFrame({"str": ["a", "b"]}).group_by(1).agg(pl.col.str.sum())
2942
2943
2944
@pytest.mark.parametrize("tail", [0, 1, 4, 5, 6, 10])
2945
def test_unique_head_tail_26429(tail: int) -> None:
2946
df = pl.DataFrame(
2947
{
2948
"x": [1, 2, 3, 4, 5],
2949
}
2950
)
2951
out = df.lazy().unique().tail(tail).collect()
2952
expected = min(tail, df.height)
2953
assert len(out) == expected
2954
2955
2956
def test_group_by_cse_alias_26423() -> None:
2957
df = pl.LazyFrame({"a": [1, 2, 1, 2, 3, 4]})
2958
result = df.group_by("a").agg(pl.len(), pl.len().alias("len_a")).collect()
2959
expected = pl.DataFrame(
2960
{"a": [1, 2, 3, 4], "len": [2, 2, 1, 1], "len_a": [2, 2, 1, 1]},
2961
schema={
2962
"a": pl.Int64,
2963
"len": pl.get_index_type(),
2964
"len_a": pl.get_index_type(),
2965
},
2966
)
2967
assert_frame_equal(result, expected, check_row_order=False)
2968
2969