Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
bytecodealliance
GitHub Repository: bytecodealliance/wasmtime
Path: blob/main/crates/wasi/src/cli/locked_async.rs
1693 views
1
use crate::cli::{IsTerminal, StdinStream, StdoutStream};
2
use crate::p2;
3
use bytes::Bytes;
4
use std::mem;
5
use std::pin::Pin;
6
use std::sync::Arc;
7
use std::task::{Context, Poll, ready};
8
use tokio::io::{self, AsyncRead, AsyncWrite};
9
use tokio::sync::{Mutex, OwnedMutexGuard};
10
use wasmtime_wasi_io::streams::{InputStream, OutputStream};
11
12
trait SharedHandleReady: Send + Sync + 'static {
13
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()>;
14
}
15
16
impl SharedHandleReady for p2::pipe::AsyncWriteStream {
17
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()> {
18
<Self>::poll_ready(self, cx)
19
}
20
}
21
22
impl SharedHandleReady for p2::pipe::AsyncReadStream {
23
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()> {
24
<Self>::poll_ready(self, cx)
25
}
26
}
27
28
/// An impl of [`StdinStream`] built on top of [`AsyncRead`].
29
//
30
// Note the usage of `tokio::sync::Mutex` here as opposed to a
31
// `std::sync::Mutex`. This is intentionally done to implement the `Pollable`
32
// variant of this trait. Note that in doing so we're left with the quandry of
33
// how to implement methods of `InputStream` since those methods are not
34
// `async`. They're currently implemented with `try_lock`, which then raises the
35
// question of what to do on contention. Currently traps are returned.
36
//
37
// Why should it be ok to return a trap? In general concurrency/contention
38
// shouldn't return a trap since it should be able to happen normally. The
39
// current assumption, though, is that WASI stdin/stdout streams are special
40
// enough that the contention case should never come up in practice. Currently
41
// in WASI there is no actually concurrency, there's just the items in a single
42
// `Store` and that store owns all of its I/O in a single Tokio task. There's no
43
// means to actually spawn multiple Tokio tasks that use the same store. This
44
// means at the very least that there's zero parallelism. Due to the lack of
45
// multiple tasks that also means that there's no concurrency either.
46
//
47
// This `AsyncStdinStream` wrapper is only intended to be used by the WASI
48
// bindings themselves. It's possible for the host to take this and work with it
49
// on its own task, but that's niche enough it's not designed for.
50
//
51
// Overall that means that the guest is either calling `Pollable` or
52
// `InputStream` methods. This means that there should never be contention
53
// between the two at this time. This may all change in the future with WASI
54
// 0.3, but perhaps we'll have a better story for stdio at that time (see the
55
// doc block on the `OutputStream` impl below)
56
pub struct AsyncStdinStream(Arc<Mutex<p2::pipe::AsyncReadStream>>);
57
58
impl AsyncStdinStream {
59
pub fn new(s: impl AsyncRead + Send + Sync + 'static) -> Self {
60
Self(Arc::new(Mutex::new(p2::pipe::AsyncReadStream::new(s))))
61
}
62
}
63
64
impl StdinStream for AsyncStdinStream {
65
fn p2_stream(&self) -> Box<dyn InputStream> {
66
Box::new(Self(self.0.clone()))
67
}
68
fn async_stream(&self) -> Box<dyn AsyncRead + Send + Sync> {
69
Box::new(StdioHandle::Ready(self.0.clone()))
70
}
71
}
72
73
impl IsTerminal for AsyncStdinStream {
74
fn is_terminal(&self) -> bool {
75
false
76
}
77
}
78
79
#[async_trait::async_trait]
80
impl InputStream for AsyncStdinStream {
81
fn read(&mut self, size: usize) -> Result<bytes::Bytes, p2::StreamError> {
82
match self.0.try_lock() {
83
Ok(mut stream) => stream.read(size),
84
Err(_) => Err(p2::StreamError::trap("concurrent reads are not supported")),
85
}
86
}
87
fn skip(&mut self, size: usize) -> Result<usize, p2::StreamError> {
88
match self.0.try_lock() {
89
Ok(mut stream) => stream.skip(size),
90
Err(_) => Err(p2::StreamError::trap("concurrent skips are not supported")),
91
}
92
}
93
async fn cancel(&mut self) {
94
// Cancel the inner stream if we're the last reference to it:
95
if let Some(mutex) = Arc::get_mut(&mut self.0) {
96
match mutex.try_lock() {
97
Ok(mut stream) => stream.cancel().await,
98
Err(_) => {}
99
}
100
}
101
}
102
}
103
104
#[async_trait::async_trait]
105
impl p2::Pollable for AsyncStdinStream {
106
async fn ready(&mut self) {
107
self.0.lock().await.ready().await
108
}
109
}
110
111
impl AsyncRead for StdioHandle<p2::pipe::AsyncReadStream> {
112
fn poll_read(
113
mut self: Pin<&mut Self>,
114
cx: &mut Context<'_>,
115
buf: &mut io::ReadBuf<'_>,
116
) -> Poll<io::Result<()>> {
117
match ready!(self.as_mut().poll(cx, |g| g.read(buf.remaining()))) {
118
Some(Ok(bytes)) => {
119
buf.put_slice(&bytes);
120
Poll::Ready(Ok(()))
121
}
122
Some(Err(e)) => Poll::Ready(Err(e)),
123
// If the guard can't be acquired that means that this stream is
124
// closed, so return that we're ready without filling in data.
125
None => Poll::Ready(Ok(())),
126
}
127
}
128
}
129
130
/// A wrapper of [`crate::p2::pipe::AsyncWriteStream`] that implements
131
/// [`StdoutStream`]. Note that the [`OutputStream`] impl for this is not
132
/// correct when used for interleaved async IO.
133
//
134
// Note that the use of `tokio::sync::Mutex` here is intentional, in addition to
135
// the `try_lock()` calls below in the implementation of `OutputStream`. For
136
// more information see the documentation on `AsyncStdinStream`.
137
pub struct AsyncStdoutStream(Arc<Mutex<p2::pipe::AsyncWriteStream>>);
138
139
impl AsyncStdoutStream {
140
pub fn new(budget: usize, s: impl AsyncWrite + Send + Sync + 'static) -> Self {
141
Self(Arc::new(Mutex::new(p2::pipe::AsyncWriteStream::new(
142
budget, s,
143
))))
144
}
145
}
146
147
impl StdoutStream for AsyncStdoutStream {
148
fn p2_stream(&self) -> Box<dyn OutputStream> {
149
Box::new(Self(self.0.clone()))
150
}
151
fn async_stream(&self) -> Box<dyn AsyncWrite + Send + Sync> {
152
Box::new(StdioHandle::Ready(self.0.clone()))
153
}
154
}
155
156
impl IsTerminal for AsyncStdoutStream {
157
fn is_terminal(&self) -> bool {
158
false
159
}
160
}
161
162
// This implementation is known to be bogus. All check-writes and writes are
163
// directed at the same underlying stream. The check-write/write protocol does
164
// require the size returned by a check-write to be accepted by write, even if
165
// other side-effects happen between those calls, and this implementation
166
// permits another view (created by StdoutStream::stream()) of the same
167
// underlying stream to accept a write which will invalidate a prior
168
// check-write of another view.
169
// Ultimately, the Std{in,out}Stream::stream() methods exist because many
170
// different places in a linked component (which may itself contain many
171
// modules) may need to access stdio without any coordination to keep those
172
// accesses all using pointing to the same resource. So, we allow many
173
// resources to be created. We have the reasonable expectation that programs
174
// won't attempt to interleave async IO from these disparate uses of stdio.
175
// If that expectation doesn't turn out to be true, and you find yourself at
176
// this comment to correct it: sorry about that.
177
#[async_trait::async_trait]
178
impl OutputStream for AsyncStdoutStream {
179
fn check_write(&mut self) -> Result<usize, p2::StreamError> {
180
match self.0.try_lock() {
181
Ok(mut stream) => stream.check_write(),
182
Err(_) => Err(p2::StreamError::trap("concurrent writes are not supported")),
183
}
184
}
185
fn write(&mut self, bytes: Bytes) -> Result<(), p2::StreamError> {
186
match self.0.try_lock() {
187
Ok(mut stream) => stream.write(bytes),
188
Err(_) => Err(p2::StreamError::trap("concurrent writes not supported yet")),
189
}
190
}
191
fn flush(&mut self) -> Result<(), p2::StreamError> {
192
match self.0.try_lock() {
193
Ok(mut stream) => stream.flush(),
194
Err(_) => Err(p2::StreamError::trap(
195
"concurrent flushes not supported yet",
196
)),
197
}
198
}
199
async fn cancel(&mut self) {
200
// Cancel the inner stream if we're the last reference to it:
201
if let Some(mutex) = Arc::get_mut(&mut self.0) {
202
match mutex.try_lock() {
203
Ok(mut stream) => stream.cancel().await,
204
Err(_) => {}
205
}
206
}
207
}
208
}
209
210
#[async_trait::async_trait]
211
impl p2::Pollable for AsyncStdoutStream {
212
async fn ready(&mut self) {
213
self.0.lock().await.ready().await
214
}
215
}
216
217
impl AsyncWrite for StdioHandle<p2::pipe::AsyncWriteStream> {
218
fn poll_write(
219
self: Pin<&mut Self>,
220
cx: &mut Context<'_>,
221
buf: &[u8],
222
) -> Poll<io::Result<usize>> {
223
match ready!(self.poll(cx, |i| i.write(Bytes::copy_from_slice(buf)))) {
224
Some(Ok(())) => Poll::Ready(Ok(buf.len())),
225
Some(Err(e)) => Poll::Ready(Err(e)),
226
None => Poll::Ready(Ok(0)),
227
}
228
}
229
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
230
match ready!(self.poll(cx, |i| i.flush())) {
231
Some(result) => Poll::Ready(result),
232
None => Poll::Ready(Ok(())),
233
}
234
}
235
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
236
Poll::Ready(Ok(()))
237
}
238
}
239
240
/// State necessary for effectively transforming `Arc<Mutex<dyn
241
/// {Input,Output}Stream>>` into `Async{Read,Write}`.
242
///
243
/// This is a beast and inefficient. It should get the job done in theory but
244
/// one must truly ask oneself at some point "but at what cost".
245
///
246
/// More seriously, it's unclear if this is the best way to transform a single
247
/// `AsyncRead` into a "multiple `AsyncRead`". This certainly is an attempt and
248
/// the hope is that everything here is private enough that we can refactor as
249
/// necessary in the future without causing much churn.
250
enum StdioHandle<S> {
251
Ready(Arc<Mutex<S>>),
252
Locking(Box<dyn Future<Output = OwnedMutexGuard<S>> + Send + Sync>),
253
Locked(OwnedMutexGuard<S>),
254
Closed,
255
}
256
257
impl<S> StdioHandle<S>
258
where
259
S: SharedHandleReady,
260
{
261
fn poll<T>(
262
mut self: Pin<&mut Self>,
263
cx: &mut Context<'_>,
264
op: impl FnOnce(&mut S) -> p2::StreamResult<T>,
265
) -> Poll<Option<io::Result<T>>> {
266
// If we don't currently have the lock on this handle, initiate the
267
// lock acquisition.
268
if let StdioHandle::Ready(lock) = &*self {
269
self.set(StdioHandle::Locking(Box::new(lock.clone().lock_owned())));
270
}
271
272
// If we're in the process of locking this handle, wait for that to
273
// finish.
274
if let Some(lock) = self.as_mut().as_locking() {
275
let guard = ready!(lock.poll(cx));
276
self.set(StdioHandle::Locked(guard));
277
}
278
279
let mut guard = match self.as_mut().take_guard() {
280
Some(guard) => guard,
281
// If the guard can't be acquired that means that this stream is
282
// closed, so return that we're ready without filling in data.
283
None => return Poll::Ready(None),
284
};
285
286
// Wait for our locked stream to be ready, resetting to the "locked"
287
// state if it's not quite ready yet.
288
match guard.poll_ready(cx) {
289
Poll::Ready(()) => {}
290
291
// If the read isn't ready yet then restore our "locked" state
292
// since we haven't finished, then return pending.
293
Poll::Pending => {
294
self.set(StdioHandle::Locked(guard));
295
return Poll::Pending;
296
}
297
}
298
299
// Perform the I/O and delegate on the result.
300
match op(&mut guard) {
301
// The I/O succeeded so relinquish the lock on this stream by
302
// transitioning back to the "Ready" state.
303
Ok(result) => {
304
self.set(StdioHandle::Ready(OwnedMutexGuard::mutex(&guard).clone()));
305
Poll::Ready(Some(Ok(result)))
306
}
307
308
// The stream is closed, and `take_guard` above already set the
309
// closed state, so return nothing indicating the closure.
310
Err(p2::StreamError::Closed) => Poll::Ready(None),
311
312
// The stream failed so propagate the error. Errors should only
313
// come from the underlying I/O object and thus should cast
314
// successfully. Additionally `take_guard` replaced our state
315
// with "closed" above which is the desired state at this point.
316
Err(p2::StreamError::LastOperationFailed(e)) => {
317
Poll::Ready(Some(Err(e.downcast().unwrap())))
318
}
319
320
// Shouldn't be possible to produce a trap here.
321
Err(p2::StreamError::Trap(_)) => unreachable!(),
322
}
323
}
324
325
fn as_locking(
326
self: Pin<&mut Self>,
327
) -> Option<Pin<&mut dyn Future<Output = OwnedMutexGuard<S>>>> {
328
// SAFETY: this is a pin-projection from `self` into the `Locking`
329
// field.
330
unsafe {
331
match self.get_unchecked_mut() {
332
StdioHandle::Locking(future) => Some(Pin::new_unchecked(&mut **future)),
333
_ => None,
334
}
335
}
336
}
337
338
fn take_guard(self: Pin<&mut Self>) -> Option<OwnedMutexGuard<S>> {
339
if !matches!(*self, StdioHandle::Locked(_)) {
340
return None;
341
}
342
// SAFETY: the `Locked` arm is safe to move as it's an invariant of this
343
// type that it's not pinned.
344
unsafe {
345
match mem::replace(self.get_unchecked_mut(), StdioHandle::Closed) {
346
StdioHandle::Locked(guard) => Some(guard),
347
_ => unreachable!(),
348
}
349
}
350
}
351
}
352
353