Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-expr/src/expressions/group_iter.rs
8424 views
1
#![allow(unsafe_op_in_unsafe_fn)]
2
use std::rc::Rc;
3
4
use polars_core::series::amortized_iter::AmortSeries;
5
use rayon::iter::IntoParallelIterator;
6
use rayon::prelude::*;
7
8
use super::*;
9
10
impl AggregationContext<'_> {
11
pub(super) fn iter_groups(
12
&mut self,
13
keep_names: bool,
14
) -> Box<dyn Iterator<Item = Option<AmortSeries>> + '_> {
15
match self.agg_state() {
16
AggState::LiteralScalar(_) => {
17
self.groups();
18
let c = self.get_values().rechunk();
19
let name = if keep_names {
20
c.name().clone()
21
} else {
22
PlSmallStr::EMPTY
23
};
24
// SAFETY: dtype is correct
25
unsafe {
26
Box::new(LitIter::new(
27
c.as_materialized_series().array_ref(0).clone(),
28
self.groups.len(),
29
c.dtype(),
30
name,
31
))
32
}
33
},
34
AggState::AggregatedScalar(_) => {
35
self.groups();
36
let c = self.get_values();
37
let name = if keep_names {
38
c.name().clone()
39
} else {
40
PlSmallStr::EMPTY
41
};
42
// SAFETY: dtype is correct
43
unsafe {
44
Box::new(FlatIter::new(
45
c.as_materialized_series().chunks(),
46
self.groups.len(),
47
c.dtype(),
48
name,
49
))
50
}
51
},
52
AggState::AggregatedList(_) => {
53
let c = self.get_values();
54
let list = c.list().unwrap();
55
let name = if keep_names {
56
c.name().clone()
57
} else {
58
PlSmallStr::EMPTY
59
};
60
Box::new(list.amortized_iter_with_name(name))
61
},
62
AggState::NotAggregated(_) => {
63
// we don't take the owned series as we want a reference
64
let _ = self.aggregated();
65
let c = self.get_values();
66
let list = c.list().unwrap();
67
let name = if keep_names {
68
c.name().clone()
69
} else {
70
PlSmallStr::EMPTY
71
};
72
Box::new(list.amortized_iter_with_name(name))
73
},
74
}
75
}
76
}
77
78
impl AggregationContext<'_> {
79
/// Iterate over groups lazily, i.e., without greedy aggregation into an AggList.
80
pub(super) fn iter_groups_lazy(&mut self) -> impl Iterator<Item = Option<Series>> + '_ {
81
match self.agg_state() {
82
AggState::NotAggregated(_) => {
83
let groups = self.groups();
84
let len = groups.len();
85
let groups = Arc::new(groups.clone());
86
87
let c = self.get_values().rechunk();
88
89
let col = Arc::new(c);
90
91
(0..len).map(move |idx| {
92
let g = groups.get(idx);
93
match g {
94
GroupsIndicator::Idx(_) => unreachable!(),
95
GroupsIndicator::Slice(s) => Some(
96
col.slice(s[0] as i64, s[1] as usize)
97
.into_materialized_series()
98
.clone(),
99
),
100
}
101
})
102
},
103
_ => unreachable!(),
104
}
105
}
106
107
/// Iterate parallel over groups lazily, i.e., without greedy aggregation into an AggList.
108
pub(super) fn par_iter_groups_lazy(
109
&mut self,
110
) -> impl IndexedParallelIterator<Item = Option<Series>> + '_ {
111
match self.agg_state() {
112
AggState::NotAggregated(_) => {
113
let groups = self.groups();
114
let len = groups.len();
115
let groups = Arc::new(groups.clone());
116
117
let c = self.get_values().rechunk();
118
119
let col = Arc::new(c);
120
121
(0..len).into_par_iter().map(move |idx| {
122
let g = groups.get(idx);
123
match g {
124
GroupsIndicator::Idx(_) => unreachable!(),
125
GroupsIndicator::Slice(s) => Some(
126
col.slice(s[0] as i64, s[1] as usize)
127
.into_materialized_series()
128
.clone(),
129
),
130
}
131
})
132
},
133
_ => unreachable!(),
134
}
135
}
136
}
137
138
struct LitIter {
139
len: usize,
140
offset: usize,
141
// AmortSeries referenced that series
142
#[allow(dead_code)]
143
series_container: Rc<Series>,
144
item: AmortSeries,
145
}
146
147
impl LitIter {
148
/// # Safety
149
/// Caller must ensure the given `logical` dtype belongs to `array`.
150
unsafe fn new(array: ArrayRef, len: usize, logical: &DataType, name: PlSmallStr) -> Self {
151
let series_container = Rc::new(Series::from_chunks_and_dtype_unchecked(
152
name,
153
vec![array],
154
logical,
155
));
156
157
Self {
158
offset: 0,
159
len,
160
series_container: series_container.clone(),
161
// SAFETY: we pinned the series so the location is still valid
162
item: AmortSeries::new(series_container),
163
}
164
}
165
}
166
167
impl Iterator for LitIter {
168
type Item = Option<AmortSeries>;
169
170
fn next(&mut self) -> Option<Self::Item> {
171
if self.len == self.offset {
172
None
173
} else {
174
self.offset += 1;
175
Some(Some(self.item.clone()))
176
}
177
}
178
179
fn size_hint(&self) -> (usize, Option<usize>) {
180
(self.len, Some(self.len))
181
}
182
}
183
184
struct FlatIter {
185
current_array: ArrayRef,
186
chunks: Vec<ArrayRef>,
187
offset: usize,
188
chunk_offset: usize,
189
len: usize,
190
// AmortSeries referenced that series
191
#[allow(dead_code)]
192
series_container: Rc<Series>,
193
item: AmortSeries,
194
}
195
196
impl FlatIter {
197
/// # Safety
198
/// Caller must ensure the given `logical` dtype belongs to `array`.
199
unsafe fn new(chunks: &[ArrayRef], len: usize, logical: &DataType, name: PlSmallStr) -> Self {
200
let mut stack = Vec::with_capacity(chunks.len());
201
for chunk in chunks.iter().rev() {
202
stack.push(chunk.clone())
203
}
204
let current_array = stack.pop().unwrap();
205
let series_container = Rc::new(Series::from_chunks_and_dtype_unchecked(
206
name,
207
vec![current_array.clone()],
208
logical,
209
));
210
Self {
211
current_array,
212
chunks: stack,
213
offset: 0,
214
chunk_offset: 0,
215
len,
216
series_container: series_container.clone(),
217
item: AmortSeries::new(series_container),
218
}
219
}
220
}
221
222
impl Iterator for FlatIter {
223
type Item = Option<AmortSeries>;
224
225
fn next(&mut self) -> Option<Self::Item> {
226
if self.len == self.offset {
227
None
228
} else {
229
if self.chunk_offset < self.current_array.len() {
230
let mut arr = unsafe { self.current_array.sliced_unchecked(self.chunk_offset, 1) };
231
unsafe { self.item.swap(&mut arr) };
232
} else {
233
match self.chunks.pop() {
234
Some(arr) => {
235
self.current_array = arr;
236
self.chunk_offset = 0;
237
return self.next();
238
},
239
None => return None,
240
}
241
}
242
self.offset += 1;
243
self.chunk_offset += 1;
244
Some(Some(self.item.clone()))
245
}
246
}
247
fn size_hint(&self) -> (usize, Option<usize>) {
248
(self.len - self.offset, Some(self.len - self.offset))
249
}
250
}
251
252