Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/interop/numpy/test_from_numpy_series.py
8420 views
1
from __future__ import annotations
2
3
from datetime import timedelta
4
from typing import TYPE_CHECKING
5
6
import numpy as np
7
import pytest
8
from numpy.testing import assert_array_equal
9
10
import polars as pl
11
12
if TYPE_CHECKING:
13
from polars._typing import TimeUnit
14
15
16
@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"])
17
def test_from_numpy_timedelta(time_unit: TimeUnit) -> None:
18
s = pl.Series(
19
"name",
20
np.array(
21
[timedelta(days=1), timedelta(seconds=1)], dtype=f"timedelta64[{time_unit}]"
22
),
23
)
24
assert s.dtype == pl.Duration(time_unit)
25
assert s.name == "name"
26
assert s.dt[0] == timedelta(days=1)
27
assert s.dt[1] == timedelta(seconds=1)
28
29
30
def test_from_numpy_records() -> None:
31
# numpy arrays in dicts/records
32
arr_int = np.array([1, 2, 3], dtype=np.int64)
33
arr_float = np.array([1.1, 2.2, 3.3], dtype=np.float64)
34
arr_bool = np.array([True, False, True], dtype=np.bool_)
35
36
s = pl.Series(
37
name="data",
38
values=[{"ints": arr_int, "floats": arr_float, "bools": arr_bool}],
39
)
40
assert s.dtype == pl.Struct(
41
{
42
"ints": pl.List(pl.Int64),
43
"floats": pl.List(pl.Float64),
44
"bools": pl.List(pl.Boolean),
45
}
46
)
47
round_trip_array = s.to_frame().unnest("data").row(0)
48
assert_array_equal(
49
round_trip_array,
50
[arr_int, arr_float, arr_bool],
51
)
52
53
data = [
54
{"id": 1, "values": np.array([1, 2, 3], dtype=np.int64)},
55
{"id": 2, "values": np.array([4, 5, 6], dtype=np.int64)},
56
{"id": 3, "values": np.array([7, 8, 9], dtype=np.int64)},
57
]
58
s = pl.Series("data", data)
59
assert s.dtype == pl.Struct({"id": pl.Int64, "values": pl.List(pl.Int64)})
60
assert len(s) == 3
61
62
63
@pytest.mark.parametrize(
64
("numpy_dtype", "polars_dtype"),
65
[
66
(np.int8, pl.Int8),
67
(np.int16, pl.Int16),
68
(np.int32, pl.Int32),
69
(np.int64, pl.Int64),
70
(np.uint8, pl.UInt8),
71
(np.uint16, pl.UInt16),
72
(np.uint32, pl.UInt32),
73
(np.uint64, pl.UInt64),
74
(np.float16, pl.Float16),
75
(np.float32, pl.Float32),
76
(np.float64, pl.Float64),
77
(np.bool_, pl.Boolean),
78
],
79
)
80
def test_from_numpy_records_2d(
81
numpy_dtype: type[np.generic], polars_dtype: pl.DataType
82
) -> None:
83
arr2d = np.array([[0, 1], [2, 3]], dtype=numpy_dtype)
84
s = pl.Series("data", [{"id": 1, "values": arr2d}])
85
86
assert s.dtype == pl.Struct(
87
{"id": pl.Int64, "values": pl.List(pl.Array(polars_dtype, shape=(2,)))}
88
)
89
expected_array_values = (
90
[[False, True], [True, True]]
91
if polars_dtype == pl.Boolean
92
else [[0, 1], [2, 3]] # type: ignore[list-item]
93
)
94
assert s[0] == {"id": 1, "values": expected_array_values}
95
96
round_trip_array = s.to_numpy()[0][1]
97
assert_array_equal(round_trip_array, arr2d)
98
99