Path: blob/main/crates/polars-ops/src/chunked_array/strings/find_many.rs
8420 views
use aho_corasick::{AhoCorasick, AhoCorasickBuilder, MatchKind};1use arrow::array::Utf8ViewArray;2use polars_core::prelude::arity::unary_elementwise;3use polars_core::prelude::*;4use polars_core::utils::align_chunks_binary;56fn build_ac(7patterns: &StringChunked,8ascii_case_insensitive: bool,9leftmost: bool,10) -> PolarsResult<AhoCorasick> {11AhoCorasickBuilder::new()12.match_kind(if leftmost {13MatchKind::LeftmostFirst14} else {15MatchKind::Standard16})17.ascii_case_insensitive(ascii_case_insensitive)18.build(patterns.downcast_iter().flatten().flatten())19.map_err(|e| polars_err!(ComputeError: "could not build aho corasick automaton {}", e))20}2122fn build_ac_arr(23patterns: &Utf8ViewArray,24ascii_case_insensitive: bool,25leftmost: bool,26) -> PolarsResult<AhoCorasick> {27AhoCorasickBuilder::new()28.match_kind(if leftmost {29MatchKind::LeftmostFirst30} else {31MatchKind::Standard32})33.ascii_case_insensitive(ascii_case_insensitive)34.build(patterns.into_iter().flatten())35.map_err(|e| polars_err!(ComputeError: "could not build aho corasick automaton {}", e))36}3738pub fn contains_any(39ca: &StringChunked,40patterns: &ListChunked,41ascii_case_insensitive: bool,42) -> PolarsResult<BooleanChunked> {43polars_ensure!(44ca.len() == patterns.len() || ca.len() == 1 || patterns.len() == 1,45length_mismatch = "str.contains_any",46ca.len(),47patterns.len()48);49polars_ensure!(50patterns.len() == 1,51nyi = "`str.contains_any` with a pattern per row"52);5354if patterns.has_nulls() {55return Ok(BooleanChunked::full_null(ca.name().clone(), ca.len()));56}5758let patterns = patterns.explode(ExplodeOptions {59empty_as_null: false,60keep_nulls: true,61})?;62let patterns = patterns.str()?;63let ac = build_ac(patterns, ascii_case_insensitive, false)?;6465Ok(unary_elementwise(ca, |opt_val| {66opt_val.map(|val| ac.find(val).is_some())67}))68}6970pub fn replace_all(71ca: &StringChunked,72patterns: &ListChunked,73replace_with: &ListChunked,74ascii_case_insensitive: bool,75leftmost: bool,76) -> PolarsResult<StringChunked> {77let mut length = 1;78for (argument_idx, (argument, l)) in [79("self", ca.len()),80("patterns", patterns.len()),81("replace_with", replace_with.len()),82]83.into_iter()84.enumerate()85{86if l != 1 {87if l != length && length != 1 {88polars_bail!(89length_mismatch = "str.replace_many",90l,91length,92argument = argument,93argument_idx = argument_idx94);95}96length = l;97}98}99100polars_ensure!(101patterns.len() == 1 && replace_with.len() == 1,102nyi = "`str.replace_many` with a pattern per row"103);104105if patterns.has_nulls() || replace_with.has_nulls() {106return Ok(StringChunked::full_null(ca.name().clone(), ca.len()));107}108109let patterns = patterns.explode(ExplodeOptions {110empty_as_null: false,111keep_nulls: true,112})?;113let patterns = patterns.str()?;114let replace_with = replace_with.explode(ExplodeOptions {115empty_as_null: false,116keep_nulls: true,117})?;118let replace_with = replace_with.str()?;119120let replace_with = if replace_with.len() == 1 && patterns.len() > 1 {121replace_with.new_from_index(0, patterns.len())122} else {123replace_with.clone()124};125126polars_ensure!(patterns.len() == replace_with.len(), InvalidOperation: "expected the same amount of patterns as replacement strings");127polars_ensure!(patterns.null_count() == 0 && replace_with.null_count() == 0, InvalidOperation: "'patterns'/'replace_with' should not have nulls");128let replace_with = replace_with129.downcast_iter()130.flatten()131.flatten()132.collect::<Vec<_>>();133134let ac = build_ac(patterns, ascii_case_insensitive, leftmost)?;135136Ok(unary_elementwise(ca, |opt_val| {137opt_val.map(|val| ac.replace_all(val, replace_with.as_slice()))138}))139}140141fn push_str(142val: &str,143builder: &mut ListStringChunkedBuilder,144ac: &AhoCorasick,145overlapping: bool,146) {147if overlapping {148let iter = ac.find_overlapping_iter(val);149let iter = iter.map(|m| &val[m.start()..m.end()]);150builder.append_values_iter(iter);151} else {152let iter = ac.find_iter(val);153let iter = iter.map(|m| &val[m.start()..m.end()]);154builder.append_values_iter(iter);155}156}157158pub fn extract_many(159ca: &StringChunked,160patterns: &ListChunked,161ascii_case_insensitive: bool,162overlapping: bool,163leftmost: bool,164) -> PolarsResult<ListChunked> {165// ensure that either overlapping == false, or overlapping == true and leftmost == false166polars_ensure!(!overlapping | !leftmost, InvalidOperation: "can not match overlapping patterns when leftmost == True");167match (ca.len(), patterns.len()) {168(1, _) => match ca.get(0) {169None => Ok(ListChunked::full_null_with_dtype(170ca.name().clone(),171ca.len(),172&DataType::String,173)),174Some(val) => {175let mut builder =176ListStringChunkedBuilder::new(ca.name().clone(), ca.len(), ca.len() * 2);177178for pat in patterns.amortized_iter() {179match pat {180None => builder.append_null(),181Some(pat) => {182let pat = pat.as_ref();183let pat = pat.str()?;184let pat = pat.rechunk();185let pat = pat.downcast_as_array();186let ac = build_ac_arr(pat, ascii_case_insensitive, leftmost)?;187push_str(val, &mut builder, &ac, overlapping);188},189}190}191Ok(builder.finish())192},193},194(_, 1) => {195let patterns = patterns.explode(ExplodeOptions {196empty_as_null: false,197keep_nulls: true,198})?;199let patterns = patterns.str()?;200let ac = build_ac(patterns, ascii_case_insensitive, leftmost)?;201let mut builder =202ListStringChunkedBuilder::new(ca.name().clone(), ca.len(), ca.len() * 2);203204for arr in ca.downcast_iter() {205for opt_val in arr.into_iter() {206if let Some(val) = opt_val {207push_str(val, &mut builder, &ac, overlapping);208} else {209builder.append_null();210}211}212}213Ok(builder.finish())214},215(a, b) if a == b => {216let mut builder =217ListStringChunkedBuilder::new(ca.name().clone(), ca.len(), ca.len() * 2);218let (ca, patterns) = align_chunks_binary(ca, patterns);219220for (arr, pat_arr) in ca.downcast_iter().zip(patterns.downcast_iter()) {221for z in arr.into_iter().zip(pat_arr.into_iter()) {222match z {223(None, _) | (_, None) => builder.append_null(),224(Some(val), Some(pat)) => {225let pat = pat.as_any().downcast_ref::<Utf8ViewArray>().unwrap();226let ac = build_ac_arr(pat, ascii_case_insensitive, leftmost)?;227push_str(val, &mut builder, &ac, overlapping);228},229}230}231}232Ok(builder.finish())233},234(a, b) => polars_bail!(length_mismatch = "str.extract_many", a, b),235}236}237238type B = ListPrimitiveChunkedBuilder<UInt32Type>;239fn push_idx(val: &str, builder: &mut B, ac: &AhoCorasick, overlapping: bool) {240if overlapping {241let iter = ac.find_overlapping_iter(val);242let iter = iter.map(|m| m.start() as u32);243builder.append_values_iter(iter);244} else {245let iter = ac.find_iter(val);246let iter = iter.map(|m| m.start() as u32);247builder.append_values_iter(iter);248}249}250251pub fn find_many(252ca: &StringChunked,253patterns: &ListChunked,254ascii_case_insensitive: bool,255overlapping: bool,256leftmost: bool,257) -> PolarsResult<ListChunked> {258polars_ensure!(!overlapping | !leftmost, InvalidOperation: "can not match overlapping patterns when leftmost == True");259type B = ListPrimitiveChunkedBuilder<UInt32Type>;260match (ca.len(), patterns.len()) {261(1, _) => match ca.get(0) {262None => Ok(ListChunked::full_null_with_dtype(263ca.name().clone(),264patterns.len(),265&DataType::UInt32,266)),267Some(val) => {268let mut builder = B::new(269ca.name().clone(),270patterns.len(),271patterns.len() * 2,272DataType::UInt32,273);274for pat in patterns.amortized_iter() {275match pat {276None => builder.append_null(),277Some(pat) => {278let pat = pat.as_ref();279let pat = pat.str()?;280let pat = pat.rechunk();281let pat = pat.downcast_as_array();282let ac = build_ac_arr(pat, ascii_case_insensitive, leftmost)?;283push_idx(val, &mut builder, &ac, overlapping);284},285}286}287Ok(builder.finish())288},289},290(_, 1) => {291let patterns = patterns.explode(ExplodeOptions {292empty_as_null: false,293keep_nulls: true,294})?;295let patterns = patterns.str()?;296let ac = build_ac(patterns, ascii_case_insensitive, leftmost)?;297let mut builder = B::new(ca.name().clone(), ca.len(), ca.len() * 2, DataType::UInt32);298299for opt_val in ca.iter() {300if let Some(val) = opt_val {301push_idx(val, &mut builder, &ac, overlapping);302} else {303builder.append_null();304}305}306Ok(builder.finish())307},308(a, b) if a == b => {309let mut builder = B::new(ca.name().clone(), ca.len(), ca.len() * 2, DataType::UInt32);310let (ca, patterns) = align_chunks_binary(ca, patterns);311312for (arr, pat_arr) in ca.downcast_iter().zip(patterns.downcast_iter()) {313for z in arr.into_iter().zip(pat_arr.into_iter()) {314match z {315(None, _) | (_, None) => builder.append_null(),316(Some(val), Some(pat)) => {317let pat = pat.as_any().downcast_ref::<Utf8ViewArray>().unwrap();318let ac = build_ac_arr(pat, ascii_case_insensitive, leftmost)?;319push_idx(val, &mut builder, &ac, overlapping);320},321}322}323}324Ok(builder.finish())325},326(a, b) => polars_bail!(length_mismatch = "str.find_many", a, b),327}328}329330331