Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-expr/src/expressions/group_iter.rs
6940 views
1
#![allow(unsafe_op_in_unsafe_fn)]
2
use std::rc::Rc;
3
4
use polars_core::series::amortized_iter::AmortSeries;
5
6
use super::*;
7
8
impl AggregationContext<'_> {
9
pub(super) fn iter_groups(
10
&mut self,
11
keep_names: bool,
12
) -> Box<dyn Iterator<Item = Option<AmortSeries>> + '_> {
13
match self.agg_state() {
14
AggState::LiteralScalar(_) => {
15
self.groups();
16
let c = self.get_values().rechunk();
17
let name = if keep_names {
18
c.name().clone()
19
} else {
20
PlSmallStr::EMPTY
21
};
22
// SAFETY: dtype is correct
23
unsafe {
24
Box::new(LitIter::new(
25
c.as_materialized_series().array_ref(0).clone(),
26
self.groups.len(),
27
c.dtype(),
28
name,
29
))
30
}
31
},
32
AggState::AggregatedScalar(_) => {
33
self.groups();
34
let c = self.get_values();
35
let name = if keep_names {
36
c.name().clone()
37
} else {
38
PlSmallStr::EMPTY
39
};
40
// SAFETY: dtype is correct
41
unsafe {
42
Box::new(FlatIter::new(
43
c.as_materialized_series().chunks(),
44
self.groups.len(),
45
c.dtype(),
46
name,
47
))
48
}
49
},
50
AggState::AggregatedList(_) => {
51
let c = self.get_values();
52
let list = c.list().unwrap();
53
let name = if keep_names {
54
c.name().clone()
55
} else {
56
PlSmallStr::EMPTY
57
};
58
Box::new(list.amortized_iter_with_name(name))
59
},
60
AggState::NotAggregated(_) => {
61
// we don't take the owned series as we want a reference
62
let _ = self.aggregated();
63
let c = self.get_values();
64
let list = c.list().unwrap();
65
let name = if keep_names {
66
c.name().clone()
67
} else {
68
PlSmallStr::EMPTY
69
};
70
Box::new(list.amortized_iter_with_name(name))
71
},
72
}
73
}
74
}
75
76
struct LitIter {
77
len: usize,
78
offset: usize,
79
// AmortSeries referenced that series
80
#[allow(dead_code)]
81
series_container: Rc<Series>,
82
item: AmortSeries,
83
}
84
85
impl LitIter {
86
/// # Safety
87
/// Caller must ensure the given `logical` dtype belongs to `array`.
88
unsafe fn new(array: ArrayRef, len: usize, logical: &DataType, name: PlSmallStr) -> Self {
89
let series_container = Rc::new(Series::from_chunks_and_dtype_unchecked(
90
name,
91
vec![array],
92
logical,
93
));
94
95
Self {
96
offset: 0,
97
len,
98
series_container: series_container.clone(),
99
// SAFETY: we pinned the series so the location is still valid
100
item: AmortSeries::new(series_container),
101
}
102
}
103
}
104
105
impl Iterator for LitIter {
106
type Item = Option<AmortSeries>;
107
108
fn next(&mut self) -> Option<Self::Item> {
109
if self.len == self.offset {
110
None
111
} else {
112
self.offset += 1;
113
Some(Some(self.item.clone()))
114
}
115
}
116
117
fn size_hint(&self) -> (usize, Option<usize>) {
118
(self.len, Some(self.len))
119
}
120
}
121
122
struct FlatIter {
123
current_array: ArrayRef,
124
chunks: Vec<ArrayRef>,
125
offset: usize,
126
chunk_offset: usize,
127
len: usize,
128
// AmortSeries referenced that series
129
#[allow(dead_code)]
130
series_container: Rc<Series>,
131
item: AmortSeries,
132
}
133
134
impl FlatIter {
135
/// # Safety
136
/// Caller must ensure the given `logical` dtype belongs to `array`.
137
unsafe fn new(chunks: &[ArrayRef], len: usize, logical: &DataType, name: PlSmallStr) -> Self {
138
let mut stack = Vec::with_capacity(chunks.len());
139
for chunk in chunks.iter().rev() {
140
stack.push(chunk.clone())
141
}
142
let current_array = stack.pop().unwrap();
143
let series_container = Rc::new(Series::from_chunks_and_dtype_unchecked(
144
name,
145
vec![current_array.clone()],
146
logical,
147
));
148
Self {
149
current_array,
150
chunks: stack,
151
offset: 0,
152
chunk_offset: 0,
153
len,
154
series_container: series_container.clone(),
155
item: AmortSeries::new(series_container),
156
}
157
}
158
}
159
160
impl Iterator for FlatIter {
161
type Item = Option<AmortSeries>;
162
163
fn next(&mut self) -> Option<Self::Item> {
164
if self.len == self.offset {
165
None
166
} else {
167
if self.chunk_offset < self.current_array.len() {
168
let mut arr = unsafe { self.current_array.sliced_unchecked(self.chunk_offset, 1) };
169
unsafe { self.item.swap(&mut arr) };
170
} else {
171
match self.chunks.pop() {
172
Some(arr) => {
173
self.current_array = arr;
174
self.chunk_offset = 0;
175
return self.next();
176
},
177
None => return None,
178
}
179
}
180
self.offset += 1;
181
self.chunk_offset += 1;
182
Some(Some(self.item.clone()))
183
}
184
}
185
fn size_hint(&self) -> (usize, Option<usize>) {
186
(self.len - self.offset, Some(self.len - self.offset))
187
}
188
}
189
190