Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/functions/test_repeat.py
6939 views
1
from __future__ import annotations
2
3
from datetime import date, datetime, time, timedelta
4
from typing import TYPE_CHECKING, Any
5
6
import pytest
7
8
import polars as pl
9
from polars.exceptions import ComputeError, SchemaError, ShapeError
10
from polars.testing import assert_frame_equal, assert_series_equal
11
12
if TYPE_CHECKING:
13
from polars._typing import PolarsDataType
14
15
16
@pytest.mark.parametrize(
17
("value", "n", "dtype", "expected_dtype"),
18
[
19
(2**31, 5, None, pl.Int64),
20
(2**31 - 1, 5, None, pl.Int32),
21
(-(2**31) - 1, 3, None, pl.Int64),
22
(-(2**31), 3, None, pl.Int32),
23
("foo", 2, None, pl.String),
24
(1.0, 5, None, pl.Float64),
25
(True, 4, None, pl.Boolean),
26
(None, 7, None, pl.Null),
27
(0, 0, None, pl.Int32),
28
(datetime(2023, 2, 2), 3, None, pl.Datetime),
29
(date(2023, 2, 2), 3, None, pl.Date),
30
(time(10, 15), 1, None, pl.Time),
31
(timedelta(hours=3), 10, None, pl.Duration),
32
(8, 2, pl.UInt8, pl.UInt8),
33
(date(2023, 2, 2), 3, pl.Datetime, pl.Datetime),
34
(7.5, 5, pl.UInt16, pl.UInt16),
35
([1, 2, 3], 2, pl.List(pl.Int64), pl.List(pl.Int64)),
36
(b"ab12", 3, pl.Binary, pl.Binary),
37
],
38
)
39
def test_repeat(
40
value: Any,
41
n: int,
42
dtype: PolarsDataType,
43
expected_dtype: PolarsDataType,
44
) -> None:
45
expected = pl.Series("repeat", [value] * n).cast(expected_dtype)
46
47
result_eager = pl.repeat(value, n=n, dtype=dtype, eager=True)
48
assert_series_equal(result_eager, expected)
49
50
result_lazy = pl.select(pl.repeat(value, n=n, dtype=dtype, eager=False)).to_series()
51
assert_series_equal(result_lazy, expected)
52
53
54
def test_repeat_expr_input_eager() -> None:
55
result = pl.select(pl.repeat(1, n=pl.lit(3), eager=True)).to_series()
56
expected = pl.Series("repeat", [1, 1, 1], dtype=pl.Int32)
57
assert_series_equal(result, expected)
58
59
60
def test_repeat_expr_input_lazy() -> None:
61
df = pl.DataFrame({"a": [3, 2, 1]})
62
result = df.select(pl.repeat(1, n=pl.col("a").first())).to_series()
63
expected = pl.Series("repeat", [1, 1, 1], dtype=pl.Int32)
64
assert_series_equal(result, expected)
65
66
df = pl.DataFrame({"a": [3, 2, 1]})
67
assert df.select(pl.repeat(pl.sum("a"), n=2)).to_series().to_list() == [6, 6]
68
69
70
def test_repeat_n_zero() -> None:
71
assert pl.repeat(1, n=0, eager=True).len() == 0
72
73
74
@pytest.mark.parametrize(
75
"n",
76
[1.5, 2.0, date(1971, 1, 2), "hello"],
77
)
78
def test_repeat_n_non_integer(n: Any) -> None:
79
with pytest.raises(SchemaError, match="expected expression of dtype 'integer'"):
80
pl.repeat(1, n=pl.lit(n), eager=True)
81
82
83
def test_repeat_n_empty() -> None:
84
df = pl.DataFrame(schema={"a": pl.Int32})
85
with pytest.raises(ShapeError, match="'n' must be a scalar value"):
86
df.select(pl.repeat(1, n=pl.col("a")))
87
88
89
def test_repeat_n_negative() -> None:
90
with pytest.raises(ComputeError, match="could not parse value '-1' as a size"):
91
pl.repeat(1, n=-1, eager=True)
92
93
94
@pytest.mark.parametrize(
95
("n", "value", "dtype"),
96
[
97
(2, 1, pl.UInt32),
98
(0, 1, pl.Int16),
99
(3, 1, pl.Float32),
100
(1, "1", pl.Utf8),
101
(2, ["1"], pl.List(pl.Utf8)),
102
(4, True, pl.Boolean),
103
(2, [True], pl.List(pl.Boolean)),
104
(2, [1], pl.Array(pl.Int16, shape=1)),
105
(2, [1, 1, 1], pl.Array(pl.Int8, shape=3)),
106
(1, [1], pl.List(pl.UInt32)),
107
],
108
)
109
def test_ones(
110
n: int,
111
value: Any,
112
dtype: PolarsDataType,
113
) -> None:
114
expected = pl.Series("ones", [value] * n, dtype=dtype)
115
116
result_eager = pl.ones(n=n, dtype=dtype, eager=True)
117
assert_series_equal(result_eager, expected)
118
119
result_lazy = pl.select(pl.ones(n=n, dtype=dtype, eager=False)).to_series()
120
assert_series_equal(result_lazy, expected)
121
122
123
@pytest.mark.parametrize(
124
("n", "value", "dtype"),
125
[
126
(2, 0, pl.UInt8),
127
(0, 0, pl.Int32),
128
(3, 0, pl.Float32),
129
(1, "0", pl.Utf8),
130
(2, ["0"], pl.List(pl.Utf8)),
131
(4, False, pl.Boolean),
132
(2, [False], pl.List(pl.Boolean)),
133
(3, [0], pl.Array(pl.UInt32, shape=1)),
134
(2, [0, 0, 0], pl.Array(pl.UInt32, shape=3)),
135
(1, [0], pl.List(pl.UInt32)),
136
],
137
)
138
def test_zeros(
139
n: int,
140
value: Any,
141
dtype: PolarsDataType,
142
) -> None:
143
expected = pl.Series("zeros", [value] * n, dtype=dtype)
144
145
result_eager = pl.zeros(n=n, dtype=dtype, eager=True)
146
assert_series_equal(result_eager, expected)
147
148
result_lazy = pl.select(pl.zeros(n=n, dtype=dtype, eager=False)).to_series()
149
assert_series_equal(result_lazy, expected)
150
151
152
def test_ones_zeros_misc() -> None:
153
# check we default to f64 if dtype is unspecified
154
s_ones = pl.ones(n=2, eager=True)
155
s_zeros = pl.zeros(n=2, eager=True)
156
157
assert s_ones.dtype == s_zeros.dtype == pl.Float64
158
159
# confirm that we raise a suitable error if dtype is invalid
160
with pytest.raises(TypeError, match="invalid dtype for `ones`"):
161
pl.ones(n=2, dtype=pl.Struct({"x": pl.Date, "y": pl.Duration}), eager=True)
162
163
with pytest.raises(TypeError, match="invalid dtype for `zeros`"):
164
pl.zeros(n=2, dtype=pl.Struct({"x": pl.Date, "y": pl.Duration}), eager=True)
165
166
167
def test_repeat_by_logical_dtype() -> None:
168
df = pl.DataFrame(
169
{
170
"repeat": [1, 2, 3],
171
"date": [date(2021, 1, 1)] * 3,
172
"cat": ["a", "b", "c"],
173
},
174
schema={"repeat": pl.Int32, "date": pl.Date, "cat": pl.Categorical},
175
)
176
out = df.select(
177
pl.col("date").repeat_by("repeat"), pl.col("cat").repeat_by("repeat")
178
)
179
180
expected_df = pl.DataFrame(
181
{
182
"date": [
183
[date(2021, 1, 1)],
184
[date(2021, 1, 1), date(2021, 1, 1)],
185
[date(2021, 1, 1), date(2021, 1, 1), date(2021, 1, 1)],
186
],
187
"cat": [["a"], ["b", "b"], ["c", "c", "c"]],
188
},
189
schema={"date": pl.List(pl.Date), "cat": pl.List(pl.Categorical)},
190
)
191
192
assert_frame_equal(out, expected_df)
193
194
195
def test_repeat_by_list() -> None:
196
df = pl.DataFrame(
197
{
198
"repeat": [1, 2, 3, None],
199
"value": [None, [1, 2, 3], [4, None], [1, 2]],
200
},
201
schema={"repeat": pl.UInt32, "value": pl.List(pl.UInt8)},
202
)
203
out = df.select(pl.col("value").repeat_by("repeat"))
204
205
expected_df = pl.DataFrame(
206
{
207
"value": [
208
[None],
209
[[1, 2, 3], [1, 2, 3]],
210
[[4, None], [4, None], [4, None]],
211
None,
212
],
213
},
214
schema={"value": pl.List(pl.List(pl.UInt8))},
215
)
216
217
assert_frame_equal(out, expected_df)
218
219
220
def test_repeat_by_nested_list() -> None:
221
df = pl.DataFrame(
222
{
223
"repeat": [1, 2, 3],
224
"value": [None, [[1], [2, 2]], [[3, 3], None, [4, None]]],
225
},
226
schema={"repeat": pl.UInt32, "value": pl.List(pl.List(pl.Int16))},
227
)
228
out = df.select(pl.col("value").repeat_by("repeat"))
229
230
expected_df = pl.DataFrame(
231
{
232
"value": [
233
[None],
234
[[[1], [2, 2]], [[1], [2, 2]]],
235
[
236
[[3, 3], None, [4, None]],
237
[[3, 3], None, [4, None]],
238
[[3, 3], None, [4, None]],
239
],
240
],
241
},
242
schema={"value": pl.List(pl.List(pl.List(pl.Int16)))},
243
)
244
245
assert_frame_equal(out, expected_df)
246
247
248
def test_repeat_by_struct() -> None:
249
df = pl.DataFrame(
250
{
251
"repeat": [1, 2, 3],
252
"value": [None, {"a": 1, "b": 2}, {"a": 3, "b": None}],
253
},
254
schema={"repeat": pl.UInt32, "value": pl.Struct({"a": pl.Int8, "b": pl.Int32})},
255
)
256
out = df.select(pl.col("value").repeat_by("repeat"))
257
258
expected_df = pl.DataFrame(
259
{
260
"value": [
261
[None],
262
[{"a": 1, "b": 2}, {"a": 1, "b": 2}],
263
[{"a": 3, "b": None}, {"a": 3, "b": None}, {"a": 3, "b": None}],
264
],
265
},
266
schema={"value": pl.List(pl.Struct({"a": pl.Int8, "b": pl.Int32}))},
267
)
268
269
assert_frame_equal(out, expected_df)
270
271
272
def test_repeat_by_nested_struct() -> None:
273
df = pl.DataFrame(
274
{
275
"repeat": [1, 2, 3],
276
"value": [
277
None,
278
{"a": {"x": 1, "y": 1}, "b": 2},
279
{"a": {"x": None, "y": 3}, "b": None},
280
],
281
},
282
schema={
283
"repeat": pl.UInt32,
284
"value": pl.Struct(
285
{"a": pl.Struct({"x": pl.Int64, "y": pl.Int128}), "b": pl.Int32}
286
),
287
},
288
)
289
out = df.select(pl.col("value").repeat_by("repeat"))
290
291
expected_df = pl.DataFrame(
292
{
293
"value": [
294
[None],
295
[{"a": {"x": 1, "y": 1}, "b": 2}, {"a": {"x": 1, "y": 1}, "b": 2}],
296
[
297
{"a": {"x": None, "y": 3}, "b": None},
298
{"a": {"x": None, "y": 3}, "b": None},
299
{"a": {"x": None, "y": 3}, "b": None},
300
],
301
],
302
},
303
schema={
304
"value": pl.List(
305
pl.Struct(
306
{"a": pl.Struct({"x": pl.Int64, "y": pl.Int128}), "b": pl.Int32}
307
)
308
)
309
},
310
)
311
312
assert_frame_equal(out, expected_df)
313
314
315
def test_repeat_by_struct_in_list() -> None:
316
df = pl.DataFrame(
317
{
318
"repeat": [1, 2, 3],
319
"value": [
320
None,
321
[{"a": "foo", "b": "A"}, None],
322
[{"a": None, "b": "B"}, {"a": "test", "b": "B"}],
323
],
324
},
325
schema={
326
"repeat": pl.UInt32,
327
"value": pl.List(pl.Struct({"a": pl.String, "b": pl.Enum(["A", "B"])})),
328
},
329
)
330
out = df.select(pl.col("value").repeat_by("repeat"))
331
332
expected_df = pl.DataFrame(
333
{
334
"value": [
335
[None],
336
[[{"a": "foo", "b": "A"}, None], [{"a": "foo", "b": "A"}, None]],
337
[
338
[{"a": None, "b": "B"}, {"a": "test", "b": "B"}],
339
[{"a": None, "b": "B"}, {"a": "test", "b": "B"}],
340
[{"a": None, "b": "B"}, {"a": "test", "b": "B"}],
341
],
342
],
343
},
344
schema={
345
"value": pl.List(
346
pl.List(pl.Struct({"a": pl.String, "b": pl.Enum(["A", "B"])}))
347
)
348
},
349
)
350
351
assert_frame_equal(out, expected_df)
352
353
354
def test_repeat_by_list_in_struct() -> None:
355
df = pl.DataFrame(
356
{
357
"repeat": [1, 2, 3],
358
"value": [
359
None,
360
{"a": [1, 2, 3], "b": ["x", "y", None]},
361
{"a": [None, 5, 6], "b": None},
362
],
363
},
364
schema={
365
"repeat": pl.UInt32,
366
"value": pl.Struct({"a": pl.List(pl.Int8), "b": pl.List(pl.String)}),
367
},
368
)
369
out = df.select(pl.col("value").repeat_by("repeat"))
370
371
expected_df = pl.DataFrame(
372
{
373
"value": [
374
[None],
375
[
376
{"a": [1, 2, 3], "b": ["x", "y", None]},
377
{"a": [1, 2, 3], "b": ["x", "y", None]},
378
],
379
[
380
{"a": [None, 5, 6], "b": None},
381
{"a": [None, 5, 6], "b": None},
382
{"a": [None, 5, 6], "b": None},
383
],
384
],
385
},
386
schema={
387
"value": pl.List(
388
pl.Struct({"a": pl.List(pl.Int8), "b": pl.List(pl.String)})
389
)
390
},
391
)
392
393
assert_frame_equal(out, expected_df)
394
395
396
@pytest.mark.parametrize(
397
("data", "expected_data"),
398
[
399
(["a", "b", None], [["a", "a"], None, [None, None, None]]),
400
([1, 2, None], [[1, 1], None, [None, None, None]]),
401
([1.1, 2.2, None], [[1.1, 1.1], None, [None, None, None]]),
402
([True, False, None], [[True, True], None, [None, None, None]]),
403
],
404
)
405
def test_repeat_by_none_13053(data: list[Any], expected_data: list[list[Any]]) -> None:
406
df = pl.DataFrame({"x": data, "by": [2, None, 3]})
407
res = df.select(repeat=pl.col("x").repeat_by("by"))
408
expected = pl.Series("repeat", expected_data)
409
assert_series_equal(res.to_series(), expected)
410
411
412
def test_repeat_by_literal_none_20268() -> None:
413
df = pl.DataFrame({"x": ["a", "b"]})
414
expected = pl.Series("repeat", [None, None], dtype=pl.List(pl.String))
415
416
res = df.select(repeat=pl.col("x").repeat_by(pl.lit(None)))
417
assert_series_equal(res.to_series(), expected)
418
419
res = df.select(repeat=pl.col("x").repeat_by(None)) # type: ignore[arg-type]
420
assert_series_equal(res.to_series(), expected)
421
422
423
@pytest.mark.parametrize("value", [pl.Series([]), pl.Series([1, 2])])
424
def test_repeat_nonscalar_value(value: pl.Series) -> None:
425
with pytest.raises(ShapeError, match="'value' must be a scalar value"):
426
pl.select(pl.repeat(pl.Series(value), n=1))
427
428
429
@pytest.mark.parametrize("n", [[], [1, 2]])
430
def test_repeat_nonscalar_n(n: list[int]) -> None:
431
df = pl.DataFrame({"n": n})
432
with pytest.raises(ShapeError, match="'n' must be a scalar value"):
433
df.select(pl.repeat("a", pl.col("n")))
434
435
436
def test_repeat_value_first() -> None:
437
df = pl.DataFrame({"a": ["a", "b", "c"], "n": [4, 5, 6]})
438
result = df.select(rep=pl.repeat(pl.col("a").first(), n=pl.col("n").first()))
439
expected = pl.DataFrame({"rep": ["a", "a", "a", "a"]})
440
assert_frame_equal(result, expected)
441
442
443
def test_repeat_by_arr() -> None:
444
assert_series_equal(
445
pl.Series([["a", "b"], ["a", "c"]], dtype=pl.Array(pl.String, 2)).repeat_by(2),
446
pl.Series(
447
[[["a", "b"], ["a", "b"]], [["a", "c"], ["a", "c"]]],
448
dtype=pl.List(pl.Array(pl.String, 2)),
449
),
450
)
451
452
453
def test_repeat_by_null() -> None:
454
assert_series_equal(
455
pl.Series([None, None], dtype=pl.Null).repeat_by(2),
456
pl.Series([[None, None], [None, None]], dtype=pl.List(pl.Null)),
457
)
458
459