Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-expr/src/groups/row_encoded.rs
6940 views
1
use arrow::array::Array;
2
use polars_row::RowEncodingOptions;
3
use polars_utils::idx_map::bytes_idx_map::{BytesIndexMap, Entry};
4
use polars_utils::itertools::Itertools;
5
use polars_utils::vec::PushUnchecked;
6
7
use self::row_encode::get_row_encoding_context;
8
use super::*;
9
use crate::hash_keys::HashKeys;
10
11
#[derive(Default)]
12
pub struct RowEncodedHashGrouper {
13
idx_map: BytesIndexMap<()>,
14
}
15
16
impl RowEncodedHashGrouper {
17
pub fn new() -> Self {
18
Self {
19
idx_map: BytesIndexMap::new(),
20
}
21
}
22
23
fn insert_key(&mut self, hash: u64, key: &[u8]) -> IdxSize {
24
match self.idx_map.entry(hash, key) {
25
Entry::Occupied(o) => o.index(),
26
Entry::Vacant(v) => {
27
let index = v.index();
28
v.insert(());
29
index
30
},
31
}
32
}
33
34
fn contains_key(&self, hash: u64, key: &[u8]) -> bool {
35
self.idx_map.contains_key(hash, key)
36
}
37
38
fn finalize_keys(&self, key_schema: &Schema, mut key_rows: Vec<&[u8]>) -> DataFrame {
39
let key_dtypes = key_schema
40
.iter()
41
.map(|(_name, dt)| dt.to_physical().to_arrow(CompatLevel::newest()))
42
.collect::<Vec<_>>();
43
let ctxts = key_schema
44
.iter()
45
.map(|(_, dt)| get_row_encoding_context(dt))
46
.collect::<Vec<_>>();
47
let fields = vec![RowEncodingOptions::new_unsorted(); key_dtypes.len()];
48
let key_columns =
49
unsafe { polars_row::decode::decode_rows(&mut key_rows, &fields, &ctxts, &key_dtypes) };
50
51
let cols = key_schema
52
.iter()
53
.zip(key_columns)
54
.map(|((name, dt), col)| {
55
let s = Series::try_from((name.clone(), col)).unwrap();
56
unsafe { s.from_physical_unchecked(dt) }
57
.unwrap()
58
.into_column()
59
})
60
.collect();
61
unsafe { DataFrame::new_no_checks_height_from_first(cols) }
62
}
63
}
64
65
impl Grouper for RowEncodedHashGrouper {
66
fn new_empty(&self) -> Box<dyn Grouper> {
67
Box::new(Self::new())
68
}
69
70
fn reserve(&mut self, additional: usize) {
71
self.idx_map.reserve(additional);
72
}
73
74
fn num_groups(&self) -> IdxSize {
75
self.idx_map.len()
76
}
77
78
unsafe fn insert_keys_subset(
79
&mut self,
80
keys: &HashKeys,
81
subset: &[IdxSize],
82
group_idxs: Option<&mut Vec<IdxSize>>,
83
) {
84
let HashKeys::RowEncoded(keys) = keys else {
85
unreachable!()
86
};
87
88
unsafe {
89
if let Some(group_idxs) = group_idxs {
90
group_idxs.reserve(subset.len());
91
keys.for_each_hash_subset(subset, |idx, opt_hash| {
92
if let Some(hash) = opt_hash {
93
let key = keys.keys.value_unchecked(idx as usize);
94
group_idxs.push_unchecked(self.insert_key(hash, key));
95
}
96
});
97
} else {
98
keys.for_each_hash_subset(subset, |idx, opt_hash| {
99
if let Some(hash) = opt_hash {
100
let key = keys.keys.value_unchecked(idx as usize);
101
self.insert_key(hash, key);
102
}
103
});
104
}
105
}
106
}
107
108
fn get_keys_in_group_order(&self, schema: &Schema) -> DataFrame {
109
unsafe {
110
let mut key_rows: Vec<&[u8]> = Vec::with_capacity(self.idx_map.len() as usize);
111
for (_, key) in self.idx_map.iter_hash_keys() {
112
key_rows.push_unchecked(key);
113
}
114
self.finalize_keys(schema, key_rows)
115
}
116
}
117
118
/// # Safety
119
/// All groupers must be a RowEncodedHashGrouper.
120
unsafe fn probe_partitioned_groupers(
121
&self,
122
groupers: &[Box<dyn Grouper>],
123
keys: &HashKeys,
124
partitioner: &HashPartitioner,
125
invert: bool,
126
probe_matches: &mut Vec<IdxSize>,
127
) {
128
let HashKeys::RowEncoded(keys) = keys else {
129
unreachable!()
130
};
131
assert!(partitioner.num_partitions() == groupers.len());
132
133
unsafe {
134
if keys.keys.has_nulls() {
135
for (idx, hash) in keys.hashes.values_iter().enumerate_idx() {
136
let has_group = if let Some(key) = keys.keys.get_unchecked(idx as usize) {
137
let p = partitioner.hash_to_partition(*hash);
138
let dyn_grouper: &dyn Grouper = &**groupers.get_unchecked(p);
139
let grouper =
140
&*(dyn_grouper as *const dyn Grouper as *const RowEncodedHashGrouper);
141
grouper.contains_key(*hash, key)
142
} else {
143
false
144
};
145
146
if has_group != invert {
147
probe_matches.push(idx);
148
}
149
}
150
} else {
151
for (idx, (hash, key)) in keys
152
.hashes
153
.values_iter()
154
.zip(keys.keys.values_iter())
155
.enumerate_idx()
156
{
157
let p = partitioner.hash_to_partition(*hash);
158
let dyn_grouper: &dyn Grouper = &**groupers.get_unchecked(p);
159
let grouper =
160
&*(dyn_grouper as *const dyn Grouper as *const RowEncodedHashGrouper);
161
if grouper.contains_key(*hash, key) != invert {
162
probe_matches.push(idx);
163
}
164
}
165
}
166
}
167
}
168
169
/// # Safety
170
/// All groupers must be a RowEncodedHashGrouper.
171
unsafe fn contains_key_partitioned_groupers(
172
&self,
173
groupers: &[Box<dyn Grouper>],
174
keys: &HashKeys,
175
partitioner: &HashPartitioner,
176
invert: bool,
177
contains_key: &mut BitmapBuilder,
178
) {
179
let HashKeys::RowEncoded(keys) = keys else {
180
unreachable!()
181
};
182
assert!(partitioner.num_partitions() == groupers.len());
183
184
unsafe {
185
if keys.keys.has_nulls() {
186
for (idx, hash) in keys.hashes.values_iter().enumerate_idx() {
187
let has_group = if let Some(key) = keys.keys.get_unchecked(idx as usize) {
188
let p = partitioner.hash_to_partition(*hash);
189
let dyn_grouper: &dyn Grouper = &**groupers.get_unchecked(p);
190
let grouper =
191
&*(dyn_grouper as *const dyn Grouper as *const RowEncodedHashGrouper);
192
grouper.contains_key(*hash, key)
193
} else {
194
false
195
};
196
197
contains_key.push(has_group != invert);
198
}
199
} else {
200
for (hash, key) in keys.hashes.values_iter().zip(keys.keys.values_iter()) {
201
let p = partitioner.hash_to_partition(*hash);
202
let dyn_grouper: &dyn Grouper = &**groupers.get_unchecked(p);
203
let grouper =
204
&*(dyn_grouper as *const dyn Grouper as *const RowEncodedHashGrouper);
205
contains_key.push(grouper.contains_key(*hash, key) != invert);
206
}
207
}
208
}
209
}
210
211
fn as_any(&self) -> &dyn Any {
212
self
213
}
214
}
215
216