Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/operations/test_gather.py
8424 views
1
import numpy as np
2
import pytest
3
4
import polars as pl
5
from polars.exceptions import ComputeError, OutOfBoundsError
6
from polars.testing import assert_frame_equal, assert_series_equal
7
8
9
def test_negative_index() -> None:
10
df = pl.DataFrame({"a": [1, 2, 3, 4, 5, 6]})
11
assert df.select(pl.col("a").gather([0, -1])).to_dict(as_series=False) == {
12
"a": [1, 6]
13
}
14
assert_frame_equal(
15
df.group_by(pl.col("a") % 2).agg(b=pl.col("a").gather([0, -1])),
16
pl.DataFrame({"a": [0, 1], "b": [[2, 6], [1, 5]]}),
17
check_row_order=False,
18
)
19
20
21
def test_gather_agg_schema() -> None:
22
df = pl.DataFrame(
23
{
24
"group": [
25
"one",
26
"one",
27
"one",
28
"two",
29
"two",
30
"two",
31
],
32
"value": [1, 98, 2, 3, 99, 4],
33
}
34
)
35
assert (
36
df.lazy()
37
.group_by("group", maintain_order=True)
38
.agg(pl.col("value").get(1))
39
.collect_schema()["value"]
40
== pl.Int64
41
)
42
43
44
def test_gather_lit_single_16535() -> None:
45
df = pl.DataFrame({"x": [1, 2, 2, 1], "y": [1, 2, 3, 4]})
46
47
assert df.group_by(["x"], maintain_order=True).agg(pl.all().gather([1])).to_dict(
48
as_series=False
49
) == {"x": [1, 2], "y": [[4], [3]]}
50
51
52
def test_list_get_null_offset_17248() -> None:
53
df = pl.DataFrame({"material": [["PB", "PVC", "CI"], ["CI"], ["CI"]]})
54
55
assert df.select(
56
result=pl.when(pl.col.material.list.len() == 1).then("material").list.get(0),
57
)["result"].to_list() == [None, "CI", "CI"]
58
59
60
def test_list_get_null_oob_17252() -> None:
61
df = pl.DataFrame(
62
{
63
"name": ["BOB-3", "BOB", None],
64
}
65
)
66
67
split = df.with_columns(pl.col("name").str.split("-"))
68
assert split.with_columns(pl.col("name").list.get(0))["name"].to_list() == [
69
"BOB",
70
"BOB",
71
None,
72
]
73
74
75
def test_list_get_null_on_oob_false_success() -> None:
76
# test Series (single offset) with nulls
77
expected = pl.Series("a", [2, None, 2], dtype=pl.Int64)
78
s_nulls = pl.Series("a", [[1, 2], None, [1, 2, 3]])
79
out = s_nulls.list.get(1, null_on_oob=False)
80
assert_series_equal(out, expected)
81
82
# test Expr (multiple offsets) with nulls
83
df = s_nulls.to_frame().with_columns(pl.lit(1).alias("idx"))
84
out = df.select(pl.col("a").list.get("idx", null_on_oob=True)).to_series()
85
assert_series_equal(out, expected)
86
87
# test Series (single offset) with no nulls
88
expected = pl.Series("a", [2, 2, 2], dtype=pl.Int64)
89
s_no_nulls = pl.Series("a", [[1, 2], [1, 2], [1, 2, 3]])
90
out = s_no_nulls.list.get(1, null_on_oob=False)
91
assert_series_equal(out, expected)
92
93
# test Expr (multiple offsets) with no nulls
94
df = s_no_nulls.to_frame().with_columns(pl.lit(1).alias("idx"))
95
out = df.select(pl.col("a").list.get("idx", null_on_oob=True)).to_series()
96
assert_series_equal(out, expected)
97
98
99
def test_list_get_null_on_oob_false_failure() -> None:
100
# test Series (single offset) with nulls
101
s_nulls = pl.Series("a", [[1, 2], None, [1, 2, 3]])
102
with pytest.raises(ComputeError, match="get index is out of bounds"):
103
s_nulls.list.get(2, null_on_oob=False)
104
105
# test Expr (multiple offsets) with nulls
106
df = s_nulls.to_frame().with_columns(pl.lit(2).alias("idx"))
107
with pytest.raises(ComputeError, match="get index is out of bounds"):
108
df.select(pl.col("a").list.get("idx", null_on_oob=False))
109
110
# test Series (single offset) with no nulls
111
s_no_nulls = pl.Series("a", [[1, 2], [1], [1, 2, 3]])
112
with pytest.raises(ComputeError, match="get index is out of bounds"):
113
s_no_nulls.list.get(2, null_on_oob=False)
114
115
# test Expr (multiple offsets) with no nulls
116
df = s_no_nulls.to_frame().with_columns(pl.lit(2).alias("idx"))
117
with pytest.raises(ComputeError, match="get index is out of bounds"):
118
df.select(pl.col("a").list.get("idx", null_on_oob=False))
119
120
121
def test_list_get_null_on_oob_true() -> None:
122
# test Series (single offset) with nulls
123
s_nulls = pl.Series("a", [[1, 2], None, [1, 2, 3]])
124
out = s_nulls.list.get(2, null_on_oob=True)
125
expected = pl.Series("a", [None, None, 3], dtype=pl.Int64)
126
assert_series_equal(out, expected)
127
128
# test Expr (multiple offsets) with nulls
129
df = s_nulls.to_frame().with_columns(pl.lit(2).alias("idx"))
130
out = df.select(pl.col("a").list.get("idx", null_on_oob=True)).to_series()
131
assert_series_equal(out, expected)
132
133
# test Series (single offset) with no nulls
134
s_no_nulls = pl.Series("a", [[1, 2], [1], [1, 2, 3]])
135
out = s_no_nulls.list.get(2, null_on_oob=True)
136
expected = pl.Series("a", [None, None, 3], dtype=pl.Int64)
137
assert_series_equal(out, expected)
138
139
# test Expr (multiple offsets) with no nulls
140
df = s_no_nulls.to_frame().with_columns(pl.lit(2).alias("idx"))
141
out = df.select(pl.col("a").list.get("idx", null_on_oob=True)).to_series()
142
assert_series_equal(out, expected)
143
144
145
def test_chunked_gather_phys_repr_17446() -> None:
146
dfa = pl.DataFrame({"replace_unique_id": range(2)})
147
148
for dt in [pl.Date, pl.Time, pl.Duration]:
149
dfb = dfa.clone()
150
dfb = dfb.with_columns(ds_start_date_right=pl.lit(None).cast(dt))
151
dfb = pl.concat([dfb, dfb])
152
153
assert dfa.join(dfb, how="left", on=pl.col("replace_unique_id")).shape == (4, 2)
154
155
156
def test_gather_str_col_18099() -> None:
157
df = pl.DataFrame({"foo": [1, 2, 3], "idx": [0, 0, 1]})
158
assert df.with_columns(pl.col("foo").gather("idx")).to_dict(as_series=False) == {
159
"foo": [1, 1, 2],
160
"idx": [0, 0, 1],
161
}
162
163
164
def test_gather_list_19243() -> None:
165
df = pl.DataFrame({"a": [[0.1, 0.2, 0.3]]})
166
assert df.with_columns(pl.lit([0]).alias("c")).with_columns(
167
gather=pl.col("a").list.gather(pl.col("c"), null_on_oob=True)
168
).to_dict(as_series=False) == {
169
"a": [[0.1, 0.2, 0.3]],
170
"c": [[0]],
171
"gather": [[0.1]],
172
}
173
174
175
def test_gather_array_list_null_19302() -> None:
176
data = pl.DataFrame(
177
{"data": [None]}, schema_overrides={"data": pl.List(pl.Array(pl.Float32, 1))}
178
)
179
assert data.select(pl.col("data").list.get(0)).to_dict(as_series=False) == {
180
"data": [None]
181
}
182
183
184
def test_gather_array() -> None:
185
a = np.arange(16).reshape(-1, 2, 2)
186
s = pl.Series(a)
187
188
for idx in [[1, 2], [0, 0], [1, 0], [1, 1, 1, 1, 1, 1, 1, 1]]:
189
assert (s.gather(idx).to_numpy() == a[idx]).all()
190
191
v = s[[0, 1, None, 3]] # type: ignore[list-item]
192
assert v[2] is None
193
194
195
def test_gather_array_outer_validity_19482() -> None:
196
s = (
197
pl.Series([[1], [1]], dtype=pl.Array(pl.Int64, 1))
198
.to_frame()
199
.select(pl.when(pl.int_range(pl.len()) == 0).then(pl.first()))
200
.to_series()
201
)
202
203
expect = pl.Series([[1], None], dtype=pl.Array(pl.Int64, 1))
204
assert_series_equal(s, expect)
205
assert_series_equal(s.gather([0, 1]), expect)
206
207
208
def test_gather_len_19561() -> None:
209
N = 4
210
df = pl.DataFrame({"foo": ["baz"] * N, "bar": range(N)})
211
212
idxs = (
213
pl.int_range(1, N)
214
.repeat_by(pl.int_range(1, N))
215
.list.explode(keep_nulls=False, empty_as_null=False)
216
)
217
gather = pl.col("bar").gather(idxs).alias("gather")
218
219
assert df.group_by("foo").agg(gather.len()).to_dict(as_series=False) == {
220
"foo": ["baz"],
221
"gather": [6],
222
}
223
224
225
def test_gather_agg_group_update_scalar() -> None:
226
# If `gather` doesn't update groups properly, `first` will try to access
227
# index 2 (the original index of the first element of group `1`), but gather
228
# outputs only two elements (one for each group), leading to an out of
229
# bounds access.
230
df = (
231
pl.DataFrame({"gid": [0, 0, 1, 1], "x": ["0:0", "0:1", "1:0", "1:1"]})
232
.lazy()
233
.group_by("gid", maintain_order=True)
234
.agg(x_at_gid=pl.col("x").gather(pl.col("gid").last()).first())
235
.collect(optimizations=pl.QueryOptFlags.none())
236
)
237
expected = pl.DataFrame({"gid": [0, 1], "x_at_gid": ["0:0", "1:1"]})
238
assert_frame_equal(df, expected)
239
240
241
def test_gather_agg_group_update_literal() -> None:
242
# If `gather` doesn't update groups properly, `first` will try to access
243
# index 2 (the original index of the first element of group `1`), but gather
244
# outputs only two elements (one for each group), leading to an out of
245
# bounds access.
246
df = (
247
pl.DataFrame({"gid": [0, 0, 1], "x": ["0:0", "0:1", "1:0"]})
248
.lazy()
249
.group_by("gid", maintain_order=True)
250
.agg(x_at_0=pl.col("x").gather(0).first())
251
.collect(optimizations=pl.QueryOptFlags.none())
252
)
253
expected = pl.DataFrame({"gid": [0, 1], "x_at_0": ["0:0", "1:0"]})
254
assert_frame_equal(df, expected)
255
256
257
def test_gather_agg_group_update_negative() -> None:
258
# If `gather` doesn't update groups properly, `first` will try to access
259
# index 2 (the original index of the first element of group `1`), but gather
260
# outputs only two elements (one for each group), leading to an out of
261
# bounds access.
262
df = (
263
pl.DataFrame({"gid": [0, 0, 1], "x": ["0:0", "0:1", "1:0"]})
264
.lazy()
265
.group_by("gid", maintain_order=True)
266
.agg(x_last=pl.col("x").gather(-1).first())
267
.collect(optimizations=pl.QueryOptFlags.none())
268
)
269
expected = pl.DataFrame({"gid": [0, 1], "x_last": ["0:1", "1:0"]})
270
assert_frame_equal(df, expected)
271
272
273
def test_gather_agg_group_update_multiple() -> None:
274
# If `gather` doesn't update groups properly, `first` will try to access
275
# index 4 (the original index of the first element of group `1`), but gather
276
# outputs only four elements (two for each group), leading to an out of
277
# bounds access.
278
df = (
279
pl.DataFrame(
280
{
281
"gid": [0, 0, 0, 0, 1, 1],
282
"x": ["0:0", "0:1", "0:2", "0:3", "1:0", "1:1"],
283
}
284
)
285
.lazy()
286
.group_by("gid", maintain_order=True)
287
.agg(x_at_0=pl.col("x").gather([0, 1]).first())
288
.collect(optimizations=pl.QueryOptFlags.none())
289
)
290
expected = pl.DataFrame({"gid": [0, 1], "x_at_0": ["0:0", "1:0"]})
291
assert_frame_equal(df, expected)
292
293
294
def test_get_agg_group_update_literal_21610() -> None:
295
df = (
296
pl.DataFrame(
297
{
298
"group": [100, 100, 100, 200, 200, 200],
299
"value": [1, 2, 3, 2, 3, 4],
300
}
301
)
302
.group_by("group", maintain_order=True)
303
.agg(pl.col("value") - pl.col("value").get(0))
304
)
305
306
expected = pl.DataFrame({"group": [100, 200], "value": [[0, 1, 2], [0, 1, 2]]})
307
assert_frame_equal(df, expected)
308
309
310
def test_get_agg_group_update_scalar_21610() -> None:
311
df = (
312
pl.DataFrame(
313
{
314
"group": [100, 100, 100, 200, 200, 200],
315
"value": [1, 2, 3, 2, 3, 4],
316
}
317
)
318
.group_by("group", maintain_order=True)
319
.agg(pl.col("value") - pl.col("value").get(pl.col("value").first()))
320
)
321
322
expected = pl.DataFrame({"group": [100, 200], "value": [[-1, 0, 1], [-2, -1, 0]]})
323
assert_frame_equal(df, expected)
324
325
326
def test_get_dt_truncate_21533() -> None:
327
df = pl.DataFrame(
328
{
329
"timestamp": pl.datetime_range(
330
pl.datetime(2016, 1, 1),
331
pl.datetime(2017, 12, 31),
332
interval="1d",
333
eager=True,
334
),
335
}
336
).with_columns(
337
month=pl.col.timestamp.dt.month(),
338
)
339
340
report = df.group_by("month", maintain_order=True).agg(
341
trunc_ts=pl.col.timestamp.get(0).dt.truncate("1m")
342
)
343
assert report.shape == (12, 2)
344
345
346
@pytest.mark.parametrize("maintain_order", [False, True])
347
def test_gather_group_by_23696(maintain_order: bool) -> None:
348
df = (
349
pl.DataFrame(
350
{
351
"a": [1, 2, 3, 4],
352
"b": [0, 0, 1, 1],
353
"c": [0, 0, -1, -1],
354
}
355
)
356
.group_by(pl.col.a % 2, maintain_order=maintain_order)
357
.agg(
358
get_first=pl.col.a.get(pl.col.b.get(0)),
359
get_last=pl.col.a.get(pl.col.b.get(1)),
360
normal=pl.col.a.gather(pl.col.b),
361
signed=pl.col.a.gather(pl.col.c),
362
drop_nulls=pl.col.a.gather(pl.col.b.drop_nulls()),
363
drop_nulls_signed=pl.col.a.gather(pl.col.c.drop_nulls()),
364
literal=pl.col.a.gather([0, 1]),
365
literal_signed=pl.col.a.gather([0, -1]),
366
)
367
)
368
369
expected = pl.DataFrame(
370
{
371
"a": [1, 0],
372
"get_first": [1, 2],
373
"get_last": [3, 4],
374
"normal": [[1, 3], [2, 4]],
375
"signed": [[1, 3], [2, 4]],
376
"drop_nulls": [[1, 3], [2, 4]],
377
"drop_nulls_signed": [[1, 3], [2, 4]],
378
"literal": [[1, 3], [2, 4]],
379
"literal_signed": [[1, 3], [2, 4]],
380
}
381
)
382
383
assert_frame_equal(df, expected, check_row_order=maintain_order)
384
385
386
def test_gather_invalid_indices_groupby_24182() -> None:
387
df = pl.DataFrame({"x": [1, 2]})
388
with pytest.raises(pl.exceptions.InvalidOperationError):
389
df.group_by(True).agg(pl.col("x").gather(pl.lit("y")))
390
391
392
@pytest.mark.parametrize("maintain_order", [False, True])
393
def test_gather_group_by_lit(maintain_order: bool) -> None:
394
assert_frame_equal(
395
pl.DataFrame(
396
{
397
"a": [1, 2, 3],
398
}
399
)
400
.group_by("a", maintain_order=maintain_order)
401
.agg(pl.lit([1]).gather([0, 0, 0])),
402
pl.DataFrame({"a": [1, 2, 3], "literal": [[[1], [1], [1]]] * 3}),
403
check_row_order=maintain_order,
404
)
405
406
407
def test_get_window_with_filtered_empty_groups_23029() -> None:
408
# https://github.com/pola-rs/polars/issues/23029
409
df = pl.DataFrame(
410
{
411
"group": [1, 1, 2, 2, 3, 3],
412
"value": [10, 20, 30, 40, 50, 60],
413
"filter_condition": [False, True, False, False, True, True],
414
}
415
)
416
417
result = df.with_columns(
418
get_first=(
419
pl.col("value")
420
.filter(pl.col("filter_condition"))
421
.get(0, null_on_oob=True)
422
.over("group")
423
),
424
first_value=(
425
pl.col("value").filter(pl.col("filter_condition")).first().over("group")
426
),
427
)
428
429
assert_series_equal(
430
result["get_first"],
431
result["first_value"],
432
check_names=False,
433
)
434
435
# And the concrete expected values are:
436
expected = pl.DataFrame(
437
{
438
"group": [1, 1, 2, 2, 3, 3],
439
"value": [10, 20, 30, 40, 50, 60],
440
"filter_condition": [False, True, False, False, True, True],
441
"get_first": [20, 20, None, None, 50, 50],
442
"first_value": [20, 20, None, None, 50, 50],
443
}
444
)
445
446
assert_frame_equal(result, expected)
447
448
449
@pytest.mark.parametrize("idx_dtype", [pl.Int64, pl.UInt64, pl.Int128, pl.UInt128])
450
def test_get_typed_index_null_on_oob_true(idx_dtype: pl.DataType) -> None:
451
# OOB typed index with null_on_oob=True -> null, for multiple integer dtypes.
452
df = pl.DataFrame({"value": [1, 2, 10]})
453
454
out = df.select(v=pl.col("value").get(pl.lit(5, dtype=idx_dtype), null_on_oob=True))
455
456
assert out["v"].to_list() == [None]
457
458
459
@pytest.mark.parametrize("idx_dtype", [pl.Int64, pl.UInt64, pl.Int128, pl.UInt128])
460
def test_get_typed_index_null_on_oob_false_raises(idx_dtype: pl.DataType) -> None:
461
# OOB typed index with null_on_oob=False -> OutOfBoundsError, for multiple dtypes.
462
df = pl.DataFrame({"value": [10, 11]})
463
464
with pytest.raises(OutOfBoundsError, match="gather indices are out of bounds"):
465
df.select(pl.col("value").get(pl.lit(5, dtype=idx_dtype), null_on_oob=False))
466
467
468
@pytest.mark.parametrize("idx_dtype", [pl.Int64, pl.UInt64, pl.Int128, pl.UInt128])
469
def test_get_typed_index_default_raises_out_of_bounds(idx_dtype: pl.DataType) -> None:
470
# Default behavior (null_on_oob omitted) should behave like null_on_oob=False
471
df = pl.DataFrame({"value": [10, 11]})
472
473
with pytest.raises(OutOfBoundsError, match="gather indices are out of bounds"):
474
df.select(pl.col("value").get(pl.lit(5, dtype=idx_dtype)))
475
476