Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
google
GitHub Repository: google/crosvm
Path: blob/main/cros_async/src/tokio_executor.rs
5392 views
1
// Copyright 2023 The ChromiumOS Authors
2
// Use of this source code is governed by a BSD-style license that can be
3
// found in the LICENSE file.
4
5
use std::future::Future;
6
use std::pin::Pin;
7
use std::sync::Arc;
8
use std::sync::OnceLock;
9
10
use base::AsRawDescriptors;
11
use base::RawDescriptor;
12
use tokio::runtime::Runtime;
13
use tokio::task::LocalSet;
14
15
use crate::sys::platform::tokio_source::TokioSource;
16
use crate::AsyncError;
17
use crate::AsyncResult;
18
use crate::ExecutorTrait;
19
use crate::IntoAsync;
20
use crate::IoSource;
21
use crate::TaskHandle;
22
23
mod send_wrapper {
24
use std::thread;
25
26
#[derive(Clone)]
27
pub(super) struct SendWrapper<T> {
28
instance: T,
29
thread_id: thread::ThreadId,
30
}
31
32
impl<T> SendWrapper<T> {
33
pub(super) fn new(instance: T) -> SendWrapper<T> {
34
SendWrapper {
35
instance,
36
thread_id: thread::current().id(),
37
}
38
}
39
}
40
41
// SAFETY: panics when the value is accessed on the wrong thread.
42
unsafe impl<T> Send for SendWrapper<T> {}
43
// SAFETY: panics when the value is accessed on the wrong thread.
44
unsafe impl<T> Sync for SendWrapper<T> {}
45
46
impl<T> Drop for SendWrapper<T> {
47
fn drop(&mut self) {
48
if self.thread_id != thread::current().id() {
49
panic!("SendWrapper value was dropped on the wrong thread");
50
}
51
}
52
}
53
54
impl<T> std::ops::Deref for SendWrapper<T> {
55
type Target = T;
56
57
fn deref(&self) -> &T {
58
if self.thread_id != thread::current().id() {
59
panic!("SendWrapper value was accessed on the wrong thread");
60
}
61
&self.instance
62
}
63
}
64
}
65
66
#[derive(Clone)]
67
pub struct TokioExecutor {
68
runtime: Arc<Runtime>,
69
local_set: Arc<OnceLock<send_wrapper::SendWrapper<LocalSet>>>,
70
}
71
72
impl TokioExecutor {
73
pub fn new() -> AsyncResult<Self> {
74
Ok(TokioExecutor {
75
runtime: Arc::new(Runtime::new().map_err(AsyncError::Io)?),
76
local_set: Arc::new(OnceLock::new()),
77
})
78
}
79
}
80
81
impl ExecutorTrait for TokioExecutor {
82
fn async_from<'a, F: IntoAsync + 'a>(&self, f: F) -> AsyncResult<IoSource<F>> {
83
Ok(IoSource::Tokio(TokioSource::new(
84
f,
85
self.runtime.handle().clone(),
86
)?))
87
}
88
89
fn run_until<F: Future>(&self, f: F) -> AsyncResult<F::Output> {
90
let local_set = self
91
.local_set
92
.get_or_init(|| send_wrapper::SendWrapper::new(LocalSet::new()));
93
Ok(self
94
.runtime
95
.block_on(async { local_set.run_until(f).await }))
96
}
97
98
fn spawn<F>(&self, f: F) -> TaskHandle<F::Output>
99
where
100
F: Future + Send + 'static,
101
F::Output: Send + 'static,
102
{
103
TaskHandle::Tokio(TokioTaskHandle {
104
join_handle: Some(self.runtime.spawn(f)),
105
})
106
}
107
108
fn spawn_blocking<F, R>(&self, f: F) -> TaskHandle<R>
109
where
110
F: FnOnce() -> R + Send + 'static,
111
R: Send + 'static,
112
{
113
TaskHandle::Tokio(TokioTaskHandle {
114
join_handle: Some(self.runtime.spawn_blocking(f)),
115
})
116
}
117
118
fn spawn_local<F>(&self, f: F) -> TaskHandle<F::Output>
119
where
120
F: Future + 'static,
121
F::Output: 'static,
122
{
123
let local_set = self
124
.local_set
125
.get_or_init(|| send_wrapper::SendWrapper::new(LocalSet::new()));
126
TaskHandle::Tokio(TokioTaskHandle {
127
join_handle: Some(local_set.spawn_local(f)),
128
})
129
}
130
}
131
132
impl AsRawDescriptors for TokioExecutor {
133
fn as_raw_descriptors(&self) -> Vec<RawDescriptor> {
134
todo!();
135
}
136
}
137
138
pub struct TokioTaskHandle<T> {
139
join_handle: Option<tokio::task::JoinHandle<T>>,
140
}
141
impl<R> TokioTaskHandle<R> {
142
pub async fn cancel(mut self) -> Option<R> {
143
match self.join_handle.take() {
144
Some(handle) => {
145
handle.abort();
146
handle.await.ok()
147
}
148
None => None,
149
}
150
}
151
pub fn detach(mut self) {
152
self.join_handle.take();
153
}
154
}
155
impl<R: 'static> Future for TokioTaskHandle<R> {
156
type Output = R;
157
fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context) -> std::task::Poll<Self::Output> {
158
let self_mut = self.get_mut();
159
Pin::new(self_mut.join_handle.as_mut().unwrap())
160
.poll(cx)
161
.map(|v| v.unwrap())
162
}
163
}
164
impl<T> std::ops::Drop for TokioTaskHandle<T> {
165
fn drop(&mut self) {
166
if let Some(handle) = self.join_handle.take() {
167
handle.abort()
168
}
169
}
170
}
171
172