Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-ops/src/series/ops/replace.rs
6939 views
1
use polars_core::prelude::*;
2
use polars_core::utils::try_get_supertype;
3
use polars_error::polars_ensure;
4
5
use crate::frame::join::*;
6
use crate::prelude::*;
7
8
fn find_output_length(
9
fnname: &str,
10
items: impl IntoIterator<Item = (&'static str, usize)>,
11
) -> PolarsResult<usize> {
12
let mut length = 1;
13
for (argument_idx, (argument, l)) in items.into_iter().enumerate() {
14
if l != 1 {
15
if l != length && length != 1 {
16
polars_bail!(
17
length_mismatch = fnname,
18
l,
19
length,
20
argument = argument,
21
argument_idx = argument_idx
22
);
23
}
24
length = l;
25
}
26
}
27
Ok(length)
28
}
29
30
/// Replace values by different values of the same data type.
31
pub fn replace(s: &Series, old: &ListChunked, new: &ListChunked) -> PolarsResult<Series> {
32
find_output_length(
33
"replace",
34
[("self", s.len()), ("old", old.len()), ("new", new.len())],
35
)?;
36
37
polars_ensure!(
38
old.len() == 1 && new.len() == 1,
39
nyi = "`replace` with a replacement pattern per row"
40
);
41
42
let old = old.explode(true)?;
43
let new = new.explode(true)?;
44
45
if old.is_empty() {
46
return Ok(s.clone());
47
}
48
validate_old(&old)?;
49
50
let dtype = s.dtype();
51
let old = old.strict_cast(dtype)?;
52
let new = new.strict_cast(dtype)?;
53
54
if new.len() == 1 {
55
replace_by_single(s, &old, &new, s)
56
} else {
57
replace_by_multiple(s, old, new, s)
58
}
59
}
60
61
/// Replace all values by different values.
62
///
63
/// Unmatched values are replaced by a default value.
64
pub fn replace_or_default(
65
s: &Series,
66
old: &ListChunked,
67
new: &ListChunked,
68
default: &Series,
69
return_dtype: Option<DataType>,
70
) -> PolarsResult<Series> {
71
find_output_length(
72
"replace_strict",
73
[
74
("self", s.len()),
75
("old", old.len()),
76
("new", new.len()),
77
("default", default.len()),
78
],
79
)?;
80
81
polars_ensure!(
82
old.len() == 1 && new.len() == 1,
83
nyi = "`replace_strict` with a replacement pattern per row"
84
);
85
86
let old = old.explode(true)?;
87
let new = new.explode(true)?;
88
89
polars_ensure!(
90
default.len() == s.len() || default.len() == 1,
91
InvalidOperation: "`default` input for `replace_strict` must have the same length as the input or have length 1"
92
);
93
validate_old(&old)?;
94
95
let return_dtype = match return_dtype {
96
Some(dtype) => dtype,
97
None => try_get_supertype(new.dtype(), default.dtype())?,
98
};
99
let default = default.cast(&return_dtype)?;
100
101
if old.is_empty() {
102
let out = if default.len() == 1 && s.len() != 1 {
103
default.new_from_index(0, s.len())
104
} else {
105
default
106
};
107
return Ok(out);
108
}
109
110
let old = old.strict_cast(s.dtype())?;
111
let new = new.cast(&return_dtype)?;
112
113
if new.len() == 1 {
114
replace_by_single(s, &old, &new, &default)
115
} else {
116
replace_by_multiple(s, old, new, &default)
117
}
118
}
119
120
/// Replace all values by different values.
121
///
122
/// Raises an error if not all values were replaced.
123
pub fn replace_strict(
124
s: &Series,
125
old: &ListChunked,
126
new: &ListChunked,
127
return_dtype: Option<DataType>,
128
) -> PolarsResult<Series> {
129
find_output_length(
130
"replace_strict",
131
[("self", s.len()), ("old", old.len()), ("new", new.len())],
132
)?;
133
134
polars_ensure!(
135
old.len() == 1 && new.len() == 1,
136
nyi = "`replace_strict` with a replacement pattern per row"
137
);
138
139
let old = old.explode(true)?;
140
let new = new.explode(true)?;
141
142
if old.is_empty() {
143
polars_ensure!(
144
s.len() == s.null_count(),
145
InvalidOperation: "must specify which values to replace"
146
);
147
return Ok(s.clone());
148
}
149
validate_old(&old)?;
150
151
let old = old.strict_cast(s.dtype())?;
152
let new = match return_dtype {
153
Some(dtype) => new.strict_cast(&dtype)?,
154
None => new,
155
};
156
157
if new.len() == 1 {
158
replace_by_single_strict(s, &old, &new)
159
} else {
160
replace_by_multiple_strict(s, old, new)
161
}
162
}
163
164
/// Validate the `old` input.
165
fn validate_old(old: &Series) -> PolarsResult<()> {
166
polars_ensure!(
167
old.n_unique()? == old.len(),
168
InvalidOperation: "`old` input for `replace` must not contain duplicates"
169
);
170
Ok(())
171
}
172
173
// Fast path for replacing by a single value
174
fn replace_by_single(
175
s: &Series,
176
old: &Series,
177
new: &Series,
178
default: &Series,
179
) -> PolarsResult<Series> {
180
let mut mask = get_replacement_mask(s, old)?;
181
if old.null_count() > 0 {
182
mask = mask.fill_null_with_values(true)?;
183
}
184
new.zip_with(&mask, default)
185
}
186
/// Fast path for replacing by a single value in strict mode
187
fn replace_by_single_strict(s: &Series, old: &Series, new: &Series) -> PolarsResult<Series> {
188
let mask = get_replacement_mask(s, old)?;
189
ensure_all_replaced(&mask, s, old.null_count() > 0, true)?;
190
191
let mut out = new.new_from_index(0, s.len());
192
193
// Transfer validity from `mask` to `out`.
194
if mask.null_count() > 0 {
195
out = out.zip_with(&mask, &Series::new_null(PlSmallStr::EMPTY, s.len()))?
196
}
197
Ok(out)
198
}
199
/// Get a boolean mask of which values in the original Series will be replaced.
200
///
201
/// Null values are propagated to the mask.
202
fn get_replacement_mask(s: &Series, old: &Series) -> PolarsResult<BooleanChunked> {
203
if old.null_count() == old.len() {
204
// Fast path for when users are using `replace(None, ...)` instead of `fill_null`.
205
Ok(s.is_null())
206
} else {
207
let old = old.implode()?;
208
is_in(s, &old.into_series(), false)
209
}
210
}
211
212
/// General case for replacing by multiple values
213
fn replace_by_multiple(
214
s: &Series,
215
old: Series,
216
new: Series,
217
default: &Series,
218
) -> PolarsResult<Series> {
219
validate_new(&new, &old)?;
220
221
let df = s.clone().into_frame();
222
let add_replacer_mask = new.null_count() > 0;
223
let replacer = create_replacer(old, new, add_replacer_mask)?;
224
225
let joined = df.join(
226
&replacer,
227
[s.name().as_str()],
228
["__POLARS_REPLACE_OLD"],
229
JoinArgs {
230
how: JoinType::Left,
231
coalesce: JoinCoalesce::CoalesceColumns,
232
nulls_equal: true,
233
..Default::default()
234
},
235
None,
236
)?;
237
238
let replaced = joined
239
.column("__POLARS_REPLACE_NEW")
240
.unwrap()
241
.as_materialized_series();
242
243
if replaced.null_count() == 0 {
244
return Ok(replaced.clone());
245
}
246
247
match joined.column("__POLARS_REPLACE_MASK") {
248
Ok(col) => {
249
let mask = col.bool().unwrap();
250
replaced.zip_with(mask, default)
251
},
252
Err(_) => {
253
let mask = &replaced.is_not_null();
254
replaced.zip_with(mask, default)
255
},
256
}
257
}
258
259
/// General case for replacing by multiple values in strict mode
260
fn replace_by_multiple_strict(s: &Series, old: Series, new: Series) -> PolarsResult<Series> {
261
validate_new(&new, &old)?;
262
263
let df = s.clone().into_frame();
264
let old_has_null = old.null_count() > 0;
265
let replacer = create_replacer(old, new, true)?;
266
267
let joined = df.join(
268
&replacer,
269
[s.name().as_str()],
270
["__POLARS_REPLACE_OLD"],
271
JoinArgs {
272
how: JoinType::Left,
273
coalesce: JoinCoalesce::CoalesceColumns,
274
nulls_equal: true,
275
..Default::default()
276
},
277
None,
278
)?;
279
280
let replaced = joined.column("__POLARS_REPLACE_NEW").unwrap();
281
282
let mask = joined
283
.column("__POLARS_REPLACE_MASK")
284
.unwrap()
285
.bool()
286
.unwrap();
287
ensure_all_replaced(mask, s, old_has_null, false)?;
288
289
Ok(replaced.as_materialized_series().clone())
290
}
291
292
// Build replacer dataframe.
293
fn create_replacer(mut old: Series, mut new: Series, add_mask: bool) -> PolarsResult<DataFrame> {
294
old.rename(PlSmallStr::from_static("__POLARS_REPLACE_OLD"));
295
new.rename(PlSmallStr::from_static("__POLARS_REPLACE_NEW"));
296
297
let len = old.len();
298
let cols = if add_mask {
299
let mask = Column::new_scalar(
300
PlSmallStr::from_static("__POLARS_REPLACE_MASK"),
301
true.into(),
302
new.len(),
303
);
304
vec![old.into(), new.into(), mask]
305
} else {
306
vec![old.into(), new.into()]
307
};
308
let out = unsafe { DataFrame::new_no_checks(len, cols) };
309
Ok(out)
310
}
311
312
/// Validate the `new` input.
313
fn validate_new(new: &Series, old: &Series) -> PolarsResult<()> {
314
polars_ensure!(
315
new.len() == old.len(),
316
InvalidOperation: "`new` input for `replace` must have the same length as `old` or have length 1"
317
);
318
Ok(())
319
}
320
321
/// Ensure that all values were replaced.
322
fn ensure_all_replaced(
323
mask: &BooleanChunked,
324
s: &Series,
325
old_has_null: bool,
326
check_all: bool,
327
) -> PolarsResult<()> {
328
let nulls_check = if old_has_null {
329
mask.null_count() == 0
330
} else {
331
mask.null_count() == s.null_count()
332
};
333
// Checking booleans is only relevant for the 'replace_by_single' path.
334
let bools_check = !check_all || mask.all();
335
336
let all_replaced = bools_check && nulls_check;
337
polars_ensure!(
338
all_replaced,
339
InvalidOperation: "incomplete mapping specified for `replace_strict`\n\nHint: Pass a `default` value to set unmapped values."
340
);
341
Ok(())
342
}
343
344