Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-compute/src/arithmetic/mod.rs
6939 views
1
use std::any::TypeId;
2
3
use arrow::array::{Array, PrimitiveArray};
4
use arrow::bitmap::BitmapBuilder;
5
use arrow::types::NativeType;
6
7
// Low-level comparison kernel.
8
pub trait ArithmeticKernel: Sized + Array {
9
type Scalar;
10
type TrueDivT: NativeType;
11
12
fn wrapping_abs(self) -> Self;
13
fn wrapping_neg(self) -> Self;
14
fn wrapping_add(self, rhs: Self) -> Self;
15
fn wrapping_sub(self, rhs: Self) -> Self;
16
fn wrapping_mul(self, rhs: Self) -> Self;
17
fn wrapping_floor_div(self, rhs: Self) -> Self;
18
fn wrapping_trunc_div(self, rhs: Self) -> Self;
19
fn wrapping_mod(self, rhs: Self) -> Self;
20
21
fn wrapping_add_scalar(self, rhs: Self::Scalar) -> Self;
22
fn wrapping_sub_scalar(self, rhs: Self::Scalar) -> Self;
23
fn wrapping_sub_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self;
24
fn wrapping_mul_scalar(self, rhs: Self::Scalar) -> Self;
25
fn wrapping_floor_div_scalar(self, rhs: Self::Scalar) -> Self;
26
fn wrapping_floor_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self;
27
fn wrapping_trunc_div_scalar(self, rhs: Self::Scalar) -> Self;
28
fn wrapping_trunc_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self;
29
fn wrapping_mod_scalar(self, rhs: Self::Scalar) -> Self;
30
fn wrapping_mod_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self;
31
32
fn checked_mul_scalar(self, rhs: Self::Scalar) -> Self;
33
34
fn true_div(self, rhs: Self) -> PrimitiveArray<Self::TrueDivT>;
35
fn true_div_scalar(self, rhs: Self::Scalar) -> PrimitiveArray<Self::TrueDivT>;
36
fn true_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> PrimitiveArray<Self::TrueDivT>;
37
38
// TODO: remove these.
39
// These are flooring division for integer types, true division for floating point types.
40
fn legacy_div(self, rhs: Self) -> Self {
41
if TypeId::of::<Self>() == TypeId::of::<PrimitiveArray<Self::TrueDivT>>() {
42
let ret = self.true_div(rhs);
43
unsafe {
44
let cast_ret = std::mem::transmute_copy(&ret);
45
std::mem::forget(ret);
46
cast_ret
47
}
48
} else {
49
self.wrapping_floor_div(rhs)
50
}
51
}
52
fn legacy_div_scalar(self, rhs: Self::Scalar) -> Self {
53
if TypeId::of::<Self>() == TypeId::of::<PrimitiveArray<Self::TrueDivT>>() {
54
let ret = self.true_div_scalar(rhs);
55
unsafe {
56
let cast_ret = std::mem::transmute_copy(&ret);
57
std::mem::forget(ret);
58
cast_ret
59
}
60
} else {
61
self.wrapping_floor_div_scalar(rhs)
62
}
63
}
64
65
fn legacy_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self {
66
if TypeId::of::<Self>() == TypeId::of::<PrimitiveArray<Self::TrueDivT>>() {
67
let ret = ArithmeticKernel::true_div_scalar_lhs(lhs, rhs);
68
unsafe {
69
let cast_ret = std::mem::transmute_copy(&ret);
70
std::mem::forget(ret);
71
cast_ret
72
}
73
} else {
74
ArithmeticKernel::wrapping_floor_div_scalar_lhs(lhs, rhs)
75
}
76
}
77
}
78
79
// Proxy trait so one can bound T: HasPrimitiveArithmeticKernel. Sadly Rust
80
// doesn't support adding supertraits for other types.
81
#[allow(private_bounds)]
82
pub trait HasPrimitiveArithmeticKernel: NativeType + PrimitiveArithmeticKernelImpl {}
83
impl<T: NativeType + PrimitiveArithmeticKernelImpl> HasPrimitiveArithmeticKernel for T {}
84
85
use PrimitiveArray as PArr;
86
use num_traits::{CheckedMul, WrappingMul};
87
use polars_utils::vec::PushUnchecked;
88
89
#[doc(hidden)]
90
pub trait PrimitiveArithmeticKernelImpl: NativeType {
91
type TrueDivT: NativeType;
92
93
fn prim_wrapping_abs(lhs: PArr<Self>) -> PArr<Self>;
94
fn prim_wrapping_neg(lhs: PArr<Self>) -> PArr<Self>;
95
fn prim_wrapping_add(lhs: PArr<Self>, rhs: PArr<Self>) -> PArr<Self>;
96
fn prim_wrapping_sub(lhs: PArr<Self>, rhs: PArr<Self>) -> PArr<Self>;
97
fn prim_wrapping_mul(lhs: PArr<Self>, rhs: PArr<Self>) -> PArr<Self>;
98
fn prim_wrapping_floor_div(lhs: PArr<Self>, rhs: PArr<Self>) -> PArr<Self>;
99
fn prim_wrapping_trunc_div(lhs: PArr<Self>, rhs: PArr<Self>) -> PArr<Self>;
100
fn prim_wrapping_mod(lhs: PArr<Self>, rhs: PArr<Self>) -> PArr<Self>;
101
102
fn prim_wrapping_add_scalar(lhs: PArr<Self>, rhs: Self) -> PArr<Self>;
103
fn prim_wrapping_sub_scalar(lhs: PArr<Self>, rhs: Self) -> PArr<Self>;
104
fn prim_wrapping_sub_scalar_lhs(lhs: Self, rhs: PArr<Self>) -> PArr<Self>;
105
fn prim_wrapping_mul_scalar(lhs: PArr<Self>, rhs: Self) -> PArr<Self>;
106
fn prim_wrapping_floor_div_scalar(lhs: PArr<Self>, rhs: Self) -> PArr<Self>;
107
fn prim_wrapping_floor_div_scalar_lhs(lhs: Self, rhs: PArr<Self>) -> PArr<Self>;
108
fn prim_wrapping_trunc_div_scalar(lhs: PArr<Self>, rhs: Self) -> PArr<Self>;
109
fn prim_wrapping_trunc_div_scalar_lhs(lhs: Self, rhs: PArr<Self>) -> PArr<Self>;
110
fn prim_wrapping_mod_scalar(lhs: PArr<Self>, rhs: Self) -> PArr<Self>;
111
fn prim_wrapping_mod_scalar_lhs(lhs: Self, rhs: PArr<Self>) -> PArr<Self>;
112
113
fn prim_checked_mul_scalar(lhs: PArr<Self>, rhs: Self) -> PArr<Self>;
114
115
fn prim_true_div(lhs: PArr<Self>, rhs: PArr<Self>) -> PArr<Self::TrueDivT>;
116
fn prim_true_div_scalar(lhs: PArr<Self>, rhs: Self) -> PArr<Self::TrueDivT>;
117
fn prim_true_div_scalar_lhs(lhs: Self, rhs: PArr<Self>) -> PArr<Self::TrueDivT>;
118
}
119
120
#[rustfmt::skip]
121
impl<T: HasPrimitiveArithmeticKernel> ArithmeticKernel for PrimitiveArray<T> {
122
type Scalar = T;
123
type TrueDivT = T::TrueDivT;
124
125
fn wrapping_abs(self) -> Self { T::prim_wrapping_abs(self) }
126
fn wrapping_neg(self) -> Self { T::prim_wrapping_neg(self) }
127
fn wrapping_add(self, rhs: Self) -> Self { T::prim_wrapping_add(self, rhs) }
128
fn wrapping_sub(self, rhs: Self) -> Self { T::prim_wrapping_sub(self, rhs) }
129
fn wrapping_mul(self, rhs: Self) -> Self { T::prim_wrapping_mul(self, rhs) }
130
fn wrapping_floor_div(self, rhs: Self) -> Self { T::prim_wrapping_floor_div(self, rhs) }
131
fn wrapping_trunc_div(self, rhs: Self) -> Self { T::prim_wrapping_trunc_div(self, rhs) }
132
fn wrapping_mod(self, rhs: Self) -> Self { T::prim_wrapping_mod(self, rhs) }
133
134
fn wrapping_add_scalar(self, rhs: Self::Scalar) -> Self { T::prim_wrapping_add_scalar(self, rhs) }
135
fn wrapping_sub_scalar(self, rhs: Self::Scalar) -> Self { T::prim_wrapping_sub_scalar(self, rhs) }
136
fn wrapping_sub_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self { T::prim_wrapping_sub_scalar_lhs(lhs, rhs) }
137
fn wrapping_mul_scalar(self, rhs: Self::Scalar) -> Self { T::prim_wrapping_mul_scalar(self, rhs) }
138
fn wrapping_floor_div_scalar(self, rhs: Self::Scalar) -> Self { T::prim_wrapping_floor_div_scalar(self, rhs) }
139
fn wrapping_floor_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self { T::prim_wrapping_floor_div_scalar_lhs(lhs, rhs) }
140
fn wrapping_trunc_div_scalar(self, rhs: Self::Scalar) -> Self { T::prim_wrapping_trunc_div_scalar(self, rhs) }
141
fn wrapping_trunc_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self { T::prim_wrapping_trunc_div_scalar_lhs(lhs, rhs) }
142
fn wrapping_mod_scalar(self, rhs: Self::Scalar) -> Self { T::prim_wrapping_mod_scalar(self, rhs) }
143
fn wrapping_mod_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self { T::prim_wrapping_mod_scalar_lhs(lhs, rhs) }
144
145
fn checked_mul_scalar(self, rhs: Self::Scalar) -> Self { T::prim_checked_mul_scalar(self, rhs) }
146
147
fn true_div(self, rhs: Self) -> PrimitiveArray<Self::TrueDivT> { T::prim_true_div(self, rhs) }
148
fn true_div_scalar(self, rhs: Self::Scalar) -> PrimitiveArray<Self::TrueDivT> { T::prim_true_div_scalar(self, rhs) }
149
fn true_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> PrimitiveArray<Self::TrueDivT> { T::prim_true_div_scalar_lhs(lhs, rhs) }
150
}
151
152
mod float;
153
pub mod pl_num;
154
mod signed;
155
mod unsigned;
156
157
fn prim_checked_mul_scalar<I: NativeType + CheckedMul + WrappingMul>(
158
array: &PrimitiveArray<I>,
159
factor: I,
160
) -> PrimitiveArray<I> {
161
let values = array.values();
162
let mut out = Vec::with_capacity(array.len());
163
let mut i = 0;
164
165
while i < array.len() && values[i].checked_mul(&factor).is_some() {
166
// SAFETY: We allocated enough before.
167
unsafe { out.push_unchecked(values[i].wrapping_mul(&factor)) };
168
i += 1;
169
}
170
171
if out.len() == array.len() {
172
return PrimitiveArray::<I>::new(
173
I::PRIMITIVE.into(),
174
out.into(),
175
array.validity().cloned(),
176
);
177
}
178
179
let mut validity = BitmapBuilder::with_capacity(array.len());
180
validity.extend_constant(out.len(), true);
181
182
for &value in &values[out.len()..] {
183
// SAFETY: We allocated enough before.
184
unsafe {
185
out.push_unchecked(value.wrapping_mul(&factor));
186
validity.push_unchecked(value.checked_mul(&factor).is_some());
187
}
188
}
189
190
debug_assert_eq!(out.len(), array.len());
191
debug_assert_eq!(validity.len(), array.len());
192
193
let validity = validity.freeze();
194
let validity = match array.validity() {
195
None => validity,
196
Some(arr_validity) => arrow::bitmap::and(&validity, arr_validity),
197
};
198
199
PrimitiveArray::<I>::new(I::PRIMITIVE.into(), out.into(), Some(validity))
200
}
201
202