Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-ops/src/chunked_array/repeat_by.rs
6939 views
1
use arrow::array::builder::{ArrayBuilder, ShareStrategy, make_builder};
2
use arrow::array::{Array, IntoBoxedArray, ListArray, NullArray};
3
use arrow::bitmap::BitmapBuilder;
4
use arrow::offset::Offsets;
5
use arrow::pushable::Pushable;
6
use polars_core::prelude::*;
7
use polars_core::with_match_physical_numeric_polars_type;
8
9
type LargeListArray = ListArray<i64>;
10
11
fn check_lengths(length_srs: usize, length_by: usize) -> PolarsResult<()> {
12
polars_ensure!(
13
(length_srs == length_by) | (length_by == 1) | (length_srs == 1),
14
ShapeMismatch: "repeat_by argument and the Series should have equal length, or at least one of them should have length 1. Series length {}, by length {}",
15
length_srs, length_by
16
);
17
Ok(())
18
}
19
20
fn new_by(by: &IdxCa, len: usize) -> IdxCa {
21
if let Some(x) = by.get(0) {
22
let values = std::iter::repeat_n(x, len).collect::<Vec<IdxSize>>();
23
IdxCa::new(PlSmallStr::EMPTY, values)
24
} else {
25
IdxCa::full_null(PlSmallStr::EMPTY, len)
26
}
27
}
28
29
fn repeat_by_primitive<T>(ca: &ChunkedArray<T>, by: &IdxCa) -> PolarsResult<ListChunked>
30
where
31
T: PolarsNumericType,
32
{
33
check_lengths(ca.len(), by.len())?;
34
35
match (ca.len(), by.len()) {
36
(left_len, right_len) if left_len == right_len => {
37
Ok(arity::binary(ca, by, |arr, by| {
38
let iter = arr.into_iter().zip(by).map(|(opt_v, opt_by)| {
39
opt_by.map(|by| std::iter::repeat_n(opt_v.copied(), *by as usize))
40
});
41
42
// SAFETY: length of iter is trusted.
43
unsafe {
44
LargeListArray::from_iter_primitive_trusted_len(
45
iter,
46
T::get_static_dtype().to_arrow(CompatLevel::newest()),
47
)
48
}
49
}))
50
},
51
(_, 1) => {
52
let by = new_by(by, ca.len());
53
repeat_by_primitive(ca, &by)
54
},
55
(1, _) => {
56
let new_array = ca.new_from_index(0, by.len());
57
repeat_by_primitive(&new_array, by)
58
},
59
// we have already checked the length
60
_ => unreachable!(),
61
}
62
}
63
64
fn repeat_by_bool(ca: &BooleanChunked, by: &IdxCa) -> PolarsResult<ListChunked> {
65
check_lengths(ca.len(), by.len())?;
66
67
match (ca.len(), by.len()) {
68
(left_len, right_len) if left_len == right_len => {
69
Ok(arity::binary(ca, by, |arr, by| {
70
let iter = arr.into_iter().zip(by).map(|(opt_v, opt_by)| {
71
opt_by.map(|by| std::iter::repeat_n(opt_v, *by as usize))
72
});
73
74
// SAFETY: length of iter is trusted.
75
unsafe { LargeListArray::from_iter_bool_trusted_len(iter) }
76
}))
77
},
78
(_, 1) => {
79
let by = new_by(by, ca.len());
80
repeat_by_bool(ca, &by)
81
},
82
(1, _) => {
83
let new_array = ca.new_from_index(0, by.len());
84
repeat_by_bool(&new_array, by)
85
},
86
// we have already checked the length
87
_ => unreachable!(),
88
}
89
}
90
91
fn repeat_by_binary(ca: &BinaryChunked, by: &IdxCa) -> PolarsResult<ListChunked> {
92
check_lengths(ca.len(), by.len())?;
93
94
match (ca.len(), by.len()) {
95
(left_len, right_len) if left_len == right_len => {
96
Ok(arity::binary(ca, by, |arr, by| {
97
let iter = arr.into_iter().zip(by).map(|(opt_v, opt_by)| {
98
opt_by.map(|by| std::iter::repeat_n(opt_v, *by as usize))
99
});
100
101
// SAFETY: length of iter is trusted.
102
unsafe { LargeListArray::from_iter_binary_trusted_len(iter, ca.len()) }
103
}))
104
},
105
(_, 1) => {
106
let by = new_by(by, ca.len());
107
repeat_by_binary(ca, &by)
108
},
109
(1, _) => {
110
let new_array = ca.new_from_index(0, by.len());
111
repeat_by_binary(&new_array, by)
112
},
113
// we have already checked the length
114
_ => unreachable!(),
115
}
116
}
117
118
fn repeat_by_list(ca: &ListChunked, by: &IdxCa) -> PolarsResult<ListChunked> {
119
check_lengths(ca.len(), by.len())?;
120
121
match (ca.len(), by.len()) {
122
(left_len, right_len) if left_len == right_len => Ok(repeat_by_generic_inner(ca, by)),
123
(_, 1) => {
124
let by = new_by(by, ca.len());
125
repeat_by_list(ca, &by)
126
},
127
(1, _) => {
128
let new_array = ca.new_from_index(0, by.len());
129
repeat_by_list(&new_array, by)
130
},
131
// we have already checked the length
132
_ => unreachable!(),
133
}
134
}
135
136
fn repeat_by_null(ca: &NullChunked, by: &IdxCa) -> PolarsResult<ListChunked> {
137
check_lengths(ca.len(), by.len())?;
138
139
match (ca.len(), by.len()) {
140
(left_len, right_len) if left_len == right_len => {
141
let arr_length = by.iter().flatten().map(|x| x as usize).sum();
142
let arr = NullArray::new(ArrowDataType::Null, arr_length);
143
144
let mut validity = BitmapBuilder::with_capacity(by.len());
145
let mut offsets = Offsets::<i64>::with_capacity(by.len());
146
for n_repeat in by.iter() {
147
validity.push(n_repeat.is_some());
148
if let Some(repeats) = n_repeat {
149
offsets.push(repeats as usize);
150
} else {
151
offsets.push_null();
152
}
153
}
154
155
let array = LargeListArray::new(
156
ListArray::<i64>::default_datatype(arr.dtype().clone()),
157
offsets.into(),
158
arr.into_boxed(),
159
validity.into_opt_validity(),
160
);
161
162
Ok(unsafe {
163
ListChunked::from_chunks_and_dtype(
164
ca.name().clone(),
165
vec![array.into_boxed()],
166
DataType::List(Box::new(DataType::Null)),
167
)
168
})
169
},
170
(_, 1) => {
171
let by = new_by(by, ca.len());
172
repeat_by_null(ca, &by)
173
},
174
(1, _) => {
175
let new_array = ca.new_from_index(0, by.len());
176
let new_array = new_array.null().unwrap();
177
repeat_by_null(new_array, by)
178
},
179
// we have already checked the length
180
_ => unreachable!(),
181
}
182
}
183
184
#[cfg(feature = "dtype-array")]
185
fn repeat_by_array(ca: &ArrayChunked, by: &IdxCa) -> PolarsResult<ListChunked> {
186
check_lengths(ca.len(), by.len())?;
187
188
match (ca.len(), by.len()) {
189
(left_len, right_len) if left_len == right_len => Ok(repeat_by_generic_inner(ca, by)),
190
(_, 1) => {
191
let by = new_by(by, ca.len());
192
repeat_by_array(ca, &by)
193
},
194
(1, _) => {
195
let new_array = ca.new_from_index(0, by.len());
196
repeat_by_array(&new_array, by)
197
},
198
// we have already checked the length
199
_ => unreachable!(),
200
}
201
}
202
203
#[cfg(feature = "dtype-struct")]
204
fn repeat_by_struct(ca: &StructChunked, by: &IdxCa) -> PolarsResult<ListChunked> {
205
check_lengths(ca.len(), by.len())?;
206
207
match (ca.len(), by.len()) {
208
(left_len, right_len) if left_len == right_len => Ok(repeat_by_generic_inner(ca, by)),
209
(_, 1) => {
210
let by = new_by(by, ca.len());
211
repeat_by_struct(ca, &by)
212
},
213
(1, _) => {
214
let new_array = ca.new_from_index(0, by.len());
215
repeat_by_struct(&new_array, by)
216
},
217
// we have already checked the length
218
_ => unreachable!(),
219
}
220
}
221
222
fn repeat_by_generic_inner<T: PolarsDataType>(ca: &ChunkedArray<T>, by: &IdxCa) -> ListChunked {
223
let mut builder = make_builder(&ca.dtype().to_arrow(CompatLevel::newest()));
224
arity::binary(ca, by, |arr, by| {
225
let arr_length = by.iter().flatten().map(|x| *x as usize).sum();
226
builder.reserve(arr_length);
227
228
let mut validity = BitmapBuilder::with_capacity(by.len());
229
let mut offsets = Offsets::<i64>::with_capacity(by.len());
230
for (idx, n_repeat) in by.iter().enumerate() {
231
validity.push(n_repeat.is_some());
232
if let Some(repeats) = n_repeat {
233
offsets.push(*repeats as usize);
234
builder.subslice_extend_repeated(
235
arr,
236
idx,
237
1,
238
*repeats as usize,
239
ShareStrategy::Always,
240
);
241
} else {
242
offsets.push_null();
243
}
244
}
245
246
let repeated_values = builder.freeze_reset();
247
LargeListArray::new(
248
ListArray::<i64>::default_datatype(arr.dtype().clone()),
249
offsets.into(),
250
repeated_values,
251
validity.into_opt_validity(),
252
)
253
})
254
}
255
256
pub fn repeat_by(s: &Series, by: &IdxCa) -> PolarsResult<ListChunked> {
257
let s_phys = s.to_physical_repr();
258
use DataType as D;
259
let out = match s_phys.dtype() {
260
D::Null => repeat_by_null(s_phys.null().unwrap(), by),
261
D::Boolean => repeat_by_bool(s_phys.bool().unwrap(), by),
262
D::String => {
263
let ca = s_phys.str().unwrap();
264
repeat_by_binary(&ca.as_binary(), by)
265
.and_then(|ca| ca.apply_to_inner(&|s| unsafe { s.cast_unchecked(&D::String) }))
266
},
267
D::Binary => repeat_by_binary(s_phys.binary().unwrap(), by),
268
dt if dt.is_primitive_numeric() => {
269
with_match_physical_numeric_polars_type!(dt, |$T| {
270
let ca: &ChunkedArray<$T> = s_phys.as_ref().as_ref().as_ref();
271
repeat_by_primitive(ca, by)
272
})
273
},
274
D::List(_) => repeat_by_list(s_phys.list().unwrap(), by),
275
#[cfg(feature = "dtype-struct")]
276
D::Struct(_) => repeat_by_struct(s_phys.struct_().unwrap(), by),
277
#[cfg(feature = "dtype-array")]
278
D::Array(_, _) => repeat_by_array(s_phys.array().unwrap(), by),
279
_ => polars_bail!(opq = repeat_by, s.dtype()),
280
};
281
out.and_then(|ca| {
282
let logical_type = s.dtype();
283
ca.apply_to_inner(&|s| unsafe { s.from_physical_unchecked(logical_type) })
284
})
285
}
286
287