Path: blob/main/py-polars/tests/unit/interchange/test_column.py
6939 views
from __future__ import annotations12from datetime import datetime3from typing import TYPE_CHECKING45import pytest67import polars as pl8from polars.interchange.column import PolarsColumn9from polars.interchange.protocol import ColumnNullType, CopyNotAllowedError, DtypeKind10from polars.testing import assert_series_equal1112if TYPE_CHECKING:13from polars.interchange.protocol import Dtype141516def test_size() -> None:17s = pl.Series([1, 2, 3])18col = PolarsColumn(s)19assert col.size() == 3202122def test_offset() -> None:23s = pl.Series([1, 2, 3])24col = PolarsColumn(s)25assert col.offset == 0262728def test_dtype_int() -> None:29s = pl.Series([1, 2, 3], dtype=pl.Int32)30col = PolarsColumn(s)31assert col.dtype == (DtypeKind.INT, 32, "i", "=")323334def test_dtype_categorical() -> None:35s = pl.Series(["a", "b", "a"], dtype=pl.Categorical)36col = PolarsColumn(s)37assert col.dtype == (DtypeKind.CATEGORICAL, 32, "I", "=")383940def test_describe_categorical() -> None:41s = pl.Series(["b", "a", "a", "c", None, "b"], dtype=pl.Categorical)42col = PolarsColumn(s)4344out = col.describe_categorical4546assert out["is_ordered"] is False47assert out["is_dictionary"] is True48assert set(out["categories"]._col) >= {"b", "a", "c"}495051def test_describe_categorical_enum() -> None:52s = pl.Series(["b", "a", "a", "c", None, "b"], dtype=pl.Enum(["a", "b", "c"]))53col = PolarsColumn(s)5455out = col.describe_categorical5657assert out["is_ordered"] is True58assert out["is_dictionary"] is True5960expected_categories = pl.Series("category", ["a", "b", "c"])61assert_series_equal(out["categories"]._col, expected_categories)626364def test_describe_categorical_other_dtype() -> None:65s = pl.Series(["a", "b", "a"], dtype=pl.String)66col = PolarsColumn(s)67with pytest.raises(TypeError):68col.describe_categorical697071def test_describe_null() -> None:72s = pl.Series([1, 2, None])73col = PolarsColumn(s)74assert col.describe_null == (ColumnNullType.USE_BITMASK, 0)757677def test_describe_null_no_null_values() -> None:78s = pl.Series([1, 2, 3])79col = PolarsColumn(s)80assert col.describe_null == (ColumnNullType.NON_NULLABLE, None)818283def test_null_count() -> None:84s = pl.Series([None, 2, None])85col = PolarsColumn(s)86assert col.null_count == 2878889def test_metadata() -> None:90s = pl.Series([1, 2])91col = PolarsColumn(s)92assert col.metadata == {}939495def test_num_chunks() -> None:96s = pl.Series([1, 2])97col = PolarsColumn(s)98assert col.num_chunks() == 199100s2 = pl.concat([s, s], rechunk=False)101col2 = s2.to_frame().__dataframe__().get_column(0)102assert col2.num_chunks() == 2103104105@pytest.mark.parametrize("n_chunks", [None, 2])106def test_get_chunks(n_chunks: int | None) -> None:107s1 = pl.Series([1, 2, 3])108s2 = pl.Series([4, 5])109s = pl.concat([s1, s2], rechunk=False)110col = PolarsColumn(s)111112out = col.get_chunks(n_chunks)113114expected = [s1, s2]115for o, e in zip(out, expected):116assert_series_equal(o._col, e)117118119def test_get_chunks_invalid_input() -> None:120s1 = pl.Series([1, 2, 3])121s2 = pl.Series([4, 5])122s = pl.concat([s1, s2], rechunk=False)123col = PolarsColumn(s)124125with pytest.raises(ValueError):126next(col.get_chunks(0))127128with pytest.raises(ValueError):129next(col.get_chunks(3))130131132def test_get_chunks_subdivided_chunks() -> None:133s1 = pl.Series([1, 2, 3])134s2 = pl.Series([4, 5])135s = pl.concat([s1, s2], rechunk=False)136col = PolarsColumn(s)137138out = col.get_chunks(4)139140chunk1 = next(out)141expected1 = pl.Series([1, 2])142assert_series_equal(chunk1._col, expected1)143144chunk2 = next(out)145expected2 = pl.Series([3])146assert_series_equal(chunk2._col, expected2)147148chunk3 = next(out)149expected3 = pl.Series([4])150assert_series_equal(chunk3._col, expected3)151152chunk4 = next(out)153expected4 = pl.Series([5])154assert_series_equal(chunk4._col, expected4)155156with pytest.raises(StopIteration):157next(out)158159160@pytest.mark.parametrize(161("series", "expected_data", "expected_dtype"),162[163(164pl.Series([1, None, 3], dtype=pl.Int16),165pl.Series([1, 0, 3], dtype=pl.Int16),166(DtypeKind.INT, 16, "s", "="),167),168(169pl.Series([-1.5, 3.0, None], dtype=pl.Float64),170pl.Series([-1.5, 3.0, 0.0], dtype=pl.Float64),171(DtypeKind.FLOAT, 64, "g", "="),172),173(174pl.Series(["a", "bc", None, "éâç"], dtype=pl.String),175pl.Series([97, 98, 99, 195, 169, 195, 162, 195, 167], dtype=pl.UInt8),176(DtypeKind.UINT, 8, "C", "="),177),178(179pl.Series(180[datetime(1988, 1, 2), None, datetime(2022, 12, 3)], dtype=pl.Datetime181),182pl.Series([568080000000000, 0, 1670025600000000], dtype=pl.Int64),183(DtypeKind.INT, 64, "l", "="),184),185# TODO: cat-rework: re-enable this with a unique named categorical.186# (187# pl.Series(["a", "b", None, "a"], dtype=pl.Categorical),188# pl.Series([0, 1, 0, 0], dtype=pl.UInt32),189# (DtypeKind.UINT, 32, "I", "="),190# ),191],192)193def test_get_buffers_data(194series: pl.Series,195expected_data: pl.Series,196expected_dtype: Dtype,197) -> None:198col = PolarsColumn(series)199200out = col.get_buffers()201202data_buffer, data_dtype = out["data"]203assert_series_equal(data_buffer._data, expected_data)204assert data_dtype == expected_dtype205206207def test_get_buffers_int() -> None:208s = pl.Series([1, 2, 3], dtype=pl.Int8)209col = PolarsColumn(s)210211out = col.get_buffers()212213data_buffer, data_dtype = out["data"]214assert_series_equal(data_buffer._data, s)215assert data_dtype == (DtypeKind.INT, 8, "c", "=")216217assert out["validity"] is None218assert out["offsets"] is None219220221def test_get_buffers_with_validity_and_offsets() -> None:222s = pl.Series(["a", "bc", None, "éâç"])223col = PolarsColumn(s)224225out = col.get_buffers()226227data_buffer, data_dtype = out["data"]228expected = pl.Series([97, 98, 99, 195, 169, 195, 162, 195, 167], dtype=pl.UInt8)229assert_series_equal(data_buffer._data, expected)230assert data_dtype == (DtypeKind.UINT, 8, "C", "=")231232validity = out["validity"]233assert validity is not None234val_buffer, val_dtype = validity235expected = pl.Series([True, True, False, True])236assert_series_equal(val_buffer._data, expected)237assert val_dtype == (DtypeKind.BOOL, 1, "b", "=")238239offsets = out["offsets"]240assert offsets is not None241offsets_buffer, offsets_dtype = offsets242expected = pl.Series([0, 1, 3, 3, 9], dtype=pl.Int64)243assert_series_equal(offsets_buffer._data, expected)244assert offsets_dtype == (DtypeKind.INT, 64, "l", "=")245246247def test_get_buffers_chunked_bitmask() -> None:248s = pl.Series([True, False], dtype=pl.Boolean)249s_chunked = pl.concat([s[:1], s[1:]], rechunk=False)250col = PolarsColumn(s_chunked)251252chunks = list(col.get_chunks())253assert chunks[0].get_buffers()["data"][0]._data.item() is True254assert chunks[1].get_buffers()["data"][0]._data.item() is False255256257def test_get_buffers_string_zero_copy_fails() -> None:258s = pl.Series("a", ["a", "bc"], dtype=pl.String)259260col = PolarsColumn(s, allow_copy=False)261262msg = "string buffers must be converted"263with pytest.raises(CopyNotAllowedError, match=msg):264col.get_buffers()265266267@pytest.mark.parametrize("allow_copy", [False, True])268def test_get_buffers_categorical(allow_copy: bool) -> None:269s = pl.Series("a", ["c", "b"], dtype=pl.Categorical)270col = PolarsColumn(s, allow_copy=allow_copy)271result = col.get_buffers()272273data_buffer, _ = result["data"]274assert len(data_buffer._data) == 2275assert data_buffer._data[0] != data_buffer._data[1]276assert data_buffer._data.dtype == pl.UInt32277278279def test_get_buffers_chunked_zero_copy_fails() -> None:280s1 = pl.Series([1, 2, 3])281s = pl.concat([s1, s1], rechunk=False)282col = PolarsColumn(s, allow_copy=False)283284with pytest.raises(285CopyNotAllowedError, match="non-contiguous buffer must be made contiguous"286):287col.get_buffers()288289290def test_wrap_data_buffer() -> None:291values = pl.Series([1, 2, 3])292col = PolarsColumn(pl.Series())293294result_buffer, result_dtype = col._wrap_data_buffer(values)295296assert_series_equal(result_buffer._data, values)297assert result_dtype == (DtypeKind.INT, 64, "l", "=")298299300def test_wrap_validity_buffer() -> None:301validity = pl.Series([True, False, True])302col = PolarsColumn(pl.Series())303304result = col._wrap_validity_buffer(validity)305306assert result is not None307308result_buffer, result_dtype = result309assert_series_equal(result_buffer._data, validity)310assert result_dtype == (DtypeKind.BOOL, 1, "b", "=")311312313def test_wrap_validity_buffer_no_nulls() -> None:314col = PolarsColumn(pl.Series())315assert col._wrap_validity_buffer(None) is None316317318def test_wrap_offsets_buffer() -> None:319offsets = pl.Series([0, 1, 3, 3, 9], dtype=pl.Int64)320col = PolarsColumn(pl.Series())321322result = col._wrap_offsets_buffer(offsets)323324assert result is not None325326result_buffer, result_dtype = result327assert_series_equal(result_buffer._data, offsets)328assert result_dtype == (DtypeKind.INT, 64, "l", "=")329330331def test_wrap_offsets_buffer_none() -> None:332col = PolarsColumn(pl.Series())333assert col._wrap_validity_buffer(None) is None334335336def test_column_unsupported_type() -> None:337s = pl.Series("a", [[4], [5, 6]])338col = PolarsColumn(s)339340# Certain column operations work341assert col.num_chunks() == 1342assert col.null_count == 0343344# Error is raised when unsupported operations are requested345with pytest.raises(ValueError, match="not supported"):346col.dtype347348349