Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-ops/src/series/ops/concat_arr.rs
6939 views
1
use arrow::array::FixedSizeListArray;
2
use arrow::compute::utils::combine_validities_and;
3
use polars_compute::horizontal_flatten::horizontal_flatten_unchecked;
4
use polars_core::prelude::{ArrayChunked, Column, CompatLevel, DataType, IntoColumn};
5
use polars_core::series::Series;
6
use polars_error::{PolarsResult, polars_bail};
7
use polars_utils::pl_str::PlSmallStr;
8
9
/// Note: The caller must ensure all columns in `args` have the same type.
10
///
11
/// # Panics
12
/// Panics if
13
/// * `args` is empty
14
/// * `dtype` is not a `DataType::Array`
15
pub fn concat_arr(args: &[Column], dtype: &DataType) -> PolarsResult<Column> {
16
let DataType::Array(inner_dtype, width) = dtype else {
17
panic!("{}", dtype);
18
};
19
20
let inner_dtype = inner_dtype.as_ref();
21
let width = *width;
22
23
let mut output_height = args[0].len();
24
let mut calculated_width = 0;
25
let mut mismatch_height = (&PlSmallStr::EMPTY, output_height);
26
// If there is a `Array` column with a single NULL, the output will be entirely NULL.
27
let mut return_all_null = false;
28
// Indicates whether all `arrays` have unit length (excluding zero-width arrays)
29
let mut all_unit_len = true;
30
let mut validities = Vec::with_capacity(args.len());
31
32
let (arrays, widths): (Vec<_>, Vec<_>) = args
33
.iter()
34
.map(|c| {
35
let len = c.len();
36
37
// Handle broadcasting
38
if output_height == 1 {
39
output_height = len;
40
mismatch_height.1 = len;
41
}
42
43
if len != output_height && len != 1 && mismatch_height.1 == output_height {
44
mismatch_height = (c.name(), len);
45
}
46
47
// Don't expand scalars to height, this is handled by the `horizontal_flatten` kernel.
48
let s = c.as_materialized_series_maintain_scalar();
49
50
match s.dtype() {
51
DataType::Array(inner, width) => {
52
debug_assert_eq!(inner.as_ref(), inner_dtype);
53
54
let arr = s.array().unwrap().rechunk();
55
let validity = arr.rechunk_validity();
56
57
return_all_null |= len == 1 && validity.as_ref().is_some_and(|x| !x.get_bit(0));
58
59
// Ignore unit-length validities. If they are non-valid then `return_all_null` will
60
// cause an early return.
61
if let Some(v) = validity.filter(|_| len > 1) {
62
validities.push(v)
63
}
64
65
(arr.downcast_as_array().values().clone(), *width)
66
},
67
dtype => {
68
debug_assert_eq!(dtype, inner_dtype);
69
// Note: We ignore the validity of non-array input columns, their outer is always valid after
70
// being reshaped to (-1, 1).
71
(s.rechunk().into_chunks()[0].clone(), 1)
72
},
73
}
74
})
75
// Filter out zero-width
76
.filter(|x| x.1 > 0)
77
.inspect(|x| {
78
calculated_width += x.1;
79
all_unit_len &= x.0.len() == 1;
80
})
81
.unzip();
82
83
assert_eq!(calculated_width, width);
84
85
if mismatch_height.1 != output_height {
86
polars_bail!(
87
ShapeMismatch:
88
"concat_arr: length of column '{}' (len={}) did not match length of \
89
first column '{}' (len={})",
90
mismatch_height.0, mismatch_height.1, args[0].name(), output_height,
91
)
92
}
93
94
if return_all_null || output_height == 0 {
95
let arr =
96
FixedSizeListArray::new_null(dtype.to_arrow(CompatLevel::newest()), output_height);
97
return Ok(ArrayChunked::with_chunk(args[0].name().clone(), arr).into_column());
98
}
99
100
// Combine validities
101
let outer_validity = validities.into_iter().fold(None, |a, b| {
102
debug_assert_eq!(b.len(), output_height);
103
combine_validities_and(a.as_ref(), Some(&b))
104
});
105
106
// At this point the output height and all arrays should have non-zero length
107
let out = if all_unit_len && width > 0 {
108
// Fast-path for all scalars
109
let inner_arr = unsafe { horizontal_flatten_unchecked(&arrays, &widths, 1) };
110
111
let arr = FixedSizeListArray::new(
112
dtype.to_arrow(CompatLevel::newest()),
113
1,
114
inner_arr,
115
outer_validity,
116
);
117
118
return Ok(ArrayChunked::with_chunk(args[0].name().clone(), arr)
119
.into_column()
120
.new_from_index(0, output_height));
121
} else {
122
let inner_arr = if width == 0 {
123
Series::new_empty(PlSmallStr::EMPTY, inner_dtype)
124
.into_chunks()
125
.into_iter()
126
.next()
127
.unwrap()
128
} else {
129
unsafe { horizontal_flatten_unchecked(&arrays, &widths, output_height) }
130
};
131
132
let arr = FixedSizeListArray::new(
133
dtype.to_arrow(CompatLevel::newest()),
134
output_height,
135
inner_arr,
136
outer_validity,
137
);
138
ArrayChunked::with_chunk(args[0].name().clone(), arr).into_column()
139
};
140
141
Ok(out)
142
}
143
144