Path: blob/main/crates/polars-expr/src/idx_table/single_key.rs
6940 views
#![allow(clippy::unnecessary_cast)] // Clippy doesn't recognize that IdxSize and u64 can be different.1#![allow(unsafe_op_in_unsafe_fn)]23use polars_utils::idx_map::total_idx_map::{Entry, TotalIndexMap};4use polars_utils::idx_vec::UnitVec;5use polars_utils::itertools::Itertools;6use polars_utils::relaxed_cell::RelaxedCell;7use polars_utils::total_ord::{TotalEq, TotalHash};8use polars_utils::unitvec;910use super::*;11use crate::hash_keys::HashKeys;1213pub struct SingleKeyIdxTable<T: PolarsDataType> {14// These AtomicU64s actually are IdxSizes, but we use the top bit of the15// first index in each to mark keys during probing.16idx_map: TotalIndexMap<T::Physical<'static>, UnitVec<RelaxedCell<u64>>>,17idx_offset: IdxSize,18null_keys: Vec<IdxSize>,19nulls_emitted: RelaxedCell<bool>,20}2122impl<T: PolarsDataType> SingleKeyIdxTable<T> {23pub fn new() -> Self {24Self {25idx_map: TotalIndexMap::default(),26idx_offset: 0,27null_keys: Vec::new(),28nulls_emitted: RelaxedCell::from(false),29}30}31}3233impl<T, K> SingleKeyIdxTable<T>34where35for<'a> T: PolarsDataType<Physical<'a> = K>,36K: TotalHash + TotalEq + Send + Sync + 'static,37{38#[inline(always)]39fn probe_one<const MARK_MATCHES: bool>(40&self,41key_idx: IdxSize,42key: &K,43table_match: &mut Vec<IdxSize>,44probe_match: &mut Vec<IdxSize>,45) -> bool {46if let Some(idxs) = self.idx_map.get(key) {47for idx in &idxs[..] {48// Create matches, making sure to clear top bit.49table_match.push((idx.load() & !(1 << 63)) as IdxSize);50probe_match.push(key_idx);51}5253// Mark if necessary. This action is idempotent so doesn't need54// atomic fetch_or to do it atomically.55if MARK_MATCHES {56let first_idx = unsafe { idxs.get_unchecked(0) };57let first_idx_val = first_idx.load();58if first_idx_val >> 63 == 0 {59first_idx.store(first_idx_val | (1 << 63));60}61}62true63} else {64false65}66}6768fn probe_impl<69const MARK_MATCHES: bool,70const EMIT_UNMATCHED: bool,71const NULL_IS_VALID: bool,72>(73&self,74keys: impl Iterator<Item = (IdxSize, Option<K>)>,75table_match: &mut Vec<IdxSize>,76probe_match: &mut Vec<IdxSize>,77limit: IdxSize,78) -> IdxSize {79let mut keys_processed = 0;80for (key_idx, key) in keys {81let found_match = if let Some(key) = key {82self.probe_one::<MARK_MATCHES>(key_idx, &key, table_match, probe_match)83} else if NULL_IS_VALID {84for idx in &self.null_keys {85table_match.push(*idx);86probe_match.push(key_idx);87}88if MARK_MATCHES && !self.nulls_emitted.load() {89self.nulls_emitted.store(true);90}91!self.null_keys.is_empty()92} else {93false94};9596if EMIT_UNMATCHED && !found_match {97table_match.push(IdxSize::MAX);98probe_match.push(key_idx);99}100101keys_processed += 1;102if table_match.len() >= limit as usize {103break;104}105}106keys_processed107}108109#[allow(clippy::too_many_arguments)]110fn probe_dispatch(111&self,112keys: impl Iterator<Item = (IdxSize, Option<K>)>,113table_match: &mut Vec<IdxSize>,114probe_match: &mut Vec<IdxSize>,115mark_matches: bool,116emit_unmatched: bool,117null_is_valid: bool,118limit: IdxSize,119) -> IdxSize {120match (mark_matches, emit_unmatched, null_is_valid) {121(false, false, false) => {122self.probe_impl::<false, false, false>(keys, table_match, probe_match, limit)123},124(false, false, true) => {125self.probe_impl::<false, false, true>(keys, table_match, probe_match, limit)126},127(false, true, false) => {128self.probe_impl::<false, true, false>(keys, table_match, probe_match, limit)129},130(false, true, true) => {131self.probe_impl::<false, true, true>(keys, table_match, probe_match, limit)132},133(true, false, false) => {134self.probe_impl::<true, false, false>(keys, table_match, probe_match, limit)135},136(true, false, true) => {137self.probe_impl::<true, false, true>(keys, table_match, probe_match, limit)138},139(true, true, false) => {140self.probe_impl::<true, true, false>(keys, table_match, probe_match, limit)141},142(true, true, true) => {143self.probe_impl::<true, true, true>(keys, table_match, probe_match, limit)144},145}146}147}148149impl<T, K> IdxTable for SingleKeyIdxTable<T>150where151for<'a> T: PolarsDataType<Physical<'a> = K>,152K: TotalHash + TotalEq + Send + Sync + 'static,153{154fn new_empty(&self) -> Box<dyn IdxTable> {155Box::new(Self::new())156}157158fn reserve(&mut self, additional: usize) {159self.idx_map.reserve(additional);160}161162fn num_keys(&self) -> IdxSize {163self.idx_map.len()164}165166fn insert_keys(&mut self, _hash_keys: &HashKeys, _track_unmatchable: bool) {167// Isn't needed anymore, but also don't want to remove the code from the other implementations.168unimplemented!()169}170171unsafe fn insert_keys_subset(172&mut self,173hash_keys: &HashKeys,174subset: &[IdxSize],175track_unmatchable: bool,176) {177let HashKeys::Single(hash_keys) = hash_keys else {178unreachable!()179};180let new_idx_offset = (self.idx_offset as usize)181.checked_add(subset.len())182.unwrap();183assert!(184new_idx_offset < IdxSize::MAX as usize,185"overly large index in SingleKeyIdxTable"186);187188let keys: &ChunkedArray<T> = hash_keys.keys.as_phys_any().downcast_ref().unwrap();189for (i, subset_idx) in subset.iter().enumerate_idx() {190let key = unsafe { keys.get_unchecked(*subset_idx as usize) };191let idx = self.idx_offset + i;192if let Some(key) = key {193match self.idx_map.entry(key) {194Entry::Occupied(o) => {195o.into_mut().push(RelaxedCell::from(idx as u64));196},197Entry::Vacant(v) => {198v.insert(unitvec![RelaxedCell::from(idx as u64)]);199},200}201} else if track_unmatchable | hash_keys.null_is_valid {202self.null_keys.push(idx);203}204}205206self.idx_offset = new_idx_offset as IdxSize;207}208209fn probe(210&self,211_hash_keys: &HashKeys,212_table_match: &mut Vec<IdxSize>,213_probe_match: &mut Vec<IdxSize>,214_mark_matches: bool,215_emit_unmatched: bool,216_limit: IdxSize,217) -> IdxSize {218// Isn't needed anymore, but also don't want to remove the code from the other implementations.219unimplemented!()220}221222unsafe fn probe_subset(223&self,224hash_keys: &HashKeys,225subset: &[IdxSize],226table_match: &mut Vec<IdxSize>,227probe_match: &mut Vec<IdxSize>,228mark_matches: bool,229emit_unmatched: bool,230limit: IdxSize,231) -> IdxSize {232let HashKeys::Single(hash_keys) = hash_keys else {233unreachable!()234};235236let keys: &ChunkedArray<T> = hash_keys.keys.as_phys_any().downcast_ref().unwrap();237if keys.has_nulls() {238let iter = subset.iter().map(|i| (*i, keys.get_unchecked(*i as usize)));239self.probe_dispatch(240iter,241table_match,242probe_match,243mark_matches,244emit_unmatched,245hash_keys.null_is_valid,246limit,247)248} else {249let iter = subset250.iter()251.map(|i| (*i, Some(keys.value_unchecked(*i as usize))));252self.probe_dispatch(253iter,254table_match,255probe_match,256mark_matches,257emit_unmatched,258false, // Whether or not nulls are valid doesn't matter.259limit,260)261}262}263264fn unmarked_keys(265&self,266out: &mut Vec<IdxSize>,267mut offset: IdxSize,268limit: IdxSize,269) -> IdxSize {270out.clear();271272let mut keys_processed = 0;273if !self.nulls_emitted.load() {274if (offset as usize) < self.null_keys.len() {275out.extend(276self.null_keys[offset as usize..]277.iter()278.copied()279.take(limit as usize),280);281keys_processed += out.len() as IdxSize;282offset += out.len() as IdxSize;283if out.len() >= limit as usize {284return keys_processed;285}286}287offset -= self.null_keys.len() as IdxSize;288}289290while let Some((_, idxs)) = self.idx_map.get_index(offset) {291let first_idx = unsafe { idxs.get_unchecked(0) };292let first_idx_val = first_idx.load();293if first_idx_val >> 63 == 0 {294for idx in &idxs[..] {295out.push((idx.load() & !(1 << 63)) as IdxSize);296}297}298299keys_processed += 1;300offset += 1;301if out.len() >= limit as usize {302break;303}304}305306keys_processed307}308}309310311