Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-stream/src/nodes/joins/asof_join.rs
8480 views
1
use std::collections::VecDeque;
2
3
use polars_core::prelude::*;
4
use polars_core::utils::Container;
5
use polars_ops::frame::{AsOfOptions, AsofStrategy, JoinArgs, JoinType};
6
use polars_utils::format_pl_smallstr;
7
8
use crate::DEFAULT_DISTRIBUTOR_BUFFER_SIZE;
9
use crate::async_executor::{JoinHandle, TaskPriority, TaskScope};
10
use crate::async_primitives::distributor_channel as dc;
11
use crate::execute::StreamingExecutionState;
12
use crate::graph::PortState;
13
use crate::morsel::{Morsel, MorselSeq, SourceToken};
14
use crate::nodes::ComputeNode;
15
use crate::nodes::joins::utils::DataFrameSearchBuffer;
16
use crate::pipe::{PortReceiver, PortSender, RecvPort, SendPort};
17
18
#[derive(Debug)]
19
pub struct AsOfJoinSideParams {
20
pub on: PlSmallStr,
21
pub tmp_key_col: Option<PlSmallStr>,
22
}
23
24
impl AsOfJoinSideParams {
25
fn key_col(&self) -> &PlSmallStr {
26
self.tmp_key_col.as_ref().unwrap_or(&self.on)
27
}
28
}
29
30
#[derive(Debug)]
31
struct AsOfJoinParams {
32
left: AsOfJoinSideParams,
33
right: AsOfJoinSideParams,
34
args: JoinArgs,
35
}
36
37
impl AsOfJoinParams {
38
fn as_of_options(&self) -> &AsOfOptions {
39
let JoinType::AsOf(ref options) = self.args.how else {
40
unreachable!("incorrect join type");
41
};
42
options
43
}
44
}
45
46
#[derive(Debug, Default, PartialEq)]
47
enum AsOfJoinState {
48
#[default]
49
Running,
50
FlushInputBuffer,
51
Done,
52
}
53
54
#[derive(Debug)]
55
pub struct AsOfJoinNode {
56
params: AsOfJoinParams,
57
state: AsOfJoinState,
58
/// We may need to stash a morsel on the left side whenever we do not
59
/// have enough data on the right side, but the right side is empty.
60
/// In these cases, we stash that morsel here.
61
left_buffer: VecDeque<(DataFrame, MorselSeq)>,
62
/// Buffer of the live range of right AsOf join rows.
63
right_buffer: DataFrameSearchBuffer,
64
}
65
66
impl AsOfJoinNode {
67
pub fn new(
68
left_input_schema: SchemaRef,
69
right_input_schema: SchemaRef,
70
left_on: PlSmallStr,
71
right_on: PlSmallStr,
72
tmp_left_key_col: Option<PlSmallStr>,
73
tmp_right_key_col: Option<PlSmallStr>,
74
args: JoinArgs,
75
) -> Self {
76
let left_key_col = tmp_left_key_col.as_ref().unwrap_or(&left_on);
77
let right_key_col = tmp_right_key_col.as_ref().unwrap_or(&right_on);
78
let left_key_dtype = left_input_schema.get(left_key_col).unwrap();
79
let right_key_dtype = right_input_schema.get(right_key_col).unwrap();
80
assert_eq!(left_key_dtype, right_key_dtype);
81
let left = AsOfJoinSideParams {
82
on: left_on,
83
tmp_key_col: tmp_left_key_col,
84
};
85
let right = AsOfJoinSideParams {
86
on: right_on,
87
tmp_key_col: tmp_right_key_col,
88
};
89
90
let params = AsOfJoinParams { left, right, args };
91
AsOfJoinNode {
92
params,
93
state: AsOfJoinState::default(),
94
left_buffer: Default::default(),
95
right_buffer: DataFrameSearchBuffer::empty_with_schema(right_input_schema),
96
}
97
}
98
}
99
100
impl ComputeNode for AsOfJoinNode {
101
fn name(&self) -> &str {
102
"asof-join"
103
}
104
105
fn update_state(
106
&mut self,
107
recv: &mut [PortState],
108
send: &mut [PortState],
109
_state: &StreamingExecutionState,
110
) -> PolarsResult<()> {
111
assert!(recv.len() == 2 && send.len() == 1);
112
113
if send[0] == PortState::Done {
114
self.state = AsOfJoinState::Done;
115
}
116
117
if self.state == AsOfJoinState::Running && recv[0] == PortState::Done {
118
self.state = AsOfJoinState::FlushInputBuffer;
119
}
120
121
if self.state == AsOfJoinState::FlushInputBuffer && self.left_buffer.is_empty() {
122
self.state = AsOfJoinState::Done;
123
}
124
125
let recv0_blocked = recv[0] == PortState::Blocked;
126
let recv1_blocked = recv[1] == PortState::Blocked;
127
let send_blocked = send[0] == PortState::Blocked;
128
match self.state {
129
AsOfJoinState::Running => {
130
recv[0] = PortState::Ready;
131
recv[1] = PortState::Ready;
132
send[0] = PortState::Ready;
133
if recv0_blocked {
134
recv[1] = PortState::Blocked;
135
send[0] = PortState::Blocked;
136
}
137
if recv1_blocked {
138
recv[0] = PortState::Blocked;
139
send[0] = PortState::Blocked;
140
}
141
if send_blocked {
142
recv[0] = PortState::Blocked;
143
recv[1] = PortState::Blocked;
144
}
145
},
146
AsOfJoinState::FlushInputBuffer => {
147
recv[0] = PortState::Done;
148
recv[1] = PortState::Ready;
149
send[0] = PortState::Ready;
150
if recv1_blocked {
151
send[0] = PortState::Blocked;
152
}
153
if send_blocked {
154
recv[1] = PortState::Blocked;
155
}
156
},
157
AsOfJoinState::Done => {
158
recv.fill(PortState::Done);
159
send[0] = PortState::Done;
160
},
161
}
162
163
Ok(())
164
}
165
166
fn spawn<'env, 's>(
167
&'env mut self,
168
scope: &'s TaskScope<'s, 'env>,
169
recv_ports: &mut [Option<RecvPort<'_>>],
170
send_ports: &mut [Option<SendPort<'_>>],
171
_state: &'s StreamingExecutionState,
172
join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,
173
) {
174
assert!(recv_ports.len() == 2 && send_ports.len() == 1);
175
176
match &self.state {
177
AsOfJoinState::Running | AsOfJoinState::FlushInputBuffer => {
178
let params = &self.params;
179
let recv_left = match self.state {
180
AsOfJoinState::Running => Some(recv_ports[0].take().unwrap().serial()),
181
_ => None,
182
};
183
let recv_right = recv_ports[1].take().map(RecvPort::serial);
184
let send = send_ports[0].take().unwrap().parallel();
185
let (distributor, dist_recv) =
186
dc::distributor_channel(send.len(), *DEFAULT_DISTRIBUTOR_BUFFER_SIZE);
187
let left_buffer = &mut self.left_buffer;
188
let right_buffer = &mut self.right_buffer;
189
join_handles.push(scope.spawn_task(TaskPriority::High, async move {
190
distribute_work_task(
191
recv_left,
192
recv_right,
193
distributor,
194
left_buffer,
195
right_buffer,
196
params,
197
)
198
.await
199
}));
200
201
join_handles.extend(dist_recv.into_iter().zip(send).map(|(recv, send)| {
202
scope.spawn_task(TaskPriority::High, async move {
203
compute_and_emit_task(recv, send, params).await
204
})
205
}));
206
},
207
AsOfJoinState::Done => {
208
unreachable!();
209
},
210
}
211
}
212
}
213
214
/// Tell the sender to this port to stop, and buffer everything that is still in the pipe.
215
async fn stop_and_buffer_pipe_contents<F>(port: Option<&mut PortReceiver>, buffer_morsel: &mut F)
216
where
217
F: FnMut(DataFrame, MorselSeq),
218
{
219
let Some(port) = port else {
220
return;
221
};
222
223
while let Ok(morsel) = port.recv().await {
224
morsel.source_token().stop();
225
let (df, seq, _, _) = morsel.into_inner();
226
buffer_morsel(df, seq);
227
}
228
}
229
230
async fn distribute_work_task(
231
mut recv_left: Option<PortReceiver>,
232
mut recv_right: Option<PortReceiver>,
233
mut distributor: dc::Sender<(DataFrame, DataFrameSearchBuffer, MorselSeq, SourceToken)>,
234
left_buffer: &mut VecDeque<(DataFrame, MorselSeq)>,
235
right_buffer: &mut DataFrameSearchBuffer,
236
params: &AsOfJoinParams,
237
) -> PolarsResult<()> {
238
let source_token = SourceToken::new();
239
let right_done = recv_right.is_none();
240
241
loop {
242
if source_token.stop_requested() {
243
stop_and_buffer_pipe_contents(recv_left.as_mut(), &mut |df, seq| {
244
left_buffer.push_back((df, seq))
245
})
246
.await;
247
stop_and_buffer_pipe_contents(recv_right.as_mut(), &mut |df, _| {
248
right_buffer.push_df(df)
249
})
250
.await;
251
return Ok(());
252
}
253
254
let (left_df, seq, st) = if let Some((df, seq)) = left_buffer.pop_front() {
255
(df, seq, source_token.clone())
256
} else if let Some(ref mut recv) = recv_left
257
&& let Ok(m) = recv.recv().await
258
{
259
let (df, seq, st, _) = m.into_inner();
260
(df, seq, st)
261
} else {
262
stop_and_buffer_pipe_contents(recv_right.as_mut(), &mut |df, _| {
263
right_buffer.push_df(df)
264
})
265
.await;
266
return Ok(());
267
};
268
269
while need_more_right_side(&left_df, right_buffer, params)? && !right_done {
270
if let Some(ref mut recv) = recv_right
271
&& let Ok(morsel_right) = recv.recv().await
272
{
273
right_buffer.push_df(morsel_right.into_df());
274
} else {
275
// The right pipe is empty at this stage, we will need to wait for
276
// a new stage and try again.
277
left_buffer.push_back((left_df, seq));
278
stop_and_buffer_pipe_contents(recv_left.as_mut(), &mut |df, seq| {
279
left_buffer.push_back((df, seq))
280
})
281
.await;
282
return Ok(());
283
}
284
}
285
286
distributor
287
.send((left_df.clone(), right_buffer.clone(), seq, st))
288
.await
289
.unwrap();
290
prune_right_side(&left_df, right_buffer, params)?;
291
}
292
}
293
294
/// Do we need more values on the right side before we can compute the AsOf join
295
/// between the right side and the complete left side?
296
fn need_more_right_side(
297
left: &DataFrame,
298
right: &DataFrameSearchBuffer,
299
params: &AsOfJoinParams,
300
) -> PolarsResult<bool> {
301
let options = params.as_of_options();
302
let left_key = left.column(params.left.key_col())?.as_materialized_series();
303
if left_key.is_empty() {
304
return Ok(false);
305
}
306
// SAFETY: We just checked that left_key is not empty
307
let left_last_val = unsafe { left_key.get_unchecked(left_key.len() - 1) };
308
let right_range_end = match (options.strategy, options.allow_eq) {
309
(AsofStrategy::Forward, true) => {
310
right.binary_search(|x| *x >= left_last_val, params.right.key_col(), false)
311
},
312
(AsofStrategy::Forward, false) | (AsofStrategy::Backward, true) => {
313
right.binary_search(|x| *x > left_last_val, params.right.key_col(), false)
314
},
315
(AsofStrategy::Backward, false) | (AsofStrategy::Nearest, _) => {
316
let first_greater =
317
right.binary_search(|x| *x > left_last_val, params.right.key_col(), false);
318
if first_greater >= right.height() {
319
return Ok(true);
320
}
321
// In the Backward/Nearest cases, there may be a chunk of consecutive equal
322
// values following the match value on the left side. In this case, the AsOf
323
// join is greedy and should until the *end* of that chunk.
324
325
// SAFETY: We just checked that right_range_end is in bounds
326
let fst_greater_val =
327
unsafe { right.get_bypass_validity(params.right.key_col(), first_greater, false) };
328
right.binary_search(|x| *x > fst_greater_val, params.right.key_col(), false)
329
},
330
};
331
Ok(right_range_end >= right.height())
332
}
333
334
fn prune_right_side(
335
left: &DataFrame,
336
right: &mut DataFrameSearchBuffer,
337
params: &AsOfJoinParams,
338
) -> PolarsResult<()> {
339
let left_key = left.column(params.left.key_col())?.as_materialized_series();
340
if left.len() == 0 {
341
return Ok(());
342
}
343
// SAFETY: We just checked that left_key is not empty
344
let left_first_val = unsafe { left_key.get_unchecked(0) };
345
let right_range_start = right
346
.binary_search(|x| *x >= left_first_val, params.right.key_col(), false)
347
.saturating_sub(1);
348
right.split_at(right_range_start);
349
Ok(())
350
}
351
352
async fn compute_and_emit_task(
353
mut dist_recv: dc::Receiver<(DataFrame, DataFrameSearchBuffer, MorselSeq, SourceToken)>,
354
mut send: PortSender,
355
params: &AsOfJoinParams,
356
) -> PolarsResult<()> {
357
let options = params.as_of_options();
358
while let Ok((left_df, right_buffer, seq, st)) = dist_recv.recv().await {
359
let right_df = right_buffer.into_df();
360
361
let left_key = left_df.column(params.left.key_col())?;
362
let right_key = right_df.column(params.right.key_col())?;
363
let any_key_is_temporary_col =
364
params.left.tmp_key_col.is_some() || params.right.tmp_key_col.is_some();
365
let mut out = polars_ops::frame::AsofJoin::_join_asof(
366
&left_df,
367
&right_df,
368
left_key.as_materialized_series(),
369
right_key.as_materialized_series(),
370
options.strategy,
371
options.tolerance.clone().map(Scalar::into_value),
372
params.args.suffix.clone(),
373
None,
374
any_key_is_temporary_col || params.args.should_coalesce(),
375
options.allow_eq,
376
options.check_sortedness,
377
)?;
378
379
// Drop any temporary key columns that were added
380
for tmp_key_col in [&params.left.tmp_key_col, &params.right.tmp_key_col] {
381
if let Some(tmp_col) = tmp_key_col
382
&& out.schema().contains(tmp_col)
383
{
384
out.drop_in_place(tmp_col)?;
385
}
386
}
387
388
// If the join key passed to _join_asof() was a temporary key column,
389
// we still need to coalesce the real 'on' columns ourselves.
390
if any_key_is_temporary_col
391
&& params.args.should_coalesce()
392
&& params.left.on == params.right.on
393
{
394
let right_on_name = format_pl_smallstr!("{}{}", params.right.on, params.args.suffix());
395
out.drop_in_place(&right_on_name)?;
396
}
397
398
let morsel = Morsel::new(out, seq, st);
399
if send.send(morsel).await.is_err() {
400
return Ok(());
401
}
402
}
403
Ok(())
404
}
405
406