Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/functions/test_lit.py
8420 views
1
# mypy: disable-error-code="redundant-expr"
2
from __future__ import annotations
3
4
import enum
5
import sys
6
from datetime import date, datetime, time, timedelta
7
from decimal import Decimal
8
from typing import TYPE_CHECKING, Any
9
10
import numpy as np
11
import pytest
12
from hypothesis import given
13
14
import polars as pl
15
from polars.testing import assert_frame_equal
16
from polars.testing.parametric.strategies import series
17
from polars.testing.parametric.strategies.data import datetimes
18
19
if TYPE_CHECKING:
20
from polars._typing import PolarsDataType
21
22
23
if sys.version_info >= (3, 11):
24
from enum import StrEnum
25
26
PyStrEnum: type[enum.Enum] | None = StrEnum
27
else:
28
PyStrEnum = None
29
30
31
@pytest.mark.parametrize(
32
"input",
33
[
34
[[1, 2], [3, 4, 5]],
35
[1, 2, 3],
36
],
37
)
38
def test_lit_list_input(input: list[Any]) -> None:
39
df = pl.DataFrame({"a": [1, 2]})
40
result = df.with_columns(pl.lit(input).first())
41
expected = pl.DataFrame({"a": [1, 2], "literal": [input, input]})
42
assert_frame_equal(result, expected)
43
44
45
@pytest.mark.parametrize(
46
"input",
47
[
48
([1, 2], [3, 4, 5]),
49
(1, 2, 3),
50
],
51
)
52
def test_lit_tuple_input(input: tuple[Any, ...]) -> None:
53
df = pl.DataFrame({"a": [1, 2]})
54
result = df.with_columns(pl.lit(input).first())
55
56
expected = pl.DataFrame({"a": [1, 2], "literal": [list(input), list(input)]})
57
assert_frame_equal(result, expected)
58
59
60
def test_lit_numpy_array_input() -> None:
61
df = pl.DataFrame({"a": [1, 2]})
62
input = np.array([3, 4])
63
64
result = df.with_columns(pl.lit(input, dtype=pl.Int64))
65
66
expected = pl.DataFrame({"a": [1, 2], "literal": [3, 4]})
67
assert_frame_equal(result, expected)
68
69
70
def test_lit_ambiguous_datetimes_11379() -> None:
71
df = pl.DataFrame(
72
{
73
"ts": pl.datetime_range(
74
datetime(2020, 10, 25),
75
datetime(2020, 10, 25, 2),
76
"1h",
77
time_zone="Europe/London",
78
eager=True,
79
)
80
}
81
)
82
for i in range(df.height):
83
result = df.filter(pl.col("ts") >= df["ts"][i])
84
expected = df[i:]
85
assert_frame_equal(result, expected)
86
87
88
def test_list_datetime_11571() -> None:
89
sec_np_ns = np.timedelta64(1_000_000_000, "ns")
90
sec_np_us = np.timedelta64(1_000_000, "us")
91
assert pl.select(pl.lit(sec_np_ns))[0, 0] == timedelta(seconds=1)
92
assert pl.select(pl.lit(sec_np_us))[0, 0] == timedelta(seconds=1)
93
94
95
@pytest.mark.parametrize(
96
("input", "dtype"),
97
[
98
pytest.param(-(2**31), pl.Int32, id="i32 min"),
99
pytest.param(-(2**31) - 1, pl.Int64, id="below i32 min"),
100
pytest.param(2**31 - 1, pl.Int32, id="i32 max"),
101
pytest.param(2**31, pl.Int64, id="above i32 max"),
102
pytest.param(2**63 - 1, pl.Int64, id="i64 max"),
103
pytest.param(2**63, pl.UInt64, id="above i64 max"),
104
],
105
)
106
def test_lit_int_return_type(input: int, dtype: PolarsDataType) -> None:
107
assert pl.select(pl.lit(input)).to_series().dtype == dtype
108
109
110
def test_lit_unsupported_type() -> None:
111
with pytest.raises(
112
TypeError,
113
match="cannot create expression literal for value of type LazyFrame",
114
):
115
pl.lit(pl.LazyFrame({"a": [1, 2, 3]}))
116
117
118
@pytest.mark.parametrize(
119
"EnumBase",
120
[
121
(enum.Enum,),
122
(str, enum.Enum),
123
*([(PyStrEnum,)] if PyStrEnum is not None else []),
124
],
125
)
126
def test_lit_enum_input_16668(EnumBase: tuple[type, ...]) -> None:
127
# https://github.com/pola-rs/polars/issues/16668
128
129
class State(*EnumBase): # type: ignore[misc]
130
NSW = "New South Wales"
131
QLD = "Queensland"
132
VIC = "Victoria"
133
134
# validate that frame schema has inferred the enum
135
df = pl.DataFrame({"state": [State.NSW, State.VIC]})
136
assert df.schema == {
137
"state": pl.Enum(["New South Wales", "Queensland", "Victoria"])
138
}
139
140
# check use of enum as lit/constraint
141
value = State.VIC
142
expected = "Victoria"
143
144
for lit_value in (
145
pl.lit(value),
146
pl.lit(value.value), # type: ignore[attr-defined]
147
):
148
assert pl.select(lit_value).item() == expected
149
assert df.filter(state=value).item() == expected
150
assert df.filter(state=lit_value).item() == expected
151
152
assert df.filter(pl.col("state") == State.QLD).is_empty()
153
assert df.filter(pl.col("state") != State.QLD).height == 2
154
155
156
@pytest.mark.parametrize(
157
"EnumBase",
158
[
159
(enum.Enum,),
160
(enum.Flag,),
161
(enum.IntEnum,),
162
(enum.IntFlag,),
163
(int, enum.Enum),
164
],
165
)
166
def test_lit_enum_input_non_string(EnumBase: tuple[type, ...]) -> None:
167
# https://github.com/pola-rs/polars/issues/16668
168
169
class Number(*EnumBase): # type: ignore[misc]
170
ONE = 1
171
TWO = 2
172
173
value = Number.ONE
174
175
result = pl.lit(value)
176
assert pl.select(result).dtypes[0] == pl.Int32
177
assert pl.select(result).item() == 1
178
179
result = pl.lit(value, dtype=pl.Int8)
180
assert pl.select(result).dtypes[0] == pl.Int8
181
assert pl.select(result).item() == 1
182
183
184
@given(value=datetimes("ns"))
185
def test_datetime_ns(value: datetime) -> None:
186
result = pl.select(pl.lit(value, dtype=pl.Datetime("ns")))["literal"][0]
187
assert result == value
188
189
190
@given(value=datetimes("us"))
191
def test_datetime_us(value: datetime) -> None:
192
result = pl.select(pl.lit(value, dtype=pl.Datetime("us")))["literal"][0]
193
assert result == value
194
result = pl.select(pl.lit(value, dtype=pl.Datetime))["literal"][0]
195
assert result == value
196
197
198
@given(value=datetimes("ms"))
199
def test_datetime_ms(value: datetime) -> None:
200
result = pl.select(pl.lit(value, dtype=pl.Datetime("ms")))["literal"][0]
201
expected_microsecond = value.microsecond // 1000 * 1000
202
assert result == value.replace(microsecond=expected_microsecond)
203
204
205
def test_np_datetime64_as_date_24521() -> None:
206
result = pl.select(pl.lit(np.datetime64("2020-12-27")))
207
series = result.get_column("literal")
208
assert series.dtype == pl.Date
209
assert series[0] == date(2020, 12, 27)
210
211
212
@pytest.mark.may_fail_cloud # @cloud-decimal
213
def test_lit_decimal() -> None:
214
value = Decimal("0.1")
215
216
expr = pl.lit(value)
217
df = pl.select(expr)
218
result = df.item()
219
220
assert df.dtypes[0] == pl.Decimal(None, 1)
221
assert result == value
222
223
224
def test_lit_string_float() -> None:
225
value = 3.2
226
227
expr = pl.lit(value, dtype=pl.Utf8)
228
df = pl.select(expr)
229
result = df.item()
230
231
assert df.dtypes[0] == pl.String
232
assert result == str(value)
233
234
235
@pytest.mark.may_fail_cloud # @cloud-decimal
236
@given(s=series(min_size=1, max_size=1, allow_null=False, allowed_dtypes=pl.Decimal))
237
def test_lit_decimal_parametric(s: pl.Series) -> None:
238
scale = s.dtype.scale # type: ignore[attr-defined]
239
value = s.item()
240
241
expr = pl.lit(value)
242
df = pl.select(expr)
243
result = df.item()
244
245
assert df.dtypes[0] == pl.Decimal(None, scale)
246
assert result == value
247
248
249
@pytest.mark.parametrize(
250
"item",
251
[pytest.param({}, marks=pytest.mark.may_fail_cloud), {"foo": 1}],
252
)
253
def test_lit_structs(item: Any) -> None:
254
assert pl.select(pl.lit(item)).to_dict(as_series=False) == {"literal": [item]}
255
256
257
@pytest.mark.parametrize(
258
("value", "expected_dtype"),
259
[
260
(np.float32(1.2), pl.Float32),
261
(np.float64(1.2), pl.Float64),
262
(np.int8(1), pl.Int8),
263
(np.uint8(1), pl.UInt8),
264
(np.int16(1), pl.Int16),
265
(np.uint16(1), pl.UInt16),
266
(np.int32(1), pl.Int32),
267
(np.uint32(1), pl.UInt32),
268
(np.int64(1), pl.Int64),
269
(np.uint64(1), pl.UInt64),
270
],
271
)
272
def test_numpy_lit(value: Any, expected_dtype: PolarsDataType) -> None:
273
result = pl.select(pl.lit(value)).get_column("literal")
274
assert result.dtype == expected_dtype
275
276
277
def test_lit_object_type_25713() -> None:
278
obj = time(hour=1)
279
out = pl.select(pl.lit(obj, dtype=pl.Object))
280
expected = pl.DataFrame({"literal": [obj]}, schema={"literal": pl.Object})
281
assert out.to_dict(as_series=False) == expected.to_dict(as_series=False)
282
283