Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/io/test_serde.py
8424 views
1
from __future__ import annotations
2
3
import io
4
import pickle
5
import re
6
from datetime import datetime, timedelta
7
8
import pytest
9
10
import polars as pl
11
from polars.exceptions import SchemaError
12
from polars.testing import assert_frame_equal, assert_series_equal
13
14
15
def test_pickling_simple_expression() -> None:
16
e = pl.col("foo").sum()
17
buf = pickle.dumps(e)
18
assert str(pickle.loads(buf)) == str(e)
19
20
21
def test_pickling_as_struct_11100() -> None:
22
e = pl.struct("a")
23
buf = pickle.dumps(e)
24
assert str(pickle.loads(buf)) == str(e)
25
26
27
def test_serde_time_unit() -> None:
28
values = [datetime(2022, 1, 1) + timedelta(days=1) for _ in range(3)]
29
s = pl.Series(values).cast(pl.Datetime("ns"))
30
result = pickle.loads(pickle.dumps(s))
31
assert result.dtype == pl.Datetime("ns")
32
33
34
def test_serde_duration() -> None:
35
df = (
36
pl.DataFrame(
37
{
38
"a": [
39
datetime(2021, 2, 1, 9, 20),
40
datetime(2021, 2, 2, 9, 20),
41
],
42
"b": [4, 5],
43
}
44
)
45
.with_columns([pl.col("a").cast(pl.Datetime("ns")).alias("a")])
46
.select(pl.all())
47
)
48
df = df.with_columns([pl.col("a").diff(n=1).alias("a_td")])
49
serde_df = pickle.loads(pickle.dumps(df))
50
assert serde_df["a_td"].dtype == pl.Duration("ns")
51
assert_series_equal(
52
serde_df["a_td"],
53
pl.Series("a_td", [None, timedelta(days=1)], dtype=pl.Duration("ns")),
54
)
55
56
57
def test_serde_expression_5461() -> None:
58
e = pl.col("a").sqrt() / pl.col("b").alias("c")
59
assert pickle.loads(pickle.dumps(e)).meta == e.meta
60
61
62
def test_serde_binary() -> None:
63
data = pl.Series(
64
"binary_data",
65
[
66
b"\xba\x9b\xca\xd3y\xcb\xc9#",
67
b"9\x04\xab\xe2\x11\xf3\x85",
68
b"\xb8\xcb\xc9^\\\xa9-\x94\xe0H\x9d ",
69
b"S\xbc:\xcb\xf0\xf5r\xfe\x18\xfeH",
70
b",\xf5)y\x00\xe5\xf7",
71
b"\xfd\xf6\xf1\xc2X\x0cn\xb9#",
72
b"\x06\xef\xa6\xa2\xb7",
73
b"@\xff\x95\xda\xff\xd2\x18",
74
],
75
)
76
assert_series_equal(
77
data,
78
pickle.loads(pickle.dumps(data)),
79
)
80
81
82
def test_pickle_lazyframe() -> None:
83
q = pl.LazyFrame({"a": [1, 4, 3]}).sort("a")
84
85
s = pickle.dumps(q)
86
assert_frame_equal(pickle.loads(s).collect(), pl.DataFrame({"a": [1, 3, 4]}))
87
88
89
def test_deser_empty_list() -> None:
90
s = pickle.loads(pickle.dumps(pl.Series([[[42.0]], []])))
91
assert s.dtype == pl.List(pl.List(pl.Float64))
92
assert s.to_list() == [[[42.0]], []]
93
94
95
def times2(x: pl.Series) -> pl.Series:
96
return x * 2
97
98
99
@pytest.mark.may_fail_cloud # reason: eager - return_dtype must be set
100
def test_pickle_udf_expression() -> None:
101
df = pl.DataFrame({"a": [1, 2, 3]})
102
103
e = pl.col("a").map_batches(times2)
104
b = pickle.dumps(e)
105
e = pickle.loads(b)
106
107
result = df.select(e)
108
expected = pl.DataFrame({"a": [2, 4, 6]})
109
assert_frame_equal(result, expected)
110
111
e = pl.col("a").map_batches(times2, return_dtype=pl.String)
112
b = pickle.dumps(e)
113
e = pickle.loads(b)
114
115
# tests that 'GetOutput' is also deserialized
116
with pytest.raises(
117
SchemaError,
118
match=r"expected output type 'String', got 'Int64'; set `return_dtype` to the proper datatype",
119
):
120
df.select(e)
121
122
123
def test_pickle_small_integers() -> None:
124
df = pl.DataFrame(
125
[
126
pl.Series([1, 2], dtype=pl.Int16),
127
pl.Series([3, 2], dtype=pl.Int8),
128
pl.Series([32, 2], dtype=pl.UInt8),
129
pl.Series([3, 3], dtype=pl.UInt16),
130
]
131
)
132
b = pickle.dumps(df)
133
assert_frame_equal(pickle.loads(b), df)
134
135
136
def df_times2(df: pl.DataFrame) -> pl.DataFrame:
137
return df.select(pl.all() * 2)
138
139
140
def test_pickle_lazyframe_udf() -> None:
141
df = pl.DataFrame({"a": [1, 2, 3]})
142
143
q = df.lazy().map_batches(df_times2)
144
b = pickle.dumps(q)
145
146
q = pickle.loads(b)
147
assert q.collect()["a"].to_list() == [2, 4, 6]
148
149
150
def test_pickle_lazyframe_nested_function_udf() -> None:
151
df = pl.DataFrame({"a": [1, 2, 3]})
152
153
# NOTE: This is only possible when we're using cloudpickle.
154
def inner_df_times2(df: pl.DataFrame) -> pl.DataFrame:
155
return df.select(pl.all() * 2)
156
157
q = df.lazy().map_batches(inner_df_times2)
158
b = pickle.dumps(q)
159
160
q = pickle.loads(b)
161
assert q.collect()["a"].to_list() == [2, 4, 6]
162
163
164
def test_serde_categorical_series_10586() -> None:
165
s = pl.Series(["a", "b", "b", "a", "c"], dtype=pl.Categorical)
166
loaded_s = pickle.loads(pickle.dumps(s))
167
assert_series_equal(loaded_s, s)
168
169
170
def test_serde_keep_dtype_empty_list() -> None:
171
s = pl.Series([{"a": None}], dtype=pl.Struct([pl.Field("a", pl.List(pl.String))]))
172
assert s.dtype == pickle.loads(pickle.dumps(s)).dtype
173
174
175
def test_serde_array_dtype() -> None:
176
s = pl.Series(
177
[[1, 2, 3], [None, None, None], [1, None, 3]],
178
dtype=pl.Array(pl.Int32(), 3),
179
)
180
assert_series_equal(pickle.loads(pickle.dumps(s)), s)
181
182
nested_s = pl.Series(
183
[[[1, 2, 3], [4, None, 5]], None, [[None, None, 2]]],
184
dtype=pl.List(pl.Array(pl.Int32(), 3)),
185
)
186
assert_series_equal(pickle.loads(pickle.dumps(nested_s)), nested_s)
187
188
189
def test_serde_data_type_class() -> None:
190
dtype = pl.Datetime
191
serialized = pickle.dumps(dtype)
192
deserialized = pickle.loads(serialized)
193
assert deserialized == dtype
194
assert isinstance(deserialized, type)
195
196
197
def test_serde_data_type_instantiated() -> None:
198
dtype = pl.Int8()
199
serialized = pickle.dumps(dtype)
200
deserialized = pickle.loads(serialized)
201
assert deserialized == dtype
202
assert isinstance(deserialized, pl.DataType)
203
204
205
def test_serde_data_type_instantiated_with_attributes() -> None:
206
dtype = pl.Enum(["a", "b"])
207
serialized = pickle.dumps(dtype)
208
deserialized = pickle.loads(serialized)
209
assert deserialized == dtype
210
assert isinstance(deserialized, pl.DataType)
211
212
213
def test_serde_udf() -> None:
214
lf = pl.LazyFrame({"a": [[1, 2], [3, 4, 5]], "b": [3, 4]}).select(
215
pl.col("a").map_elements(lambda x: sum(x), return_dtype=pl.Int32)
216
)
217
result = pl.LazyFrame.deserialize(io.BytesIO(lf.serialize()))
218
219
assert_frame_equal(lf, result)
220
221
222
def test_serde_empty_df_lazy_frame() -> None:
223
lf = pl.LazyFrame()
224
f = io.BytesIO()
225
f.write(lf.serialize())
226
f.seek(0)
227
assert pl.LazyFrame.deserialize(f).collect().shape == (0, 0)
228
229
230
def test_pickle_class_objects_21021() -> None:
231
assert isinstance(pickle.loads(pickle.dumps(pl.col))("A"), pl.Expr)
232
assert isinstance(pickle.loads(pickle.dumps(pl.DataFrame))(), pl.DataFrame)
233
assert isinstance(pickle.loads(pickle.dumps(pl.LazyFrame))(), pl.LazyFrame)
234
235
236
@pytest.mark.slow
237
def test_serialize_does_not_overflow_stack() -> None:
238
n = 2000
239
lf = pl.DataFrame({"a": 0}).lazy()
240
241
for i in range(1, n):
242
lf = pl.concat([lf, pl.DataFrame({"a": i}).lazy()])
243
244
f = io.BytesIO()
245
f.write(lf.serialize())
246
f.seek(0)
247
actual = pl.LazyFrame.deserialize(f).collect()
248
expected = pl.DataFrame({"a": range(n)})
249
assert_frame_equal(actual, expected)
250
251
252
def test_lf_cache_serde() -> None:
253
lf = pl.LazyFrame({"a": [1, 2, 3]}).cache()
254
lf = pl.concat([lf, lf])
255
256
ser = lf.serialize()
257
de = pl.LazyFrame.deserialize(io.BytesIO(ser))
258
259
e1 = de.explain()
260
e2 = de.explain(optimizations=pl.QueryOptFlags.none())
261
262
rgx = re.compile(r"CACHE\[id: (.*)\]")
263
264
e1_matches = rgx.findall(e1)
265
e2_matches = rgx.findall(e2)
266
267
# there are only 2 caches
268
assert len(e1_matches) == 2
269
assert len(e2_matches) == 2
270
271
# all caches are the same
272
assert e1_matches[0] == e1_matches[1]
273
assert e2_matches[0] == e2_matches[1]
274
275