Path: blob/main/crates/polars-ops/src/series/ops/is_close.rs
8460 views
use std::cmp::max_by;12use arrow::array::BooleanArray;3use arrow::compute::utils::combine_validities_and;4use num_traits::AsPrimitive;5use polars_core::prelude::arity::apply_binary_kernel_broadcast;6use polars_core::prelude::*;78pub fn is_close(9s: &Series,10other: &Series,11abs_tol: f64,12rel_tol: f64,13nans_equal: bool,14) -> PolarsResult<BooleanChunked> {15if abs_tol < 0.0 {16polars_bail!(ComputeError: "`abs_tol` must be non-negative but got {}", abs_tol);17}18if rel_tol < 0.0 {19polars_bail!(ComputeError: "`rel_tol` must be non-negative but got {}", rel_tol);20}21validate_numeric(s.dtype())?;22validate_numeric(other.dtype())?;2324let ca = match (s.dtype(), other.dtype()) {25#[cfg(feature = "dtype-f16")]26(DataType::Float16, DataType::Float16) => apply_binary_kernel_broadcast(27s.f16().unwrap(),28other.f16().unwrap(),29|l, r| is_close_kernel::<Float16Type>(l, r, abs_tol, rel_tol, nans_equal),30|v, ca| is_close_kernel_unary::<Float16Type>(ca, v.as_(), abs_tol, rel_tol, nans_equal),31|ca, v| is_close_kernel_unary::<Float16Type>(ca, v.as_(), abs_tol, rel_tol, nans_equal),32),33(DataType::Float32, DataType::Float32) => apply_binary_kernel_broadcast(34s.f32().unwrap(),35other.f32().unwrap(),36|l, r| is_close_kernel::<Float32Type>(l, r, abs_tol, rel_tol, nans_equal),37|v, ca| is_close_kernel_unary::<Float32Type>(ca, v.as_(), abs_tol, rel_tol, nans_equal),38|ca, v| is_close_kernel_unary::<Float32Type>(ca, v.as_(), abs_tol, rel_tol, nans_equal),39),40(DataType::Float64, DataType::Float64) => apply_binary_kernel_broadcast(41s.f64().unwrap(),42other.f64().unwrap(),43|l, r| is_close_kernel::<Float64Type>(l, r, abs_tol, rel_tol, nans_equal),44|v, ca| is_close_kernel_unary::<Float64Type>(ca, v.as_(), abs_tol, rel_tol, nans_equal),45|ca, v| is_close_kernel_unary::<Float64Type>(ca, v.as_(), abs_tol, rel_tol, nans_equal),46),47_ => apply_binary_kernel_broadcast(48s.cast(&DataType::Float64)?.f64().unwrap(),49other.cast(&DataType::Float64)?.f64().unwrap(),50|l, r| is_close_kernel::<Float64Type>(l, r, abs_tol, rel_tol, nans_equal),51|v, ca| is_close_kernel_unary::<Float64Type>(ca, v.as_(), abs_tol, rel_tol, nans_equal),52|ca, v| is_close_kernel_unary::<Float64Type>(ca, v.as_(), abs_tol, rel_tol, nans_equal),53),54};55Ok(ca)56}5758fn validate_numeric(dtype: &DataType) -> PolarsResult<()> {59if !dtype.is_primitive_numeric() && !dtype.is_decimal() {60polars_bail!(61op = "is_close",62dtype,63hint = "`is_close` is only supported for numeric types"64);65}66Ok(())67}6869/* ------------------------------------------- KERNEL ------------------------------------------ */7071fn is_close_kernel<T>(72lhs_arr: &T::Array,73rhs_arr: &T::Array,74abs_tol: f64,75rel_tol: f64,76nans_equal: bool,77) -> BooleanArray78where79T: PolarsNumericType,80{81let validity = combine_validities_and(lhs_arr.validity(), rhs_arr.validity());82let element_iter = lhs_arr83.values_iter()84.zip(rhs_arr.values_iter())85.map(|(x, y)| is_close_scalar(x.as_(), y.as_(), abs_tol, rel_tol, nans_equal));86let result: BooleanArray = element_iter.collect_arr();87result.with_validity_typed(validity)88}8990fn is_close_kernel_unary<T>(91arr: &T::Array,92value: f64,93abs_tol: f64,94rel_tol: f64,95nans_equal: bool,96) -> BooleanArray97where98T: PolarsNumericType,99{100let validity = arr.validity().cloned();101let element_iter = arr102.values_iter()103.map(|x| is_close_scalar(x.as_(), value, abs_tol, rel_tol, nans_equal));104let result: BooleanArray = element_iter.collect_arr();105result.with_validity_typed(validity)106}107108/* ---------------------------------------- SCALAR LOGIC --------------------------------------- */109110#[inline(always)]111fn is_close_scalar(x: f64, y: f64, abs_tol: f64, rel_tol: f64, nans_equal: bool) -> bool {112// The logic in this function is taken from https://peps.python.org/pep-0485/.113let cmp = (x - y).abs()114<= max_by(115rel_tol * max_by(x.abs(), y.abs(), f64::total_cmp),116abs_tol,117f64::total_cmp,118);119(x.is_finite() && y.is_finite() && cmp)120|| (x.is_nan() && y.is_nan() && nans_equal)121|| (x.is_infinite() && y.is_infinite() && x.signum() == y.signum())122}123124125