Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-expr/src/dispatch/pow.rs
7884 views
1
use num_traits::pow::Pow;
2
use num_traits::{Float, One, ToPrimitive, Zero};
3
use polars_core::error::{PolarsResult, polars_bail, polars_ensure, polars_err};
4
use polars_core::prelude::arity::{broadcast_binary_elementwise, unary_elementwise_values};
5
use polars_core::prelude::{
6
ChunkApply, ChunkedArray, Column, DataType, IntoColumn, PolarsFloatType, PolarsIntegerType,
7
PolarsNumericType,
8
};
9
use polars_core::with_match_physical_integer_type;
10
11
fn pow_on_chunked_arrays<T, F>(
12
base: &ChunkedArray<T>,
13
exponent: &ChunkedArray<F>,
14
) -> ChunkedArray<T>
15
where
16
T: PolarsNumericType,
17
F: PolarsNumericType,
18
T::Native: Pow<F::Native, Output = T::Native> + ToPrimitive,
19
{
20
if exponent.len() == 1 {
21
if let Some(e) = exponent.get(0) {
22
if e == F::Native::zero() {
23
return unary_elementwise_values(base, |_| T::Native::one());
24
}
25
if e == F::Native::one() {
26
return base.clone();
27
}
28
if e == F::Native::one() + F::Native::one() {
29
return base * base;
30
}
31
}
32
}
33
34
broadcast_binary_elementwise(base, exponent, |b, e| Some(Pow::pow(b?, e?)))
35
}
36
37
fn pow_on_floats<T>(base: &ChunkedArray<T>, exponent: &ChunkedArray<T>) -> PolarsResult<Column>
38
where
39
T: PolarsFloatType,
40
T::Native: Pow<T::Native, Output = T::Native> + ToPrimitive + Float,
41
ChunkedArray<T>: IntoColumn,
42
{
43
let dtype = T::get_static_dtype();
44
45
if exponent.len() == 1 {
46
let Some(exponent_value) = exponent.get(0) else {
47
return Ok(Column::full_null(base.name().clone(), base.len(), &dtype));
48
};
49
let s = match exponent_value.to_f64().unwrap() {
50
1.0 => base.clone().into_column(),
51
// specialized sqrt will ensure (-inf)^0.5 = NaN
52
// and will likely be faster as well.
53
0.5 => base.apply_values(|v| v.sqrt()).into_column(),
54
a if a.fract() == 0.0 && a < 10.0 && a > 1.0 => {
55
let mut out = base.clone();
56
57
for _ in 1..exponent_value.to_u8().unwrap() {
58
out = out * base.clone()
59
}
60
out.into_column()
61
},
62
_ => base
63
.apply_values(|v| Pow::pow(v, exponent_value))
64
.into_column(),
65
};
66
Ok(s)
67
} else {
68
Ok(pow_on_chunked_arrays(base, exponent).into_column())
69
}
70
}
71
72
fn pow_to_uint_dtype<T, F>(
73
base: &ChunkedArray<T>,
74
exponent: &ChunkedArray<F>,
75
) -> PolarsResult<Column>
76
where
77
T: PolarsIntegerType,
78
F: PolarsIntegerType,
79
T::Native: Pow<F::Native, Output = T::Native> + ToPrimitive,
80
ChunkedArray<T>: IntoColumn,
81
{
82
let dtype = T::get_static_dtype();
83
84
if exponent.len() == 1 {
85
let Some(exponent_value) = exponent.get(0) else {
86
return Ok(Column::full_null(base.name().clone(), base.len(), &dtype));
87
};
88
let s = match exponent_value.to_u64().unwrap() {
89
1 => base.clone().into_column(),
90
2..=10 => {
91
let mut out = base.clone();
92
93
for _ in 1..exponent_value.to_u8().unwrap() {
94
out = out * base.clone()
95
}
96
out.into_column()
97
},
98
_ => base
99
.apply_values(|v| Pow::pow(v, exponent_value))
100
.into_column(),
101
};
102
Ok(s)
103
} else {
104
Ok(pow_on_chunked_arrays(base, exponent).into_column())
105
}
106
}
107
108
fn pow_on_series(base: &Column, exponent: &Column) -> PolarsResult<Column> {
109
let base_dtype = base.dtype();
110
polars_ensure!(
111
base_dtype.is_primitive_numeric(),
112
InvalidOperation: "`pow` operation not supported for dtype `{}` as base", base_dtype
113
);
114
let exponent_dtype = exponent.dtype();
115
polars_ensure!(
116
exponent_dtype.is_primitive_numeric(),
117
InvalidOperation: "`pow` operation not supported for dtype `{}` as exponent", exponent_dtype
118
);
119
120
// if false, dtype is float
121
if base_dtype.is_integer() {
122
with_match_physical_integer_type!(base_dtype, |$native_type| {
123
if exponent_dtype.is_float() {
124
match exponent_dtype {
125
#[cfg(feature = "dtype-f16")]
126
Float16 => {
127
let ca = base.cast(&DataType::Float16)?;
128
let exponent = exponent.strict_cast(&DataType::Float16)?;
129
pow_on_floats(ca.f16().unwrap(), exponent.f16().unwrap())
130
},
131
Float32 => {
132
let ca = base.cast(&DataType::Float32)?;
133
pow_on_floats(ca.f32().unwrap(), exponent.f32().unwrap())
134
},
135
Float64 => {
136
let ca = base.cast(&DataType::Float64)?;
137
pow_on_floats(ca.f64().unwrap(), exponent.f64().unwrap())
138
},
139
_ => unreachable!(),
140
}
141
} else {
142
let ca = base.$native_type().unwrap();
143
let exponent = exponent.strict_cast(&DataType::UInt32).map_err(|err| polars_err!(
144
InvalidOperation:
145
"{}\n\nHint: if you were trying to raise an integer to a negative integer power, please cast your base or exponent to float first.",
146
err
147
))?;
148
pow_to_uint_dtype(ca, exponent.u32().unwrap())
149
}
150
})
151
} else {
152
match base_dtype {
153
#[cfg(feature = "dtype-f16")]
154
DataType::Float16 => {
155
let ca = base.f16().unwrap();
156
let exponent = exponent.strict_cast(&DataType::Float16)?;
157
pow_on_floats(ca, exponent.f16().unwrap())
158
},
159
DataType::Float32 => {
160
let ca = base.f32().unwrap();
161
let exponent = exponent.strict_cast(&DataType::Float32)?;
162
pow_on_floats(ca, exponent.f32().unwrap())
163
},
164
DataType::Float64 => {
165
let ca = base.f64().unwrap();
166
let exponent = exponent.strict_cast(&DataType::Float64)?;
167
pow_on_floats(ca, exponent.f64().unwrap())
168
},
169
_ => unreachable!(),
170
}
171
}
172
}
173
174
pub(super) fn pow(s: &mut [Column]) -> PolarsResult<Column> {
175
let base = &s[0];
176
let exponent = &s[1];
177
178
let base_len = base.len();
179
let exp_len = exponent.len();
180
match (base_len, exp_len) {
181
(1, _) | (_, 1) => pow_on_series(base, exponent),
182
(len_a, len_b) if len_a == len_b => pow_on_series(base, exponent),
183
_ => polars_bail!(
184
ComputeError:
185
"exponent shape: {} in `pow` expression does not match that of the base: {}",
186
exp_len, base_len,
187
),
188
}
189
}
190
191
pub(super) fn sqrt(base: &Column) -> PolarsResult<Column> {
192
match base.dtype() {
193
#[cfg(feature = "dtype-f16")]
194
DataType::Float16 => {
195
let ca = base.f16().unwrap();
196
sqrt_on_floats(ca)
197
},
198
DataType::Float32 => {
199
let ca = base.f32().unwrap();
200
sqrt_on_floats(ca)
201
},
202
DataType::Float64 => {
203
let ca = base.f64().unwrap();
204
sqrt_on_floats(ca)
205
},
206
_ => {
207
let base = base.cast(&DataType::Float64)?;
208
sqrt(&base)
209
},
210
}
211
}
212
213
fn sqrt_on_floats<T>(base: &ChunkedArray<T>) -> PolarsResult<Column>
214
where
215
T: PolarsFloatType,
216
T::Native: Pow<T::Native, Output = T::Native> + ToPrimitive + Float,
217
ChunkedArray<T>: IntoColumn,
218
{
219
Ok(base.apply_values(|v| v.sqrt()).into_column())
220
}
221
222
pub(super) fn cbrt(base: &Column) -> PolarsResult<Column> {
223
match base.dtype() {
224
#[cfg(feature = "dtype-f16")]
225
DataType::Float16 => {
226
let ca = base.f16().unwrap();
227
cbrt_on_floats(ca)
228
},
229
DataType::Float32 => {
230
let ca = base.f32().unwrap();
231
cbrt_on_floats(ca)
232
},
233
DataType::Float64 => {
234
let ca = base.f64().unwrap();
235
cbrt_on_floats(ca)
236
},
237
_ => {
238
let base = base.cast(&DataType::Float64)?;
239
cbrt(&base)
240
},
241
}
242
}
243
244
fn cbrt_on_floats<T>(base: &ChunkedArray<T>) -> PolarsResult<Column>
245
where
246
T: PolarsFloatType,
247
T::Native: Pow<T::Native, Output = T::Native> + ToPrimitive + Float,
248
ChunkedArray<T>: IntoColumn,
249
{
250
Ok(base.apply_values(|v| v.cbrt()).into_column())
251
}
252
253