Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/test_arity.py
6939 views
1
from __future__ import annotations
2
3
import pytest
4
5
import polars as pl
6
from polars.testing import assert_frame_equal
7
8
9
def test_expression_literal_series_order() -> None:
10
s = pl.Series([1, 2, 3])
11
df = pl.DataFrame({"a": [1, 2, 3]})
12
13
result = df.select(pl.col("a") + s)
14
expected = pl.DataFrame({"a": [2, 4, 6]})
15
assert_frame_equal(result, expected)
16
17
result = df.select(pl.lit(s) + pl.col("a"))
18
expected = pl.DataFrame({"": [2, 4, 6]})
19
assert_frame_equal(result, expected)
20
21
22
def test_when_then_broadcast_nulls_12665() -> None:
23
df = pl.DataFrame(
24
{
25
"val": [1, 2, 3, 4],
26
"threshold": [4, None, None, 1],
27
}
28
)
29
30
assert df.select(
31
when=pl.when(pl.col("val") > pl.col("threshold")).then(1).otherwise(0),
32
).to_dict(as_series=False) == {"when": [0, 0, 0, 1]}
33
34
35
@pytest.mark.parametrize(
36
("needs_broadcast", "expect_contains"),
37
[
38
(pl.lit("a"), [True, False, False]),
39
(pl.col("name").head(1), [True, False, False]),
40
(pl.lit(None, dtype=pl.String), [None, None, None]),
41
(pl.col("null_utf8").head(1), [None, None, None]),
42
],
43
)
44
@pytest.mark.parametrize("literal", [True, False])
45
@pytest.mark.parametrize(
46
"df",
47
[
48
pl.DataFrame(
49
{
50
"name": ["a", "b", "c"],
51
"null_utf8": pl.Series([None, None, None], dtype=pl.String),
52
}
53
)
54
],
55
)
56
def test_broadcast_string_ops_12632(
57
df: pl.DataFrame,
58
needs_broadcast: pl.Expr,
59
expect_contains: list[bool],
60
literal: bool,
61
) -> None:
62
assert (
63
df.select(needs_broadcast.str.contains(pl.col("name"), literal=literal))
64
.to_series()
65
.to_list()
66
== expect_contains
67
)
68
69
assert (
70
df.select(needs_broadcast.str.starts_with(pl.col("name"))).to_series().to_list()
71
== expect_contains
72
)
73
74
assert (
75
df.select(needs_broadcast.str.ends_with(pl.col("name"))).to_series().to_list()
76
== expect_contains
77
)
78
79
assert df.select(needs_broadcast.str.strip_chars(pl.col("name"))).height == 3
80
assert df.select(needs_broadcast.str.strip_chars_start(pl.col("name"))).height == 3
81
assert df.select(needs_broadcast.str.strip_chars_end(pl.col("name"))).height == 3
82
83
84
def test_negate_inlined_14278() -> None:
85
df = pl.DataFrame(
86
{"group": ["A", "A", "B", "B", "B", "C", "C"], "value": [1, 2, 3, 4, 5, 6, 7]}
87
)
88
89
agg_expr = [
90
pl.struct("group", "value").tail(2).alias("list"),
91
pl.col("value").sort().tail(2).count().alias("count"),
92
]
93
94
q = df.lazy().group_by("group").agg(agg_expr)
95
assert q.collect().sort("group").to_dict(as_series=False) == {
96
"group": ["A", "B", "C"],
97
"list": [
98
[{"group": "A", "value": 1}, {"group": "A", "value": 2}],
99
[{"group": "B", "value": 4}, {"group": "B", "value": 5}],
100
[{"group": "C", "value": 6}, {"group": "C", "value": 7}],
101
],
102
"count": [2, 2, 2],
103
}
104
105
106
def test_nested_level_literals_17377() -> None:
107
df = pl.LazyFrame({"group": [1, 2], "value": [1, 2]})
108
109
df2 = df.group_by("group").agg(
110
[
111
pl.when((pl.col("value") < 0).all())
112
.then(None)
113
.otherwise(pl.col("value").mean())
114
.alias("res")
115
]
116
)
117
118
assert df2.collect_schema() == pl.Schema({"group": pl.Int64(), "res": pl.Float64()})
119
120