Path: blob/main/py-polars/tests/unit/operations/test_cast.py
6939 views
from __future__ import annotations12import operator3from datetime import date, datetime, time, timedelta4from decimal import Decimal5from typing import TYPE_CHECKING, Any, Callable67import pytest89import polars as pl10from polars._utils.constants import MS_PER_SECOND, NS_PER_SECOND, US_PER_SECOND11from polars.exceptions import ComputeError, InvalidOperationError12from polars.testing import assert_frame_equal13from polars.testing.asserts.series import assert_series_equal14from tests.unit.conftest import INTEGER_DTYPES, NUMERIC_DTYPES1516if TYPE_CHECKING:17from polars._typing import PolarsDataType, PythonDataType181920@pytest.mark.parametrize("dtype", [pl.Date(), pl.Date, date])21def test_string_date(dtype: PolarsDataType | PythonDataType) -> None:22df = pl.DataFrame({"x1": ["2021-01-01"]}).with_columns(23**{"x1-date": pl.col("x1").cast(dtype)}24)25expected = pl.DataFrame({"x1-date": [date(2021, 1, 1)]})26out = df.select(pl.col("x1-date"))27assert_frame_equal(expected, out)282930def test_invalid_string_date() -> None:31df = pl.DataFrame({"x1": ["2021-01-aa"]})3233with pytest.raises(InvalidOperationError):34df.with_columns(**{"x1-date": pl.col("x1").cast(pl.Date)})353637def test_string_datetime() -> None:38df = pl.DataFrame(39{"x1": ["2021-12-19T00:39:57", "2022-12-19T16:39:57"]}40).with_columns(41**{42"x1-datetime-ns": pl.col("x1").cast(pl.Datetime(time_unit="ns")),43"x1-datetime-ms": pl.col("x1").cast(pl.Datetime(time_unit="ms")),44"x1-datetime-us": pl.col("x1").cast(pl.Datetime(time_unit="us")),45}46)47first_row = datetime(year=2021, month=12, day=19, hour=00, minute=39, second=57)48second_row = datetime(year=2022, month=12, day=19, hour=16, minute=39, second=57)49expected = pl.DataFrame(50{51"x1-datetime-ns": [first_row, second_row],52"x1-datetime-ms": [first_row, second_row],53"x1-datetime-us": [first_row, second_row],54}55).select(56pl.col("x1-datetime-ns").dt.cast_time_unit("ns"),57pl.col("x1-datetime-ms").dt.cast_time_unit("ms"),58pl.col("x1-datetime-us").dt.cast_time_unit("us"),59)6061out = df.select(62pl.col("x1-datetime-ns"), pl.col("x1-datetime-ms"), pl.col("x1-datetime-us")63)64assert_frame_equal(expected, out)656667def test_invalid_string_datetime() -> None:68df = pl.DataFrame({"x1": ["2021-12-19 00:39:57", "2022-12-19 16:39:57"]})69with pytest.raises(InvalidOperationError):70df.with_columns(71**{"x1-datetime-ns": pl.col("x1").cast(pl.Datetime(time_unit="ns"))}72)737475def test_string_datetime_timezone() -> None:76ccs_tz = "America/Caracas"77stg_tz = "America/Santiago"78utc_tz = "UTC"79df = pl.DataFrame(80{"x1": ["1996-12-19T16:39:57 +00:00", "2022-12-19T00:39:57 +00:00"]}81).with_columns(82**{83"x1-datetime-ns": pl.col("x1").cast(84pl.Datetime(time_unit="ns", time_zone=ccs_tz)85),86"x1-datetime-ms": pl.col("x1").cast(87pl.Datetime(time_unit="ms", time_zone=stg_tz)88),89"x1-datetime-us": pl.col("x1").cast(90pl.Datetime(time_unit="us", time_zone=utc_tz)91),92}93)9495expected = pl.DataFrame(96{97"x1-datetime-ns": [98datetime(year=1996, month=12, day=19, hour=12, minute=39, second=57),99datetime(year=2022, month=12, day=18, hour=20, minute=39, second=57),100],101"x1-datetime-ms": [102datetime(year=1996, month=12, day=19, hour=13, minute=39, second=57),103datetime(year=2022, month=12, day=18, hour=21, minute=39, second=57),104],105"x1-datetime-us": [106datetime(year=1996, month=12, day=19, hour=16, minute=39, second=57),107datetime(year=2022, month=12, day=19, hour=00, minute=39, second=57),108],109}110).select(111pl.col("x1-datetime-ns").dt.cast_time_unit("ns").dt.replace_time_zone(ccs_tz),112pl.col("x1-datetime-ms").dt.cast_time_unit("ms").dt.replace_time_zone(stg_tz),113pl.col("x1-datetime-us").dt.cast_time_unit("us").dt.replace_time_zone(utc_tz),114)115116out = df.select(117pl.col("x1-datetime-ns"), pl.col("x1-datetime-ms"), pl.col("x1-datetime-us")118)119120assert_frame_equal(expected, out)121122123@pytest.mark.parametrize(("dtype"), [pl.Int8, pl.Int16, pl.Int32, pl.Int64])124def test_leading_plus_zero_int(dtype: pl.DataType) -> None:125s_int = pl.Series(126[127"-000000000000002",128"-1",129"-0",130"0",131"+0",132"1",133"+1",134"0000000000000000000002",135"+000000000000000000003",136]137)138assert_series_equal(139s_int.cast(dtype), pl.Series([-2, -1, 0, 0, 0, 1, 1, 2, 3], dtype=dtype)140)141142143@pytest.mark.parametrize(("dtype"), [pl.UInt8, pl.UInt16, pl.UInt32, pl.UInt64])144def test_leading_plus_zero_uint(dtype: pl.DataType) -> None:145s_int = pl.Series(146["0", "+0", "1", "+1", "0000000000000000000002", "+000000000000000000003"]147)148assert_series_equal(s_int.cast(dtype), pl.Series([0, 0, 1, 1, 2, 3], dtype=dtype))149150151@pytest.mark.parametrize(("dtype"), [pl.Float32, pl.Float64])152def test_leading_plus_zero_float(dtype: pl.DataType) -> None:153s_float = pl.Series(154[155"-000000000000002.0",156"-1.0",157"-.5",158"-0.0",159"0.",160"+0",161"+.5",162"1",163"+1",164"0000000000000000000002",165"+000000000000000000003",166]167)168assert_series_equal(169s_float.cast(dtype),170pl.Series(171[-2.0, -1.0, -0.5, 0.0, 0.0, 0.0, 0.5, 1.0, 1.0, 2.0, 3.0], dtype=dtype172),173)174175176def _cast_series(177val: int | datetime | date | time | timedelta,178dtype_in: PolarsDataType,179dtype_out: PolarsDataType,180strict: bool,181) -> int | datetime | date | time | timedelta | None:182return pl.Series("a", [val], dtype=dtype_in).cast(dtype_out, strict=strict).item() # type: ignore[no-any-return]183184185def _cast_expr(186val: int | datetime | date | time | timedelta,187dtype_in: PolarsDataType,188dtype_out: PolarsDataType,189strict: bool,190) -> int | datetime | date | time | timedelta | None:191return ( # type: ignore[no-any-return]192pl.Series("a", [val], dtype=dtype_in)193.to_frame()194.select(pl.col("a").cast(dtype_out, strict=strict))195.item()196)197198199def _cast_lit(200val: int | datetime | date | time | timedelta,201dtype_in: PolarsDataType,202dtype_out: PolarsDataType,203strict: bool,204) -> int | datetime | date | time | timedelta | None:205return pl.select(pl.lit(val, dtype=dtype_in).cast(dtype_out, strict=strict)).item() # type: ignore[no-any-return]206207208@pytest.mark.parametrize(209("value", "from_dtype", "to_dtype", "should_succeed", "expected_value"),210[211(-1, pl.Int8, pl.UInt8, False, None),212(-1, pl.Int16, pl.UInt16, False, None),213(-1, pl.Int32, pl.UInt32, False, None),214(-1, pl.Int64, pl.UInt64, False, None),215(2**7, pl.UInt8, pl.Int8, False, None),216(2**15, pl.UInt16, pl.Int16, False, None),217(2**31, pl.UInt32, pl.Int32, False, None),218(2**63, pl.UInt64, pl.Int64, False, None),219(2**7 - 1, pl.UInt8, pl.Int8, True, 2**7 - 1),220(2**15 - 1, pl.UInt16, pl.Int16, True, 2**15 - 1),221(2**31 - 1, pl.UInt32, pl.Int32, True, 2**31 - 1),222(2**63 - 1, pl.UInt64, pl.Int64, True, 2**63 - 1),223],224)225def test_strict_cast_int(226value: int,227from_dtype: PolarsDataType,228to_dtype: PolarsDataType,229should_succeed: bool,230expected_value: Any,231) -> None:232args = [value, from_dtype, to_dtype, True]233if should_succeed:234assert _cast_series(*args) == expected_value # type: ignore[arg-type]235assert _cast_expr(*args) == expected_value # type: ignore[arg-type]236assert _cast_lit(*args) == expected_value # type: ignore[arg-type]237else:238with pytest.raises(InvalidOperationError):239_cast_series(*args) # type: ignore[arg-type]240with pytest.raises(InvalidOperationError):241_cast_expr(*args) # type: ignore[arg-type]242with pytest.raises(InvalidOperationError):243_cast_lit(*args) # type: ignore[arg-type]244245246@pytest.mark.parametrize(247("value", "from_dtype", "to_dtype", "expected_value"),248[249(-1, pl.Int8, pl.UInt8, None),250(-1, pl.Int16, pl.UInt16, None),251(-1, pl.Int32, pl.UInt32, None),252(-1, pl.Int64, pl.UInt64, None),253(2**7, pl.UInt8, pl.Int8, None),254(2**15, pl.UInt16, pl.Int16, None),255(2**31, pl.UInt32, pl.Int32, None),256(2**63, pl.UInt64, pl.Int64, None),257(2**7 - 1, pl.UInt8, pl.Int8, 2**7 - 1),258(2**15 - 1, pl.UInt16, pl.Int16, 2**15 - 1),259(2**31 - 1, pl.UInt32, pl.Int32, 2**31 - 1),260(2**63 - 1, pl.UInt64, pl.Int64, 2**63 - 1),261],262)263def test_cast_int(264value: int,265from_dtype: PolarsDataType,266to_dtype: PolarsDataType,267expected_value: Any,268) -> None:269args = [value, from_dtype, to_dtype, False]270assert _cast_series(*args) == expected_value # type: ignore[arg-type]271assert _cast_expr(*args) == expected_value # type: ignore[arg-type]272assert _cast_lit(*args) == expected_value # type: ignore[arg-type]273274275def _cast_series_t(276val: int | datetime | date | time | timedelta,277dtype_in: PolarsDataType,278dtype_out: PolarsDataType,279strict: bool,280) -> pl.Series:281return pl.Series("a", [val], dtype=dtype_in).cast(dtype_out, strict=strict)282283284def _cast_expr_t(285val: int | datetime | date | time | timedelta,286dtype_in: PolarsDataType,287dtype_out: PolarsDataType,288strict: bool,289) -> pl.Series:290return (291pl.Series("a", [val], dtype=dtype_in)292.to_frame()293.select(pl.col("a").cast(dtype_out, strict=strict))294.to_series()295)296297298def _cast_lit_t(299val: int | datetime | date | time | timedelta,300dtype_in: PolarsDataType,301dtype_out: PolarsDataType,302strict: bool,303) -> pl.Series:304return pl.select(305pl.lit(val, dtype=dtype_in).cast(dtype_out, strict=strict)306).to_series()307308309@pytest.mark.parametrize(310(311"value",312"from_dtype",313"to_dtype",314"should_succeed",315"expected_value",316),317[318# date to datetime319(date(1970, 1, 1), pl.Date, pl.Datetime("ms"), True, datetime(1970, 1, 1)),320(date(1970, 1, 1), pl.Date, pl.Datetime("us"), True, datetime(1970, 1, 1)),321(date(1970, 1, 1), pl.Date, pl.Datetime("ns"), True, datetime(1970, 1, 1)),322# datetime to date323(datetime(1970, 1, 1), pl.Datetime("ms"), pl.Date, True, date(1970, 1, 1)),324(datetime(1970, 1, 1), pl.Datetime("us"), pl.Date, True, date(1970, 1, 1)),325(datetime(1970, 1, 1), pl.Datetime("ns"), pl.Date, True, date(1970, 1, 1)),326# datetime to time327(datetime(2000, 1, 1, 1, 0, 0), pl.Datetime("ms"), pl.Time, True, time(hour=1)),328(datetime(2000, 1, 1, 1, 0, 0), pl.Datetime("us"), pl.Time, True, time(hour=1)),329(datetime(2000, 1, 1, 1, 0, 0), pl.Datetime("ns"), pl.Time, True, time(hour=1)),330# duration to int331(timedelta(seconds=1), pl.Duration("ms"), pl.Int32, True, MS_PER_SECOND),332(timedelta(seconds=1), pl.Duration("us"), pl.Int64, True, US_PER_SECOND),333(timedelta(seconds=1), pl.Duration("ns"), pl.Int64, True, NS_PER_SECOND),334# time to duration335(time(hour=1), pl.Time, pl.Duration("ms"), True, timedelta(hours=1)),336(time(hour=1), pl.Time, pl.Duration("us"), True, timedelta(hours=1)),337(time(hour=1), pl.Time, pl.Duration("ns"), True, timedelta(hours=1)),338# int to date339(100, pl.UInt8, pl.Date, True, date(1970, 4, 11)),340(100, pl.UInt16, pl.Date, True, date(1970, 4, 11)),341(100, pl.UInt32, pl.Date, True, date(1970, 4, 11)),342(100, pl.UInt64, pl.Date, True, date(1970, 4, 11)),343(100, pl.Int8, pl.Date, True, date(1970, 4, 11)),344(100, pl.Int16, pl.Date, True, date(1970, 4, 11)),345(100, pl.Int32, pl.Date, True, date(1970, 4, 11)),346(100, pl.Int64, pl.Date, True, date(1970, 4, 11)),347# failures348(2**63 - 1, pl.Int64, pl.Date, False, None),349(-(2**62), pl.Int64, pl.Date, False, None),350(date(1970, 5, 10), pl.Date, pl.Int8, False, None),351(date(2149, 6, 7), pl.Date, pl.Int16, False, None),352(datetime(9999, 12, 31), pl.Datetime, pl.Int8, False, None),353(datetime(9999, 12, 31), pl.Datetime, pl.Int16, False, None),354],355)356def test_strict_cast_temporal(357value: int,358from_dtype: PolarsDataType,359to_dtype: PolarsDataType,360should_succeed: bool,361expected_value: Any,362) -> None:363args = [value, from_dtype, to_dtype, True]364if should_succeed:365out = _cast_series_t(*args) # type: ignore[arg-type]366assert out.item() == expected_value367assert out.dtype == to_dtype368out = _cast_expr_t(*args) # type: ignore[arg-type]369assert out.item() == expected_value370assert out.dtype == to_dtype371out = _cast_lit_t(*args) # type: ignore[arg-type]372assert out.item() == expected_value373assert out.dtype == to_dtype374else:375with pytest.raises(InvalidOperationError):376_cast_series_t(*args) # type: ignore[arg-type]377with pytest.raises(InvalidOperationError):378_cast_expr_t(*args) # type: ignore[arg-type]379with pytest.raises(InvalidOperationError):380_cast_lit_t(*args) # type: ignore[arg-type]381382383@pytest.mark.parametrize(384(385"value",386"from_dtype",387"to_dtype",388"expected_value",389),390[391# date to datetime392(date(1970, 1, 1), pl.Date, pl.Datetime("ms"), datetime(1970, 1, 1)),393(date(1970, 1, 1), pl.Date, pl.Datetime("us"), datetime(1970, 1, 1)),394(date(1970, 1, 1), pl.Date, pl.Datetime("ns"), datetime(1970, 1, 1)),395# datetime to date396(datetime(1970, 1, 1), pl.Datetime("ms"), pl.Date, date(1970, 1, 1)),397(datetime(1970, 1, 1), pl.Datetime("us"), pl.Date, date(1970, 1, 1)),398(datetime(1970, 1, 1), pl.Datetime("ns"), pl.Date, date(1970, 1, 1)),399# datetime to time400(datetime(2000, 1, 1, 1, 0, 0), pl.Datetime("ms"), pl.Time, time(hour=1)),401(datetime(2000, 1, 1, 1, 0, 0), pl.Datetime("us"), pl.Time, time(hour=1)),402(datetime(2000, 1, 1, 1, 0, 0), pl.Datetime("ns"), pl.Time, time(hour=1)),403# duration to int404(timedelta(seconds=1), pl.Duration("ms"), pl.Int32, MS_PER_SECOND),405(timedelta(seconds=1), pl.Duration("us"), pl.Int64, US_PER_SECOND),406(timedelta(seconds=1), pl.Duration("ns"), pl.Int64, NS_PER_SECOND),407# time to duration408(time(hour=1), pl.Time, pl.Duration("ms"), timedelta(hours=1)),409(time(hour=1), pl.Time, pl.Duration("us"), timedelta(hours=1)),410(time(hour=1), pl.Time, pl.Duration("ns"), timedelta(hours=1)),411# int to date412(100, pl.UInt8, pl.Date, date(1970, 4, 11)),413(100, pl.UInt16, pl.Date, date(1970, 4, 11)),414(100, pl.UInt32, pl.Date, date(1970, 4, 11)),415(100, pl.UInt64, pl.Date, date(1970, 4, 11)),416(100, pl.Int8, pl.Date, date(1970, 4, 11)),417(100, pl.Int16, pl.Date, date(1970, 4, 11)),418(100, pl.Int32, pl.Date, date(1970, 4, 11)),419(100, pl.Int64, pl.Date, date(1970, 4, 11)),420# failures421(2**63 - 1, pl.Int64, pl.Date, None),422(-(2**62), pl.Int64, pl.Date, None),423(date(1970, 5, 10), pl.Date, pl.Int8, None),424(date(2149, 6, 7), pl.Date, pl.Int16, None),425(datetime(9999, 12, 31), pl.Datetime, pl.Int8, None),426(datetime(9999, 12, 31), pl.Datetime, pl.Int16, None),427],428)429def test_cast_temporal(430value: int,431from_dtype: PolarsDataType,432to_dtype: PolarsDataType,433expected_value: Any,434) -> None:435args = [value, from_dtype, to_dtype, False]436out = _cast_series_t(*args) # type: ignore[arg-type]437if expected_value is None:438assert out.item() is None439else:440assert out.item() == expected_value441assert out.dtype == to_dtype442443out = _cast_expr_t(*args) # type: ignore[arg-type]444if expected_value is None:445assert out.item() is None446else:447assert out.item() == expected_value448assert out.dtype == to_dtype449450out = _cast_lit_t(*args) # type: ignore[arg-type]451if expected_value is None:452assert out.item() is None453else:454assert out.item() == expected_value455assert out.dtype == to_dtype456457458@pytest.mark.parametrize(459(460"value",461"from_dtype",462"to_dtype",463"expected_value",464),465[466(str(2**7 - 1), pl.String, pl.Int8, 2**7 - 1),467(str(2**15 - 1), pl.String, pl.Int16, 2**15 - 1),468(str(2**31 - 1), pl.String, pl.Int32, 2**31 - 1),469(str(2**63 - 1), pl.String, pl.Int64, 2**63 - 1),470("1.0", pl.String, pl.Float32, 1.0),471("1.0", pl.String, pl.Float64, 1.0),472# overflow473(str(2**7), pl.String, pl.Int8, None),474(str(2**15), pl.String, pl.Int16, None),475(str(2**31), pl.String, pl.Int32, None),476(str(2**63), pl.String, pl.Int64, None),477],478)479def test_cast_string(480value: int,481from_dtype: PolarsDataType,482to_dtype: PolarsDataType,483expected_value: Any,484) -> None:485args = [value, from_dtype, to_dtype, False]486out = _cast_series_t(*args) # type: ignore[arg-type]487if expected_value is None:488assert out.item() is None489else:490assert out.item() == expected_value491assert out.dtype == to_dtype492493out = _cast_expr_t(*args) # type: ignore[arg-type]494if expected_value is None:495assert out.item() is None496else:497assert out.item() == expected_value498assert out.dtype == to_dtype499500out = _cast_lit_t(*args) # type: ignore[arg-type]501if expected_value is None:502assert out.item() is None503else:504assert out.item() == expected_value505assert out.dtype == to_dtype506507508@pytest.mark.parametrize(509(510"value",511"from_dtype",512"to_dtype",513"should_succeed",514"expected_value",515),516[517(str(2**7 - 1), pl.String, pl.Int8, True, 2**7 - 1),518(str(2**15 - 1), pl.String, pl.Int16, True, 2**15 - 1),519(str(2**31 - 1), pl.String, pl.Int32, True, 2**31 - 1),520(str(2**63 - 1), pl.String, pl.Int64, True, 2**63 - 1),521("1.0", pl.String, pl.Float32, True, 1.0),522("1.0", pl.String, pl.Float64, True, 1.0),523# overflow524(str(2**7), pl.String, pl.Int8, False, None),525(str(2**15), pl.String, pl.Int16, False, None),526(str(2**31), pl.String, pl.Int32, False, None),527(str(2**63), pl.String, pl.Int64, False, None),528],529)530def test_strict_cast_string(531value: int,532from_dtype: PolarsDataType,533to_dtype: PolarsDataType,534should_succeed: bool,535expected_value: Any,536) -> None:537args = [value, from_dtype, to_dtype, True]538if should_succeed:539out = _cast_series_t(*args) # type: ignore[arg-type]540assert out.item() == expected_value541assert out.dtype == to_dtype542out = _cast_expr_t(*args) # type: ignore[arg-type]543assert out.item() == expected_value544assert out.dtype == to_dtype545out = _cast_lit_t(*args) # type: ignore[arg-type]546assert out.item() == expected_value547assert out.dtype == to_dtype548else:549with pytest.raises(InvalidOperationError):550_cast_series_t(*args) # type: ignore[arg-type]551with pytest.raises(InvalidOperationError):552_cast_expr_t(*args) # type: ignore[arg-type]553with pytest.raises(InvalidOperationError):554_cast_lit_t(*args) # type: ignore[arg-type]555556557@pytest.mark.parametrize(558"dtype_in",559[(pl.Categorical), (pl.Enum(["1"]))],560)561@pytest.mark.parametrize(562"dtype_out",563[564pl.String,565pl.Categorical,566pl.Enum(["1", "2"]),567],568)569def test_cast_categorical_name_retention(570dtype_in: PolarsDataType, dtype_out: PolarsDataType571) -> None:572assert pl.Series("a", ["1"], dtype=dtype_in).cast(dtype_out).name == "a"573574575def test_cast_date_to_time() -> None:576s = pl.Series([date(1970, 1, 1), date(2000, 12, 31)])577msg = "casting from Date to Time not supported"578with pytest.raises(InvalidOperationError, match=msg):579s.cast(pl.Time)580581582def test_cast_time_to_date() -> None:583s = pl.Series([time(0, 0), time(20, 00)])584msg = "casting from Time to Date not supported"585with pytest.raises(InvalidOperationError, match=msg):586s.cast(pl.Date)587588589def test_cast_decimal_to_boolean() -> None:590s = pl.Series("s", [Decimal("0.0"), Decimal("1.5"), Decimal("-1.5")])591assert_series_equal(s.cast(pl.Boolean), pl.Series("s", [False, True, True]))592593df = s.to_frame()594assert_frame_equal(595df.select(pl.col("s").cast(pl.Boolean)),596pl.DataFrame({"s": [False, True, True]}),597)598599600def test_cast_array_to_different_width() -> None:601s = pl.Series([[1, 2], [3, 4]], dtype=pl.Array(pl.Int8, 2))602with pytest.raises(603InvalidOperationError, match="cannot cast Array to a different width"604):605s.cast(pl.Array(pl.Int16, 3))606607608def test_cast_decimal_to_decimal_high_precision() -> None:609precision = 22610values = [Decimal("9" * precision)]611s = pl.Series(values, dtype=pl.Decimal(None, 0))612613target_dtype = pl.Decimal(precision, 0)614result = s.cast(target_dtype)615616assert result.dtype == target_dtype617assert result.to_list() == values618619620@pytest.mark.parametrize("value", [float("inf"), float("nan")])621def test_invalid_cast_float_to_decimal(value: float) -> None:622s = pl.Series([value], dtype=pl.Float64)623with pytest.raises(624InvalidOperationError,625match=r"conversion from `f64` to `decimal\[\*,0\]` failed",626):627s.cast(pl.Decimal)628629630def test_err_on_time_datetime_cast() -> None:631s = pl.Series([time(10, 0, 0), time(11, 30, 59)])632with pytest.raises(633InvalidOperationError,634match="casting from Time to Datetime\\('μs'\\) not supported; consider using `dt.combine`",635):636s.cast(pl.Datetime)637638639def test_err_on_invalid_time_zone_cast() -> None:640s = pl.Series([datetime(2021, 1, 1)])641with pytest.raises(ComputeError, match=r"unable to parse time zone: 'qwerty'"):642s.cast(pl.Datetime("us", "qwerty"))643644645def test_invalid_inner_type_cast_list() -> None:646s = pl.Series([[-1, 1]])647with pytest.raises(648InvalidOperationError,649match=r"cannot cast List inner type: 'Int64' to Categorical",650):651s.cast(pl.List(pl.Categorical))652653654@pytest.mark.parametrize(655("values", "result"),656[657([[]], [b""]),658([[1, 2], [3, 4]], [b"\x01\x02", b"\x03\x04"]),659([[1, 2], None, [3, 4]], [b"\x01\x02", None, b"\x03\x04"]),660(661[None, [111, 110, 101], [12, None], [116, 119, 111], list(range(256))],662[663None,664b"one",665# A list with a null in it gets turned into a null:666None,667b"two",668bytes(i for i in range(256)),669],670),671],672)673def test_list_uint8_to_bytes(674values: list[list[int | None] | None], result: list[bytes | None]675) -> None:676s = pl.Series(677values,678dtype=pl.List(pl.UInt8()),679)680assert s.cast(pl.Binary(), strict=False).to_list() == result681682683def test_list_uint8_to_bytes_strict() -> None:684series = pl.Series(685[[1, 2], [3, 4]],686dtype=pl.List(pl.UInt8()),687)688assert series.cast(pl.Binary(), strict=True).to_list() == [b"\x01\x02", b"\x03\x04"]689690series = pl.Series(691"mycol",692[[1, 2], [3, None]],693dtype=pl.List(pl.UInt8()),694)695with pytest.raises(696InvalidOperationError,697match="conversion from `list\\[u8\\]` to `binary` failed in column 'mycol' for 1 out of 2 values: \\[\\[3, null\\]\\]",698):699series.cast(pl.Binary(), strict=True)700701702def test_all_null_cast_5826() -> None:703df = pl.DataFrame(data=[pl.Series("a", [None], dtype=pl.String)])704out = df.with_columns(pl.col("a").cast(pl.Boolean))705assert out.dtypes == [pl.Boolean]706assert out.item() is None707708709@pytest.mark.parametrize("dtype", INTEGER_DTYPES)710def test_bool_numeric_supertype(dtype: PolarsDataType) -> None:711df = pl.DataFrame({"v": [1, 2, 3, 4, 5, 6]})712result = df.select((pl.col("v") < 3).sum().cast(dtype) / pl.len())713assert result.item() - 0.3333333 <= 0.00001714715716@pytest.mark.parametrize("dtype", [pl.String(), pl.String, str])717def test_cast_consistency(dtype: PolarsDataType | PythonDataType) -> None:718assert pl.DataFrame().with_columns(a=pl.lit(0.0)).with_columns(719b=pl.col("a").cast(dtype), c=pl.lit(0.0).cast(dtype)720).to_dict(as_series=False) == {"a": [0.0], "b": ["0.0"], "c": ["0.0"]}721722723def test_cast_int_to_string_unsets_sorted_flag_19424() -> None:724s = pl.Series([1, 2]).set_sorted()725assert s.flags["SORTED_ASC"]726assert not s.cast(pl.String).flags["SORTED_ASC"]727728729def test_cast_integer_to_decimal() -> None:730s = pl.Series([1, 2, 3])731result = s.cast(pl.Decimal(10, 2))732expected = pl.Series(733"", [Decimal("1.00"), Decimal("2.00"), Decimal("3.00")], pl.Decimal(10, 2)734)735assert_series_equal(result, expected)736737738def test_cast_python_dtypes() -> None:739s = pl.Series([0, 1])740assert s.cast(int).dtype == pl.Int64741assert s.cast(float).dtype == pl.Float64742assert s.cast(bool).dtype == pl.Boolean743assert s.cast(str).dtype == pl.String744745746def test_overflowing_cast_literals_21023() -> None:747for optimizations in [pl.QueryOptFlags(), pl.QueryOptFlags.none()]:748assert_frame_equal(749(750pl.LazyFrame()751.select(752pl.lit(pl.Series([128], dtype=pl.Int64)).cast(753pl.Int8, wrap_numerical=True754)755)756.collect(optimizations=optimizations)757),758pl.Series([-128], dtype=pl.Int8).to_frame(),759)760761762@pytest.mark.parametrize("value", [True, False])763@pytest.mark.parametrize(764"dtype",765[766pl.Enum(["a", "b"]),767pl.Series(["a", "b"], dtype=pl.Categorical).dtype,768],769)770def test_invalid_bool_to_cat(value: bool, dtype: PolarsDataType) -> None:771# Enum772with pytest.raises(773InvalidOperationError,774match="cannot cast Boolean to Categorical",775):776pl.Series([value]).cast(dtype)777778779@pytest.mark.parametrize(780("values", "from_dtype", "to_dtype", "pre_apply"),781[782([["A"]], pl.List(pl.String), pl.List(pl.Int8), None),783([["A"]], pl.Array(pl.String, 1), pl.List(pl.Int8), None),784([[["A"]]], pl.List(pl.List(pl.String)), pl.List(pl.List(pl.Int8)), None),785(786[787{"x": "1", "y": "2"},788{"x": "A", "y": "B"},789{"x": "3", "y": "4"},790{"x": "X", "y": "Y"},791{"x": "5", "y": "6"},792],793pl.Struct(794{795"x": pl.String,796"y": pl.String,797}798),799pl.Struct(800{801"x": pl.Int8,802"y": pl.Int32,803}804),805None,806),807],808)809def test_nested_strict_casts_failing(810values: list[Any],811from_dtype: pl.DataType,812to_dtype: pl.DataType,813pre_apply: Callable[[pl.Series], pl.Series] | None,814) -> None:815s = pl.Series(values, dtype=from_dtype)816817if pre_apply is not None:818s = pre_apply(s)819820with pytest.raises(821pl.exceptions.InvalidOperationError,822match=r"conversion from",823):824s.cast(to_dtype)825826827@pytest.mark.parametrize(828("values", "from_dtype", "pre_apply", "to"),829[830(831[["A"], ["1"], ["2"]],832pl.List(pl.String),833lambda s: s.slice(1, 2),834pl.Series([[1], [2]]),835),836(837[["1"], ["A"], ["2"], ["B"], ["3"]],838pl.List(pl.String),839lambda s: s.filter(pl.Series([True, False, True, False, True])),840pl.Series([[1], [2], [3]]),841),842(843[844{"x": "1", "y": "2"},845{"x": "A", "y": "B"},846{"x": "3", "y": "4"},847{"x": "X", "y": "Y"},848{"x": "5", "y": "6"},849],850pl.Struct(851{852"x": pl.String,853"y": pl.String,854}855),856lambda s: s.filter(pl.Series([True, False, True, False, True])),857pl.Series(858[859{"x": 1, "y": 2},860{"x": 3, "y": 4},861{"x": 5, "y": 6},862]863),864),865(866[867{"x": "1", "y": "2"},868{"x": "A", "y": "B"},869{"x": "3", "y": "4"},870{"x": "X", "y": "Y"},871{"x": "5", "y": "6"},872],873pl.Struct(874{875"x": pl.String,876"y": pl.String,877}878),879lambda s: pl.select(880pl.when(pl.Series([True, False, True, False, True])).then(s)881).to_series(),882pl.Series(883[884{"x": 1, "y": 2},885None,886{"x": 3, "y": 4},887None,888{"x": 5, "y": 6},889]890),891),892],893)894def test_nested_strict_casts_succeeds(895values: list[Any],896from_dtype: pl.DataType,897pre_apply: Callable[[pl.Series], pl.Series] | None,898to: pl.Series,899) -> None:900s = pl.Series(values, dtype=from_dtype)901902if pre_apply is not None:903s = pre_apply(s)904905assert_series_equal(906s.cast(to.dtype),907to,908)909910911def test_nested_struct_cast_22744() -> None:912s = pl.Series(913"x",914[{"attrs": {"class": "a"}}],915)916917expected = pl.select(918pl.lit(s).struct.with_fields(919pl.field("attrs").struct.with_fields(920[pl.field("class"), pl.lit(None, dtype=pl.String()).alias("other")]921)922)923)924925assert_series_equal(926s.cast(927pl.Struct({"attrs": pl.Struct({"class": pl.String, "other": pl.String})})928),929expected.to_series(),930)931assert_frame_equal(932pl.DataFrame([s]).cast(933{934"x": pl.Struct(935{"attrs": pl.Struct({"class": pl.String, "other": pl.String})}936)937}938),939expected,940)941942943def test_cast_to_self_is_pruned() -> None:944q = pl.LazyFrame({"x": 1}, schema={"x": pl.Int64}).with_columns(945y=pl.col("x").cast(pl.Int64)946)947948plan = q.explain()949assert 'col("x").alias("y")' in plan950951assert_frame_equal(q.collect(), pl.DataFrame({"x": 1, "y": 1}))952953954@pytest.mark.parametrize(955("s", "to", "should_fail"),956[957(958pl.Series([datetime(2025, 1, 1)]),959pl.Datetime("ns"),960False,961),962(963pl.Series([datetime(9999, 1, 1)]),964pl.Datetime("ns"),965True,966),967(968pl.Series([datetime(2025, 1, 1), datetime(9999, 1, 1)]),969pl.Datetime("ns"),970True,971),972(973pl.Series([[datetime(2025, 1, 1)], [datetime(9999, 1, 1)]]),974pl.List(pl.Datetime("ns")),975True,976),977# lower date limit for nanosecond978(pl.Series([date(1677, 9, 22)]), pl.Datetime("ns"), False),979(pl.Series([date(1677, 9, 21)]), pl.Datetime("ns"), True),980# upper date limit for nanosecond981(pl.Series([date(2262, 4, 11)]), pl.Datetime("ns"), False),982(pl.Series([date(2262, 4, 12)]), pl.Datetime("ns"), True),983],984)985def test_cast_temporals_overflow_16039(986s: pl.Series, to: pl.DataType, should_fail: bool987) -> None:988if should_fail:989with pytest.raises(990pl.exceptions.InvalidOperationError, match="conversion from"991):992s.cast(to)993else:994s.cast(to)995996997@pytest.mark.parametrize("dtype", NUMERIC_DTYPES)998def test_prune_superfluous_cast(dtype: PolarsDataType) -> None:999lf = pl.LazyFrame({"a": [1, 2, 3]}, schema={"a": dtype})1000result = lf.select(pl.col("a").cast(dtype))1001assert "strict_cast" not in result.explain()100210031004def test_not_prune_necessary_cast() -> None:1005lf = pl.LazyFrame({"a": [1, 2, 3]}, schema={"a": pl.UInt16})1006result = lf.select(pl.col("a").cast(pl.UInt8))1007assert "strict_cast" in result.explain()100810091010@pytest.mark.parametrize("target_dtype", NUMERIC_DTYPES)1011@pytest.mark.parametrize("inner_dtype", NUMERIC_DTYPES)1012@pytest.mark.parametrize("op", [operator.mul, operator.truediv])1013def test_cast_optimizer_in_list_eval_23924(1014inner_dtype: PolarsDataType,1015target_dtype: PolarsDataType,1016op: Callable[[pl.Expr, pl.Expr], pl.Expr],1017) -> None:1018print(inner_dtype, target_dtype)1019if target_dtype in INTEGER_DTYPES:1020df = pl.Series("a", [[1]], dtype=pl.List(target_dtype)).to_frame()1021else:1022df = pl.Series("a", [[1.0]], dtype=pl.List(target_dtype)).to_frame()1023q = df.lazy().select(1024pl.col("a").list.eval(1025(op(pl.element(), pl.element().cast(inner_dtype))).cast(target_dtype)1026)1027)1028assert q.collect_schema() == q.collect().schema102910301031def test_lit_cast_arithmetic_23677() -> None:1032df = pl.DataFrame({"a": [1]}, schema={"a": pl.Float32})1033q = df.lazy().select(pl.col("a") / pl.lit(1, pl.Int32))1034expected = pl.Schema({"a": pl.Float64})1035assert q.collect().schema == expected103610371038@pytest.mark.parametrize("col_dtype", NUMERIC_DTYPES)1039@pytest.mark.parametrize("lit_dtype", NUMERIC_DTYPES)1040@pytest.mark.parametrize("op", [operator.mul, operator.truediv])1041def test_lit_cast_arithmetic_matrix_schema(1042col_dtype: PolarsDataType,1043lit_dtype: PolarsDataType,1044op: Callable[[pl.Expr, pl.Expr], pl.Expr],1045) -> None:1046df = pl.DataFrame({"a": [1]}, schema={"a": col_dtype})1047q = df.lazy().select(op(pl.col("a"), pl.lit(1, lit_dtype)))1048assert q.collect_schema() == q.collect().schema104910501051