Path: blob/main/crates/polars-compute/src/arithmetic/mod.rs
6939 views
use std::any::TypeId;12use arrow::array::{Array, PrimitiveArray};3use arrow::bitmap::BitmapBuilder;4use arrow::types::NativeType;56// Low-level comparison kernel.7pub trait ArithmeticKernel: Sized + Array {8type Scalar;9type TrueDivT: NativeType;1011fn wrapping_abs(self) -> Self;12fn wrapping_neg(self) -> Self;13fn wrapping_add(self, rhs: Self) -> Self;14fn wrapping_sub(self, rhs: Self) -> Self;15fn wrapping_mul(self, rhs: Self) -> Self;16fn wrapping_floor_div(self, rhs: Self) -> Self;17fn wrapping_trunc_div(self, rhs: Self) -> Self;18fn wrapping_mod(self, rhs: Self) -> Self;1920fn wrapping_add_scalar(self, rhs: Self::Scalar) -> Self;21fn wrapping_sub_scalar(self, rhs: Self::Scalar) -> Self;22fn wrapping_sub_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self;23fn wrapping_mul_scalar(self, rhs: Self::Scalar) -> Self;24fn wrapping_floor_div_scalar(self, rhs: Self::Scalar) -> Self;25fn wrapping_floor_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self;26fn wrapping_trunc_div_scalar(self, rhs: Self::Scalar) -> Self;27fn wrapping_trunc_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self;28fn wrapping_mod_scalar(self, rhs: Self::Scalar) -> Self;29fn wrapping_mod_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self;3031fn checked_mul_scalar(self, rhs: Self::Scalar) -> Self;3233fn true_div(self, rhs: Self) -> PrimitiveArray<Self::TrueDivT>;34fn true_div_scalar(self, rhs: Self::Scalar) -> PrimitiveArray<Self::TrueDivT>;35fn true_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> PrimitiveArray<Self::TrueDivT>;3637// TODO: remove these.38// These are flooring division for integer types, true division for floating point types.39fn legacy_div(self, rhs: Self) -> Self {40if TypeId::of::<Self>() == TypeId::of::<PrimitiveArray<Self::TrueDivT>>() {41let ret = self.true_div(rhs);42unsafe {43let cast_ret = std::mem::transmute_copy(&ret);44std::mem::forget(ret);45cast_ret46}47} else {48self.wrapping_floor_div(rhs)49}50}51fn legacy_div_scalar(self, rhs: Self::Scalar) -> Self {52if TypeId::of::<Self>() == TypeId::of::<PrimitiveArray<Self::TrueDivT>>() {53let ret = self.true_div_scalar(rhs);54unsafe {55let cast_ret = std::mem::transmute_copy(&ret);56std::mem::forget(ret);57cast_ret58}59} else {60self.wrapping_floor_div_scalar(rhs)61}62}6364fn legacy_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self {65if TypeId::of::<Self>() == TypeId::of::<PrimitiveArray<Self::TrueDivT>>() {66let ret = ArithmeticKernel::true_div_scalar_lhs(lhs, rhs);67unsafe {68let cast_ret = std::mem::transmute_copy(&ret);69std::mem::forget(ret);70cast_ret71}72} else {73ArithmeticKernel::wrapping_floor_div_scalar_lhs(lhs, rhs)74}75}76}7778// Proxy trait so one can bound T: HasPrimitiveArithmeticKernel. Sadly Rust79// doesn't support adding supertraits for other types.80#[allow(private_bounds)]81pub trait HasPrimitiveArithmeticKernel: NativeType + PrimitiveArithmeticKernelImpl {}82impl<T: NativeType + PrimitiveArithmeticKernelImpl> HasPrimitiveArithmeticKernel for T {}8384use PrimitiveArray as PArr;85use num_traits::{CheckedMul, WrappingMul};86use polars_utils::vec::PushUnchecked;8788#[doc(hidden)]89pub trait PrimitiveArithmeticKernelImpl: NativeType {90type TrueDivT: NativeType;9192fn prim_wrapping_abs(lhs: PArr<Self>) -> PArr<Self>;93fn prim_wrapping_neg(lhs: PArr<Self>) -> PArr<Self>;94fn prim_wrapping_add(lhs: PArr<Self>, rhs: PArr<Self>) -> PArr<Self>;95fn prim_wrapping_sub(lhs: PArr<Self>, rhs: PArr<Self>) -> PArr<Self>;96fn prim_wrapping_mul(lhs: PArr<Self>, rhs: PArr<Self>) -> PArr<Self>;97fn prim_wrapping_floor_div(lhs: PArr<Self>, rhs: PArr<Self>) -> PArr<Self>;98fn prim_wrapping_trunc_div(lhs: PArr<Self>, rhs: PArr<Self>) -> PArr<Self>;99fn prim_wrapping_mod(lhs: PArr<Self>, rhs: PArr<Self>) -> PArr<Self>;100101fn prim_wrapping_add_scalar(lhs: PArr<Self>, rhs: Self) -> PArr<Self>;102fn prim_wrapping_sub_scalar(lhs: PArr<Self>, rhs: Self) -> PArr<Self>;103fn prim_wrapping_sub_scalar_lhs(lhs: Self, rhs: PArr<Self>) -> PArr<Self>;104fn prim_wrapping_mul_scalar(lhs: PArr<Self>, rhs: Self) -> PArr<Self>;105fn prim_wrapping_floor_div_scalar(lhs: PArr<Self>, rhs: Self) -> PArr<Self>;106fn prim_wrapping_floor_div_scalar_lhs(lhs: Self, rhs: PArr<Self>) -> PArr<Self>;107fn prim_wrapping_trunc_div_scalar(lhs: PArr<Self>, rhs: Self) -> PArr<Self>;108fn prim_wrapping_trunc_div_scalar_lhs(lhs: Self, rhs: PArr<Self>) -> PArr<Self>;109fn prim_wrapping_mod_scalar(lhs: PArr<Self>, rhs: Self) -> PArr<Self>;110fn prim_wrapping_mod_scalar_lhs(lhs: Self, rhs: PArr<Self>) -> PArr<Self>;111112fn prim_checked_mul_scalar(lhs: PArr<Self>, rhs: Self) -> PArr<Self>;113114fn prim_true_div(lhs: PArr<Self>, rhs: PArr<Self>) -> PArr<Self::TrueDivT>;115fn prim_true_div_scalar(lhs: PArr<Self>, rhs: Self) -> PArr<Self::TrueDivT>;116fn prim_true_div_scalar_lhs(lhs: Self, rhs: PArr<Self>) -> PArr<Self::TrueDivT>;117}118119#[rustfmt::skip]120impl<T: HasPrimitiveArithmeticKernel> ArithmeticKernel for PrimitiveArray<T> {121type Scalar = T;122type TrueDivT = T::TrueDivT;123124fn wrapping_abs(self) -> Self { T::prim_wrapping_abs(self) }125fn wrapping_neg(self) -> Self { T::prim_wrapping_neg(self) }126fn wrapping_add(self, rhs: Self) -> Self { T::prim_wrapping_add(self, rhs) }127fn wrapping_sub(self, rhs: Self) -> Self { T::prim_wrapping_sub(self, rhs) }128fn wrapping_mul(self, rhs: Self) -> Self { T::prim_wrapping_mul(self, rhs) }129fn wrapping_floor_div(self, rhs: Self) -> Self { T::prim_wrapping_floor_div(self, rhs) }130fn wrapping_trunc_div(self, rhs: Self) -> Self { T::prim_wrapping_trunc_div(self, rhs) }131fn wrapping_mod(self, rhs: Self) -> Self { T::prim_wrapping_mod(self, rhs) }132133fn wrapping_add_scalar(self, rhs: Self::Scalar) -> Self { T::prim_wrapping_add_scalar(self, rhs) }134fn wrapping_sub_scalar(self, rhs: Self::Scalar) -> Self { T::prim_wrapping_sub_scalar(self, rhs) }135fn wrapping_sub_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self { T::prim_wrapping_sub_scalar_lhs(lhs, rhs) }136fn wrapping_mul_scalar(self, rhs: Self::Scalar) -> Self { T::prim_wrapping_mul_scalar(self, rhs) }137fn wrapping_floor_div_scalar(self, rhs: Self::Scalar) -> Self { T::prim_wrapping_floor_div_scalar(self, rhs) }138fn wrapping_floor_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self { T::prim_wrapping_floor_div_scalar_lhs(lhs, rhs) }139fn wrapping_trunc_div_scalar(self, rhs: Self::Scalar) -> Self { T::prim_wrapping_trunc_div_scalar(self, rhs) }140fn wrapping_trunc_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self { T::prim_wrapping_trunc_div_scalar_lhs(lhs, rhs) }141fn wrapping_mod_scalar(self, rhs: Self::Scalar) -> Self { T::prim_wrapping_mod_scalar(self, rhs) }142fn wrapping_mod_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self { T::prim_wrapping_mod_scalar_lhs(lhs, rhs) }143144fn checked_mul_scalar(self, rhs: Self::Scalar) -> Self { T::prim_checked_mul_scalar(self, rhs) }145146fn true_div(self, rhs: Self) -> PrimitiveArray<Self::TrueDivT> { T::prim_true_div(self, rhs) }147fn true_div_scalar(self, rhs: Self::Scalar) -> PrimitiveArray<Self::TrueDivT> { T::prim_true_div_scalar(self, rhs) }148fn true_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> PrimitiveArray<Self::TrueDivT> { T::prim_true_div_scalar_lhs(lhs, rhs) }149}150151mod float;152pub mod pl_num;153mod signed;154mod unsigned;155156fn prim_checked_mul_scalar<I: NativeType + CheckedMul + WrappingMul>(157array: &PrimitiveArray<I>,158factor: I,159) -> PrimitiveArray<I> {160let values = array.values();161let mut out = Vec::with_capacity(array.len());162let mut i = 0;163164while i < array.len() && values[i].checked_mul(&factor).is_some() {165// SAFETY: We allocated enough before.166unsafe { out.push_unchecked(values[i].wrapping_mul(&factor)) };167i += 1;168}169170if out.len() == array.len() {171return PrimitiveArray::<I>::new(172I::PRIMITIVE.into(),173out.into(),174array.validity().cloned(),175);176}177178let mut validity = BitmapBuilder::with_capacity(array.len());179validity.extend_constant(out.len(), true);180181for &value in &values[out.len()..] {182// SAFETY: We allocated enough before.183unsafe {184out.push_unchecked(value.wrapping_mul(&factor));185validity.push_unchecked(value.checked_mul(&factor).is_some());186}187}188189debug_assert_eq!(out.len(), array.len());190debug_assert_eq!(validity.len(), array.len());191192let validity = validity.freeze();193let validity = match array.validity() {194None => validity,195Some(arr_validity) => arrow::bitmap::and(&validity, arr_validity),196};197198PrimitiveArray::<I>::new(I::PRIMITIVE.into(), out.into(), Some(validity))199}200201202