Path: blob/main/crates/polars-compute/src/horizontal_flatten/mod.rs
6939 views
#![allow(unsafe_op_in_unsafe_fn)]1use arrow::array::{2Array, ArrayCollectIterExt, BinaryArray, BinaryViewArray, BooleanArray, FixedSizeListArray,3ListArray, NullArray, PrimitiveArray, StaticArray, StructArray, Utf8ViewArray,4};5use arrow::bitmap::Bitmap;6use arrow::datatypes::{ArrowDataType, PhysicalType};7use arrow::with_match_primitive_type_full;8use strength_reduce::StrengthReducedUsize;9mod struct_;1011/// Low-level operation used by `concat_arr`. This should be called with the inner values array of12/// every FixedSizeList array.13///14/// # Safety15/// * `arrays` is non-empty16/// * `arrays` and `widths` have equal length17/// * All widths in `widths` are non-zero18/// * Every array `arrays[i]` has a length of either19/// * `widths[i] * output_height`20/// * `widths[i]` (this would be broadcasted)21/// * All arrays in `arrays` have the same type22pub unsafe fn horizontal_flatten_unchecked(23arrays: &[Box<dyn Array>],24widths: &[usize],25output_height: usize,26) -> Box<dyn Array> {27use PhysicalType::*;2829let dtype = arrays[0].dtype();3031match dtype.to_physical_type() {32Null => Box::new(NullArray::new(33dtype.clone(),34output_height * widths.iter().copied().sum::<usize>(),35)),36Boolean => Box::new(horizontal_flatten_unchecked_impl_generic(37&arrays38.iter()39.map(|x| x.as_any().downcast_ref::<BooleanArray>().unwrap().clone())40.collect::<Vec<_>>(),41widths,42output_height,43dtype,44)),45Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| {46Box::new(horizontal_flatten_unchecked_impl_generic(47&arrays48.iter()49.map(|x| x.as_any().downcast_ref::<PrimitiveArray<$T>>().unwrap().clone())50.collect::<Vec<_>>(),51widths,52output_height,53dtype54))55}),56LargeBinary => Box::new(horizontal_flatten_unchecked_impl_generic(57&arrays58.iter()59.map(|x| {60x.as_any()61.downcast_ref::<BinaryArray<i64>>()62.unwrap()63.clone()64})65.collect::<Vec<_>>(),66widths,67output_height,68dtype,69)),70Struct => Box::new(struct_::horizontal_flatten_unchecked(71&arrays72.iter()73.map(|x| x.as_any().downcast_ref::<StructArray>().unwrap().clone())74.collect::<Vec<_>>(),75widths,76output_height,77)),78LargeList => Box::new(horizontal_flatten_unchecked_impl_generic(79&arrays80.iter()81.map(|x| x.as_any().downcast_ref::<ListArray<i64>>().unwrap().clone())82.collect::<Vec<_>>(),83widths,84output_height,85dtype,86)),87FixedSizeList => Box::new(horizontal_flatten_unchecked_impl_generic(88&arrays89.iter()90.map(|x| {91x.as_any()92.downcast_ref::<FixedSizeListArray>()93.unwrap()94.clone()95})96.collect::<Vec<_>>(),97widths,98output_height,99dtype,100)),101BinaryView => Box::new(horizontal_flatten_unchecked_impl_generic(102&arrays103.iter()104.map(|x| {105x.as_any()106.downcast_ref::<BinaryViewArray>()107.unwrap()108.clone()109})110.collect::<Vec<_>>(),111widths,112output_height,113dtype,114)),115Utf8View => Box::new(horizontal_flatten_unchecked_impl_generic(116&arrays117.iter()118.map(|x| x.as_any().downcast_ref::<Utf8ViewArray>().unwrap().clone())119.collect::<Vec<_>>(),120widths,121output_height,122dtype,123)),124t => unimplemented!("horizontal_flatten not supported for data type {:?}", t),125}126}127128unsafe fn horizontal_flatten_unchecked_impl_generic<T>(129arrays: &[T],130widths: &[usize],131output_height: usize,132dtype: &ArrowDataType,133) -> T134where135T: StaticArray,136{137assert!(!arrays.is_empty());138assert_eq!(widths.len(), arrays.len());139140debug_assert!(widths.iter().all(|x| *x > 0));141debug_assert!(142arrays143.iter()144.zip(widths)145.all(|(arr, width)| arr.len() == output_height * *width || arr.len() == *width)146);147148// We modulo the array length to support broadcasting.149let lengths = arrays150.iter()151.map(|x| StrengthReducedUsize::new(x.len()))152.collect::<Vec<_>>();153let out_row_width: usize = widths.iter().cloned().sum();154let out_len = out_row_width.checked_mul(output_height).unwrap();155156let mut col_idx = 0;157let mut row_idx = 0;158let mut until = widths[0];159let mut outer_row_idx = 0;160161// We do `0..out_len` to get an `ExactSizeIterator`.162(0..out_len)163.map(|_| {164let arr = arrays.get_unchecked(col_idx);165let out = arr.get_unchecked(row_idx % *lengths.get_unchecked(col_idx));166167row_idx += 1;168169if row_idx == until {170// Safety: All widths are non-zero so we only need to increment once.171col_idx = if 1 + col_idx == widths.len() {172outer_row_idx += 1;1730174} else {1751 + col_idx176};177row_idx = outer_row_idx * *widths.get_unchecked(col_idx);178until = (1 + outer_row_idx) * *widths.get_unchecked(col_idx)179}180181out182})183.collect_arr_trusted_with_dtype(dtype.clone())184}185186187