Path: blob/main/crates/wasi/src/p2/write_stream.rs
1692 views
use crate::p2::{OutputStream, Pollable, StreamError};1use anyhow::anyhow;2use bytes::Bytes;3use std::pin::pin;4use std::sync::{Arc, Mutex};5use std::task::{Context, Poll, Waker};67#[derive(Debug)]8struct WorkerState {9alive: bool,10items: std::collections::VecDeque<Bytes>,11write_budget: usize,12flush_pending: bool,13error: Option<anyhow::Error>,14write_ready_changed: Option<Waker>,15}1617impl WorkerState {18fn check_error(&mut self) -> Result<(), StreamError> {19if let Some(e) = self.error.take() {20return Err(StreamError::LastOperationFailed(e));21}22if !self.alive {23return Err(StreamError::Closed);24}25Ok(())26}27}2829struct Worker {30state: Mutex<WorkerState>,31new_work: tokio::sync::Notify,32}3334enum Job {35Flush,36Write(Bytes),37}3839impl Worker {40fn new(write_budget: usize) -> Self {41Self {42state: Mutex::new(WorkerState {43alive: true,44items: std::collections::VecDeque::new(),45write_budget,46flush_pending: false,47error: None,48write_ready_changed: None,49}),50new_work: tokio::sync::Notify::new(),51}52}53fn check_write(&self) -> Result<usize, StreamError> {54let mut state = self.state();55if let Err(e) = state.check_error() {56return Err(e);57}5859if state.flush_pending || state.write_budget == 0 {60return Ok(0);61}6263Ok(state.write_budget)64}65fn state(&self) -> std::sync::MutexGuard<'_, WorkerState> {66self.state.lock().unwrap()67}68fn pop(&self) -> Option<Job> {69let mut state = self.state();70if state.items.is_empty() {71if state.flush_pending {72return Some(Job::Flush);73}74} else if let Some(bytes) = state.items.pop_front() {75return Some(Job::Write(bytes));76}7778None79}80fn report_error(&self, e: std::io::Error) {81let waker = {82let mut state = self.state();83state.alive = false;84state.error = Some(e.into());85state.flush_pending = false;86state.write_ready_changed.take()87};88if let Some(waker) = waker {89waker.wake();90}91}92async fn work<T: tokio::io::AsyncWrite + Send + 'static>(&self, writer: T) {93use tokio::io::AsyncWriteExt;94let mut writer = pin!(writer);95loop {96while let Some(job) = self.pop() {97match job {98Job::Flush => {99if let Err(e) = writer.flush().await {100self.report_error(e);101return;102}103104tracing::debug!("worker marking flush complete");105self.state().flush_pending = false;106}107108Job::Write(mut bytes) => {109tracing::debug!("worker writing: {bytes:?}");110let len = bytes.len();111match writer.write_all_buf(&mut bytes).await {112Err(e) => {113self.report_error(e);114return;115}116Ok(_) => {117self.state().write_budget += len;118}119}120}121}122123let waker = self.state().write_ready_changed.take();124if let Some(waker) = waker {125waker.wake();126}127}128self.new_work.notified().await;129}130}131}132133/// Provides a [`OutputStream`] impl from a [`tokio::io::AsyncWrite`] impl134pub struct AsyncWriteStream {135worker: Arc<Worker>,136join_handle: Option<crate::runtime::AbortOnDropJoinHandle<()>>,137}138139impl AsyncWriteStream {140/// Create a [`AsyncWriteStream`]. In order to use the [`OutputStream`] impl141/// provided by this struct, the argument must impl [`tokio::io::AsyncWrite`].142pub fn new<T: tokio::io::AsyncWrite + Send + 'static>(write_budget: usize, writer: T) -> Self {143let worker = Arc::new(Worker::new(write_budget));144145let w = Arc::clone(&worker);146let join_handle = crate::runtime::spawn(async move { w.work(writer).await });147148AsyncWriteStream {149worker,150join_handle: Some(join_handle),151}152}153154pub(crate) fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()> {155let mut state = self.worker.state();156if state.error.is_some() || !state.alive || (!state.flush_pending && state.write_budget > 0)157{158return Poll::Ready(());159}160state.write_ready_changed = Some(cx.waker().clone());161Poll::Pending162}163}164165#[async_trait::async_trait]166impl OutputStream for AsyncWriteStream {167fn write(&mut self, bytes: Bytes) -> Result<(), StreamError> {168let mut state = self.worker.state();169state.check_error()?;170if state.flush_pending {171return Err(StreamError::Trap(anyhow!(172"write not permitted while flush pending"173)));174}175match state.write_budget.checked_sub(bytes.len()) {176Some(remaining_budget) => {177state.write_budget = remaining_budget;178state.items.push_back(bytes);179}180None => return Err(StreamError::Trap(anyhow!("write exceeded budget"))),181}182drop(state);183self.worker.new_work.notify_one();184Ok(())185}186fn flush(&mut self) -> Result<(), StreamError> {187let mut state = self.worker.state();188state.check_error()?;189190state.flush_pending = true;191self.worker.new_work.notify_one();192193Ok(())194}195196fn check_write(&mut self) -> Result<usize, StreamError> {197self.worker.check_write()198}199200async fn cancel(&mut self) {201match self.join_handle.take() {202Some(task) => _ = task.cancel().await,203None => {}204}205}206}207#[async_trait::async_trait]208impl Pollable for AsyncWriteStream {209async fn ready(&mut self) {210std::future::poll_fn(|cx| self.poll_ready(cx)).await211}212}213214215