Path: blob/main/crates/polars-expr/src/groups/row_encoded.rs
6940 views
use arrow::array::Array;1use polars_row::RowEncodingOptions;2use polars_utils::idx_map::bytes_idx_map::{BytesIndexMap, Entry};3use polars_utils::itertools::Itertools;4use polars_utils::vec::PushUnchecked;56use self::row_encode::get_row_encoding_context;7use super::*;8use crate::hash_keys::HashKeys;910#[derive(Default)]11pub struct RowEncodedHashGrouper {12idx_map: BytesIndexMap<()>,13}1415impl RowEncodedHashGrouper {16pub fn new() -> Self {17Self {18idx_map: BytesIndexMap::new(),19}20}2122fn insert_key(&mut self, hash: u64, key: &[u8]) -> IdxSize {23match self.idx_map.entry(hash, key) {24Entry::Occupied(o) => o.index(),25Entry::Vacant(v) => {26let index = v.index();27v.insert(());28index29},30}31}3233fn contains_key(&self, hash: u64, key: &[u8]) -> bool {34self.idx_map.contains_key(hash, key)35}3637fn finalize_keys(&self, key_schema: &Schema, mut key_rows: Vec<&[u8]>) -> DataFrame {38let key_dtypes = key_schema39.iter()40.map(|(_name, dt)| dt.to_physical().to_arrow(CompatLevel::newest()))41.collect::<Vec<_>>();42let ctxts = key_schema43.iter()44.map(|(_, dt)| get_row_encoding_context(dt))45.collect::<Vec<_>>();46let fields = vec![RowEncodingOptions::new_unsorted(); key_dtypes.len()];47let key_columns =48unsafe { polars_row::decode::decode_rows(&mut key_rows, &fields, &ctxts, &key_dtypes) };4950let cols = key_schema51.iter()52.zip(key_columns)53.map(|((name, dt), col)| {54let s = Series::try_from((name.clone(), col)).unwrap();55unsafe { s.from_physical_unchecked(dt) }56.unwrap()57.into_column()58})59.collect();60unsafe { DataFrame::new_no_checks_height_from_first(cols) }61}62}6364impl Grouper for RowEncodedHashGrouper {65fn new_empty(&self) -> Box<dyn Grouper> {66Box::new(Self::new())67}6869fn reserve(&mut self, additional: usize) {70self.idx_map.reserve(additional);71}7273fn num_groups(&self) -> IdxSize {74self.idx_map.len()75}7677unsafe fn insert_keys_subset(78&mut self,79keys: &HashKeys,80subset: &[IdxSize],81group_idxs: Option<&mut Vec<IdxSize>>,82) {83let HashKeys::RowEncoded(keys) = keys else {84unreachable!()85};8687unsafe {88if let Some(group_idxs) = group_idxs {89group_idxs.reserve(subset.len());90keys.for_each_hash_subset(subset, |idx, opt_hash| {91if let Some(hash) = opt_hash {92let key = keys.keys.value_unchecked(idx as usize);93group_idxs.push_unchecked(self.insert_key(hash, key));94}95});96} else {97keys.for_each_hash_subset(subset, |idx, opt_hash| {98if let Some(hash) = opt_hash {99let key = keys.keys.value_unchecked(idx as usize);100self.insert_key(hash, key);101}102});103}104}105}106107fn get_keys_in_group_order(&self, schema: &Schema) -> DataFrame {108unsafe {109let mut key_rows: Vec<&[u8]> = Vec::with_capacity(self.idx_map.len() as usize);110for (_, key) in self.idx_map.iter_hash_keys() {111key_rows.push_unchecked(key);112}113self.finalize_keys(schema, key_rows)114}115}116117/// # Safety118/// All groupers must be a RowEncodedHashGrouper.119unsafe fn probe_partitioned_groupers(120&self,121groupers: &[Box<dyn Grouper>],122keys: &HashKeys,123partitioner: &HashPartitioner,124invert: bool,125probe_matches: &mut Vec<IdxSize>,126) {127let HashKeys::RowEncoded(keys) = keys else {128unreachable!()129};130assert!(partitioner.num_partitions() == groupers.len());131132unsafe {133if keys.keys.has_nulls() {134for (idx, hash) in keys.hashes.values_iter().enumerate_idx() {135let has_group = if let Some(key) = keys.keys.get_unchecked(idx as usize) {136let p = partitioner.hash_to_partition(*hash);137let dyn_grouper: &dyn Grouper = &**groupers.get_unchecked(p);138let grouper =139&*(dyn_grouper as *const dyn Grouper as *const RowEncodedHashGrouper);140grouper.contains_key(*hash, key)141} else {142false143};144145if has_group != invert {146probe_matches.push(idx);147}148}149} else {150for (idx, (hash, key)) in keys151.hashes152.values_iter()153.zip(keys.keys.values_iter())154.enumerate_idx()155{156let p = partitioner.hash_to_partition(*hash);157let dyn_grouper: &dyn Grouper = &**groupers.get_unchecked(p);158let grouper =159&*(dyn_grouper as *const dyn Grouper as *const RowEncodedHashGrouper);160if grouper.contains_key(*hash, key) != invert {161probe_matches.push(idx);162}163}164}165}166}167168/// # Safety169/// All groupers must be a RowEncodedHashGrouper.170unsafe fn contains_key_partitioned_groupers(171&self,172groupers: &[Box<dyn Grouper>],173keys: &HashKeys,174partitioner: &HashPartitioner,175invert: bool,176contains_key: &mut BitmapBuilder,177) {178let HashKeys::RowEncoded(keys) = keys else {179unreachable!()180};181assert!(partitioner.num_partitions() == groupers.len());182183unsafe {184if keys.keys.has_nulls() {185for (idx, hash) in keys.hashes.values_iter().enumerate_idx() {186let has_group = if let Some(key) = keys.keys.get_unchecked(idx as usize) {187let p = partitioner.hash_to_partition(*hash);188let dyn_grouper: &dyn Grouper = &**groupers.get_unchecked(p);189let grouper =190&*(dyn_grouper as *const dyn Grouper as *const RowEncodedHashGrouper);191grouper.contains_key(*hash, key)192} else {193false194};195196contains_key.push(has_group != invert);197}198} else {199for (hash, key) in keys.hashes.values_iter().zip(keys.keys.values_iter()) {200let p = partitioner.hash_to_partition(*hash);201let dyn_grouper: &dyn Grouper = &**groupers.get_unchecked(p);202let grouper =203&*(dyn_grouper as *const dyn Grouper as *const RowEncodedHashGrouper);204contains_key.push(grouper.contains_key(*hash, key) != invert);205}206}207}208}209210fn as_any(&self) -> &dyn Any {211self212}213}214215216