Path: blob/main/crates/polars-stream/src/nodes/in_memory_source.rs
6939 views
use std::sync::Arc;1use std::sync::atomic::{AtomicU64, Ordering};23use super::compute_node_prelude::*;4use crate::async_primitives::wait_group::WaitGroup;5use crate::morsel::{MorselSeq, SourceToken, get_ideal_morsel_size};67pub struct InMemorySourceNode {8source: Option<Arc<DataFrame>>,9morsel_size: usize,10seq: AtomicU64,11seq_offset: MorselSeq,12}1314impl InMemorySourceNode {15pub fn new(source: Arc<DataFrame>, seq_offset: MorselSeq) -> Self {16InMemorySourceNode {17source: Some(source),18morsel_size: 0,19seq: AtomicU64::new(0),20seq_offset,21}22}23}2425impl ComputeNode for InMemorySourceNode {26fn name(&self) -> &str {27"in-memory-source"28}2930fn update_state(31&mut self,32recv: &mut [PortState],33send: &mut [PortState],34state: &StreamingExecutionState,35) -> PolarsResult<()> {36assert!(recv.is_empty());37assert!(send.len() == 1);3839if self.morsel_size == 0 {40let len = self.source.as_ref().unwrap().height();41let ideal_morsel_count = (len / get_ideal_morsel_size()).max(1);42let morsel_count = ideal_morsel_count.next_multiple_of(state.num_pipelines);43self.morsel_size = len.div_ceil(morsel_count).max(1);44self.seq = AtomicU64::new(0);45}4647// As a temporary hack for some nodes (like the FunctionIR::FastCount)48// node that rely on an empty input, always ensure we send at least one49// morsel.50// TODO: remove this hack.51let exhausted = if let Some(src) = &self.source {52let seq = self.seq.load(Ordering::Relaxed);53seq > 0 && seq * self.morsel_size as u64 >= src.height() as u6454} else {55true56};57if send[0] == PortState::Done || exhausted {58send[0] = PortState::Done;59self.source = None;60} else {61send[0] = PortState::Ready;62}63Ok(())64}6566fn spawn<'env, 's>(67&'env mut self,68scope: &'s TaskScope<'s, 'env>,69recv_ports: &mut [Option<RecvPort<'_>>],70send_ports: &mut [Option<SendPort<'_>>],71_state: &'s StreamingExecutionState,72join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,73) {74assert!(recv_ports.is_empty() && send_ports.len() == 1);75let senders = send_ports[0].take().unwrap().parallel();76let source = self.source.as_ref().unwrap();7778// TODO: can this just be serial, using the work distributor?79let source_token = SourceToken::new();80for mut send in senders {81let slf = &*self;82let source_token = source_token.clone();83join_handles.push(scope.spawn_task(TaskPriority::Low, async move {84let wait_group = WaitGroup::default();85loop {86let seq = slf.seq.fetch_add(1, Ordering::Relaxed);87let offset = (seq as usize * slf.morsel_size) as i64;88let df = source.slice(offset, slf.morsel_size);8990// TODO: remove this 'always sent at least one morsel'91// condition, see update_state.92if df.height() == 0 && seq > 0 {93break;94}9596let morsel_seq = MorselSeq::new(seq).offset_by(slf.seq_offset);97let mut morsel = Morsel::new(df, morsel_seq, source_token.clone());98morsel.set_consume_token(wait_group.token());99if send.send(morsel).await.is_err() {100break;101}102103wait_group.wait().await;104if source_token.stop_requested() {105break;106}107}108109Ok(())110}));111}112}113}114115116