Path: blob/main/crates/test-programs/src/bin/api_proxy_streaming.rs
1693 views
use anyhow::{Result, anyhow, bail};1use futures::{Future, SinkExt, StreamExt, TryStreamExt, future, stream};2use test_programs::wasi::http::types::{3Fields, IncomingRequest, IncomingResponse, Method, OutgoingBody, OutgoingRequest,4OutgoingResponse, ResponseOutparam, Scheme,5};6use url::Url;78const MAX_CONCURRENCY: usize = 16;910struct Handler;1112test_programs::proxy::export!(Handler);1314impl test_programs::proxy::exports::wasi::http::incoming_handler::Guest for Handler {15fn handle(request: IncomingRequest, response_out: ResponseOutparam) {16executor::run(async move {17handle_request(request, response_out).await;18})19}20}2122async fn handle_request(request: IncomingRequest, response_out: ResponseOutparam) {23let headers = request.headers().entries();2425assert!(request.authority().is_some());2627match (request.method(), request.path_with_query().as_deref()) {28(Method::Get, Some("/hash-all")) => {29// Send outgoing GET requests to the specified URLs and stream the hashes of the response bodies as30// they arrive.3132let urls = headers.iter().filter_map(|(k, v)| {33(k == "url")34.then_some(v)35.and_then(|v| std::str::from_utf8(v).ok())36.and_then(|v| Url::parse(v).ok())37});3839let results = urls.map(|url| async move {40let result = hash(&url).await;41(url, result)42});4344let mut results = stream::iter(results).buffer_unordered(MAX_CONCURRENCY);4546let response = OutgoingResponse::new(47Fields::from_list(&[("content-type".to_string(), b"text/plain".to_vec())]).unwrap(),48);4950let mut body =51executor::outgoing_body(response.body().expect("response should be writable"));5253ResponseOutparam::set(response_out, Ok(response));5455while let Some((url, result)) = results.next().await {56let payload = match result {57Ok(hash) => format!("{url}: {hash}\n"),58Err(e) => format!("{url}: {e:?}\n"),59}60.into_bytes();6162if let Err(e) = body.send(payload).await {63eprintln!("Error sending payload: {e}");64}65}66}6768(Method::Post, Some("/echo")) => {69// Echo the request body without buffering it.7071let response = OutgoingResponse::new(72Fields::from_list(73&headers74.into_iter()75.filter_map(|(k, v)| (k == "content-type").then_some((k, v)))76.collect::<Vec<_>>(),77)78.unwrap(),79);8081let mut body =82executor::outgoing_body(response.body().expect("response should be writable"));8384ResponseOutparam::set(response_out, Ok(response));8586let mut stream =87executor::incoming_body(request.consume().expect("request should be readable"));8889while let Some(chunk) = stream.next().await {90match chunk {91Ok(chunk) => {92if let Err(e) = body.send(chunk).await {93eprintln!("Error sending body: {e}");94break;95}96}97Err(e) => {98eprintln!("Error receiving body: {e}");99break;100}101}102}103}104105(Method::Post, Some("/double-echo")) => {106// Pipe the request body to an outgoing request and stream the response back to the client.107108if let Some(url) = headers.iter().find_map(|(k, v)| {109(k == "url")110.then_some(v)111.and_then(|v| std::str::from_utf8(v).ok())112.and_then(|v| Url::parse(v).ok())113}) {114match double_echo(request, &url).await {115Ok((request_copy, response)) => {116let mut stream = executor::incoming_body(117response.consume().expect("response should be consumable"),118);119120let response = OutgoingResponse::new(121Fields::from_list(122&headers123.into_iter()124.filter_map(|(k, v)| (k == "content-type").then_some((k, v)))125.collect::<Vec<_>>(),126)127.unwrap(),128);129130let mut body = executor::outgoing_body(131response.body().expect("response should be writable"),132);133134ResponseOutparam::set(response_out, Ok(response));135136let response_copy = async move {137while let Some(chunk) = stream.next().await {138body.send(chunk?).await?;139}140Ok::<_, anyhow::Error>(())141};142143let (request_copy, response_copy) =144future::join(request_copy, response_copy).await;145if let Err(e) = request_copy.and(response_copy) {146eprintln!("error piping to and from {url}: {e}");147}148}149150Err(e) => {151eprintln!("Error sending outgoing request to {url}: {e}");152server_error(response_out);153}154}155} else {156bad_request(response_out);157}158}159160_ => method_not_allowed(response_out),161}162}163164async fn double_echo(165incoming_request: IncomingRequest,166url: &Url,167) -> Result<(impl Future<Output = Result<()>> + use<>, IncomingResponse)> {168let outgoing_request = OutgoingRequest::new(Fields::new());169170outgoing_request171.set_method(&Method::Post)172.map_err(|()| anyhow!("failed to set method"))?;173174outgoing_request175.set_path_with_query(Some(url.path()))176.map_err(|()| anyhow!("failed to set path_with_query"))?;177178outgoing_request179.set_scheme(Some(&match url.scheme() {180"http" => Scheme::Http,181"https" => Scheme::Https,182scheme => Scheme::Other(scheme.into()),183}))184.map_err(|()| anyhow!("failed to set scheme"))?;185186outgoing_request187.set_authority(Some(&format!(188"{}{}",189url.host_str().unwrap_or(""),190if let Some(port) = url.port() {191format!(":{port}")192} else {193String::new()194}195)))196.map_err(|()| anyhow!("failed to set authority"))?;197198let mut body = executor::outgoing_body(199outgoing_request200.body()201.expect("request body should be writable"),202);203204let response = executor::outgoing_request_send(outgoing_request);205206let mut stream = executor::incoming_body(207incoming_request208.consume()209.expect("request should be consumable"),210);211212let copy = async move {213while let Some(chunk) = stream.next().await {214body.send(chunk?).await?;215}216Ok::<_, anyhow::Error>(())217};218219let response = response.await?;220221let status = response.status();222223if !(200..300).contains(&status) {224bail!("unexpected status: {status}");225}226227Ok((copy, response))228}229230fn server_error(response_out: ResponseOutparam) {231respond(500, response_out)232}233234fn bad_request(response_out: ResponseOutparam) {235respond(400, response_out)236}237238fn method_not_allowed(response_out: ResponseOutparam) {239respond(405, response_out)240}241242fn respond(status: u16, response_out: ResponseOutparam) {243let response = OutgoingResponse::new(Fields::new());244response245.set_status_code(status)246.expect("setting status code");247248let body = response.body().expect("response should be writable");249250ResponseOutparam::set(response_out, Ok(response));251252OutgoingBody::finish(body, None).expect("outgoing-body.finish");253}254255async fn hash(url: &Url) -> Result<String> {256let request = OutgoingRequest::new(Fields::new());257258request259.set_path_with_query(Some(url.path()))260.map_err(|()| anyhow!("failed to set path_with_query"))?;261request262.set_scheme(Some(&match url.scheme() {263"http" => Scheme::Http,264"https" => Scheme::Https,265scheme => Scheme::Other(scheme.into()),266}))267.map_err(|()| anyhow!("failed to set scheme"))?;268request269.set_authority(Some(&format!(270"{}{}",271url.host_str().unwrap_or(""),272if let Some(port) = url.port() {273format!(":{port}")274} else {275String::new()276}277)))278.map_err(|()| anyhow!("failed to set authority"))?;279280let response = executor::outgoing_request_send(request).await?;281282let status = response.status();283284if !(200..300).contains(&status) {285bail!("unexpected status: {status}");286}287288let mut body =289executor::incoming_body(response.consume().expect("response should be readable"));290291use sha2::Digest;292let mut hasher = sha2::Sha256::new();293while let Some(chunk) = body.try_next().await? {294hasher.update(&chunk);295}296297use base64::Engine;298Ok(base64::engine::general_purpose::STANDARD_NO_PAD.encode(hasher.finalize()))299}300301// Technically this should not be here for a proxy, but given the current302// framework for tests it's required since this file is built as a `bin`303fn main() {}304305mod executor {306use anyhow::{Error, Result, anyhow};307use futures::{Sink, Stream, future, sink, stream};308use std::{309cell::RefCell,310future::Future,311mem,312rc::Rc,313sync::{Arc, Mutex},314task::{Context, Poll, Wake, Waker},315};316use test_programs::wasi::{317http::{318outgoing_handler,319types::{320self, FutureTrailers, IncomingBody, IncomingResponse, InputStream, OutgoingBody,321OutgoingRequest, OutputStream,322},323},324io::{self, streams::StreamError},325};326327const READ_SIZE: u64 = 16 * 1024;328329static WAKERS: Mutex<Vec<(io::poll::Pollable, Waker)>> = Mutex::new(Vec::new());330331pub fn run<T>(future: impl Future<Output = T>) -> T {332futures::pin_mut!(future);333334struct DummyWaker;335336impl Wake for DummyWaker {337fn wake(self: Arc<Self>) {}338}339340let waker = Arc::new(DummyWaker).into();341342loop {343match future.as_mut().poll(&mut Context::from_waker(&waker)) {344Poll::Pending => {345let mut new_wakers = Vec::new();346347let wakers = mem::take::<Vec<_>>(&mut WAKERS.lock().unwrap());348349assert!(!wakers.is_empty());350351let pollables = wakers352.iter()353.map(|(pollable, _)| pollable)354.collect::<Vec<_>>();355356let mut ready = vec![false; wakers.len()];357358for index in io::poll::poll(&pollables) {359ready[usize::try_from(index).unwrap()] = true;360}361362for (ready, (pollable, waker)) in ready.into_iter().zip(wakers) {363if ready {364waker.wake()365} else {366new_wakers.push((pollable, waker));367}368}369370*WAKERS.lock().unwrap() = new_wakers;371}372Poll::Ready(result) => break result,373}374}375}376377pub fn outgoing_body(body: OutgoingBody) -> impl Sink<Vec<u8>, Error = Error> {378struct Outgoing(Option<(OutputStream, OutgoingBody)>);379380impl Drop for Outgoing {381fn drop(&mut self) {382if let Some((stream, body)) = self.0.take() {383drop(stream);384OutgoingBody::finish(body, None).expect("outgoing-body.finish");385}386}387}388389let stream = body.write().expect("response body should be writable");390let pair = Rc::new(RefCell::new(Outgoing(Some((stream, body)))));391392sink::unfold((), {393move |(), chunk: Vec<u8>| {394future::poll_fn({395let mut offset = 0;396let mut flushing = false;397let pair = pair.clone();398399move |context| {400let pair = pair.borrow();401let (stream, _) = &pair.0.as_ref().unwrap();402403loop {404match stream.check_write() {405Ok(0) => {406WAKERS407.lock()408.unwrap()409.push((stream.subscribe(), context.waker().clone()));410411break Poll::Pending;412}413Ok(count) => {414if offset == chunk.len() {415if flushing {416break Poll::Ready(Ok(()));417} else {418stream.flush().expect("stream should be flushable");419flushing = true;420}421} else {422let count = usize::try_from(count)423.unwrap()424.min(chunk.len() - offset);425426match stream.write(&chunk[offset..][..count]) {427Ok(()) => {428offset += count;429}430Err(_) => break Poll::Ready(Err(anyhow!("I/O error"))),431}432}433}434Err(_) => break Poll::Ready(Err(anyhow!("I/O error"))),435}436}437}438})439}440})441}442443pub fn outgoing_request_send(444request: OutgoingRequest,445) -> impl Future<Output = Result<IncomingResponse, types::ErrorCode>> {446future::poll_fn({447let response = outgoing_handler::handle(request, None);448449move |context| match &response {450Ok(response) => {451if let Some(response) = response.get() {452Poll::Ready(response.unwrap())453} else {454WAKERS455.lock()456.unwrap()457.push((response.subscribe(), context.waker().clone()));458Poll::Pending459}460}461Err(error) => Poll::Ready(Err(error.clone())),462}463})464}465466pub fn incoming_body(body: IncomingBody) -> impl Stream<Item = Result<Vec<u8>>> {467enum Inner {468Stream {469stream: InputStream,470body: IncomingBody,471},472Trailers(FutureTrailers),473Closed,474}475476struct Incoming(Inner);477478impl Drop for Incoming {479fn drop(&mut self) {480match mem::replace(&mut self.0, Inner::Closed) {481Inner::Stream { stream, body } => {482drop(stream);483IncomingBody::finish(body);484}485Inner::Trailers(_) | Inner::Closed => {}486}487}488}489490stream::poll_fn({491let stream = body.stream().expect("response body should be readable");492let mut incoming = Incoming(Inner::Stream { stream, body });493494move |context| {495loop {496match &incoming.0 {497Inner::Stream { stream, .. } => match stream.read(READ_SIZE) {498Ok(buffer) => {499return if buffer.is_empty() {500WAKERS501.lock()502.unwrap()503.push((stream.subscribe(), context.waker().clone()));504Poll::Pending505} else {506Poll::Ready(Some(Ok(buffer)))507};508}509Err(StreamError::Closed) => {510let Inner::Stream { stream, body } =511mem::replace(&mut incoming.0, Inner::Closed)512else {513unreachable!();514};515drop(stream);516incoming.0 = Inner::Trailers(IncomingBody::finish(body));517}518Err(StreamError::LastOperationFailed(error)) => {519return Poll::Ready(Some(Err(anyhow!(520"{}",521error.to_debug_string()522))));523}524},525526Inner::Trailers(trailers) => {527match trailers.get() {528Some(Ok(trailers)) => {529incoming.0 = Inner::Closed;530match trailers {531Ok(Some(_)) => {532// Currently, we just ignore any trailers. TODO: Add a test that533// expects trailers and verify they match the expected contents.534}535Ok(None) => {536// No trailers; nothing else to do.537}538Err(error) => {539// Error reading the trailers: pass it on to the application.540return Poll::Ready(Some(Err(anyhow!("{error:?}"))));541}542}543}544Some(Err(_)) => {545// Should only happen if we try to retrieve the trailers twice, i.e. a bug in546// this code.547unreachable!();548}549None => {550WAKERS551.lock()552.unwrap()553.push((trailers.subscribe(), context.waker().clone()));554return Poll::Pending;555}556}557}558559Inner::Closed => {560return Poll::Ready(None);561}562}563}564}565})566}567}568569570