Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-stream/src/nodes/io_sinks/partition/by_key.rs
6939 views
1
use std::cmp::Reverse;
2
use std::pin::Pin;
3
use std::sync::{Arc, OnceLock};
4
5
use futures::StreamExt;
6
use futures::stream::FuturesUnordered;
7
use polars_core::config;
8
use polars_core::frame::DataFrame;
9
use polars_core::prelude::{Column, PlHashSet, PlIndexMap, row_encode};
10
use polars_core::schema::SchemaRef;
11
use polars_core::utils::arrow::buffer::Buffer;
12
use polars_error::PolarsResult;
13
use polars_plan::dsl::{PartitionTargetCallback, SinkFinishCallback, SinkOptions};
14
use polars_utils::pl_str::PlSmallStr;
15
use polars_utils::plpath::PlPath;
16
use polars_utils::priority::Priority;
17
18
use super::{CreateNewSinkFn, PerPartitionSortBy};
19
use crate::async_executor::{AbortOnDropHandle, spawn};
20
use crate::async_primitives::connector::connector;
21
use crate::execute::StreamingExecutionState;
22
use crate::morsel::SourceToken;
23
use crate::nodes::io_sinks::metrics::WriteMetrics;
24
use crate::nodes::io_sinks::partition::{SinkSender, open_new_sink};
25
use crate::nodes::io_sinks::phase::PhaseOutcome;
26
use crate::nodes::io_sinks::{SinkInputPort, SinkNode, parallelize_receive_task};
27
use crate::nodes::{JoinHandle, Morsel, MorselSeq, TaskPriority};
28
29
type Linearized =
30
Priority<Reverse<MorselSeq>, (SourceToken, Vec<(Buffer<u8>, Vec<Column>, DataFrame)>)>;
31
pub struct PartitionByKeySinkNode {
32
input_schema: SchemaRef,
33
// This is not be the same as the input_schema, e.g. when include_key=false then this will not
34
// include the keys columns.
35
sink_input_schema: SchemaRef,
36
37
key_cols: Arc<[PlSmallStr]>,
38
39
max_open_partitions: usize,
40
include_key: bool,
41
42
base_path: Arc<PlPath>,
43
file_path_cb: Option<PartitionTargetCallback>,
44
create_new: CreateNewSinkFn,
45
ext: PlSmallStr,
46
47
sink_options: SinkOptions,
48
49
per_partition_sort_by: Option<PerPartitionSortBy>,
50
written_partitions: Arc<OnceLock<DataFrame>>,
51
finish_callback: Option<SinkFinishCallback>,
52
}
53
54
impl PartitionByKeySinkNode {
55
#[allow(clippy::too_many_arguments)]
56
pub fn new(
57
input_schema: SchemaRef,
58
key_cols: Arc<[PlSmallStr]>,
59
base_path: Arc<PlPath>,
60
file_path_cb: Option<PartitionTargetCallback>,
61
create_new: CreateNewSinkFn,
62
ext: PlSmallStr,
63
sink_options: SinkOptions,
64
include_key: bool,
65
per_partition_sort_by: Option<PerPartitionSortBy>,
66
finish_callback: Option<SinkFinishCallback>,
67
) -> Self {
68
assert!(!key_cols.is_empty());
69
70
let mut sink_input_schema = input_schema.clone();
71
if !include_key {
72
let keys_col_hm = PlHashSet::from_iter(key_cols.iter().map(|s| s.as_str()));
73
sink_input_schema = Arc::new(
74
sink_input_schema
75
.try_project(
76
input_schema
77
.iter_names()
78
.filter(|n| !keys_col_hm.contains(n.as_str()))
79
.cloned(),
80
)
81
.unwrap(),
82
);
83
}
84
85
const DEFAULT_MAX_OPEN_PARTITIONS: usize = 128;
86
let max_open_partitions =
87
std::env::var("POLARS_MAX_OPEN_PARTITIONS").map_or(DEFAULT_MAX_OPEN_PARTITIONS, |v| {
88
v.parse::<usize>()
89
.expect("unable to parse POLARS_MAX_OPEN_PARTITIONS")
90
});
91
92
Self {
93
input_schema,
94
sink_input_schema,
95
key_cols,
96
max_open_partitions,
97
include_key,
98
base_path,
99
file_path_cb,
100
create_new,
101
ext,
102
sink_options,
103
per_partition_sort_by,
104
written_partitions: Arc::new(OnceLock::new()),
105
finish_callback,
106
}
107
}
108
}
109
110
impl SinkNode for PartitionByKeySinkNode {
111
fn name(&self) -> &str {
112
"partition-by-key-sink"
113
}
114
115
fn is_sink_input_parallel(&self) -> bool {
116
true
117
}
118
119
fn do_maintain_order(&self) -> bool {
120
self.sink_options.maintain_order
121
}
122
123
fn initialize(&mut self, _state: &StreamingExecutionState) -> PolarsResult<()> {
124
Ok(())
125
}
126
127
fn spawn_sink(
128
&mut self,
129
recv_port_rx: crate::async_primitives::connector::Receiver<(PhaseOutcome, SinkInputPort)>,
130
state: &StreamingExecutionState,
131
join_handles: &mut Vec<JoinHandle<polars_error::PolarsResult<()>>>,
132
) {
133
let (io_tx, mut io_rx) = connector();
134
let pass_rxs = parallelize_receive_task::<Linearized>(
135
join_handles,
136
recv_port_rx,
137
state.num_pipelines,
138
self.sink_options.maintain_order,
139
io_tx,
140
);
141
142
join_handles.extend(pass_rxs.into_iter().map(|mut pass_rx| {
143
let key_cols = self.key_cols.clone();
144
let stable = self.sink_options.maintain_order;
145
let include_key = self.include_key;
146
147
spawn(TaskPriority::High, async move {
148
while let Ok((mut rx, mut lin_tx)) = pass_rx.recv().await {
149
while let Ok(morsel) = rx.recv().await {
150
let (df, seq, source_token, consume_token) = morsel.into_inner();
151
152
let partition_include_key = true; // We need the keys to send to the
153
// appropriate sink.
154
let parallel = false; // We handle parallel processing in the streaming
155
// engine.
156
let partitions = df._partition_by_impl(
157
&key_cols,
158
stable,
159
partition_include_key,
160
parallel,
161
)?;
162
163
let partitions = partitions
164
.into_iter()
165
.map(|mut df| {
166
let keys = df.select_columns(key_cols.iter().cloned())?;
167
let keys = keys
168
.into_iter()
169
.map(|c| c.head(Some(1)))
170
.collect::<Vec<_>>();
171
172
let row_encoded = row_encode::encode_rows_unordered(&keys)?
173
.downcast_into_iter()
174
.next()
175
.unwrap();
176
let row_encoded = row_encoded.into_inner().2;
177
178
if !include_key {
179
df = df.drop_many(key_cols.iter().cloned());
180
}
181
182
PolarsResult::Ok((row_encoded, keys, df))
183
})
184
.collect::<PolarsResult<Vec<(Buffer<u8>, Vec<Column>, DataFrame)>>>()?;
185
186
if lin_tx
187
.insert(Priority(Reverse(seq), (source_token, partitions)))
188
.await
189
.is_err()
190
{
191
return Ok(());
192
}
193
// It is important that we don't pass the consume
194
// token to the sinks, because that leads to
195
// deadlocks.
196
drop(consume_token);
197
}
198
}
199
200
Ok(())
201
})
202
}));
203
204
let state = state.clone();
205
let input_schema = self.input_schema.clone();
206
let key_cols = self.key_cols.clone();
207
let sink_input_schema = self.sink_input_schema.clone();
208
let max_open_partitions = self.max_open_partitions;
209
let base_path = self.base_path.clone();
210
let file_path_cb = self.file_path_cb.clone();
211
let create_new_sink = self.create_new.clone();
212
let ext = self.ext.clone();
213
let per_partition_sort_by = self.per_partition_sort_by.clone();
214
let output_written_partitions = self.written_partitions.clone();
215
join_handles.push(spawn(TaskPriority::High, async move {
216
enum OpenPartition {
217
Sink {
218
sender: SinkSender,
219
join_handles: FuturesUnordered<AbortOnDropHandle<PolarsResult<()>>>,
220
node: Box<dyn SinkNode + Send>,
221
keys: Vec<Column>,
222
},
223
Buffer {
224
buffered: Vec<DataFrame>,
225
keys: Vec<Column>,
226
},
227
}
228
229
let verbose = config::verbose();
230
let mut file_idx = 0;
231
let mut open_partitions: PlIndexMap<Buffer<u8>, OpenPartition> = PlIndexMap::default();
232
233
// Wrap this in a closure so that a failure to send (which signifies a failure) can be
234
// caught while waiting for tasks.
235
let mut receive_and_pass = async || {
236
while let Ok(mut lin_rx) = io_rx.recv().await {
237
while let Some(Priority(Reverse(seq), (source_token, partitions))) =
238
lin_rx.get().await
239
{
240
for (row_encoded, keys, partition) in partitions {
241
let num_open_partitions = open_partitions.len();
242
let open_partition = match open_partitions.get_mut(&row_encoded) {
243
None if num_open_partitions >= max_open_partitions => {
244
if num_open_partitions == max_open_partitions && verbose {
245
eprintln!(
246
"[partition[by-key]]: Reached maximum open partitions. Buffering the rest to memory before writing.",
247
);
248
}
249
250
let (idx, previous) = open_partitions.insert_full(
251
row_encoded,
252
OpenPartition::Buffer { buffered: Vec::new(), keys },
253
);
254
debug_assert!(previous.is_none());
255
open_partitions.get_index_mut(idx).unwrap().1
256
},
257
None => {
258
let result = open_new_sink(
259
base_path.as_ref().as_ref(),
260
file_path_cb.as_ref(),
261
super::default_by_key_file_path_cb,
262
file_idx,
263
file_idx,
264
0,
265
Some(keys.as_slice()),
266
&create_new_sink,
267
sink_input_schema.clone(),
268
"by-key",
269
ext.as_str(),
270
verbose,
271
&state,
272
per_partition_sort_by.as_ref(),
273
).await?;
274
file_idx += 1;
275
276
let Some((join_handles, sender, node)) = result else {
277
return Ok(());
278
};
279
280
let (idx, previous) = open_partitions.insert_full(
281
row_encoded,
282
OpenPartition::Sink { sender, join_handles, node, keys },
283
);
284
debug_assert!(previous.is_none());
285
open_partitions.get_index_mut(idx).unwrap().1
286
},
287
Some(open_partition) => open_partition,
288
};
289
290
match open_partition {
291
OpenPartition::Sink { sender, .. } => {
292
let morsel = Morsel::new(partition, seq, source_token.clone());
293
if sender.send(morsel).await.is_err() {
294
return Ok(());
295
}
296
},
297
OpenPartition::Buffer { buffered, .. } => buffered.push(partition),
298
}
299
}
300
}
301
}
302
303
PolarsResult::Ok(())
304
};
305
receive_and_pass().await?;
306
307
let mut partition_metrics = Vec::with_capacity(file_idx);
308
309
// At this point, we need to wait for all sinks to finish writing and close them. Also,
310
// sinks that ended up buffering need to output their data.
311
for open_partition in open_partitions.into_values() {
312
let (sender, mut join_handles, mut node, keys) = match open_partition {
313
OpenPartition::Sink { sender, join_handles, node, keys } => (sender, join_handles, node, keys),
314
OpenPartition::Buffer { buffered, keys } => {
315
let result = open_new_sink(
316
base_path.as_ref().as_ref(),
317
file_path_cb.as_ref(),
318
super::default_by_key_file_path_cb,
319
file_idx,
320
file_idx,
321
0,
322
Some(keys.as_slice()),
323
&create_new_sink,
324
sink_input_schema.clone(),
325
"by-key",
326
ext.as_str(),
327
verbose,
328
&state,
329
per_partition_sort_by.as_ref(),
330
).await?;
331
file_idx += 1;
332
let Some((join_handles, mut sender, node)) = result else {
333
return Ok(());
334
};
335
336
let source_token = SourceToken::new();
337
let mut seq = MorselSeq::default();
338
for df in buffered {
339
let morsel = Morsel::new(df, seq, source_token.clone());
340
if sender.send(morsel).await.is_err() {
341
return Ok(());
342
}
343
seq = seq.successor();
344
}
345
346
(sender, join_handles, node, keys)
347
},
348
};
349
350
drop(sender); // Signal to the sink that nothing more is coming.
351
while let Some(res) = join_handles.next().await {
352
res?;
353
}
354
355
if let Some(mut metrics) = node.get_metrics()? {
356
metrics.keys = Some(keys.into_iter().map(|c| c.get(0).unwrap().into_static()).collect());
357
partition_metrics.push(metrics);
358
}
359
if let Some(finalize) = node.finalize(&state) {
360
finalize.await?;
361
}
362
}
363
364
let df = WriteMetrics::collapse_to_df(partition_metrics, &sink_input_schema, Some(&input_schema.try_project(key_cols.iter()).unwrap()));
365
output_written_partitions.set(df).unwrap();
366
Ok(())
367
}));
368
}
369
370
fn finalize(
371
&mut self,
372
_state: &StreamingExecutionState,
373
) -> Option<Pin<Box<dyn Future<Output = PolarsResult<()>> + Send>>> {
374
let finish_callback = self.finish_callback.clone();
375
let written_partitions = self.written_partitions.clone();
376
377
Some(Box::pin(async move {
378
if let Some(finish_callback) = &finish_callback {
379
let df = written_partitions.get().unwrap();
380
finish_callback.call(df.clone())?;
381
}
382
Ok(())
383
}))
384
}
385
}
386
387