Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-ops/src/series/ops/cut.rs
6939 views
1
use polars_compute::rolling::QuantileMethod;
2
use polars_core::chunked_array::builder::CategoricalChunkedBuilder;
3
use polars_core::prelude::*;
4
use polars_utils::format_pl_smallstr;
5
6
fn map_cats(
7
s: &Series,
8
labels: &[PlSmallStr],
9
sorted_breaks: &[f64],
10
left_closed: bool,
11
include_breaks: bool,
12
) -> PolarsResult<Series> {
13
let out_name = PlSmallStr::from_static("category");
14
15
let s2 = s.cast(&DataType::Float64)?;
16
// It would be nice to parallelize this
17
let s_iter = s2.f64()?.into_iter();
18
19
let op = if left_closed {
20
PartialOrd::ge
21
} else {
22
PartialOrd::gt
23
};
24
25
if include_breaks {
26
// This is to replicate the behavior of the old buggy version that only worked on series and
27
// returned a dataframe. That included a column of the right endpoint of the interval. So we
28
// return a struct series instead which can be turned into a dataframe later.
29
let right_ends = [sorted_breaks, &[f64::INFINITY]].concat();
30
let mut bld = CategoricalChunkedBuilder::<Categorical32Type>::new(
31
out_name.clone(),
32
DataType::from_categories(Categories::global()),
33
);
34
let mut brk_vals = PrimitiveChunkedBuilder::<Float64Type>::new(
35
PlSmallStr::from_static("breakpoint"),
36
s.len(),
37
);
38
s_iter
39
.map(|opt| {
40
opt.filter(|x| !x.is_nan())
41
.map(|x| sorted_breaks.partition_point(|v| op(&x, v)))
42
})
43
.for_each(|idx| match idx {
44
None => {
45
bld.append_null();
46
brk_vals.append_null();
47
},
48
Some(idx) => unsafe {
49
bld.append_str(labels.get_unchecked(idx)).unwrap();
50
brk_vals.append_value(*right_ends.get_unchecked(idx));
51
},
52
});
53
54
let outvals = [brk_vals.finish().into_series(), bld.finish().into_series()];
55
Ok(StructChunked::from_series(out_name, outvals[0].len(), outvals.iter())?.into_series())
56
} else {
57
Ok(CategoricalChunked::<Categorical32Type>::from_str_iter(
58
out_name,
59
DataType::from_categories(Categories::global()),
60
s_iter.map(|opt| {
61
opt.filter(|x| !x.is_nan()).map(|x| {
62
let pt = sorted_breaks.partition_point(|v| op(&x, v));
63
unsafe { labels.get_unchecked(pt).as_str() }
64
})
65
}),
66
)?
67
.into_series())
68
}
69
}
70
71
pub fn compute_labels(breaks: &[f64], left_closed: bool) -> PolarsResult<Vec<PlSmallStr>> {
72
let lo = std::iter::once(&f64::NEG_INFINITY).chain(breaks.iter());
73
let hi = breaks.iter().chain(std::iter::once(&f64::INFINITY));
74
75
let ret = lo
76
.zip(hi)
77
.map(|(l, h)| {
78
if left_closed {
79
format_pl_smallstr!("[{}, {})", l, h)
80
} else {
81
format_pl_smallstr!("({}, {}]", l, h)
82
}
83
})
84
.collect();
85
Ok(ret)
86
}
87
88
pub fn cut(
89
s: &Series,
90
mut breaks: Vec<f64>,
91
labels: Option<Vec<PlSmallStr>>,
92
left_closed: bool,
93
include_breaks: bool,
94
) -> PolarsResult<Series> {
95
// Breaks must be sorted to cut inputs properly.
96
polars_ensure!(!breaks.iter().any(|x| x.is_nan()), ComputeError: "breaks cannot be NaN");
97
breaks.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
98
99
polars_ensure!(breaks.windows(2).all(|x| x[0] != x[1]), Duplicate: "breaks are not unique");
100
if !breaks.is_empty() {
101
polars_ensure!(breaks[0] > f64::NEG_INFINITY, ComputeError: "don't include -inf in breaks");
102
polars_ensure!(breaks[breaks.len() - 1] < f64::INFINITY, ComputeError: "don't include inf in breaks");
103
}
104
105
let cut_labels = if let Some(l) = labels {
106
polars_ensure!(l.len() == breaks.len() + 1, ShapeMismatch: "provide len(quantiles) + 1 labels");
107
l
108
} else {
109
compute_labels(&breaks, left_closed)?
110
};
111
map_cats(s, &cut_labels, &breaks, left_closed, include_breaks)
112
}
113
114
pub fn qcut(
115
s: &Series,
116
probs: Vec<f64>,
117
labels: Option<Vec<PlSmallStr>>,
118
left_closed: bool,
119
allow_duplicates: bool,
120
include_breaks: bool,
121
) -> PolarsResult<Series> {
122
polars_ensure!(!probs.iter().any(|x| x.is_nan()), ComputeError: "quantiles cannot be NaN");
123
124
if s.null_count() == s.len() {
125
// If we only have nulls we don't have any breakpoints.
126
return Ok(Series::full_null(
127
s.name().clone(),
128
s.len(),
129
&DataType::from_categories(Categories::global()),
130
));
131
}
132
133
let s = s.cast(&DataType::Float64)?;
134
let s2 = s.sort(SortOptions::default())?;
135
let ca = s2.f64()?;
136
137
let f = |&p| ca.quantile(p, QuantileMethod::Linear).unwrap().unwrap();
138
let mut qbreaks: Vec<_> = probs.iter().map(f).collect();
139
qbreaks.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
140
141
if !allow_duplicates {
142
polars_ensure!(qbreaks.windows(2).all(|x| x[0] != x[1]), Duplicate: "quantiles are not unique while allow_duplicates=False");
143
}
144
145
let cut_labels = if let Some(l) = labels {
146
polars_ensure!(l.len() == qbreaks.len() + 1, ShapeMismatch: "provide len(quantiles) + 1 labels");
147
l
148
} else {
149
compute_labels(&qbreaks, left_closed)?
150
};
151
152
map_cats(&s, &cut_labels, &qbreaks, left_closed, include_breaks)
153
}
154
155
mod test {
156
// This need metadata in fields
157
#[ignore]
158
#[test]
159
fn test_map_cats_fast_unique() {
160
// This test is here to check the fast unique flag is set when it can be
161
// as it is not visible to Python.
162
use polars_core::prelude::*;
163
164
use super::map_cats;
165
166
let s = Series::new("x".into(), &[1, 2, 3, 4, 5]);
167
168
let labels = &["a", "b", "c"].map(PlSmallStr::from_static);
169
let breaks = &[2.0, 4.0];
170
let left_closed = false;
171
172
let include_breaks = false;
173
let out = map_cats(&s, labels, breaks, left_closed, include_breaks).unwrap();
174
out.cat32().unwrap();
175
176
let include_breaks = true;
177
let out = map_cats(&s, labels, breaks, left_closed, include_breaks).unwrap();
178
let out = out.struct_().unwrap().fields_as_series()[1].clone();
179
out.cat32().unwrap();
180
}
181
}
182
183