Path: blob/main/crates/polars-compute/src/arithmetic/signed.rs
6939 views
use arrow::array::{PrimitiveArray as PArr, StaticArray};1use arrow::compute::utils::{combine_validities_and, combine_validities_and3};2use polars_utils::floor_divmod::FloorDivMod;3use strength_reduce::*;45use super::PrimitiveArithmeticKernelImpl;6use crate::arity::{prim_binary_values, prim_unary_values};7use crate::comparisons::TotalEqKernel;89macro_rules! impl_signed_arith_kernel {10($T:ty, $StrRed:ty) => {11impl PrimitiveArithmeticKernelImpl for $T {12type TrueDivT = f64;1314fn prim_wrapping_abs(lhs: PArr<$T>) -> PArr<$T> {15prim_unary_values(lhs, |x| x.wrapping_abs())16}1718fn prim_wrapping_neg(lhs: PArr<$T>) -> PArr<$T> {19prim_unary_values(lhs, |x| x.wrapping_neg())20}2122fn prim_wrapping_add(lhs: PArr<$T>, other: PArr<$T>) -> PArr<$T> {23prim_binary_values(lhs, other, |a, b| a.wrapping_add(b))24}2526fn prim_wrapping_sub(lhs: PArr<$T>, other: PArr<$T>) -> PArr<$T> {27prim_binary_values(lhs, other, |a, b| a.wrapping_sub(b))28}2930fn prim_wrapping_mul(lhs: PArr<$T>, other: PArr<$T>) -> PArr<$T> {31prim_binary_values(lhs, other, |a, b| a.wrapping_mul(b))32}3334fn prim_wrapping_floor_div(mut lhs: PArr<$T>, mut other: PArr<$T>) -> PArr<$T> {35let mask = other.tot_ne_kernel_broadcast(&0);36let valid = combine_validities_and3(37lhs.take_validity().as_ref(), // Take validity so we don't38other.take_validity().as_ref(), // compute combination twice.39Some(&mask),40);41let ret =42prim_binary_values(lhs, other, |lhs, rhs| lhs.wrapping_floor_div_mod(rhs).0);43ret.with_validity(valid)44}4546fn prim_wrapping_trunc_div(mut lhs: PArr<$T>, mut other: PArr<$T>) -> PArr<$T> {47let mask = other.tot_ne_kernel_broadcast(&0);48let valid = combine_validities_and3(49lhs.take_validity().as_ref(), // Take validity so we don't50other.take_validity().as_ref(), // compute combination twice.51Some(&mask),52);53let ret = prim_binary_values(lhs, other, |lhs, rhs| {54if rhs != 0 { lhs.wrapping_div(rhs) } else { 0 }55});56ret.with_validity(valid)57}5859fn prim_wrapping_mod(mut lhs: PArr<$T>, mut other: PArr<$T>) -> PArr<$T> {60let mask = other.tot_ne_kernel_broadcast(&0);61let valid = combine_validities_and3(62lhs.take_validity().as_ref(), // Take validity so we don't63other.take_validity().as_ref(), // compute combination twice.64Some(&mask),65);66let ret =67prim_binary_values(lhs, other, |lhs, rhs| lhs.wrapping_floor_div_mod(rhs).1);68ret.with_validity(valid)69}7071fn prim_wrapping_add_scalar(lhs: PArr<$T>, rhs: $T) -> PArr<$T> {72prim_unary_values(lhs, |x| x.wrapping_add(rhs))73}7475fn prim_wrapping_sub_scalar(lhs: PArr<$T>, rhs: $T) -> PArr<$T> {76Self::prim_wrapping_add_scalar(lhs, rhs.wrapping_neg())77}7879fn prim_wrapping_sub_scalar_lhs(lhs: $T, rhs: PArr<$T>) -> PArr<$T> {80prim_unary_values(rhs, |x| lhs.wrapping_sub(x))81}8283fn prim_wrapping_mul_scalar(lhs: PArr<$T>, rhs: $T) -> PArr<$T> {84let scalar_u = rhs.unsigned_abs();85if rhs == 0 {86lhs.fill_with(0)87} else if rhs == 1 {88lhs89} else if scalar_u.is_power_of_two() {90// Power of two.91let shift = scalar_u.trailing_zeros();92if rhs > 0 {93prim_unary_values(lhs, |x| x << shift)94} else {95prim_unary_values(lhs, |x| (x << shift).wrapping_neg())96}97} else {98prim_unary_values(lhs, |x| x.wrapping_mul(rhs))99}100}101102fn prim_wrapping_floor_div_scalar(lhs: PArr<$T>, rhs: $T) -> PArr<$T> {103if rhs == 0 {104PArr::full_null(lhs.len(), lhs.dtype().clone())105} else if rhs == -1 {106Self::prim_wrapping_neg(lhs)107} else if rhs == 1 {108lhs109} else {110let red = <$StrRed>::new(rhs.unsigned_abs());111prim_unary_values(lhs, |x| {112let (quot, rem) = <$StrRed>::div_rem(x.unsigned_abs(), red);113if (x < 0) != (rhs < 0) {114// Different signs: result should be negative.115// Since we handled rhs.abs() <= 1, quot fits.116let mut ret = -(quot as $T);117if rem != 0 {118// Division had remainder, subtract 1 to floor to119// negative infinity, as we truncated to zero.120ret -= 1;121}122ret123} else {124quot as $T125}126})127}128}129130fn prim_wrapping_floor_div_scalar_lhs(lhs: $T, rhs: PArr<$T>) -> PArr<$T> {131let mask = rhs.tot_ne_kernel_broadcast(&0);132let valid = combine_validities_and(rhs.validity(), Some(&mask));133let ret = if lhs == 0 {134rhs.fill_with(0)135} else {136prim_unary_values(rhs, |x| lhs.wrapping_floor_div_mod(x).0)137};138ret.with_validity(valid)139}140141fn prim_wrapping_trunc_div_scalar(lhs: PArr<$T>, rhs: $T) -> PArr<$T> {142if rhs == 0 {143PArr::full_null(lhs.len(), lhs.dtype().clone())144} else if rhs == -1 {145Self::prim_wrapping_neg(lhs)146} else if rhs == 1 {147lhs148} else {149let red = <$StrRed>::new(rhs.unsigned_abs());150prim_unary_values(lhs, |x| {151let quot = x.unsigned_abs() / red;152if (x < 0) != (rhs < 0) {153// Different signs: result should be negative.154-(quot as $T)155} else {156quot as $T157}158})159}160}161162fn prim_wrapping_trunc_div_scalar_lhs(lhs: $T, rhs: PArr<$T>) -> PArr<$T> {163let mask = rhs.tot_ne_kernel_broadcast(&0);164let valid = combine_validities_and(rhs.validity(), Some(&mask));165let ret = if lhs == 0 {166rhs.fill_with(0)167} else {168prim_unary_values(rhs, |x| if x != 0 { lhs.wrapping_div(x) } else { 0 })169};170ret.with_validity(valid)171}172173fn prim_wrapping_mod_scalar(lhs: PArr<$T>, rhs: $T) -> PArr<$T> {174if rhs == 0 {175PArr::full_null(lhs.len(), lhs.dtype().clone())176} else if rhs == -1 || rhs == 1 {177lhs.fill_with(0)178} else {179let scalar_u = rhs.unsigned_abs();180let red = <$StrRed>::new(scalar_u);181prim_unary_values(lhs, |x| {182// Remainder fits in signed type after reduction.183// Largest possible modulo -I::MIN, with184// -I::MIN-1 == I::MAX as largest remainder.185let mut rem_u = x.unsigned_abs() % red;186187// Mixed signs: swap direction of remainder.188if rem_u != 0 && (rhs < 0) != (x < 0) {189rem_u = scalar_u - rem_u;190}191192// Remainder should have sign of RHS.193if rhs < 0 { -(rem_u as $T) } else { rem_u as $T }194})195}196}197198fn prim_wrapping_mod_scalar_lhs(lhs: $T, rhs: PArr<$T>) -> PArr<$T> {199let mask = rhs.tot_ne_kernel_broadcast(&0);200let valid = combine_validities_and(rhs.validity(), Some(&mask));201let ret = if lhs == 0 {202rhs.fill_with(0)203} else {204prim_unary_values(rhs, |x| lhs.wrapping_floor_div_mod(x).1)205};206ret.with_validity(valid)207}208209fn prim_checked_mul_scalar(lhs: PArr<$T>, rhs: $T) -> PArr<$T> {210super::prim_checked_mul_scalar(&lhs, rhs)211}212213fn prim_true_div(lhs: PArr<$T>, other: PArr<$T>) -> PArr<Self::TrueDivT> {214prim_binary_values(lhs, other, |a, b| a as f64 / b as f64)215}216217fn prim_true_div_scalar(lhs: PArr<$T>, rhs: $T) -> PArr<Self::TrueDivT> {218let inv = 1.0 / rhs as f64;219prim_unary_values(lhs, |x| x as f64 * inv)220}221222fn prim_true_div_scalar_lhs(lhs: $T, rhs: PArr<$T>) -> PArr<Self::TrueDivT> {223prim_unary_values(rhs, |x| lhs as f64 / x as f64)224}225}226};227}228229impl_signed_arith_kernel!(i8, StrengthReducedU8);230impl_signed_arith_kernel!(i16, StrengthReducedU16);231impl_signed_arith_kernel!(i32, StrengthReducedU32);232impl_signed_arith_kernel!(i64, StrengthReducedU64);233impl_signed_arith_kernel!(i128, StrengthReducedU128);234235236