Path: blob/main/py-polars/tests/unit/interchange/test_from_dataframe.py
6939 views
from __future__ import annotations12from datetime import date, datetime, time, timedelta3from typing import Any45import pandas as pd6import pyarrow as pa7import pytest89import polars as pl10from polars.interchange.buffer import PolarsBuffer11from polars.interchange.column import PolarsColumn12from polars.interchange.from_dataframe import (13_categorical_column_to_series,14_column_to_series,15_construct_data_buffer,16_construct_offsets_buffer,17_construct_validity_buffer,18_construct_validity_buffer_from_bitmask,19_construct_validity_buffer_from_bytemask,20_string_column_to_series,21)22from polars.interchange.protocol import (23ColumnNullType,24CopyNotAllowedError,25DtypeKind,26Endianness,27)28from polars.testing import assert_frame_equal, assert_series_equal2930NE = Endianness.NATIVE313233def test_from_dataframe_polars() -> None:34df = pl.DataFrame({"a": [1, 2], "b": [3.0, 4.0], "c": ["foo", "bar"]})35with pytest.deprecated_call(match="`allow_copy` is deprecated"):36result = pl.from_dataframe(df, allow_copy=False)37assert_frame_equal(result, df)383940def test_from_dataframe_polars_interchange_fast_path() -> None:41df = pl.DataFrame(42{"a": [1, 2], "b": [3.0, 4.0], "c": ["foo", "bar"]},43schema_overrides={"c": pl.Categorical},44)45dfi = df.__dataframe__()46with pytest.deprecated_call(match="`allow_copy` is deprecated"):47result = pl.from_dataframe(dfi, allow_copy=False)48assert_frame_equal(result, df)495051def test_from_dataframe_categorical() -> None:52df = pl.DataFrame({"a": ["foo", "bar"]}, schema={"a": pl.Categorical})53df_pa = df.to_arrow()5455with pytest.deprecated_call(match="`allow_copy` is deprecated"):56result = pl.from_dataframe(df_pa, allow_copy=True)57expected = pl.DataFrame({"a": ["foo", "bar"]}, schema={"a": pl.Categorical})58assert_frame_equal(result, expected)596061def test_from_dataframe_empty_string_zero_copy() -> None:62df = pl.DataFrame({"a": []}, schema={"a": pl.String})63df_pa = df.to_arrow()64with pytest.deprecated_call(match="`allow_copy` is deprecated"):65result = pl.from_dataframe(df_pa, allow_copy=False)66assert_frame_equal(result, df)676869def test_from_dataframe_empty_bool_zero_copy() -> None:70df = pl.DataFrame(schema={"a": pl.Boolean})71df_pd = df.to_pandas()72with pytest.deprecated_call(match="`allow_copy` is deprecated"):73result = pl.from_dataframe(df_pd, allow_copy=False)74assert_frame_equal(result, df)757677def test_from_dataframe_empty_categories_zero_copy() -> None:78df = pl.DataFrame(schema={"a": pl.Enum([])})79df_pa = df.to_arrow()80with pytest.deprecated_call(match="`allow_copy` is deprecated"):81result = pl.from_dataframe(df_pa, allow_copy=False)82assert_frame_equal(result, df)838485def test_from_dataframe_pandas_zero_copy() -> None:86data = {"a": [1, 2], "b": [3.0, 4.0]}8788df = pd.DataFrame(data)89with pytest.deprecated_call(match="`allow_copy` is deprecated"):90result = pl.from_dataframe(df, allow_copy=False)91expected = pl.DataFrame(data)92assert_frame_equal(result, expected)939495def test_from_dataframe_pyarrow_table_zero_copy() -> None:96df = pl.DataFrame(97{98"a": [1, 2],99"b": [3.0, 4.0],100}101)102df_pa = df.to_arrow()103104with pytest.deprecated_call(match="`allow_copy` is deprecated"):105result = pl.from_dataframe(df_pa, allow_copy=False)106assert_frame_equal(result, df)107108109def test_from_dataframe_pyarrow_empty_table() -> None:110df = pl.Series("a", dtype=pl.Int8).to_frame()111df_pa = df.to_arrow()112113with pytest.deprecated_call(match="`allow_copy` is deprecated"):114result = pl.from_dataframe(df_pa, allow_copy=False)115assert_frame_equal(result, df)116117118def test_from_dataframe_pyarrow_recordbatch_zero_copy() -> None:119a = pa.array([1, 2])120b = pa.array([3.0, 4.0])121122batch = pa.record_batch([a, b], names=["a", "b"])123with pytest.deprecated_call(match="`allow_copy` is deprecated"):124result = pl.from_dataframe(batch, allow_copy=False)125126expected = pl.DataFrame({"a": [1, 2], "b": [3.0, 4.0]})127assert_frame_equal(result, expected)128129130def test_from_dataframe_invalid_type() -> None:131df = [[1, 2], [3, 4]]132with pytest.raises(TypeError):133pl.from_dataframe(df) # type: ignore[arg-type]134135136def test_from_dataframe_pyarrow_boolean() -> None:137df = pl.Series("a", [True, False]).to_frame()138df_pa = df.to_arrow()139140result = pl.from_dataframe(df_pa)141assert_frame_equal(result, df)142143with pytest.deprecated_call(match="`allow_copy` is deprecated"):144result = pl.from_dataframe(df_pa, allow_copy=False)145assert_frame_equal(result, df)146147148def test_from_dataframe_chunked() -> None:149df = pl.Series("a", [0, 1], dtype=pl.Int8).to_frame()150df_chunked = pl.concat([df[:1], df[1:]], rechunk=False)151152df_pa = df_chunked.to_arrow()153result = pl.from_dataframe(df_pa, rechunk=False)154155assert_frame_equal(result, df_chunked)156assert result.n_chunks() == 2157158159@pytest.mark.may_fail_auto_streaming160@pytest.mark.may_fail_cloud # reason: chunking161def test_from_dataframe_chunked_string() -> None:162df = pl.Series("a", ["a", None, "bc", "d", None, "efg"]).to_frame()163df_chunked = pl.concat([df[:1], df[1:3], df[3:]], rechunk=False)164165df_pa = df_chunked.to_arrow()166result = pl.from_dataframe(df_pa, rechunk=False)167168assert_frame_equal(result, df_chunked)169assert result.n_chunks() == 3170171172def test_from_dataframe_pandas_nan_as_null() -> None:173df = pd.Series([1.0, float("nan"), float("inf")], name="a").to_frame()174result = pl.from_dataframe(df)175expected = pl.Series("a", [1.0, None, float("inf")]).to_frame()176assert_frame_equal(result, expected)177assert result.n_chunks() == 1178179180def test_from_dataframe_pandas_boolean_bytes() -> None:181df = pd.Series([True, False], name="a").to_frame()182result = pl.from_dataframe(df)183184expected = pl.Series("a", [True, False]).to_frame()185assert_frame_equal(result, expected)186187with pytest.deprecated_call(match="`allow_copy` is deprecated"):188result = pl.from_dataframe(df, allow_copy=False)189expected = pl.Series("a", [True, False]).to_frame()190assert_frame_equal(result, expected)191192193def test_from_dataframe_categorical_pandas() -> None:194values = ["a", "b", None, "a"]195196df_pd = pd.Series(values, dtype="category", name="a").to_frame()197198result = pl.from_dataframe(df_pd)199expected = pl.Series("a", values, dtype=pl.Categorical).to_frame()200assert_frame_equal(result, expected)201202with pytest.deprecated_call(match="`allow_copy` is deprecated"):203result = pl.from_dataframe(df_pd, allow_copy=False)204expected = pl.Series("a", values, dtype=pl.Categorical).to_frame()205assert_frame_equal(result, expected)206207208def test_from_dataframe_categorical_pyarrow() -> None:209values = ["a", "b", None, "a"]210211dtype = pa.dictionary(pa.int32(), pa.utf8())212arr = pa.array(values, dtype)213df_pa = pa.Table.from_arrays([arr], names=["a"])214215result = pl.from_dataframe(df_pa)216expected = pl.Series("a", values, dtype=pl.Categorical).to_frame()217assert_frame_equal(result, expected)218219with pytest.deprecated_call(match="`allow_copy` is deprecated"):220result = pl.from_dataframe(df_pa, allow_copy=False)221assert_frame_equal(result, expected)222223224def test_from_dataframe_categorical_non_string_keys() -> None:225values = [1, 2, None, 1]226227dtype = pa.dictionary(pa.uint32(), pa.int32())228arr = pa.array(values, dtype)229df_pa = pa.Table.from_arrays([arr], names=["a"])230result = pl.from_dataframe(df_pa)231expected = pl.DataFrame({"a": [1, 2, None, 1]}, schema={"a": pl.Int32})232assert_frame_equal(result, expected)233234235def test_from_dataframe_categorical_non_u32_values() -> None:236values = [None, None]237238dtype = pa.dictionary(pa.int8(), pa.utf8())239arr = pa.array(values, dtype)240df_pa = pa.Table.from_arrays([arr], names=["a"])241242result = pl.from_dataframe(df_pa)243expected = pl.Series("a", values, dtype=pl.Categorical).to_frame()244assert_frame_equal(result, expected)245246with pytest.deprecated_call(match="`allow_copy` is deprecated"):247result = pl.from_dataframe(df_pa, allow_copy=False)248assert_frame_equal(result, expected)249250251class PatchableColumn(PolarsColumn):252"""Helper class that allows patching certain PolarsColumn properties."""253254describe_null: tuple[ColumnNullType, Any] = (ColumnNullType.USE_BITMASK, 0)255describe_categorical: dict[str, Any] = {} # type: ignore[assignment] # noqa: RUF012256null_count = 0257258259def test_column_to_series_use_sentinel_i64_min() -> None:260I64_MIN = -9223372036854775808261dtype = pl.Datetime("us")262physical = pl.Series([0, I64_MIN])263logical = physical.cast(dtype)264265col = PatchableColumn(logical)266col.describe_null = (ColumnNullType.USE_SENTINEL, I64_MIN)267col.null_count = 1268269result = _column_to_series(col, dtype, allow_copy=True)270expected = pl.Series([datetime(1970, 1, 1), None])271assert_series_equal(result, expected)272273274def test_column_to_series_duration() -> None:275s = pl.Series([timedelta(seconds=10), timedelta(days=5), None])276col = PolarsColumn(s)277result = _column_to_series(col, s.dtype, allow_copy=True)278assert_series_equal(result, s)279280281def test_column_to_series_time() -> None:282s = pl.Series([time(10, 0), time(23, 59, 59), None])283col = PolarsColumn(s)284result = _column_to_series(col, s.dtype, allow_copy=True)285assert_series_equal(result, s)286287288def test_column_to_series_use_sentinel_date() -> None:289mask_value = date(1900, 1, 1)290291s = pl.Series([date(1970, 1, 1), mask_value, date(2000, 1, 1)])292293col = PatchableColumn(s)294col.describe_null = (ColumnNullType.USE_SENTINEL, mask_value)295col.null_count = 1296297result = _column_to_series(col, pl.Date, allow_copy=True)298expected = pl.Series([date(1970, 1, 1), None, date(2000, 1, 1)])299assert_series_equal(result, expected)300301302def test_column_to_series_use_sentinel_datetime() -> None:303dtype = pl.Datetime("ns")304mask_value = datetime(1900, 1, 1)305306s = pl.Series([datetime(1970, 1, 1), mask_value, datetime(2000, 1, 1)], dtype=dtype)307308col = PatchableColumn(s)309col.describe_null = (ColumnNullType.USE_SENTINEL, mask_value)310col.null_count = 1311312result = _column_to_series(col, dtype, allow_copy=True)313expected = pl.Series(314[datetime(1970, 1, 1), None, datetime(2000, 1, 1)], dtype=dtype315)316assert_series_equal(result, expected)317318319def test_column_to_series_use_sentinel_invalid_value() -> None:320dtype = pl.Datetime("ns")321mask_value = "invalid"322323s = pl.Series([datetime(1970, 1, 1), None, datetime(2000, 1, 1)], dtype=dtype)324325col = PatchableColumn(s)326col.describe_null = (ColumnNullType.USE_SENTINEL, mask_value)327col.null_count = 1328329with pytest.raises(330TypeError,331match="invalid sentinel value for column of type Datetime\\(time_unit='ns', time_zone=None\\): 'invalid'",332):333_column_to_series(col, dtype, allow_copy=True)334335336def test_string_column_to_series_no_offsets() -> None:337s = pl.Series([97, 98, 99])338col = PolarsColumn(s)339with pytest.raises(340RuntimeError,341match="cannot create String column without an offsets buffer",342):343_string_column_to_series(col, allow_copy=True)344345346def test_categorical_column_to_series_non_dictionary() -> None:347s = pl.Series(["a", "b", None, "a"], dtype=pl.Categorical)348349col = PatchableColumn(s)350col.describe_categorical = {"is_dictionary": False}351352with pytest.raises(353NotImplementedError, match="non-dictionary categoricals are not yet supported"354):355_categorical_column_to_series(col, allow_copy=True)356357358def test_construct_data_buffer() -> None:359data = pl.Series([0, 1, 3, 3, 9], dtype=pl.Int64)360buffer = PolarsBuffer(data)361dtype = (DtypeKind.INT, 64, "l", NE)362363result = _construct_data_buffer(buffer, dtype, length=5, allow_copy=True)364assert_series_equal(result, data)365366367def test_construct_data_buffer_boolean_sliced() -> None:368data = pl.Series([False, True, True, False])369data_sliced = data[2:]370buffer = PolarsBuffer(data_sliced)371dtype = (DtypeKind.BOOL, 1, "b", NE)372373result = _construct_data_buffer(buffer, dtype, length=2, offset=2, allow_copy=True)374assert_series_equal(result, data_sliced)375376377def test_construct_data_buffer_logical_dtype() -> None:378data = pl.Series([100, 200, 300], dtype=pl.Int32)379buffer = PolarsBuffer(data)380dtype = (DtypeKind.DATETIME, 32, "tdD", NE)381382result = _construct_data_buffer(buffer, dtype, length=3, allow_copy=True)383assert_series_equal(result, data)384385386def test_construct_offsets_buffer() -> None:387data = pl.Series([0, 1, 3, 3, 9], dtype=pl.Int64)388buffer = PolarsBuffer(data)389dtype = (DtypeKind.INT, 64, "l", NE)390391result = _construct_offsets_buffer(buffer, dtype, offset=0, allow_copy=True)392assert_series_equal(result, data)393394395def test_construct_offsets_buffer_offset() -> None:396data = pl.Series([0, 1, 3, 3, 9], dtype=pl.Int64)397buffer = PolarsBuffer(data)398dtype = (DtypeKind.INT, 64, "l", NE)399offset = 2400401result = _construct_offsets_buffer(buffer, dtype, offset=offset, allow_copy=True)402assert_series_equal(result, data[offset:])403404405def test_construct_offsets_buffer_copy() -> None:406data = pl.Series([0, 1, 3, 3, 9], dtype=pl.UInt32)407buffer = PolarsBuffer(data)408dtype = (DtypeKind.UINT, 32, "I", NE)409410with pytest.raises(CopyNotAllowedError):411_construct_offsets_buffer(buffer, dtype, offset=0, allow_copy=False)412413result = _construct_offsets_buffer(buffer, dtype, offset=0, allow_copy=True)414expected = pl.Series([0, 1, 3, 3, 9], dtype=pl.Int64)415assert_series_equal(result, expected)416417418@pytest.fixture419def bitmask() -> PolarsBuffer:420data = pl.Series([False, True, True, False])421return PolarsBuffer(data)422423424@pytest.fixture425def bytemask() -> PolarsBuffer:426data = pl.Series([0, 1, 1, 0], dtype=pl.UInt8)427return PolarsBuffer(data)428429430def test_construct_validity_buffer_non_nullable() -> None:431s = pl.Series([1, 2, 3])432433col = PatchableColumn(s)434col.describe_null = (ColumnNullType.NON_NULLABLE, None)435col.null_count = 1436437result = _construct_validity_buffer(None, col, s.dtype, s, allow_copy=True)438assert result is None439440441def test_construct_validity_buffer_null_count() -> None:442s = pl.Series([1, 2, 3])443444col = PatchableColumn(s)445col.describe_null = (ColumnNullType.USE_SENTINEL, -1)446col.null_count = 0447448result = _construct_validity_buffer(None, col, s.dtype, s, allow_copy=True)449assert result is None450451452def test_construct_validity_buffer_use_bitmask(bitmask: PolarsBuffer) -> None:453s = pl.Series([1, 2, 3, 4])454455col = PatchableColumn(s)456col.describe_null = (ColumnNullType.USE_BITMASK, 0)457col.null_count = 2458459dtype = (DtypeKind.BOOL, 1, "b", NE)460validity_buffer_info = (bitmask, dtype)461462result = _construct_validity_buffer(463validity_buffer_info, col, s.dtype, s, allow_copy=True464)465expected = pl.Series([False, True, True, False])466assert_series_equal(result, expected) # type: ignore[arg-type]467468result = _construct_validity_buffer(None, col, s.dtype, s, allow_copy=True)469assert result is None470471472def test_construct_validity_buffer_use_bytemask(bytemask: PolarsBuffer) -> None:473s = pl.Series([1, 2, 3, 4])474475col = PatchableColumn(s)476col.describe_null = (ColumnNullType.USE_BYTEMASK, 0)477col.null_count = 2478479dtype = (DtypeKind.UINT, 8, "C", NE)480validity_buffer_info = (bytemask, dtype)481482result = _construct_validity_buffer(483validity_buffer_info, col, s.dtype, s, allow_copy=True484)485expected = pl.Series([False, True, True, False])486assert_series_equal(result, expected) # type: ignore[arg-type]487488result = _construct_validity_buffer(None, col, s.dtype, s, allow_copy=True)489assert result is None490491492def test_construct_validity_buffer_use_nan() -> None:493s = pl.Series([1.0, 2.0, float("nan")])494495col = PatchableColumn(s)496col.describe_null = (ColumnNullType.USE_NAN, None)497col.null_count = 1498499result = _construct_validity_buffer(None, col, s.dtype, s, allow_copy=True)500expected = pl.Series([True, True, False])501assert_series_equal(result, expected) # type: ignore[arg-type]502503with pytest.raises(CopyNotAllowedError, match="bitmask must be constructed"):504_construct_validity_buffer(None, col, s.dtype, s, allow_copy=False)505506507def test_construct_validity_buffer_use_sentinel() -> None:508s = pl.Series(["a", "bc", "NULL"])509510col = PatchableColumn(s)511col.describe_null = (ColumnNullType.USE_SENTINEL, "NULL")512col.null_count = 1513514result = _construct_validity_buffer(None, col, s.dtype, s, allow_copy=True)515expected = pl.Series([True, True, False])516assert_series_equal(result, expected) # type: ignore[arg-type]517518with pytest.raises(CopyNotAllowedError, match="bitmask must be constructed"):519_construct_validity_buffer(None, col, s.dtype, s, allow_copy=False)520521522def test_construct_validity_buffer_unsupported() -> None:523s = pl.Series([1, 2, 3])524525col = PatchableColumn(s)526col.describe_null = (100, None) # type: ignore[assignment]527col.null_count = 1528529with pytest.raises(NotImplementedError, match="unsupported null type: 100"):530_construct_validity_buffer(None, col, s.dtype, s, allow_copy=True)531532533@pytest.mark.parametrize("allow_copy", [True, False])534def test_construct_validity_buffer_from_bitmask(535allow_copy: bool, bitmask: PolarsBuffer536) -> None:537result = _construct_validity_buffer_from_bitmask(538bitmask, null_value=0, offset=0, length=4, allow_copy=allow_copy539)540expected = pl.Series([False, True, True, False])541assert_series_equal(result, expected)542543544def test_construct_validity_buffer_from_bitmask_inverted(bitmask: PolarsBuffer) -> None:545result = _construct_validity_buffer_from_bitmask(546bitmask, null_value=1, offset=0, length=4, allow_copy=True547)548expected = pl.Series([True, False, False, True])549assert_series_equal(result, expected)550551552def test_construct_validity_buffer_from_bitmask_zero_copy_fails(553bitmask: PolarsBuffer,554) -> None:555with pytest.raises(CopyNotAllowedError):556_construct_validity_buffer_from_bitmask(557bitmask, null_value=1, offset=0, length=4, allow_copy=False558)559560561def test_construct_validity_buffer_from_bitmask_sliced() -> None:562data = pl.Series([False, True, True, False])563data_sliced = data[2:]564bitmask = PolarsBuffer(data_sliced)565566result = _construct_validity_buffer_from_bitmask(567bitmask, null_value=0, offset=2, length=2, allow_copy=True568)569assert_series_equal(result, data_sliced)570571572def test_construct_validity_buffer_from_bytemask(bytemask: PolarsBuffer) -> None:573result = _construct_validity_buffer_from_bytemask(574bytemask, null_value=0, allow_copy=True575)576expected = pl.Series([False, True, True, False])577assert_series_equal(result, expected)578579580def test_construct_validity_buffer_from_bytemask_inverted(581bytemask: PolarsBuffer,582) -> None:583result = _construct_validity_buffer_from_bytemask(584bytemask, null_value=1, allow_copy=True585)586expected = pl.Series([True, False, False, True])587assert_series_equal(result, expected)588589590def test_construct_validity_buffer_from_bytemask_zero_copy_fails(591bytemask: PolarsBuffer,592) -> None:593with pytest.raises(CopyNotAllowedError):594_construct_validity_buffer_from_bytemask(595bytemask, null_value=0, allow_copy=False596)597598599def test_interchange_protocol_fallback(monkeypatch: pytest.MonkeyPatch) -> None:600df_pd = pd.DataFrame({"a": [1, 2, 3]})601monkeypatch.setattr(df_pd, "__arrow_c_stream__", lambda *args, **kwargs: 1 / 0)602with pytest.warns(603UserWarning, match="Falling back to Dataframe Interchange Protocol"604):605result = pl.from_dataframe(df_pd)606expected = pl.DataFrame({"a": [1, 2, 3]})607assert_frame_equal(result, expected)608609610def test_to_pandas_int8_20316() -> None:611df = pl.Series("a", [None], pl.Int8).to_frame()612df_pd = df.to_pandas(use_pyarrow_extension_array=True)613result = pl.from_dataframe(df_pd)614assert_frame_equal(result, df)615616617