Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/expr/test_exprs.py
8362 views
1
from __future__ import annotations
2
3
import re
4
from datetime import date, datetime, timedelta, timezone
5
from itertools import permutations
6
from typing import TYPE_CHECKING, Any, cast
7
from zoneinfo import ZoneInfo
8
9
import numpy as np
10
import pytest
11
12
import polars as pl
13
from polars.testing import assert_frame_equal, assert_series_equal
14
from tests.unit.conftest import (
15
DATETIME_DTYPES,
16
DURATION_DTYPES,
17
FLOAT_DTYPES,
18
INTEGER_DTYPES,
19
NUMERIC_DTYPES,
20
TEMPORAL_DTYPES,
21
)
22
23
if TYPE_CHECKING:
24
from polars._typing import PolarsDataType
25
26
27
def assert_expression_repr_matches(expression: pl.Expr, match: str) -> None:
28
"""Assert that the `repr` of a given Expression matches expectation."""
29
pattern = rf"<Expr \['{re.escape(match)}'\] at 0x[\dA-Fa-f]+>$"
30
assert re.search(pattern, repr(expression)) is not None
31
32
33
def test_arg_true() -> None:
34
df = pl.DataFrame({"a": [1, 1, 2, 1]})
35
res = df.select((pl.col("a") == 1).arg_true())
36
expected = pl.DataFrame([pl.Series("a", [0, 1, 3], dtype=pl.get_index_type())])
37
assert_frame_equal(res, expected)
38
39
40
def test_suffix(fruits_cars: pl.DataFrame) -> None:
41
df = fruits_cars
42
out = df.select([pl.all().name.suffix("_reverse")])
43
assert out.columns == ["A_reverse", "fruits_reverse", "B_reverse", "cars_reverse"]
44
45
46
def test_pipe() -> None:
47
df = pl.DataFrame({"foo": [1, 2, 3], "bar": [6, None, 8]})
48
49
def _multiply(expr: pl.Expr, mul: int) -> pl.Expr:
50
return expr * mul
51
52
result = df.select(
53
pl.col("foo").pipe(_multiply, mul=2),
54
pl.col("bar").pipe(_multiply, mul=3),
55
)
56
57
expected = pl.DataFrame({"foo": [2, 4, 6], "bar": [18, None, 24]})
58
assert_frame_equal(result, expected)
59
60
61
def test_prefix(fruits_cars: pl.DataFrame) -> None:
62
df = fruits_cars
63
out = df.select([pl.all().name.prefix("reverse_")])
64
assert out.columns == ["reverse_A", "reverse_fruits", "reverse_B", "reverse_cars"]
65
66
67
def test_filter_where() -> None:
68
df = pl.DataFrame({"a": [1, 2, 3, 1, 2, 3], "b": [4, 5, 6, 7, 8, 9]})
69
result_filter = df.group_by("a", maintain_order=True).agg(
70
pl.col("b").filter(pl.col("b") > 4).alias("c")
71
)
72
expected = pl.DataFrame({"a": [1, 2, 3], "c": [[7], [5, 8], [6, 9]]})
73
assert_frame_equal(result_filter, expected)
74
75
with pytest.deprecated_call():
76
result_where = df.group_by("a", maintain_order=True).agg(
77
pl.col("b").where(pl.col("b") > 4).alias("c")
78
)
79
assert_frame_equal(result_where, expected)
80
81
# apply filter constraints using kwargs
82
df = pl.DataFrame(
83
{
84
"key": ["a", "a", "a", "a", "a", "a", "b", "b", "b", "b", "b", "b"],
85
"n": [1, 4, 4, 2, 2, 3, 1, 3, 0, 2, 3, 4],
86
},
87
schema_overrides={"n": pl.UInt8},
88
)
89
res = (
90
df.group_by("key")
91
.agg(
92
n_0=pl.col("n").filter(n=0),
93
n_1=pl.col("n").filter(n=1),
94
n_2=pl.col("n").filter(n=2),
95
n_3=pl.col("n").filter(n=3),
96
n_4=pl.col("n").filter(n=4),
97
)
98
.sort(by="key")
99
)
100
assert res.rows() == [
101
("a", [], [1], [2, 2], [3], [4, 4]),
102
("b", [0], [1], [2], [3, 3], [4]),
103
]
104
105
106
def test_len_expr() -> None:
107
df = pl.DataFrame({"a": [1, 2, 3, 3, 3], "b": ["a", "a", "b", "a", "a"]})
108
109
out = df.select(pl.len())
110
assert out.shape == (1, 1)
111
assert cast("int", out.item()) == 5
112
113
out = df.group_by("b", maintain_order=True).agg(pl.len())
114
assert out["b"].to_list() == ["a", "b"]
115
assert out["len"].to_list() == [4, 1]
116
117
118
def test_map_alias() -> None:
119
out = pl.DataFrame({"foo": [1, 2, 3]}).select(
120
(pl.col("foo") * 2).name.map(lambda name: f"{name}{name}")
121
)
122
expected = pl.DataFrame({"foofoo": [2, 4, 6]})
123
assert_frame_equal(out, expected)
124
125
126
def test_unique_stable() -> None:
127
s = pl.Series("a", [1, 1, 1, 1, 2, 2, 2, 3, 3])
128
expected = pl.Series("a", [1, 2, 3])
129
assert_series_equal(s.unique(maintain_order=True), expected)
130
131
132
def test_entropy() -> None:
133
df = pl.DataFrame(
134
{
135
"group": ["A", "A", "A", "B", "B", "B", "B", "C"],
136
"id": [1, 2, 1, 4, 5, 4, 6, 7],
137
}
138
)
139
result = df.group_by("group", maintain_order=True).agg(
140
pl.col("id").entropy(normalize=True)
141
)
142
expected = pl.DataFrame(
143
{"group": ["A", "B", "C"], "id": [1.0397207708399179, 1.371381017771811, 0.0]}
144
)
145
assert_frame_equal(result, expected)
146
147
148
@pytest.mark.parametrize(
149
"dtype",
150
FLOAT_DTYPES,
151
)
152
def test_log_broadcast(dtype: pl.DataType) -> None:
153
a = pl.Series("a", [1, 3, 9, 27, 81], dtype=dtype)
154
b = pl.Series("a", [3, 3, 9, 3, 9], dtype=dtype)
155
156
assert_series_equal(a.log(b), pl.Series("a", [0, 1, 1, 3, 2], dtype=dtype))
157
assert_series_equal(
158
a.log(pl.Series("a", [3], dtype=dtype)),
159
pl.Series("a", [0, 1, 2, 3, 4], dtype=dtype),
160
)
161
assert_series_equal(
162
pl.Series("a", [81], dtype=dtype).log(b),
163
pl.Series("a", [4, 4, 2, 4, 2], dtype=dtype),
164
)
165
166
167
@pytest.mark.parametrize(
168
("dtype", "expected_dtype"),
169
[
170
(pl.Float64, pl.Float64),
171
(pl.Float32, pl.Float32),
172
(pl.Float16, pl.Float16),
173
(pl.Int32, pl.Float64),
174
(pl.Int64, pl.Float64),
175
],
176
)
177
def test_log_broadcast_upcasting(
178
dtype: pl.DataType, expected_dtype: pl.DataType
179
) -> None:
180
a = pl.Series("a", [1, 3, 9, 27, 81], dtype=dtype)
181
b = pl.Series("a", [3, 3, 9, 3, 9], dtype=dtype)
182
expected = pl.Series("a", [0, 1, 1, 3, 2], dtype=expected_dtype)
183
184
assert_series_equal(a.log(b), expected)
185
186
187
@pytest.mark.parametrize(
188
("dtype_a", "dtype_base", "dtype_out"),
189
[
190
(pl.Float16, pl.Float16, pl.Float16),
191
(pl.Float16, pl.Float32, pl.Float16),
192
(pl.Float16, pl.Float64, pl.Float16),
193
(pl.Float32, pl.Float16, pl.Float32),
194
(pl.Float32, pl.Float32, pl.Float32),
195
(pl.Float32, pl.Float64, pl.Float32),
196
(pl.Float64, pl.Float16, pl.Float64),
197
(pl.Float64, pl.Float32, pl.Float64),
198
(pl.Float64, pl.Float64, pl.Float64),
199
(pl.Float16, pl.Int32, pl.Float16),
200
(pl.Float32, pl.Int32, pl.Float32),
201
(pl.Float64, pl.Int32, pl.Float64),
202
(pl.Int32, pl.Float16, pl.Float16),
203
(pl.Int32, pl.Float32, pl.Float32),
204
(pl.Int32, pl.Float64, pl.Float64),
205
(pl.Decimal(21, 3), pl.Decimal(21, 3), pl.Float64),
206
(pl.Int32, pl.Int32, pl.Float64),
207
],
208
)
209
def test_log(
210
dtype_a: PolarsDataType,
211
dtype_base: PolarsDataType,
212
dtype_out: PolarsDataType,
213
) -> None:
214
a = pl.Series("a", [1, 3, 9, 27, 81], dtype=dtype_a)
215
base = pl.Series("base", [3, 3, 9, 3, 9], dtype=dtype_base)
216
lf = pl.LazyFrame([a, base])
217
218
result = lf.select(pl.col("a").log(pl.col("base")))
219
expected = pl.DataFrame({"a": pl.Series([0, 1, 1, 3, 2], dtype=dtype_out)})
220
221
assert_frame_equal(result.collect(), expected)
222
assert result.collect_schema() == expected.schema, (
223
f"{result.collect_schema()} != {expected.schema}"
224
)
225
226
227
@pytest.mark.parametrize(
228
("dtype_in", "dtype_out"),
229
[
230
(pl.Int32, pl.Float64),
231
(pl.Float16, pl.Float16),
232
(pl.Float32, pl.Float32),
233
(pl.Float64, pl.Float64),
234
],
235
)
236
def test_exp_log1p(dtype_in: PolarsDataType, dtype_out: PolarsDataType) -> None:
237
a = pl.Series("a", [1, 3, 9, 4, 10], dtype=dtype_in)
238
lf = pl.LazyFrame([a])
239
240
# exp
241
result = lf.select(pl.col("a").exp())
242
expected = pl.Series("a", np.exp(a.to_numpy())).cast(dtype_out).to_frame()
243
assert_frame_equal(result.collect(), expected)
244
assert result.collect_schema() == expected.schema
245
246
# log1p
247
result = lf.select(pl.col("a").log1p())
248
expected = pl.Series("a", np.log1p(a.to_numpy())).cast(dtype_out).to_frame()
249
assert_frame_equal(result.collect(), expected)
250
assert result.collect_schema() == expected.schema
251
252
253
def test_dot_in_group_by() -> None:
254
df = pl.DataFrame(
255
{
256
"group": ["a", "a", "a", "b", "b", "b"],
257
"x": [1, 1, 1, 1, 1, 1],
258
"y": [1, 2, 3, 4, 5, 6],
259
}
260
)
261
262
result = df.group_by("group", maintain_order=True).agg(
263
pl.col("x").dot("y").alias("dot")
264
)
265
expected = pl.DataFrame({"group": ["a", "b"], "dot": [6, 15]})
266
assert_frame_equal(result, expected)
267
268
269
def test_dtype_col_selection() -> None:
270
df = pl.DataFrame(
271
data=[],
272
schema={
273
"a1": pl.Datetime,
274
"a2": pl.Datetime("ms"),
275
"a3": pl.Datetime("ms"),
276
"a4": pl.Datetime("ns"),
277
"b": pl.Date,
278
"c": pl.Time,
279
"d1": pl.Duration,
280
"d2": pl.Duration("ms"),
281
"d3": pl.Duration("us"),
282
"d4": pl.Duration("ns"),
283
"e": pl.Int8,
284
"f": pl.Int16,
285
"g": pl.Int32,
286
"h": pl.Int64,
287
"i": pl.Float32,
288
"j": pl.Float64,
289
"k": pl.UInt8,
290
"l": pl.UInt16,
291
"m": pl.UInt32,
292
"n": pl.UInt64,
293
"o": pl.Float16,
294
},
295
)
296
assert df.select(pl.col(INTEGER_DTYPES)).columns == [
297
"e",
298
"f",
299
"g",
300
"h",
301
"k",
302
"l",
303
"m",
304
"n",
305
]
306
assert df.select(pl.col(FLOAT_DTYPES)).columns == ["i", "j", "o"]
307
assert df.select(pl.col(NUMERIC_DTYPES)).columns == [
308
"e",
309
"f",
310
"g",
311
"h",
312
"i",
313
"j",
314
"k",
315
"l",
316
"m",
317
"n",
318
"o",
319
]
320
assert df.select(pl.col(TEMPORAL_DTYPES)).columns == [
321
"a1",
322
"a2",
323
"a3",
324
"a4",
325
"b",
326
"c",
327
"d1",
328
"d2",
329
"d3",
330
"d4",
331
]
332
assert df.select(pl.col(DATETIME_DTYPES)).columns == [
333
"a1",
334
"a2",
335
"a3",
336
"a4",
337
]
338
assert df.select(pl.col(DURATION_DTYPES)).columns == [
339
"d1",
340
"d2",
341
"d3",
342
"d4",
343
]
344
345
346
def test_list_eval_expression() -> None:
347
df = pl.DataFrame({"a": [1, 8, 3], "b": [4, 5, 2]})
348
349
for parallel in [True, False]:
350
assert df.with_columns(
351
pl.concat_list(["a", "b"])
352
.list.eval(pl.element().rank(), parallel=parallel)
353
.alias("rank")
354
).to_dict(as_series=False) == {
355
"a": [1, 8, 3],
356
"b": [4, 5, 2],
357
"rank": [[1.0, 2.0], [2.0, 1.0], [2.0, 1.0]],
358
}
359
360
assert df["a"].reshape((1, -1)).arr.to_list().list.eval(
361
pl.element(), parallel=parallel
362
).to_list() == [[1, 8, 3]]
363
364
365
def test_null_count_expr() -> None:
366
df = pl.DataFrame({"key": ["a", "b", "b", "a"], "val": [1, 2, None, 1]})
367
368
assert df.select([pl.all().null_count()]).to_dict(as_series=False) == {
369
"key": [0],
370
"val": [1],
371
}
372
373
374
def test_pos_neg() -> None:
375
df = pl.DataFrame(
376
{
377
"x": [3, 2, 1],
378
"y": [6, 7, 8],
379
}
380
).with_columns(-pl.col("x"), +pl.col("y"), -pl.lit(1))
381
382
# #11149: ensure that we preserve the output name (where available)
383
assert df.to_dict(as_series=False) == {
384
"x": [-3, -2, -1],
385
"y": [6, 7, 8],
386
"literal": [-1, -1, -1],
387
}
388
389
390
def test_power_by_expression() -> None:
391
out = pl.DataFrame(
392
{"a": [1, None, None, 4, 5, 6], "b": [1, 2, None, 4, None, 6]}
393
).select(
394
pl.col("a").pow(pl.col("b")).alias("pow_expr"),
395
(pl.col("a") ** pl.col("b")).alias("pow_op"),
396
(2 ** pl.col("b")).alias("pow_op_left"),
397
)
398
399
for pow_col in ("pow_expr", "pow_op"):
400
assert out[pow_col].to_list() == [1.0, None, None, 256.0, None, 46656.0]
401
assert out["pow_op_left"].to_list() == [2.0, 4.0, None, 16.0, None, 64.0]
402
403
404
@pytest.mark.may_fail_cloud # reason: chunking
405
@pytest.mark.may_fail_auto_streaming
406
def test_expression_appends() -> None:
407
df = pl.DataFrame({"a": [1, 1, 2]})
408
409
assert df.select(pl.repeat(None, 3).append(pl.col("a"))).n_chunks() == 2
410
assert df.select(pl.repeat(None, 3).append(pl.col("a")).rechunk()).n_chunks() == 1
411
412
out = df.select(pl.concat([pl.repeat(None, 3), pl.col("a")], rechunk=True))
413
414
assert out.n_chunks() == 1
415
assert out.to_series().to_list() == [None, None, None, 1, 1, 2]
416
417
418
def test_arr_contains() -> None:
419
df_groups = pl.DataFrame(
420
{
421
"animals": [
422
["cat", "mouse", "dog"],
423
["dog", "hedgehog", "mouse", "cat"],
424
["peacock", "mouse", "aardvark"],
425
],
426
}
427
)
428
# string array contains
429
assert df_groups.lazy().filter(
430
pl.col("animals").list.contains("mouse"),
431
).collect().to_dict(as_series=False) == {
432
"animals": [
433
["cat", "mouse", "dog"],
434
["dog", "hedgehog", "mouse", "cat"],
435
["peacock", "mouse", "aardvark"],
436
]
437
}
438
# string array contains and *not* contains
439
assert df_groups.filter(
440
pl.col("animals").list.contains("mouse"),
441
~pl.col("animals").list.contains("hedgehog"),
442
).to_dict(as_series=False) == {
443
"animals": [
444
["cat", "mouse", "dog"],
445
["peacock", "mouse", "aardvark"],
446
],
447
}
448
449
450
def test_logical_boolean() -> None:
451
# note, cannot use expressions in logical
452
# boolean context (eg: and/or/not operators)
453
with pytest.raises(TypeError, match="ambiguous"):
454
pl.col("colx") and pl.col("coly") # type: ignore[redundant-expr]
455
456
with pytest.raises(TypeError, match="ambiguous"):
457
pl.col("colx") or pl.col("coly") # type: ignore[redundant-expr]
458
459
df = pl.DataFrame({"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, 4, 5]})
460
461
with pytest.raises(TypeError, match="ambiguous"):
462
df.select([(pl.col("a") > pl.col("b")) and (pl.col("b") > pl.col("b"))])
463
464
with pytest.raises(TypeError, match="ambiguous"):
465
df.select([(pl.col("a") > pl.col("b")) or (pl.col("b") > pl.col("b"))])
466
467
468
def test_lit_dtypes() -> None:
469
def lit_series(value: Any, dtype: PolarsDataType | None) -> pl.Series:
470
return pl.select(pl.lit(value, dtype=dtype)).to_series()
471
472
d = datetime(2049, 10, 5, 1, 2, 3, 987654)
473
d_ms = datetime(2049, 10, 5, 1, 2, 3, 987000)
474
d_tz = datetime(2049, 10, 5, 1, 2, 3, 987654, tzinfo=ZoneInfo("Asia/Kathmandu"))
475
476
td = timedelta(days=942, hours=6, microseconds=123456)
477
td_ms = timedelta(days=942, seconds=21600, microseconds=123000)
478
479
df = pl.DataFrame(
480
{
481
"dtm_ms": lit_series(d, pl.Datetime("ms")),
482
"dtm_us": lit_series(d, pl.Datetime("us")),
483
"dtm_ns": lit_series(d, pl.Datetime("ns")),
484
"dtm_aware_0": lit_series(d, pl.Datetime("us", "Asia/Kathmandu")),
485
"dtm_aware_1": lit_series(d_tz, pl.Datetime("us")),
486
"dtm_aware_2": lit_series(d_tz, None),
487
"dtm_aware_3": lit_series(d, pl.Datetime(time_zone="Asia/Kathmandu")),
488
"dur_ms": lit_series(td, pl.Duration("ms")),
489
"dur_us": lit_series(td, pl.Duration("us")),
490
"dur_ns": lit_series(td, pl.Duration("ns")),
491
"f16": lit_series(0, pl.Float16),
492
"f32": lit_series(0, pl.Float32),
493
"f64": lit_series(0, pl.Float64),
494
"u16": lit_series(0, pl.UInt16),
495
"i16": lit_series(0, pl.Int16),
496
"i64": lit_series(pl.Series([8]), None),
497
"list_i64": lit_series(pl.Series([[1, 2, 3]]), None),
498
}
499
)
500
assert df.dtypes == [
501
pl.Datetime("ms"),
502
pl.Datetime("us"),
503
pl.Datetime("ns"),
504
pl.Datetime("us", "Asia/Kathmandu"),
505
pl.Datetime("us", "Asia/Kathmandu"),
506
pl.Datetime("us", "Asia/Kathmandu"),
507
pl.Datetime("us", "Asia/Kathmandu"),
508
pl.Duration("ms"),
509
pl.Duration("us"),
510
pl.Duration("ns"),
511
pl.Float16,
512
pl.Float32,
513
pl.Float64,
514
pl.UInt16,
515
pl.Int16,
516
pl.Int64,
517
pl.List(pl.Int64),
518
]
519
assert df.row(0) == (
520
d_ms,
521
d,
522
d,
523
d_tz,
524
d_tz,
525
d_tz,
526
d_tz,
527
td_ms,
528
td,
529
td,
530
0,
531
0,
532
0,
533
0,
534
0,
535
8,
536
[1, 2, 3],
537
)
538
539
540
def test_lit_empty_tu() -> None:
541
td = timedelta(1)
542
assert pl.select(pl.lit(td, dtype=pl.Duration)).item() == td
543
assert pl.select(pl.lit(td, dtype=pl.Duration)).dtypes[0].time_unit == "us" # type: ignore[attr-defined]
544
545
t = datetime(2023, 1, 1)
546
assert pl.select(pl.lit(t, dtype=pl.Datetime)).item() == t
547
assert pl.select(pl.lit(t, dtype=pl.Datetime)).dtypes[0].time_unit == "us" # type: ignore[attr-defined]
548
549
550
def test_incompatible_lit_dtype() -> None:
551
with pytest.raises(
552
TypeError,
553
match=r"time zone of dtype \('Asia/Kathmandu'\) differs from time zone of value \(datetime.timezone.utc\)",
554
):
555
pl.lit(
556
datetime(2020, 1, 1, tzinfo=timezone.utc),
557
dtype=pl.Datetime("us", "Asia/Kathmandu"),
558
)
559
560
561
def test_lit_dtype_utc() -> None:
562
result = pl.select(
563
pl.lit(
564
datetime(2020, 1, 1, tzinfo=ZoneInfo("Asia/Kathmandu")),
565
dtype=pl.Datetime("us", "Asia/Kathmandu"),
566
)
567
)
568
expected = pl.DataFrame(
569
{"literal": [datetime(2019, 12, 31, 18, 15, tzinfo=timezone.utc)]}
570
).select(pl.col("literal").dt.convert_time_zone("Asia/Kathmandu"))
571
assert_frame_equal(result, expected)
572
573
574
@pytest.mark.parametrize(
575
("input", "expected"),
576
[
577
(("a",), ["b", "c"]),
578
(("a", "b"), ["c"]),
579
((["a", "b"],), ["c"]),
580
((pl.Int64,), ["c"]),
581
((pl.String, pl.Float32), ["a", "b"]),
582
(([pl.String, pl.Float32],), ["a", "b"]),
583
],
584
)
585
def test_exclude(input: tuple[Any, ...], expected: list[str]) -> None:
586
df = pl.DataFrame(schema={"a": pl.Int64, "b": pl.Int64, "c": pl.String})
587
assert df.select(pl.all().exclude(*input)).columns == expected
588
589
590
@pytest.mark.parametrize("input", [(5,), (["a"], date.today()), (pl.Int64, "a")])
591
def test_exclude_invalid_input(input: tuple[Any, ...]) -> None:
592
df = pl.DataFrame(schema=["a", "b", "c"])
593
with pytest.raises(TypeError):
594
df.select(pl.all().exclude(*input))
595
596
597
def test_operators_vs_expressions() -> None:
598
df = pl.DataFrame(
599
data={
600
"x": [5, 6, 7, 4, 8],
601
"y": [1.5, 2.5, 1.0, 4.0, -5.75],
602
"z": [-9, 2, -1, 4, 8],
603
}
604
)
605
for c1, c2 in permutations("xyz", r=2):
606
df_op = df.select(
607
a=pl.col(c1) == pl.col(c2),
608
b=pl.col(c1) // pl.col(c2),
609
c=pl.col(c1) > pl.col(c2),
610
d=pl.col(c1) >= pl.col(c2),
611
e=pl.col(c1) < pl.col(c2),
612
f=pl.col(c1) <= pl.col(c2),
613
g=pl.col(c1) % pl.col(c2),
614
h=pl.col(c1) != pl.col(c2),
615
i=pl.col(c1) - pl.col(c2),
616
j=pl.col(c1) / pl.col(c2),
617
k=pl.col(c1) * pl.col(c2),
618
l=pl.col(c1) + pl.col(c2),
619
)
620
df_expr = df.select(
621
a=pl.col(c1).eq(pl.col(c2)),
622
b=pl.col(c1).floordiv(pl.col(c2)),
623
c=pl.col(c1).gt(pl.col(c2)),
624
d=pl.col(c1).ge(pl.col(c2)),
625
e=pl.col(c1).lt(pl.col(c2)),
626
f=pl.col(c1).le(pl.col(c2)),
627
g=pl.col(c1).mod(pl.col(c2)),
628
h=pl.col(c1).ne(pl.col(c2)),
629
i=pl.col(c1).sub(pl.col(c2)),
630
j=pl.col(c1).truediv(pl.col(c2)),
631
k=pl.col(c1).mul(pl.col(c2)),
632
l=pl.col(c1).add(pl.col(c2)),
633
)
634
assert_frame_equal(df_op, df_expr)
635
636
# xor - only int cols
637
assert_frame_equal(
638
df.select(pl.col("x") ^ pl.col("z")),
639
df.select(pl.col("x").xor(pl.col("z"))),
640
)
641
642
# and (&) or (|) chains
643
assert_frame_equal(
644
df.select(
645
all=(pl.col("x") >= pl.col("z")).and_(
646
pl.col("y") >= pl.col("z"),
647
pl.col("y") == pl.col("y"),
648
pl.col("z") <= pl.col("x"),
649
pl.col("y") != pl.col("x"),
650
)
651
),
652
df.select(
653
all=(
654
(pl.col("x") >= pl.col("z"))
655
& (pl.col("y") >= pl.col("z"))
656
& (pl.col("y") == pl.col("y"))
657
& (pl.col("z") <= pl.col("x"))
658
& (pl.col("y") != pl.col("x"))
659
)
660
),
661
)
662
663
assert_frame_equal(
664
df.select(
665
any=(pl.col("x") == pl.col("y")).or_(
666
pl.col("x") == pl.col("y"),
667
pl.col("y") == pl.col("z"),
668
pl.col("y").cast(int) == pl.col("z"),
669
)
670
),
671
df.select(
672
any=(pl.col("x") == pl.col("y"))
673
| (pl.col("x") == pl.col("y"))
674
| (pl.col("y") == pl.col("z"))
675
| (pl.col("y").cast(int) == pl.col("z"))
676
),
677
)
678
679
680
def test_head() -> None:
681
df = pl.DataFrame({"a": [1, 2, 3, 4, 5]})
682
assert df.select(pl.col("a").head(0)).to_dict(as_series=False) == {"a": []}
683
assert df.select(pl.col("a").head(3)).to_dict(as_series=False) == {"a": [1, 2, 3]}
684
assert df.select(pl.col("a").head(10)).to_dict(as_series=False) == {
685
"a": [1, 2, 3, 4, 5]
686
}
687
assert df.select(pl.col("a").head(pl.len() // 2)).to_dict(as_series=False) == {
688
"a": [1, 2]
689
}
690
691
692
def test_tail() -> None:
693
df = pl.DataFrame({"a": [1, 2, 3, 4, 5]})
694
assert df.select(pl.col("a").tail(0)).to_dict(as_series=False) == {"a": []}
695
assert df.select(pl.col("a").tail(3)).to_dict(as_series=False) == {"a": [3, 4, 5]}
696
assert df.select(pl.col("a").tail(10)).to_dict(as_series=False) == {
697
"a": [1, 2, 3, 4, 5]
698
}
699
assert df.select(pl.col("a").tail(pl.len() // 2)).to_dict(as_series=False) == {
700
"a": [4, 5]
701
}
702
703
704
def test_repr_short_expression() -> None:
705
expr = pl.functions.all().len().name.prefix("length-long:")
706
assert_expression_repr_matches(
707
expression=expr,
708
match="cs.all().len().name.prefix(len…",
709
)
710
711
712
def test_repr_long_expression() -> None:
713
expr = pl.functions.col(pl.String).str.count_matches("")
714
assert_expression_repr_matches(
715
expression=expr,
716
match="cs.string().str.count_matches(…",
717
)
718
719
720
def test_repr_gather() -> None:
721
assert_expression_repr_matches(
722
expression=pl.col("a").gather(0),
723
match='col("a").gather(dyn int: 0)',
724
)
725
assert_expression_repr_matches(
726
expression=pl.col("a").get(0),
727
match='col("a").get(dyn int: 0)',
728
)
729
730
731
def test_repr_miscellaneous() -> None:
732
expr = pl.struct("x", pl.col("y"), pl.col("z").sin() ** 2)
733
assert_expression_repr_matches(
734
expression=expr,
735
match='as_struct("x", "y", col("z").s…',
736
)
737
full_repr = expr._pyexpr.to_str()
738
assert full_repr == 'as_struct("x", "y", col("z").sin().pow([dyn int: 2]))'
739
740
741
def test_replace_no_cse() -> None:
742
plan = (
743
pl.LazyFrame({"a": [1], "b": [2]})
744
.select([(pl.col("a") * pl.col("a")).sum().replace(1, None)])
745
.explain()
746
)
747
assert "POLARS_CSER" not in plan
748
749
750
def test_slice_rejects_non_integral() -> None:
751
df = pl.LazyFrame({"a": [0, 1, 2, 3], "b": [1.5, 2, 3, 4]})
752
753
with pytest.raises(pl.exceptions.InvalidOperationError):
754
df.select(pl.col("a").slice(pl.col("b").slice(0, 1), None)).collect()
755
756
with pytest.raises(pl.exceptions.InvalidOperationError):
757
df.select(pl.col("a").slice(0, pl.col("b").slice(1, 2))).collect()
758
759
with pytest.raises(pl.exceptions.InvalidOperationError):
760
df.select(pl.col("a").slice(pl.lit("1"), None)).collect()
761
762
763
def test_slice() -> None:
764
data = {"a": [0, 1, 2, 3], "b": [1, 2, 3, 4]}
765
df = pl.DataFrame(data)
766
767
result = df.select(pl.col("a").slice(1))
768
expected = pl.DataFrame({"a": data["a"][1:]})
769
assert_frame_equal(result, expected)
770
771
result = df.select(pl.all().slice(1, 1))
772
expected = pl.DataFrame({"a": data["a"][1:2], "b": data["b"][1:2]})
773
assert_frame_equal(result, expected)
774
775
776
@pytest.mark.may_fail_cloud # reason: shrink_dtype
777
def test_function_expr_scalar_identification_18755() -> None:
778
# The function uses `ApplyOptions::GroupWise`, however the input is scalar.
779
with pytest.warns(
780
DeprecationWarning,
781
match=r"use `Series\.shrink_dtype` instead",
782
):
783
assert_frame_equal(
784
pl.DataFrame({"a": [1, 2]}).with_columns(
785
pl.lit(5, pl.Int64).shrink_dtype().alias("b")
786
),
787
pl.DataFrame({"a": [1, 2], "b": pl.Series([5, 5], dtype=pl.Int64)}),
788
)
789
790
791
def test_concat_deprecation() -> None:
792
with pytest.deprecated_call(match=r"`str\.concat` is deprecated."):
793
pl.Series(["foo"]).str.concat()
794
with pytest.deprecated_call(match=r"`str\.concat` is deprecated."):
795
pl.DataFrame({"foo": ["bar"]}).select(pl.all().str.concat())
796
797