Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-stream/src/nodes/top_k.rs
6939 views
1
use std::any::Any;
2
use std::collections::BinaryHeap;
3
use std::sync::Arc;
4
5
use polars_core::prelude::row_encode::_get_rows_encoded;
6
use polars_core::prelude::*;
7
use polars_core::schema::Schema;
8
use polars_core::utils::accumulate_dataframes_vertical;
9
use polars_core::with_match_physical_numeric_polars_type;
10
use polars_utils::IdxSize;
11
use polars_utils::priority::Priority;
12
use polars_utils::sort::ReorderWithNulls;
13
use polars_utils::total_ord::TotalOrdWrap;
14
use slotmap::{SecondaryMap, SlotMap, new_key_type};
15
16
use super::compute_node_prelude::*;
17
use crate::expression::StreamExpr;
18
use crate::nodes::in_memory_sink::InMemorySinkNode;
19
use crate::nodes::in_memory_source::InMemorySourceNode;
20
21
new_key_type! {
22
struct DfsKey;
23
struct RowIdxKey;
24
}
25
26
/// Represents a subset of a dataframe.
27
struct DfSubset {
28
df: DataFrame,
29
rows: Vec<RowIdxKey>,
30
subset_len: usize,
31
}
32
33
impl DfSubset {
34
/// Gather this subset into a contiguous DataFrame, updating the relevant row indices.
35
pub fn gather(
36
&mut self,
37
row_idxs: &mut SlotMap<RowIdxKey, IdxSize>,
38
gather_idx_buf: &mut Vec<IdxSize>,
39
) {
40
if self.subset_len == self.df.height() {
41
return;
42
}
43
44
gather_idx_buf.clear();
45
let mut new_idx = 0;
46
self.rows.retain(|row_idx_key| {
47
let row_idx = &mut row_idxs[*row_idx_key];
48
if *row_idx != IdxSize::MAX {
49
gather_idx_buf.push(*row_idx);
50
*row_idx = new_idx;
51
new_idx += 1;
52
true
53
} else {
54
row_idxs.remove(*row_idx_key);
55
false
56
}
57
});
58
59
unsafe { self.df = self.df.take_slice_unchecked(gather_idx_buf) }
60
}
61
}
62
63
pub struct BottomKWithPayload<P> {
64
k: usize,
65
heap: BinaryHeap<Priority<P, (DfsKey, RowIdxKey)>>,
66
df_subsets: SlotMap<DfsKey, DfSubset>,
67
row_idxs: SlotMap<RowIdxKey, IdxSize>,
68
to_prune: Vec<DfsKey>,
69
gather_idxs: Vec<IdxSize>,
70
}
71
72
impl<P: Ord + Clone> BottomKWithPayload<P> {
73
pub fn new(k: usize) -> Self {
74
Self {
75
k,
76
heap: BinaryHeap::with_capacity(k + 1),
77
df_subsets: SlotMap::with_key(),
78
row_idxs: SlotMap::with_key(),
79
to_prune: Vec::new(),
80
gather_idxs: Vec::new(),
81
}
82
}
83
84
pub fn add_df<Q>(
85
&mut self,
86
df: DataFrame,
87
keys: impl IntoIterator<Item = Q>,
88
is_less: impl Fn(&Q, &P) -> bool,
89
to_owned: impl Fn(Q) -> P,
90
) {
91
let dfs_key = self.df_subsets.insert(DfSubset {
92
df,
93
rows: Vec::new(),
94
subset_len: 0,
95
});
96
97
for (row_idx, key) in keys.into_iter().enumerate() {
98
self.add_one(
99
dfs_key,
100
row_idx.try_into().unwrap(),
101
key,
102
&is_less,
103
&to_owned,
104
)
105
}
106
self.prune();
107
}
108
109
fn add_one<Q>(
110
&mut self,
111
dfs_key: DfsKey,
112
row_idx: IdxSize,
113
key: Q,
114
is_less: impl Fn(&Q, &P) -> bool,
115
to_owned: impl Fn(Q) -> P,
116
) {
117
// We use a max-heap for our bottom k. This means the top element in our heap (peek())
118
// is the first to be replaced.
119
if self.heap.len() < self.k || is_less(&key, &self.heap.peek().unwrap().0) {
120
let row_idx_key = self.row_idxs.insert(row_idx);
121
let df_subset = &mut self.df_subsets[dfs_key];
122
df_subset.subset_len += 1;
123
df_subset.rows.push(row_idx_key);
124
self.heap
125
.push(Priority(to_owned(key), (dfs_key, row_idx_key)));
126
}
127
128
if self.heap.len() > self.k {
129
let (dfs_key, row_idx_key) = self.heap.pop().unwrap().1;
130
self.row_idxs[row_idx_key] = IdxSize::MAX;
131
let df_subset = &mut self.df_subsets[dfs_key];
132
df_subset.subset_len -= 1;
133
if df_subset.subset_len == self.df_subsets.len() / 2 {
134
self.to_prune.push(dfs_key);
135
}
136
}
137
}
138
139
pub fn prune(&mut self) {
140
for dfs_key in self.to_prune.drain(..) {
141
if self.df_subsets[dfs_key].subset_len == 0 {
142
let df_subset = self.df_subsets.remove(dfs_key).unwrap();
143
for row_idx in df_subset.rows {
144
self.row_idxs.remove(row_idx);
145
}
146
} else {
147
self.df_subsets[dfs_key].gather(&mut self.row_idxs, &mut self.gather_idxs);
148
}
149
}
150
}
151
152
pub fn combine(&mut self, other: &BottomKWithPayload<P>) {
153
let mut new_df_keys =
154
SecondaryMap::<DfsKey, DfsKey>::with_capacity(other.df_subsets.capacity());
155
for (dfs_key, dfs) in &other.df_subsets {
156
if dfs.subset_len > 0 {
157
let subset = DfSubset {
158
df: dfs.df.clone(),
159
rows: Vec::new(),
160
subset_len: 0,
161
};
162
new_df_keys.insert(dfs_key, self.df_subsets.insert(subset));
163
}
164
}
165
for prio in &other.heap {
166
let (dfs_key, row_idx_key) = prio.1;
167
self.add_one(
168
new_df_keys[dfs_key],
169
other.row_idxs[row_idx_key],
170
prio.0.clone(),
171
|l, r| l < r,
172
|x| x,
173
);
174
}
175
self.prune();
176
}
177
178
pub fn finalize(&mut self) -> Option<DataFrame> {
179
let mut gather_idx_buf = Vec::new();
180
if self.df_subsets.is_empty() {
181
return None;
182
}
183
let ret = accumulate_dataframes_vertical(self.df_subsets.drain().map(|(_k, mut df)| {
184
df.gather(&mut self.row_idxs, &mut gather_idx_buf);
185
df.df
186
}));
187
self.heap.clear();
188
self.row_idxs.clear();
189
self.to_prune.clear();
190
Some(ret.unwrap())
191
}
192
}
193
194
trait DfByKeyReducer: Any + Send + 'static {
195
fn new_empty(&self) -> Box<dyn DfByKeyReducer>;
196
fn add(&mut self, df: DataFrame, keys: DataFrame);
197
fn combine(&mut self, other: &dyn DfByKeyReducer);
198
fn finalize(self: Box<Self>) -> Option<DataFrame>;
199
}
200
201
struct PrimitiveBottomK<T: PolarsNumericType, const REVERSE: bool, const NULLS_LAST: bool> {
202
inner: BottomKWithPayload<
203
ReorderWithNulls<TotalOrdWrap<T::Physical<'static>>, REVERSE, NULLS_LAST>,
204
>,
205
}
206
207
impl<T: PolarsNumericType, const REVERSE: bool, const NULLS_LAST: bool>
208
PrimitiveBottomK<T, REVERSE, NULLS_LAST>
209
{
210
fn new(k: usize) -> Self {
211
Self {
212
inner: BottomKWithPayload::new(k),
213
}
214
}
215
}
216
217
impl<T: PolarsNumericType, const REVERSE: bool, const NULLS_LAST: bool> DfByKeyReducer
218
for PrimitiveBottomK<T, REVERSE, NULLS_LAST>
219
{
220
fn new_empty(&self) -> Box<dyn DfByKeyReducer> {
221
Box::new(Self {
222
inner: BottomKWithPayload::new(self.inner.k),
223
})
224
}
225
226
fn add(&mut self, df: DataFrame, keys: DataFrame) {
227
assert!(keys.width() == 1);
228
let keys = keys.get_columns()[0].as_materialized_series();
229
let key_ca: &ChunkedArray<T> = keys.as_phys_any().downcast_ref().unwrap();
230
self.inner.add_df(
231
df,
232
key_ca
233
.iter()
234
.map(|opt_x| ReorderWithNulls(opt_x.map(TotalOrdWrap))),
235
|l, r| l < r,
236
|x| x,
237
);
238
}
239
240
fn combine(&mut self, other: &dyn DfByKeyReducer) {
241
let other: &Self = (other as &dyn Any).downcast_ref().unwrap();
242
self.inner.combine(&other.inner);
243
}
244
245
fn finalize(mut self: Box<Self>) -> Option<DataFrame> {
246
self.inner.finalize()
247
}
248
}
249
250
struct BinaryBottomK<const REVERSE: bool, const NULLS_LAST: bool> {
251
inner: BottomKWithPayload<ReorderWithNulls<Vec<u8>, REVERSE, NULLS_LAST>>,
252
}
253
254
impl<const REVERSE: bool, const NULLS_LAST: bool> BinaryBottomK<REVERSE, NULLS_LAST> {
255
fn new(k: usize) -> Self {
256
Self {
257
inner: BottomKWithPayload::new(k),
258
}
259
}
260
}
261
262
impl<const REVERSE: bool, const NULLS_LAST: bool> DfByKeyReducer
263
for BinaryBottomK<REVERSE, NULLS_LAST>
264
{
265
fn new_empty(&self) -> Box<dyn DfByKeyReducer> {
266
Box::new(Self {
267
inner: BottomKWithPayload::new(self.inner.k),
268
})
269
}
270
271
fn add(&mut self, df: DataFrame, keys: DataFrame) {
272
assert!(keys.width() == 1);
273
let key_ca = if let Ok(ca_str) = keys[0].str() {
274
ca_str.as_binary()
275
} else {
276
df[0].binary().unwrap().clone()
277
};
278
self.inner.add_df(
279
df,
280
key_ca
281
.iter()
282
.map(ReorderWithNulls::<_, REVERSE, NULLS_LAST>),
283
|l, r| l < &ReorderWithNulls(r.0.as_deref()),
284
|x| ReorderWithNulls(x.0.map(<[u8]>::to_vec)),
285
);
286
}
287
288
fn combine(&mut self, other: &dyn DfByKeyReducer) {
289
let other: &Self = (other as &dyn Any).downcast_ref().unwrap();
290
self.inner.combine(&other.inner);
291
}
292
293
fn finalize(mut self: Box<Self>) -> Option<DataFrame> {
294
self.inner.finalize()
295
}
296
}
297
298
struct RowEncodedBottomK {
299
inner: BottomKWithPayload<Vec<u8>>,
300
reverse: Vec<bool>,
301
nulls_last: Vec<bool>,
302
}
303
304
impl RowEncodedBottomK {
305
fn new(k: usize, reverse: Vec<bool>, nulls_last: Vec<bool>) -> Self {
306
Self {
307
inner: BottomKWithPayload::new(k),
308
reverse,
309
nulls_last,
310
}
311
}
312
}
313
314
impl DfByKeyReducer for RowEncodedBottomK {
315
fn new_empty(&self) -> Box<dyn DfByKeyReducer> {
316
Box::new(Self {
317
inner: BottomKWithPayload::new(self.inner.k),
318
reverse: self.reverse.clone(),
319
nulls_last: self.nulls_last.clone(),
320
})
321
}
322
323
fn add(&mut self, df: DataFrame, keys: DataFrame) {
324
let keys_encoded = _get_rows_encoded(keys.get_columns(), &self.reverse, &self.nulls_last)
325
.unwrap()
326
.into_array();
327
self.inner.add_df(
328
df,
329
keys_encoded.values_iter(),
330
|l, r| *l < r.as_slice(),
331
|x| x.to_vec(),
332
);
333
}
334
335
fn combine(&mut self, other: &dyn DfByKeyReducer) {
336
let other: &Self = (other as &dyn Any).downcast_ref().unwrap();
337
self.inner.combine(&other.inner);
338
}
339
340
fn finalize(mut self: Box<Self>) -> Option<DataFrame> {
341
self.inner.finalize()
342
}
343
}
344
345
fn new_top_k_reducer(
346
k: usize,
347
reverse: &[bool],
348
nulls_last: &[bool],
349
key_schema: &Schema,
350
) -> Box<dyn DfByKeyReducer> {
351
if key_schema.len() == 1 {
352
let (_name, dt) = key_schema.get_at_index(0).unwrap();
353
match dt {
354
dt if dt.is_primitive_numeric() | dt.is_temporal() | dt.is_decimal() | dt.is_enum() => {
355
return with_match_physical_numeric_polars_type!(dt.to_physical(), |$T| {
356
match (reverse[0], nulls_last[0]) {
357
(false, false) => Box::new(PrimitiveBottomK::<$T, true, false>::new(k)),
358
(false, true) => Box::new(PrimitiveBottomK::<$T, true, true>::new(k)),
359
(true, false) => Box::new(PrimitiveBottomK::<$T, false, false>::new(k)),
360
(true, true) => Box::new(PrimitiveBottomK::<$T, false, true>::new(k)),
361
}
362
});
363
},
364
365
DataType::String | DataType::Binary => {
366
return match (reverse[0], nulls_last[0]) {
367
(false, false) => Box::new(BinaryBottomK::<true, false>::new(k)),
368
(false, true) => Box::new(BinaryBottomK::<true, true>::new(k)),
369
(true, false) => Box::new(BinaryBottomK::<false, false>::new(k)),
370
(true, true) => Box::new(BinaryBottomK::<false, true>::new(k)),
371
};
372
},
373
374
// TODO: categorical single-key.
375
_ => {},
376
}
377
}
378
379
let reverse = reverse.iter().map(|r| !r).collect();
380
Box::new(RowEncodedBottomK::new(k, reverse, nulls_last.to_vec()))
381
}
382
383
enum TopKState {
384
WaitingForK(InMemorySinkNode),
385
386
Sink {
387
key_selectors: Vec<StreamExpr>,
388
reducers: Vec<Box<dyn DfByKeyReducer>>,
389
},
390
391
Source(InMemorySourceNode),
392
393
Done,
394
}
395
396
pub struct TopKNode {
397
reverse: Vec<bool>,
398
nulls_last: Vec<bool>,
399
key_schema: Arc<Schema>,
400
key_selectors: Vec<StreamExpr>,
401
state: TopKState,
402
}
403
404
impl TopKNode {
405
pub fn new(
406
k_schema: Arc<Schema>,
407
reverse: Vec<bool>,
408
nulls_last: Vec<bool>,
409
key_schema: Arc<Schema>,
410
key_selectors: Vec<StreamExpr>,
411
) -> Self {
412
Self {
413
reverse,
414
nulls_last,
415
key_schema,
416
key_selectors,
417
state: TopKState::WaitingForK(InMemorySinkNode::new(k_schema)),
418
}
419
}
420
}
421
422
impl ComputeNode for TopKNode {
423
fn name(&self) -> &str {
424
if self.reverse.iter().all(|r| *r) {
425
"bottom-k"
426
} else {
427
"top-k"
428
}
429
}
430
431
fn update_state(
432
&mut self,
433
recv: &mut [PortState],
434
send: &mut [PortState],
435
state: &StreamingExecutionState,
436
) -> PolarsResult<()> {
437
assert!(recv.len() == 2 && send.len() == 1);
438
439
// State transitions.
440
match &mut self.state {
441
// If the output doesn't want any more data, transition to being done.
442
_ if send[0] == PortState::Done => {
443
self.state = TopKState::Done;
444
},
445
// We've received k, transition to being a sink.
446
TopKState::WaitingForK(inner) if recv[1] == PortState::Done => {
447
let k_frame = inner.get_output()?.unwrap();
448
polars_ensure!(k_frame.height() == 1, ComputeError: "got more than one value for 'k' in top_k");
449
let k_item = k_frame.get_columns()[0].get(0)?;
450
let k = k_item.extract::<usize>().ok_or_else(
451
|| polars_err!(ComputeError: "invalid value of 'k' in top_k: {:?}", k_item),
452
)?;
453
454
if k > 0 {
455
let reducer =
456
new_top_k_reducer(k, &self.reverse, &self.nulls_last, &self.key_schema);
457
let reducers = (0..state.num_pipelines)
458
.map(|_| reducer.new_empty())
459
.collect();
460
self.state = TopKState::Sink {
461
key_selectors: core::mem::take(&mut self.key_selectors),
462
reducers,
463
};
464
} else {
465
self.state = TopKState::Done;
466
}
467
},
468
// Input is done, transition to being a source.
469
TopKState::Sink { reducers, .. } if recv[0] == PortState::Done => {
470
let mut reducer = reducers.pop().unwrap();
471
for r in reducers {
472
reducer.combine(&**r);
473
}
474
if let Some(df) = reducer.finalize() {
475
self.state = TopKState::Source(InMemorySourceNode::new(
476
Arc::new(df),
477
MorselSeq::default(),
478
));
479
} else {
480
self.state = TopKState::Done;
481
}
482
},
483
// Nothing to change.
484
_ => {},
485
}
486
487
// Communicate our state.
488
match &mut self.state {
489
TopKState::WaitingForK(inner) => {
490
send[0] = PortState::Blocked;
491
recv[0] = PortState::Blocked;
492
inner.update_state(&mut recv[1..2], &mut [], state)?;
493
},
494
TopKState::Sink { .. } => {
495
send[0] = PortState::Blocked;
496
recv[0] = PortState::Ready;
497
recv[1] = PortState::Done;
498
},
499
TopKState::Source(src) => {
500
src.update_state(&mut [], send, state)?;
501
recv[0] = PortState::Done;
502
recv[1] = PortState::Done;
503
},
504
TopKState::Done => {
505
recv[0] = PortState::Done;
506
recv[1] = PortState::Done;
507
send[0] = PortState::Done;
508
},
509
}
510
Ok(())
511
}
512
513
fn spawn<'env, 's>(
514
&'env mut self,
515
scope: &'s TaskScope<'s, 'env>,
516
recv_ports: &mut [Option<RecvPort<'_>>],
517
send_ports: &mut [Option<SendPort<'_>>],
518
state: &'s StreamingExecutionState,
519
join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,
520
) {
521
assert!(recv_ports.len() == 2 && send_ports.len() == 1);
522
match &mut self.state {
523
TopKState::WaitingForK(inner) => {
524
assert!(send_ports[0].is_none());
525
assert!(recv_ports[0].is_none());
526
inner.spawn(scope, &mut recv_ports[1..2], &mut [], state, join_handles);
527
},
528
TopKState::Sink {
529
key_selectors,
530
reducers,
531
} => {
532
assert!(send_ports[0].is_none());
533
assert!(recv_ports[1].is_none());
534
let receivers = recv_ports[0].take().unwrap().parallel();
535
536
for (mut recv, reducer) in receivers.into_iter().zip(reducers) {
537
let key_selectors = &*key_selectors;
538
join_handles.push(scope.spawn_task(TaskPriority::High, async move {
539
while let Ok(morsel) = recv.recv().await {
540
let df = morsel.into_df();
541
let mut key_columns = Vec::new();
542
for selector in key_selectors {
543
let s = selector.evaluate(&df, &state.in_memory_exec_state).await?;
544
key_columns.push(s.into_column());
545
}
546
let keys = DataFrame::new_with_broadcast_len(key_columns, df.height())?;
547
548
reducer.add(df, keys);
549
}
550
551
Ok(())
552
}));
553
}
554
},
555
556
TopKState::Source(src) => {
557
assert!(recv_ports[0].is_none());
558
assert!(recv_ports[1].is_none());
559
src.spawn(scope, &mut [], send_ports, state, join_handles);
560
},
561
562
TopKState::Done => unreachable!(),
563
}
564
}
565
}
566
567