Path: blob/main/crates/polars-stream/src/nodes/peak_minmax.rs
6939 views
use polars_core::frame::DataFrame;1use polars_core::prelude::{AnyValue, Column, IntoColumn};2use polars_error::PolarsResult;3use polars_ops::prelude::peaks;45use super::ComputeNode;6use crate::async_executor::{JoinHandle, TaskPriority, TaskScope};7use crate::async_primitives::wait_group::WaitGroup;8use crate::execute::StreamingExecutionState;9use crate::graph::PortState;10use crate::morsel::{Morsel, MorselSeq, SourceToken};11use crate::pipe::{RecvPort, SendPort};1213enum State {14/// No morsels seen yet.15Start,16/// We have seen one morsel. Wait until 1 more to start streaming out data.17One(MorselSeq, Column),18/// We have seen two morsels. We have saved the last value of 2 morsels ago and the last19/// morsel.20Two(AnyValue<'static>, MorselSeq, Column),21/// No more morsels will be received.22Done,23}2425pub struct PeakMinMaxNode {26state: State,2728/// Is the node the `peak_max`?29is_peak_max: bool,30}3132impl PeakMinMaxNode {33pub fn new(is_peak_max: bool) -> Self {34Self {35state: State::Start,36is_peak_max,37}38}39}4041impl ComputeNode for PeakMinMaxNode {42fn name(&self) -> &str {43if self.is_peak_max {44"peaks_max"45} else {46"peaks_min"47}48}4950fn update_state(51&mut self,52recv: &mut [PortState],53send: &mut [PortState],54_state: &StreamingExecutionState,55) -> PolarsResult<()> {56assert!(recv.len() == 1 && send.len() == 1);5758if matches!(self.state, State::Done) {59send[0] = PortState::Done;60recv[0] = PortState::Done;61} else if send[0] == PortState::Done {62recv[0] = PortState::Done;63self.state = State::Done;64} else if recv[0] == PortState::Done {65if matches!(self.state, State::Start) {66send[0] = PortState::Done;67} else {68send[0] = PortState::Ready;69}70} else {71recv.swap_with_slice(send);72}7374Ok(())75}7677fn spawn<'env, 's>(78&'env mut self,79scope: &'s TaskScope<'s, 'env>,80recv_ports: &mut [Option<RecvPort<'_>>],81send_ports: &mut [Option<SendPort<'_>>],82_state: &'s StreamingExecutionState,83join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,84) {85assert_eq!(recv_ports.len(), 1);86assert_eq!(send_ports.len(), 1);8788let recv = recv_ports[0].take();89let mut send = send_ports[0].take().unwrap().serial();9091match recv {92// No more morsels to receive. Flush out the remaining data.93None => {94if matches!(self.state, State::Start) {95return;96}9798join_handles.push(scope.spawn_task(TaskPriority::High, async move {99let (start, seq, prev_column) = match &self.state {100State::Start => unreachable!(),101State::One(seq, df) => (&AnyValue::Int8(0), *seq, df),102State::Two(av, seq, df) => (av, *seq, df),103State::Done => unreachable!(),104};105106let column = peaks::peak_min_max(107prev_column,108start,109&AnyValue::Int8(0),110self.is_peak_max,111)?112.into_column();113let df = DataFrame::new(vec![column]).unwrap();114_ = send.send(Morsel::new(df, seq, SourceToken::new())).await;115116self.state = State::Done;117Ok(())118}));119},120121Some(recv) => {122let mut recv = recv.serial();123join_handles.push(scope.spawn_task(TaskPriority::High, async move {124let source_token = SourceToken::new();125126while let Ok(m) = recv.recv().await {127let (df, seq, in_source_token, in_wait_token) = m.into_inner();128drop(in_wait_token);129if df.height() == 0 {130continue;131}132133assert_eq!(df.width(), 1);134let column = &df[0];135136let (start, prev_seq, prev_column) = match &self.state {137State::Start => {138self.state = State::One(seq, column.clone());139continue;140},141State::One(prev_seq, prev_column) => {142(&AnyValue::Int8(0), *prev_seq, prev_column)143},144State::Two(prev_start, prev_seq, prev_column) => {145(prev_start, *prev_seq, prev_column)146},147State::Done => unreachable!(),148};149let end = &column.get(0).unwrap();150let out = peaks::peak_min_max(prev_column, start, end, self.is_peak_max)?151.into_column();152153let wg = WaitGroup::default();154let mut m = Morsel::new(155DataFrame::new(vec![out]).unwrap(),156prev_seq,157source_token.clone(),158);159m.set_consume_token(wg.token());160161if send.send(m).await.is_err() {162self.state = State::Done;163break;164}165166wg.wait().await;167if source_token.stop_requested() {168in_source_token.stop();169}170171let prev_end = prev_column172.get(prev_column.len() - 1)173.unwrap()174.to_physical()175.into_static();176self.state = State::Two(prev_end, seq, column.clone());177}178Ok(())179}));180},181}182}183}184185186