Path: blob/main/crates/polars-expr/src/state/execution_state.rs
8382 views
use std::borrow::Cow;1use std::sync::atomic::{AtomicI64, Ordering};2use std::sync::{Mutex, RwLock};3use std::time::Duration;45use arrow::bitmap::Bitmap;6use bitflags::bitflags;7use polars_core::config::verbose;8use polars_core::prelude::*;9use polars_ops::prelude::ChunkJoinOptIds;10use polars_utils::relaxed_cell::RelaxedCell;11use polars_utils::unique_id::UniqueId;1213use super::NodeTimer;14use crate::prelude::AggregationContext;1516pub type JoinTuplesCache = Arc<Mutex<PlHashMap<String, ChunkJoinOptIds>>>;1718#[derive(Default)]19pub struct WindowCache {20groups: RwLock<PlHashMap<String, GroupPositions>>,21join_tuples: RwLock<PlHashMap<String, Arc<ChunkJoinOptIds>>>,22map_idx: RwLock<PlHashMap<String, Arc<IdxCa>>>,23}2425impl WindowCache {26pub(crate) fn clear(&self) {27let Self {28groups,29join_tuples,30map_idx,31} = self;32groups.write().unwrap().clear();33join_tuples.write().unwrap().clear();34map_idx.write().unwrap().clear();35}3637pub fn get_groups(&self, key: &str) -> Option<GroupPositions> {38let g = self.groups.read().unwrap();39g.get(key).cloned()40}4142pub fn insert_groups(&self, key: String, groups: GroupPositions) {43let mut g = self.groups.write().unwrap();44g.insert(key, groups);45}4647pub fn get_join(&self, key: &str) -> Option<Arc<ChunkJoinOptIds>> {48let g = self.join_tuples.read().unwrap();49g.get(key).cloned()50}5152pub fn insert_join(&self, key: String, join_tuples: Arc<ChunkJoinOptIds>) {53let mut g = self.join_tuples.write().unwrap();54g.insert(key, join_tuples);55}5657pub fn get_map(&self, key: &str) -> Option<Arc<IdxCa>> {58let g = self.map_idx.read().unwrap();59g.get(key).cloned()60}6162pub fn insert_map(&self, key: String, idx: Arc<IdxCa>) {63let mut g = self.map_idx.write().unwrap();64g.insert(key, idx);65}66}6768bitflags! {69#[repr(transparent)]70#[derive(Copy, Clone)]71pub(super) struct StateFlags: u8 {72/// More verbose logging73const VERBOSE = 0x01;74/// Indicates that window expression's [`GroupTuples`] may be cached.75const CACHE_WINDOW_EXPR = 0x02;76/// Indicates the expression has a window function77const HAS_WINDOW = 0x04;78}79}8081impl Default for StateFlags {82fn default() -> Self {83StateFlags::CACHE_WINDOW_EXPR84}85}8687impl StateFlags {88fn init() -> Self {89let verbose = verbose();90let mut flags: StateFlags = Default::default();91if verbose {92flags |= StateFlags::VERBOSE;93}94flags95}96fn as_u8(self) -> u8 {97unsafe { std::mem::transmute(self) }98}99}100101impl From<u8> for StateFlags {102fn from(value: u8) -> Self {103unsafe { std::mem::transmute(value) }104}105}106107struct CachedValue {108/// The number of times the cache will still be read.109/// Zero means that there will be no more reads and the cache can be dropped.110remaining_hits: AtomicI64,111df: DataFrame,112}113114/// State/ cache that is maintained during the Execution of the physical plan.115#[derive(Clone)]116pub struct ExecutionState {117// cached by a `.cache` call and kept in memory for the duration of the plan.118df_cache: Arc<RwLock<PlHashMap<UniqueId, Arc<CachedValue>>>>,119pub schema_cache: Arc<RwLock<Option<SchemaRef>>>,120/// Used by Window Expressions to cache intermediate state121pub window_cache: Arc<WindowCache>,122// every join/union split gets an increment to distinguish between schema state123pub branch_idx: usize,124pub flags: RelaxedCell<u8>,125#[cfg(feature = "dtype-struct")]126pub with_fields: Option<Arc<StructChunked>>,127#[cfg(feature = "dtype-struct")]128pub with_fields_ac: Option<Arc<AggregationContext<'static>>>,129pub ext_contexts: Arc<Vec<DataFrame>>,130pub element: Arc<Option<(Column, Option<Bitmap>)>>,131node_timer: Option<NodeTimer>,132stop: Arc<RelaxedCell<bool>>,133}134135impl ExecutionState {136pub fn new() -> Self {137let mut flags: StateFlags = Default::default();138if verbose() {139flags |= StateFlags::VERBOSE;140}141Self {142df_cache: Default::default(),143schema_cache: Default::default(),144window_cache: Default::default(),145branch_idx: 0,146flags: RelaxedCell::from(StateFlags::init().as_u8()),147#[cfg(feature = "dtype-struct")]148with_fields: Default::default(),149#[cfg(feature = "dtype-struct")]150with_fields_ac: Default::default(),151ext_contexts: Default::default(),152element: Default::default(),153node_timer: None,154stop: Arc::new(RelaxedCell::from(false)),155}156}157158/// Toggle this to measure execution times.159pub fn time_nodes(&mut self, start: std::time::Instant) {160self.node_timer = Some(NodeTimer::new(start))161}162pub fn has_node_timer(&self) -> bool {163self.node_timer.is_some()164}165166pub fn finish_timer(self) -> PolarsResult<DataFrame> {167self.node_timer.unwrap().finish()168}169170// Timings should be a list of (start, end, name) where the start171// and end are raw durations since the query start as nanoseconds.172pub fn record_raw_timings(&self, timings: &[(u64, u64, String)]) {173for &(start, end, ref name) in timings {174self.node_timer.as_ref().unwrap().store_duration(175Duration::from_nanos(start),176Duration::from_nanos(end),177name.to_string(),178);179}180}181182// This is wrong when the U64 overflows which will never happen.183pub fn should_stop(&self) -> PolarsResult<()> {184try_raise_keyboard_interrupt();185polars_ensure!(!self.stop.load(), ComputeError: "query interrupted");186Ok(())187}188189pub fn cancel_token(&self) -> Arc<RelaxedCell<bool>> {190self.stop.clone()191}192193pub fn record<T, F: FnOnce() -> T>(&self, func: F, name: Cow<'static, str>) -> T {194match &self.node_timer {195None => func(),196Some(timer) => {197let start = std::time::Instant::now();198let out = func();199let end = std::time::Instant::now();200201timer.store(start, end, name.as_ref().to_string());202out203},204}205}206207/// Partially clones and partially clears state208/// This should be used when splitting a node, like a join or union209pub fn split(&self) -> Self {210Self {211df_cache: self.df_cache.clone(),212schema_cache: Default::default(),213window_cache: Default::default(),214branch_idx: self.branch_idx,215flags: self.flags.clone(),216ext_contexts: self.ext_contexts.clone(),217// Retain input values for `pl.element` in Eval context218element: self.element.clone(),219#[cfg(feature = "dtype-struct")]220with_fields: self.with_fields.clone(),221#[cfg(feature = "dtype-struct")]222with_fields_ac: self.with_fields_ac.clone(),223node_timer: self.node_timer.clone(),224stop: self.stop.clone(),225}226}227228pub fn set_schema(&self, schema: SchemaRef) {229let mut lock = self.schema_cache.write().unwrap();230*lock = Some(schema);231}232233/// Clear the schema. Typically at the end of a projection.234pub fn clear_schema_cache(&self) {235let mut lock = self.schema_cache.write().unwrap();236*lock = None;237}238239/// Get the schema.240pub fn get_schema(&self) -> Option<SchemaRef> {241let lock = self.schema_cache.read().unwrap();242lock.clone()243}244245pub fn set_df_cache(&self, id: &UniqueId, df: DataFrame, cache_hits: u32) {246if self.verbose() {247eprintln!("CACHE SET: cache id: {id}");248}249250let value = Arc::new(CachedValue {251remaining_hits: AtomicI64::new(cache_hits as i64),252df,253});254255let prev = self.df_cache.write().unwrap().insert(*id, value);256assert!(prev.is_none(), "duplicate set cache: {id}");257}258259pub fn get_df_cache(&self, id: &UniqueId) -> DataFrame {260let guard = self.df_cache.read().unwrap();261let value = guard.get(id).expect("cache not prefilled");262let remaining = value.remaining_hits.fetch_sub(1, Ordering::Relaxed);263if remaining < 0 {264panic!("cache used more times than expected: {id}");265}266if self.verbose() {267eprintln!("CACHE HIT: cache id: {id}");268}269if remaining == 1 {270drop(guard);271let value = self.df_cache.write().unwrap().remove(id).unwrap();272if self.verbose() {273eprintln!("CACHE DROP: cache id: {id}");274}275Arc::into_inner(value).unwrap().df276} else {277value.df.clone()278}279}280281/// Clear the cache used by the Window expressions282pub fn clear_window_expr_cache(&self) {283self.window_cache.clear();284}285286fn set_flags(&self, f: &dyn Fn(StateFlags) -> StateFlags) {287let flags: StateFlags = self.flags.load().into();288let flags = f(flags);289self.flags.store(flags.as_u8());290}291292/// Indicates that window expression's [`GroupTuples`] may be cached.293pub fn cache_window(&self) -> bool {294let flags: StateFlags = self.flags.load().into();295flags.contains(StateFlags::CACHE_WINDOW_EXPR)296}297298/// Indicates that window expression's [`GroupTuples`] may be cached.299pub fn has_window(&self) -> bool {300let flags: StateFlags = self.flags.load().into();301flags.contains(StateFlags::HAS_WINDOW)302}303304/// More verbose logging305pub fn verbose(&self) -> bool {306let flags: StateFlags = self.flags.load().into();307flags.contains(StateFlags::VERBOSE)308}309310pub fn remove_cache_window_flag(&mut self) {311self.set_flags(&|mut flags| {312flags.remove(StateFlags::CACHE_WINDOW_EXPR);313flags314});315}316317pub fn insert_cache_window_flag(&mut self) {318self.set_flags(&|mut flags| {319flags.insert(StateFlags::CACHE_WINDOW_EXPR);320flags321});322}323// this will trigger some conservative324pub fn insert_has_window_function_flag(&mut self) {325self.set_flags(&|mut flags| {326flags.insert(StateFlags::HAS_WINDOW);327flags328});329}330}331332impl Default for ExecutionState {333fn default() -> Self {334ExecutionState::new()335}336}337338339