Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-ops/src/series/ops/is_close.rs
8460 views
1
use std::cmp::max_by;
2
3
use arrow::array::BooleanArray;
4
use arrow::compute::utils::combine_validities_and;
5
use num_traits::AsPrimitive;
6
use polars_core::prelude::arity::apply_binary_kernel_broadcast;
7
use polars_core::prelude::*;
8
9
pub fn is_close(
10
s: &Series,
11
other: &Series,
12
abs_tol: f64,
13
rel_tol: f64,
14
nans_equal: bool,
15
) -> PolarsResult<BooleanChunked> {
16
if abs_tol < 0.0 {
17
polars_bail!(ComputeError: "`abs_tol` must be non-negative but got {}", abs_tol);
18
}
19
if rel_tol < 0.0 {
20
polars_bail!(ComputeError: "`rel_tol` must be non-negative but got {}", rel_tol);
21
}
22
validate_numeric(s.dtype())?;
23
validate_numeric(other.dtype())?;
24
25
let ca = match (s.dtype(), other.dtype()) {
26
#[cfg(feature = "dtype-f16")]
27
(DataType::Float16, DataType::Float16) => apply_binary_kernel_broadcast(
28
s.f16().unwrap(),
29
other.f16().unwrap(),
30
|l, r| is_close_kernel::<Float16Type>(l, r, abs_tol, rel_tol, nans_equal),
31
|v, ca| is_close_kernel_unary::<Float16Type>(ca, v.as_(), abs_tol, rel_tol, nans_equal),
32
|ca, v| is_close_kernel_unary::<Float16Type>(ca, v.as_(), abs_tol, rel_tol, nans_equal),
33
),
34
(DataType::Float32, DataType::Float32) => apply_binary_kernel_broadcast(
35
s.f32().unwrap(),
36
other.f32().unwrap(),
37
|l, r| is_close_kernel::<Float32Type>(l, r, abs_tol, rel_tol, nans_equal),
38
|v, ca| is_close_kernel_unary::<Float32Type>(ca, v.as_(), abs_tol, rel_tol, nans_equal),
39
|ca, v| is_close_kernel_unary::<Float32Type>(ca, v.as_(), abs_tol, rel_tol, nans_equal),
40
),
41
(DataType::Float64, DataType::Float64) => apply_binary_kernel_broadcast(
42
s.f64().unwrap(),
43
other.f64().unwrap(),
44
|l, r| is_close_kernel::<Float64Type>(l, r, abs_tol, rel_tol, nans_equal),
45
|v, ca| is_close_kernel_unary::<Float64Type>(ca, v.as_(), abs_tol, rel_tol, nans_equal),
46
|ca, v| is_close_kernel_unary::<Float64Type>(ca, v.as_(), abs_tol, rel_tol, nans_equal),
47
),
48
_ => apply_binary_kernel_broadcast(
49
s.cast(&DataType::Float64)?.f64().unwrap(),
50
other.cast(&DataType::Float64)?.f64().unwrap(),
51
|l, r| is_close_kernel::<Float64Type>(l, r, abs_tol, rel_tol, nans_equal),
52
|v, ca| is_close_kernel_unary::<Float64Type>(ca, v.as_(), abs_tol, rel_tol, nans_equal),
53
|ca, v| is_close_kernel_unary::<Float64Type>(ca, v.as_(), abs_tol, rel_tol, nans_equal),
54
),
55
};
56
Ok(ca)
57
}
58
59
fn validate_numeric(dtype: &DataType) -> PolarsResult<()> {
60
if !dtype.is_primitive_numeric() && !dtype.is_decimal() {
61
polars_bail!(
62
op = "is_close",
63
dtype,
64
hint = "`is_close` is only supported for numeric types"
65
);
66
}
67
Ok(())
68
}
69
70
/* ------------------------------------------- KERNEL ------------------------------------------ */
71
72
fn is_close_kernel<T>(
73
lhs_arr: &T::Array,
74
rhs_arr: &T::Array,
75
abs_tol: f64,
76
rel_tol: f64,
77
nans_equal: bool,
78
) -> BooleanArray
79
where
80
T: PolarsNumericType,
81
{
82
let validity = combine_validities_and(lhs_arr.validity(), rhs_arr.validity());
83
let element_iter = lhs_arr
84
.values_iter()
85
.zip(rhs_arr.values_iter())
86
.map(|(x, y)| is_close_scalar(x.as_(), y.as_(), abs_tol, rel_tol, nans_equal));
87
let result: BooleanArray = element_iter.collect_arr();
88
result.with_validity_typed(validity)
89
}
90
91
fn is_close_kernel_unary<T>(
92
arr: &T::Array,
93
value: f64,
94
abs_tol: f64,
95
rel_tol: f64,
96
nans_equal: bool,
97
) -> BooleanArray
98
where
99
T: PolarsNumericType,
100
{
101
let validity = arr.validity().cloned();
102
let element_iter = arr
103
.values_iter()
104
.map(|x| is_close_scalar(x.as_(), value, abs_tol, rel_tol, nans_equal));
105
let result: BooleanArray = element_iter.collect_arr();
106
result.with_validity_typed(validity)
107
}
108
109
/* ---------------------------------------- SCALAR LOGIC --------------------------------------- */
110
111
#[inline(always)]
112
fn is_close_scalar(x: f64, y: f64, abs_tol: f64, rel_tol: f64, nans_equal: bool) -> bool {
113
// The logic in this function is taken from https://peps.python.org/pep-0485/.
114
let cmp = (x - y).abs()
115
<= max_by(
116
rel_tol * max_by(x.abs(), y.abs(), f64::total_cmp),
117
abs_tol,
118
f64::total_cmp,
119
);
120
(x.is_finite() && y.is_finite() && cmp)
121
|| (x.is_nan() && y.is_nan() && nans_equal)
122
|| (x.is_infinite() && y.is_infinite() && x.signum() == y.signum())
123
}
124
125