Path: blob/main/crates/polars-ops/src/frame/join/hash_join/single_keys.rs
6940 views
use polars_utils::hashing::{DirtyHash, hash_to_partition};1use polars_utils::idx_vec::IdxVec;2use polars_utils::nulls::IsNull;3use polars_utils::sync::SyncPtr;4use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash};5use polars_utils::unitvec;67use super::*;89// FIXME: we should compute the number of threads / partition size we'll use.10// let avail_threads = POOL.current_num_threads();11// let n_threads = (num_keys / MIN_ELEMS_PER_THREAD).clamp(1, avail_threads);12// Use a small element per thread threshold for debugging/testing purposes.13const MIN_ELEMS_PER_THREAD: usize = if cfg!(debug_assertions) { 1 } else { 128 };1415pub(crate) fn build_tables<T, I>(16keys: Vec<I>,17nulls_equal: bool,18) -> Vec<PlHashMap<<T as ToTotalOrd>::TotalOrdItem, IdxVec>>19where20T: TotalHash + TotalEq + ToTotalOrd,21<T as ToTotalOrd>::TotalOrdItem: Send + Sync + Copy + Hash + Eq + DirtyHash + IsNull,22I: IntoIterator<Item = T> + Send + Sync + Clone,23{24// FIXME: change interface to split the input here, instead of taking25// pre-split input iterators.26let n_partitions = keys.len();27let n_threads = n_partitions;28let num_keys_est: usize = keys29.iter()30.map(|k| k.clone().into_iter().size_hint().0)31.sum();3233// Don't bother parallelizing anything for small inputs.34if num_keys_est < 2 * MIN_ELEMS_PER_THREAD {35let mut hm: PlHashMap<T::TotalOrdItem, IdxVec> = PlHashMap::new();36let mut offset = 0;37for it in keys {38for k in it {39let k = k.to_total_ord();40if !k.is_null() || nulls_equal {41hm.entry(k).or_default().push(offset);42}43offset += 1;44}45}46return vec![hm];47}4849POOL.install(|| {50// Compute the number of elements in each partition for each portion.51let per_thread_partition_sizes: Vec<Vec<usize>> = keys52.par_iter()53.with_max_len(1)54.map(|key_portion| {55let mut partition_sizes = vec![0; n_partitions];56for key in key_portion.clone() {57let key = key.to_total_ord();58let p = hash_to_partition(key.dirty_hash(), n_partitions);59unsafe {60*partition_sizes.get_unchecked_mut(p) += 1;61}62}63partition_sizes64})65.collect();6667// Compute output offsets with a cumulative sum.68let mut per_thread_partition_offsets = vec![0; n_partitions * n_threads + 1];69let mut partition_offsets = vec![0; n_partitions + 1];70let mut cum_offset = 0;71for p in 0..n_partitions {72partition_offsets[p] = cum_offset;73for t in 0..n_threads {74per_thread_partition_offsets[t * n_partitions + p] = cum_offset;75cum_offset += per_thread_partition_sizes[t][p];76}77}78let num_keys = cum_offset;79per_thread_partition_offsets[n_threads * n_partitions] = num_keys;80partition_offsets[n_partitions] = num_keys;8182// FIXME: we wouldn't need this if we changed our interface to split the83// input in this function, instead of taking a vec of iterators.84let mut per_thread_input_offsets = vec![0; n_partitions];85cum_offset = 0;86for t in 0..n_threads {87per_thread_input_offsets[t] = cum_offset;88for p in 0..n_partitions {89cum_offset += per_thread_partition_sizes[t][p];90}91}9293// Scatter values into partitions.94let mut scatter_keys: Vec<T::TotalOrdItem> = Vec::with_capacity(num_keys);95let mut scatter_idxs: Vec<IdxSize> = Vec::with_capacity(num_keys);96let scatter_keys_ptr = unsafe { SyncPtr::new(scatter_keys.as_mut_ptr()) };97let scatter_idxs_ptr = unsafe { SyncPtr::new(scatter_idxs.as_mut_ptr()) };98keys.into_par_iter()99.with_max_len(1)100.enumerate()101.for_each(|(t, key_portion)| {102let mut partition_offsets =103per_thread_partition_offsets[t * n_partitions..(t + 1) * n_partitions].to_vec();104for (i, key) in key_portion.into_iter().enumerate() {105let key = key.to_total_ord();106unsafe {107let p = hash_to_partition(key.dirty_hash(), n_partitions);108let off = partition_offsets.get_unchecked_mut(p);109*scatter_keys_ptr.get().add(*off) = key;110*scatter_idxs_ptr.get().add(*off) =111(per_thread_input_offsets[t] + i) as IdxSize;112*off += 1;113}114}115});116unsafe {117scatter_keys.set_len(num_keys);118scatter_idxs.set_len(num_keys);119}120121// Build tables.122(0..n_partitions)123.into_par_iter()124.with_max_len(1)125.map(|p| {126// Resizing the hash map is very, very expensive. That's why we127// adopt a hybrid strategy: we assume an initially small hash128// map, which would satisfy a highly skewed relation. If this129// fills up we immediately reserve enough for a full cardinality130// data set.131let partition_range = partition_offsets[p]..partition_offsets[p + 1];132let full_size = partition_range.len();133let mut conservative_size = _HASHMAP_INIT_SIZE.max(full_size / 64);134let mut hm: PlHashMap<T::TotalOrdItem, IdxVec> =135PlHashMap::with_capacity(conservative_size);136137unsafe {138for i in partition_range {139if hm.len() == conservative_size {140hm.reserve(full_size - conservative_size);141conservative_size = 0; // Hack to ensure we never hit this branch again.142}143144let key = *scatter_keys.get_unchecked(i);145146if !key.is_null() || nulls_equal {147let idx = *scatter_idxs.get_unchecked(i);148match hm.entry(key) {149Entry::Occupied(mut o) => {150o.get_mut().push(idx as IdxSize);151},152Entry::Vacant(v) => {153let iv = unitvec![idx as IdxSize];154v.insert(iv);155},156};157}158}159}160161hm162})163.collect()164})165}166167// we determine the offset so that we later know which index to store in the join tuples168pub(super) fn probe_to_offsets<T, I>(probe: &[I]) -> Vec<usize>169where170I: IntoIterator<Item = T> + Clone,171{172probe173.iter()174.map(|ph| ph.clone().into_iter().size_hint().1.unwrap())175.scan(0, |state, val| {176let out = *state;177*state += val;178Some(out)179})180.collect()181}182183184