Path: blob/main/crates/polars-ops/src/series/ops/replace.rs
8421 views
use polars_core::prelude::*;1use polars_core::utils::try_get_supertype;2use polars_error::polars_ensure;34use crate::frame::join::*;5use crate::prelude::*;67fn find_output_length(8fnname: &str,9items: impl IntoIterator<Item = (&'static str, usize)>,10) -> PolarsResult<usize> {11let mut length = 1;12for (argument_idx, (argument, l)) in items.into_iter().enumerate() {13if l != 1 {14if l != length && length != 1 {15polars_bail!(16length_mismatch = fnname,17l,18length,19argument = argument,20argument_idx = argument_idx21);22}23length = l;24}25}26Ok(length)27}2829/// Replace values by different values of the same data type.30pub fn replace(s: &Series, old: &ListChunked, new: &ListChunked) -> PolarsResult<Series> {31find_output_length(32"replace",33[("self", s.len()), ("old", old.len()), ("new", new.len())],34)?;3536polars_ensure!(37old.len() == 1 && new.len() == 1,38nyi = "`replace` with a replacement pattern per row"39);4041let old = old.explode(ExplodeOptions {42empty_as_null: false,43keep_nulls: true,44})?;45let new = new.explode(ExplodeOptions {46empty_as_null: false,47keep_nulls: true,48})?;4950if old.is_empty() {51return Ok(s.clone());52}53validate_old(&old)?;5455let dtype = s.dtype();56let old = old.strict_cast(dtype)?;57let new = new.strict_cast(dtype)?;5859if new.len() == 1 {60replace_by_single(s, &old, &new, s)61} else {62replace_by_multiple(s, old, new, s)63}64}6566/// Replace all values by different values.67///68/// Unmatched values are replaced by a default value.69pub fn replace_or_default(70s: &Series,71old: &ListChunked,72new: &ListChunked,73default: &Series,74return_dtype: Option<DataType>,75) -> PolarsResult<Series> {76find_output_length(77"replace_strict",78[79("self", s.len()),80("old", old.len()),81("new", new.len()),82("default", default.len()),83],84)?;8586polars_ensure!(87old.len() == 1 && new.len() == 1,88nyi = "`replace_strict` with a replacement pattern per row"89);9091let old = old.explode(ExplodeOptions {92empty_as_null: false,93keep_nulls: true,94})?;95let new = new.explode(ExplodeOptions {96empty_as_null: false,97keep_nulls: true,98})?;99100polars_ensure!(101default.len() == s.len() || default.len() == 1,102InvalidOperation: "`default` input for `replace_strict` must have the same length as the input or have length 1"103);104validate_old(&old)?;105106let return_dtype = match return_dtype {107Some(dtype) => dtype,108None => try_get_supertype(new.dtype(), default.dtype())?,109};110let default = default.cast(&return_dtype)?;111112if old.is_empty() {113let out = if default.len() == 1 && s.len() != 1 {114default.new_from_index(0, s.len())115} else {116default117};118return Ok(out);119}120121let old = old.strict_cast(s.dtype())?;122let new = new.cast(&return_dtype)?;123124if new.len() == 1 {125replace_by_single(s, &old, &new, &default)126} else {127replace_by_multiple(s, old, new, &default)128}129}130131/// Replace all values by different values.132///133/// Raises an error if not all values were replaced.134pub fn replace_strict(135s: &Series,136old: &ListChunked,137new: &ListChunked,138return_dtype: Option<DataType>,139) -> PolarsResult<Series> {140find_output_length(141"replace_strict",142[("self", s.len()), ("old", old.len()), ("new", new.len())],143)?;144145polars_ensure!(146old.len() == 1 && new.len() == 1,147nyi = "`replace_strict` with a replacement pattern per row"148);149150let old = old.explode(ExplodeOptions {151empty_as_null: false,152keep_nulls: true,153})?;154let new = new.explode(ExplodeOptions {155empty_as_null: false,156keep_nulls: true,157})?;158159if old.is_empty() {160polars_ensure!(161s.len() == s.null_count(),162InvalidOperation: "must specify which values to replace"163);164return Ok(s.clone());165}166validate_old(&old)?;167168// Extra check because strict_cast is too permissive, e.g. allows string -> struct cast.169if old.dtype().can_cast_to(s.dtype()) != Some(true) {170polars_bail!(171InvalidOperation: "cannot use values of type `{}` to replace values in a column of type `{}`",172old.dtype(),173s.dtype()174)175}176177let old = old.strict_cast(s.dtype())?;178179let new = match return_dtype {180Some(dtype) => new.strict_cast(&dtype)?,181None => new,182};183184if new.len() == 1 {185replace_by_single_strict(s, &old, &new)186} else {187replace_by_multiple_strict(s, old, new)188}189}190191/// Validate the `old` input.192fn validate_old(old: &Series) -> PolarsResult<()> {193polars_ensure!(194old.n_unique()? == old.len(),195InvalidOperation: "`old` input for `replace` must not contain duplicates"196);197Ok(())198}199200// Fast path for replacing by a single value201fn replace_by_single(202s: &Series,203old: &Series,204new: &Series,205default: &Series,206) -> PolarsResult<Series> {207let mut mask = get_replacement_mask(s, old)?;208if old.null_count() > 0 {209mask = mask.fill_null_with_values(true)?;210}211new.zip_with(&mask, default)212}213/// Fast path for replacing by a single value in strict mode214fn replace_by_single_strict(s: &Series, old: &Series, new: &Series) -> PolarsResult<Series> {215let mask = get_replacement_mask(s, old)?;216ensure_all_replaced(&mask, s, old.null_count() > 0, true)?;217218let mut out = new.new_from_index(0, s.len());219220// Transfer validity from `mask` to `out`.221if mask.null_count() > 0 {222out = out.zip_with(&mask, &Series::new_null(PlSmallStr::EMPTY, s.len()))?223}224Ok(out)225}226/// Get a boolean mask of which values in the original Series will be replaced.227///228/// Null values are propagated to the mask.229fn get_replacement_mask(s: &Series, old: &Series) -> PolarsResult<BooleanChunked> {230if old.null_count() == old.len() {231// Fast path for when users are using `replace(None, ...)` instead of `fill_null`.232Ok(s.is_null())233} else {234let old = old.implode()?;235is_in(s, &old.into_series(), false)236}237}238239/// General case for replacing by multiple values240fn replace_by_multiple(241s: &Series,242old: Series,243new: Series,244default: &Series,245) -> PolarsResult<Series> {246validate_new(&new, &old)?;247248let df = s.clone().into_frame();249let add_replacer_mask = new.null_count() > 0;250let replacer = create_replacer(old, new, add_replacer_mask)?;251252let joined = df.join(253&replacer,254[s.name().as_str()],255["__POLARS_REPLACE_OLD"],256JoinArgs {257how: JoinType::Left,258coalesce: JoinCoalesce::CoalesceColumns,259nulls_equal: true,260..Default::default()261},262None,263)?;264265let replaced = joined266.column("__POLARS_REPLACE_NEW")267.unwrap()268.as_materialized_series();269270if replaced.null_count() == 0 {271return Ok(replaced.clone());272}273274match joined.column("__POLARS_REPLACE_MASK") {275Ok(col) => {276let mask = col.bool().unwrap();277replaced.zip_with(mask, default)278},279Err(_) => {280let mask = &replaced.is_not_null();281replaced.zip_with(mask, default)282},283}284}285286/// General case for replacing by multiple values in strict mode287fn replace_by_multiple_strict(s: &Series, old: Series, new: Series) -> PolarsResult<Series> {288validate_new(&new, &old)?;289290let df = s.clone().into_frame();291let old_has_null = old.null_count() > 0;292let replacer = create_replacer(old, new, true)?;293294let joined = df.join(295&replacer,296[s.name().as_str()],297["__POLARS_REPLACE_OLD"],298JoinArgs {299how: JoinType::Left,300coalesce: JoinCoalesce::CoalesceColumns,301nulls_equal: true,302..Default::default()303},304None,305)?;306307let replaced = joined.column("__POLARS_REPLACE_NEW").unwrap();308309let mask = joined310.column("__POLARS_REPLACE_MASK")311.unwrap()312.bool()313.unwrap();314ensure_all_replaced(mask, s, old_has_null, false)?;315316Ok(replaced.as_materialized_series().clone())317}318319// Build replacer dataframe.320fn create_replacer(mut old: Series, mut new: Series, add_mask: bool) -> PolarsResult<DataFrame> {321old.rename(PlSmallStr::from_static("__POLARS_REPLACE_OLD"));322new.rename(PlSmallStr::from_static("__POLARS_REPLACE_NEW"));323324let len = old.len();325let cols = if add_mask {326let mask = Column::new_scalar(327PlSmallStr::from_static("__POLARS_REPLACE_MASK"),328true.into(),329new.len(),330);331vec![old.into(), new.into(), mask]332} else {333vec![old.into(), new.into()]334};335let out = unsafe { DataFrame::new_unchecked(len, cols) };336Ok(out)337}338339/// Validate the `new` input.340fn validate_new(new: &Series, old: &Series) -> PolarsResult<()> {341polars_ensure!(342new.len() == old.len(),343InvalidOperation: "`new` input for `replace` must have the same length as `old` or have length 1"344);345Ok(())346}347348/// Ensure that all values were replaced.349fn ensure_all_replaced(350mask: &BooleanChunked,351s: &Series,352old_has_null: bool,353check_all: bool,354) -> PolarsResult<()> {355let nulls_check = if old_has_null {356mask.null_count() == 0357} else {358mask.null_count() == s.null_count()359};360// Checking booleans is only relevant for the 'replace_by_single' path.361let bools_check = !check_all || mask.all();362363let all_replaced = bools_check && nulls_check;364polars_ensure!(365all_replaced,366InvalidOperation: "incomplete mapping specified for `replace_strict`\n\nHint: Pass a `default` value to set unmapped values."367);368Ok(())369}370371372