Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/interop/numpy/test_ufunc_series.py
8424 views
1
from collections.abc import Callable
2
from typing import Any, cast
3
4
import numpy as np
5
import pytest
6
from numpy.testing import assert_array_equal
7
8
import polars as pl
9
from polars.exceptions import ComputeError
10
from polars.testing import assert_series_equal
11
12
13
def test_ufunc() -> None:
14
# test if output dtype is calculated correctly.
15
s_float32 = pl.Series("a", [1.0, 2.0, 3.0, 4.0], dtype=pl.Float32)
16
assert_series_equal(
17
cast("pl.Series", np.multiply(s_float32, 4)),
18
pl.Series("a", [4.0, 8.0, 12.0, 16.0], dtype=pl.Float32),
19
)
20
21
s_float64 = pl.Series("a", [1.0, 2.0, 3.0, 4.0], dtype=pl.Float64)
22
assert_series_equal(
23
cast("pl.Series", np.multiply(s_float64, 4)),
24
pl.Series("a", [4.0, 8.0, 12.0, 16.0], dtype=pl.Float64),
25
)
26
27
s_uint8 = pl.Series("a", [1, 2, 3, 4], dtype=pl.UInt8)
28
assert_series_equal(
29
cast("pl.Series", np.power(s_uint8, 2)),
30
pl.Series("a", [1, 4, 9, 16], dtype=pl.UInt8),
31
)
32
assert_series_equal(
33
cast("pl.Series", np.power(s_uint8, 2.0)),
34
pl.Series("a", [1.0, 4.0, 9.0, 16.0], dtype=pl.Float64),
35
)
36
assert_series_equal(
37
cast("pl.Series", np.power(s_uint8, 2, dtype=np.uint16)),
38
pl.Series("a", [1, 4, 9, 16], dtype=pl.UInt16),
39
)
40
41
s_int8 = pl.Series("a", [1, -2, 3, -4], dtype=pl.Int8)
42
assert_series_equal(
43
cast("pl.Series", np.power(s_int8, 2)),
44
pl.Series("a", [1, 4, 9, 16], dtype=pl.Int8),
45
)
46
assert_series_equal(
47
cast("pl.Series", np.power(s_int8, 2.0)),
48
pl.Series("a", [1.0, 4.0, 9.0, 16.0], dtype=pl.Float64),
49
)
50
assert_series_equal(
51
cast("pl.Series", np.power(s_int8, 2, dtype=np.int16)),
52
pl.Series("a", [1, 4, 9, 16], dtype=pl.Int16),
53
)
54
55
s_uint32 = pl.Series("a", [1, 2, 3, 4], dtype=pl.UInt32)
56
assert_series_equal(
57
cast("pl.Series", np.power(s_uint32, 2)),
58
pl.Series("a", [1, 4, 9, 16], dtype=pl.UInt32),
59
)
60
assert_series_equal(
61
cast("pl.Series", np.power(s_uint32, 2.0)),
62
pl.Series("a", [1.0, 4.0, 9.0, 16.0], dtype=pl.Float64),
63
)
64
65
s_int32 = pl.Series("a", [1, -2, 3, -4], dtype=pl.Int32)
66
assert_series_equal(
67
cast("pl.Series", np.power(s_int32, 2)),
68
pl.Series("a", [1, 4, 9, 16], dtype=pl.Int32),
69
)
70
assert_series_equal(
71
cast("pl.Series", np.power(s_int32, 2.0)),
72
pl.Series("a", [1.0, 4.0, 9.0, 16.0], dtype=pl.Float64),
73
)
74
75
s_uint64 = pl.Series("a", [1, 2, 3, 4], dtype=pl.UInt64)
76
assert_series_equal(
77
cast("pl.Series", np.power(s_uint64, 2)),
78
pl.Series("a", [1, 4, 9, 16], dtype=pl.UInt64),
79
)
80
assert_series_equal(
81
cast("pl.Series", np.power(s_uint64, 2.0)),
82
pl.Series("a", [1.0, 4.0, 9.0, 16.0], dtype=pl.Float64),
83
)
84
85
s_int64 = pl.Series("a", [1, -2, 3, -4], dtype=pl.Int64)
86
assert_series_equal(
87
cast("pl.Series", np.power(s_int64, 2)),
88
pl.Series("a", [1, 4, 9, 16], dtype=pl.Int64),
89
)
90
assert_series_equal(
91
cast("pl.Series", np.power(s_int64, 2.0)),
92
pl.Series("a", [1.0, 4.0, 9.0, 16.0], dtype=pl.Float64),
93
)
94
95
# test if null bitmask is preserved
96
a1 = pl.Series("a", [1.0, None, 3.0])
97
b1 = cast("pl.Series", np.exp(a1))
98
assert b1.null_count() == 1
99
100
# test if it works with chunked series.
101
a2 = pl.Series("a", [1.0, None, 3.0])
102
b2 = pl.Series("b", [4.0, 5.0, None])
103
a2.append(b2)
104
assert a2.n_chunks() == 2
105
c2 = np.multiply(a2, 3)
106
assert_series_equal(
107
cast("pl.Series", c2),
108
pl.Series("a", [3.0, None, 9.0, 12.0, 15.0, None]),
109
)
110
111
# Test if nulls propagate through ufuncs
112
a3 = pl.Series("a", [None, None, 3, 3])
113
b3 = pl.Series("b", [None, 3, None, 3])
114
assert_series_equal(
115
cast("pl.Series", np.maximum(a3, b3)), pl.Series("a", [None, None, None, 3])
116
)
117
118
119
def test_numpy_string_array() -> None:
120
s_str = pl.Series("a", ["aa", "bb", "cc", "dd"], dtype=pl.String)
121
assert_array_equal(
122
np.char.capitalize(s_str),
123
np.array(["Aa", "Bb", "Cc", "Dd"], dtype="<U2"),
124
)
125
126
127
def make_add_one() -> Callable[[pl.Series], pl.Series]:
128
numba = pytest.importorskip("numba", exc_type=ImportError)
129
130
@numba.guvectorize([(numba.float64[:], numba.float64[:])], "(n)->(n)") # type: ignore[misc, untyped-decorator]
131
def add_one(arr: Any, result: Any) -> None:
132
for i in range(len(arr)):
133
result[i] = arr[i] + 1.0
134
135
return add_one # type: ignore[no-any-return]
136
137
138
def test_generalized_ufunc() -> None:
139
"""A generalized ufunc can be called on a pl.Series."""
140
add_one = make_add_one()
141
s_float = pl.Series("f", [1.0, 2.0, 3.0])
142
result = add_one(s_float)
143
assert_series_equal(result, pl.Series("f", [2.0, 3.0, 4.0]))
144
145
146
def test_generalized_ufunc_missing_data() -> None:
147
"""
148
If a pl.Series is missing data, using a generalized ufunc is not allowed.
149
150
While this particular example isn't necessarily a semantic issue, consider
151
a mean() function running on integers: it will give wrong results if the
152
input is missing data, since NumPy has no way to model missing slots. In
153
the general case, we can't assume the function will handle missing data
154
correctly.
155
"""
156
add_one = make_add_one()
157
s_float = pl.Series("f", [1.0, 2.0, 3.0, None], dtype=pl.Float64)
158
with pytest.raises(
159
ComputeError,
160
match="can't pass a Series with missing data to a generalized ufunc",
161
):
162
add_one(s_float)
163
164
165
def make_divide_by_sum() -> Callable[[pl.Series, pl.Series], pl.Series]:
166
numba = pytest.importorskip("numba", exc_type=ImportError)
167
float64 = numba.float64
168
169
@numba.guvectorize([(float64[:], float64[:], float64[:])], "(n),(m)->(m)") # type: ignore[misc, untyped-decorator]
170
def divide_by_sum(arr: Any, arr2: Any, result: Any) -> None:
171
total = arr.sum()
172
for i in range(len(arr2)):
173
result[i] = arr2[i] / total
174
175
return divide_by_sum # type: ignore[no-any-return]
176
177
178
def test_generalized_ufunc_different_output_size() -> None:
179
"""
180
It's possible to call a generalized ufunc that takes pl.Series of different sizes.
181
182
The result has the correct size.
183
"""
184
divide_by_sum = make_divide_by_sum()
185
186
series = pl.Series("s", [1.0, 3.0], dtype=pl.Float64)
187
series2 = pl.Series("s2", [8.0, 16.0, 32.0], dtype=pl.Float64)
188
assert_series_equal(
189
divide_by_sum(series, series2),
190
pl.Series("s", [2.0, 4.0, 8.0], dtype=pl.Float64),
191
)
192
assert_series_equal(
193
divide_by_sum(series2, series),
194
pl.Series("s2", [1.0 / 56, 3.0 / 56], dtype=pl.Float64),
195
)
196
197