Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-ops/src/series/ops/interpolation/interpolate.rs
6939 views
1
use std::ops::{Add, Div, Mul, Sub};
2
3
use arrow::array::PrimitiveArray;
4
use arrow::bitmap::MutableBitmap;
5
use num_traits::{NumCast, Zero};
6
use polars_core::downcast_as_macro_arg_physical;
7
use polars_core::prelude::*;
8
#[cfg(feature = "serde")]
9
use serde::{Deserialize, Serialize};
10
11
use super::{linear_itp, nearest_itp};
12
13
fn near_interp<T>(low: T, high: T, steps: IdxSize, steps_n: T, out: &mut Vec<T>)
14
where
15
T: Sub<Output = T>
16
+ Mul<Output = T>
17
+ Add<Output = T>
18
+ Div<Output = T>
19
+ NumCast
20
+ Copy
21
+ PartialOrd,
22
{
23
let diff = high - low;
24
for step_i in 1..steps {
25
let step_i: T = NumCast::from(step_i).unwrap();
26
let v = nearest_itp(low, step_i, diff, steps_n);
27
out.push(v)
28
}
29
}
30
31
#[inline]
32
fn signed_interp<T>(low: T, high: T, steps: IdxSize, steps_n: T, out: &mut Vec<T>)
33
where
34
T: Sub<Output = T> + Mul<Output = T> + Add<Output = T> + Div<Output = T> + NumCast + Copy,
35
{
36
let slope = (high - low) / steps_n;
37
for step_i in 1..steps {
38
let step_i: T = NumCast::from(step_i).unwrap();
39
let v = linear_itp(low, step_i, slope);
40
out.push(v)
41
}
42
}
43
44
fn interpolate_impl<T, I>(chunked_arr: &ChunkedArray<T>, interpolation_branch: I) -> ChunkedArray<T>
45
where
46
T: PolarsNumericType,
47
I: Fn(T::Native, T::Native, IdxSize, T::Native, &mut Vec<T::Native>),
48
{
49
// This implementation differs from pandas as that boundary None's are not removed.
50
// This prevents a lot of errors due to expressions leading to different lengths.
51
if !chunked_arr.has_nulls() || chunked_arr.null_count() == chunked_arr.len() {
52
return chunked_arr.clone();
53
}
54
55
// We first find the first and last so that we can set the null buffer.
56
let first = chunked_arr.first_non_null().unwrap();
57
let last = chunked_arr.last_non_null().unwrap() + 1;
58
59
// Fill out with `first` nulls.
60
let mut out = Vec::with_capacity(chunked_arr.len());
61
let mut iter = chunked_arr.iter().skip(first);
62
for _ in 0..first {
63
out.push(Zero::zero());
64
}
65
66
// The next element of `iter` is definitely `Some(Some(v))`, because we skipped the first
67
// elements `first` and if all values were missing we'd have done an early return.
68
let mut low = iter.next().unwrap().unwrap();
69
out.push(low);
70
while let Some(next) = iter.next() {
71
if let Some(v) = next {
72
out.push(v);
73
low = v;
74
} else {
75
let mut steps = 1 as IdxSize;
76
for next in iter.by_ref() {
77
steps += 1;
78
if let Some(high) = next {
79
let steps_n: T::Native = NumCast::from(steps).unwrap();
80
interpolation_branch(low, high, steps, steps_n, &mut out);
81
out.push(high);
82
low = high;
83
break;
84
}
85
}
86
}
87
}
88
if first != 0 || last != chunked_arr.len() {
89
let mut validity = MutableBitmap::with_capacity(chunked_arr.len());
90
validity.extend_constant(chunked_arr.len(), true);
91
92
for i in 0..first {
93
unsafe { validity.set_unchecked(i, false) };
94
}
95
96
for i in last..chunked_arr.len() {
97
unsafe { validity.set_unchecked(i, false) };
98
out.push(Zero::zero())
99
}
100
101
let array = PrimitiveArray::new(
102
T::get_static_dtype().to_arrow(CompatLevel::newest()),
103
out.into(),
104
Some(validity.into()),
105
);
106
ChunkedArray::with_chunk(chunked_arr.name().clone(), array)
107
} else {
108
ChunkedArray::from_vec(chunked_arr.name().clone(), out)
109
}
110
}
111
112
fn interpolate_nearest(s: &Series) -> Series {
113
match s.dtype() {
114
#[cfg(feature = "dtype-categorical")]
115
DataType::Categorical(_, _) | DataType::Enum(_, _) => s.clone(),
116
DataType::Binary => s.clone(),
117
#[cfg(feature = "dtype-struct")]
118
DataType::Struct(_) => s.clone(),
119
DataType::List(_) => s.clone(),
120
_ => {
121
let logical = s.dtype();
122
let s = s.to_physical_repr();
123
124
macro_rules! dispatch {
125
($ca:expr) => {{ interpolate_impl($ca, near_interp).into_series() }};
126
}
127
let out = downcast_as_macro_arg_physical!(s, dispatch);
128
match logical {
129
#[cfg(feature = "dtype-decimal")]
130
DataType::Decimal(_, _) => unsafe { out.from_physical_unchecked(logical).unwrap() },
131
_ => out.cast(logical).unwrap(),
132
}
133
},
134
}
135
}
136
137
fn interpolate_linear(s: &Series) -> Series {
138
match s.dtype() {
139
#[cfg(feature = "dtype-categorical")]
140
DataType::Categorical(_, _) | DataType::Enum(_, _) => s.clone(),
141
DataType::Binary => s.clone(),
142
#[cfg(feature = "dtype-struct")]
143
DataType::Struct(_) => s.clone(),
144
DataType::List(_) => s.clone(),
145
_ => {
146
let logical = s.dtype();
147
148
let s = s.to_physical_repr();
149
150
#[cfg(feature = "dtype-decimal")]
151
{
152
if matches!(logical, DataType::Decimal(_, _)) {
153
let out = linear_interp_signed(s.i128().unwrap());
154
return unsafe { out.from_physical_unchecked(logical).unwrap() };
155
}
156
}
157
158
let out = if matches!(
159
logical,
160
DataType::Date | DataType::Datetime(_, _) | DataType::Duration(_) | DataType::Time
161
) {
162
match s.dtype() {
163
// Datetime, Time, or Duration
164
DataType::Int64 => linear_interp_signed(s.i64().unwrap()),
165
// Date
166
DataType::Int32 => linear_interp_signed(s.i32().unwrap()),
167
_ => unreachable!(),
168
}
169
} else {
170
match s.dtype() {
171
DataType::Float32 => linear_interp_signed(s.f32().unwrap()),
172
DataType::Float64 => linear_interp_signed(s.f64().unwrap()),
173
DataType::Int8
174
| DataType::Int16
175
| DataType::Int32
176
| DataType::Int64
177
| DataType::Int128
178
| DataType::UInt8
179
| DataType::UInt16
180
| DataType::UInt32
181
| DataType::UInt64 => {
182
linear_interp_signed(s.cast(&DataType::Float64).unwrap().f64().unwrap())
183
},
184
_ => s.as_ref().clone(),
185
}
186
};
187
match logical {
188
DataType::Date
189
| DataType::Datetime(_, _)
190
| DataType::Duration(_)
191
| DataType::Time => out.cast(logical).unwrap(),
192
_ => out,
193
}
194
},
195
}
196
}
197
198
fn linear_interp_signed<T: PolarsNumericType>(ca: &ChunkedArray<T>) -> Series {
199
interpolate_impl(ca, signed_interp::<T::Native>).into_series()
200
}
201
202
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
203
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
204
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
205
pub enum InterpolationMethod {
206
Linear,
207
Nearest,
208
}
209
210
pub fn interpolate(s: &Series, method: InterpolationMethod) -> Series {
211
match method {
212
InterpolationMethod::Linear => interpolate_linear(s),
213
InterpolationMethod::Nearest => interpolate_nearest(s),
214
}
215
}
216
217
#[cfg(test)]
218
mod test {
219
use super::*;
220
221
#[test]
222
fn test_interpolate() {
223
let ca = UInt32Chunked::new("".into(), &[Some(1), None, None, Some(4), Some(5)]);
224
let out = interpolate(&ca.into_series(), InterpolationMethod::Linear);
225
let out = out.f64().unwrap();
226
assert_eq!(
227
Vec::from(out),
228
&[Some(1.0), Some(2.0), Some(3.0), Some(4.0), Some(5.0)]
229
);
230
231
let ca = UInt32Chunked::new("".into(), &[None, Some(1), None, None, Some(4), Some(5)]);
232
let out = interpolate(&ca.into_series(), InterpolationMethod::Linear);
233
let out = out.f64().unwrap();
234
assert_eq!(
235
Vec::from(out),
236
&[None, Some(1.0), Some(2.0), Some(3.0), Some(4.0), Some(5.0)]
237
);
238
239
let ca = UInt32Chunked::new(
240
"".into(),
241
&[None, Some(1), None, None, Some(4), Some(5), None],
242
);
243
let out = interpolate(&ca.into_series(), InterpolationMethod::Linear);
244
let out = out.f64().unwrap();
245
assert_eq!(
246
Vec::from(out),
247
&[
248
None,
249
Some(1.0),
250
Some(2.0),
251
Some(3.0),
252
Some(4.0),
253
Some(5.0),
254
None
255
]
256
);
257
let ca = UInt32Chunked::new(
258
"".into(),
259
&[None, Some(1), None, None, Some(4), Some(5), None],
260
);
261
let out = interpolate(&ca.into_series(), InterpolationMethod::Nearest);
262
let out = out.u32().unwrap();
263
assert_eq!(
264
Vec::from(out),
265
&[None, Some(1), Some(1), Some(4), Some(4), Some(5), None]
266
);
267
}
268
269
#[test]
270
fn test_interpolate_decreasing_unsigned() {
271
let ca = UInt32Chunked::new("".into(), &[Some(4), None, None, Some(1)]);
272
let out = interpolate(&ca.into_series(), InterpolationMethod::Linear);
273
let out = out.f64().unwrap();
274
assert_eq!(
275
Vec::from(out),
276
&[Some(4.0), Some(3.0), Some(2.0), Some(1.0)]
277
)
278
}
279
280
#[test]
281
fn test_interpolate2() {
282
let ca = Float32Chunked::new(
283
"".into(),
284
&[
285
Some(4653f32),
286
None,
287
None,
288
None,
289
Some(4657f32),
290
None,
291
None,
292
Some(4657f32),
293
None,
294
Some(4657f32),
295
None,
296
None,
297
Some(4660f32),
298
],
299
);
300
let out = interpolate(&ca.into_series(), InterpolationMethod::Linear);
301
let out = out.f32().unwrap();
302
303
assert_eq!(
304
Vec::from(out),
305
&[
306
Some(4653.0),
307
Some(4654.0),
308
Some(4655.0),
309
Some(4656.0),
310
Some(4657.0),
311
Some(4657.0),
312
Some(4657.0),
313
Some(4657.0),
314
Some(4657.0),
315
Some(4657.0),
316
Some(4658.0),
317
Some(4659.0),
318
Some(4660.0)
319
]
320
);
321
}
322
}
323
324