Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/functions/test_lit.py
6939 views
1
# mypy: disable-error-code="redundant-expr"
2
from __future__ import annotations
3
4
import enum
5
import sys
6
from datetime import datetime, timedelta
7
from decimal import Decimal
8
from typing import TYPE_CHECKING, Any
9
10
import numpy as np
11
import pytest
12
from hypothesis import given
13
14
import polars as pl
15
from polars.testing import assert_frame_equal
16
from polars.testing.parametric.strategies import series
17
from polars.testing.parametric.strategies.data import datetimes
18
19
if TYPE_CHECKING:
20
from polars._typing import PolarsDataType
21
22
23
if sys.version_info >= (3, 11):
24
from enum import StrEnum
25
26
PyStrEnum: type[enum.Enum] | None = StrEnum
27
else:
28
PyStrEnum = None
29
30
31
@pytest.mark.parametrize(
32
"input",
33
[
34
[[1, 2], [3, 4, 5]],
35
[1, 2, 3],
36
],
37
)
38
def test_lit_list_input(input: list[Any]) -> None:
39
df = pl.DataFrame({"a": [1, 2]})
40
result = df.with_columns(pl.lit(input).first())
41
expected = pl.DataFrame({"a": [1, 2], "literal": [input, input]})
42
assert_frame_equal(result, expected)
43
44
45
@pytest.mark.parametrize(
46
"input",
47
[
48
([1, 2], [3, 4, 5]),
49
(1, 2, 3),
50
],
51
)
52
def test_lit_tuple_input(input: tuple[Any, ...]) -> None:
53
df = pl.DataFrame({"a": [1, 2]})
54
result = df.with_columns(pl.lit(input).first())
55
56
expected = pl.DataFrame({"a": [1, 2], "literal": [list(input), list(input)]})
57
assert_frame_equal(result, expected)
58
59
60
def test_lit_numpy_array_input() -> None:
61
df = pl.DataFrame({"a": [1, 2]})
62
input = np.array([3, 4])
63
64
result = df.with_columns(pl.lit(input, dtype=pl.Int64))
65
66
expected = pl.DataFrame({"a": [1, 2], "literal": [3, 4]})
67
assert_frame_equal(result, expected)
68
69
70
def test_lit_ambiguous_datetimes_11379() -> None:
71
df = pl.DataFrame(
72
{
73
"ts": pl.datetime_range(
74
datetime(2020, 10, 25),
75
datetime(2020, 10, 25, 2),
76
"1h",
77
time_zone="Europe/London",
78
eager=True,
79
)
80
}
81
)
82
for i in range(df.height):
83
result = df.filter(pl.col("ts") >= df["ts"][i])
84
expected = df[i:]
85
assert_frame_equal(result, expected)
86
87
88
def test_list_datetime_11571() -> None:
89
sec_np_ns = np.timedelta64(1_000_000_000, "ns")
90
sec_np_us = np.timedelta64(1_000_000, "us")
91
assert pl.select(pl.lit(sec_np_ns))[0, 0] == timedelta(seconds=1)
92
assert pl.select(pl.lit(sec_np_us))[0, 0] == timedelta(seconds=1)
93
94
95
@pytest.mark.parametrize(
96
("input", "dtype"),
97
[
98
pytest.param(-(2**31), pl.Int32, id="i32 min"),
99
pytest.param(-(2**31) - 1, pl.Int64, id="below i32 min"),
100
pytest.param(2**31 - 1, pl.Int32, id="i32 max"),
101
pytest.param(2**31, pl.Int64, id="above i32 max"),
102
pytest.param(2**63 - 1, pl.Int64, id="i64 max"),
103
pytest.param(2**63, pl.UInt64, id="above i64 max"),
104
],
105
)
106
def test_lit_int_return_type(input: int, dtype: PolarsDataType) -> None:
107
assert pl.select(pl.lit(input)).to_series().dtype == dtype
108
109
110
def test_lit_unsupported_type() -> None:
111
with pytest.raises(
112
TypeError,
113
match="cannot create expression literal for value of type LazyFrame",
114
):
115
pl.lit(pl.LazyFrame({"a": [1, 2, 3]}))
116
117
118
@pytest.mark.parametrize(
119
"EnumBase",
120
[
121
(enum.Enum,),
122
(str, enum.Enum),
123
*([(PyStrEnum,)] if PyStrEnum is not None else []),
124
],
125
)
126
def test_lit_enum_input_16668(EnumBase: tuple[type, ...]) -> None:
127
# https://github.com/pola-rs/polars/issues/16668
128
129
class State(*EnumBase): # type: ignore[misc]
130
NSW = "New South Wales"
131
QLD = "Queensland"
132
VIC = "Victoria"
133
134
# validate that frame schema has inferred the enum
135
df = pl.DataFrame({"state": [State.NSW, State.VIC]})
136
assert df.schema == {
137
"state": pl.Enum(["New South Wales", "Queensland", "Victoria"])
138
}
139
140
# check use of enum as lit/constraint
141
value = State.VIC
142
expected = "Victoria"
143
144
for lit_value in (
145
pl.lit(value),
146
pl.lit(value.value), # type: ignore[attr-defined]
147
):
148
assert pl.select(lit_value).item() == expected
149
assert df.filter(state=value).item() == expected
150
assert df.filter(state=lit_value).item() == expected
151
152
assert df.filter(pl.col("state") == State.QLD).is_empty()
153
assert df.filter(pl.col("state") != State.QLD).height == 2
154
155
156
@pytest.mark.parametrize(
157
"EnumBase",
158
[
159
(enum.Enum,),
160
(enum.Flag,),
161
(enum.IntEnum,),
162
(enum.IntFlag,),
163
(int, enum.Enum),
164
],
165
)
166
def test_lit_enum_input_non_string(EnumBase: tuple[type, ...]) -> None:
167
# https://github.com/pola-rs/polars/issues/16668
168
169
class Number(*EnumBase): # type: ignore[misc]
170
ONE = 1
171
TWO = 2
172
173
value = Number.ONE
174
175
result = pl.lit(value)
176
assert pl.select(result).dtypes[0] == pl.Int32
177
assert pl.select(result).item() == 1
178
179
result = pl.lit(value, dtype=pl.Int8)
180
assert pl.select(result).dtypes[0] == pl.Int8
181
assert pl.select(result).item() == 1
182
183
184
@given(value=datetimes("ns"))
185
def test_datetime_ns(value: datetime) -> None:
186
result = pl.select(pl.lit(value, dtype=pl.Datetime("ns")))["literal"][0]
187
assert result == value
188
189
190
@given(value=datetimes("us"))
191
def test_datetime_us(value: datetime) -> None:
192
result = pl.select(pl.lit(value, dtype=pl.Datetime("us")))["literal"][0]
193
assert result == value
194
result = pl.select(pl.lit(value, dtype=pl.Datetime))["literal"][0]
195
assert result == value
196
197
198
@given(value=datetimes("ms"))
199
def test_datetime_ms(value: datetime) -> None:
200
result = pl.select(pl.lit(value, dtype=pl.Datetime("ms")))["literal"][0]
201
expected_microsecond = value.microsecond // 1000 * 1000
202
assert result == value.replace(microsecond=expected_microsecond)
203
204
205
@pytest.mark.may_fail_cloud # @cloud-decimal
206
def test_lit_decimal() -> None:
207
value = Decimal("0.1")
208
209
expr = pl.lit(value)
210
df = pl.select(expr)
211
result = df.item()
212
213
assert df.dtypes[0] == pl.Decimal(None, 1)
214
assert result == value
215
216
217
def test_lit_string_float() -> None:
218
value = 3.2
219
220
expr = pl.lit(value, dtype=pl.Utf8)
221
df = pl.select(expr)
222
result = df.item()
223
224
assert df.dtypes[0] == pl.String
225
assert result == str(value)
226
227
228
@pytest.mark.may_fail_cloud # @cloud-decimal
229
@given(s=series(min_size=1, max_size=1, allow_null=False, allowed_dtypes=pl.Decimal))
230
def test_lit_decimal_parametric(s: pl.Series) -> None:
231
scale = s.dtype.scale # type: ignore[attr-defined]
232
value = s.item()
233
234
expr = pl.lit(value)
235
df = pl.select(expr)
236
result = df.item()
237
238
assert df.dtypes[0] == pl.Decimal(None, scale)
239
assert result == value
240
241
242
@pytest.mark.parametrize(
243
"item",
244
[pytest.param({}, marks=pytest.mark.may_fail_cloud), {"foo": 1}],
245
)
246
def test_lit_structs(item: Any) -> None:
247
assert pl.select(pl.lit(item)).to_dict(as_series=False) == {"literal": [item]}
248
249
250
@pytest.mark.parametrize(
251
("value", "expected_dtype"),
252
[
253
(np.float32(1.2), pl.Float32),
254
(np.float64(1.2), pl.Float64),
255
(np.int8(1), pl.Int8),
256
(np.uint8(1), pl.UInt8),
257
(np.int16(1), pl.Int16),
258
(np.uint16(1), pl.UInt16),
259
(np.int32(1), pl.Int32),
260
(np.uint32(1), pl.UInt32),
261
(np.int64(1), pl.Int64),
262
(np.uint64(1), pl.UInt64),
263
],
264
)
265
def test_numpy_lit(value: Any, expected_dtype: PolarsDataType) -> None:
266
result = pl.select(pl.lit(value)).get_column("literal")
267
assert result.dtype == expected_dtype
268
269