Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-arrow/src/compute/decimal.rs
6939 views
1
use num_traits::Euclid;
2
use polars_utils::relaxed_cell::RelaxedCell;
3
4
static TRIM_DECIMAL_ZEROS: RelaxedCell<bool> = RelaxedCell::new_bool(false);
5
6
pub fn get_trim_decimal_zeros() -> bool {
7
TRIM_DECIMAL_ZEROS.load()
8
}
9
pub fn set_trim_decimal_zeros(trim: Option<bool>) {
10
TRIM_DECIMAL_ZEROS.store(trim.unwrap_or(false))
11
}
12
13
/// Assuming bytes are a well-formed decimal number (with or without a separator),
14
/// infer the scale of the number. If no separator is present, the scale is 0.
15
pub fn infer_scale(bytes: &[u8]) -> u8 {
16
let Some(separator) = bytes.iter().position(|b| *b == b'.') else {
17
return 0;
18
};
19
(bytes.len() - (1 + separator)) as u8
20
}
21
22
/// Deserialize bytes to a single i128 representing a decimal, at a specified
23
/// precision (optional) and scale (required). The number is checked to ensure
24
/// it fits within the specified precision and scale. Consistent with float
25
/// parsing, no decimal separator is required (eg "500", "500.", and "500.0" are
26
/// all accepted); this allows mixed integer/decimal sequences to be parsed as
27
/// decimals. All trailing zeros are assumed to be significant, whether or not
28
/// a separator is present: 1200 requires precision >= 4, while 1200.200
29
/// requires precision >= 7 and scale >= 3. Returns None if the number is not
30
/// well-formed, or does not fit. Only b'.' is allowed as a decimal separator
31
/// (issue #6698).
32
#[inline]
33
pub fn deserialize_decimal(bytes: &[u8], precision: Option<u8>, scale: u8) -> Option<i128> {
34
let precision_digits = precision.unwrap_or(38).min(38) as usize;
35
if scale as usize > precision_digits {
36
return None;
37
}
38
39
let separator = bytes.iter().position(|b| *b == b'.').unwrap_or(bytes.len());
40
let (mut int, mut frac) = bytes.split_at(separator);
41
if frac.len() <= 1 || scale == 0 {
42
// Only integer fast path.
43
let n: i128 = atoi_simd::parse(int).ok()?;
44
let ret = n.checked_mul(POW10[scale as usize] as i128)?;
45
if precision.is_some() && ret >= POW10[precision_digits] as i128 {
46
return None;
47
}
48
return Some(ret);
49
}
50
51
// Skip period.
52
frac = &frac[1..];
53
54
// Skip sign.
55
let negative = match bytes.first() {
56
Some(s @ (b'+' | b'-')) => {
57
int = &int[1..];
58
*s == b'-'
59
},
60
_ => false,
61
};
62
63
// Truncate trailing digits that extend beyond the scale.
64
let frac_scale = if scale as usize <= frac.len() {
65
frac = &frac[..scale as usize];
66
0
67
} else {
68
scale as usize - frac.len()
69
};
70
71
// Parse and combine parts.
72
let pint: u128 = if int.is_empty() {
73
0
74
} else {
75
atoi_simd::parse_pos(int).ok()?
76
};
77
let pfrac: u128 = atoi_simd::parse_pos(frac).ok()?;
78
79
let ret = pint
80
.checked_mul(POW10[scale as usize])?
81
.checked_add(pfrac.checked_mul(POW10[frac_scale])?)?;
82
if precision.is_some() && ret >= POW10[precision_digits] {
83
return None;
84
}
85
if negative {
86
if ret > (1 << 127) {
87
None
88
} else {
89
Some(ret.wrapping_neg() as i128)
90
}
91
} else {
92
ret.try_into().ok()
93
}
94
}
95
96
const MAX_DECIMAL_LEN: usize = 48;
97
98
#[derive(Clone, Copy)]
99
pub struct DecimalFmtBuffer {
100
data: [u8; MAX_DECIMAL_LEN],
101
len: usize,
102
}
103
104
impl Default for DecimalFmtBuffer {
105
fn default() -> Self {
106
Self::new()
107
}
108
}
109
110
impl DecimalFmtBuffer {
111
#[inline]
112
pub const fn new() -> Self {
113
Self {
114
data: [0; MAX_DECIMAL_LEN],
115
len: 0,
116
}
117
}
118
119
pub fn format(&mut self, x: i128, scale: usize, trim_zeros: bool) -> &str {
120
let factor = POW10[scale];
121
let mut itoa_buf = itoa::Buffer::new();
122
123
self.len = 0;
124
let (div, rem) = x.unsigned_abs().div_rem_euclid(&factor);
125
if x < 0 {
126
self.data[0] = b'-';
127
self.len += 1;
128
}
129
130
let div_fmt = itoa_buf.format(div);
131
self.data[self.len..self.len + div_fmt.len()].copy_from_slice(div_fmt.as_bytes());
132
self.len += div_fmt.len();
133
134
if scale == 0 {
135
return unsafe { std::str::from_utf8_unchecked(&self.data[..self.len]) };
136
}
137
138
self.data[self.len] = b'.';
139
self.len += 1;
140
141
let rem_fmt = itoa_buf.format(rem + factor); // + factor adds leading 1 where period would be.
142
self.data[self.len..self.len + rem_fmt.len() - 1].copy_from_slice(&rem_fmt.as_bytes()[1..]);
143
self.len += rem_fmt.len() - 1;
144
145
if trim_zeros {
146
while self.data.get(self.len - 1) == Some(&b'0') {
147
self.len -= 1;
148
}
149
if self.data.get(self.len - 1) == Some(&b'.') {
150
self.len -= 1;
151
}
152
}
153
154
unsafe { std::str::from_utf8_unchecked(&self.data[..self.len]) }
155
}
156
}
157
158
const POW10: [u128; 39] = [
159
1,
160
10,
161
100,
162
1000,
163
10000,
164
100000,
165
1000000,
166
10000000,
167
100000000,
168
1000000000,
169
10000000000,
170
100000000000,
171
1000000000000,
172
10000000000000,
173
100000000000000,
174
1000000000000000,
175
10000000000000000,
176
100000000000000000,
177
1000000000000000000,
178
10000000000000000000,
179
100000000000000000000,
180
1000000000000000000000,
181
10000000000000000000000,
182
100000000000000000000000,
183
1000000000000000000000000,
184
10000000000000000000000000,
185
100000000000000000000000000,
186
1000000000000000000000000000,
187
10000000000000000000000000000,
188
100000000000000000000000000000,
189
1000000000000000000000000000000,
190
10000000000000000000000000000000,
191
100000000000000000000000000000000,
192
1000000000000000000000000000000000,
193
10000000000000000000000000000000000,
194
100000000000000000000000000000000000,
195
1000000000000000000000000000000000000,
196
10000000000000000000000000000000000000,
197
100000000000000000000000000000000000000,
198
];
199
200
#[cfg(test)]
201
mod test {
202
use super::*;
203
#[test]
204
fn test_decimal() {
205
let precision = Some(8);
206
let scale = 2;
207
208
let val = "12.09";
209
assert_eq!(
210
deserialize_decimal(val.as_bytes(), precision, scale),
211
Some(1209)
212
);
213
214
let val = "1200.90";
215
assert_eq!(
216
deserialize_decimal(val.as_bytes(), precision, scale),
217
Some(120090)
218
);
219
220
let val = "143.9";
221
assert_eq!(
222
deserialize_decimal(val.as_bytes(), precision, scale),
223
Some(14390)
224
);
225
226
let val = "+000000.5";
227
assert_eq!(
228
deserialize_decimal(val.as_bytes(), precision, scale),
229
Some(50)
230
);
231
232
let val = "-0.5";
233
assert_eq!(
234
deserialize_decimal(val.as_bytes(), precision, scale),
235
Some(-50)
236
);
237
238
let val = "-1.5";
239
assert_eq!(
240
deserialize_decimal(val.as_bytes(), precision, scale),
241
Some(-150)
242
);
243
244
let scale = 20;
245
let val = "0.01";
246
assert_eq!(deserialize_decimal(val.as_bytes(), precision, scale), None);
247
assert_eq!(
248
deserialize_decimal(val.as_bytes(), None, scale),
249
Some(1000000000000000000)
250
);
251
252
let scale = 5;
253
let val = "12ABC.34";
254
assert_eq!(deserialize_decimal(val.as_bytes(), precision, scale), None);
255
256
let val = "1ABC2.34";
257
assert_eq!(deserialize_decimal(val.as_bytes(), precision, scale), None);
258
259
let val = "12.3ABC4";
260
assert_eq!(deserialize_decimal(val.as_bytes(), precision, scale), None);
261
262
let val = "12.3.ABC4";
263
assert_eq!(deserialize_decimal(val.as_bytes(), precision, scale), None);
264
265
let val = "12.-3";
266
assert_eq!(deserialize_decimal(val.as_bytes(), precision, scale), None);
267
268
let val = "";
269
assert_eq!(deserialize_decimal(val.as_bytes(), precision, scale), None);
270
271
let val = "5.";
272
assert_eq!(
273
deserialize_decimal(val.as_bytes(), precision, scale),
274
Some(500000i128)
275
);
276
277
let val = "5";
278
assert_eq!(
279
deserialize_decimal(val.as_bytes(), precision, scale),
280
Some(500000i128)
281
);
282
283
let val = ".5";
284
assert_eq!(
285
deserialize_decimal(val.as_bytes(), precision, scale),
286
Some(50000i128)
287
);
288
289
// Precision and scale fitting:
290
let val = b"1200";
291
assert_eq!(deserialize_decimal(val, None, 0), Some(1200));
292
assert_eq!(deserialize_decimal(val, Some(4), 0), Some(1200));
293
assert_eq!(deserialize_decimal(val, Some(3), 0), None);
294
assert_eq!(deserialize_decimal(val, Some(4), 1), None);
295
296
let val = b"1200.010";
297
assert_eq!(deserialize_decimal(val, None, 0), Some(1200)); // truncate scale
298
assert_eq!(deserialize_decimal(val, None, 3), Some(1200010)); // exact scale
299
assert_eq!(deserialize_decimal(val, None, 6), Some(1200010000)); // excess scale
300
assert_eq!(deserialize_decimal(val, Some(7), 0), Some(1200)); // sufficient precision and truncate scale
301
assert_eq!(deserialize_decimal(val, Some(7), 3), Some(1200010)); // exact precision and scale
302
assert_eq!(deserialize_decimal(val, Some(10), 6), Some(1200010000)); // exact precision, excess scale
303
assert_eq!(deserialize_decimal(val, Some(5), 6), None); // insufficient precision, excess scale
304
assert_eq!(deserialize_decimal(val, Some(5), 3), None); // insufficient precision, exact scale
305
assert_eq!(deserialize_decimal(val, Some(12), 5), Some(120001000)); // excess precision, excess scale
306
assert_eq!(
307
deserialize_decimal(val, None, 35),
308
Some(120001000000000000000000000000000000000)
309
);
310
assert_eq!(deserialize_decimal(val, None, 36), None);
311
assert_eq!(deserialize_decimal(val, Some(38), 35), None); // scale causes insufficient precision
312
}
313
}
314
315