Path: blob/main/crates/polars-stream/src/nodes/negative_slice.rs
6939 views
use std::collections::VecDeque;1use std::sync::Arc;23use polars_core::utils::accumulate_dataframes_vertical_unchecked;45use super::compute_node_prelude::*;6use crate::nodes::in_memory_source::InMemorySourceNode;78/// A node that will pass-through up to length rows, starting at start_offset.9/// Since start_offset must be non-negative this can be done in a streaming10/// manner.11enum NegativeSliceState {12Buffering(Buffer),13Source(InMemorySourceNode),14Done,15}1617#[derive(Default)]18struct Buffer {19frames: VecDeque<DataFrame>,20total_len: usize,21}2223pub struct NegativeSliceNode {24state: NegativeSliceState,25slice_offset: i64,26length: usize,27}2829impl NegativeSliceNode {30pub fn new(slice_offset: i64, length: usize) -> Self {31assert!(slice_offset < 0);32Self {33state: NegativeSliceState::Buffering(Buffer::default()),34slice_offset,35length,36}37}38}3940impl ComputeNode for NegativeSliceNode {41fn name(&self) -> &str {42"negative-slice"43}4445fn update_state(46&mut self,47recv: &mut [PortState],48send: &mut [PortState],49state: &StreamingExecutionState,50) -> PolarsResult<()> {51use NegativeSliceState::*;5253if send[0] == PortState::Done || self.length == 0 {54self.state = Done;55}5657if recv[0] == PortState::Done {58if let Buffering(buffer) = &mut self.state {59// These offsets are relative to the start of buffer.60let mut signed_start_offset = buffer.total_len as i64 + self.slice_offset;61let signed_stop_offset =62signed_start_offset.saturating_add_unsigned(self.length as u64);6364// Trim the frames in the buffer to just those that are relevant.65while buffer.total_len > 066&& signed_start_offset >= buffer.frames.front().unwrap().height() as i6467{68let len = buffer.frames.pop_front().unwrap().height();69buffer.total_len -= len;70signed_start_offset -= len as i64;71}7273while !buffer.frames.is_empty()74&& buffer.total_len as i64 - buffer.frames.back().unwrap().height() as i6475> signed_stop_offset76{77buffer.total_len -= buffer.frames.pop_back().unwrap().height();78}7980if buffer.total_len == 0 {81self.state = Done;82} else {83let mut df = accumulate_dataframes_vertical_unchecked(buffer.frames.drain(..));84let clamped_start = signed_start_offset.max(0);85let len = (signed_stop_offset - clamped_start).max(0) as usize;86df = df.slice(clamped_start, len);87self.state =88Source(InMemorySourceNode::new(Arc::new(df), MorselSeq::default()));89}90}91}9293match &mut self.state {94Buffering(_) => {95recv[0] = PortState::Ready;96send[0] = PortState::Blocked;97},98Source(node) => {99recv[0] = PortState::Done;100node.update_state(&mut [], send, state)?;101},102Done => {103recv[0] = PortState::Done;104send[0] = PortState::Done;105},106}107Ok(())108}109110fn spawn<'env, 's>(111&'env mut self,112scope: &'s TaskScope<'s, 'env>,113recv_ports: &mut [Option<RecvPort<'_>>],114send_ports: &mut [Option<SendPort<'_>>],115state: &'s StreamingExecutionState,116join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,117) {118assert!(recv_ports.len() == 1 && send_ports.len() == 1);119match &mut self.state {120NegativeSliceState::Buffering(buffer) => {121let mut recv = recv_ports[0].take().unwrap().serial();122assert!(send_ports[0].is_none());123let max_buffer_needed = self.slice_offset.unsigned_abs() as usize;124join_handles.push(scope.spawn_task(TaskPriority::High, async move {125while let Ok(morsel) = recv.recv().await {126buffer.total_len += morsel.df().height();127buffer.frames.push_back(morsel.into_df());128129if buffer.total_len - buffer.frames.front().unwrap().height()130>= max_buffer_needed131{132buffer.total_len -= buffer.frames.pop_front().unwrap().height();133}134}135136Ok(())137}));138},139NegativeSliceState::Source(in_memory_source_node) => {140assert!(recv_ports[0].is_none());141in_memory_source_node.spawn(scope, &mut [], send_ports, state, join_handles);142},143NegativeSliceState::Done => unreachable!(),144}145}146}147148149