use anyhow::anyhow;
use bytes::Bytes;
use std::pin::{Pin, pin};
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use tokio::io::{self, AsyncRead, AsyncWrite};
use tokio::sync::mpsc;
use wasmtime_wasi_io::{
poll::Pollable,
streams::{InputStream, OutputStream, StreamError},
};
pub use crate::p2::write_stream::AsyncWriteStream;
#[derive(Debug, Clone)]
pub struct MemoryInputPipe {
buffer: Arc<Mutex<Bytes>>,
}
impl MemoryInputPipe {
pub fn new(bytes: impl Into<Bytes>) -> Self {
Self {
buffer: Arc::new(Mutex::new(bytes.into())),
}
}
pub fn is_empty(&self) -> bool {
self.buffer.lock().unwrap().is_empty()
}
}
#[async_trait::async_trait]
impl InputStream for MemoryInputPipe {
fn read(&mut self, size: usize) -> Result<Bytes, StreamError> {
let mut buffer = self.buffer.lock().unwrap();
if buffer.is_empty() {
return Err(StreamError::Closed);
}
let size = size.min(buffer.len());
let read = buffer.split_to(size);
Ok(read)
}
}
#[async_trait::async_trait]
impl Pollable for MemoryInputPipe {
async fn ready(&mut self) {}
}
impl AsyncRead for MemoryInputPipe {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &mut io::ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let mut buffer = self.buffer.lock().unwrap();
let size = buf.remaining().min(buffer.len());
let read = buffer.split_to(size);
buf.put_slice(&read);
Poll::Ready(Ok(()))
}
}
#[derive(Debug, Clone)]
pub struct MemoryOutputPipe {
capacity: usize,
buffer: Arc<Mutex<bytes::BytesMut>>,
}
impl MemoryOutputPipe {
pub fn new(capacity: usize) -> Self {
MemoryOutputPipe {
capacity,
buffer: std::sync::Arc::new(std::sync::Mutex::new(bytes::BytesMut::new())),
}
}
pub fn contents(&self) -> bytes::Bytes {
self.buffer.lock().unwrap().clone().freeze()
}
pub fn try_into_inner(self) -> Option<bytes::BytesMut> {
std::sync::Arc::into_inner(self.buffer).map(|m| m.into_inner().unwrap())
}
}
#[async_trait::async_trait]
impl OutputStream for MemoryOutputPipe {
fn write(&mut self, bytes: Bytes) -> Result<(), StreamError> {
let mut buf = self.buffer.lock().unwrap();
if bytes.len() > self.capacity - buf.len() {
return Err(StreamError::Trap(anyhow!(
"write beyond capacity of MemoryOutputPipe"
)));
}
buf.extend_from_slice(bytes.as_ref());
Ok(())
}
fn flush(&mut self) -> Result<(), StreamError> {
Ok(())
}
fn check_write(&mut self) -> Result<usize, StreamError> {
let consumed = self.buffer.lock().unwrap().len();
if consumed < self.capacity {
Ok(self.capacity - consumed)
} else {
Err(StreamError::Closed)
}
}
}
#[async_trait::async_trait]
impl Pollable for MemoryOutputPipe {
async fn ready(&mut self) {}
}
impl AsyncWrite for MemoryOutputPipe {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let mut buffer = self.buffer.lock().unwrap();
let amt = buf.len().min(self.capacity - buffer.len());
buffer.extend_from_slice(&buf[..amt]);
Poll::Ready(Ok(amt))
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}
pub struct AsyncReadStream {
closed: bool,
buffer: Option<Result<Bytes, StreamError>>,
receiver: mpsc::Receiver<Result<Bytes, StreamError>>,
join_handle: Option<crate::runtime::AbortOnDropJoinHandle<()>>,
}
impl AsyncReadStream {
pub fn new<T: AsyncRead + Send + 'static>(reader: T) -> Self {
let (sender, receiver) = mpsc::channel(1);
let join_handle = crate::runtime::spawn(async move {
let mut reader = pin!(reader);
loop {
use tokio::io::AsyncReadExt;
let mut buf = bytes::BytesMut::with_capacity(4096);
let sent = match reader.read_buf(&mut buf).await {
Ok(nbytes) if nbytes == 0 => sender.send(Err(StreamError::Closed)).await,
Ok(_) => sender.send(Ok(buf.freeze())).await,
Err(e) => {
sender
.send(Err(StreamError::LastOperationFailed(e.into())))
.await
}
};
if sent.is_err() {
break;
}
}
});
AsyncReadStream {
closed: false,
buffer: None,
receiver,
join_handle: Some(join_handle),
}
}
pub(crate) fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()> {
if self.buffer.is_some() || self.closed {
return Poll::Ready(());
}
match self.receiver.poll_recv(cx) {
Poll::Ready(Some(res)) => {
self.buffer = Some(res);
Poll::Ready(())
}
Poll::Ready(None) => {
panic!("no more sender for an open AsyncReadStream - should be impossible")
}
Poll::Pending => Poll::Pending,
}
}
}
#[async_trait::async_trait]
impl InputStream for AsyncReadStream {
fn read(&mut self, size: usize) -> Result<Bytes, StreamError> {
use mpsc::error::TryRecvError;
match self.buffer.take() {
Some(Ok(mut bytes)) => {
let len = bytes.len().min(size);
let rest = bytes.split_off(len);
if !rest.is_empty() {
self.buffer = Some(Ok(rest));
}
return Ok(bytes);
}
Some(Err(e)) => {
self.closed = true;
return Err(e);
}
None => {}
}
match self.receiver.try_recv() {
Ok(Ok(mut bytes)) => {
let len = bytes.len().min(size);
let rest = bytes.split_off(len);
if !rest.is_empty() {
self.buffer = Some(Ok(rest));
}
Ok(bytes)
}
Ok(Err(e)) => {
self.closed = true;
Err(e)
}
Err(TryRecvError::Empty) => Ok(Bytes::new()),
Err(TryRecvError::Disconnected) => Err(StreamError::Trap(anyhow!(
"AsyncReadStream sender died - should be impossible"
))),
}
}
async fn cancel(&mut self) {
match self.join_handle.take() {
Some(task) => _ = task.cancel().await,
None => {}
}
}
}
#[async_trait::async_trait]
impl Pollable for AsyncReadStream {
async fn ready(&mut self) {
std::future::poll_fn(|cx| self.poll_ready(cx)).await
}
}
#[derive(Copy, Clone)]
pub struct SinkOutputStream;
#[async_trait::async_trait]
impl OutputStream for SinkOutputStream {
fn write(&mut self, _buf: Bytes) -> Result<(), StreamError> {
Ok(())
}
fn flush(&mut self) -> Result<(), StreamError> {
Ok(())
}
fn check_write(&mut self) -> Result<usize, StreamError> {
Ok(usize::MAX)
}
}
#[async_trait::async_trait]
impl Pollable for SinkOutputStream {
async fn ready(&mut self) {}
}
#[derive(Copy, Clone)]
pub struct ClosedInputStream;
#[async_trait::async_trait]
impl InputStream for ClosedInputStream {
fn read(&mut self, _size: usize) -> Result<Bytes, StreamError> {
Err(StreamError::Closed)
}
}
#[async_trait::async_trait]
impl Pollable for ClosedInputStream {
async fn ready(&mut self) {}
}
#[derive(Copy, Clone)]
pub struct ClosedOutputStream;
#[async_trait::async_trait]
impl OutputStream for ClosedOutputStream {
fn write(&mut self, _: Bytes) -> Result<(), StreamError> {
Err(StreamError::Closed)
}
fn flush(&mut self) -> Result<(), StreamError> {
Err(StreamError::Closed)
}
fn check_write(&mut self) -> Result<usize, StreamError> {
Err(StreamError::Closed)
}
}
#[async_trait::async_trait]
impl Pollable for ClosedOutputStream {
async fn ready(&mut self) {}
}
#[cfg(test)]
mod test {
use super::*;
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
#[cfg(not(target_arch = "x86_64"))]
const TEST_ITERATIONS: usize = 10;
#[cfg(target_arch = "x86_64")]
const TEST_ITERATIONS: usize = 100;
async fn resolves_immediately<F, O>(fut: F) -> O
where
F: futures::Future<Output = O>,
{
tokio::time::timeout(Duration::from_secs(2), fut)
.await
.expect("operation timed out")
}
async fn never_resolves<F: futures::Future>(fut: F) {
tokio::time::timeout(Duration::from_millis(10), fut)
.await
.err()
.expect("operation should time out");
}
pub fn simplex(size: usize) -> (impl AsyncRead, impl AsyncWrite) {
let (a, b) = tokio::io::duplex(size);
let (_read_half, write_half) = tokio::io::split(a);
let (read_half, _write_half) = tokio::io::split(b);
(read_half, write_half)
}
#[test_log::test(tokio::test(flavor = "multi_thread"))]
async fn empty_read_stream() {
let mut reader = AsyncReadStream::new(tokio::io::empty());
match reader.read(10) {
Err(StreamError::Closed) => {}
Ok(bs) => {
assert!(bs.is_empty());
resolves_immediately(reader.ready()).await;
assert!(matches!(reader.read(0), Err(StreamError::Closed)));
}
res => panic!("unexpected: {res:?}"),
}
}
#[test_log::test(tokio::test(flavor = "multi_thread"))]
async fn infinite_read_stream() {
let mut reader = AsyncReadStream::new(tokio::io::repeat(0));
let bs = reader.read(10).unwrap();
if bs.is_empty() {
resolves_immediately(reader.ready()).await;
let bs = reader.read(10).unwrap();
assert_eq!(bs.len(), 10);
} else {
assert_eq!(bs.len(), 10);
}
let bs = reader.read(10).unwrap();
assert_eq!(bs.len(), 10);
let bs = reader.read(0).unwrap();
assert_eq!(bs.len(), 0);
}
async fn finite_async_reader(contents: &[u8]) -> impl AsyncRead + Send + 'static + use<> {
let (r, mut w) = simplex(contents.len());
w.write_all(contents).await.unwrap();
r
}
#[test_log::test(tokio::test(flavor = "multi_thread"))]
async fn finite_read_stream() {
let mut reader = AsyncReadStream::new(finite_async_reader(&[1; 123]).await);
let bs = reader.read(123).unwrap();
if bs.is_empty() {
resolves_immediately(reader.ready()).await;
let bs = reader.read(123).unwrap();
assert_eq!(bs.len(), 123);
} else {
assert_eq!(bs.len(), 123);
}
match reader.read(0) {
Err(StreamError::Closed) => {}
Ok(bs) => {
assert!(bs.is_empty());
resolves_immediately(reader.ready()).await;
assert!(matches!(reader.read(0), Err(StreamError::Closed)));
}
res => panic!("unexpected: {res:?}"),
}
}
#[test_log::test(tokio::test(flavor = "multi_thread"))]
async fn multiple_chunks_read_stream() {
let (r, mut w) = simplex(1024);
let mut reader = AsyncReadStream::new(r);
w.write_all(&[123]).await.unwrap();
let bs = reader.read(1).unwrap();
if bs.is_empty() {
resolves_immediately(reader.ready()).await;
let bs = reader.read(1).unwrap();
assert_eq!(*bs, [123u8]);
} else {
assert_eq!(*bs, [123u8]);
}
let bs = reader.read(1).unwrap();
assert!(bs.is_empty());
never_resolves(reader.ready()).await;
let bs = reader.read(1).unwrap();
assert!(bs.is_empty());
w.write_all(&[45]).await.unwrap();
resolves_immediately(reader.ready()).await;
let bs = reader.read(1).unwrap();
assert_eq!(*bs, [45u8]);
let bs = reader.read(1).unwrap();
assert!(bs.is_empty());
never_resolves(reader.ready()).await;
let bs = reader.read(1).unwrap();
assert!(bs.is_empty());
drop(w);
resolves_immediately(reader.ready()).await;
assert!(matches!(reader.read(1), Err(StreamError::Closed)));
}
#[test_log::test(tokio::test(flavor = "multi_thread"))]
async fn backpressure_read_stream() {
let (r, mut w) = simplex(16 * 1024);
let mut reader = AsyncReadStream::new(r);
let writer_task = tokio::task::spawn(async move {
w.write_all(&[123; 8192]).await.unwrap();
w
});
resolves_immediately(reader.ready()).await;
let bs = reader.read(4097).unwrap();
assert_eq!(bs.len(), 4096);
resolves_immediately(reader.ready()).await;
let bs = reader.read(4097).unwrap();
assert_eq!(bs.len(), 4096);
let w = resolves_immediately(writer_task).await;
drop(w);
resolves_immediately(reader.ready()).await;
assert!(matches!(reader.read(4097), Err(StreamError::Closed)));
}
#[test_log::test(test_log::test(tokio::test(flavor = "multi_thread")))]
async fn sink_write_stream() {
let mut writer = AsyncWriteStream::new(2048, tokio::io::sink());
let chunk = Bytes::from_static(&[0; 1024]);
let readiness = resolves_immediately(writer.write_ready())
.await
.expect("write_ready does not trap");
assert_eq!(readiness, 2048);
writer.write(chunk.clone()).expect("write does not error");
let readiness = resolves_immediately(writer.write_ready())
.await
.expect("write_ready does not trap");
assert!(
readiness == 1024 || readiness == 2048,
"readiness should be 1024 or 2048, got {readiness}"
);
if readiness == 1024 {
writer.write(chunk.clone()).expect("write does not error");
let readiness = resolves_immediately(writer.write_ready())
.await
.expect("write_ready does not trap");
assert!(
readiness == 1024 || readiness == 2048,
"readiness should be 1024 or 2048, got {readiness}"
);
}
}
#[test_log::test(tokio::test(flavor = "multi_thread"))]
async fn closed_write_stream() {
for n in 0..TEST_ITERATIONS {
closed_write_stream_(n).await
}
}
#[tracing::instrument]
async fn closed_write_stream_(n: usize) {
let (reader, writer) = simplex(1);
let mut writer = AsyncWriteStream::new(1024, writer);
drop(reader);
let mut should_be_closed = false;
let chunk = Bytes::from_static(&[0; 1]);
writer
.write(chunk.clone())
.expect("first write should succeed");
let mut write_ready_res = None;
if n % 2 == 0 {
let r = resolves_immediately(writer.write_ready()).await;
match r {
Ok(1023) => {}
Err(StreamError::LastOperationFailed(_)) => {
tracing::debug!("discovered stream failure in first write_ready");
should_be_closed = true;
}
r => panic!("unexpected write_ready: {r:?}"),
}
write_ready_res = Some(r);
}
let flush_res = writer.flush();
match flush_res {
Err(StreamError::LastOperationFailed(_)) => {
tracing::debug!("discovered stream failure trying to flush");
assert!(!should_be_closed);
should_be_closed = true;
}
Err(StreamError::Closed) => {
assert!(
should_be_closed,
"expected a LastOperationFailed before we see Closed. {write_ready_res:?}"
);
}
Ok(()) => {}
Err(e) => panic!("unexpected flush error: {e:?} {write_ready_res:?}"),
}
match resolves_immediately(writer.write_ready()).await {
Err(StreamError::LastOperationFailed(_)) => {
tracing::debug!("discovered stream failure trying to flush");
assert!(!should_be_closed);
}
Err(StreamError::Closed) => {
assert!(should_be_closed);
}
r => {
panic!(
"stream should be reported closed by the end of write_ready after flush, got {r:?}. {write_ready_res:?} {flush_res:?}"
)
}
}
}
#[test_log::test(tokio::test(flavor = "multi_thread"))]
async fn multiple_chunks_write_stream() {
for n in 0..TEST_ITERATIONS {
multiple_chunks_write_stream_aux(n).await
}
}
#[tracing::instrument]
async fn multiple_chunks_write_stream_aux(_: usize) {
use std::ops::Deref;
let (mut reader, writer) = simplex(1024);
let mut writer = AsyncWriteStream::new(1024, writer);
let chunk = Bytes::from_static(&[123; 1]);
let permit = resolves_immediately(writer.write_ready())
.await
.expect("write should be ready");
assert_eq!(permit, 1024);
writer.write(chunk.clone()).expect("write does not trap");
let permit = resolves_immediately(writer.write_ready())
.await
.expect("write should be ready");
assert!(matches!(permit, 1023 | 1024));
let mut read_buf = vec![0; chunk.len()];
let read_len = reader.read_exact(&mut read_buf).await.unwrap();
assert_eq!(read_len, chunk.len());
assert_eq!(read_buf.as_slice(), chunk.deref());
let chunk2 = Bytes::from_static(&[45; 1]);
writer.flush().expect("channel is still alive");
let permit = resolves_immediately(writer.write_ready())
.await
.expect("write should be ready");
assert_eq!(permit, 1024);
writer.write(chunk2.clone()).expect("write does not trap");
let permit = resolves_immediately(writer.write_ready())
.await
.expect("write should be ready");
assert!(matches!(permit, 1023 | 1024));
let mut read2_buf = vec![0; chunk2.len()];
let read2_len = reader.read_exact(&mut read2_buf).await.unwrap();
assert_eq!(read2_len, chunk2.len());
assert_eq!(read2_buf.as_slice(), chunk2.deref());
writer.flush().expect("channel is still alive");
let permit = resolves_immediately(writer.write_ready())
.await
.expect("write should be ready");
assert_eq!(permit, 1024);
}
#[test_log::test(tokio::test(flavor = "multi_thread"))]
async fn backpressure_write_stream() {
for n in 0..TEST_ITERATIONS {
backpressure_write_stream_aux(n).await
}
}
#[tracing::instrument]
async fn backpressure_write_stream_aux(_: usize) {
use futures::future::poll_immediate;
let (mut reader, writer) = simplex(1024);
let mut writer = AsyncWriteStream::new(1024, writer);
let chunk = Bytes::from_static(&[0; 1024]);
let permit = resolves_immediately(writer.write_ready())
.await
.expect("write should be ready");
assert_eq!(permit, 1024);
writer.write(chunk.clone()).expect("write succeeds");
let permit = poll_immediate(writer.write_ready()).await;
assert!(matches!(permit, None | Some(Ok(1024))));
let permit = resolves_immediately(writer.write_ready())
.await
.expect("write should be ready");
assert_eq!(permit, 1024);
writer.write(chunk.clone()).expect("write does not trap");
writer
.write(chunk.clone())
.err()
.expect("unpermitted write does trap");
never_resolves(writer.write_ready()).await;
let mut buf = [0; 2048];
reader.read_exact(&mut buf).await.unwrap();
never_resolves(reader.read(&mut buf)).await;
let permit = resolves_immediately(writer.write_ready())
.await
.expect("ready is ok");
assert_eq!(permit, 1024);
writer.write(chunk.clone()).expect("write does not trap");
}
#[test_log::test(tokio::test(flavor = "multi_thread"))]
async fn backpressure_write_stream_with_flush() {
for n in 0..TEST_ITERATIONS {
backpressure_write_stream_with_flush_aux(n).await;
}
}
async fn backpressure_write_stream_with_flush_aux(_: usize) {
let (mut reader, writer) = simplex(1024);
let mut writer = AsyncWriteStream::new(1024, writer);
let chunk = Bytes::from_static(&[0; 1024]);
let permit = resolves_immediately(writer.write_ready())
.await
.expect("write should be ready");
assert_eq!(permit, 1024);
writer.write(chunk.clone()).expect("write succeeds");
writer.flush().expect("flush succeeds");
let permit = resolves_immediately(writer.write_ready())
.await
.expect("write_ready succeeds");
assert_eq!(permit, 1024);
writer.write(chunk.clone()).expect("write does not trap");
writer.flush().expect("flush succeeds");
writer
.write(chunk.clone())
.err()
.expect("unpermitted write does trap");
never_resolves(writer.write_ready()).await;
let mut buf = [0; 2048];
reader.read_exact(&mut buf).await.unwrap();
never_resolves(reader.read(&mut buf)).await;
let permit = resolves_immediately(writer.write_ready())
.await
.expect("ready is ok");
assert_eq!(permit, 1024);
writer.write(chunk.clone()).expect("write does not trap");
writer.flush().expect("flush succeeds");
let permit = resolves_immediately(writer.write_ready())
.await
.expect("ready is ok");
assert_eq!(permit, 1024);
}
}