Path: blob/main/crates/polars-stream/src/nodes/joins/in_memory.rs
6939 views
use std::sync::Arc;12use polars_core::schema::Schema;34use crate::nodes::compute_node_prelude::*;5use crate::nodes::in_memory_sink::InMemorySinkNode;6use crate::nodes::in_memory_source::InMemorySourceNode;78enum InMemoryJoinState {9Sink {10left: InMemorySinkNode,11right: InMemorySinkNode,12},13Source(InMemorySourceNode),14Done,15}1617pub struct InMemoryJoinNode {18state: InMemoryJoinState,19joiner: Arc<dyn Fn(DataFrame, DataFrame) -> PolarsResult<DataFrame> + Send + Sync>,20}2122impl InMemoryJoinNode {23pub fn new(24left_input_schema: Arc<Schema>,25right_input_schema: Arc<Schema>,26joiner: Arc<dyn Fn(DataFrame, DataFrame) -> PolarsResult<DataFrame> + Send + Sync>,27) -> Self {28Self {29state: InMemoryJoinState::Sink {30left: InMemorySinkNode::new(left_input_schema),31right: InMemorySinkNode::new(right_input_schema),32},33joiner,34}35}36}3738impl ComputeNode for InMemoryJoinNode {39fn name(&self) -> &str {40"in-memory-join"41}4243fn update_state(44&mut self,45recv: &mut [PortState],46send: &mut [PortState],47state: &StreamingExecutionState,48) -> PolarsResult<()> {49assert!(recv.len() == 2 && send.len() == 1);5051// If the output doesn't want any more data, transition to being done.52if send[0] == PortState::Done && !matches!(self.state, InMemoryJoinState::Done) {53self.state = InMemoryJoinState::Done;54}5556// If the input is done, transition to being a source.57if let InMemoryJoinState::Sink { left, right } = &mut self.state {58if recv[0] == PortState::Done && recv[1] == PortState::Done {59let left_df = left.get_output()?.unwrap();60let right_df = right.get_output()?.unwrap();61let source_node = InMemorySourceNode::new(62Arc::new((self.joiner)(left_df, right_df)?),63MorselSeq::default(),64);65self.state = InMemoryJoinState::Source(source_node);66}67}6869match &mut self.state {70InMemoryJoinState::Sink { left, right, .. } => {71left.update_state(&mut recv[0..1], &mut [], state)?;72right.update_state(&mut recv[1..2], &mut [], state)?;73send[0] = PortState::Blocked;74},75InMemoryJoinState::Source(source_node) => {76recv[0] = PortState::Done;77recv[1] = PortState::Done;78source_node.update_state(&mut [], send, state)?;79},80InMemoryJoinState::Done => {81recv[0] = PortState::Done;82recv[1] = PortState::Done;83send[0] = PortState::Done;84},85}86Ok(())87}8889fn is_memory_intensive_pipeline_blocker(&self) -> bool {90matches!(self.state, InMemoryJoinState::Sink { .. })91}9293fn spawn<'env, 's>(94&'env mut self,95scope: &'s TaskScope<'s, 'env>,96recv_ports: &mut [Option<RecvPort<'_>>],97send_ports: &mut [Option<SendPort<'_>>],98state: &'s StreamingExecutionState,99join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,100) {101assert!(recv_ports.len() == 2);102assert!(send_ports.len() == 1);103match &mut self.state {104InMemoryJoinState::Sink { left, right, .. } => {105if recv_ports[0].is_some() {106left.spawn(scope, &mut recv_ports[0..1], &mut [], state, join_handles);107}108if recv_ports[1].is_some() {109right.spawn(scope, &mut recv_ports[1..2], &mut [], state, join_handles);110}111},112InMemoryJoinState::Source(source) => {113source.spawn(scope, &mut [], send_ports, state, join_handles)114},115InMemoryJoinState::Done => unreachable!(),116}117}118}119120121