Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-expr/src/dispatch/range/linear_space.rs
7884 views
1
use arrow::temporal_conversions::MICROSECONDS_IN_DAY;
2
use polars_core::prelude::*;
3
use polars_ops::series::{ClosedInterval, new_linear_space_f32, new_linear_space_f64};
4
5
use super::utils::{build_nulls, ensure_items_contain_exactly_one_value};
6
7
const CAPACITY_FACTOR: usize = 5;
8
9
pub(super) fn linear_space(s: &[Column], closed: ClosedInterval) -> PolarsResult<Column> {
10
let start = &s[0];
11
let end = &s[1];
12
let num_samples = &s[2];
13
let name = start.name();
14
15
ensure_items_contain_exactly_one_value(&[start, end], &["start", "end"])?;
16
polars_ensure!(
17
num_samples.len() == 1,
18
ComputeError: "`num_samples` must contain exactly one value, got {} values", num_samples.len()
19
);
20
21
let start = start.get(0).unwrap();
22
let end = end.get(0).unwrap();
23
let num_samples = num_samples.get(0).unwrap();
24
let num_samples = num_samples
25
.extract::<u64>()
26
.ok_or(PolarsError::ComputeError(
27
format!("'num_samples' must be non-negative integer, got {num_samples}").into(),
28
))?;
29
30
match (start.dtype(), end.dtype()) {
31
(DataType::Float32, DataType::Float32) => new_linear_space_f32(
32
start.extract::<f32>().unwrap(),
33
end.extract::<f32>().unwrap(),
34
num_samples,
35
closed,
36
name.clone(),
37
)
38
.map(|s| s.into_column()),
39
(mut dt, dt2) if dt.is_temporal() && dt == dt2 => {
40
let mut start = start.extract::<i64>().unwrap();
41
let mut end = end.extract::<i64>().unwrap();
42
43
// A linear space of a Date produces a sequence of Datetimes, so we must upcast.
44
if dt == DataType::Date {
45
start *= MICROSECONDS_IN_DAY;
46
end *= MICROSECONDS_IN_DAY;
47
dt = DataType::Datetime(TimeUnit::Microseconds, None);
48
}
49
new_linear_space_f64(start as f64, end as f64, num_samples, closed, name.clone())
50
.map(|s| s.cast(&dt).unwrap().into_column())
51
},
52
(dt1, dt2) if !dt1.is_primitive_numeric() || !dt2.is_primitive_numeric() => {
53
Err(PolarsError::ComputeError(
54
format!("'start' and 'end' have incompatible dtypes, got {dt1:?} and {dt2:?}")
55
.into(),
56
))
57
},
58
(_, _) => new_linear_space_f64(
59
start.extract::<f64>().unwrap(),
60
end.extract::<f64>().unwrap(),
61
num_samples,
62
closed,
63
name.clone(),
64
)
65
.map(|s| s.into_column()),
66
}
67
}
68
69
pub(super) fn linear_spaces(
70
s: &[Column],
71
closed: ClosedInterval,
72
array_width: Option<usize>,
73
) -> PolarsResult<Column> {
74
let start = &s[0];
75
let end = &s[1];
76
77
let (num_samples, capacity_factor) = match array_width {
78
Some(ns) => {
79
// An array width is provided instead of a column of `num_sample`s.
80
let scalar = Scalar::new(DataType::UInt64, AnyValue::UInt64(ns as u64));
81
(&Column::new_scalar(PlSmallStr::EMPTY, scalar, 1), ns)
82
},
83
None => (&s[2], CAPACITY_FACTOR),
84
};
85
let name = start.name().clone();
86
87
let num_samples = num_samples.strict_cast(&DataType::UInt64)?;
88
let num_samples = num_samples.u64()?;
89
let len = start.len().max(end.len()).max(num_samples.len());
90
91
match (start.dtype(), end.dtype()) {
92
(DataType::Float32, DataType::Float32) => {
93
let mut builder = ListPrimitiveChunkedBuilder::<Float32Type>::new(
94
name,
95
len,
96
len * capacity_factor,
97
DataType::Float32,
98
);
99
100
let linspace_impl =
101
|start,
102
end,
103
num_samples,
104
builder: &mut ListPrimitiveChunkedBuilder<Float32Type>| {
105
let ls =
106
new_linear_space_f32(start, end, num_samples, closed, PlSmallStr::EMPTY)?;
107
builder.append_slice(ls.cont_slice().unwrap());
108
Ok(())
109
};
110
111
let start = start.f32()?;
112
let end = end.f32()?;
113
let out =
114
linear_spaces_impl_broadcast(start, end, num_samples, linspace_impl, &mut builder)?;
115
116
let to_type = array_width.map_or_else(
117
|| DataType::List(Box::new(DataType::Float32)),
118
|width| DataType::Array(Box::new(DataType::Float32), width),
119
);
120
out.cast(&to_type)
121
},
122
(mut dt, dt2) if dt.is_temporal() && dt == dt2 => {
123
let mut start = start.to_physical_repr();
124
let mut end = end.to_physical_repr();
125
126
// A linear space of a Date produces a sequence of Datetimes, so we must upcast.
127
if dt == &DataType::Date {
128
start = start.cast(&DataType::Int64)? * MICROSECONDS_IN_DAY;
129
end = end.cast(&DataType::Int64)? * MICROSECONDS_IN_DAY;
130
dt = &DataType::Datetime(TimeUnit::Microseconds, None);
131
}
132
133
let start = start.cast(&DataType::Float64)?;
134
let start = start.f64()?;
135
let end = end.cast(&DataType::Float64)?;
136
let end = end.f64()?;
137
138
let mut builder = ListPrimitiveChunkedBuilder::<Float64Type>::new(
139
name,
140
len,
141
len * capacity_factor,
142
DataType::Float64,
143
);
144
145
let linspace_impl =
146
|start,
147
end,
148
num_samples,
149
builder: &mut ListPrimitiveChunkedBuilder<Float64Type>| {
150
let ls =
151
new_linear_space_f64(start, end, num_samples, closed, PlSmallStr::EMPTY)?;
152
builder.append_slice(ls.cont_slice().unwrap());
153
Ok(())
154
};
155
let out =
156
linear_spaces_impl_broadcast(start, end, num_samples, linspace_impl, &mut builder)?;
157
158
let to_type = array_width.map_or_else(
159
|| DataType::List(Box::new(dt.clone())),
160
|width| DataType::Array(Box::new(dt.clone()), width),
161
);
162
out.cast(&to_type)
163
},
164
(dt1, dt2) if !dt1.is_primitive_numeric() || !dt2.is_primitive_numeric() => {
165
Err(PolarsError::ComputeError(
166
format!("'start' and 'end' have incompatible dtypes, got {dt1:?} and {dt2:?}")
167
.into(),
168
))
169
},
170
(_, _) => {
171
let start = start.strict_cast(&DataType::Float64)?;
172
let end = end.strict_cast(&DataType::Float64)?;
173
let start = start.f64()?;
174
let end = end.f64()?;
175
176
let mut builder = ListPrimitiveChunkedBuilder::<Float64Type>::new(
177
name,
178
len,
179
len * capacity_factor,
180
DataType::Float64,
181
);
182
183
let linspace_impl =
184
|start,
185
end,
186
num_samples,
187
builder: &mut ListPrimitiveChunkedBuilder<Float64Type>| {
188
let ls =
189
new_linear_space_f64(start, end, num_samples, closed, PlSmallStr::EMPTY)?;
190
builder.append_slice(ls.cont_slice().unwrap());
191
Ok(())
192
};
193
let out =
194
linear_spaces_impl_broadcast(start, end, num_samples, linspace_impl, &mut builder)?;
195
196
let to_type = array_width.map_or_else(
197
|| DataType::List(Box::new(DataType::Float64)),
198
|width| DataType::Array(Box::new(DataType::Float64), width),
199
);
200
out.cast(&to_type)
201
},
202
}
203
}
204
205
/// Create a ranges column from the given start/end columns and a range function.
206
pub(super) fn linear_spaces_impl_broadcast<T, F>(
207
start: &ChunkedArray<T>,
208
end: &ChunkedArray<T>,
209
num_samples: &UInt64Chunked,
210
linear_space_impl: F,
211
builder: &mut ListPrimitiveChunkedBuilder<T>,
212
) -> PolarsResult<Column>
213
where
214
T: PolarsFloatType,
215
F: Fn(T::Native, T::Native, u64, &mut ListPrimitiveChunkedBuilder<T>) -> PolarsResult<()>,
216
ListPrimitiveChunkedBuilder<T>: ListBuilderTrait,
217
{
218
match (start.len(), end.len(), num_samples.len()) {
219
(len_start, len_end, len_samples) if len_start == len_end && len_start == len_samples => {
220
// (n, n, n)
221
build_linear_spaces::<_, _, _, T, F>(
222
start.iter(),
223
end.iter(),
224
num_samples.iter(),
225
linear_space_impl,
226
builder,
227
)?;
228
},
229
// (1, n, n)
230
(1, len_end, len_samples) if len_end == len_samples => {
231
let start_value = start.get(0);
232
if start_value.is_some() {
233
build_linear_spaces::<_, _, _, T, F>(
234
std::iter::repeat(start_value),
235
end.iter(),
236
num_samples.iter(),
237
linear_space_impl,
238
builder,
239
)?
240
} else {
241
build_nulls(builder, len_end)
242
}
243
},
244
// (n, 1, n)
245
(len_start, 1, len_samples) if len_start == len_samples => {
246
let end_value = end.get(0);
247
if end_value.is_some() {
248
build_linear_spaces::<_, _, _, T, F>(
249
start.iter(),
250
std::iter::repeat(end_value),
251
num_samples.iter(),
252
linear_space_impl,
253
builder,
254
)?
255
} else {
256
build_nulls(builder, len_start)
257
}
258
},
259
// (n, n, 1)
260
(len_start, len_end, 1) if len_start == len_end => {
261
let num_samples_value = num_samples.get(0);
262
if num_samples_value.is_some() {
263
build_linear_spaces::<_, _, _, T, F>(
264
start.iter(),
265
end.iter(),
266
std::iter::repeat(num_samples_value),
267
linear_space_impl,
268
builder,
269
)?
270
} else {
271
build_nulls(builder, len_start)
272
}
273
},
274
// (n, 1, 1)
275
(len_start, 1, 1) => {
276
let end_value = end.get(0);
277
let num_samples_value = num_samples.get(0);
278
match (end_value, num_samples_value) {
279
(Some(_), Some(_)) => build_linear_spaces::<_, _, _, T, F>(
280
start.iter(),
281
std::iter::repeat(end_value),
282
std::iter::repeat(num_samples_value),
283
linear_space_impl,
284
builder,
285
)?,
286
_ => build_nulls(builder, len_start),
287
}
288
},
289
// (1, n, 1)
290
(1, len_end, 1) => {
291
let start_value = start.get(0);
292
let num_samples_value = num_samples.get(0);
293
match (start_value, num_samples_value) {
294
(Some(_), Some(_)) => build_linear_spaces::<_, _, _, T, F>(
295
std::iter::repeat(start_value),
296
end.iter(),
297
std::iter::repeat(num_samples_value),
298
linear_space_impl,
299
builder,
300
)?,
301
_ => build_nulls(builder, len_end),
302
}
303
},
304
// (1, 1, n)
305
(1, 1, len_num_samples) => {
306
let start_value = start.get(0);
307
let end_value = end.get(0);
308
match (start_value, end_value) {
309
(Some(_), Some(_)) => build_linear_spaces::<_, _, _, T, F>(
310
std::iter::repeat(start_value),
311
std::iter::repeat(end_value),
312
num_samples.iter(),
313
linear_space_impl,
314
builder,
315
)?,
316
_ => build_nulls(builder, len_num_samples),
317
}
318
},
319
(len_start, len_end, len_num_samples) => {
320
polars_bail!(
321
ComputeError:
322
"lengths of `start` ({}), `end` ({}), and `num_samples` ({}) do not match",
323
len_start, len_end, len_num_samples
324
)
325
},
326
};
327
let out = builder.finish().into_column();
328
Ok(out)
329
}
330
331
/// Iterate over a start and end column and create a range for each entry.
332
fn build_linear_spaces<I, J, K, T, F>(
333
start: I,
334
end: J,
335
num_samples: K,
336
linear_space_impl: F,
337
builder: &mut ListPrimitiveChunkedBuilder<T>,
338
) -> PolarsResult<()>
339
where
340
I: Iterator<Item = Option<T::Native>>,
341
J: Iterator<Item = Option<T::Native>>,
342
K: Iterator<Item = Option<u64>>,
343
T: PolarsFloatType,
344
F: Fn(T::Native, T::Native, u64, &mut ListPrimitiveChunkedBuilder<T>) -> PolarsResult<()>,
345
ListPrimitiveChunkedBuilder<T>: ListBuilderTrait,
346
{
347
for ((start, end), num_samples) in start.zip(end).zip(num_samples) {
348
match (start, end, num_samples) {
349
(Some(start), Some(end), Some(num_samples)) => {
350
linear_space_impl(start, end, num_samples, builder)?
351
},
352
_ => builder.append_null(),
353
}
354
}
355
Ok(())
356
}
357
358