Path: blob/main/crates/polars-ops/src/chunked_array/strings/concat.rs
6939 views
use arrow::array::{Utf8Array, ValueSize};1use polars_compute::cast::utf8_to_utf8view;2use polars_core::prelude::arity::unary_elementwise;3use polars_core::prelude::*;45// Vertically concatenate all strings in a StringChunked.6pub fn str_join(ca: &StringChunked, delimiter: &str, ignore_nulls: bool) -> StringChunked {7if ca.is_empty() {8return StringChunked::new(ca.name().clone(), &[""]);9}1011// Propagate null value.12if !ignore_nulls && ca.null_count() != 0 {13return StringChunked::full_null(ca.name().clone(), 1);14}1516// Fast path for all nulls.17if ignore_nulls && ca.null_count() == ca.len() {18return StringChunked::new(ca.name().clone(), &[""]);19}2021if ca.len() == 1 {22return ca.clone();23}2425// Calculate capacity.26let capacity = ca.get_values_size() + delimiter.len() * (ca.len() - 1);2728let mut buf = String::with_capacity(capacity);29let mut first = true;30ca.for_each(|val| {31if let Some(val) = val {32if !first {33buf.push_str(delimiter);34}35buf.push_str(val);36first = false;37}38});3940let buf = buf.into_bytes();41assert!(capacity >= buf.len());42let offsets = vec![0, buf.len() as i64];43let arr = unsafe { Utf8Array::from_data_unchecked_default(offsets.into(), buf.into(), None) };44// conversion is cheap with one value.45let arr = utf8_to_utf8view(&arr);46StringChunked::with_chunk(ca.name().clone(), arr)47}4849enum ColumnIter<I, T> {50Iter(I),51Broadcast(T),52}5354/// Horizontally concatenate all strings.55///56/// Each array should have length 1 or a length equal to the maximum length.57pub fn hor_str_concat(58cas: &[&StringChunked],59delimiter: &str,60ignore_nulls: bool,61) -> PolarsResult<StringChunked> {62if cas.is_empty() {63return Ok(StringChunked::full_null(PlSmallStr::EMPTY, 0));64}65if cas.len() == 1 {66let ca = cas[0];67return if !ignore_nulls || ca.null_count() == 0 {68Ok(ca.clone())69} else {70Ok(unary_elementwise(ca, |val| Some(val.unwrap_or(""))))71};72}7374// Calculate the post-broadcast length and ensure everything is consistent.75let len = cas76.iter()77.map(|ca| ca.len())78.filter(|l| *l != 1)79.max()80.unwrap_or(1);81polars_ensure!(82cas.iter().all(|ca| ca.len() == 1 || ca.len() == len),83ShapeMismatch: "all series in `hor_str_concat` should have equal or unit length"84);8586let mut builder = StringChunkedBuilder::new(cas[0].name().clone(), len);8788// Broadcast if appropriate.89let mut cols: Vec<_> = cas90.iter()91.map(|ca| match ca.len() {920 => ColumnIter::Broadcast(None),931 => ColumnIter::Broadcast(ca.get(0)),94_ => ColumnIter::Iter(ca.iter()),95})96.collect();9798// Build concatenated string.99let mut buf = String::with_capacity(1024);100for _row in 0..len {101let mut has_null = false;102let mut found_not_null_value = false;103for col in cols.iter_mut() {104let val = match col {105ColumnIter::Iter(i) => i.next().unwrap(),106ColumnIter::Broadcast(s) => *s,107};108109if has_null && !ignore_nulls {110// We know that the result must be null, but we can't just break out of the loop,111// because all cols iterator has to be moved correctly.112continue;113}114115if let Some(s) = val {116if found_not_null_value {117buf.push_str(delimiter);118}119buf.push_str(s);120found_not_null_value = true;121} else {122has_null = true;123}124}125126if !ignore_nulls && has_null {127builder.append_null();128} else {129builder.append_value(&buf)130}131buf.clear();132}133134Ok(builder.finish())135}136137#[cfg(test)]138mod test {139use super::*;140141#[test]142fn test_str_concat() {143let ca = Int32Chunked::new("foo".into(), &[Some(1), None, Some(3)]);144let ca_str = ca.cast(&DataType::String).unwrap();145let out = str_join(ca_str.str().unwrap(), "-", true);146147let out = out.get(0);148assert_eq!(out, Some("1-3"));149}150151#[test]152fn test_hor_str_concat() {153let a = StringChunked::new("a".into(), &["foo", "bar"]);154let b = StringChunked::new("b".into(), &["spam", "ham"]);155156let out = hor_str_concat(&[&a, &b], "_", true).unwrap();157assert_eq!(Vec::from(&out), &[Some("foo_spam"), Some("bar_ham")]);158159let c = StringChunked::new("b".into(), &["literal"]);160let out = hor_str_concat(&[&a, &b, &c], "_", true).unwrap();161assert_eq!(162Vec::from(&out),163&[Some("foo_spam_literal"), Some("bar_ham_literal")]164);165}166}167168169