use crate::{bindings::http::types, types::FieldMap};
use anyhow::anyhow;
use bytes::Bytes;
use http_body::{Body, Frame};
use http_body_util::BodyExt;
use http_body_util::combinators::BoxBody;
use std::future::Future;
use std::mem;
use std::task::{Context, Poll};
use std::{pin::Pin, sync::Arc, time::Duration};
use tokio::sync::{mpsc, oneshot};
use wasmtime_wasi::p2::{InputStream, OutputStream, Pollable, StreamError};
use wasmtime_wasi::runtime::{AbortOnDropJoinHandle, poll_noop};
pub type HyperIncomingBody = BoxBody<Bytes, types::ErrorCode>;
pub type HyperOutgoingBody = BoxBody<Bytes, types::ErrorCode>;
#[derive(Debug)]
pub struct HostIncomingBody {
body: IncomingBodyState,
worker: Option<AbortOnDropJoinHandle<()>>,
}
impl HostIncomingBody {
pub fn new(body: HyperIncomingBody, between_bytes_timeout: Duration) -> HostIncomingBody {
let body = BodyWithTimeout::new(body, between_bytes_timeout);
HostIncomingBody {
body: IncomingBodyState::Start(body),
worker: None,
}
}
pub fn retain_worker(&mut self, worker: AbortOnDropJoinHandle<()>) {
assert!(self.worker.is_none());
self.worker = Some(worker);
}
pub fn take_stream(&mut self) -> Option<HostIncomingBodyStream> {
match &mut self.body {
IncomingBodyState::Start(_) => {}
IncomingBodyState::InBodyStream(_) => return None,
}
let (tx, rx) = oneshot::channel();
let body = match mem::replace(&mut self.body, IncomingBodyState::InBodyStream(rx)) {
IncomingBodyState::Start(b) => b,
IncomingBodyState::InBodyStream(_) => unreachable!(),
};
Some(HostIncomingBodyStream {
state: IncomingBodyStreamState::Open { body, tx },
buffer: Bytes::new(),
error: None,
})
}
pub fn into_future_trailers(self) -> HostFutureTrailers {
HostFutureTrailers::Waiting(self)
}
}
#[derive(Debug)]
enum IncomingBodyState {
Start(BodyWithTimeout),
InBodyStream(oneshot::Receiver<StreamEnd>),
}
#[derive(Debug)]
struct BodyWithTimeout {
inner: HyperIncomingBody,
timeout: Pin<Box<tokio::time::Sleep>>,
reset_sleep: bool,
between_bytes_timeout: Duration,
}
impl BodyWithTimeout {
fn new(inner: HyperIncomingBody, between_bytes_timeout: Duration) -> BodyWithTimeout {
BodyWithTimeout {
inner,
between_bytes_timeout,
reset_sleep: true,
timeout: Box::pin(wasmtime_wasi::runtime::with_ambient_tokio_runtime(|| {
tokio::time::sleep(Duration::new(0, 0))
})),
}
}
}
impl Body for BodyWithTimeout {
type Data = Bytes;
type Error = types::ErrorCode;
fn poll_frame(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Frame<Bytes>, types::ErrorCode>>> {
let me = Pin::into_inner(self);
if me.reset_sleep {
me.timeout
.as_mut()
.reset(tokio::time::Instant::now() + me.between_bytes_timeout);
me.reset_sleep = false;
}
if let Poll::Ready(()) = me.timeout.as_mut().poll(cx) {
return Poll::Ready(Some(Err(types::ErrorCode::ConnectionReadTimeout)));
}
let result = Pin::new(&mut me.inner).poll_frame(cx);
me.reset_sleep = result.is_ready();
result
}
}
#[derive(Debug)]
enum StreamEnd {
Remaining(BodyWithTimeout),
Trailers(Option<FieldMap>),
}
#[derive(Debug)]
pub struct HostIncomingBodyStream {
state: IncomingBodyStreamState,
buffer: Bytes,
error: Option<anyhow::Error>,
}
impl HostIncomingBodyStream {
fn record_frame(&mut self, frame: Option<Result<Frame<Bytes>, types::ErrorCode>>) {
match frame {
Some(Ok(frame)) => match frame.into_data() {
Ok(bytes) => {
assert!(self.buffer.is_empty());
self.buffer = bytes;
}
Err(trailers) => {
let trailers = trailers.into_trailers().unwrap();
let tx = match mem::replace(&mut self.state, IncomingBodyStreamState::Closed) {
IncomingBodyStreamState::Open { body: _, tx } => tx,
IncomingBodyStreamState::Closed => unreachable!(),
};
let _ = tx.send(StreamEnd::Trailers(Some(trailers)));
}
},
Some(Err(e)) => {
self.error = Some(e.into());
self.state = IncomingBodyStreamState::Closed;
}
None => {
self.state = IncomingBodyStreamState::Closed;
}
}
}
}
#[derive(Debug)]
enum IncomingBodyStreamState {
Open {
body: BodyWithTimeout,
tx: oneshot::Sender<StreamEnd>,
},
Closed,
}
#[async_trait::async_trait]
impl InputStream for HostIncomingBodyStream {
fn read(&mut self, size: usize) -> Result<Bytes, StreamError> {
loop {
if !self.buffer.is_empty() {
let len = size.min(self.buffer.len());
let chunk = self.buffer.split_to(len);
return Ok(chunk);
}
if let Some(e) = self.error.take() {
return Err(StreamError::LastOperationFailed(e));
}
let body = match &mut self.state {
IncomingBodyStreamState::Open { body, .. } => body,
IncomingBodyStreamState::Closed => return Err(StreamError::Closed),
};
let future = body.frame();
futures::pin_mut!(future);
match poll_noop(future) {
Some(result) => {
self.record_frame(result);
}
None => return Ok(Bytes::new()),
}
}
}
}
#[async_trait::async_trait]
impl Pollable for HostIncomingBodyStream {
async fn ready(&mut self) {
if !self.buffer.is_empty() || self.error.is_some() {
return;
}
if let IncomingBodyStreamState::Open { body, .. } = &mut self.state {
let frame = body.frame().await;
self.record_frame(frame);
}
}
}
impl Drop for HostIncomingBodyStream {
fn drop(&mut self) {
let prev = mem::replace(&mut self.state, IncomingBodyStreamState::Closed);
if let IncomingBodyStreamState::Open { body, tx } = prev {
let _ = tx.send(StreamEnd::Remaining(body));
}
}
}
#[derive(Debug)]
pub enum HostFutureTrailers {
Waiting(HostIncomingBody),
Done(Result<Option<FieldMap>, types::ErrorCode>),
Consumed,
}
#[async_trait::async_trait]
impl Pollable for HostFutureTrailers {
async fn ready(&mut self) {
let body = match self {
HostFutureTrailers::Waiting(body) => body,
HostFutureTrailers::Done(_) => return,
HostFutureTrailers::Consumed => return,
};
if let IncomingBodyState::InBodyStream(rx) = &mut body.body {
match rx.await {
Ok(StreamEnd::Trailers(t)) => *self = Self::Done(Ok(t)),
Ok(StreamEnd::Remaining(b)) => body.body = IncomingBodyState::Start(b),
Err(_) => {
*self = HostFutureTrailers::Done(Ok(None));
}
}
}
let body = match self {
HostFutureTrailers::Waiting(body) => body,
HostFutureTrailers::Done(_) => return,
HostFutureTrailers::Consumed => return,
};
let hyper_body = match &mut body.body {
IncomingBodyState::Start(body) => body,
IncomingBodyState::InBodyStream(_) => unreachable!(),
};
let result = loop {
match hyper_body.frame().await {
None => break Ok(None),
Some(Err(e)) => break Err(e),
Some(Ok(frame)) => {
if let Ok(headers) = frame.into_trailers() {
break Ok(Some(headers));
}
}
}
};
*self = HostFutureTrailers::Done(result);
}
}
#[derive(Debug, Clone)]
struct WrittenState {
expected: u64,
written: Arc<std::sync::atomic::AtomicU64>,
}
impl WrittenState {
fn new(expected_size: u64) -> Self {
Self {
expected: expected_size,
written: Arc::new(std::sync::atomic::AtomicU64::new(0)),
}
}
fn written(&self) -> u64 {
self.written.load(std::sync::atomic::Ordering::Relaxed)
}
fn update(&self, len: usize) -> bool {
let len = len as u64;
let old = self
.written
.fetch_add(len, std::sync::atomic::Ordering::Relaxed);
old + len <= self.expected
}
}
pub struct HostOutgoingBody {
body_output_stream: Option<Box<dyn OutputStream>>,
context: StreamContext,
written: Option<WrittenState>,
finish_sender: Option<tokio::sync::oneshot::Sender<FinishMessage>>,
}
impl HostOutgoingBody {
pub fn new(
context: StreamContext,
size: Option<u64>,
buffer_chunks: usize,
chunk_size: usize,
) -> (Self, HyperOutgoingBody) {
assert!(buffer_chunks >= 1);
let written = size.map(WrittenState::new);
use tokio::sync::oneshot::error::RecvError;
struct BodyImpl {
body_receiver: mpsc::Receiver<Bytes>,
finish_receiver: Option<oneshot::Receiver<FinishMessage>>,
}
impl Body for BodyImpl {
type Data = Bytes;
type Error = types::ErrorCode;
fn poll_frame(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
match self.as_mut().body_receiver.poll_recv(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(Some(frame)) => Poll::Ready(Some(Ok(Frame::data(frame)))),
Poll::Ready(None) => {
if let Some(mut finish_receiver) = self.as_mut().finish_receiver.take() {
match Pin::new(&mut finish_receiver).poll(cx) {
Poll::Pending => {
self.as_mut().finish_receiver = Some(finish_receiver);
Poll::Pending
}
Poll::Ready(Ok(message)) => match message {
FinishMessage::Finished => Poll::Ready(None),
FinishMessage::Trailers(trailers) => {
Poll::Ready(Some(Ok(Frame::trailers(trailers))))
}
FinishMessage::Abort => {
Poll::Ready(Some(Err(types::ErrorCode::HttpProtocolError)))
}
},
Poll::Ready(Err(RecvError { .. })) => Poll::Ready(None),
}
} else {
Poll::Ready(None)
}
}
}
}
}
let (body_sender, body_receiver) = mpsc::channel(buffer_chunks + 1);
let (finish_sender, finish_receiver) = oneshot::channel();
let body_impl = BodyImpl {
body_receiver,
finish_receiver: Some(finish_receiver),
}
.boxed();
let output_stream = BodyWriteStream::new(context, chunk_size, body_sender, written.clone());
(
Self {
body_output_stream: Some(Box::new(output_stream)),
context,
written,
finish_sender: Some(finish_sender),
},
body_impl,
)
}
pub fn take_output_stream(&mut self) -> Option<Box<dyn OutputStream>> {
self.body_output_stream.take()
}
pub fn finish(mut self, trailers: Option<FieldMap>) -> Result<(), types::ErrorCode> {
drop(self.body_output_stream);
let sender = self
.finish_sender
.take()
.expect("outgoing-body trailer_sender consumed by a non-owning function");
if let Some(w) = self.written {
let written = w.written();
if written != w.expected {
let _ = sender.send(FinishMessage::Abort);
return Err(self.context.as_body_size_error(written));
}
}
let message = if let Some(ts) = trailers {
FinishMessage::Trailers(ts)
} else {
FinishMessage::Finished
};
let _ = sender.send(message);
Ok(())
}
pub fn abort(mut self) {
drop(self.body_output_stream);
let sender = self
.finish_sender
.take()
.expect("outgoing-body trailer_sender consumed by a non-owning function");
let _ = sender.send(FinishMessage::Abort);
}
}
#[derive(Debug)]
enum FinishMessage {
Finished,
Trailers(hyper::HeaderMap),
Abort,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum StreamContext {
Request,
Response,
}
impl StreamContext {
pub fn as_body_size_error(&self, size: u64) -> types::ErrorCode {
match self {
StreamContext::Request => types::ErrorCode::HttpRequestBodySize(Some(size)),
StreamContext::Response => types::ErrorCode::HttpResponseBodySize(Some(size)),
}
}
}
#[derive(Debug)]
struct BodyWriteStream {
context: StreamContext,
writer: mpsc::Sender<Bytes>,
write_budget: usize,
written: Option<WrittenState>,
}
impl BodyWriteStream {
fn new(
context: StreamContext,
write_budget: usize,
writer: mpsc::Sender<Bytes>,
written: Option<WrittenState>,
) -> Self {
assert!(writer.max_capacity() >= 1);
BodyWriteStream {
context,
writer,
write_budget,
written,
}
}
}
#[async_trait::async_trait]
impl OutputStream for BodyWriteStream {
fn write(&mut self, bytes: Bytes) -> Result<(), StreamError> {
let len = bytes.len();
match self.writer.try_send(bytes) {
Ok(()) => {
if let Some(written) = self.written.as_ref() {
if !written.update(len) {
let total = written.written();
return Err(StreamError::LastOperationFailed(anyhow!(
self.context.as_body_size_error(total)
)));
}
}
Ok(())
}
Err(mpsc::error::TrySendError::Full(_)) => {
Err(StreamError::Trap(anyhow!("write exceeded budget")))
}
Err(mpsc::error::TrySendError::Closed(_)) => Err(StreamError::Closed),
}
}
fn flush(&mut self) -> Result<(), StreamError> {
if self.writer.is_closed() {
Err(StreamError::Closed)
} else {
Ok(())
}
}
fn check_write(&mut self) -> Result<usize, StreamError> {
if self.writer.is_closed() {
Err(StreamError::Closed)
} else if self.writer.capacity() == 0 {
Ok(0)
} else {
Ok(self.write_budget)
}
}
}
#[async_trait::async_trait]
impl Pollable for BodyWriteStream {
async fn ready(&mut self) {
let _ = self.writer.reserve().await;
}
}