Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-ops/src/chunked_array/top_k.rs
6939 views
1
use arrow::array::{BinaryViewArray, BooleanArray, PrimitiveArray, StaticArray, View};
2
use arrow::bitmap::{Bitmap, BitmapBuilder};
3
use polars_core::chunked_array::ops::sort::arg_bottom_k::_arg_bottom_k;
4
use polars_core::prelude::*;
5
use polars_core::series::IsSorted;
6
use polars_core::{POOL, downcast_as_macro_arg_physical};
7
use polars_utils::total_ord::TotalOrd;
8
9
fn first_n_valid_mask(num_valid: usize, out_len: usize) -> Option<Bitmap> {
10
if num_valid < out_len {
11
let mut bm = BitmapBuilder::with_capacity(out_len);
12
bm.extend_constant(num_valid, true);
13
bm.extend_constant(out_len - num_valid, false);
14
Some(bm.freeze())
15
} else {
16
None
17
}
18
}
19
20
fn top_k_bool_impl(
21
ca: &ChunkedArray<BooleanType>,
22
k: usize,
23
descending: bool,
24
) -> ChunkedArray<BooleanType> {
25
if k >= ca.len() && ca.null_count() == 0 {
26
return ca.clone();
27
}
28
29
let null_count = ca.null_count();
30
let non_null_count = ca.len() - ca.null_count();
31
let true_count = ca.sum().unwrap() as usize;
32
let false_count = non_null_count - true_count;
33
let mut out_len = k.min(ca.len());
34
let validity = first_n_valid_mask(non_null_count, out_len);
35
36
// Logical sequence of physical bits.
37
let sequence = if descending {
38
[
39
(false_count, false),
40
(true_count, true),
41
(null_count, false),
42
]
43
} else {
44
[
45
(true_count, true),
46
(false_count, false),
47
(null_count, false),
48
]
49
};
50
51
let mut bm = BitmapBuilder::with_capacity(out_len);
52
for (n, value) in sequence {
53
if out_len == 0 {
54
break;
55
}
56
let extra = out_len.min(n);
57
bm.extend_constant(extra, value);
58
out_len -= extra;
59
}
60
61
let arr = BooleanArray::from_data_default(bm.freeze(), validity);
62
ChunkedArray::with_chunk_like(ca, arr)
63
}
64
65
fn top_k_num_impl<T>(ca: &ChunkedArray<T>, k: usize, descending: bool) -> ChunkedArray<T>
66
where
67
T: PolarsNumericType,
68
{
69
if k >= ca.len() && ca.null_count() == 0 {
70
return ca.clone();
71
}
72
73
// Get rid of all the nulls and transform into Vec<T::Native>.
74
let mut nnca = ca.drop_nulls();
75
nnca.rechunk_mut();
76
let chunk = nnca.downcast_into_iter().next().unwrap();
77
let (_, buffer, _) = chunk.into_inner();
78
let mut vec = buffer.make_mut();
79
80
// Partition.
81
if k < vec.len() {
82
if descending {
83
vec.select_nth_unstable_by(k, TotalOrd::tot_cmp);
84
} else {
85
vec.select_nth_unstable_by(k, |a, b| TotalOrd::tot_cmp(b, a));
86
}
87
}
88
89
// Reconstruct output (with nulls at the end).
90
let out_len = k.min(ca.len());
91
let non_null_count = ca.len() - ca.null_count();
92
vec.resize(out_len, T::Native::default());
93
let validity = first_n_valid_mask(non_null_count, out_len);
94
95
let arr = PrimitiveArray::from_vec(vec).with_validity_typed(validity);
96
ChunkedArray::with_chunk_like(ca, arr)
97
}
98
99
fn top_k_binary_impl(
100
ca: &ChunkedArray<BinaryType>,
101
k: usize,
102
descending: bool,
103
) -> ChunkedArray<BinaryType> {
104
if k >= ca.len() && ca.null_count() == 0 {
105
return ca.clone();
106
}
107
108
// Get rid of all the nulls and transform into mutable views.
109
let mut nnca = ca.drop_nulls();
110
nnca.rechunk_mut();
111
let chunk = nnca.downcast_into_iter().next().unwrap();
112
let buffers = chunk.data_buffers().clone();
113
let mut views = chunk.into_views();
114
115
// Partition.
116
if k < views.len() {
117
if descending {
118
views.select_nth_unstable_by(k, |a, b| unsafe {
119
let a_sl = a.get_slice_unchecked(&buffers);
120
let b_sl = b.get_slice_unchecked(&buffers);
121
a_sl.cmp(b_sl)
122
});
123
} else {
124
views.select_nth_unstable_by(k, |a, b| unsafe {
125
let a_sl = a.get_slice_unchecked(&buffers);
126
let b_sl = b.get_slice_unchecked(&buffers);
127
b_sl.cmp(a_sl)
128
});
129
}
130
}
131
132
// Reconstruct output (with nulls at the end).
133
let out_len = k.min(ca.len());
134
let non_null_count = ca.len() - ca.null_count();
135
views.resize(out_len, View::default());
136
let validity = first_n_valid_mask(non_null_count, out_len);
137
138
let arr = unsafe {
139
BinaryViewArray::new_unchecked_unknown_md(
140
ArrowDataType::BinaryView,
141
views.into(),
142
buffers,
143
validity,
144
None,
145
)
146
};
147
ChunkedArray::with_chunk_like(ca, arr)
148
}
149
150
pub fn top_k(s: &[Column], descending: bool) -> PolarsResult<Column> {
151
fn extract_target_and_k(s: &[Column]) -> PolarsResult<(usize, &Column)> {
152
let k_s = &s[1];
153
polars_ensure!(
154
k_s.len() == 1,
155
ComputeError: "`k` must be a single value for `top_k`."
156
);
157
158
let Some(k) = k_s.cast(&IDX_DTYPE)?.idx()?.get(0) else {
159
polars_bail!(ComputeError: "`k` must be set for `top_k`")
160
};
161
162
let src = &s[0];
163
Ok((k as usize, src))
164
}
165
166
let (k, src) = extract_target_and_k(s)?;
167
168
if src.is_empty() {
169
return Ok(src.clone());
170
}
171
172
let sorted_flag = src.is_sorted_flag();
173
let is_sorted = match src.is_sorted_flag() {
174
IsSorted::Ascending => true,
175
IsSorted::Descending => true,
176
IsSorted::Not => false,
177
};
178
if is_sorted {
179
let out_len = k.min(src.len());
180
let ignored_len = src.len() - out_len;
181
let slice_at_start = (sorted_flag == IsSorted::Ascending) == descending;
182
let nulls_at_start = src.get(0).unwrap() == AnyValue::Null;
183
let offset = if nulls_at_start == slice_at_start {
184
src.null_count().min(ignored_len)
185
} else {
186
0
187
};
188
189
return if slice_at_start {
190
Ok(src.slice(offset as i64, out_len))
191
} else {
192
Ok(src.slice(-(offset as i64) - (out_len as i64), out_len))
193
};
194
}
195
196
let origin_dtype = src.dtype();
197
198
let s = src.to_physical_repr();
199
200
match s.dtype() {
201
DataType::Boolean => Ok(top_k_bool_impl(s.bool().unwrap(), k, descending).into_column()),
202
DataType::String => {
203
let ca = top_k_binary_impl(&s.str().unwrap().as_binary(), k, descending);
204
let ca = unsafe { ca.to_string_unchecked() };
205
Ok(ca.into_column())
206
},
207
DataType::Binary => Ok(top_k_binary_impl(s.binary().unwrap(), k, descending).into_column()),
208
DataType::Null => Ok(src.slice(0, k)),
209
dt if dt.is_primitive_numeric() => {
210
macro_rules! dispatch {
211
($ca:expr) => {{ top_k_num_impl($ca, k, descending).into_column() }};
212
}
213
unsafe {
214
downcast_as_macro_arg_physical!(&s, dispatch).from_physical_unchecked(origin_dtype)
215
}
216
},
217
_ => {
218
// Fallback to more generic impl.
219
top_k_by_impl(k, src, std::slice::from_ref(src), vec![descending])
220
},
221
}
222
}
223
224
pub fn top_k_by(s: &[Column], descending: Vec<bool>) -> PolarsResult<Column> {
225
/// Return (k, src, by)
226
fn extract_parameters(s: &[Column]) -> PolarsResult<(usize, &Column, &[Column])> {
227
let k_s = &s[1];
228
229
polars_ensure!(
230
k_s.len() == 1,
231
ComputeError: "`k` must be a single value for `top_k`."
232
);
233
234
let Some(k) = k_s.cast(&IDX_DTYPE)?.idx()?.get(0) else {
235
polars_bail!(ComputeError: "`k` must be set for `top_k`")
236
};
237
238
let src = &s[0];
239
240
let by = &s[2..];
241
242
Ok((k as usize, src, by))
243
}
244
245
let (k, src, by) = extract_parameters(s)?;
246
247
if src.is_empty() {
248
return Ok(src.clone());
249
}
250
251
if by.first().map(|x| x.is_empty()).unwrap_or(false) {
252
return Ok(src.clone());
253
}
254
255
for s in by {
256
if s.len() != src.len() {
257
polars_bail!(ComputeError: "`by` column's ({}) length ({}) should have the same length as the source column length ({}) in `top_k`", s.name(), s.len(), src.len())
258
}
259
}
260
261
top_k_by_impl(k, src, by, descending)
262
}
263
264
fn top_k_by_impl(
265
k: usize,
266
src: &Column,
267
by: &[Column],
268
descending: Vec<bool>,
269
) -> PolarsResult<Column> {
270
if src.is_empty() {
271
return Ok(src.clone());
272
}
273
274
let multithreaded = k >= 10000 && POOL.current_num_threads() > 1;
275
let mut sort_options = SortMultipleOptions {
276
descending: descending.into_iter().map(|x| !x).collect(),
277
nulls_last: vec![true; by.len()],
278
multithreaded,
279
maintain_order: false,
280
limit: None,
281
};
282
283
let idx = _arg_bottom_k(k, by, &mut sort_options)?;
284
285
let result = unsafe {
286
src.as_materialized_series()
287
.take_unchecked(&idx.into_inner())
288
};
289
Ok(result.into())
290
}
291
292