Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-stream/src/nodes/joins/utils.rs
8430 views
1
use std::collections::BTreeMap;
2
3
use polars_core::frame::DataFrame;
4
use polars_core::prelude::*;
5
use polars_core::schema::SchemaRef;
6
use polars_core::series::Series;
7
8
#[derive(Clone, Debug)]
9
pub(super) struct DataFrameSearchBuffer {
10
schema: SchemaRef,
11
dfs_at_offsets: BTreeMap<usize, DataFrame>,
12
total_rows: usize,
13
skip_rows: usize,
14
frozen: bool,
15
}
16
17
impl DataFrameSearchBuffer {
18
pub(super) fn empty_with_schema(schema: SchemaRef) -> Self {
19
DataFrameSearchBuffer {
20
schema,
21
dfs_at_offsets: BTreeMap::new(),
22
total_rows: 0,
23
skip_rows: 0,
24
frozen: false,
25
}
26
}
27
28
pub(super) fn height(&self) -> usize {
29
self.total_rows
30
}
31
32
/// Get the `row_index`th value from the `column` bypassing its validity bitmap.
33
///
34
/// SAFETY: Caller must ensure that `row_index` is within bounds.
35
pub(super) unsafe fn get_bypass_validity(
36
&self,
37
column: &str,
38
row_index: usize,
39
bypass_validity: bool,
40
) -> AnyValue<'_> {
41
debug_assert!(row_index < self.total_rows);
42
let first_offset = match self.dfs_at_offsets.first_key_value() {
43
Some((offset, _)) => *offset,
44
None => 0,
45
};
46
let buf_index = self.skip_rows + first_offset + row_index;
47
let (df_offset, df) = self.dfs_at_offsets.range(..=buf_index).next_back().unwrap();
48
let series_index = buf_index - df_offset;
49
let series = df.column(column).unwrap().as_materialized_series();
50
unsafe { series_get_bypass_validity(series, series_index, bypass_validity) }
51
}
52
53
pub(super) fn push_df(&mut self, df: DataFrame) {
54
assert!(!self.frozen);
55
let added_rows = df.height();
56
let offset = match self.dfs_at_offsets.last_key_value() {
57
Some((last_key, last_df)) => last_key + last_df.height(),
58
None => 0,
59
};
60
self.dfs_at_offsets.insert(offset, df);
61
self.total_rows += added_rows;
62
}
63
64
pub(super) fn split_at(&mut self, mut at: usize) -> Self {
65
at = at.clamp(0, self.total_rows);
66
let mut top = self.clone();
67
top.total_rows = at;
68
top.frozen = true;
69
self.skip_rows += at;
70
self.total_rows -= at;
71
self.gc();
72
top
73
}
74
75
pub(super) fn slice(mut self, offset: usize, len: usize) -> Self {
76
self.skip_rows += offset;
77
self.total_rows -= offset;
78
self.total_rows = usize::min(self.total_rows, len);
79
self.frozen = true;
80
self
81
}
82
83
pub(super) fn into_df(self) -> DataFrame {
84
let mut acc = DataFrame::empty_with_schema(&self.schema);
85
for df in self.dfs_at_offsets.into_values() {
86
acc.vstack_mut_owned(df).unwrap();
87
}
88
acc.slice(self.skip_rows as i64, self.total_rows)
89
}
90
91
fn gc(&mut self) {
92
while let Some((_, df)) = self.dfs_at_offsets.first_key_value() {
93
if self.skip_rows > df.height() {
94
let (_, df) = self.dfs_at_offsets.pop_first().unwrap();
95
self.skip_rows -= df.height();
96
} else {
97
break;
98
}
99
}
100
}
101
102
pub(super) fn is_empty(&self) -> bool {
103
self.total_rows == 0
104
}
105
106
/// Find the index of the first item in the buffer that satisfies `predicate`,
107
/// assuming it is first always false and then always true.
108
pub(super) fn binary_search<P>(
109
&self,
110
predicate: P,
111
key_col_name: &str,
112
binary_offset_bypass_validity: bool,
113
) -> usize
114
where
115
P: Fn(&AnyValue<'_>) -> bool,
116
{
117
let mut lower = 0;
118
let mut upper = self.height();
119
while lower < upper {
120
let mid = (lower + upper) / 2;
121
let mid_val = unsafe {
122
self.get_bypass_validity(key_col_name, mid, binary_offset_bypass_validity)
123
};
124
if predicate(&mid_val) {
125
upper = mid;
126
} else {
127
lower = mid + 1;
128
}
129
}
130
lower
131
}
132
}
133
134
/// Get value from series bypassing the validity bitmap.
135
///
136
/// SAFETY: Caller must ensure that `index` is within bounds of `s`.
137
unsafe fn series_get_bypass_validity<'a>(
138
s: &'a Series,
139
index: usize,
140
binary_offset_bypass_validity: bool,
141
) -> AnyValue<'a> {
142
debug_assert!(index < s.len());
143
if binary_offset_bypass_validity {
144
let arr = s.binary_offset().unwrap();
145
unsafe { arr.get_any_value_bypass_validity(index) }
146
} else {
147
unsafe { s.get_unchecked(index) }
148
}
149
}
150
151