Path: blob/main/crates/polars-stream/src/nodes/repeat.rs
6939 views
use std::sync::Arc;12use polars_core::schema::Schema;34use super::compute_node_prelude::*;5use crate::async_primitives::wait_group::WaitGroup;6use crate::morsel::{SourceToken, get_ideal_morsel_size};7use crate::nodes::in_memory_sink::InMemorySinkNode;8pub enum RepeatNode {9GatheringParams {10value: InMemorySinkNode,11repeats: InMemorySinkNode,12},13Repeating {14value: DataFrame,15seq: MorselSeq,16repeats_left: usize,17},18}1920impl RepeatNode {21pub fn new(value_schema: Arc<Schema>, repeats_schema: Arc<Schema>) -> Self {22assert!(value_schema.len() == 1);23assert!(repeats_schema.len() == 1);24Self::GatheringParams {25value: InMemorySinkNode::new(value_schema),26repeats: InMemorySinkNode::new(repeats_schema),27}28}29}3031impl ComputeNode for RepeatNode {32fn name(&self) -> &str {33"repeat"34}3536fn update_state(37&mut self,38recv: &mut [PortState],39send: &mut [PortState],40state: &StreamingExecutionState,41) -> PolarsResult<()> {42assert!(recv.len() == 2 && send.len() == 1);4344if recv[0] == PortState::Done && recv[1] == PortState::Done {45if let Self::GatheringParams { value, repeats } = self {46let repeats = repeats.get_output()?.unwrap();47let repeats_item = repeats.get_columns()[0].get(0)?;48let repeats_left = repeats_item.extract::<usize>().unwrap();4950let value = value.get_output()?.unwrap();51let seq = MorselSeq::default();52*self = Self::Repeating {53value,54seq,55repeats_left,56};57}58}5960match self {61Self::GatheringParams { value, repeats } => {62value.update_state(&mut recv[0..1], &mut [], state)?;63repeats.update_state(&mut recv[1..2], &mut [], state)?;64send[0] = PortState::Blocked;65},66Self::Repeating { repeats_left, .. } => {67recv[0] = PortState::Done;68recv[1] = PortState::Done;69send[0] = if *repeats_left > 0 {70PortState::Ready71} else {72PortState::Done73};74},75}76Ok(())77}7879fn spawn<'env, 's>(80&'env mut self,81scope: &'s TaskScope<'s, 'env>,82recv_ports: &mut [Option<RecvPort<'_>>],83send_ports: &mut [Option<SendPort<'_>>],84state: &'s StreamingExecutionState,85join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,86) {87assert!(recv_ports.len() == 2 && send_ports.len() == 1);88match self {89Self::GatheringParams { value, repeats } => {90assert!(send_ports[0].is_none());91if recv_ports[0].is_some() {92value.spawn(scope, &mut recv_ports[0..1], &mut [], state, join_handles);93}94if recv_ports[1].is_some() {95repeats.spawn(scope, &mut recv_ports[1..2], &mut [], state, join_handles);96}97},98Self::Repeating {99value,100seq,101repeats_left,102} => {103assert!(recv_ports[0].is_none());104assert!(recv_ports[1].is_none());105106let mut send = send_ports[0].take().unwrap().serial();107108let ideal_morsel_count = (*repeats_left / get_ideal_morsel_size()).max(1);109let morsel_count = ideal_morsel_count.next_multiple_of(state.num_pipelines);110let morsel_size = repeats_left.div_ceil(morsel_count).max(1);111112join_handles.push(scope.spawn_task(TaskPriority::Low, async move {113let source_token = SourceToken::new();114115let wait_group = WaitGroup::default();116while *repeats_left > 0 && !source_token.stop_requested() {117let height = morsel_size.min(*repeats_left);118let df = value.new_from_index(0, height);119let mut morsel = Morsel::new(df, *seq, source_token.clone());120morsel.set_consume_token(wait_group.token());121122*seq = seq.successor();123*repeats_left -= height;124125if send.send(morsel).await.is_err() {126break;127}128wait_group.wait().await;129}130131Ok(())132}));133},134}135}136}137138139