Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/ml/test_to_jax.py
6940 views
1
from __future__ import annotations
2
3
from typing import TYPE_CHECKING, Any
4
5
import pytest
6
7
import polars as pl
8
import polars.selectors as cs
9
from polars.dependencies import _lazy_import
10
11
# don't import jax until an actual test is triggered (the decorator already
12
# ensures the tests aren't run locally; this avoids premature local import)
13
jx, _ = _lazy_import("jax")
14
jxn, _ = _lazy_import("jax.numpy")
15
16
pytestmark = pytest.mark.ci_only
17
18
if TYPE_CHECKING:
19
from polars._typing import PolarsDataType
20
21
22
@pytest.fixture
23
def df() -> pl.DataFrame:
24
return pl.DataFrame(
25
{
26
"x": [1, 2, 2, 3],
27
"y": [1, 0, 1, 0],
28
"z": [1.5, -0.5, 0.0, -2.0],
29
},
30
schema_overrides={"x": pl.Int8, "z": pl.Float32},
31
)
32
33
34
def assert_array_equal(actual: Any, expected: Any, nans_equal: bool = True) -> None:
35
assert isinstance(actual, jx.Array)
36
jxn.array_equal(actual, expected, equal_nan=nans_equal)
37
38
39
@pytest.mark.parametrize(
40
("dtype", "expected_jax_dtype"),
41
[
42
(pl.Int8, "int8"),
43
(pl.Int16, "int16"),
44
(pl.Int32, "int32"),
45
(pl.Int64, "int32"),
46
(pl.UInt8, "uint8"),
47
(pl.UInt16, "uint16"),
48
(pl.UInt32, "uint32"),
49
(pl.UInt64, "uint32"),
50
],
51
)
52
def test_to_jax_from_series(
53
dtype: PolarsDataType,
54
expected_jax_dtype: str,
55
) -> None:
56
s = pl.Series("x", [1, 2, 3, 4], dtype=dtype)
57
for dvc in (None, "cpu", jx.devices("cpu")[0]):
58
assert_array_equal(
59
s.to_jax(device=dvc),
60
jxn.array([1, 2, 3, 4], dtype=getattr(jxn, expected_jax_dtype)),
61
)
62
63
64
def test_to_jax_array(df: pl.DataFrame) -> None:
65
a1 = df.to_jax()
66
a2 = df.to_jax("array")
67
a3 = df.to_jax("array", device="cpu")
68
a4 = df.to_jax("array", device=jx.devices("cpu")[0])
69
70
expected = jxn.array(
71
[
72
[1.0, 1.0, 1.5],
73
[2.0, 0.0, -0.5],
74
[2.0, 1.0, 0.0],
75
[3.0, 0.0, -2.0],
76
],
77
dtype=jxn.float32,
78
)
79
for a in (a1, a2, a3, a4):
80
assert_array_equal(a, expected)
81
82
83
def test_2D_array_cols_to_jax() -> None:
84
# 2D array
85
df1 = pl.DataFrame(
86
{"data": [[1, 1], [1, 2], [2, 2]]},
87
schema_overrides={"data": pl.Array(pl.Int32, shape=(2,))},
88
)
89
arr1 = df1.to_jax()
90
assert_array_equal(
91
arr1,
92
jxn.array([[1, 1], [1, 2], [2, 2]], dtype=jxn.int32),
93
)
94
95
# nested 2D array
96
df2 = pl.DataFrame(
97
{"data": [[[1, 1], [1, 2]], [[2, 2], [2, 3]]]},
98
schema_overrides={"data": pl.Array(pl.Array(pl.Int32, shape=(2,)), shape=(2,))},
99
)
100
arr2 = df2.to_jax()
101
assert_array_equal(
102
arr2,
103
jxn.array([[[1, 1], [1, 2]], [[2, 2], [2, 3]]], dtype=jxn.int32),
104
)
105
106
# dict with 2D array
107
df3 = df2.insert_column(0, pl.Series("lbl", [0, 1]))
108
lbl_feat_dict = df3.to_jax("dict")
109
assert_array_equal(
110
lbl_feat_dict["lbl"],
111
jxn.array([0, 1], jxn.int32),
112
)
113
assert_array_equal(
114
lbl_feat_dict["data"],
115
jxn.array([[[1, 1], [1, 2]], [[2, 2], [2, 3]]], jxn.int32),
116
)
117
118
# no support for list (yet? could add if ragged arrays are valid)
119
with pytest.raises(
120
TypeError,
121
match=r"cannot convert List column 'data' to Jax Array \(use Array dtype instead\)",
122
):
123
pl.DataFrame({"data": [[1, 1], [1, 2], [2, 2]]}).to_jax()
124
125
126
def test_to_jax_dict(df: pl.DataFrame) -> None:
127
arr_dict = df.to_jax("dict")
128
assert list(arr_dict.keys()) == ["x", "y", "z"]
129
130
assert_array_equal(arr_dict["x"], jxn.array([1, 2, 2, 3], dtype=jxn.int8))
131
assert_array_equal(arr_dict["y"], jxn.array([1, 0, 1, 0], dtype=jxn.int32))
132
assert_array_equal(
133
arr_dict["z"],
134
jxn.array([1.5, -0.5, 0.0, -2.0], dtype=jxn.float32),
135
)
136
137
arr_dict = df.to_jax("dict", dtype=pl.Float32)
138
for a, expected_data in zip(
139
arr_dict.values(),
140
([1.0, 2.0, 2.0, 3.0], [1.0, 0.0, 1.0, 0.0], [1.5, -0.5, 0.0, -2.0]),
141
):
142
assert_array_equal(a, jxn.array(expected_data, dtype=jxn.float32))
143
144
145
def test_to_jax_feature_label_dict(df: pl.DataFrame) -> None:
146
df = pl.DataFrame(
147
{
148
"age": [25, 32, 45, 22, 34],
149
"income": [50000, 75000, 60000, 58000, 120000],
150
"education": ["bachelor", "master", "phd", "bachelor", "phd"],
151
"purchased": [False, True, True, False, True],
152
}
153
).to_dummies("education", separator=":")
154
155
lbl_feat_dict = df.to_jax(return_type="dict", label="purchased")
156
assert list(lbl_feat_dict.keys()) == ["label", "features"]
157
158
assert_array_equal(
159
lbl_feat_dict["label"],
160
jxn.array([[False], [True], [True], [False], [True]], dtype=jxn.bool),
161
)
162
assert_array_equal(
163
lbl_feat_dict["features"],
164
jxn.array(
165
[
166
[25, 50000, 1, 0, 0],
167
[32, 75000, 0, 1, 0],
168
[45, 60000, 0, 0, 1],
169
[22, 58000, 1, 0, 0],
170
[34, 120000, 0, 0, 1],
171
],
172
dtype=jxn.int32,
173
),
174
)
175
176
177
def test_misc_errors(df: pl.DataFrame) -> None:
178
with pytest.raises(
179
ValueError,
180
match="invalid `return_type`: 'stroopwafel'",
181
):
182
_res0 = df.to_jax("stroopwafel") # type: ignore[call-overload]
183
184
with pytest.raises(
185
ValueError,
186
match="`label` is required if setting `features` when `return_type='dict'",
187
):
188
_res2 = df.to_jax("dict", features=cs.float())
189
190
with pytest.raises(
191
ValueError,
192
match="`label` and `features` only apply when `return_type` is 'dict'",
193
):
194
_res3 = df.to_jax(label="stroopwafel")
195
196