Path: blob/main/py-polars/tests/unit/operations/test_cast.py
8424 views
from __future__ import annotations12import operator3from datetime import date, datetime, time, timedelta4from decimal import Decimal5from typing import TYPE_CHECKING, Any67import pytest89import polars as pl10from polars._utils.constants import MS_PER_SECOND, NS_PER_SECOND, US_PER_SECOND11from polars.exceptions import ComputeError, InvalidOperationError12from polars.testing import assert_frame_equal13from polars.testing.asserts.series import assert_series_equal14from tests.unit.conftest import INTEGER_DTYPES, NUMERIC_DTYPES1516if TYPE_CHECKING:17from collections.abc import Callable1819from polars._typing import PolarsDataType, PythonDataType202122@pytest.mark.parametrize("dtype", [pl.Date(), pl.Date, date])23def test_string_date(dtype: PolarsDataType | PythonDataType) -> None:24df = pl.DataFrame({"x1": ["2021-01-01"]}).with_columns(25**{"x1-date": pl.col("x1").cast(dtype)}26)27expected = pl.DataFrame({"x1-date": [date(2021, 1, 1)]})28out = df.select(pl.col("x1-date"))29assert_frame_equal(expected, out)303132def test_invalid_string_date() -> None:33df = pl.DataFrame({"x1": ["2021-01-aa"]})3435with pytest.raises(InvalidOperationError):36df.with_columns(**{"x1-date": pl.col("x1").cast(pl.Date)})373839def test_string_datetime() -> None:40df = pl.DataFrame(41{"x1": ["2021-12-19T00:39:57", "2022-12-19T16:39:57"]}42).with_columns(43**{44"x1-datetime-ns": pl.col("x1").cast(pl.Datetime(time_unit="ns")),45"x1-datetime-ms": pl.col("x1").cast(pl.Datetime(time_unit="ms")),46"x1-datetime-us": pl.col("x1").cast(pl.Datetime(time_unit="us")),47}48)49first_row = datetime(year=2021, month=12, day=19, hour=00, minute=39, second=57)50second_row = datetime(year=2022, month=12, day=19, hour=16, minute=39, second=57)51expected = pl.DataFrame(52{53"x1-datetime-ns": [first_row, second_row],54"x1-datetime-ms": [first_row, second_row],55"x1-datetime-us": [first_row, second_row],56}57).select(58pl.col("x1-datetime-ns").dt.cast_time_unit("ns"),59pl.col("x1-datetime-ms").dt.cast_time_unit("ms"),60pl.col("x1-datetime-us").dt.cast_time_unit("us"),61)6263out = df.select(64pl.col("x1-datetime-ns"), pl.col("x1-datetime-ms"), pl.col("x1-datetime-us")65)66assert_frame_equal(expected, out)676869def test_invalid_string_datetime() -> None:70df = pl.DataFrame({"x1": ["2021-12-19 00:39:57", "2022-12-19 16:39:57"]})71with pytest.raises(InvalidOperationError):72df.with_columns(73**{"x1-datetime-ns": pl.col("x1").cast(pl.Datetime(time_unit="ns"))}74)757677def test_string_datetime_timezone() -> None:78ccs_tz = "America/Caracas"79stg_tz = "America/Santiago"80utc_tz = "UTC"81df = pl.DataFrame(82{"x1": ["1996-12-19T16:39:57 +00:00", "2022-12-19T00:39:57 +00:00"]}83).with_columns(84**{85"x1-datetime-ns": pl.col("x1").cast(86pl.Datetime(time_unit="ns", time_zone=ccs_tz)87),88"x1-datetime-ms": pl.col("x1").cast(89pl.Datetime(time_unit="ms", time_zone=stg_tz)90),91"x1-datetime-us": pl.col("x1").cast(92pl.Datetime(time_unit="us", time_zone=utc_tz)93),94}95)9697expected = pl.DataFrame(98{99"x1-datetime-ns": [100datetime(year=1996, month=12, day=19, hour=12, minute=39, second=57),101datetime(year=2022, month=12, day=18, hour=20, minute=39, second=57),102],103"x1-datetime-ms": [104datetime(year=1996, month=12, day=19, hour=13, minute=39, second=57),105datetime(year=2022, month=12, day=18, hour=21, minute=39, second=57),106],107"x1-datetime-us": [108datetime(year=1996, month=12, day=19, hour=16, minute=39, second=57),109datetime(year=2022, month=12, day=19, hour=00, minute=39, second=57),110],111}112).select(113pl.col("x1-datetime-ns").dt.cast_time_unit("ns").dt.replace_time_zone(ccs_tz),114pl.col("x1-datetime-ms").dt.cast_time_unit("ms").dt.replace_time_zone(stg_tz),115pl.col("x1-datetime-us").dt.cast_time_unit("us").dt.replace_time_zone(utc_tz),116)117118out = df.select(119pl.col("x1-datetime-ns"), pl.col("x1-datetime-ms"), pl.col("x1-datetime-us")120)121122assert_frame_equal(expected, out)123124125@pytest.mark.parametrize(("dtype"), [pl.Int8, pl.Int16, pl.Int32, pl.Int64])126def test_leading_plus_zero_int(dtype: pl.DataType) -> None:127s_int = pl.Series(128[129"-000000000000002",130"-1",131"-0",132"0",133"+0",134"1",135"+1",136"0000000000000000000002",137"+000000000000000000003",138]139)140assert_series_equal(141s_int.cast(dtype), pl.Series([-2, -1, 0, 0, 0, 1, 1, 2, 3], dtype=dtype)142)143144145@pytest.mark.parametrize(("dtype"), [pl.UInt8, pl.UInt16, pl.UInt32, pl.UInt64])146def test_leading_plus_zero_uint(dtype: pl.DataType) -> None:147s_int = pl.Series(148["0", "+0", "1", "+1", "0000000000000000000002", "+000000000000000000003"]149)150assert_series_equal(s_int.cast(dtype), pl.Series([0, 0, 1, 1, 2, 3], dtype=dtype))151152153@pytest.mark.parametrize(("dtype"), [pl.Float32, pl.Float64])154def test_leading_plus_zero_float(dtype: pl.DataType) -> None:155s_float = pl.Series(156[157"-000000000000002.0",158"-1.0",159"-.5",160"-0.0",161"0.",162"+0",163"+.5",164"1",165"+1",166"0000000000000000000002",167"+000000000000000000003",168]169)170assert_series_equal(171s_float.cast(dtype),172pl.Series(173[-2.0, -1.0, -0.5, 0.0, 0.0, 0.0, 0.5, 1.0, 1.0, 2.0, 3.0], dtype=dtype174),175)176177178def _cast_series(179val: int | datetime | date | time | timedelta,180dtype_in: PolarsDataType,181dtype_out: PolarsDataType,182strict: bool,183) -> int | datetime | date | time | timedelta | None:184return pl.Series("a", [val], dtype=dtype_in).cast(dtype_out, strict=strict).item() # type: ignore[no-any-return]185186187def _cast_expr(188val: int | datetime | date | time | timedelta,189dtype_in: PolarsDataType,190dtype_out: PolarsDataType,191strict: bool,192) -> int | datetime | date | time | timedelta | None:193return ( # type: ignore[no-any-return]194pl.Series("a", [val], dtype=dtype_in)195.to_frame()196.select(pl.col("a").cast(dtype_out, strict=strict))197.item()198)199200201def _cast_lit(202val: int | datetime | date | time | timedelta,203dtype_in: PolarsDataType,204dtype_out: PolarsDataType,205strict: bool,206) -> int | datetime | date | time | timedelta | None:207return pl.select(pl.lit(val, dtype=dtype_in).cast(dtype_out, strict=strict)).item() # type: ignore[no-any-return]208209210@pytest.mark.parametrize(211("value", "from_dtype", "to_dtype", "should_succeed", "expected_value"),212[213(-1, pl.Int8, pl.UInt8, False, None),214(-1, pl.Int16, pl.UInt16, False, None),215(-1, pl.Int32, pl.UInt32, False, None),216(-1, pl.Int64, pl.UInt64, False, None),217(2**7, pl.UInt8, pl.Int8, False, None),218(2**15, pl.UInt16, pl.Int16, False, None),219(2**31, pl.UInt32, pl.Int32, False, None),220(2**63, pl.UInt64, pl.Int64, False, None),221(2**7 - 1, pl.UInt8, pl.Int8, True, 2**7 - 1),222(2**15 - 1, pl.UInt16, pl.Int16, True, 2**15 - 1),223(2**31 - 1, pl.UInt32, pl.Int32, True, 2**31 - 1),224(2**63 - 1, pl.UInt64, pl.Int64, True, 2**63 - 1),225],226)227def test_strict_cast_int(228value: int,229from_dtype: PolarsDataType,230to_dtype: PolarsDataType,231should_succeed: bool,232expected_value: Any,233) -> None:234args = [value, from_dtype, to_dtype, True]235if should_succeed:236assert _cast_series(*args) == expected_value # type: ignore[arg-type]237assert _cast_expr(*args) == expected_value # type: ignore[arg-type]238assert _cast_lit(*args) == expected_value # type: ignore[arg-type]239else:240with pytest.raises(InvalidOperationError):241_cast_series(*args) # type: ignore[arg-type]242with pytest.raises(InvalidOperationError):243_cast_expr(*args) # type: ignore[arg-type]244with pytest.raises(InvalidOperationError):245_cast_lit(*args) # type: ignore[arg-type]246247248@pytest.mark.parametrize(249("value", "from_dtype", "to_dtype", "expected_value"),250[251(-1, pl.Int8, pl.UInt8, None),252(-1, pl.Int16, pl.UInt16, None),253(-1, pl.Int32, pl.UInt32, None),254(-1, pl.Int64, pl.UInt64, None),255(2**7, pl.UInt8, pl.Int8, None),256(2**15, pl.UInt16, pl.Int16, None),257(2**31, pl.UInt32, pl.Int32, None),258(2**63, pl.UInt64, pl.Int64, None),259(2**7 - 1, pl.UInt8, pl.Int8, 2**7 - 1),260(2**15 - 1, pl.UInt16, pl.Int16, 2**15 - 1),261(2**31 - 1, pl.UInt32, pl.Int32, 2**31 - 1),262(2**63 - 1, pl.UInt64, pl.Int64, 2**63 - 1),263],264)265def test_cast_int(266value: int,267from_dtype: PolarsDataType,268to_dtype: PolarsDataType,269expected_value: Any,270) -> None:271args = [value, from_dtype, to_dtype, False]272assert _cast_series(*args) == expected_value # type: ignore[arg-type]273assert _cast_expr(*args) == expected_value # type: ignore[arg-type]274assert _cast_lit(*args) == expected_value # type: ignore[arg-type]275276277def _cast_series_t(278val: int | datetime | date | time | timedelta,279dtype_in: PolarsDataType,280dtype_out: PolarsDataType,281strict: bool,282) -> pl.Series:283return pl.Series("a", [val], dtype=dtype_in).cast(dtype_out, strict=strict)284285286def _cast_expr_t(287val: int | datetime | date | time | timedelta,288dtype_in: PolarsDataType,289dtype_out: PolarsDataType,290strict: bool,291) -> pl.Series:292return (293pl.Series("a", [val], dtype=dtype_in)294.to_frame()295.select(pl.col("a").cast(dtype_out, strict=strict))296.to_series()297)298299300def _cast_lit_t(301val: int | datetime | date | time | timedelta,302dtype_in: PolarsDataType,303dtype_out: PolarsDataType,304strict: bool,305) -> pl.Series:306return pl.select(307pl.lit(val, dtype=dtype_in).cast(dtype_out, strict=strict)308).to_series()309310311@pytest.mark.parametrize(312(313"value",314"from_dtype",315"to_dtype",316"should_succeed",317"expected_value",318),319[320# date to datetime321(date(1970, 1, 1), pl.Date, pl.Datetime("ms"), True, datetime(1970, 1, 1)),322(date(1970, 1, 1), pl.Date, pl.Datetime("us"), True, datetime(1970, 1, 1)),323(date(1970, 1, 1), pl.Date, pl.Datetime("ns"), True, datetime(1970, 1, 1)),324# datetime to date325(datetime(1970, 1, 1), pl.Datetime("ms"), pl.Date, True, date(1970, 1, 1)),326(datetime(1970, 1, 1), pl.Datetime("us"), pl.Date, True, date(1970, 1, 1)),327(datetime(1970, 1, 1), pl.Datetime("ns"), pl.Date, True, date(1970, 1, 1)),328# datetime to time329(datetime(2000, 1, 1, 1, 0, 0), pl.Datetime("ms"), pl.Time, True, time(hour=1)),330(datetime(2000, 1, 1, 1, 0, 0), pl.Datetime("us"), pl.Time, True, time(hour=1)),331(datetime(2000, 1, 1, 1, 0, 0), pl.Datetime("ns"), pl.Time, True, time(hour=1)),332# duration to int333(timedelta(seconds=1), pl.Duration("ms"), pl.Int32, True, MS_PER_SECOND),334(timedelta(seconds=1), pl.Duration("us"), pl.Int64, True, US_PER_SECOND),335(timedelta(seconds=1), pl.Duration("ns"), pl.Int64, True, NS_PER_SECOND),336# time to duration337(time(hour=1), pl.Time, pl.Duration("ms"), True, timedelta(hours=1)),338(time(hour=1), pl.Time, pl.Duration("us"), True, timedelta(hours=1)),339(time(hour=1), pl.Time, pl.Duration("ns"), True, timedelta(hours=1)),340# int to date341(100, pl.UInt8, pl.Date, True, date(1970, 4, 11)),342(100, pl.UInt16, pl.Date, True, date(1970, 4, 11)),343(100, pl.UInt32, pl.Date, True, date(1970, 4, 11)),344(100, pl.UInt64, pl.Date, True, date(1970, 4, 11)),345(100, pl.Int8, pl.Date, True, date(1970, 4, 11)),346(100, pl.Int16, pl.Date, True, date(1970, 4, 11)),347(100, pl.Int32, pl.Date, True, date(1970, 4, 11)),348(100, pl.Int64, pl.Date, True, date(1970, 4, 11)),349# failures350(2**63 - 1, pl.Int64, pl.Date, False, None),351(-(2**62), pl.Int64, pl.Date, False, None),352(date(1970, 5, 10), pl.Date, pl.Int8, False, None),353(date(2149, 6, 7), pl.Date, pl.Int16, False, None),354(datetime(9999, 12, 31), pl.Datetime, pl.Int8, False, None),355(datetime(9999, 12, 31), pl.Datetime, pl.Int16, False, None),356],357)358def test_strict_cast_temporal(359value: int,360from_dtype: PolarsDataType,361to_dtype: PolarsDataType,362should_succeed: bool,363expected_value: Any,364) -> None:365args = [value, from_dtype, to_dtype, True]366if should_succeed:367out = _cast_series_t(*args) # type: ignore[arg-type]368assert out.item() == expected_value369assert out.dtype == to_dtype370out = _cast_expr_t(*args) # type: ignore[arg-type]371assert out.item() == expected_value372assert out.dtype == to_dtype373out = _cast_lit_t(*args) # type: ignore[arg-type]374assert out.item() == expected_value375assert out.dtype == to_dtype376else:377with pytest.raises(InvalidOperationError):378_cast_series_t(*args) # type: ignore[arg-type]379with pytest.raises(InvalidOperationError):380_cast_expr_t(*args) # type: ignore[arg-type]381with pytest.raises(InvalidOperationError):382_cast_lit_t(*args) # type: ignore[arg-type]383384385@pytest.mark.parametrize(386(387"value",388"from_dtype",389"to_dtype",390"expected_value",391),392[393# date to datetime394(date(1970, 1, 1), pl.Date, pl.Datetime("ms"), datetime(1970, 1, 1)),395(date(1970, 1, 1), pl.Date, pl.Datetime("us"), datetime(1970, 1, 1)),396(date(1970, 1, 1), pl.Date, pl.Datetime("ns"), datetime(1970, 1, 1)),397# datetime to date398(datetime(1970, 1, 1), pl.Datetime("ms"), pl.Date, date(1970, 1, 1)),399(datetime(1970, 1, 1), pl.Datetime("us"), pl.Date, date(1970, 1, 1)),400(datetime(1970, 1, 1), pl.Datetime("ns"), pl.Date, date(1970, 1, 1)),401# datetime to time402(datetime(2000, 1, 1, 1, 0, 0), pl.Datetime("ms"), pl.Time, time(hour=1)),403(datetime(2000, 1, 1, 1, 0, 0), pl.Datetime("us"), pl.Time, time(hour=1)),404(datetime(2000, 1, 1, 1, 0, 0), pl.Datetime("ns"), pl.Time, time(hour=1)),405# duration to int406(timedelta(seconds=1), pl.Duration("ms"), pl.Int32, MS_PER_SECOND),407(timedelta(seconds=1), pl.Duration("us"), pl.Int64, US_PER_SECOND),408(timedelta(seconds=1), pl.Duration("ns"), pl.Int64, NS_PER_SECOND),409# time to duration410(time(hour=1), pl.Time, pl.Duration("ms"), timedelta(hours=1)),411(time(hour=1), pl.Time, pl.Duration("us"), timedelta(hours=1)),412(time(hour=1), pl.Time, pl.Duration("ns"), timedelta(hours=1)),413# int to date414(100, pl.UInt8, pl.Date, date(1970, 4, 11)),415(100, pl.UInt16, pl.Date, date(1970, 4, 11)),416(100, pl.UInt32, pl.Date, date(1970, 4, 11)),417(100, pl.UInt64, pl.Date, date(1970, 4, 11)),418(100, pl.Int8, pl.Date, date(1970, 4, 11)),419(100, pl.Int16, pl.Date, date(1970, 4, 11)),420(100, pl.Int32, pl.Date, date(1970, 4, 11)),421(100, pl.Int64, pl.Date, date(1970, 4, 11)),422# failures423(2**63 - 1, pl.Int64, pl.Date, None),424(-(2**62), pl.Int64, pl.Date, None),425(date(1970, 5, 10), pl.Date, pl.Int8, None),426(date(2149, 6, 7), pl.Date, pl.Int16, None),427(datetime(9999, 12, 31), pl.Datetime, pl.Int8, None),428(datetime(9999, 12, 31), pl.Datetime, pl.Int16, None),429],430)431def test_cast_temporal(432value: int,433from_dtype: PolarsDataType,434to_dtype: PolarsDataType,435expected_value: Any,436) -> None:437args = [value, from_dtype, to_dtype, False]438out = _cast_series_t(*args) # type: ignore[arg-type]439if expected_value is None:440assert out.item() is None441else:442assert out.item() == expected_value443assert out.dtype == to_dtype444445out = _cast_expr_t(*args) # type: ignore[arg-type]446if expected_value is None:447assert out.item() is None448else:449assert out.item() == expected_value450assert out.dtype == to_dtype451452out = _cast_lit_t(*args) # type: ignore[arg-type]453if expected_value is None:454assert out.item() is None455else:456assert out.item() == expected_value457assert out.dtype == to_dtype458459460@pytest.mark.parametrize(461(462"value",463"from_dtype",464"to_dtype",465"expected_value",466),467[468(str(2**7 - 1), pl.String, pl.Int8, 2**7 - 1),469(str(2**15 - 1), pl.String, pl.Int16, 2**15 - 1),470(str(2**31 - 1), pl.String, pl.Int32, 2**31 - 1),471(str(2**63 - 1), pl.String, pl.Int64, 2**63 - 1),472("1.0", pl.String, pl.Float32, 1.0),473("1.0", pl.String, pl.Float64, 1.0),474# overflow475(str(2**7), pl.String, pl.Int8, None),476(str(2**15), pl.String, pl.Int16, None),477(str(2**31), pl.String, pl.Int32, None),478(str(2**63), pl.String, pl.Int64, None),479],480)481def test_cast_string(482value: int,483from_dtype: PolarsDataType,484to_dtype: PolarsDataType,485expected_value: Any,486) -> None:487args = [value, from_dtype, to_dtype, False]488out = _cast_series_t(*args) # type: ignore[arg-type]489if expected_value is None:490assert out.item() is None491else:492assert out.item() == expected_value493assert out.dtype == to_dtype494495out = _cast_expr_t(*args) # type: ignore[arg-type]496if expected_value is None:497assert out.item() is None498else:499assert out.item() == expected_value500assert out.dtype == to_dtype501502out = _cast_lit_t(*args) # type: ignore[arg-type]503if expected_value is None:504assert out.item() is None505else:506assert out.item() == expected_value507assert out.dtype == to_dtype508509510@pytest.mark.parametrize(511(512"value",513"from_dtype",514"to_dtype",515"should_succeed",516"expected_value",517),518[519(str(2**7 - 1), pl.String, pl.Int8, True, 2**7 - 1),520(str(2**15 - 1), pl.String, pl.Int16, True, 2**15 - 1),521(str(2**31 - 1), pl.String, pl.Int32, True, 2**31 - 1),522(str(2**63 - 1), pl.String, pl.Int64, True, 2**63 - 1),523("1.0", pl.String, pl.Float32, True, 1.0),524("1.0", pl.String, pl.Float64, True, 1.0),525# overflow526(str(2**7), pl.String, pl.Int8, False, None),527(str(2**15), pl.String, pl.Int16, False, None),528(str(2**31), pl.String, pl.Int32, False, None),529(str(2**63), pl.String, pl.Int64, False, None),530],531)532def test_strict_cast_string(533value: int,534from_dtype: PolarsDataType,535to_dtype: PolarsDataType,536should_succeed: bool,537expected_value: Any,538) -> None:539args = [value, from_dtype, to_dtype, True]540if should_succeed:541out = _cast_series_t(*args) # type: ignore[arg-type]542assert out.item() == expected_value543assert out.dtype == to_dtype544out = _cast_expr_t(*args) # type: ignore[arg-type]545assert out.item() == expected_value546assert out.dtype == to_dtype547out = _cast_lit_t(*args) # type: ignore[arg-type]548assert out.item() == expected_value549assert out.dtype == to_dtype550else:551with pytest.raises(InvalidOperationError):552_cast_series_t(*args) # type: ignore[arg-type]553with pytest.raises(InvalidOperationError):554_cast_expr_t(*args) # type: ignore[arg-type]555with pytest.raises(InvalidOperationError):556_cast_lit_t(*args) # type: ignore[arg-type]557558559@pytest.mark.parametrize(560"dtype_in",561[(pl.Categorical), (pl.Enum(["1"]))],562)563@pytest.mark.parametrize(564"dtype_out",565[566pl.String,567pl.Categorical,568pl.Enum(["1", "2"]),569],570)571def test_cast_categorical_name_retention(572dtype_in: PolarsDataType, dtype_out: PolarsDataType573) -> None:574assert pl.Series("a", ["1"], dtype=dtype_in).cast(dtype_out).name == "a"575576577def test_cast_date_to_time() -> None:578s = pl.Series([date(1970, 1, 1), date(2000, 12, 31)])579msg = "casting from Date to Time not supported"580with pytest.raises(InvalidOperationError, match=msg):581s.cast(pl.Time)582583584def test_cast_time_to_date() -> None:585s = pl.Series([time(0, 0), time(20, 00)])586msg = "casting from Time to Date not supported"587with pytest.raises(InvalidOperationError, match=msg):588s.cast(pl.Date)589590591def test_cast_decimal_to_boolean() -> None:592s = pl.Series("s", [Decimal("0.0"), Decimal("1.5"), Decimal("-1.5")])593assert_series_equal(s.cast(pl.Boolean), pl.Series("s", [False, True, True]))594595df = s.to_frame()596assert_frame_equal(597df.select(pl.col("s").cast(pl.Boolean)),598pl.DataFrame({"s": [False, True, True]}),599)600601602def test_cast_array_to_different_width() -> None:603s = pl.Series([[1, 2], [3, 4]], dtype=pl.Array(pl.Int8, 2))604with pytest.raises(605InvalidOperationError, match="cannot cast Array to a different width"606):607s.cast(pl.Array(pl.Int16, 3))608609610def test_cast_decimal_to_decimal_high_precision() -> None:611precision = 22612values = [Decimal("9" * precision)]613s = pl.Series(values, dtype=pl.Decimal(None, 0))614615target_dtype = pl.Decimal(precision, 0)616result = s.cast(target_dtype)617618assert result.dtype == target_dtype619assert result.to_list() == values620621622@pytest.mark.parametrize("value", [float("inf"), float("nan")])623def test_invalid_cast_float_to_decimal(value: float) -> None:624s = pl.Series([value], dtype=pl.Float64)625with pytest.raises(626InvalidOperationError,627match=r"conversion from `f64` to `decimal\[10,2\]` failed",628):629s.cast(pl.Decimal(10, 2))630631632def test_err_on_time_datetime_cast() -> None:633s = pl.Series([time(10, 0, 0), time(11, 30, 59)])634with pytest.raises(635InvalidOperationError,636match=r"casting from Time to Datetime\('μs'\) not supported; consider using `dt\.combine`",637):638s.cast(pl.Datetime)639640641def test_err_on_invalid_time_zone_cast() -> None:642s = pl.Series([datetime(2021, 1, 1)])643with pytest.raises(ComputeError, match=r"unable to parse time zone: 'qwerty'"):644s.cast(pl.Datetime("us", "qwerty"))645646647def test_invalid_inner_type_cast_list() -> None:648s = pl.Series([[-1, 1]])649with pytest.raises(650InvalidOperationError,651match=r"cannot cast List inner type: 'Int64' to Categorical",652):653s.cast(pl.List(pl.Categorical))654655656@pytest.mark.parametrize(657("values", "result"),658[659([[]], [b""]),660([[1, 2], [3, 4]], [b"\x01\x02", b"\x03\x04"]),661([[1, 2], None, [3, 4]], [b"\x01\x02", None, b"\x03\x04"]),662(663[None, [111, 110, 101], [12, None], [116, 119, 111], list(range(256))],664[665None,666b"one",667# A list with a null in it gets turned into a null:668None,669b"two",670bytes(i for i in range(256)),671],672),673],674)675def test_list_uint8_to_bytes(676values: list[list[int | None] | None], result: list[bytes | None]677) -> None:678s = pl.Series(679values,680dtype=pl.List(pl.UInt8()),681)682assert s.cast(pl.Binary(), strict=False).to_list() == result683684685def test_list_uint8_to_bytes_strict() -> None:686series = pl.Series(687[[1, 2], [3, 4]],688dtype=pl.List(pl.UInt8()),689)690assert series.cast(pl.Binary(), strict=True).to_list() == [b"\x01\x02", b"\x03\x04"]691692series = pl.Series(693"mycol",694[[1, 2], [3, None]],695dtype=pl.List(pl.UInt8()),696)697with pytest.raises(698InvalidOperationError,699match="conversion from `list\\[u8\\]` to `binary` failed in column 'mycol' for 1 out of 2 values: \\[\\[3, null\\]\\]",700):701series.cast(pl.Binary(), strict=True)702703704def test_all_null_cast_5826() -> None:705df = pl.DataFrame(data=[pl.Series("a", [None], dtype=pl.String)])706out = df.with_columns(pl.col("a").cast(pl.Boolean))707assert out.dtypes == [pl.Boolean]708assert out.item() is None709710711@pytest.mark.parametrize("dtype", INTEGER_DTYPES)712def test_bool_numeric_supertype(dtype: PolarsDataType) -> None:713df = pl.DataFrame({"v": [1, 2, 3, 4, 5, 6]})714result = df.select((pl.col("v") < 3).sum().cast(dtype) / pl.len())715assert result.item() - 0.3333333 <= 0.00001716717718@pytest.mark.parametrize("dtype", [pl.String(), pl.String, str])719def test_cast_consistency(dtype: PolarsDataType | PythonDataType) -> None:720assert pl.DataFrame().with_columns(a=pl.lit(0.0)).with_columns(721b=pl.col("a").cast(dtype), c=pl.lit(0.0).cast(dtype)722).to_dict(as_series=False) == {"a": [0.0], "b": ["0.0"], "c": ["0.0"]}723724725def test_cast_int_to_string_unsets_sorted_flag_19424() -> None:726s = pl.Series([1, 2]).set_sorted()727assert s.flags["SORTED_ASC"]728assert not s.cast(pl.String).flags["SORTED_ASC"]729730731def test_cast_integer_to_decimal() -> None:732s = pl.Series([1, 2, 3])733result = s.cast(pl.Decimal(10, 2))734expected = pl.Series(735"", [Decimal("1.00"), Decimal("2.00"), Decimal("3.00")], pl.Decimal(10, 2)736)737assert_series_equal(result, expected)738739740def test_cast_python_dtypes() -> None:741s = pl.Series([0, 1])742assert s.cast(int).dtype == pl.Int64743assert s.cast(float).dtype == pl.Float64744assert s.cast(bool).dtype == pl.Boolean745assert s.cast(str).dtype == pl.String746747748def test_overflowing_cast_literals_21023() -> None:749for optimizations in [pl.QueryOptFlags(), pl.QueryOptFlags.none()]:750assert_frame_equal(751(752pl.LazyFrame()753.select(754pl.lit(pl.Series([128], dtype=pl.Int64)).cast(755pl.Int8, wrap_numerical=True756)757)758.collect(optimizations=optimizations)759),760pl.Series([-128], dtype=pl.Int8).to_frame(),761)762763764@pytest.mark.parametrize("value", [True, False])765@pytest.mark.parametrize(766"dtype",767[768pl.Enum(["a", "b"]),769pl.Series(["a", "b"], dtype=pl.Categorical).dtype,770],771)772def test_invalid_bool_to_cat(value: bool, dtype: PolarsDataType) -> None:773# Enum774with pytest.raises(775InvalidOperationError,776match="cannot cast Boolean to Categorical",777):778pl.Series([value]).cast(dtype)779780781@pytest.mark.parametrize(782("values", "from_dtype", "to_dtype", "pre_apply"),783[784([["A"]], pl.List(pl.String), pl.List(pl.Int8), None),785([["A"]], pl.Array(pl.String, 1), pl.List(pl.Int8), None),786([[["A"]]], pl.List(pl.List(pl.String)), pl.List(pl.List(pl.Int8)), None),787(788[789{"x": "1", "y": "2"},790{"x": "A", "y": "B"},791{"x": "3", "y": "4"},792{"x": "X", "y": "Y"},793{"x": "5", "y": "6"},794],795pl.Struct(796{797"x": pl.String,798"y": pl.String,799}800),801pl.Struct(802{803"x": pl.Int8,804"y": pl.Int32,805}806),807None,808),809],810)811def test_nested_strict_casts_failing(812values: list[Any],813from_dtype: pl.DataType,814to_dtype: pl.DataType,815pre_apply: Callable[[pl.Series], pl.Series] | None,816) -> None:817s = pl.Series(values, dtype=from_dtype)818819if pre_apply is not None:820s = pre_apply(s)821822with pytest.raises(823pl.exceptions.InvalidOperationError,824match=r"conversion from",825):826s.cast(to_dtype)827828829@pytest.mark.parametrize(830("values", "from_dtype", "pre_apply", "to"),831[832(833[["A"], ["1"], ["2"]],834pl.List(pl.String),835lambda s: s.slice(1, 2),836pl.Series([[1], [2]]),837),838(839[["1"], ["A"], ["2"], ["B"], ["3"]],840pl.List(pl.String),841lambda s: s.filter(pl.Series([True, False, True, False, True])),842pl.Series([[1], [2], [3]]),843),844(845[846{"x": "1", "y": "2"},847{"x": "A", "y": "B"},848{"x": "3", "y": "4"},849{"x": "X", "y": "Y"},850{"x": "5", "y": "6"},851],852pl.Struct(853{854"x": pl.String,855"y": pl.String,856}857),858lambda s: s.filter(pl.Series([True, False, True, False, True])),859pl.Series(860[861{"x": 1, "y": 2},862{"x": 3, "y": 4},863{"x": 5, "y": 6},864]865),866),867(868[869{"x": "1", "y": "2"},870{"x": "A", "y": "B"},871{"x": "3", "y": "4"},872{"x": "X", "y": "Y"},873{"x": "5", "y": "6"},874],875pl.Struct(876{877"x": pl.String,878"y": pl.String,879}880),881lambda s: pl.select(882pl.when(pl.Series([True, False, True, False, True])).then(s)883).to_series(),884pl.Series(885[886{"x": 1, "y": 2},887None,888{"x": 3, "y": 4},889None,890{"x": 5, "y": 6},891]892),893),894],895)896def test_nested_strict_casts_succeeds(897values: list[Any],898from_dtype: pl.DataType,899pre_apply: Callable[[pl.Series], pl.Series] | None,900to: pl.Series,901) -> None:902s = pl.Series(values, dtype=from_dtype)903904if pre_apply is not None:905s = pre_apply(s)906907assert_series_equal(908s.cast(to.dtype),909to,910)911912913def test_nested_struct_cast_22744() -> None:914s = pl.Series(915"x",916[{"attrs": {"class": "a"}}],917)918919expected = pl.select(920pl.lit(s).struct.with_fields(921pl.field("attrs").struct.with_fields(922[pl.field("class"), pl.lit(None, dtype=pl.String()).alias("other")]923)924)925)926927assert_series_equal(928s.cast(929pl.Struct({"attrs": pl.Struct({"class": pl.String, "other": pl.String})})930),931expected.to_series(),932)933assert_frame_equal(934pl.DataFrame([s]).cast(935{936"x": pl.Struct(937{"attrs": pl.Struct({"class": pl.String, "other": pl.String})}938)939}940),941expected,942)943944945def test_cast_to_self_is_pruned() -> None:946q = pl.LazyFrame({"x": 1}, schema={"x": pl.Int64}).with_columns(947y=pl.col("x").cast(pl.Int64)948)949950plan = q.explain()951assert 'col("x").alias("y")' in plan952953assert_frame_equal(q.collect(), pl.DataFrame({"x": 1, "y": 1}))954955956@pytest.mark.parametrize(957("s", "to", "should_fail"),958[959(960pl.Series([datetime(2025, 1, 1)]),961pl.Datetime("ns"),962False,963),964(965pl.Series([datetime(9999, 1, 1)]),966pl.Datetime("ns"),967True,968),969(970pl.Series([datetime(2025, 1, 1), datetime(9999, 1, 1)]),971pl.Datetime("ns"),972True,973),974(975pl.Series([[datetime(2025, 1, 1)], [datetime(9999, 1, 1)]]),976pl.List(pl.Datetime("ns")),977True,978),979# lower date limit for nanosecond980(pl.Series([date(1677, 9, 22)]), pl.Datetime("ns"), False),981(pl.Series([date(1677, 9, 21)]), pl.Datetime("ns"), True),982# upper date limit for nanosecond983(pl.Series([date(2262, 4, 11)]), pl.Datetime("ns"), False),984(pl.Series([date(2262, 4, 12)]), pl.Datetime("ns"), True),985],986)987def test_cast_temporals_overflow_16039(988s: pl.Series, to: pl.DataType, should_fail: bool989) -> None:990if should_fail:991with pytest.raises(992pl.exceptions.InvalidOperationError, match="conversion from"993):994s.cast(to)995else:996s.cast(to)997998999@pytest.mark.parametrize("dtype", NUMERIC_DTYPES)1000def test_prune_superfluous_cast(dtype: PolarsDataType) -> None:1001lf = pl.LazyFrame({"a": [1, 2, 3]}, schema={"a": dtype})1002result = lf.select(pl.col("a").cast(dtype))1003assert "strict_cast" not in result.explain()100410051006def test_not_prune_necessary_cast() -> None:1007lf = pl.LazyFrame({"a": [1, 2, 3]}, schema={"a": pl.UInt16})1008result = lf.select(pl.col("a").cast(pl.UInt8))1009assert "strict_cast" in result.explain()101010111012@pytest.mark.parametrize("target_dtype", NUMERIC_DTYPES)1013@pytest.mark.parametrize("inner_dtype", NUMERIC_DTYPES)1014@pytest.mark.parametrize("op", [operator.mul, operator.truediv])1015def test_cast_optimizer_in_list_eval_23924(1016inner_dtype: PolarsDataType,1017target_dtype: PolarsDataType,1018op: Callable[[pl.Expr, pl.Expr], pl.Expr],1019) -> None:1020print(inner_dtype, target_dtype)1021if target_dtype in INTEGER_DTYPES:1022df = pl.Series("a", [[1]], dtype=pl.List(target_dtype)).to_frame()1023else:1024df = pl.Series("a", [[1.0]], dtype=pl.List(target_dtype)).to_frame()1025q = df.lazy().select(1026pl.col("a").list.eval(1027(op(pl.element(), pl.element().cast(inner_dtype))).cast(target_dtype)1028)1029)1030assert q.collect_schema() == q.collect().schema103110321033def test_lit_cast_arithmetic_23677() -> None:1034df = pl.DataFrame({"a": [1]}, schema={"a": pl.Float32})1035q = df.lazy().select(pl.col("a") / pl.lit(1, pl.Int32))1036expected = pl.Schema({"a": pl.Float64})1037assert q.collect().schema == expected103810391040@pytest.mark.parametrize("col_dtype", NUMERIC_DTYPES + [pl.Unknown])1041@pytest.mark.parametrize("lit_dtype", NUMERIC_DTYPES + [pl.Unknown])1042@pytest.mark.parametrize("op", [operator.mul, operator.truediv])1043def test_lit_cast_arithmetic_matrix_schema(1044col_dtype: PolarsDataType,1045lit_dtype: PolarsDataType,1046op: Callable[[pl.Expr, pl.Expr], pl.Expr],1047) -> None:1048# Note (hacky): simply casting to 'pl.Unknown' would create1049# `Unknown(UnknownKind::Any())` which is not what we want: the1050# default maps to `Unknown(UnknownKind::Int(_)))` so we adjust1051df = (1052pl.DataFrame({"a": [1]})1053if col_dtype == pl.Unknown1054else pl.DataFrame({"a": [1]}, schema={"a": col_dtype})1055)1056q = (1057df.lazy().select(op(pl.col("a"), pl.lit(1)))1058if lit_dtype == pl.Unknown1059else df.lazy().select(op(pl.col("a"), pl.lit(1, lit_dtype)))1060)1061assert q.collect_schema() == q.collect().schema106210631064def test_strict_cast_nested() -> None:1065df = pl.DataFrame({"a": ["42", "10a"]})1066struct = pl.Struct({"x": pl.Int32})1067with pytest.raises(InvalidOperationError):1068df.cast(struct, strict=True)10691070assert_frame_equal(1071df.cast(struct, strict=False),1072pl.DataFrame({"a": [{"x": 42}, {"x": None}]}, schema={"a": struct}),1073)107410751076