Path: blob/main/py-polars/tests/unit/interop/numpy/test_to_numpy_series.py
8422 views
from __future__ import annotations12from datetime import date, datetime, time, timedelta3from decimal import Decimal as D4from pathlib import Path5from typing import TYPE_CHECKING, Any67import numpy as np8import pytest9from hypothesis import given, settings10from numpy.testing import assert_array_equal1112import polars as pl13from polars.testing import assert_series_equal14from polars.testing.parametric import series1516if TYPE_CHECKING:17import numpy.typing as npt1819from polars._typing import PolarsDataType20from tests.conftest import PlMonkeyPatch212223def assert_zero_copy(s: pl.Series, arr: np.ndarray[Any, Any]) -> None:24if s.len() == 0:25return26s_ptr = s._get_buffers()["values"]._get_buffer_info()[0]27arr_ptr = arr.__array_interface__["data"][0]28assert s_ptr == arr_ptr293031def assert_allow_copy_false_raises(s: pl.Series) -> None:32with pytest.raises(RuntimeError, match="copy not allowed"):33s.to_numpy(allow_copy=False)343536@pytest.mark.parametrize(37("dtype", "expected_dtype"),38[39(pl.Int8, np.int8),40(pl.Int16, np.int16),41(pl.Int32, np.int32),42(pl.Int64, np.int64),43(pl.UInt8, np.uint8),44(pl.UInt16, np.uint16),45(pl.UInt32, np.uint32),46(pl.UInt64, np.uint64),47(pl.Float32, np.float32),48(pl.Float64, np.float64),49],50)51def test_series_to_numpy_numeric_zero_copy(52dtype: PolarsDataType, expected_dtype: npt.DTypeLike53) -> None:54s = pl.Series([1, 2, 3]).cast(dtype)55result: npt.NDArray[np.generic] = s.to_numpy(allow_copy=False)5657assert_zero_copy(s, result)58assert result.tolist() == s.to_list()59assert result.dtype == expected_dtype606162@pytest.mark.parametrize(63("dtype", "expected_dtype"),64[65(pl.Int8, np.float32),66(pl.Int16, np.float32),67(pl.Int32, np.float64),68(pl.Int64, np.float64),69(pl.UInt8, np.float32),70(pl.UInt16, np.float32),71(pl.UInt32, np.float64),72(pl.UInt64, np.float64),73(pl.Float32, np.float32),74(pl.Float64, np.float64),75],76)77def test_series_to_numpy_numeric_with_nulls(78dtype: PolarsDataType, expected_dtype: npt.DTypeLike79) -> None:80s = pl.Series([1, 2, None], dtype=dtype, strict=False)81result: npt.NDArray[np.generic] = s.to_numpy()8283assert result.tolist()[:-1] == s.to_list()[:-1]84assert np.isnan(result[-1])85assert result.dtype == expected_dtype86assert_allow_copy_false_raises(s)878889@pytest.mark.parametrize(90("dtype", "expected_dtype"),91[92(pl.Duration, np.dtype("timedelta64[us]")),93(pl.Duration("ms"), np.dtype("timedelta64[ms]")),94(pl.Duration("us"), np.dtype("timedelta64[us]")),95(pl.Duration("ns"), np.dtype("timedelta64[ns]")),96(pl.Datetime, np.dtype("datetime64[us]")),97(pl.Datetime("ms"), np.dtype("datetime64[ms]")),98(pl.Datetime("us"), np.dtype("datetime64[us]")),99(pl.Datetime("ns"), np.dtype("datetime64[ns]")),100],101)102def test_series_to_numpy_temporal_zero_copy(103dtype: PolarsDataType, expected_dtype: npt.DTypeLike104) -> None:105values = [0, 2_000, 1_000_000]106s = pl.Series(values, dtype=dtype, strict=False)107result: npt.NDArray[np.generic] = s.to_numpy(allow_copy=False)108109assert_zero_copy(s, result)110# NumPy tolist returns integers for ns precision111if s.dtype.time_unit == "ns": # type: ignore[attr-defined]112assert result.tolist() == values113else:114assert result.tolist() == s.to_list()115assert result.dtype == expected_dtype116117118def test_series_to_numpy_datetime_with_tz_zero_copy() -> None:119values = [datetime(1970, 1, 1), datetime(2024, 2, 28)]120s = pl.Series(values).dt.convert_time_zone("Europe/Amsterdam").rechunk()121result: npt.NDArray[np.generic] = s.to_numpy(allow_copy=False)122123assert_zero_copy(s, result)124assert result.tolist() == values125assert result.dtype == np.dtype("datetime64[us]")126127128def test_series_to_numpy_date() -> None:129values = [date(1970, 1, 1), date(2024, 2, 28)]130s = pl.Series(values)131132result: npt.NDArray[np.generic] = s.to_numpy()133134assert s.to_list() == result.tolist()135assert result.dtype == np.dtype("datetime64[D]")136assert result.flags.writeable is True137assert_allow_copy_false_raises(s)138139140def test_series_to_numpy_multi_dimensional_init() -> None:141s = pl.Series(np.atleast_3d(np.array([-10.5, 0.0, 10.5])))142assert_series_equal(143s,144pl.Series(145[[[-10.5], [0.0], [10.5]]],146dtype=pl.Array(pl.Float64, shape=(3, 1)),147),148)149s = pl.Series(np.array(0), dtype=pl.Int32)150assert_series_equal(s, pl.Series([0], dtype=pl.Int32))151152153@pytest.mark.parametrize(154("dtype", "expected_dtype"),155[156(pl.Date, np.dtype("datetime64[D]")),157(pl.Duration("ms"), np.dtype("timedelta64[ms]")),158(pl.Duration("us"), np.dtype("timedelta64[us]")),159(pl.Duration("ns"), np.dtype("timedelta64[ns]")),160(pl.Datetime, np.dtype("datetime64[us]")),161(pl.Datetime("ms"), np.dtype("datetime64[ms]")),162(pl.Datetime("us"), np.dtype("datetime64[us]")),163(pl.Datetime("ns"), np.dtype("datetime64[ns]")),164],165)166def test_series_to_numpy_temporal_with_nulls(167dtype: PolarsDataType, expected_dtype: npt.DTypeLike168) -> None:169values = [0, 2_000, 1_000_000, None]170s = pl.Series(values, dtype=dtype, strict=False)171result: npt.NDArray[np.generic] = s.to_numpy()172173# NumPy tolist returns integers for ns precision174if getattr(s.dtype, "time_unit", None) == "ns":175assert result.tolist() == values176else:177assert result.tolist() == s.to_list()178assert result.dtype == expected_dtype179assert_allow_copy_false_raises(s)180181182def test_series_to_numpy_datetime_with_tz_with_nulls() -> None:183values = [datetime(1970, 1, 1), datetime(2024, 2, 28), None]184s = pl.Series(values).dt.convert_time_zone("Europe/Amsterdam")185result: npt.NDArray[np.generic] = s.to_numpy()186187assert result.tolist() == values188assert result.dtype == np.dtype("datetime64[us]")189assert_allow_copy_false_raises(s)190191192@pytest.mark.parametrize(193("dtype", "values"),194[195(pl.Time, [time(10, 30, 45), time(23, 59, 59)]),196(pl.Categorical, ["a", "b", "a"]),197(pl.Enum(["a", "b", "c"]), ["a", "b", "a"]),198(pl.String, ["a", "bc", "def"]),199(pl.Binary, [b"a", b"bc", b"def"]),200(pl.Decimal, [D("1.234"), D("2.345"), D("-3.456")]),201(pl.Object, [Path(), Path("abc")]),202],203)204@pytest.mark.parametrize("with_nulls", [False, True])205def test_to_numpy_object_dtypes(206dtype: PolarsDataType, values: list[Any], with_nulls: bool207) -> None:208if with_nulls:209values.append(None)210211s = pl.Series(values, dtype=dtype)212result: npt.NDArray[np.generic] = s.to_numpy()213214assert result.tolist() == values215assert result.dtype == np.object_216assert_allow_copy_false_raises(s)217218219def test_series_to_numpy_bool() -> None:220s = pl.Series([True, False])221result: npt.NDArray[np.generic] = s.to_numpy()222223assert s.to_list() == result.tolist()224assert result.dtype == np.bool_225assert result.flags.writeable is True226assert_allow_copy_false_raises(s)227228229def test_series_to_numpy_bool_with_nulls() -> None:230s = pl.Series([True, False, None])231result: npt.NDArray[np.generic] = s.to_numpy()232233assert s.to_list() == result.tolist()234assert result.dtype == np.object_235assert_allow_copy_false_raises(s)236237238def test_series_to_numpy_array_of_int() -> None:239values = [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]240s = pl.Series(values, dtype=pl.Array(pl.Array(pl.Int8, 3), 2))241result = s.to_numpy(allow_copy=False)242243expected = np.array(values)244assert_array_equal(result, expected)245assert result.dtype == np.int8246assert result.shape == (2, 2, 3)247248249def test_series_to_numpy_array_of_str() -> None:250values = [["1", "2", "3"], ["4", "5", "10000"]]251s = pl.Series(values, dtype=pl.Array(pl.String, 3))252result: npt.NDArray[np.generic] = s.to_numpy()253assert result.tolist() == values254assert result.dtype == np.object_255256257def test_series_to_numpy_array_with_nulls() -> None:258values = [[1, 2], [3, 4], None]259s = pl.Series(values, dtype=pl.Array(pl.Int64, 2))260result = s.to_numpy()261262expected = np.array([[1.0, 2.0], [3.0, 4.0], [np.nan, np.nan]])263assert_array_equal(result, expected)264assert result.dtype == np.float64265assert_allow_copy_false_raises(s)266267268def test_series_to_numpy_array_with_nested_nulls() -> None:269values = [[None, 2], [3, 4], [5, None]]270s = pl.Series(values, dtype=pl.Array(pl.Int64, 2))271result = s.to_numpy()272273expected = np.array([[np.nan, 2.0], [3.0, 4.0], [5.0, np.nan]])274assert_array_equal(result, expected)275assert result.dtype == np.float64276assert_allow_copy_false_raises(s)277278279def test_series_to_numpy_array_of_arrays() -> None:280values = [[[None, 2], [3, 4]], [None, [7, 8]]]281s = pl.Series(values, dtype=pl.Array(pl.Array(pl.Int64, 2), 2))282result = s.to_numpy()283284expected = np.array([[[np.nan, 2], [3, 4]], [[np.nan, np.nan], [7, 8]]])285assert_array_equal(result, expected)286assert result.dtype == np.float64287assert result.shape == (2, 2, 2)288assert_allow_copy_false_raises(s)289290291@pytest.mark.parametrize("chunked", [True, False])292def test_series_to_numpy_list(chunked: bool) -> None:293values = [[1, 2], [3, 4, 5], [6], []]294s = pl.Series(values)295if chunked:296s = pl.concat([s[:2], s[2:]])297result = s.to_numpy()298299expected = np.array([np.array(v, dtype=np.int64) for v in values], dtype=np.object_)300for res, exp in zip(result, expected, strict=True):301assert_array_equal(res, exp)302assert result.dtype == expected.dtype303assert_allow_copy_false_raises(s)304305306def test_series_to_numpy_struct_numeric_supertype() -> None:307values = [{"a": 1, "b": 2.0}, {"a": 3, "b": 4.0}, {"a": 5, "b": None}]308s = pl.Series(values)309result = s.to_numpy()310311expected = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, np.nan]])312assert_array_equal(result, expected)313assert result.dtype == np.float64314assert_allow_copy_false_raises(s)315316317def test_to_numpy_null() -> None:318s = pl.Series([None, None], dtype=pl.Null)319result = s.to_numpy()320expected = np.array([np.nan, np.nan], dtype=np.float32)321assert_array_equal(result, expected)322assert result.dtype == np.float32323assert_allow_copy_false_raises(s)324325326def test_to_numpy_empty() -> None:327s = pl.Series(dtype=pl.String)328result = s.to_numpy(allow_copy=False)329assert result.dtype == np.object_330assert result.shape == (0,)331332333def test_to_numpy_empty_writable() -> None:334s = pl.Series(dtype=pl.Int64)335result = s.to_numpy(allow_copy=False, writable=True)336assert result.dtype == np.int64337assert result.shape == (0,)338assert result.flags.writeable is True339340341def test_to_numpy_chunked() -> None:342s1 = pl.Series([1, 2])343s2 = pl.Series([3, 4])344s = pl.concat([s1, s2], rechunk=False)345346result: npt.NDArray[np.generic] = s.to_numpy()347348assert result.tolist() == s.to_list()349assert result.dtype == np.int64350assert result.flags.writeable is True351assert_allow_copy_false_raises(s)352353# Check that writing to the array doesn't change the original data354result[0] = 10355assert result.tolist() == [10, 2, 3, 4]356assert s.to_list() == [1, 2, 3, 4]357358359def test_to_numpy_chunked_temporal_nested() -> None:360dtype = pl.Array(pl.Datetime("us"), 1)361s1 = pl.Series([[datetime(2020, 1, 1)], [datetime(2021, 1, 1)]], dtype=dtype)362s2 = pl.Series([[datetime(2022, 1, 1)], [datetime(2023, 1, 1)]], dtype=dtype)363s = pl.concat([s1, s2], rechunk=False)364365result: npt.NDArray[np.generic] = s.to_numpy()366367assert result.tolist() == s.to_list()368assert result.dtype == np.dtype("datetime64[us]")369assert result.shape == (4, 1)370assert result.flags.writeable is True371assert_allow_copy_false_raises(s)372373374def test_zero_copy_only_deprecated() -> None:375values = [1, 2]376s = pl.Series([1, 2])377with pytest.deprecated_call():378result: npt.NDArray[np.generic] = s.to_numpy(zero_copy_only=True)379assert result.tolist() == values380381382def test_series_to_numpy_temporal() -> None:383s0 = pl.Series("date", [123543, 283478, 1243]).cast(pl.Date)384s1 = pl.Series(385"datetime", [datetime(2021, 1, 2, 3, 4, 5), datetime(2021, 2, 3, 4, 5, 6)]386)387s2 = pl.datetime_range(388datetime(2021, 1, 1, 0),389datetime(2021, 1, 1, 1),390interval="1h",391time_unit="ms",392eager=True,393)394assert str(s0.to_numpy()) == "['2308-04-02' '2746-02-20' '1973-05-28']"395assert (396str(s1.to_numpy()[:2])397== "['2021-01-02T03:04:05.000000' '2021-02-03T04:05:06.000000']"398)399assert (400str(s2.to_numpy()[:2])401== "['2021-01-01T00:00:00.000' '2021-01-01T01:00:00.000']"402)403s3 = pl.Series([timedelta(hours=1), timedelta(hours=-2)])404out = np.array([3_600_000_000_000, -7_200_000_000_000], dtype="timedelta64[ns]")405assert (s3.to_numpy() == out).all()406407408@given(409s=series(410min_size=1,411max_size=10,412excluded_dtypes=[413pl.Float16,414pl.Int128,415pl.UInt128,416pl.Categorical,417pl.List,418pl.Struct,419pl.Datetime("ms"),420pl.Duration("ms"),421],422allow_null=False,423allow_time_zones=False, # NumPy does not support parsing time zone aware data424).filter(425lambda s: (426not (s.dtype == pl.String and s.str.contains("\x00").any())427and not (s.dtype == pl.Binary and s.bin.contains(b"\x00").any())428)429),430)431@settings(max_examples=250)432def test_series_to_numpy(s: pl.Series) -> None:433result = s.to_numpy()434435values = s.to_list()436dtype_map = {437pl.Datetime("ns"): "datetime64[ns]",438pl.Datetime("us"): "datetime64[us]",439pl.Duration("ns"): "timedelta64[ns]",440pl.Duration("us"): "timedelta64[us]",441pl.Null(): "float32",442}443np_dtype = dtype_map.get(s.dtype)444expected = np.array(values, dtype=np_dtype)445446assert_array_equal(result, expected)447448449@pytest.mark.parametrize("writable", [False, True])450@pytest.mark.parametrize("pyarrow_available", [False, True])451def test_to_numpy2(452writable: bool, pyarrow_available: bool, plmonkeypatch: PlMonkeyPatch453) -> None:454plmonkeypatch.setattr(pl.series.series, "_PYARROW_AVAILABLE", pyarrow_available)455456np_array = pl.Series("a", [1, 2, 3], pl.UInt8).to_numpy(writable=writable)457458np.testing.assert_array_equal(np_array, np.array([1, 2, 3], dtype=np.uint8))459# Test if numpy array is readonly or writable.460assert np_array.flags.writeable == writable461462if writable:463np_array[1] += 10464np.testing.assert_array_equal(np_array, np.array([1, 12, 3], dtype=np.uint8))465466np_array_with_missing_values = pl.Series("a", [None, 2, 3], pl.UInt8).to_numpy(467writable=writable468)469470np.testing.assert_array_equal(471np_array_with_missing_values,472np.array(473[np.nan, 2.0, 3.0],474dtype=(np.float64 if pyarrow_available else np.float32),475),476)477478if writable:479# As Null values can't be encoded natively in a numpy array,480# this array will never be a view.481assert np_array_with_missing_values.flags.writeable == writable482483484def test_to_numpy_series_indexed_18986() -> None:485df = pl.DataFrame({"a": [[4, 5, 6], [7, 8, 9, 10], None]})486assert (df[1].to_numpy()[0, 0] == np.array([7, 8, 9, 10])).all()487assert (488df.to_numpy()[2] == np.array([None])489).all() # this one is strange, but only option in numpy?490491492