Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/lazyframe/test_serde.py
8424 views
1
from __future__ import annotations
2
3
import io
4
from typing import TYPE_CHECKING
5
6
import pytest
7
from hypothesis import example, given
8
9
import polars as pl
10
from polars.exceptions import ComputeError
11
from polars.testing import assert_frame_equal
12
from polars.testing.parametric import dataframes
13
14
if TYPE_CHECKING:
15
from pathlib import Path
16
17
from polars._typing import SerializationFormat
18
from tests.conftest import PlMonkeyPatch
19
20
21
@given(
22
lf=dataframes(
23
lazy=True,
24
excluded_dtypes=[pl.Struct],
25
)
26
)
27
@example(lf=pl.LazyFrame({"foo": ["a", "b", "a"]}, schema={"foo": pl.Enum(["b", "a"])}))
28
def test_lf_serde_roundtrip_binary(lf: pl.LazyFrame) -> None:
29
serialized = lf.serialize(format="binary")
30
result = pl.LazyFrame.deserialize(io.BytesIO(serialized), format="binary")
31
assert_frame_equal(result, lf, categorical_as_str=True)
32
33
34
@given(
35
lf=dataframes(
36
lazy=True,
37
excluded_dtypes=[
38
pl.Float32, # Bug, see: https://github.com/pola-rs/polars/issues/17211
39
pl.Float64, # Bug, see: https://github.com/pola-rs/polars/issues/17211
40
pl.Struct, # Outer nullability not supported
41
],
42
)
43
)
44
@pytest.mark.filterwarnings("ignore")
45
def test_lf_serde_roundtrip_json(lf: pl.LazyFrame) -> None:
46
serialized = lf.serialize(format="json")
47
result = pl.LazyFrame.deserialize(io.StringIO(serialized), format="json")
48
assert_frame_equal(result, lf, categorical_as_str=True)
49
50
51
@pytest.fixture
52
def lf() -> pl.LazyFrame:
53
"""Sample LazyFrame for testing serialization/deserialization."""
54
return pl.LazyFrame({"a": [1, 2, 3], "b": ["a", "b", "c"]}).select("a").sum()
55
56
57
@pytest.mark.filterwarnings("ignore")
58
def test_lf_serde_json_stringio(lf: pl.LazyFrame) -> None:
59
serialized = lf.serialize(format="json")
60
assert isinstance(serialized, str)
61
result = pl.LazyFrame.deserialize(io.StringIO(serialized), format="json")
62
assert_frame_equal(result, lf)
63
64
65
def test_lf_serde(lf: pl.LazyFrame) -> None:
66
serialized = lf.serialize()
67
assert isinstance(serialized, bytes)
68
result = pl.LazyFrame.deserialize(io.BytesIO(serialized))
69
assert_frame_equal(result, lf)
70
71
72
@pytest.mark.parametrize(
73
("format", "buf"),
74
[
75
("binary", io.BytesIO()),
76
("json", io.StringIO()),
77
("json", io.BytesIO()),
78
],
79
)
80
@pytest.mark.filterwarnings("ignore")
81
def test_lf_serde_to_from_buffer(
82
lf: pl.LazyFrame, format: SerializationFormat, buf: io.IOBase
83
) -> None:
84
lf.serialize(buf, format=format)
85
buf.seek(0)
86
result = pl.LazyFrame.deserialize(buf, format=format)
87
assert_frame_equal(lf, result)
88
89
90
@pytest.mark.write_disk
91
def test_lf_serde_to_from_file(lf: pl.LazyFrame, tmp_path: Path) -> None:
92
tmp_path.mkdir(exist_ok=True)
93
94
file_path = tmp_path / "small.bin"
95
lf.serialize(file_path)
96
result = pl.LazyFrame.deserialize(file_path)
97
98
assert_frame_equal(lf, result)
99
100
101
def test_lf_deserialize_validation() -> None:
102
f = io.BytesIO(b"hello world!")
103
with pytest.raises(ComputeError, match="expected value at line 1 column 1"):
104
pl.LazyFrame.deserialize(f, format="json")
105
106
107
@pytest.mark.write_disk
108
def test_lf_serde_scan(tmp_path: Path) -> None:
109
tmp_path.mkdir(exist_ok=True)
110
path = tmp_path / "dataset.parquet"
111
112
df = pl.DataFrame({"a": [1, 2, 3], "b": ["x", "y", "z"]})
113
df.write_parquet(path)
114
lf = pl.scan_parquet(path)
115
116
ser = lf.serialize()
117
result = pl.LazyFrame.deserialize(io.BytesIO(ser))
118
assert_frame_equal(result, lf)
119
assert_frame_equal(result.collect(), df)
120
121
122
@pytest.mark.filterwarnings("ignore::polars.exceptions.PolarsInefficientMapWarning")
123
def test_lf_serde_version_specific_lambda() -> None:
124
lf = pl.LazyFrame({"a": [1, 2, 3]}).select(
125
pl.col("a").map_elements(lambda x: x + 1, return_dtype=pl.Int64)
126
)
127
ser = lf.serialize()
128
129
result = pl.LazyFrame.deserialize(io.BytesIO(ser))
130
expected = pl.LazyFrame({"a": [2, 3, 4]})
131
assert_frame_equal(result, expected)
132
133
134
def custom_function(x: pl.Series) -> pl.Series:
135
return x + 1
136
137
138
@pytest.mark.may_fail_cloud # reason: cloud does not have access to this scope
139
@pytest.mark.filterwarnings("ignore::polars.exceptions.PolarsInefficientMapWarning")
140
def test_lf_serde_version_specific_named_function() -> None:
141
lf = pl.LazyFrame({"a": [1, 2, 3]}).select(
142
pl.col("a").map_batches(custom_function, return_dtype=pl.Int64)
143
)
144
ser = lf.serialize()
145
146
result = pl.LazyFrame.deserialize(io.BytesIO(ser))
147
expected = pl.LazyFrame({"a": [2, 3, 4]})
148
assert_frame_equal(result, expected)
149
150
151
@pytest.mark.filterwarnings("ignore::polars.exceptions.PolarsInefficientMapWarning")
152
def test_lf_serde_map_batches_on_lazyframe() -> None:
153
lf = pl.LazyFrame({"a": [1, 2, 3]}).map_batches(lambda x: x + 1)
154
ser = lf.serialize()
155
156
result = pl.LazyFrame.deserialize(io.BytesIO(ser))
157
expected = pl.LazyFrame({"a": [2, 3, 4]})
158
assert_frame_equal(result, expected)
159
160
161
@pytest.mark.parametrize("max_byte_slice_len", [1, 2, 3, 100, 4294967295])
162
def test_lf_serde_chunked_bytes(
163
plmonkeypatch: PlMonkeyPatch, max_byte_slice_len: int
164
) -> None:
165
plmonkeypatch.setenv(
166
"POLARS_SERIALIZE_LAZYFRAME_MAX_BYTE_SLICE_LEN", str(max_byte_slice_len)
167
)
168
lf = pl.LazyFrame({"a": range(5000)})
169
170
b = lf.serialize()
171
172
assert_frame_equal(pl.LazyFrame.deserialize(io.BytesIO(b)).collect(), lf.collect())
173
174
175
def test_lf_collect_schema_does_not_change_serialize_25719() -> None:
176
df = pl.DataFrame({"x": [1, 2, 3]})
177
178
lf = df.lazy()
179
lf.collect_schema()
180
181
assert lf.serialize() == df.lazy().serialize()
182
lf_sum = lf.sum()
183
lf_sum.collect_schema()
184
assert lf_sum.serialize() == df.lazy().sum().serialize()
185
186
q = pl.concat([lf_sum, lf_sum])
187
188
assert_frame_equal(
189
pl.LazyFrame.deserialize(q.serialize()).collect(),
190
pl.DataFrame({"x": [6, 6]}),
191
)
192
193