Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/operations/test_over.py
6939 views
1
import pytest
2
3
import polars as pl
4
from polars.testing import assert_frame_equal, assert_series_equal
5
6
7
def test_implode_explode_over_22188() -> None:
8
df = pl.DataFrame(
9
{
10
"x": [1, 2, 3, 1, 2, 3, 1, 2, 3],
11
"y": [2, 2, 2, 3, 3, 3, 4, 4, 4],
12
}
13
)
14
result = df.select(
15
(pl.col.x * (pl.lit(pl.Series([1, 1, 1])).implode().explode())).over(pl.col.y),
16
)
17
18
assert_series_equal(result.to_series(), df.get_column("x"))
19
20
21
def test_implode_in_over_22188() -> None:
22
df = pl.DataFrame(
23
{
24
"x": [[1], [2], [3]],
25
"y": [2, 3, 4],
26
}
27
).select(pl.col.x.list.set_union(pl.lit(pl.Series([1])).implode()).over(pl.col.y))
28
assert_series_equal(df.to_series(), pl.Series("x", [[1], [2, 1], [3, 1]]))
29
30
31
def test_over_no_partition_by() -> None:
32
df = pl.DataFrame({"a": [1, 1, 2], "i": [2, 1, 3]})
33
result = df.with_columns(b=pl.col("a").cum_sum().over(order_by="i"))
34
expected = pl.DataFrame({"a": [1, 1, 2], "i": [2, 1, 3], "b": [2, 1, 4]})
35
assert_frame_equal(result, expected)
36
37
38
def test_over_no_partition_by_no_over() -> None:
39
df = pl.DataFrame({"a": [1, 1, 2], "i": [2, 1, 3]})
40
with pytest.raises(pl.exceptions.InvalidOperationError):
41
df.with_columns(b=pl.col("a").cum_sum().over())
42
43
44
def test_over_explode_22770() -> None:
45
df = pl.DataFrame({"x": [[1.0], [2.0]], "idx": [1, 2]})
46
e = pl.col("x").list.explode().over("idx", mapping_strategy="join")
47
48
assert_frame_equal(
49
df.select(pl.col("x").list.diff()),
50
df.select(e.list.diff()),
51
)
52
53
54
def test_over_replace_strict_22870() -> None:
55
lookup = pl.DataFrame(
56
{
57
"cat": ["a", "b", "c"],
58
"val": [102, 100, 101],
59
}
60
)
61
62
df = pl.DataFrame(
63
{
64
"cat": ["a", "b", "a", "a", "b"],
65
"data": [2, 3, 4, 5, 6],
66
"a": ["a", "b", "c", "d", "e"],
67
"b": [102, 100, 101, 109, 110],
68
}
69
)
70
71
out = (
72
df.lazy()
73
.select(
74
pl.col("cat")
75
.replace_strict(lookup["cat"], lookup["val"], default=-1)
76
.alias("val"),
77
pl.col("cat")
78
.replace_strict(lookup["cat"], lookup["val"], default=-1)
79
.over("cat")
80
.alias("val_over"),
81
)
82
.collect()
83
)
84
assert_series_equal(
85
out.get_column("val"), out.get_column("val_over"), check_names=False
86
)
87
88
out = (
89
df.lazy()
90
.select(
91
pl.col("cat").replace_strict(pl.col.a, pl.col.b, default=-1).alias("val"),
92
pl.col("cat")
93
.replace_strict(pl.col.a, pl.col.b, default=-1)
94
.over("cat")
95
.alias("val_over"),
96
)
97
.collect()
98
)
99
assert_series_equal(
100
out.get_column("val"), out.get_column("val_over"), check_names=False
101
)
102
103