Path: blob/main/crates/polars-ops/src/chunked_array/strings/extract.rs
6939 views
use std::iter::zip;12#[cfg(feature = "extract_groups")]3use arrow::array::{Array, StructArray};4use arrow::array::{MutablePlString, Utf8ViewArray};5use polars_core::prelude::arity::{try_binary_mut_with_options, try_unary_mut_with_options};6use regex::Regex;78use super::*;910#[cfg(feature = "extract_groups")]11fn extract_groups_array(12arr: &Utf8ViewArray,13reg: &Regex,14names: &[&str],15dtype: ArrowDataType,16) -> PolarsResult<ArrayRef> {17let mut builders = (0..names.len())18.map(|_| MutablePlString::with_capacity(arr.len()))19.collect::<Vec<_>>();2021let mut locs = reg.capture_locations();22for opt_v in arr {23if let Some(s) = opt_v {24if reg.captures_read(&mut locs, s).is_some() {25for (i, builder) in builders.iter_mut().enumerate() {26builder.push(locs.get(i + 1).map(|(start, stop)| &s[start..stop]));27}28continue;29}30}3132// Push nulls if either the string is null or there was no match. We33// distinguish later between the two by copying arr's validity mask.34builders.iter_mut().for_each(|arr| arr.push_null());35}3637let values = builders.into_iter().map(|a| a.freeze().boxed()).collect();38Ok(StructArray::new(dtype, arr.len(), values, arr.validity().cloned()).boxed())39}4041#[cfg(feature = "extract_groups")]42pub(super) fn extract_groups(43ca: &StringChunked,44pat: &str,45dtype: &DataType,46) -> PolarsResult<Series> {47let reg = polars_utils::regex_cache::compile_regex(pat)?;48let n_fields = reg.captures_len();49if n_fields == 1 {50return StructChunked::from_series(ca.name().clone(), ca.len(), [].iter())51.map(|ca| ca.into_series());52}5354let arrow_dtype = dtype.try_to_arrow(CompatLevel::newest())?;55let DataType::Struct(fields) = dtype else {56unreachable!() // Implementation error if it isn't a struct.57};58let names = fields59.iter()60.map(|fld| fld.name.as_str())61.collect::<Vec<_>>();6263let chunks = ca64.downcast_iter()65.map(|array| extract_groups_array(array, ®, &names, arrow_dtype.clone()))66.collect::<PolarsResult<Vec<_>>>()?;6768Series::try_from((ca.name().clone(), chunks))69}7071fn extract_group_reg_lit(72arr: &Utf8ViewArray,73reg: &Regex,74group_index: usize,75) -> PolarsResult<Utf8ViewArray> {76let mut builder = MutablePlString::with_capacity(arr.len());7778let mut locs = reg.capture_locations();79for opt_v in arr {80if let Some(s) = opt_v {81if reg.captures_read(&mut locs, s).is_some() {82builder.push(locs.get(group_index).map(|(start, stop)| &s[start..stop]));83continue;84}85}8687// Push null if either the string is null or there was no match.88builder.push_null();89}9091Ok(builder.into())92}9394fn extract_group_array_lit(95s: &str,96pat: &Utf8ViewArray,97group_index: usize,98) -> PolarsResult<Utf8ViewArray> {99let mut builder = MutablePlString::with_capacity(pat.len());100101for opt_pat in pat {102if let Some(pat) = opt_pat {103let reg = polars_utils::regex_cache::compile_regex(pat)?;104let mut locs = reg.capture_locations();105if reg.captures_read(&mut locs, s).is_some() {106builder.push(locs.get(group_index).map(|(start, stop)| &s[start..stop]));107continue;108}109}110111// Push null if either the pat is null or there was no match.112builder.push_null();113}114115Ok(builder.into())116}117118fn extract_group_binary(119arr: &Utf8ViewArray,120pat: &Utf8ViewArray,121group_index: usize,122) -> PolarsResult<Utf8ViewArray> {123let mut builder = MutablePlString::with_capacity(arr.len());124125for (opt_s, opt_pat) in zip(arr, pat) {126match (opt_s, opt_pat) {127(Some(s), Some(pat)) => {128let reg = polars_utils::regex_cache::compile_regex(pat)?;129let mut locs = reg.capture_locations();130if reg.captures_read(&mut locs, s).is_some() {131builder.push(locs.get(group_index).map(|(start, stop)| &s[start..stop]));132continue;133}134// Push null if there was no match.135builder.push_null()136},137_ => builder.push_null(),138}139}140141Ok(builder.into())142}143144pub(super) fn extract_group(145ca: &StringChunked,146pat: &StringChunked,147group_index: usize,148) -> PolarsResult<StringChunked> {149match (ca.len(), pat.len()) {150(_, 1) => {151if let Some(pat) = pat.get(0) {152let reg = polars_utils::regex_cache::compile_regex(pat)?;153try_unary_mut_with_options(ca, |arr| extract_group_reg_lit(arr, ®, group_index))154} else {155Ok(StringChunked::full_null(ca.name().clone(), ca.len()))156}157},158(1, _) => {159if let Some(s) = ca.get(0) {160try_unary_mut_with_options(pat, |pat| extract_group_array_lit(s, pat, group_index))161} else {162Ok(StringChunked::full_null(ca.name().clone(), pat.len()))163}164},165(len_ca, len_pat) if len_ca == len_pat => try_binary_mut_with_options(166ca,167pat,168|ca, pat| extract_group_binary(ca, pat, group_index),169ca.name().clone(),170),171_ => {172polars_bail!(ComputeError: "ca(len: {}) and pat(len: {}) should either broadcast or have the same length", ca.len(), pat.len())173},174}175}176177178