Path: blob/main/crates/polars-stream/src/nodes/gather_every.rs
7884 views
use polars_error::polars_ensure;12use super::compute_node_prelude::*;3use crate::DEFAULT_DISTRIBUTOR_BUFFER_SIZE;4use crate::async_primitives::distributor_channel::distributor_channel;5use crate::async_primitives::wait_group::WaitGroup;67pub struct GatherEveryNode {8n: usize,9offset: usize,10}1112impl GatherEveryNode {13pub fn new(n: usize, offset: usize) -> PolarsResult<Self> {14polars_ensure!(n > 0, InvalidOperation: "gather_every(n): n should be positive");1516assert!(i64::try_from(n).unwrap() > 0);17assert!(i64::try_from(offset).unwrap() >= 0);1819Ok(Self { n, offset })20}21}2223impl ComputeNode for GatherEveryNode {24fn name(&self) -> &str {25"gather_every"26}2728fn update_state(29&mut self,30recv: &mut [PortState],31send: &mut [PortState],32_state: &StreamingExecutionState,33) -> PolarsResult<()> {34assert!(recv.len() == 1 && send.len() == 1);35recv.swap_with_slice(send);36Ok(())37}3839fn spawn<'env, 's>(40&'env mut self,41scope: &'s TaskScope<'s, 'env>,42recv_ports: &mut [Option<RecvPort<'_>>],43send_ports: &mut [Option<SendPort<'_>>],44_state: &'s StreamingExecutionState,45join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,46) {47assert!(recv_ports.len() == 1 && send_ports.len() == 1);48let mut receiver = recv_ports[0].take().unwrap().serial();49let senders = send_ports[0].take().unwrap().parallel();5051let (mut distributor, distr_receivers) =52distributor_channel(senders.len(), *DEFAULT_DISTRIBUTOR_BUFFER_SIZE);5354let n = self.n;5556// To figure out the correct offsets we need to be serial.57join_handles.push(scope.spawn_task(TaskPriority::High, async move {58while let Ok(morsel) = receiver.recv().await {59let height = morsel.df().height();60if self.offset >= height {61self.offset -= height;62continue;63}6465if distributor.send((morsel, self.offset)).await.is_err() {66break;67}6869// Calculates `offset = (offset - height) mod n` without under- and overflow.70self.offset += height.next_multiple_of(self.n) - height;71self.offset %= self.n;72}7374Ok(())75}));7677// But gathering the column can be done in parallel.78for (mut send, mut recv) in senders.into_iter().zip(distr_receivers) {79join_handles.push(scope.spawn_task(TaskPriority::High, async move {80let wait_group = WaitGroup::default();81while let Ok((morsel, offset)) = recv.recv().await {82let mut morsel = morsel.try_map(|mut df| {83let column = &df.get_columns()[0];84let out = column85.gather_every(n, offset)?86.with_name(column.name().clone());87unsafe { df.get_columns_mut()[0] = out };88PolarsResult::Ok(df)89})?;90morsel.set_consume_token(wait_group.token());91if send.send(morsel).await.is_err() {92break;93}94wait_group.wait().await;95}9697Ok(())98}));99}100}101}102103104