Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-expr/src/groups/single_key.rs
6940 views
1
use arrow::array::Array;
2
use arrow::bitmap::MutableBitmap;
3
use polars_utils::idx_map::total_idx_map::{Entry, TotalIndexMap};
4
use polars_utils::total_ord::{TotalEq, TotalHash};
5
use polars_utils::vec::PushUnchecked;
6
7
use super::*;
8
use crate::hash_keys::{HashKeys, for_each_hash_single};
9
10
#[derive(Default)]
11
pub struct SingleKeyHashGrouper<T: PolarsDataType> {
12
idx_map: TotalIndexMap<T::Physical<'static>, ()>,
13
null_idx: IdxSize,
14
}
15
16
impl<K, T: PolarsDataType> SingleKeyHashGrouper<T>
17
where
18
for<'a> T: PolarsDataType<Physical<'a> = K>,
19
K: Default + TotalHash + TotalEq,
20
{
21
pub fn new() -> Self {
22
Self {
23
idx_map: TotalIndexMap::default(),
24
null_idx: IdxSize::MAX,
25
}
26
}
27
28
#[inline(always)]
29
fn insert_key(&mut self, key: T::Physical<'static>) -> IdxSize {
30
match self.idx_map.entry(key) {
31
Entry::Occupied(o) => o.index(),
32
Entry::Vacant(v) => {
33
let index = v.index();
34
v.insert(());
35
index
36
},
37
}
38
}
39
40
#[inline(always)]
41
fn insert_null(&mut self) -> IdxSize {
42
if self.null_idx == IdxSize::MAX {
43
self.null_idx = self.idx_map.push_unmapped_entry(T::Physical::default(), ());
44
}
45
self.null_idx
46
}
47
48
#[inline(always)]
49
fn contains_key(&self, key: &T::Physical<'static>) -> bool {
50
self.idx_map.get(key).is_some()
51
}
52
53
#[inline(always)]
54
fn contains_null(&self) -> bool {
55
self.null_idx < IdxSize::MAX
56
}
57
58
fn finalize_keys(&self, schema: &Schema, keys: Vec<T::Physical<'static>>) -> DataFrame {
59
let (name, dtype) = schema.get_at_index(0).unwrap();
60
let mut keys =
61
T::Array::from_vec(keys, dtype.to_physical().to_arrow(CompatLevel::newest()));
62
if self.null_idx < IdxSize::MAX {
63
let mut validity = MutableBitmap::new();
64
validity.extend_constant(keys.len(), true);
65
validity.set(self.null_idx as usize, false);
66
keys = keys.with_validity_typed(Some(validity.freeze()));
67
}
68
unsafe {
69
let s =
70
Series::from_chunks_and_dtype_unchecked(name.clone(), vec![Box::new(keys)], dtype);
71
DataFrame::new(vec![Column::from(s)]).unwrap()
72
}
73
}
74
}
75
76
impl<K, T: PolarsDataType> Grouper for SingleKeyHashGrouper<T>
77
where
78
for<'a> T: PolarsDataType<Physical<'a> = K>,
79
K: Default + TotalHash + TotalEq + Clone + Send + Sync + 'static,
80
{
81
fn new_empty(&self) -> Box<dyn Grouper> {
82
Box::new(Self::new())
83
}
84
85
fn reserve(&mut self, additional: usize) {
86
self.idx_map.reserve(additional);
87
}
88
89
fn num_groups(&self) -> IdxSize {
90
self.idx_map.len()
91
}
92
93
unsafe fn insert_keys_subset(
94
&mut self,
95
hash_keys: &HashKeys,
96
subset: &[IdxSize],
97
group_idxs: Option<&mut Vec<IdxSize>>,
98
) {
99
let HashKeys::Single(hash_keys) = hash_keys else {
100
unreachable!()
101
};
102
let ca: &ChunkedArray<T> = hash_keys.keys.as_phys_any().downcast_ref().unwrap();
103
let arr = ca.downcast_as_array();
104
105
unsafe {
106
if arr.has_nulls() {
107
if hash_keys.null_is_valid {
108
let groups = subset.iter().map(|idx| {
109
let opt_k = arr.get_unchecked(*idx as usize);
110
if let Some(k) = opt_k {
111
self.insert_key(k)
112
} else {
113
self.insert_null()
114
}
115
});
116
if let Some(group_idxs) = group_idxs {
117
group_idxs.reserve(subset.len());
118
group_idxs.extend(groups);
119
} else {
120
groups.for_each(drop);
121
}
122
} else {
123
let groups = subset.iter().filter_map(|idx| {
124
let opt_k = arr.get_unchecked(*idx as usize);
125
opt_k.map(|k| self.insert_key(k))
126
});
127
if let Some(group_idxs) = group_idxs {
128
group_idxs.reserve(subset.len());
129
group_idxs.extend(groups);
130
} else {
131
groups.for_each(drop);
132
}
133
}
134
} else {
135
let groups = subset.iter().map(|idx| {
136
let k = arr.value_unchecked(*idx as usize);
137
self.insert_key(k)
138
});
139
if let Some(group_idxs) = group_idxs {
140
group_idxs.reserve(subset.len());
141
group_idxs.extend(groups);
142
} else {
143
groups.for_each(drop);
144
}
145
}
146
}
147
}
148
149
fn get_keys_in_group_order(&self, schema: &Schema) -> DataFrame {
150
unsafe {
151
let mut key_rows = Vec::with_capacity(self.idx_map.len() as usize);
152
for key in self.idx_map.iter_keys() {
153
key_rows.push_unchecked(key.clone());
154
}
155
self.finalize_keys(schema, key_rows)
156
}
157
}
158
159
/// # Safety
160
/// All groupers must be a SingleKeyHashGrouper<T>.
161
unsafe fn probe_partitioned_groupers(
162
&self,
163
groupers: &[Box<dyn Grouper>],
164
hash_keys: &HashKeys,
165
partitioner: &HashPartitioner,
166
invert: bool,
167
probe_matches: &mut Vec<IdxSize>,
168
) {
169
let HashKeys::Single(hash_keys) = hash_keys else {
170
unreachable!()
171
};
172
let ca: &ChunkedArray<T> = hash_keys.keys.as_phys_any().downcast_ref().unwrap();
173
let arr = ca.downcast_as_array();
174
assert!(partitioner.num_partitions() == groupers.len());
175
176
unsafe {
177
let null_p = partitioner.null_partition();
178
for_each_hash_single(ca, &hash_keys.random_state, |idx, opt_h| {
179
let has_group = if let Some(h) = opt_h {
180
let p = partitioner.hash_to_partition(h);
181
let dyn_grouper: &dyn Grouper = &**groupers.get_unchecked(p);
182
let grouper =
183
&*(dyn_grouper as *const dyn Grouper as *const SingleKeyHashGrouper<T>);
184
let key = arr.value_unchecked(idx as usize);
185
grouper.contains_key(&key)
186
} else {
187
let dyn_grouper: &dyn Grouper = &**groupers.get_unchecked(null_p);
188
let grouper =
189
&*(dyn_grouper as *const dyn Grouper as *const SingleKeyHashGrouper<T>);
190
grouper.contains_null()
191
};
192
193
if has_group != invert {
194
probe_matches.push(idx);
195
}
196
});
197
}
198
}
199
200
/// # Safety
201
/// All groupers must be a SingleKeyHashGrouper<T>.
202
unsafe fn contains_key_partitioned_groupers(
203
&self,
204
groupers: &[Box<dyn Grouper>],
205
hash_keys: &HashKeys,
206
partitioner: &HashPartitioner,
207
invert: bool,
208
contains_key: &mut BitmapBuilder,
209
) {
210
let HashKeys::Single(hash_keys) = hash_keys else {
211
unreachable!()
212
};
213
let ca: &ChunkedArray<T> = hash_keys.keys.as_phys_any().downcast_ref().unwrap();
214
let arr = ca.downcast_as_array();
215
assert!(partitioner.num_partitions() == groupers.len());
216
217
unsafe {
218
let null_p = partitioner.null_partition();
219
for_each_hash_single(ca, &hash_keys.random_state, |idx, opt_h| {
220
let has_group = if let Some(h) = opt_h {
221
let p = partitioner.hash_to_partition(h);
222
let dyn_grouper: &dyn Grouper = &**groupers.get_unchecked(p);
223
let grouper =
224
&*(dyn_grouper as *const dyn Grouper as *const SingleKeyHashGrouper<T>);
225
let key = arr.value_unchecked(idx as usize);
226
grouper.contains_key(&key)
227
} else {
228
let dyn_grouper: &dyn Grouper = &**groupers.get_unchecked(null_p);
229
let grouper =
230
&*(dyn_grouper as *const dyn Grouper as *const SingleKeyHashGrouper<T>);
231
grouper.contains_null()
232
};
233
234
contains_key.push(has_group != invert);
235
});
236
}
237
}
238
239
fn as_any(&self) -> &dyn Any {
240
self
241
}
242
}
243
244