Path: blob/main/crates/test-programs/src/sockets.rs
1693 views
use crate::wasi::clocks::monotonic_clock;1use crate::wasi::io::poll::{self, Pollable};2use crate::wasi::io::streams::{InputStream, OutputStream, StreamError};3use crate::wasi::random;4use crate::wasi::sockets::instance_network;5use crate::wasi::sockets::ip_name_lookup;6use crate::wasi::sockets::network::{7ErrorCode, IpAddress, IpAddressFamily, IpSocketAddress, Ipv4SocketAddress, Ipv6SocketAddress,8Network,9};10use crate::wasi::sockets::tcp::TcpSocket;11use crate::wasi::sockets::udp::{12IncomingDatagram, IncomingDatagramStream, OutgoingDatagram, OutgoingDatagramStream, UdpSocket,13};14use crate::wasi::sockets::{tcp_create_socket, udp_create_socket};15use std::ops::Range;1617const TIMEOUT_NS: u64 = 1_000_000_000;1819impl Pollable {20pub fn block_until(&self, timeout: &Pollable) -> Result<(), ErrorCode> {21let ready = poll::poll(&[self, timeout]);22assert!(ready.len() > 0);23match ready[0] {240 => Ok(()),251 => Err(ErrorCode::Timeout),26_ => unreachable!(),27}28}29}3031impl InputStream {32pub fn blocking_read_to_end(&self) -> Result<Vec<u8>, crate::wasi::io::error::Error> {33let mut data = vec![];34loop {35match self.blocking_read(1024 * 1024) {36Ok(chunk) => data.extend(chunk),37Err(StreamError::Closed) => return Ok(data),38Err(StreamError::LastOperationFailed(e)) => return Err(e),39}40}41}42}4344impl OutputStream {45pub fn blocking_write_util(&self, mut bytes: &[u8]) -> Result<(), StreamError> {46let timeout = monotonic_clock::subscribe_duration(TIMEOUT_NS);47let pollable = self.subscribe();4849while !bytes.is_empty() {50pollable.block_until(&timeout).expect("write timed out");5152let permit = self.check_write()?;5354let len = bytes.len().min(permit as usize);55let (chunk, rest) = bytes.split_at(len);5657self.write(chunk)?;5859self.blocking_flush()?;6061bytes = rest;62}63Ok(())64}65}6667impl Network {68pub fn default() -> Network {69instance_network::instance_network()70}7172pub fn blocking_resolve_addresses(&self, name: &str) -> Result<Vec<IpAddress>, ErrorCode> {73let stream = ip_name_lookup::resolve_addresses(&self, name)?;7475let timeout = monotonic_clock::subscribe_duration(TIMEOUT_NS);76let pollable = stream.subscribe();7778let mut addresses = vec![];7980loop {81match stream.resolve_next_address() {82Ok(Some(addr)) => {83addresses.push(addr);84}85Ok(None) => match addresses[..] {86[] => return Err(ErrorCode::NameUnresolvable),87_ => return Ok(addresses),88},89Err(ErrorCode::WouldBlock) => {90pollable.block_until(&timeout)?;91}92Err(err) => return Err(err),93}94}95}9697/// Same as `Network::blocking_resolve_addresses` but ignores post validation errors98///99/// The ignored error codes signal that the input passed validation100/// and a lookup was actually attempted, but failed. These are ignored to101/// make the CI tests less flaky.102pub fn permissive_blocking_resolve_addresses(103&self,104name: &str,105) -> Result<Vec<IpAddress>, ErrorCode> {106match self.blocking_resolve_addresses(name) {107Err(ErrorCode::NameUnresolvable | ErrorCode::TemporaryResolverFailure) => Ok(vec![]),108r => r,109}110}111}112113impl TcpSocket {114pub fn new(address_family: IpAddressFamily) -> Result<TcpSocket, ErrorCode> {115tcp_create_socket::create_tcp_socket(address_family)116}117118pub fn blocking_bind(119&self,120network: &Network,121local_address: IpSocketAddress,122) -> Result<(), ErrorCode> {123let timeout = monotonic_clock::subscribe_duration(TIMEOUT_NS);124let sub = self.subscribe();125126self.start_bind(&network, local_address)?;127128loop {129match self.finish_bind() {130Err(ErrorCode::WouldBlock) => sub.block_until(&timeout)?,131result => return result,132}133}134}135136pub fn blocking_listen(&self) -> Result<(), ErrorCode> {137let timeout = monotonic_clock::subscribe_duration(TIMEOUT_NS);138let sub = self.subscribe();139140self.start_listen()?;141142loop {143match self.finish_listen() {144Err(ErrorCode::WouldBlock) => sub.block_until(&timeout)?,145result => return result,146}147}148}149150pub fn blocking_connect(151&self,152network: &Network,153remote_address: IpSocketAddress,154) -> Result<(InputStream, OutputStream), ErrorCode> {155let timeout = monotonic_clock::subscribe_duration(TIMEOUT_NS);156let sub = self.subscribe();157158self.start_connect(&network, remote_address)?;159160loop {161match self.finish_connect() {162Err(ErrorCode::WouldBlock) => sub.block_until(&timeout)?,163result => return result,164}165}166}167168pub fn blocking_accept(&self) -> Result<(TcpSocket, InputStream, OutputStream), ErrorCode> {169let timeout = monotonic_clock::subscribe_duration(TIMEOUT_NS);170let sub = self.subscribe();171172loop {173match self.accept() {174Err(ErrorCode::WouldBlock) => sub.block_until(&timeout)?,175result => return result,176}177}178}179}180181impl UdpSocket {182pub fn new(address_family: IpAddressFamily) -> Result<UdpSocket, ErrorCode> {183udp_create_socket::create_udp_socket(address_family)184}185186pub fn blocking_bind(187&self,188network: &Network,189local_address: IpSocketAddress,190) -> Result<(), ErrorCode> {191let timeout = monotonic_clock::subscribe_duration(TIMEOUT_NS);192let sub = self.subscribe();193194self.start_bind(&network, local_address)?;195196loop {197match self.finish_bind() {198Err(ErrorCode::WouldBlock) => sub.block_until(&timeout)?,199result => return result,200}201}202}203204pub fn blocking_bind_unspecified(&self, network: &Network) -> Result<(), ErrorCode> {205let ip = IpAddress::new_unspecified(self.address_family());206let port = 0;207208self.blocking_bind(network, IpSocketAddress::new(ip, port))209}210}211212impl OutgoingDatagramStream {213fn blocking_check_send(&self, timeout: &Pollable) -> Result<u64, ErrorCode> {214let sub = self.subscribe();215216loop {217match self.check_send() {218Ok(0) => sub.block_until(timeout)?,219result => return result,220}221}222}223224pub fn blocking_send(&self, mut datagrams: &[OutgoingDatagram]) -> Result<(), ErrorCode> {225let timeout = monotonic_clock::subscribe_duration(TIMEOUT_NS);226227while !datagrams.is_empty() {228let permit = self.blocking_check_send(&timeout)?;229let chunk_len = datagrams.len().min(permit as usize);230match self.send(&datagrams[..chunk_len]) {231Ok(0) => {}232Ok(packets_sent) => {233let packets_sent = packets_sent as usize;234datagrams = &datagrams[packets_sent..];235}236Err(err) => return Err(err),237}238}239240Ok(())241}242}243244impl IncomingDatagramStream {245pub fn blocking_receive(&self, count: Range<u64>) -> Result<Vec<IncomingDatagram>, ErrorCode> {246let timeout = monotonic_clock::subscribe_duration(TIMEOUT_NS);247let pollable = self.subscribe();248let mut datagrams = vec![];249250loop {251match self.receive(count.end - datagrams.len() as u64) {252Ok(mut chunk) => {253datagrams.append(&mut chunk);254255if datagrams.len() >= count.start as usize {256return Ok(datagrams);257} else {258pollable.block_until(&timeout)?;259}260}261Err(err) => return Err(err),262}263}264}265}266267impl IpAddress {268pub const IPV4_BROADCAST: IpAddress = IpAddress::Ipv4((255, 255, 255, 255));269270pub const IPV4_LOOPBACK: IpAddress = IpAddress::Ipv4((127, 0, 0, 1));271pub const IPV6_LOOPBACK: IpAddress = IpAddress::Ipv6((0, 0, 0, 0, 0, 0, 0, 1));272273pub const IPV4_UNSPECIFIED: IpAddress = IpAddress::Ipv4((0, 0, 0, 0));274pub const IPV6_UNSPECIFIED: IpAddress = IpAddress::Ipv6((0, 0, 0, 0, 0, 0, 0, 0));275276pub const IPV4_MAPPED_LOOPBACK: IpAddress =277IpAddress::Ipv6((0, 0, 0, 0, 0, 0xFFFF, 0x7F00, 0x0001));278279pub const fn new_loopback(family: IpAddressFamily) -> IpAddress {280match family {281IpAddressFamily::Ipv4 => Self::IPV4_LOOPBACK,282IpAddressFamily::Ipv6 => Self::IPV6_LOOPBACK,283}284}285286pub const fn new_unspecified(family: IpAddressFamily) -> IpAddress {287match family {288IpAddressFamily::Ipv4 => Self::IPV4_UNSPECIFIED,289IpAddressFamily::Ipv6 => Self::IPV6_UNSPECIFIED,290}291}292293pub const fn family(&self) -> IpAddressFamily {294match self {295IpAddress::Ipv4(_) => IpAddressFamily::Ipv4,296IpAddress::Ipv6(_) => IpAddressFamily::Ipv6,297}298}299}300301impl PartialEq for IpAddress {302fn eq(&self, other: &Self) -> bool {303match (self, other) {304(Self::Ipv4(left), Self::Ipv4(right)) => left == right,305(Self::Ipv6(left), Self::Ipv6(right)) => left == right,306_ => false,307}308}309}310311impl IpSocketAddress {312pub const fn new(ip: IpAddress, port: u16) -> IpSocketAddress {313match ip {314IpAddress::Ipv4(addr) => IpSocketAddress::Ipv4(Ipv4SocketAddress {315port,316address: addr,317}),318IpAddress::Ipv6(addr) => IpSocketAddress::Ipv6(Ipv6SocketAddress {319port,320address: addr,321flow_info: 0,322scope_id: 0,323}),324}325}326327pub const fn ip(&self) -> IpAddress {328match self {329IpSocketAddress::Ipv4(addr) => IpAddress::Ipv4(addr.address),330IpSocketAddress::Ipv6(addr) => IpAddress::Ipv6(addr.address),331}332}333334pub const fn port(&self) -> u16 {335match self {336IpSocketAddress::Ipv4(addr) => addr.port,337IpSocketAddress::Ipv6(addr) => addr.port,338}339}340341pub const fn family(&self) -> IpAddressFamily {342match self {343IpSocketAddress::Ipv4(_) => IpAddressFamily::Ipv4,344IpSocketAddress::Ipv6(_) => IpAddressFamily::Ipv6,345}346}347}348349impl PartialEq for Ipv4SocketAddress {350fn eq(&self, other: &Self) -> bool {351self.port == other.port && self.address == other.address352}353}354355impl PartialEq for Ipv6SocketAddress {356fn eq(&self, other: &Self) -> bool {357self.port == other.port358&& self.flow_info == other.flow_info359&& self.address == other.address360&& self.scope_id == other.scope_id361}362}363364impl PartialEq for IpSocketAddress {365fn eq(&self, other: &Self) -> bool {366match (self, other) {367(Self::Ipv4(l0), Self::Ipv4(r0)) => l0 == r0,368(Self::Ipv6(l0), Self::Ipv6(r0)) => l0 == r0,369_ => false,370}371}372}373374fn generate_random_u16(range: Range<u16>) -> u16 {375let start = range.start as u64;376let end = range.end as u64;377let port = start + (random::random::get_random_u64() % (end - start));378port as u16379}380381/// Execute the inner function with a randomly generated port.382/// To prevent random failures, we make a few attempts before giving up.383pub fn attempt_random_port<F>(384local_address: IpAddress,385mut f: F,386) -> Result<IpSocketAddress, ErrorCode>387where388F: FnMut(IpSocketAddress) -> Result<(), ErrorCode>,389{390const MAX_ATTEMPTS: u32 = 10;391let mut i = 0;392loop {393i += 1;394395let port: u16 = generate_random_u16(1024..u16::MAX);396let sock_addr = IpSocketAddress::new(local_address, port);397398match f(sock_addr) {399Ok(_) => return Ok(sock_addr),400Err(e) if i >= MAX_ATTEMPTS => return Err(e),401// Try again if the port is already taken. This can sometimes show up as `AccessDenied` on Windows.402Err(ErrorCode::AddressInUse | ErrorCode::AccessDenied) => {}403Err(e) => return Err(e),404}405}406}407408409