Path: blob/main/crates/polars-ops/src/series/ops/round.rs
8446 views
use num_traits::AsPrimitive;1use polars_core::prelude::*;2use polars_core::with_match_physical_numeric_polars_type;3use polars_utils::float16::pf16;4#[cfg(feature = "serde")]5use serde::{Deserialize, Serialize};6use strum_macros::IntoStaticStr;78use crate::series::ops::SeriesSealed;910#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, IntoStaticStr)]11#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]12#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]13#[strum(serialize_all = "snake_case")]14#[derive(Default)]15pub enum RoundMode {16#[default]17HalfToEven,18HalfAwayFromZero,19}2021pub trait RoundSeries: SeriesSealed {22/// Round underlying floating point array to given decimal.23fn round(&self, decimals: u32, mode: RoundMode) -> PolarsResult<Series> {24let s = self.as_series();2526#[cfg(feature = "dtype-f16")]27if let Ok(ca) = s.f16() {28match mode {29RoundMode::HalfToEven => {30return if decimals == 0 {31let s = ca32.apply_values(|val| f32::from(val).round_ties_even().into())33.into_series();34Ok(s)35} else if decimals >= 11 {36// More precise than smallest denormal.37Ok(s.clone())38} else {39let multiplier = 10.0_f32.powi(decimals as i32);40let s = ca41.apply_values(|val| {42let val_f32: f32 = val.into();43let ret: pf16 =44((val_f32 * multiplier).round_ties_even() / multiplier).into();45if ret.is_finite() {46ret47} else {48// We return the original value which is correct both for overflows and non-finite inputs.49val50}51})52.into_series();53Ok(s)54};55},56RoundMode::HalfAwayFromZero => {57return if decimals == 0 {58let s = ca59.apply_values(|val| f32::from(val).round().into())60.into_series();61Ok(s)62} else if decimals >= 11 {63// More precise than smallest denormal.64Ok(s.clone())65} else {66let multiplier = 10.0_f32.powi(decimals as i32);67let s = ca68.apply_values(|val| {69let val_f32: f32 = val.into();70let ret: pf16 =71((val_f32 * multiplier).round() / multiplier).into();72if ret.is_finite() {73ret74} else {75// We return the original value which is correct both for overflows and non-finite inputs.76val77}78})79.into_series();80Ok(s)81};82},83}84}85if let Ok(ca) = s.f32() {86match mode {87RoundMode::HalfToEven => {88return if decimals == 0 {89let s = ca.apply_values(|val| val.round_ties_even()).into_series();90Ok(s)91} else if decimals >= 47 {92// More precise than smallest denormal.93Ok(s.clone())94} else {95// Note we do the computation on f64 floats to not lose precision96// when the computation is done, we cast to f3297let multiplier = 10.0_f64.powi(decimals as i32);98let s = ca99.apply_values(|val| {100let ret = ((val as f64 * multiplier).round_ties_even() / multiplier)101as f32;102if ret.is_finite() {103ret104} else {105// We return the original value which is correct both for overflows and non-finite inputs.106val107}108})109.into_series();110Ok(s)111};112},113RoundMode::HalfAwayFromZero => {114return if decimals == 0 {115let s = ca.apply_values(|val| val.round()).into_series();116Ok(s)117} else if decimals >= 47 {118// More precise than smallest denormal.119Ok(s.clone())120} else {121// Note we do the computation on f64 floats to not lose precision122// when the computation is done, we cast to f32123let multiplier = 10.0_f64.powi(decimals as i32);124let s = ca125.apply_values(|val| {126let ret = ((val as f64 * multiplier).round() / multiplier) as f32;127if ret.is_finite() {128ret129} else {130// We return the original value which is correct both for overflows and non-finite inputs.131val132}133})134.into_series();135Ok(s)136};137},138}139}140if let Ok(ca) = s.f64() {141match mode {142RoundMode::HalfToEven => {143return if decimals == 0 {144let s = ca.apply_values(|val| val.round_ties_even()).into_series();145Ok(s)146} else if decimals >= 326 {147// More precise than smallest denormal.148Ok(s.clone())149} else if decimals >= 300 {150// We're getting into unrepresentable territory for the multiplier151// here, split up the 10^n multiplier into 2^n and 5^n.152let mul2 = libm::scalbn(1.0, decimals as i32);153let invmul2 = 1.0 / mul2; // Still exact for any valid value of decimals.154let mul5 = 5.0_f64.powi(decimals as i32);155let s = ca156.apply_values(|val| {157let ret = (val * mul2 * mul5).round_ties_even() / mul5 * invmul2;158if ret.is_finite() {159ret160} else {161// We return the original value which is correct both for overflows and non-finite inputs.162val163}164})165.into_series();166Ok(s)167} else {168let multiplier = 10.0_f64.powi(decimals as i32);169let s = ca170.apply_values(|val| {171let ret = (val * multiplier).round_ties_even() / multiplier;172if ret.is_finite() {173ret174} else {175// We return the original value which is correct both for overflows and non-finite inputs.176val177}178})179.into_series();180Ok(s)181};182},183RoundMode::HalfAwayFromZero => {184return if decimals == 0 {185let s = ca.apply_values(|val| val.round()).into_series();186Ok(s)187} else if decimals >= 326 {188// More precise than smallest denormal.189Ok(s.clone())190} else if decimals >= 300 {191// We're getting into unrepresentable territory for the multiplier192// here, split up the 10^n multiplier into 2^n and 5^n.193let mul2 = libm::scalbn(1.0, decimals as i32);194let invmul2 = 1.0 / mul2; // Still exact for any valid value of decimals.195let mul5 = 5.0_f64.powi(decimals as i32);196let s = ca197.apply_values(|val| {198let ret = (val * mul2 * mul5).round() / mul5 * invmul2;199if ret.is_finite() {200ret201} else {202// We return the original value which is correct both for overflows and non-finite inputs.203val204}205})206.into_series();207Ok(s)208} else {209let multiplier = 10.0_f64.powi(decimals as i32);210let s = ca211.apply_values(|val| {212let ret = (val * multiplier).round() / multiplier;213if ret.is_finite() {214ret215} else {216// We return the original value which is correct both for overflows and non-finite inputs.217val218}219})220.into_series();221Ok(s)222};223},224}225}226#[cfg(feature = "dtype-decimal")]227if let Some(ca) = s.try_decimal() {228let scale = ca.scale() as u32;229230if scale <= decimals {231return Ok(ca.clone().into_series());232}233234let decimal_delta = scale - decimals;235let multiplier = 10i128.pow(decimal_delta);236let threshold = multiplier / 2;237238let res = match mode {239RoundMode::HalfToEven => ca.physical().apply_values(|v| {240let rem_big = v % (2 * multiplier);241let is_v_floor_even = rem_big.abs() < multiplier;242let rem = if is_v_floor_even {243rem_big244} else if rem_big > 0 {245rem_big - multiplier246} else {247rem_big + multiplier248};249250let threshold = threshold + i128::from(is_v_floor_even);251let round_offset = if rem.abs() >= threshold {252if v < 0 { -multiplier } else { multiplier }253} else {2540255};256v - rem + round_offset257}),258RoundMode::HalfAwayFromZero => ca.physical().apply_values(|v| {259let rem = v % multiplier;260let round_offset = if rem.abs() >= threshold {261if v < 0 { -multiplier } else { multiplier }262} else {2630264};265v - rem + round_offset266}),267};268return Ok(res269.into_decimal_unchecked(ca.precision(), scale as usize)270.into_series());271}272273polars_ensure!(s.dtype().is_integer(), InvalidOperation: "round can only be used on numeric types" );274Ok(s.clone())275}276277fn round_sig_figs(&self, digits: i32) -> PolarsResult<Series> {278let s = self.as_series();279polars_ensure!(digits >= 1, InvalidOperation: "digits must be an integer >= 1");280281#[cfg(feature = "dtype-decimal")]282if let Some(ca) = s.try_decimal() {283let precision = ca.precision();284let scale = ca.scale() as u32;285286let s = ca287.physical()288.apply_values(|v| {289if v == 0 {290return 0;291}292293let mut magnitude = v.abs().ilog10();294let magnitude_mult = 10i128.pow(magnitude); // @Q? It might be better to do this with a295// LUT.296if v.abs() > magnitude_mult {297magnitude += 1;298}299let decimals = magnitude.saturating_sub(digits as u32);300let multiplier = 10i128.pow(decimals); // @Q? It might be better to do this with a301// LUT.302let threshold = multiplier / 2;303304// We use rounding=ROUND_HALF_EVEN305let rem = v % multiplier;306let is_v_floor_even = decimals <= scale && ((v - rem) / multiplier) % 2 == 0;307let threshold = threshold + i128::from(is_v_floor_even);308let round_offset = if rem.abs() >= threshold {309multiplier310} else {3110312};313let round_offset = if v < 0 { -round_offset } else { round_offset };314v - rem + round_offset315})316.into_decimal_unchecked(precision, scale as usize)317.into_series();318319return Ok(s);320}321322polars_ensure!(s.dtype().is_primitive_numeric(), InvalidOperation: "round_sig_figs can only be used on numeric types" );323with_match_physical_numeric_polars_type!(s.dtype(), |$T| {324let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();325let s = ca.apply_values(|value| {326let value = AsPrimitive::<f64>::as_(value);327if value == 0.0 {328return AsPrimitive::<<$T as PolarsNumericType>::Native>::as_(value);329}330// To deal with very large/small numbers we split up 10^n in 5^n and 2^n.331// The scaling by 2^n is almost always lossless.332let exp = digits - 1 - value.abs().log10().floor() as i32;333let pow5 = 5.0_f64.powi(exp);334let scaled = libm::scalbn(value, exp) * pow5;335let descaled = libm::scalbn(scaled.round() / pow5, -exp);336AsPrimitive::<<$T as PolarsNumericType>::Native>::as_(337if descaled.is_finite() { descaled } else { value }338)339}).into_series();340return Ok(s);341});342}343344/// Floor underlying floating point array to the lowest integers smaller or equal to the float value.345fn floor(&self) -> PolarsResult<Series> {346let s = self.as_series();347348if let Ok(ca) = s.f32() {349let s = ca.apply_values(|val| val.floor()).into_series();350return Ok(s);351}352if let Ok(ca) = s.f64() {353let s = ca.apply_values(|val| val.floor()).into_series();354return Ok(s);355}356#[cfg(feature = "dtype-decimal")]357if let Some(ca) = s.try_decimal() {358let precision = ca.precision();359let scale = ca.scale() as u32;360if scale == 0 {361return Ok(ca.clone().into_series());362}363364let decimal_delta = scale;365let multiplier = 10i128.pow(decimal_delta);366367let ca = ca368.physical()369.apply_values(|v| {370let rem = v % multiplier;371let round_offset = if v < 0 { multiplier + rem } else { rem };372let round_offset = if rem == 0 { 0 } else { round_offset };373v - round_offset374})375.into_decimal_unchecked(precision, scale as usize);376377return Ok(ca.into_series());378}379380polars_ensure!(s.dtype().is_primitive_numeric(), InvalidOperation: "floor can only be used on numeric types" );381Ok(s.clone())382}383384/// Ceil underlying floating point array to the highest integers smaller or equal to the float value.385fn ceil(&self) -> PolarsResult<Series> {386let s = self.as_series();387388if let Ok(ca) = s.f32() {389let s = ca.apply_values(|val| val.ceil()).into_series();390return Ok(s);391}392if let Ok(ca) = s.f64() {393let s = ca.apply_values(|val| val.ceil()).into_series();394return Ok(s);395}396#[cfg(feature = "dtype-decimal")]397if let Some(ca) = s.try_decimal() {398let precision = ca.precision();399let scale = ca.scale() as u32;400if scale == 0 {401return Ok(ca.clone().into_series());402}403404let decimal_delta = scale;405let multiplier = 10i128.pow(decimal_delta);406407let ca = ca408.physical()409.apply_values(|v| {410let rem = v % multiplier;411let round_offset = if v < 0 { -rem } else { multiplier - rem };412let round_offset = if rem == 0 { 0 } else { round_offset };413v + round_offset414})415.into_decimal_unchecked(precision, scale as usize);416417return Ok(ca.into_series());418}419420polars_ensure!(s.dtype().is_primitive_numeric(), InvalidOperation: "ceil can only be used on numeric types" );421Ok(s.clone())422}423}424425impl RoundSeries for Series {}426427#[cfg(test)]428mod test {429use super::*;430431#[test]432fn test_round_series() {433let series = Series::new("a".into(), &[1.003, 2.23222, 3.4352]);434let out = series.round(2, RoundMode::default()).unwrap();435let ca = out.f64().unwrap();436assert_eq!(ca.get(0), Some(1.0));437}438}439440441