Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
bytecodealliance
GitHub Repository: bytecodealliance/wasmtime
Path: blob/main/crates/wasi/src/p2/write_stream.rs
1692 views
1
use crate::p2::{OutputStream, Pollable, StreamError};
2
use anyhow::anyhow;
3
use bytes::Bytes;
4
use std::pin::pin;
5
use std::sync::{Arc, Mutex};
6
use std::task::{Context, Poll, Waker};
7
8
#[derive(Debug)]
9
struct WorkerState {
10
alive: bool,
11
items: std::collections::VecDeque<Bytes>,
12
write_budget: usize,
13
flush_pending: bool,
14
error: Option<anyhow::Error>,
15
write_ready_changed: Option<Waker>,
16
}
17
18
impl WorkerState {
19
fn check_error(&mut self) -> Result<(), StreamError> {
20
if let Some(e) = self.error.take() {
21
return Err(StreamError::LastOperationFailed(e));
22
}
23
if !self.alive {
24
return Err(StreamError::Closed);
25
}
26
Ok(())
27
}
28
}
29
30
struct Worker {
31
state: Mutex<WorkerState>,
32
new_work: tokio::sync::Notify,
33
}
34
35
enum Job {
36
Flush,
37
Write(Bytes),
38
}
39
40
impl Worker {
41
fn new(write_budget: usize) -> Self {
42
Self {
43
state: Mutex::new(WorkerState {
44
alive: true,
45
items: std::collections::VecDeque::new(),
46
write_budget,
47
flush_pending: false,
48
error: None,
49
write_ready_changed: None,
50
}),
51
new_work: tokio::sync::Notify::new(),
52
}
53
}
54
fn check_write(&self) -> Result<usize, StreamError> {
55
let mut state = self.state();
56
if let Err(e) = state.check_error() {
57
return Err(e);
58
}
59
60
if state.flush_pending || state.write_budget == 0 {
61
return Ok(0);
62
}
63
64
Ok(state.write_budget)
65
}
66
fn state(&self) -> std::sync::MutexGuard<'_, WorkerState> {
67
self.state.lock().unwrap()
68
}
69
fn pop(&self) -> Option<Job> {
70
let mut state = self.state();
71
if state.items.is_empty() {
72
if state.flush_pending {
73
return Some(Job::Flush);
74
}
75
} else if let Some(bytes) = state.items.pop_front() {
76
return Some(Job::Write(bytes));
77
}
78
79
None
80
}
81
fn report_error(&self, e: std::io::Error) {
82
let waker = {
83
let mut state = self.state();
84
state.alive = false;
85
state.error = Some(e.into());
86
state.flush_pending = false;
87
state.write_ready_changed.take()
88
};
89
if let Some(waker) = waker {
90
waker.wake();
91
}
92
}
93
async fn work<T: tokio::io::AsyncWrite + Send + 'static>(&self, writer: T) {
94
use tokio::io::AsyncWriteExt;
95
let mut writer = pin!(writer);
96
loop {
97
while let Some(job) = self.pop() {
98
match job {
99
Job::Flush => {
100
if let Err(e) = writer.flush().await {
101
self.report_error(e);
102
return;
103
}
104
105
tracing::debug!("worker marking flush complete");
106
self.state().flush_pending = false;
107
}
108
109
Job::Write(mut bytes) => {
110
tracing::debug!("worker writing: {bytes:?}");
111
let len = bytes.len();
112
match writer.write_all_buf(&mut bytes).await {
113
Err(e) => {
114
self.report_error(e);
115
return;
116
}
117
Ok(_) => {
118
self.state().write_budget += len;
119
}
120
}
121
}
122
}
123
124
let waker = self.state().write_ready_changed.take();
125
if let Some(waker) = waker {
126
waker.wake();
127
}
128
}
129
self.new_work.notified().await;
130
}
131
}
132
}
133
134
/// Provides a [`OutputStream`] impl from a [`tokio::io::AsyncWrite`] impl
135
pub struct AsyncWriteStream {
136
worker: Arc<Worker>,
137
join_handle: Option<crate::runtime::AbortOnDropJoinHandle<()>>,
138
}
139
140
impl AsyncWriteStream {
141
/// Create a [`AsyncWriteStream`]. In order to use the [`OutputStream`] impl
142
/// provided by this struct, the argument must impl [`tokio::io::AsyncWrite`].
143
pub fn new<T: tokio::io::AsyncWrite + Send + 'static>(write_budget: usize, writer: T) -> Self {
144
let worker = Arc::new(Worker::new(write_budget));
145
146
let w = Arc::clone(&worker);
147
let join_handle = crate::runtime::spawn(async move { w.work(writer).await });
148
149
AsyncWriteStream {
150
worker,
151
join_handle: Some(join_handle),
152
}
153
}
154
155
pub(crate) fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()> {
156
let mut state = self.worker.state();
157
if state.error.is_some() || !state.alive || (!state.flush_pending && state.write_budget > 0)
158
{
159
return Poll::Ready(());
160
}
161
state.write_ready_changed = Some(cx.waker().clone());
162
Poll::Pending
163
}
164
}
165
166
#[async_trait::async_trait]
167
impl OutputStream for AsyncWriteStream {
168
fn write(&mut self, bytes: Bytes) -> Result<(), StreamError> {
169
let mut state = self.worker.state();
170
state.check_error()?;
171
if state.flush_pending {
172
return Err(StreamError::Trap(anyhow!(
173
"write not permitted while flush pending"
174
)));
175
}
176
match state.write_budget.checked_sub(bytes.len()) {
177
Some(remaining_budget) => {
178
state.write_budget = remaining_budget;
179
state.items.push_back(bytes);
180
}
181
None => return Err(StreamError::Trap(anyhow!("write exceeded budget"))),
182
}
183
drop(state);
184
self.worker.new_work.notify_one();
185
Ok(())
186
}
187
fn flush(&mut self) -> Result<(), StreamError> {
188
let mut state = self.worker.state();
189
state.check_error()?;
190
191
state.flush_pending = true;
192
self.worker.new_work.notify_one();
193
194
Ok(())
195
}
196
197
fn check_write(&mut self) -> Result<usize, StreamError> {
198
self.worker.check_write()
199
}
200
201
async fn cancel(&mut self) {
202
match self.join_handle.take() {
203
Some(task) => _ = task.cancel().await,
204
None => {}
205
}
206
}
207
}
208
#[async_trait::async_trait]
209
impl Pollable for AsyncWriteStream {
210
async fn ready(&mut self) {
211
std::future::poll_fn(|cx| self.poll_ready(cx)).await
212
}
213
}
214
215