Path: blob/main/crates/polars-stream/src/async_executor/mod.rs
8406 views
#![allow(clippy::disallowed_types)]12mod park_group;3mod task;45use std::cell::{Cell, UnsafeCell};6use std::future::Future;7use std::marker::PhantomData;8use std::panic::{AssertUnwindSafe, Location};9use std::pin::Pin;10use std::sync::atomic::{AtomicBool, Ordering};11use std::sync::{Arc, OnceLock, Weak};12use std::task::{Context, Poll};13use std::time::Instant;1415use crossbeam_deque::{Injector, Steal, Stealer, Worker as WorkQueue};16use crossbeam_utils::CachePadded;17use park_group::ParkGroup;18use parking_lot::Mutex;19use polars_core::ALLOW_RAYON_THREADS;20use polars_utils::relaxed_cell::RelaxedCell;21use rand::rngs::SmallRng;22use rand::{Rng, SeedableRng};23use slotmap::SlotMap;24use task::{Cancellable, DynTask, Runnable};2526static NUM_EXECUTOR_THREADS: RelaxedCell<usize> = RelaxedCell::new_usize(0);27pub fn set_num_threads(t: usize) {28NUM_EXECUTOR_THREADS.store(t);29}3031static TRACK_METRICS: RelaxedCell<bool> = RelaxedCell::new_bool(false);3233pub fn track_task_metrics(should_track: bool) {34TRACK_METRICS.store(should_track);35}3637static GLOBAL_SCHEDULER: OnceLock<Executor> = OnceLock::new();3839thread_local!(40/// Used to store which executor thread this is.41static TLS_THREAD_ID: Cell<usize> = const { Cell::new(usize::MAX) };42);4344slotmap::new_key_type! {45struct TaskKey;46}4748/// High priority tasks are scheduled preferentially over low priority tasks.49#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]50pub enum TaskPriority {51Low,52High,53}5455/// Metadata associated with a task to help schedule it and clean it up.56struct ScopedTaskMetadata {57task_key: TaskKey,58completed_tasks: Weak<Mutex<Vec<TaskKey>>>,59}6061#[derive(Default)]62#[repr(align(128))]63pub struct TaskMetrics {64pub total_polls: RelaxedCell<u64>,65pub total_stolen_polls: RelaxedCell<u64>,66pub total_poll_time_ns: RelaxedCell<u64>,67pub max_poll_time_ns: RelaxedCell<u64>,68pub done: RelaxedCell<bool>,69}7071struct TaskMetadata {72spawn_location: &'static Location<'static>,73priority: TaskPriority,74freshly_spawned: AtomicBool,75scoped: Option<ScopedTaskMetadata>,76metrics: Option<Arc<TaskMetrics>>,77}7879impl Drop for TaskMetadata {80fn drop(&mut self) {81if let Some(metrics) = self.metrics.as_ref() {82metrics.done.store(true);83}8485if let Some(scoped) = &self.scoped {86if let Some(completed_tasks) = scoped.completed_tasks.upgrade() {87completed_tasks.lock().push(scoped.task_key);88}89}90}91}9293pub struct JoinHandle<T>(Arc<dyn DynTask<T, TaskMetadata>>);94pub struct CancelHandle(Weak<dyn Cancellable>);9596impl<T> JoinHandle<T> {97pub fn metrics(&self) -> Option<&Arc<TaskMetrics>> {98self.0.metadata().metrics.as_ref()99}100101#[allow(unused)]102pub fn spawn_location(&self) -> &'static Location<'static> {103self.0.metadata().spawn_location104}105106pub fn cancel_handle(&self) -> CancelHandle {107let coerce: Weak<dyn DynTask<T, TaskMetadata>> = Arc::downgrade(&self.0);108CancelHandle(coerce)109}110}111112impl<T> Future for JoinHandle<T> {113type Output = T;114115fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {116self.0.poll_join(ctx)117}118}119120impl CancelHandle {121pub fn cancel(&self) {122if let Some(t) = self.0.upgrade() {123t.cancel();124}125}126}127128pub struct AbortOnDropHandle<T> {129join_handle: JoinHandle<T>,130cancel_handle: CancelHandle,131}132133impl<T> AbortOnDropHandle<T> {134pub fn new(join_handle: JoinHandle<T>) -> Self {135let cancel_handle = join_handle.cancel_handle();136Self {137join_handle,138cancel_handle,139}140}141}142143impl<T> Future for AbortOnDropHandle<T> {144type Output = T;145146fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {147Pin::new(&mut self.join_handle).poll(cx)148}149}150151impl<T> Drop for AbortOnDropHandle<T> {152fn drop(&mut self) {153self.cancel_handle.cancel();154}155}156157/// A task ready to run.158type ReadyTask = Arc<dyn Runnable<TaskMetadata>>;159160/// A per-thread task list.161struct ThreadLocalTaskList {162// May be used from any thread.163high_prio_tasks_stealer: Stealer<ReadyTask>,164165// SAFETY: these may only be used on the thread this task list belongs to.166high_prio_tasks: WorkQueue<ReadyTask>,167local_slot: UnsafeCell<Option<ReadyTask>>,168}169170unsafe impl Sync for ThreadLocalTaskList {}171172struct Executor {173park_group: ParkGroup,174thread_task_lists: Vec<CachePadded<ThreadLocalTaskList>>,175global_high_prio_task_queue: Injector<ReadyTask>,176global_low_prio_task_queue: Injector<ReadyTask>,177}178179impl Executor {180fn schedule_task(&self, task: ReadyTask) {181let thread = TLS_THREAD_ID.get();182let meta = task.metadata();183let opt_ttl = self.thread_task_lists.get(thread);184185let mut use_global_queue = opt_ttl.is_none();186if meta.freshly_spawned.load(Ordering::Relaxed) {187use_global_queue = true;188meta.freshly_spawned.store(false, Ordering::Relaxed);189}190191if use_global_queue {192// Scheduled from an unknown thread, add to global queue.193if meta.priority == TaskPriority::High {194self.global_high_prio_task_queue.push(task);195} else {196self.global_low_prio_task_queue.push(task);197}198self.park_group.unpark_one();199} else {200let ttl = opt_ttl.unwrap();201// SAFETY: this slot may only be accessed from the local thread, which we are.202let slot = unsafe { &mut *ttl.local_slot.get() };203204if meta.priority == TaskPriority::High {205// Insert new task into thread local slot, taking out the old task.206let Some(task) = slot.replace(task) else {207// We pushed a task into our local slot which was empty. Since208// we are already awake, no need to notify anyone.209return;210};211212ttl.high_prio_tasks.push(task);213self.park_group.unpark_one();214} else {215// Optimization: while this is a low priority task we have no216// high priority tasks on this thread so we'll execute this one.217if ttl.high_prio_tasks.is_empty() && slot.is_none() {218*slot = Some(task);219} else {220self.global_low_prio_task_queue.push(task);221self.park_group.unpark_one();222}223}224}225}226227fn try_steal_task<R: Rng>(&self, thread: usize, rng: &mut R) -> Option<ReadyTask> {228// Try to get a global task.229loop {230match self.global_high_prio_task_queue.steal() {231Steal::Empty => break,232Steal::Success(task) => return Some(task),233Steal::Retry => std::hint::spin_loop(),234}235}236237loop {238match self.global_low_prio_task_queue.steal() {239Steal::Empty => break,240Steal::Success(task) => return Some(task),241Steal::Retry => std::hint::spin_loop(),242}243}244245// Try to steal tasks.246let ttl = &self.thread_task_lists[thread];247for _ in 0..4 {248let mut retry = true;249while retry {250retry = false;251252for idx in random_permutation(self.thread_task_lists.len() as u32, rng) {253let foreign_ttl = &self.thread_task_lists[idx as usize];254match foreign_ttl255.high_prio_tasks_stealer256.steal_batch_and_pop(&ttl.high_prio_tasks)257{258Steal::Empty => {},259Steal::Success(task) => return Some(task),260Steal::Retry => retry = true,261}262}263264std::hint::spin_loop()265}266}267268None269}270271fn runner(&self, thread: usize) {272TLS_THREAD_ID.set(thread);273ALLOW_RAYON_THREADS.set(false);274275let mut rng = SmallRng::from_rng(&mut rand::rng());276let mut worker = self.park_group.new_worker();277278loop {279let ttl = &self.thread_task_lists[thread];280let mut local = true;281let task = (|| {282// Try to get a task from LIFO slot.283if let Some(task) = unsafe { (*ttl.local_slot.get()).take() } {284return Some(task);285}286287// Try to get a local high-priority task.288if let Some(task) = ttl.high_prio_tasks.pop() {289return Some(task);290}291292// Try to steal a task.293local = false;294if let Some(task) = self.try_steal_task(thread, &mut rng) {295return Some(task);296}297298// Prepare to park, then try one more steal attempt.299let park = worker.prepare_park();300if let Some(task) = self.try_steal_task(thread, &mut rng) {301return Some(task);302}303304park.park();305None306})();307308if let Some(task) = task {309worker.recruit_next();310if let Some(metrics) = task.metadata().metrics.clone() {311let start = Instant::now();312task.run();313let elapsed_ns = start.elapsed().as_nanos() as u64;314metrics.total_polls.fetch_add(1);315if !local {316metrics.total_stolen_polls.fetch_add(1);317}318metrics.total_poll_time_ns.fetch_add(elapsed_ns);319metrics.max_poll_time_ns.fetch_max(elapsed_ns);320} else {321task.run();322}323}324}325}326327fn global() -> &'static Executor {328GLOBAL_SCHEDULER.get_or_init(|| {329let mut n_threads = NUM_EXECUTOR_THREADS.load();330if n_threads == 0 {331n_threads = std::thread::available_parallelism()332.map(|n| n.get())333.unwrap_or(4);334}335336let thread_task_lists = (0..n_threads)337.map(|t| {338std::thread::Builder::new()339.name(format!("async-executor-{t}"))340.spawn(move || Self::global().runner(t))341.unwrap();342343let high_prio_tasks = WorkQueue::new_lifo();344CachePadded::new(ThreadLocalTaskList {345high_prio_tasks_stealer: high_prio_tasks.stealer(),346high_prio_tasks,347local_slot: UnsafeCell::new(None),348})349})350.collect();351Self {352park_group: ParkGroup::new(),353thread_task_lists,354global_high_prio_task_queue: Injector::new(),355global_low_prio_task_queue: Injector::new(),356}357})358}359}360361pub struct TaskScope<'scope, 'env: 'scope> {362// Keep track of in-progress tasks so we can forcibly cancel them363// when the scope ends, to ensure the lifetimes are respected.364// Tasks add their own key to completed_tasks when done so we can365// reclaim the memory used by the cancel_handles.366cancel_handles: Mutex<SlotMap<TaskKey, CancelHandle>>,367completed_tasks: Arc<Mutex<Vec<TaskKey>>>,368369// Copied from std::thread::scope. Necessary to prevent unsoundness.370scope: PhantomData<&'scope mut &'scope ()>,371env: PhantomData<&'env mut &'env ()>,372}373374impl<'scope> TaskScope<'scope, '_> {375// Not Drop because that extends lifetimes.376fn destroy(&self) {377// Make sure all tasks are cancelled.378for (_, t) in self.cancel_handles.lock().drain() {379t.cancel();380}381}382383fn clear_completed_tasks(&self) {384let mut cancel_handles = self.cancel_handles.lock();385for t in self.completed_tasks.lock().drain(..) {386cancel_handles.remove(t);387}388}389390#[track_caller]391pub fn spawn_task<F: Future + Send + 'scope>(392&self,393priority: TaskPriority,394fut: F,395) -> JoinHandle<F::Output>396where397<F as Future>::Output: Send + 'static,398{399let spawn_location = Location::caller();400self.clear_completed_tasks();401402let mut runnable = None;403let mut join_handle = None;404self.cancel_handles.lock().insert_with_key(|task_key| {405let metrics = TRACK_METRICS.load().then(Arc::default);406let dyn_task = unsafe {407// SAFETY: we make sure to cancel this task before 'scope ends.408let executor = Executor::global();409let on_wake = move |task| executor.schedule_task(task);410task::spawn_with_lifetime(411fut,412on_wake,413TaskMetadata {414spawn_location,415priority,416freshly_spawned: AtomicBool::new(true),417scoped: Some(ScopedTaskMetadata {418task_key,419completed_tasks: Arc::downgrade(&self.completed_tasks),420}),421metrics,422},423)424};425runnable = Some(Arc::clone(&dyn_task));426let jh = JoinHandle(dyn_task);427let cancel_handle = jh.cancel_handle();428join_handle = Some(jh);429cancel_handle430});431runnable.unwrap().schedule();432join_handle.unwrap()433}434}435436pub fn task_scope<'env, F, T>(f: F) -> T437where438F: for<'scope> FnOnce(&'scope TaskScope<'scope, 'env>) -> T,439{440// By having this local variable inaccessible to anyone we guarantee441// that either abort is called killing the entire process, or that this442// executor is properly destroyed.443let scope = TaskScope {444cancel_handles: Mutex::default(),445completed_tasks: Arc::new(Mutex::default()),446scope: PhantomData,447env: PhantomData,448};449450let result = std::panic::catch_unwind(AssertUnwindSafe(|| f(&scope)));451452// Make sure all tasks are properly destroyed.453scope.destroy();454455match result {456Err(e) => std::panic::resume_unwind(e),457Ok(result) => result,458}459}460461#[track_caller]462pub fn spawn<F: Future + Send + 'static>(priority: TaskPriority, fut: F) -> JoinHandle<F::Output>463where464<F as Future>::Output: Send + 'static,465{466let spawn_location = Location::caller();467let executor = Executor::global();468let on_wake = move |task| executor.schedule_task(task);469let metrics = TRACK_METRICS.load().then(Arc::default);470let dyn_task = task::spawn(471fut,472on_wake,473TaskMetadata {474spawn_location,475priority,476freshly_spawned: AtomicBool::new(true),477scoped: None,478metrics,479},480);481Arc::clone(&dyn_task).schedule();482JoinHandle(dyn_task)483}484485fn random_permutation<R: Rng>(len: u32, rng: &mut R) -> impl Iterator<Item = u32> {486let modulus = len.next_power_of_two();487let halfwidth = modulus.trailing_zeros() / 2;488let mask = modulus - 1;489let displace_zero = rng.random::<u32>();490let odd1 = rng.random::<u32>() | 1;491let odd2 = rng.random::<u32>() | 1;492let uniform_first = ((rng.random::<u32>() as u64 * len as u64) >> 32) as u32;493494(0..modulus)495.map(move |mut i| {496// Invertible permutation on [0, modulus).497i = i.wrapping_add(displace_zero);498i = i.wrapping_mul(odd1);499i ^= (i & mask) >> halfwidth;500i = i.wrapping_mul(odd2);501i & mask502})503.filter(move |i| *i < len)504.map(move |mut i| {505i += uniform_first;506if i >= len {507i -= len;508}509i510})511}512513514