Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
bytecodealliance
GitHub Repository: bytecodealliance/wasmtime
Path: blob/main/crates/wasi/src/sockets/udp.rs
3137 views
1
use crate::runtime::with_ambient_tokio_runtime;
2
use crate::sockets::util::{
3
ErrorCode, get_unicast_hop_limit, is_valid_address_family, is_valid_remote_address,
4
receive_buffer_size, send_buffer_size, set_receive_buffer_size, set_send_buffer_size,
5
set_unicast_hop_limit, udp_bind, udp_disconnect, udp_socket,
6
};
7
use crate::sockets::{SocketAddrCheck, SocketAddressFamily, WasiSocketsCtx};
8
use cap_net_ext::AddressFamily;
9
use io_lifetimes::AsSocketlike as _;
10
use io_lifetimes::raw::{FromRawSocketlike as _, IntoRawSocketlike as _};
11
use rustix::io::Errno;
12
use rustix::net::connect;
13
use std::net::SocketAddr;
14
use std::sync::Arc;
15
use tracing::debug;
16
17
/// The state of a UDP socket.
18
///
19
/// This represents the various states a socket can be in during the
20
/// activities of binding, and connecting.
21
enum UdpState {
22
/// The initial state for a newly-created socket.
23
Default,
24
25
/// A `bind` operation has started but has yet to complete with
26
/// `finish_bind`.
27
BindStarted,
28
29
/// Binding finished via `finish_bind`. The socket has an address but
30
/// is not yet listening for connections.
31
Bound,
32
33
/// The socket is "connected" to a peer address.
34
#[cfg_attr(
35
not(feature = "p3"),
36
expect(dead_code, reason = "p2 has its own way of managing sending/receiving")
37
)]
38
Connected(SocketAddr),
39
}
40
41
/// A host UDP socket, plus associated bookkeeping.
42
///
43
/// The inner state is wrapped in an Arc because the same underlying socket is
44
/// used for implementing the stream types.
45
pub struct UdpSocket {
46
socket: Arc<tokio::net::UdpSocket>,
47
48
/// The current state in the bind/connect progression.
49
udp_state: UdpState,
50
51
/// Socket address family.
52
family: SocketAddressFamily,
53
54
/// If set, use this custom check for addrs, otherwise use what's in
55
/// `WasiSocketsCtx`.
56
socket_addr_check: Option<SocketAddrCheck>,
57
}
58
59
impl UdpSocket {
60
/// Create a new socket in the given family.
61
pub(crate) fn new(cx: &WasiSocketsCtx, family: AddressFamily) -> Result<Self, ErrorCode> {
62
cx.allowed_network_uses.check_allowed_udp()?;
63
64
// Delegate socket creation to cap_net_ext. They handle a couple of things for us:
65
// - On Windows: call WSAStartup if not done before.
66
// - Set the NONBLOCK and CLOEXEC flags. Either immediately during socket creation,
67
// or afterwards using ioctl or fcntl. Exact method depends on the platform.
68
69
let fd = udp_socket(family)?;
70
71
let socket_address_family = match family {
72
AddressFamily::Ipv4 => SocketAddressFamily::Ipv4,
73
AddressFamily::Ipv6 => {
74
rustix::net::sockopt::set_ipv6_v6only(&fd, true)?;
75
SocketAddressFamily::Ipv6
76
}
77
};
78
79
let socket = with_ambient_tokio_runtime(|| {
80
tokio::net::UdpSocket::try_from(unsafe {
81
std::net::UdpSocket::from_raw_socketlike(fd.into_raw_socketlike())
82
})
83
})?;
84
85
Ok(Self {
86
socket: Arc::new(socket),
87
udp_state: UdpState::Default,
88
family: socket_address_family,
89
socket_addr_check: None,
90
})
91
}
92
93
pub(crate) fn bind(&mut self, addr: SocketAddr) -> Result<(), ErrorCode> {
94
if !matches!(self.udp_state, UdpState::Default) {
95
return Err(ErrorCode::InvalidState);
96
}
97
if !is_valid_address_family(addr.ip(), self.family) {
98
return Err(ErrorCode::InvalidArgument);
99
}
100
udp_bind(&self.socket, addr)?;
101
self.udp_state = UdpState::BindStarted;
102
Ok(())
103
}
104
105
pub(crate) fn finish_bind(&mut self) -> Result<(), ErrorCode> {
106
match self.udp_state {
107
UdpState::BindStarted => {
108
self.udp_state = UdpState::Bound;
109
Ok(())
110
}
111
_ => Err(ErrorCode::NotInProgress),
112
}
113
}
114
115
pub(crate) fn is_connected(&self) -> bool {
116
matches!(self.udp_state, UdpState::Connected(..))
117
}
118
119
pub(crate) fn is_bound(&self) -> bool {
120
matches!(self.udp_state, UdpState::Connected(..) | UdpState::Bound)
121
}
122
123
pub(crate) fn disconnect(&mut self) -> Result<(), ErrorCode> {
124
if !self.is_connected() {
125
return Err(ErrorCode::InvalidState);
126
}
127
udp_disconnect(&self.socket)?;
128
self.udp_state = UdpState::Bound;
129
Ok(())
130
}
131
132
/// Connect using p2 semantics. (no implicit bind)
133
pub(crate) fn connect_p2(&mut self, addr: SocketAddr) -> Result<(), ErrorCode> {
134
match self.udp_state {
135
UdpState::Bound | UdpState::Connected(_) => {}
136
_ => return Err(ErrorCode::InvalidState),
137
}
138
139
self.connect_common(addr)
140
}
141
142
/// Connect using p3 semantics. (with implicit bind)
143
#[cfg(feature = "p3")]
144
pub(crate) fn connect_p3(&mut self, addr: SocketAddr) -> Result<(), ErrorCode> {
145
match self.udp_state {
146
UdpState::Default | UdpState::Bound | UdpState::Connected(_) => {}
147
_ => return Err(ErrorCode::InvalidState),
148
}
149
150
self.connect_common(addr)
151
}
152
153
fn connect_common(&mut self, addr: SocketAddr) -> Result<(), ErrorCode> {
154
if !is_valid_address_family(addr.ip(), self.family) || !is_valid_remote_address(addr) {
155
return Err(ErrorCode::InvalidArgument);
156
}
157
158
// We disconnect & (re)connect in two distinct steps for two reasons:
159
// - To leave our socket instance in a consistent state in case the
160
// connect fails.
161
// - When reconnecting to a different address, Linux sometimes fails
162
// if there isn't a disconnect in between.
163
164
// Step #1: Disconnect
165
if let UdpState::Connected(..) = self.udp_state {
166
udp_disconnect(&self.socket)?;
167
self.udp_state = UdpState::Bound;
168
}
169
// Step #2: (Re)connect
170
connect(&self.socket, &addr).map_err(|error| match error {
171
Errno::AFNOSUPPORT => ErrorCode::InvalidArgument, // See `udp_bind` implementation.
172
Errno::INPROGRESS => {
173
debug!("UDP connect returned EINPROGRESS, which should never happen");
174
ErrorCode::Unknown
175
}
176
err => err.into(),
177
})?;
178
self.udp_state = UdpState::Connected(addr);
179
Ok(())
180
}
181
182
/// Send data using p3 semantics. (with implicit bind)
183
#[cfg(feature = "p3")]
184
pub(crate) fn send_p3(
185
&mut self,
186
buf: Vec<u8>,
187
addr: Option<SocketAddr>,
188
) -> impl Future<Output = Result<(), ErrorCode>> + use<> {
189
enum Mode {
190
Send(Arc<tokio::net::UdpSocket>),
191
SendTo(Arc<tokio::net::UdpSocket>, SocketAddr),
192
}
193
let mut socket = match (&self.udp_state, addr) {
194
(UdpState::BindStarted, _) => Err(ErrorCode::InvalidState),
195
(UdpState::Default | UdpState::Bound, None) => Err(ErrorCode::InvalidArgument),
196
(UdpState::Default | UdpState::Bound, Some(addr)) => {
197
Ok(Mode::SendTo(Arc::clone(&self.socket), addr))
198
}
199
(UdpState::Connected(..), None) => Ok(Mode::Send(Arc::clone(&self.socket))),
200
(UdpState::Connected(caddr), Some(addr)) => {
201
if addr == *caddr {
202
Ok(Mode::Send(Arc::clone(&self.socket)))
203
} else {
204
Err(ErrorCode::InvalidArgument)
205
}
206
}
207
};
208
209
// Send may be called without a prior bind or connect. In that case, the
210
// first send will automatically assign a free local port. This is
211
// normally performed by the OS itself. However, if the `send` syscall
212
// failed, we can't reliably know which state the socket is in at the
213
// kernel level and our own `udp_state` bookkeeping may have become
214
// out-of-sync.
215
// To avoid that, we perform the implicit bind ourselves here. This way,
216
// we always leave the socket in a consistent state: Bound.
217
if socket.is_ok()
218
&& let UdpState::Default = self.udp_state
219
{
220
let implicit_addr = crate::sockets::util::implicit_bind_addr(self.family);
221
match udp_bind(&self.socket, implicit_addr) {
222
Ok(()) => {
223
self.udp_state = UdpState::Bound;
224
}
225
Err(e) => {
226
socket = Err(e);
227
}
228
}
229
}
230
231
async move {
232
match socket? {
233
Mode::Send(socket) => send(&socket, &buf).await,
234
Mode::SendTo(socket, addr) => send_to(&socket, &buf, addr).await,
235
}
236
}
237
}
238
239
/// Receive data using p3 semantics.
240
#[cfg(feature = "p3")]
241
pub(crate) fn receive_p3(
242
&self,
243
) -> impl Future<Output = Result<(Vec<u8>, SocketAddr), ErrorCode>> + use<> {
244
enum Mode {
245
Recv(Arc<tokio::net::UdpSocket>, SocketAddr),
246
RecvFrom(Arc<tokio::net::UdpSocket>),
247
}
248
let socket = match self.udp_state {
249
UdpState::Default | UdpState::BindStarted => Err(ErrorCode::InvalidState),
250
UdpState::Bound => Ok(Mode::RecvFrom(Arc::clone(&self.socket))),
251
UdpState::Connected(addr) => Ok(Mode::Recv(Arc::clone(&self.socket), addr)),
252
};
253
async move {
254
let socket = socket?;
255
let mut buf = vec![0; super::MAX_UDP_DATAGRAM_SIZE];
256
let (n, addr) = match socket {
257
Mode::Recv(socket, addr) => {
258
let n = socket.recv(&mut buf).await?;
259
(n, addr)
260
}
261
Mode::RecvFrom(socket) => {
262
let (n, addr) = socket.recv_from(&mut buf).await?;
263
(n, addr)
264
}
265
};
266
buf.truncate(n);
267
Ok((buf, addr))
268
}
269
}
270
271
pub(crate) fn local_address(&self) -> Result<SocketAddr, ErrorCode> {
272
if matches!(self.udp_state, UdpState::Default | UdpState::BindStarted) {
273
return Err(ErrorCode::InvalidState);
274
}
275
let addr = self
276
.socket
277
.as_socketlike_view::<std::net::UdpSocket>()
278
.local_addr()?;
279
Ok(addr)
280
}
281
282
pub(crate) fn remote_address(&self) -> Result<SocketAddr, ErrorCode> {
283
if !matches!(self.udp_state, UdpState::Connected(..)) {
284
return Err(ErrorCode::InvalidState);
285
}
286
let addr = self
287
.socket
288
.as_socketlike_view::<std::net::UdpSocket>()
289
.peer_addr()?;
290
Ok(addr)
291
}
292
293
pub(crate) fn address_family(&self) -> SocketAddressFamily {
294
self.family
295
}
296
297
pub(crate) fn unicast_hop_limit(&self) -> Result<u8, ErrorCode> {
298
let n = get_unicast_hop_limit(&self.socket, self.family)?;
299
Ok(n)
300
}
301
302
pub(crate) fn set_unicast_hop_limit(&self, value: u8) -> Result<(), ErrorCode> {
303
set_unicast_hop_limit(&self.socket, self.family, value)?;
304
Ok(())
305
}
306
307
pub(crate) fn receive_buffer_size(&self) -> Result<u64, ErrorCode> {
308
let n = receive_buffer_size(&self.socket)?;
309
Ok(n)
310
}
311
312
pub(crate) fn set_receive_buffer_size(&self, value: u64) -> Result<(), ErrorCode> {
313
set_receive_buffer_size(&self.socket, value)?;
314
Ok(())
315
}
316
317
pub(crate) fn send_buffer_size(&self) -> Result<u64, ErrorCode> {
318
let n = send_buffer_size(&self.socket)?;
319
Ok(n)
320
}
321
322
pub(crate) fn set_send_buffer_size(&self, value: u64) -> Result<(), ErrorCode> {
323
set_send_buffer_size(&self.socket, value)?;
324
Ok(())
325
}
326
327
pub(crate) fn socket(&self) -> &Arc<tokio::net::UdpSocket> {
328
&self.socket
329
}
330
331
pub(crate) fn socket_addr_check(&self) -> Option<&SocketAddrCheck> {
332
self.socket_addr_check.as_ref()
333
}
334
335
pub(crate) fn set_socket_addr_check(&mut self, check: Option<SocketAddrCheck>) {
336
self.socket_addr_check = check;
337
}
338
}
339
340
#[cfg(feature = "p3")]
341
async fn send(socket: &tokio::net::UdpSocket, buf: &[u8]) -> Result<(), ErrorCode> {
342
let n = socket.send(buf).await?;
343
// From Rust stdlib docs:
344
// > Note that the operating system may refuse buffers larger than 65507.
345
// > However, partial writes are not possible until buffer sizes above `i32::MAX`.
346
//
347
// For example, on Windows, at most `i32::MAX` bytes will be written
348
if n != buf.len() {
349
Err(ErrorCode::Unknown)
350
} else {
351
Ok(())
352
}
353
}
354
355
#[cfg(feature = "p3")]
356
async fn send_to(
357
socket: &tokio::net::UdpSocket,
358
buf: &[u8],
359
addr: SocketAddr,
360
) -> Result<(), ErrorCode> {
361
let n = socket.send_to(buf, addr).await?;
362
// See [`send`] documentation
363
if n != buf.len() {
364
Err(ErrorCode::Unknown)
365
} else {
366
Ok(())
367
}
368
}
369
370