Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/expr/test_exprs.py
6939 views
1
from __future__ import annotations
2
3
from datetime import date, datetime, timedelta, timezone
4
from itertools import permutations
5
from typing import TYPE_CHECKING, Any, cast
6
from zoneinfo import ZoneInfo
7
8
import numpy as np
9
import pytest
10
11
import polars as pl
12
from polars.testing import assert_frame_equal, assert_series_equal
13
from tests.unit.conftest import (
14
DATETIME_DTYPES,
15
DURATION_DTYPES,
16
FLOAT_DTYPES,
17
INTEGER_DTYPES,
18
NUMERIC_DTYPES,
19
TEMPORAL_DTYPES,
20
)
21
22
if TYPE_CHECKING:
23
from polars._typing import PolarsDataType
24
25
26
def test_arg_true() -> None:
27
df = pl.DataFrame({"a": [1, 1, 2, 1]})
28
res = df.select((pl.col("a") == 1).arg_true())
29
expected = pl.DataFrame([pl.Series("a", [0, 1, 3], dtype=pl.UInt32)])
30
assert_frame_equal(res, expected)
31
32
33
def test_suffix(fruits_cars: pl.DataFrame) -> None:
34
df = fruits_cars
35
out = df.select([pl.all().name.suffix("_reverse")])
36
assert out.columns == ["A_reverse", "fruits_reverse", "B_reverse", "cars_reverse"]
37
38
39
def test_pipe() -> None:
40
df = pl.DataFrame({"foo": [1, 2, 3], "bar": [6, None, 8]})
41
42
def _multiply(expr: pl.Expr, mul: int) -> pl.Expr:
43
return expr * mul
44
45
result = df.select(
46
pl.col("foo").pipe(_multiply, mul=2),
47
pl.col("bar").pipe(_multiply, mul=3),
48
)
49
50
expected = pl.DataFrame({"foo": [2, 4, 6], "bar": [18, None, 24]})
51
assert_frame_equal(result, expected)
52
53
54
def test_prefix(fruits_cars: pl.DataFrame) -> None:
55
df = fruits_cars
56
out = df.select([pl.all().name.prefix("reverse_")])
57
assert out.columns == ["reverse_A", "reverse_fruits", "reverse_B", "reverse_cars"]
58
59
60
def test_filter_where() -> None:
61
df = pl.DataFrame({"a": [1, 2, 3, 1, 2, 3], "b": [4, 5, 6, 7, 8, 9]})
62
result_filter = df.group_by("a", maintain_order=True).agg(
63
pl.col("b").filter(pl.col("b") > 4).alias("c")
64
)
65
expected = pl.DataFrame({"a": [1, 2, 3], "c": [[7], [5, 8], [6, 9]]})
66
assert_frame_equal(result_filter, expected)
67
68
with pytest.deprecated_call():
69
result_where = df.group_by("a", maintain_order=True).agg(
70
pl.col("b").where(pl.col("b") > 4).alias("c")
71
)
72
assert_frame_equal(result_where, expected)
73
74
# apply filter constraints using kwargs
75
df = pl.DataFrame(
76
{
77
"key": ["a", "a", "a", "a", "a", "a", "b", "b", "b", "b", "b", "b"],
78
"n": [1, 4, 4, 2, 2, 3, 1, 3, 0, 2, 3, 4],
79
},
80
schema_overrides={"n": pl.UInt8},
81
)
82
res = (
83
df.group_by("key")
84
.agg(
85
n_0=pl.col("n").filter(n=0),
86
n_1=pl.col("n").filter(n=1),
87
n_2=pl.col("n").filter(n=2),
88
n_3=pl.col("n").filter(n=3),
89
n_4=pl.col("n").filter(n=4),
90
)
91
.sort(by="key")
92
)
93
assert res.rows() == [
94
("a", [], [1], [2, 2], [3], [4, 4]),
95
("b", [0], [1], [2], [3, 3], [4]),
96
]
97
98
99
def test_len_expr() -> None:
100
df = pl.DataFrame({"a": [1, 2, 3, 3, 3], "b": ["a", "a", "b", "a", "a"]})
101
102
out = df.select(pl.len())
103
assert out.shape == (1, 1)
104
assert cast(int, out.item()) == 5
105
106
out = df.group_by("b", maintain_order=True).agg(pl.len())
107
assert out["b"].to_list() == ["a", "b"]
108
assert out["len"].to_list() == [4, 1]
109
110
111
def test_map_alias() -> None:
112
out = pl.DataFrame({"foo": [1, 2, 3]}).select(
113
(pl.col("foo") * 2).name.map(lambda name: f"{name}{name}")
114
)
115
expected = pl.DataFrame({"foofoo": [2, 4, 6]})
116
assert_frame_equal(out, expected)
117
118
119
def test_unique_stable() -> None:
120
s = pl.Series("a", [1, 1, 1, 1, 2, 2, 2, 3, 3])
121
expected = pl.Series("a", [1, 2, 3])
122
assert_series_equal(s.unique(maintain_order=True), expected)
123
124
125
def test_entropy() -> None:
126
df = pl.DataFrame(
127
{
128
"group": ["A", "A", "A", "B", "B", "B", "B", "C"],
129
"id": [1, 2, 1, 4, 5, 4, 6, 7],
130
}
131
)
132
result = df.group_by("group", maintain_order=True).agg(
133
pl.col("id").entropy(normalize=True)
134
)
135
expected = pl.DataFrame(
136
{"group": ["A", "B", "C"], "id": [1.0397207708399179, 1.371381017771811, 0.0]}
137
)
138
assert_frame_equal(result, expected)
139
140
141
@pytest.mark.parametrize(
142
"dtype",
143
[
144
pl.Float64,
145
pl.Float32,
146
],
147
)
148
def test_log_broadcast(dtype: pl.DataType) -> None:
149
a = pl.Series("a", [1, 3, 9, 27, 81], dtype=dtype)
150
b = pl.Series("a", [3, 3, 9, 3, 9], dtype=dtype)
151
152
assert_series_equal(a.log(b), pl.Series("a", [0, 1, 1, 3, 2], dtype=dtype))
153
assert_series_equal(
154
a.log(pl.Series("a", [3], dtype=dtype)),
155
pl.Series("a", [0, 1, 2, 3, 4], dtype=dtype),
156
)
157
assert_series_equal(
158
pl.Series("a", [81], dtype=dtype).log(b),
159
pl.Series("a", [4, 4, 2, 4, 2], dtype=dtype),
160
)
161
162
163
@pytest.mark.parametrize(
164
"dtype",
165
[
166
pl.Float32,
167
pl.Int32,
168
pl.Int64,
169
],
170
)
171
def test_log_broadcast_upcasting(dtype: pl.DataType) -> None:
172
a = pl.Series("a", [1, 3, 9, 27, 81], dtype=dtype)
173
b = pl.Series("a", [3, 3, 9, 3, 9], dtype=dtype)
174
expected = pl.Series("a", [0, 1, 1, 3, 2], dtype=pl.Float64)
175
176
assert_series_equal(a.log(b.cast(pl.Float64)), expected)
177
assert_series_equal(a.cast(pl.Float64).log(b), expected)
178
179
180
@pytest.mark.parametrize(
181
("dtype_a", "dtype_base", "dtype_out"),
182
[
183
(pl.UInt8, pl.UInt8, pl.Float64),
184
(pl.Int32, pl.Int32, pl.Float64),
185
(pl.Decimal(21, 3), pl.Decimal(21, 3), pl.Float64),
186
(pl.Float32, pl.Float32, pl.Float32),
187
(pl.Float32, pl.Float64, pl.Float64),
188
(pl.Float64, pl.Float32, pl.Float64),
189
(pl.Float64, pl.Float64, pl.Float64),
190
],
191
)
192
def test_log(
193
dtype_a: PolarsDataType,
194
dtype_base: PolarsDataType,
195
dtype_out: PolarsDataType,
196
) -> None:
197
a = pl.Series("a", [1, 3, 9, 27, 81], dtype=dtype_a)
198
base = pl.Series("base", [3, 3, 9, 3, 9], dtype=dtype_base)
199
lf = pl.LazyFrame([a, base])
200
201
# log
202
result = lf.select(pl.col("a").log("base"))
203
expected = pl.DataFrame({"a": pl.Series([0, 1, 1, 3, 2], dtype=dtype_out)})
204
205
assert_frame_equal(result.collect(), expected)
206
assert result.collect_schema() == expected.schema
207
208
209
@pytest.mark.parametrize(
210
("dtype_in", "dtype_out"),
211
[
212
(pl.Int32, pl.Float64),
213
(pl.Float32, pl.Float32),
214
(pl.Float64, pl.Float64),
215
],
216
)
217
def test_exp_log1p(dtype_in: PolarsDataType, dtype_out: PolarsDataType) -> None:
218
a = pl.Series("a", [1, 3, 9, 4, 10], dtype=dtype_in)
219
lf = pl.LazyFrame([a])
220
221
# exp
222
result = lf.select(pl.col("a").exp())
223
expected = pl.Series("a", np.exp(a.to_numpy())).cast(dtype_out).to_frame()
224
assert_frame_equal(result.collect(), expected)
225
assert result.collect_schema() == expected.schema
226
227
# log1p
228
result = lf.select(pl.col("a").log1p())
229
expected = pl.Series("a", np.log1p(a.to_numpy())).cast(dtype_out).to_frame()
230
assert_frame_equal(result.collect(), expected)
231
assert result.collect_schema() == expected.schema
232
233
234
def test_dot_in_group_by() -> None:
235
df = pl.DataFrame(
236
{
237
"group": ["a", "a", "a", "b", "b", "b"],
238
"x": [1, 1, 1, 1, 1, 1],
239
"y": [1, 2, 3, 4, 5, 6],
240
}
241
)
242
243
result = df.group_by("group", maintain_order=True).agg(
244
pl.col("x").dot("y").alias("dot")
245
)
246
expected = pl.DataFrame({"group": ["a", "b"], "dot": [6, 15]})
247
assert_frame_equal(result, expected)
248
249
250
def test_dtype_col_selection() -> None:
251
df = pl.DataFrame(
252
data=[],
253
schema={
254
"a1": pl.Datetime,
255
"a2": pl.Datetime("ms"),
256
"a3": pl.Datetime("ms"),
257
"a4": pl.Datetime("ns"),
258
"b": pl.Date,
259
"c": pl.Time,
260
"d1": pl.Duration,
261
"d2": pl.Duration("ms"),
262
"d3": pl.Duration("us"),
263
"d4": pl.Duration("ns"),
264
"e": pl.Int8,
265
"f": pl.Int16,
266
"g": pl.Int32,
267
"h": pl.Int64,
268
"i": pl.Float32,
269
"j": pl.Float64,
270
"k": pl.UInt8,
271
"l": pl.UInt16,
272
"m": pl.UInt32,
273
"n": pl.UInt64,
274
},
275
)
276
assert df.select(pl.col(INTEGER_DTYPES)).columns == [
277
"e",
278
"f",
279
"g",
280
"h",
281
"k",
282
"l",
283
"m",
284
"n",
285
]
286
assert df.select(pl.col(FLOAT_DTYPES)).columns == ["i", "j"]
287
assert df.select(pl.col(NUMERIC_DTYPES)).columns == [
288
"e",
289
"f",
290
"g",
291
"h",
292
"i",
293
"j",
294
"k",
295
"l",
296
"m",
297
"n",
298
]
299
assert df.select(pl.col(TEMPORAL_DTYPES)).columns == [
300
"a1",
301
"a2",
302
"a3",
303
"a4",
304
"b",
305
"c",
306
"d1",
307
"d2",
308
"d3",
309
"d4",
310
]
311
assert df.select(pl.col(DATETIME_DTYPES)).columns == [
312
"a1",
313
"a2",
314
"a3",
315
"a4",
316
]
317
assert df.select(pl.col(DURATION_DTYPES)).columns == [
318
"d1",
319
"d2",
320
"d3",
321
"d4",
322
]
323
324
325
def test_list_eval_expression() -> None:
326
df = pl.DataFrame({"a": [1, 8, 3], "b": [4, 5, 2]})
327
328
for parallel in [True, False]:
329
assert df.with_columns(
330
pl.concat_list(["a", "b"])
331
.list.eval(pl.element().rank(), parallel=parallel)
332
.alias("rank")
333
).to_dict(as_series=False) == {
334
"a": [1, 8, 3],
335
"b": [4, 5, 2],
336
"rank": [[1.0, 2.0], [2.0, 1.0], [2.0, 1.0]],
337
}
338
339
assert df["a"].reshape((1, -1)).arr.to_list().list.eval(
340
pl.element(), parallel=parallel
341
).to_list() == [[1, 8, 3]]
342
343
344
def test_null_count_expr() -> None:
345
df = pl.DataFrame({"key": ["a", "b", "b", "a"], "val": [1, 2, None, 1]})
346
347
assert df.select([pl.all().null_count()]).to_dict(as_series=False) == {
348
"key": [0],
349
"val": [1],
350
}
351
352
353
def test_pos_neg() -> None:
354
df = pl.DataFrame(
355
{
356
"x": [3, 2, 1],
357
"y": [6, 7, 8],
358
}
359
).with_columns(-pl.col("x"), +pl.col("y"), -pl.lit(1))
360
361
# #11149: ensure that we preserve the output name (where available)
362
assert df.to_dict(as_series=False) == {
363
"x": [-3, -2, -1],
364
"y": [6, 7, 8],
365
"literal": [-1, -1, -1],
366
}
367
368
369
def test_power_by_expression() -> None:
370
out = pl.DataFrame(
371
{"a": [1, None, None, 4, 5, 6], "b": [1, 2, None, 4, None, 6]}
372
).select(
373
pl.col("a").pow(pl.col("b")).alias("pow_expr"),
374
(pl.col("a") ** pl.col("b")).alias("pow_op"),
375
(2 ** pl.col("b")).alias("pow_op_left"),
376
)
377
378
for pow_col in ("pow_expr", "pow_op"):
379
assert out[pow_col].to_list() == [1.0, None, None, 256.0, None, 46656.0]
380
assert out["pow_op_left"].to_list() == [2.0, 4.0, None, 16.0, None, 64.0]
381
382
383
@pytest.mark.may_fail_cloud # reason: chunking
384
@pytest.mark.may_fail_auto_streaming
385
def test_expression_appends() -> None:
386
df = pl.DataFrame({"a": [1, 1, 2]})
387
388
assert df.select(pl.repeat(None, 3).append(pl.col("a"))).n_chunks() == 2
389
assert df.select(pl.repeat(None, 3).append(pl.col("a")).rechunk()).n_chunks() == 1
390
391
out = df.select(pl.concat([pl.repeat(None, 3), pl.col("a")], rechunk=True))
392
393
assert out.n_chunks() == 1
394
assert out.to_series().to_list() == [None, None, None, 1, 1, 2]
395
396
397
def test_arr_contains() -> None:
398
df_groups = pl.DataFrame(
399
{
400
"animals": [
401
["cat", "mouse", "dog"],
402
["dog", "hedgehog", "mouse", "cat"],
403
["peacock", "mouse", "aardvark"],
404
],
405
}
406
)
407
# string array contains
408
assert df_groups.lazy().filter(
409
pl.col("animals").list.contains("mouse"),
410
).collect().to_dict(as_series=False) == {
411
"animals": [
412
["cat", "mouse", "dog"],
413
["dog", "hedgehog", "mouse", "cat"],
414
["peacock", "mouse", "aardvark"],
415
]
416
}
417
# string array contains and *not* contains
418
assert df_groups.filter(
419
pl.col("animals").list.contains("mouse"),
420
~pl.col("animals").list.contains("hedgehog"),
421
).to_dict(as_series=False) == {
422
"animals": [
423
["cat", "mouse", "dog"],
424
["peacock", "mouse", "aardvark"],
425
],
426
}
427
428
429
def test_logical_boolean() -> None:
430
# note, cannot use expressions in logical
431
# boolean context (eg: and/or/not operators)
432
with pytest.raises(TypeError, match="ambiguous"):
433
pl.col("colx") and pl.col("coly") # type: ignore[redundant-expr]
434
435
with pytest.raises(TypeError, match="ambiguous"):
436
pl.col("colx") or pl.col("coly") # type: ignore[redundant-expr]
437
438
df = pl.DataFrame({"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, 4, 5]})
439
440
with pytest.raises(TypeError, match="ambiguous"):
441
df.select([(pl.col("a") > pl.col("b")) and (pl.col("b") > pl.col("b"))])
442
443
with pytest.raises(TypeError, match="ambiguous"):
444
df.select([(pl.col("a") > pl.col("b")) or (pl.col("b") > pl.col("b"))])
445
446
447
def test_lit_dtypes() -> None:
448
def lit_series(value: Any, dtype: PolarsDataType | None) -> pl.Series:
449
return pl.select(pl.lit(value, dtype=dtype)).to_series()
450
451
d = datetime(2049, 10, 5, 1, 2, 3, 987654)
452
d_ms = datetime(2049, 10, 5, 1, 2, 3, 987000)
453
d_tz = datetime(2049, 10, 5, 1, 2, 3, 987654, tzinfo=ZoneInfo("Asia/Kathmandu"))
454
455
td = timedelta(days=942, hours=6, microseconds=123456)
456
td_ms = timedelta(days=942, seconds=21600, microseconds=123000)
457
458
df = pl.DataFrame(
459
{
460
"dtm_ms": lit_series(d, pl.Datetime("ms")),
461
"dtm_us": lit_series(d, pl.Datetime("us")),
462
"dtm_ns": lit_series(d, pl.Datetime("ns")),
463
"dtm_aware_0": lit_series(d, pl.Datetime("us", "Asia/Kathmandu")),
464
"dtm_aware_1": lit_series(d_tz, pl.Datetime("us")),
465
"dtm_aware_2": lit_series(d_tz, None),
466
"dtm_aware_3": lit_series(d, pl.Datetime(time_zone="Asia/Kathmandu")),
467
"dur_ms": lit_series(td, pl.Duration("ms")),
468
"dur_us": lit_series(td, pl.Duration("us")),
469
"dur_ns": lit_series(td, pl.Duration("ns")),
470
"f32": lit_series(0, pl.Float32),
471
"u16": lit_series(0, pl.UInt16),
472
"i16": lit_series(0, pl.Int16),
473
"i64": lit_series(pl.Series([8]), None),
474
"list_i64": lit_series(pl.Series([[1, 2, 3]]), None),
475
}
476
)
477
assert df.dtypes == [
478
pl.Datetime("ms"),
479
pl.Datetime("us"),
480
pl.Datetime("ns"),
481
pl.Datetime("us", "Asia/Kathmandu"),
482
pl.Datetime("us", "Asia/Kathmandu"),
483
pl.Datetime("us", "Asia/Kathmandu"),
484
pl.Datetime("us", "Asia/Kathmandu"),
485
pl.Duration("ms"),
486
pl.Duration("us"),
487
pl.Duration("ns"),
488
pl.Float32,
489
pl.UInt16,
490
pl.Int16,
491
pl.Int64,
492
pl.List(pl.Int64),
493
]
494
assert df.row(0) == (
495
d_ms,
496
d,
497
d,
498
d_tz,
499
d_tz,
500
d_tz,
501
d_tz,
502
td_ms,
503
td,
504
td,
505
0,
506
0,
507
0,
508
8,
509
[1, 2, 3],
510
)
511
512
513
def test_lit_empty_tu() -> None:
514
td = timedelta(1)
515
assert pl.select(pl.lit(td, dtype=pl.Duration)).item() == td
516
assert pl.select(pl.lit(td, dtype=pl.Duration)).dtypes[0].time_unit == "us" # type: ignore[attr-defined]
517
518
t = datetime(2023, 1, 1)
519
assert pl.select(pl.lit(t, dtype=pl.Datetime)).item() == t
520
assert pl.select(pl.lit(t, dtype=pl.Datetime)).dtypes[0].time_unit == "us" # type: ignore[attr-defined]
521
522
523
def test_incompatible_lit_dtype() -> None:
524
with pytest.raises(
525
TypeError,
526
match=r"time zone of dtype \('Asia/Kathmandu'\) differs from time zone of value \(datetime.timezone.utc\)",
527
):
528
pl.lit(
529
datetime(2020, 1, 1, tzinfo=timezone.utc),
530
dtype=pl.Datetime("us", "Asia/Kathmandu"),
531
)
532
533
534
def test_lit_dtype_utc() -> None:
535
result = pl.select(
536
pl.lit(
537
datetime(2020, 1, 1, tzinfo=ZoneInfo("Asia/Kathmandu")),
538
dtype=pl.Datetime("us", "Asia/Kathmandu"),
539
)
540
)
541
expected = pl.DataFrame(
542
{"literal": [datetime(2019, 12, 31, 18, 15, tzinfo=timezone.utc)]}
543
).select(pl.col("literal").dt.convert_time_zone("Asia/Kathmandu"))
544
assert_frame_equal(result, expected)
545
546
547
@pytest.mark.parametrize(
548
("input", "expected"),
549
[
550
(("a",), ["b", "c"]),
551
(("a", "b"), ["c"]),
552
((["a", "b"],), ["c"]),
553
((pl.Int64,), ["c"]),
554
((pl.String, pl.Float32), ["a", "b"]),
555
(([pl.String, pl.Float32],), ["a", "b"]),
556
],
557
)
558
def test_exclude(input: tuple[Any, ...], expected: list[str]) -> None:
559
df = pl.DataFrame(schema={"a": pl.Int64, "b": pl.Int64, "c": pl.String})
560
assert df.select(pl.all().exclude(*input)).columns == expected
561
562
563
@pytest.mark.parametrize("input", [(5,), (["a"], date.today()), (pl.Int64, "a")])
564
def test_exclude_invalid_input(input: tuple[Any, ...]) -> None:
565
df = pl.DataFrame(schema=["a", "b", "c"])
566
with pytest.raises(TypeError):
567
df.select(pl.all().exclude(*input))
568
569
570
def test_operators_vs_expressions() -> None:
571
df = pl.DataFrame(
572
data={
573
"x": [5, 6, 7, 4, 8],
574
"y": [1.5, 2.5, 1.0, 4.0, -5.75],
575
"z": [-9, 2, -1, 4, 8],
576
}
577
)
578
for c1, c2 in permutations("xyz", r=2):
579
df_op = df.select(
580
a=pl.col(c1) == pl.col(c2),
581
b=pl.col(c1) // pl.col(c2),
582
c=pl.col(c1) > pl.col(c2),
583
d=pl.col(c1) >= pl.col(c2),
584
e=pl.col(c1) < pl.col(c2),
585
f=pl.col(c1) <= pl.col(c2),
586
g=pl.col(c1) % pl.col(c2),
587
h=pl.col(c1) != pl.col(c2),
588
i=pl.col(c1) - pl.col(c2),
589
j=pl.col(c1) / pl.col(c2),
590
k=pl.col(c1) * pl.col(c2),
591
l=pl.col(c1) + pl.col(c2),
592
)
593
df_expr = df.select(
594
a=pl.col(c1).eq(pl.col(c2)),
595
b=pl.col(c1).floordiv(pl.col(c2)),
596
c=pl.col(c1).gt(pl.col(c2)),
597
d=pl.col(c1).ge(pl.col(c2)),
598
e=pl.col(c1).lt(pl.col(c2)),
599
f=pl.col(c1).le(pl.col(c2)),
600
g=pl.col(c1).mod(pl.col(c2)),
601
h=pl.col(c1).ne(pl.col(c2)),
602
i=pl.col(c1).sub(pl.col(c2)),
603
j=pl.col(c1).truediv(pl.col(c2)),
604
k=pl.col(c1).mul(pl.col(c2)),
605
l=pl.col(c1).add(pl.col(c2)),
606
)
607
assert_frame_equal(df_op, df_expr)
608
609
# xor - only int cols
610
assert_frame_equal(
611
df.select(pl.col("x") ^ pl.col("z")),
612
df.select(pl.col("x").xor(pl.col("z"))),
613
)
614
615
# and (&) or (|) chains
616
assert_frame_equal(
617
df.select(
618
all=(pl.col("x") >= pl.col("z")).and_(
619
pl.col("y") >= pl.col("z"),
620
pl.col("y") == pl.col("y"),
621
pl.col("z") <= pl.col("x"),
622
pl.col("y") != pl.col("x"),
623
)
624
),
625
df.select(
626
all=(
627
(pl.col("x") >= pl.col("z"))
628
& (pl.col("y") >= pl.col("z"))
629
& (pl.col("y") == pl.col("y"))
630
& (pl.col("z") <= pl.col("x"))
631
& (pl.col("y") != pl.col("x"))
632
)
633
),
634
)
635
636
assert_frame_equal(
637
df.select(
638
any=(pl.col("x") == pl.col("y")).or_(
639
pl.col("x") == pl.col("y"),
640
pl.col("y") == pl.col("z"),
641
pl.col("y").cast(int) == pl.col("z"),
642
)
643
),
644
df.select(
645
any=(pl.col("x") == pl.col("y"))
646
| (pl.col("x") == pl.col("y"))
647
| (pl.col("y") == pl.col("z"))
648
| (pl.col("y").cast(int) == pl.col("z"))
649
),
650
)
651
652
653
def test_head() -> None:
654
df = pl.DataFrame({"a": [1, 2, 3, 4, 5]})
655
assert df.select(pl.col("a").head(0)).to_dict(as_series=False) == {"a": []}
656
assert df.select(pl.col("a").head(3)).to_dict(as_series=False) == {"a": [1, 2, 3]}
657
assert df.select(pl.col("a").head(10)).to_dict(as_series=False) == {
658
"a": [1, 2, 3, 4, 5]
659
}
660
assert df.select(pl.col("a").head(pl.len() // 2)).to_dict(as_series=False) == {
661
"a": [1, 2]
662
}
663
664
665
def test_tail() -> None:
666
df = pl.DataFrame({"a": [1, 2, 3, 4, 5]})
667
assert df.select(pl.col("a").tail(0)).to_dict(as_series=False) == {"a": []}
668
assert df.select(pl.col("a").tail(3)).to_dict(as_series=False) == {"a": [3, 4, 5]}
669
assert df.select(pl.col("a").tail(10)).to_dict(as_series=False) == {
670
"a": [1, 2, 3, 4, 5]
671
}
672
assert df.select(pl.col("a").tail(pl.len() // 2)).to_dict(as_series=False) == {
673
"a": [4, 5]
674
}
675
676
677
def test_repr_short_expression() -> None:
678
expr = pl.functions.all().len().name.prefix("length-long:")
679
# we cut off the last ten characters because that includes the
680
# memory location which will vary between runs
681
result = repr(expr).split("0x")[0]
682
683
expected = "<Expr ['cs.all().len().prefix(length-l…'] at "
684
assert result == expected
685
686
687
def test_repr_long_expression() -> None:
688
expr = pl.functions.col(pl.String).str.count_matches("")
689
690
# we cut off the last ten characters because that includes the
691
# memory location which will vary between runs
692
result = repr(expr).split("0x")[0]
693
694
# note the … denoting that there was truncated text
695
expected = "<Expr ['cs.string().str.count_matches(…'] at "
696
assert result == expected
697
assert repr(expr).endswith(">")
698
699
700
def test_repr_gather() -> None:
701
result = repr(pl.col("a").gather(0))
702
assert 'col("a").gather(dyn int: 0)' in result
703
result = repr(pl.col("a").get(0))
704
assert 'col("a").get(dyn int: 0)' in result
705
706
707
def test_replace_no_cse() -> None:
708
plan = (
709
pl.LazyFrame({"a": [1], "b": [2]})
710
.select([(pl.col("a") * pl.col("a")).sum().replace(1, None)])
711
.explain()
712
)
713
assert "POLARS_CSER" not in plan
714
715
716
def test_slice_rejects_non_integral() -> None:
717
df = pl.LazyFrame({"a": [0, 1, 2, 3], "b": [1.5, 2, 3, 4]})
718
719
with pytest.raises(pl.exceptions.InvalidOperationError):
720
df.select(pl.col("a").slice(pl.col("b").slice(0, 1), None)).collect()
721
722
with pytest.raises(pl.exceptions.InvalidOperationError):
723
df.select(pl.col("a").slice(0, pl.col("b").slice(1, 2))).collect()
724
725
with pytest.raises(pl.exceptions.InvalidOperationError):
726
df.select(pl.col("a").slice(pl.lit("1"), None)).collect()
727
728
729
def test_slice() -> None:
730
data = {"a": [0, 1, 2, 3], "b": [1, 2, 3, 4]}
731
df = pl.DataFrame(data)
732
733
result = df.select(pl.col("a").slice(1))
734
expected = pl.DataFrame({"a": data["a"][1:]})
735
assert_frame_equal(result, expected)
736
737
result = df.select(pl.all().slice(1, 1))
738
expected = pl.DataFrame({"a": data["a"][1:2], "b": data["b"][1:2]})
739
assert_frame_equal(result, expected)
740
741
742
@pytest.mark.may_fail_cloud # reason: shrink_dtype
743
def test_function_expr_scalar_identification_18755() -> None:
744
# The function uses `ApplyOptions::GroupWise`, however the input is scalar.
745
with pytest.warns(DeprecationWarning):
746
assert_frame_equal(
747
pl.DataFrame({"a": [1, 2]}).with_columns(
748
pl.lit(5, pl.Int64).shrink_dtype().alias("b")
749
),
750
pl.DataFrame({"a": [1, 2], "b": pl.Series([5, 5], dtype=pl.Int64)}),
751
)
752
753
754
def test_concat_deprecation() -> None:
755
with pytest.deprecated_call(match="`str.concat` is deprecated."):
756
pl.Series(["foo"]).str.concat()
757
with pytest.deprecated_call(match="`str.concat` is deprecated."):
758
pl.DataFrame({"foo": ["bar"]}).select(pl.all().str.concat())
759
760