Path: blob/main/crates/polars-arrow/src/compute/decimal.rs
6939 views
use num_traits::Euclid;1use polars_utils::relaxed_cell::RelaxedCell;23static TRIM_DECIMAL_ZEROS: RelaxedCell<bool> = RelaxedCell::new_bool(false);45pub fn get_trim_decimal_zeros() -> bool {6TRIM_DECIMAL_ZEROS.load()7}8pub fn set_trim_decimal_zeros(trim: Option<bool>) {9TRIM_DECIMAL_ZEROS.store(trim.unwrap_or(false))10}1112/// Assuming bytes are a well-formed decimal number (with or without a separator),13/// infer the scale of the number. If no separator is present, the scale is 0.14pub fn infer_scale(bytes: &[u8]) -> u8 {15let Some(separator) = bytes.iter().position(|b| *b == b'.') else {16return 0;17};18(bytes.len() - (1 + separator)) as u819}2021/// Deserialize bytes to a single i128 representing a decimal, at a specified22/// precision (optional) and scale (required). The number is checked to ensure23/// it fits within the specified precision and scale. Consistent with float24/// parsing, no decimal separator is required (eg "500", "500.", and "500.0" are25/// all accepted); this allows mixed integer/decimal sequences to be parsed as26/// decimals. All trailing zeros are assumed to be significant, whether or not27/// a separator is present: 1200 requires precision >= 4, while 1200.20028/// requires precision >= 7 and scale >= 3. Returns None if the number is not29/// well-formed, or does not fit. Only b'.' is allowed as a decimal separator30/// (issue #6698).31#[inline]32pub fn deserialize_decimal(bytes: &[u8], precision: Option<u8>, scale: u8) -> Option<i128> {33let precision_digits = precision.unwrap_or(38).min(38) as usize;34if scale as usize > precision_digits {35return None;36}3738let separator = bytes.iter().position(|b| *b == b'.').unwrap_or(bytes.len());39let (mut int, mut frac) = bytes.split_at(separator);40if frac.len() <= 1 || scale == 0 {41// Only integer fast path.42let n: i128 = atoi_simd::parse(int).ok()?;43let ret = n.checked_mul(POW10[scale as usize] as i128)?;44if precision.is_some() && ret >= POW10[precision_digits] as i128 {45return None;46}47return Some(ret);48}4950// Skip period.51frac = &frac[1..];5253// Skip sign.54let negative = match bytes.first() {55Some(s @ (b'+' | b'-')) => {56int = &int[1..];57*s == b'-'58},59_ => false,60};6162// Truncate trailing digits that extend beyond the scale.63let frac_scale = if scale as usize <= frac.len() {64frac = &frac[..scale as usize];65066} else {67scale as usize - frac.len()68};6970// Parse and combine parts.71let pint: u128 = if int.is_empty() {72073} else {74atoi_simd::parse_pos(int).ok()?75};76let pfrac: u128 = atoi_simd::parse_pos(frac).ok()?;7778let ret = pint79.checked_mul(POW10[scale as usize])?80.checked_add(pfrac.checked_mul(POW10[frac_scale])?)?;81if precision.is_some() && ret >= POW10[precision_digits] {82return None;83}84if negative {85if ret > (1 << 127) {86None87} else {88Some(ret.wrapping_neg() as i128)89}90} else {91ret.try_into().ok()92}93}9495const MAX_DECIMAL_LEN: usize = 48;9697#[derive(Clone, Copy)]98pub struct DecimalFmtBuffer {99data: [u8; MAX_DECIMAL_LEN],100len: usize,101}102103impl Default for DecimalFmtBuffer {104fn default() -> Self {105Self::new()106}107}108109impl DecimalFmtBuffer {110#[inline]111pub const fn new() -> Self {112Self {113data: [0; MAX_DECIMAL_LEN],114len: 0,115}116}117118pub fn format(&mut self, x: i128, scale: usize, trim_zeros: bool) -> &str {119let factor = POW10[scale];120let mut itoa_buf = itoa::Buffer::new();121122self.len = 0;123let (div, rem) = x.unsigned_abs().div_rem_euclid(&factor);124if x < 0 {125self.data[0] = b'-';126self.len += 1;127}128129let div_fmt = itoa_buf.format(div);130self.data[self.len..self.len + div_fmt.len()].copy_from_slice(div_fmt.as_bytes());131self.len += div_fmt.len();132133if scale == 0 {134return unsafe { std::str::from_utf8_unchecked(&self.data[..self.len]) };135}136137self.data[self.len] = b'.';138self.len += 1;139140let rem_fmt = itoa_buf.format(rem + factor); // + factor adds leading 1 where period would be.141self.data[self.len..self.len + rem_fmt.len() - 1].copy_from_slice(&rem_fmt.as_bytes()[1..]);142self.len += rem_fmt.len() - 1;143144if trim_zeros {145while self.data.get(self.len - 1) == Some(&b'0') {146self.len -= 1;147}148if self.data.get(self.len - 1) == Some(&b'.') {149self.len -= 1;150}151}152153unsafe { std::str::from_utf8_unchecked(&self.data[..self.len]) }154}155}156157const POW10: [u128; 39] = [1581,15910,160100,1611000,16210000,163100000,1641000000,16510000000,166100000000,1671000000000,16810000000000,169100000000000,1701000000000000,17110000000000000,172100000000000000,1731000000000000000,17410000000000000000,175100000000000000000,1761000000000000000000,17710000000000000000000,178100000000000000000000,1791000000000000000000000,18010000000000000000000000,181100000000000000000000000,1821000000000000000000000000,18310000000000000000000000000,184100000000000000000000000000,1851000000000000000000000000000,18610000000000000000000000000000,187100000000000000000000000000000,1881000000000000000000000000000000,18910000000000000000000000000000000,190100000000000000000000000000000000,1911000000000000000000000000000000000,19210000000000000000000000000000000000,193100000000000000000000000000000000000,1941000000000000000000000000000000000000,19510000000000000000000000000000000000000,196100000000000000000000000000000000000000,197];198199#[cfg(test)]200mod test {201use super::*;202#[test]203fn test_decimal() {204let precision = Some(8);205let scale = 2;206207let val = "12.09";208assert_eq!(209deserialize_decimal(val.as_bytes(), precision, scale),210Some(1209)211);212213let val = "1200.90";214assert_eq!(215deserialize_decimal(val.as_bytes(), precision, scale),216Some(120090)217);218219let val = "143.9";220assert_eq!(221deserialize_decimal(val.as_bytes(), precision, scale),222Some(14390)223);224225let val = "+000000.5";226assert_eq!(227deserialize_decimal(val.as_bytes(), precision, scale),228Some(50)229);230231let val = "-0.5";232assert_eq!(233deserialize_decimal(val.as_bytes(), precision, scale),234Some(-50)235);236237let val = "-1.5";238assert_eq!(239deserialize_decimal(val.as_bytes(), precision, scale),240Some(-150)241);242243let scale = 20;244let val = "0.01";245assert_eq!(deserialize_decimal(val.as_bytes(), precision, scale), None);246assert_eq!(247deserialize_decimal(val.as_bytes(), None, scale),248Some(1000000000000000000)249);250251let scale = 5;252let val = "12ABC.34";253assert_eq!(deserialize_decimal(val.as_bytes(), precision, scale), None);254255let val = "1ABC2.34";256assert_eq!(deserialize_decimal(val.as_bytes(), precision, scale), None);257258let val = "12.3ABC4";259assert_eq!(deserialize_decimal(val.as_bytes(), precision, scale), None);260261let val = "12.3.ABC4";262assert_eq!(deserialize_decimal(val.as_bytes(), precision, scale), None);263264let val = "12.-3";265assert_eq!(deserialize_decimal(val.as_bytes(), precision, scale), None);266267let val = "";268assert_eq!(deserialize_decimal(val.as_bytes(), precision, scale), None);269270let val = "5.";271assert_eq!(272deserialize_decimal(val.as_bytes(), precision, scale),273Some(500000i128)274);275276let val = "5";277assert_eq!(278deserialize_decimal(val.as_bytes(), precision, scale),279Some(500000i128)280);281282let val = ".5";283assert_eq!(284deserialize_decimal(val.as_bytes(), precision, scale),285Some(50000i128)286);287288// Precision and scale fitting:289let val = b"1200";290assert_eq!(deserialize_decimal(val, None, 0), Some(1200));291assert_eq!(deserialize_decimal(val, Some(4), 0), Some(1200));292assert_eq!(deserialize_decimal(val, Some(3), 0), None);293assert_eq!(deserialize_decimal(val, Some(4), 1), None);294295let val = b"1200.010";296assert_eq!(deserialize_decimal(val, None, 0), Some(1200)); // truncate scale297assert_eq!(deserialize_decimal(val, None, 3), Some(1200010)); // exact scale298assert_eq!(deserialize_decimal(val, None, 6), Some(1200010000)); // excess scale299assert_eq!(deserialize_decimal(val, Some(7), 0), Some(1200)); // sufficient precision and truncate scale300assert_eq!(deserialize_decimal(val, Some(7), 3), Some(1200010)); // exact precision and scale301assert_eq!(deserialize_decimal(val, Some(10), 6), Some(1200010000)); // exact precision, excess scale302assert_eq!(deserialize_decimal(val, Some(5), 6), None); // insufficient precision, excess scale303assert_eq!(deserialize_decimal(val, Some(5), 3), None); // insufficient precision, exact scale304assert_eq!(deserialize_decimal(val, Some(12), 5), Some(120001000)); // excess precision, excess scale305assert_eq!(306deserialize_decimal(val, None, 35),307Some(120001000000000000000000000000000000000)308);309assert_eq!(deserialize_decimal(val, None, 36), None);310assert_eq!(deserialize_decimal(val, Some(38), 35), None); // scale causes insufficient precision311}312}313314315