Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-expr/src/state/execution_state.rs
6940 views
1
use std::borrow::Cow;
2
use std::sync::atomic::{AtomicI64, Ordering};
3
use std::sync::{Mutex, RwLock};
4
use std::time::Duration;
5
6
use bitflags::bitflags;
7
use polars_core::config::verbose;
8
use polars_core::prelude::*;
9
use polars_ops::prelude::ChunkJoinOptIds;
10
use polars_utils::relaxed_cell::RelaxedCell;
11
use polars_utils::unique_id::UniqueId;
12
13
use super::NodeTimer;
14
15
pub type JoinTuplesCache = Arc<Mutex<PlHashMap<String, ChunkJoinOptIds>>>;
16
17
#[derive(Default)]
18
pub struct WindowCache {
19
groups: RwLock<PlHashMap<String, GroupPositions>>,
20
join_tuples: RwLock<PlHashMap<String, Arc<ChunkJoinOptIds>>>,
21
map_idx: RwLock<PlHashMap<String, Arc<IdxCa>>>,
22
}
23
24
impl WindowCache {
25
pub(crate) fn clear(&self) {
26
let mut g = self.groups.write().unwrap();
27
g.clear();
28
let mut g = self.join_tuples.write().unwrap();
29
g.clear();
30
}
31
32
pub fn get_groups(&self, key: &str) -> Option<GroupPositions> {
33
let g = self.groups.read().unwrap();
34
g.get(key).cloned()
35
}
36
37
pub fn insert_groups(&self, key: String, groups: GroupPositions) {
38
let mut g = self.groups.write().unwrap();
39
g.insert(key, groups);
40
}
41
42
pub fn get_join(&self, key: &str) -> Option<Arc<ChunkJoinOptIds>> {
43
let g = self.join_tuples.read().unwrap();
44
g.get(key).cloned()
45
}
46
47
pub fn insert_join(&self, key: String, join_tuples: Arc<ChunkJoinOptIds>) {
48
let mut g = self.join_tuples.write().unwrap();
49
g.insert(key, join_tuples);
50
}
51
52
pub fn get_map(&self, key: &str) -> Option<Arc<IdxCa>> {
53
let g = self.map_idx.read().unwrap();
54
g.get(key).cloned()
55
}
56
57
pub fn insert_map(&self, key: String, idx: Arc<IdxCa>) {
58
let mut g = self.map_idx.write().unwrap();
59
g.insert(key, idx);
60
}
61
}
62
63
bitflags! {
64
#[repr(transparent)]
65
#[derive(Copy, Clone)]
66
pub(super) struct StateFlags: u8 {
67
/// More verbose logging
68
const VERBOSE = 0x01;
69
/// Indicates that window expression's [`GroupTuples`] may be cached.
70
const CACHE_WINDOW_EXPR = 0x02;
71
/// Indicates the expression has a window function
72
const HAS_WINDOW = 0x04;
73
}
74
}
75
76
impl Default for StateFlags {
77
fn default() -> Self {
78
StateFlags::CACHE_WINDOW_EXPR
79
}
80
}
81
82
impl StateFlags {
83
fn init() -> Self {
84
let verbose = verbose();
85
let mut flags: StateFlags = Default::default();
86
if verbose {
87
flags |= StateFlags::VERBOSE;
88
}
89
flags
90
}
91
fn as_u8(self) -> u8 {
92
unsafe { std::mem::transmute(self) }
93
}
94
}
95
96
impl From<u8> for StateFlags {
97
fn from(value: u8) -> Self {
98
unsafe { std::mem::transmute(value) }
99
}
100
}
101
102
struct CachedValue {
103
/// The number of times the cache will still be read.
104
/// Zero means that there will be no more reads and the cache can be dropped.
105
remaining_hits: AtomicI64,
106
df: DataFrame,
107
}
108
109
/// State/ cache that is maintained during the Execution of the physical plan.
110
#[derive(Clone)]
111
pub struct ExecutionState {
112
// cached by a `.cache` call and kept in memory for the duration of the plan.
113
df_cache: Arc<RwLock<PlHashMap<UniqueId, Arc<CachedValue>>>>,
114
pub schema_cache: Arc<RwLock<Option<SchemaRef>>>,
115
/// Used by Window Expressions to cache intermediate state
116
pub window_cache: Arc<WindowCache>,
117
// every join/union split gets an increment to distinguish between schema state
118
pub branch_idx: usize,
119
pub flags: RelaxedCell<u8>,
120
pub ext_contexts: Arc<Vec<DataFrame>>,
121
node_timer: Option<NodeTimer>,
122
stop: Arc<RelaxedCell<bool>>,
123
}
124
125
impl ExecutionState {
126
pub fn new() -> Self {
127
let mut flags: StateFlags = Default::default();
128
if verbose() {
129
flags |= StateFlags::VERBOSE;
130
}
131
Self {
132
df_cache: Default::default(),
133
schema_cache: Default::default(),
134
window_cache: Default::default(),
135
branch_idx: 0,
136
flags: RelaxedCell::from(StateFlags::init().as_u8()),
137
ext_contexts: Default::default(),
138
node_timer: None,
139
stop: Arc::new(RelaxedCell::from(false)),
140
}
141
}
142
143
/// Toggle this to measure execution times.
144
pub fn time_nodes(&mut self, start: std::time::Instant) {
145
self.node_timer = Some(NodeTimer::new(start))
146
}
147
pub fn has_node_timer(&self) -> bool {
148
self.node_timer.is_some()
149
}
150
151
pub fn finish_timer(self) -> PolarsResult<DataFrame> {
152
self.node_timer.unwrap().finish()
153
}
154
155
// Timings should be a list of (start, end, name) where the start
156
// and end are raw durations since the query start as nanoseconds.
157
pub fn record_raw_timings(&self, timings: &[(u64, u64, String)]) {
158
for &(start, end, ref name) in timings {
159
self.node_timer.as_ref().unwrap().store_duration(
160
Duration::from_nanos(start),
161
Duration::from_nanos(end),
162
name.to_string(),
163
);
164
}
165
}
166
167
// This is wrong when the U64 overflows which will never happen.
168
pub fn should_stop(&self) -> PolarsResult<()> {
169
try_raise_keyboard_interrupt();
170
polars_ensure!(!self.stop.load(), ComputeError: "query interrupted");
171
Ok(())
172
}
173
174
pub fn cancel_token(&self) -> Arc<RelaxedCell<bool>> {
175
self.stop.clone()
176
}
177
178
pub fn record<T, F: FnOnce() -> T>(&self, func: F, name: Cow<'static, str>) -> T {
179
match &self.node_timer {
180
None => func(),
181
Some(timer) => {
182
let start = std::time::Instant::now();
183
let out = func();
184
let end = std::time::Instant::now();
185
186
timer.store(start, end, name.as_ref().to_string());
187
out
188
},
189
}
190
}
191
192
/// Partially clones and partially clears state
193
/// This should be used when splitting a node, like a join or union
194
pub fn split(&self) -> Self {
195
Self {
196
df_cache: self.df_cache.clone(),
197
schema_cache: Default::default(),
198
window_cache: Default::default(),
199
branch_idx: self.branch_idx,
200
flags: self.flags.clone(),
201
ext_contexts: self.ext_contexts.clone(),
202
node_timer: self.node_timer.clone(),
203
stop: self.stop.clone(),
204
}
205
}
206
207
pub fn set_schema(&self, schema: SchemaRef) {
208
let mut lock = self.schema_cache.write().unwrap();
209
*lock = Some(schema);
210
}
211
212
/// Clear the schema. Typically at the end of a projection.
213
pub fn clear_schema_cache(&self) {
214
let mut lock = self.schema_cache.write().unwrap();
215
*lock = None;
216
}
217
218
/// Get the schema.
219
pub fn get_schema(&self) -> Option<SchemaRef> {
220
let lock = self.schema_cache.read().unwrap();
221
lock.clone()
222
}
223
224
pub fn set_df_cache(&self, id: &UniqueId, df: DataFrame, cache_hits: u32) {
225
if self.verbose() {
226
eprintln!("CACHE SET: cache id: {id}");
227
}
228
229
let value = Arc::new(CachedValue {
230
remaining_hits: AtomicI64::new(cache_hits as i64),
231
df,
232
});
233
234
let prev = self.df_cache.write().unwrap().insert(*id, value);
235
assert!(prev.is_none(), "duplicate set cache: {id}");
236
}
237
238
pub fn get_df_cache(&self, id: &UniqueId) -> DataFrame {
239
let guard = self.df_cache.read().unwrap();
240
let value = guard.get(id).expect("cache not prefilled");
241
let remaining = value.remaining_hits.fetch_sub(1, Ordering::Relaxed);
242
if remaining < 0 {
243
panic!("cache used more times than expected: {id}");
244
}
245
if self.verbose() {
246
eprintln!("CACHE HIT: cache id: {id}");
247
}
248
if remaining == 1 {
249
drop(guard);
250
let value = self.df_cache.write().unwrap().remove(id).unwrap();
251
if self.verbose() {
252
eprintln!("CACHE DROP: cache id: {id}");
253
}
254
Arc::into_inner(value).unwrap().df
255
} else {
256
value.df.clone()
257
}
258
}
259
260
/// Clear the cache used by the Window expressions
261
pub fn clear_window_expr_cache(&self) {
262
self.window_cache.clear();
263
}
264
265
fn set_flags(&self, f: &dyn Fn(StateFlags) -> StateFlags) {
266
let flags: StateFlags = self.flags.load().into();
267
let flags = f(flags);
268
self.flags.store(flags.as_u8());
269
}
270
271
/// Indicates that window expression's [`GroupTuples`] may be cached.
272
pub fn cache_window(&self) -> bool {
273
let flags: StateFlags = self.flags.load().into();
274
flags.contains(StateFlags::CACHE_WINDOW_EXPR)
275
}
276
277
/// Indicates that window expression's [`GroupTuples`] may be cached.
278
pub fn has_window(&self) -> bool {
279
let flags: StateFlags = self.flags.load().into();
280
flags.contains(StateFlags::HAS_WINDOW)
281
}
282
283
/// More verbose logging
284
pub fn verbose(&self) -> bool {
285
let flags: StateFlags = self.flags.load().into();
286
flags.contains(StateFlags::VERBOSE)
287
}
288
289
pub fn remove_cache_window_flag(&mut self) {
290
self.set_flags(&|mut flags| {
291
flags.remove(StateFlags::CACHE_WINDOW_EXPR);
292
flags
293
});
294
}
295
296
pub fn insert_cache_window_flag(&mut self) {
297
self.set_flags(&|mut flags| {
298
flags.insert(StateFlags::CACHE_WINDOW_EXPR);
299
flags
300
});
301
}
302
// this will trigger some conservative
303
pub fn insert_has_window_function_flag(&mut self) {
304
self.set_flags(&|mut flags| {
305
flags.insert(StateFlags::HAS_WINDOW);
306
flags
307
});
308
}
309
}
310
311
impl Default for ExecutionState {
312
fn default() -> Self {
313
ExecutionState::new()
314
}
315
}
316
317