Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-compute/src/arithmetic/signed.rs
6939 views
1
use arrow::array::{PrimitiveArray as PArr, StaticArray};
2
use arrow::compute::utils::{combine_validities_and, combine_validities_and3};
3
use polars_utils::floor_divmod::FloorDivMod;
4
use strength_reduce::*;
5
6
use super::PrimitiveArithmeticKernelImpl;
7
use crate::arity::{prim_binary_values, prim_unary_values};
8
use crate::comparisons::TotalEqKernel;
9
10
macro_rules! impl_signed_arith_kernel {
11
($T:ty, $StrRed:ty) => {
12
impl PrimitiveArithmeticKernelImpl for $T {
13
type TrueDivT = f64;
14
15
fn prim_wrapping_abs(lhs: PArr<$T>) -> PArr<$T> {
16
prim_unary_values(lhs, |x| x.wrapping_abs())
17
}
18
19
fn prim_wrapping_neg(lhs: PArr<$T>) -> PArr<$T> {
20
prim_unary_values(lhs, |x| x.wrapping_neg())
21
}
22
23
fn prim_wrapping_add(lhs: PArr<$T>, other: PArr<$T>) -> PArr<$T> {
24
prim_binary_values(lhs, other, |a, b| a.wrapping_add(b))
25
}
26
27
fn prim_wrapping_sub(lhs: PArr<$T>, other: PArr<$T>) -> PArr<$T> {
28
prim_binary_values(lhs, other, |a, b| a.wrapping_sub(b))
29
}
30
31
fn prim_wrapping_mul(lhs: PArr<$T>, other: PArr<$T>) -> PArr<$T> {
32
prim_binary_values(lhs, other, |a, b| a.wrapping_mul(b))
33
}
34
35
fn prim_wrapping_floor_div(mut lhs: PArr<$T>, mut other: PArr<$T>) -> PArr<$T> {
36
let mask = other.tot_ne_kernel_broadcast(&0);
37
let valid = combine_validities_and3(
38
lhs.take_validity().as_ref(), // Take validity so we don't
39
other.take_validity().as_ref(), // compute combination twice.
40
Some(&mask),
41
);
42
let ret =
43
prim_binary_values(lhs, other, |lhs, rhs| lhs.wrapping_floor_div_mod(rhs).0);
44
ret.with_validity(valid)
45
}
46
47
fn prim_wrapping_trunc_div(mut lhs: PArr<$T>, mut other: PArr<$T>) -> PArr<$T> {
48
let mask = other.tot_ne_kernel_broadcast(&0);
49
let valid = combine_validities_and3(
50
lhs.take_validity().as_ref(), // Take validity so we don't
51
other.take_validity().as_ref(), // compute combination twice.
52
Some(&mask),
53
);
54
let ret = prim_binary_values(lhs, other, |lhs, rhs| {
55
if rhs != 0 { lhs.wrapping_div(rhs) } else { 0 }
56
});
57
ret.with_validity(valid)
58
}
59
60
fn prim_wrapping_mod(mut lhs: PArr<$T>, mut other: PArr<$T>) -> PArr<$T> {
61
let mask = other.tot_ne_kernel_broadcast(&0);
62
let valid = combine_validities_and3(
63
lhs.take_validity().as_ref(), // Take validity so we don't
64
other.take_validity().as_ref(), // compute combination twice.
65
Some(&mask),
66
);
67
let ret =
68
prim_binary_values(lhs, other, |lhs, rhs| lhs.wrapping_floor_div_mod(rhs).1);
69
ret.with_validity(valid)
70
}
71
72
fn prim_wrapping_add_scalar(lhs: PArr<$T>, rhs: $T) -> PArr<$T> {
73
prim_unary_values(lhs, |x| x.wrapping_add(rhs))
74
}
75
76
fn prim_wrapping_sub_scalar(lhs: PArr<$T>, rhs: $T) -> PArr<$T> {
77
Self::prim_wrapping_add_scalar(lhs, rhs.wrapping_neg())
78
}
79
80
fn prim_wrapping_sub_scalar_lhs(lhs: $T, rhs: PArr<$T>) -> PArr<$T> {
81
prim_unary_values(rhs, |x| lhs.wrapping_sub(x))
82
}
83
84
fn prim_wrapping_mul_scalar(lhs: PArr<$T>, rhs: $T) -> PArr<$T> {
85
let scalar_u = rhs.unsigned_abs();
86
if rhs == 0 {
87
lhs.fill_with(0)
88
} else if rhs == 1 {
89
lhs
90
} else if scalar_u.is_power_of_two() {
91
// Power of two.
92
let shift = scalar_u.trailing_zeros();
93
if rhs > 0 {
94
prim_unary_values(lhs, |x| x << shift)
95
} else {
96
prim_unary_values(lhs, |x| (x << shift).wrapping_neg())
97
}
98
} else {
99
prim_unary_values(lhs, |x| x.wrapping_mul(rhs))
100
}
101
}
102
103
fn prim_wrapping_floor_div_scalar(lhs: PArr<$T>, rhs: $T) -> PArr<$T> {
104
if rhs == 0 {
105
PArr::full_null(lhs.len(), lhs.dtype().clone())
106
} else if rhs == -1 {
107
Self::prim_wrapping_neg(lhs)
108
} else if rhs == 1 {
109
lhs
110
} else {
111
let red = <$StrRed>::new(rhs.unsigned_abs());
112
prim_unary_values(lhs, |x| {
113
let (quot, rem) = <$StrRed>::div_rem(x.unsigned_abs(), red);
114
if (x < 0) != (rhs < 0) {
115
// Different signs: result should be negative.
116
// Since we handled rhs.abs() <= 1, quot fits.
117
let mut ret = -(quot as $T);
118
if rem != 0 {
119
// Division had remainder, subtract 1 to floor to
120
// negative infinity, as we truncated to zero.
121
ret -= 1;
122
}
123
ret
124
} else {
125
quot as $T
126
}
127
})
128
}
129
}
130
131
fn prim_wrapping_floor_div_scalar_lhs(lhs: $T, rhs: PArr<$T>) -> PArr<$T> {
132
let mask = rhs.tot_ne_kernel_broadcast(&0);
133
let valid = combine_validities_and(rhs.validity(), Some(&mask));
134
let ret = if lhs == 0 {
135
rhs.fill_with(0)
136
} else {
137
prim_unary_values(rhs, |x| lhs.wrapping_floor_div_mod(x).0)
138
};
139
ret.with_validity(valid)
140
}
141
142
fn prim_wrapping_trunc_div_scalar(lhs: PArr<$T>, rhs: $T) -> PArr<$T> {
143
if rhs == 0 {
144
PArr::full_null(lhs.len(), lhs.dtype().clone())
145
} else if rhs == -1 {
146
Self::prim_wrapping_neg(lhs)
147
} else if rhs == 1 {
148
lhs
149
} else {
150
let red = <$StrRed>::new(rhs.unsigned_abs());
151
prim_unary_values(lhs, |x| {
152
let quot = x.unsigned_abs() / red;
153
if (x < 0) != (rhs < 0) {
154
// Different signs: result should be negative.
155
-(quot as $T)
156
} else {
157
quot as $T
158
}
159
})
160
}
161
}
162
163
fn prim_wrapping_trunc_div_scalar_lhs(lhs: $T, rhs: PArr<$T>) -> PArr<$T> {
164
let mask = rhs.tot_ne_kernel_broadcast(&0);
165
let valid = combine_validities_and(rhs.validity(), Some(&mask));
166
let ret = if lhs == 0 {
167
rhs.fill_with(0)
168
} else {
169
prim_unary_values(rhs, |x| if x != 0 { lhs.wrapping_div(x) } else { 0 })
170
};
171
ret.with_validity(valid)
172
}
173
174
fn prim_wrapping_mod_scalar(lhs: PArr<$T>, rhs: $T) -> PArr<$T> {
175
if rhs == 0 {
176
PArr::full_null(lhs.len(), lhs.dtype().clone())
177
} else if rhs == -1 || rhs == 1 {
178
lhs.fill_with(0)
179
} else {
180
let scalar_u = rhs.unsigned_abs();
181
let red = <$StrRed>::new(scalar_u);
182
prim_unary_values(lhs, |x| {
183
// Remainder fits in signed type after reduction.
184
// Largest possible modulo -I::MIN, with
185
// -I::MIN-1 == I::MAX as largest remainder.
186
let mut rem_u = x.unsigned_abs() % red;
187
188
// Mixed signs: swap direction of remainder.
189
if rem_u != 0 && (rhs < 0) != (x < 0) {
190
rem_u = scalar_u - rem_u;
191
}
192
193
// Remainder should have sign of RHS.
194
if rhs < 0 { -(rem_u as $T) } else { rem_u as $T }
195
})
196
}
197
}
198
199
fn prim_wrapping_mod_scalar_lhs(lhs: $T, rhs: PArr<$T>) -> PArr<$T> {
200
let mask = rhs.tot_ne_kernel_broadcast(&0);
201
let valid = combine_validities_and(rhs.validity(), Some(&mask));
202
let ret = if lhs == 0 {
203
rhs.fill_with(0)
204
} else {
205
prim_unary_values(rhs, |x| lhs.wrapping_floor_div_mod(x).1)
206
};
207
ret.with_validity(valid)
208
}
209
210
fn prim_checked_mul_scalar(lhs: PArr<$T>, rhs: $T) -> PArr<$T> {
211
super::prim_checked_mul_scalar(&lhs, rhs)
212
}
213
214
fn prim_true_div(lhs: PArr<$T>, other: PArr<$T>) -> PArr<Self::TrueDivT> {
215
prim_binary_values(lhs, other, |a, b| a as f64 / b as f64)
216
}
217
218
fn prim_true_div_scalar(lhs: PArr<$T>, rhs: $T) -> PArr<Self::TrueDivT> {
219
let inv = 1.0 / rhs as f64;
220
prim_unary_values(lhs, |x| x as f64 * inv)
221
}
222
223
fn prim_true_div_scalar_lhs(lhs: $T, rhs: PArr<$T>) -> PArr<Self::TrueDivT> {
224
prim_unary_values(rhs, |x| lhs as f64 / x as f64)
225
}
226
}
227
};
228
}
229
230
impl_signed_arith_kernel!(i8, StrengthReducedU8);
231
impl_signed_arith_kernel!(i16, StrengthReducedU16);
232
impl_signed_arith_kernel!(i32, StrengthReducedU32);
233
impl_signed_arith_kernel!(i64, StrengthReducedU64);
234
impl_signed_arith_kernel!(i128, StrengthReducedU128);
235
236