Path: blob/main/crates/polars-ops/src/series/ops/is_close.rs
6939 views
use std::cmp::max_by;12use arrow::array::BooleanArray;3use arrow::compute::utils::combine_validities_and;4use num_traits::AsPrimitive;5use polars_core::prelude::arity::apply_binary_kernel_broadcast;6use polars_core::prelude::*;78pub fn is_close(9s: &Series,10other: &Series,11abs_tol: f64,12rel_tol: f64,13nans_equal: bool,14) -> PolarsResult<BooleanChunked> {15if abs_tol < 0.0 {16polars_bail!(ComputeError: "`abs_tol` must be non-negative but got {}", abs_tol);17}18if rel_tol < 0.0 {19polars_bail!(ComputeError: "`rel_tol` must be non-negative but got {}", rel_tol);20}21validate_numeric(s.dtype())?;22validate_numeric(other.dtype())?;2324let ca = match (s.dtype(), other.dtype()) {25(DataType::Float32, DataType::Float32) => apply_binary_kernel_broadcast(26s.f32().unwrap(),27other.f32().unwrap(),28|l, r| is_close_kernel::<Float32Type>(l, r, abs_tol, rel_tol, nans_equal),29|v, ca| is_close_kernel_unary::<Float32Type>(ca, v.as_(), abs_tol, rel_tol, nans_equal),30|ca, v| is_close_kernel_unary::<Float32Type>(ca, v.as_(), abs_tol, rel_tol, nans_equal),31),32(DataType::Float64, DataType::Float64) => apply_binary_kernel_broadcast(33s.f64().unwrap(),34other.f64().unwrap(),35|l, r| is_close_kernel::<Float64Type>(l, r, abs_tol, rel_tol, nans_equal),36|v, ca| is_close_kernel_unary::<Float64Type>(ca, v.as_(), abs_tol, rel_tol, nans_equal),37|ca, v| is_close_kernel_unary::<Float64Type>(ca, v.as_(), abs_tol, rel_tol, nans_equal),38),39_ => apply_binary_kernel_broadcast(40s.cast(&DataType::Float64)?.f64().unwrap(),41other.cast(&DataType::Float64)?.f64().unwrap(),42|l, r| is_close_kernel::<Float64Type>(l, r, abs_tol, rel_tol, nans_equal),43|v, ca| is_close_kernel_unary::<Float64Type>(ca, v.as_(), abs_tol, rel_tol, nans_equal),44|ca, v| is_close_kernel_unary::<Float64Type>(ca, v.as_(), abs_tol, rel_tol, nans_equal),45),46};47Ok(ca)48}4950fn validate_numeric(dtype: &DataType) -> PolarsResult<()> {51if !dtype.is_primitive_numeric() && !dtype.is_decimal() {52polars_bail!(53op = "is_close",54dtype,55hint = "`is_close` is only supported for numeric types"56);57}58Ok(())59}6061/* ------------------------------------------- KERNEL ------------------------------------------ */6263fn is_close_kernel<T>(64lhs_arr: &T::Array,65rhs_arr: &T::Array,66abs_tol: f64,67rel_tol: f64,68nans_equal: bool,69) -> BooleanArray70where71T: PolarsNumericType,72{73let validity = combine_validities_and(lhs_arr.validity(), rhs_arr.validity());74let element_iter = lhs_arr75.values_iter()76.zip(rhs_arr.values_iter())77.map(|(x, y)| is_close_scalar(x.as_(), y.as_(), abs_tol, rel_tol, nans_equal));78let result: BooleanArray = element_iter.collect_arr();79result.with_validity_typed(validity)80}8182fn is_close_kernel_unary<T>(83arr: &T::Array,84value: f64,85abs_tol: f64,86rel_tol: f64,87nans_equal: bool,88) -> BooleanArray89where90T: PolarsNumericType,91{92let validity = arr.validity().cloned();93let element_iter = arr94.values_iter()95.map(|x| is_close_scalar(x.as_(), value, abs_tol, rel_tol, nans_equal));96let result: BooleanArray = element_iter.collect_arr();97result.with_validity_typed(validity)98}99100/* ---------------------------------------- SCALAR LOGIC --------------------------------------- */101102#[inline(always)]103fn is_close_scalar(x: f64, y: f64, abs_tol: f64, rel_tol: f64, nans_equal: bool) -> bool {104// The logic in this function is taken from https://peps.python.org/pep-0485/.105let cmp = (x - y).abs()106<= max_by(107rel_tol * max_by(x.abs(), y.abs(), f64::total_cmp),108abs_tol,109f64::total_cmp,110);111(x.is_finite() && y.is_finite() && cmp)112|| (x.is_nan() && y.is_nan() && nans_equal)113|| (x.is_infinite() && y.is_infinite() && x.signum() == y.signum())114}115116117