Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-ops/src/chunked_array/scatter.rs
6939 views
1
#![allow(unsafe_op_in_unsafe_fn)]
2
use arrow::array::{Array, PrimitiveArray};
3
use polars_core::prelude::*;
4
use polars_core::series::IsSorted;
5
use polars_core::utils::arrow::bitmap::MutableBitmap;
6
use polars_core::utils::arrow::types::NativeType;
7
use polars_utils::index::check_bounds;
8
9
pub trait ChunkedSet<T: Copy> {
10
/// Invariant for implementations: if the scatter() fails, typically because
11
/// of bad indexes, then self should remain unmodified.
12
fn scatter<V>(self, idx: &[IdxSize], values: V) -> PolarsResult<Series>
13
where
14
V: IntoIterator<Item = Option<T>>;
15
}
16
fn check_sorted(idx: &[IdxSize]) -> PolarsResult<()> {
17
if idx.is_empty() {
18
return Ok(());
19
}
20
let mut sorted = true;
21
let mut previous = idx[0];
22
for &i in &idx[1..] {
23
if i < previous {
24
// we will not break here as that prevents SIMD
25
sorted = false;
26
}
27
previous = i;
28
}
29
polars_ensure!(sorted, ComputeError: "set indices must be sorted");
30
Ok(())
31
}
32
33
trait PolarsOpsNumericType: PolarsNumericType {}
34
35
impl PolarsOpsNumericType for UInt8Type {}
36
impl PolarsOpsNumericType for UInt16Type {}
37
impl PolarsOpsNumericType for UInt32Type {}
38
impl PolarsOpsNumericType for UInt64Type {}
39
impl PolarsOpsNumericType for Int8Type {}
40
impl PolarsOpsNumericType for Int16Type {}
41
impl PolarsOpsNumericType for Int32Type {}
42
impl PolarsOpsNumericType for Int64Type {}
43
#[cfg(feature = "dtype-i128")]
44
impl PolarsOpsNumericType for Int128Type {}
45
impl PolarsOpsNumericType for Float32Type {}
46
impl PolarsOpsNumericType for Float64Type {}
47
48
unsafe fn scatter_impl<V, T: NativeType>(
49
new_values_slice: &mut [T],
50
set_values: V,
51
arr: &mut PrimitiveArray<T>,
52
idx: &[IdxSize],
53
len: usize,
54
) where
55
V: IntoIterator<Item = Option<T>>,
56
{
57
let mut values_iter = set_values.into_iter();
58
59
if arr.null_count() > 0 {
60
arr.apply_validity(|v| {
61
let mut mut_validity = v.make_mut();
62
63
for (idx, val) in idx.iter().zip(&mut values_iter) {
64
match val {
65
Some(value) => {
66
mut_validity.set_unchecked(*idx as usize, true);
67
*new_values_slice.get_unchecked_mut(*idx as usize) = value
68
},
69
None => mut_validity.set_unchecked(*idx as usize, false),
70
}
71
}
72
mut_validity.into()
73
})
74
} else {
75
let mut null_idx = vec![];
76
for (idx, val) in idx.iter().zip(values_iter) {
77
match val {
78
Some(value) => *new_values_slice.get_unchecked_mut(*idx as usize) = value,
79
None => {
80
null_idx.push(*idx);
81
},
82
}
83
}
84
// only make a validity bitmap when null values are set
85
if !null_idx.is_empty() {
86
let mut validity = MutableBitmap::with_capacity(len);
87
validity.extend_constant(len, true);
88
for idx in null_idx {
89
validity.set_unchecked(idx as usize, false)
90
}
91
arr.set_validity(Some(validity.into()))
92
}
93
}
94
}
95
96
impl<T: PolarsOpsNumericType> ChunkedSet<T::Native> for &mut ChunkedArray<T> {
97
fn scatter<V>(self, idx: &[IdxSize], values: V) -> PolarsResult<Series>
98
where
99
V: IntoIterator<Item = Option<T::Native>>,
100
{
101
check_bounds(idx, self.len() as IdxSize)?;
102
let mut ca = std::mem::take(self);
103
ca.rechunk_mut();
104
105
// SAFETY:
106
// we will not modify the length
107
// and we unset the sorted flag.
108
ca.set_sorted_flag(IsSorted::Not);
109
let arr = unsafe { ca.downcast_iter_mut() }.next().unwrap();
110
let len = arr.len();
111
112
match arr.get_mut_values() {
113
Some(current_values) => {
114
let ptr = current_values.as_mut_ptr();
115
116
// reborrow because the bck does not allow it
117
let current_values = unsafe { &mut *std::slice::from_raw_parts_mut(ptr, len) };
118
// SAFETY:
119
// we checked bounds
120
unsafe { scatter_impl(current_values, values, arr, idx, len) };
121
},
122
None => {
123
let mut new_values = arr.values().as_slice().to_vec();
124
// SAFETY:
125
// we checked bounds
126
unsafe { scatter_impl(&mut new_values, values, arr, idx, len) };
127
arr.set_values(new_values.into());
128
},
129
};
130
131
// The null count may have changed - make sure to update the ChunkedArray
132
let new_null_count = arr.null_count();
133
unsafe { ca.set_null_count(new_null_count) };
134
135
Ok(ca.into_series())
136
}
137
}
138
139
impl<'a> ChunkedSet<&'a str> for &'a StringChunked {
140
fn scatter<V>(self, idx: &[IdxSize], values: V) -> PolarsResult<Series>
141
where
142
V: IntoIterator<Item = Option<&'a str>>,
143
{
144
check_bounds(idx, self.len() as IdxSize)?;
145
check_sorted(idx)?;
146
let mut ca_iter = self.into_iter().enumerate();
147
let mut builder = StringChunkedBuilder::new(self.name().clone(), self.len());
148
149
for (current_idx, current_value) in idx.iter().zip(values) {
150
for (cnt_idx, opt_val_self) in &mut ca_iter {
151
if cnt_idx == *current_idx as usize {
152
builder.append_option(current_value);
153
break;
154
} else {
155
builder.append_option(opt_val_self);
156
}
157
}
158
}
159
// the last idx is probably not the last value so we finish the iterator
160
for (_, opt_val_self) in ca_iter {
161
builder.append_option(opt_val_self);
162
}
163
164
let ca = builder.finish();
165
Ok(ca.into_series())
166
}
167
}
168
impl ChunkedSet<bool> for &BooleanChunked {
169
fn scatter<V>(self, idx: &[IdxSize], values: V) -> PolarsResult<Series>
170
where
171
V: IntoIterator<Item = Option<bool>>,
172
{
173
check_bounds(idx, self.len() as IdxSize)?;
174
check_sorted(idx)?;
175
let mut ca_iter = self.into_iter().enumerate();
176
let mut builder = BooleanChunkedBuilder::new(self.name().clone(), self.len());
177
178
for (current_idx, current_value) in idx.iter().zip(values) {
179
for (cnt_idx, opt_val_self) in &mut ca_iter {
180
if cnt_idx == *current_idx as usize {
181
builder.append_option(current_value);
182
break;
183
} else {
184
builder.append_option(opt_val_self);
185
}
186
}
187
}
188
// the last idx is probably not the last value so we finish the iterator
189
for (_, opt_val_self) in ca_iter {
190
builder.append_option(opt_val_self);
191
}
192
193
let ca = builder.finish();
194
Ok(ca.into_series())
195
}
196
}
197
198