Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/functions/range/test_int_range.py
6939 views
1
from __future__ import annotations
2
3
from typing import Any
4
5
import pytest
6
7
import polars as pl
8
from polars.exceptions import (
9
ComputeError,
10
InvalidOperationError,
11
SchemaError,
12
ShapeError,
13
)
14
from polars.testing import assert_frame_equal, assert_series_equal
15
16
17
def test_int_range() -> None:
18
result = pl.int_range(0, 3)
19
expected = pl.Series("int_range", [0, 1, 2])
20
assert_series_equal(pl.select(int_range=result).to_series(), expected)
21
22
23
def test_int_range_alias() -> None:
24
# note: `arange` is an alias for `int_range`
25
ldf = pl.LazyFrame({"a": [1, 1, 1]})
26
result = ldf.filter(pl.col("a") >= pl.arange(0, 3)).collect()
27
expected = pl.DataFrame({"a": [1, 1]})
28
assert_frame_equal(result, expected)
29
30
31
def test_int_range_decreasing() -> None:
32
assert pl.int_range(10, 1, -2, eager=True).to_list() == list(range(10, 1, -2))
33
assert pl.int_range(10, -1, -1, eager=True).to_list() == list(range(10, -1, -1))
34
35
36
def test_int_range_expr() -> None:
37
df = pl.DataFrame({"a": ["foobar", "barfoo"]})
38
out = df.select(pl.int_range(0, pl.col("a").count() * 10))
39
assert out.shape == (20, 1)
40
assert out.to_series(0)[-1] == 19
41
42
# eager arange
43
out2 = pl.arange(0, 10, 2, eager=True)
44
assert out2.to_list() == [0, 2, 4, 6, 8]
45
46
47
def test_int_range_short_syntax() -> None:
48
result = pl.int_range(3)
49
expected = pl.Series("int", [0, 1, 2])
50
assert_series_equal(pl.select(int=result).to_series(), expected)
51
52
53
def test_int_ranges_short_syntax() -> None:
54
result = pl.int_ranges(3)
55
expected = pl.Series("int", [[0, 1, 2]])
56
assert_series_equal(pl.select(int=result).to_series(), expected)
57
58
59
def test_int_range_start_default() -> None:
60
result = pl.int_range(end=3)
61
expected = pl.Series("int", [0, 1, 2])
62
assert_series_equal(pl.select(int=result).to_series(), expected)
63
64
65
def test_int_ranges_start_default() -> None:
66
df = pl.DataFrame({"end": [3, 2]})
67
result = df.select(int_range=pl.int_ranges(end="end"))
68
expected = pl.DataFrame({"int_range": [[0, 1, 2], [0, 1]]})
69
assert_frame_equal(result, expected)
70
71
72
def test_int_range_eager() -> None:
73
result = pl.int_range(0, 3, eager=True)
74
expected = pl.Series("literal", [0, 1, 2])
75
assert_series_equal(result, expected)
76
77
78
def test_int_range_lazy() -> None:
79
lf = pl.select(n=pl.int_range(8, 0, -2), eager=False)
80
expected = pl.LazyFrame({"n": [8, 6, 4, 2]})
81
assert_frame_equal(lf, expected)
82
83
84
def test_int_range_schema() -> None:
85
result = pl.LazyFrame().select(int=pl.int_range(-3, 3))
86
87
expected_schema = {"int": pl.Int64}
88
assert result.collect_schema() == expected_schema
89
assert result.collect().schema == expected_schema
90
91
92
@pytest.mark.parametrize(
93
("start", "end", "expected"),
94
[
95
("a", "b", pl.Series("a", [[1, 2], [2, 3]])),
96
(-1, "a", pl.Series("literal", [[-1, 0], [-1, 0, 1]])),
97
("b", 4, pl.Series("b", [[3], []])),
98
],
99
)
100
def test_int_ranges(start: Any, end: Any, expected: pl.Series) -> None:
101
df = pl.DataFrame({"a": [1, 2], "b": [3, 4]})
102
103
result = df.select(pl.int_ranges(start, end))
104
assert_series_equal(result.to_series(), expected)
105
106
107
def test_int_ranges_decreasing() -> None:
108
expected = pl.Series("literal", [[5, 4, 3, 2, 1]], dtype=pl.List(pl.Int64))
109
assert_series_equal(pl.int_ranges(5, 0, -1, eager=True), expected)
110
assert_series_equal(pl.select(pl.int_ranges(5, 0, -1)).to_series(), expected)
111
112
113
@pytest.mark.parametrize(
114
("start", "end", "step"),
115
[
116
(0, -5, 1),
117
(5, 0, 1),
118
(0, 5, -1),
119
],
120
)
121
def test_int_ranges_empty(start: int, end: int, step: int) -> None:
122
assert_series_equal(
123
pl.int_range(start, end, step, eager=True),
124
pl.Series("literal", [], dtype=pl.Int64),
125
)
126
assert_series_equal(
127
pl.int_ranges(start, end, step, eager=True),
128
pl.Series("literal", [[]], dtype=pl.List(pl.Int64)),
129
)
130
assert_series_equal(
131
pl.Series("int", [], dtype=pl.Int64),
132
pl.select(int=pl.int_range(start, end, step)).to_series(),
133
)
134
assert_series_equal(
135
pl.Series("int_range", [[]], dtype=pl.List(pl.Int64)),
136
pl.select(int_range=pl.int_ranges(start, end, step)).to_series(),
137
)
138
139
140
def test_int_ranges_eager() -> None:
141
start = pl.Series("s", [1, 2])
142
result = pl.int_ranges(start, 4, eager=True)
143
144
expected = pl.Series("s", [[1, 2, 3], [2, 3]])
145
assert_series_equal(result, expected)
146
147
148
def test_int_ranges_schema_dtype_default() -> None:
149
lf = pl.LazyFrame({"start": [1, 2], "end": [3, 4]})
150
151
result = lf.select(pl.int_ranges("start", "end"))
152
153
expected_schema = {"start": pl.List(pl.Int64)}
154
assert result.collect_schema() == expected_schema
155
assert result.collect().schema == expected_schema
156
157
158
def test_int_ranges_schema_dtype_arg() -> None:
159
lf = pl.LazyFrame({"start": [1, 2], "end": [3, 4]})
160
161
result = lf.select(pl.int_ranges("start", "end", dtype=pl.UInt16))
162
163
expected_schema = {"start": pl.List(pl.UInt16)}
164
assert result.collect_schema() == expected_schema
165
assert result.collect().schema == expected_schema
166
167
168
def test_int_range_input_shape_empty() -> None:
169
empty = pl.Series(dtype=pl.Time)
170
single = pl.Series([5])
171
172
with pytest.raises(ShapeError):
173
pl.int_range(empty, single, eager=True)
174
with pytest.raises(ShapeError):
175
pl.int_range(single, empty, eager=True)
176
with pytest.raises(ShapeError):
177
pl.int_range(empty, empty, eager=True)
178
179
180
def test_int_range_input_shape_multiple_values() -> None:
181
single = pl.Series([5])
182
multiple = pl.Series([10, 15])
183
184
with pytest.raises(ShapeError):
185
pl.int_range(multiple, single, eager=True)
186
with pytest.raises(ShapeError):
187
pl.int_range(single, multiple, eager=True)
188
with pytest.raises(ShapeError):
189
pl.int_range(multiple, multiple, eager=True)
190
191
192
# https://github.com/pola-rs/polars/issues/10867
193
def test_int_range_index_type_negative() -> None:
194
result = pl.select(pl.int_range(pl.lit(3).cast(pl.UInt32).alias("start"), -1, -1))
195
expected = pl.DataFrame({"start": [3, 2, 1, 0]})
196
assert_frame_equal(result, expected)
197
198
199
def test_int_range_null_input() -> None:
200
with pytest.raises(ComputeError, match="invalid null input for `int_range`"):
201
pl.select(pl.int_range(3, pl.lit(None), -1, dtype=pl.UInt32))
202
203
204
def test_int_range_invalid_conversion() -> None:
205
with pytest.raises(
206
InvalidOperationError, match="conversion from `i128` to `u32` failed"
207
):
208
pl.select(pl.int_range(3, -1, -1, dtype=pl.UInt32))
209
210
211
def test_int_range_non_integer_dtype() -> None:
212
with pytest.raises(
213
SchemaError, match="non-integer `dtype` passed to `int_range`: 'f64'"
214
):
215
pl.select(pl.int_range(3, -1, -1, dtype=pl.Float64)) # type: ignore[arg-type]
216
217
218
def test_int_ranges_broadcasting() -> None:
219
df = pl.DataFrame({"int": [1, 2, 3]})
220
result = df.select(
221
# result column name means these columns will be broadcast
222
pl.int_ranges(1, pl.Series([2, 4, 6]), "int").alias("start"),
223
pl.int_ranges("int", 6, "int").alias("end"),
224
pl.int_ranges("int", pl.col("int") + 2, 1).alias("step"),
225
pl.int_ranges("int", 3, 1).alias("end_step"),
226
pl.int_ranges(1, "int", 1).alias("start_step"),
227
pl.int_ranges(1, 6, "int").alias("start_end"),
228
pl.int_ranges("int", pl.Series([4, 5, 10]), "int").alias("no_broadcast"),
229
)
230
expected = pl.DataFrame(
231
{
232
"start": [[1], [1, 3], [1, 4]],
233
"end": [
234
[1, 2, 3, 4, 5],
235
[2, 4],
236
[3],
237
],
238
"step": [[1, 2], [2, 3], [3, 4]],
239
"end_step": [
240
[1, 2],
241
[2],
242
[],
243
],
244
"start_step": [
245
[],
246
[1],
247
[1, 2],
248
],
249
"start_end": [
250
[1, 2, 3, 4, 5],
251
[1, 3, 5],
252
[1, 4],
253
],
254
"no_broadcast": [[1, 2, 3], [2, 4], [3, 6, 9]],
255
}
256
)
257
assert_frame_equal(result, expected)
258
259
260
# https://github.com/pola-rs/polars/issues/15307
261
def test_int_range_non_int_dtype() -> None:
262
with pytest.raises(
263
SchemaError, match="non-integer `dtype` passed to `int_range`: 'str'"
264
):
265
pl.int_range(0, 3, dtype=pl.String, eager=True) # type: ignore[arg-type]
266
267
268
# https://github.com/pola-rs/polars/issues/15307
269
def test_int_ranges_non_int_dtype() -> None:
270
with pytest.raises(
271
SchemaError, match="non-integer `dtype` passed to `int_ranges`: 'str'"
272
):
273
pl.int_ranges(0, 3, dtype=pl.String, eager=True) # type: ignore[arg-type]
274
275
276
# https://github.com/pola-rs/polars/issues/22640
277
def test_int_ranges_non_numeric_input_should_error() -> None:
278
df = pl.DataFrame(
279
{
280
"start": ["a", "b"],
281
"end": ["c", "d"],
282
}
283
)
284
285
with pytest.raises(pl.exceptions.InvalidOperationError) as excinfo:
286
_ = df.select(pl.int_ranges("start", "end"))
287
288
assert "conversion from `str` to `i64` failed" in str(excinfo.value)
289
290
291
def test_int_range_len_count() -> None:
292
values = [1, 2, None, 4, 5, 6]
293
294
lf = pl.Series("a", values).to_frame().lazy()
295
296
def irange(e: pl.Expr) -> pl.LazyFrame:
297
return lf.select(r=pl.int_range(0, e, dtype=pl.get_index_type()))
298
299
q = irange(pl.len())
300
assert_series_equal(
301
q.collect().to_series(),
302
pl.Series("r", [0, 1, 2, 3, 4, 5], pl.get_index_type()),
303
)
304
305
q = irange(pl.col.a.len())
306
assert_series_equal(
307
q.collect().to_series(),
308
pl.Series("r", [0, 1, 2, 3, 4, 5], pl.get_index_type()),
309
)
310
311
q = irange(pl.col.a.filter(pl.col.a.ne_missing(4)).len())
312
assert_series_equal(
313
q.collect().to_series(),
314
pl.Series("r", [0, 1, 2, 3, 4], pl.get_index_type()),
315
)
316
317
q = irange(pl.col.a.count())
318
assert_series_equal(
319
q.collect().to_series(),
320
pl.Series("r", [0, 1, 2, 3, 4], pl.get_index_type()),
321
)
322
323
q = irange(pl.col.a.filter(pl.col.a.ne_missing(4)).count())
324
assert_series_equal(
325
q.collect().to_series(),
326
pl.Series("r", [0, 1, 2, 3], pl.get_index_type()),
327
)
328
329