Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-compute/src/cast/primitive_to.rs
8446 views
1
use std::hash::Hash;
2
3
use arrow::array::*;
4
use arrow::bitmap::{Bitmap, BitmapBuilder};
5
use arrow::compute::arity::unary;
6
use arrow::datatypes::{ArrowDataType, TimeUnit};
7
use arrow::offset::{Offset, Offsets};
8
use arrow::types::NativeType;
9
use num_traits::AsPrimitive;
10
#[cfg(feature = "dtype-decimal")]
11
use num_traits::{Float, ToPrimitive};
12
use polars_error::PolarsResult;
13
use polars_utils::float16::pf16;
14
use polars_utils::pl_str::PlSmallStr;
15
use polars_utils::vec::PushUnchecked;
16
17
use super::CastOptionsImpl;
18
use super::temporal::*;
19
#[cfg(feature = "dtype-decimal")]
20
use crate::decimal::{dec128_verify_prec_scale, f64_to_dec128, i128_to_dec128};
21
22
pub trait SerPrimitive {
23
fn write(f: &mut Vec<u8>, val: Self) -> usize
24
where
25
Self: Sized;
26
}
27
28
macro_rules! impl_ser_primitive {
29
($ptype:ident) => {
30
impl SerPrimitive for $ptype {
31
fn write(f: &mut Vec<u8>, val: Self) -> usize
32
where
33
Self: Sized,
34
{
35
let mut buffer = itoa::Buffer::new();
36
let value = buffer.format(val);
37
f.extend_from_slice(value.as_bytes());
38
value.len()
39
}
40
}
41
};
42
}
43
44
impl_ser_primitive!(i8);
45
impl_ser_primitive!(i16);
46
impl_ser_primitive!(i32);
47
impl_ser_primitive!(i64);
48
impl_ser_primitive!(i128);
49
impl_ser_primitive!(u8);
50
impl_ser_primitive!(u16);
51
impl_ser_primitive!(u32);
52
impl_ser_primitive!(u64);
53
impl_ser_primitive!(u128);
54
55
impl SerPrimitive for pf16 {
56
fn write(f: &mut Vec<u8>, val: Self) -> usize
57
where
58
Self: Sized,
59
{
60
f32::write(f, AsPrimitive::<f32>::as_(val))
61
}
62
}
63
64
impl SerPrimitive for f32 {
65
fn write(f: &mut Vec<u8>, val: Self) -> usize
66
where
67
Self: Sized,
68
{
69
let mut buffer = zmij::Buffer::new();
70
let value = buffer.format(val);
71
f.extend_from_slice(value.as_bytes());
72
value.len()
73
}
74
}
75
76
impl SerPrimitive for f64 {
77
fn write(f: &mut Vec<u8>, val: Self) -> usize
78
where
79
Self: Sized,
80
{
81
let mut buffer = zmij::Buffer::new();
82
let value = buffer.format(val);
83
f.extend_from_slice(value.as_bytes());
84
value.len()
85
}
86
}
87
88
fn fallible_unary<I, F, G, O>(
89
array: &PrimitiveArray<I>,
90
op: F,
91
fail: G,
92
dtype: ArrowDataType,
93
) -> PrimitiveArray<O>
94
where
95
I: NativeType,
96
O: NativeType,
97
F: Fn(I) -> O,
98
G: Fn(I) -> bool,
99
{
100
let values = array.values();
101
let mut out = Vec::with_capacity(array.len());
102
let mut i = 0;
103
104
while i < array.len() && !fail(values[i]) {
105
// SAFETY: We allocated enough before.
106
unsafe { out.push_unchecked(op(values[i])) };
107
i += 1;
108
}
109
110
if out.len() == array.len() {
111
return PrimitiveArray::<O>::new(dtype, out.into(), array.validity().cloned());
112
}
113
114
let mut validity = BitmapBuilder::with_capacity(array.len());
115
validity.extend_constant(out.len(), true);
116
117
for &value in &values[out.len()..] {
118
// SAFETY: We allocated enough before.
119
unsafe {
120
out.push_unchecked(op(value));
121
validity.push_unchecked(!fail(value));
122
}
123
}
124
125
debug_assert_eq!(out.len(), array.len());
126
debug_assert_eq!(validity.len(), array.len());
127
128
let validity = validity.freeze();
129
let validity = match array.validity() {
130
None => validity,
131
Some(arr_validity) => arrow::bitmap::and(&validity, arr_validity),
132
};
133
134
PrimitiveArray::<O>::new(dtype, out.into(), Some(validity))
135
}
136
137
fn primitive_to_values_and_offsets<T: NativeType + SerPrimitive, O: Offset>(
138
from: &PrimitiveArray<T>,
139
) -> (Vec<u8>, Offsets<O>) {
140
let mut values: Vec<u8> = Vec::with_capacity(from.len());
141
let mut offsets: Vec<O> = Vec::with_capacity(from.len() + 1);
142
offsets.push(O::default());
143
144
let mut offset: usize = 0;
145
146
unsafe {
147
for &x in from.values().iter() {
148
let len = T::write(&mut values, x);
149
150
offset += len;
151
offsets.push(O::from_as_usize(offset));
152
}
153
values.set_len(offset);
154
values.shrink_to_fit();
155
// SAFETY: offsets _are_ monotonically increasing
156
let offsets = Offsets::new_unchecked(offsets);
157
158
(values, offsets)
159
}
160
}
161
162
/// Returns a [`BooleanArray`] where every element is different from zero.
163
/// Validity is preserved.
164
pub fn primitive_to_boolean<T: NativeType>(
165
from: &PrimitiveArray<T>,
166
to_type: ArrowDataType,
167
) -> BooleanArray {
168
let iter = from.values().iter().map(|v| *v != T::default());
169
let values = Bitmap::from_trusted_len_iter(iter);
170
171
BooleanArray::new(to_type, values, from.validity().cloned())
172
}
173
174
pub(super) fn primitive_to_boolean_dyn<T>(
175
from: &dyn Array,
176
to_type: ArrowDataType,
177
) -> PolarsResult<Box<dyn Array>>
178
where
179
T: NativeType,
180
{
181
let from = from.as_any().downcast_ref().unwrap();
182
Ok(Box::new(primitive_to_boolean::<T>(from, to_type)))
183
}
184
185
/// Returns a [`Utf8Array`] where every element is the utf8 representation of the number.
186
pub(super) fn primitive_to_utf8<T: NativeType + SerPrimitive, O: Offset>(
187
from: &PrimitiveArray<T>,
188
) -> Utf8Array<O> {
189
let (values, offsets) = primitive_to_values_and_offsets(from);
190
unsafe {
191
Utf8Array::<O>::new_unchecked(
192
Utf8Array::<O>::default_dtype(),
193
offsets.into(),
194
values.into(),
195
from.validity().cloned(),
196
)
197
}
198
}
199
200
pub(super) fn primitive_to_utf8_dyn<T, O>(from: &dyn Array) -> PolarsResult<Box<dyn Array>>
201
where
202
O: Offset,
203
T: NativeType + SerPrimitive,
204
{
205
let from = from.as_any().downcast_ref().unwrap();
206
Ok(Box::new(primitive_to_utf8::<T, O>(from)))
207
}
208
209
pub(super) fn primitive_to_primitive_dyn<I, O>(
210
from: &dyn Array,
211
to_type: &ArrowDataType,
212
options: CastOptionsImpl,
213
) -> PolarsResult<Box<dyn Array>>
214
where
215
I: NativeType + num_traits::NumCast + num_traits::AsPrimitive<O>,
216
O: NativeType + num_traits::NumCast,
217
{
218
let from = from.as_any().downcast_ref::<PrimitiveArray<I>>().unwrap();
219
if options.wrapped {
220
Ok(Box::new(primitive_as_primitive::<I, O>(from, to_type)))
221
} else {
222
Ok(Box::new(primitive_to_primitive::<I, O>(from, to_type)))
223
}
224
}
225
226
/// Cast [`PrimitiveArray`] to a [`PrimitiveArray`] of another physical type via numeric conversion.
227
pub fn primitive_to_primitive<I, O>(
228
from: &PrimitiveArray<I>,
229
to_type: &ArrowDataType,
230
) -> PrimitiveArray<O>
231
where
232
I: NativeType + num_traits::NumCast,
233
O: NativeType + num_traits::NumCast,
234
{
235
let iter = from
236
.iter()
237
.map(|v| v.and_then(|x| num_traits::cast::cast::<I, O>(*x)));
238
PrimitiveArray::<O>::from_trusted_len_iter(iter).to(to_type.clone())
239
}
240
241
/// Returns a [`PrimitiveArray<i128>`] with the cast values. Values are `None` on overflow
242
#[cfg(feature = "dtype-decimal")]
243
pub fn integer_to_decimal<T: NativeType + ToPrimitive>(
244
from: &PrimitiveArray<T>,
245
to_precision: usize,
246
to_scale: usize,
247
) -> PrimitiveArray<i128> {
248
assert!(dec128_verify_prec_scale(to_precision, to_scale).is_ok());
249
let values = from
250
.iter()
251
.map(|x| i128_to_dec128(x?.to_i128()?, to_precision, to_scale));
252
PrimitiveArray::<i128>::from_trusted_len_iter(values)
253
.to(ArrowDataType::Decimal(to_precision, to_scale))
254
}
255
256
#[cfg(feature = "dtype-decimal")]
257
pub(super) fn integer_to_decimal_dyn<T>(
258
from: &dyn Array,
259
precision: usize,
260
scale: usize,
261
) -> PolarsResult<Box<dyn Array>>
262
where
263
T: NativeType + ToPrimitive,
264
{
265
let from = from.as_any().downcast_ref().unwrap();
266
Ok(Box::new(integer_to_decimal::<T>(from, precision, scale)))
267
}
268
269
/// Returns a [`PrimitiveArray<i128>`] with the cast values. Values are `None` on overflow
270
#[cfg(feature = "dtype-decimal")]
271
pub fn float_to_decimal<T: NativeType + Float + AsPrimitive<f64>>(
272
from: &PrimitiveArray<T>,
273
to_precision: usize,
274
to_scale: usize,
275
) -> PrimitiveArray<i128> {
276
assert!(dec128_verify_prec_scale(to_precision, to_scale).is_ok());
277
let values = from
278
.iter()
279
.map(|x| f64_to_dec128(x?.as_(), to_precision, to_scale));
280
PrimitiveArray::<i128>::from_trusted_len_iter(values)
281
.to(ArrowDataType::Decimal(to_precision, to_scale))
282
}
283
284
#[cfg(feature = "dtype-decimal")]
285
pub(super) fn float_to_decimal_dyn<T: NativeType + Float + AsPrimitive<f64>>(
286
from: &dyn Array,
287
precision: usize,
288
scale: usize,
289
) -> PolarsResult<Box<dyn Array>> {
290
let from = from.as_any().downcast_ref().unwrap();
291
Ok(Box::new(float_to_decimal::<T>(from, precision, scale)))
292
}
293
294
/// Cast [`PrimitiveArray`] as a [`PrimitiveArray`]
295
/// Same as `number as to_number_type` in rust
296
pub fn primitive_as_primitive<I, O>(
297
from: &PrimitiveArray<I>,
298
to_type: &ArrowDataType,
299
) -> PrimitiveArray<O>
300
where
301
I: NativeType + num_traits::AsPrimitive<O>,
302
O: NativeType,
303
{
304
unary(from, num_traits::AsPrimitive::<O>::as_, to_type.clone())
305
}
306
307
/// Cast [`PrimitiveArray`] to a [`PrimitiveArray`] of the same physical type.
308
/// This is O(1).
309
pub fn primitive_to_same_primitive<T>(
310
from: &PrimitiveArray<T>,
311
to_type: &ArrowDataType,
312
) -> PrimitiveArray<T>
313
where
314
T: NativeType,
315
{
316
PrimitiveArray::<T>::new(
317
to_type.clone(),
318
from.values().clone(),
319
from.validity().cloned(),
320
)
321
}
322
323
/// Cast [`PrimitiveArray`] to a [`PrimitiveArray`] of the same physical type.
324
/// This is O(1).
325
pub(super) fn primitive_to_same_primitive_dyn<T>(
326
from: &dyn Array,
327
to_type: &ArrowDataType,
328
) -> PolarsResult<Box<dyn Array>>
329
where
330
T: NativeType,
331
{
332
let from = from.as_any().downcast_ref().unwrap();
333
Ok(Box::new(primitive_to_same_primitive::<T>(from, to_type)))
334
}
335
336
pub(super) fn primitive_to_dictionary_dyn<T: NativeType + Eq + Hash, K: DictionaryKey>(
337
from: &dyn Array,
338
) -> PolarsResult<Box<dyn Array>> {
339
let from = from.as_any().downcast_ref().unwrap();
340
primitive_to_dictionary::<T, K>(from).map(|x| Box::new(x) as Box<dyn Array>)
341
}
342
343
/// Cast [`PrimitiveArray`] to [`DictionaryArray`]. Also known as packing.
344
/// # Errors
345
/// This function errors if the maximum key is smaller than the number of distinct elements
346
/// in the array.
347
pub fn primitive_to_dictionary<T: NativeType + Eq + Hash, K: DictionaryKey>(
348
from: &PrimitiveArray<T>,
349
) -> PolarsResult<DictionaryArray<K>> {
350
let iter = from.iter().map(|x| x.copied());
351
let mut array = MutableDictionaryArray::<K, _>::try_empty(MutablePrimitiveArray::<T>::from(
352
from.dtype().clone(),
353
))?;
354
array.reserve(from.len());
355
array.try_extend(iter)?;
356
357
Ok(array.into())
358
}
359
360
/// # Safety
361
///
362
/// `dtype` should be valid for primitive.
363
pub unsafe fn primitive_map_is_valid<T: NativeType>(
364
from: &PrimitiveArray<T>,
365
f: impl Fn(T) -> bool,
366
dtype: ArrowDataType,
367
) -> PrimitiveArray<T> {
368
let values = from.values().clone();
369
370
let validity: Bitmap = values.iter().map(|&v| f(v)).collect();
371
372
let validity = if validity.unset_bits() > 0 {
373
let new_validity = match from.validity() {
374
None => validity,
375
Some(v) => v & &validity,
376
};
377
378
Some(new_validity)
379
} else {
380
from.validity().cloned()
381
};
382
383
// SAFETY:
384
// - Validity did not change length
385
// - dtype should be valid
386
unsafe { PrimitiveArray::new_unchecked(dtype, values, validity) }
387
}
388
389
/// Conversion of `Int32` to `Time32(TimeUnit::Second)`
390
pub fn int32_to_time32s(from: &PrimitiveArray<i32>) -> PrimitiveArray<i32> {
391
// SAFETY: Time32(TimeUnit::Second) is valid for Int32
392
unsafe {
393
primitive_map_is_valid(
394
from,
395
|v| (0..SECONDS_IN_DAY as i32).contains(&v),
396
ArrowDataType::Time32(TimeUnit::Second),
397
)
398
}
399
}
400
401
/// Conversion of `Int32` to `Time32(TimeUnit::Millisecond)`
402
pub fn int32_to_time32ms(from: &PrimitiveArray<i32>) -> PrimitiveArray<i32> {
403
// SAFETY: Time32(TimeUnit::Millisecond) is valid for Int32
404
unsafe {
405
primitive_map_is_valid(
406
from,
407
|v| (0..MILLISECONDS_IN_DAY as i32).contains(&v),
408
ArrowDataType::Time32(TimeUnit::Millisecond),
409
)
410
}
411
}
412
413
/// Conversion of `Int64` to `Time32(TimeUnit::Microsecond)`
414
pub fn int64_to_time64us(from: &PrimitiveArray<i64>) -> PrimitiveArray<i64> {
415
// SAFETY: Time64(TimeUnit::Microsecond) is valid for Int64
416
unsafe {
417
primitive_map_is_valid(
418
from,
419
|v| (0..MICROSECONDS_IN_DAY).contains(&v),
420
ArrowDataType::Time32(TimeUnit::Microsecond),
421
)
422
}
423
}
424
425
/// Conversion of `Int64` to `Time32(TimeUnit::Nanosecond)`
426
pub fn int64_to_time64ns(from: &PrimitiveArray<i64>) -> PrimitiveArray<i64> {
427
// SAFETY: Time64(TimeUnit::Nanosecond) is valid for Int64
428
unsafe {
429
primitive_map_is_valid(
430
from,
431
|v| (0..NANOSECONDS_IN_DAY).contains(&v),
432
ArrowDataType::Time64(TimeUnit::Nanosecond),
433
)
434
}
435
}
436
437
/// Conversion of dates
438
pub fn date32_to_date64(from: &PrimitiveArray<i32>) -> PrimitiveArray<i64> {
439
unary(
440
from,
441
|x| x as i64 * MILLISECONDS_IN_DAY,
442
ArrowDataType::Date64,
443
)
444
}
445
446
/// Conversion of dates
447
pub fn date64_to_date32(from: &PrimitiveArray<i64>) -> PrimitiveArray<i32> {
448
unary(
449
from,
450
|x| (x / MILLISECONDS_IN_DAY) as i32,
451
ArrowDataType::Date32,
452
)
453
}
454
455
/// Conversion of times
456
pub fn time32s_to_time32ms(from: &PrimitiveArray<i32>) -> PrimitiveArray<i32> {
457
fallible_unary(
458
from,
459
|x| x.wrapping_mul(1000),
460
|x| x.checked_mul(1000).is_none(),
461
ArrowDataType::Time32(TimeUnit::Millisecond),
462
)
463
}
464
465
/// Conversion of times
466
pub fn time32ms_to_time32s(from: &PrimitiveArray<i32>) -> PrimitiveArray<i32> {
467
unary(from, |x| x / 1000, ArrowDataType::Time32(TimeUnit::Second))
468
}
469
470
/// Conversion of times
471
pub fn time64us_to_time64ns(from: &PrimitiveArray<i64>) -> PrimitiveArray<i64> {
472
fallible_unary(
473
from,
474
|x| x.wrapping_mul(1000),
475
|x| x.checked_mul(1000).is_none(),
476
ArrowDataType::Time64(TimeUnit::Nanosecond),
477
)
478
}
479
480
/// Conversion of times
481
pub fn time64ns_to_time64us(from: &PrimitiveArray<i64>) -> PrimitiveArray<i64> {
482
unary(
483
from,
484
|x| x / 1000,
485
ArrowDataType::Time64(TimeUnit::Microsecond),
486
)
487
}
488
489
/// Conversion of timestamp
490
pub fn timestamp_to_date64(from: &PrimitiveArray<i64>, from_unit: TimeUnit) -> PrimitiveArray<i64> {
491
let from_size = time_unit_multiple(from_unit);
492
let to_size = MILLISECONDS;
493
let to_type = ArrowDataType::Date64;
494
495
// Scale time_array by (to_size / from_size) using a
496
// single integer operation, but need to avoid integer
497
// math rounding down to zero
498
499
match to_size.cmp(&from_size) {
500
std::cmp::Ordering::Less => unary(from, |x| x / (from_size / to_size), to_type),
501
std::cmp::Ordering::Equal => primitive_to_same_primitive(from, &to_type),
502
std::cmp::Ordering::Greater => fallible_unary(
503
from,
504
|x| x.wrapping_mul(to_size / from_size),
505
|x| x.checked_mul(to_size / from_size).is_none(),
506
to_type,
507
),
508
}
509
}
510
511
/// Conversion of timestamp
512
pub fn timestamp_to_date32(from: &PrimitiveArray<i64>, from_unit: TimeUnit) -> PrimitiveArray<i32> {
513
let from_size = time_unit_multiple(from_unit) * SECONDS_IN_DAY;
514
unary(from, |x| (x / from_size) as i32, ArrowDataType::Date32)
515
}
516
517
/// Conversion of time
518
pub fn time32_to_time64(
519
from: &PrimitiveArray<i32>,
520
from_unit: TimeUnit,
521
to_unit: TimeUnit,
522
) -> PrimitiveArray<i64> {
523
let from_size = time_unit_multiple(from_unit);
524
let to_size = time_unit_multiple(to_unit);
525
let divisor = to_size / from_size;
526
fallible_unary(
527
from,
528
|x| (x as i64).wrapping_mul(divisor),
529
|x| (x as i64).checked_mul(divisor).is_none(),
530
ArrowDataType::Time64(to_unit),
531
)
532
}
533
534
/// Conversion of time
535
pub fn time64_to_time32(
536
from: &PrimitiveArray<i64>,
537
from_unit: TimeUnit,
538
to_unit: TimeUnit,
539
) -> PrimitiveArray<i32> {
540
let from_size = time_unit_multiple(from_unit);
541
let to_size = time_unit_multiple(to_unit);
542
let divisor = from_size / to_size;
543
unary(
544
from,
545
|x| (x / divisor) as i32,
546
ArrowDataType::Time32(to_unit),
547
)
548
}
549
550
/// Conversion of timestamp
551
pub fn timestamp_to_timestamp(
552
from: &PrimitiveArray<i64>,
553
from_unit: TimeUnit,
554
to_unit: TimeUnit,
555
tz: &Option<PlSmallStr>,
556
) -> PrimitiveArray<i64> {
557
let from_size = time_unit_multiple(from_unit);
558
let to_size = time_unit_multiple(to_unit);
559
let to_type = ArrowDataType::Timestamp(to_unit, tz.clone());
560
// we either divide or multiply, depending on size of each unit
561
if from_size >= to_size {
562
unary(from, |x| x / (from_size / to_size), to_type)
563
} else {
564
fallible_unary(
565
from,
566
|x| x.wrapping_mul(to_size / from_size),
567
|x| x.checked_mul(to_size / from_size).is_none(),
568
to_type,
569
)
570
}
571
}
572
573
/// Returns a [`Utf8Array`] where every element is the utf8 representation of the number.
574
pub(super) fn primitive_to_binview<T: NativeType + SerPrimitive>(
575
from: &PrimitiveArray<T>,
576
) -> BinaryViewArray {
577
let mut mutable = MutableBinaryViewArray::with_capacity(from.len());
578
579
let mut scratch = vec![];
580
for &x in from.values().iter() {
581
unsafe { scratch.set_len(0) };
582
T::write(&mut scratch, x);
583
mutable.push_value_ignore_validity(&scratch)
584
}
585
586
mutable.freeze().with_validity(from.validity().cloned())
587
}
588
589
pub(super) fn primitive_to_binview_dyn<T>(from: &dyn Array) -> BinaryViewArray
590
where
591
T: NativeType + SerPrimitive,
592
{
593
let from = from.as_any().downcast_ref().unwrap();
594
primitive_to_binview::<T>(from)
595
}
596
597