Path: blob/main/crates/polars-ops/src/series/ops/concat_arr.rs
6939 views
use arrow::array::FixedSizeListArray;1use arrow::compute::utils::combine_validities_and;2use polars_compute::horizontal_flatten::horizontal_flatten_unchecked;3use polars_core::prelude::{ArrayChunked, Column, CompatLevel, DataType, IntoColumn};4use polars_core::series::Series;5use polars_error::{PolarsResult, polars_bail};6use polars_utils::pl_str::PlSmallStr;78/// Note: The caller must ensure all columns in `args` have the same type.9///10/// # Panics11/// Panics if12/// * `args` is empty13/// * `dtype` is not a `DataType::Array`14pub fn concat_arr(args: &[Column], dtype: &DataType) -> PolarsResult<Column> {15let DataType::Array(inner_dtype, width) = dtype else {16panic!("{}", dtype);17};1819let inner_dtype = inner_dtype.as_ref();20let width = *width;2122let mut output_height = args[0].len();23let mut calculated_width = 0;24let mut mismatch_height = (&PlSmallStr::EMPTY, output_height);25// If there is a `Array` column with a single NULL, the output will be entirely NULL.26let mut return_all_null = false;27// Indicates whether all `arrays` have unit length (excluding zero-width arrays)28let mut all_unit_len = true;29let mut validities = Vec::with_capacity(args.len());3031let (arrays, widths): (Vec<_>, Vec<_>) = args32.iter()33.map(|c| {34let len = c.len();3536// Handle broadcasting37if output_height == 1 {38output_height = len;39mismatch_height.1 = len;40}4142if len != output_height && len != 1 && mismatch_height.1 == output_height {43mismatch_height = (c.name(), len);44}4546// Don't expand scalars to height, this is handled by the `horizontal_flatten` kernel.47let s = c.as_materialized_series_maintain_scalar();4849match s.dtype() {50DataType::Array(inner, width) => {51debug_assert_eq!(inner.as_ref(), inner_dtype);5253let arr = s.array().unwrap().rechunk();54let validity = arr.rechunk_validity();5556return_all_null |= len == 1 && validity.as_ref().is_some_and(|x| !x.get_bit(0));5758// Ignore unit-length validities. If they are non-valid then `return_all_null` will59// cause an early return.60if let Some(v) = validity.filter(|_| len > 1) {61validities.push(v)62}6364(arr.downcast_as_array().values().clone(), *width)65},66dtype => {67debug_assert_eq!(dtype, inner_dtype);68// Note: We ignore the validity of non-array input columns, their outer is always valid after69// being reshaped to (-1, 1).70(s.rechunk().into_chunks()[0].clone(), 1)71},72}73})74// Filter out zero-width75.filter(|x| x.1 > 0)76.inspect(|x| {77calculated_width += x.1;78all_unit_len &= x.0.len() == 1;79})80.unzip();8182assert_eq!(calculated_width, width);8384if mismatch_height.1 != output_height {85polars_bail!(86ShapeMismatch:87"concat_arr: length of column '{}' (len={}) did not match length of \88first column '{}' (len={})",89mismatch_height.0, mismatch_height.1, args[0].name(), output_height,90)91}9293if return_all_null || output_height == 0 {94let arr =95FixedSizeListArray::new_null(dtype.to_arrow(CompatLevel::newest()), output_height);96return Ok(ArrayChunked::with_chunk(args[0].name().clone(), arr).into_column());97}9899// Combine validities100let outer_validity = validities.into_iter().fold(None, |a, b| {101debug_assert_eq!(b.len(), output_height);102combine_validities_and(a.as_ref(), Some(&b))103});104105// At this point the output height and all arrays should have non-zero length106let out = if all_unit_len && width > 0 {107// Fast-path for all scalars108let inner_arr = unsafe { horizontal_flatten_unchecked(&arrays, &widths, 1) };109110let arr = FixedSizeListArray::new(111dtype.to_arrow(CompatLevel::newest()),1121,113inner_arr,114outer_validity,115);116117return Ok(ArrayChunked::with_chunk(args[0].name().clone(), arr)118.into_column()119.new_from_index(0, output_height));120} else {121let inner_arr = if width == 0 {122Series::new_empty(PlSmallStr::EMPTY, inner_dtype)123.into_chunks()124.into_iter()125.next()126.unwrap()127} else {128unsafe { horizontal_flatten_unchecked(&arrays, &widths, output_height) }129};130131let arr = FixedSizeListArray::new(132dtype.to_arrow(CompatLevel::newest()),133output_height,134inner_arr,135outer_validity,136);137ArrayChunked::with_chunk(args[0].name().clone(), arr).into_column()138};139140Ok(out)141}142143144