Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-ops/src/chunked_array/scatter.rs
8412 views
1
#![allow(unsafe_op_in_unsafe_fn)]
2
use arrow::array::{Array, BinaryViewArrayGeneric, BooleanArray, PrimitiveArray, View, ViewType};
3
use polars_buffer::Buffer;
4
use polars_core::prelude::*;
5
use polars_core::series::IsSorted;
6
use polars_core::utils::arrow::bitmap::MutableBitmap;
7
use polars_core::utils::arrow::types::NativeType;
8
use polars_utils::index::check_bounds;
9
10
pub trait ChunkedSet<T: Copy> {
11
/// Invariant for implementations: if the scatter() fails, typically because
12
/// of bad indexes, then self should remain unmodified.
13
fn scatter<V>(self, idx: &[IdxSize], values: V) -> PolarsResult<Series>
14
where
15
V: IntoIterator<Item = Option<T>>;
16
}
17
18
trait PolarsOpsNumericType: PolarsNumericType {}
19
20
impl PolarsOpsNumericType for UInt8Type {}
21
impl PolarsOpsNumericType for UInt16Type {}
22
impl PolarsOpsNumericType for UInt32Type {}
23
impl PolarsOpsNumericType for UInt64Type {}
24
#[cfg(feature = "dtype-u128")]
25
impl PolarsOpsNumericType for UInt128Type {}
26
impl PolarsOpsNumericType for Int8Type {}
27
impl PolarsOpsNumericType for Int16Type {}
28
impl PolarsOpsNumericType for Int32Type {}
29
impl PolarsOpsNumericType for Int64Type {}
30
#[cfg(feature = "dtype-i128")]
31
impl PolarsOpsNumericType for Int128Type {}
32
#[cfg(feature = "dtype-f16")]
33
impl PolarsOpsNumericType for Float16Type {}
34
impl PolarsOpsNumericType for Float32Type {}
35
impl PolarsOpsNumericType for Float64Type {}
36
37
unsafe fn scatter_primitive_impl<V, T: NativeType>(
38
set_values: V,
39
arr: &mut PrimitiveArray<T>,
40
idx: &[IdxSize],
41
) where
42
V: IntoIterator<Item = Option<T>>,
43
{
44
let mut values_iter = set_values.into_iter();
45
46
if let Some(validity) = arr.take_validity() {
47
let mut mut_validity = validity.make_mut();
48
arr.with_values_mut(|cur_values| {
49
for (idx, val) in idx.iter().zip(&mut values_iter) {
50
match val {
51
Some(value) => {
52
mut_validity.set_unchecked(*idx as usize, true);
53
*cur_values.get_unchecked_mut(*idx as usize) = value
54
},
55
None => mut_validity.set_unchecked(*idx as usize, false),
56
}
57
}
58
});
59
arr.set_validity(mut_validity.into())
60
} else {
61
let mut null_idx = vec![];
62
arr.with_values_mut(|cur_values| {
63
for (idx, val) in idx.iter().zip(values_iter) {
64
match val {
65
Some(value) => *cur_values.get_unchecked_mut(*idx as usize) = value,
66
None => {
67
null_idx.push(*idx);
68
},
69
}
70
}
71
});
72
73
// Only make a validity bitmap when null values are set.
74
if !null_idx.is_empty() {
75
let mut validity = MutableBitmap::with_capacity(arr.len());
76
validity.extend_constant(arr.len(), true);
77
for idx in null_idx {
78
validity.set_unchecked(idx as usize, false)
79
}
80
arr.set_validity(Some(validity.into()))
81
}
82
}
83
}
84
85
unsafe fn scatter_bool_impl<V>(set_values: V, arr: &mut BooleanArray, idx: &[IdxSize])
86
where
87
V: IntoIterator<Item = Option<bool>>,
88
{
89
let mut values_iter = set_values.into_iter();
90
91
if let Some(validity) = arr.take_validity() {
92
let mut mut_validity = validity.make_mut();
93
arr.apply_values_mut(|cur_values| {
94
for (idx, val) in idx.iter().zip(&mut values_iter) {
95
match val {
96
Some(value) => {
97
mut_validity.set_unchecked(*idx as usize, true);
98
cur_values.set_unchecked(*idx as usize, value);
99
},
100
None => mut_validity.set_unchecked(*idx as usize, false),
101
}
102
}
103
});
104
arr.set_validity(mut_validity.into())
105
} else {
106
let mut null_idx = vec![];
107
arr.apply_values_mut(|cur_values| {
108
for (idx, val) in idx.iter().zip(values_iter) {
109
match val {
110
Some(value) => cur_values.set_unchecked(*idx as usize, value),
111
None => {
112
null_idx.push(*idx);
113
},
114
}
115
}
116
});
117
118
// Only make a validity bitmap when null values are set.
119
if !null_idx.is_empty() {
120
let mut validity = MutableBitmap::with_capacity(arr.len());
121
validity.extend_constant(arr.len(), true);
122
for idx in null_idx {
123
validity.set_unchecked(idx as usize, false)
124
}
125
arr.set_validity(Some(validity.into()))
126
}
127
}
128
}
129
130
unsafe fn scatter_binview_impl<'a, V, T: ViewType + ?Sized>(
131
set_values: V,
132
arr: &mut BinaryViewArrayGeneric<T>,
133
idx: &[IdxSize],
134
) where
135
V: IntoIterator<Item = Option<&'a T>>,
136
{
137
let mut values_iter = set_values.into_iter();
138
let buffer_offset = arr.data_buffers().len() as u32;
139
let mut new_buffers = Vec::new();
140
141
if let Some(validity) = arr.take_validity() {
142
let mut mut_validity = validity.make_mut();
143
arr.with_views_mut(|views| {
144
for (idx, val) in idx.iter().zip(&mut values_iter) {
145
if let Some(v) = val {
146
let view =
147
View::new_with_buffers(v.to_bytes(), buffer_offset, &mut new_buffers);
148
*views.get_unchecked_mut(*idx as usize) = view;
149
mut_validity.set_unchecked(*idx as usize, true);
150
} else {
151
mut_validity.set_unchecked(*idx as usize, false);
152
}
153
}
154
});
155
arr.set_validity(mut_validity.into())
156
} else {
157
let mut null_idx = vec![];
158
arr.with_views_mut(|views| {
159
for (idx, val) in idx.iter().zip(values_iter) {
160
if let Some(v) = val {
161
let view =
162
View::new_with_buffers(v.to_bytes(), buffer_offset, &mut new_buffers);
163
*views.get_unchecked_mut(*idx as usize) = view;
164
} else {
165
null_idx.push(*idx);
166
}
167
}
168
});
169
170
// Only make a validity bitmap when null values are set.
171
if !null_idx.is_empty() {
172
let mut validity = MutableBitmap::with_capacity(arr.len());
173
validity.extend_constant(arr.len(), true);
174
for idx in null_idx {
175
validity.set_unchecked(idx as usize, false)
176
}
177
arr.set_validity(Some(validity.into()))
178
}
179
}
180
181
let mut buffers = Buffer::to_vec(core::mem::take(arr.data_buffers_mut()));
182
buffers.extend(new_buffers.into_iter().map(Buffer::from));
183
*arr.data_buffers_mut() = Buffer::from(buffers);
184
}
185
186
impl<T: PolarsOpsNumericType> ChunkedSet<T::Native> for &mut ChunkedArray<T> {
187
fn scatter<V>(self, idx: &[IdxSize], values: V) -> PolarsResult<Series>
188
where
189
V: IntoIterator<Item = Option<T::Native>>,
190
{
191
check_bounds(idx, self.len() as IdxSize)?;
192
let mut ca = std::mem::take(self);
193
194
// SAFETY: we will not modify the length and we unset the sorted flag,
195
// making sure to update the null count as well.
196
unsafe {
197
ca.rechunk_mut();
198
let arr = ca.downcast_iter_mut().next().unwrap();
199
scatter_primitive_impl(values, arr, idx);
200
let null_count = arr.null_count();
201
ca.set_sorted_flag(IsSorted::Not);
202
ca.set_null_count(null_count);
203
}
204
205
Ok(ca.into_series())
206
}
207
}
208
209
impl<'a> ChunkedSet<&'a [u8]> for &mut BinaryChunked {
210
fn scatter<V>(self, idx: &[IdxSize], values: V) -> PolarsResult<Series>
211
where
212
V: IntoIterator<Item = Option<&'a [u8]>>,
213
{
214
check_bounds(idx, self.len() as IdxSize)?;
215
let mut ca = std::mem::take(self);
216
217
unsafe {
218
ca.rechunk_mut();
219
let arr = ca.downcast_iter_mut().next().unwrap();
220
scatter_binview_impl(values, arr, idx);
221
let null_count = arr.null_count();
222
ca.set_sorted_flag(IsSorted::Not);
223
ca.set_null_count(null_count);
224
}
225
226
Ok(ca.into_series())
227
}
228
}
229
230
impl<'a> ChunkedSet<&'a str> for &mut StringChunked {
231
fn scatter<V>(self, idx: &[IdxSize], values: V) -> PolarsResult<Series>
232
where
233
V: IntoIterator<Item = Option<&'a str>>,
234
{
235
check_bounds(idx, self.len() as IdxSize)?;
236
let mut ca = std::mem::take(self);
237
238
unsafe {
239
ca.rechunk_mut();
240
let arr = ca.downcast_iter_mut().next().unwrap();
241
scatter_binview_impl(values, arr, idx);
242
let null_count = arr.null_count();
243
ca.set_sorted_flag(IsSorted::Not);
244
ca.set_null_count(null_count);
245
}
246
247
Ok(ca.into_series())
248
}
249
}
250
impl ChunkedSet<bool> for &mut BooleanChunked {
251
fn scatter<V>(self, idx: &[IdxSize], values: V) -> PolarsResult<Series>
252
where
253
V: IntoIterator<Item = Option<bool>>,
254
{
255
check_bounds(idx, self.len() as IdxSize)?;
256
let mut ca = std::mem::take(self);
257
258
unsafe {
259
ca.rechunk_mut();
260
let arr = ca.downcast_iter_mut().next().unwrap();
261
scatter_bool_impl(values, arr, idx);
262
let null_count = arr.null_count();
263
ca.set_sorted_flag(IsSorted::Not);
264
ca.set_null_count(null_count);
265
}
266
267
Ok(ca.into_series())
268
}
269
}
270
271