Path: blob/main/crates/polars-expr/src/expressions/group_iter.rs
6940 views
#![allow(unsafe_op_in_unsafe_fn)]1use std::rc::Rc;23use polars_core::series::amortized_iter::AmortSeries;45use super::*;67impl AggregationContext<'_> {8pub(super) fn iter_groups(9&mut self,10keep_names: bool,11) -> Box<dyn Iterator<Item = Option<AmortSeries>> + '_> {12match self.agg_state() {13AggState::LiteralScalar(_) => {14self.groups();15let c = self.get_values().rechunk();16let name = if keep_names {17c.name().clone()18} else {19PlSmallStr::EMPTY20};21// SAFETY: dtype is correct22unsafe {23Box::new(LitIter::new(24c.as_materialized_series().array_ref(0).clone(),25self.groups.len(),26c.dtype(),27name,28))29}30},31AggState::AggregatedScalar(_) => {32self.groups();33let c = self.get_values();34let name = if keep_names {35c.name().clone()36} else {37PlSmallStr::EMPTY38};39// SAFETY: dtype is correct40unsafe {41Box::new(FlatIter::new(42c.as_materialized_series().chunks(),43self.groups.len(),44c.dtype(),45name,46))47}48},49AggState::AggregatedList(_) => {50let c = self.get_values();51let list = c.list().unwrap();52let name = if keep_names {53c.name().clone()54} else {55PlSmallStr::EMPTY56};57Box::new(list.amortized_iter_with_name(name))58},59AggState::NotAggregated(_) => {60// we don't take the owned series as we want a reference61let _ = self.aggregated();62let c = self.get_values();63let list = c.list().unwrap();64let name = if keep_names {65c.name().clone()66} else {67PlSmallStr::EMPTY68};69Box::new(list.amortized_iter_with_name(name))70},71}72}73}7475struct LitIter {76len: usize,77offset: usize,78// AmortSeries referenced that series79#[allow(dead_code)]80series_container: Rc<Series>,81item: AmortSeries,82}8384impl LitIter {85/// # Safety86/// Caller must ensure the given `logical` dtype belongs to `array`.87unsafe fn new(array: ArrayRef, len: usize, logical: &DataType, name: PlSmallStr) -> Self {88let series_container = Rc::new(Series::from_chunks_and_dtype_unchecked(89name,90vec![array],91logical,92));9394Self {95offset: 0,96len,97series_container: series_container.clone(),98// SAFETY: we pinned the series so the location is still valid99item: AmortSeries::new(series_container),100}101}102}103104impl Iterator for LitIter {105type Item = Option<AmortSeries>;106107fn next(&mut self) -> Option<Self::Item> {108if self.len == self.offset {109None110} else {111self.offset += 1;112Some(Some(self.item.clone()))113}114}115116fn size_hint(&self) -> (usize, Option<usize>) {117(self.len, Some(self.len))118}119}120121struct FlatIter {122current_array: ArrayRef,123chunks: Vec<ArrayRef>,124offset: usize,125chunk_offset: usize,126len: usize,127// AmortSeries referenced that series128#[allow(dead_code)]129series_container: Rc<Series>,130item: AmortSeries,131}132133impl FlatIter {134/// # Safety135/// Caller must ensure the given `logical` dtype belongs to `array`.136unsafe fn new(chunks: &[ArrayRef], len: usize, logical: &DataType, name: PlSmallStr) -> Self {137let mut stack = Vec::with_capacity(chunks.len());138for chunk in chunks.iter().rev() {139stack.push(chunk.clone())140}141let current_array = stack.pop().unwrap();142let series_container = Rc::new(Series::from_chunks_and_dtype_unchecked(143name,144vec![current_array.clone()],145logical,146));147Self {148current_array,149chunks: stack,150offset: 0,151chunk_offset: 0,152len,153series_container: series_container.clone(),154item: AmortSeries::new(series_container),155}156}157}158159impl Iterator for FlatIter {160type Item = Option<AmortSeries>;161162fn next(&mut self) -> Option<Self::Item> {163if self.len == self.offset {164None165} else {166if self.chunk_offset < self.current_array.len() {167let mut arr = unsafe { self.current_array.sliced_unchecked(self.chunk_offset, 1) };168unsafe { self.item.swap(&mut arr) };169} else {170match self.chunks.pop() {171Some(arr) => {172self.current_array = arr;173self.chunk_offset = 0;174return self.next();175},176None => return None,177}178}179self.offset += 1;180self.chunk_offset += 1;181Some(Some(self.item.clone()))182}183}184fn size_hint(&self) -> (usize, Option<usize>) {185(self.len - self.offset, Some(self.len - self.offset))186}187}188189190