Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-compute/src/arity.rs
6939 views
1
#![allow(unsafe_op_in_unsafe_fn)]
2
use arrow::array::PrimitiveArray;
3
use arrow::compute::utils::combine_validities_and;
4
use arrow::types::NativeType;
5
6
/// To reduce codegen we use these helpers where the input and output arrays
7
/// may overlap. These are marked to never be inlined, this way only a single
8
/// unrolled kernel gets generated, even if we call it in multiple ways.
9
///
10
/// # Safety
11
/// - arr must point to a readable slice of length len.
12
/// - out must point to a writable slice of length len.
13
#[inline(never)]
14
unsafe fn ptr_apply_unary_kernel<I: Copy, O, F: Fn(I) -> O>(
15
arr: *const I,
16
out: *mut O,
17
len: usize,
18
op: F,
19
) {
20
for i in 0..len {
21
let ret = op(arr.add(i).read());
22
out.add(i).write(ret);
23
}
24
}
25
26
/// # Safety
27
/// - left must point to a readable slice of length len.
28
/// - right must point to a readable slice of length len.
29
/// - out must point to a writable slice of length len.
30
#[inline(never)]
31
unsafe fn ptr_apply_binary_kernel<L: Copy, R: Copy, O, F: Fn(L, R) -> O>(
32
left: *const L,
33
right: *const R,
34
out: *mut O,
35
len: usize,
36
op: F,
37
) {
38
for i in 0..len {
39
let ret = op(left.add(i).read(), right.add(i).read());
40
out.add(i).write(ret);
41
}
42
}
43
44
/// Applies a function to all the values (regardless of nullability).
45
///
46
/// May reuse the memory of the array if possible.
47
pub fn prim_unary_values<I, O, F>(mut arr: PrimitiveArray<I>, op: F) -> PrimitiveArray<O>
48
where
49
I: NativeType,
50
O: NativeType,
51
F: Fn(I) -> O,
52
{
53
let len = arr.len();
54
55
// Reuse memory if possible.
56
if size_of::<I>() == size_of::<O>() && align_of::<I>() == align_of::<O>() {
57
if let Some(values) = arr.get_mut_values() {
58
let ptr = values.as_mut_ptr();
59
// SAFETY: checked same size & alignment I/O, NativeType is always Pod.
60
unsafe { ptr_apply_unary_kernel(ptr, ptr as *mut O, len, op) }
61
return arr.transmute::<O>();
62
}
63
}
64
65
let mut out = Vec::with_capacity(len);
66
unsafe {
67
// SAFETY: checked pointers point to slices of length len.
68
ptr_apply_unary_kernel(arr.values().as_ptr(), out.as_mut_ptr(), len, op);
69
out.set_len(len);
70
}
71
PrimitiveArray::from_vec(out).with_validity(arr.take_validity())
72
}
73
74
/// Apply a binary function to all the values (regardless of nullability)
75
/// in (lhs, rhs). Combines the validities with a bitand.
76
///
77
/// May reuse the memory of one of its arguments if possible.
78
pub fn prim_binary_values<L, R, O, F>(
79
mut lhs: PrimitiveArray<L>,
80
mut rhs: PrimitiveArray<R>,
81
op: F,
82
) -> PrimitiveArray<O>
83
where
84
L: NativeType,
85
R: NativeType,
86
O: NativeType,
87
F: Fn(L, R) -> O,
88
{
89
assert_eq!(lhs.len(), rhs.len());
90
let len = lhs.len();
91
92
let validity = combine_validities_and(lhs.validity(), rhs.validity());
93
94
// Reuse memory if possible.
95
if size_of::<L>() == size_of::<O>() && align_of::<L>() == align_of::<O>() {
96
if let Some(lv) = lhs.get_mut_values() {
97
let lp = lv.as_mut_ptr();
98
let rp = rhs.values().as_ptr();
99
unsafe {
100
// SAFETY: checked same size & alignment L/O, NativeType is always Pod.
101
ptr_apply_binary_kernel(lp, rp, lp as *mut O, len, op);
102
}
103
return lhs.transmute::<O>().with_validity(validity);
104
}
105
}
106
if size_of::<R>() == size_of::<O>() && align_of::<R>() == align_of::<O>() {
107
if let Some(rv) = rhs.get_mut_values() {
108
let lp = lhs.values().as_ptr();
109
let rp = rv.as_mut_ptr();
110
unsafe {
111
// SAFETY: checked same size & alignment R/O, NativeType is always Pod.
112
ptr_apply_binary_kernel(lp, rp, rp as *mut O, len, op);
113
}
114
return rhs.transmute::<O>().with_validity(validity);
115
}
116
}
117
118
let mut out = Vec::with_capacity(len);
119
unsafe {
120
// SAFETY: checked pointers point to slices of length len.
121
let lp = lhs.values().as_ptr();
122
let rp = rhs.values().as_ptr();
123
ptr_apply_binary_kernel(lp, rp, out.as_mut_ptr(), len, op);
124
out.set_len(len);
125
}
126
PrimitiveArray::from_vec(out).with_validity(validity)
127
}
128
129