Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-expr/src/state/execution_state.rs
8382 views
1
use std::borrow::Cow;
2
use std::sync::atomic::{AtomicI64, Ordering};
3
use std::sync::{Mutex, RwLock};
4
use std::time::Duration;
5
6
use arrow::bitmap::Bitmap;
7
use bitflags::bitflags;
8
use polars_core::config::verbose;
9
use polars_core::prelude::*;
10
use polars_ops::prelude::ChunkJoinOptIds;
11
use polars_utils::relaxed_cell::RelaxedCell;
12
use polars_utils::unique_id::UniqueId;
13
14
use super::NodeTimer;
15
use crate::prelude::AggregationContext;
16
17
pub type JoinTuplesCache = Arc<Mutex<PlHashMap<String, ChunkJoinOptIds>>>;
18
19
#[derive(Default)]
20
pub struct WindowCache {
21
groups: RwLock<PlHashMap<String, GroupPositions>>,
22
join_tuples: RwLock<PlHashMap<String, Arc<ChunkJoinOptIds>>>,
23
map_idx: RwLock<PlHashMap<String, Arc<IdxCa>>>,
24
}
25
26
impl WindowCache {
27
pub(crate) fn clear(&self) {
28
let Self {
29
groups,
30
join_tuples,
31
map_idx,
32
} = self;
33
groups.write().unwrap().clear();
34
join_tuples.write().unwrap().clear();
35
map_idx.write().unwrap().clear();
36
}
37
38
pub fn get_groups(&self, key: &str) -> Option<GroupPositions> {
39
let g = self.groups.read().unwrap();
40
g.get(key).cloned()
41
}
42
43
pub fn insert_groups(&self, key: String, groups: GroupPositions) {
44
let mut g = self.groups.write().unwrap();
45
g.insert(key, groups);
46
}
47
48
pub fn get_join(&self, key: &str) -> Option<Arc<ChunkJoinOptIds>> {
49
let g = self.join_tuples.read().unwrap();
50
g.get(key).cloned()
51
}
52
53
pub fn insert_join(&self, key: String, join_tuples: Arc<ChunkJoinOptIds>) {
54
let mut g = self.join_tuples.write().unwrap();
55
g.insert(key, join_tuples);
56
}
57
58
pub fn get_map(&self, key: &str) -> Option<Arc<IdxCa>> {
59
let g = self.map_idx.read().unwrap();
60
g.get(key).cloned()
61
}
62
63
pub fn insert_map(&self, key: String, idx: Arc<IdxCa>) {
64
let mut g = self.map_idx.write().unwrap();
65
g.insert(key, idx);
66
}
67
}
68
69
bitflags! {
70
#[repr(transparent)]
71
#[derive(Copy, Clone)]
72
pub(super) struct StateFlags: u8 {
73
/// More verbose logging
74
const VERBOSE = 0x01;
75
/// Indicates that window expression's [`GroupTuples`] may be cached.
76
const CACHE_WINDOW_EXPR = 0x02;
77
/// Indicates the expression has a window function
78
const HAS_WINDOW = 0x04;
79
}
80
}
81
82
impl Default for StateFlags {
83
fn default() -> Self {
84
StateFlags::CACHE_WINDOW_EXPR
85
}
86
}
87
88
impl StateFlags {
89
fn init() -> Self {
90
let verbose = verbose();
91
let mut flags: StateFlags = Default::default();
92
if verbose {
93
flags |= StateFlags::VERBOSE;
94
}
95
flags
96
}
97
fn as_u8(self) -> u8 {
98
unsafe { std::mem::transmute(self) }
99
}
100
}
101
102
impl From<u8> for StateFlags {
103
fn from(value: u8) -> Self {
104
unsafe { std::mem::transmute(value) }
105
}
106
}
107
108
struct CachedValue {
109
/// The number of times the cache will still be read.
110
/// Zero means that there will be no more reads and the cache can be dropped.
111
remaining_hits: AtomicI64,
112
df: DataFrame,
113
}
114
115
/// State/ cache that is maintained during the Execution of the physical plan.
116
#[derive(Clone)]
117
pub struct ExecutionState {
118
// cached by a `.cache` call and kept in memory for the duration of the plan.
119
df_cache: Arc<RwLock<PlHashMap<UniqueId, Arc<CachedValue>>>>,
120
pub schema_cache: Arc<RwLock<Option<SchemaRef>>>,
121
/// Used by Window Expressions to cache intermediate state
122
pub window_cache: Arc<WindowCache>,
123
// every join/union split gets an increment to distinguish between schema state
124
pub branch_idx: usize,
125
pub flags: RelaxedCell<u8>,
126
#[cfg(feature = "dtype-struct")]
127
pub with_fields: Option<Arc<StructChunked>>,
128
#[cfg(feature = "dtype-struct")]
129
pub with_fields_ac: Option<Arc<AggregationContext<'static>>>,
130
pub ext_contexts: Arc<Vec<DataFrame>>,
131
pub element: Arc<Option<(Column, Option<Bitmap>)>>,
132
node_timer: Option<NodeTimer>,
133
stop: Arc<RelaxedCell<bool>>,
134
}
135
136
impl ExecutionState {
137
pub fn new() -> Self {
138
let mut flags: StateFlags = Default::default();
139
if verbose() {
140
flags |= StateFlags::VERBOSE;
141
}
142
Self {
143
df_cache: Default::default(),
144
schema_cache: Default::default(),
145
window_cache: Default::default(),
146
branch_idx: 0,
147
flags: RelaxedCell::from(StateFlags::init().as_u8()),
148
#[cfg(feature = "dtype-struct")]
149
with_fields: Default::default(),
150
#[cfg(feature = "dtype-struct")]
151
with_fields_ac: Default::default(),
152
ext_contexts: Default::default(),
153
element: Default::default(),
154
node_timer: None,
155
stop: Arc::new(RelaxedCell::from(false)),
156
}
157
}
158
159
/// Toggle this to measure execution times.
160
pub fn time_nodes(&mut self, start: std::time::Instant) {
161
self.node_timer = Some(NodeTimer::new(start))
162
}
163
pub fn has_node_timer(&self) -> bool {
164
self.node_timer.is_some()
165
}
166
167
pub fn finish_timer(self) -> PolarsResult<DataFrame> {
168
self.node_timer.unwrap().finish()
169
}
170
171
// Timings should be a list of (start, end, name) where the start
172
// and end are raw durations since the query start as nanoseconds.
173
pub fn record_raw_timings(&self, timings: &[(u64, u64, String)]) {
174
for &(start, end, ref name) in timings {
175
self.node_timer.as_ref().unwrap().store_duration(
176
Duration::from_nanos(start),
177
Duration::from_nanos(end),
178
name.to_string(),
179
);
180
}
181
}
182
183
// This is wrong when the U64 overflows which will never happen.
184
pub fn should_stop(&self) -> PolarsResult<()> {
185
try_raise_keyboard_interrupt();
186
polars_ensure!(!self.stop.load(), ComputeError: "query interrupted");
187
Ok(())
188
}
189
190
pub fn cancel_token(&self) -> Arc<RelaxedCell<bool>> {
191
self.stop.clone()
192
}
193
194
pub fn record<T, F: FnOnce() -> T>(&self, func: F, name: Cow<'static, str>) -> T {
195
match &self.node_timer {
196
None => func(),
197
Some(timer) => {
198
let start = std::time::Instant::now();
199
let out = func();
200
let end = std::time::Instant::now();
201
202
timer.store(start, end, name.as_ref().to_string());
203
out
204
},
205
}
206
}
207
208
/// Partially clones and partially clears state
209
/// This should be used when splitting a node, like a join or union
210
pub fn split(&self) -> Self {
211
Self {
212
df_cache: self.df_cache.clone(),
213
schema_cache: Default::default(),
214
window_cache: Default::default(),
215
branch_idx: self.branch_idx,
216
flags: self.flags.clone(),
217
ext_contexts: self.ext_contexts.clone(),
218
// Retain input values for `pl.element` in Eval context
219
element: self.element.clone(),
220
#[cfg(feature = "dtype-struct")]
221
with_fields: self.with_fields.clone(),
222
#[cfg(feature = "dtype-struct")]
223
with_fields_ac: self.with_fields_ac.clone(),
224
node_timer: self.node_timer.clone(),
225
stop: self.stop.clone(),
226
}
227
}
228
229
pub fn set_schema(&self, schema: SchemaRef) {
230
let mut lock = self.schema_cache.write().unwrap();
231
*lock = Some(schema);
232
}
233
234
/// Clear the schema. Typically at the end of a projection.
235
pub fn clear_schema_cache(&self) {
236
let mut lock = self.schema_cache.write().unwrap();
237
*lock = None;
238
}
239
240
/// Get the schema.
241
pub fn get_schema(&self) -> Option<SchemaRef> {
242
let lock = self.schema_cache.read().unwrap();
243
lock.clone()
244
}
245
246
pub fn set_df_cache(&self, id: &UniqueId, df: DataFrame, cache_hits: u32) {
247
if self.verbose() {
248
eprintln!("CACHE SET: cache id: {id}");
249
}
250
251
let value = Arc::new(CachedValue {
252
remaining_hits: AtomicI64::new(cache_hits as i64),
253
df,
254
});
255
256
let prev = self.df_cache.write().unwrap().insert(*id, value);
257
assert!(prev.is_none(), "duplicate set cache: {id}");
258
}
259
260
pub fn get_df_cache(&self, id: &UniqueId) -> DataFrame {
261
let guard = self.df_cache.read().unwrap();
262
let value = guard.get(id).expect("cache not prefilled");
263
let remaining = value.remaining_hits.fetch_sub(1, Ordering::Relaxed);
264
if remaining < 0 {
265
panic!("cache used more times than expected: {id}");
266
}
267
if self.verbose() {
268
eprintln!("CACHE HIT: cache id: {id}");
269
}
270
if remaining == 1 {
271
drop(guard);
272
let value = self.df_cache.write().unwrap().remove(id).unwrap();
273
if self.verbose() {
274
eprintln!("CACHE DROP: cache id: {id}");
275
}
276
Arc::into_inner(value).unwrap().df
277
} else {
278
value.df.clone()
279
}
280
}
281
282
/// Clear the cache used by the Window expressions
283
pub fn clear_window_expr_cache(&self) {
284
self.window_cache.clear();
285
}
286
287
fn set_flags(&self, f: &dyn Fn(StateFlags) -> StateFlags) {
288
let flags: StateFlags = self.flags.load().into();
289
let flags = f(flags);
290
self.flags.store(flags.as_u8());
291
}
292
293
/// Indicates that window expression's [`GroupTuples`] may be cached.
294
pub fn cache_window(&self) -> bool {
295
let flags: StateFlags = self.flags.load().into();
296
flags.contains(StateFlags::CACHE_WINDOW_EXPR)
297
}
298
299
/// Indicates that window expression's [`GroupTuples`] may be cached.
300
pub fn has_window(&self) -> bool {
301
let flags: StateFlags = self.flags.load().into();
302
flags.contains(StateFlags::HAS_WINDOW)
303
}
304
305
/// More verbose logging
306
pub fn verbose(&self) -> bool {
307
let flags: StateFlags = self.flags.load().into();
308
flags.contains(StateFlags::VERBOSE)
309
}
310
311
pub fn remove_cache_window_flag(&mut self) {
312
self.set_flags(&|mut flags| {
313
flags.remove(StateFlags::CACHE_WINDOW_EXPR);
314
flags
315
});
316
}
317
318
pub fn insert_cache_window_flag(&mut self) {
319
self.set_flags(&|mut flags| {
320
flags.insert(StateFlags::CACHE_WINDOW_EXPR);
321
flags
322
});
323
}
324
// this will trigger some conservative
325
pub fn insert_has_window_function_flag(&mut self) {
326
self.set_flags(&|mut flags| {
327
flags.insert(StateFlags::HAS_WINDOW);
328
flags
329
});
330
}
331
}
332
333
impl Default for ExecutionState {
334
fn default() -> Self {
335
ExecutionState::new()
336
}
337
}
338
339