Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-ops/src/chunked_array/hist.rs
6939 views
1
use std::cmp;
2
use std::fmt::Write;
3
4
use num_traits::ToPrimitive;
5
use polars_core::prelude::*;
6
use polars_core::with_match_physical_numeric_polars_type;
7
8
const DEFAULT_BIN_COUNT: usize = 10;
9
10
fn get_breaks<T>(
11
ca: &ChunkedArray<T>,
12
bin_count: Option<usize>,
13
bins: Option<&[f64]>,
14
) -> PolarsResult<(Vec<f64>, bool)>
15
where
16
T: PolarsNumericType,
17
ChunkedArray<T>: ChunkAgg<T::Native>,
18
{
19
let (bins, uniform) = match (bin_count, bins) {
20
(Some(_), Some(_)) => {
21
return Err(PolarsError::ComputeError(
22
"can only provide one of `bin_count` or `bins`".into(),
23
));
24
},
25
(None, Some(bins)) => {
26
// User-supplied bins. Note these are actually bin edges. Check for monotonicity.
27
// If we only have one edge, we have no bins.
28
let bin_len = bins.len();
29
if bin_len > 1 {
30
for i in 1..bin_len {
31
if (bins[i] - bins[i - 1]) <= 0.0 {
32
return Err(PolarsError::ComputeError(
33
"bins must increase monotonically".into(),
34
));
35
}
36
}
37
(bins.to_vec(), false)
38
} else {
39
(Vec::<f64>::new(), false)
40
}
41
},
42
(bin_count, None) => {
43
// User-supplied bin count, or 10 by default. Compute edges from the data.
44
let bin_count = bin_count.unwrap_or(DEFAULT_BIN_COUNT);
45
let n = ca.len() - ca.null_count();
46
let (offset, width, upper_limit) = if n == 0 {
47
// No non-null items; supply unit interval.
48
(0.0, 1.0 / bin_count as f64, 1.0)
49
} else if n == 1 {
50
// Unit interval around single point
51
let idx = ca.first_non_null().unwrap();
52
// SAFETY: idx is guaranteed to contain an element.
53
let center = unsafe { ca.get_unchecked(idx) }.unwrap().to_f64().unwrap();
54
(center - 0.5, 1.0 / bin_count as f64, center + 0.5)
55
} else {
56
// Determine outer bin edges from the data itself
57
let min_value = ca.min().unwrap().to_f64().unwrap();
58
let max_value = ca.max().unwrap().to_f64().unwrap();
59
60
// All data points are identical--use unit interval.
61
if min_value == max_value {
62
(min_value - 0.5, 1.0 / bin_count as f64, max_value + 0.5)
63
} else {
64
(
65
min_value,
66
(max_value - min_value) / bin_count as f64,
67
max_value,
68
)
69
}
70
};
71
// Manually set the final value to the maximum value to ensure the final value isn't
72
// missed due to floating-point precision.
73
let out = (0..bin_count)
74
.map(|x| (x as f64 * width) + offset)
75
.chain(std::iter::once(upper_limit))
76
.collect::<Vec<f64>>();
77
(out, true)
78
},
79
};
80
Ok((bins, uniform))
81
}
82
83
// O(n) implementation when buckets are fixed-size.
84
// We deposit items directly into their buckets.
85
fn uniform_hist_count<T>(breaks: &[f64], ca: &ChunkedArray<T>) -> Vec<IdxSize>
86
where
87
T: PolarsNumericType,
88
ChunkedArray<T>: ChunkAgg<T::Native>,
89
{
90
let num_bins = breaks.len() - 1;
91
let mut count: Vec<IdxSize> = vec![0; num_bins];
92
let min_break: f64 = breaks[0];
93
let max_break: f64 = breaks[num_bins];
94
let scale = num_bins as f64 / (max_break - min_break);
95
let max_idx = num_bins - 1;
96
97
for chunk in ca.downcast_iter() {
98
for item in chunk.non_null_values_iter() {
99
let item = item.to_f64().unwrap();
100
if item > min_break && item <= max_break {
101
// idx > (num_bins - 1) may happen due to floating point representation imprecision
102
let mut idx = cmp::min((scale * (item - min_break)) as usize, max_idx);
103
104
// Adjust for float imprecision providing idx > 1 ULP of the breaks
105
if item <= breaks[idx] {
106
idx -= 1;
107
} else if item > breaks[idx + 1] {
108
idx += 1;
109
}
110
111
count[idx] += 1;
112
} else if item == min_break {
113
count[0] += 1;
114
}
115
}
116
}
117
count
118
}
119
120
// Variable-width bucketing. We sort the items and then move linearly through buckets.
121
fn hist_count<T>(breaks: &[f64], ca: &ChunkedArray<T>) -> Vec<IdxSize>
122
where
123
T: PolarsNumericType,
124
ChunkedArray<T>: ChunkAgg<T::Native>,
125
{
126
let num_bins = breaks.len() - 1;
127
let mut breaks_iter = breaks.iter().skip(1); // Skip the first lower bound
128
let (min_break, max_break) = (breaks[0], breaks[breaks.len() - 1]);
129
let mut upper_bound = *breaks_iter.next().unwrap();
130
let mut sorted = ca.sort(false);
131
sorted.rechunk_mut();
132
let mut current_count: IdxSize = 0;
133
let chunk = sorted.downcast_as_array();
134
let mut count: Vec<IdxSize> = Vec::with_capacity(num_bins);
135
136
'item: for item in chunk.non_null_values_iter() {
137
let item = item.to_f64().unwrap();
138
139
// Cycle through items until we hit the first bucket.
140
if item.is_nan() || item < min_break {
141
continue;
142
}
143
144
while item > upper_bound {
145
if item > max_break {
146
// No more items will fit in any buckets
147
break 'item;
148
}
149
150
// Finished with prior bucket; push, reset, and move to next.
151
count.push(current_count);
152
current_count = 0;
153
upper_bound = *breaks_iter.next().unwrap();
154
}
155
156
// Item is in bound.
157
current_count += 1;
158
}
159
count.push(current_count);
160
count.resize(num_bins, 0); // If we left early, fill remainder with 0.
161
count
162
}
163
164
fn compute_hist<T>(
165
ca: &ChunkedArray<T>,
166
bin_count: Option<usize>,
167
bins: Option<&[f64]>,
168
include_category: bool,
169
include_breakpoint: bool,
170
) -> PolarsResult<Series>
171
where
172
T: PolarsNumericType,
173
ChunkedArray<T>: ChunkAgg<T::Native>,
174
{
175
let (breaks, uniform) = get_breaks(ca, bin_count, bins)?;
176
let num_bins = std::cmp::max(breaks.len(), 1) - 1;
177
let count = if num_bins > 0 && ca.len() > ca.null_count() {
178
if uniform {
179
uniform_hist_count(&breaks, ca)
180
} else {
181
hist_count(&breaks, ca)
182
}
183
} else {
184
vec![0; num_bins]
185
};
186
187
// Generate output: breakpoint (optional), breaks (optional), count
188
let mut fields = Vec::with_capacity(3);
189
190
if include_breakpoint {
191
let breakpoints = if num_bins > 0 {
192
Series::new(PlSmallStr::from_static("breakpoint"), &breaks[1..])
193
} else {
194
let empty: &[f64; 0] = &[];
195
Series::new(PlSmallStr::from_static("breakpoint"), empty)
196
};
197
fields.push(breakpoints)
198
}
199
200
if include_category {
201
let mut categories =
202
StringChunkedBuilder::new(PlSmallStr::from_static("category"), breaks.len());
203
if num_bins > 0 {
204
let mut lower = AnyValue::Float64(breaks[0]);
205
let mut buf = String::new();
206
let mut open_bracket = "[";
207
for br in &breaks[1..] {
208
let br = AnyValue::Float64(*br);
209
buf.clear();
210
write!(buf, "{open_bracket}{lower}, {br}]").unwrap();
211
open_bracket = "(";
212
categories.append_value(buf.as_str());
213
lower = br;
214
}
215
}
216
let categories = categories
217
.finish()
218
.cast(&DataType::from_categories(Categories::global()))
219
.unwrap();
220
fields.push(categories);
221
};
222
223
let count = Series::new(PlSmallStr::from_static("count"), count);
224
fields.push(count);
225
226
Ok(if fields.len() == 1 {
227
fields.pop().unwrap().with_name(ca.name().clone())
228
} else {
229
StructChunked::from_series(ca.name().clone(), fields[0].len(), fields.iter())
230
.unwrap()
231
.into_series()
232
})
233
}
234
235
pub fn hist_series(
236
s: &Series,
237
bin_count: Option<usize>,
238
bins: Option<Series>,
239
include_category: bool,
240
include_breakpoint: bool,
241
) -> PolarsResult<Series> {
242
let mut bins_arg = None;
243
244
let owned_bins;
245
if let Some(bins) = bins {
246
polars_ensure!(bins.null_count() == 0, InvalidOperation: "nulls not supported in 'bins' argument");
247
let bins = bins.cast(&DataType::Float64)?;
248
let bins_s = bins.rechunk();
249
owned_bins = bins_s;
250
let bins = owned_bins.f64().unwrap();
251
let bins = bins.cont_slice().unwrap();
252
bins_arg = Some(bins);
253
};
254
polars_ensure!(s.dtype().is_primitive_numeric(), InvalidOperation: "'hist' is only supported for numeric data");
255
256
let out = with_match_physical_numeric_polars_type!(s.dtype(), |$T| {
257
let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();
258
compute_hist(ca, bin_count, bins_arg, include_category, include_breakpoint)?
259
});
260
Ok(out)
261
}
262
263