Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/dataframe/test_serde.py
6939 views
1
from __future__ import annotations
2
3
import io
4
import pickle
5
from datetime import date, datetime, timedelta
6
from decimal import Decimal as D
7
from multiprocessing.pool import ThreadPool
8
from typing import TYPE_CHECKING, Any
9
10
import pytest
11
from hypothesis import example, given
12
13
import polars as pl
14
from polars.exceptions import ComputeError
15
from polars.testing import assert_frame_equal
16
from polars.testing.parametric import dataframes
17
18
if TYPE_CHECKING:
19
from pathlib import Path
20
21
from polars._typing import SerializationFormat
22
23
24
def test_df_serde_roundtrip_binary(df: pl.DataFrame) -> None:
25
serialized = df.serialize()
26
result = pl.DataFrame.deserialize(io.BytesIO(serialized), format="binary")
27
assert_frame_equal(result, df, categorical_as_str=True)
28
29
30
@given(df=dataframes())
31
@example(df=pl.DataFrame({"a": [None, None]}, schema={"a": pl.Null}))
32
@example(df=pl.DataFrame(schema={"a": pl.List(pl.String)}))
33
def test_df_serde_roundtrip_json(df: pl.DataFrame) -> None:
34
serialized = df.serialize(format="json")
35
result = pl.DataFrame.deserialize(io.StringIO(serialized), format="json")
36
37
if isinstance(dt := df.to_series(0).dtype, pl.Decimal):
38
if dt.precision is None:
39
# This gets converted to precision 38 upon `to_arrow()`
40
pytest.skip("precision None")
41
42
assert_frame_equal(result, df, categorical_as_str=True)
43
44
45
def test_df_serde(df: pl.DataFrame) -> None:
46
serialized = df.serialize()
47
assert isinstance(serialized, bytes)
48
result = pl.DataFrame.deserialize(io.BytesIO(serialized))
49
assert_frame_equal(result, df)
50
51
52
def test_df_serde_json_stringio(df: pl.DataFrame) -> None:
53
serialized = df.serialize(format="json")
54
assert isinstance(serialized, str)
55
result = pl.DataFrame.deserialize(io.StringIO(serialized), format="json")
56
assert_frame_equal(result, df)
57
58
59
def test_df_serialize_json() -> None:
60
df = pl.DataFrame({"a": [1, 2, 3], "b": [9, 5, 6]}).sort("a")
61
result = df.serialize(format="json")
62
63
assert isinstance(result, str)
64
65
f = io.StringIO(result)
66
67
assert_frame_equal(pl.DataFrame.deserialize(f, format="json"), df)
68
69
70
@pytest.mark.parametrize(
71
("format", "buf"),
72
[
73
("binary", io.BytesIO()),
74
("json", io.StringIO()),
75
("json", io.BytesIO()),
76
],
77
)
78
def test_df_serde_to_from_buffer(
79
df: pl.DataFrame, format: SerializationFormat, buf: io.IOBase
80
) -> None:
81
df.serialize(buf, format=format)
82
buf.seek(0)
83
read_df = pl.DataFrame.deserialize(buf, format=format)
84
assert_frame_equal(df, read_df, categorical_as_str=True)
85
86
87
@pytest.mark.write_disk
88
def test_df_serde_to_from_file(df: pl.DataFrame, tmp_path: Path) -> None:
89
tmp_path.mkdir(exist_ok=True)
90
91
file_path = tmp_path / "small.bin"
92
df.serialize(file_path)
93
out = pl.DataFrame.deserialize(file_path)
94
95
assert_frame_equal(df, out, categorical_as_str=True)
96
97
98
def test_df_serde2(df: pl.DataFrame) -> None:
99
# Text-based conversion loses time info
100
df = df.select(pl.all().exclude(["cat", "time"]))
101
s = df.serialize()
102
f = io.BytesIO()
103
f.write(s)
104
f.seek(0)
105
out = pl.DataFrame.deserialize(f)
106
assert_frame_equal(out, df)
107
108
file = io.BytesIO()
109
df.serialize(file)
110
file.seek(0)
111
out = pl.DataFrame.deserialize(file)
112
assert_frame_equal(out, df)
113
114
115
def test_df_serde_enum() -> None:
116
dtype = pl.Enum(["foo", "bar", "ham"])
117
df = pl.DataFrame([pl.Series("e", ["foo", "bar", "ham"], dtype=dtype)])
118
buf = io.BytesIO()
119
df.serialize(buf)
120
buf.seek(0)
121
df_in = pl.DataFrame.deserialize(buf)
122
assert df_in.schema["e"] == dtype
123
124
125
@pytest.mark.parametrize(
126
("data", "dtype"),
127
[
128
([[1, 2, 3], [None, None, None], [1, None, 3]], pl.Array(pl.Int32(), shape=3)),
129
([["a", "b"], [None, None]], pl.Array(pl.Utf8, shape=2)),
130
([[True, False, None], [None, None, None]], pl.Array(pl.Boolean, shape=3)),
131
(
132
[[[1, 2, 3], [4, None, 5]], None, [[None, None, 2]]],
133
pl.List(pl.Array(pl.Int32(), shape=3)),
134
),
135
(
136
[
137
[datetime(1991, 1, 1), datetime(1991, 1, 1), None],
138
[None, None, None],
139
],
140
pl.Array(pl.Datetime, shape=3),
141
),
142
(
143
[[D("1.0"), D("2.0"), D("3.0")], [None, None, None]],
144
# we have to specify precision, because `AnonymousListBuilder::finish`
145
# use `ArrowDataType` which will remap `None` precision to `38`
146
pl.Array(pl.Decimal(precision=38, scale=1), shape=3),
147
),
148
],
149
)
150
def test_df_serde_array(data: Any, dtype: pl.DataType) -> None:
151
df = pl.DataFrame({"foo": data}, schema={"foo": dtype})
152
buf = io.BytesIO()
153
df.serialize(buf)
154
buf.seek(0)
155
deserialized_df = pl.DataFrame.deserialize(buf)
156
assert_frame_equal(deserialized_df, df)
157
158
159
@pytest.mark.parametrize(
160
("data", "dtype"),
161
[
162
(
163
[
164
[
165
datetime(1997, 10, 1),
166
datetime(2000, 1, 2, 10, 30, 1),
167
],
168
[None, None],
169
],
170
pl.Array(pl.Datetime, shape=2),
171
),
172
(
173
[[date(1997, 10, 1), date(2000, 1, 1)], [None, None]],
174
pl.Array(pl.Date, shape=2),
175
),
176
(
177
[
178
[timedelta(seconds=1), timedelta(seconds=10)],
179
[None, None],
180
],
181
pl.Array(pl.Duration, shape=2),
182
),
183
],
184
)
185
def test_df_serde_array_logical_inner_type(data: Any, dtype: pl.DataType) -> None:
186
df = pl.DataFrame({"foo": data}, schema={"foo": dtype})
187
buf = io.BytesIO()
188
df.serialize(buf)
189
buf.seek(0)
190
result = pl.DataFrame.deserialize(buf)
191
assert_frame_equal(result, df)
192
193
194
def test_df_serde_float_inf_nan() -> None:
195
df = pl.DataFrame({"a": [1.0, float("inf"), float("-inf"), float("nan")]})
196
ser = df.serialize(format="json")
197
result = pl.DataFrame.deserialize(io.StringIO(ser), format="json")
198
assert_frame_equal(result, df)
199
200
201
def test_df_serialize_invalid_type() -> None:
202
df = pl.DataFrame({"a": [object()]})
203
with pytest.raises(
204
ComputeError, match="serializing data of type Object is not supported"
205
):
206
df.serialize()
207
208
209
def test_df_serde_list_of_null_17230() -> None:
210
df = pl.Series([[]], dtype=pl.List(pl.Null)).to_frame()
211
ser = df.serialize(format="json")
212
result = pl.DataFrame.deserialize(io.StringIO(ser), format="json")
213
assert_frame_equal(result, df)
214
215
216
def test_df_serialize_from_multiple_python_threads_22364() -> None:
217
df = pl.DataFrame({"A": [1, 2, 3, 4]})
218
219
with ThreadPool(4) as tp:
220
tp.map(pickle.dumps, [df] * 1_000)
221
222