Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/operations/test_fill_null.py
6939 views
1
import datetime
2
3
import pytest
4
5
import polars as pl
6
from polars.testing import assert_frame_equal, assert_series_equal
7
8
9
def test_fill_null_minimal_upcast_4056() -> None:
10
df = pl.DataFrame({"a": [-1, 2, None]})
11
df = df.with_columns(pl.col("a").cast(pl.Int8))
12
assert df.with_columns(pl.col(pl.Int8).fill_null(-1)).dtypes[0] == pl.Int8
13
assert df.with_columns(pl.col(pl.Int8).fill_null(-1000)).dtypes[0] == pl.Int16
14
15
16
def test_fill_enum_upcast() -> None:
17
dtype = pl.Enum(["a", "b"])
18
s = pl.Series(["a", "b", None], dtype=dtype)
19
s_filled = s.fill_null("b")
20
expected = pl.Series(["a", "b", "b"], dtype=dtype)
21
assert s_filled.dtype == dtype
22
assert_series_equal(s_filled, expected)
23
24
25
def test_fill_null_static_schema_4843() -> None:
26
df1 = pl.DataFrame(
27
{
28
"a": [1, 2, None],
29
"b": [1, None, 4],
30
}
31
).lazy()
32
33
df2 = df1.select([pl.col(pl.Int64).fill_null(0)])
34
df3 = df2.select(pl.col(pl.Int64))
35
assert df3.collect_schema() == {"a": pl.Int64, "b": pl.Int64}
36
37
38
def test_fill_null_non_lit() -> None:
39
df = pl.DataFrame(
40
{
41
"a": pl.Series([1, None], dtype=pl.Int32),
42
"b": pl.Series([None, 2], dtype=pl.UInt32),
43
"c": pl.Series([None, 2], dtype=pl.Int64),
44
"d": pl.Series([None, 2], dtype=pl.Decimal),
45
}
46
)
47
assert df.fill_null(0).select(pl.all().null_count()).transpose().sum().item() == 0
48
49
50
def test_fill_null_f32_with_lit() -> None:
51
# ensure the literal integer does not upcast the f32 to an f64
52
df = pl.DataFrame({"a": [1.1, 1.2]}, schema=[("a", pl.Float32)])
53
assert df.fill_null(value=0).dtypes == [pl.Float32]
54
55
56
def test_fill_null_lit_() -> None:
57
df = pl.DataFrame(
58
{
59
"a": pl.Series([1, None], dtype=pl.Int32),
60
"b": pl.Series([None, 2], dtype=pl.UInt32),
61
"c": pl.Series([None, 2], dtype=pl.Int64),
62
}
63
)
64
assert (
65
df.fill_null(pl.lit(0)).select(pl.all().null_count()).transpose().sum().item()
66
== 0
67
)
68
69
70
def test_fill_null_decimal_with_int_14331() -> None:
71
s = pl.Series("a", ["1.1", None], dtype=pl.Decimal(precision=None, scale=5))
72
result = s.fill_null(0)
73
expected = pl.Series("a", ["1.1", "0.0"], dtype=pl.Decimal(precision=None, scale=5))
74
assert_series_equal(result, expected)
75
76
77
def test_fill_null_date_with_int_11362() -> None:
78
match = "got invalid or ambiguous dtypes"
79
80
s = pl.Series([datetime.date(2000, 1, 1)])
81
with pytest.raises(pl.exceptions.InvalidOperationError, match=match):
82
s.fill_null(0)
83
84
s = pl.Series([None], dtype=pl.Date)
85
with pytest.raises(pl.exceptions.InvalidOperationError, match=match):
86
s.fill_null(1)
87
88
89
def test_fill_null_int_dtype_15546() -> None:
90
df = pl.Series("a", [1, 2, None], dtype=pl.Int8).to_frame().lazy()
91
result = df.fill_null(0).collect()
92
expected = pl.Series("a", [1, 2, 0], dtype=pl.Int8).to_frame()
93
assert_frame_equal(result, expected)
94
95
96
def test_fill_null_with_list_10869() -> None:
97
assert_series_equal(
98
pl.Series([[1], None]).fill_null([2]),
99
pl.Series([[1], [2]]),
100
)
101
102
match = "failed to determine supertype"
103
with pytest.raises(pl.exceptions.SchemaError, match=match):
104
pl.Series([1, None]).fill_null([2])
105
106
107
def test_unequal_lengths_22018() -> None:
108
with pytest.raises(pl.exceptions.ShapeError):
109
pl.Series([1, None]).fill_null(pl.Series([1] * 3))
110
with pytest.raises(pl.exceptions.ShapeError):
111
pl.Series([1, 2]).fill_null(pl.Series([1] * 3))
112
113
114
def test_self_broadcast() -> None:
115
assert_series_equal(
116
pl.Series([1]).fill_null(pl.Series(range(3))),
117
pl.Series([1] * 3),
118
)
119
120
assert_series_equal(
121
pl.Series([None]).fill_null(pl.Series(range(3))),
122
pl.Series(range(3)),
123
)
124
125