Path: blob/main/crates/wasi-common/src/sync/net.rs
1693 views
use crate::{1Error, ErrorExt,2file::{FdFlags, FileType, RiFlags, RoFlags, SdFlags, SiFlags, WasiFile},3};4#[cfg(windows)]5use io_extras::os::windows::{AsRawHandleOrSocket, RawHandleOrSocket};6use io_lifetimes::AsSocketlike;7#[cfg(unix)]8use io_lifetimes::{AsFd, BorrowedFd};9#[cfg(windows)]10use io_lifetimes::{AsSocket, BorrowedSocket};11use std::any::Any;12use std::io;13#[cfg(unix)]14use system_interface::fs::GetSetFdFlags;15use system_interface::io::IoExt;16use system_interface::io::IsReadWrite;17use system_interface::io::ReadReady;1819pub enum Socket {20TcpListener(cap_std::net::TcpListener),21TcpStream(cap_std::net::TcpStream),22#[cfg(unix)]23UnixStream(cap_std::os::unix::net::UnixStream),24#[cfg(unix)]25UnixListener(cap_std::os::unix::net::UnixListener),26}2728impl From<cap_std::net::TcpListener> for Socket {29fn from(listener: cap_std::net::TcpListener) -> Self {30Self::TcpListener(listener)31}32}3334impl From<cap_std::net::TcpStream> for Socket {35fn from(stream: cap_std::net::TcpStream) -> Self {36Self::TcpStream(stream)37}38}3940#[cfg(unix)]41impl From<cap_std::os::unix::net::UnixListener> for Socket {42fn from(listener: cap_std::os::unix::net::UnixListener) -> Self {43Self::UnixListener(listener)44}45}4647#[cfg(unix)]48impl From<cap_std::os::unix::net::UnixStream> for Socket {49fn from(stream: cap_std::os::unix::net::UnixStream) -> Self {50Self::UnixStream(stream)51}52}5354#[cfg(unix)]55impl From<Socket> for Box<dyn WasiFile> {56fn from(listener: Socket) -> Self {57match listener {58Socket::TcpListener(l) => Box::new(crate::sync::net::TcpListener::from_cap_std(l)),59Socket::UnixListener(l) => Box::new(crate::sync::net::UnixListener::from_cap_std(l)),60Socket::TcpStream(l) => Box::new(crate::sync::net::TcpStream::from_cap_std(l)),61Socket::UnixStream(l) => Box::new(crate::sync::net::UnixStream::from_cap_std(l)),62}63}64}6566#[cfg(windows)]67impl From<Socket> for Box<dyn WasiFile> {68fn from(listener: Socket) -> Self {69match listener {70Socket::TcpListener(l) => Box::new(crate::sync::net::TcpListener::from_cap_std(l)),71Socket::TcpStream(l) => Box::new(crate::sync::net::TcpStream::from_cap_std(l)),72}73}74}7576macro_rules! wasi_listen_write_impl {77($ty:ty, $stream:ty) => {78#[wiggle::async_trait]79impl WasiFile for $ty {80fn as_any(&self) -> &dyn Any {81self82}83#[cfg(unix)]84fn pollable(&self) -> Option<rustix::fd::BorrowedFd<'_>> {85Some(self.0.as_fd())86}87#[cfg(windows)]88fn pollable(&self) -> Option<io_extras::os::windows::RawHandleOrSocket> {89Some(self.0.as_raw_handle_or_socket())90}91async fn sock_accept(&self, fdflags: FdFlags) -> Result<Box<dyn WasiFile>, Error> {92let (stream, _) = self.0.accept()?;93let mut stream = <$stream>::from_cap_std(stream);94stream.set_fdflags(fdflags).await?;95Ok(Box::new(stream))96}97async fn get_filetype(&self) -> Result<FileType, Error> {98Ok(FileType::SocketStream)99}100#[cfg(unix)]101async fn get_fdflags(&self) -> Result<FdFlags, Error> {102let fdflags = get_fd_flags(&self.0)?;103Ok(fdflags)104}105async fn set_fdflags(&mut self, fdflags: FdFlags) -> Result<(), Error> {106if fdflags == crate::file::FdFlags::NONBLOCK {107self.0.set_nonblocking(true)?;108} else if fdflags.is_empty() {109self.0.set_nonblocking(false)?;110} else {111return Err(112Error::invalid_argument().context("cannot set anything else than NONBLOCK")113);114}115Ok(())116}117fn num_ready_bytes(&self) -> Result<u64, Error> {118Ok(1)119}120}121122#[cfg(windows)]123impl AsSocket for $ty {124#[inline]125fn as_socket(&self) -> BorrowedSocket<'_> {126self.0.as_socket()127}128}129130#[cfg(windows)]131impl AsRawHandleOrSocket for $ty {132#[inline]133fn as_raw_handle_or_socket(&self) -> RawHandleOrSocket {134self.0.as_raw_handle_or_socket()135}136}137138#[cfg(unix)]139impl AsFd for $ty {140fn as_fd(&self) -> BorrowedFd<'_> {141self.0.as_fd()142}143}144};145}146147pub struct TcpListener(cap_std::net::TcpListener);148149impl TcpListener {150pub fn from_cap_std(cap_std: cap_std::net::TcpListener) -> Self {151TcpListener(cap_std)152}153}154wasi_listen_write_impl!(TcpListener, TcpStream);155156#[cfg(unix)]157pub struct UnixListener(cap_std::os::unix::net::UnixListener);158159#[cfg(unix)]160impl UnixListener {161pub fn from_cap_std(cap_std: cap_std::os::unix::net::UnixListener) -> Self {162UnixListener(cap_std)163}164}165166#[cfg(unix)]167wasi_listen_write_impl!(UnixListener, UnixStream);168169macro_rules! wasi_stream_write_impl {170($ty:ty, $std_ty:ty) => {171#[wiggle::async_trait]172impl WasiFile for $ty {173fn as_any(&self) -> &dyn Any {174self175}176#[cfg(unix)]177fn pollable(&self) -> Option<rustix::fd::BorrowedFd<'_>> {178Some(self.0.as_fd())179}180#[cfg(windows)]181fn pollable(&self) -> Option<io_extras::os::windows::RawHandleOrSocket> {182Some(self.0.as_raw_handle_or_socket())183}184async fn get_filetype(&self) -> Result<FileType, Error> {185Ok(FileType::SocketStream)186}187#[cfg(unix)]188async fn get_fdflags(&self) -> Result<FdFlags, Error> {189let fdflags = get_fd_flags(&self.0)?;190Ok(fdflags)191}192async fn set_fdflags(&mut self, fdflags: FdFlags) -> Result<(), Error> {193if fdflags == crate::file::FdFlags::NONBLOCK {194self.0.set_nonblocking(true)?;195} else if fdflags.is_empty() {196self.0.set_nonblocking(false)?;197} else {198return Err(199Error::invalid_argument().context("cannot set anything else than NONBLOCK")200);201}202Ok(())203}204async fn read_vectored<'a>(205&self,206bufs: &mut [io::IoSliceMut<'a>],207) -> Result<u64, Error> {208use std::io::Read;209let n = Read::read_vectored(&mut &*self.as_socketlike_view::<$std_ty>(), bufs)?;210Ok(n.try_into()?)211}212async fn write_vectored<'a>(&self, bufs: &[io::IoSlice<'a>]) -> Result<u64, Error> {213use std::io::Write;214let n = Write::write_vectored(&mut &*self.as_socketlike_view::<$std_ty>(), bufs)?;215Ok(n.try_into()?)216}217async fn peek(&self, buf: &mut [u8]) -> Result<u64, Error> {218let n = self.0.peek(buf)?;219Ok(n.try_into()?)220}221fn num_ready_bytes(&self) -> Result<u64, Error> {222let val = self.as_socketlike_view::<$std_ty>().num_ready_bytes()?;223Ok(val)224}225async fn readable(&self) -> Result<(), Error> {226let (readable, _writeable) = is_read_write(&self.0)?;227if readable { Ok(()) } else { Err(Error::io()) }228}229async fn writable(&self) -> Result<(), Error> {230let (_readable, writeable) = is_read_write(&self.0)?;231if writeable { Ok(()) } else { Err(Error::io()) }232}233234async fn sock_recv<'a>(235&self,236ri_data: &mut [std::io::IoSliceMut<'a>],237ri_flags: RiFlags,238) -> Result<(u64, RoFlags), Error> {239if (ri_flags & !(RiFlags::RECV_PEEK | RiFlags::RECV_WAITALL)) != RiFlags::empty() {240return Err(Error::not_supported());241}242243if ri_flags.contains(RiFlags::RECV_PEEK) {244if let Some(first) = ri_data.iter_mut().next() {245let n = self.0.peek(first)?;246return Ok((n as u64, RoFlags::empty()));247} else {248return Ok((0, RoFlags::empty()));249}250}251252if ri_flags.contains(RiFlags::RECV_WAITALL) {253let n: usize = ri_data.iter().map(|buf| buf.len()).sum();254self.0.read_exact_vectored(ri_data)?;255return Ok((n as u64, RoFlags::empty()));256}257258let n = self.0.read_vectored(ri_data)?;259Ok((n as u64, RoFlags::empty()))260}261262async fn sock_send<'a>(263&self,264si_data: &[std::io::IoSlice<'a>],265si_flags: SiFlags,266) -> Result<u64, Error> {267if si_flags != SiFlags::empty() {268return Err(Error::not_supported());269}270271let n = self.0.write_vectored(si_data)?;272Ok(n as u64)273}274275async fn sock_shutdown(&self, how: SdFlags) -> Result<(), Error> {276let how = if how == SdFlags::RD | SdFlags::WR {277cap_std::net::Shutdown::Both278} else if how == SdFlags::RD {279cap_std::net::Shutdown::Read280} else if how == SdFlags::WR {281cap_std::net::Shutdown::Write282} else {283return Err(Error::invalid_argument());284};285self.0.shutdown(how)?;286Ok(())287}288}289#[cfg(unix)]290impl AsFd for $ty {291fn as_fd(&self) -> BorrowedFd<'_> {292self.0.as_fd()293}294}295296#[cfg(windows)]297impl AsSocket for $ty {298/// Borrows the socket.299fn as_socket(&self) -> BorrowedSocket<'_> {300self.0.as_socket()301}302}303304#[cfg(windows)]305impl AsRawHandleOrSocket for TcpStream {306#[inline]307fn as_raw_handle_or_socket(&self) -> RawHandleOrSocket {308self.0.as_raw_handle_or_socket()309}310}311};312}313314pub struct TcpStream(cap_std::net::TcpStream);315316impl TcpStream {317pub fn from_cap_std(socket: cap_std::net::TcpStream) -> Self {318TcpStream(socket)319}320}321322wasi_stream_write_impl!(TcpStream, std::net::TcpStream);323324#[cfg(unix)]325pub struct UnixStream(cap_std::os::unix::net::UnixStream);326327#[cfg(unix)]328impl UnixStream {329pub fn from_cap_std(socket: cap_std::os::unix::net::UnixStream) -> Self {330UnixStream(socket)331}332}333334#[cfg(unix)]335wasi_stream_write_impl!(UnixStream, std::os::unix::net::UnixStream);336337pub fn filetype_from(ft: &cap_std::fs::FileType) -> FileType {338use cap_fs_ext::FileTypeExt;339if ft.is_block_device() {340FileType::SocketDgram341} else {342FileType::SocketStream343}344}345346/// Return the file-descriptor flags for a given file-like object.347///348/// This returns the flags needed to implement [`WasiFile::get_fdflags`].349pub fn get_fd_flags<Socketlike: AsSocketlike>(f: Socketlike) -> io::Result<crate::file::FdFlags> {350// On Unix-family platforms, we can use the same system call that we'd use351// for files on sockets here.352#[cfg(not(windows))]353{354let mut out = crate::file::FdFlags::empty();355if f.get_fd_flags()?356.contains(system_interface::fs::FdFlags::NONBLOCK)357{358out |= crate::file::FdFlags::NONBLOCK;359}360Ok(out)361}362363// On Windows, sockets are different, and there is no direct way to364// query for the non-blocking flag. We can get a sufficient approximation365// by testing whether a zero-length `recv` appears to block.366#[cfg(windows)]367let buf: &mut [u8] = &mut [];368#[cfg(windows)]369match rustix::net::recv(f, buf, rustix::net::RecvFlags::empty()) {370Ok(_) => Ok(crate::file::FdFlags::empty()),371Err(rustix::io::Errno::WOULDBLOCK) => Ok(crate::file::FdFlags::NONBLOCK),372Err(e) => Err(e.into()),373}374}375376/// Return the file-descriptor flags for a given file-like object.377///378/// This returns the flags needed to implement [`WasiFile::get_fdflags`].379pub fn is_read_write<Socketlike: AsSocketlike>(f: Socketlike) -> io::Result<(bool, bool)> {380// On Unix-family platforms, we have an `IsReadWrite` impl.381#[cfg(not(windows))]382{383f.is_read_write()384}385386// On Windows, we only have a `TcpStream` impl, so make a view first.387#[cfg(windows)]388{389f.as_socketlike_view::<std::net::TcpStream>()390.is_read_write()391}392}393394395