Path: blob/main/crates/polars-stream/src/nodes/in_memory_source.rs
8420 views
use std::sync::Arc;1use std::sync::atomic::{AtomicU64, Ordering};23use super::compute_node_prelude::*;4use crate::async_primitives::wait_group::WaitGroup;5use crate::morsel::{MorselSeq, SourceToken, get_ideal_morsel_size};67#[derive(Debug)]8pub struct InMemorySourceNode {9source: Option<Arc<DataFrame>>,10morsel_size: usize,11seq: AtomicU64,12seq_offset: MorselSeq,13}1415impl InMemorySourceNode {16pub fn new(source: Arc<DataFrame>, seq_offset: MorselSeq) -> Self {17InMemorySourceNode {18source: Some(source),19morsel_size: 0,20seq: AtomicU64::new(0),21seq_offset,22}23}2425pub fn new_no_morsel_split(source: Arc<DataFrame>, seq_offset: MorselSeq) -> Self {26let morsel_size = source.height();2728InMemorySourceNode {29source: Some(source),30morsel_size,31seq: AtomicU64::new(0),32seq_offset,33}34}35}3637impl ComputeNode for InMemorySourceNode {38fn name(&self) -> &str {39"in-memory-source"40}4142fn update_state(43&mut self,44recv: &mut [PortState],45send: &mut [PortState],46state: &StreamingExecutionState,47) -> PolarsResult<()> {48assert!(recv.is_empty());49assert!(send.len() == 1);5051if self.morsel_size == 0 {52let len = self.source.as_ref().unwrap().height();53let ideal_morsel_count = (len / get_ideal_morsel_size()).max(1);54let morsel_count = ideal_morsel_count.next_multiple_of(state.num_pipelines);55self.morsel_size = len.div_ceil(morsel_count).max(1);56self.seq = AtomicU64::new(0);57}5859// As a temporary hack for some nodes (like the FunctionIR::FastCount)60// node that rely on an empty input, always ensure we send at least one61// morsel.62// TODO: remove this hack.63let exhausted = if let Some(src) = &self.source {64let seq = self.seq.load(Ordering::Relaxed);65seq > 0 && seq * self.morsel_size as u64 >= src.height() as u6466} else {67true68};69if send[0] == PortState::Done || exhausted {70send[0] = PortState::Done;71self.source = None;72} else {73send[0] = PortState::Ready;74}75Ok(())76}7778fn spawn<'env, 's>(79&'env mut self,80scope: &'s TaskScope<'s, 'env>,81recv_ports: &mut [Option<RecvPort<'_>>],82send_ports: &mut [Option<SendPort<'_>>],83_state: &'s StreamingExecutionState,84join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,85) {86assert!(recv_ports.is_empty() && send_ports.len() == 1);87let senders = send_ports[0].take().unwrap().parallel();88let source = self.source.as_ref().unwrap();8990// TODO: can this just be serial, using the work distributor?91let source_token = SourceToken::new();92for mut send in senders {93let slf = &*self;94let source_token = source_token.clone();95join_handles.push(scope.spawn_task(TaskPriority::Low, async move {96let wait_group = WaitGroup::default();97loop {98let seq = slf.seq.fetch_add(1, Ordering::Relaxed);99let offset = (seq as usize * slf.morsel_size) as i64;100let df = source.slice(offset, slf.morsel_size);101102// TODO: remove this 'always sent at least one morsel'103// condition, see update_state.104if df.height() == 0 && seq > 0 {105break;106}107108let morsel_seq = MorselSeq::new(seq).offset_by(slf.seq_offset);109let mut morsel = Morsel::new(df, morsel_seq, source_token.clone());110morsel.set_consume_token(wait_group.token());111if send.send(morsel).await.is_err() {112break;113}114115wait_group.wait().await;116if source_token.stop_requested() {117break;118}119}120121Ok(())122}));123}124}125}126127128