Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
bytecodealliance
GitHub Repository: bytecodealliance/wasmtime
Path: blob/main/crates/wasi-tls/src/io.rs
1691 views
1
//! Utility types for converting Rust & Tokio I/O types into WASI I/O types,
2
//! and vice versa.
3
4
use anyhow::Result;
5
use bytes::Bytes;
6
use std::io;
7
use std::sync::Arc;
8
use std::task::{Poll, ready};
9
use std::{future::Future, mem, pin::Pin};
10
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
11
use tokio::sync::Mutex;
12
use wasmtime_wasi::async_trait;
13
use wasmtime_wasi::p2::{
14
DynInputStream, DynOutputStream, OutputStream, Pollable, StreamError, StreamResult,
15
};
16
use wasmtime_wasi::runtime::AbortOnDropJoinHandle;
17
18
enum FutureState<T> {
19
Pending(Pin<Box<dyn Future<Output = T> + Send>>),
20
Ready(T),
21
Consumed,
22
}
23
24
pub(crate) enum FutureOutput<T> {
25
Pending,
26
Ready(T),
27
Consumed,
28
}
29
30
pub(crate) struct WasiFuture<T>(FutureState<T>);
31
32
impl<T> WasiFuture<T>
33
where
34
T: Send + 'static,
35
{
36
pub(crate) fn spawn<F>(fut: F) -> Self
37
where
38
F: Future<Output = T> + Send + 'static,
39
{
40
Self(FutureState::Pending(Box::pin(
41
wasmtime_wasi::runtime::spawn(async move { fut.await }),
42
)))
43
}
44
45
pub(crate) fn get(&mut self) -> FutureOutput<T> {
46
match &self.0 {
47
FutureState::Pending(_) => return FutureOutput::Pending,
48
FutureState::Consumed => return FutureOutput::Consumed,
49
FutureState::Ready(_) => (),
50
}
51
52
let FutureState::Ready(value) = mem::replace(&mut self.0, FutureState::Consumed) else {
53
unreachable!()
54
};
55
56
FutureOutput::Ready(value)
57
}
58
}
59
60
#[async_trait]
61
impl<T> Pollable for WasiFuture<T>
62
where
63
T: Send + 'static,
64
{
65
async fn ready(&mut self) {
66
match &mut self.0 {
67
FutureState::Ready(_) | FutureState::Consumed => return,
68
FutureState::Pending(task) => self.0 = FutureState::Ready(task.as_mut().await),
69
}
70
}
71
}
72
73
pub(crate) struct WasiStreamReader(FutureState<DynInputStream>);
74
impl WasiStreamReader {
75
pub(crate) fn new(stream: DynInputStream) -> Self {
76
Self(FutureState::Ready(stream))
77
}
78
}
79
impl AsyncRead for WasiStreamReader {
80
fn poll_read(
81
mut self: Pin<&mut Self>,
82
cx: &mut std::task::Context<'_>,
83
buf: &mut tokio::io::ReadBuf<'_>,
84
) -> Poll<std::io::Result<()>> {
85
loop {
86
let stream = match &mut self.0 {
87
FutureState::Ready(stream) => stream,
88
FutureState::Pending(fut) => {
89
let stream = ready!(fut.as_mut().poll(cx));
90
self.0 = FutureState::Ready(stream);
91
if let FutureState::Ready(stream) = &mut self.0 {
92
stream
93
} else {
94
unreachable!()
95
}
96
}
97
FutureState::Consumed => {
98
return Poll::Ready(Ok(()));
99
}
100
};
101
match stream.read(buf.remaining()) {
102
Ok(bytes) if bytes.is_empty() => {
103
let FutureState::Ready(mut stream) =
104
std::mem::replace(&mut self.0, FutureState::Consumed)
105
else {
106
unreachable!()
107
};
108
109
self.0 = FutureState::Pending(Box::pin(async move {
110
stream.ready().await;
111
stream
112
}));
113
}
114
Ok(bytes) => {
115
buf.put_slice(&bytes);
116
117
return Poll::Ready(Ok(()));
118
}
119
Err(StreamError::Closed) => {
120
self.0 = FutureState::Consumed;
121
return Poll::Ready(Ok(()));
122
}
123
Err(e) => {
124
self.0 = FutureState::Consumed;
125
return Poll::Ready(Err(std::io::Error::other(e)));
126
}
127
}
128
}
129
}
130
}
131
132
pub(crate) struct WasiStreamWriter(FutureState<DynOutputStream>);
133
impl WasiStreamWriter {
134
pub(crate) fn new(stream: DynOutputStream) -> Self {
135
Self(FutureState::Ready(stream))
136
}
137
}
138
impl AsyncWrite for WasiStreamWriter {
139
fn poll_write(
140
mut self: Pin<&mut Self>,
141
cx: &mut std::task::Context<'_>,
142
buf: &[u8],
143
) -> std::task::Poll<std::result::Result<usize, std::io::Error>> {
144
loop {
145
match &mut self.as_mut().0 {
146
FutureState::Consumed => unreachable!(),
147
FutureState::Pending(future) => {
148
let value = ready!(future.as_mut().poll(cx));
149
self.as_mut().0 = FutureState::Ready(value);
150
}
151
FutureState::Ready(output) => {
152
match output.check_write() {
153
Ok(0) => {
154
let FutureState::Ready(mut output) =
155
mem::replace(&mut self.as_mut().0, FutureState::Consumed)
156
else {
157
unreachable!()
158
};
159
self.as_mut().0 = FutureState::Pending(Box::pin(async move {
160
output.ready().await;
161
output
162
}));
163
}
164
Ok(count) => {
165
let count = count.min(buf.len());
166
return match output.write(Bytes::copy_from_slice(&buf[..count])) {
167
Ok(()) => Poll::Ready(Ok(count)),
168
Err(StreamError::Closed) => Poll::Ready(Ok(0)),
169
Err(e) => Poll::Ready(Err(std::io::Error::other(e))),
170
};
171
}
172
Err(StreamError::Closed) => {
173
// Our current version of tokio-rustls does not handle returning `Ok(0)` well.
174
// See: https://github.com/rustls/tokio-rustls/issues/92
175
return Poll::Ready(Err(std::io::ErrorKind::WriteZero.into()));
176
}
177
Err(e) => return Poll::Ready(Err(std::io::Error::other(e))),
178
};
179
}
180
}
181
}
182
}
183
184
fn poll_flush(
185
self: Pin<&mut Self>,
186
cx: &mut std::task::Context<'_>,
187
) -> std::task::Poll<std::result::Result<(), std::io::Error>> {
188
self.poll_write(cx, &[]).map(|v| v.map(drop))
189
}
190
191
fn poll_shutdown(
192
self: Pin<&mut Self>,
193
cx: &mut std::task::Context<'_>,
194
) -> std::task::Poll<std::result::Result<(), std::io::Error>> {
195
self.poll_flush(cx)
196
}
197
}
198
199
pub(crate) use wasmtime_wasi::p2::pipe::AsyncReadStream;
200
201
pub(crate) struct AsyncWriteStream<IO>(Arc<Mutex<WriteState<IO>>>);
202
203
impl<IO> AsyncWriteStream<IO>
204
where
205
IO: AsyncWrite + Send + Unpin + 'static,
206
{
207
pub(crate) fn new(io: IO) -> Self {
208
AsyncWriteStream(Arc::new(Mutex::new(WriteState::new(io))))
209
}
210
211
pub(crate) fn close(&mut self) -> wasmtime::Result<()> {
212
self.try_lock()?.close();
213
Ok(())
214
}
215
216
async fn lock(&self) -> tokio::sync::MutexGuard<'_, WriteState<IO>> {
217
self.0.lock().await
218
}
219
220
fn try_lock(&self) -> Result<tokio::sync::MutexGuard<'_, WriteState<IO>>, StreamError> {
221
self.0
222
.try_lock()
223
.map_err(|_| StreamError::trap("concurrent access to resource not supported"))
224
}
225
}
226
impl<IO> Clone for AsyncWriteStream<IO> {
227
fn clone(&self) -> Self {
228
Self(Arc::clone(&self.0))
229
}
230
}
231
232
#[async_trait]
233
impl<IO> OutputStream for AsyncWriteStream<IO>
234
where
235
IO: AsyncWrite + Send + Unpin + 'static,
236
{
237
fn write(&mut self, bytes: bytes::Bytes) -> StreamResult<()> {
238
self.try_lock()?.write(bytes)
239
}
240
241
fn flush(&mut self) -> StreamResult<()> {
242
self.try_lock()?.flush()
243
}
244
245
fn check_write(&mut self) -> StreamResult<usize> {
246
self.try_lock()?.check_write()
247
}
248
249
async fn cancel(&mut self) {
250
self.lock().await.cancel().await
251
}
252
}
253
254
#[async_trait]
255
impl<IO> Pollable for AsyncWriteStream<IO>
256
where
257
IO: AsyncWrite + Send + Unpin + 'static,
258
{
259
async fn ready(&mut self) {
260
self.lock().await.ready().await
261
}
262
}
263
264
enum WriteState<IO> {
265
Ready(IO),
266
Writing(AbortOnDropJoinHandle<io::Result<IO>>),
267
Flushing(AbortOnDropJoinHandle<io::Result<IO>>),
268
Closing(AbortOnDropJoinHandle<io::Result<()>>),
269
Closed,
270
Error(io::Error),
271
}
272
const READY_SIZE: usize = 1024 * 1024 * 1024;
273
274
impl<IO> WriteState<IO>
275
where
276
IO: AsyncWrite + Send + Unpin + 'static,
277
{
278
fn new(stream: IO) -> Self {
279
Self::Ready(stream)
280
}
281
282
fn write(&mut self, mut bytes: bytes::Bytes) -> StreamResult<()> {
283
let WriteState::Ready(_) = self else {
284
return Err(StreamError::Trap(anyhow::anyhow!(
285
"unpermitted: must call check_write first"
286
)));
287
};
288
289
if bytes.is_empty() {
290
return Ok(());
291
}
292
293
let WriteState::Ready(mut stream) = std::mem::replace(self, WriteState::Closed) else {
294
unreachable!()
295
};
296
297
*self = WriteState::Writing(wasmtime_wasi::runtime::spawn(async move {
298
while !bytes.is_empty() {
299
let n = stream.write(&bytes).await?;
300
let _ = bytes.split_to(n);
301
}
302
303
Ok(stream)
304
}));
305
306
Ok(())
307
}
308
309
fn flush(&mut self) -> StreamResult<()> {
310
match self {
311
// Immediately flush:
312
WriteState::Ready(_) => {
313
let WriteState::Ready(mut stream) = std::mem::replace(self, WriteState::Closed)
314
else {
315
unreachable!()
316
};
317
*self = WriteState::Flushing(wasmtime_wasi::runtime::spawn(async move {
318
stream.flush().await?;
319
Ok(stream)
320
}));
321
}
322
323
// Schedule the flush after the current write has finished:
324
WriteState::Writing(_) => {
325
let WriteState::Writing(write) = std::mem::replace(self, WriteState::Closed) else {
326
unreachable!()
327
};
328
*self = WriteState::Flushing(wasmtime_wasi::runtime::spawn(async move {
329
let mut stream = write.await?;
330
stream.flush().await?;
331
Ok(stream)
332
}));
333
}
334
335
WriteState::Flushing(_) | WriteState::Closing(_) | WriteState::Error(_) => {}
336
WriteState::Closed => return Err(StreamError::Closed),
337
}
338
339
Ok(())
340
}
341
342
fn check_write(&mut self) -> StreamResult<usize> {
343
match self {
344
WriteState::Ready(_) => Ok(READY_SIZE),
345
WriteState::Writing(_) => Ok(0),
346
WriteState::Flushing(_) => Ok(0),
347
WriteState::Closing(_) => Ok(0),
348
WriteState::Closed => Err(StreamError::Closed),
349
WriteState::Error(_) => {
350
let WriteState::Error(e) = std::mem::replace(self, WriteState::Closed) else {
351
unreachable!()
352
};
353
354
Err(StreamError::LastOperationFailed(e.into()))
355
}
356
}
357
}
358
359
fn close(&mut self) {
360
match std::mem::replace(self, WriteState::Closed) {
361
// No write in progress, immediately shut down:
362
WriteState::Ready(mut stream) => {
363
*self = WriteState::Closing(wasmtime_wasi::runtime::spawn(async move {
364
stream.shutdown().await
365
}));
366
}
367
368
// Schedule the shutdown after the current operation has finished:
369
WriteState::Writing(op) | WriteState::Flushing(op) => {
370
*self = WriteState::Closing(wasmtime_wasi::runtime::spawn(async move {
371
let mut stream = op.await?;
372
stream.shutdown().await
373
}));
374
}
375
376
WriteState::Closing(t) => {
377
*self = WriteState::Closing(t);
378
}
379
WriteState::Closed | WriteState::Error(_) => {}
380
}
381
}
382
383
async fn cancel(&mut self) {
384
match std::mem::replace(self, WriteState::Closed) {
385
WriteState::Writing(task) | WriteState::Flushing(task) => _ = task.cancel().await,
386
WriteState::Closing(task) => _ = task.cancel().await,
387
_ => {}
388
}
389
}
390
391
async fn ready(&mut self) {
392
match self {
393
WriteState::Writing(task) | WriteState::Flushing(task) => {
394
*self = match task.await {
395
Ok(s) => WriteState::Ready(s),
396
Err(e) => WriteState::Error(e),
397
}
398
}
399
WriteState::Closing(task) => {
400
*self = match task.await {
401
Ok(()) => WriteState::Closed,
402
Err(e) => WriteState::Error(e),
403
}
404
}
405
_ => {}
406
}
407
}
408
}
409
410