Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-ops/src/series/ops/is_close.rs
6939 views
1
use std::cmp::max_by;
2
3
use arrow::array::BooleanArray;
4
use arrow::compute::utils::combine_validities_and;
5
use num_traits::AsPrimitive;
6
use polars_core::prelude::arity::apply_binary_kernel_broadcast;
7
use polars_core::prelude::*;
8
9
pub fn is_close(
10
s: &Series,
11
other: &Series,
12
abs_tol: f64,
13
rel_tol: f64,
14
nans_equal: bool,
15
) -> PolarsResult<BooleanChunked> {
16
if abs_tol < 0.0 {
17
polars_bail!(ComputeError: "`abs_tol` must be non-negative but got {}", abs_tol);
18
}
19
if rel_tol < 0.0 {
20
polars_bail!(ComputeError: "`rel_tol` must be non-negative but got {}", rel_tol);
21
}
22
validate_numeric(s.dtype())?;
23
validate_numeric(other.dtype())?;
24
25
let ca = match (s.dtype(), other.dtype()) {
26
(DataType::Float32, DataType::Float32) => apply_binary_kernel_broadcast(
27
s.f32().unwrap(),
28
other.f32().unwrap(),
29
|l, r| is_close_kernel::<Float32Type>(l, r, abs_tol, rel_tol, nans_equal),
30
|v, ca| is_close_kernel_unary::<Float32Type>(ca, v.as_(), abs_tol, rel_tol, nans_equal),
31
|ca, v| is_close_kernel_unary::<Float32Type>(ca, v.as_(), abs_tol, rel_tol, nans_equal),
32
),
33
(DataType::Float64, DataType::Float64) => apply_binary_kernel_broadcast(
34
s.f64().unwrap(),
35
other.f64().unwrap(),
36
|l, r| is_close_kernel::<Float64Type>(l, r, abs_tol, rel_tol, nans_equal),
37
|v, ca| is_close_kernel_unary::<Float64Type>(ca, v.as_(), abs_tol, rel_tol, nans_equal),
38
|ca, v| is_close_kernel_unary::<Float64Type>(ca, v.as_(), abs_tol, rel_tol, nans_equal),
39
),
40
_ => apply_binary_kernel_broadcast(
41
s.cast(&DataType::Float64)?.f64().unwrap(),
42
other.cast(&DataType::Float64)?.f64().unwrap(),
43
|l, r| is_close_kernel::<Float64Type>(l, r, abs_tol, rel_tol, nans_equal),
44
|v, ca| is_close_kernel_unary::<Float64Type>(ca, v.as_(), abs_tol, rel_tol, nans_equal),
45
|ca, v| is_close_kernel_unary::<Float64Type>(ca, v.as_(), abs_tol, rel_tol, nans_equal),
46
),
47
};
48
Ok(ca)
49
}
50
51
fn validate_numeric(dtype: &DataType) -> PolarsResult<()> {
52
if !dtype.is_primitive_numeric() && !dtype.is_decimal() {
53
polars_bail!(
54
op = "is_close",
55
dtype,
56
hint = "`is_close` is only supported for numeric types"
57
);
58
}
59
Ok(())
60
}
61
62
/* ------------------------------------------- KERNEL ------------------------------------------ */
63
64
fn is_close_kernel<T>(
65
lhs_arr: &T::Array,
66
rhs_arr: &T::Array,
67
abs_tol: f64,
68
rel_tol: f64,
69
nans_equal: bool,
70
) -> BooleanArray
71
where
72
T: PolarsNumericType,
73
{
74
let validity = combine_validities_and(lhs_arr.validity(), rhs_arr.validity());
75
let element_iter = lhs_arr
76
.values_iter()
77
.zip(rhs_arr.values_iter())
78
.map(|(x, y)| is_close_scalar(x.as_(), y.as_(), abs_tol, rel_tol, nans_equal));
79
let result: BooleanArray = element_iter.collect_arr();
80
result.with_validity_typed(validity)
81
}
82
83
fn is_close_kernel_unary<T>(
84
arr: &T::Array,
85
value: f64,
86
abs_tol: f64,
87
rel_tol: f64,
88
nans_equal: bool,
89
) -> BooleanArray
90
where
91
T: PolarsNumericType,
92
{
93
let validity = arr.validity().cloned();
94
let element_iter = arr
95
.values_iter()
96
.map(|x| is_close_scalar(x.as_(), value, abs_tol, rel_tol, nans_equal));
97
let result: BooleanArray = element_iter.collect_arr();
98
result.with_validity_typed(validity)
99
}
100
101
/* ---------------------------------------- SCALAR LOGIC --------------------------------------- */
102
103
#[inline(always)]
104
fn is_close_scalar(x: f64, y: f64, abs_tol: f64, rel_tol: f64, nans_equal: bool) -> bool {
105
// The logic in this function is taken from https://peps.python.org/pep-0485/.
106
let cmp = (x - y).abs()
107
<= max_by(
108
rel_tol * max_by(x.abs(), y.abs(), f64::total_cmp),
109
abs_tol,
110
f64::total_cmp,
111
);
112
(x.is_finite() && y.is_finite() && cmp)
113
|| (x.is_nan() && y.is_nan() && nans_equal)
114
|| (x.is_infinite() && y.is_infinite() && x.signum() == y.signum())
115
}
116
117