Path: blob/main/py-polars/tests/unit/interop/numpy/test_to_numpy_series.py
6939 views
from __future__ import annotations12from datetime import date, datetime, time, timedelta3from decimal import Decimal as D4from pathlib import Path5from typing import TYPE_CHECKING, Any67import numpy as np8import pytest9from hypothesis import given, settings10from numpy.testing import assert_array_equal1112import polars as pl13from polars.testing import assert_series_equal14from polars.testing.parametric import series1516if TYPE_CHECKING:17import numpy.typing as npt1819from polars._typing import PolarsDataType202122def assert_zero_copy(s: pl.Series, arr: np.ndarray[Any, Any]) -> None:23if s.len() == 0:24return25s_ptr = s._get_buffers()["values"]._get_buffer_info()[0]26arr_ptr = arr.__array_interface__["data"][0]27assert s_ptr == arr_ptr282930def assert_allow_copy_false_raises(s: pl.Series) -> None:31with pytest.raises(RuntimeError, match="copy not allowed"):32s.to_numpy(allow_copy=False)333435@pytest.mark.parametrize(36("dtype", "expected_dtype"),37[38(pl.Int8, np.int8),39(pl.Int16, np.int16),40(pl.Int32, np.int32),41(pl.Int64, np.int64),42(pl.UInt8, np.uint8),43(pl.UInt16, np.uint16),44(pl.UInt32, np.uint32),45(pl.UInt64, np.uint64),46(pl.Float32, np.float32),47(pl.Float64, np.float64),48],49)50def test_series_to_numpy_numeric_zero_copy(51dtype: PolarsDataType, expected_dtype: npt.DTypeLike52) -> None:53s = pl.Series([1, 2, 3]).cast(dtype)54result: npt.NDArray[np.generic] = s.to_numpy(allow_copy=False)5556assert_zero_copy(s, result)57assert result.tolist() == s.to_list()58assert result.dtype == expected_dtype596061@pytest.mark.parametrize(62("dtype", "expected_dtype"),63[64(pl.Int8, np.float32),65(pl.Int16, np.float32),66(pl.Int32, np.float64),67(pl.Int64, np.float64),68(pl.UInt8, np.float32),69(pl.UInt16, np.float32),70(pl.UInt32, np.float64),71(pl.UInt64, np.float64),72(pl.Float32, np.float32),73(pl.Float64, np.float64),74],75)76def test_series_to_numpy_numeric_with_nulls(77dtype: PolarsDataType, expected_dtype: npt.DTypeLike78) -> None:79s = pl.Series([1, 2, None], dtype=dtype, strict=False)80result: npt.NDArray[np.generic] = s.to_numpy()8182assert result.tolist()[:-1] == s.to_list()[:-1]83assert np.isnan(result[-1])84assert result.dtype == expected_dtype85assert_allow_copy_false_raises(s)868788@pytest.mark.parametrize(89("dtype", "expected_dtype"),90[91(pl.Duration, np.dtype("timedelta64[us]")),92(pl.Duration("ms"), np.dtype("timedelta64[ms]")),93(pl.Duration("us"), np.dtype("timedelta64[us]")),94(pl.Duration("ns"), np.dtype("timedelta64[ns]")),95(pl.Datetime, np.dtype("datetime64[us]")),96(pl.Datetime("ms"), np.dtype("datetime64[ms]")),97(pl.Datetime("us"), np.dtype("datetime64[us]")),98(pl.Datetime("ns"), np.dtype("datetime64[ns]")),99],100)101def test_series_to_numpy_temporal_zero_copy(102dtype: PolarsDataType, expected_dtype: npt.DTypeLike103) -> None:104values = [0, 2_000, 1_000_000]105s = pl.Series(values, dtype=dtype, strict=False)106result: npt.NDArray[np.generic] = s.to_numpy(allow_copy=False)107108assert_zero_copy(s, result)109# NumPy tolist returns integers for ns precision110if s.dtype.time_unit == "ns": # type: ignore[attr-defined]111assert result.tolist() == values112else:113assert result.tolist() == s.to_list()114assert result.dtype == expected_dtype115116117def test_series_to_numpy_datetime_with_tz_zero_copy() -> None:118values = [datetime(1970, 1, 1), datetime(2024, 2, 28)]119s = pl.Series(values).dt.convert_time_zone("Europe/Amsterdam").rechunk()120result: npt.NDArray[np.generic] = s.to_numpy(allow_copy=False)121122assert_zero_copy(s, result)123assert result.tolist() == values124assert result.dtype == np.dtype("datetime64[us]")125126127def test_series_to_numpy_date() -> None:128values = [date(1970, 1, 1), date(2024, 2, 28)]129s = pl.Series(values)130131result: npt.NDArray[np.generic] = s.to_numpy()132133assert s.to_list() == result.tolist()134assert result.dtype == np.dtype("datetime64[D]")135assert result.flags.writeable is True136assert_allow_copy_false_raises(s)137138139def test_series_to_numpy_multi_dimensional_init() -> None:140s = pl.Series(np.atleast_3d(np.array([-10.5, 0.0, 10.5])))141assert_series_equal(142s,143pl.Series(144[[[-10.5], [0.0], [10.5]]],145dtype=pl.Array(pl.Float64, shape=(3, 1)),146),147)148s = pl.Series(np.array(0), dtype=pl.Int32)149assert_series_equal(s, pl.Series([0], dtype=pl.Int32))150151152@pytest.mark.parametrize(153("dtype", "expected_dtype"),154[155(pl.Date, np.dtype("datetime64[D]")),156(pl.Duration("ms"), np.dtype("timedelta64[ms]")),157(pl.Duration("us"), np.dtype("timedelta64[us]")),158(pl.Duration("ns"), np.dtype("timedelta64[ns]")),159(pl.Datetime, np.dtype("datetime64[us]")),160(pl.Datetime("ms"), np.dtype("datetime64[ms]")),161(pl.Datetime("us"), np.dtype("datetime64[us]")),162(pl.Datetime("ns"), np.dtype("datetime64[ns]")),163],164)165def test_series_to_numpy_temporal_with_nulls(166dtype: PolarsDataType, expected_dtype: npt.DTypeLike167) -> None:168values = [0, 2_000, 1_000_000, None]169s = pl.Series(values, dtype=dtype, strict=False)170result: npt.NDArray[np.generic] = s.to_numpy()171172# NumPy tolist returns integers for ns precision173if getattr(s.dtype, "time_unit", None) == "ns":174assert result.tolist() == values175else:176assert result.tolist() == s.to_list()177assert result.dtype == expected_dtype178assert_allow_copy_false_raises(s)179180181def test_series_to_numpy_datetime_with_tz_with_nulls() -> None:182values = [datetime(1970, 1, 1), datetime(2024, 2, 28), None]183s = pl.Series(values).dt.convert_time_zone("Europe/Amsterdam")184result: npt.NDArray[np.generic] = s.to_numpy()185186assert result.tolist() == values187assert result.dtype == np.dtype("datetime64[us]")188assert_allow_copy_false_raises(s)189190191@pytest.mark.parametrize(192("dtype", "values"),193[194(pl.Time, [time(10, 30, 45), time(23, 59, 59)]),195(pl.Categorical, ["a", "b", "a"]),196(pl.Enum(["a", "b", "c"]), ["a", "b", "a"]),197(pl.String, ["a", "bc", "def"]),198(pl.Binary, [b"a", b"bc", b"def"]),199(pl.Decimal, [D("1.234"), D("2.345"), D("-3.456")]),200(pl.Object, [Path(), Path("abc")]),201],202)203@pytest.mark.parametrize("with_nulls", [False, True])204def test_to_numpy_object_dtypes(205dtype: PolarsDataType, values: list[Any], with_nulls: bool206) -> None:207if with_nulls:208values.append(None)209210s = pl.Series(values, dtype=dtype)211result: npt.NDArray[np.generic] = s.to_numpy()212213assert result.tolist() == values214assert result.dtype == np.object_215assert_allow_copy_false_raises(s)216217218def test_series_to_numpy_bool() -> None:219s = pl.Series([True, False])220result: npt.NDArray[np.generic] = s.to_numpy()221222assert s.to_list() == result.tolist()223assert result.dtype == np.bool_224assert result.flags.writeable is True225assert_allow_copy_false_raises(s)226227228def test_series_to_numpy_bool_with_nulls() -> None:229s = pl.Series([True, False, None])230result: npt.NDArray[np.generic] = s.to_numpy()231232assert s.to_list() == result.tolist()233assert result.dtype == np.object_234assert_allow_copy_false_raises(s)235236237def test_series_to_numpy_array_of_int() -> None:238values = [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]239s = pl.Series(values, dtype=pl.Array(pl.Array(pl.Int8, 3), 2))240result = s.to_numpy(allow_copy=False)241242expected = np.array(values)243assert_array_equal(result, expected)244assert result.dtype == np.int8245assert result.shape == (2, 2, 3)246247248def test_series_to_numpy_array_of_str() -> None:249values = [["1", "2", "3"], ["4", "5", "10000"]]250s = pl.Series(values, dtype=pl.Array(pl.String, 3))251result: npt.NDArray[np.generic] = s.to_numpy()252assert result.tolist() == values253assert result.dtype == np.object_254255256def test_series_to_numpy_array_with_nulls() -> None:257values = [[1, 2], [3, 4], None]258s = pl.Series(values, dtype=pl.Array(pl.Int64, 2))259result = s.to_numpy()260261expected = np.array([[1.0, 2.0], [3.0, 4.0], [np.nan, np.nan]])262assert_array_equal(result, expected)263assert result.dtype == np.float64264assert_allow_copy_false_raises(s)265266267def test_series_to_numpy_array_with_nested_nulls() -> None:268values = [[None, 2], [3, 4], [5, None]]269s = pl.Series(values, dtype=pl.Array(pl.Int64, 2))270result = s.to_numpy()271272expected = np.array([[np.nan, 2.0], [3.0, 4.0], [5.0, np.nan]])273assert_array_equal(result, expected)274assert result.dtype == np.float64275assert_allow_copy_false_raises(s)276277278def test_series_to_numpy_array_of_arrays() -> None:279values = [[[None, 2], [3, 4]], [None, [7, 8]]]280s = pl.Series(values, dtype=pl.Array(pl.Array(pl.Int64, 2), 2))281result = s.to_numpy()282283expected = np.array([[[np.nan, 2], [3, 4]], [[np.nan, np.nan], [7, 8]]])284assert_array_equal(result, expected)285assert result.dtype == np.float64286assert result.shape == (2, 2, 2)287assert_allow_copy_false_raises(s)288289290@pytest.mark.parametrize("chunked", [True, False])291def test_series_to_numpy_list(chunked: bool) -> None:292values = [[1, 2], [3, 4, 5], [6], []]293s = pl.Series(values)294if chunked:295s = pl.concat([s[:2], s[2:]])296result = s.to_numpy()297298expected = np.array([np.array(v, dtype=np.int64) for v in values], dtype=np.object_)299for res, exp in zip(result, expected):300assert_array_equal(res, exp)301assert result.dtype == expected.dtype302assert_allow_copy_false_raises(s)303304305def test_series_to_numpy_struct_numeric_supertype() -> None:306values = [{"a": 1, "b": 2.0}, {"a": 3, "b": 4.0}, {"a": 5, "b": None}]307s = pl.Series(values)308result = s.to_numpy()309310expected = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, np.nan]])311assert_array_equal(result, expected)312assert result.dtype == np.float64313assert_allow_copy_false_raises(s)314315316def test_to_numpy_null() -> None:317s = pl.Series([None, None], dtype=pl.Null)318result = s.to_numpy()319expected = np.array([np.nan, np.nan], dtype=np.float32)320assert_array_equal(result, expected)321assert result.dtype == np.float32322assert_allow_copy_false_raises(s)323324325def test_to_numpy_empty() -> None:326s = pl.Series(dtype=pl.String)327result = s.to_numpy(allow_copy=False)328assert result.dtype == np.object_329assert result.shape == (0,)330331332def test_to_numpy_empty_writable() -> None:333s = pl.Series(dtype=pl.Int64)334result = s.to_numpy(allow_copy=False, writable=True)335assert result.dtype == np.int64336assert result.shape == (0,)337assert result.flags.writeable is True338339340def test_to_numpy_chunked() -> None:341s1 = pl.Series([1, 2])342s2 = pl.Series([3, 4])343s = pl.concat([s1, s2], rechunk=False)344345result: npt.NDArray[np.generic] = s.to_numpy()346347assert result.tolist() == s.to_list()348assert result.dtype == np.int64349assert result.flags.writeable is True350assert_allow_copy_false_raises(s)351352# Check that writing to the array doesn't change the original data353result[0] = 10354assert result.tolist() == [10, 2, 3, 4]355assert s.to_list() == [1, 2, 3, 4]356357358def test_to_numpy_chunked_temporal_nested() -> None:359dtype = pl.Array(pl.Datetime("us"), 1)360s1 = pl.Series([[datetime(2020, 1, 1)], [datetime(2021, 1, 1)]], dtype=dtype)361s2 = pl.Series([[datetime(2022, 1, 1)], [datetime(2023, 1, 1)]], dtype=dtype)362s = pl.concat([s1, s2], rechunk=False)363364result: npt.NDArray[np.generic] = s.to_numpy()365366assert result.tolist() == s.to_list()367assert result.dtype == np.dtype("datetime64[us]")368assert result.shape == (4, 1)369assert result.flags.writeable is True370assert_allow_copy_false_raises(s)371372373def test_zero_copy_only_deprecated() -> None:374values = [1, 2]375s = pl.Series([1, 2])376with pytest.deprecated_call():377result: npt.NDArray[np.generic] = s.to_numpy(zero_copy_only=True)378assert result.tolist() == values379380381def test_series_to_numpy_temporal() -> None:382s0 = pl.Series("date", [123543, 283478, 1243]).cast(pl.Date)383s1 = pl.Series(384"datetime", [datetime(2021, 1, 2, 3, 4, 5), datetime(2021, 2, 3, 4, 5, 6)]385)386s2 = pl.datetime_range(387datetime(2021, 1, 1, 0),388datetime(2021, 1, 1, 1),389interval="1h",390time_unit="ms",391eager=True,392)393assert str(s0.to_numpy()) == "['2308-04-02' '2746-02-20' '1973-05-28']"394assert (395str(s1.to_numpy()[:2])396== "['2021-01-02T03:04:05.000000' '2021-02-03T04:05:06.000000']"397)398assert (399str(s2.to_numpy()[:2])400== "['2021-01-01T00:00:00.000' '2021-01-01T01:00:00.000']"401)402s3 = pl.Series([timedelta(hours=1), timedelta(hours=-2)])403out = np.array([3_600_000_000_000, -7_200_000_000_000], dtype="timedelta64[ns]")404assert (s3.to_numpy() == out).all()405406407@given(408s=series(409min_size=1,410max_size=10,411excluded_dtypes=[412pl.Categorical,413pl.List,414pl.Struct,415pl.Datetime("ms"),416pl.Duration("ms"),417],418allow_null=False,419allow_time_zones=False, # NumPy does not support parsing time zone aware data420).filter(421lambda s: (422not (s.dtype == pl.String and s.str.contains("\x00").any())423and not (s.dtype == pl.Binary and s.bin.contains(b"\x00").any())424)425),426)427@settings(max_examples=250)428def test_series_to_numpy(s: pl.Series) -> None:429result = s.to_numpy()430431values = s.to_list()432dtype_map = {433pl.Datetime("ns"): "datetime64[ns]",434pl.Datetime("us"): "datetime64[us]",435pl.Duration("ns"): "timedelta64[ns]",436pl.Duration("us"): "timedelta64[us]",437pl.Null(): "float32",438}439np_dtype = dtype_map.get(s.dtype)440expected = np.array(values, dtype=np_dtype)441442assert_array_equal(result, expected)443444445@pytest.mark.parametrize("writable", [False, True])446@pytest.mark.parametrize("pyarrow_available", [False, True])447def test_to_numpy2(448writable: bool, pyarrow_available: bool, monkeypatch: pytest.MonkeyPatch449) -> None:450monkeypatch.setattr(pl.series.series, "_PYARROW_AVAILABLE", pyarrow_available)451452np_array = pl.Series("a", [1, 2, 3], pl.UInt8).to_numpy(writable=writable)453454np.testing.assert_array_equal(np_array, np.array([1, 2, 3], dtype=np.uint8))455# Test if numpy array is readonly or writable.456assert np_array.flags.writeable == writable457458if writable:459np_array[1] += 10460np.testing.assert_array_equal(np_array, np.array([1, 12, 3], dtype=np.uint8))461462np_array_with_missing_values = pl.Series("a", [None, 2, 3], pl.UInt8).to_numpy(463writable=writable464)465466np.testing.assert_array_equal(467np_array_with_missing_values,468np.array(469[np.nan, 2.0, 3.0],470dtype=(np.float64 if pyarrow_available else np.float32),471),472)473474if writable:475# As Null values can't be encoded natively in a numpy array,476# this array will never be a view.477assert np_array_with_missing_values.flags.writeable == writable478479480def test_to_numpy_series_indexed_18986() -> None:481df = pl.DataFrame({"a": [[4, 5, 6], [7, 8, 9, 10], None]})482assert (df[1].to_numpy()[0, 0] == np.array([7, 8, 9, 10])).all()483assert (484df.to_numpy()[2] == np.array([None])485).all() # this one is strange, but only option in numpy?486487488