Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-ops/src/frame/join/hash_join/single_keys.rs
8446 views
1
use polars_utils::hashing::{DirtyHash, hash_to_partition};
2
use polars_utils::idx_vec::IdxVec;
3
use polars_utils::nulls::IsNull;
4
use polars_utils::sync::SyncPtr;
5
use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash};
6
use polars_utils::unitvec;
7
8
use super::*;
9
10
// TODO: we should compute the number of threads / partition size we'll use.
11
// let avail_threads = POOL.current_num_threads();
12
// let n_threads = (num_keys / MIN_ELEMS_PER_THREAD).clamp(1, avail_threads);
13
// Use a small element per thread threshold for debugging/testing purposes.
14
const MIN_ELEMS_PER_THREAD: usize = if cfg!(debug_assertions) { 1 } else { 128 };
15
16
pub(crate) fn build_tables<T, I>(
17
keys: Vec<I>,
18
nulls_equal: bool,
19
) -> Vec<PlHashMap<<T as ToTotalOrd>::TotalOrdItem, IdxVec>>
20
where
21
T: TotalHash + TotalEq + ToTotalOrd,
22
<T as ToTotalOrd>::TotalOrdItem: Send + Sync + Copy + Hash + Eq + DirtyHash + IsNull,
23
I: IntoIterator<Item = T> + Send + Sync + Clone,
24
{
25
// TODO: change interface to split the input here, instead of taking
26
// pre-split input iterators.
27
let n_partitions = keys.len();
28
let n_threads = n_partitions;
29
let num_keys_est: usize = keys
30
.iter()
31
.map(|k| k.clone().into_iter().size_hint().0)
32
.sum();
33
34
// Don't bother parallelizing anything for small inputs.
35
if num_keys_est < 2 * MIN_ELEMS_PER_THREAD {
36
let mut hm: PlHashMap<T::TotalOrdItem, IdxVec> = PlHashMap::new();
37
let mut offset = 0;
38
for it in keys {
39
for k in it {
40
let k = k.to_total_ord();
41
if !k.is_null() || nulls_equal {
42
hm.entry(k).or_default().push(offset);
43
}
44
offset += 1;
45
}
46
}
47
return vec![hm];
48
}
49
50
POOL.install(|| {
51
// Compute the number of elements in each partition for each portion.
52
let per_thread_partition_sizes: Vec<Vec<usize>> = keys
53
.par_iter()
54
.with_max_len(1)
55
.map(|key_portion| {
56
let mut partition_sizes = vec![0; n_partitions];
57
for key in key_portion.clone() {
58
let key = key.to_total_ord();
59
let p = hash_to_partition(key.dirty_hash(), n_partitions);
60
unsafe {
61
*partition_sizes.get_unchecked_mut(p) += 1;
62
}
63
}
64
partition_sizes
65
})
66
.collect();
67
68
// Compute output offsets with a cumulative sum.
69
let mut per_thread_partition_offsets = vec![0; n_partitions * n_threads + 1];
70
let mut partition_offsets = vec![0; n_partitions + 1];
71
let mut cum_offset = 0;
72
for p in 0..n_partitions {
73
partition_offsets[p] = cum_offset;
74
for t in 0..n_threads {
75
per_thread_partition_offsets[t * n_partitions + p] = cum_offset;
76
cum_offset += per_thread_partition_sizes[t][p];
77
}
78
}
79
let num_keys = cum_offset;
80
per_thread_partition_offsets[n_threads * n_partitions] = num_keys;
81
partition_offsets[n_partitions] = num_keys;
82
83
// TODO: we wouldn't need this if we changed our interface to split the
84
// input in this function, instead of taking a vec of iterators.
85
let mut per_thread_input_offsets = vec![0; n_partitions];
86
cum_offset = 0;
87
for t in 0..n_threads {
88
per_thread_input_offsets[t] = cum_offset;
89
cum_offset += per_thread_partition_sizes[t]
90
.iter()
91
.take(n_partitions)
92
.sum::<usize>();
93
}
94
95
// Scatter values into partitions.
96
let mut scatter_keys: Vec<T::TotalOrdItem> = Vec::with_capacity(num_keys);
97
let mut scatter_idxs: Vec<IdxSize> = Vec::with_capacity(num_keys);
98
let scatter_keys_ptr = unsafe { SyncPtr::new(scatter_keys.as_mut_ptr()) };
99
let scatter_idxs_ptr = unsafe { SyncPtr::new(scatter_idxs.as_mut_ptr()) };
100
keys.into_par_iter()
101
.with_max_len(1)
102
.enumerate()
103
.for_each(|(t, key_portion)| {
104
let mut partition_offsets =
105
per_thread_partition_offsets[t * n_partitions..(t + 1) * n_partitions].to_vec();
106
for (i, key) in key_portion.into_iter().enumerate() {
107
let key = key.to_total_ord();
108
unsafe {
109
let p = hash_to_partition(key.dirty_hash(), n_partitions);
110
let off = partition_offsets.get_unchecked_mut(p);
111
*scatter_keys_ptr.get().add(*off) = key;
112
*scatter_idxs_ptr.get().add(*off) =
113
(per_thread_input_offsets[t] + i) as IdxSize;
114
*off += 1;
115
}
116
}
117
});
118
unsafe {
119
scatter_keys.set_len(num_keys);
120
scatter_idxs.set_len(num_keys);
121
}
122
123
// Build tables.
124
(0..n_partitions)
125
.into_par_iter()
126
.with_max_len(1)
127
.map(|p| {
128
// Resizing the hash map is very, very expensive. That's why we
129
// adopt a hybrid strategy: we assume an initially small hash
130
// map, which would satisfy a highly skewed relation. If this
131
// fills up we immediately reserve enough for a full cardinality
132
// data set.
133
let partition_range = partition_offsets[p]..partition_offsets[p + 1];
134
let full_size = partition_range.len();
135
let mut conservative_size = _HASHMAP_INIT_SIZE.max(full_size / 64);
136
let mut hm: PlHashMap<T::TotalOrdItem, IdxVec> =
137
PlHashMap::with_capacity(conservative_size);
138
139
unsafe {
140
for i in partition_range {
141
if hm.len() == conservative_size {
142
hm.reserve(full_size - conservative_size);
143
conservative_size = 0; // Hack to ensure we never hit this branch again.
144
}
145
146
let key = *scatter_keys.get_unchecked(i);
147
148
if !key.is_null() || nulls_equal {
149
let idx = *scatter_idxs.get_unchecked(i);
150
match hm.entry(key) {
151
Entry::Occupied(mut o) => {
152
o.get_mut().push(idx as IdxSize);
153
},
154
Entry::Vacant(v) => {
155
let iv = unitvec![idx as IdxSize];
156
v.insert(iv);
157
},
158
};
159
}
160
}
161
}
162
163
hm
164
})
165
.collect()
166
})
167
}
168
169
// we determine the offset so that we later know which index to store in the join tuples
170
pub(super) fn probe_to_offsets<T, I>(probe: &[I]) -> Vec<usize>
171
where
172
I: IntoIterator<Item = T> + Clone,
173
{
174
probe
175
.iter()
176
.map(|ph| ph.clone().into_iter().size_hint().1.unwrap())
177
.scan(0, |state, val| {
178
let out = *state;
179
*state += val;
180
Some(out)
181
})
182
.collect()
183
}
184
185