Path: blob/main/crates/polars-ops/src/frame/join/hash_join/single_keys.rs
8446 views
use polars_utils::hashing::{DirtyHash, hash_to_partition};1use polars_utils::idx_vec::IdxVec;2use polars_utils::nulls::IsNull;3use polars_utils::sync::SyncPtr;4use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash};5use polars_utils::unitvec;67use super::*;89// TODO: we should compute the number of threads / partition size we'll use.10// let avail_threads = POOL.current_num_threads();11// let n_threads = (num_keys / MIN_ELEMS_PER_THREAD).clamp(1, avail_threads);12// Use a small element per thread threshold for debugging/testing purposes.13const MIN_ELEMS_PER_THREAD: usize = if cfg!(debug_assertions) { 1 } else { 128 };1415pub(crate) fn build_tables<T, I>(16keys: Vec<I>,17nulls_equal: bool,18) -> Vec<PlHashMap<<T as ToTotalOrd>::TotalOrdItem, IdxVec>>19where20T: TotalHash + TotalEq + ToTotalOrd,21<T as ToTotalOrd>::TotalOrdItem: Send + Sync + Copy + Hash + Eq + DirtyHash + IsNull,22I: IntoIterator<Item = T> + Send + Sync + Clone,23{24// TODO: change interface to split the input here, instead of taking25// pre-split input iterators.26let n_partitions = keys.len();27let n_threads = n_partitions;28let num_keys_est: usize = keys29.iter()30.map(|k| k.clone().into_iter().size_hint().0)31.sum();3233// Don't bother parallelizing anything for small inputs.34if num_keys_est < 2 * MIN_ELEMS_PER_THREAD {35let mut hm: PlHashMap<T::TotalOrdItem, IdxVec> = PlHashMap::new();36let mut offset = 0;37for it in keys {38for k in it {39let k = k.to_total_ord();40if !k.is_null() || nulls_equal {41hm.entry(k).or_default().push(offset);42}43offset += 1;44}45}46return vec![hm];47}4849POOL.install(|| {50// Compute the number of elements in each partition for each portion.51let per_thread_partition_sizes: Vec<Vec<usize>> = keys52.par_iter()53.with_max_len(1)54.map(|key_portion| {55let mut partition_sizes = vec![0; n_partitions];56for key in key_portion.clone() {57let key = key.to_total_ord();58let p = hash_to_partition(key.dirty_hash(), n_partitions);59unsafe {60*partition_sizes.get_unchecked_mut(p) += 1;61}62}63partition_sizes64})65.collect();6667// Compute output offsets with a cumulative sum.68let mut per_thread_partition_offsets = vec![0; n_partitions * n_threads + 1];69let mut partition_offsets = vec![0; n_partitions + 1];70let mut cum_offset = 0;71for p in 0..n_partitions {72partition_offsets[p] = cum_offset;73for t in 0..n_threads {74per_thread_partition_offsets[t * n_partitions + p] = cum_offset;75cum_offset += per_thread_partition_sizes[t][p];76}77}78let num_keys = cum_offset;79per_thread_partition_offsets[n_threads * n_partitions] = num_keys;80partition_offsets[n_partitions] = num_keys;8182// TODO: we wouldn't need this if we changed our interface to split the83// input in this function, instead of taking a vec of iterators.84let mut per_thread_input_offsets = vec![0; n_partitions];85cum_offset = 0;86for t in 0..n_threads {87per_thread_input_offsets[t] = cum_offset;88cum_offset += per_thread_partition_sizes[t]89.iter()90.take(n_partitions)91.sum::<usize>();92}9394// Scatter values into partitions.95let mut scatter_keys: Vec<T::TotalOrdItem> = Vec::with_capacity(num_keys);96let mut scatter_idxs: Vec<IdxSize> = Vec::with_capacity(num_keys);97let scatter_keys_ptr = unsafe { SyncPtr::new(scatter_keys.as_mut_ptr()) };98let scatter_idxs_ptr = unsafe { SyncPtr::new(scatter_idxs.as_mut_ptr()) };99keys.into_par_iter()100.with_max_len(1)101.enumerate()102.for_each(|(t, key_portion)| {103let mut partition_offsets =104per_thread_partition_offsets[t * n_partitions..(t + 1) * n_partitions].to_vec();105for (i, key) in key_portion.into_iter().enumerate() {106let key = key.to_total_ord();107unsafe {108let p = hash_to_partition(key.dirty_hash(), n_partitions);109let off = partition_offsets.get_unchecked_mut(p);110*scatter_keys_ptr.get().add(*off) = key;111*scatter_idxs_ptr.get().add(*off) =112(per_thread_input_offsets[t] + i) as IdxSize;113*off += 1;114}115}116});117unsafe {118scatter_keys.set_len(num_keys);119scatter_idxs.set_len(num_keys);120}121122// Build tables.123(0..n_partitions)124.into_par_iter()125.with_max_len(1)126.map(|p| {127// Resizing the hash map is very, very expensive. That's why we128// adopt a hybrid strategy: we assume an initially small hash129// map, which would satisfy a highly skewed relation. If this130// fills up we immediately reserve enough for a full cardinality131// data set.132let partition_range = partition_offsets[p]..partition_offsets[p + 1];133let full_size = partition_range.len();134let mut conservative_size = _HASHMAP_INIT_SIZE.max(full_size / 64);135let mut hm: PlHashMap<T::TotalOrdItem, IdxVec> =136PlHashMap::with_capacity(conservative_size);137138unsafe {139for i in partition_range {140if hm.len() == conservative_size {141hm.reserve(full_size - conservative_size);142conservative_size = 0; // Hack to ensure we never hit this branch again.143}144145let key = *scatter_keys.get_unchecked(i);146147if !key.is_null() || nulls_equal {148let idx = *scatter_idxs.get_unchecked(i);149match hm.entry(key) {150Entry::Occupied(mut o) => {151o.get_mut().push(idx as IdxSize);152},153Entry::Vacant(v) => {154let iv = unitvec![idx as IdxSize];155v.insert(iv);156},157};158}159}160}161162hm163})164.collect()165})166}167168// we determine the offset so that we later know which index to store in the join tuples169pub(super) fn probe_to_offsets<T, I>(probe: &[I]) -> Vec<usize>170where171I: IntoIterator<Item = T> + Clone,172{173probe174.iter()175.map(|ph| ph.clone().into_iter().size_hint().1.unwrap())176.scan(0, |state, val| {177let out = *state;178*state += val;179Some(out)180})181.collect()182}183184185