use std::collections::VecDeque;
use std::future::Future;
use std::mem;
use std::sync::mpsc::channel;
use std::sync::mpsc::Receiver;
use std::sync::mpsc::Sender;
use std::sync::Arc;
use std::thread;
use std::thread::JoinHandle;
use std::time::Duration;
use std::time::Instant;
use base::error;
use base::warn;
use futures::channel::oneshot;
use slab::Slab;
use sync::Condvar;
use sync::Mutex;
const DEFAULT_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(10);
struct State {
tasks: VecDeque<Box<dyn FnOnce() + Send>>,
num_threads: usize,
num_idle: usize,
num_notified: usize,
worker_threads: Slab<JoinHandle<()>>,
exited_threads: Option<Receiver<usize>>,
exit: Sender<usize>,
shutting_down: bool,
}
fn run_blocking_thread(idx: usize, inner: Arc<Inner>, exit: Sender<usize>) {
let mut state = inner.state.lock();
while !state.shutting_down {
if let Some(f) = state.tasks.pop_front() {
drop(state);
f();
state = inner.state.lock();
continue;
}
state.num_idle += 1;
let (guard, result) = inner
.condvar
.wait_timeout_while(state, inner.keepalive, |s| {
!s.shutting_down && s.num_notified == 0
});
state = guard;
if state.num_notified > 0 {
state.num_notified -= 1;
continue;
}
if result.timed_out() {
state.num_idle = state
.num_idle
.checked_sub(1)
.expect("`num_idle` underflow on timeout");
break;
}
}
state.num_threads -= 1;
let last_exited_thread = if let Some(exited_threads) = state.exited_threads.as_mut() {
exited_threads
.try_recv()
.map(|idx| state.worker_threads.remove(idx))
.ok()
} else {
None
};
drop(state);
if let Some(handle) = last_exited_thread {
let _ = handle.join();
}
if let Err(e) = exit.send(idx) {
error!("Failed to send thread exit event on channel: {}", e);
}
}
struct Inner {
state: Mutex<State>,
condvar: Condvar,
max_threads: usize,
keepalive: Duration,
}
impl Inner {
pub fn spawn<F, R>(self: &Arc<Self>, f: F) -> impl Future<Output = R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
let mut state = self.state.lock();
if state.shutting_down {
error!("spawn called after shutdown");
return futures::future::Either::Left(async {
panic!("tried to poll BlockingPool task after shutdown")
});
}
let (send_chan, recv_chan) = oneshot::channel();
state.tasks.push_back(Box::new(|| {
let _ = send_chan.send(f());
}));
if state.num_idle == 0 {
if state.num_threads < self.max_threads {
state.num_threads += 1;
let exit = state.exit.clone();
let entry = state.worker_threads.vacant_entry();
let idx = entry.key();
let inner = self.clone();
entry.insert(
thread::Builder::new()
.name(format!("blockingPool{idx}"))
.spawn(move || run_blocking_thread(idx, inner, exit))
.unwrap(),
);
}
} else {
state.num_idle -= 1;
state.num_notified += 1;
self.condvar.notify_one();
}
futures::future::Either::Right(async {
recv_chan
.await
.expect("BlockingThread task unexpectedly cancelled")
})
}
}
#[derive(Debug, thiserror::Error)]
#[error("{0} BlockingPool threads did not exit in time and will be detached")]
pub struct ShutdownTimedOut(usize);
pub struct BlockingPool {
inner: Arc<Inner>,
}
impl BlockingPool {
pub fn new(max_threads: usize, keepalive: Duration) -> BlockingPool {
let (exit, exited_threads) = channel();
BlockingPool {
inner: Arc::new(Inner {
state: Mutex::new(State {
tasks: VecDeque::new(),
num_threads: 0,
num_idle: 0,
num_notified: 0,
worker_threads: Slab::new(),
exited_threads: Some(exited_threads),
exit,
shutting_down: false,
}),
condvar: Condvar::new(),
max_threads,
keepalive,
}),
}
}
pub fn with_capacity(max_threads: usize, keepalive: Duration) -> BlockingPool {
let (exit, exited_threads) = channel();
BlockingPool {
inner: Arc::new(Inner {
state: Mutex::new(State {
tasks: VecDeque::new(),
num_threads: 0,
num_idle: 0,
num_notified: 0,
worker_threads: Slab::with_capacity(max_threads),
exited_threads: Some(exited_threads),
exit,
shutting_down: false,
}),
condvar: Condvar::new(),
max_threads,
keepalive,
}),
}
}
pub fn spawn<F, R>(&self, f: F) -> impl Future<Output = R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
self.inner.spawn(f)
}
pub fn shutdown(&self, deadline: Option<Instant>) -> Result<(), ShutdownTimedOut> {
let mut state = self.inner.state.lock();
if state.shutting_down {
return Ok(());
}
state.shutting_down = true;
let exited_threads = state.exited_threads.take().expect("exited_threads missing");
let unfinished_tasks = std::mem::take(&mut state.tasks);
let mut worker_threads = mem::replace(&mut state.worker_threads, Slab::new());
drop(state);
self.inner.condvar.notify_all();
drop(unfinished_tasks);
if let Some(deadline) = deadline {
let mut now = Instant::now();
while now < deadline && !worker_threads.is_empty() {
if let Ok(idx) = exited_threads.recv_timeout(deadline - now) {
let _ = worker_threads.remove(idx).join();
}
now = Instant::now();
}
if !worker_threads.is_empty() {
return Err(ShutdownTimedOut(worker_threads.len()));
}
Ok(())
} else {
for handle in worker_threads.drain() {
let _ = handle.join();
}
Ok(())
}
}
#[cfg(test)]
pub(crate) fn shutting_down(&self) -> bool {
self.inner.state.lock().shutting_down
}
}
impl Default for BlockingPool {
fn default() -> BlockingPool {
BlockingPool::new(256, Duration::from_secs(10))
}
}
impl Drop for BlockingPool {
fn drop(&mut self) {
if let Err(e) = self.shutdown(Some(Instant::now() + DEFAULT_SHUTDOWN_TIMEOUT)) {
warn!("{}", e);
}
}
}
#[cfg(test)]
mod test {
use std::sync::Arc;
use std::sync::Barrier;
use std::thread;
use std::time::Duration;
use std::time::Instant;
use futures::executor::block_on;
use futures::stream::FuturesUnordered;
use futures::StreamExt;
use sync::Condvar;
use sync::Mutex;
use super::super::super::BlockingPool;
#[test]
fn blocking_sleep() {
let pool = BlockingPool::default();
let res = block_on(pool.spawn(|| 42));
assert_eq!(res, 42);
}
#[test]
fn drop_doesnt_block() {
let pool = BlockingPool::default();
let (tx, rx) = std::sync::mpsc::sync_channel(0);
std::mem::drop(pool.spawn(move || tx.send(()).unwrap()));
rx.recv().unwrap();
}
#[test]
fn fast_tasks_with_short_keepalive() {
let pool = BlockingPool::new(256, Duration::from_millis(1));
let streams = FuturesUnordered::new();
for _ in 0..2 {
for _ in 0..256 {
let task = pool.spawn(|| ());
streams.push(task);
}
thread::sleep(Duration::from_millis(1));
}
block_on(streams.collect::<Vec<_>>());
}
#[test]
fn more_tasks_than_threads() {
let pool = BlockingPool::new(4, Duration::from_secs(10));
let stream = (0..19)
.map(|_| pool.spawn(|| thread::sleep(Duration::from_millis(5))))
.collect::<FuturesUnordered<_>>();
let results = block_on(stream.collect::<Vec<_>>());
assert_eq!(results.len(), 19);
}
#[test]
fn shutdown() {
let pool = BlockingPool::default();
let stream = (0..19)
.map(|_| pool.spawn(|| thread::sleep(Duration::from_millis(5))))
.collect::<FuturesUnordered<_>>();
let results = block_on(stream.collect::<Vec<_>>());
assert_eq!(results.len(), 19);
pool.shutdown(Some(Instant::now() + Duration::from_secs(10)))
.unwrap();
let state = pool.inner.state.lock();
assert_eq!(state.num_threads, 0);
}
#[test]
fn keepalive_timeout() {
let pool = BlockingPool::new(7, Duration::from_millis(1));
let stream = (0..19)
.map(|_| pool.spawn(|| thread::sleep(Duration::from_millis(5))))
.collect::<FuturesUnordered<_>>();
let results = block_on(stream.collect::<Vec<_>>());
assert_eq!(results.len(), 19);
let deadline = Instant::now() + Duration::from_secs(10);
while Instant::now() < deadline {
thread::sleep(Duration::from_millis(100));
let state = pool.inner.state.lock();
if state.num_threads == 0 {
break;
}
}
{
let state = pool.inner.state.lock();
assert_eq!(state.num_threads, 0);
assert_eq!(state.num_idle, 0);
}
}
#[test]
#[should_panic]
fn shutdown_with_pending_work() {
let pool = BlockingPool::new(1, Duration::from_secs(10));
let mu = Arc::new(Mutex::new(false));
let cv = Arc::new(Condvar::new());
let task_mu = mu.clone();
let task_cv = cv.clone();
let _blocking_task = pool.spawn(move || {
let mut ready = task_mu.lock();
while !*ready {
ready = task_cv.wait(ready);
}
});
let unfinished = pool.spawn(|| 5);
let inner = pool.inner.clone();
thread::spawn(move || {
let mut state = inner.state.lock();
while !state.shutting_down {
state = inner.condvar.wait(state);
}
*mu.lock() = true;
cv.notify_all();
});
pool.shutdown(None).unwrap();
assert_eq!(block_on(unfinished), 5);
}
#[test]
fn unfinished_worker_thread() {
let pool = BlockingPool::default();
let ready = Arc::new(Mutex::new(false));
let cv = Arc::new(Condvar::new());
let barrier = Arc::new(Barrier::new(2));
let thread_ready = ready.clone();
let thread_barrier = barrier.clone();
let thread_cv = cv.clone();
let task = pool.spawn(move || {
thread_barrier.wait();
let mut ready = thread_ready.lock();
while !*ready {
ready = thread_cv.wait(ready);
}
});
barrier.wait();
pool.shutdown(Some(Instant::now() + Duration::from_millis(5)))
.unwrap_err();
let num_threads = pool.inner.state.lock().num_threads;
assert_eq!(num_threads, 1);
*ready.lock() = true;
cv.notify_all();
block_on(task);
let deadline = Instant::now() + Duration::from_secs(10);
while Instant::now() < deadline {
thread::sleep(Duration::from_millis(100));
let state = pool.inner.state.lock();
if state.num_threads == 0 {
break;
}
}
{
let state = pool.inner.state.lock();
assert_eq!(state.num_threads, 0);
assert_eq!(state.num_idle, 0);
}
}
}