Path: blob/main/py-polars/tests/unit/interchange/test_from_dataframe.py
8415 views
from __future__ import annotations12from datetime import date, datetime, time, timedelta3from typing import TYPE_CHECKING, Any45import pandas as pd6import pyarrow as pa7import pytest89import polars as pl10from polars.interchange.buffer import PolarsBuffer11from polars.interchange.column import PolarsColumn12from polars.interchange.from_dataframe import (13_categorical_column_to_series,14_column_to_series,15_construct_data_buffer,16_construct_offsets_buffer,17_construct_validity_buffer,18_construct_validity_buffer_from_bitmask,19_construct_validity_buffer_from_bytemask,20_string_column_to_series,21)22from polars.interchange.protocol import (23ColumnNullType,24CopyNotAllowedError,25DtypeKind,26Endianness,27)28from polars.testing import assert_frame_equal, assert_series_equal2930NE = Endianness.NATIVE3132if TYPE_CHECKING:33from tests.conftest import PlMonkeyPatch343536def test_from_dataframe_polars() -> None:37df = pl.DataFrame({"a": [1, 2], "b": [3.0, 4.0], "c": ["foo", "bar"]})38with pytest.deprecated_call(match="`allow_copy` is deprecated"):39result = pl.from_dataframe(df, allow_copy=False)40assert_frame_equal(result, df)414243def test_from_dataframe_polars_interchange_fast_path() -> None:44df = pl.DataFrame(45{"a": [1, 2], "b": [3.0, 4.0], "c": ["foo", "bar"]},46schema_overrides={"c": pl.Categorical},47)48dfi = df.__dataframe__()49with pytest.deprecated_call(match="`allow_copy` is deprecated"):50result = pl.from_dataframe(dfi, allow_copy=False)51assert_frame_equal(result, df)525354def test_from_dataframe_categorical() -> None:55df = pl.DataFrame({"a": ["foo", "bar"]}, schema={"a": pl.Categorical})56df_pa = df.to_arrow()5758with pytest.deprecated_call(match="`allow_copy` is deprecated"):59result = pl.from_dataframe(df_pa, allow_copy=True)60expected = pl.DataFrame({"a": ["foo", "bar"]}, schema={"a": pl.Categorical})61assert_frame_equal(result, expected)626364def test_from_dataframe_empty_string_zero_copy() -> None:65df = pl.DataFrame({"a": []}, schema={"a": pl.String})66df_pa = df.to_arrow()67with pytest.deprecated_call(match="`allow_copy` is deprecated"):68result = pl.from_dataframe(df_pa, allow_copy=False)69assert_frame_equal(result, df)707172def test_from_dataframe_empty_bool_zero_copy() -> None:73df = pl.DataFrame(schema={"a": pl.Boolean})74df_pd = df.to_pandas()75with pytest.deprecated_call(match="`allow_copy` is deprecated"):76result = pl.from_dataframe(df_pd, allow_copy=False)77assert_frame_equal(result, df)787980def test_from_dataframe_empty_categories_zero_copy() -> None:81df = pl.DataFrame(schema={"a": pl.Enum([])})82df_pa = df.to_arrow()83with pytest.deprecated_call(match="`allow_copy` is deprecated"):84result = pl.from_dataframe(df_pa, allow_copy=False)85assert_frame_equal(result, df)868788def test_from_dataframe_pandas_zero_copy() -> None:89data = {"a": [1, 2], "b": [3.0, 4.0]}9091df = pd.DataFrame(data)92with pytest.deprecated_call(match="`allow_copy` is deprecated"):93result = pl.from_dataframe(df, allow_copy=False)94expected = pl.DataFrame(data)95assert_frame_equal(result, expected)969798def test_from_dataframe_pyarrow_table_zero_copy() -> None:99df = pl.DataFrame(100{101"a": [1, 2],102"b": [3.0, 4.0],103}104)105df_pa = df.to_arrow()106107with pytest.deprecated_call(match="`allow_copy` is deprecated"):108result = pl.from_dataframe(df_pa, allow_copy=False)109assert_frame_equal(result, df)110111112def test_from_dataframe_pyarrow_empty_table() -> None:113df = pl.Series("a", dtype=pl.Int8).to_frame()114df_pa = df.to_arrow()115116with pytest.deprecated_call(match="`allow_copy` is deprecated"):117result = pl.from_dataframe(df_pa, allow_copy=False)118assert_frame_equal(result, df)119120121def test_from_dataframe_pyarrow_recordbatch_zero_copy() -> None:122a = pa.array([1, 2])123b = pa.array([3.0, 4.0])124125batch = pa.record_batch([a, b], names=["a", "b"])126with pytest.deprecated_call(match="`allow_copy` is deprecated"):127result = pl.from_dataframe(batch, allow_copy=False)128129expected = pl.DataFrame({"a": [1, 2], "b": [3.0, 4.0]})130assert_frame_equal(result, expected)131132133def test_from_dataframe_invalid_type() -> None:134df = [[1, 2], [3, 4]]135with pytest.raises(TypeError):136pl.from_dataframe(df) # type: ignore[arg-type]137138139def test_from_dataframe_pyarrow_boolean() -> None:140df = pl.Series("a", [True, False]).to_frame()141df_pa = df.to_arrow()142143result = pl.from_dataframe(df_pa)144assert_frame_equal(result, df)145146with pytest.deprecated_call(match="`allow_copy` is deprecated"):147result = pl.from_dataframe(df_pa, allow_copy=False)148assert_frame_equal(result, df)149150151def test_from_dataframe_chunked() -> None:152df = pl.Series("a", [0, 1], dtype=pl.Int8).to_frame()153df_chunked = pl.concat([df[:1], df[1:]], rechunk=False)154155df_pa = df_chunked.to_arrow()156result = pl.from_dataframe(df_pa, rechunk=False)157158assert_frame_equal(result, df_chunked)159assert result.n_chunks() == 2160161162@pytest.mark.may_fail_auto_streaming163@pytest.mark.may_fail_cloud # reason: chunking164def test_from_dataframe_chunked_string() -> None:165df = pl.Series("a", ["a", None, "bc", "d", None, "efg"]).to_frame()166df_chunked = pl.concat([df[:1], df[1:3], df[3:]], rechunk=False)167168df_pa = df_chunked.to_arrow()169result = pl.from_dataframe(df_pa, rechunk=False)170171assert_frame_equal(result, df_chunked)172assert result.n_chunks() == 3173174175def test_from_dataframe_pandas_nan_as_null() -> None:176df = pd.Series([1.0, float("nan"), float("inf")], name="a").to_frame()177result = pl.from_dataframe(df)178expected = pl.Series("a", [1.0, None, float("inf")]).to_frame()179assert_frame_equal(result, expected)180assert result.n_chunks() == 1181182183def test_from_dataframe_pandas_boolean_bytes() -> None:184df = pd.Series([True, False], name="a").to_frame()185result = pl.from_dataframe(df)186187expected = pl.Series("a", [True, False]).to_frame()188assert_frame_equal(result, expected)189190with pytest.deprecated_call(match="`allow_copy` is deprecated"):191result = pl.from_dataframe(df, allow_copy=False)192expected = pl.Series("a", [True, False]).to_frame()193assert_frame_equal(result, expected)194195196def test_from_dataframe_categorical_pandas() -> None:197values = ["a", "b", None, "a"]198199df_pd = pd.Series(values, dtype="category", name="a").to_frame()200201result = pl.from_dataframe(df_pd)202expected = pl.Series("a", values, dtype=pl.Categorical).to_frame()203assert_frame_equal(result, expected)204205with pytest.deprecated_call(match="`allow_copy` is deprecated"):206result = pl.from_dataframe(df_pd, allow_copy=False)207expected = pl.Series("a", values, dtype=pl.Categorical).to_frame()208assert_frame_equal(result, expected)209210211def test_from_dataframe_categorical_pyarrow() -> None:212values = ["a", "b", None, "a"]213214dtype = pa.dictionary(pa.int32(), pa.utf8())215arr = pa.array(values, dtype)216df_pa = pa.Table.from_arrays([arr], names=["a"])217218result = pl.from_dataframe(df_pa)219expected = pl.Series("a", values, dtype=pl.Categorical).to_frame()220assert_frame_equal(result, expected)221222with pytest.deprecated_call(match="`allow_copy` is deprecated"):223result = pl.from_dataframe(df_pa, allow_copy=False)224assert_frame_equal(result, expected)225226227def test_from_dataframe_categorical_non_string_keys() -> None:228values = [1, 2, None, 1]229230dtype = pa.dictionary(pa.uint32(), pa.int32())231arr = pa.array(values, dtype)232df_pa = pa.Table.from_arrays([arr], names=["a"])233result = pl.from_dataframe(df_pa)234expected = pl.DataFrame({"a": [1, 2, None, 1]}, schema={"a": pl.Int32})235assert_frame_equal(result, expected)236237238def test_from_dataframe_categorical_non_u32_values() -> None:239values = [None, None]240241dtype = pa.dictionary(pa.int8(), pa.utf8())242arr = pa.array(values, dtype)243df_pa = pa.Table.from_arrays([arr], names=["a"])244245result = pl.from_dataframe(df_pa)246expected = pl.Series("a", values, dtype=pl.Categorical).to_frame()247assert_frame_equal(result, expected)248249with pytest.deprecated_call(match="`allow_copy` is deprecated"):250result = pl.from_dataframe(df_pa, allow_copy=False)251assert_frame_equal(result, expected)252253254class PatchableColumn(PolarsColumn):255"""Helper class that allows patching certain PolarsColumn properties."""256257describe_null: tuple[ColumnNullType, Any] = (ColumnNullType.USE_BITMASK, 0)258describe_categorical: dict[str, Any] = {} # type: ignore[assignment] # noqa: RUF012259null_count = 0260261262def test_column_to_series_use_sentinel_i64_min() -> None:263I64_MIN = -9223372036854775808264dtype = pl.Datetime("us")265physical = pl.Series([0, I64_MIN])266logical = physical.cast(dtype)267268col = PatchableColumn(logical)269col.describe_null = (ColumnNullType.USE_SENTINEL, I64_MIN)270col.null_count = 1271272result = _column_to_series(col, dtype, allow_copy=True)273expected = pl.Series([datetime(1970, 1, 1), None])274assert_series_equal(result, expected)275276277def test_column_to_series_duration() -> None:278s = pl.Series([timedelta(seconds=10), timedelta(days=5), None])279col = PolarsColumn(s)280result = _column_to_series(col, s.dtype, allow_copy=True)281assert_series_equal(result, s)282283284def test_column_to_series_time() -> None:285s = pl.Series([time(10, 0), time(23, 59, 59), None])286col = PolarsColumn(s)287result = _column_to_series(col, s.dtype, allow_copy=True)288assert_series_equal(result, s)289290291def test_column_to_series_use_sentinel_date() -> None:292mask_value = date(1900, 1, 1)293294s = pl.Series([date(1970, 1, 1), mask_value, date(2000, 1, 1)])295296col = PatchableColumn(s)297col.describe_null = (ColumnNullType.USE_SENTINEL, mask_value)298col.null_count = 1299300result = _column_to_series(col, pl.Date, allow_copy=True)301expected = pl.Series([date(1970, 1, 1), None, date(2000, 1, 1)])302assert_series_equal(result, expected)303304305def test_column_to_series_use_sentinel_datetime() -> None:306dtype = pl.Datetime("ns")307mask_value = datetime(1900, 1, 1)308309s = pl.Series([datetime(1970, 1, 1), mask_value, datetime(2000, 1, 1)], dtype=dtype)310311col = PatchableColumn(s)312col.describe_null = (ColumnNullType.USE_SENTINEL, mask_value)313col.null_count = 1314315result = _column_to_series(col, dtype, allow_copy=True)316expected = pl.Series(317[datetime(1970, 1, 1), None, datetime(2000, 1, 1)], dtype=dtype318)319assert_series_equal(result, expected)320321322def test_column_to_series_use_sentinel_invalid_value() -> None:323dtype = pl.Datetime("ns")324mask_value = "invalid"325326s = pl.Series([datetime(1970, 1, 1), None, datetime(2000, 1, 1)], dtype=dtype)327328col = PatchableColumn(s)329col.describe_null = (ColumnNullType.USE_SENTINEL, mask_value)330col.null_count = 1331332with pytest.raises(333TypeError,334match=r"invalid sentinel value for column of type Datetime\(time_unit='ns', time_zone=None\): 'invalid'",335):336_column_to_series(col, dtype, allow_copy=True)337338339def test_string_column_to_series_no_offsets() -> None:340s = pl.Series([97, 98, 99])341col = PolarsColumn(s)342with pytest.raises(343RuntimeError,344match="cannot create String column without an offsets buffer",345):346_string_column_to_series(col, allow_copy=True)347348349def test_categorical_column_to_series_non_dictionary() -> None:350s = pl.Series(["a", "b", None, "a"], dtype=pl.Categorical)351352col = PatchableColumn(s)353col.describe_categorical = {"is_dictionary": False}354355with pytest.raises(356NotImplementedError, match="non-dictionary categoricals are not yet supported"357):358_categorical_column_to_series(col, allow_copy=True)359360361def test_construct_data_buffer() -> None:362data = pl.Series([0, 1, 3, 3, 9], dtype=pl.Int64)363buffer = PolarsBuffer(data)364dtype = (DtypeKind.INT, 64, "l", NE)365366result = _construct_data_buffer(buffer, dtype, length=5, allow_copy=True)367assert_series_equal(result, data)368369370def test_construct_data_buffer_boolean_sliced() -> None:371data = pl.Series([False, True, True, False])372data_sliced = data[2:]373buffer = PolarsBuffer(data_sliced)374dtype = (DtypeKind.BOOL, 1, "b", NE)375376result = _construct_data_buffer(buffer, dtype, length=2, offset=2, allow_copy=True)377assert_series_equal(result, data_sliced)378379380def test_construct_data_buffer_logical_dtype() -> None:381data = pl.Series([100, 200, 300], dtype=pl.Int32)382buffer = PolarsBuffer(data)383dtype = (DtypeKind.DATETIME, 32, "tdD", NE)384385result = _construct_data_buffer(buffer, dtype, length=3, allow_copy=True)386assert_series_equal(result, data)387388389def test_construct_offsets_buffer() -> None:390data = pl.Series([0, 1, 3, 3, 9], dtype=pl.Int64)391buffer = PolarsBuffer(data)392dtype = (DtypeKind.INT, 64, "l", NE)393394result = _construct_offsets_buffer(buffer, dtype, offset=0, allow_copy=True)395assert_series_equal(result, data)396397398def test_construct_offsets_buffer_offset() -> None:399data = pl.Series([0, 1, 3, 3, 9], dtype=pl.Int64)400buffer = PolarsBuffer(data)401dtype = (DtypeKind.INT, 64, "l", NE)402offset = 2403404result = _construct_offsets_buffer(buffer, dtype, offset=offset, allow_copy=True)405assert_series_equal(result, data[offset:])406407408def test_construct_offsets_buffer_copy() -> None:409data = pl.Series([0, 1, 3, 3, 9], dtype=pl.UInt32)410buffer = PolarsBuffer(data)411dtype = (DtypeKind.UINT, 32, "I", NE)412413with pytest.raises(CopyNotAllowedError):414_construct_offsets_buffer(buffer, dtype, offset=0, allow_copy=False)415416result = _construct_offsets_buffer(buffer, dtype, offset=0, allow_copy=True)417expected = pl.Series([0, 1, 3, 3, 9], dtype=pl.Int64)418assert_series_equal(result, expected)419420421@pytest.fixture422def bitmask() -> PolarsBuffer:423data = pl.Series([False, True, True, False])424return PolarsBuffer(data)425426427@pytest.fixture428def bytemask() -> PolarsBuffer:429data = pl.Series([0, 1, 1, 0], dtype=pl.UInt8)430return PolarsBuffer(data)431432433def test_construct_validity_buffer_non_nullable() -> None:434s = pl.Series([1, 2, 3])435436col = PatchableColumn(s)437col.describe_null = (ColumnNullType.NON_NULLABLE, None)438col.null_count = 1439440result = _construct_validity_buffer(None, col, s.dtype, s, allow_copy=True)441assert result is None442443444def test_construct_validity_buffer_null_count() -> None:445s = pl.Series([1, 2, 3])446447col = PatchableColumn(s)448col.describe_null = (ColumnNullType.USE_SENTINEL, -1)449col.null_count = 0450451result = _construct_validity_buffer(None, col, s.dtype, s, allow_copy=True)452assert result is None453454455def test_construct_validity_buffer_use_bitmask(bitmask: PolarsBuffer) -> None:456s = pl.Series([1, 2, 3, 4])457458col = PatchableColumn(s)459col.describe_null = (ColumnNullType.USE_BITMASK, 0)460col.null_count = 2461462dtype = (DtypeKind.BOOL, 1, "b", NE)463validity_buffer_info = (bitmask, dtype)464465result = _construct_validity_buffer(466validity_buffer_info, col, s.dtype, s, allow_copy=True467)468expected = pl.Series([False, True, True, False])469assert_series_equal(result, expected) # type: ignore[arg-type]470471result = _construct_validity_buffer(None, col, s.dtype, s, allow_copy=True)472assert result is None473474475def test_construct_validity_buffer_use_bytemask(bytemask: PolarsBuffer) -> None:476s = pl.Series([1, 2, 3, 4])477478col = PatchableColumn(s)479col.describe_null = (ColumnNullType.USE_BYTEMASK, 0)480col.null_count = 2481482dtype = (DtypeKind.UINT, 8, "C", NE)483validity_buffer_info = (bytemask, dtype)484485result = _construct_validity_buffer(486validity_buffer_info, col, s.dtype, s, allow_copy=True487)488expected = pl.Series([False, True, True, False])489assert_series_equal(result, expected) # type: ignore[arg-type]490491result = _construct_validity_buffer(None, col, s.dtype, s, allow_copy=True)492assert result is None493494495def test_construct_validity_buffer_use_nan() -> None:496s = pl.Series([1.0, 2.0, float("nan")])497498col = PatchableColumn(s)499col.describe_null = (ColumnNullType.USE_NAN, None)500col.null_count = 1501502result = _construct_validity_buffer(None, col, s.dtype, s, allow_copy=True)503expected = pl.Series([True, True, False])504assert_series_equal(result, expected) # type: ignore[arg-type]505506with pytest.raises(CopyNotAllowedError, match="bitmask must be constructed"):507_construct_validity_buffer(None, col, s.dtype, s, allow_copy=False)508509510def test_construct_validity_buffer_use_sentinel() -> None:511s = pl.Series(["a", "bc", "NULL"])512513col = PatchableColumn(s)514col.describe_null = (ColumnNullType.USE_SENTINEL, "NULL")515col.null_count = 1516517result = _construct_validity_buffer(None, col, s.dtype, s, allow_copy=True)518expected = pl.Series([True, True, False])519assert_series_equal(result, expected) # type: ignore[arg-type]520521with pytest.raises(CopyNotAllowedError, match="bitmask must be constructed"):522_construct_validity_buffer(None, col, s.dtype, s, allow_copy=False)523524525def test_construct_validity_buffer_unsupported() -> None:526s = pl.Series([1, 2, 3])527528col = PatchableColumn(s)529col.describe_null = (100, None) # type: ignore[assignment]530col.null_count = 1531532with pytest.raises(NotImplementedError, match="unsupported null type: 100"):533_construct_validity_buffer(None, col, s.dtype, s, allow_copy=True)534535536@pytest.mark.parametrize("allow_copy", [True, False])537def test_construct_validity_buffer_from_bitmask(538allow_copy: bool, bitmask: PolarsBuffer539) -> None:540result = _construct_validity_buffer_from_bitmask(541bitmask, null_value=0, offset=0, length=4, allow_copy=allow_copy542)543expected = pl.Series([False, True, True, False])544assert_series_equal(result, expected)545546547def test_construct_validity_buffer_from_bitmask_inverted(bitmask: PolarsBuffer) -> None:548result = _construct_validity_buffer_from_bitmask(549bitmask, null_value=1, offset=0, length=4, allow_copy=True550)551expected = pl.Series([True, False, False, True])552assert_series_equal(result, expected)553554555def test_construct_validity_buffer_from_bitmask_zero_copy_fails(556bitmask: PolarsBuffer,557) -> None:558with pytest.raises(CopyNotAllowedError):559_construct_validity_buffer_from_bitmask(560bitmask, null_value=1, offset=0, length=4, allow_copy=False561)562563564def test_construct_validity_buffer_from_bitmask_sliced() -> None:565data = pl.Series([False, True, True, False])566data_sliced = data[2:]567bitmask = PolarsBuffer(data_sliced)568569result = _construct_validity_buffer_from_bitmask(570bitmask, null_value=0, offset=2, length=2, allow_copy=True571)572assert_series_equal(result, data_sliced)573574575def test_construct_validity_buffer_from_bytemask(bytemask: PolarsBuffer) -> None:576result = _construct_validity_buffer_from_bytemask(577bytemask, null_value=0, allow_copy=True578)579expected = pl.Series([False, True, True, False])580assert_series_equal(result, expected)581582583def test_construct_validity_buffer_from_bytemask_inverted(584bytemask: PolarsBuffer,585) -> None:586result = _construct_validity_buffer_from_bytemask(587bytemask, null_value=1, allow_copy=True588)589expected = pl.Series([True, False, False, True])590assert_series_equal(result, expected)591592593def test_construct_validity_buffer_from_bytemask_zero_copy_fails(594bytemask: PolarsBuffer,595) -> None:596with pytest.raises(CopyNotAllowedError):597_construct_validity_buffer_from_bytemask(598bytemask, null_value=0, allow_copy=False599)600601602def test_interchange_protocol_fallback(plmonkeypatch: PlMonkeyPatch) -> None:603df_pd = pd.DataFrame({"a": [1, 2, 3]})604plmonkeypatch.setattr(df_pd, "__arrow_c_stream__", lambda *args, **kwargs: 1 / 0)605with pytest.warns(606UserWarning, match="Falling back to Dataframe Interchange Protocol"607):608result = pl.from_dataframe(df_pd)609expected = pl.DataFrame({"a": [1, 2, 3]})610assert_frame_equal(result, expected)611612613def test_to_pandas_int8_20316() -> None:614df = pl.Series("a", [None], pl.Int8).to_frame()615df_pd = df.to_pandas(use_pyarrow_extension_array=True)616result = pl.from_dataframe(df_pd)617assert_frame_equal(result, df)618619620