Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/lazyframe/test_serde.py
6939 views
1
from __future__ import annotations
2
3
import io
4
from typing import TYPE_CHECKING
5
6
import pytest
7
from hypothesis import example, given
8
9
import polars as pl
10
from polars.exceptions import ComputeError
11
from polars.testing import assert_frame_equal
12
from polars.testing.parametric import dataframes
13
14
if TYPE_CHECKING:
15
from pathlib import Path
16
17
from polars._typing import SerializationFormat
18
19
20
@given(
21
lf=dataframes(
22
lazy=True,
23
excluded_dtypes=[pl.Struct],
24
)
25
)
26
@example(lf=pl.LazyFrame({"foo": ["a", "b", "a"]}, schema={"foo": pl.Enum(["b", "a"])}))
27
def test_lf_serde_roundtrip_binary(lf: pl.LazyFrame) -> None:
28
serialized = lf.serialize(format="binary")
29
result = pl.LazyFrame.deserialize(io.BytesIO(serialized), format="binary")
30
assert_frame_equal(result, lf, categorical_as_str=True)
31
32
33
@given(
34
lf=dataframes(
35
lazy=True,
36
excluded_dtypes=[
37
pl.Float32, # Bug, see: https://github.com/pola-rs/polars/issues/17211
38
pl.Float64, # Bug, see: https://github.com/pola-rs/polars/issues/17211
39
pl.Struct, # Outer nullability not supported
40
],
41
)
42
)
43
@pytest.mark.filterwarnings("ignore")
44
def test_lf_serde_roundtrip_json(lf: pl.LazyFrame) -> None:
45
serialized = lf.serialize(format="json")
46
result = pl.LazyFrame.deserialize(io.StringIO(serialized), format="json")
47
assert_frame_equal(result, lf, categorical_as_str=True)
48
49
50
@pytest.fixture
51
def lf() -> pl.LazyFrame:
52
"""Sample LazyFrame for testing serialization/deserialization."""
53
return pl.LazyFrame({"a": [1, 2, 3], "b": ["a", "b", "c"]}).select("a").sum()
54
55
56
@pytest.mark.filterwarnings("ignore")
57
def test_lf_serde_json_stringio(lf: pl.LazyFrame) -> None:
58
serialized = lf.serialize(format="json")
59
assert isinstance(serialized, str)
60
result = pl.LazyFrame.deserialize(io.StringIO(serialized), format="json")
61
assert_frame_equal(result, lf)
62
63
64
def test_lf_serde(lf: pl.LazyFrame) -> None:
65
serialized = lf.serialize()
66
assert isinstance(serialized, bytes)
67
result = pl.LazyFrame.deserialize(io.BytesIO(serialized))
68
assert_frame_equal(result, lf)
69
70
71
@pytest.mark.parametrize(
72
("format", "buf"),
73
[
74
("binary", io.BytesIO()),
75
("json", io.StringIO()),
76
("json", io.BytesIO()),
77
],
78
)
79
@pytest.mark.filterwarnings("ignore")
80
def test_lf_serde_to_from_buffer(
81
lf: pl.LazyFrame, format: SerializationFormat, buf: io.IOBase
82
) -> None:
83
lf.serialize(buf, format=format)
84
buf.seek(0)
85
result = pl.LazyFrame.deserialize(buf, format=format)
86
assert_frame_equal(lf, result)
87
88
89
@pytest.mark.write_disk
90
def test_lf_serde_to_from_file(lf: pl.LazyFrame, tmp_path: Path) -> None:
91
tmp_path.mkdir(exist_ok=True)
92
93
file_path = tmp_path / "small.bin"
94
lf.serialize(file_path)
95
result = pl.LazyFrame.deserialize(file_path)
96
97
assert_frame_equal(lf, result)
98
99
100
def test_lf_deserialize_validation() -> None:
101
f = io.BytesIO(b"hello world!")
102
with pytest.raises(ComputeError, match="expected value at line 1 column 1"):
103
pl.LazyFrame.deserialize(f, format="json")
104
105
106
@pytest.mark.write_disk
107
def test_lf_serde_scan(tmp_path: Path) -> None:
108
tmp_path.mkdir(exist_ok=True)
109
path = tmp_path / "dataset.parquet"
110
111
df = pl.DataFrame({"a": [1, 2, 3], "b": ["x", "y", "z"]})
112
df.write_parquet(path)
113
lf = pl.scan_parquet(path)
114
115
ser = lf.serialize()
116
result = pl.LazyFrame.deserialize(io.BytesIO(ser))
117
assert_frame_equal(result, lf)
118
assert_frame_equal(result.collect(), df)
119
120
121
@pytest.mark.filterwarnings("ignore::polars.exceptions.PolarsInefficientMapWarning")
122
def test_lf_serde_version_specific_lambda() -> None:
123
lf = pl.LazyFrame({"a": [1, 2, 3]}).select(
124
pl.col("a").map_elements(lambda x: x + 1, return_dtype=pl.Int64)
125
)
126
ser = lf.serialize()
127
128
result = pl.LazyFrame.deserialize(io.BytesIO(ser))
129
expected = pl.LazyFrame({"a": [2, 3, 4]})
130
assert_frame_equal(result, expected)
131
132
133
def custom_function(x: pl.Series) -> pl.Series:
134
return x + 1
135
136
137
@pytest.mark.may_fail_cloud # reason: cloud does not have access to this scope
138
@pytest.mark.filterwarnings("ignore::polars.exceptions.PolarsInefficientMapWarning")
139
def test_lf_serde_version_specific_named_function() -> None:
140
lf = pl.LazyFrame({"a": [1, 2, 3]}).select(
141
pl.col("a").map_batches(custom_function, return_dtype=pl.Int64)
142
)
143
ser = lf.serialize()
144
145
result = pl.LazyFrame.deserialize(io.BytesIO(ser))
146
expected = pl.LazyFrame({"a": [2, 3, 4]})
147
assert_frame_equal(result, expected)
148
149
150
@pytest.mark.filterwarnings("ignore::polars.exceptions.PolarsInefficientMapWarning")
151
def test_lf_serde_map_batches_on_lazyframe() -> None:
152
lf = pl.LazyFrame({"a": [1, 2, 3]}).map_batches(lambda x: x + 1)
153
ser = lf.serialize()
154
155
result = pl.LazyFrame.deserialize(io.BytesIO(ser))
156
expected = pl.LazyFrame({"a": [2, 3, 4]})
157
assert_frame_equal(result, expected)
158
159