Path: blob/main/py-polars/tests/unit/dataframe/test_serde.py
8415 views
from __future__ import annotations12import io3import pickle4from datetime import date, datetime, timedelta5from decimal import Decimal as D6from multiprocessing.pool import ThreadPool7from typing import TYPE_CHECKING, Any89import pytest10from hypothesis import example, given1112import polars as pl13from polars.exceptions import ComputeError14from polars.testing import assert_frame_equal15from polars.testing.parametric import dataframes1617if TYPE_CHECKING:18from pathlib import Path1920from polars._typing import SerializationFormat212223def test_df_serde_roundtrip_binary(df: pl.DataFrame) -> None:24serialized = df.serialize()25result = pl.DataFrame.deserialize(serialized, format="binary")26assert_frame_equal(result, df, categorical_as_str=True)272829@given(df=dataframes())30@example(df=pl.DataFrame({"a": {"a": 1.0}}, schema={"a": pl.Struct({"a": pl.Float16})}))31@example(df=pl.DataFrame({"a": [None, None]}, schema={"a": pl.Null}))32@example(df=pl.DataFrame(schema={"a": pl.List(pl.String)}))33def test_df_serde_roundtrip_json(df: pl.DataFrame) -> None:34serialized = df.serialize(format="json")35result = pl.DataFrame.deserialize(io.StringIO(serialized), format="json")3637if isinstance(dt := df.to_series(0).dtype, pl.Decimal):38if dt.precision is None:39# This gets converted to precision 38 upon `to_arrow()`40pytest.skip("precision None")4142assert_frame_equal(result, df, categorical_as_str=True)434445def test_df_serde(df: pl.DataFrame) -> None:46serialized = df.serialize()47assert isinstance(serialized, bytes)48result = pl.DataFrame.deserialize(serialized)49assert_frame_equal(result, df)505152def test_df_serde_json_stringio(df: pl.DataFrame) -> None:53serialized = df.serialize(format="json")54assert isinstance(serialized, str)55result = pl.DataFrame.deserialize(io.StringIO(serialized), format="json")56assert_frame_equal(result, df)575859def test_df_serialize_json() -> None:60df = pl.DataFrame({"a": [1, 2, 3], "b": [9, 5, 6]}).sort("a")61result = df.serialize(format="json")6263assert isinstance(result, str)6465f = io.StringIO(result)6667assert_frame_equal(pl.DataFrame.deserialize(f, format="json"), df)686970@pytest.mark.parametrize(71("format", "buf"),72[73("binary", io.BytesIO()),74("json", io.StringIO()),75("json", io.BytesIO()),76],77)78def test_df_serde_to_from_buffer(79df: pl.DataFrame, format: SerializationFormat, buf: io.IOBase80) -> None:81df.serialize(buf, format=format)82buf.seek(0)83read_df = pl.DataFrame.deserialize(buf, format=format)84assert_frame_equal(df, read_df, categorical_as_str=True)858687@pytest.mark.write_disk88def test_df_serde_to_from_file(df: pl.DataFrame, tmp_path: Path) -> None:89tmp_path.mkdir(exist_ok=True)9091file_path = tmp_path / "small.bin"92df.serialize(file_path)93out = pl.DataFrame.deserialize(file_path)9495assert_frame_equal(df, out, categorical_as_str=True)969798def test_df_serde2(df: pl.DataFrame) -> None:99# Text-based conversion loses time info100df = df.select(pl.all().exclude(["cat", "time"]))101s = df.serialize()102f = io.BytesIO()103f.write(s)104f.seek(0)105out = pl.DataFrame.deserialize(f)106assert_frame_equal(out, df)107108file = io.BytesIO()109df.serialize(file)110file.seek(0)111out = pl.DataFrame.deserialize(file)112assert_frame_equal(out, df)113114115def test_df_serde_enum() -> None:116dtype = pl.Enum(["foo", "bar", "ham"])117df = pl.DataFrame([pl.Series("e", ["foo", "bar", "ham"], dtype=dtype)])118buf = io.BytesIO()119df.serialize(buf)120buf.seek(0)121df_in = pl.DataFrame.deserialize(buf)122assert df_in.schema["e"] == dtype123124125@pytest.mark.parametrize(126("data", "dtype"),127[128([[1, 2, 3], [None, None, None], [1, None, 3]], pl.Array(pl.Int32(), shape=3)),129([["a", "b"], [None, None]], pl.Array(pl.Utf8, shape=2)),130([[True, False, None], [None, None, None]], pl.Array(pl.Boolean, shape=3)),131(132[[[1, 2, 3], [4, None, 5]], None, [[None, None, 2]]],133pl.List(pl.Array(pl.Int32(), shape=3)),134),135(136[137[datetime(1991, 1, 1), datetime(1991, 1, 1), None],138[None, None, None],139],140pl.Array(pl.Datetime, shape=3),141),142(143[[D("1.0"), D("2.0"), D("3.0")], [None, None, None]],144# we have to specify precision, because `AnonymousListBuilder::finish`145# use `ArrowDataType` which will remap `None` precision to `38`146pl.Array(pl.Decimal(precision=38, scale=1), shape=3),147),148],149)150def test_df_serde_array(data: Any, dtype: pl.DataType) -> None:151df = pl.DataFrame({"foo": data}, schema={"foo": dtype})152buf = io.BytesIO()153df.serialize(buf)154buf.seek(0)155deserialized_df = pl.DataFrame.deserialize(buf)156assert_frame_equal(deserialized_df, df)157158159@pytest.mark.parametrize(160("data", "dtype"),161[162(163[164[165datetime(1997, 10, 1),166datetime(2000, 1, 2, 10, 30, 1),167],168[None, None],169],170pl.Array(pl.Datetime, shape=2),171),172(173[[date(1997, 10, 1), date(2000, 1, 1)], [None, None]],174pl.Array(pl.Date, shape=2),175),176(177[178[timedelta(seconds=1), timedelta(seconds=10)],179[None, None],180],181pl.Array(pl.Duration, shape=2),182),183],184)185def test_df_serde_array_logical_inner_type(data: Any, dtype: pl.DataType) -> None:186df = pl.DataFrame({"foo": data}, schema={"foo": dtype})187buf = io.BytesIO()188df.serialize(buf)189buf.seek(0)190result = pl.DataFrame.deserialize(buf)191assert_frame_equal(result, df)192193194def test_df_serde_float_inf_nan() -> None:195df = pl.DataFrame({"a": [1.0, float("inf"), float("-inf"), float("nan")]})196ser = df.serialize(format="json")197result = pl.DataFrame.deserialize(io.StringIO(ser), format="json")198assert_frame_equal(result, df)199200201def test_df_serialize_invalid_type() -> None:202df = pl.DataFrame({"a": [object()]})203with pytest.raises(204ComputeError, match="serializing data of type Object is not supported"205):206df.serialize()207208209def test_df_serde_list_of_null_17230() -> None:210df = pl.Series([[]], dtype=pl.List(pl.Null)).to_frame()211ser = df.serialize(format="json")212result = pl.DataFrame.deserialize(io.StringIO(ser), format="json")213assert_frame_equal(result, df)214215216def test_df_serialize_from_multiple_python_threads_22364() -> None:217df = pl.DataFrame({"A": [1, 2, 3, 4]})218219with ThreadPool(4) as tp:220tp.map(pickle.dumps, [df] * 1_000)221222223