Path: blob/main/crates/polars-stream/src/nodes/select.rs
6939 views
use std::sync::Arc;12use polars_core::prelude::IntoColumn;3use polars_core::schema::Schema;45use super::compute_node_prelude::*;6use crate::expression::StreamExpr;78pub struct SelectNode {9selectors: Vec<StreamExpr>,10schema: Arc<Schema>,11extend_original: bool,12}1314impl SelectNode {15pub fn new(selectors: Vec<StreamExpr>, schema: Arc<Schema>, extend_original: bool) -> Self {16Self {17selectors,18schema,19extend_original,20}21}22}2324impl ComputeNode for SelectNode {25fn name(&self) -> &str {26if self.extend_original {27"with-columns"28} else {29"select"30}31}3233fn update_state(34&mut self,35recv: &mut [PortState],36send: &mut [PortState],37_state: &StreamingExecutionState,38) -> PolarsResult<()> {39assert!(recv.len() == 1 && send.len() == 1);40recv.swap_with_slice(send);41Ok(())42}4344fn spawn<'env, 's>(45&'env mut self,46scope: &'s TaskScope<'s, 'env>,47recv_ports: &mut [Option<RecvPort<'_>>],48send_ports: &mut [Option<SendPort<'_>>],49state: &'s StreamingExecutionState,50join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,51) {52assert!(recv_ports.len() == 1 && send_ports.len() == 1);53let receivers = recv_ports[0].take().unwrap().parallel();54let senders = send_ports[0].take().unwrap().parallel();5556for (mut recv, mut send) in receivers.into_iter().zip(senders) {57let slf = &*self;58join_handles.push(scope.spawn_task(TaskPriority::High, async move {59while let Ok(morsel) = recv.recv().await {60let (df, seq, source_token, consume_token) = morsel.into_inner();61let mut selected = Vec::new();62for selector in slf.selectors.iter() {63let s = selector.evaluate(&df, &state.in_memory_exec_state).await?;64selected.push(s.into_column());65}6667let ret = if slf.extend_original {68let mut out = df;69out._add_columns(selected, &slf.schema)?;70out71} else {72DataFrame::new_with_broadcast(selected)?73};7475let mut morsel = Morsel::new(ret, seq, source_token);76if let Some(token) = consume_token {77morsel.set_consume_token(token);78}7980if send.send(morsel).await.is_err() {81break;82}83}8485Ok(())86}));87}88}89}909192