Path: blob/main/crates/polars-ops/src/series/ops/round.rs
6939 views
use polars_core::prelude::*;1use polars_core::with_match_physical_numeric_polars_type;2#[cfg(feature = "serde")]3use serde::{Deserialize, Serialize};4use strum_macros::IntoStaticStr;56use crate::series::ops::SeriesSealed;78#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, IntoStaticStr)]9#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]10#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]11#[strum(serialize_all = "snake_case")]12#[derive(Default)]13pub enum RoundMode {14#[default]15HalfToEven,16HalfAwayFromZero,17}1819pub trait RoundSeries: SeriesSealed {20/// Round underlying floating point array to given decimal.21fn round(&self, decimals: u32, mode: RoundMode) -> PolarsResult<Series> {22let s = self.as_series();2324if let Ok(ca) = s.f32() {25match mode {26RoundMode::HalfToEven => {27return if decimals == 0 {28let s = ca.apply_values(|val| val.round_ties_even()).into_series();29Ok(s)30} else if decimals >= 326 {31// More precise than smallest denormal.32Ok(s.clone())33} else {34// Note we do the computation on f64 floats to not lose precision35// when the computation is done, we cast to f3236let multiplier = 10.0_f64.powi(decimals as i32);37let s = ca38.apply_values(|val| {39let ret = ((val as f64 * multiplier).round_ties_even() / multiplier)40as f32;41if ret.is_finite() {42ret43} else {44// We return the original value which is correct both for overflows and non-finite inputs.45val46}47})48.into_series();49Ok(s)50};51},52RoundMode::HalfAwayFromZero => {53return if decimals == 0 {54let s = ca.apply_values(|val| val.round()).into_series();55Ok(s)56} else if decimals >= 326 {57// More precise than smallest denormal.58Ok(s.clone())59} else {60// Note we do the computation on f64 floats to not lose precision61// when the computation is done, we cast to f3262let multiplier = 10.0_f64.powi(decimals as i32);63let s = ca64.apply_values(|val| {65let ret = ((val as f64 * multiplier).round_ties_even() / multiplier)66as f32;67if ret.is_finite() {68ret69} else {70// We return the original value which is correct both for overflows and non-finite inputs.71val72}73})74.into_series();75Ok(s)76};77},78}79}80if let Ok(ca) = s.f64() {81match mode {82RoundMode::HalfToEven => {83return if decimals == 0 {84let s = ca.apply_values(|val| val.round_ties_even()).into_series();85Ok(s)86} else if decimals >= 326 {87// More precise than smallest denormal.88Ok(s.clone())89} else if decimals >= 300 {90// We're getting into unrepresentable territory for the multiplier91// here, split up the 10^n multiplier into 2^n and 5^n.92let mul2 = libm::scalbn(1.0, decimals as i32);93let invmul2 = 1.0 / mul2; // Still exact for any valid value of decimals.94let mul5 = 5.0_f64.powi(decimals as i32);95let s = ca96.apply_values(|val| {97let ret = (val * mul2 * mul5).round_ties_even() / mul5 * invmul2;98if ret.is_finite() {99ret100} else {101// We return the original value which is correct both for overflows and non-finite inputs.102val103}104})105.into_series();106Ok(s)107} else {108let multiplier = 10.0_f64.powi(decimals as i32);109let s = ca110.apply_values(|val| {111let ret = (val * multiplier).round_ties_even() / multiplier;112if ret.is_finite() {113ret114} else {115// We return the original value which is correct both for overflows and non-finite inputs.116val117}118})119.into_series();120Ok(s)121};122},123RoundMode::HalfAwayFromZero => {124return if decimals == 0 {125let s = ca.apply_values(|val| val.round()).into_series();126Ok(s)127} else if decimals >= 326 {128// More precise than smallest denormal.129Ok(s.clone())130} else if decimals >= 300 {131// We're getting into unrepresentable territory for the multiplier132// here, split up the 10^n multiplier into 2^n and 5^n.133let mul2 = libm::scalbn(1.0, decimals as i32);134let invmul2 = 1.0 / mul2; // Still exact for any valid value of decimals.135let mul5 = 5.0_f64.powi(decimals as i32);136let s = ca137.apply_values(|val| {138let ret = (val * mul2 * mul5).round() / mul5 * invmul2;139if ret.is_finite() {140ret141} else {142// We return the original value which is correct both for overflows and non-finite inputs.143val144}145})146.into_series();147Ok(s)148} else {149let multiplier = 10.0_f64.powi(decimals as i32);150let s = ca151.apply_values(|val| {152let ret = (val * multiplier).round() / multiplier;153if ret.is_finite() {154ret155} else {156// We return the original value which is correct both for overflows and non-finite inputs.157val158}159})160.into_series();161Ok(s)162};163},164}165}166#[cfg(feature = "dtype-decimal")]167if let Some(ca) = s.try_decimal() {168let scale = ca.scale() as u32;169170if scale <= decimals {171return Ok(ca.clone().into_series());172}173174let decimal_delta = scale - decimals;175let multiplier = 10i128.pow(decimal_delta);176let threshold = multiplier / 2;177178let res = match mode {179RoundMode::HalfToEven => ca.physical().apply_values(|v| {180let rem_big = v % (2 * multiplier);181let is_v_floor_even = rem_big.abs() < multiplier;182let rem = if is_v_floor_even {183rem_big184} else if rem_big > 0 {185rem_big - multiplier186} else {187rem_big + multiplier188};189190let threshold = threshold + i128::from(is_v_floor_even);191let round_offset = if rem.abs() >= threshold {192if v < 0 { -multiplier } else { multiplier }193} else {1940195};196v - rem + round_offset197}),198RoundMode::HalfAwayFromZero => ca.physical().apply_values(|v| {199let rem = v % multiplier;200let round_offset = if rem.abs() >= threshold {201if v < 0 { -multiplier } else { multiplier }202} else {2030204};205v - rem + round_offset206}),207};208return Ok(res209.into_decimal_unchecked(ca.precision(), scale as usize)210.into_series());211}212213polars_ensure!(s.dtype().is_integer(), InvalidOperation: "round can only be used on numeric types" );214Ok(s.clone())215}216217fn round_sig_figs(&self, digits: i32) -> PolarsResult<Series> {218let s = self.as_series();219polars_ensure!(digits >= 1, InvalidOperation: "digits must be an integer >= 1");220221#[cfg(feature = "dtype-decimal")]222if let Some(ca) = s.try_decimal() {223let precision = ca.precision();224let scale = ca.scale() as u32;225226let s = ca227.physical()228.apply_values(|v| {229if v == 0 {230return 0;231}232233let mut magnitude = v.abs().ilog10();234let magnitude_mult = 10i128.pow(magnitude); // @Q? It might be better to do this with a235// LUT.236if v.abs() > magnitude_mult {237magnitude += 1;238}239let decimals = magnitude.saturating_sub(digits as u32);240let multiplier = 10i128.pow(decimals); // @Q? It might be better to do this with a241// LUT.242let threshold = multiplier / 2;243244// We use rounding=ROUND_HALF_EVEN245let rem = v % multiplier;246let is_v_floor_even = decimals <= scale && ((v - rem) / multiplier) % 2 == 0;247let threshold = threshold + i128::from(is_v_floor_even);248let round_offset = if rem.abs() >= threshold {249multiplier250} else {2510252};253let round_offset = if v < 0 { -round_offset } else { round_offset };254v - rem + round_offset255})256.into_decimal_unchecked(precision, scale as usize)257.into_series();258259return Ok(s);260}261262polars_ensure!(s.dtype().is_primitive_numeric(), InvalidOperation: "round_sig_figs can only be used on numeric types" );263with_match_physical_numeric_polars_type!(s.dtype(), |$T| {264let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();265let s = ca.apply_values(|value| {266let value = value as f64;267if value == 0.0 {268return value as <$T as PolarsNumericType>::Native;269}270// To deal with very large/small numbers we split up 10^n in 5^n and 2^n.271// The scaling by 2^n is almost always lossless.272let exp = digits - 1 - value.abs().log10().floor() as i32;273let pow5 = 5.0_f64.powi(exp);274let scaled = libm::scalbn(value, exp) * pow5;275let descaled = libm::scalbn(scaled.round() / pow5, -exp);276if descaled.is_finite() {277descaled as <$T as PolarsNumericType>::Native278} else {279value as <$T as PolarsNumericType>::Native280}281}).into_series();282return Ok(s);283});284}285286/// Floor underlying floating point array to the lowest integers smaller or equal to the float value.287fn floor(&self) -> PolarsResult<Series> {288let s = self.as_series();289290if let Ok(ca) = s.f32() {291let s = ca.apply_values(|val| val.floor()).into_series();292return Ok(s);293}294if let Ok(ca) = s.f64() {295let s = ca.apply_values(|val| val.floor()).into_series();296return Ok(s);297}298#[cfg(feature = "dtype-decimal")]299if let Some(ca) = s.try_decimal() {300let precision = ca.precision();301let scale = ca.scale() as u32;302if scale == 0 {303return Ok(ca.clone().into_series());304}305306let decimal_delta = scale;307let multiplier = 10i128.pow(decimal_delta);308309let ca = ca310.physical()311.apply_values(|v| {312let rem = v % multiplier;313let round_offset = if v < 0 { multiplier + rem } else { rem };314let round_offset = if rem == 0 { 0 } else { round_offset };315v - round_offset316})317.into_decimal_unchecked(precision, scale as usize);318319return Ok(ca.into_series());320}321322polars_ensure!(s.dtype().is_primitive_numeric(), InvalidOperation: "floor can only be used on numeric types" );323Ok(s.clone())324}325326/// Ceil underlying floating point array to the highest integers smaller or equal to the float value.327fn ceil(&self) -> PolarsResult<Series> {328let s = self.as_series();329330if let Ok(ca) = s.f32() {331let s = ca.apply_values(|val| val.ceil()).into_series();332return Ok(s);333}334if let Ok(ca) = s.f64() {335let s = ca.apply_values(|val| val.ceil()).into_series();336return Ok(s);337}338#[cfg(feature = "dtype-decimal")]339if let Some(ca) = s.try_decimal() {340let precision = ca.precision();341let scale = ca.scale() as u32;342if scale == 0 {343return Ok(ca.clone().into_series());344}345346let decimal_delta = scale;347let multiplier = 10i128.pow(decimal_delta);348349let ca = ca350.physical()351.apply_values(|v| {352let rem = v % multiplier;353let round_offset = if v < 0 { -rem } else { multiplier - rem };354let round_offset = if rem == 0 { 0 } else { round_offset };355v + round_offset356})357.into_decimal_unchecked(precision, scale as usize);358359return Ok(ca.into_series());360}361362polars_ensure!(s.dtype().is_primitive_numeric(), InvalidOperation: "ceil can only be used on numeric types" );363Ok(s.clone())364}365}366367impl RoundSeries for Series {}368369#[cfg(test)]370mod test {371use super::*;372373#[test]374fn test_round_series() {375let series = Series::new("a".into(), &[1.003, 2.23222, 3.4352]);376let out = series.round(2, RoundMode::default()).unwrap();377let ca = out.f64().unwrap();378assert_eq!(ca.get(0), Some(1.0));379}380}381382383