Path: blob/main/crates/polars-ops/src/series/ops/replace.rs
6939 views
use polars_core::prelude::*;1use polars_core::utils::try_get_supertype;2use polars_error::polars_ensure;34use crate::frame::join::*;5use crate::prelude::*;67fn find_output_length(8fnname: &str,9items: impl IntoIterator<Item = (&'static str, usize)>,10) -> PolarsResult<usize> {11let mut length = 1;12for (argument_idx, (argument, l)) in items.into_iter().enumerate() {13if l != 1 {14if l != length && length != 1 {15polars_bail!(16length_mismatch = fnname,17l,18length,19argument = argument,20argument_idx = argument_idx21);22}23length = l;24}25}26Ok(length)27}2829/// Replace values by different values of the same data type.30pub fn replace(s: &Series, old: &ListChunked, new: &ListChunked) -> PolarsResult<Series> {31find_output_length(32"replace",33[("self", s.len()), ("old", old.len()), ("new", new.len())],34)?;3536polars_ensure!(37old.len() == 1 && new.len() == 1,38nyi = "`replace` with a replacement pattern per row"39);4041let old = old.explode(true)?;42let new = new.explode(true)?;4344if old.is_empty() {45return Ok(s.clone());46}47validate_old(&old)?;4849let dtype = s.dtype();50let old = old.strict_cast(dtype)?;51let new = new.strict_cast(dtype)?;5253if new.len() == 1 {54replace_by_single(s, &old, &new, s)55} else {56replace_by_multiple(s, old, new, s)57}58}5960/// Replace all values by different values.61///62/// Unmatched values are replaced by a default value.63pub fn replace_or_default(64s: &Series,65old: &ListChunked,66new: &ListChunked,67default: &Series,68return_dtype: Option<DataType>,69) -> PolarsResult<Series> {70find_output_length(71"replace_strict",72[73("self", s.len()),74("old", old.len()),75("new", new.len()),76("default", default.len()),77],78)?;7980polars_ensure!(81old.len() == 1 && new.len() == 1,82nyi = "`replace_strict` with a replacement pattern per row"83);8485let old = old.explode(true)?;86let new = new.explode(true)?;8788polars_ensure!(89default.len() == s.len() || default.len() == 1,90InvalidOperation: "`default` input for `replace_strict` must have the same length as the input or have length 1"91);92validate_old(&old)?;9394let return_dtype = match return_dtype {95Some(dtype) => dtype,96None => try_get_supertype(new.dtype(), default.dtype())?,97};98let default = default.cast(&return_dtype)?;99100if old.is_empty() {101let out = if default.len() == 1 && s.len() != 1 {102default.new_from_index(0, s.len())103} else {104default105};106return Ok(out);107}108109let old = old.strict_cast(s.dtype())?;110let new = new.cast(&return_dtype)?;111112if new.len() == 1 {113replace_by_single(s, &old, &new, &default)114} else {115replace_by_multiple(s, old, new, &default)116}117}118119/// Replace all values by different values.120///121/// Raises an error if not all values were replaced.122pub fn replace_strict(123s: &Series,124old: &ListChunked,125new: &ListChunked,126return_dtype: Option<DataType>,127) -> PolarsResult<Series> {128find_output_length(129"replace_strict",130[("self", s.len()), ("old", old.len()), ("new", new.len())],131)?;132133polars_ensure!(134old.len() == 1 && new.len() == 1,135nyi = "`replace_strict` with a replacement pattern per row"136);137138let old = old.explode(true)?;139let new = new.explode(true)?;140141if old.is_empty() {142polars_ensure!(143s.len() == s.null_count(),144InvalidOperation: "must specify which values to replace"145);146return Ok(s.clone());147}148validate_old(&old)?;149150let old = old.strict_cast(s.dtype())?;151let new = match return_dtype {152Some(dtype) => new.strict_cast(&dtype)?,153None => new,154};155156if new.len() == 1 {157replace_by_single_strict(s, &old, &new)158} else {159replace_by_multiple_strict(s, old, new)160}161}162163/// Validate the `old` input.164fn validate_old(old: &Series) -> PolarsResult<()> {165polars_ensure!(166old.n_unique()? == old.len(),167InvalidOperation: "`old` input for `replace` must not contain duplicates"168);169Ok(())170}171172// Fast path for replacing by a single value173fn replace_by_single(174s: &Series,175old: &Series,176new: &Series,177default: &Series,178) -> PolarsResult<Series> {179let mut mask = get_replacement_mask(s, old)?;180if old.null_count() > 0 {181mask = mask.fill_null_with_values(true)?;182}183new.zip_with(&mask, default)184}185/// Fast path for replacing by a single value in strict mode186fn replace_by_single_strict(s: &Series, old: &Series, new: &Series) -> PolarsResult<Series> {187let mask = get_replacement_mask(s, old)?;188ensure_all_replaced(&mask, s, old.null_count() > 0, true)?;189190let mut out = new.new_from_index(0, s.len());191192// Transfer validity from `mask` to `out`.193if mask.null_count() > 0 {194out = out.zip_with(&mask, &Series::new_null(PlSmallStr::EMPTY, s.len()))?195}196Ok(out)197}198/// Get a boolean mask of which values in the original Series will be replaced.199///200/// Null values are propagated to the mask.201fn get_replacement_mask(s: &Series, old: &Series) -> PolarsResult<BooleanChunked> {202if old.null_count() == old.len() {203// Fast path for when users are using `replace(None, ...)` instead of `fill_null`.204Ok(s.is_null())205} else {206let old = old.implode()?;207is_in(s, &old.into_series(), false)208}209}210211/// General case for replacing by multiple values212fn replace_by_multiple(213s: &Series,214old: Series,215new: Series,216default: &Series,217) -> PolarsResult<Series> {218validate_new(&new, &old)?;219220let df = s.clone().into_frame();221let add_replacer_mask = new.null_count() > 0;222let replacer = create_replacer(old, new, add_replacer_mask)?;223224let joined = df.join(225&replacer,226[s.name().as_str()],227["__POLARS_REPLACE_OLD"],228JoinArgs {229how: JoinType::Left,230coalesce: JoinCoalesce::CoalesceColumns,231nulls_equal: true,232..Default::default()233},234None,235)?;236237let replaced = joined238.column("__POLARS_REPLACE_NEW")239.unwrap()240.as_materialized_series();241242if replaced.null_count() == 0 {243return Ok(replaced.clone());244}245246match joined.column("__POLARS_REPLACE_MASK") {247Ok(col) => {248let mask = col.bool().unwrap();249replaced.zip_with(mask, default)250},251Err(_) => {252let mask = &replaced.is_not_null();253replaced.zip_with(mask, default)254},255}256}257258/// General case for replacing by multiple values in strict mode259fn replace_by_multiple_strict(s: &Series, old: Series, new: Series) -> PolarsResult<Series> {260validate_new(&new, &old)?;261262let df = s.clone().into_frame();263let old_has_null = old.null_count() > 0;264let replacer = create_replacer(old, new, true)?;265266let joined = df.join(267&replacer,268[s.name().as_str()],269["__POLARS_REPLACE_OLD"],270JoinArgs {271how: JoinType::Left,272coalesce: JoinCoalesce::CoalesceColumns,273nulls_equal: true,274..Default::default()275},276None,277)?;278279let replaced = joined.column("__POLARS_REPLACE_NEW").unwrap();280281let mask = joined282.column("__POLARS_REPLACE_MASK")283.unwrap()284.bool()285.unwrap();286ensure_all_replaced(mask, s, old_has_null, false)?;287288Ok(replaced.as_materialized_series().clone())289}290291// Build replacer dataframe.292fn create_replacer(mut old: Series, mut new: Series, add_mask: bool) -> PolarsResult<DataFrame> {293old.rename(PlSmallStr::from_static("__POLARS_REPLACE_OLD"));294new.rename(PlSmallStr::from_static("__POLARS_REPLACE_NEW"));295296let len = old.len();297let cols = if add_mask {298let mask = Column::new_scalar(299PlSmallStr::from_static("__POLARS_REPLACE_MASK"),300true.into(),301new.len(),302);303vec![old.into(), new.into(), mask]304} else {305vec![old.into(), new.into()]306};307let out = unsafe { DataFrame::new_no_checks(len, cols) };308Ok(out)309}310311/// Validate the `new` input.312fn validate_new(new: &Series, old: &Series) -> PolarsResult<()> {313polars_ensure!(314new.len() == old.len(),315InvalidOperation: "`new` input for `replace` must have the same length as `old` or have length 1"316);317Ok(())318}319320/// Ensure that all values were replaced.321fn ensure_all_replaced(322mask: &BooleanChunked,323s: &Series,324old_has_null: bool,325check_all: bool,326) -> PolarsResult<()> {327let nulls_check = if old_has_null {328mask.null_count() == 0329} else {330mask.null_count() == s.null_count()331};332// Checking booleans is only relevant for the 'replace_by_single' path.333let bools_check = !check_all || mask.all();334335let all_replaced = bools_check && nulls_check;336polars_ensure!(337all_replaced,338InvalidOperation: "incomplete mapping specified for `replace_strict`\n\nHint: Pass a `default` value to set unmapped values."339);340Ok(())341}342343344