Path: blob/main/crates/polars-ops/src/chunked_array/hist.rs
6939 views
use std::cmp;1use std::fmt::Write;23use num_traits::ToPrimitive;4use polars_core::prelude::*;5use polars_core::with_match_physical_numeric_polars_type;67const DEFAULT_BIN_COUNT: usize = 10;89fn get_breaks<T>(10ca: &ChunkedArray<T>,11bin_count: Option<usize>,12bins: Option<&[f64]>,13) -> PolarsResult<(Vec<f64>, bool)>14where15T: PolarsNumericType,16ChunkedArray<T>: ChunkAgg<T::Native>,17{18let (bins, uniform) = match (bin_count, bins) {19(Some(_), Some(_)) => {20return Err(PolarsError::ComputeError(21"can only provide one of `bin_count` or `bins`".into(),22));23},24(None, Some(bins)) => {25// User-supplied bins. Note these are actually bin edges. Check for monotonicity.26// If we only have one edge, we have no bins.27let bin_len = bins.len();28if bin_len > 1 {29for i in 1..bin_len {30if (bins[i] - bins[i - 1]) <= 0.0 {31return Err(PolarsError::ComputeError(32"bins must increase monotonically".into(),33));34}35}36(bins.to_vec(), false)37} else {38(Vec::<f64>::new(), false)39}40},41(bin_count, None) => {42// User-supplied bin count, or 10 by default. Compute edges from the data.43let bin_count = bin_count.unwrap_or(DEFAULT_BIN_COUNT);44let n = ca.len() - ca.null_count();45let (offset, width, upper_limit) = if n == 0 {46// No non-null items; supply unit interval.47(0.0, 1.0 / bin_count as f64, 1.0)48} else if n == 1 {49// Unit interval around single point50let idx = ca.first_non_null().unwrap();51// SAFETY: idx is guaranteed to contain an element.52let center = unsafe { ca.get_unchecked(idx) }.unwrap().to_f64().unwrap();53(center - 0.5, 1.0 / bin_count as f64, center + 0.5)54} else {55// Determine outer bin edges from the data itself56let min_value = ca.min().unwrap().to_f64().unwrap();57let max_value = ca.max().unwrap().to_f64().unwrap();5859// All data points are identical--use unit interval.60if min_value == max_value {61(min_value - 0.5, 1.0 / bin_count as f64, max_value + 0.5)62} else {63(64min_value,65(max_value - min_value) / bin_count as f64,66max_value,67)68}69};70// Manually set the final value to the maximum value to ensure the final value isn't71// missed due to floating-point precision.72let out = (0..bin_count)73.map(|x| (x as f64 * width) + offset)74.chain(std::iter::once(upper_limit))75.collect::<Vec<f64>>();76(out, true)77},78};79Ok((bins, uniform))80}8182// O(n) implementation when buckets are fixed-size.83// We deposit items directly into their buckets.84fn uniform_hist_count<T>(breaks: &[f64], ca: &ChunkedArray<T>) -> Vec<IdxSize>85where86T: PolarsNumericType,87ChunkedArray<T>: ChunkAgg<T::Native>,88{89let num_bins = breaks.len() - 1;90let mut count: Vec<IdxSize> = vec![0; num_bins];91let min_break: f64 = breaks[0];92let max_break: f64 = breaks[num_bins];93let scale = num_bins as f64 / (max_break - min_break);94let max_idx = num_bins - 1;9596for chunk in ca.downcast_iter() {97for item in chunk.non_null_values_iter() {98let item = item.to_f64().unwrap();99if item > min_break && item <= max_break {100// idx > (num_bins - 1) may happen due to floating point representation imprecision101let mut idx = cmp::min((scale * (item - min_break)) as usize, max_idx);102103// Adjust for float imprecision providing idx > 1 ULP of the breaks104if item <= breaks[idx] {105idx -= 1;106} else if item > breaks[idx + 1] {107idx += 1;108}109110count[idx] += 1;111} else if item == min_break {112count[0] += 1;113}114}115}116count117}118119// Variable-width bucketing. We sort the items and then move linearly through buckets.120fn hist_count<T>(breaks: &[f64], ca: &ChunkedArray<T>) -> Vec<IdxSize>121where122T: PolarsNumericType,123ChunkedArray<T>: ChunkAgg<T::Native>,124{125let num_bins = breaks.len() - 1;126let mut breaks_iter = breaks.iter().skip(1); // Skip the first lower bound127let (min_break, max_break) = (breaks[0], breaks[breaks.len() - 1]);128let mut upper_bound = *breaks_iter.next().unwrap();129let mut sorted = ca.sort(false);130sorted.rechunk_mut();131let mut current_count: IdxSize = 0;132let chunk = sorted.downcast_as_array();133let mut count: Vec<IdxSize> = Vec::with_capacity(num_bins);134135'item: for item in chunk.non_null_values_iter() {136let item = item.to_f64().unwrap();137138// Cycle through items until we hit the first bucket.139if item.is_nan() || item < min_break {140continue;141}142143while item > upper_bound {144if item > max_break {145// No more items will fit in any buckets146break 'item;147}148149// Finished with prior bucket; push, reset, and move to next.150count.push(current_count);151current_count = 0;152upper_bound = *breaks_iter.next().unwrap();153}154155// Item is in bound.156current_count += 1;157}158count.push(current_count);159count.resize(num_bins, 0); // If we left early, fill remainder with 0.160count161}162163fn compute_hist<T>(164ca: &ChunkedArray<T>,165bin_count: Option<usize>,166bins: Option<&[f64]>,167include_category: bool,168include_breakpoint: bool,169) -> PolarsResult<Series>170where171T: PolarsNumericType,172ChunkedArray<T>: ChunkAgg<T::Native>,173{174let (breaks, uniform) = get_breaks(ca, bin_count, bins)?;175let num_bins = std::cmp::max(breaks.len(), 1) - 1;176let count = if num_bins > 0 && ca.len() > ca.null_count() {177if uniform {178uniform_hist_count(&breaks, ca)179} else {180hist_count(&breaks, ca)181}182} else {183vec![0; num_bins]184};185186// Generate output: breakpoint (optional), breaks (optional), count187let mut fields = Vec::with_capacity(3);188189if include_breakpoint {190let breakpoints = if num_bins > 0 {191Series::new(PlSmallStr::from_static("breakpoint"), &breaks[1..])192} else {193let empty: &[f64; 0] = &[];194Series::new(PlSmallStr::from_static("breakpoint"), empty)195};196fields.push(breakpoints)197}198199if include_category {200let mut categories =201StringChunkedBuilder::new(PlSmallStr::from_static("category"), breaks.len());202if num_bins > 0 {203let mut lower = AnyValue::Float64(breaks[0]);204let mut buf = String::new();205let mut open_bracket = "[";206for br in &breaks[1..] {207let br = AnyValue::Float64(*br);208buf.clear();209write!(buf, "{open_bracket}{lower}, {br}]").unwrap();210open_bracket = "(";211categories.append_value(buf.as_str());212lower = br;213}214}215let categories = categories216.finish()217.cast(&DataType::from_categories(Categories::global()))218.unwrap();219fields.push(categories);220};221222let count = Series::new(PlSmallStr::from_static("count"), count);223fields.push(count);224225Ok(if fields.len() == 1 {226fields.pop().unwrap().with_name(ca.name().clone())227} else {228StructChunked::from_series(ca.name().clone(), fields[0].len(), fields.iter())229.unwrap()230.into_series()231})232}233234pub fn hist_series(235s: &Series,236bin_count: Option<usize>,237bins: Option<Series>,238include_category: bool,239include_breakpoint: bool,240) -> PolarsResult<Series> {241let mut bins_arg = None;242243let owned_bins;244if let Some(bins) = bins {245polars_ensure!(bins.null_count() == 0, InvalidOperation: "nulls not supported in 'bins' argument");246let bins = bins.cast(&DataType::Float64)?;247let bins_s = bins.rechunk();248owned_bins = bins_s;249let bins = owned_bins.f64().unwrap();250let bins = bins.cont_slice().unwrap();251bins_arg = Some(bins);252};253polars_ensure!(s.dtype().is_primitive_numeric(), InvalidOperation: "'hist' is only supported for numeric data");254255let out = with_match_physical_numeric_polars_type!(s.dtype(), |$T| {256let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();257compute_hist(ca, bin_count, bins_arg, include_category, include_breakpoint)?258});259Ok(out)260}261262263