Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
bytecodealliance
GitHub Repository: bytecodealliance/wasmtime
Path: blob/main/crates/wasi/src/sockets/tcp.rs
1692 views
1
use crate::p2::P2TcpStreamingState;
2
use crate::runtime::with_ambient_tokio_runtime;
3
use crate::sockets::util::{
4
ErrorCode, get_unicast_hop_limit, is_valid_address_family, is_valid_remote_address,
5
is_valid_unicast_address, receive_buffer_size, send_buffer_size, set_keep_alive_count,
6
set_keep_alive_idle_time, set_keep_alive_interval, set_receive_buffer_size,
7
set_send_buffer_size, set_unicast_hop_limit, tcp_bind,
8
};
9
use crate::sockets::{DEFAULT_TCP_BACKLOG, SocketAddressFamily, WasiSocketsCtx};
10
use io_lifetimes::AsSocketlike as _;
11
use io_lifetimes::views::SocketlikeView;
12
use rustix::io::Errno;
13
use rustix::net::sockopt;
14
use std::fmt::Debug;
15
use std::io;
16
use std::mem;
17
use std::net::SocketAddr;
18
use std::pin::Pin;
19
use std::sync::Arc;
20
use std::task::{Context, Poll, Waker};
21
use std::time::Duration;
22
23
/// The state of a TCP socket.
24
///
25
/// This represents the various states a socket can be in during the
26
/// activities of binding, listening, accepting, and connecting. Note that this
27
/// state machine encompasses both WASIp2 and WASIp3.
28
enum TcpState {
29
/// The initial state for a newly-created socket.
30
///
31
/// From here a socket can transition to `BindStarted`, `ListenStarted`, or
32
/// `Connecting`.
33
Default(tokio::net::TcpSocket),
34
35
/// A state indicating that a bind has been started and must be finished
36
/// subsequently with `finish_bind`.
37
///
38
/// From here a socket can transition to `Bound`.
39
BindStarted(tokio::net::TcpSocket),
40
41
/// Binding finished. The socket has an address but is not yet listening for
42
/// connections.
43
///
44
/// From here a socket can transition to `ListenStarted`, or `Connecting`.
45
Bound(tokio::net::TcpSocket),
46
47
/// Listening on a socket has started and must be completed with
48
/// `finish_listen`.
49
///
50
/// From here a socket can transition to `Listening`.
51
ListenStarted(tokio::net::TcpSocket),
52
53
/// The socket is now listening and waiting for an incoming connection.
54
///
55
/// Sockets will not leave this state.
56
Listening {
57
/// The raw tokio-basd TCP listener managing the underyling socket.
58
listener: Arc<tokio::net::TcpListener>,
59
60
/// The last-accepted connection, set during the `ready` method and read
61
/// during the `accept` method. Note that this is only used for WASIp2
62
/// at this time.
63
pending_accept: Option<io::Result<tokio::net::TcpStream>>,
64
},
65
66
/// An outgoing connection is started.
67
///
68
/// This is created via the `start_connect` method. The payload here is an
69
/// optionally-specified owned future for the result of the connect. In
70
/// WASIp2 the future lives here, but in WASIp3 it lives on the event loop
71
/// so this is `None`.
72
///
73
/// From here a socket can transition to `ConnectReady` or `Connected`.
74
Connecting(Option<Pin<Box<dyn Future<Output = io::Result<tokio::net::TcpStream>> + Send>>>),
75
76
/// A connection via `Connecting` has completed.
77
///
78
/// This is present for WASIp2 where the `Connecting` state stores `Some` of
79
/// a future, and the result of that future is recorded here when it
80
/// finishes as part of the `ready` method.
81
///
82
/// From here a socket can transition to `Connected`.
83
ConnectReady(io::Result<tokio::net::TcpStream>),
84
85
/// A connection has been established.
86
///
87
/// This is created either via `finish_connect` or for freshly accepted
88
/// sockets from a TCP listener.
89
///
90
/// From here a socket can transition to `Receiving` or `P2Streaming`.
91
Connected(Arc<tokio::net::TcpStream>),
92
93
/// A connection has been established and `receive` has been called.
94
///
95
/// A socket will not transition out of this state.
96
#[cfg(feature = "p3")]
97
Receiving(Arc<tokio::net::TcpStream>),
98
99
/// This is a WASIp2-bound socket which stores some extra state for
100
/// read/write streams to handle TCP shutdown.
101
///
102
/// A socket will not transition out of this state.
103
P2Streaming(Box<P2TcpStreamingState>),
104
105
/// This is not actually a socket but a deferred error.
106
///
107
/// This error came out of `accept` and is deferred until the socket is
108
/// operated on.
109
#[cfg(feature = "p3")]
110
Error(io::Error),
111
112
/// The socket is closed and no more operations can be performed.
113
Closed,
114
}
115
116
impl Debug for TcpState {
117
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118
match self {
119
Self::Default(_) => f.debug_tuple("Default").finish(),
120
Self::BindStarted(_) => f.debug_tuple("BindStarted").finish(),
121
Self::Bound(_) => f.debug_tuple("Bound").finish(),
122
Self::ListenStarted { .. } => f.debug_tuple("ListenStarted").finish(),
123
Self::Listening { .. } => f.debug_tuple("Listening").finish(),
124
Self::Connecting(..) => f.debug_tuple("Connecting").finish(),
125
Self::ConnectReady(..) => f.debug_tuple("ConnectReady").finish(),
126
Self::Connected { .. } => f.debug_tuple("Connected").finish(),
127
#[cfg(feature = "p3")]
128
Self::Receiving { .. } => f.debug_tuple("Receiving").finish(),
129
Self::P2Streaming(_) => f.debug_tuple("P2Streaming").finish(),
130
#[cfg(feature = "p3")]
131
Self::Error(..) => f.debug_tuple("Error").finish(),
132
Self::Closed => write!(f, "Closed"),
133
}
134
}
135
}
136
137
/// A host TCP socket, plus associated bookkeeping.
138
pub struct TcpSocket {
139
/// The current state in the bind/listen/accept/connect progression.
140
tcp_state: TcpState,
141
142
/// The desired listen queue size.
143
listen_backlog_size: u32,
144
145
family: SocketAddressFamily,
146
147
options: NonInheritedOptions,
148
}
149
150
impl TcpSocket {
151
/// Create a new socket in the given family.
152
pub(crate) fn new(
153
ctx: &WasiSocketsCtx,
154
family: SocketAddressFamily,
155
) -> Result<Self, ErrorCode> {
156
ctx.allowed_network_uses.check_allowed_tcp()?;
157
158
with_ambient_tokio_runtime(|| {
159
let socket = match family {
160
SocketAddressFamily::Ipv4 => tokio::net::TcpSocket::new_v4()?,
161
SocketAddressFamily::Ipv6 => {
162
let socket = tokio::net::TcpSocket::new_v6()?;
163
sockopt::set_ipv6_v6only(&socket, true)?;
164
socket
165
}
166
};
167
168
Ok(Self::from_state(TcpState::Default(socket), family))
169
})
170
}
171
172
#[cfg(feature = "p3")]
173
pub(crate) fn new_error(err: io::Error, family: SocketAddressFamily) -> Self {
174
TcpSocket::from_state(TcpState::Error(err), family)
175
}
176
177
/// Creates a new socket with the `result` of an accepted socket from a
178
/// `TcpListener`.
179
///
180
/// This will handle the `result` internally and `result` should be the raw
181
/// result from a TCP listen operation.
182
pub(crate) fn new_accept(
183
result: io::Result<tokio::net::TcpStream>,
184
options: &NonInheritedOptions,
185
family: SocketAddressFamily,
186
) -> io::Result<Self> {
187
let client = result.map_err(|err| match Errno::from_io_error(&err) {
188
// From: https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-accept#:~:text=WSAEINPROGRESS
189
// > WSAEINPROGRESS: A blocking Windows Sockets 1.1 call is in progress,
190
// > or the service provider is still processing a callback function.
191
//
192
// wasi-sockets doesn't have an equivalent to the EINPROGRESS error,
193
// because in POSIX this error is only returned by a non-blocking
194
// `connect` and wasi-sockets has a different solution for that.
195
#[cfg(windows)]
196
Some(Errno::INPROGRESS) => Errno::INTR.into(),
197
198
// Normalize Linux' non-standard behavior.
199
//
200
// From https://man7.org/linux/man-pages/man2/accept.2.html:
201
// > Linux accept() passes already-pending network errors on the
202
// > new socket as an error code from accept(). This behavior
203
// > differs from other BSD socket implementations. (...)
204
#[cfg(target_os = "linux")]
205
Some(
206
Errno::CONNRESET
207
| Errno::NETRESET
208
| Errno::HOSTUNREACH
209
| Errno::HOSTDOWN
210
| Errno::NETDOWN
211
| Errno::NETUNREACH
212
| Errno::PROTO
213
| Errno::NOPROTOOPT
214
| Errno::NONET
215
| Errno::OPNOTSUPP,
216
) => Errno::CONNABORTED.into(),
217
218
_ => err,
219
})?;
220
options.apply(family, &client);
221
Ok(Self::from_state(
222
TcpState::Connected(Arc::new(client)),
223
family,
224
))
225
}
226
227
/// Create a `TcpSocket` from an existing socket.
228
fn from_state(state: TcpState, family: SocketAddressFamily) -> Self {
229
Self {
230
tcp_state: state,
231
listen_backlog_size: DEFAULT_TCP_BACKLOG,
232
family,
233
options: Default::default(),
234
}
235
}
236
237
pub(crate) fn as_std_view(&self) -> Result<SocketlikeView<'_, std::net::TcpStream>, ErrorCode> {
238
match &self.tcp_state {
239
TcpState::Default(socket)
240
| TcpState::BindStarted(socket)
241
| TcpState::Bound(socket)
242
| TcpState::ListenStarted(socket) => Ok(socket.as_socketlike_view()),
243
TcpState::Connected(stream) => Ok(stream.as_socketlike_view()),
244
#[cfg(feature = "p3")]
245
TcpState::Receiving(stream) => Ok(stream.as_socketlike_view()),
246
TcpState::Listening { listener, .. } => Ok(listener.as_socketlike_view()),
247
TcpState::P2Streaming(state) => Ok(state.stream.as_socketlike_view()),
248
TcpState::Connecting(..) | TcpState::ConnectReady(_) | TcpState::Closed => {
249
Err(ErrorCode::InvalidState)
250
}
251
#[cfg(feature = "p3")]
252
TcpState::Error(err) => Err(err.into()),
253
}
254
}
255
256
pub(crate) fn start_bind(&mut self, addr: SocketAddr) -> Result<(), ErrorCode> {
257
let ip = addr.ip();
258
if !is_valid_unicast_address(ip) || !is_valid_address_family(ip, self.family) {
259
return Err(ErrorCode::InvalidArgument);
260
}
261
match mem::replace(&mut self.tcp_state, TcpState::Closed) {
262
TcpState::Default(sock) => {
263
if let Err(err) = tcp_bind(&sock, addr) {
264
self.tcp_state = TcpState::Default(sock);
265
Err(err)
266
} else {
267
self.tcp_state = TcpState::BindStarted(sock);
268
Ok(())
269
}
270
}
271
tcp_state => {
272
self.tcp_state = tcp_state;
273
Err(ErrorCode::InvalidState)
274
}
275
}
276
}
277
278
pub(crate) fn finish_bind(&mut self) -> Result<(), ErrorCode> {
279
match mem::replace(&mut self.tcp_state, TcpState::Closed) {
280
TcpState::BindStarted(socket) => {
281
self.tcp_state = TcpState::Bound(socket);
282
Ok(())
283
}
284
current_state => {
285
// Reset the state so that the outside world doesn't see this socket as closed
286
self.tcp_state = current_state;
287
Err(ErrorCode::NotInProgress)
288
}
289
}
290
}
291
292
pub(crate) fn start_connect(
293
&mut self,
294
addr: &SocketAddr,
295
) -> Result<tokio::net::TcpSocket, ErrorCode> {
296
match self.tcp_state {
297
TcpState::Default(..) | TcpState::Bound(..) => {}
298
TcpState::Connecting(..) => {
299
return Err(ErrorCode::ConcurrencyConflict);
300
}
301
_ => return Err(ErrorCode::InvalidState),
302
};
303
304
if !is_valid_unicast_address(addr.ip())
305
|| !is_valid_remote_address(*addr)
306
|| !is_valid_address_family(addr.ip(), self.family)
307
{
308
return Err(ErrorCode::InvalidArgument);
309
};
310
311
let (TcpState::Default(tokio_socket) | TcpState::Bound(tokio_socket)) =
312
mem::replace(&mut self.tcp_state, TcpState::Connecting(None))
313
else {
314
unreachable!();
315
};
316
317
Ok(tokio_socket)
318
}
319
320
/// For WASIp2 this is used to record the actual connection future as part
321
/// of `start_connect` within this socket state.
322
pub(crate) fn set_pending_connect(
323
&mut self,
324
future: impl Future<Output = io::Result<tokio::net::TcpStream>> + Send + 'static,
325
) -> Result<(), ErrorCode> {
326
match &mut self.tcp_state {
327
TcpState::Connecting(slot @ None) => {
328
*slot = Some(Box::pin(future));
329
Ok(())
330
}
331
_ => Err(ErrorCode::InvalidState),
332
}
333
}
334
335
/// For WASIp2 this retreives the result from the future passed to
336
/// `set_pending_connect`.
337
///
338
/// Return states here are:
339
///
340
/// * `Ok(Some(res))` - where `res` is the result of the connect operation.
341
/// * `Ok(None)` - the connect operation isn't ready yet.
342
/// * `Err(e)` - a connect operation is not in progress.
343
pub(crate) fn take_pending_connect(
344
&mut self,
345
) -> Result<Option<io::Result<tokio::net::TcpStream>>, ErrorCode> {
346
match mem::replace(&mut self.tcp_state, TcpState::Connecting(None)) {
347
TcpState::ConnectReady(result) => Ok(Some(result)),
348
TcpState::Connecting(Some(mut future)) => {
349
let mut cx = Context::from_waker(Waker::noop());
350
match with_ambient_tokio_runtime(|| future.as_mut().poll(&mut cx)) {
351
Poll::Ready(result) => Ok(Some(result)),
352
Poll::Pending => {
353
self.tcp_state = TcpState::Connecting(Some(future));
354
Ok(None)
355
}
356
}
357
}
358
current_state => {
359
self.tcp_state = current_state;
360
Err(ErrorCode::NotInProgress)
361
}
362
}
363
}
364
365
pub(crate) fn finish_connect(
366
&mut self,
367
result: io::Result<tokio::net::TcpStream>,
368
) -> Result<(), ErrorCode> {
369
if !matches!(self.tcp_state, TcpState::Connecting(None)) {
370
return Err(ErrorCode::InvalidState);
371
}
372
match result {
373
Ok(stream) => {
374
self.tcp_state = TcpState::Connected(Arc::new(stream));
375
Ok(())
376
}
377
Err(err) => {
378
self.tcp_state = TcpState::Closed;
379
Err(ErrorCode::from(err))
380
}
381
}
382
}
383
384
pub(crate) fn start_listen(&mut self) -> Result<(), ErrorCode> {
385
match mem::replace(&mut self.tcp_state, TcpState::Closed) {
386
TcpState::Bound(tokio_socket) => {
387
self.tcp_state = TcpState::ListenStarted(tokio_socket);
388
Ok(())
389
}
390
previous_state => {
391
self.tcp_state = previous_state;
392
Err(ErrorCode::InvalidState)
393
}
394
}
395
}
396
397
pub(crate) fn finish_listen(&mut self) -> Result<(), ErrorCode> {
398
let tokio_socket = match mem::replace(&mut self.tcp_state, TcpState::Closed) {
399
TcpState::ListenStarted(tokio_socket) => tokio_socket,
400
previous_state => {
401
self.tcp_state = previous_state;
402
return Err(ErrorCode::NotInProgress);
403
}
404
};
405
406
match with_ambient_tokio_runtime(|| tokio_socket.listen(self.listen_backlog_size)) {
407
Ok(listener) => {
408
self.tcp_state = TcpState::Listening {
409
listener: Arc::new(listener),
410
pending_accept: None,
411
};
412
Ok(())
413
}
414
Err(err) => {
415
self.tcp_state = TcpState::Closed;
416
417
Err(match Errno::from_io_error(&err) {
418
// See: https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-listen#:~:text=WSAEMFILE
419
// According to the docs, `listen` can return EMFILE on Windows.
420
// This is odd, because we're not trying to create a new socket
421
// or file descriptor of any kind. So we rewrite it to less
422
// surprising error code.
423
//
424
// At the time of writing, this behavior has never been experimentally
425
// observed by any of the wasmtime authors, so we're relying fully
426
// on Microsoft's documentation here.
427
#[cfg(windows)]
428
Some(Errno::MFILE) => Errno::NOBUFS.into(),
429
430
_ => err.into(),
431
})
432
}
433
}
434
}
435
436
pub(crate) fn accept(&mut self) -> Result<Option<Self>, ErrorCode> {
437
let TcpState::Listening {
438
listener,
439
pending_accept,
440
} = &mut self.tcp_state
441
else {
442
return Err(ErrorCode::InvalidState);
443
};
444
445
let result = match pending_accept.take() {
446
Some(result) => result,
447
None => {
448
let mut cx = std::task::Context::from_waker(Waker::noop());
449
match with_ambient_tokio_runtime(|| listener.poll_accept(&mut cx))
450
.map_ok(|(stream, _)| stream)
451
{
452
Poll::Ready(result) => result,
453
Poll::Pending => return Ok(None),
454
}
455
}
456
};
457
458
Ok(Some(Self::new_accept(result, &self.options, self.family)?))
459
}
460
461
#[cfg(feature = "p3")]
462
pub(crate) fn start_receive(&mut self) -> Option<&Arc<tokio::net::TcpStream>> {
463
match mem::replace(&mut self.tcp_state, TcpState::Closed) {
464
TcpState::Connected(stream) => {
465
self.tcp_state = TcpState::Receiving(stream);
466
Some(self.tcp_stream_arc().unwrap())
467
}
468
prev => {
469
self.tcp_state = prev;
470
None
471
}
472
}
473
}
474
475
pub(crate) fn local_address(&self) -> Result<SocketAddr, ErrorCode> {
476
match &self.tcp_state {
477
TcpState::Bound(socket) => Ok(socket.local_addr()?),
478
TcpState::Connected(stream) => Ok(stream.local_addr()?),
479
#[cfg(feature = "p3")]
480
TcpState::Receiving(stream) => Ok(stream.local_addr()?),
481
TcpState::P2Streaming(state) => Ok(state.stream.local_addr()?),
482
TcpState::Listening { listener, .. } => Ok(listener.local_addr()?),
483
#[cfg(feature = "p3")]
484
TcpState::Error(err) => Err(err.into()),
485
_ => Err(ErrorCode::InvalidState),
486
}
487
}
488
489
pub(crate) fn remote_address(&self) -> Result<SocketAddr, ErrorCode> {
490
let stream = self.tcp_stream_arc()?;
491
let addr = stream.peer_addr()?;
492
Ok(addr)
493
}
494
495
pub(crate) fn is_listening(&self) -> bool {
496
matches!(self.tcp_state, TcpState::Listening { .. })
497
}
498
499
pub(crate) fn address_family(&self) -> SocketAddressFamily {
500
self.family
501
}
502
503
pub(crate) fn set_listen_backlog_size(&mut self, value: u64) -> Result<(), ErrorCode> {
504
const MIN_BACKLOG: u32 = 1;
505
const MAX_BACKLOG: u32 = i32::MAX as u32; // OS'es will most likely limit it down even further.
506
507
if value == 0 {
508
return Err(ErrorCode::InvalidArgument);
509
}
510
// Silently clamp backlog size. This is OK for us to do, because operating systems do this too.
511
let value = value
512
.try_into()
513
.unwrap_or(MAX_BACKLOG)
514
.clamp(MIN_BACKLOG, MAX_BACKLOG);
515
match &self.tcp_state {
516
TcpState::Default(..) | TcpState::Bound(..) => {
517
// Socket not listening yet. Stash value for first invocation to `listen`.
518
self.listen_backlog_size = value;
519
Ok(())
520
}
521
TcpState::Listening { listener, .. } => {
522
// Try to update the backlog by calling `listen` again.
523
// Not all platforms support this. We'll only update our own value if the OS supports changing the backlog size after the fact.
524
if rustix::net::listen(&listener, value.try_into().unwrap_or(i32::MAX)).is_err() {
525
return Err(ErrorCode::NotSupported);
526
}
527
self.listen_backlog_size = value;
528
Ok(())
529
}
530
#[cfg(feature = "p3")]
531
TcpState::Error(err) => Err(err.into()),
532
_ => Err(ErrorCode::InvalidState),
533
}
534
}
535
536
pub(crate) fn keep_alive_enabled(&self) -> Result<bool, ErrorCode> {
537
let fd = &*self.as_std_view()?;
538
let v = sockopt::socket_keepalive(fd)?;
539
Ok(v)
540
}
541
542
pub(crate) fn set_keep_alive_enabled(&self, value: bool) -> Result<(), ErrorCode> {
543
let fd = &*self.as_std_view()?;
544
sockopt::set_socket_keepalive(fd, value)?;
545
Ok(())
546
}
547
548
pub(crate) fn keep_alive_idle_time(&self) -> Result<u64, ErrorCode> {
549
let fd = &*self.as_std_view()?;
550
let v = sockopt::tcp_keepidle(fd)?;
551
Ok(v.as_nanos().try_into().unwrap_or(u64::MAX))
552
}
553
554
pub(crate) fn set_keep_alive_idle_time(&mut self, value: u64) -> Result<(), ErrorCode> {
555
let value = {
556
let fd = self.as_std_view()?;
557
set_keep_alive_idle_time(&*fd, value)?
558
};
559
self.options.set_keep_alive_idle_time(value);
560
Ok(())
561
}
562
563
pub(crate) fn keep_alive_interval(&self) -> Result<u64, ErrorCode> {
564
let fd = &*self.as_std_view()?;
565
let v = sockopt::tcp_keepintvl(fd)?;
566
Ok(v.as_nanos().try_into().unwrap_or(u64::MAX))
567
}
568
569
pub(crate) fn set_keep_alive_interval(&self, value: u64) -> Result<(), ErrorCode> {
570
let fd = &*self.as_std_view()?;
571
set_keep_alive_interval(fd, Duration::from_nanos(value))?;
572
Ok(())
573
}
574
575
pub(crate) fn keep_alive_count(&self) -> Result<u32, ErrorCode> {
576
let fd = &*self.as_std_view()?;
577
let v = sockopt::tcp_keepcnt(fd)?;
578
Ok(v)
579
}
580
581
pub(crate) fn set_keep_alive_count(&self, value: u32) -> Result<(), ErrorCode> {
582
let fd = &*self.as_std_view()?;
583
set_keep_alive_count(fd, value)?;
584
Ok(())
585
}
586
587
pub(crate) fn hop_limit(&self) -> Result<u8, ErrorCode> {
588
let fd = &*self.as_std_view()?;
589
let n = get_unicast_hop_limit(fd, self.family)?;
590
Ok(n)
591
}
592
593
pub(crate) fn set_hop_limit(&mut self, value: u8) -> Result<(), ErrorCode> {
594
{
595
let fd = &*self.as_std_view()?;
596
set_unicast_hop_limit(fd, self.family, value)?;
597
}
598
self.options.set_hop_limit(value);
599
Ok(())
600
}
601
602
pub(crate) fn receive_buffer_size(&self) -> Result<u64, ErrorCode> {
603
let fd = &*self.as_std_view()?;
604
let n = receive_buffer_size(fd)?;
605
Ok(n)
606
}
607
608
pub(crate) fn set_receive_buffer_size(&mut self, value: u64) -> Result<(), ErrorCode> {
609
let res = {
610
let fd = &*self.as_std_view()?;
611
set_receive_buffer_size(fd, value)?
612
};
613
self.options.set_receive_buffer_size(res);
614
Ok(())
615
}
616
617
pub(crate) fn send_buffer_size(&self) -> Result<u64, ErrorCode> {
618
let fd = &*self.as_std_view()?;
619
let n = send_buffer_size(fd)?;
620
Ok(n)
621
}
622
623
pub(crate) fn set_send_buffer_size(&mut self, value: u64) -> Result<(), ErrorCode> {
624
let res = {
625
let fd = &*self.as_std_view()?;
626
set_send_buffer_size(fd, value)?
627
};
628
self.options.set_send_buffer_size(res);
629
Ok(())
630
}
631
632
#[cfg(feature = "p3")]
633
pub(crate) fn non_inherited_options(&self) -> &NonInheritedOptions {
634
&self.options
635
}
636
637
#[cfg(feature = "p3")]
638
pub(crate) fn tcp_listener_arc(&self) -> Result<&Arc<tokio::net::TcpListener>, ErrorCode> {
639
match &self.tcp_state {
640
TcpState::Listening { listener, .. } => Ok(listener),
641
#[cfg(feature = "p3")]
642
TcpState::Error(err) => Err(err.into()),
643
_ => Err(ErrorCode::InvalidState),
644
}
645
}
646
647
pub(crate) fn tcp_stream_arc(&self) -> Result<&Arc<tokio::net::TcpStream>, ErrorCode> {
648
match &self.tcp_state {
649
TcpState::Connected(socket) => Ok(socket),
650
#[cfg(feature = "p3")]
651
TcpState::Receiving(socket) => Ok(socket),
652
TcpState::P2Streaming(state) => Ok(&state.stream),
653
#[cfg(feature = "p3")]
654
TcpState::Error(err) => Err(err.into()),
655
_ => Err(ErrorCode::InvalidState),
656
}
657
}
658
659
pub(crate) fn p2_streaming_state(&self) -> Result<&P2TcpStreamingState, ErrorCode> {
660
match &self.tcp_state {
661
TcpState::P2Streaming(state) => Ok(state),
662
#[cfg(feature = "p3")]
663
TcpState::Error(err) => Err(err.into()),
664
_ => Err(ErrorCode::InvalidState),
665
}
666
}
667
668
pub(crate) fn set_p2_streaming_state(
669
&mut self,
670
state: P2TcpStreamingState,
671
) -> Result<(), ErrorCode> {
672
if !matches!(self.tcp_state, TcpState::Connected(_)) {
673
return Err(ErrorCode::InvalidState);
674
}
675
self.tcp_state = TcpState::P2Streaming(Box::new(state));
676
Ok(())
677
}
678
679
/// Used for `Pollable` in the WASIp2 implementation this awaits the socket
680
/// to be connected, if in the connecting state, or for a TCP accept to be
681
/// ready, if this is in the listening state.
682
///
683
/// For all other states this method immediately returns.
684
pub(crate) async fn ready(&mut self) {
685
match &mut self.tcp_state {
686
TcpState::Default(..)
687
| TcpState::BindStarted(..)
688
| TcpState::Bound(..)
689
| TcpState::ListenStarted(..)
690
| TcpState::ConnectReady(..)
691
| TcpState::Closed
692
| TcpState::Connected { .. }
693
| TcpState::Connecting(None)
694
| TcpState::Listening {
695
pending_accept: Some(_),
696
..
697
}
698
| TcpState::P2Streaming(_) => {}
699
700
#[cfg(feature = "p3")]
701
TcpState::Receiving(_) | TcpState::Error(_) => {}
702
703
TcpState::Connecting(Some(future)) => {
704
self.tcp_state = TcpState::ConnectReady(future.as_mut().await);
705
}
706
707
TcpState::Listening {
708
listener,
709
pending_accept: slot @ None,
710
} => {
711
let result = futures::future::poll_fn(|cx| {
712
listener.poll_accept(cx).map_ok(|(stream, _)| stream)
713
})
714
.await;
715
*slot = Some(result);
716
}
717
}
718
}
719
}
720
721
#[cfg(not(target_os = "macos"))]
722
pub use inherits_option::*;
723
#[cfg(not(target_os = "macos"))]
724
mod inherits_option {
725
use crate::sockets::SocketAddressFamily;
726
use tokio::net::TcpStream;
727
728
#[derive(Default, Clone)]
729
pub struct NonInheritedOptions;
730
731
impl NonInheritedOptions {
732
pub fn set_keep_alive_idle_time(&mut self, _value: u64) {}
733
734
pub fn set_hop_limit(&mut self, _value: u8) {}
735
736
pub fn set_receive_buffer_size(&mut self, _value: usize) {}
737
738
pub fn set_send_buffer_size(&mut self, _value: usize) {}
739
740
pub(crate) fn apply(&self, _family: SocketAddressFamily, _stream: &TcpStream) {}
741
}
742
}
743
744
#[cfg(target_os = "macos")]
745
pub use does_not_inherit_options::*;
746
#[cfg(target_os = "macos")]
747
mod does_not_inherit_options {
748
use crate::sockets::SocketAddressFamily;
749
use rustix::net::sockopt;
750
use std::sync::Arc;
751
use std::sync::atomic::{AtomicU8, AtomicU64, AtomicUsize, Ordering::Relaxed};
752
use std::time::Duration;
753
use tokio::net::TcpStream;
754
755
// The socket options below are not automatically inherited from the listener
756
// on all platforms. So we keep track of which options have been explicitly
757
// set and manually apply those values to newly accepted clients.
758
#[derive(Default, Clone)]
759
pub struct NonInheritedOptions(Arc<Inner>);
760
761
#[derive(Default)]
762
struct Inner {
763
receive_buffer_size: AtomicUsize,
764
send_buffer_size: AtomicUsize,
765
hop_limit: AtomicU8,
766
keep_alive_idle_time: AtomicU64, // nanoseconds
767
}
768
769
impl NonInheritedOptions {
770
pub fn set_keep_alive_idle_time(&mut self, value: u64) {
771
self.0.keep_alive_idle_time.store(value, Relaxed);
772
}
773
774
pub fn set_hop_limit(&mut self, value: u8) {
775
self.0.hop_limit.store(value, Relaxed);
776
}
777
778
pub fn set_receive_buffer_size(&mut self, value: usize) {
779
self.0.receive_buffer_size.store(value, Relaxed);
780
}
781
782
pub fn set_send_buffer_size(&mut self, value: usize) {
783
self.0.send_buffer_size.store(value, Relaxed);
784
}
785
786
pub(crate) fn apply(&self, family: SocketAddressFamily, stream: &TcpStream) {
787
// Manually inherit socket options from listener. We only have to
788
// do this on platforms that don't already do this automatically
789
// and only if a specific value was explicitly set on the listener.
790
791
let receive_buffer_size = self.0.receive_buffer_size.load(Relaxed);
792
if receive_buffer_size > 0 {
793
// Ignore potential error.
794
_ = sockopt::set_socket_recv_buffer_size(&stream, receive_buffer_size);
795
}
796
797
let send_buffer_size = self.0.send_buffer_size.load(Relaxed);
798
if send_buffer_size > 0 {
799
// Ignore potential error.
800
_ = sockopt::set_socket_send_buffer_size(&stream, send_buffer_size);
801
}
802
803
// For some reason, IP_TTL is inherited, but IPV6_UNICAST_HOPS isn't.
804
if family == SocketAddressFamily::Ipv6 {
805
let hop_limit = self.0.hop_limit.load(Relaxed);
806
if hop_limit > 0 {
807
// Ignore potential error.
808
_ = sockopt::set_ipv6_unicast_hops(&stream, Some(hop_limit));
809
}
810
}
811
812
let keep_alive_idle_time = self.0.keep_alive_idle_time.load(Relaxed);
813
if keep_alive_idle_time > 0 {
814
// Ignore potential error.
815
_ = sockopt::set_tcp_keepidle(&stream, Duration::from_nanos(keep_alive_idle_time));
816
}
817
}
818
}
819
}
820
821