Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/operations/test_group_by.py
6939 views
1
from __future__ import annotations
2
3
import typing
4
from collections import OrderedDict
5
from datetime import date, datetime, timedelta
6
from typing import TYPE_CHECKING, Any
7
8
import numpy as np
9
import pytest
10
11
import polars as pl
12
import polars.selectors as cs
13
from polars.exceptions import ColumnNotFoundError
14
from polars.meta import get_index_type
15
from polars.testing import assert_frame_equal, assert_series_equal
16
17
if TYPE_CHECKING:
18
from polars._typing import PolarsDataType
19
20
21
def test_group_by() -> None:
22
df = pl.DataFrame(
23
{
24
"a": ["a", "b", "a", "b", "b", "c"],
25
"b": [1, 2, 3, 4, 5, 6],
26
"c": [6, 5, 4, 3, 2, 1],
27
}
28
)
29
30
# Use lazy API in eager group_by
31
assert sorted(df.group_by("a").agg([pl.sum("b")]).rows()) == [
32
("a", 4),
33
("b", 11),
34
("c", 6),
35
]
36
# test if it accepts a single expression
37
assert df.group_by("a", maintain_order=True).agg(pl.sum("b")).rows() == [
38
("a", 4),
39
("b", 11),
40
("c", 6),
41
]
42
43
df = pl.DataFrame(
44
{
45
"a": [1, 2, 3, 4, 5],
46
"b": ["a", "a", "b", "b", "b"],
47
"c": [None, 1, None, 1, None],
48
}
49
)
50
51
# check if this query runs and thus column names propagate
52
df.group_by("b").agg(pl.col("c").fill_null(strategy="forward")).explode("c")
53
54
# get a specific column
55
result = df.group_by("b", maintain_order=True).agg(pl.count("a"))
56
assert result.rows() == [("a", 2), ("b", 3)]
57
assert result.columns == ["b", "a"]
58
59
60
@pytest.mark.parametrize(
61
("input", "expected", "input_dtype", "output_dtype"),
62
[
63
([1, 2, 3, 4], [2, 4], pl.UInt8, pl.Float64),
64
([1, 2, 3, 4], [2, 4], pl.Int8, pl.Float64),
65
([1, 2, 3, 4], [2, 4], pl.UInt16, pl.Float64),
66
([1, 2, 3, 4], [2, 4], pl.Int16, pl.Float64),
67
([1, 2, 3, 4], [2, 4], pl.UInt32, pl.Float64),
68
([1, 2, 3, 4], [2, 4], pl.Int32, pl.Float64),
69
([1, 2, 3, 4], [2, 4], pl.UInt64, pl.Float64),
70
([1, 2, 3, 4], [2, 4], pl.Float32, pl.Float32),
71
([1, 2, 3, 4], [2, 4], pl.Float64, pl.Float64),
72
([False, True, True, True], [2 / 3, 1], pl.Boolean, pl.Float64),
73
(
74
[date(2023, 1, 1), date(2023, 1, 2), date(2023, 1, 4), date(2023, 1, 5)],
75
[datetime(2023, 1, 2, 8, 0, 0), datetime(2023, 1, 5)],
76
pl.Date,
77
pl.Datetime("us"),
78
),
79
(
80
[
81
datetime(2023, 1, 1),
82
datetime(2023, 1, 2),
83
datetime(2023, 1, 3),
84
datetime(2023, 1, 4),
85
],
86
[datetime(2023, 1, 2), datetime(2023, 1, 4)],
87
pl.Datetime("ms"),
88
pl.Datetime("ms"),
89
),
90
(
91
[
92
datetime(2023, 1, 1),
93
datetime(2023, 1, 2),
94
datetime(2023, 1, 3),
95
datetime(2023, 1, 4),
96
],
97
[datetime(2023, 1, 2), datetime(2023, 1, 4)],
98
pl.Datetime("us"),
99
pl.Datetime("us"),
100
),
101
(
102
[
103
datetime(2023, 1, 1),
104
datetime(2023, 1, 2),
105
datetime(2023, 1, 3),
106
datetime(2023, 1, 4),
107
],
108
[datetime(2023, 1, 2), datetime(2023, 1, 4)],
109
pl.Datetime("ns"),
110
pl.Datetime("ns"),
111
),
112
(
113
[timedelta(1), timedelta(2), timedelta(3), timedelta(4)],
114
[timedelta(2), timedelta(4)],
115
pl.Duration("ms"),
116
pl.Duration("ms"),
117
),
118
(
119
[timedelta(1), timedelta(2), timedelta(3), timedelta(4)],
120
[timedelta(2), timedelta(4)],
121
pl.Duration("us"),
122
pl.Duration("us"),
123
),
124
(
125
[timedelta(1), timedelta(2), timedelta(3), timedelta(4)],
126
[timedelta(2), timedelta(4)],
127
pl.Duration("ns"),
128
pl.Duration("ns"),
129
),
130
],
131
)
132
def test_group_by_mean_by_dtype(
133
input: list[Any],
134
expected: list[Any],
135
input_dtype: PolarsDataType,
136
output_dtype: PolarsDataType,
137
) -> None:
138
# groups are defined by first 3 values, then last value
139
name = str(input_dtype)
140
key = ["a", "a", "a", "b"]
141
df = pl.LazyFrame(
142
{
143
"key": key,
144
name: pl.Series(input, dtype=input_dtype),
145
}
146
)
147
result = df.group_by("key", maintain_order=True).mean()
148
df_expected = pl.DataFrame(
149
{
150
"key": ["a", "b"],
151
name: pl.Series(expected, dtype=output_dtype),
152
}
153
)
154
assert result.collect_schema() == df_expected.schema
155
assert_frame_equal(result.collect(), df_expected)
156
157
158
@pytest.mark.parametrize(
159
("input", "expected", "input_dtype", "output_dtype"),
160
[
161
([1, 2, 4, 5], [2, 5], pl.UInt8, pl.Float64),
162
([1, 2, 4, 5], [2, 5], pl.Int8, pl.Float64),
163
([1, 2, 4, 5], [2, 5], pl.UInt16, pl.Float64),
164
([1, 2, 4, 5], [2, 5], pl.Int16, pl.Float64),
165
([1, 2, 4, 5], [2, 5], pl.UInt32, pl.Float64),
166
([1, 2, 4, 5], [2, 5], pl.Int32, pl.Float64),
167
([1, 2, 4, 5], [2, 5], pl.UInt64, pl.Float64),
168
([1, 2, 4, 5], [2, 5], pl.Float32, pl.Float32),
169
([1, 2, 4, 5], [2, 5], pl.Float64, pl.Float64),
170
([False, True, True, True], [1, 1], pl.Boolean, pl.Float64),
171
(
172
[date(2023, 1, 1), date(2023, 1, 2), date(2023, 1, 4), date(2023, 1, 5)],
173
[datetime(2023, 1, 2), datetime(2023, 1, 5)],
174
pl.Date,
175
pl.Datetime("us"),
176
),
177
(
178
[
179
datetime(2023, 1, 1),
180
datetime(2023, 1, 2),
181
datetime(2023, 1, 4),
182
datetime(2023, 1, 5),
183
],
184
[datetime(2023, 1, 2), datetime(2023, 1, 5)],
185
pl.Datetime("ms"),
186
pl.Datetime("ms"),
187
),
188
(
189
[
190
datetime(2023, 1, 1),
191
datetime(2023, 1, 2),
192
datetime(2023, 1, 4),
193
datetime(2023, 1, 5),
194
],
195
[datetime(2023, 1, 2), datetime(2023, 1, 5)],
196
pl.Datetime("us"),
197
pl.Datetime("us"),
198
),
199
(
200
[
201
datetime(2023, 1, 1),
202
datetime(2023, 1, 2),
203
datetime(2023, 1, 4),
204
datetime(2023, 1, 5),
205
],
206
[datetime(2023, 1, 2), datetime(2023, 1, 5)],
207
pl.Datetime("ns"),
208
pl.Datetime("ns"),
209
),
210
(
211
[timedelta(1), timedelta(2), timedelta(4), timedelta(5)],
212
[timedelta(2), timedelta(5)],
213
pl.Duration("ms"),
214
pl.Duration("ms"),
215
),
216
(
217
[timedelta(1), timedelta(2), timedelta(4), timedelta(5)],
218
[timedelta(2), timedelta(5)],
219
pl.Duration("us"),
220
pl.Duration("us"),
221
),
222
(
223
[timedelta(1), timedelta(2), timedelta(4), timedelta(5)],
224
[timedelta(2), timedelta(5)],
225
pl.Duration("ns"),
226
pl.Duration("ns"),
227
),
228
],
229
)
230
def test_group_by_median_by_dtype(
231
input: list[Any],
232
expected: list[Any],
233
input_dtype: PolarsDataType,
234
output_dtype: PolarsDataType,
235
) -> None:
236
# groups are defined by first 3 values, then last value
237
name = str(input_dtype)
238
key = ["a", "a", "a", "b"]
239
df = pl.LazyFrame(
240
{
241
"key": key,
242
name: pl.Series(input, dtype=input_dtype),
243
}
244
)
245
result = df.group_by("key", maintain_order=True).median()
246
df_expected = pl.DataFrame(
247
{
248
"key": ["a", "b"],
249
name: pl.Series(expected, dtype=output_dtype),
250
}
251
)
252
assert result.collect_schema() == df_expected.schema
253
assert_frame_equal(result.collect(), df_expected)
254
255
256
@pytest.fixture
257
def df() -> pl.DataFrame:
258
return pl.DataFrame(
259
{
260
"a": [1, 2, 3, 4, 5],
261
"b": ["a", "a", "b", "b", "b"],
262
"c": [None, 1, None, 1, None],
263
}
264
)
265
266
267
@pytest.mark.parametrize(
268
("method", "expected"),
269
[
270
("all", [("a", [1, 2], [None, 1]), ("b", [3, 4, 5], [None, 1, None])]),
271
("len", [("a", 2), ("b", 3)]),
272
("first", [("a", 1, None), ("b", 3, None)]),
273
("last", [("a", 2, 1), ("b", 5, None)]),
274
("max", [("a", 2, 1), ("b", 5, 1)]),
275
("mean", [("a", 1.5, 1.0), ("b", 4.0, 1.0)]),
276
("median", [("a", 1.5, 1.0), ("b", 4.0, 1.0)]),
277
("min", [("a", 1, 1), ("b", 3, 1)]),
278
("n_unique", [("a", 2, 2), ("b", 3, 2)]),
279
],
280
)
281
def test_group_by_shorthands(
282
df: pl.DataFrame, method: str, expected: list[tuple[Any]]
283
) -> None:
284
gb = df.group_by("b", maintain_order=True)
285
result = getattr(gb, method)()
286
assert result.rows() == expected
287
288
gb_lazy = df.lazy().group_by("b", maintain_order=True)
289
result = getattr(gb_lazy, method)().collect()
290
assert result.rows() == expected
291
292
293
def test_group_by_shorthand_quantile(df: pl.DataFrame) -> None:
294
result = df.group_by("b", maintain_order=True).quantile(0.5)
295
expected = [("a", 2.0, 1.0), ("b", 4.0, 1.0)]
296
assert result.rows() == expected
297
298
result = df.lazy().group_by("b", maintain_order=True).quantile(0.5).collect()
299
assert result.rows() == expected
300
301
302
def test_group_by_args() -> None:
303
df = pl.DataFrame(
304
{
305
"a": ["a", "b", "a", "b", "b", "c"],
306
"b": [1, 2, 3, 4, 5, 6],
307
"c": [6, 5, 4, 3, 2, 1],
308
}
309
)
310
311
# Single column name
312
assert df.group_by("a").agg("b").columns == ["a", "b"]
313
# Column names as list
314
expected = ["a", "b", "c"]
315
assert df.group_by(["a", "b"]).agg("c").columns == expected
316
# Column names as positional arguments
317
assert df.group_by("a", "b").agg("c").columns == expected
318
# With keyword argument
319
assert df.group_by("a", "b", maintain_order=True).agg("c").columns == expected
320
# Multiple aggregations as list
321
assert df.group_by("a").agg(["b", "c"]).columns == expected
322
# Multiple aggregations as positional arguments
323
assert df.group_by("a").agg("b", "c").columns == expected
324
# Multiple aggregations as keyword arguments
325
assert df.group_by("a").agg(q="b", r="c").columns == ["a", "q", "r"]
326
327
328
def test_group_by_empty() -> None:
329
df = pl.DataFrame({"a": [1, 1, 2]})
330
result = df.group_by("a").agg()
331
expected = pl.DataFrame({"a": [1, 2]})
332
assert_frame_equal(result, expected, check_row_order=False)
333
334
335
def test_group_by_iteration() -> None:
336
df = pl.DataFrame(
337
{
338
"foo": ["a", "b", "a", "b", "b", "c"],
339
"bar": [1, 2, 3, 4, 5, 6],
340
"baz": [6, 5, 4, 3, 2, 1],
341
}
342
)
343
expected_names = ["a", "b", "c"]
344
expected_rows = [
345
[("a", 1, 6), ("a", 3, 4)],
346
[("b", 2, 5), ("b", 4, 3), ("b", 5, 2)],
347
[("c", 6, 1)],
348
]
349
gb_iter = enumerate(df.group_by("foo", maintain_order=True))
350
for i, (group, data) in gb_iter:
351
assert group == (expected_names[i],)
352
assert data.rows() == expected_rows[i]
353
354
# Grouped by ALL columns should give groups of a single row
355
result = list(df.group_by(["foo", "bar", "baz"]))
356
assert len(result) == 6
357
358
# Iterating over groups should also work when grouping by expressions
359
result2 = list(df.group_by(["foo", pl.col("bar") * pl.col("baz")]))
360
assert len(result2) == 5
361
362
# Single expression, alias in group_by
363
df = pl.DataFrame({"foo": [1, 2, 3, 4, 5, 6]})
364
gb = df.group_by((pl.col("foo") // 2).alias("bar"), maintain_order=True)
365
result3 = [(group, df.rows()) for group, df in gb]
366
expected3 = [
367
((0,), [(1,)]),
368
((1,), [(2,), (3,)]),
369
((2,), [(4,), (5,)]),
370
((3,), [(6,)]),
371
]
372
assert result3 == expected3
373
374
375
def test_group_by_iteration_selector() -> None:
376
df = pl.DataFrame({"a": ["one", "two", "one", "two"], "b": [1, 2, 3, 4]})
377
result = dict(df.group_by(cs.string()))
378
result_first = result["one",]
379
assert result_first.to_dict(as_series=False) == {"a": ["one", "one"], "b": [1, 3]}
380
381
382
@pytest.mark.parametrize("input", [[pl.col("b").sum()], pl.col("b").sum()])
383
def test_group_by_agg_input_types(input: Any) -> None:
384
df = pl.LazyFrame({"a": [1, 1, 2, 2], "b": [1, 2, 3, 4]})
385
result = df.group_by("a", maintain_order=True).agg(input)
386
expected = pl.LazyFrame({"a": [1, 2], "b": [3, 7]})
387
assert_frame_equal(result, expected)
388
389
390
@pytest.mark.parametrize("input", [str, "b".join])
391
def test_group_by_agg_bad_input_types(input: Any) -> None:
392
df = pl.LazyFrame({"a": [1, 1, 2, 2], "b": [1, 2, 3, 4]})
393
with pytest.raises(TypeError):
394
df.group_by("a").agg(input)
395
396
397
def test_group_by_sorted_empty_dataframe_3680() -> None:
398
df = (
399
pl.DataFrame(
400
[
401
pl.Series("key", [], dtype=pl.Categorical),
402
pl.Series("val", [], dtype=pl.Float64),
403
]
404
)
405
.lazy()
406
.sort("key")
407
.group_by("key")
408
.tail(1)
409
.collect(optimizations=pl.QueryOptFlags(check_order_observe=False))
410
)
411
assert df.rows() == []
412
assert df.shape == (0, 2)
413
assert df.schema == {"key": pl.Categorical(ordering="lexical"), "val": pl.Float64}
414
415
416
def test_group_by_custom_agg_empty_list() -> None:
417
assert (
418
pl.DataFrame(
419
[
420
pl.Series("key", [], dtype=pl.Categorical),
421
pl.Series("val", [], dtype=pl.Float64),
422
]
423
)
424
.group_by("key")
425
.agg(
426
[
427
pl.col("val").mean().alias("mean"),
428
pl.col("val").std().alias("std"),
429
pl.col("val").skew().alias("skew"),
430
pl.col("val").kurtosis().alias("kurt"),
431
]
432
)
433
).dtypes == [pl.Categorical, pl.Float64, pl.Float64, pl.Float64, pl.Float64]
434
435
436
def test_apply_after_take_in_group_by_3869() -> None:
437
assert (
438
pl.DataFrame(
439
{
440
"k": list("aaabbb"),
441
"t": [1, 2, 3, 4, 5, 6],
442
"v": [3, 1, 2, 5, 6, 4],
443
}
444
)
445
.group_by("k", maintain_order=True)
446
.agg(
447
pl.col("v").get(pl.col("t").arg_max()).sqrt()
448
) # <- fails for sqrt, exp, log, pow, etc.
449
).to_dict(as_series=False) == {"k": ["a", "b"], "v": [1.4142135623730951, 2.0]}
450
451
452
def test_group_by_signed_transmutes() -> None:
453
df = pl.DataFrame({"foo": [-1, -2, -3, -4, -5], "bar": [500, 600, 700, 800, 900]})
454
455
for dt in [pl.Int8, pl.Int16, pl.Int32, pl.Int64]:
456
df = (
457
df.with_columns([pl.col("foo").cast(dt), pl.col("bar")])
458
.group_by("foo", maintain_order=True)
459
.agg(pl.col("bar").median())
460
)
461
462
assert df.to_dict(as_series=False) == {
463
"foo": [-1, -2, -3, -4, -5],
464
"bar": [500.0, 600.0, 700.0, 800.0, 900.0],
465
}
466
467
468
def test_arg_sort_sort_by_groups_update__4360() -> None:
469
df = pl.DataFrame(
470
{
471
"group": ["a"] * 3 + ["b"] * 3 + ["c"] * 3,
472
"col1": [1, 2, 3] * 3,
473
"col2": [1, 2, 3, 3, 2, 1, 2, 3, 1],
474
}
475
)
476
477
out = df.with_columns(
478
pl.col("col2").arg_sort().over("group").alias("col2_arg_sort")
479
).with_columns(
480
pl.col("col1").sort_by(pl.col("col2_arg_sort")).over("group").alias("result_a"),
481
pl.col("col1")
482
.sort_by(pl.col("col2").arg_sort())
483
.over("group")
484
.alias("result_b"),
485
)
486
487
assert_series_equal(out["result_a"], out["result_b"], check_names=False)
488
assert out["result_a"].to_list() == [1, 2, 3, 3, 2, 1, 2, 3, 1]
489
490
491
def test_unique_order() -> None:
492
df = pl.DataFrame({"a": [1, 2, 1]}).with_row_index()
493
assert df.unique(keep="last", subset="a", maintain_order=True).to_dict(
494
as_series=False
495
) == {
496
"index": [1, 2],
497
"a": [2, 1],
498
}
499
assert df.unique(keep="first", subset="a", maintain_order=True).to_dict(
500
as_series=False
501
) == {
502
"index": [0, 1],
503
"a": [1, 2],
504
}
505
506
507
def test_group_by_dynamic_flat_agg_4814() -> None:
508
df = pl.DataFrame({"a": [1, 2, 2], "b": [1, 8, 12]}).set_sorted("a")
509
510
assert df.group_by_dynamic("a", every="1i", period="2i").agg(
511
[
512
(pl.col("b").sum() / pl.col("a").sum()).alias("sum_ratio_1"),
513
(pl.col("b").last() / pl.col("a").last()).alias("last_ratio_1"),
514
(pl.col("b") / pl.col("a")).last().alias("last_ratio_2"),
515
]
516
).to_dict(as_series=False) == {
517
"a": [1, 2],
518
"sum_ratio_1": [4.2, 5.0],
519
"last_ratio_1": [6.0, 6.0],
520
"last_ratio_2": [6.0, 6.0],
521
}
522
523
524
@pytest.mark.parametrize(
525
("every", "period"),
526
[
527
("10s", timedelta(seconds=100)),
528
(timedelta(seconds=10), "100s"),
529
],
530
)
531
@pytest.mark.parametrize("time_zone", [None, "UTC", "Asia/Kathmandu"])
532
def test_group_by_dynamic_overlapping_groups_flat_apply_multiple_5038(
533
every: str | timedelta, period: str | timedelta, time_zone: str | None
534
) -> None:
535
res = (
536
(
537
pl.DataFrame(
538
{
539
"a": [
540
datetime(2021, 1, 1) + timedelta(seconds=2**i)
541
for i in range(10)
542
],
543
"b": [float(i) for i in range(10)],
544
}
545
)
546
.with_columns(pl.col("a").dt.replace_time_zone(time_zone))
547
.lazy()
548
.set_sorted("a")
549
.group_by_dynamic("a", every=every, period=period)
550
.agg([pl.col("b").var().sqrt().alias("corr")])
551
)
552
.collect()
553
.sum()
554
.to_dict(as_series=False)
555
)
556
557
assert res["corr"] == pytest.approx([6.988674024215477])
558
assert res["a"] == [None]
559
560
561
def test_take_in_group_by() -> None:
562
df = pl.DataFrame({"group": [1, 1, 1, 2, 2, 2], "values": [10, 200, 3, 40, 500, 6]})
563
assert df.group_by("group").agg(
564
pl.col("values").get(1) - pl.col("values").get(2)
565
).sort("group").to_dict(as_series=False) == {"group": [1, 2], "values": [197, 494]}
566
567
568
def test_group_by_wildcard() -> None:
569
df = pl.DataFrame(
570
{
571
"a": [1, 2],
572
"b": [1, 2],
573
}
574
)
575
assert df.group_by([pl.col("*")], maintain_order=True).agg(
576
[pl.col("a").first().name.suffix("_agg")]
577
).to_dict(as_series=False) == {"a": [1, 2], "b": [1, 2], "a_agg": [1, 2]}
578
579
580
def test_group_by_all_masked_out() -> None:
581
df = pl.DataFrame(
582
{
583
"val": pl.Series(
584
[None, None, None, None], dtype=pl.Categorical, nan_to_null=True
585
).set_sorted(),
586
"col": [4, 4, 4, 4],
587
}
588
)
589
parts = df.partition_by("val")
590
assert len(parts) == 1
591
assert_frame_equal(parts[0], df)
592
593
594
def test_group_by_null_propagation_6185() -> None:
595
df_1 = pl.DataFrame({"A": [0, 0], "B": [1, 2]})
596
597
expr = pl.col("A").filter(pl.col("A") > 0)
598
599
expected = {"B": [1, 2], "A": [None, None]}
600
assert (
601
df_1.group_by("B")
602
.agg((expr - expr.mean()).mean())
603
.sort("B")
604
.to_dict(as_series=False)
605
== expected
606
)
607
608
609
def test_group_by_when_then_with_binary_and_agg_in_pred_6202() -> None:
610
df = pl.DataFrame(
611
{"code": ["a", "b", "b", "b", "a"], "xx": [1.0, -1.5, -0.2, -3.9, 3.0]}
612
)
613
assert (
614
df.group_by("code", maintain_order=True).agg(
615
[pl.when(pl.col("xx") > pl.min("xx")).then(True).otherwise(False)]
616
)
617
).to_dict(as_series=False) == {
618
"code": ["a", "b"],
619
"literal": [[False, True], [True, True, False]],
620
}
621
622
623
def test_group_by_binary_agg_with_literal() -> None:
624
df = pl.DataFrame({"id": ["a", "a", "b", "b"], "value": [1, 2, 3, 4]})
625
626
out = df.group_by("id", maintain_order=True).agg(
627
pl.col("value") + pl.Series([1, 3])
628
)
629
assert out.to_dict(as_series=False) == {"id": ["a", "b"], "value": [[2, 5], [4, 7]]}
630
631
out = df.group_by("id", maintain_order=True).agg(pl.col("value") + pl.lit(1))
632
assert out.to_dict(as_series=False) == {"id": ["a", "b"], "value": [[2, 3], [4, 5]]}
633
634
out = df.group_by("id", maintain_order=True).agg(pl.lit(1) + pl.lit(2))
635
assert out.to_dict(as_series=False) == {"id": ["a", "b"], "literal": [3, 3]}
636
637
out = df.group_by("id", maintain_order=True).agg(pl.lit(1) + pl.Series([2, 3]))
638
assert out.to_dict(as_series=False) == {
639
"id": ["a", "b"],
640
"literal": [[3, 4], [3, 4]],
641
}
642
643
out = df.group_by("id", maintain_order=True).agg(
644
value=pl.lit(pl.Series([1, 2])) + pl.lit(pl.Series([3, 4]))
645
)
646
assert out.to_dict(as_series=False) == {"id": ["a", "b"], "value": [[4, 6], [4, 6]]}
647
648
649
@pytest.mark.slow
650
@pytest.mark.parametrize("dtype", [pl.Int32, pl.UInt32])
651
def test_overflow_mean_partitioned_group_by_5194(dtype: PolarsDataType) -> None:
652
df = pl.DataFrame(
653
[
654
pl.Series("data", [10_00_00_00] * 100_000, dtype=dtype),
655
pl.Series("group", [1, 2] * 50_000, dtype=dtype),
656
]
657
)
658
result = df.group_by("group").agg(pl.col("data").mean()).sort(by="group")
659
expected = {"group": [1, 2], "data": [10000000.0, 10000000.0]}
660
assert result.to_dict(as_series=False) == expected
661
662
663
# https://github.com/pola-rs/polars/issues/7181
664
def test_group_by_multiple_column_reference() -> None:
665
df = pl.DataFrame(
666
{
667
"gr": ["a", "b", "a", "b", "a", "b"],
668
"val": [1, 20, 100, 2000, 10000, 200000],
669
}
670
)
671
result = df.group_by("gr").agg(
672
pl.col("val") + pl.col("val").shift().fill_null(0),
673
)
674
675
assert result.sort("gr").to_dict(as_series=False) == {
676
"gr": ["a", "b"],
677
"val": [[1, 101, 10100], [20, 2020, 202000]],
678
}
679
680
681
@pytest.mark.parametrize(
682
("aggregation", "args", "expected_values", "expected_dtype"),
683
[
684
("first", [], [1, None], pl.Int64),
685
("last", [], [1, None], pl.Int64),
686
("max", [], [1, None], pl.Int64),
687
("mean", [], [1.0, None], pl.Float64),
688
("median", [], [1.0, None], pl.Float64),
689
("min", [], [1, None], pl.Int64),
690
("n_unique", [], [1, 0], pl.UInt32),
691
("quantile", [0.5], [1.0, None], pl.Float64),
692
],
693
)
694
def test_group_by_empty_groups(
695
aggregation: str,
696
args: list[object],
697
expected_values: list[object],
698
expected_dtype: pl.DataType,
699
) -> None:
700
df = pl.DataFrame({"a": [1, 2], "b": [1, 2]})
701
result = df.group_by("b", maintain_order=True).agg(
702
getattr(pl.col("a").filter(pl.col("b") != 2), aggregation)(*args)
703
)
704
expected = pl.DataFrame({"b": [1, 2], "a": expected_values}).with_columns(
705
pl.col("a").cast(expected_dtype)
706
)
707
assert_frame_equal(result, expected)
708
709
710
# https://github.com/pola-rs/polars/issues/8663
711
def test_perfect_hash_table_null_values() -> None:
712
# fmt: off
713
values = ["3", "41", "17", "5", "26", "27", "43", "45", "41", "13", "45", "48", "17", "22", "31", "25", "28", "13", "7", "26", "17", "4", "43", "47", "30", "28", "8", "27", "6", "7", "26", "11", "37", "29", "49", "20", "29", "28", "23", "9", None, "38", "19", "7", "38", "3", "30", "37", "41", "5", "16", "26", "31", "6", "25", "11", "17", "31", "31", "20", "26", None, "39", "10", "38", "4", "39", "15", "13", "35", "38", "11", "39", "11", "48", "36", "18", "11", "34", "16", "28", "9", "37", "8", "17", "48", "44", "28", "25", "30", "37", "30", "18", "12", None, "27", "10", "3", "16", "27", "6"]
714
groups = ["3", "41", "17", "5", "26", "27", "43", "45", "13", "48", "22", "31", "25", "28", "7", "4", "47", "30", "8", "6", "11", "37", "29", "49", "20", "23", "9", None, "38", "19", "16", "39", "10", "15", "35", "36", "18", "34", "44", "12"]
715
# fmt: on
716
717
s = pl.Series("a", values, dtype=pl.Categorical)
718
719
result = (
720
s.to_frame("a").group_by("a", maintain_order=True).agg(pl.col("a").alias("agg"))
721
)
722
723
agg_values = [
724
["3", "3", "3"],
725
["41", "41", "41"],
726
["17", "17", "17", "17", "17"],
727
["5", "5"],
728
["26", "26", "26", "26", "26"],
729
["27", "27", "27", "27"],
730
["43", "43"],
731
["45", "45"],
732
["13", "13", "13"],
733
["48", "48", "48"],
734
["22"],
735
["31", "31", "31", "31"],
736
["25", "25", "25"],
737
["28", "28", "28", "28", "28"],
738
["7", "7", "7"],
739
["4", "4"],
740
["47"],
741
["30", "30", "30", "30"],
742
["8", "8"],
743
["6", "6", "6"],
744
["11", "11", "11", "11", "11"],
745
["37", "37", "37", "37"],
746
["29", "29"],
747
["49"],
748
["20", "20"],
749
["23"],
750
["9", "9"],
751
[None, None, None],
752
["38", "38", "38", "38"],
753
["19"],
754
["16", "16", "16"],
755
["39", "39", "39"],
756
["10", "10"],
757
["15"],
758
["35"],
759
["36"],
760
["18", "18"],
761
["34"],
762
["44"],
763
["12"],
764
]
765
expected = pl.DataFrame(
766
{
767
"a": groups,
768
"agg": agg_values,
769
},
770
schema={"a": pl.Categorical, "agg": pl.List(pl.Categorical)},
771
)
772
assert_frame_equal(result, expected)
773
774
775
def test_group_by_partitioned_ending_cast(monkeypatch: Any) -> None:
776
monkeypatch.setenv("POLARS_FORCE_PARTITION", "1")
777
df = pl.DataFrame({"a": [1] * 5, "b": [1] * 5})
778
out = df.group_by(["a", "b"]).agg(pl.len().cast(pl.Int64).alias("num"))
779
expected = pl.DataFrame({"a": [1], "b": [1], "num": [5]})
780
assert_frame_equal(out, expected)
781
782
783
def test_group_by_series_partitioned(partition_limit: int) -> None:
784
# test 15354
785
df = pl.DataFrame([0, 0] * partition_limit)
786
groups = pl.Series([0, 1] * partition_limit)
787
df.group_by(groups).agg(pl.all().is_not_null().sum())
788
789
790
def test_group_by_list_scalar_11749() -> None:
791
df = pl.DataFrame(
792
{
793
"group_name": ["a;b", "a;b", "c;d", "c;d", "a;b", "a;b"],
794
"parent_name": ["a", "b", "c", "d", "a", "b"],
795
"measurement": [
796
["x1", "x2"],
797
["x1", "x2"],
798
["y1", "y2"],
799
["z1", "z2"],
800
["x1", "x2"],
801
["x1", "x2"],
802
],
803
}
804
)
805
assert (
806
df.group_by("group_name").agg(
807
(pl.col("measurement").first() == pl.col("measurement")).alias("eq"),
808
)
809
).sort("group_name").to_dict(as_series=False) == {
810
"group_name": ["a;b", "c;d"],
811
"eq": [[True, True, True, True], [True, False]],
812
}
813
814
815
def test_group_by_with_expr_as_key() -> None:
816
gb = pl.select(x=1).group_by(pl.col("x").alias("key"))
817
result = gb.agg(pl.all().first())
818
expected = gb.agg(pl.first("x"))
819
assert_frame_equal(result, expected)
820
821
# tests: 11766
822
result = gb.head(0)
823
expected = gb.agg(pl.col("x").head(0)).explode("x")
824
assert_frame_equal(result, expected)
825
826
result = gb.tail(0)
827
expected = gb.agg(pl.col("x").tail(0)).explode("x")
828
assert_frame_equal(result, expected)
829
830
831
def test_lazy_group_by_reuse_11767() -> None:
832
lgb = pl.select(x=1).lazy().group_by("x")
833
a = lgb.len()
834
b = lgb.len()
835
assert_frame_equal(a, b)
836
837
838
def test_group_by_double_on_empty_12194() -> None:
839
df = pl.DataFrame({"group": [1], "x": [1]}).clear()
840
squared_deviation_sum = ((pl.col("x") - pl.col("x").mean()) ** 2).sum()
841
assert df.group_by("group").agg(squared_deviation_sum).schema == OrderedDict(
842
[("group", pl.Int64), ("x", pl.Float64)]
843
)
844
845
846
def test_group_by_when_then_no_aggregation_predicate() -> None:
847
df = pl.DataFrame(
848
{
849
"key": ["aa", "aa", "bb", "bb", "aa", "aa"],
850
"val": [-3, -2, 1, 4, -3, 5],
851
}
852
)
853
assert df.group_by("key").agg(
854
pos=pl.when(pl.col("val") >= 0).then(pl.col("val")).sum(),
855
neg=pl.when(pl.col("val") < 0).then(pl.col("val")).sum(),
856
).sort("key").to_dict(as_series=False) == {
857
"key": ["aa", "bb"],
858
"pos": [5, 5],
859
"neg": [-8, 0],
860
}
861
862
863
def test_group_by_apply_first_input_is_literal() -> None:
864
df = pl.DataFrame({"x": [1, 2, 3, 4, 5], "g": [1, 1, 2, 2, 2]})
865
pow = df.group_by("g").agg(2 ** pl.col("x"))
866
assert pow.sort("g").to_dict(as_series=False) == {
867
"g": [1, 2],
868
"literal": [[2.0, 4.0], [8.0, 16.0, 32.0]],
869
}
870
871
872
def test_group_by_all_12869() -> None:
873
df = pl.DataFrame({"a": [1]})
874
result = next(iter(df.group_by(pl.all())))[1]
875
assert_frame_equal(df, result)
876
877
878
def test_group_by_named() -> None:
879
df = pl.DataFrame({"a": [1, 1, 2, 2, 3, 3], "b": range(6)})
880
result = df.group_by(z=pl.col("a") * 2, maintain_order=True).agg(pl.col("b").min())
881
expected = df.group_by((pl.col("a") * 2).alias("z"), maintain_order=True).agg(
882
pl.col("b").min()
883
)
884
assert_frame_equal(result, expected)
885
886
887
def test_group_by_with_null() -> None:
888
df = pl.DataFrame(
889
{"a": [None, None, None, None], "b": [1, 1, 2, 2], "c": ["x", "y", "z", "u"]}
890
)
891
expected = pl.DataFrame(
892
{"a": [None, None], "b": [1, 2], "c": [["x", "y"], ["z", "u"]]}
893
)
894
output = df.group_by(["a", "b"], maintain_order=True).agg(pl.col("c"))
895
assert_frame_equal(expected, output)
896
897
898
def test_partitioned_group_by_14954(monkeypatch: Any) -> None:
899
monkeypatch.setenv("POLARS_FORCE_PARTITION", "1")
900
assert (
901
pl.DataFrame({"a": range(20)})
902
.select(pl.col("a") % 2)
903
.group_by("a")
904
.agg(
905
(pl.col("a") > 1000).alias("a > 1000"),
906
)
907
).sort("a").to_dict(as_series=False) == {
908
"a": [0, 1],
909
"a > 1000": [
910
[False, False, False, False, False, False, False, False, False, False],
911
[False, False, False, False, False, False, False, False, False, False],
912
],
913
}
914
915
916
def test_partitioned_group_by_nulls_mean_21838() -> None:
917
size = 10
918
a = [1 for i in range(size)] + [2 for i in range(size)] + [3 for i in range(size)]
919
b = [1 for i in range(size)] + [None for i in range(size * 2)]
920
df = pl.DataFrame({"a": a, "b": b})
921
assert df.group_by("a").mean().sort("a").to_dict(as_series=False) == {
922
"a": [1, 2, 3],
923
"b": [1.0, None, None],
924
}
925
926
927
def test_aggregated_scalar_elementwise_15602() -> None:
928
df = pl.DataFrame({"group": [1, 2, 1]})
929
930
out = df.group_by("group", maintain_order=True).agg(
931
foo=pl.col("group").is_between(1, pl.max("group"))
932
)
933
expected = pl.DataFrame({"group": [1, 2], "foo": [[True, True], [True]]})
934
assert_frame_equal(out, expected)
935
936
937
def test_group_by_multiple_null_cols_15623() -> None:
938
df = pl.DataFrame(schema={"a": pl.Null, "b": pl.Null}).group_by(pl.all()).len()
939
assert df.is_empty()
940
941
942
@pytest.mark.release
943
def test_categorical_vs_str_group_by() -> None:
944
# this triggers the perfect hash table
945
s = pl.Series("a", np.random.randint(0, 50, 100))
946
s_with_nulls = pl.select(
947
pl.when(s < 3).then(None).otherwise(s).alias("a")
948
).to_series()
949
950
for s_ in [s, s_with_nulls]:
951
s_ = s_.cast(str)
952
cat_out = (
953
s_.cast(pl.Categorical)
954
.to_frame("a")
955
.group_by("a")
956
.agg(pl.first().alias("first"))
957
)
958
959
str_out = s_.to_frame("a").group_by("a").agg(pl.first().alias("first"))
960
cat_out.with_columns(pl.col("a").cast(str))
961
assert_frame_equal(
962
cat_out.with_columns(
963
pl.col("a").cast(str), pl.col("first").cast(pl.List(str))
964
).sort("a"),
965
str_out.sort("a"),
966
)
967
968
969
@pytest.mark.release
970
def test_boolean_min_max_agg() -> None:
971
np.random.seed(0)
972
idx = np.random.randint(0, 500, 1000)
973
c = np.random.randint(0, 500, 1000) > 250
974
975
df = pl.DataFrame({"idx": idx, "c": c})
976
aggs = [pl.col("c").min().alias("c_min"), pl.col("c").max().alias("c_max")]
977
978
result = df.group_by("idx").agg(aggs).sum()
979
980
schema = {"idx": pl.Int64, "c_min": pl.UInt32, "c_max": pl.UInt32}
981
expected = pl.DataFrame(
982
{
983
"idx": [107583],
984
"c_min": [120],
985
"c_max": [321],
986
},
987
schema=schema,
988
)
989
assert_frame_equal(result, expected)
990
991
nulls = np.random.randint(0, 500, 1000) < 100
992
993
result = (
994
df.with_columns(c=pl.when(pl.lit(nulls)).then(None).otherwise(pl.col("c")))
995
.group_by("idx")
996
.agg(aggs)
997
.sum()
998
)
999
1000
expected = pl.DataFrame(
1001
{
1002
"idx": [107583],
1003
"c_min": [133],
1004
"c_max": [276],
1005
},
1006
schema=schema,
1007
)
1008
assert_frame_equal(result, expected)
1009
1010
1011
def test_partitioned_group_by_chunked(partition_limit: int) -> None:
1012
n = partition_limit
1013
df1 = pl.DataFrame(np.random.randn(n, 2))
1014
df2 = pl.DataFrame(np.random.randn(n, 2))
1015
gps = pl.Series(name="oo", values=[0] * n + [1] * n)
1016
df = pl.concat([df1, df2], rechunk=False)
1017
assert_frame_equal(
1018
df.group_by(gps).sum().sort("oo"),
1019
df.rechunk().group_by(gps, maintain_order=True).sum(),
1020
)
1021
1022
1023
def test_schema_on_agg() -> None:
1024
lf = pl.LazyFrame({"a": ["x", "x", "y", "n"], "b": [1, 2, 3, 4]})
1025
1026
result = lf.group_by("a").agg(
1027
pl.col("b").min().alias("min"),
1028
pl.col("b").max().alias("max"),
1029
pl.col("b").sum().alias("sum"),
1030
pl.col("b").first().alias("first"),
1031
pl.col("b").last().alias("last"),
1032
)
1033
expected_schema = {
1034
"a": pl.String,
1035
"min": pl.Int64,
1036
"max": pl.Int64,
1037
"sum": pl.Int64,
1038
"first": pl.Int64,
1039
"last": pl.Int64,
1040
}
1041
assert result.collect_schema() == expected_schema
1042
1043
1044
def test_group_by_schema_err() -> None:
1045
lf = pl.LazyFrame({"foo": [None, 1, 2], "bar": [1, 2, 3]})
1046
with pytest.raises(ColumnNotFoundError):
1047
lf.group_by("not-existent").agg(
1048
pl.col("bar").max().alias("max_bar")
1049
).collect_schema()
1050
1051
1052
@pytest.mark.parametrize(
1053
("data", "expr", "expected_select", "expected_gb"),
1054
[
1055
(
1056
{"x": ["x"], "y": ["y"]},
1057
pl.coalesce(pl.col("x"), pl.col("y")),
1058
{"x": pl.String},
1059
{"x": pl.List(pl.String)},
1060
),
1061
(
1062
{"x": [True]},
1063
pl.col("x").sum(),
1064
{"x": pl.UInt32},
1065
{"x": pl.UInt32},
1066
),
1067
(
1068
{"a": [[1, 2]]},
1069
pl.col("a").list.sum(),
1070
{"a": pl.Int64},
1071
{"a": pl.List(pl.Int64)},
1072
),
1073
],
1074
)
1075
def test_schemas(
1076
data: dict[str, list[Any]],
1077
expr: pl.Expr,
1078
expected_select: dict[str, PolarsDataType],
1079
expected_gb: dict[str, PolarsDataType],
1080
) -> None:
1081
df = pl.DataFrame(data)
1082
1083
# test selection schema
1084
schema = df.select(expr).schema
1085
for key, dtype in expected_select.items():
1086
assert schema[key] == dtype
1087
1088
# test group_by schema
1089
schema = df.group_by(pl.lit(1)).agg(expr).schema
1090
for key, dtype in expected_gb.items():
1091
assert schema[key] == dtype
1092
1093
1094
def test_lit_iter_schema() -> None:
1095
df = pl.DataFrame(
1096
{
1097
"key": ["A", "A", "A", "A"],
1098
"dates": [
1099
date(1970, 1, 1),
1100
date(1970, 1, 1),
1101
date(1970, 1, 2),
1102
date(1970, 1, 3),
1103
],
1104
}
1105
)
1106
1107
result = df.group_by("key").agg(pl.col("dates").unique() + timedelta(days=1))
1108
expected = {
1109
"key": ["A"],
1110
"dates": [[date(1970, 1, 2), date(1970, 1, 3), date(1970, 1, 4)]],
1111
}
1112
assert result.to_dict(as_series=False) == expected
1113
1114
1115
def test_absence_off_null_prop_8224() -> None:
1116
# a reminder to self to not do null propagation
1117
# it is inconsistent and makes output dtype
1118
# dependent of the data, big no!
1119
1120
def sub_col_min(column: str, min_column: str) -> pl.Expr:
1121
return pl.col(column).sub(pl.col(min_column).min())
1122
1123
df = pl.DataFrame(
1124
{
1125
"group": [1, 1, 2, 2],
1126
"vals_num": [10.0, 11.0, 12.0, 13.0],
1127
"vals_partial": [None, None, 12.0, 13.0],
1128
"vals_null": [None, None, None, None],
1129
}
1130
)
1131
1132
q = (
1133
df.lazy()
1134
.group_by("group")
1135
.agg(
1136
sub_col_min("vals_num", "vals_num").alias("sub_num"),
1137
sub_col_min("vals_num", "vals_partial").alias("sub_partial"),
1138
sub_col_min("vals_num", "vals_null").alias("sub_null"),
1139
)
1140
)
1141
1142
assert q.collect().dtypes == [
1143
pl.Int64,
1144
pl.List(pl.Float64),
1145
pl.List(pl.Float64),
1146
pl.List(pl.Float64),
1147
]
1148
1149
1150
@pytest.mark.parametrize("maintain_order", [False, True])
1151
def test_grouped_slice_literals(maintain_order: bool) -> None:
1152
df = pl.DataFrame({"idx": [1, 2, 3]})
1153
q = (
1154
df.lazy()
1155
.group_by(True, maintain_order=maintain_order)
1156
.agg(
1157
x=pl.lit([1, 2]).slice(
1158
-1, 1
1159
), # slices a list of 1 element, so remains the same element
1160
x2=pl.lit(pl.Series([1, 2])).slice(-1, 1),
1161
x3=pl.lit(pl.Series([[1, 2]])).slice(-1, 1),
1162
)
1163
)
1164
out = q.collect()
1165
expected = pl.DataFrame(
1166
{"literal": [True], "x": [[[1, 2]]], "x2": [[2]], "x3": [[[1, 2]]]}
1167
)
1168
assert_frame_equal(
1169
out,
1170
expected,
1171
check_row_order=maintain_order,
1172
)
1173
assert q.collect_schema() == q.collect().schema
1174
1175
1176
def test_positional_by_with_list_or_tuple_17540() -> None:
1177
with pytest.raises(TypeError, match="Hint: if you"):
1178
pl.DataFrame({"a": [1, 2, 3]}).group_by(by=["a"])
1179
with pytest.raises(TypeError, match="Hint: if you"):
1180
pl.LazyFrame({"a": [1, 2, 3]}).group_by(by=["a"])
1181
1182
1183
def test_group_by_agg_19173() -> None:
1184
df = pl.DataFrame({"x": [1.0], "g": [0]})
1185
out = df.head(0).group_by("g").agg((pl.col.x - pl.col.x.sum() * pl.col.x) ** 2)
1186
assert out.to_dict(as_series=False) == {"g": [], "x": []}
1187
assert out.schema == pl.Schema([("g", pl.Int64), ("x", pl.List(pl.Float64))])
1188
1189
1190
def test_group_by_map_groups_slice_pushdown_20002() -> None:
1191
schema = {
1192
"a": pl.Int8,
1193
"b": pl.UInt8,
1194
}
1195
1196
df = (
1197
pl.LazyFrame(
1198
data={"a": [1, 2, 3, 4, 5], "b": [90, 80, 70, 60, 50]},
1199
schema=schema,
1200
)
1201
.group_by("a", maintain_order=True)
1202
.map_groups(lambda df: df * 2.0, schema=schema)
1203
.head(3)
1204
.collect()
1205
)
1206
1207
assert_frame_equal(
1208
df,
1209
pl.DataFrame(
1210
{
1211
"a": [2.0, 4.0, 6.0],
1212
"b": [180.0, 160.0, 140.0],
1213
}
1214
),
1215
)
1216
1217
1218
@typing.no_type_check
1219
def test_group_by_lit_series(capfd: Any, monkeypatch: Any) -> None:
1220
monkeypatch.setenv("POLARS_VERBOSE", "1")
1221
n = 10
1222
df = pl.DataFrame({"x": np.ones(2 * n), "y": n * list(range(2))})
1223
a = np.ones(n, dtype=float)
1224
df.lazy().group_by("y").agg(pl.col("x").dot(a)).collect()
1225
captured = capfd.readouterr().err
1226
assert "are not partitionable" in captured
1227
1228
1229
def test_group_by_list_column() -> None:
1230
df = pl.DataFrame({"a": [1, 2, 3], "b": [[1, 2], [3], [1, 2]]})
1231
result = df.group_by("b").agg(pl.sum("a")).sort("b")
1232
expected = pl.DataFrame({"b": [[1, 2], [3]], "a": [4, 2]})
1233
assert_frame_equal(result, expected)
1234
1235
1236
def test_enum_perfect_group_by_21360() -> None:
1237
dtype = pl.Enum(categories=["a", "b"])
1238
1239
assert_frame_equal(
1240
pl.from_dicts([{"col": "a"}], schema={"col": dtype})
1241
.group_by("col")
1242
.agg(pl.len()),
1243
pl.DataFrame(
1244
[
1245
pl.Series("col", ["a"], dtype),
1246
pl.Series("len", [1], get_index_type()),
1247
]
1248
),
1249
)
1250
1251
1252
def test_partitioned_group_by_21634(partition_limit: int) -> None:
1253
n = partition_limit
1254
df = pl.DataFrame({"grp": [1] * n, "x": [1] * n})
1255
assert df.group_by("grp", True).agg().to_dict(as_series=False) == {
1256
"grp": [1],
1257
"literal": [True],
1258
}
1259
1260
1261
def test_group_by_cse_dup_key_alias_22238() -> None:
1262
df = pl.LazyFrame({"a": [1, 1, 2, 2, -1], "x": [0, 1, 2, 3, 10]})
1263
result = df.group_by(
1264
pl.col("a").abs(),
1265
pl.col("a").abs().alias("a_with_alias"),
1266
).agg(pl.col("x").sum())
1267
assert_frame_equal(
1268
result.collect(),
1269
pl.DataFrame({"a": [1, 2], "a_with_alias": [1, 2], "x": [11, 5]}),
1270
check_row_order=False,
1271
)
1272
1273
1274
def test_group_by_22328() -> None:
1275
N = 20
1276
1277
df1 = pl.select(
1278
x=pl.repeat(1, N // 2).append(pl.repeat(2, N // 2)).shuffle(),
1279
y=pl.lit(3.0, pl.Float32),
1280
).lazy()
1281
1282
df2 = pl.select(x=pl.repeat(4, N)).lazy()
1283
1284
assert (
1285
df2.join(df1.group_by("x").mean().with_columns(z="y"), how="left", on="x")
1286
.with_columns(pl.col("z").fill_null(0))
1287
.collect()
1288
).shape == (20, 3)
1289
1290
1291
@pytest.mark.parametrize("maintain_order", [False, True])
1292
def test_group_by_arrays_22574(maintain_order: bool) -> None:
1293
assert_frame_equal(
1294
pl.Series("a", [[1], [2], [2]], pl.Array(pl.Int64, 1))
1295
.to_frame()
1296
.group_by("a", maintain_order=maintain_order)
1297
.agg(pl.len()),
1298
pl.DataFrame(
1299
[
1300
pl.Series("a", [[1], [2]], pl.Array(pl.Int64, 1)),
1301
pl.Series("len", [1, 2], pl.get_index_type()),
1302
]
1303
),
1304
check_row_order=maintain_order,
1305
)
1306
1307
assert_frame_equal(
1308
pl.Series(
1309
"a", [[[1, 2]], [[2, 3]], [[2, 3]]], pl.Array(pl.Array(pl.Int64, 2), 1)
1310
)
1311
.to_frame()
1312
.group_by("a", maintain_order=maintain_order)
1313
.agg(pl.len()),
1314
pl.DataFrame(
1315
[
1316
pl.Series(
1317
"a", [[[1, 2]], [[2, 3]]], pl.Array(pl.Array(pl.Int64, 2), 1)
1318
),
1319
pl.Series("len", [1, 2], pl.get_index_type()),
1320
]
1321
),
1322
check_row_order=maintain_order,
1323
)
1324
1325
1326
def test_group_by_empty_rows_with_literal_21959() -> None:
1327
out = (
1328
pl.LazyFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [1, 1, 3]})
1329
.filter(pl.col("c") == 99)
1330
.group_by(pl.lit(1).alias("d"), pl.col("a"), pl.col("b"))
1331
.agg()
1332
.collect()
1333
)
1334
expected = pl.DataFrame(
1335
{"d": [], "a": [], "b": []},
1336
schema={"d": pl.Int32, "a": pl.Int64, "b": pl.Int64},
1337
)
1338
assert_frame_equal(out, expected)
1339
1340
1341
def test_group_by_empty_dtype_22716() -> None:
1342
df = pl.DataFrame(schema={"a": pl.String, "b": pl.Int64})
1343
out = df.group_by("a").agg(x=(pl.col("b") == pl.int_range(pl.len())).all())
1344
assert_frame_equal(out, pl.DataFrame(schema={"a": pl.String, "x": pl.Boolean}))
1345
1346
1347
def test_group_by_implode_22870() -> None:
1348
out = (
1349
pl.DataFrame({"x": ["a", "b"]})
1350
.group_by(pl.col.x)
1351
.agg(
1352
y=pl.col.x.replace_strict(
1353
pl.lit(pl.Series(["a", "b"])).implode(),
1354
pl.lit(pl.Series([1, 2])).implode(),
1355
default=-1,
1356
)
1357
)
1358
)
1359
assert_frame_equal(
1360
out,
1361
pl.DataFrame({"x": ["a", "b"], "y": [[1], [2]]}),
1362
check_row_order=False,
1363
)
1364
1365
1366
# Note: the underlying bug is not guaranteed to manifest itself as it depends
1367
# on the internal group order, i.e., for the bug to materialize, there must be
1368
# empty groups before the non-empty group
1369
def test_group_by_empty_groups_23338() -> None:
1370
# We need one non-empty and many groups
1371
df = pl.DataFrame(
1372
{
1373
"k": [10, 10, 20, 30, 40, 50, 60, 70, 80, 90],
1374
"a": [1, 1, 2, 3, 4, 5, 6, 7, 8, 9],
1375
}
1376
)
1377
out = df.group_by("k").agg(
1378
pl.col("a").filter(pl.col("a") == 1).fill_nan(None).sum()
1379
)
1380
expected = df.group_by("k").agg(pl.col("a").filter(pl.col("a") == 1).sum())
1381
assert_frame_equal(out.sort("k"), expected.sort("k"))
1382
1383
1384
def test_group_by_filter_all_22955() -> None:
1385
df = pl.DataFrame(
1386
{
1387
"grp": [1, 2, 3, 4, 5],
1388
"value": [10, 20, 30, 40, 50],
1389
}
1390
)
1391
1392
assert_frame_equal(
1393
df.group_by("grp").agg(
1394
pl.all().filter(pl.col("value") > 20),
1395
),
1396
pl.DataFrame(
1397
{
1398
"grp": [1, 2, 3, 4, 5],
1399
"value": [[], [], [30], [40], [50]],
1400
}
1401
),
1402
check_row_order=False,
1403
)
1404
1405
1406
@pytest.mark.parametrize("maintain_order", [False, True])
1407
def test_group_by_series_lit_22103(maintain_order: bool) -> None:
1408
df = pl.DataFrame(
1409
{
1410
"g": [0, 1],
1411
}
1412
)
1413
assert_frame_equal(
1414
df.group_by("g", maintain_order=maintain_order).agg(
1415
foo=pl.lit(pl.Series([42, 2, 3]))
1416
),
1417
pl.DataFrame(
1418
{
1419
"g": [0, 1],
1420
"foo": [[42, 2, 3], [42, 2, 3]],
1421
}
1422
),
1423
check_row_order=maintain_order,
1424
)
1425
1426
1427
@pytest.mark.parametrize("maintain_order", [False, True])
1428
def test_group_by_filter_sum_23897(maintain_order: bool) -> None:
1429
testdf = pl.DataFrame(
1430
{
1431
"id": [8113, 9110, 9110],
1432
"value": [None, None, 1.0],
1433
"weight": [1.0, 1.0, 1.0],
1434
}
1435
)
1436
1437
w = pl.col("weight").filter(pl.col("value").is_finite())
1438
1439
w = w / w.sum()
1440
1441
result = w.sum()
1442
1443
assert_frame_equal(
1444
testdf.group_by("id", maintain_order=maintain_order).agg(result),
1445
pl.DataFrame({"id": [8113, 9110], "weight": [0.0, 1.0]}),
1446
check_row_order=maintain_order,
1447
)
1448
1449
1450
@pytest.mark.parametrize("maintain_order", [False, True])
1451
def test_group_by_shift_filter_23910(maintain_order: bool) -> None:
1452
df = pl.DataFrame({"a": [3, 7, 5, 9, 2, 1], "b": [2, 2, 2, 3, 3, 1]})
1453
1454
out = df.group_by("b", maintain_order=maintain_order).agg(
1455
pl.col("a").filter(pl.col("a") > pl.col("a").shift(1)).sum().alias("tt")
1456
)
1457
1458
assert_frame_equal(
1459
out,
1460
pl.DataFrame(
1461
{
1462
"b": [2, 3, 1],
1463
"tt": [7, 0, 0],
1464
}
1465
),
1466
check_row_order=maintain_order,
1467
)
1468
1469
1470
def test_group_by_tuple_typing_24112() -> None:
1471
df = pl.DataFrame({"id": ["a", "b", "a"], "val": [1, 2, 3]})
1472
for (id_,), _ in df.group_by("id"):
1473
_should_work: str = id_
1474
1475
1476
def test_group_by_input_independent_with_len_23868() -> None:
1477
out = pl.DataFrame({"a": ["A", "B", "C"]}).group_by(pl.lit("G")).agg(pl.len())
1478
assert_frame_equal(
1479
out,
1480
pl.DataFrame(
1481
{"literal": "G", "len": 3},
1482
schema={"literal": pl.String, "len": pl.get_index_type()},
1483
),
1484
)
1485
1486
1487
@pytest.mark.parametrize("maintain_order", [False, True])
1488
def test_group_by_head_tail_24215(maintain_order: bool) -> None:
1489
df = pl.DataFrame(
1490
{
1491
"station": ["A", "A", "B"],
1492
"num_rides": [1, 2, 3],
1493
}
1494
)
1495
expected = pl.DataFrame(
1496
{"station": ["A", "B"], "num_rides": [1.5, 3], "rides_per_day": [[1, 2], [3]]}
1497
)
1498
1499
result = (
1500
df.group_by("station", maintain_order=maintain_order)
1501
.agg(
1502
cs.numeric().mean(),
1503
pl.col("num_rides").alias("rides_per_day"),
1504
)
1505
.group_by("station", maintain_order=maintain_order)
1506
.head(1)
1507
)
1508
assert_frame_equal(result, expected, check_row_order=maintain_order)
1509
1510
result = (
1511
df.group_by("station", maintain_order=maintain_order)
1512
.agg(
1513
cs.numeric().mean(),
1514
pl.col("num_rides").alias("rides_per_day"),
1515
)
1516
.group_by("station", maintain_order=maintain_order)
1517
.tail(1)
1518
)
1519
assert_frame_equal(result, expected, check_row_order=maintain_order)
1520
1521
1522
def test_slice_group_by_offset_24259() -> None:
1523
df = pl.DataFrame(
1524
{
1525
"letters": ["c", "c", "a", "c", "a", "b", "d"],
1526
"nrs": [1, 2, 3, 4, 5, 6, None],
1527
}
1528
)
1529
assert df.group_by("letters").agg(
1530
x=pl.col("nrs").drop_nulls(),
1531
tail=pl.col("nrs").drop_nulls().tail(1),
1532
).sort("letters").to_dict(as_series=False) == {
1533
"letters": ["a", "b", "c", "d"],
1534
"x": [[3, 5], [6], [1, 2, 4], []],
1535
"tail": [[5], [6], [4], []],
1536
}
1537
1538
1539
def test_group_by_first_nondet_24278() -> None:
1540
values = [
1541
96, 86, 0, 86, 43, 50, 9, 14, 98, 39, 93, 7, 71, 1, 93, 41, 56,
1542
56, 93, 41, 58, 91, 81, 29, 81, 68, 5, 9, 32, 93, 78, 34, 17, 40,
1543
14, 2, 52, 77, 81, 4, 56, 42, 64, 12, 29, 58, 71, 98, 32, 49, 34,
1544
86, 29, 94, 37, 21, 41, 36, 9, 72, 23, 28, 71, 9, 66, 72, 84, 81,
1545
23, 12, 64, 57, 99, 15, 77, 38, 95, 64, 13, 91, 43, 61, 70, 47,
1546
39, 75, 47, 93, 45, 1, 95, 55, 29, 5, 83, 8, 3, 6, 45, 84,
1547
] # fmt: skip
1548
q = (
1549
pl.LazyFrame({"a": values, "idx": range(100)})
1550
.group_by("a")
1551
.agg(pl.col.idx.first())
1552
.select(a=pl.col.idx)
1553
)
1554
1555
fst_value = q.collect().to_series().sum()
1556
for _ in range(10):
1557
assert q.collect().to_series().sum() == fst_value
1558
1559