Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-compute/src/horizontal_flatten/mod.rs
6939 views
1
#![allow(unsafe_op_in_unsafe_fn)]
2
use arrow::array::{
3
Array, ArrayCollectIterExt, BinaryArray, BinaryViewArray, BooleanArray, FixedSizeListArray,
4
ListArray, NullArray, PrimitiveArray, StaticArray, StructArray, Utf8ViewArray,
5
};
6
use arrow::bitmap::Bitmap;
7
use arrow::datatypes::{ArrowDataType, PhysicalType};
8
use arrow::with_match_primitive_type_full;
9
use strength_reduce::StrengthReducedUsize;
10
mod struct_;
11
12
/// Low-level operation used by `concat_arr`. This should be called with the inner values array of
13
/// every FixedSizeList array.
14
///
15
/// # Safety
16
/// * `arrays` is non-empty
17
/// * `arrays` and `widths` have equal length
18
/// * All widths in `widths` are non-zero
19
/// * Every array `arrays[i]` has a length of either
20
/// * `widths[i] * output_height`
21
/// * `widths[i]` (this would be broadcasted)
22
/// * All arrays in `arrays` have the same type
23
pub unsafe fn horizontal_flatten_unchecked(
24
arrays: &[Box<dyn Array>],
25
widths: &[usize],
26
output_height: usize,
27
) -> Box<dyn Array> {
28
use PhysicalType::*;
29
30
let dtype = arrays[0].dtype();
31
32
match dtype.to_physical_type() {
33
Null => Box::new(NullArray::new(
34
dtype.clone(),
35
output_height * widths.iter().copied().sum::<usize>(),
36
)),
37
Boolean => Box::new(horizontal_flatten_unchecked_impl_generic(
38
&arrays
39
.iter()
40
.map(|x| x.as_any().downcast_ref::<BooleanArray>().unwrap().clone())
41
.collect::<Vec<_>>(),
42
widths,
43
output_height,
44
dtype,
45
)),
46
Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| {
47
Box::new(horizontal_flatten_unchecked_impl_generic(
48
&arrays
49
.iter()
50
.map(|x| x.as_any().downcast_ref::<PrimitiveArray<$T>>().unwrap().clone())
51
.collect::<Vec<_>>(),
52
widths,
53
output_height,
54
dtype
55
))
56
}),
57
LargeBinary => Box::new(horizontal_flatten_unchecked_impl_generic(
58
&arrays
59
.iter()
60
.map(|x| {
61
x.as_any()
62
.downcast_ref::<BinaryArray<i64>>()
63
.unwrap()
64
.clone()
65
})
66
.collect::<Vec<_>>(),
67
widths,
68
output_height,
69
dtype,
70
)),
71
Struct => Box::new(struct_::horizontal_flatten_unchecked(
72
&arrays
73
.iter()
74
.map(|x| x.as_any().downcast_ref::<StructArray>().unwrap().clone())
75
.collect::<Vec<_>>(),
76
widths,
77
output_height,
78
)),
79
LargeList => Box::new(horizontal_flatten_unchecked_impl_generic(
80
&arrays
81
.iter()
82
.map(|x| x.as_any().downcast_ref::<ListArray<i64>>().unwrap().clone())
83
.collect::<Vec<_>>(),
84
widths,
85
output_height,
86
dtype,
87
)),
88
FixedSizeList => Box::new(horizontal_flatten_unchecked_impl_generic(
89
&arrays
90
.iter()
91
.map(|x| {
92
x.as_any()
93
.downcast_ref::<FixedSizeListArray>()
94
.unwrap()
95
.clone()
96
})
97
.collect::<Vec<_>>(),
98
widths,
99
output_height,
100
dtype,
101
)),
102
BinaryView => Box::new(horizontal_flatten_unchecked_impl_generic(
103
&arrays
104
.iter()
105
.map(|x| {
106
x.as_any()
107
.downcast_ref::<BinaryViewArray>()
108
.unwrap()
109
.clone()
110
})
111
.collect::<Vec<_>>(),
112
widths,
113
output_height,
114
dtype,
115
)),
116
Utf8View => Box::new(horizontal_flatten_unchecked_impl_generic(
117
&arrays
118
.iter()
119
.map(|x| x.as_any().downcast_ref::<Utf8ViewArray>().unwrap().clone())
120
.collect::<Vec<_>>(),
121
widths,
122
output_height,
123
dtype,
124
)),
125
t => unimplemented!("horizontal_flatten not supported for data type {:?}", t),
126
}
127
}
128
129
unsafe fn horizontal_flatten_unchecked_impl_generic<T>(
130
arrays: &[T],
131
widths: &[usize],
132
output_height: usize,
133
dtype: &ArrowDataType,
134
) -> T
135
where
136
T: StaticArray,
137
{
138
assert!(!arrays.is_empty());
139
assert_eq!(widths.len(), arrays.len());
140
141
debug_assert!(widths.iter().all(|x| *x > 0));
142
debug_assert!(
143
arrays
144
.iter()
145
.zip(widths)
146
.all(|(arr, width)| arr.len() == output_height * *width || arr.len() == *width)
147
);
148
149
// We modulo the array length to support broadcasting.
150
let lengths = arrays
151
.iter()
152
.map(|x| StrengthReducedUsize::new(x.len()))
153
.collect::<Vec<_>>();
154
let out_row_width: usize = widths.iter().cloned().sum();
155
let out_len = out_row_width.checked_mul(output_height).unwrap();
156
157
let mut col_idx = 0;
158
let mut row_idx = 0;
159
let mut until = widths[0];
160
let mut outer_row_idx = 0;
161
162
// We do `0..out_len` to get an `ExactSizeIterator`.
163
(0..out_len)
164
.map(|_| {
165
let arr = arrays.get_unchecked(col_idx);
166
let out = arr.get_unchecked(row_idx % *lengths.get_unchecked(col_idx));
167
168
row_idx += 1;
169
170
if row_idx == until {
171
// Safety: All widths are non-zero so we only need to increment once.
172
col_idx = if 1 + col_idx == widths.len() {
173
outer_row_idx += 1;
174
0
175
} else {
176
1 + col_idx
177
};
178
row_idx = outer_row_idx * *widths.get_unchecked(col_idx);
179
until = (1 + outer_row_idx) * *widths.get_unchecked(col_idx)
180
}
181
182
out
183
})
184
.collect_arr_trusted_with_dtype(dtype.clone())
185
}
186
187