Path: blob/main/py-polars/tests/unit/dataframe/test_serde.py
6939 views
from __future__ import annotations12import io3import pickle4from datetime import date, datetime, timedelta5from decimal import Decimal as D6from multiprocessing.pool import ThreadPool7from typing import TYPE_CHECKING, Any89import pytest10from hypothesis import example, given1112import polars as pl13from polars.exceptions import ComputeError14from polars.testing import assert_frame_equal15from polars.testing.parametric import dataframes1617if TYPE_CHECKING:18from pathlib import Path1920from polars._typing import SerializationFormat212223def test_df_serde_roundtrip_binary(df: pl.DataFrame) -> None:24serialized = df.serialize()25result = pl.DataFrame.deserialize(io.BytesIO(serialized), format="binary")26assert_frame_equal(result, df, categorical_as_str=True)272829@given(df=dataframes())30@example(df=pl.DataFrame({"a": [None, None]}, schema={"a": pl.Null}))31@example(df=pl.DataFrame(schema={"a": pl.List(pl.String)}))32def test_df_serde_roundtrip_json(df: pl.DataFrame) -> None:33serialized = df.serialize(format="json")34result = pl.DataFrame.deserialize(io.StringIO(serialized), format="json")3536if isinstance(dt := df.to_series(0).dtype, pl.Decimal):37if dt.precision is None:38# This gets converted to precision 38 upon `to_arrow()`39pytest.skip("precision None")4041assert_frame_equal(result, df, categorical_as_str=True)424344def test_df_serde(df: pl.DataFrame) -> None:45serialized = df.serialize()46assert isinstance(serialized, bytes)47result = pl.DataFrame.deserialize(io.BytesIO(serialized))48assert_frame_equal(result, df)495051def test_df_serde_json_stringio(df: pl.DataFrame) -> None:52serialized = df.serialize(format="json")53assert isinstance(serialized, str)54result = pl.DataFrame.deserialize(io.StringIO(serialized), format="json")55assert_frame_equal(result, df)565758def test_df_serialize_json() -> None:59df = pl.DataFrame({"a": [1, 2, 3], "b": [9, 5, 6]}).sort("a")60result = df.serialize(format="json")6162assert isinstance(result, str)6364f = io.StringIO(result)6566assert_frame_equal(pl.DataFrame.deserialize(f, format="json"), df)676869@pytest.mark.parametrize(70("format", "buf"),71[72("binary", io.BytesIO()),73("json", io.StringIO()),74("json", io.BytesIO()),75],76)77def test_df_serde_to_from_buffer(78df: pl.DataFrame, format: SerializationFormat, buf: io.IOBase79) -> None:80df.serialize(buf, format=format)81buf.seek(0)82read_df = pl.DataFrame.deserialize(buf, format=format)83assert_frame_equal(df, read_df, categorical_as_str=True)848586@pytest.mark.write_disk87def test_df_serde_to_from_file(df: pl.DataFrame, tmp_path: Path) -> None:88tmp_path.mkdir(exist_ok=True)8990file_path = tmp_path / "small.bin"91df.serialize(file_path)92out = pl.DataFrame.deserialize(file_path)9394assert_frame_equal(df, out, categorical_as_str=True)959697def test_df_serde2(df: pl.DataFrame) -> None:98# Text-based conversion loses time info99df = df.select(pl.all().exclude(["cat", "time"]))100s = df.serialize()101f = io.BytesIO()102f.write(s)103f.seek(0)104out = pl.DataFrame.deserialize(f)105assert_frame_equal(out, df)106107file = io.BytesIO()108df.serialize(file)109file.seek(0)110out = pl.DataFrame.deserialize(file)111assert_frame_equal(out, df)112113114def test_df_serde_enum() -> None:115dtype = pl.Enum(["foo", "bar", "ham"])116df = pl.DataFrame([pl.Series("e", ["foo", "bar", "ham"], dtype=dtype)])117buf = io.BytesIO()118df.serialize(buf)119buf.seek(0)120df_in = pl.DataFrame.deserialize(buf)121assert df_in.schema["e"] == dtype122123124@pytest.mark.parametrize(125("data", "dtype"),126[127([[1, 2, 3], [None, None, None], [1, None, 3]], pl.Array(pl.Int32(), shape=3)),128([["a", "b"], [None, None]], pl.Array(pl.Utf8, shape=2)),129([[True, False, None], [None, None, None]], pl.Array(pl.Boolean, shape=3)),130(131[[[1, 2, 3], [4, None, 5]], None, [[None, None, 2]]],132pl.List(pl.Array(pl.Int32(), shape=3)),133),134(135[136[datetime(1991, 1, 1), datetime(1991, 1, 1), None],137[None, None, None],138],139pl.Array(pl.Datetime, shape=3),140),141(142[[D("1.0"), D("2.0"), D("3.0")], [None, None, None]],143# we have to specify precision, because `AnonymousListBuilder::finish`144# use `ArrowDataType` which will remap `None` precision to `38`145pl.Array(pl.Decimal(precision=38, scale=1), shape=3),146),147],148)149def test_df_serde_array(data: Any, dtype: pl.DataType) -> None:150df = pl.DataFrame({"foo": data}, schema={"foo": dtype})151buf = io.BytesIO()152df.serialize(buf)153buf.seek(0)154deserialized_df = pl.DataFrame.deserialize(buf)155assert_frame_equal(deserialized_df, df)156157158@pytest.mark.parametrize(159("data", "dtype"),160[161(162[163[164datetime(1997, 10, 1),165datetime(2000, 1, 2, 10, 30, 1),166],167[None, None],168],169pl.Array(pl.Datetime, shape=2),170),171(172[[date(1997, 10, 1), date(2000, 1, 1)], [None, None]],173pl.Array(pl.Date, shape=2),174),175(176[177[timedelta(seconds=1), timedelta(seconds=10)],178[None, None],179],180pl.Array(pl.Duration, shape=2),181),182],183)184def test_df_serde_array_logical_inner_type(data: Any, dtype: pl.DataType) -> None:185df = pl.DataFrame({"foo": data}, schema={"foo": dtype})186buf = io.BytesIO()187df.serialize(buf)188buf.seek(0)189result = pl.DataFrame.deserialize(buf)190assert_frame_equal(result, df)191192193def test_df_serde_float_inf_nan() -> None:194df = pl.DataFrame({"a": [1.0, float("inf"), float("-inf"), float("nan")]})195ser = df.serialize(format="json")196result = pl.DataFrame.deserialize(io.StringIO(ser), format="json")197assert_frame_equal(result, df)198199200def test_df_serialize_invalid_type() -> None:201df = pl.DataFrame({"a": [object()]})202with pytest.raises(203ComputeError, match="serializing data of type Object is not supported"204):205df.serialize()206207208def test_df_serde_list_of_null_17230() -> None:209df = pl.Series([[]], dtype=pl.List(pl.Null)).to_frame()210ser = df.serialize(format="json")211result = pl.DataFrame.deserialize(io.StringIO(ser), format="json")212assert_frame_equal(result, df)213214215def test_df_serialize_from_multiple_python_threads_22364() -> None:216df = pl.DataFrame({"A": [1, 2, 3, 4]})217218with ThreadPool(4) as tp:219tp.map(pickle.dumps, [df] * 1_000)220221222