Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-ops/src/series/ops/replace.rs
8421 views
1
use polars_core::prelude::*;
2
use polars_core::utils::try_get_supertype;
3
use polars_error::polars_ensure;
4
5
use crate::frame::join::*;
6
use crate::prelude::*;
7
8
fn find_output_length(
9
fnname: &str,
10
items: impl IntoIterator<Item = (&'static str, usize)>,
11
) -> PolarsResult<usize> {
12
let mut length = 1;
13
for (argument_idx, (argument, l)) in items.into_iter().enumerate() {
14
if l != 1 {
15
if l != length && length != 1 {
16
polars_bail!(
17
length_mismatch = fnname,
18
l,
19
length,
20
argument = argument,
21
argument_idx = argument_idx
22
);
23
}
24
length = l;
25
}
26
}
27
Ok(length)
28
}
29
30
/// Replace values by different values of the same data type.
31
pub fn replace(s: &Series, old: &ListChunked, new: &ListChunked) -> PolarsResult<Series> {
32
find_output_length(
33
"replace",
34
[("self", s.len()), ("old", old.len()), ("new", new.len())],
35
)?;
36
37
polars_ensure!(
38
old.len() == 1 && new.len() == 1,
39
nyi = "`replace` with a replacement pattern per row"
40
);
41
42
let old = old.explode(ExplodeOptions {
43
empty_as_null: false,
44
keep_nulls: true,
45
})?;
46
let new = new.explode(ExplodeOptions {
47
empty_as_null: false,
48
keep_nulls: true,
49
})?;
50
51
if old.is_empty() {
52
return Ok(s.clone());
53
}
54
validate_old(&old)?;
55
56
let dtype = s.dtype();
57
let old = old.strict_cast(dtype)?;
58
let new = new.strict_cast(dtype)?;
59
60
if new.len() == 1 {
61
replace_by_single(s, &old, &new, s)
62
} else {
63
replace_by_multiple(s, old, new, s)
64
}
65
}
66
67
/// Replace all values by different values.
68
///
69
/// Unmatched values are replaced by a default value.
70
pub fn replace_or_default(
71
s: &Series,
72
old: &ListChunked,
73
new: &ListChunked,
74
default: &Series,
75
return_dtype: Option<DataType>,
76
) -> PolarsResult<Series> {
77
find_output_length(
78
"replace_strict",
79
[
80
("self", s.len()),
81
("old", old.len()),
82
("new", new.len()),
83
("default", default.len()),
84
],
85
)?;
86
87
polars_ensure!(
88
old.len() == 1 && new.len() == 1,
89
nyi = "`replace_strict` with a replacement pattern per row"
90
);
91
92
let old = old.explode(ExplodeOptions {
93
empty_as_null: false,
94
keep_nulls: true,
95
})?;
96
let new = new.explode(ExplodeOptions {
97
empty_as_null: false,
98
keep_nulls: true,
99
})?;
100
101
polars_ensure!(
102
default.len() == s.len() || default.len() == 1,
103
InvalidOperation: "`default` input for `replace_strict` must have the same length as the input or have length 1"
104
);
105
validate_old(&old)?;
106
107
let return_dtype = match return_dtype {
108
Some(dtype) => dtype,
109
None => try_get_supertype(new.dtype(), default.dtype())?,
110
};
111
let default = default.cast(&return_dtype)?;
112
113
if old.is_empty() {
114
let out = if default.len() == 1 && s.len() != 1 {
115
default.new_from_index(0, s.len())
116
} else {
117
default
118
};
119
return Ok(out);
120
}
121
122
let old = old.strict_cast(s.dtype())?;
123
let new = new.cast(&return_dtype)?;
124
125
if new.len() == 1 {
126
replace_by_single(s, &old, &new, &default)
127
} else {
128
replace_by_multiple(s, old, new, &default)
129
}
130
}
131
132
/// Replace all values by different values.
133
///
134
/// Raises an error if not all values were replaced.
135
pub fn replace_strict(
136
s: &Series,
137
old: &ListChunked,
138
new: &ListChunked,
139
return_dtype: Option<DataType>,
140
) -> PolarsResult<Series> {
141
find_output_length(
142
"replace_strict",
143
[("self", s.len()), ("old", old.len()), ("new", new.len())],
144
)?;
145
146
polars_ensure!(
147
old.len() == 1 && new.len() == 1,
148
nyi = "`replace_strict` with a replacement pattern per row"
149
);
150
151
let old = old.explode(ExplodeOptions {
152
empty_as_null: false,
153
keep_nulls: true,
154
})?;
155
let new = new.explode(ExplodeOptions {
156
empty_as_null: false,
157
keep_nulls: true,
158
})?;
159
160
if old.is_empty() {
161
polars_ensure!(
162
s.len() == s.null_count(),
163
InvalidOperation: "must specify which values to replace"
164
);
165
return Ok(s.clone());
166
}
167
validate_old(&old)?;
168
169
// Extra check because strict_cast is too permissive, e.g. allows string -> struct cast.
170
if old.dtype().can_cast_to(s.dtype()) != Some(true) {
171
polars_bail!(
172
InvalidOperation: "cannot use values of type `{}` to replace values in a column of type `{}`",
173
old.dtype(),
174
s.dtype()
175
)
176
}
177
178
let old = old.strict_cast(s.dtype())?;
179
180
let new = match return_dtype {
181
Some(dtype) => new.strict_cast(&dtype)?,
182
None => new,
183
};
184
185
if new.len() == 1 {
186
replace_by_single_strict(s, &old, &new)
187
} else {
188
replace_by_multiple_strict(s, old, new)
189
}
190
}
191
192
/// Validate the `old` input.
193
fn validate_old(old: &Series) -> PolarsResult<()> {
194
polars_ensure!(
195
old.n_unique()? == old.len(),
196
InvalidOperation: "`old` input for `replace` must not contain duplicates"
197
);
198
Ok(())
199
}
200
201
// Fast path for replacing by a single value
202
fn replace_by_single(
203
s: &Series,
204
old: &Series,
205
new: &Series,
206
default: &Series,
207
) -> PolarsResult<Series> {
208
let mut mask = get_replacement_mask(s, old)?;
209
if old.null_count() > 0 {
210
mask = mask.fill_null_with_values(true)?;
211
}
212
new.zip_with(&mask, default)
213
}
214
/// Fast path for replacing by a single value in strict mode
215
fn replace_by_single_strict(s: &Series, old: &Series, new: &Series) -> PolarsResult<Series> {
216
let mask = get_replacement_mask(s, old)?;
217
ensure_all_replaced(&mask, s, old.null_count() > 0, true)?;
218
219
let mut out = new.new_from_index(0, s.len());
220
221
// Transfer validity from `mask` to `out`.
222
if mask.null_count() > 0 {
223
out = out.zip_with(&mask, &Series::new_null(PlSmallStr::EMPTY, s.len()))?
224
}
225
Ok(out)
226
}
227
/// Get a boolean mask of which values in the original Series will be replaced.
228
///
229
/// Null values are propagated to the mask.
230
fn get_replacement_mask(s: &Series, old: &Series) -> PolarsResult<BooleanChunked> {
231
if old.null_count() == old.len() {
232
// Fast path for when users are using `replace(None, ...)` instead of `fill_null`.
233
Ok(s.is_null())
234
} else {
235
let old = old.implode()?;
236
is_in(s, &old.into_series(), false)
237
}
238
}
239
240
/// General case for replacing by multiple values
241
fn replace_by_multiple(
242
s: &Series,
243
old: Series,
244
new: Series,
245
default: &Series,
246
) -> PolarsResult<Series> {
247
validate_new(&new, &old)?;
248
249
let df = s.clone().into_frame();
250
let add_replacer_mask = new.null_count() > 0;
251
let replacer = create_replacer(old, new, add_replacer_mask)?;
252
253
let joined = df.join(
254
&replacer,
255
[s.name().as_str()],
256
["__POLARS_REPLACE_OLD"],
257
JoinArgs {
258
how: JoinType::Left,
259
coalesce: JoinCoalesce::CoalesceColumns,
260
nulls_equal: true,
261
..Default::default()
262
},
263
None,
264
)?;
265
266
let replaced = joined
267
.column("__POLARS_REPLACE_NEW")
268
.unwrap()
269
.as_materialized_series();
270
271
if replaced.null_count() == 0 {
272
return Ok(replaced.clone());
273
}
274
275
match joined.column("__POLARS_REPLACE_MASK") {
276
Ok(col) => {
277
let mask = col.bool().unwrap();
278
replaced.zip_with(mask, default)
279
},
280
Err(_) => {
281
let mask = &replaced.is_not_null();
282
replaced.zip_with(mask, default)
283
},
284
}
285
}
286
287
/// General case for replacing by multiple values in strict mode
288
fn replace_by_multiple_strict(s: &Series, old: Series, new: Series) -> PolarsResult<Series> {
289
validate_new(&new, &old)?;
290
291
let df = s.clone().into_frame();
292
let old_has_null = old.null_count() > 0;
293
let replacer = create_replacer(old, new, true)?;
294
295
let joined = df.join(
296
&replacer,
297
[s.name().as_str()],
298
["__POLARS_REPLACE_OLD"],
299
JoinArgs {
300
how: JoinType::Left,
301
coalesce: JoinCoalesce::CoalesceColumns,
302
nulls_equal: true,
303
..Default::default()
304
},
305
None,
306
)?;
307
308
let replaced = joined.column("__POLARS_REPLACE_NEW").unwrap();
309
310
let mask = joined
311
.column("__POLARS_REPLACE_MASK")
312
.unwrap()
313
.bool()
314
.unwrap();
315
ensure_all_replaced(mask, s, old_has_null, false)?;
316
317
Ok(replaced.as_materialized_series().clone())
318
}
319
320
// Build replacer dataframe.
321
fn create_replacer(mut old: Series, mut new: Series, add_mask: bool) -> PolarsResult<DataFrame> {
322
old.rename(PlSmallStr::from_static("__POLARS_REPLACE_OLD"));
323
new.rename(PlSmallStr::from_static("__POLARS_REPLACE_NEW"));
324
325
let len = old.len();
326
let cols = if add_mask {
327
let mask = Column::new_scalar(
328
PlSmallStr::from_static("__POLARS_REPLACE_MASK"),
329
true.into(),
330
new.len(),
331
);
332
vec![old.into(), new.into(), mask]
333
} else {
334
vec![old.into(), new.into()]
335
};
336
let out = unsafe { DataFrame::new_unchecked(len, cols) };
337
Ok(out)
338
}
339
340
/// Validate the `new` input.
341
fn validate_new(new: &Series, old: &Series) -> PolarsResult<()> {
342
polars_ensure!(
343
new.len() == old.len(),
344
InvalidOperation: "`new` input for `replace` must have the same length as `old` or have length 1"
345
);
346
Ok(())
347
}
348
349
/// Ensure that all values were replaced.
350
fn ensure_all_replaced(
351
mask: &BooleanChunked,
352
s: &Series,
353
old_has_null: bool,
354
check_all: bool,
355
) -> PolarsResult<()> {
356
let nulls_check = if old_has_null {
357
mask.null_count() == 0
358
} else {
359
mask.null_count() == s.null_count()
360
};
361
// Checking booleans is only relevant for the 'replace_by_single' path.
362
let bools_check = !check_all || mask.all();
363
364
let all_replaced = bools_check && nulls_check;
365
polars_ensure!(
366
all_replaced,
367
InvalidOperation: "incomplete mapping specified for `replace_strict`\n\nHint: Pass a `default` value to set unmapped values."
368
);
369
Ok(())
370
}
371
372