Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/series/test_scatter.py
8424 views
1
from datetime import date, datetime
2
from typing import Any
3
4
import numpy as np
5
import pytest
6
7
import polars as pl
8
from polars._typing import PolarsDataType
9
from polars.exceptions import ComputeError, InvalidOperationError, OutOfBoundsError
10
from polars.testing import assert_series_equal
11
12
13
@pytest.mark.parametrize(
14
"input",
15
[
16
(),
17
[],
18
pl.Series(),
19
pl.Series(dtype=pl.Int8),
20
np.array([]),
21
],
22
)
23
def test_scatter_noop(input: Any) -> None:
24
s = pl.Series("s", [1, 2, 3])
25
s.scatter(input, 8)
26
assert s.to_list() == [1, 2, 3]
27
28
29
def test_scatter() -> None:
30
s = pl.Series("s", [1, 2, 3])
31
32
# set new values, one index at a time
33
s.scatter(0, 8)
34
s.scatter([1], None)
35
assert s.to_list() == [8, None, 3]
36
37
# set new value at multiple indexes in one go
38
s.scatter([0, 2], None)
39
assert s.to_list() == [None, None, None]
40
41
# try with different series dtype
42
s = pl.Series("s", ["a", "b", "c"])
43
s.scatter((1, 2), "x")
44
assert s.to_list() == ["a", "x", "x"]
45
assert s.scatter([0, 2], 0.12345).to_list() == ["0.12345", "x", "0.12345"]
46
47
# set multiple values
48
s = pl.Series(["z", "z", "z"])
49
assert s.scatter([0, 1], ["a", "b"]).to_list() == ["a", "b", "z"]
50
s = pl.Series([True, False, True])
51
assert s.scatter([0, 1], [False, True]).to_list() == [False, True, True]
52
53
# set negative indices
54
a = pl.Series("r", range(5))
55
a[-2] = None
56
a[-5] = None
57
assert a.to_list() == [None, 1, 2, None, 4]
58
59
a = pl.Series("x", [1, 2])
60
with pytest.raises(OutOfBoundsError):
61
a[-100] = None
62
assert_series_equal(a, pl.Series("x", [1, 2]))
63
64
65
def test_index_with_None_errors_16905() -> None:
66
s = pl.Series("s", [1, 2, 3])
67
with pytest.raises(ComputeError, match="index values should not be null"):
68
s[[1, None]] = 5
69
# The error doesn't trash the series, as it used to:
70
assert_series_equal(s, pl.Series("s", [1, 2, 3]))
71
72
73
def test_object_dtype_16905() -> None:
74
obj = object()
75
s = pl.Series("s", [obj, 27], dtype=pl.Object)
76
# This operation is not semantically wrong, it might be supported in the
77
# future, but for now it isn't.
78
with pytest.raises(InvalidOperationError):
79
s[0] = 5
80
# The error doesn't trash the series, as it used to:
81
assert s.dtype.is_object()
82
assert s.name == "s"
83
assert s.to_list() == [obj, 27]
84
85
86
def test_scatter_datetime() -> None:
87
s = pl.Series("dt", [None, datetime(2024, 1, 31)])
88
result = s.scatter(0, datetime(2022, 2, 2))
89
expected = pl.Series("dt", [datetime(2022, 2, 2), datetime(2024, 1, 31)])
90
assert_series_equal(result, expected)
91
92
93
def test_scatter_logical_all_null() -> None:
94
s = pl.Series("dt", [None, None], dtype=pl.Date)
95
result = s.scatter(0, date(2022, 2, 2))
96
expected = pl.Series("dt", [date(2022, 2, 2), None])
97
assert_series_equal(result, expected)
98
99
100
def test_scatter_categorical_21175() -> None:
101
s = pl.Series(["a", "b", "c"], dtype=pl.Categorical)
102
assert_series_equal(
103
s.scatter(0, "b"), pl.Series(["b", "b", "c"], dtype=pl.Categorical)
104
)
105
v = pl.Series(["v"], dtype=pl.Categorical)
106
assert_series_equal(
107
s.scatter([0, 2], v), pl.Series(["v", "b", "v"], dtype=pl.Categorical)
108
)
109
110
with pytest.raises(InvalidOperationError):
111
s.scatter(1, 2)
112
113
114
def test_scatter_enum() -> None:
115
e = pl.Enum(["a", "b", "c", "v"])
116
s = pl.Series(["a", "b", "c"], dtype=e)
117
assert_series_equal(s.scatter(0, "b"), pl.Series(["b", "b", "c"], dtype=e))
118
v = pl.Series(["v"], dtype=pl.Categorical)
119
assert_series_equal(s.scatter([0, 2], v), pl.Series(["v", "b", "v"], dtype=e))
120
121
with pytest.raises(InvalidOperationError):
122
s.scatter(1, "d")
123
124
with pytest.raises(InvalidOperationError):
125
s.scatter(1, 2)
126
127
128
@pytest.mark.parametrize("dtype", [pl.Categorical, pl.Enum(["a", "b"])])
129
def test_scatter_null(dtype: PolarsDataType) -> None:
130
s = pl.Series("a", ["a", "b"], dtype=dtype)
131
result = s.scatter(0, None)
132
expected = pl.Series("a", [None, "b"], dtype=dtype)
133
assert_series_equal(result, expected)
134
135
136
def test_scatter_decimal_25869() -> None:
137
s = pl.Series([1, 2, 3], dtype=pl.Decimal(scale=10))
138
assert_series_equal(
139
s.scatter(0, 15), pl.Series([15, 2, 3], dtype=pl.Decimal(scale=10))
140
)
141
142
with pytest.raises(InvalidOperationError):
143
s.scatter(1, "test")
144
145