Path: blob/main/py-polars/tests/unit/interchange/test_utils.py
8420 views
from __future__ import annotations12from typing import TYPE_CHECKING34import pytest56import polars as pl7from polars.interchange.protocol import DtypeKind, Endianness8from polars.interchange.utils import (9dtype_to_polars_dtype,10get_buffer_length_in_elements,11polars_dtype_to_data_buffer_dtype,12polars_dtype_to_dtype,13)1415if TYPE_CHECKING:16from polars._typing import PolarsDataType17from polars.interchange.protocol import Dtype1819NE = Endianness.NATIVE202122@pytest.mark.parametrize(23("polars_dtype", "dtype"),24[25(pl.Int8, (DtypeKind.INT, 8, "c", NE)),26(pl.Int16, (DtypeKind.INT, 16, "s", NE)),27(pl.Int32, (DtypeKind.INT, 32, "i", NE)),28(pl.Int64, (DtypeKind.INT, 64, "l", NE)),29(pl.UInt8, (DtypeKind.UINT, 8, "C", NE)),30(pl.UInt16, (DtypeKind.UINT, 16, "S", NE)),31(pl.UInt32, (DtypeKind.UINT, 32, "I", NE)),32(pl.UInt64, (DtypeKind.UINT, 64, "L", NE)),33(pl.Float32, (DtypeKind.FLOAT, 32, "f", NE)),34(pl.Float64, (DtypeKind.FLOAT, 64, "g", NE)),35(pl.Boolean, (DtypeKind.BOOL, 1, "b", NE)),36(pl.String, (DtypeKind.STRING, 8, "U", NE)),37(pl.Date, (DtypeKind.DATETIME, 32, "tdD", NE)),38(pl.Time, (DtypeKind.DATETIME, 64, "ttu", NE)),39(pl.Duration, (DtypeKind.DATETIME, 64, "tDu", NE)),40(pl.Duration(time_unit="ns"), (DtypeKind.DATETIME, 64, "tDn", NE)),41(pl.Datetime, (DtypeKind.DATETIME, 64, "tsu:", NE)),42(pl.Datetime(time_unit="ms"), (DtypeKind.DATETIME, 64, "tsm:", NE)),43(44pl.Datetime(time_zone="Amsterdam/Europe"),45(DtypeKind.DATETIME, 64, "tsu:Amsterdam/Europe", NE),46),47(48pl.Datetime(time_unit="ns", time_zone="Asia/Seoul"),49(DtypeKind.DATETIME, 64, "tsn:Asia/Seoul", NE),50),51],52)53def test_dtype_conversions(polars_dtype: PolarsDataType, dtype: Dtype) -> None:54assert polars_dtype_to_dtype(polars_dtype) == dtype55assert dtype_to_polars_dtype(dtype) == polars_dtype565758@pytest.mark.parametrize(59"dtype",60[61(DtypeKind.CATEGORICAL, 32, "I", NE),62(DtypeKind.CATEGORICAL, 8, "C", NE),63],64)65def test_dtype_to_polars_dtype_categorical(dtype: Dtype) -> None:66assert dtype_to_polars_dtype(dtype) == pl.Enum676869@pytest.mark.parametrize(70"polars_dtype",71[72pl.Categorical,73pl.Categorical(),74pl.Enum,75pl.Enum(["a", "b"]),76],77)78def test_polars_dtype_to_dtype_categorical(polars_dtype: PolarsDataType) -> None:79assert polars_dtype_to_dtype(polars_dtype) == (DtypeKind.CATEGORICAL, 32, "I", NE)808182def test_polars_dtype_to_dtype_unsupported_type() -> None:83polars_dtype = pl.List(pl.Int8)84with pytest.raises(ValueError, match="not supported"):85polars_dtype_to_dtype(polars_dtype)868788def test_dtype_to_polars_dtype_unsupported_temporal_type() -> None:89dtype = (DtypeKind.DATETIME, 64, "tss:", NE)90with pytest.raises(91NotImplementedError,92match=r"unsupported temporal data type: \(<DtypeKind\.DATETIME: 22>, 64, 'tss:', '='\)",93):94dtype_to_polars_dtype(dtype)959697@pytest.mark.parametrize(98("dtype", "expected"),99[100((DtypeKind.INT, 64, "l", NE), 3),101((DtypeKind.UINT, 32, "I", NE), 6),102],103)104def test_get_buffer_length_in_elements(dtype: Dtype, expected: int) -> None:105assert get_buffer_length_in_elements(24, dtype) == expected106107108def test_get_buffer_length_in_elements_unsupported_dtype() -> None:109dtype = (DtypeKind.BOOL, 1, "b", NE)110with pytest.raises(111ValueError,112match=r"cannot get buffer length for buffer with dtype \(<DtypeKind\.BOOL: 20>, 1, 'b', '='\)",113):114get_buffer_length_in_elements(24, dtype)115116117@pytest.mark.parametrize(118("dtype", "expected"),119[120(pl.Int8, pl.Int8),121(pl.Date, pl.Int32),122(pl.Time, pl.Int64),123(pl.String, pl.UInt8),124(pl.Enum, pl.UInt32),125],126)127def test_polars_dtype_to_data_buffer_dtype(128dtype: PolarsDataType, expected: PolarsDataType129) -> None:130assert polars_dtype_to_data_buffer_dtype(dtype) == expected131132133def test_polars_dtype_to_data_buffer_dtype_unsupported_dtype() -> None:134dtype = pl.List(pl.Int8)135with pytest.raises(NotImplementedError):136polars_dtype_to_data_buffer_dtype(dtype)137138139