Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/operations/test_cast.py
8424 views
1
from __future__ import annotations
2
3
import operator
4
from datetime import date, datetime, time, timedelta
5
from decimal import Decimal
6
from typing import TYPE_CHECKING, Any
7
8
import pytest
9
10
import polars as pl
11
from polars._utils.constants import MS_PER_SECOND, NS_PER_SECOND, US_PER_SECOND
12
from polars.exceptions import ComputeError, InvalidOperationError
13
from polars.testing import assert_frame_equal
14
from polars.testing.asserts.series import assert_series_equal
15
from tests.unit.conftest import INTEGER_DTYPES, NUMERIC_DTYPES
16
17
if TYPE_CHECKING:
18
from collections.abc import Callable
19
20
from polars._typing import PolarsDataType, PythonDataType
21
22
23
@pytest.mark.parametrize("dtype", [pl.Date(), pl.Date, date])
24
def test_string_date(dtype: PolarsDataType | PythonDataType) -> None:
25
df = pl.DataFrame({"x1": ["2021-01-01"]}).with_columns(
26
**{"x1-date": pl.col("x1").cast(dtype)}
27
)
28
expected = pl.DataFrame({"x1-date": [date(2021, 1, 1)]})
29
out = df.select(pl.col("x1-date"))
30
assert_frame_equal(expected, out)
31
32
33
def test_invalid_string_date() -> None:
34
df = pl.DataFrame({"x1": ["2021-01-aa"]})
35
36
with pytest.raises(InvalidOperationError):
37
df.with_columns(**{"x1-date": pl.col("x1").cast(pl.Date)})
38
39
40
def test_string_datetime() -> None:
41
df = pl.DataFrame(
42
{"x1": ["2021-12-19T00:39:57", "2022-12-19T16:39:57"]}
43
).with_columns(
44
**{
45
"x1-datetime-ns": pl.col("x1").cast(pl.Datetime(time_unit="ns")),
46
"x1-datetime-ms": pl.col("x1").cast(pl.Datetime(time_unit="ms")),
47
"x1-datetime-us": pl.col("x1").cast(pl.Datetime(time_unit="us")),
48
}
49
)
50
first_row = datetime(year=2021, month=12, day=19, hour=00, minute=39, second=57)
51
second_row = datetime(year=2022, month=12, day=19, hour=16, minute=39, second=57)
52
expected = pl.DataFrame(
53
{
54
"x1-datetime-ns": [first_row, second_row],
55
"x1-datetime-ms": [first_row, second_row],
56
"x1-datetime-us": [first_row, second_row],
57
}
58
).select(
59
pl.col("x1-datetime-ns").dt.cast_time_unit("ns"),
60
pl.col("x1-datetime-ms").dt.cast_time_unit("ms"),
61
pl.col("x1-datetime-us").dt.cast_time_unit("us"),
62
)
63
64
out = df.select(
65
pl.col("x1-datetime-ns"), pl.col("x1-datetime-ms"), pl.col("x1-datetime-us")
66
)
67
assert_frame_equal(expected, out)
68
69
70
def test_invalid_string_datetime() -> None:
71
df = pl.DataFrame({"x1": ["2021-12-19 00:39:57", "2022-12-19 16:39:57"]})
72
with pytest.raises(InvalidOperationError):
73
df.with_columns(
74
**{"x1-datetime-ns": pl.col("x1").cast(pl.Datetime(time_unit="ns"))}
75
)
76
77
78
def test_string_datetime_timezone() -> None:
79
ccs_tz = "America/Caracas"
80
stg_tz = "America/Santiago"
81
utc_tz = "UTC"
82
df = pl.DataFrame(
83
{"x1": ["1996-12-19T16:39:57 +00:00", "2022-12-19T00:39:57 +00:00"]}
84
).with_columns(
85
**{
86
"x1-datetime-ns": pl.col("x1").cast(
87
pl.Datetime(time_unit="ns", time_zone=ccs_tz)
88
),
89
"x1-datetime-ms": pl.col("x1").cast(
90
pl.Datetime(time_unit="ms", time_zone=stg_tz)
91
),
92
"x1-datetime-us": pl.col("x1").cast(
93
pl.Datetime(time_unit="us", time_zone=utc_tz)
94
),
95
}
96
)
97
98
expected = pl.DataFrame(
99
{
100
"x1-datetime-ns": [
101
datetime(year=1996, month=12, day=19, hour=12, minute=39, second=57),
102
datetime(year=2022, month=12, day=18, hour=20, minute=39, second=57),
103
],
104
"x1-datetime-ms": [
105
datetime(year=1996, month=12, day=19, hour=13, minute=39, second=57),
106
datetime(year=2022, month=12, day=18, hour=21, minute=39, second=57),
107
],
108
"x1-datetime-us": [
109
datetime(year=1996, month=12, day=19, hour=16, minute=39, second=57),
110
datetime(year=2022, month=12, day=19, hour=00, minute=39, second=57),
111
],
112
}
113
).select(
114
pl.col("x1-datetime-ns").dt.cast_time_unit("ns").dt.replace_time_zone(ccs_tz),
115
pl.col("x1-datetime-ms").dt.cast_time_unit("ms").dt.replace_time_zone(stg_tz),
116
pl.col("x1-datetime-us").dt.cast_time_unit("us").dt.replace_time_zone(utc_tz),
117
)
118
119
out = df.select(
120
pl.col("x1-datetime-ns"), pl.col("x1-datetime-ms"), pl.col("x1-datetime-us")
121
)
122
123
assert_frame_equal(expected, out)
124
125
126
@pytest.mark.parametrize(("dtype"), [pl.Int8, pl.Int16, pl.Int32, pl.Int64])
127
def test_leading_plus_zero_int(dtype: pl.DataType) -> None:
128
s_int = pl.Series(
129
[
130
"-000000000000002",
131
"-1",
132
"-0",
133
"0",
134
"+0",
135
"1",
136
"+1",
137
"0000000000000000000002",
138
"+000000000000000000003",
139
]
140
)
141
assert_series_equal(
142
s_int.cast(dtype), pl.Series([-2, -1, 0, 0, 0, 1, 1, 2, 3], dtype=dtype)
143
)
144
145
146
@pytest.mark.parametrize(("dtype"), [pl.UInt8, pl.UInt16, pl.UInt32, pl.UInt64])
147
def test_leading_plus_zero_uint(dtype: pl.DataType) -> None:
148
s_int = pl.Series(
149
["0", "+0", "1", "+1", "0000000000000000000002", "+000000000000000000003"]
150
)
151
assert_series_equal(s_int.cast(dtype), pl.Series([0, 0, 1, 1, 2, 3], dtype=dtype))
152
153
154
@pytest.mark.parametrize(("dtype"), [pl.Float32, pl.Float64])
155
def test_leading_plus_zero_float(dtype: pl.DataType) -> None:
156
s_float = pl.Series(
157
[
158
"-000000000000002.0",
159
"-1.0",
160
"-.5",
161
"-0.0",
162
"0.",
163
"+0",
164
"+.5",
165
"1",
166
"+1",
167
"0000000000000000000002",
168
"+000000000000000000003",
169
]
170
)
171
assert_series_equal(
172
s_float.cast(dtype),
173
pl.Series(
174
[-2.0, -1.0, -0.5, 0.0, 0.0, 0.0, 0.5, 1.0, 1.0, 2.0, 3.0], dtype=dtype
175
),
176
)
177
178
179
def _cast_series(
180
val: int | datetime | date | time | timedelta,
181
dtype_in: PolarsDataType,
182
dtype_out: PolarsDataType,
183
strict: bool,
184
) -> int | datetime | date | time | timedelta | None:
185
return pl.Series("a", [val], dtype=dtype_in).cast(dtype_out, strict=strict).item() # type: ignore[no-any-return]
186
187
188
def _cast_expr(
189
val: int | datetime | date | time | timedelta,
190
dtype_in: PolarsDataType,
191
dtype_out: PolarsDataType,
192
strict: bool,
193
) -> int | datetime | date | time | timedelta | None:
194
return ( # type: ignore[no-any-return]
195
pl.Series("a", [val], dtype=dtype_in)
196
.to_frame()
197
.select(pl.col("a").cast(dtype_out, strict=strict))
198
.item()
199
)
200
201
202
def _cast_lit(
203
val: int | datetime | date | time | timedelta,
204
dtype_in: PolarsDataType,
205
dtype_out: PolarsDataType,
206
strict: bool,
207
) -> int | datetime | date | time | timedelta | None:
208
return pl.select(pl.lit(val, dtype=dtype_in).cast(dtype_out, strict=strict)).item() # type: ignore[no-any-return]
209
210
211
@pytest.mark.parametrize(
212
("value", "from_dtype", "to_dtype", "should_succeed", "expected_value"),
213
[
214
(-1, pl.Int8, pl.UInt8, False, None),
215
(-1, pl.Int16, pl.UInt16, False, None),
216
(-1, pl.Int32, pl.UInt32, False, None),
217
(-1, pl.Int64, pl.UInt64, False, None),
218
(2**7, pl.UInt8, pl.Int8, False, None),
219
(2**15, pl.UInt16, pl.Int16, False, None),
220
(2**31, pl.UInt32, pl.Int32, False, None),
221
(2**63, pl.UInt64, pl.Int64, False, None),
222
(2**7 - 1, pl.UInt8, pl.Int8, True, 2**7 - 1),
223
(2**15 - 1, pl.UInt16, pl.Int16, True, 2**15 - 1),
224
(2**31 - 1, pl.UInt32, pl.Int32, True, 2**31 - 1),
225
(2**63 - 1, pl.UInt64, pl.Int64, True, 2**63 - 1),
226
],
227
)
228
def test_strict_cast_int(
229
value: int,
230
from_dtype: PolarsDataType,
231
to_dtype: PolarsDataType,
232
should_succeed: bool,
233
expected_value: Any,
234
) -> None:
235
args = [value, from_dtype, to_dtype, True]
236
if should_succeed:
237
assert _cast_series(*args) == expected_value # type: ignore[arg-type]
238
assert _cast_expr(*args) == expected_value # type: ignore[arg-type]
239
assert _cast_lit(*args) == expected_value # type: ignore[arg-type]
240
else:
241
with pytest.raises(InvalidOperationError):
242
_cast_series(*args) # type: ignore[arg-type]
243
with pytest.raises(InvalidOperationError):
244
_cast_expr(*args) # type: ignore[arg-type]
245
with pytest.raises(InvalidOperationError):
246
_cast_lit(*args) # type: ignore[arg-type]
247
248
249
@pytest.mark.parametrize(
250
("value", "from_dtype", "to_dtype", "expected_value"),
251
[
252
(-1, pl.Int8, pl.UInt8, None),
253
(-1, pl.Int16, pl.UInt16, None),
254
(-1, pl.Int32, pl.UInt32, None),
255
(-1, pl.Int64, pl.UInt64, None),
256
(2**7, pl.UInt8, pl.Int8, None),
257
(2**15, pl.UInt16, pl.Int16, None),
258
(2**31, pl.UInt32, pl.Int32, None),
259
(2**63, pl.UInt64, pl.Int64, None),
260
(2**7 - 1, pl.UInt8, pl.Int8, 2**7 - 1),
261
(2**15 - 1, pl.UInt16, pl.Int16, 2**15 - 1),
262
(2**31 - 1, pl.UInt32, pl.Int32, 2**31 - 1),
263
(2**63 - 1, pl.UInt64, pl.Int64, 2**63 - 1),
264
],
265
)
266
def test_cast_int(
267
value: int,
268
from_dtype: PolarsDataType,
269
to_dtype: PolarsDataType,
270
expected_value: Any,
271
) -> None:
272
args = [value, from_dtype, to_dtype, False]
273
assert _cast_series(*args) == expected_value # type: ignore[arg-type]
274
assert _cast_expr(*args) == expected_value # type: ignore[arg-type]
275
assert _cast_lit(*args) == expected_value # type: ignore[arg-type]
276
277
278
def _cast_series_t(
279
val: int | datetime | date | time | timedelta,
280
dtype_in: PolarsDataType,
281
dtype_out: PolarsDataType,
282
strict: bool,
283
) -> pl.Series:
284
return pl.Series("a", [val], dtype=dtype_in).cast(dtype_out, strict=strict)
285
286
287
def _cast_expr_t(
288
val: int | datetime | date | time | timedelta,
289
dtype_in: PolarsDataType,
290
dtype_out: PolarsDataType,
291
strict: bool,
292
) -> pl.Series:
293
return (
294
pl.Series("a", [val], dtype=dtype_in)
295
.to_frame()
296
.select(pl.col("a").cast(dtype_out, strict=strict))
297
.to_series()
298
)
299
300
301
def _cast_lit_t(
302
val: int | datetime | date | time | timedelta,
303
dtype_in: PolarsDataType,
304
dtype_out: PolarsDataType,
305
strict: bool,
306
) -> pl.Series:
307
return pl.select(
308
pl.lit(val, dtype=dtype_in).cast(dtype_out, strict=strict)
309
).to_series()
310
311
312
@pytest.mark.parametrize(
313
(
314
"value",
315
"from_dtype",
316
"to_dtype",
317
"should_succeed",
318
"expected_value",
319
),
320
[
321
# date to datetime
322
(date(1970, 1, 1), pl.Date, pl.Datetime("ms"), True, datetime(1970, 1, 1)),
323
(date(1970, 1, 1), pl.Date, pl.Datetime("us"), True, datetime(1970, 1, 1)),
324
(date(1970, 1, 1), pl.Date, pl.Datetime("ns"), True, datetime(1970, 1, 1)),
325
# datetime to date
326
(datetime(1970, 1, 1), pl.Datetime("ms"), pl.Date, True, date(1970, 1, 1)),
327
(datetime(1970, 1, 1), pl.Datetime("us"), pl.Date, True, date(1970, 1, 1)),
328
(datetime(1970, 1, 1), pl.Datetime("ns"), pl.Date, True, date(1970, 1, 1)),
329
# datetime to time
330
(datetime(2000, 1, 1, 1, 0, 0), pl.Datetime("ms"), pl.Time, True, time(hour=1)),
331
(datetime(2000, 1, 1, 1, 0, 0), pl.Datetime("us"), pl.Time, True, time(hour=1)),
332
(datetime(2000, 1, 1, 1, 0, 0), pl.Datetime("ns"), pl.Time, True, time(hour=1)),
333
# duration to int
334
(timedelta(seconds=1), pl.Duration("ms"), pl.Int32, True, MS_PER_SECOND),
335
(timedelta(seconds=1), pl.Duration("us"), pl.Int64, True, US_PER_SECOND),
336
(timedelta(seconds=1), pl.Duration("ns"), pl.Int64, True, NS_PER_SECOND),
337
# time to duration
338
(time(hour=1), pl.Time, pl.Duration("ms"), True, timedelta(hours=1)),
339
(time(hour=1), pl.Time, pl.Duration("us"), True, timedelta(hours=1)),
340
(time(hour=1), pl.Time, pl.Duration("ns"), True, timedelta(hours=1)),
341
# int to date
342
(100, pl.UInt8, pl.Date, True, date(1970, 4, 11)),
343
(100, pl.UInt16, pl.Date, True, date(1970, 4, 11)),
344
(100, pl.UInt32, pl.Date, True, date(1970, 4, 11)),
345
(100, pl.UInt64, pl.Date, True, date(1970, 4, 11)),
346
(100, pl.Int8, pl.Date, True, date(1970, 4, 11)),
347
(100, pl.Int16, pl.Date, True, date(1970, 4, 11)),
348
(100, pl.Int32, pl.Date, True, date(1970, 4, 11)),
349
(100, pl.Int64, pl.Date, True, date(1970, 4, 11)),
350
# failures
351
(2**63 - 1, pl.Int64, pl.Date, False, None),
352
(-(2**62), pl.Int64, pl.Date, False, None),
353
(date(1970, 5, 10), pl.Date, pl.Int8, False, None),
354
(date(2149, 6, 7), pl.Date, pl.Int16, False, None),
355
(datetime(9999, 12, 31), pl.Datetime, pl.Int8, False, None),
356
(datetime(9999, 12, 31), pl.Datetime, pl.Int16, False, None),
357
],
358
)
359
def test_strict_cast_temporal(
360
value: int,
361
from_dtype: PolarsDataType,
362
to_dtype: PolarsDataType,
363
should_succeed: bool,
364
expected_value: Any,
365
) -> None:
366
args = [value, from_dtype, to_dtype, True]
367
if should_succeed:
368
out = _cast_series_t(*args) # type: ignore[arg-type]
369
assert out.item() == expected_value
370
assert out.dtype == to_dtype
371
out = _cast_expr_t(*args) # type: ignore[arg-type]
372
assert out.item() == expected_value
373
assert out.dtype == to_dtype
374
out = _cast_lit_t(*args) # type: ignore[arg-type]
375
assert out.item() == expected_value
376
assert out.dtype == to_dtype
377
else:
378
with pytest.raises(InvalidOperationError):
379
_cast_series_t(*args) # type: ignore[arg-type]
380
with pytest.raises(InvalidOperationError):
381
_cast_expr_t(*args) # type: ignore[arg-type]
382
with pytest.raises(InvalidOperationError):
383
_cast_lit_t(*args) # type: ignore[arg-type]
384
385
386
@pytest.mark.parametrize(
387
(
388
"value",
389
"from_dtype",
390
"to_dtype",
391
"expected_value",
392
),
393
[
394
# date to datetime
395
(date(1970, 1, 1), pl.Date, pl.Datetime("ms"), datetime(1970, 1, 1)),
396
(date(1970, 1, 1), pl.Date, pl.Datetime("us"), datetime(1970, 1, 1)),
397
(date(1970, 1, 1), pl.Date, pl.Datetime("ns"), datetime(1970, 1, 1)),
398
# datetime to date
399
(datetime(1970, 1, 1), pl.Datetime("ms"), pl.Date, date(1970, 1, 1)),
400
(datetime(1970, 1, 1), pl.Datetime("us"), pl.Date, date(1970, 1, 1)),
401
(datetime(1970, 1, 1), pl.Datetime("ns"), pl.Date, date(1970, 1, 1)),
402
# datetime to time
403
(datetime(2000, 1, 1, 1, 0, 0), pl.Datetime("ms"), pl.Time, time(hour=1)),
404
(datetime(2000, 1, 1, 1, 0, 0), pl.Datetime("us"), pl.Time, time(hour=1)),
405
(datetime(2000, 1, 1, 1, 0, 0), pl.Datetime("ns"), pl.Time, time(hour=1)),
406
# duration to int
407
(timedelta(seconds=1), pl.Duration("ms"), pl.Int32, MS_PER_SECOND),
408
(timedelta(seconds=1), pl.Duration("us"), pl.Int64, US_PER_SECOND),
409
(timedelta(seconds=1), pl.Duration("ns"), pl.Int64, NS_PER_SECOND),
410
# time to duration
411
(time(hour=1), pl.Time, pl.Duration("ms"), timedelta(hours=1)),
412
(time(hour=1), pl.Time, pl.Duration("us"), timedelta(hours=1)),
413
(time(hour=1), pl.Time, pl.Duration("ns"), timedelta(hours=1)),
414
# int to date
415
(100, pl.UInt8, pl.Date, date(1970, 4, 11)),
416
(100, pl.UInt16, pl.Date, date(1970, 4, 11)),
417
(100, pl.UInt32, pl.Date, date(1970, 4, 11)),
418
(100, pl.UInt64, pl.Date, date(1970, 4, 11)),
419
(100, pl.Int8, pl.Date, date(1970, 4, 11)),
420
(100, pl.Int16, pl.Date, date(1970, 4, 11)),
421
(100, pl.Int32, pl.Date, date(1970, 4, 11)),
422
(100, pl.Int64, pl.Date, date(1970, 4, 11)),
423
# failures
424
(2**63 - 1, pl.Int64, pl.Date, None),
425
(-(2**62), pl.Int64, pl.Date, None),
426
(date(1970, 5, 10), pl.Date, pl.Int8, None),
427
(date(2149, 6, 7), pl.Date, pl.Int16, None),
428
(datetime(9999, 12, 31), pl.Datetime, pl.Int8, None),
429
(datetime(9999, 12, 31), pl.Datetime, pl.Int16, None),
430
],
431
)
432
def test_cast_temporal(
433
value: int,
434
from_dtype: PolarsDataType,
435
to_dtype: PolarsDataType,
436
expected_value: Any,
437
) -> None:
438
args = [value, from_dtype, to_dtype, False]
439
out = _cast_series_t(*args) # type: ignore[arg-type]
440
if expected_value is None:
441
assert out.item() is None
442
else:
443
assert out.item() == expected_value
444
assert out.dtype == to_dtype
445
446
out = _cast_expr_t(*args) # type: ignore[arg-type]
447
if expected_value is None:
448
assert out.item() is None
449
else:
450
assert out.item() == expected_value
451
assert out.dtype == to_dtype
452
453
out = _cast_lit_t(*args) # type: ignore[arg-type]
454
if expected_value is None:
455
assert out.item() is None
456
else:
457
assert out.item() == expected_value
458
assert out.dtype == to_dtype
459
460
461
@pytest.mark.parametrize(
462
(
463
"value",
464
"from_dtype",
465
"to_dtype",
466
"expected_value",
467
),
468
[
469
(str(2**7 - 1), pl.String, pl.Int8, 2**7 - 1),
470
(str(2**15 - 1), pl.String, pl.Int16, 2**15 - 1),
471
(str(2**31 - 1), pl.String, pl.Int32, 2**31 - 1),
472
(str(2**63 - 1), pl.String, pl.Int64, 2**63 - 1),
473
("1.0", pl.String, pl.Float32, 1.0),
474
("1.0", pl.String, pl.Float64, 1.0),
475
# overflow
476
(str(2**7), pl.String, pl.Int8, None),
477
(str(2**15), pl.String, pl.Int16, None),
478
(str(2**31), pl.String, pl.Int32, None),
479
(str(2**63), pl.String, pl.Int64, None),
480
],
481
)
482
def test_cast_string(
483
value: int,
484
from_dtype: PolarsDataType,
485
to_dtype: PolarsDataType,
486
expected_value: Any,
487
) -> None:
488
args = [value, from_dtype, to_dtype, False]
489
out = _cast_series_t(*args) # type: ignore[arg-type]
490
if expected_value is None:
491
assert out.item() is None
492
else:
493
assert out.item() == expected_value
494
assert out.dtype == to_dtype
495
496
out = _cast_expr_t(*args) # type: ignore[arg-type]
497
if expected_value is None:
498
assert out.item() is None
499
else:
500
assert out.item() == expected_value
501
assert out.dtype == to_dtype
502
503
out = _cast_lit_t(*args) # type: ignore[arg-type]
504
if expected_value is None:
505
assert out.item() is None
506
else:
507
assert out.item() == expected_value
508
assert out.dtype == to_dtype
509
510
511
@pytest.mark.parametrize(
512
(
513
"value",
514
"from_dtype",
515
"to_dtype",
516
"should_succeed",
517
"expected_value",
518
),
519
[
520
(str(2**7 - 1), pl.String, pl.Int8, True, 2**7 - 1),
521
(str(2**15 - 1), pl.String, pl.Int16, True, 2**15 - 1),
522
(str(2**31 - 1), pl.String, pl.Int32, True, 2**31 - 1),
523
(str(2**63 - 1), pl.String, pl.Int64, True, 2**63 - 1),
524
("1.0", pl.String, pl.Float32, True, 1.0),
525
("1.0", pl.String, pl.Float64, True, 1.0),
526
# overflow
527
(str(2**7), pl.String, pl.Int8, False, None),
528
(str(2**15), pl.String, pl.Int16, False, None),
529
(str(2**31), pl.String, pl.Int32, False, None),
530
(str(2**63), pl.String, pl.Int64, False, None),
531
],
532
)
533
def test_strict_cast_string(
534
value: int,
535
from_dtype: PolarsDataType,
536
to_dtype: PolarsDataType,
537
should_succeed: bool,
538
expected_value: Any,
539
) -> None:
540
args = [value, from_dtype, to_dtype, True]
541
if should_succeed:
542
out = _cast_series_t(*args) # type: ignore[arg-type]
543
assert out.item() == expected_value
544
assert out.dtype == to_dtype
545
out = _cast_expr_t(*args) # type: ignore[arg-type]
546
assert out.item() == expected_value
547
assert out.dtype == to_dtype
548
out = _cast_lit_t(*args) # type: ignore[arg-type]
549
assert out.item() == expected_value
550
assert out.dtype == to_dtype
551
else:
552
with pytest.raises(InvalidOperationError):
553
_cast_series_t(*args) # type: ignore[arg-type]
554
with pytest.raises(InvalidOperationError):
555
_cast_expr_t(*args) # type: ignore[arg-type]
556
with pytest.raises(InvalidOperationError):
557
_cast_lit_t(*args) # type: ignore[arg-type]
558
559
560
@pytest.mark.parametrize(
561
"dtype_in",
562
[(pl.Categorical), (pl.Enum(["1"]))],
563
)
564
@pytest.mark.parametrize(
565
"dtype_out",
566
[
567
pl.String,
568
pl.Categorical,
569
pl.Enum(["1", "2"]),
570
],
571
)
572
def test_cast_categorical_name_retention(
573
dtype_in: PolarsDataType, dtype_out: PolarsDataType
574
) -> None:
575
assert pl.Series("a", ["1"], dtype=dtype_in).cast(dtype_out).name == "a"
576
577
578
def test_cast_date_to_time() -> None:
579
s = pl.Series([date(1970, 1, 1), date(2000, 12, 31)])
580
msg = "casting from Date to Time not supported"
581
with pytest.raises(InvalidOperationError, match=msg):
582
s.cast(pl.Time)
583
584
585
def test_cast_time_to_date() -> None:
586
s = pl.Series([time(0, 0), time(20, 00)])
587
msg = "casting from Time to Date not supported"
588
with pytest.raises(InvalidOperationError, match=msg):
589
s.cast(pl.Date)
590
591
592
def test_cast_decimal_to_boolean() -> None:
593
s = pl.Series("s", [Decimal("0.0"), Decimal("1.5"), Decimal("-1.5")])
594
assert_series_equal(s.cast(pl.Boolean), pl.Series("s", [False, True, True]))
595
596
df = s.to_frame()
597
assert_frame_equal(
598
df.select(pl.col("s").cast(pl.Boolean)),
599
pl.DataFrame({"s": [False, True, True]}),
600
)
601
602
603
def test_cast_array_to_different_width() -> None:
604
s = pl.Series([[1, 2], [3, 4]], dtype=pl.Array(pl.Int8, 2))
605
with pytest.raises(
606
InvalidOperationError, match="cannot cast Array to a different width"
607
):
608
s.cast(pl.Array(pl.Int16, 3))
609
610
611
def test_cast_decimal_to_decimal_high_precision() -> None:
612
precision = 22
613
values = [Decimal("9" * precision)]
614
s = pl.Series(values, dtype=pl.Decimal(None, 0))
615
616
target_dtype = pl.Decimal(precision, 0)
617
result = s.cast(target_dtype)
618
619
assert result.dtype == target_dtype
620
assert result.to_list() == values
621
622
623
@pytest.mark.parametrize("value", [float("inf"), float("nan")])
624
def test_invalid_cast_float_to_decimal(value: float) -> None:
625
s = pl.Series([value], dtype=pl.Float64)
626
with pytest.raises(
627
InvalidOperationError,
628
match=r"conversion from `f64` to `decimal\[10,2\]` failed",
629
):
630
s.cast(pl.Decimal(10, 2))
631
632
633
def test_err_on_time_datetime_cast() -> None:
634
s = pl.Series([time(10, 0, 0), time(11, 30, 59)])
635
with pytest.raises(
636
InvalidOperationError,
637
match=r"casting from Time to Datetime\('μs'\) not supported; consider using `dt\.combine`",
638
):
639
s.cast(pl.Datetime)
640
641
642
def test_err_on_invalid_time_zone_cast() -> None:
643
s = pl.Series([datetime(2021, 1, 1)])
644
with pytest.raises(ComputeError, match=r"unable to parse time zone: 'qwerty'"):
645
s.cast(pl.Datetime("us", "qwerty"))
646
647
648
def test_invalid_inner_type_cast_list() -> None:
649
s = pl.Series([[-1, 1]])
650
with pytest.raises(
651
InvalidOperationError,
652
match=r"cannot cast List inner type: 'Int64' to Categorical",
653
):
654
s.cast(pl.List(pl.Categorical))
655
656
657
@pytest.mark.parametrize(
658
("values", "result"),
659
[
660
([[]], [b""]),
661
([[1, 2], [3, 4]], [b"\x01\x02", b"\x03\x04"]),
662
([[1, 2], None, [3, 4]], [b"\x01\x02", None, b"\x03\x04"]),
663
(
664
[None, [111, 110, 101], [12, None], [116, 119, 111], list(range(256))],
665
[
666
None,
667
b"one",
668
# A list with a null in it gets turned into a null:
669
None,
670
b"two",
671
bytes(i for i in range(256)),
672
],
673
),
674
],
675
)
676
def test_list_uint8_to_bytes(
677
values: list[list[int | None] | None], result: list[bytes | None]
678
) -> None:
679
s = pl.Series(
680
values,
681
dtype=pl.List(pl.UInt8()),
682
)
683
assert s.cast(pl.Binary(), strict=False).to_list() == result
684
685
686
def test_list_uint8_to_bytes_strict() -> None:
687
series = pl.Series(
688
[[1, 2], [3, 4]],
689
dtype=pl.List(pl.UInt8()),
690
)
691
assert series.cast(pl.Binary(), strict=True).to_list() == [b"\x01\x02", b"\x03\x04"]
692
693
series = pl.Series(
694
"mycol",
695
[[1, 2], [3, None]],
696
dtype=pl.List(pl.UInt8()),
697
)
698
with pytest.raises(
699
InvalidOperationError,
700
match="conversion from `list\\[u8\\]` to `binary` failed in column 'mycol' for 1 out of 2 values: \\[\\[3, null\\]\\]",
701
):
702
series.cast(pl.Binary(), strict=True)
703
704
705
def test_all_null_cast_5826() -> None:
706
df = pl.DataFrame(data=[pl.Series("a", [None], dtype=pl.String)])
707
out = df.with_columns(pl.col("a").cast(pl.Boolean))
708
assert out.dtypes == [pl.Boolean]
709
assert out.item() is None
710
711
712
@pytest.mark.parametrize("dtype", INTEGER_DTYPES)
713
def test_bool_numeric_supertype(dtype: PolarsDataType) -> None:
714
df = pl.DataFrame({"v": [1, 2, 3, 4, 5, 6]})
715
result = df.select((pl.col("v") < 3).sum().cast(dtype) / pl.len())
716
assert result.item() - 0.3333333 <= 0.00001
717
718
719
@pytest.mark.parametrize("dtype", [pl.String(), pl.String, str])
720
def test_cast_consistency(dtype: PolarsDataType | PythonDataType) -> None:
721
assert pl.DataFrame().with_columns(a=pl.lit(0.0)).with_columns(
722
b=pl.col("a").cast(dtype), c=pl.lit(0.0).cast(dtype)
723
).to_dict(as_series=False) == {"a": [0.0], "b": ["0.0"], "c": ["0.0"]}
724
725
726
def test_cast_int_to_string_unsets_sorted_flag_19424() -> None:
727
s = pl.Series([1, 2]).set_sorted()
728
assert s.flags["SORTED_ASC"]
729
assert not s.cast(pl.String).flags["SORTED_ASC"]
730
731
732
def test_cast_integer_to_decimal() -> None:
733
s = pl.Series([1, 2, 3])
734
result = s.cast(pl.Decimal(10, 2))
735
expected = pl.Series(
736
"", [Decimal("1.00"), Decimal("2.00"), Decimal("3.00")], pl.Decimal(10, 2)
737
)
738
assert_series_equal(result, expected)
739
740
741
def test_cast_python_dtypes() -> None:
742
s = pl.Series([0, 1])
743
assert s.cast(int).dtype == pl.Int64
744
assert s.cast(float).dtype == pl.Float64
745
assert s.cast(bool).dtype == pl.Boolean
746
assert s.cast(str).dtype == pl.String
747
748
749
def test_overflowing_cast_literals_21023() -> None:
750
for optimizations in [pl.QueryOptFlags(), pl.QueryOptFlags.none()]:
751
assert_frame_equal(
752
(
753
pl.LazyFrame()
754
.select(
755
pl.lit(pl.Series([128], dtype=pl.Int64)).cast(
756
pl.Int8, wrap_numerical=True
757
)
758
)
759
.collect(optimizations=optimizations)
760
),
761
pl.Series([-128], dtype=pl.Int8).to_frame(),
762
)
763
764
765
@pytest.mark.parametrize("value", [True, False])
766
@pytest.mark.parametrize(
767
"dtype",
768
[
769
pl.Enum(["a", "b"]),
770
pl.Series(["a", "b"], dtype=pl.Categorical).dtype,
771
],
772
)
773
def test_invalid_bool_to_cat(value: bool, dtype: PolarsDataType) -> None:
774
# Enum
775
with pytest.raises(
776
InvalidOperationError,
777
match="cannot cast Boolean to Categorical",
778
):
779
pl.Series([value]).cast(dtype)
780
781
782
@pytest.mark.parametrize(
783
("values", "from_dtype", "to_dtype", "pre_apply"),
784
[
785
([["A"]], pl.List(pl.String), pl.List(pl.Int8), None),
786
([["A"]], pl.Array(pl.String, 1), pl.List(pl.Int8), None),
787
([[["A"]]], pl.List(pl.List(pl.String)), pl.List(pl.List(pl.Int8)), None),
788
(
789
[
790
{"x": "1", "y": "2"},
791
{"x": "A", "y": "B"},
792
{"x": "3", "y": "4"},
793
{"x": "X", "y": "Y"},
794
{"x": "5", "y": "6"},
795
],
796
pl.Struct(
797
{
798
"x": pl.String,
799
"y": pl.String,
800
}
801
),
802
pl.Struct(
803
{
804
"x": pl.Int8,
805
"y": pl.Int32,
806
}
807
),
808
None,
809
),
810
],
811
)
812
def test_nested_strict_casts_failing(
813
values: list[Any],
814
from_dtype: pl.DataType,
815
to_dtype: pl.DataType,
816
pre_apply: Callable[[pl.Series], pl.Series] | None,
817
) -> None:
818
s = pl.Series(values, dtype=from_dtype)
819
820
if pre_apply is not None:
821
s = pre_apply(s)
822
823
with pytest.raises(
824
pl.exceptions.InvalidOperationError,
825
match=r"conversion from",
826
):
827
s.cast(to_dtype)
828
829
830
@pytest.mark.parametrize(
831
("values", "from_dtype", "pre_apply", "to"),
832
[
833
(
834
[["A"], ["1"], ["2"]],
835
pl.List(pl.String),
836
lambda s: s.slice(1, 2),
837
pl.Series([[1], [2]]),
838
),
839
(
840
[["1"], ["A"], ["2"], ["B"], ["3"]],
841
pl.List(pl.String),
842
lambda s: s.filter(pl.Series([True, False, True, False, True])),
843
pl.Series([[1], [2], [3]]),
844
),
845
(
846
[
847
{"x": "1", "y": "2"},
848
{"x": "A", "y": "B"},
849
{"x": "3", "y": "4"},
850
{"x": "X", "y": "Y"},
851
{"x": "5", "y": "6"},
852
],
853
pl.Struct(
854
{
855
"x": pl.String,
856
"y": pl.String,
857
}
858
),
859
lambda s: s.filter(pl.Series([True, False, True, False, True])),
860
pl.Series(
861
[
862
{"x": 1, "y": 2},
863
{"x": 3, "y": 4},
864
{"x": 5, "y": 6},
865
]
866
),
867
),
868
(
869
[
870
{"x": "1", "y": "2"},
871
{"x": "A", "y": "B"},
872
{"x": "3", "y": "4"},
873
{"x": "X", "y": "Y"},
874
{"x": "5", "y": "6"},
875
],
876
pl.Struct(
877
{
878
"x": pl.String,
879
"y": pl.String,
880
}
881
),
882
lambda s: pl.select(
883
pl.when(pl.Series([True, False, True, False, True])).then(s)
884
).to_series(),
885
pl.Series(
886
[
887
{"x": 1, "y": 2},
888
None,
889
{"x": 3, "y": 4},
890
None,
891
{"x": 5, "y": 6},
892
]
893
),
894
),
895
],
896
)
897
def test_nested_strict_casts_succeeds(
898
values: list[Any],
899
from_dtype: pl.DataType,
900
pre_apply: Callable[[pl.Series], pl.Series] | None,
901
to: pl.Series,
902
) -> None:
903
s = pl.Series(values, dtype=from_dtype)
904
905
if pre_apply is not None:
906
s = pre_apply(s)
907
908
assert_series_equal(
909
s.cast(to.dtype),
910
to,
911
)
912
913
914
def test_nested_struct_cast_22744() -> None:
915
s = pl.Series(
916
"x",
917
[{"attrs": {"class": "a"}}],
918
)
919
920
expected = pl.select(
921
pl.lit(s).struct.with_fields(
922
pl.field("attrs").struct.with_fields(
923
[pl.field("class"), pl.lit(None, dtype=pl.String()).alias("other")]
924
)
925
)
926
)
927
928
assert_series_equal(
929
s.cast(
930
pl.Struct({"attrs": pl.Struct({"class": pl.String, "other": pl.String})})
931
),
932
expected.to_series(),
933
)
934
assert_frame_equal(
935
pl.DataFrame([s]).cast(
936
{
937
"x": pl.Struct(
938
{"attrs": pl.Struct({"class": pl.String, "other": pl.String})}
939
)
940
}
941
),
942
expected,
943
)
944
945
946
def test_cast_to_self_is_pruned() -> None:
947
q = pl.LazyFrame({"x": 1}, schema={"x": pl.Int64}).with_columns(
948
y=pl.col("x").cast(pl.Int64)
949
)
950
951
plan = q.explain()
952
assert 'col("x").alias("y")' in plan
953
954
assert_frame_equal(q.collect(), pl.DataFrame({"x": 1, "y": 1}))
955
956
957
@pytest.mark.parametrize(
958
("s", "to", "should_fail"),
959
[
960
(
961
pl.Series([datetime(2025, 1, 1)]),
962
pl.Datetime("ns"),
963
False,
964
),
965
(
966
pl.Series([datetime(9999, 1, 1)]),
967
pl.Datetime("ns"),
968
True,
969
),
970
(
971
pl.Series([datetime(2025, 1, 1), datetime(9999, 1, 1)]),
972
pl.Datetime("ns"),
973
True,
974
),
975
(
976
pl.Series([[datetime(2025, 1, 1)], [datetime(9999, 1, 1)]]),
977
pl.List(pl.Datetime("ns")),
978
True,
979
),
980
# lower date limit for nanosecond
981
(pl.Series([date(1677, 9, 22)]), pl.Datetime("ns"), False),
982
(pl.Series([date(1677, 9, 21)]), pl.Datetime("ns"), True),
983
# upper date limit for nanosecond
984
(pl.Series([date(2262, 4, 11)]), pl.Datetime("ns"), False),
985
(pl.Series([date(2262, 4, 12)]), pl.Datetime("ns"), True),
986
],
987
)
988
def test_cast_temporals_overflow_16039(
989
s: pl.Series, to: pl.DataType, should_fail: bool
990
) -> None:
991
if should_fail:
992
with pytest.raises(
993
pl.exceptions.InvalidOperationError, match="conversion from"
994
):
995
s.cast(to)
996
else:
997
s.cast(to)
998
999
1000
@pytest.mark.parametrize("dtype", NUMERIC_DTYPES)
1001
def test_prune_superfluous_cast(dtype: PolarsDataType) -> None:
1002
lf = pl.LazyFrame({"a": [1, 2, 3]}, schema={"a": dtype})
1003
result = lf.select(pl.col("a").cast(dtype))
1004
assert "strict_cast" not in result.explain()
1005
1006
1007
def test_not_prune_necessary_cast() -> None:
1008
lf = pl.LazyFrame({"a": [1, 2, 3]}, schema={"a": pl.UInt16})
1009
result = lf.select(pl.col("a").cast(pl.UInt8))
1010
assert "strict_cast" in result.explain()
1011
1012
1013
@pytest.mark.parametrize("target_dtype", NUMERIC_DTYPES)
1014
@pytest.mark.parametrize("inner_dtype", NUMERIC_DTYPES)
1015
@pytest.mark.parametrize("op", [operator.mul, operator.truediv])
1016
def test_cast_optimizer_in_list_eval_23924(
1017
inner_dtype: PolarsDataType,
1018
target_dtype: PolarsDataType,
1019
op: Callable[[pl.Expr, pl.Expr], pl.Expr],
1020
) -> None:
1021
print(inner_dtype, target_dtype)
1022
if target_dtype in INTEGER_DTYPES:
1023
df = pl.Series("a", [[1]], dtype=pl.List(target_dtype)).to_frame()
1024
else:
1025
df = pl.Series("a", [[1.0]], dtype=pl.List(target_dtype)).to_frame()
1026
q = df.lazy().select(
1027
pl.col("a").list.eval(
1028
(op(pl.element(), pl.element().cast(inner_dtype))).cast(target_dtype)
1029
)
1030
)
1031
assert q.collect_schema() == q.collect().schema
1032
1033
1034
def test_lit_cast_arithmetic_23677() -> None:
1035
df = pl.DataFrame({"a": [1]}, schema={"a": pl.Float32})
1036
q = df.lazy().select(pl.col("a") / pl.lit(1, pl.Int32))
1037
expected = pl.Schema({"a": pl.Float64})
1038
assert q.collect().schema == expected
1039
1040
1041
@pytest.mark.parametrize("col_dtype", NUMERIC_DTYPES + [pl.Unknown])
1042
@pytest.mark.parametrize("lit_dtype", NUMERIC_DTYPES + [pl.Unknown])
1043
@pytest.mark.parametrize("op", [operator.mul, operator.truediv])
1044
def test_lit_cast_arithmetic_matrix_schema(
1045
col_dtype: PolarsDataType,
1046
lit_dtype: PolarsDataType,
1047
op: Callable[[pl.Expr, pl.Expr], pl.Expr],
1048
) -> None:
1049
# Note (hacky): simply casting to 'pl.Unknown' would create
1050
# `Unknown(UnknownKind::Any())` which is not what we want: the
1051
# default maps to `Unknown(UnknownKind::Int(_)))` so we adjust
1052
df = (
1053
pl.DataFrame({"a": [1]})
1054
if col_dtype == pl.Unknown
1055
else pl.DataFrame({"a": [1]}, schema={"a": col_dtype})
1056
)
1057
q = (
1058
df.lazy().select(op(pl.col("a"), pl.lit(1)))
1059
if lit_dtype == pl.Unknown
1060
else df.lazy().select(op(pl.col("a"), pl.lit(1, lit_dtype)))
1061
)
1062
assert q.collect_schema() == q.collect().schema
1063
1064
1065
def test_strict_cast_nested() -> None:
1066
df = pl.DataFrame({"a": ["42", "10a"]})
1067
struct = pl.Struct({"x": pl.Int32})
1068
with pytest.raises(InvalidOperationError):
1069
df.cast(struct, strict=True)
1070
1071
assert_frame_equal(
1072
df.cast(struct, strict=False),
1073
pl.DataFrame({"a": [{"x": 42}, {"x": None}]}, schema={"a": struct}),
1074
)
1075
1076