Path: blob/main/crates/polars-expr/src/groups/single_key.rs
6940 views
use arrow::array::Array;1use arrow::bitmap::MutableBitmap;2use polars_utils::idx_map::total_idx_map::{Entry, TotalIndexMap};3use polars_utils::total_ord::{TotalEq, TotalHash};4use polars_utils::vec::PushUnchecked;56use super::*;7use crate::hash_keys::{HashKeys, for_each_hash_single};89#[derive(Default)]10pub struct SingleKeyHashGrouper<T: PolarsDataType> {11idx_map: TotalIndexMap<T::Physical<'static>, ()>,12null_idx: IdxSize,13}1415impl<K, T: PolarsDataType> SingleKeyHashGrouper<T>16where17for<'a> T: PolarsDataType<Physical<'a> = K>,18K: Default + TotalHash + TotalEq,19{20pub fn new() -> Self {21Self {22idx_map: TotalIndexMap::default(),23null_idx: IdxSize::MAX,24}25}2627#[inline(always)]28fn insert_key(&mut self, key: T::Physical<'static>) -> IdxSize {29match self.idx_map.entry(key) {30Entry::Occupied(o) => o.index(),31Entry::Vacant(v) => {32let index = v.index();33v.insert(());34index35},36}37}3839#[inline(always)]40fn insert_null(&mut self) -> IdxSize {41if self.null_idx == IdxSize::MAX {42self.null_idx = self.idx_map.push_unmapped_entry(T::Physical::default(), ());43}44self.null_idx45}4647#[inline(always)]48fn contains_key(&self, key: &T::Physical<'static>) -> bool {49self.idx_map.get(key).is_some()50}5152#[inline(always)]53fn contains_null(&self) -> bool {54self.null_idx < IdxSize::MAX55}5657fn finalize_keys(&self, schema: &Schema, keys: Vec<T::Physical<'static>>) -> DataFrame {58let (name, dtype) = schema.get_at_index(0).unwrap();59let mut keys =60T::Array::from_vec(keys, dtype.to_physical().to_arrow(CompatLevel::newest()));61if self.null_idx < IdxSize::MAX {62let mut validity = MutableBitmap::new();63validity.extend_constant(keys.len(), true);64validity.set(self.null_idx as usize, false);65keys = keys.with_validity_typed(Some(validity.freeze()));66}67unsafe {68let s =69Series::from_chunks_and_dtype_unchecked(name.clone(), vec![Box::new(keys)], dtype);70DataFrame::new(vec![Column::from(s)]).unwrap()71}72}73}7475impl<K, T: PolarsDataType> Grouper for SingleKeyHashGrouper<T>76where77for<'a> T: PolarsDataType<Physical<'a> = K>,78K: Default + TotalHash + TotalEq + Clone + Send + Sync + 'static,79{80fn new_empty(&self) -> Box<dyn Grouper> {81Box::new(Self::new())82}8384fn reserve(&mut self, additional: usize) {85self.idx_map.reserve(additional);86}8788fn num_groups(&self) -> IdxSize {89self.idx_map.len()90}9192unsafe fn insert_keys_subset(93&mut self,94hash_keys: &HashKeys,95subset: &[IdxSize],96group_idxs: Option<&mut Vec<IdxSize>>,97) {98let HashKeys::Single(hash_keys) = hash_keys else {99unreachable!()100};101let ca: &ChunkedArray<T> = hash_keys.keys.as_phys_any().downcast_ref().unwrap();102let arr = ca.downcast_as_array();103104unsafe {105if arr.has_nulls() {106if hash_keys.null_is_valid {107let groups = subset.iter().map(|idx| {108let opt_k = arr.get_unchecked(*idx as usize);109if let Some(k) = opt_k {110self.insert_key(k)111} else {112self.insert_null()113}114});115if let Some(group_idxs) = group_idxs {116group_idxs.reserve(subset.len());117group_idxs.extend(groups);118} else {119groups.for_each(drop);120}121} else {122let groups = subset.iter().filter_map(|idx| {123let opt_k = arr.get_unchecked(*idx as usize);124opt_k.map(|k| self.insert_key(k))125});126if let Some(group_idxs) = group_idxs {127group_idxs.reserve(subset.len());128group_idxs.extend(groups);129} else {130groups.for_each(drop);131}132}133} else {134let groups = subset.iter().map(|idx| {135let k = arr.value_unchecked(*idx as usize);136self.insert_key(k)137});138if let Some(group_idxs) = group_idxs {139group_idxs.reserve(subset.len());140group_idxs.extend(groups);141} else {142groups.for_each(drop);143}144}145}146}147148fn get_keys_in_group_order(&self, schema: &Schema) -> DataFrame {149unsafe {150let mut key_rows = Vec::with_capacity(self.idx_map.len() as usize);151for key in self.idx_map.iter_keys() {152key_rows.push_unchecked(key.clone());153}154self.finalize_keys(schema, key_rows)155}156}157158/// # Safety159/// All groupers must be a SingleKeyHashGrouper<T>.160unsafe fn probe_partitioned_groupers(161&self,162groupers: &[Box<dyn Grouper>],163hash_keys: &HashKeys,164partitioner: &HashPartitioner,165invert: bool,166probe_matches: &mut Vec<IdxSize>,167) {168let HashKeys::Single(hash_keys) = hash_keys else {169unreachable!()170};171let ca: &ChunkedArray<T> = hash_keys.keys.as_phys_any().downcast_ref().unwrap();172let arr = ca.downcast_as_array();173assert!(partitioner.num_partitions() == groupers.len());174175unsafe {176let null_p = partitioner.null_partition();177for_each_hash_single(ca, &hash_keys.random_state, |idx, opt_h| {178let has_group = if let Some(h) = opt_h {179let p = partitioner.hash_to_partition(h);180let dyn_grouper: &dyn Grouper = &**groupers.get_unchecked(p);181let grouper =182&*(dyn_grouper as *const dyn Grouper as *const SingleKeyHashGrouper<T>);183let key = arr.value_unchecked(idx as usize);184grouper.contains_key(&key)185} else {186let dyn_grouper: &dyn Grouper = &**groupers.get_unchecked(null_p);187let grouper =188&*(dyn_grouper as *const dyn Grouper as *const SingleKeyHashGrouper<T>);189grouper.contains_null()190};191192if has_group != invert {193probe_matches.push(idx);194}195});196}197}198199/// # Safety200/// All groupers must be a SingleKeyHashGrouper<T>.201unsafe fn contains_key_partitioned_groupers(202&self,203groupers: &[Box<dyn Grouper>],204hash_keys: &HashKeys,205partitioner: &HashPartitioner,206invert: bool,207contains_key: &mut BitmapBuilder,208) {209let HashKeys::Single(hash_keys) = hash_keys else {210unreachable!()211};212let ca: &ChunkedArray<T> = hash_keys.keys.as_phys_any().downcast_ref().unwrap();213let arr = ca.downcast_as_array();214assert!(partitioner.num_partitions() == groupers.len());215216unsafe {217let null_p = partitioner.null_partition();218for_each_hash_single(ca, &hash_keys.random_state, |idx, opt_h| {219let has_group = if let Some(h) = opt_h {220let p = partitioner.hash_to_partition(h);221let dyn_grouper: &dyn Grouper = &**groupers.get_unchecked(p);222let grouper =223&*(dyn_grouper as *const dyn Grouper as *const SingleKeyHashGrouper<T>);224let key = arr.value_unchecked(idx as usize);225grouper.contains_key(&key)226} else {227let dyn_grouper: &dyn Grouper = &**groupers.get_unchecked(null_p);228let grouper =229&*(dyn_grouper as *const dyn Grouper as *const SingleKeyHashGrouper<T>);230grouper.contains_null()231};232233contains_key.push(has_group != invert);234});235}236}237238fn as_any(&self) -> &dyn Any {239self240}241}242243244