Path: blob/main/crates/wasi/src/sockets/udp.rs
3137 views
use crate::runtime::with_ambient_tokio_runtime;1use crate::sockets::util::{2ErrorCode, get_unicast_hop_limit, is_valid_address_family, is_valid_remote_address,3receive_buffer_size, send_buffer_size, set_receive_buffer_size, set_send_buffer_size,4set_unicast_hop_limit, udp_bind, udp_disconnect, udp_socket,5};6use crate::sockets::{SocketAddrCheck, SocketAddressFamily, WasiSocketsCtx};7use cap_net_ext::AddressFamily;8use io_lifetimes::AsSocketlike as _;9use io_lifetimes::raw::{FromRawSocketlike as _, IntoRawSocketlike as _};10use rustix::io::Errno;11use rustix::net::connect;12use std::net::SocketAddr;13use std::sync::Arc;14use tracing::debug;1516/// The state of a UDP socket.17///18/// This represents the various states a socket can be in during the19/// activities of binding, and connecting.20enum UdpState {21/// The initial state for a newly-created socket.22Default,2324/// A `bind` operation has started but has yet to complete with25/// `finish_bind`.26BindStarted,2728/// Binding finished via `finish_bind`. The socket has an address but29/// is not yet listening for connections.30Bound,3132/// The socket is "connected" to a peer address.33#[cfg_attr(34not(feature = "p3"),35expect(dead_code, reason = "p2 has its own way of managing sending/receiving")36)]37Connected(SocketAddr),38}3940/// A host UDP socket, plus associated bookkeeping.41///42/// The inner state is wrapped in an Arc because the same underlying socket is43/// used for implementing the stream types.44pub struct UdpSocket {45socket: Arc<tokio::net::UdpSocket>,4647/// The current state in the bind/connect progression.48udp_state: UdpState,4950/// Socket address family.51family: SocketAddressFamily,5253/// If set, use this custom check for addrs, otherwise use what's in54/// `WasiSocketsCtx`.55socket_addr_check: Option<SocketAddrCheck>,56}5758impl UdpSocket {59/// Create a new socket in the given family.60pub(crate) fn new(cx: &WasiSocketsCtx, family: AddressFamily) -> Result<Self, ErrorCode> {61cx.allowed_network_uses.check_allowed_udp()?;6263// Delegate socket creation to cap_net_ext. They handle a couple of things for us:64// - On Windows: call WSAStartup if not done before.65// - Set the NONBLOCK and CLOEXEC flags. Either immediately during socket creation,66// or afterwards using ioctl or fcntl. Exact method depends on the platform.6768let fd = udp_socket(family)?;6970let socket_address_family = match family {71AddressFamily::Ipv4 => SocketAddressFamily::Ipv4,72AddressFamily::Ipv6 => {73rustix::net::sockopt::set_ipv6_v6only(&fd, true)?;74SocketAddressFamily::Ipv675}76};7778let socket = with_ambient_tokio_runtime(|| {79tokio::net::UdpSocket::try_from(unsafe {80std::net::UdpSocket::from_raw_socketlike(fd.into_raw_socketlike())81})82})?;8384Ok(Self {85socket: Arc::new(socket),86udp_state: UdpState::Default,87family: socket_address_family,88socket_addr_check: None,89})90}9192pub(crate) fn bind(&mut self, addr: SocketAddr) -> Result<(), ErrorCode> {93if !matches!(self.udp_state, UdpState::Default) {94return Err(ErrorCode::InvalidState);95}96if !is_valid_address_family(addr.ip(), self.family) {97return Err(ErrorCode::InvalidArgument);98}99udp_bind(&self.socket, addr)?;100self.udp_state = UdpState::BindStarted;101Ok(())102}103104pub(crate) fn finish_bind(&mut self) -> Result<(), ErrorCode> {105match self.udp_state {106UdpState::BindStarted => {107self.udp_state = UdpState::Bound;108Ok(())109}110_ => Err(ErrorCode::NotInProgress),111}112}113114pub(crate) fn is_connected(&self) -> bool {115matches!(self.udp_state, UdpState::Connected(..))116}117118pub(crate) fn is_bound(&self) -> bool {119matches!(self.udp_state, UdpState::Connected(..) | UdpState::Bound)120}121122pub(crate) fn disconnect(&mut self) -> Result<(), ErrorCode> {123if !self.is_connected() {124return Err(ErrorCode::InvalidState);125}126udp_disconnect(&self.socket)?;127self.udp_state = UdpState::Bound;128Ok(())129}130131/// Connect using p2 semantics. (no implicit bind)132pub(crate) fn connect_p2(&mut self, addr: SocketAddr) -> Result<(), ErrorCode> {133match self.udp_state {134UdpState::Bound | UdpState::Connected(_) => {}135_ => return Err(ErrorCode::InvalidState),136}137138self.connect_common(addr)139}140141/// Connect using p3 semantics. (with implicit bind)142#[cfg(feature = "p3")]143pub(crate) fn connect_p3(&mut self, addr: SocketAddr) -> Result<(), ErrorCode> {144match self.udp_state {145UdpState::Default | UdpState::Bound | UdpState::Connected(_) => {}146_ => return Err(ErrorCode::InvalidState),147}148149self.connect_common(addr)150}151152fn connect_common(&mut self, addr: SocketAddr) -> Result<(), ErrorCode> {153if !is_valid_address_family(addr.ip(), self.family) || !is_valid_remote_address(addr) {154return Err(ErrorCode::InvalidArgument);155}156157// We disconnect & (re)connect in two distinct steps for two reasons:158// - To leave our socket instance in a consistent state in case the159// connect fails.160// - When reconnecting to a different address, Linux sometimes fails161// if there isn't a disconnect in between.162163// Step #1: Disconnect164if let UdpState::Connected(..) = self.udp_state {165udp_disconnect(&self.socket)?;166self.udp_state = UdpState::Bound;167}168// Step #2: (Re)connect169connect(&self.socket, &addr).map_err(|error| match error {170Errno::AFNOSUPPORT => ErrorCode::InvalidArgument, // See `udp_bind` implementation.171Errno::INPROGRESS => {172debug!("UDP connect returned EINPROGRESS, which should never happen");173ErrorCode::Unknown174}175err => err.into(),176})?;177self.udp_state = UdpState::Connected(addr);178Ok(())179}180181/// Send data using p3 semantics. (with implicit bind)182#[cfg(feature = "p3")]183pub(crate) fn send_p3(184&mut self,185buf: Vec<u8>,186addr: Option<SocketAddr>,187) -> impl Future<Output = Result<(), ErrorCode>> + use<> {188enum Mode {189Send(Arc<tokio::net::UdpSocket>),190SendTo(Arc<tokio::net::UdpSocket>, SocketAddr),191}192let mut socket = match (&self.udp_state, addr) {193(UdpState::BindStarted, _) => Err(ErrorCode::InvalidState),194(UdpState::Default | UdpState::Bound, None) => Err(ErrorCode::InvalidArgument),195(UdpState::Default | UdpState::Bound, Some(addr)) => {196Ok(Mode::SendTo(Arc::clone(&self.socket), addr))197}198(UdpState::Connected(..), None) => Ok(Mode::Send(Arc::clone(&self.socket))),199(UdpState::Connected(caddr), Some(addr)) => {200if addr == *caddr {201Ok(Mode::Send(Arc::clone(&self.socket)))202} else {203Err(ErrorCode::InvalidArgument)204}205}206};207208// Send may be called without a prior bind or connect. In that case, the209// first send will automatically assign a free local port. This is210// normally performed by the OS itself. However, if the `send` syscall211// failed, we can't reliably know which state the socket is in at the212// kernel level and our own `udp_state` bookkeeping may have become213// out-of-sync.214// To avoid that, we perform the implicit bind ourselves here. This way,215// we always leave the socket in a consistent state: Bound.216if socket.is_ok()217&& let UdpState::Default = self.udp_state218{219let implicit_addr = crate::sockets::util::implicit_bind_addr(self.family);220match udp_bind(&self.socket, implicit_addr) {221Ok(()) => {222self.udp_state = UdpState::Bound;223}224Err(e) => {225socket = Err(e);226}227}228}229230async move {231match socket? {232Mode::Send(socket) => send(&socket, &buf).await,233Mode::SendTo(socket, addr) => send_to(&socket, &buf, addr).await,234}235}236}237238/// Receive data using p3 semantics.239#[cfg(feature = "p3")]240pub(crate) fn receive_p3(241&self,242) -> impl Future<Output = Result<(Vec<u8>, SocketAddr), ErrorCode>> + use<> {243enum Mode {244Recv(Arc<tokio::net::UdpSocket>, SocketAddr),245RecvFrom(Arc<tokio::net::UdpSocket>),246}247let socket = match self.udp_state {248UdpState::Default | UdpState::BindStarted => Err(ErrorCode::InvalidState),249UdpState::Bound => Ok(Mode::RecvFrom(Arc::clone(&self.socket))),250UdpState::Connected(addr) => Ok(Mode::Recv(Arc::clone(&self.socket), addr)),251};252async move {253let socket = socket?;254let mut buf = vec![0; super::MAX_UDP_DATAGRAM_SIZE];255let (n, addr) = match socket {256Mode::Recv(socket, addr) => {257let n = socket.recv(&mut buf).await?;258(n, addr)259}260Mode::RecvFrom(socket) => {261let (n, addr) = socket.recv_from(&mut buf).await?;262(n, addr)263}264};265buf.truncate(n);266Ok((buf, addr))267}268}269270pub(crate) fn local_address(&self) -> Result<SocketAddr, ErrorCode> {271if matches!(self.udp_state, UdpState::Default | UdpState::BindStarted) {272return Err(ErrorCode::InvalidState);273}274let addr = self275.socket276.as_socketlike_view::<std::net::UdpSocket>()277.local_addr()?;278Ok(addr)279}280281pub(crate) fn remote_address(&self) -> Result<SocketAddr, ErrorCode> {282if !matches!(self.udp_state, UdpState::Connected(..)) {283return Err(ErrorCode::InvalidState);284}285let addr = self286.socket287.as_socketlike_view::<std::net::UdpSocket>()288.peer_addr()?;289Ok(addr)290}291292pub(crate) fn address_family(&self) -> SocketAddressFamily {293self.family294}295296pub(crate) fn unicast_hop_limit(&self) -> Result<u8, ErrorCode> {297let n = get_unicast_hop_limit(&self.socket, self.family)?;298Ok(n)299}300301pub(crate) fn set_unicast_hop_limit(&self, value: u8) -> Result<(), ErrorCode> {302set_unicast_hop_limit(&self.socket, self.family, value)?;303Ok(())304}305306pub(crate) fn receive_buffer_size(&self) -> Result<u64, ErrorCode> {307let n = receive_buffer_size(&self.socket)?;308Ok(n)309}310311pub(crate) fn set_receive_buffer_size(&self, value: u64) -> Result<(), ErrorCode> {312set_receive_buffer_size(&self.socket, value)?;313Ok(())314}315316pub(crate) fn send_buffer_size(&self) -> Result<u64, ErrorCode> {317let n = send_buffer_size(&self.socket)?;318Ok(n)319}320321pub(crate) fn set_send_buffer_size(&self, value: u64) -> Result<(), ErrorCode> {322set_send_buffer_size(&self.socket, value)?;323Ok(())324}325326pub(crate) fn socket(&self) -> &Arc<tokio::net::UdpSocket> {327&self.socket328}329330pub(crate) fn socket_addr_check(&self) -> Option<&SocketAddrCheck> {331self.socket_addr_check.as_ref()332}333334pub(crate) fn set_socket_addr_check(&mut self, check: Option<SocketAddrCheck>) {335self.socket_addr_check = check;336}337}338339#[cfg(feature = "p3")]340async fn send(socket: &tokio::net::UdpSocket, buf: &[u8]) -> Result<(), ErrorCode> {341let n = socket.send(buf).await?;342// From Rust stdlib docs:343// > Note that the operating system may refuse buffers larger than 65507.344// > However, partial writes are not possible until buffer sizes above `i32::MAX`.345//346// For example, on Windows, at most `i32::MAX` bytes will be written347if n != buf.len() {348Err(ErrorCode::Unknown)349} else {350Ok(())351}352}353354#[cfg(feature = "p3")]355async fn send_to(356socket: &tokio::net::UdpSocket,357buf: &[u8],358addr: SocketAddr,359) -> Result<(), ErrorCode> {360let n = socket.send_to(buf, addr).await?;361// See [`send`] documentation362if n != buf.len() {363Err(ErrorCode::Unknown)364} else {365Ok(())366}367}368369370