Path: blob/main/crates/wasi/src/sockets/udp.rs
1692 views
use crate::runtime::with_ambient_tokio_runtime;1use crate::sockets::util::{2ErrorCode, get_unicast_hop_limit, is_valid_address_family, is_valid_remote_address,3receive_buffer_size, send_buffer_size, set_receive_buffer_size, set_send_buffer_size,4set_unicast_hop_limit, udp_bind, udp_disconnect, udp_socket,5};6use crate::sockets::{SocketAddrCheck, SocketAddressFamily, WasiSocketsCtx};7use cap_net_ext::AddressFamily;8use io_lifetimes::AsSocketlike as _;9use io_lifetimes::raw::{FromRawSocketlike as _, IntoRawSocketlike as _};10use rustix::io::Errno;11use rustix::net::connect;12use std::net::SocketAddr;13use std::sync::Arc;14use tracing::debug;1516/// The state of a UDP socket.17///18/// This represents the various states a socket can be in during the19/// activities of binding, and connecting.20enum UdpState {21/// The initial state for a newly-created socket.22Default,2324/// A `bind` operation has started but has yet to complete with25/// `finish_bind`.26BindStarted,2728/// Binding finished via `finish_bind`. The socket has an address but29/// is not yet listening for connections.30Bound,3132/// The socket is "connected" to a peer address.33#[cfg_attr(34not(feature = "p3"),35expect(dead_code, reason = "p2 has its own way of managing sending/receiving")36)]37Connected(SocketAddr),38}3940/// A host UDP socket, plus associated bookkeeping.41///42/// The inner state is wrapped in an Arc because the same underlying socket is43/// used for implementing the stream types.44pub struct UdpSocket {45socket: Arc<tokio::net::UdpSocket>,4647/// The current state in the bind/connect progression.48udp_state: UdpState,4950/// Socket address family.51family: SocketAddressFamily,5253/// If set, use this custom check for addrs, otherwise use what's in54/// `WasiSocketsCtx`.55socket_addr_check: Option<SocketAddrCheck>,56}5758impl UdpSocket {59/// Create a new socket in the given family.60pub(crate) fn new(cx: &WasiSocketsCtx, family: AddressFamily) -> Result<Self, ErrorCode> {61cx.allowed_network_uses.check_allowed_udp()?;6263// Delegate socket creation to cap_net_ext. They handle a couple of things for us:64// - On Windows: call WSAStartup if not done before.65// - Set the NONBLOCK and CLOEXEC flags. Either immediately during socket creation,66// or afterwards using ioctl or fcntl. Exact method depends on the platform.6768let fd = udp_socket(family)?;6970let socket_address_family = match family {71AddressFamily::Ipv4 => SocketAddressFamily::Ipv4,72AddressFamily::Ipv6 => {73rustix::net::sockopt::set_ipv6_v6only(&fd, true)?;74SocketAddressFamily::Ipv675}76};7778let socket = with_ambient_tokio_runtime(|| {79tokio::net::UdpSocket::try_from(unsafe {80std::net::UdpSocket::from_raw_socketlike(fd.into_raw_socketlike())81})82})?;8384Ok(Self {85socket: Arc::new(socket),86udp_state: UdpState::Default,87family: socket_address_family,88socket_addr_check: None,89})90}9192pub(crate) fn bind(&mut self, addr: SocketAddr) -> Result<(), ErrorCode> {93if !matches!(self.udp_state, UdpState::Default) {94return Err(ErrorCode::InvalidState);95}96if !is_valid_address_family(addr.ip(), self.family) {97return Err(ErrorCode::InvalidArgument);98}99udp_bind(&self.socket, addr)?;100self.udp_state = UdpState::BindStarted;101Ok(())102}103104pub(crate) fn finish_bind(&mut self) -> Result<(), ErrorCode> {105match self.udp_state {106UdpState::BindStarted => {107self.udp_state = UdpState::Bound;108Ok(())109}110_ => Err(ErrorCode::NotInProgress),111}112}113114pub(crate) fn is_connected(&self) -> bool {115matches!(self.udp_state, UdpState::Connected(..))116}117118pub(crate) fn is_bound(&self) -> bool {119matches!(self.udp_state, UdpState::Connected(..) | UdpState::Bound)120}121122pub(crate) fn disconnect(&mut self) -> Result<(), ErrorCode> {123if !self.is_connected() {124return Err(ErrorCode::InvalidState);125}126udp_disconnect(&self.socket)?;127self.udp_state = UdpState::Bound;128Ok(())129}130131pub(crate) fn connect(&mut self, addr: SocketAddr) -> Result<(), ErrorCode> {132if !is_valid_address_family(addr.ip(), self.family) || !is_valid_remote_address(addr) {133return Err(ErrorCode::InvalidArgument);134}135136match self.udp_state {137UdpState::Bound | UdpState::Connected(_) => {}138_ => return Err(ErrorCode::InvalidState),139}140141// We disconnect & (re)connect in two distinct steps for two reasons:142// - To leave our socket instance in a consistent state in case the143// connect fails.144// - When reconnecting to a different address, Linux sometimes fails145// if there isn't a disconnect in between.146147// Step #1: Disconnect148if let UdpState::Connected(..) = self.udp_state {149udp_disconnect(&self.socket)?;150self.udp_state = UdpState::Bound;151}152// Step #2: (Re)connect153connect(&self.socket, &addr).map_err(|error| match error {154Errno::AFNOSUPPORT => ErrorCode::InvalidArgument, // See `udp_bind` implementation.155Errno::INPROGRESS => {156debug!("UDP connect returned EINPROGRESS, which should never happen");157ErrorCode::Unknown158}159err => err.into(),160})?;161self.udp_state = UdpState::Connected(addr);162Ok(())163}164165#[cfg(feature = "p3")]166pub(crate) fn send(&self, buf: Vec<u8>) -> impl Future<Output = Result<(), ErrorCode>> + use<> {167let socket = if let UdpState::Connected(..) = self.udp_state {168Ok(Arc::clone(&self.socket))169} else {170Err(ErrorCode::InvalidArgument)171};172async move {173let socket = socket?;174send(&socket, &buf).await175}176}177178#[cfg(feature = "p3")]179pub(crate) fn send_to(180&self,181buf: Vec<u8>,182addr: SocketAddr,183) -> impl Future<Output = Result<(), ErrorCode>> + use<> {184enum Mode {185Send(Arc<tokio::net::UdpSocket>),186SendTo(Arc<tokio::net::UdpSocket>, SocketAddr),187}188let socket = match &self.udp_state {189UdpState::BindStarted => Err(ErrorCode::InvalidState),190UdpState::Default | UdpState::Bound => Ok(Mode::SendTo(Arc::clone(&self.socket), addr)),191UdpState::Connected(caddr) if addr == *caddr => {192Ok(Mode::Send(Arc::clone(&self.socket)))193}194UdpState::Connected(..) => Err(ErrorCode::InvalidArgument),195};196async move {197match socket? {198Mode::Send(socket) => send(&socket, &buf).await,199Mode::SendTo(socket, addr) => send_to(&socket, &buf, addr).await,200}201}202}203204#[cfg(feature = "p3")]205pub(crate) fn receive(206&self,207) -> impl Future<Output = Result<(Vec<u8>, SocketAddr), ErrorCode>> + use<> {208enum Mode {209Recv(Arc<tokio::net::UdpSocket>, SocketAddr),210RecvFrom(Arc<tokio::net::UdpSocket>),211}212let socket = match self.udp_state {213UdpState::Default | UdpState::BindStarted => Err(ErrorCode::InvalidState),214UdpState::Bound => Ok(Mode::RecvFrom(Arc::clone(&self.socket))),215UdpState::Connected(addr) => Ok(Mode::Recv(Arc::clone(&self.socket), addr)),216};217async move {218let socket = socket?;219let mut buf = vec![0; super::MAX_UDP_DATAGRAM_SIZE];220let (n, addr) = match socket {221Mode::Recv(socket, addr) => {222let n = socket.recv(&mut buf).await?;223(n, addr)224}225Mode::RecvFrom(socket) => {226let (n, addr) = socket.recv_from(&mut buf).await?;227(n, addr)228}229};230buf.truncate(n);231Ok((buf, addr))232}233}234235pub(crate) fn local_address(&self) -> Result<SocketAddr, ErrorCode> {236if matches!(self.udp_state, UdpState::Default | UdpState::BindStarted) {237return Err(ErrorCode::InvalidState);238}239let addr = self240.socket241.as_socketlike_view::<std::net::UdpSocket>()242.local_addr()?;243Ok(addr)244}245246pub(crate) fn remote_address(&self) -> Result<SocketAddr, ErrorCode> {247if !matches!(self.udp_state, UdpState::Connected(..)) {248return Err(ErrorCode::InvalidState);249}250let addr = self251.socket252.as_socketlike_view::<std::net::UdpSocket>()253.peer_addr()?;254Ok(addr)255}256257pub(crate) fn address_family(&self) -> SocketAddressFamily {258self.family259}260261pub(crate) fn unicast_hop_limit(&self) -> Result<u8, ErrorCode> {262let n = get_unicast_hop_limit(&self.socket, self.family)?;263Ok(n)264}265266pub(crate) fn set_unicast_hop_limit(&self, value: u8) -> Result<(), ErrorCode> {267set_unicast_hop_limit(&self.socket, self.family, value)?;268Ok(())269}270271pub(crate) fn receive_buffer_size(&self) -> Result<u64, ErrorCode> {272let n = receive_buffer_size(&self.socket)?;273Ok(n)274}275276pub(crate) fn set_receive_buffer_size(&self, value: u64) -> Result<(), ErrorCode> {277set_receive_buffer_size(&self.socket, value)?;278Ok(())279}280281pub(crate) fn send_buffer_size(&self) -> Result<u64, ErrorCode> {282let n = send_buffer_size(&self.socket)?;283Ok(n)284}285286pub(crate) fn set_send_buffer_size(&self, value: u64) -> Result<(), ErrorCode> {287set_send_buffer_size(&self.socket, value)?;288Ok(())289}290291pub(crate) fn socket(&self) -> &Arc<tokio::net::UdpSocket> {292&self.socket293}294295pub(crate) fn socket_addr_check(&self) -> Option<&SocketAddrCheck> {296self.socket_addr_check.as_ref()297}298299pub(crate) fn set_socket_addr_check(&mut self, check: Option<SocketAddrCheck>) {300self.socket_addr_check = check;301}302}303304#[cfg(feature = "p3")]305async fn send(socket: &tokio::net::UdpSocket, buf: &[u8]) -> Result<(), ErrorCode> {306let n = socket.send(buf).await?;307// From Rust stdlib docs:308// > Note that the operating system may refuse buffers larger than 65507.309// > However, partial writes are not possible until buffer sizes above `i32::MAX`.310//311// For example, on Windows, at most `i32::MAX` bytes will be written312if n != buf.len() {313Err(ErrorCode::Unknown)314} else {315Ok(())316}317}318319#[cfg(feature = "p3")]320async fn send_to(321socket: &tokio::net::UdpSocket,322buf: &[u8],323addr: SocketAddr,324) -> Result<(), ErrorCode> {325let n = socket.send_to(buf, addr).await?;326// See [`send`] documentation327if n != buf.len() {328Err(ErrorCode::Unknown)329} else {330Ok(())331}332}333334335