Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/sql/test_cast.py
6939 views
1
from __future__ import annotations
2
3
from typing import Any
4
5
import pytest
6
7
import polars as pl
8
import polars.selectors as cs
9
from polars.exceptions import InvalidOperationError, SQLInterfaceError
10
from polars.testing import assert_frame_equal
11
12
13
def test_cast() -> None:
14
df = pl.DataFrame(
15
{
16
"a": [1, 2, 3, 4, 5],
17
"b": [1.1, 2.2, 3.3, 4.4, 5.5],
18
"c": ["a", "b", "c", "d", "e"],
19
"d": [True, False, True, False, True],
20
"e": [-1, 0, None, 1, 2],
21
}
22
)
23
24
# test various dtype casts, using standard ("CAST <col> AS <dtype>")
25
# and postgres-specific ("<col>::<dtype>") cast syntax
26
with pl.SQLContext(df=df, eager=True) as ctx:
27
res = ctx.execute(
28
"""
29
SELECT
30
-- float
31
CAST(a AS DOUBLE PRECISION) AS a_f64,
32
a::real AS a_f32,
33
b::float(24) AS b_f32,
34
b::float(25) AS b_f64,
35
e::float8 AS e_f64,
36
e::float4 AS e_f32,
37
38
-- integer
39
CAST(b AS TINYINT) AS b_i8,
40
CAST(b AS SMALLINT) AS b_i16,
41
b::bigint AS b_i64,
42
d::tinyint AS d_i8,
43
d::hugeint AS d_i128,
44
a::int1 AS a_i8,
45
a::int2 AS a_i16,
46
a::int4 AS a_i32,
47
a::int8 AS a_i64,
48
49
-- unsigned integer
50
CAST(a AS TINYINT UNSIGNED) AS a_u8,
51
d::uint1 AS d_u8,
52
a::uint2 AS a_u16,
53
b::uint4 AS b_u32,
54
b::uint8 AS b_u64,
55
CAST(a AS BIGINT UNSIGNED) AS a_u64,
56
b::utinyint AS b_u8,
57
b::usmallint AS b_u16,
58
a::uinteger AS a_u32,
59
d::ubigint AS d_u64,
60
61
-- string/binary
62
CAST(a AS CHAR) AS a_char,
63
CAST(b AS VARCHAR) AS b_varchar,
64
c::blob AS c_blob,
65
c::bytes AS c_bytes,
66
c::VARBINARY AS c_varbinary,
67
CAST(d AS CHARACTER VARYING) AS d_charvar,
68
69
-- boolean
70
e::bool AS e_bool,
71
e::boolean AS e_boolean
72
FROM df
73
"""
74
)
75
assert res.schema == {
76
"a_f64": pl.Float64,
77
"a_f32": pl.Float32,
78
"b_f32": pl.Float32,
79
"b_f64": pl.Float64,
80
"e_f64": pl.Float64,
81
"e_f32": pl.Float32,
82
"b_i8": pl.Int8,
83
"b_i16": pl.Int16,
84
"b_i64": pl.Int64,
85
"d_i8": pl.Int8,
86
"d_i128": pl.Int128,
87
"a_i8": pl.Int8,
88
"a_i16": pl.Int16,
89
"a_i32": pl.Int32,
90
"a_i64": pl.Int64,
91
"a_u8": pl.UInt8,
92
"d_u8": pl.UInt8,
93
"a_u16": pl.UInt16,
94
"b_u32": pl.UInt32,
95
"b_u64": pl.UInt64,
96
"a_u64": pl.UInt64,
97
"b_u8": pl.UInt8,
98
"b_u16": pl.UInt16,
99
"a_u32": pl.UInt32,
100
"d_u64": pl.UInt64,
101
"a_char": pl.String,
102
"b_varchar": pl.String,
103
"c_blob": pl.Binary,
104
"c_bytes": pl.Binary,
105
"c_varbinary": pl.Binary,
106
"d_charvar": pl.String,
107
"e_bool": pl.Boolean,
108
"e_boolean": pl.Boolean,
109
}
110
assert res.select(cs.by_dtype(pl.Float32)).rows() == pytest.approx(
111
[
112
(1.0, 1.100000023841858, -1.0),
113
(2.0, 2.200000047683716, 0.0),
114
(3.0, 3.299999952316284, None),
115
(4.0, 4.400000095367432, 1.0),
116
(5.0, 5.5, 2.0),
117
]
118
)
119
assert res.select(cs.by_dtype(pl.Float64)).rows() == [
120
(1.0, 1.1, -1.0),
121
(2.0, 2.2, 0.0),
122
(3.0, 3.3, None),
123
(4.0, 4.4, 1.0),
124
(5.0, 5.5, 2.0),
125
]
126
assert res.select(cs.integer()).rows() == [
127
(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1),
128
(2, 2, 2, 0, 0, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 0),
129
(3, 3, 3, 1, 1, 3, 3, 3, 3, 3, 1, 3, 3, 3, 3, 3, 3, 3, 1),
130
(4, 4, 4, 0, 0, 4, 4, 4, 4, 4, 0, 4, 4, 4, 4, 4, 4, 4, 0),
131
(5, 5, 5, 1, 1, 5, 5, 5, 5, 5, 1, 5, 5, 5, 5, 5, 5, 5, 1),
132
]
133
assert res.select(cs.string()).rows() == [
134
("1", "1.1", "true"),
135
("2", "2.2", "false"),
136
("3", "3.3", "true"),
137
("4", "4.4", "false"),
138
("5", "5.5", "true"),
139
]
140
assert res.select(cs.binary()).rows() == [
141
(b"a", b"a", b"a"),
142
(b"b", b"b", b"b"),
143
(b"c", b"c", b"c"),
144
(b"d", b"d", b"d"),
145
(b"e", b"e", b"e"),
146
]
147
assert res.select(cs.boolean()).rows() == [
148
(True, True),
149
(False, False),
150
(None, None),
151
(True, True),
152
(True, True),
153
]
154
155
with pytest.raises(
156
SQLInterfaceError,
157
match="use of FORMAT is not currently supported in CAST",
158
):
159
pl.SQLContext(df=df, eager=True).execute(
160
"SELECT CAST(a AS STRING FORMAT 'HEX') FROM df"
161
)
162
163
164
@pytest.mark.parametrize(
165
("values", "cast_op", "error"),
166
[
167
([1.0, -1.0], "values::uint8", "conversion from `f64` to `u64` failed"),
168
([10, 0, -1], "values::uint4", "conversion from `i64` to `u32` failed"),
169
([int(1e8)], "values::int1", "conversion from `i64` to `i8` failed"),
170
(["a", "b"], "values::date", "conversion from `str` to `date` failed"),
171
(["a", "b"], "values::time", "conversion from `str` to `time` failed"),
172
(["a", "b"], "values::int4", "conversion from `str` to `i32` failed"),
173
],
174
)
175
def test_cast_errors(values: Any, cast_op: str, error: str) -> None:
176
df = pl.DataFrame({"values": values})
177
178
# invalid CAST should raise an error...
179
with pytest.raises(InvalidOperationError, match=error):
180
df.sql(f"SELECT {cast_op} FROM self")
181
182
# ... or return `null` values if using TRY_CAST
183
target_type = cast_op.split("::")[1]
184
res = df.sql(f"SELECT TRY_CAST(values AS {target_type}) AS cast_values FROM self")
185
assert None in res.to_series()
186
187
188
@pytest.mark.may_fail_cloud # reason: eager construct to_struct
189
@pytest.mark.xfail # this is a construct we cannot deal with anymore
190
def test_cast_json() -> None:
191
df = pl.DataFrame({"txt": ['{"a":[1,2,3],"b":["x","y","z"],"c":5.0}']})
192
193
with pl.SQLContext(df=df, eager=True) as ctx:
194
for json_cast in ("txt::json", "CAST(txt AS JSON)"):
195
res = ctx.execute(f"SELECT {json_cast} AS j FROM df")
196
197
assert res.schema == {
198
"j": pl.Struct(
199
{
200
"a": pl.List(pl.Int64),
201
"b": pl.List(pl.String),
202
"c": pl.Float64,
203
},
204
)
205
}
206
assert_frame_equal(
207
res.unnest("j"),
208
pl.DataFrame(
209
{
210
"a": [[1, 2, 3]],
211
"b": [["x", "y", "z"]],
212
"c": [5.0],
213
}
214
),
215
)
216
217