Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/interchange/test_utils.py
6939 views
1
from __future__ import annotations
2
3
from typing import TYPE_CHECKING
4
5
import pytest
6
7
import polars as pl
8
from polars.interchange.protocol import DtypeKind, Endianness
9
from polars.interchange.utils import (
10
dtype_to_polars_dtype,
11
get_buffer_length_in_elements,
12
polars_dtype_to_data_buffer_dtype,
13
polars_dtype_to_dtype,
14
)
15
16
if TYPE_CHECKING:
17
from polars._typing import PolarsDataType
18
from polars.interchange.protocol import Dtype
19
20
NE = Endianness.NATIVE
21
22
23
@pytest.mark.parametrize(
24
("polars_dtype", "dtype"),
25
[
26
(pl.Int8, (DtypeKind.INT, 8, "c", NE)),
27
(pl.Int16, (DtypeKind.INT, 16, "s", NE)),
28
(pl.Int32, (DtypeKind.INT, 32, "i", NE)),
29
(pl.Int64, (DtypeKind.INT, 64, "l", NE)),
30
(pl.UInt8, (DtypeKind.UINT, 8, "C", NE)),
31
(pl.UInt16, (DtypeKind.UINT, 16, "S", NE)),
32
(pl.UInt32, (DtypeKind.UINT, 32, "I", NE)),
33
(pl.UInt64, (DtypeKind.UINT, 64, "L", NE)),
34
(pl.Float32, (DtypeKind.FLOAT, 32, "f", NE)),
35
(pl.Float64, (DtypeKind.FLOAT, 64, "g", NE)),
36
(pl.Boolean, (DtypeKind.BOOL, 1, "b", NE)),
37
(pl.String, (DtypeKind.STRING, 8, "U", NE)),
38
(pl.Date, (DtypeKind.DATETIME, 32, "tdD", NE)),
39
(pl.Time, (DtypeKind.DATETIME, 64, "ttu", NE)),
40
(pl.Duration, (DtypeKind.DATETIME, 64, "tDu", NE)),
41
(pl.Duration(time_unit="ns"), (DtypeKind.DATETIME, 64, "tDn", NE)),
42
(pl.Datetime, (DtypeKind.DATETIME, 64, "tsu:", NE)),
43
(pl.Datetime(time_unit="ms"), (DtypeKind.DATETIME, 64, "tsm:", NE)),
44
(
45
pl.Datetime(time_zone="Amsterdam/Europe"),
46
(DtypeKind.DATETIME, 64, "tsu:Amsterdam/Europe", NE),
47
),
48
(
49
pl.Datetime(time_unit="ns", time_zone="Asia/Seoul"),
50
(DtypeKind.DATETIME, 64, "tsn:Asia/Seoul", NE),
51
),
52
],
53
)
54
def test_dtype_conversions(polars_dtype: PolarsDataType, dtype: Dtype) -> None:
55
assert polars_dtype_to_dtype(polars_dtype) == dtype
56
assert dtype_to_polars_dtype(dtype) == polars_dtype
57
58
59
@pytest.mark.parametrize(
60
"dtype",
61
[
62
(DtypeKind.CATEGORICAL, 32, "I", NE),
63
(DtypeKind.CATEGORICAL, 8, "C", NE),
64
],
65
)
66
def test_dtype_to_polars_dtype_categorical(dtype: Dtype) -> None:
67
assert dtype_to_polars_dtype(dtype) == pl.Enum
68
69
70
@pytest.mark.parametrize(
71
"polars_dtype",
72
[
73
pl.Categorical,
74
pl.Categorical("lexical"),
75
pl.Enum,
76
pl.Enum(["a", "b"]),
77
],
78
)
79
def test_polars_dtype_to_dtype_categorical(polars_dtype: PolarsDataType) -> None:
80
assert polars_dtype_to_dtype(polars_dtype) == (DtypeKind.CATEGORICAL, 32, "I", NE)
81
82
83
def test_polars_dtype_to_dtype_unsupported_type() -> None:
84
polars_dtype = pl.List(pl.Int8)
85
with pytest.raises(ValueError, match="not supported"):
86
polars_dtype_to_dtype(polars_dtype)
87
88
89
def test_dtype_to_polars_dtype_unsupported_type() -> None:
90
dtype = (DtypeKind.FLOAT, 16, "e", NE)
91
with pytest.raises(
92
NotImplementedError,
93
match="unsupported data type: \\(<DtypeKind.FLOAT: 2>, 16, 'e', '='\\)",
94
):
95
dtype_to_polars_dtype(dtype)
96
97
98
def test_dtype_to_polars_dtype_unsupported_temporal_type() -> None:
99
dtype = (DtypeKind.DATETIME, 64, "tss:", NE)
100
with pytest.raises(
101
NotImplementedError,
102
match="unsupported temporal data type: \\(<DtypeKind.DATETIME: 22>, 64, 'tss:', '='\\)",
103
):
104
dtype_to_polars_dtype(dtype)
105
106
107
@pytest.mark.parametrize(
108
("dtype", "expected"),
109
[
110
((DtypeKind.INT, 64, "l", NE), 3),
111
((DtypeKind.UINT, 32, "I", NE), 6),
112
],
113
)
114
def test_get_buffer_length_in_elements(dtype: Dtype, expected: int) -> None:
115
assert get_buffer_length_in_elements(24, dtype) == expected
116
117
118
def test_get_buffer_length_in_elements_unsupported_dtype() -> None:
119
dtype = (DtypeKind.BOOL, 1, "b", NE)
120
with pytest.raises(
121
ValueError,
122
match="cannot get buffer length for buffer with dtype \\(<DtypeKind.BOOL: 20>, 1, 'b', '='\\)",
123
):
124
get_buffer_length_in_elements(24, dtype)
125
126
127
@pytest.mark.parametrize(
128
("dtype", "expected"),
129
[
130
(pl.Int8, pl.Int8),
131
(pl.Date, pl.Int32),
132
(pl.Time, pl.Int64),
133
(pl.String, pl.UInt8),
134
(pl.Enum, pl.UInt32),
135
],
136
)
137
def test_polars_dtype_to_data_buffer_dtype(
138
dtype: PolarsDataType, expected: PolarsDataType
139
) -> None:
140
assert polars_dtype_to_data_buffer_dtype(dtype) == expected
141
142
143
def test_polars_dtype_to_data_buffer_dtype_unsupported_dtype() -> None:
144
dtype = pl.List(pl.Int8)
145
with pytest.raises(NotImplementedError):
146
polars_dtype_to_data_buffer_dtype(dtype)
147
148