Path: blob/main/crates/polars-stream/src/nodes/joins/utils.rs
8430 views
use std::collections::BTreeMap;12use polars_core::frame::DataFrame;3use polars_core::prelude::*;4use polars_core::schema::SchemaRef;5use polars_core::series::Series;67#[derive(Clone, Debug)]8pub(super) struct DataFrameSearchBuffer {9schema: SchemaRef,10dfs_at_offsets: BTreeMap<usize, DataFrame>,11total_rows: usize,12skip_rows: usize,13frozen: bool,14}1516impl DataFrameSearchBuffer {17pub(super) fn empty_with_schema(schema: SchemaRef) -> Self {18DataFrameSearchBuffer {19schema,20dfs_at_offsets: BTreeMap::new(),21total_rows: 0,22skip_rows: 0,23frozen: false,24}25}2627pub(super) fn height(&self) -> usize {28self.total_rows29}3031/// Get the `row_index`th value from the `column` bypassing its validity bitmap.32///33/// SAFETY: Caller must ensure that `row_index` is within bounds.34pub(super) unsafe fn get_bypass_validity(35&self,36column: &str,37row_index: usize,38bypass_validity: bool,39) -> AnyValue<'_> {40debug_assert!(row_index < self.total_rows);41let first_offset = match self.dfs_at_offsets.first_key_value() {42Some((offset, _)) => *offset,43None => 0,44};45let buf_index = self.skip_rows + first_offset + row_index;46let (df_offset, df) = self.dfs_at_offsets.range(..=buf_index).next_back().unwrap();47let series_index = buf_index - df_offset;48let series = df.column(column).unwrap().as_materialized_series();49unsafe { series_get_bypass_validity(series, series_index, bypass_validity) }50}5152pub(super) fn push_df(&mut self, df: DataFrame) {53assert!(!self.frozen);54let added_rows = df.height();55let offset = match self.dfs_at_offsets.last_key_value() {56Some((last_key, last_df)) => last_key + last_df.height(),57None => 0,58};59self.dfs_at_offsets.insert(offset, df);60self.total_rows += added_rows;61}6263pub(super) fn split_at(&mut self, mut at: usize) -> Self {64at = at.clamp(0, self.total_rows);65let mut top = self.clone();66top.total_rows = at;67top.frozen = true;68self.skip_rows += at;69self.total_rows -= at;70self.gc();71top72}7374pub(super) fn slice(mut self, offset: usize, len: usize) -> Self {75self.skip_rows += offset;76self.total_rows -= offset;77self.total_rows = usize::min(self.total_rows, len);78self.frozen = true;79self80}8182pub(super) fn into_df(self) -> DataFrame {83let mut acc = DataFrame::empty_with_schema(&self.schema);84for df in self.dfs_at_offsets.into_values() {85acc.vstack_mut_owned(df).unwrap();86}87acc.slice(self.skip_rows as i64, self.total_rows)88}8990fn gc(&mut self) {91while let Some((_, df)) = self.dfs_at_offsets.first_key_value() {92if self.skip_rows > df.height() {93let (_, df) = self.dfs_at_offsets.pop_first().unwrap();94self.skip_rows -= df.height();95} else {96break;97}98}99}100101pub(super) fn is_empty(&self) -> bool {102self.total_rows == 0103}104105/// Find the index of the first item in the buffer that satisfies `predicate`,106/// assuming it is first always false and then always true.107pub(super) fn binary_search<P>(108&self,109predicate: P,110key_col_name: &str,111binary_offset_bypass_validity: bool,112) -> usize113where114P: Fn(&AnyValue<'_>) -> bool,115{116let mut lower = 0;117let mut upper = self.height();118while lower < upper {119let mid = (lower + upper) / 2;120let mid_val = unsafe {121self.get_bypass_validity(key_col_name, mid, binary_offset_bypass_validity)122};123if predicate(&mid_val) {124upper = mid;125} else {126lower = mid + 1;127}128}129lower130}131}132133/// Get value from series bypassing the validity bitmap.134///135/// SAFETY: Caller must ensure that `index` is within bounds of `s`.136unsafe fn series_get_bypass_validity<'a>(137s: &'a Series,138index: usize,139binary_offset_bypass_validity: bool,140) -> AnyValue<'a> {141debug_assert!(index < s.len());142if binary_offset_bypass_validity {143let arr = s.binary_offset().unwrap();144unsafe { arr.get_any_value_bypass_validity(index) }145} else {146unsafe { s.get_unchecked(index) }147}148}149150151