Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-ops/src/series/ops/interpolation/interpolate_by.rs
6939 views
1
use std::ops::{Add, Div, Mul, Sub};
2
3
use arrow::array::PrimitiveArray;
4
use arrow::bitmap::MutableBitmap;
5
use bytemuck::allocation::zeroed_vec;
6
use num_traits::{NumCast, Zero};
7
use polars_core::prelude::*;
8
use polars_utils::slice::SliceAble;
9
10
use super::linear_itp;
11
12
/// # Safety
13
/// - `x` must be non-empty.
14
#[inline]
15
unsafe fn signed_interp_by_sorted<T, F>(y_start: T, y_end: T, x: &[F], out: &mut Vec<T>)
16
where
17
T: Sub<Output = T>
18
+ Mul<Output = T>
19
+ Add<Output = T>
20
+ Div<Output = T>
21
+ NumCast
22
+ Copy
23
+ Zero,
24
F: Sub<Output = F> + NumCast + Copy,
25
{
26
let range_y = y_end - y_start;
27
let x_start;
28
let range_x;
29
let iter;
30
unsafe {
31
x_start = x.get_unchecked(0);
32
range_x = NumCast::from(*x.get_unchecked(x.len() - 1) - *x_start).unwrap();
33
iter = x.slice_unchecked(1..x.len() - 1).iter();
34
}
35
let slope = range_y / range_x;
36
for x_i in iter {
37
let x_delta = NumCast::from(*x_i - *x_start).unwrap();
38
let v = linear_itp(y_start, x_delta, slope);
39
out.push(v)
40
}
41
}
42
43
/// # Safety
44
/// - `x` must be non-empty.
45
/// - `sorting_indices` must be the same size as `x`
46
#[inline]
47
unsafe fn signed_interp_by<T, F>(
48
y_start: T,
49
y_end: T,
50
x: &[F],
51
out: &mut [T],
52
sorting_indices: &[IdxSize],
53
) where
54
T: Sub<Output = T>
55
+ Mul<Output = T>
56
+ Add<Output = T>
57
+ Div<Output = T>
58
+ NumCast
59
+ Copy
60
+ Zero,
61
F: Sub<Output = F> + NumCast + Copy,
62
{
63
let range_y = y_end - y_start;
64
let x_start;
65
let range_x;
66
let iter;
67
unsafe {
68
x_start = x.get_unchecked(0);
69
range_x = NumCast::from(*x.get_unchecked(x.len() - 1) - *x_start).unwrap();
70
iter = x.slice_unchecked(1..x.len() - 1).iter();
71
}
72
let slope = range_y / range_x;
73
for (idx, x_i) in iter.enumerate() {
74
let x_delta = NumCast::from(*x_i - *x_start).unwrap();
75
let v = linear_itp(y_start, x_delta, slope);
76
unsafe {
77
let out_idx = sorting_indices.get_unchecked(idx + 1);
78
*out.get_unchecked_mut(*out_idx as usize) = v;
79
}
80
}
81
}
82
83
fn interpolate_impl_by_sorted<T, F, I>(
84
chunked_arr: &ChunkedArray<T>,
85
by: &ChunkedArray<F>,
86
interpolation_branch: I,
87
) -> PolarsResult<ChunkedArray<T>>
88
where
89
T: PolarsNumericType,
90
F: PolarsNumericType,
91
I: Fn(T::Native, T::Native, &[F::Native], &mut Vec<T::Native>),
92
{
93
// This implementation differs from pandas as that boundary None's are not removed.
94
// This prevents a lot of errors due to expressions leading to different lengths.
95
if !chunked_arr.has_nulls() || chunked_arr.null_count() == chunked_arr.len() {
96
return Ok(chunked_arr.clone());
97
}
98
99
polars_ensure!(by.null_count() == 0, InvalidOperation: "null values in `by` column are not yet supported in 'interpolate_by' expression");
100
let by = by.rechunk();
101
let by_values = by.cont_slice().unwrap();
102
103
// We first find the first and last so that we can set the null buffer.
104
let first = chunked_arr.first_non_null().unwrap();
105
let last = chunked_arr.last_non_null().unwrap() + 1;
106
107
// Fill out with `first` nulls.
108
let mut out = Vec::with_capacity(chunked_arr.len());
109
let mut iter = chunked_arr.iter().enumerate().skip(first);
110
for _ in 0..first {
111
out.push(Zero::zero());
112
}
113
114
// The next element of `iter` is definitely `Some(idx, Some(v))`, because we skipped the first
115
// `first` elements and if all values were missing we'd have done an early return.
116
let (mut low_idx, opt_low) = iter.next().unwrap();
117
let mut low = opt_low.unwrap();
118
out.push(low);
119
while let Some((idx, next)) = iter.next() {
120
if let Some(v) = next {
121
out.push(v);
122
low = v;
123
low_idx = idx;
124
} else {
125
for (high_idx, next) in iter.by_ref() {
126
if let Some(high) = next {
127
// SAFETY: we are in bounds, and `x` is non-empty.
128
unsafe {
129
let x = &by_values.slice_unchecked(low_idx..high_idx + 1);
130
interpolation_branch(low, high, x, &mut out);
131
}
132
out.push(high);
133
low = high;
134
low_idx = high_idx;
135
break;
136
}
137
}
138
}
139
}
140
if first != 0 || last != chunked_arr.len() {
141
let mut validity = MutableBitmap::with_capacity(chunked_arr.len());
142
validity.extend_constant(chunked_arr.len(), true);
143
144
for i in 0..first {
145
unsafe { validity.set_unchecked(i, false) };
146
}
147
148
for i in last..chunked_arr.len() {
149
unsafe { validity.set_unchecked(i, false) }
150
out.push(Zero::zero());
151
}
152
153
let array = PrimitiveArray::new(
154
T::get_static_dtype().to_arrow(CompatLevel::newest()),
155
out.into(),
156
Some(validity.into()),
157
);
158
Ok(ChunkedArray::with_chunk(chunked_arr.name().clone(), array))
159
} else {
160
Ok(ChunkedArray::from_vec(chunked_arr.name().clone(), out))
161
}
162
}
163
164
// Sort on behalf of user
165
fn interpolate_impl_by<T, F, I>(
166
ca: &ChunkedArray<T>,
167
by: &ChunkedArray<F>,
168
interpolation_branch: I,
169
) -> PolarsResult<ChunkedArray<T>>
170
where
171
T: PolarsNumericType,
172
F: PolarsNumericType,
173
I: Fn(T::Native, T::Native, &[F::Native], &mut [T::Native], &[IdxSize]),
174
{
175
// This implementation differs from pandas as that boundary None's are not removed.
176
// This prevents a lot of errors due to expressions leading to different lengths.
177
if !ca.has_nulls() || ca.null_count() == ca.len() {
178
return Ok(ca.clone());
179
}
180
181
polars_ensure!(by.null_count() == 0, InvalidOperation: "null values in `by` column are not yet supported in 'interpolate_by' expression");
182
let sorting_indices = by.arg_sort(Default::default());
183
let sorting_indices = sorting_indices
184
.cont_slice()
185
.expect("arg sort produces single chunk");
186
let by_sorted = unsafe { by.take_unchecked(sorting_indices) };
187
let ca_sorted = unsafe { ca.take_unchecked(sorting_indices) };
188
let by_sorted_values = by_sorted
189
.cont_slice()
190
.expect("We already checked for nulls, and `take_unchecked` produces single chunk");
191
192
// We first find the first and last so that we can set the null buffer.
193
let first = ca_sorted.first_non_null().unwrap();
194
let last = ca_sorted.last_non_null().unwrap() + 1;
195
196
let mut out = zeroed_vec(ca_sorted.len());
197
let mut iter = ca_sorted.iter().enumerate().skip(first);
198
199
// The next element of `iter` is definitely `Some(idx, Some(v))`, because we skipped the first
200
// `first` elements and if all values were missing we'd have done an early return.
201
let (mut low_idx, opt_low) = iter.next().unwrap();
202
let mut low = opt_low.unwrap();
203
unsafe {
204
let out_idx = sorting_indices.get_unchecked(low_idx);
205
*out.get_unchecked_mut(*out_idx as usize) = low;
206
}
207
while let Some((idx, next)) = iter.next() {
208
if let Some(v) = next {
209
unsafe {
210
let out_idx = sorting_indices.get_unchecked(idx);
211
*out.get_unchecked_mut(*out_idx as usize) = v;
212
}
213
low = v;
214
low_idx = idx;
215
} else {
216
for (high_idx, next) in iter.by_ref() {
217
if let Some(high) = next {
218
// SAFETY: we are in bounds, and the slices are the same length (and non-empty).
219
unsafe {
220
interpolation_branch(
221
low,
222
high,
223
by_sorted_values.slice_unchecked(low_idx..high_idx + 1),
224
&mut out,
225
sorting_indices.slice_unchecked(low_idx..high_idx + 1),
226
);
227
let out_idx = sorting_indices.get_unchecked(high_idx);
228
*out.get_unchecked_mut(*out_idx as usize) = high;
229
}
230
low = high;
231
low_idx = high_idx;
232
break;
233
}
234
}
235
}
236
}
237
if first != 0 || last != ca_sorted.len() {
238
let mut validity = MutableBitmap::with_capacity(ca_sorted.len());
239
validity.extend_constant(ca_sorted.len(), true);
240
241
for i in 0..first {
242
unsafe {
243
let out_idx = sorting_indices.get_unchecked(i);
244
validity.set_unchecked(*out_idx as usize, false);
245
}
246
}
247
248
for i in last..ca_sorted.len() {
249
unsafe {
250
let out_idx = sorting_indices.get_unchecked(i);
251
validity.set_unchecked(*out_idx as usize, false);
252
}
253
}
254
255
let array = PrimitiveArray::new(
256
T::get_static_dtype().to_arrow(CompatLevel::newest()),
257
out.into(),
258
Some(validity.into()),
259
);
260
Ok(ChunkedArray::with_chunk(ca_sorted.name().clone(), array))
261
} else {
262
Ok(ChunkedArray::from_vec(ca_sorted.name().clone(), out))
263
}
264
}
265
266
pub fn interpolate_by(s: &Column, by: &Column, by_is_sorted: bool) -> PolarsResult<Column> {
267
polars_ensure!(s.len() == by.len(), InvalidOperation: "`by` column must be the same length as Series ({}), got {}", s.len(), by.len());
268
269
fn func<T, F>(
270
ca: &ChunkedArray<T>,
271
by: &ChunkedArray<F>,
272
is_sorted: bool,
273
) -> PolarsResult<Column>
274
where
275
T: PolarsNumericType,
276
F: PolarsNumericType,
277
ChunkedArray<T>: IntoColumn,
278
{
279
if is_sorted {
280
interpolate_impl_by_sorted(ca, by, |y_start, y_end, x, out| unsafe {
281
signed_interp_by_sorted(y_start, y_end, x, out)
282
})
283
.map(|x| x.into_column())
284
} else {
285
interpolate_impl_by(ca, by, |y_start, y_end, x, out, sorting_indices| unsafe {
286
signed_interp_by(y_start, y_end, x, out, sorting_indices)
287
})
288
.map(|x| x.into_column())
289
}
290
}
291
292
match (s.dtype(), by.dtype()) {
293
(DataType::Float64, DataType::Float64) => {
294
func(s.f64().unwrap(), by.f64().unwrap(), by_is_sorted)
295
},
296
(DataType::Float64, DataType::Float32) => {
297
func(s.f64().unwrap(), by.f32().unwrap(), by_is_sorted)
298
},
299
(DataType::Float32, DataType::Float64) => {
300
func(s.f32().unwrap(), by.f64().unwrap(), by_is_sorted)
301
},
302
(DataType::Float32, DataType::Float32) => {
303
func(s.f32().unwrap(), by.f32().unwrap(), by_is_sorted)
304
},
305
(DataType::Float64, DataType::Int64) => {
306
func(s.f64().unwrap(), by.i64().unwrap(), by_is_sorted)
307
},
308
(DataType::Float64, DataType::Int32) => {
309
func(s.f64().unwrap(), by.i32().unwrap(), by_is_sorted)
310
},
311
(DataType::Float64, DataType::UInt64) => {
312
func(s.f64().unwrap(), by.u64().unwrap(), by_is_sorted)
313
},
314
(DataType::Float64, DataType::UInt32) => {
315
func(s.f64().unwrap(), by.u32().unwrap(), by_is_sorted)
316
},
317
(DataType::Float32, DataType::Int64) => {
318
func(s.f32().unwrap(), by.i64().unwrap(), by_is_sorted)
319
},
320
(DataType::Float32, DataType::Int32) => {
321
func(s.f32().unwrap(), by.i32().unwrap(), by_is_sorted)
322
},
323
(DataType::Float32, DataType::UInt64) => {
324
func(s.f32().unwrap(), by.u64().unwrap(), by_is_sorted)
325
},
326
(DataType::Float32, DataType::UInt32) => {
327
func(s.f32().unwrap(), by.u32().unwrap(), by_is_sorted)
328
},
329
#[cfg(feature = "dtype-date")]
330
(_, DataType::Date) => interpolate_by(s, &by.cast(&DataType::Int32).unwrap(), by_is_sorted),
331
#[cfg(feature = "dtype-datetime")]
332
(_, DataType::Datetime(_, _)) => {
333
interpolate_by(s, &by.cast(&DataType::Int64).unwrap(), by_is_sorted)
334
},
335
(DataType::UInt64 | DataType::UInt32 | DataType::Int64 | DataType::Int32, _) => {
336
interpolate_by(&s.cast(&DataType::Float64).unwrap(), by, by_is_sorted)
337
},
338
_ => {
339
polars_bail!(InvalidOperation: "expected series to be Float64, Float32, \
340
Int64, Int32, UInt64, UInt32, and `by` to be Date, Datetime, Int64, Int32, \
341
UInt64, UInt32, Float32 or Float64")
342
},
343
}
344
}
345
346