Path: blob/main/crates/polars-ops/src/chunked_array/strings/find_many.rs
6939 views
use aho_corasick::{AhoCorasick, AhoCorasickBuilder};1use arrow::array::Utf8ViewArray;2use polars_core::prelude::arity::unary_elementwise;3use polars_core::prelude::*;4use polars_core::utils::align_chunks_binary;56fn build_ac(patterns: &StringChunked, ascii_case_insensitive: bool) -> PolarsResult<AhoCorasick> {7AhoCorasickBuilder::new()8.ascii_case_insensitive(ascii_case_insensitive)9.build(patterns.downcast_iter().flatten().flatten())10.map_err(|e| polars_err!(ComputeError: "could not build aho corasick automaton {}", e))11}1213fn build_ac_arr(14patterns: &Utf8ViewArray,15ascii_case_insensitive: bool,16) -> PolarsResult<AhoCorasick> {17AhoCorasickBuilder::new()18.ascii_case_insensitive(ascii_case_insensitive)19.build(patterns.into_iter().flatten())20.map_err(|e| polars_err!(ComputeError: "could not build aho corasick automaton {}", e))21}2223pub fn contains_any(24ca: &StringChunked,25patterns: &ListChunked,26ascii_case_insensitive: bool,27) -> PolarsResult<BooleanChunked> {28polars_ensure!(29ca.len() == patterns.len() || ca.len() == 1 || patterns.len() == 1,30length_mismatch = "str.contains_any",31ca.len(),32patterns.len()33);34polars_ensure!(35patterns.len() == 1,36nyi = "`str.contains_any` with a pattern per row"37);3839if patterns.has_nulls() {40return Ok(BooleanChunked::full_null(ca.name().clone(), ca.len()));41}4243let patterns = patterns.explode(true)?;44let patterns = patterns.str()?;45let ac = build_ac(patterns, ascii_case_insensitive)?;4647Ok(unary_elementwise(ca, |opt_val| {48opt_val.map(|val| ac.find(val).is_some())49}))50}5152pub fn replace_all(53ca: &StringChunked,54patterns: &ListChunked,55replace_with: &ListChunked,56ascii_case_insensitive: bool,57) -> PolarsResult<StringChunked> {58let mut length = 1;59for (argument_idx, (argument, l)) in [60("self", ca.len()),61("patterns", patterns.len()),62("replace_with", replace_with.len()),63]64.into_iter()65.enumerate()66{67if l != 1 {68if l != length && length != 1 {69polars_bail!(70length_mismatch = "str.replace_many",71l,72length,73argument = argument,74argument_idx = argument_idx75);76}77length = l;78}79}8081polars_ensure!(82patterns.len() == 1 && replace_with.len() == 1,83nyi = "`str.replace_many` with a pattern per row"84);8586if patterns.has_nulls() || replace_with.has_nulls() {87return Ok(StringChunked::full_null(ca.name().clone(), ca.len()));88}8990let patterns = patterns.explode(true)?;91let patterns = patterns.str()?;92let replace_with = replace_with.explode(true)?;93let replace_with = replace_with.str()?;9495let replace_with = if replace_with.len() == 1 && patterns.len() > 1 {96replace_with.new_from_index(0, patterns.len())97} else {98replace_with.clone()99};100101polars_ensure!(patterns.len() == replace_with.len(), InvalidOperation: "expected the same amount of patterns as replacement strings");102polars_ensure!(patterns.null_count() == 0 && replace_with.null_count() == 0, InvalidOperation: "'patterns'/'replace_with' should not have nulls");103let replace_with = replace_with104.downcast_iter()105.flatten()106.flatten()107.collect::<Vec<_>>();108109let ac = build_ac(patterns, ascii_case_insensitive)?;110111Ok(unary_elementwise(ca, |opt_val| {112opt_val.map(|val| ac.replace_all(val, replace_with.as_slice()))113}))114}115116fn push_str(117val: &str,118builder: &mut ListStringChunkedBuilder,119ac: &AhoCorasick,120overlapping: bool,121) {122if overlapping {123let iter = ac.find_overlapping_iter(val);124let iter = iter.map(|m| &val[m.start()..m.end()]);125builder.append_values_iter(iter);126} else {127let iter = ac.find_iter(val);128let iter = iter.map(|m| &val[m.start()..m.end()]);129builder.append_values_iter(iter);130}131}132133pub fn extract_many(134ca: &StringChunked,135patterns: &ListChunked,136ascii_case_insensitive: bool,137overlapping: bool,138) -> PolarsResult<ListChunked> {139match (ca.len(), patterns.len()) {140(1, _) => match ca.get(0) {141None => Ok(ListChunked::full_null_with_dtype(142ca.name().clone(),143ca.len(),144&DataType::String,145)),146Some(val) => {147let mut builder =148ListStringChunkedBuilder::new(ca.name().clone(), ca.len(), ca.len() * 2);149150for pat in patterns.amortized_iter() {151match pat {152None => builder.append_null(),153Some(pat) => {154let pat = pat.as_ref();155let pat = pat.str()?;156let pat = pat.rechunk();157let pat = pat.downcast_as_array();158let ac = build_ac_arr(pat, ascii_case_insensitive)?;159push_str(val, &mut builder, &ac, overlapping);160},161}162}163Ok(builder.finish())164},165},166(_, 1) => {167let patterns = patterns.explode(true)?;168let patterns = patterns.str()?;169let ac = build_ac(patterns, ascii_case_insensitive)?;170let mut builder =171ListStringChunkedBuilder::new(ca.name().clone(), ca.len(), ca.len() * 2);172173for arr in ca.downcast_iter() {174for opt_val in arr.into_iter() {175if let Some(val) = opt_val {176push_str(val, &mut builder, &ac, overlapping);177} else {178builder.append_null();179}180}181}182Ok(builder.finish())183},184(a, b) if a == b => {185let mut builder =186ListStringChunkedBuilder::new(ca.name().clone(), ca.len(), ca.len() * 2);187let (ca, patterns) = align_chunks_binary(ca, patterns);188189for (arr, pat_arr) in ca.downcast_iter().zip(patterns.downcast_iter()) {190for z in arr.into_iter().zip(pat_arr.into_iter()) {191match z {192(None, _) | (_, None) => builder.append_null(),193(Some(val), Some(pat)) => {194let pat = pat.as_any().downcast_ref::<Utf8ViewArray>().unwrap();195let ac = build_ac_arr(pat, ascii_case_insensitive)?;196push_str(val, &mut builder, &ac, overlapping);197},198}199}200}201Ok(builder.finish())202},203(a, b) => polars_bail!(length_mismatch = "str.extract_many", a, b),204}205}206207type B = ListPrimitiveChunkedBuilder<UInt32Type>;208fn push_idx(val: &str, builder: &mut B, ac: &AhoCorasick, overlapping: bool) {209if overlapping {210let iter = ac.find_overlapping_iter(val);211let iter = iter.map(|m| m.start() as u32);212builder.append_values_iter(iter);213} else {214let iter = ac.find_iter(val);215let iter = iter.map(|m| m.start() as u32);216builder.append_values_iter(iter);217}218}219220pub fn find_many(221ca: &StringChunked,222patterns: &ListChunked,223ascii_case_insensitive: bool,224overlapping: bool,225) -> PolarsResult<ListChunked> {226type B = ListPrimitiveChunkedBuilder<UInt32Type>;227match (ca.len(), patterns.len()) {228(1, _) => match ca.get(0) {229None => Ok(ListChunked::full_null_with_dtype(230ca.name().clone(),231patterns.len(),232&DataType::UInt32,233)),234Some(val) => {235let mut builder = B::new(236ca.name().clone(),237patterns.len(),238patterns.len() * 2,239DataType::UInt32,240);241for pat in patterns.amortized_iter() {242match pat {243None => builder.append_null(),244Some(pat) => {245let pat = pat.as_ref();246let pat = pat.str()?;247let pat = pat.rechunk();248let pat = pat.downcast_as_array();249let ac = build_ac_arr(pat, ascii_case_insensitive)?;250push_idx(val, &mut builder, &ac, overlapping);251},252}253}254Ok(builder.finish())255},256},257(_, 1) => {258let patterns = patterns.explode(true)?;259let patterns = patterns.str()?;260let ac = build_ac(patterns, ascii_case_insensitive)?;261let mut builder = B::new(ca.name().clone(), ca.len(), ca.len() * 2, DataType::UInt32);262263for opt_val in ca.iter() {264if let Some(val) = opt_val {265push_idx(val, &mut builder, &ac, overlapping);266} else {267builder.append_null();268}269}270Ok(builder.finish())271},272(a, b) if a == b => {273let mut builder = B::new(ca.name().clone(), ca.len(), ca.len() * 2, DataType::UInt32);274let (ca, patterns) = align_chunks_binary(ca, patterns);275276for (arr, pat_arr) in ca.downcast_iter().zip(patterns.downcast_iter()) {277for z in arr.into_iter().zip(pat_arr.into_iter()) {278match z {279(None, _) | (_, None) => builder.append_null(),280(Some(val), Some(pat)) => {281let pat = pat.as_any().downcast_ref::<Utf8ViewArray>().unwrap();282let ac = build_ac_arr(pat, ascii_case_insensitive)?;283push_idx(val, &mut builder, &ac, overlapping);284},285}286}287}288Ok(builder.finish())289},290(a, b) => polars_bail!(length_mismatch = "str.find_many", a, b),291}292}293294295