Path: blob/main/crates/polars-expr/src/expressions/group_iter.rs
8424 views
#![allow(unsafe_op_in_unsafe_fn)]1use std::rc::Rc;23use polars_core::series::amortized_iter::AmortSeries;4use rayon::iter::IntoParallelIterator;5use rayon::prelude::*;67use super::*;89impl AggregationContext<'_> {10pub(super) fn iter_groups(11&mut self,12keep_names: bool,13) -> Box<dyn Iterator<Item = Option<AmortSeries>> + '_> {14match self.agg_state() {15AggState::LiteralScalar(_) => {16self.groups();17let c = self.get_values().rechunk();18let name = if keep_names {19c.name().clone()20} else {21PlSmallStr::EMPTY22};23// SAFETY: dtype is correct24unsafe {25Box::new(LitIter::new(26c.as_materialized_series().array_ref(0).clone(),27self.groups.len(),28c.dtype(),29name,30))31}32},33AggState::AggregatedScalar(_) => {34self.groups();35let c = self.get_values();36let name = if keep_names {37c.name().clone()38} else {39PlSmallStr::EMPTY40};41// SAFETY: dtype is correct42unsafe {43Box::new(FlatIter::new(44c.as_materialized_series().chunks(),45self.groups.len(),46c.dtype(),47name,48))49}50},51AggState::AggregatedList(_) => {52let c = self.get_values();53let list = c.list().unwrap();54let name = if keep_names {55c.name().clone()56} else {57PlSmallStr::EMPTY58};59Box::new(list.amortized_iter_with_name(name))60},61AggState::NotAggregated(_) => {62// we don't take the owned series as we want a reference63let _ = self.aggregated();64let c = self.get_values();65let list = c.list().unwrap();66let name = if keep_names {67c.name().clone()68} else {69PlSmallStr::EMPTY70};71Box::new(list.amortized_iter_with_name(name))72},73}74}75}7677impl AggregationContext<'_> {78/// Iterate over groups lazily, i.e., without greedy aggregation into an AggList.79pub(super) fn iter_groups_lazy(&mut self) -> impl Iterator<Item = Option<Series>> + '_ {80match self.agg_state() {81AggState::NotAggregated(_) => {82let groups = self.groups();83let len = groups.len();84let groups = Arc::new(groups.clone());8586let c = self.get_values().rechunk();8788let col = Arc::new(c);8990(0..len).map(move |idx| {91let g = groups.get(idx);92match g {93GroupsIndicator::Idx(_) => unreachable!(),94GroupsIndicator::Slice(s) => Some(95col.slice(s[0] as i64, s[1] as usize)96.into_materialized_series()97.clone(),98),99}100})101},102_ => unreachable!(),103}104}105106/// Iterate parallel over groups lazily, i.e., without greedy aggregation into an AggList.107pub(super) fn par_iter_groups_lazy(108&mut self,109) -> impl IndexedParallelIterator<Item = Option<Series>> + '_ {110match self.agg_state() {111AggState::NotAggregated(_) => {112let groups = self.groups();113let len = groups.len();114let groups = Arc::new(groups.clone());115116let c = self.get_values().rechunk();117118let col = Arc::new(c);119120(0..len).into_par_iter().map(move |idx| {121let g = groups.get(idx);122match g {123GroupsIndicator::Idx(_) => unreachable!(),124GroupsIndicator::Slice(s) => Some(125col.slice(s[0] as i64, s[1] as usize)126.into_materialized_series()127.clone(),128),129}130})131},132_ => unreachable!(),133}134}135}136137struct LitIter {138len: usize,139offset: usize,140// AmortSeries referenced that series141#[allow(dead_code)]142series_container: Rc<Series>,143item: AmortSeries,144}145146impl LitIter {147/// # Safety148/// Caller must ensure the given `logical` dtype belongs to `array`.149unsafe fn new(array: ArrayRef, len: usize, logical: &DataType, name: PlSmallStr) -> Self {150let series_container = Rc::new(Series::from_chunks_and_dtype_unchecked(151name,152vec![array],153logical,154));155156Self {157offset: 0,158len,159series_container: series_container.clone(),160// SAFETY: we pinned the series so the location is still valid161item: AmortSeries::new(series_container),162}163}164}165166impl Iterator for LitIter {167type Item = Option<AmortSeries>;168169fn next(&mut self) -> Option<Self::Item> {170if self.len == self.offset {171None172} else {173self.offset += 1;174Some(Some(self.item.clone()))175}176}177178fn size_hint(&self) -> (usize, Option<usize>) {179(self.len, Some(self.len))180}181}182183struct FlatIter {184current_array: ArrayRef,185chunks: Vec<ArrayRef>,186offset: usize,187chunk_offset: usize,188len: usize,189// AmortSeries referenced that series190#[allow(dead_code)]191series_container: Rc<Series>,192item: AmortSeries,193}194195impl FlatIter {196/// # Safety197/// Caller must ensure the given `logical` dtype belongs to `array`.198unsafe fn new(chunks: &[ArrayRef], len: usize, logical: &DataType, name: PlSmallStr) -> Self {199let mut stack = Vec::with_capacity(chunks.len());200for chunk in chunks.iter().rev() {201stack.push(chunk.clone())202}203let current_array = stack.pop().unwrap();204let series_container = Rc::new(Series::from_chunks_and_dtype_unchecked(205name,206vec![current_array.clone()],207logical,208));209Self {210current_array,211chunks: stack,212offset: 0,213chunk_offset: 0,214len,215series_container: series_container.clone(),216item: AmortSeries::new(series_container),217}218}219}220221impl Iterator for FlatIter {222type Item = Option<AmortSeries>;223224fn next(&mut self) -> Option<Self::Item> {225if self.len == self.offset {226None227} else {228if self.chunk_offset < self.current_array.len() {229let mut arr = unsafe { self.current_array.sliced_unchecked(self.chunk_offset, 1) };230unsafe { self.item.swap(&mut arr) };231} else {232match self.chunks.pop() {233Some(arr) => {234self.current_array = arr;235self.chunk_offset = 0;236return self.next();237},238None => return None,239}240}241self.offset += 1;242self.chunk_offset += 1;243Some(Some(self.item.clone()))244}245}246fn size_hint(&self) -> (usize, Option<usize>) {247(self.len - self.offset, Some(self.len - self.offset))248}249}250251252