Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/operations/test_replace.py
6939 views
1
from __future__ import annotations
2
3
from typing import Any
4
5
import pytest
6
7
import polars as pl
8
from polars.exceptions import InvalidOperationError
9
from polars.testing import assert_frame_equal, assert_series_equal
10
11
12
@pytest.fixture(scope="module")
13
def str_mapping() -> dict[str | None, str]:
14
return {
15
"CA": "Canada",
16
"DE": "Germany",
17
"FR": "France",
18
None: "Not specified",
19
}
20
21
22
def test_replace_str_to_str(str_mapping: dict[str | None, str]) -> None:
23
df = pl.DataFrame({"country_code": ["FR", None, "ES", "DE"]})
24
result = df.select(replaced=pl.col("country_code").replace(str_mapping))
25
expected = pl.DataFrame({"replaced": ["France", "Not specified", "ES", "Germany"]})
26
assert_frame_equal(result, expected)
27
28
29
def test_replace_enum() -> None:
30
dtype = pl.Enum(["a", "b", "c", "d"])
31
s = pl.Series(["a", "b", "c"], dtype=dtype)
32
old = ["a", "b"]
33
new = pl.Series(["c", "d"], dtype=dtype)
34
35
result = s.replace(old, new)
36
37
expected = pl.Series(["c", "d", "c"], dtype=dtype)
38
assert_series_equal(result, expected)
39
40
41
def test_replace_enum_to_str() -> None:
42
dtype = pl.Enum(["a", "b", "c", "d"])
43
s = pl.Series(["a", "b", "c"], dtype=dtype)
44
45
result = s.replace({"a": "c", "b": "d"})
46
47
expected = pl.Series(["c", "d", "c"], dtype=dtype)
48
assert_series_equal(result, expected)
49
50
51
def test_replace_cat_to_cat(str_mapping: dict[str | None, str]) -> None:
52
lf = pl.LazyFrame(
53
{"country_code": ["FR", None, "ES", "DE"]},
54
schema={"country_code": pl.Categorical},
55
)
56
old = pl.Series(["CA", "DE", "FR", None], dtype=pl.Categorical)
57
new = pl.Series(
58
["Canada", "Germany", "France", "Not specified"], dtype=pl.Categorical
59
)
60
61
result = lf.select(replaced=pl.col("country_code").replace(old, new))
62
63
expected = pl.LazyFrame(
64
{"replaced": ["France", "Not specified", "ES", "Germany"]},
65
schema_overrides={"replaced": pl.Categorical},
66
)
67
assert_frame_equal(result, expected)
68
69
70
def test_replace_invalid_old_dtype() -> None:
71
lf = pl.LazyFrame({"a": [1, 2, 3]})
72
mapping = {"a": 10, "b": 20}
73
with pytest.raises(
74
InvalidOperationError, match="conversion from `str` to `i64` failed"
75
):
76
lf.select(pl.col("a").replace(mapping)).collect()
77
78
79
def test_replace_int_to_int() -> None:
80
df = pl.DataFrame({"int": [None, 1, None, 3]}, schema={"int": pl.Int16})
81
mapping = {1: 5, 3: 7}
82
result = df.select(replaced=pl.col("int").replace(mapping))
83
expected = pl.DataFrame(
84
{"replaced": [None, 5, None, 7]}, schema={"replaced": pl.Int16}
85
)
86
assert_frame_equal(result, expected)
87
88
89
def test_replace_int_to_int_keep_dtype() -> None:
90
df = pl.DataFrame({"int": [None, 1, None, 3]}, schema={"int": pl.Int16})
91
old = [1, 3]
92
new = pl.Series([5, 7], dtype=pl.Int16)
93
94
result = df.select(replaced=pl.col("int").replace(old, new))
95
expected = pl.DataFrame(
96
{"replaced": [None, 5, None, 7]}, schema={"replaced": pl.Int16}
97
)
98
assert_frame_equal(result, expected)
99
100
101
def test_replace_int_to_str() -> None:
102
df = pl.DataFrame({"int": [None, 1, None, 3]}, schema={"int": pl.Int16})
103
mapping = {1: "b", 3: "d"}
104
with pytest.raises(
105
InvalidOperationError, match="conversion from `str` to `i16` failed"
106
):
107
df.select(replaced=pl.col("int").replace(mapping))
108
109
110
def test_replace_int_to_str_with_null() -> None:
111
df = pl.DataFrame({"int": [None, 1, None, 3]}, schema={"int": pl.Int16})
112
mapping = {1: "b", 3: "d", None: "e"}
113
with pytest.raises(
114
InvalidOperationError, match="conversion from `str` to `i16` failed"
115
):
116
df.select(replaced=pl.col("int").replace(mapping))
117
118
119
def test_replace_empty_mapping() -> None:
120
df = pl.DataFrame({"int": [None, 1, None, 3]}, schema={"int": pl.Int16})
121
mapping: dict[Any, Any] = {}
122
result = df.select(pl.col("int").replace(mapping))
123
assert_frame_equal(result, df)
124
125
126
def test_replace_mapping_different_dtype_str_int() -> None:
127
df = pl.DataFrame({"int": [None, "1", None, "3"]})
128
mapping = {1: "b", 3: "d"}
129
130
result = df.select(pl.col("int").replace(mapping))
131
expected = pl.DataFrame({"int": [None, "b", None, "d"]})
132
assert_frame_equal(result, expected)
133
134
135
def test_replace_mapping_different_dtype_map_none() -> None:
136
df = pl.DataFrame({"int": [None, "1", None, "3"]})
137
mapping = {1: "b", 3: "d", None: "e"}
138
result = df.select(pl.col("int").replace(mapping))
139
expected = pl.DataFrame({"int": ["e", "b", "e", "d"]})
140
assert_frame_equal(result, expected)
141
142
143
def test_replace_mapping_different_dtype_str_float() -> None:
144
df = pl.DataFrame({"int": [None, "1", None, "3"]})
145
mapping = {1.0: "b", 3.0: "d"}
146
147
result = df.select(pl.col("int").replace(mapping))
148
assert_frame_equal(result, df)
149
150
151
# https://github.com/pola-rs/polars/issues/7132
152
def test_replace_str_to_str_replace_all() -> None:
153
df = pl.DataFrame({"text": ["abc"]})
154
mapping = {"abc": "123"}
155
result = df.select(pl.col("text").replace(mapping).str.replace_all("1", "-"))
156
expected = pl.DataFrame({"text": ["-23"]})
157
assert_frame_equal(result, expected)
158
159
160
@pytest.fixture(scope="module")
161
def int_mapping() -> dict[int, int]:
162
return {1: 11, 2: 22, 3: 33, 4: 44, 5: 55}
163
164
165
def test_replace_int_to_int1(int_mapping: dict[int, int]) -> None:
166
s = pl.Series([-1, 22, None, 44, -5])
167
result = s.replace(int_mapping)
168
expected = pl.Series([-1, 22, None, 44, -5])
169
assert_series_equal(result, expected)
170
171
172
def test_replace_int_to_int4(int_mapping: dict[int, int]) -> None:
173
s = pl.Series([-1, 22, None, 44, -5])
174
result = s.replace(int_mapping)
175
expected = pl.Series([-1, 22, None, 44, -5])
176
assert_series_equal(result, expected)
177
178
179
# https://github.com/pola-rs/polars/issues/12728
180
def test_replace_str_to_int2() -> None:
181
s = pl.Series(["a", "b"])
182
mapping = {"a": 1, "b": 2}
183
result = s.replace(mapping)
184
expected = pl.Series(["1", "2"])
185
assert_series_equal(result, expected)
186
187
188
def test_replace_str_to_bool_without_default() -> None:
189
s = pl.Series(["True", "False", "False", None])
190
mapping = {"True": True, "False": False}
191
result = s.replace(mapping)
192
expected = pl.Series(["true", "false", "false", None])
193
assert_series_equal(result, expected)
194
195
196
def test_replace_old_new() -> None:
197
s = pl.Series([1, 2, 2, 3])
198
result = s.replace(2, 9)
199
expected = s = pl.Series([1, 9, 9, 3])
200
assert_series_equal(result, expected)
201
202
203
def test_replace_old_new_many_to_one() -> None:
204
s = pl.Series([1, 2, 2, 3])
205
result = s.replace([2, 3], 9)
206
expected = s = pl.Series([1, 9, 9, 9])
207
assert_series_equal(result, expected)
208
209
210
def test_replace_old_new_mismatched_lengths() -> None:
211
s = pl.Series([1, 2, 2, 3, 4])
212
with pytest.raises(InvalidOperationError):
213
s.replace([2, 3, 4], [8, 9])
214
215
216
def test_replace_fast_path_one_to_one() -> None:
217
lf = pl.LazyFrame({"a": [1, 2, 2, 3]})
218
result = lf.select(pl.col("a").replace(2, 100))
219
expected = pl.LazyFrame({"a": [1, 100, 100, 3]})
220
assert_frame_equal(result, expected)
221
222
223
def test_replace_fast_path_one_null_to_one() -> None:
224
# https://github.com/pola-rs/polars/issues/13391
225
lf = pl.LazyFrame({"a": [1, None]})
226
result = lf.select(pl.col("a").replace(None, 100))
227
expected = pl.LazyFrame({"a": [1, 100]})
228
assert_frame_equal(result, expected)
229
230
231
def test_replace_fast_path_many_with_null_to_one() -> None:
232
lf = pl.LazyFrame({"a": [1, 2, None]})
233
result = lf.select(pl.col("a").replace([1, None], 100))
234
expected = pl.LazyFrame({"a": [100, 2, 100]})
235
assert_frame_equal(result, expected)
236
237
238
def test_replace_fast_path_many_to_one() -> None:
239
lf = pl.LazyFrame({"a": [1, 2, 2, 3]})
240
result = lf.select(pl.col("a").replace([2, 3], 100))
241
expected = pl.LazyFrame({"a": [1, 100, 100, 100]})
242
assert_frame_equal(result, expected)
243
244
245
@pytest.mark.parametrize(
246
("old", "new"),
247
[
248
([2, 2], 100),
249
([2, 2], [100, 200]),
250
([2, 2], [100, 100]),
251
],
252
)
253
def test_replace_duplicates_old(old: list[int], new: int | list[int]) -> None:
254
s = pl.Series([1, 2, 3, 2, 3])
255
with pytest.raises(
256
InvalidOperationError,
257
match="`old` input for `replace` must not contain duplicates",
258
):
259
s.replace(old, new)
260
261
262
def test_replace_duplicates_new() -> None:
263
s = pl.Series([1, 2, 3, 2, 3])
264
result = s.replace([1, 2], [100, 100])
265
expected = s = pl.Series([100, 100, 3, 100, 3])
266
assert_series_equal(result, expected)
267
268
269
def test_replace_return_dtype_deprecated() -> None:
270
s = pl.Series([1, 2, 3])
271
with pytest.deprecated_call():
272
result = s.replace(1, 10, return_dtype=pl.Int8)
273
expected = pl.Series([10, 2, 3], dtype=pl.Int8)
274
assert_series_equal(result, expected)
275
276
277
def test_replace_default_deprecated() -> None:
278
s = pl.Series([1, 2, 3])
279
with pytest.deprecated_call():
280
result = s.replace(1, 10, default=None)
281
expected = pl.Series([10, None, None], dtype=pl.Int32)
282
assert_series_equal(result, expected)
283
284
285
def test_replace_single_argument_not_mapping() -> None:
286
df = pl.DataFrame({"a": ["a", "b", "c"]})
287
with pytest.raises(
288
TypeError,
289
match="`new` argument is required if `old` argument is not a Mapping type",
290
):
291
df.select(pl.col("a").replace("b"))
292
293