Path: blob/main/py-polars/tests/unit/interchange/test_utils.py
6939 views
from __future__ import annotations12from typing import TYPE_CHECKING34import pytest56import polars as pl7from polars.interchange.protocol import DtypeKind, Endianness8from polars.interchange.utils import (9dtype_to_polars_dtype,10get_buffer_length_in_elements,11polars_dtype_to_data_buffer_dtype,12polars_dtype_to_dtype,13)1415if TYPE_CHECKING:16from polars._typing import PolarsDataType17from polars.interchange.protocol import Dtype1819NE = Endianness.NATIVE202122@pytest.mark.parametrize(23("polars_dtype", "dtype"),24[25(pl.Int8, (DtypeKind.INT, 8, "c", NE)),26(pl.Int16, (DtypeKind.INT, 16, "s", NE)),27(pl.Int32, (DtypeKind.INT, 32, "i", NE)),28(pl.Int64, (DtypeKind.INT, 64, "l", NE)),29(pl.UInt8, (DtypeKind.UINT, 8, "C", NE)),30(pl.UInt16, (DtypeKind.UINT, 16, "S", NE)),31(pl.UInt32, (DtypeKind.UINT, 32, "I", NE)),32(pl.UInt64, (DtypeKind.UINT, 64, "L", NE)),33(pl.Float32, (DtypeKind.FLOAT, 32, "f", NE)),34(pl.Float64, (DtypeKind.FLOAT, 64, "g", NE)),35(pl.Boolean, (DtypeKind.BOOL, 1, "b", NE)),36(pl.String, (DtypeKind.STRING, 8, "U", NE)),37(pl.Date, (DtypeKind.DATETIME, 32, "tdD", NE)),38(pl.Time, (DtypeKind.DATETIME, 64, "ttu", NE)),39(pl.Duration, (DtypeKind.DATETIME, 64, "tDu", NE)),40(pl.Duration(time_unit="ns"), (DtypeKind.DATETIME, 64, "tDn", NE)),41(pl.Datetime, (DtypeKind.DATETIME, 64, "tsu:", NE)),42(pl.Datetime(time_unit="ms"), (DtypeKind.DATETIME, 64, "tsm:", NE)),43(44pl.Datetime(time_zone="Amsterdam/Europe"),45(DtypeKind.DATETIME, 64, "tsu:Amsterdam/Europe", NE),46),47(48pl.Datetime(time_unit="ns", time_zone="Asia/Seoul"),49(DtypeKind.DATETIME, 64, "tsn:Asia/Seoul", NE),50),51],52)53def test_dtype_conversions(polars_dtype: PolarsDataType, dtype: Dtype) -> None:54assert polars_dtype_to_dtype(polars_dtype) == dtype55assert dtype_to_polars_dtype(dtype) == polars_dtype565758@pytest.mark.parametrize(59"dtype",60[61(DtypeKind.CATEGORICAL, 32, "I", NE),62(DtypeKind.CATEGORICAL, 8, "C", NE),63],64)65def test_dtype_to_polars_dtype_categorical(dtype: Dtype) -> None:66assert dtype_to_polars_dtype(dtype) == pl.Enum676869@pytest.mark.parametrize(70"polars_dtype",71[72pl.Categorical,73pl.Categorical("lexical"),74pl.Enum,75pl.Enum(["a", "b"]),76],77)78def test_polars_dtype_to_dtype_categorical(polars_dtype: PolarsDataType) -> None:79assert polars_dtype_to_dtype(polars_dtype) == (DtypeKind.CATEGORICAL, 32, "I", NE)808182def test_polars_dtype_to_dtype_unsupported_type() -> None:83polars_dtype = pl.List(pl.Int8)84with pytest.raises(ValueError, match="not supported"):85polars_dtype_to_dtype(polars_dtype)868788def test_dtype_to_polars_dtype_unsupported_type() -> None:89dtype = (DtypeKind.FLOAT, 16, "e", NE)90with pytest.raises(91NotImplementedError,92match="unsupported data type: \\(<DtypeKind.FLOAT: 2>, 16, 'e', '='\\)",93):94dtype_to_polars_dtype(dtype)959697def test_dtype_to_polars_dtype_unsupported_temporal_type() -> None:98dtype = (DtypeKind.DATETIME, 64, "tss:", NE)99with pytest.raises(100NotImplementedError,101match="unsupported temporal data type: \\(<DtypeKind.DATETIME: 22>, 64, 'tss:', '='\\)",102):103dtype_to_polars_dtype(dtype)104105106@pytest.mark.parametrize(107("dtype", "expected"),108[109((DtypeKind.INT, 64, "l", NE), 3),110((DtypeKind.UINT, 32, "I", NE), 6),111],112)113def test_get_buffer_length_in_elements(dtype: Dtype, expected: int) -> None:114assert get_buffer_length_in_elements(24, dtype) == expected115116117def test_get_buffer_length_in_elements_unsupported_dtype() -> None:118dtype = (DtypeKind.BOOL, 1, "b", NE)119with pytest.raises(120ValueError,121match="cannot get buffer length for buffer with dtype \\(<DtypeKind.BOOL: 20>, 1, 'b', '='\\)",122):123get_buffer_length_in_elements(24, dtype)124125126@pytest.mark.parametrize(127("dtype", "expected"),128[129(pl.Int8, pl.Int8),130(pl.Date, pl.Int32),131(pl.Time, pl.Int64),132(pl.String, pl.UInt8),133(pl.Enum, pl.UInt32),134],135)136def test_polars_dtype_to_data_buffer_dtype(137dtype: PolarsDataType, expected: PolarsDataType138) -> None:139assert polars_dtype_to_data_buffer_dtype(dtype) == expected140141142def test_polars_dtype_to_data_buffer_dtype_unsupported_dtype() -> None:143dtype = pl.List(pl.Int8)144with pytest.raises(NotImplementedError):145polars_dtype_to_data_buffer_dtype(dtype)146147148