Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/dataframe/test_serde.py
8415 views
1
from __future__ import annotations
2
3
import io
4
import pickle
5
from datetime import date, datetime, timedelta
6
from decimal import Decimal as D
7
from multiprocessing.pool import ThreadPool
8
from typing import TYPE_CHECKING, Any
9
10
import pytest
11
from hypothesis import example, given
12
13
import polars as pl
14
from polars.exceptions import ComputeError
15
from polars.testing import assert_frame_equal
16
from polars.testing.parametric import dataframes
17
18
if TYPE_CHECKING:
19
from pathlib import Path
20
21
from polars._typing import SerializationFormat
22
23
24
def test_df_serde_roundtrip_binary(df: pl.DataFrame) -> None:
25
serialized = df.serialize()
26
result = pl.DataFrame.deserialize(serialized, format="binary")
27
assert_frame_equal(result, df, categorical_as_str=True)
28
29
30
@given(df=dataframes())
31
@example(df=pl.DataFrame({"a": {"a": 1.0}}, schema={"a": pl.Struct({"a": pl.Float16})}))
32
@example(df=pl.DataFrame({"a": [None, None]}, schema={"a": pl.Null}))
33
@example(df=pl.DataFrame(schema={"a": pl.List(pl.String)}))
34
def test_df_serde_roundtrip_json(df: pl.DataFrame) -> None:
35
serialized = df.serialize(format="json")
36
result = pl.DataFrame.deserialize(io.StringIO(serialized), format="json")
37
38
if isinstance(dt := df.to_series(0).dtype, pl.Decimal):
39
if dt.precision is None:
40
# This gets converted to precision 38 upon `to_arrow()`
41
pytest.skip("precision None")
42
43
assert_frame_equal(result, df, categorical_as_str=True)
44
45
46
def test_df_serde(df: pl.DataFrame) -> None:
47
serialized = df.serialize()
48
assert isinstance(serialized, bytes)
49
result = pl.DataFrame.deserialize(serialized)
50
assert_frame_equal(result, df)
51
52
53
def test_df_serde_json_stringio(df: pl.DataFrame) -> None:
54
serialized = df.serialize(format="json")
55
assert isinstance(serialized, str)
56
result = pl.DataFrame.deserialize(io.StringIO(serialized), format="json")
57
assert_frame_equal(result, df)
58
59
60
def test_df_serialize_json() -> None:
61
df = pl.DataFrame({"a": [1, 2, 3], "b": [9, 5, 6]}).sort("a")
62
result = df.serialize(format="json")
63
64
assert isinstance(result, str)
65
66
f = io.StringIO(result)
67
68
assert_frame_equal(pl.DataFrame.deserialize(f, format="json"), df)
69
70
71
@pytest.mark.parametrize(
72
("format", "buf"),
73
[
74
("binary", io.BytesIO()),
75
("json", io.StringIO()),
76
("json", io.BytesIO()),
77
],
78
)
79
def test_df_serde_to_from_buffer(
80
df: pl.DataFrame, format: SerializationFormat, buf: io.IOBase
81
) -> None:
82
df.serialize(buf, format=format)
83
buf.seek(0)
84
read_df = pl.DataFrame.deserialize(buf, format=format)
85
assert_frame_equal(df, read_df, categorical_as_str=True)
86
87
88
@pytest.mark.write_disk
89
def test_df_serde_to_from_file(df: pl.DataFrame, tmp_path: Path) -> None:
90
tmp_path.mkdir(exist_ok=True)
91
92
file_path = tmp_path / "small.bin"
93
df.serialize(file_path)
94
out = pl.DataFrame.deserialize(file_path)
95
96
assert_frame_equal(df, out, categorical_as_str=True)
97
98
99
def test_df_serde2(df: pl.DataFrame) -> None:
100
# Text-based conversion loses time info
101
df = df.select(pl.all().exclude(["cat", "time"]))
102
s = df.serialize()
103
f = io.BytesIO()
104
f.write(s)
105
f.seek(0)
106
out = pl.DataFrame.deserialize(f)
107
assert_frame_equal(out, df)
108
109
file = io.BytesIO()
110
df.serialize(file)
111
file.seek(0)
112
out = pl.DataFrame.deserialize(file)
113
assert_frame_equal(out, df)
114
115
116
def test_df_serde_enum() -> None:
117
dtype = pl.Enum(["foo", "bar", "ham"])
118
df = pl.DataFrame([pl.Series("e", ["foo", "bar", "ham"], dtype=dtype)])
119
buf = io.BytesIO()
120
df.serialize(buf)
121
buf.seek(0)
122
df_in = pl.DataFrame.deserialize(buf)
123
assert df_in.schema["e"] == dtype
124
125
126
@pytest.mark.parametrize(
127
("data", "dtype"),
128
[
129
([[1, 2, 3], [None, None, None], [1, None, 3]], pl.Array(pl.Int32(), shape=3)),
130
([["a", "b"], [None, None]], pl.Array(pl.Utf8, shape=2)),
131
([[True, False, None], [None, None, None]], pl.Array(pl.Boolean, shape=3)),
132
(
133
[[[1, 2, 3], [4, None, 5]], None, [[None, None, 2]]],
134
pl.List(pl.Array(pl.Int32(), shape=3)),
135
),
136
(
137
[
138
[datetime(1991, 1, 1), datetime(1991, 1, 1), None],
139
[None, None, None],
140
],
141
pl.Array(pl.Datetime, shape=3),
142
),
143
(
144
[[D("1.0"), D("2.0"), D("3.0")], [None, None, None]],
145
# we have to specify precision, because `AnonymousListBuilder::finish`
146
# use `ArrowDataType` which will remap `None` precision to `38`
147
pl.Array(pl.Decimal(precision=38, scale=1), shape=3),
148
),
149
],
150
)
151
def test_df_serde_array(data: Any, dtype: pl.DataType) -> None:
152
df = pl.DataFrame({"foo": data}, schema={"foo": dtype})
153
buf = io.BytesIO()
154
df.serialize(buf)
155
buf.seek(0)
156
deserialized_df = pl.DataFrame.deserialize(buf)
157
assert_frame_equal(deserialized_df, df)
158
159
160
@pytest.mark.parametrize(
161
("data", "dtype"),
162
[
163
(
164
[
165
[
166
datetime(1997, 10, 1),
167
datetime(2000, 1, 2, 10, 30, 1),
168
],
169
[None, None],
170
],
171
pl.Array(pl.Datetime, shape=2),
172
),
173
(
174
[[date(1997, 10, 1), date(2000, 1, 1)], [None, None]],
175
pl.Array(pl.Date, shape=2),
176
),
177
(
178
[
179
[timedelta(seconds=1), timedelta(seconds=10)],
180
[None, None],
181
],
182
pl.Array(pl.Duration, shape=2),
183
),
184
],
185
)
186
def test_df_serde_array_logical_inner_type(data: Any, dtype: pl.DataType) -> None:
187
df = pl.DataFrame({"foo": data}, schema={"foo": dtype})
188
buf = io.BytesIO()
189
df.serialize(buf)
190
buf.seek(0)
191
result = pl.DataFrame.deserialize(buf)
192
assert_frame_equal(result, df)
193
194
195
def test_df_serde_float_inf_nan() -> None:
196
df = pl.DataFrame({"a": [1.0, float("inf"), float("-inf"), float("nan")]})
197
ser = df.serialize(format="json")
198
result = pl.DataFrame.deserialize(io.StringIO(ser), format="json")
199
assert_frame_equal(result, df)
200
201
202
def test_df_serialize_invalid_type() -> None:
203
df = pl.DataFrame({"a": [object()]})
204
with pytest.raises(
205
ComputeError, match="serializing data of type Object is not supported"
206
):
207
df.serialize()
208
209
210
def test_df_serde_list_of_null_17230() -> None:
211
df = pl.Series([[]], dtype=pl.List(pl.Null)).to_frame()
212
ser = df.serialize(format="json")
213
result = pl.DataFrame.deserialize(io.StringIO(ser), format="json")
214
assert_frame_equal(result, df)
215
216
217
def test_df_serialize_from_multiple_python_threads_22364() -> None:
218
df = pl.DataFrame({"A": [1, 2, 3, 4]})
219
220
with ThreadPool(4) as tp:
221
tp.map(pickle.dumps, [df] * 1_000)
222
223