Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/operations/test_gather.py
6939 views
1
import numpy as np
2
import pytest
3
4
import polars as pl
5
from polars.exceptions import ComputeError
6
from polars.testing import assert_frame_equal, assert_series_equal
7
8
9
def test_negative_index() -> None:
10
df = pl.DataFrame({"a": [1, 2, 3, 4, 5, 6]})
11
assert df.select(pl.col("a").gather([0, -1])).to_dict(as_series=False) == {
12
"a": [1, 6]
13
}
14
assert_frame_equal(
15
df.group_by(pl.col("a") % 2).agg(b=pl.col("a").gather([0, -1])),
16
pl.DataFrame({"a": [0, 1], "b": [[2, 6], [1, 5]]}),
17
check_row_order=False,
18
)
19
20
21
def test_gather_agg_schema() -> None:
22
df = pl.DataFrame(
23
{
24
"group": [
25
"one",
26
"one",
27
"one",
28
"two",
29
"two",
30
"two",
31
],
32
"value": [1, 98, 2, 3, 99, 4],
33
}
34
)
35
assert (
36
df.lazy()
37
.group_by("group", maintain_order=True)
38
.agg(pl.col("value").get(1))
39
.collect_schema()["value"]
40
== pl.Int64
41
)
42
43
44
def test_gather_lit_single_16535() -> None:
45
df = pl.DataFrame({"x": [1, 2, 2, 1], "y": [1, 2, 3, 4]})
46
47
assert df.group_by(["x"], maintain_order=True).agg(pl.all().gather([1])).to_dict(
48
as_series=False
49
) == {"x": [1, 2], "y": [[4], [3]]}
50
51
52
def test_list_get_null_offset_17248() -> None:
53
df = pl.DataFrame({"material": [["PB", "PVC", "CI"], ["CI"], ["CI"]]})
54
55
assert df.select(
56
result=pl.when(pl.col.material.list.len() == 1).then("material").list.get(0),
57
)["result"].to_list() == [None, "CI", "CI"]
58
59
60
def test_list_get_null_oob_17252() -> None:
61
df = pl.DataFrame(
62
{
63
"name": ["BOB-3", "BOB", None],
64
}
65
)
66
67
split = df.with_columns(pl.col("name").str.split("-"))
68
assert split.with_columns(pl.col("name").list.get(0))["name"].to_list() == [
69
"BOB",
70
"BOB",
71
None,
72
]
73
74
75
def test_list_get_null_on_oob_false_success() -> None:
76
# test Series (single offset) with nulls
77
expected = pl.Series("a", [2, None, 2], dtype=pl.Int64)
78
s_nulls = pl.Series("a", [[1, 2], None, [1, 2, 3]])
79
out = s_nulls.list.get(1, null_on_oob=False)
80
assert_series_equal(out, expected)
81
82
# test Expr (multiple offsets) with nulls
83
df = s_nulls.to_frame().with_columns(pl.lit(1).alias("idx"))
84
out = df.select(pl.col("a").list.get("idx", null_on_oob=True)).to_series()
85
assert_series_equal(out, expected)
86
87
# test Series (single offset) with no nulls
88
expected = pl.Series("a", [2, 2, 2], dtype=pl.Int64)
89
s_no_nulls = pl.Series("a", [[1, 2], [1, 2], [1, 2, 3]])
90
out = s_no_nulls.list.get(1, null_on_oob=False)
91
assert_series_equal(out, expected)
92
93
# test Expr (multiple offsets) with no nulls
94
df = s_no_nulls.to_frame().with_columns(pl.lit(1).alias("idx"))
95
out = df.select(pl.col("a").list.get("idx", null_on_oob=True)).to_series()
96
assert_series_equal(out, expected)
97
98
99
def test_list_get_null_on_oob_false_failure() -> None:
100
# test Series (single offset) with nulls
101
s_nulls = pl.Series("a", [[1, 2], None, [1, 2, 3]])
102
with pytest.raises(ComputeError, match="get index is out of bounds"):
103
s_nulls.list.get(2, null_on_oob=False)
104
105
# test Expr (multiple offsets) with nulls
106
df = s_nulls.to_frame().with_columns(pl.lit(2).alias("idx"))
107
with pytest.raises(ComputeError, match="get index is out of bounds"):
108
df.select(pl.col("a").list.get("idx", null_on_oob=False))
109
110
# test Series (single offset) with no nulls
111
s_no_nulls = pl.Series("a", [[1, 2], [1], [1, 2, 3]])
112
with pytest.raises(ComputeError, match="get index is out of bounds"):
113
s_no_nulls.list.get(2, null_on_oob=False)
114
115
# test Expr (multiple offsets) with no nulls
116
df = s_no_nulls.to_frame().with_columns(pl.lit(2).alias("idx"))
117
with pytest.raises(ComputeError, match="get index is out of bounds"):
118
df.select(pl.col("a").list.get("idx", null_on_oob=False))
119
120
121
def test_list_get_null_on_oob_true() -> None:
122
# test Series (single offset) with nulls
123
s_nulls = pl.Series("a", [[1, 2], None, [1, 2, 3]])
124
out = s_nulls.list.get(2, null_on_oob=True)
125
expected = pl.Series("a", [None, None, 3], dtype=pl.Int64)
126
assert_series_equal(out, expected)
127
128
# test Expr (multiple offsets) with nulls
129
df = s_nulls.to_frame().with_columns(pl.lit(2).alias("idx"))
130
out = df.select(pl.col("a").list.get("idx", null_on_oob=True)).to_series()
131
assert_series_equal(out, expected)
132
133
# test Series (single offset) with no nulls
134
s_no_nulls = pl.Series("a", [[1, 2], [1], [1, 2, 3]])
135
out = s_no_nulls.list.get(2, null_on_oob=True)
136
expected = pl.Series("a", [None, None, 3], dtype=pl.Int64)
137
assert_series_equal(out, expected)
138
139
# test Expr (multiple offsets) with no nulls
140
df = s_no_nulls.to_frame().with_columns(pl.lit(2).alias("idx"))
141
out = df.select(pl.col("a").list.get("idx", null_on_oob=True)).to_series()
142
assert_series_equal(out, expected)
143
144
145
def test_chunked_gather_phys_repr_17446() -> None:
146
dfa = pl.DataFrame({"replace_unique_id": range(2)})
147
148
for dt in [pl.Date, pl.Time, pl.Duration]:
149
dfb = dfa.clone()
150
dfb = dfb.with_columns(ds_start_date_right=pl.lit(None).cast(dt))
151
dfb = pl.concat([dfb, dfb])
152
153
assert dfa.join(dfb, how="left", on=pl.col("replace_unique_id")).shape == (4, 2)
154
155
156
def test_gather_str_col_18099() -> None:
157
df = pl.DataFrame({"foo": [1, 2, 3], "idx": [0, 0, 1]})
158
assert df.with_columns(pl.col("foo").gather("idx")).to_dict(as_series=False) == {
159
"foo": [1, 1, 2],
160
"idx": [0, 0, 1],
161
}
162
163
164
def test_gather_list_19243() -> None:
165
df = pl.DataFrame({"a": [[0.1, 0.2, 0.3]]})
166
assert df.with_columns(pl.lit([0]).alias("c")).with_columns(
167
gather=pl.col("a").list.gather(pl.col("c"), null_on_oob=True)
168
).to_dict(as_series=False) == {
169
"a": [[0.1, 0.2, 0.3]],
170
"c": [[0]],
171
"gather": [[0.1]],
172
}
173
174
175
def test_gather_array_list_null_19302() -> None:
176
data = pl.DataFrame(
177
{"data": [None]}, schema_overrides={"data": pl.List(pl.Array(pl.Float32, 1))}
178
)
179
assert data.select(pl.col("data").list.get(0)).to_dict(as_series=False) == {
180
"data": [None]
181
}
182
183
184
def test_gather_array() -> None:
185
a = np.arange(16).reshape(-1, 2, 2)
186
s = pl.Series(a)
187
188
for idx in [[1, 2], [0, 0], [1, 0], [1, 1, 1, 1, 1, 1, 1, 1]]:
189
assert (s.gather(idx).to_numpy() == a[idx]).all()
190
191
v = s[[0, 1, None, 3]] # type: ignore[list-item]
192
assert v[2] is None
193
194
195
def test_gather_array_outer_validity_19482() -> None:
196
s = (
197
pl.Series([[1], [1]], dtype=pl.Array(pl.Int64, 1))
198
.to_frame()
199
.select(pl.when(pl.int_range(pl.len()) == 0).then(pl.first()))
200
.to_series()
201
)
202
203
expect = pl.Series([[1], None], dtype=pl.Array(pl.Int64, 1))
204
assert_series_equal(s, expect)
205
assert_series_equal(s.gather([0, 1]), expect)
206
207
208
def test_gather_len_19561() -> None:
209
N = 4
210
df = pl.DataFrame({"foo": ["baz"] * N, "bar": range(N)})
211
idxs = pl.int_range(1, N).repeat_by(pl.int_range(1, N)).flatten()
212
gather = pl.col.bar.gather(idxs).alias("gather")
213
214
assert df.group_by("foo").agg(gather.len()).to_dict(as_series=False) == {
215
"foo": ["baz"],
216
"gather": [6],
217
}
218
219
220
def test_gather_agg_group_update_scalar() -> None:
221
# If `gather` doesn't update groups properly, `first` will try to access
222
# index 2 (the original index of the first element of group `1`), but gather
223
# outputs only two elements (one for each group), leading to an out of
224
# bounds access.
225
df = (
226
pl.DataFrame({"gid": [0, 0, 1, 1], "x": ["0:0", "0:1", "1:0", "1:1"]})
227
.lazy()
228
.group_by("gid", maintain_order=True)
229
.agg(x_at_gid=pl.col("x").gather(pl.col("gid").last()).first())
230
.collect(optimizations=pl.QueryOptFlags.none())
231
)
232
expected = pl.DataFrame({"gid": [0, 1], "x_at_gid": ["0:0", "1:1"]})
233
assert_frame_equal(df, expected)
234
235
236
def test_gather_agg_group_update_literal() -> None:
237
# If `gather` doesn't update groups properly, `first` will try to access
238
# index 2 (the original index of the first element of group `1`), but gather
239
# outputs only two elements (one for each group), leading to an out of
240
# bounds access.
241
df = (
242
pl.DataFrame({"gid": [0, 0, 1], "x": ["0:0", "0:1", "1:0"]})
243
.lazy()
244
.group_by("gid", maintain_order=True)
245
.agg(x_at_0=pl.col("x").gather(0).first())
246
.collect(optimizations=pl.QueryOptFlags.none())
247
)
248
expected = pl.DataFrame({"gid": [0, 1], "x_at_0": ["0:0", "1:0"]})
249
assert_frame_equal(df, expected)
250
251
252
def test_gather_agg_group_update_negative() -> None:
253
# If `gather` doesn't update groups properly, `first` will try to access
254
# index 2 (the original index of the first element of group `1`), but gather
255
# outputs only two elements (one for each group), leading to an out of
256
# bounds access.
257
df = (
258
pl.DataFrame({"gid": [0, 0, 1], "x": ["0:0", "0:1", "1:0"]})
259
.lazy()
260
.group_by("gid", maintain_order=True)
261
.agg(x_last=pl.col("x").gather(-1).first())
262
.collect(optimizations=pl.QueryOptFlags.none())
263
)
264
expected = pl.DataFrame({"gid": [0, 1], "x_last": ["0:1", "1:0"]})
265
assert_frame_equal(df, expected)
266
267
268
def test_gather_agg_group_update_multiple() -> None:
269
# If `gather` doesn't update groups properly, `first` will try to access
270
# index 4 (the original index of the first element of group `1`), but gather
271
# outputs only four elements (two for each group), leading to an out of
272
# bounds access.
273
df = (
274
pl.DataFrame(
275
{
276
"gid": [0, 0, 0, 0, 1, 1],
277
"x": ["0:0", "0:1", "0:2", "0:3", "1:0", "1:1"],
278
}
279
)
280
.lazy()
281
.group_by("gid", maintain_order=True)
282
.agg(x_at_0=pl.col("x").gather([0, 1]).first())
283
.collect(optimizations=pl.QueryOptFlags.none())
284
)
285
expected = pl.DataFrame({"gid": [0, 1], "x_at_0": ["0:0", "1:0"]})
286
assert_frame_equal(df, expected)
287
288
289
def test_get_agg_group_update_literal_21610() -> None:
290
df = (
291
pl.DataFrame(
292
{
293
"group": [100, 100, 100, 200, 200, 200],
294
"value": [1, 2, 3, 2, 3, 4],
295
}
296
)
297
.group_by("group", maintain_order=True)
298
.agg(pl.col("value") - pl.col("value").get(0))
299
)
300
301
expected = pl.DataFrame({"group": [100, 200], "value": [[0, 1, 2], [0, 1, 2]]})
302
assert_frame_equal(df, expected)
303
304
305
def test_get_agg_group_update_scalar_21610() -> None:
306
df = (
307
pl.DataFrame(
308
{
309
"group": [100, 100, 100, 200, 200, 200],
310
"value": [1, 2, 3, 2, 3, 4],
311
}
312
)
313
.group_by("group", maintain_order=True)
314
.agg(pl.col("value") - pl.col("value").get(pl.col("value").first()))
315
)
316
317
expected = pl.DataFrame({"group": [100, 200], "value": [[-1, 0, 1], [-2, -1, 0]]})
318
assert_frame_equal(df, expected)
319
320
321
def test_get_dt_truncate_21533() -> None:
322
df = pl.DataFrame(
323
{
324
"timestamp": pl.datetime_range(
325
pl.datetime(2016, 1, 1),
326
pl.datetime(2017, 12, 31),
327
interval="1d",
328
eager=True,
329
),
330
}
331
).with_columns(
332
month=pl.col.timestamp.dt.month(),
333
)
334
335
report = df.group_by("month", maintain_order=True).agg(
336
trunc_ts=pl.col.timestamp.get(0).dt.truncate("1m")
337
)
338
assert report.shape == (12, 2)
339
340
341
@pytest.mark.parametrize("maintain_order", [False, True])
342
def test_gather_group_by_23696(maintain_order: bool) -> None:
343
df = (
344
pl.DataFrame(
345
{
346
"a": [1, 2, 3, 4],
347
"b": [0, 0, 1, 1],
348
"c": [0, 0, -1, -1],
349
}
350
)
351
.group_by(pl.col.a % 2, maintain_order=maintain_order)
352
.agg(
353
get_first=pl.col.a.get(pl.col.b.get(0)),
354
get_last=pl.col.a.get(pl.col.b.get(1)),
355
normal=pl.col.a.gather(pl.col.b),
356
signed=pl.col.a.gather(pl.col.c),
357
drop_nulls=pl.col.a.gather(pl.col.b.drop_nulls()),
358
drop_nulls_signed=pl.col.a.gather(pl.col.c.drop_nulls()),
359
literal=pl.col.a.gather([0, 1]),
360
literal_signed=pl.col.a.gather([0, -1]),
361
)
362
)
363
364
expected = pl.DataFrame(
365
{
366
"a": [1, 0],
367
"get_first": [1, 2],
368
"get_last": [3, 4],
369
"normal": [[1, 3], [2, 4]],
370
"signed": [[1, 3], [2, 4]],
371
"drop_nulls": [[1, 3], [2, 4]],
372
"drop_nulls_signed": [[1, 3], [2, 4]],
373
"literal": [[1, 3], [2, 4]],
374
"literal_signed": [[1, 3], [2, 4]],
375
}
376
)
377
378
assert_frame_equal(df, expected, check_row_order=maintain_order)
379
380
381
def test_gather_invalid_indices_groupby_24182() -> None:
382
df = pl.DataFrame({"x": [1, 2]})
383
with pytest.raises(pl.exceptions.InvalidOperationError):
384
df.group_by(True).agg(pl.col("x").gather(pl.lit("y")))
385
386
387
@pytest.mark.parametrize("maintain_order", [False, True])
388
def test_gather_group_by_lit(maintain_order: bool) -> None:
389
assert_frame_equal(
390
pl.DataFrame(
391
{
392
"a": [1, 2, 3],
393
}
394
)
395
.group_by("a", maintain_order=maintain_order)
396
.agg(pl.lit([1]).gather([0, 0, 0])),
397
pl.DataFrame({"a": [1, 2, 3], "literal": [[[1], [1], [1]]] * 3}),
398
check_row_order=maintain_order,
399
)
400
401