Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/interop/numpy/test_ufunc_series.py
6939 views
1
from typing import Any, Callable, cast
2
3
import numpy as np
4
import pytest
5
from numpy.testing import assert_array_equal
6
7
import polars as pl
8
from polars.exceptions import ComputeError
9
from polars.testing import assert_series_equal
10
11
12
def test_ufunc() -> None:
13
# test if output dtype is calculated correctly.
14
s_float32 = pl.Series("a", [1.0, 2.0, 3.0, 4.0], dtype=pl.Float32)
15
assert_series_equal(
16
cast(pl.Series, np.multiply(s_float32, 4)),
17
pl.Series("a", [4.0, 8.0, 12.0, 16.0], dtype=pl.Float32),
18
)
19
20
s_float64 = pl.Series("a", [1.0, 2.0, 3.0, 4.0], dtype=pl.Float64)
21
assert_series_equal(
22
cast(pl.Series, np.multiply(s_float64, 4)),
23
pl.Series("a", [4.0, 8.0, 12.0, 16.0], dtype=pl.Float64),
24
)
25
26
s_uint8 = pl.Series("a", [1, 2, 3, 4], dtype=pl.UInt8)
27
assert_series_equal(
28
cast(pl.Series, np.power(s_uint8, 2)),
29
pl.Series("a", [1, 4, 9, 16], dtype=pl.UInt8),
30
)
31
assert_series_equal(
32
cast(pl.Series, np.power(s_uint8, 2.0)),
33
pl.Series("a", [1.0, 4.0, 9.0, 16.0], dtype=pl.Float64),
34
)
35
assert_series_equal(
36
cast(pl.Series, np.power(s_uint8, 2, dtype=np.uint16)),
37
pl.Series("a", [1, 4, 9, 16], dtype=pl.UInt16),
38
)
39
40
s_int8 = pl.Series("a", [1, -2, 3, -4], dtype=pl.Int8)
41
assert_series_equal(
42
cast(pl.Series, np.power(s_int8, 2)),
43
pl.Series("a", [1, 4, 9, 16], dtype=pl.Int8),
44
)
45
assert_series_equal(
46
cast(pl.Series, np.power(s_int8, 2.0)),
47
pl.Series("a", [1.0, 4.0, 9.0, 16.0], dtype=pl.Float64),
48
)
49
assert_series_equal(
50
cast(pl.Series, np.power(s_int8, 2, dtype=np.int16)),
51
pl.Series("a", [1, 4, 9, 16], dtype=pl.Int16),
52
)
53
54
s_uint32 = pl.Series("a", [1, 2, 3, 4], dtype=pl.UInt32)
55
assert_series_equal(
56
cast(pl.Series, np.power(s_uint32, 2)),
57
pl.Series("a", [1, 4, 9, 16], dtype=pl.UInt32),
58
)
59
assert_series_equal(
60
cast(pl.Series, np.power(s_uint32, 2.0)),
61
pl.Series("a", [1.0, 4.0, 9.0, 16.0], dtype=pl.Float64),
62
)
63
64
s_int32 = pl.Series("a", [1, -2, 3, -4], dtype=pl.Int32)
65
assert_series_equal(
66
cast(pl.Series, np.power(s_int32, 2)),
67
pl.Series("a", [1, 4, 9, 16], dtype=pl.Int32),
68
)
69
assert_series_equal(
70
cast(pl.Series, np.power(s_int32, 2.0)),
71
pl.Series("a", [1.0, 4.0, 9.0, 16.0], dtype=pl.Float64),
72
)
73
74
s_uint64 = pl.Series("a", [1, 2, 3, 4], dtype=pl.UInt64)
75
assert_series_equal(
76
cast(pl.Series, np.power(s_uint64, 2)),
77
pl.Series("a", [1, 4, 9, 16], dtype=pl.UInt64),
78
)
79
assert_series_equal(
80
cast(pl.Series, np.power(s_uint64, 2.0)),
81
pl.Series("a", [1.0, 4.0, 9.0, 16.0], dtype=pl.Float64),
82
)
83
84
s_int64 = pl.Series("a", [1, -2, 3, -4], dtype=pl.Int64)
85
assert_series_equal(
86
cast(pl.Series, np.power(s_int64, 2)),
87
pl.Series("a", [1, 4, 9, 16], dtype=pl.Int64),
88
)
89
assert_series_equal(
90
cast(pl.Series, np.power(s_int64, 2.0)),
91
pl.Series("a", [1.0, 4.0, 9.0, 16.0], dtype=pl.Float64),
92
)
93
94
# test if null bitmask is preserved
95
a1 = pl.Series("a", [1.0, None, 3.0])
96
b1 = cast(pl.Series, np.exp(a1))
97
assert b1.null_count() == 1
98
99
# test if it works with chunked series.
100
a2 = pl.Series("a", [1.0, None, 3.0])
101
b2 = pl.Series("b", [4.0, 5.0, None])
102
a2.append(b2)
103
assert a2.n_chunks() == 2
104
c2 = np.multiply(a2, 3)
105
assert_series_equal(
106
cast(pl.Series, c2),
107
pl.Series("a", [3.0, None, 9.0, 12.0, 15.0, None]),
108
)
109
110
# Test if nulls propagate through ufuncs
111
a3 = pl.Series("a", [None, None, 3, 3])
112
b3 = pl.Series("b", [None, 3, None, 3])
113
assert_series_equal(
114
cast(pl.Series, np.maximum(a3, b3)), pl.Series("a", [None, None, None, 3])
115
)
116
117
118
def test_numpy_string_array() -> None:
119
s_str = pl.Series("a", ["aa", "bb", "cc", "dd"], dtype=pl.String)
120
assert_array_equal(
121
np.char.capitalize(s_str),
122
np.array(["Aa", "Bb", "Cc", "Dd"], dtype="<U2"),
123
)
124
125
126
def make_add_one() -> Callable[[pl.Series], pl.Series]:
127
numba = pytest.importorskip("numba")
128
129
@numba.guvectorize([(numba.float64[:], numba.float64[:])], "(n)->(n)") # type: ignore[misc]
130
def add_one(arr: Any, result: Any) -> None:
131
for i in range(len(arr)):
132
result[i] = arr[i] + 1.0
133
134
return add_one # type: ignore[no-any-return]
135
136
137
def test_generalized_ufunc() -> None:
138
"""A generalized ufunc can be called on a pl.Series."""
139
add_one = make_add_one()
140
s_float = pl.Series("f", [1.0, 2.0, 3.0])
141
result = add_one(s_float)
142
assert_series_equal(result, pl.Series("f", [2.0, 3.0, 4.0]))
143
144
145
def test_generalized_ufunc_missing_data() -> None:
146
"""
147
If a pl.Series is missing data, using a generalized ufunc is not allowed.
148
149
While this particular example isn't necessarily a semantic issue, consider
150
a mean() function running on integers: it will give wrong results if the
151
input is missing data, since NumPy has no way to model missing slots. In
152
the general case, we can't assume the function will handle missing data
153
correctly.
154
"""
155
add_one = make_add_one()
156
s_float = pl.Series("f", [1.0, 2.0, 3.0, None], dtype=pl.Float64)
157
with pytest.raises(
158
ComputeError,
159
match="can't pass a Series with missing data to a generalized ufunc",
160
):
161
add_one(s_float)
162
163
164
def make_divide_by_sum() -> Callable[[pl.Series, pl.Series], pl.Series]:
165
numba = pytest.importorskip("numba")
166
float64 = numba.float64
167
168
@numba.guvectorize([(float64[:], float64[:], float64[:])], "(n),(m)->(m)") # type: ignore[misc]
169
def divide_by_sum(arr: Any, arr2: Any, result: Any) -> None:
170
total = arr.sum()
171
for i in range(len(arr2)):
172
result[i] = arr2[i] / total
173
174
return divide_by_sum # type: ignore[no-any-return]
175
176
177
def test_generalized_ufunc_different_output_size() -> None:
178
"""
179
It's possible to call a generalized ufunc that takes pl.Series of different sizes.
180
181
The result has the correct size.
182
"""
183
divide_by_sum = make_divide_by_sum()
184
185
series = pl.Series("s", [1.0, 3.0], dtype=pl.Float64)
186
series2 = pl.Series("s2", [8.0, 16.0, 32.0], dtype=pl.Float64)
187
assert_series_equal(
188
divide_by_sum(series, series2),
189
pl.Series("s", [2.0, 4.0, 8.0], dtype=pl.Float64),
190
)
191
assert_series_equal(
192
divide_by_sum(series2, series),
193
pl.Series("s2", [1.0 / 56, 3.0 / 56], dtype=pl.Float64),
194
)
195
196