Path: blob/main/crates/polars-ops/src/chunked_array/strings/substring.rs
6939 views
use std::cmp::Ordering;12use arrow::array::View;3use polars_core::prelude::arity::{binary_elementwise, ternary_elementwise, unary_elementwise};4use polars_core::prelude::{ChunkFullNull, Int64Chunked, StringChunked, UInt64Chunked};5use polars_error::{PolarsResult, polars_ensure};67fn head_binary(opt_str_val: Option<&str>, opt_n: Option<i64>) -> Option<&str> {8if let (Some(str_val), Some(n)) = (opt_str_val, opt_n) {9let end_idx = head_binary_values(str_val, n);10Some(unsafe { str_val.get_unchecked(..end_idx) })11} else {12None13}14}1516fn head_binary_values(str_val: &str, n: i64) -> usize {17match n.cmp(&0) {18Ordering::Equal => 0,19Ordering::Greater => {20if n as usize >= str_val.len() {21return str_val.len();22}23// End after the nth codepoint.24str_val25.char_indices()26.nth(n as usize)27.map(|(idx, _)| idx)28.unwrap_or(str_val.len())29},30_ => {31// End after the nth codepoint from the end.32str_val33.char_indices()34.rev()35.nth((-n - 1) as usize)36.map(|(idx, _)| idx)37.unwrap_or(0)38},39}40}4142fn tail_binary(opt_str_val: Option<&str>, opt_n: Option<i64>) -> Option<&str> {43if let (Some(str_val), Some(n)) = (opt_str_val, opt_n) {44let start_idx = tail_binary_values(str_val, n);45Some(unsafe { str_val.get_unchecked(start_idx..) })46} else {47None48}49}5051fn tail_binary_values(str_val: &str, n: i64) -> usize {52// `max_len` is guaranteed to be at least the total number of characters.53let max_len = str_val.len();5455match n.cmp(&0) {56Ordering::Equal => max_len,57Ordering::Greater => {58if n as usize >= max_len {59return 0;60}61// Start from nth codepoint from the end62str_val63.char_indices()64.rev()65.nth((n - 1) as usize)66.map(|(idx, _)| idx)67.unwrap_or(0)68},69_ => {70// Start after the nth codepoint71str_val72.char_indices()73.nth((-n) as usize)74.map(|(idx, _)| idx)75.unwrap_or(max_len)76},77}78}7980fn substring_ternary_offsets(81opt_str_val: Option<&str>,82opt_offset: Option<i64>,83opt_length: Option<u64>,84) -> Option<(usize, usize)> {85let str_val = opt_str_val?;86let offset = opt_offset?;87Some(substring_ternary_offsets_value(88str_val,89offset,90opt_length.unwrap_or(u64::MAX),91))92}9394pub fn substring_ternary_offsets_value(str_val: &str, offset: i64, length: u64) -> (usize, usize) {95// Fast-path: always empty string.96if length == 0 || offset >= str_val.len() as i64 {97return (0, 0);98}99100let mut indices = str_val.char_indices().map(|(o, _)| o);101let mut length_reduction = 0;102let start_byte_offset = if offset >= 0 {103indices.nth(offset as usize).unwrap_or(str_val.len())104} else {105// If `offset` is negative, it counts from the end of the string.106let mut chars_skipped = 0;107let found = indices108.inspect(|_| chars_skipped += 1)109.nth_back((-offset - 1) as usize);110111// If we didn't find our char that means our offset was so negative it112// is before the start of our string. This means our length must be113// reduced, assuming it is finite.114if let Some(off) = found {115off116} else {117length_reduction = (-offset) as usize - chars_skipped;1180119}120};121122let str_val = &str_val[start_byte_offset..];123let mut indices = str_val.char_indices().map(|(o, _)| o);124let stop_byte_offset = indices125.nth((length as usize).saturating_sub(length_reduction))126.unwrap_or(str_val.len());127(start_byte_offset, stop_byte_offset + start_byte_offset)128}129130fn substring_ternary(131opt_str_val: Option<&str>,132opt_offset: Option<i64>,133opt_length: Option<u64>,134) -> Option<&str> {135let (start, end) = substring_ternary_offsets(opt_str_val, opt_offset, opt_length)?;136unsafe { opt_str_val.map(|str_val| str_val.get_unchecked(start..end)) }137}138139pub fn update_view(mut view: View, start: usize, end: usize, val: &str) -> View {140let length = (end - start) as u32;141view.length = length;142143// SAFETY: we just compute the start /end.144let subval = unsafe { val.get_unchecked(start..end).as_bytes() };145146if length <= 12 {147View::new_inline(subval)148} else {149view.offset += start as u32;150view.length = length;151view.prefix = u32::from_le_bytes(subval[0..4].try_into().unwrap());152view153}154}155156pub(super) fn substring(157ca: &StringChunked,158offset: &Int64Chunked,159length: &UInt64Chunked,160) -> StringChunked {161match (ca.len(), offset.len(), length.len()) {162(1, 1, _) => {163let str_val = ca.get(0);164let offset = offset.get(0);165unary_elementwise(length, |length| substring_ternary(str_val, offset, length))166.with_name(ca.name().clone())167},168(_, 1, 1) => {169let offset = offset.get(0);170let length = length.get(0).unwrap_or(u64::MAX);171172let Some(offset) = offset else {173return StringChunked::full_null(ca.name().clone(), ca.len());174};175176unsafe {177ca.apply_views(|view, val| {178let (start, end) = substring_ternary_offsets_value(val, offset, length);179update_view(view, start, end, val)180})181}182},183(1, _, 1) => {184let str_val = ca.get(0);185let length = length.get(0);186unary_elementwise(offset, |offset| substring_ternary(str_val, offset, length))187.with_name(ca.name().clone())188},189(1, len_b, len_c) if len_b == len_c => {190let str_val = ca.get(0);191binary_elementwise(offset, length, |offset, length| {192substring_ternary(str_val, offset, length)193})194},195(len_a, 1, len_c) if len_a == len_c => {196fn infer<F: for<'a> FnMut(Option<&'a str>, Option<u64>) -> Option<&'a str>>(f: F) -> F where197{198f199}200let offset = offset.get(0);201binary_elementwise(202ca,203length,204infer(|str_val, length| substring_ternary(str_val, offset, length)),205)206},207(len_a, len_b, 1) if len_a == len_b => {208fn infer<F: for<'a> FnMut(Option<&'a str>, Option<i64>) -> Option<&'a str>>(f: F) -> F where209{210f211}212let length = length.get(0);213binary_elementwise(214ca,215offset,216infer(|str_val, offset| substring_ternary(str_val, offset, length)),217)218},219_ => ternary_elementwise(ca, offset, length, substring_ternary),220}221}222223pub(super) fn head(ca: &StringChunked, n: &Int64Chunked) -> PolarsResult<StringChunked> {224match (ca.len(), n.len()) {225(len, 1) => {226let n = n.get(0);227let Some(n) = n else {228return Ok(StringChunked::full_null(ca.name().clone(), len));229};230231Ok(unsafe {232ca.apply_views(|view, val| {233let end = head_binary_values(val, n);234update_view(view, 0, end, val)235})236})237},238// TODO! below should also work on only views239(1, _) => {240let str_val = ca.get(0);241Ok(unary_elementwise(n, |n| head_binary(str_val, n)).with_name(ca.name().clone()))242},243(a, b) => {244polars_ensure!(a == b, ShapeMismatch: "lengths of arguments do not align in 'str.head' got length: {} for column: {}, got length: {} for argument 'n'", a, ca.name(), b);245Ok(binary_elementwise(ca, n, head_binary))246},247}248}249250pub(super) fn tail(ca: &StringChunked, n: &Int64Chunked) -> PolarsResult<StringChunked> {251Ok(match (ca.len(), n.len()) {252(len, 1) => {253let n = n.get(0);254let Some(n) = n else {255return Ok(StringChunked::full_null(ca.name().clone(), len));256};257unsafe {258ca.apply_views(|view, val| {259let start = tail_binary_values(val, n);260update_view(view, start, val.len(), val)261})262}263},264// TODO! below should also work on only views265(1, _) => {266let str_val = ca.get(0);267unary_elementwise(n, |n| tail_binary(str_val, n)).with_name(ca.name().clone())268},269(a, b) => {270polars_ensure!(a == b, ShapeMismatch: "lengths of arguments do not align in 'str.tail' got length: {} for column: {}, got length: {} for argument 'n'", a, ca.name(), b);271binary_elementwise(ca, n, tail_binary)272},273})274}275276277