Path: blob/main/crates/polars-stream/src/async_executor/mod.rs
6939 views
#![allow(clippy::disallowed_types)]12mod park_group;3mod task;45use std::cell::{Cell, UnsafeCell};6use std::collections::HashMap;7use std::future::Future;8use std::marker::PhantomData;9use std::panic::{AssertUnwindSafe, Location};10use std::sync::atomic::{AtomicBool, Ordering};11use std::sync::{Arc, LazyLock, OnceLock, Weak};12use std::time::Duration;1314use crossbeam_deque::{Injector, Steal, Stealer, Worker as WorkQueue};15use crossbeam_utils::CachePadded;16use park_group::ParkGroup;17use parking_lot::Mutex;18use polars_utils::relaxed_cell::RelaxedCell;19use rand::rngs::SmallRng;20use rand::{Rng, SeedableRng};21use slotmap::SlotMap;22pub use task::{AbortOnDropHandle, JoinHandle};23use task::{CancelHandle, Runnable};2425static NUM_EXECUTOR_THREADS: RelaxedCell<usize> = RelaxedCell::new_usize(0);26pub fn set_num_threads(t: usize) {27NUM_EXECUTOR_THREADS.store(t);28}2930static GLOBAL_SCHEDULER: OnceLock<Executor> = OnceLock::new();3132thread_local!(33/// Used to store which executor thread this is.34static TLS_THREAD_ID: Cell<usize> = const { Cell::new(usize::MAX) };35);3637static NS_SPENT_BLOCKED: LazyLock<Mutex<HashMap<&'static Location<'static>, u64>>> =38LazyLock::new(Mutex::default);3940static TRACK_WAIT_STATISTICS: RelaxedCell<bool> = RelaxedCell::new_bool(false);4142pub fn track_task_wait_statistics(should_track: bool) {43TRACK_WAIT_STATISTICS.store(should_track);44}4546pub fn get_task_wait_statistics() -> Vec<(&'static Location<'static>, Duration)> {47NS_SPENT_BLOCKED48.lock()49.iter()50.map(|(l, ns)| (*l, Duration::from_nanos(*ns)))51.collect()52}5354pub fn clear_task_wait_statistics() {55NS_SPENT_BLOCKED.lock().clear()56}5758slotmap::new_key_type! {59struct TaskKey;60}6162/// High priority tasks are scheduled preferentially over low priority tasks.63#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]64pub enum TaskPriority {65Low,66High,67}6869/// Metadata associated with a task to help schedule it and clean it up.70struct ScopedTaskMetadata {71task_key: TaskKey,72completed_tasks: Weak<Mutex<Vec<TaskKey>>>,73}7475struct TaskMetadata {76spawn_location: &'static Location<'static>,77ns_spent_blocked: RelaxedCell<u64>,78priority: TaskPriority,79freshly_spawned: AtomicBool,80scoped: Option<ScopedTaskMetadata>,81}8283impl Drop for TaskMetadata {84fn drop(&mut self) {85*NS_SPENT_BLOCKED86.lock()87.entry(self.spawn_location)88.or_default() += self.ns_spent_blocked.load();89if let Some(scoped) = &self.scoped {90if let Some(completed_tasks) = scoped.completed_tasks.upgrade() {91completed_tasks.lock().push(scoped.task_key);92}93}94}95}9697/// A task ready to run.98type ReadyTask = Runnable<TaskMetadata>;99100/// A per-thread task list.101struct ThreadLocalTaskList {102// May be used from any thread.103high_prio_tasks_stealer: Stealer<ReadyTask>,104105// SAFETY: these may only be used on the thread this task list belongs to.106high_prio_tasks: WorkQueue<ReadyTask>,107local_slot: UnsafeCell<Option<ReadyTask>>,108}109110unsafe impl Sync for ThreadLocalTaskList {}111112struct Executor {113park_group: ParkGroup,114thread_task_lists: Vec<CachePadded<ThreadLocalTaskList>>,115global_high_prio_task_queue: Injector<ReadyTask>,116global_low_prio_task_queue: Injector<ReadyTask>,117}118119impl Executor {120fn schedule_task(&self, task: ReadyTask) {121let thread = TLS_THREAD_ID.get();122let meta = task.metadata();123let opt_ttl = self.thread_task_lists.get(thread);124125let mut use_global_queue = opt_ttl.is_none();126if meta.freshly_spawned.load(Ordering::Relaxed) {127use_global_queue = true;128meta.freshly_spawned.store(false, Ordering::Relaxed);129}130131if use_global_queue {132// Scheduled from an unknown thread, add to global queue.133if meta.priority == TaskPriority::High {134self.global_high_prio_task_queue.push(task);135} else {136self.global_low_prio_task_queue.push(task);137}138self.park_group.unpark_one();139} else {140let ttl = opt_ttl.unwrap();141// SAFETY: this slot may only be accessed from the local thread, which we are.142let slot = unsafe { &mut *ttl.local_slot.get() };143144if meta.priority == TaskPriority::High {145// Insert new task into thread local slot, taking out the old task.146let Some(task) = slot.replace(task) else {147// We pushed a task into our local slot which was empty. Since148// we are already awake, no need to notify anyone.149return;150};151152ttl.high_prio_tasks.push(task);153self.park_group.unpark_one();154} else {155// Optimization: while this is a low priority task we have no156// high priority tasks on this thread so we'll execute this one.157if ttl.high_prio_tasks.is_empty() && slot.is_none() {158*slot = Some(task);159} else {160self.global_low_prio_task_queue.push(task);161self.park_group.unpark_one();162}163}164}165}166167fn try_steal_task<R: Rng>(&self, thread: usize, rng: &mut R) -> Option<ReadyTask> {168// Try to get a global task.169loop {170match self.global_high_prio_task_queue.steal() {171Steal::Empty => break,172Steal::Success(task) => return Some(task),173Steal::Retry => std::hint::spin_loop(),174}175}176177loop {178match self.global_low_prio_task_queue.steal() {179Steal::Empty => break,180Steal::Success(task) => return Some(task),181Steal::Retry => std::hint::spin_loop(),182}183}184185// Try to steal tasks.186let ttl = &self.thread_task_lists[thread];187for _ in 0..4 {188let mut retry = true;189while retry {190retry = false;191192for idx in random_permutation(self.thread_task_lists.len() as u32, rng) {193let foreign_ttl = &self.thread_task_lists[idx as usize];194match foreign_ttl195.high_prio_tasks_stealer196.steal_batch_and_pop(&ttl.high_prio_tasks)197{198Steal::Empty => {},199Steal::Success(task) => return Some(task),200Steal::Retry => retry = true,201}202}203204std::hint::spin_loop()205}206}207208None209}210211fn runner(&self, thread: usize) {212TLS_THREAD_ID.set(thread);213214let mut rng = SmallRng::from_rng(&mut rand::rng());215let mut worker = self.park_group.new_worker();216let mut last_block_start = None;217218loop {219let ttl = &self.thread_task_lists[thread];220let task = (|| {221// Try to get a task from LIFO slot.222if let Some(task) = unsafe { (*ttl.local_slot.get()).take() } {223return Some(task);224}225226// Try to get a local high-priority task.227if let Some(task) = ttl.high_prio_tasks.pop() {228return Some(task);229}230231// Try to steal a task.232if let Some(task) = self.try_steal_task(thread, &mut rng) {233return Some(task);234}235236// Prepare to park, then try one more steal attempt.237let park = worker.prepare_park();238if let Some(task) = self.try_steal_task(thread, &mut rng) {239return Some(task);240}241242if last_block_start.is_none() && TRACK_WAIT_STATISTICS.load() {243last_block_start = Some(std::time::Instant::now());244}245park.park();246None247})();248249if let Some(task) = task {250if let Some(t) = last_block_start.take() {251if TRACK_WAIT_STATISTICS.load() {252let ns: u64 = t.elapsed().as_nanos().try_into().unwrap();253task.metadata().ns_spent_blocked.fetch_add(ns);254}255}256worker.recruit_next();257task.run();258}259}260}261262fn global() -> &'static Executor {263GLOBAL_SCHEDULER.get_or_init(|| {264let mut n_threads = NUM_EXECUTOR_THREADS.load();265if n_threads == 0 {266n_threads = std::thread::available_parallelism()267.map(|n| n.get())268.unwrap_or(4);269}270271let thread_task_lists = (0..n_threads)272.map(|t| {273std::thread::Builder::new()274.name(format!("async-executor-{t}"))275.spawn(move || Self::global().runner(t))276.unwrap();277278let high_prio_tasks = WorkQueue::new_lifo();279CachePadded::new(ThreadLocalTaskList {280high_prio_tasks_stealer: high_prio_tasks.stealer(),281high_prio_tasks,282local_slot: UnsafeCell::new(None),283})284})285.collect();286Self {287park_group: ParkGroup::new(),288thread_task_lists,289global_high_prio_task_queue: Injector::new(),290global_low_prio_task_queue: Injector::new(),291}292})293}294}295296pub struct TaskScope<'scope, 'env: 'scope> {297// Keep track of in-progress tasks so we can forcibly cancel them298// when the scope ends, to ensure the lifetimes are respected.299// Tasks add their own key to completed_tasks when done so we can300// reclaim the memory used by the cancel_handles.301cancel_handles: Mutex<SlotMap<TaskKey, CancelHandle>>,302completed_tasks: Arc<Mutex<Vec<TaskKey>>>,303304// Copied from std::thread::scope. Necessary to prevent unsoundness.305scope: PhantomData<&'scope mut &'scope ()>,306env: PhantomData<&'env mut &'env ()>,307}308309impl<'scope> TaskScope<'scope, '_> {310// Not Drop because that extends lifetimes.311fn destroy(&self) {312// Make sure all tasks are cancelled.313for (_, t) in self.cancel_handles.lock().drain() {314t.cancel();315}316}317318fn clear_completed_tasks(&self) {319let mut cancel_handles = self.cancel_handles.lock();320for t in self.completed_tasks.lock().drain(..) {321cancel_handles.remove(t);322}323}324325#[track_caller]326pub fn spawn_task<F: Future + Send + 'scope>(327&self,328priority: TaskPriority,329fut: F,330) -> JoinHandle<F::Output>331where332<F as Future>::Output: Send + 'static,333{334let spawn_location = Location::caller();335self.clear_completed_tasks();336337let mut runnable = None;338let mut join_handle = None;339self.cancel_handles.lock().insert_with_key(|task_key| {340let (run, jh) = unsafe {341// SAFETY: we make sure to cancel this task before 'scope ends.342let executor = Executor::global();343let on_wake = move |task| executor.schedule_task(task);344task::spawn_with_lifetime(345fut,346on_wake,347TaskMetadata {348spawn_location,349ns_spent_blocked: RelaxedCell::new_u64(0),350priority,351freshly_spawned: AtomicBool::new(true),352scoped: Some(ScopedTaskMetadata {353task_key,354completed_tasks: Arc::downgrade(&self.completed_tasks),355}),356},357)358};359let cancel_handle = jh.cancel_handle();360runnable = Some(run);361join_handle = Some(jh);362cancel_handle363});364runnable.unwrap().schedule();365join_handle.unwrap()366}367}368369pub fn task_scope<'env, F, T>(f: F) -> T370where371F: for<'scope> FnOnce(&'scope TaskScope<'scope, 'env>) -> T,372{373// By having this local variable inaccessible to anyone we guarantee374// that either abort is called killing the entire process, or that this375// executor is properly destroyed.376let scope = TaskScope {377cancel_handles: Mutex::default(),378completed_tasks: Arc::new(Mutex::default()),379scope: PhantomData,380env: PhantomData,381};382383let result = std::panic::catch_unwind(AssertUnwindSafe(|| f(&scope)));384385// Make sure all tasks are properly destroyed.386scope.destroy();387388match result {389Err(e) => std::panic::resume_unwind(e),390Ok(result) => result,391}392}393394#[track_caller]395pub fn spawn<F: Future + Send + 'static>(priority: TaskPriority, fut: F) -> JoinHandle<F::Output>396where397<F as Future>::Output: Send + 'static,398{399let spawn_location = Location::caller();400let executor = Executor::global();401let on_wake = move |task| executor.schedule_task(task);402let (runnable, join_handle) = task::spawn(403fut,404on_wake,405TaskMetadata {406spawn_location,407ns_spent_blocked: RelaxedCell::new_u64(0),408priority,409freshly_spawned: AtomicBool::new(true),410scoped: None,411},412);413runnable.schedule();414join_handle415}416417fn random_permutation<R: Rng>(len: u32, rng: &mut R) -> impl Iterator<Item = u32> {418let modulus = len.next_power_of_two();419let halfwidth = modulus.trailing_zeros() / 2;420let mask = modulus - 1;421let displace_zero = rng.random::<u32>();422let odd1 = rng.random::<u32>() | 1;423let odd2 = rng.random::<u32>() | 1;424let uniform_first = ((rng.random::<u32>() as u64 * len as u64) >> 32) as u32;425426(0..modulus)427.map(move |mut i| {428// Invertible permutation on [0, modulus).429i = i.wrapping_add(displace_zero);430i = i.wrapping_mul(odd1);431i ^= (i & mask) >> halfwidth;432i = i.wrapping_mul(odd2);433i & mask434})435.filter(move |i| *i < len)436.map(move |mut i| {437i += uniform_first;438if i >= len {439i -= len;440}441i442})443}444445446