Path: blob/main/crates/wasi/src/cli/locked_async.rs
1693 views
use crate::cli::{IsTerminal, StdinStream, StdoutStream};1use crate::p2;2use bytes::Bytes;3use std::mem;4use std::pin::Pin;5use std::sync::Arc;6use std::task::{Context, Poll, ready};7use tokio::io::{self, AsyncRead, AsyncWrite};8use tokio::sync::{Mutex, OwnedMutexGuard};9use wasmtime_wasi_io::streams::{InputStream, OutputStream};1011trait SharedHandleReady: Send + Sync + 'static {12fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()>;13}1415impl SharedHandleReady for p2::pipe::AsyncWriteStream {16fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()> {17<Self>::poll_ready(self, cx)18}19}2021impl SharedHandleReady for p2::pipe::AsyncReadStream {22fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()> {23<Self>::poll_ready(self, cx)24}25}2627/// An impl of [`StdinStream`] built on top of [`AsyncRead`].28//29// Note the usage of `tokio::sync::Mutex` here as opposed to a30// `std::sync::Mutex`. This is intentionally done to implement the `Pollable`31// variant of this trait. Note that in doing so we're left with the quandry of32// how to implement methods of `InputStream` since those methods are not33// `async`. They're currently implemented with `try_lock`, which then raises the34// question of what to do on contention. Currently traps are returned.35//36// Why should it be ok to return a trap? In general concurrency/contention37// shouldn't return a trap since it should be able to happen normally. The38// current assumption, though, is that WASI stdin/stdout streams are special39// enough that the contention case should never come up in practice. Currently40// in WASI there is no actually concurrency, there's just the items in a single41// `Store` and that store owns all of its I/O in a single Tokio task. There's no42// means to actually spawn multiple Tokio tasks that use the same store. This43// means at the very least that there's zero parallelism. Due to the lack of44// multiple tasks that also means that there's no concurrency either.45//46// This `AsyncStdinStream` wrapper is only intended to be used by the WASI47// bindings themselves. It's possible for the host to take this and work with it48// on its own task, but that's niche enough it's not designed for.49//50// Overall that means that the guest is either calling `Pollable` or51// `InputStream` methods. This means that there should never be contention52// between the two at this time. This may all change in the future with WASI53// 0.3, but perhaps we'll have a better story for stdio at that time (see the54// doc block on the `OutputStream` impl below)55pub struct AsyncStdinStream(Arc<Mutex<p2::pipe::AsyncReadStream>>);5657impl AsyncStdinStream {58pub fn new(s: impl AsyncRead + Send + Sync + 'static) -> Self {59Self(Arc::new(Mutex::new(p2::pipe::AsyncReadStream::new(s))))60}61}6263impl StdinStream for AsyncStdinStream {64fn p2_stream(&self) -> Box<dyn InputStream> {65Box::new(Self(self.0.clone()))66}67fn async_stream(&self) -> Box<dyn AsyncRead + Send + Sync> {68Box::new(StdioHandle::Ready(self.0.clone()))69}70}7172impl IsTerminal for AsyncStdinStream {73fn is_terminal(&self) -> bool {74false75}76}7778#[async_trait::async_trait]79impl InputStream for AsyncStdinStream {80fn read(&mut self, size: usize) -> Result<bytes::Bytes, p2::StreamError> {81match self.0.try_lock() {82Ok(mut stream) => stream.read(size),83Err(_) => Err(p2::StreamError::trap("concurrent reads are not supported")),84}85}86fn skip(&mut self, size: usize) -> Result<usize, p2::StreamError> {87match self.0.try_lock() {88Ok(mut stream) => stream.skip(size),89Err(_) => Err(p2::StreamError::trap("concurrent skips are not supported")),90}91}92async fn cancel(&mut self) {93// Cancel the inner stream if we're the last reference to it:94if let Some(mutex) = Arc::get_mut(&mut self.0) {95match mutex.try_lock() {96Ok(mut stream) => stream.cancel().await,97Err(_) => {}98}99}100}101}102103#[async_trait::async_trait]104impl p2::Pollable for AsyncStdinStream {105async fn ready(&mut self) {106self.0.lock().await.ready().await107}108}109110impl AsyncRead for StdioHandle<p2::pipe::AsyncReadStream> {111fn poll_read(112mut self: Pin<&mut Self>,113cx: &mut Context<'_>,114buf: &mut io::ReadBuf<'_>,115) -> Poll<io::Result<()>> {116match ready!(self.as_mut().poll(cx, |g| g.read(buf.remaining()))) {117Some(Ok(bytes)) => {118buf.put_slice(&bytes);119Poll::Ready(Ok(()))120}121Some(Err(e)) => Poll::Ready(Err(e)),122// If the guard can't be acquired that means that this stream is123// closed, so return that we're ready without filling in data.124None => Poll::Ready(Ok(())),125}126}127}128129/// A wrapper of [`crate::p2::pipe::AsyncWriteStream`] that implements130/// [`StdoutStream`]. Note that the [`OutputStream`] impl for this is not131/// correct when used for interleaved async IO.132//133// Note that the use of `tokio::sync::Mutex` here is intentional, in addition to134// the `try_lock()` calls below in the implementation of `OutputStream`. For135// more information see the documentation on `AsyncStdinStream`.136pub struct AsyncStdoutStream(Arc<Mutex<p2::pipe::AsyncWriteStream>>);137138impl AsyncStdoutStream {139pub fn new(budget: usize, s: impl AsyncWrite + Send + Sync + 'static) -> Self {140Self(Arc::new(Mutex::new(p2::pipe::AsyncWriteStream::new(141budget, s,142))))143}144}145146impl StdoutStream for AsyncStdoutStream {147fn p2_stream(&self) -> Box<dyn OutputStream> {148Box::new(Self(self.0.clone()))149}150fn async_stream(&self) -> Box<dyn AsyncWrite + Send + Sync> {151Box::new(StdioHandle::Ready(self.0.clone()))152}153}154155impl IsTerminal for AsyncStdoutStream {156fn is_terminal(&self) -> bool {157false158}159}160161// This implementation is known to be bogus. All check-writes and writes are162// directed at the same underlying stream. The check-write/write protocol does163// require the size returned by a check-write to be accepted by write, even if164// other side-effects happen between those calls, and this implementation165// permits another view (created by StdoutStream::stream()) of the same166// underlying stream to accept a write which will invalidate a prior167// check-write of another view.168// Ultimately, the Std{in,out}Stream::stream() methods exist because many169// different places in a linked component (which may itself contain many170// modules) may need to access stdio without any coordination to keep those171// accesses all using pointing to the same resource. So, we allow many172// resources to be created. We have the reasonable expectation that programs173// won't attempt to interleave async IO from these disparate uses of stdio.174// If that expectation doesn't turn out to be true, and you find yourself at175// this comment to correct it: sorry about that.176#[async_trait::async_trait]177impl OutputStream for AsyncStdoutStream {178fn check_write(&mut self) -> Result<usize, p2::StreamError> {179match self.0.try_lock() {180Ok(mut stream) => stream.check_write(),181Err(_) => Err(p2::StreamError::trap("concurrent writes are not supported")),182}183}184fn write(&mut self, bytes: Bytes) -> Result<(), p2::StreamError> {185match self.0.try_lock() {186Ok(mut stream) => stream.write(bytes),187Err(_) => Err(p2::StreamError::trap("concurrent writes not supported yet")),188}189}190fn flush(&mut self) -> Result<(), p2::StreamError> {191match self.0.try_lock() {192Ok(mut stream) => stream.flush(),193Err(_) => Err(p2::StreamError::trap(194"concurrent flushes not supported yet",195)),196}197}198async fn cancel(&mut self) {199// Cancel the inner stream if we're the last reference to it:200if let Some(mutex) = Arc::get_mut(&mut self.0) {201match mutex.try_lock() {202Ok(mut stream) => stream.cancel().await,203Err(_) => {}204}205}206}207}208209#[async_trait::async_trait]210impl p2::Pollable for AsyncStdoutStream {211async fn ready(&mut self) {212self.0.lock().await.ready().await213}214}215216impl AsyncWrite for StdioHandle<p2::pipe::AsyncWriteStream> {217fn poll_write(218self: Pin<&mut Self>,219cx: &mut Context<'_>,220buf: &[u8],221) -> Poll<io::Result<usize>> {222match ready!(self.poll(cx, |i| i.write(Bytes::copy_from_slice(buf)))) {223Some(Ok(())) => Poll::Ready(Ok(buf.len())),224Some(Err(e)) => Poll::Ready(Err(e)),225None => Poll::Ready(Ok(0)),226}227}228fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {229match ready!(self.poll(cx, |i| i.flush())) {230Some(result) => Poll::Ready(result),231None => Poll::Ready(Ok(())),232}233}234fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {235Poll::Ready(Ok(()))236}237}238239/// State necessary for effectively transforming `Arc<Mutex<dyn240/// {Input,Output}Stream>>` into `Async{Read,Write}`.241///242/// This is a beast and inefficient. It should get the job done in theory but243/// one must truly ask oneself at some point "but at what cost".244///245/// More seriously, it's unclear if this is the best way to transform a single246/// `AsyncRead` into a "multiple `AsyncRead`". This certainly is an attempt and247/// the hope is that everything here is private enough that we can refactor as248/// necessary in the future without causing much churn.249enum StdioHandle<S> {250Ready(Arc<Mutex<S>>),251Locking(Box<dyn Future<Output = OwnedMutexGuard<S>> + Send + Sync>),252Locked(OwnedMutexGuard<S>),253Closed,254}255256impl<S> StdioHandle<S>257where258S: SharedHandleReady,259{260fn poll<T>(261mut self: Pin<&mut Self>,262cx: &mut Context<'_>,263op: impl FnOnce(&mut S) -> p2::StreamResult<T>,264) -> Poll<Option<io::Result<T>>> {265// If we don't currently have the lock on this handle, initiate the266// lock acquisition.267if let StdioHandle::Ready(lock) = &*self {268self.set(StdioHandle::Locking(Box::new(lock.clone().lock_owned())));269}270271// If we're in the process of locking this handle, wait for that to272// finish.273if let Some(lock) = self.as_mut().as_locking() {274let guard = ready!(lock.poll(cx));275self.set(StdioHandle::Locked(guard));276}277278let mut guard = match self.as_mut().take_guard() {279Some(guard) => guard,280// If the guard can't be acquired that means that this stream is281// closed, so return that we're ready without filling in data.282None => return Poll::Ready(None),283};284285// Wait for our locked stream to be ready, resetting to the "locked"286// state if it's not quite ready yet.287match guard.poll_ready(cx) {288Poll::Ready(()) => {}289290// If the read isn't ready yet then restore our "locked" state291// since we haven't finished, then return pending.292Poll::Pending => {293self.set(StdioHandle::Locked(guard));294return Poll::Pending;295}296}297298// Perform the I/O and delegate on the result.299match op(&mut guard) {300// The I/O succeeded so relinquish the lock on this stream by301// transitioning back to the "Ready" state.302Ok(result) => {303self.set(StdioHandle::Ready(OwnedMutexGuard::mutex(&guard).clone()));304Poll::Ready(Some(Ok(result)))305}306307// The stream is closed, and `take_guard` above already set the308// closed state, so return nothing indicating the closure.309Err(p2::StreamError::Closed) => Poll::Ready(None),310311// The stream failed so propagate the error. Errors should only312// come from the underlying I/O object and thus should cast313// successfully. Additionally `take_guard` replaced our state314// with "closed" above which is the desired state at this point.315Err(p2::StreamError::LastOperationFailed(e)) => {316Poll::Ready(Some(Err(e.downcast().unwrap())))317}318319// Shouldn't be possible to produce a trap here.320Err(p2::StreamError::Trap(_)) => unreachable!(),321}322}323324fn as_locking(325self: Pin<&mut Self>,326) -> Option<Pin<&mut dyn Future<Output = OwnedMutexGuard<S>>>> {327// SAFETY: this is a pin-projection from `self` into the `Locking`328// field.329unsafe {330match self.get_unchecked_mut() {331StdioHandle::Locking(future) => Some(Pin::new_unchecked(&mut **future)),332_ => None,333}334}335}336337fn take_guard(self: Pin<&mut Self>) -> Option<OwnedMutexGuard<S>> {338if !matches!(*self, StdioHandle::Locked(_)) {339return None;340}341// SAFETY: the `Locked` arm is safe to move as it's an invariant of this342// type that it's not pinned.343unsafe {344match mem::replace(self.get_unchecked_mut(), StdioHandle::Closed) {345StdioHandle::Locked(guard) => Some(guard),346_ => unreachable!(),347}348}349}350}351352353