Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/operations/test_cast.py
6939 views
1
from __future__ import annotations
2
3
import operator
4
from datetime import date, datetime, time, timedelta
5
from decimal import Decimal
6
from typing import TYPE_CHECKING, Any, Callable
7
8
import pytest
9
10
import polars as pl
11
from polars._utils.constants import MS_PER_SECOND, NS_PER_SECOND, US_PER_SECOND
12
from polars.exceptions import ComputeError, InvalidOperationError
13
from polars.testing import assert_frame_equal
14
from polars.testing.asserts.series import assert_series_equal
15
from tests.unit.conftest import INTEGER_DTYPES, NUMERIC_DTYPES
16
17
if TYPE_CHECKING:
18
from polars._typing import PolarsDataType, PythonDataType
19
20
21
@pytest.mark.parametrize("dtype", [pl.Date(), pl.Date, date])
22
def test_string_date(dtype: PolarsDataType | PythonDataType) -> None:
23
df = pl.DataFrame({"x1": ["2021-01-01"]}).with_columns(
24
**{"x1-date": pl.col("x1").cast(dtype)}
25
)
26
expected = pl.DataFrame({"x1-date": [date(2021, 1, 1)]})
27
out = df.select(pl.col("x1-date"))
28
assert_frame_equal(expected, out)
29
30
31
def test_invalid_string_date() -> None:
32
df = pl.DataFrame({"x1": ["2021-01-aa"]})
33
34
with pytest.raises(InvalidOperationError):
35
df.with_columns(**{"x1-date": pl.col("x1").cast(pl.Date)})
36
37
38
def test_string_datetime() -> None:
39
df = pl.DataFrame(
40
{"x1": ["2021-12-19T00:39:57", "2022-12-19T16:39:57"]}
41
).with_columns(
42
**{
43
"x1-datetime-ns": pl.col("x1").cast(pl.Datetime(time_unit="ns")),
44
"x1-datetime-ms": pl.col("x1").cast(pl.Datetime(time_unit="ms")),
45
"x1-datetime-us": pl.col("x1").cast(pl.Datetime(time_unit="us")),
46
}
47
)
48
first_row = datetime(year=2021, month=12, day=19, hour=00, minute=39, second=57)
49
second_row = datetime(year=2022, month=12, day=19, hour=16, minute=39, second=57)
50
expected = pl.DataFrame(
51
{
52
"x1-datetime-ns": [first_row, second_row],
53
"x1-datetime-ms": [first_row, second_row],
54
"x1-datetime-us": [first_row, second_row],
55
}
56
).select(
57
pl.col("x1-datetime-ns").dt.cast_time_unit("ns"),
58
pl.col("x1-datetime-ms").dt.cast_time_unit("ms"),
59
pl.col("x1-datetime-us").dt.cast_time_unit("us"),
60
)
61
62
out = df.select(
63
pl.col("x1-datetime-ns"), pl.col("x1-datetime-ms"), pl.col("x1-datetime-us")
64
)
65
assert_frame_equal(expected, out)
66
67
68
def test_invalid_string_datetime() -> None:
69
df = pl.DataFrame({"x1": ["2021-12-19 00:39:57", "2022-12-19 16:39:57"]})
70
with pytest.raises(InvalidOperationError):
71
df.with_columns(
72
**{"x1-datetime-ns": pl.col("x1").cast(pl.Datetime(time_unit="ns"))}
73
)
74
75
76
def test_string_datetime_timezone() -> None:
77
ccs_tz = "America/Caracas"
78
stg_tz = "America/Santiago"
79
utc_tz = "UTC"
80
df = pl.DataFrame(
81
{"x1": ["1996-12-19T16:39:57 +00:00", "2022-12-19T00:39:57 +00:00"]}
82
).with_columns(
83
**{
84
"x1-datetime-ns": pl.col("x1").cast(
85
pl.Datetime(time_unit="ns", time_zone=ccs_tz)
86
),
87
"x1-datetime-ms": pl.col("x1").cast(
88
pl.Datetime(time_unit="ms", time_zone=stg_tz)
89
),
90
"x1-datetime-us": pl.col("x1").cast(
91
pl.Datetime(time_unit="us", time_zone=utc_tz)
92
),
93
}
94
)
95
96
expected = pl.DataFrame(
97
{
98
"x1-datetime-ns": [
99
datetime(year=1996, month=12, day=19, hour=12, minute=39, second=57),
100
datetime(year=2022, month=12, day=18, hour=20, minute=39, second=57),
101
],
102
"x1-datetime-ms": [
103
datetime(year=1996, month=12, day=19, hour=13, minute=39, second=57),
104
datetime(year=2022, month=12, day=18, hour=21, minute=39, second=57),
105
],
106
"x1-datetime-us": [
107
datetime(year=1996, month=12, day=19, hour=16, minute=39, second=57),
108
datetime(year=2022, month=12, day=19, hour=00, minute=39, second=57),
109
],
110
}
111
).select(
112
pl.col("x1-datetime-ns").dt.cast_time_unit("ns").dt.replace_time_zone(ccs_tz),
113
pl.col("x1-datetime-ms").dt.cast_time_unit("ms").dt.replace_time_zone(stg_tz),
114
pl.col("x1-datetime-us").dt.cast_time_unit("us").dt.replace_time_zone(utc_tz),
115
)
116
117
out = df.select(
118
pl.col("x1-datetime-ns"), pl.col("x1-datetime-ms"), pl.col("x1-datetime-us")
119
)
120
121
assert_frame_equal(expected, out)
122
123
124
@pytest.mark.parametrize(("dtype"), [pl.Int8, pl.Int16, pl.Int32, pl.Int64])
125
def test_leading_plus_zero_int(dtype: pl.DataType) -> None:
126
s_int = pl.Series(
127
[
128
"-000000000000002",
129
"-1",
130
"-0",
131
"0",
132
"+0",
133
"1",
134
"+1",
135
"0000000000000000000002",
136
"+000000000000000000003",
137
]
138
)
139
assert_series_equal(
140
s_int.cast(dtype), pl.Series([-2, -1, 0, 0, 0, 1, 1, 2, 3], dtype=dtype)
141
)
142
143
144
@pytest.mark.parametrize(("dtype"), [pl.UInt8, pl.UInt16, pl.UInt32, pl.UInt64])
145
def test_leading_plus_zero_uint(dtype: pl.DataType) -> None:
146
s_int = pl.Series(
147
["0", "+0", "1", "+1", "0000000000000000000002", "+000000000000000000003"]
148
)
149
assert_series_equal(s_int.cast(dtype), pl.Series([0, 0, 1, 1, 2, 3], dtype=dtype))
150
151
152
@pytest.mark.parametrize(("dtype"), [pl.Float32, pl.Float64])
153
def test_leading_plus_zero_float(dtype: pl.DataType) -> None:
154
s_float = pl.Series(
155
[
156
"-000000000000002.0",
157
"-1.0",
158
"-.5",
159
"-0.0",
160
"0.",
161
"+0",
162
"+.5",
163
"1",
164
"+1",
165
"0000000000000000000002",
166
"+000000000000000000003",
167
]
168
)
169
assert_series_equal(
170
s_float.cast(dtype),
171
pl.Series(
172
[-2.0, -1.0, -0.5, 0.0, 0.0, 0.0, 0.5, 1.0, 1.0, 2.0, 3.0], dtype=dtype
173
),
174
)
175
176
177
def _cast_series(
178
val: int | datetime | date | time | timedelta,
179
dtype_in: PolarsDataType,
180
dtype_out: PolarsDataType,
181
strict: bool,
182
) -> int | datetime | date | time | timedelta | None:
183
return pl.Series("a", [val], dtype=dtype_in).cast(dtype_out, strict=strict).item() # type: ignore[no-any-return]
184
185
186
def _cast_expr(
187
val: int | datetime | date | time | timedelta,
188
dtype_in: PolarsDataType,
189
dtype_out: PolarsDataType,
190
strict: bool,
191
) -> int | datetime | date | time | timedelta | None:
192
return ( # type: ignore[no-any-return]
193
pl.Series("a", [val], dtype=dtype_in)
194
.to_frame()
195
.select(pl.col("a").cast(dtype_out, strict=strict))
196
.item()
197
)
198
199
200
def _cast_lit(
201
val: int | datetime | date | time | timedelta,
202
dtype_in: PolarsDataType,
203
dtype_out: PolarsDataType,
204
strict: bool,
205
) -> int | datetime | date | time | timedelta | None:
206
return pl.select(pl.lit(val, dtype=dtype_in).cast(dtype_out, strict=strict)).item() # type: ignore[no-any-return]
207
208
209
@pytest.mark.parametrize(
210
("value", "from_dtype", "to_dtype", "should_succeed", "expected_value"),
211
[
212
(-1, pl.Int8, pl.UInt8, False, None),
213
(-1, pl.Int16, pl.UInt16, False, None),
214
(-1, pl.Int32, pl.UInt32, False, None),
215
(-1, pl.Int64, pl.UInt64, False, None),
216
(2**7, pl.UInt8, pl.Int8, False, None),
217
(2**15, pl.UInt16, pl.Int16, False, None),
218
(2**31, pl.UInt32, pl.Int32, False, None),
219
(2**63, pl.UInt64, pl.Int64, False, None),
220
(2**7 - 1, pl.UInt8, pl.Int8, True, 2**7 - 1),
221
(2**15 - 1, pl.UInt16, pl.Int16, True, 2**15 - 1),
222
(2**31 - 1, pl.UInt32, pl.Int32, True, 2**31 - 1),
223
(2**63 - 1, pl.UInt64, pl.Int64, True, 2**63 - 1),
224
],
225
)
226
def test_strict_cast_int(
227
value: int,
228
from_dtype: PolarsDataType,
229
to_dtype: PolarsDataType,
230
should_succeed: bool,
231
expected_value: Any,
232
) -> None:
233
args = [value, from_dtype, to_dtype, True]
234
if should_succeed:
235
assert _cast_series(*args) == expected_value # type: ignore[arg-type]
236
assert _cast_expr(*args) == expected_value # type: ignore[arg-type]
237
assert _cast_lit(*args) == expected_value # type: ignore[arg-type]
238
else:
239
with pytest.raises(InvalidOperationError):
240
_cast_series(*args) # type: ignore[arg-type]
241
with pytest.raises(InvalidOperationError):
242
_cast_expr(*args) # type: ignore[arg-type]
243
with pytest.raises(InvalidOperationError):
244
_cast_lit(*args) # type: ignore[arg-type]
245
246
247
@pytest.mark.parametrize(
248
("value", "from_dtype", "to_dtype", "expected_value"),
249
[
250
(-1, pl.Int8, pl.UInt8, None),
251
(-1, pl.Int16, pl.UInt16, None),
252
(-1, pl.Int32, pl.UInt32, None),
253
(-1, pl.Int64, pl.UInt64, None),
254
(2**7, pl.UInt8, pl.Int8, None),
255
(2**15, pl.UInt16, pl.Int16, None),
256
(2**31, pl.UInt32, pl.Int32, None),
257
(2**63, pl.UInt64, pl.Int64, None),
258
(2**7 - 1, pl.UInt8, pl.Int8, 2**7 - 1),
259
(2**15 - 1, pl.UInt16, pl.Int16, 2**15 - 1),
260
(2**31 - 1, pl.UInt32, pl.Int32, 2**31 - 1),
261
(2**63 - 1, pl.UInt64, pl.Int64, 2**63 - 1),
262
],
263
)
264
def test_cast_int(
265
value: int,
266
from_dtype: PolarsDataType,
267
to_dtype: PolarsDataType,
268
expected_value: Any,
269
) -> None:
270
args = [value, from_dtype, to_dtype, False]
271
assert _cast_series(*args) == expected_value # type: ignore[arg-type]
272
assert _cast_expr(*args) == expected_value # type: ignore[arg-type]
273
assert _cast_lit(*args) == expected_value # type: ignore[arg-type]
274
275
276
def _cast_series_t(
277
val: int | datetime | date | time | timedelta,
278
dtype_in: PolarsDataType,
279
dtype_out: PolarsDataType,
280
strict: bool,
281
) -> pl.Series:
282
return pl.Series("a", [val], dtype=dtype_in).cast(dtype_out, strict=strict)
283
284
285
def _cast_expr_t(
286
val: int | datetime | date | time | timedelta,
287
dtype_in: PolarsDataType,
288
dtype_out: PolarsDataType,
289
strict: bool,
290
) -> pl.Series:
291
return (
292
pl.Series("a", [val], dtype=dtype_in)
293
.to_frame()
294
.select(pl.col("a").cast(dtype_out, strict=strict))
295
.to_series()
296
)
297
298
299
def _cast_lit_t(
300
val: int | datetime | date | time | timedelta,
301
dtype_in: PolarsDataType,
302
dtype_out: PolarsDataType,
303
strict: bool,
304
) -> pl.Series:
305
return pl.select(
306
pl.lit(val, dtype=dtype_in).cast(dtype_out, strict=strict)
307
).to_series()
308
309
310
@pytest.mark.parametrize(
311
(
312
"value",
313
"from_dtype",
314
"to_dtype",
315
"should_succeed",
316
"expected_value",
317
),
318
[
319
# date to datetime
320
(date(1970, 1, 1), pl.Date, pl.Datetime("ms"), True, datetime(1970, 1, 1)),
321
(date(1970, 1, 1), pl.Date, pl.Datetime("us"), True, datetime(1970, 1, 1)),
322
(date(1970, 1, 1), pl.Date, pl.Datetime("ns"), True, datetime(1970, 1, 1)),
323
# datetime to date
324
(datetime(1970, 1, 1), pl.Datetime("ms"), pl.Date, True, date(1970, 1, 1)),
325
(datetime(1970, 1, 1), pl.Datetime("us"), pl.Date, True, date(1970, 1, 1)),
326
(datetime(1970, 1, 1), pl.Datetime("ns"), pl.Date, True, date(1970, 1, 1)),
327
# datetime to time
328
(datetime(2000, 1, 1, 1, 0, 0), pl.Datetime("ms"), pl.Time, True, time(hour=1)),
329
(datetime(2000, 1, 1, 1, 0, 0), pl.Datetime("us"), pl.Time, True, time(hour=1)),
330
(datetime(2000, 1, 1, 1, 0, 0), pl.Datetime("ns"), pl.Time, True, time(hour=1)),
331
# duration to int
332
(timedelta(seconds=1), pl.Duration("ms"), pl.Int32, True, MS_PER_SECOND),
333
(timedelta(seconds=1), pl.Duration("us"), pl.Int64, True, US_PER_SECOND),
334
(timedelta(seconds=1), pl.Duration("ns"), pl.Int64, True, NS_PER_SECOND),
335
# time to duration
336
(time(hour=1), pl.Time, pl.Duration("ms"), True, timedelta(hours=1)),
337
(time(hour=1), pl.Time, pl.Duration("us"), True, timedelta(hours=1)),
338
(time(hour=1), pl.Time, pl.Duration("ns"), True, timedelta(hours=1)),
339
# int to date
340
(100, pl.UInt8, pl.Date, True, date(1970, 4, 11)),
341
(100, pl.UInt16, pl.Date, True, date(1970, 4, 11)),
342
(100, pl.UInt32, pl.Date, True, date(1970, 4, 11)),
343
(100, pl.UInt64, pl.Date, True, date(1970, 4, 11)),
344
(100, pl.Int8, pl.Date, True, date(1970, 4, 11)),
345
(100, pl.Int16, pl.Date, True, date(1970, 4, 11)),
346
(100, pl.Int32, pl.Date, True, date(1970, 4, 11)),
347
(100, pl.Int64, pl.Date, True, date(1970, 4, 11)),
348
# failures
349
(2**63 - 1, pl.Int64, pl.Date, False, None),
350
(-(2**62), pl.Int64, pl.Date, False, None),
351
(date(1970, 5, 10), pl.Date, pl.Int8, False, None),
352
(date(2149, 6, 7), pl.Date, pl.Int16, False, None),
353
(datetime(9999, 12, 31), pl.Datetime, pl.Int8, False, None),
354
(datetime(9999, 12, 31), pl.Datetime, pl.Int16, False, None),
355
],
356
)
357
def test_strict_cast_temporal(
358
value: int,
359
from_dtype: PolarsDataType,
360
to_dtype: PolarsDataType,
361
should_succeed: bool,
362
expected_value: Any,
363
) -> None:
364
args = [value, from_dtype, to_dtype, True]
365
if should_succeed:
366
out = _cast_series_t(*args) # type: ignore[arg-type]
367
assert out.item() == expected_value
368
assert out.dtype == to_dtype
369
out = _cast_expr_t(*args) # type: ignore[arg-type]
370
assert out.item() == expected_value
371
assert out.dtype == to_dtype
372
out = _cast_lit_t(*args) # type: ignore[arg-type]
373
assert out.item() == expected_value
374
assert out.dtype == to_dtype
375
else:
376
with pytest.raises(InvalidOperationError):
377
_cast_series_t(*args) # type: ignore[arg-type]
378
with pytest.raises(InvalidOperationError):
379
_cast_expr_t(*args) # type: ignore[arg-type]
380
with pytest.raises(InvalidOperationError):
381
_cast_lit_t(*args) # type: ignore[arg-type]
382
383
384
@pytest.mark.parametrize(
385
(
386
"value",
387
"from_dtype",
388
"to_dtype",
389
"expected_value",
390
),
391
[
392
# date to datetime
393
(date(1970, 1, 1), pl.Date, pl.Datetime("ms"), datetime(1970, 1, 1)),
394
(date(1970, 1, 1), pl.Date, pl.Datetime("us"), datetime(1970, 1, 1)),
395
(date(1970, 1, 1), pl.Date, pl.Datetime("ns"), datetime(1970, 1, 1)),
396
# datetime to date
397
(datetime(1970, 1, 1), pl.Datetime("ms"), pl.Date, date(1970, 1, 1)),
398
(datetime(1970, 1, 1), pl.Datetime("us"), pl.Date, date(1970, 1, 1)),
399
(datetime(1970, 1, 1), pl.Datetime("ns"), pl.Date, date(1970, 1, 1)),
400
# datetime to time
401
(datetime(2000, 1, 1, 1, 0, 0), pl.Datetime("ms"), pl.Time, time(hour=1)),
402
(datetime(2000, 1, 1, 1, 0, 0), pl.Datetime("us"), pl.Time, time(hour=1)),
403
(datetime(2000, 1, 1, 1, 0, 0), pl.Datetime("ns"), pl.Time, time(hour=1)),
404
# duration to int
405
(timedelta(seconds=1), pl.Duration("ms"), pl.Int32, MS_PER_SECOND),
406
(timedelta(seconds=1), pl.Duration("us"), pl.Int64, US_PER_SECOND),
407
(timedelta(seconds=1), pl.Duration("ns"), pl.Int64, NS_PER_SECOND),
408
# time to duration
409
(time(hour=1), pl.Time, pl.Duration("ms"), timedelta(hours=1)),
410
(time(hour=1), pl.Time, pl.Duration("us"), timedelta(hours=1)),
411
(time(hour=1), pl.Time, pl.Duration("ns"), timedelta(hours=1)),
412
# int to date
413
(100, pl.UInt8, pl.Date, date(1970, 4, 11)),
414
(100, pl.UInt16, pl.Date, date(1970, 4, 11)),
415
(100, pl.UInt32, pl.Date, date(1970, 4, 11)),
416
(100, pl.UInt64, pl.Date, date(1970, 4, 11)),
417
(100, pl.Int8, pl.Date, date(1970, 4, 11)),
418
(100, pl.Int16, pl.Date, date(1970, 4, 11)),
419
(100, pl.Int32, pl.Date, date(1970, 4, 11)),
420
(100, pl.Int64, pl.Date, date(1970, 4, 11)),
421
# failures
422
(2**63 - 1, pl.Int64, pl.Date, None),
423
(-(2**62), pl.Int64, pl.Date, None),
424
(date(1970, 5, 10), pl.Date, pl.Int8, None),
425
(date(2149, 6, 7), pl.Date, pl.Int16, None),
426
(datetime(9999, 12, 31), pl.Datetime, pl.Int8, None),
427
(datetime(9999, 12, 31), pl.Datetime, pl.Int16, None),
428
],
429
)
430
def test_cast_temporal(
431
value: int,
432
from_dtype: PolarsDataType,
433
to_dtype: PolarsDataType,
434
expected_value: Any,
435
) -> None:
436
args = [value, from_dtype, to_dtype, False]
437
out = _cast_series_t(*args) # type: ignore[arg-type]
438
if expected_value is None:
439
assert out.item() is None
440
else:
441
assert out.item() == expected_value
442
assert out.dtype == to_dtype
443
444
out = _cast_expr_t(*args) # type: ignore[arg-type]
445
if expected_value is None:
446
assert out.item() is None
447
else:
448
assert out.item() == expected_value
449
assert out.dtype == to_dtype
450
451
out = _cast_lit_t(*args) # type: ignore[arg-type]
452
if expected_value is None:
453
assert out.item() is None
454
else:
455
assert out.item() == expected_value
456
assert out.dtype == to_dtype
457
458
459
@pytest.mark.parametrize(
460
(
461
"value",
462
"from_dtype",
463
"to_dtype",
464
"expected_value",
465
),
466
[
467
(str(2**7 - 1), pl.String, pl.Int8, 2**7 - 1),
468
(str(2**15 - 1), pl.String, pl.Int16, 2**15 - 1),
469
(str(2**31 - 1), pl.String, pl.Int32, 2**31 - 1),
470
(str(2**63 - 1), pl.String, pl.Int64, 2**63 - 1),
471
("1.0", pl.String, pl.Float32, 1.0),
472
("1.0", pl.String, pl.Float64, 1.0),
473
# overflow
474
(str(2**7), pl.String, pl.Int8, None),
475
(str(2**15), pl.String, pl.Int16, None),
476
(str(2**31), pl.String, pl.Int32, None),
477
(str(2**63), pl.String, pl.Int64, None),
478
],
479
)
480
def test_cast_string(
481
value: int,
482
from_dtype: PolarsDataType,
483
to_dtype: PolarsDataType,
484
expected_value: Any,
485
) -> None:
486
args = [value, from_dtype, to_dtype, False]
487
out = _cast_series_t(*args) # type: ignore[arg-type]
488
if expected_value is None:
489
assert out.item() is None
490
else:
491
assert out.item() == expected_value
492
assert out.dtype == to_dtype
493
494
out = _cast_expr_t(*args) # type: ignore[arg-type]
495
if expected_value is None:
496
assert out.item() is None
497
else:
498
assert out.item() == expected_value
499
assert out.dtype == to_dtype
500
501
out = _cast_lit_t(*args) # type: ignore[arg-type]
502
if expected_value is None:
503
assert out.item() is None
504
else:
505
assert out.item() == expected_value
506
assert out.dtype == to_dtype
507
508
509
@pytest.mark.parametrize(
510
(
511
"value",
512
"from_dtype",
513
"to_dtype",
514
"should_succeed",
515
"expected_value",
516
),
517
[
518
(str(2**7 - 1), pl.String, pl.Int8, True, 2**7 - 1),
519
(str(2**15 - 1), pl.String, pl.Int16, True, 2**15 - 1),
520
(str(2**31 - 1), pl.String, pl.Int32, True, 2**31 - 1),
521
(str(2**63 - 1), pl.String, pl.Int64, True, 2**63 - 1),
522
("1.0", pl.String, pl.Float32, True, 1.0),
523
("1.0", pl.String, pl.Float64, True, 1.0),
524
# overflow
525
(str(2**7), pl.String, pl.Int8, False, None),
526
(str(2**15), pl.String, pl.Int16, False, None),
527
(str(2**31), pl.String, pl.Int32, False, None),
528
(str(2**63), pl.String, pl.Int64, False, None),
529
],
530
)
531
def test_strict_cast_string(
532
value: int,
533
from_dtype: PolarsDataType,
534
to_dtype: PolarsDataType,
535
should_succeed: bool,
536
expected_value: Any,
537
) -> None:
538
args = [value, from_dtype, to_dtype, True]
539
if should_succeed:
540
out = _cast_series_t(*args) # type: ignore[arg-type]
541
assert out.item() == expected_value
542
assert out.dtype == to_dtype
543
out = _cast_expr_t(*args) # type: ignore[arg-type]
544
assert out.item() == expected_value
545
assert out.dtype == to_dtype
546
out = _cast_lit_t(*args) # type: ignore[arg-type]
547
assert out.item() == expected_value
548
assert out.dtype == to_dtype
549
else:
550
with pytest.raises(InvalidOperationError):
551
_cast_series_t(*args) # type: ignore[arg-type]
552
with pytest.raises(InvalidOperationError):
553
_cast_expr_t(*args) # type: ignore[arg-type]
554
with pytest.raises(InvalidOperationError):
555
_cast_lit_t(*args) # type: ignore[arg-type]
556
557
558
@pytest.mark.parametrize(
559
"dtype_in",
560
[(pl.Categorical), (pl.Enum(["1"]))],
561
)
562
@pytest.mark.parametrize(
563
"dtype_out",
564
[
565
pl.String,
566
pl.Categorical,
567
pl.Enum(["1", "2"]),
568
],
569
)
570
def test_cast_categorical_name_retention(
571
dtype_in: PolarsDataType, dtype_out: PolarsDataType
572
) -> None:
573
assert pl.Series("a", ["1"], dtype=dtype_in).cast(dtype_out).name == "a"
574
575
576
def test_cast_date_to_time() -> None:
577
s = pl.Series([date(1970, 1, 1), date(2000, 12, 31)])
578
msg = "casting from Date to Time not supported"
579
with pytest.raises(InvalidOperationError, match=msg):
580
s.cast(pl.Time)
581
582
583
def test_cast_time_to_date() -> None:
584
s = pl.Series([time(0, 0), time(20, 00)])
585
msg = "casting from Time to Date not supported"
586
with pytest.raises(InvalidOperationError, match=msg):
587
s.cast(pl.Date)
588
589
590
def test_cast_decimal_to_boolean() -> None:
591
s = pl.Series("s", [Decimal("0.0"), Decimal("1.5"), Decimal("-1.5")])
592
assert_series_equal(s.cast(pl.Boolean), pl.Series("s", [False, True, True]))
593
594
df = s.to_frame()
595
assert_frame_equal(
596
df.select(pl.col("s").cast(pl.Boolean)),
597
pl.DataFrame({"s": [False, True, True]}),
598
)
599
600
601
def test_cast_array_to_different_width() -> None:
602
s = pl.Series([[1, 2], [3, 4]], dtype=pl.Array(pl.Int8, 2))
603
with pytest.raises(
604
InvalidOperationError, match="cannot cast Array to a different width"
605
):
606
s.cast(pl.Array(pl.Int16, 3))
607
608
609
def test_cast_decimal_to_decimal_high_precision() -> None:
610
precision = 22
611
values = [Decimal("9" * precision)]
612
s = pl.Series(values, dtype=pl.Decimal(None, 0))
613
614
target_dtype = pl.Decimal(precision, 0)
615
result = s.cast(target_dtype)
616
617
assert result.dtype == target_dtype
618
assert result.to_list() == values
619
620
621
@pytest.mark.parametrize("value", [float("inf"), float("nan")])
622
def test_invalid_cast_float_to_decimal(value: float) -> None:
623
s = pl.Series([value], dtype=pl.Float64)
624
with pytest.raises(
625
InvalidOperationError,
626
match=r"conversion from `f64` to `decimal\[\*,0\]` failed",
627
):
628
s.cast(pl.Decimal)
629
630
631
def test_err_on_time_datetime_cast() -> None:
632
s = pl.Series([time(10, 0, 0), time(11, 30, 59)])
633
with pytest.raises(
634
InvalidOperationError,
635
match="casting from Time to Datetime\\('μs'\\) not supported; consider using `dt.combine`",
636
):
637
s.cast(pl.Datetime)
638
639
640
def test_err_on_invalid_time_zone_cast() -> None:
641
s = pl.Series([datetime(2021, 1, 1)])
642
with pytest.raises(ComputeError, match=r"unable to parse time zone: 'qwerty'"):
643
s.cast(pl.Datetime("us", "qwerty"))
644
645
646
def test_invalid_inner_type_cast_list() -> None:
647
s = pl.Series([[-1, 1]])
648
with pytest.raises(
649
InvalidOperationError,
650
match=r"cannot cast List inner type: 'Int64' to Categorical",
651
):
652
s.cast(pl.List(pl.Categorical))
653
654
655
@pytest.mark.parametrize(
656
("values", "result"),
657
[
658
([[]], [b""]),
659
([[1, 2], [3, 4]], [b"\x01\x02", b"\x03\x04"]),
660
([[1, 2], None, [3, 4]], [b"\x01\x02", None, b"\x03\x04"]),
661
(
662
[None, [111, 110, 101], [12, None], [116, 119, 111], list(range(256))],
663
[
664
None,
665
b"one",
666
# A list with a null in it gets turned into a null:
667
None,
668
b"two",
669
bytes(i for i in range(256)),
670
],
671
),
672
],
673
)
674
def test_list_uint8_to_bytes(
675
values: list[list[int | None] | None], result: list[bytes | None]
676
) -> None:
677
s = pl.Series(
678
values,
679
dtype=pl.List(pl.UInt8()),
680
)
681
assert s.cast(pl.Binary(), strict=False).to_list() == result
682
683
684
def test_list_uint8_to_bytes_strict() -> None:
685
series = pl.Series(
686
[[1, 2], [3, 4]],
687
dtype=pl.List(pl.UInt8()),
688
)
689
assert series.cast(pl.Binary(), strict=True).to_list() == [b"\x01\x02", b"\x03\x04"]
690
691
series = pl.Series(
692
"mycol",
693
[[1, 2], [3, None]],
694
dtype=pl.List(pl.UInt8()),
695
)
696
with pytest.raises(
697
InvalidOperationError,
698
match="conversion from `list\\[u8\\]` to `binary` failed in column 'mycol' for 1 out of 2 values: \\[\\[3, null\\]\\]",
699
):
700
series.cast(pl.Binary(), strict=True)
701
702
703
def test_all_null_cast_5826() -> None:
704
df = pl.DataFrame(data=[pl.Series("a", [None], dtype=pl.String)])
705
out = df.with_columns(pl.col("a").cast(pl.Boolean))
706
assert out.dtypes == [pl.Boolean]
707
assert out.item() is None
708
709
710
@pytest.mark.parametrize("dtype", INTEGER_DTYPES)
711
def test_bool_numeric_supertype(dtype: PolarsDataType) -> None:
712
df = pl.DataFrame({"v": [1, 2, 3, 4, 5, 6]})
713
result = df.select((pl.col("v") < 3).sum().cast(dtype) / pl.len())
714
assert result.item() - 0.3333333 <= 0.00001
715
716
717
@pytest.mark.parametrize("dtype", [pl.String(), pl.String, str])
718
def test_cast_consistency(dtype: PolarsDataType | PythonDataType) -> None:
719
assert pl.DataFrame().with_columns(a=pl.lit(0.0)).with_columns(
720
b=pl.col("a").cast(dtype), c=pl.lit(0.0).cast(dtype)
721
).to_dict(as_series=False) == {"a": [0.0], "b": ["0.0"], "c": ["0.0"]}
722
723
724
def test_cast_int_to_string_unsets_sorted_flag_19424() -> None:
725
s = pl.Series([1, 2]).set_sorted()
726
assert s.flags["SORTED_ASC"]
727
assert not s.cast(pl.String).flags["SORTED_ASC"]
728
729
730
def test_cast_integer_to_decimal() -> None:
731
s = pl.Series([1, 2, 3])
732
result = s.cast(pl.Decimal(10, 2))
733
expected = pl.Series(
734
"", [Decimal("1.00"), Decimal("2.00"), Decimal("3.00")], pl.Decimal(10, 2)
735
)
736
assert_series_equal(result, expected)
737
738
739
def test_cast_python_dtypes() -> None:
740
s = pl.Series([0, 1])
741
assert s.cast(int).dtype == pl.Int64
742
assert s.cast(float).dtype == pl.Float64
743
assert s.cast(bool).dtype == pl.Boolean
744
assert s.cast(str).dtype == pl.String
745
746
747
def test_overflowing_cast_literals_21023() -> None:
748
for optimizations in [pl.QueryOptFlags(), pl.QueryOptFlags.none()]:
749
assert_frame_equal(
750
(
751
pl.LazyFrame()
752
.select(
753
pl.lit(pl.Series([128], dtype=pl.Int64)).cast(
754
pl.Int8, wrap_numerical=True
755
)
756
)
757
.collect(optimizations=optimizations)
758
),
759
pl.Series([-128], dtype=pl.Int8).to_frame(),
760
)
761
762
763
@pytest.mark.parametrize("value", [True, False])
764
@pytest.mark.parametrize(
765
"dtype",
766
[
767
pl.Enum(["a", "b"]),
768
pl.Series(["a", "b"], dtype=pl.Categorical).dtype,
769
],
770
)
771
def test_invalid_bool_to_cat(value: bool, dtype: PolarsDataType) -> None:
772
# Enum
773
with pytest.raises(
774
InvalidOperationError,
775
match="cannot cast Boolean to Categorical",
776
):
777
pl.Series([value]).cast(dtype)
778
779
780
@pytest.mark.parametrize(
781
("values", "from_dtype", "to_dtype", "pre_apply"),
782
[
783
([["A"]], pl.List(pl.String), pl.List(pl.Int8), None),
784
([["A"]], pl.Array(pl.String, 1), pl.List(pl.Int8), None),
785
([[["A"]]], pl.List(pl.List(pl.String)), pl.List(pl.List(pl.Int8)), None),
786
(
787
[
788
{"x": "1", "y": "2"},
789
{"x": "A", "y": "B"},
790
{"x": "3", "y": "4"},
791
{"x": "X", "y": "Y"},
792
{"x": "5", "y": "6"},
793
],
794
pl.Struct(
795
{
796
"x": pl.String,
797
"y": pl.String,
798
}
799
),
800
pl.Struct(
801
{
802
"x": pl.Int8,
803
"y": pl.Int32,
804
}
805
),
806
None,
807
),
808
],
809
)
810
def test_nested_strict_casts_failing(
811
values: list[Any],
812
from_dtype: pl.DataType,
813
to_dtype: pl.DataType,
814
pre_apply: Callable[[pl.Series], pl.Series] | None,
815
) -> None:
816
s = pl.Series(values, dtype=from_dtype)
817
818
if pre_apply is not None:
819
s = pre_apply(s)
820
821
with pytest.raises(
822
pl.exceptions.InvalidOperationError,
823
match=r"conversion from",
824
):
825
s.cast(to_dtype)
826
827
828
@pytest.mark.parametrize(
829
("values", "from_dtype", "pre_apply", "to"),
830
[
831
(
832
[["A"], ["1"], ["2"]],
833
pl.List(pl.String),
834
lambda s: s.slice(1, 2),
835
pl.Series([[1], [2]]),
836
),
837
(
838
[["1"], ["A"], ["2"], ["B"], ["3"]],
839
pl.List(pl.String),
840
lambda s: s.filter(pl.Series([True, False, True, False, True])),
841
pl.Series([[1], [2], [3]]),
842
),
843
(
844
[
845
{"x": "1", "y": "2"},
846
{"x": "A", "y": "B"},
847
{"x": "3", "y": "4"},
848
{"x": "X", "y": "Y"},
849
{"x": "5", "y": "6"},
850
],
851
pl.Struct(
852
{
853
"x": pl.String,
854
"y": pl.String,
855
}
856
),
857
lambda s: s.filter(pl.Series([True, False, True, False, True])),
858
pl.Series(
859
[
860
{"x": 1, "y": 2},
861
{"x": 3, "y": 4},
862
{"x": 5, "y": 6},
863
]
864
),
865
),
866
(
867
[
868
{"x": "1", "y": "2"},
869
{"x": "A", "y": "B"},
870
{"x": "3", "y": "4"},
871
{"x": "X", "y": "Y"},
872
{"x": "5", "y": "6"},
873
],
874
pl.Struct(
875
{
876
"x": pl.String,
877
"y": pl.String,
878
}
879
),
880
lambda s: pl.select(
881
pl.when(pl.Series([True, False, True, False, True])).then(s)
882
).to_series(),
883
pl.Series(
884
[
885
{"x": 1, "y": 2},
886
None,
887
{"x": 3, "y": 4},
888
None,
889
{"x": 5, "y": 6},
890
]
891
),
892
),
893
],
894
)
895
def test_nested_strict_casts_succeeds(
896
values: list[Any],
897
from_dtype: pl.DataType,
898
pre_apply: Callable[[pl.Series], pl.Series] | None,
899
to: pl.Series,
900
) -> None:
901
s = pl.Series(values, dtype=from_dtype)
902
903
if pre_apply is not None:
904
s = pre_apply(s)
905
906
assert_series_equal(
907
s.cast(to.dtype),
908
to,
909
)
910
911
912
def test_nested_struct_cast_22744() -> None:
913
s = pl.Series(
914
"x",
915
[{"attrs": {"class": "a"}}],
916
)
917
918
expected = pl.select(
919
pl.lit(s).struct.with_fields(
920
pl.field("attrs").struct.with_fields(
921
[pl.field("class"), pl.lit(None, dtype=pl.String()).alias("other")]
922
)
923
)
924
)
925
926
assert_series_equal(
927
s.cast(
928
pl.Struct({"attrs": pl.Struct({"class": pl.String, "other": pl.String})})
929
),
930
expected.to_series(),
931
)
932
assert_frame_equal(
933
pl.DataFrame([s]).cast(
934
{
935
"x": pl.Struct(
936
{"attrs": pl.Struct({"class": pl.String, "other": pl.String})}
937
)
938
}
939
),
940
expected,
941
)
942
943
944
def test_cast_to_self_is_pruned() -> None:
945
q = pl.LazyFrame({"x": 1}, schema={"x": pl.Int64}).with_columns(
946
y=pl.col("x").cast(pl.Int64)
947
)
948
949
plan = q.explain()
950
assert 'col("x").alias("y")' in plan
951
952
assert_frame_equal(q.collect(), pl.DataFrame({"x": 1, "y": 1}))
953
954
955
@pytest.mark.parametrize(
956
("s", "to", "should_fail"),
957
[
958
(
959
pl.Series([datetime(2025, 1, 1)]),
960
pl.Datetime("ns"),
961
False,
962
),
963
(
964
pl.Series([datetime(9999, 1, 1)]),
965
pl.Datetime("ns"),
966
True,
967
),
968
(
969
pl.Series([datetime(2025, 1, 1), datetime(9999, 1, 1)]),
970
pl.Datetime("ns"),
971
True,
972
),
973
(
974
pl.Series([[datetime(2025, 1, 1)], [datetime(9999, 1, 1)]]),
975
pl.List(pl.Datetime("ns")),
976
True,
977
),
978
# lower date limit for nanosecond
979
(pl.Series([date(1677, 9, 22)]), pl.Datetime("ns"), False),
980
(pl.Series([date(1677, 9, 21)]), pl.Datetime("ns"), True),
981
# upper date limit for nanosecond
982
(pl.Series([date(2262, 4, 11)]), pl.Datetime("ns"), False),
983
(pl.Series([date(2262, 4, 12)]), pl.Datetime("ns"), True),
984
],
985
)
986
def test_cast_temporals_overflow_16039(
987
s: pl.Series, to: pl.DataType, should_fail: bool
988
) -> None:
989
if should_fail:
990
with pytest.raises(
991
pl.exceptions.InvalidOperationError, match="conversion from"
992
):
993
s.cast(to)
994
else:
995
s.cast(to)
996
997
998
@pytest.mark.parametrize("dtype", NUMERIC_DTYPES)
999
def test_prune_superfluous_cast(dtype: PolarsDataType) -> None:
1000
lf = pl.LazyFrame({"a": [1, 2, 3]}, schema={"a": dtype})
1001
result = lf.select(pl.col("a").cast(dtype))
1002
assert "strict_cast" not in result.explain()
1003
1004
1005
def test_not_prune_necessary_cast() -> None:
1006
lf = pl.LazyFrame({"a": [1, 2, 3]}, schema={"a": pl.UInt16})
1007
result = lf.select(pl.col("a").cast(pl.UInt8))
1008
assert "strict_cast" in result.explain()
1009
1010
1011
@pytest.mark.parametrize("target_dtype", NUMERIC_DTYPES)
1012
@pytest.mark.parametrize("inner_dtype", NUMERIC_DTYPES)
1013
@pytest.mark.parametrize("op", [operator.mul, operator.truediv])
1014
def test_cast_optimizer_in_list_eval_23924(
1015
inner_dtype: PolarsDataType,
1016
target_dtype: PolarsDataType,
1017
op: Callable[[pl.Expr, pl.Expr], pl.Expr],
1018
) -> None:
1019
print(inner_dtype, target_dtype)
1020
if target_dtype in INTEGER_DTYPES:
1021
df = pl.Series("a", [[1]], dtype=pl.List(target_dtype)).to_frame()
1022
else:
1023
df = pl.Series("a", [[1.0]], dtype=pl.List(target_dtype)).to_frame()
1024
q = df.lazy().select(
1025
pl.col("a").list.eval(
1026
(op(pl.element(), pl.element().cast(inner_dtype))).cast(target_dtype)
1027
)
1028
)
1029
assert q.collect_schema() == q.collect().schema
1030
1031
1032
def test_lit_cast_arithmetic_23677() -> None:
1033
df = pl.DataFrame({"a": [1]}, schema={"a": pl.Float32})
1034
q = df.lazy().select(pl.col("a") / pl.lit(1, pl.Int32))
1035
expected = pl.Schema({"a": pl.Float64})
1036
assert q.collect().schema == expected
1037
1038
1039
@pytest.mark.parametrize("col_dtype", NUMERIC_DTYPES)
1040
@pytest.mark.parametrize("lit_dtype", NUMERIC_DTYPES)
1041
@pytest.mark.parametrize("op", [operator.mul, operator.truediv])
1042
def test_lit_cast_arithmetic_matrix_schema(
1043
col_dtype: PolarsDataType,
1044
lit_dtype: PolarsDataType,
1045
op: Callable[[pl.Expr, pl.Expr], pl.Expr],
1046
) -> None:
1047
df = pl.DataFrame({"a": [1]}, schema={"a": col_dtype})
1048
q = df.lazy().select(op(pl.col("a"), pl.lit(1, lit_dtype)))
1049
assert q.collect_schema() == q.collect().schema
1050
1051