Path: blob/main/crates/polars-ops/src/chunked_array/strings/substring.rs
8374 views
use arrow::array::View;1use polars_core::prelude::arity::{binary_elementwise, ternary_elementwise, unary_elementwise};2use polars_core::prelude::{ChunkFullNull, Int64Chunked, StringChunked, UInt64Chunked};3use polars_error::{PolarsResult, polars_ensure};45fn is_utf8_codepoint_start(b: u8) -> bool {6// The top two bits of a continuation byte are 10. Any other value is a7// starting byte. We can use signed comparison to test for this in one8// instruction, as the top bits 11, 00 and 01 are all more positive and thus9// larger in signed comparison.10(b as i8) >= (0b1100_0000_u8 as i8)11}1213/// Similar to char_to_byte_idx but if `char_idx` would be out-of-bounds the14/// number of codepoints in s is returned as an error.15pub fn char_to_byte_idx_or_cp_count(s: &str, char_idx: usize) -> Result<usize, usize> {16let bytes = s.as_bytes();17if char_idx == 0 {18return Ok(0);19}2021let mut offset = 0;22let mut num_chars_seen = 0;2324// Auto-vectorized bulk processing, but skip if index is small.25if char_idx >= 16 {26while let Some(chunk) = bytes.get(offset..offset + 16) {27let chunk_seen: usize = chunk28.iter()29.map(|b| is_utf8_codepoint_start(*b) as usize)30.sum();31if num_chars_seen + chunk_seen > char_idx {32break;33}34offset += 16;35num_chars_seen += chunk_seen;36}37}3839while let Some(b) = bytes.get(offset) {40num_chars_seen += is_utf8_codepoint_start(*b) as usize;41if num_chars_seen > char_idx {42return Ok(offset);43}44offset += 1;45}4647debug_assert!(offset == bytes.len());48Err(num_chars_seen)49}5051/// Given an offset to the start of the `char_idx`th codepoint, returns the52/// equivalent offset in bytes.53///54/// If `char_idx` would be out-of-bounds s.len() is returned.55pub fn char_to_byte_idx(s: &str, char_idx: usize) -> usize {56if char_idx >= s.len() {57// No need to even count.58s.len()59} else {60char_to_byte_idx_or_cp_count(s, char_idx).unwrap_or(s.len())61}62}6364/// Similar to rev_char_to_byte_idx but if `char_idx` would be out-of-bounds the65/// number of codepoints in s is returned as an error.66pub fn rev_char_to_byte_idx_or_cp_count(s: &str, rev_char_idx: usize) -> Result<usize, usize> {67let bytes = s.as_bytes();68if rev_char_idx == 0 {69return Ok(bytes.len());70}7172let mut offset = s.len();73let mut num_chars_seen = 0;7475// Auto-vectorized bulk processing, but skip if index is small.76if rev_char_idx >= 16 {77while offset >= 16 {78let chunk = unsafe { bytes.get_unchecked(offset - 16..offset) };79let chunk_seen: usize = chunk80.iter()81.map(|b| is_utf8_codepoint_start(*b) as usize)82.sum();83if num_chars_seen + chunk_seen >= rev_char_idx {84break;85}86offset -= 16;87num_chars_seen += chunk_seen;88}89}9091while offset > 0 {92offset -= 1;93let byte = unsafe { bytes.get_unchecked(offset) };94num_chars_seen += is_utf8_codepoint_start(*byte) as usize;95if num_chars_seen >= rev_char_idx {96return Ok(offset);97}98}99100debug_assert!(offset == 0);101Err(num_chars_seen)102}103104/// Counts rev_char_idx code points from *the end* of the string, returning an105/// offset in bytes where this codepoint ends.106///107/// For example, rev_char_to_byte_idx(0, s) returns s.len(), and108/// rev_char_to_byte_idx(1, s) returns s.len() - width(last_codepoint_in_s).109///110/// If rev_char_idx is large enough that we would go out of bounds, 0 is returned.111pub fn rev_char_to_byte_idx(s: &str, rev_char_idx: usize) -> usize {112if rev_char_idx >= s.len() {113// No need to even count.1140115} else {116rev_char_to_byte_idx_or_cp_count(s, rev_char_idx).unwrap_or(0)117}118}119120fn head_binary(opt_str_val: Option<&str>, opt_n: Option<i64>) -> Option<&str> {121if let (Some(str_val), Some(n)) = (opt_str_val, opt_n) {122let end_idx = head_binary_values(str_val, n);123Some(unsafe { str_val.get_unchecked(..end_idx) })124} else {125None126}127}128129fn head_binary_values(str_val: &str, n: i64) -> usize {130if n >= 0 {131char_to_byte_idx(str_val, n as usize)132} else {133rev_char_to_byte_idx(str_val, (-n) as usize)134}135}136137fn tail_binary(opt_str_val: Option<&str>, opt_n: Option<i64>) -> Option<&str> {138if let (Some(str_val), Some(n)) = (opt_str_val, opt_n) {139let start_idx = tail_binary_values(str_val, n);140Some(unsafe { str_val.get_unchecked(start_idx..) })141} else {142None143}144}145146fn tail_binary_values(str_val: &str, n: i64) -> usize {147if n >= 0 {148rev_char_to_byte_idx(str_val, n as usize)149} else {150char_to_byte_idx(str_val, (-n) as usize)151}152}153154fn substring_ternary_offsets(155opt_str_val: Option<&str>,156opt_offset: Option<i64>,157opt_length: Option<u64>,158) -> Option<(usize, usize)> {159let str_val = opt_str_val?;160let offset = opt_offset?;161Some(substring_ternary_offsets_value(162str_val,163offset,164opt_length.unwrap_or(u64::MAX),165))166}167168pub fn substring_ternary_offsets_value(169str_val: &str,170offset: i64,171mut length: u64,172) -> (usize, usize) {173// Fast-path: always empty string.174if length == 0 || offset >= str_val.len() as i64 {175return (0, 0);176}177178let start_byte_offset = if offset >= 0 {179char_to_byte_idx(str_val, offset as usize)180} else {181// Fast-path: always empty string.182let end_offset_upper_bound = offset183.saturating_add(str_val.len() as i64)184.saturating_add(length.try_into().unwrap_or(i64::MAX));185if end_offset_upper_bound < 0 {186return (0, 0);187}188189match rev_char_to_byte_idx_or_cp_count(str_val, (-offset) as usize) {190Ok(so) => so,191Err(n_cp) => {192// Our offset was so negative it is before the start of our string.193// This means our length must be reduced, assuming it is finite.194length = length.saturating_sub((-offset) as u64 - n_cp as u64);1950196},197}198};199200let stop_byte_offset = char_to_byte_idx(&str_val[start_byte_offset..], length as usize);201(start_byte_offset, start_byte_offset + stop_byte_offset)202}203204fn substring_ternary(205opt_str_val: Option<&str>,206opt_offset: Option<i64>,207opt_length: Option<u64>,208) -> Option<&str> {209let (start, end) = substring_ternary_offsets(opt_str_val, opt_offset, opt_length)?;210unsafe { opt_str_val.map(|str_val| str_val.get_unchecked(start..end)) }211}212213pub fn update_view(mut view: View, start: usize, end: usize, val: &str) -> View {214let length = (end - start) as u32;215view.length = length;216217// SAFETY: we just compute the start /end.218let subval = unsafe { val.get_unchecked(start..end).as_bytes() };219220if length <= 12 {221View::new_inline(subval)222} else {223view.offset += start as u32;224view.length = length;225view.prefix = u32::from_le_bytes(subval[0..4].try_into().unwrap());226view227}228}229230pub(super) fn substring(231ca: &StringChunked,232offset: &Int64Chunked,233length: &UInt64Chunked,234) -> StringChunked {235match (ca.len(), offset.len(), length.len()) {236(1, 1, _) => {237let str_val = ca.get(0);238let offset = offset.get(0);239unary_elementwise(length, |length| substring_ternary(str_val, offset, length))240.with_name(ca.name().clone())241},242(_, 1, 1) => {243let offset = offset.get(0);244let length = length.get(0).unwrap_or(u64::MAX);245246let Some(offset) = offset else {247return StringChunked::full_null(ca.name().clone(), ca.len());248};249250unsafe {251ca.apply_views(|view, val| {252let (start, end) = substring_ternary_offsets_value(val, offset, length);253update_view(view, start, end, val)254})255}256},257(1, _, 1) => {258let str_val = ca.get(0);259let length = length.get(0);260unary_elementwise(offset, |offset| substring_ternary(str_val, offset, length))261.with_name(ca.name().clone())262},263(1, len_b, len_c) if len_b == len_c => {264let str_val = ca.get(0);265binary_elementwise(offset, length, |offset, length| {266substring_ternary(str_val, offset, length)267})268},269(len_a, 1, len_c) if len_a == len_c => {270fn infer<F: for<'a> FnMut(Option<&'a str>, Option<u64>) -> Option<&'a str>>(f: F) -> F where271{272f273}274let offset = offset.get(0);275binary_elementwise(276ca,277length,278infer(|str_val, length| substring_ternary(str_val, offset, length)),279)280},281(len_a, len_b, 1) if len_a == len_b => {282fn infer<F: for<'a> FnMut(Option<&'a str>, Option<i64>) -> Option<&'a str>>(f: F) -> F where283{284f285}286let length = length.get(0);287binary_elementwise(288ca,289offset,290infer(|str_val, offset| substring_ternary(str_val, offset, length)),291)292},293_ => ternary_elementwise(ca, offset, length, substring_ternary),294}295}296297pub(super) fn head(ca: &StringChunked, n: &Int64Chunked) -> PolarsResult<StringChunked> {298match (ca.len(), n.len()) {299(len, 1) => {300let n = n.get(0);301let Some(n) = n else {302return Ok(StringChunked::full_null(ca.name().clone(), len));303};304305Ok(unsafe {306ca.apply_views(|view, val| {307let end = head_binary_values(val, n);308update_view(view, 0, end, val)309})310})311},312// TODO! below should also work on only views313(1, _) => {314let str_val = ca.get(0);315Ok(unary_elementwise(n, |n| head_binary(str_val, n)).with_name(ca.name().clone()))316},317(a, b) => {318polars_ensure!(a == b, ShapeMismatch: "lengths of arguments do not align in 'str.head' got length: {} for column: {}, got length: {} for argument 'n'", a, ca.name(), b);319Ok(binary_elementwise(ca, n, head_binary))320},321}322}323324pub(super) fn tail(ca: &StringChunked, n: &Int64Chunked) -> PolarsResult<StringChunked> {325Ok(match (ca.len(), n.len()) {326(len, 1) => {327let n = n.get(0);328let Some(n) = n else {329return Ok(StringChunked::full_null(ca.name().clone(), len));330};331unsafe {332ca.apply_views(|view, val| {333let start = tail_binary_values(val, n);334update_view(view, start, val.len(), val)335})336}337},338// TODO! below should also work on only views339(1, _) => {340let str_val = ca.get(0);341unary_elementwise(n, |n| tail_binary(str_val, n)).with_name(ca.name().clone())342},343(a, b) => {344polars_ensure!(a == b, ShapeMismatch: "lengths of arguments do not align in 'str.tail' got length: {} for column: {}, got length: {} for argument 'n'", a, ca.name(), b);345binary_elementwise(ca, n, tail_binary)346},347})348}349350351