Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
bytecodealliance
GitHub Repository: bytecodealliance/wasmtime
Path: blob/main/crates/wasi/src/p2/tcp.rs
1692 views
1
use crate::p2::{
2
DynInputStream, DynOutputStream, InputStream, OutputStream, Pollable, SocketError,
3
SocketResult, StreamError,
4
};
5
use crate::runtime::AbortOnDropJoinHandle;
6
use crate::sockets::TcpSocket;
7
use anyhow::Result;
8
use io_lifetimes::AsSocketlike;
9
use rustix::io::Errno;
10
use std::io;
11
use std::mem;
12
use std::net::Shutdown;
13
use std::sync::Arc;
14
use tokio::sync::Mutex;
15
16
impl TcpSocket {
17
pub(crate) fn p2_streams(&mut self) -> SocketResult<(DynInputStream, DynOutputStream)> {
18
let client = self.tcp_stream_arc()?;
19
let reader = Arc::new(Mutex::new(TcpReader::new(client.clone())));
20
let writer = Arc::new(Mutex::new(TcpWriter::new(client.clone())));
21
self.set_p2_streaming_state(P2TcpStreamingState {
22
stream: client.clone(),
23
reader: reader.clone(),
24
writer: writer.clone(),
25
})?;
26
let input: DynInputStream = Box::new(TcpReadStream(reader));
27
let output: DynOutputStream = Box::new(TcpWriteStream(writer));
28
Ok((input, output))
29
}
30
}
31
32
pub(crate) struct P2TcpStreamingState {
33
pub(crate) stream: Arc<tokio::net::TcpStream>,
34
reader: Arc<Mutex<TcpReader>>,
35
writer: Arc<Mutex<TcpWriter>>,
36
}
37
38
impl P2TcpStreamingState {
39
pub(crate) fn shutdown(&self, how: Shutdown) -> SocketResult<()> {
40
if let Shutdown::Both | Shutdown::Read = how {
41
try_lock_for_socket(&self.reader)?.shutdown();
42
}
43
44
if let Shutdown::Both | Shutdown::Write = how {
45
try_lock_for_socket(&self.writer)?.shutdown();
46
}
47
48
Ok(())
49
}
50
}
51
52
struct TcpReader {
53
stream: Arc<tokio::net::TcpStream>,
54
closed: bool,
55
}
56
57
impl TcpReader {
58
fn new(stream: Arc<tokio::net::TcpStream>) -> Self {
59
Self {
60
stream,
61
closed: false,
62
}
63
}
64
fn read(&mut self, size: usize) -> Result<bytes::Bytes, StreamError> {
65
if self.closed {
66
return Err(StreamError::Closed);
67
}
68
if size == 0 {
69
return Ok(bytes::Bytes::new());
70
}
71
72
let mut buf = bytes::BytesMut::with_capacity(size);
73
let n = match self.stream.try_read_buf(&mut buf) {
74
// A 0-byte read indicates that the stream has closed.
75
Ok(0) => {
76
self.closed = true;
77
return Err(StreamError::Closed);
78
}
79
Ok(n) => n,
80
81
// Failing with `EWOULDBLOCK` is how we differentiate between a closed channel and no
82
// data to read right now.
83
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => 0,
84
85
Err(e) => {
86
self.closed = true;
87
return Err(StreamError::LastOperationFailed(e.into()));
88
}
89
};
90
91
buf.truncate(n);
92
Ok(buf.freeze())
93
}
94
95
fn shutdown(&mut self) {
96
native_shutdown(&self.stream, Shutdown::Read);
97
self.closed = true;
98
}
99
100
async fn ready(&mut self) {
101
if self.closed {
102
return;
103
}
104
105
self.stream.readable().await.unwrap();
106
}
107
}
108
109
struct TcpReadStream(Arc<Mutex<TcpReader>>);
110
111
#[async_trait::async_trait]
112
impl InputStream for TcpReadStream {
113
fn read(&mut self, size: usize) -> Result<bytes::Bytes, StreamError> {
114
try_lock_for_stream(&self.0)?.read(size)
115
}
116
}
117
118
#[async_trait::async_trait]
119
impl Pollable for TcpReadStream {
120
async fn ready(&mut self) {
121
self.0.lock().await.ready().await
122
}
123
}
124
125
const SOCKET_READY_SIZE: usize = 1024 * 1024 * 1024;
126
127
struct TcpWriter {
128
stream: Arc<tokio::net::TcpStream>,
129
state: WriteState,
130
}
131
132
enum WriteState {
133
Ready,
134
Writing(AbortOnDropJoinHandle<io::Result<()>>),
135
Closing(AbortOnDropJoinHandle<io::Result<()>>),
136
Closed,
137
Error(io::Error),
138
}
139
140
impl TcpWriter {
141
fn new(stream: Arc<tokio::net::TcpStream>) -> Self {
142
Self {
143
stream,
144
state: WriteState::Ready,
145
}
146
}
147
148
fn try_write_portable(stream: &tokio::net::TcpStream, buf: &[u8]) -> io::Result<usize> {
149
stream.try_write(buf).map_err(|error| {
150
match Errno::from_io_error(&error) {
151
// Windows returns `WSAESHUTDOWN` when writing to a shut down socket.
152
// We normalize this to EPIPE, because that is what the other platforms return.
153
// See: https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-send#:~:text=WSAESHUTDOWN
154
#[cfg(windows)]
155
Some(Errno::SHUTDOWN) => io::Error::new(io::ErrorKind::BrokenPipe, error),
156
157
_ => error,
158
}
159
})
160
}
161
162
/// Write `bytes` in a background task, remembering the task handle for use in a future call to
163
/// `write_ready`
164
fn background_write(&mut self, mut bytes: bytes::Bytes) {
165
assert!(matches!(self.state, WriteState::Ready));
166
167
let stream = self.stream.clone();
168
self.state = WriteState::Writing(crate::runtime::spawn(async move {
169
// Note: we are not using the AsyncWrite impl here, and instead using the TcpStream
170
// primitive try_write, which goes directly to attempt a write with mio. This has
171
// two advantages: 1. this operation takes a &TcpStream instead of a &mut TcpStream
172
// required to AsyncWrite, and 2. it eliminates any buffering in tokio we may need
173
// to flush.
174
while !bytes.is_empty() {
175
stream.writable().await?;
176
match Self::try_write_portable(&stream, &bytes) {
177
Ok(n) => {
178
let _ = bytes.split_to(n);
179
}
180
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => continue,
181
Err(e) => return Err(e),
182
}
183
}
184
185
Ok(())
186
}));
187
}
188
189
fn write(&mut self, mut bytes: bytes::Bytes) -> Result<(), StreamError> {
190
match self.state {
191
WriteState::Ready => {}
192
WriteState::Closed => return Err(StreamError::Closed),
193
WriteState::Writing(_) | WriteState::Closing(_) | WriteState::Error(_) => {
194
return Err(StreamError::Trap(anyhow::anyhow!(
195
"unpermitted: must call check_write first"
196
)));
197
}
198
}
199
while !bytes.is_empty() {
200
match Self::try_write_portable(&self.stream, &bytes) {
201
Ok(n) => {
202
let _ = bytes.split_to(n);
203
}
204
205
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
206
// As `try_write` indicated that it would have blocked, we'll perform the write
207
// in the background to allow us to return immediately.
208
self.background_write(bytes);
209
210
return Ok(());
211
}
212
213
Err(e) if e.kind() == std::io::ErrorKind::BrokenPipe => {
214
self.state = WriteState::Closed;
215
return Err(StreamError::Closed);
216
}
217
218
Err(e) => return Err(StreamError::LastOperationFailed(e.into())),
219
}
220
}
221
222
Ok(())
223
}
224
225
fn flush(&mut self) -> Result<(), StreamError> {
226
// `flush` is a no-op here, as we're not managing any internal buffer. Additionally,
227
// `write_ready` will join the background write task if it's active, so following `flush`
228
// with `write_ready` will have the desired effect.
229
match self.state {
230
WriteState::Ready
231
| WriteState::Writing(_)
232
| WriteState::Closing(_)
233
| WriteState::Error(_) => Ok(()),
234
WriteState::Closed => Err(StreamError::Closed),
235
}
236
}
237
238
fn check_write(&mut self) -> Result<usize, StreamError> {
239
match mem::replace(&mut self.state, WriteState::Closed) {
240
WriteState::Writing(task) => {
241
self.state = WriteState::Writing(task);
242
return Ok(0);
243
}
244
WriteState::Closing(task) => {
245
self.state = WriteState::Closing(task);
246
return Ok(0);
247
}
248
WriteState::Ready => {
249
self.state = WriteState::Ready;
250
}
251
WriteState::Closed => return Err(StreamError::Closed),
252
WriteState::Error(e) => return Err(StreamError::LastOperationFailed(e.into())),
253
}
254
255
let writable = self.stream.writable();
256
futures::pin_mut!(writable);
257
if crate::runtime::poll_noop(writable).is_none() {
258
return Ok(0);
259
}
260
Ok(SOCKET_READY_SIZE)
261
}
262
263
fn shutdown(&mut self) {
264
self.state = match mem::replace(&mut self.state, WriteState::Closed) {
265
// No write in progress, immediately shut down:
266
WriteState::Ready => {
267
native_shutdown(&self.stream, Shutdown::Write);
268
WriteState::Closed
269
}
270
271
// Schedule the shutdown after the current write has finished:
272
WriteState::Writing(write) => {
273
let stream = self.stream.clone();
274
WriteState::Closing(crate::runtime::spawn(async move {
275
let result = write.await;
276
native_shutdown(&stream, Shutdown::Write);
277
result
278
}))
279
}
280
281
s => s,
282
};
283
}
284
285
async fn cancel(&mut self) {
286
match mem::replace(&mut self.state, WriteState::Closed) {
287
WriteState::Writing(task) | WriteState::Closing(task) => _ = task.cancel().await,
288
_ => {}
289
}
290
}
291
292
async fn ready(&mut self) {
293
match &mut self.state {
294
WriteState::Writing(task) => {
295
self.state = match task.await {
296
Ok(()) => WriteState::Ready,
297
Err(e) => WriteState::Error(e),
298
}
299
}
300
WriteState::Closing(task) => {
301
self.state = match task.await {
302
Ok(()) => WriteState::Closed,
303
Err(e) => WriteState::Error(e),
304
}
305
}
306
_ => {}
307
}
308
309
if let WriteState::Ready = self.state {
310
self.stream.writable().await.unwrap();
311
}
312
}
313
}
314
315
struct TcpWriteStream(Arc<Mutex<TcpWriter>>);
316
317
#[async_trait::async_trait]
318
impl OutputStream for TcpWriteStream {
319
fn write(&mut self, bytes: bytes::Bytes) -> Result<(), StreamError> {
320
try_lock_for_stream(&self.0)?.write(bytes)
321
}
322
323
fn flush(&mut self) -> Result<(), StreamError> {
324
try_lock_for_stream(&self.0)?.flush()
325
}
326
327
fn check_write(&mut self) -> Result<usize, StreamError> {
328
try_lock_for_stream(&self.0)?.check_write()
329
}
330
331
async fn cancel(&mut self) {
332
self.0.lock().await.cancel().await
333
}
334
}
335
336
#[async_trait::async_trait]
337
impl Pollable for TcpWriteStream {
338
async fn ready(&mut self) {
339
self.0.lock().await.ready().await
340
}
341
}
342
343
fn native_shutdown(stream: &tokio::net::TcpStream, how: Shutdown) {
344
_ = stream
345
.as_socketlike_view::<std::net::TcpStream>()
346
.shutdown(how);
347
}
348
349
fn try_lock_for_stream<T>(mutex: &Mutex<T>) -> Result<tokio::sync::MutexGuard<'_, T>, StreamError> {
350
mutex
351
.try_lock()
352
.map_err(|_| StreamError::trap("concurrent access to resource not supported"))
353
}
354
355
fn try_lock_for_socket<T>(mutex: &Mutex<T>) -> SocketResult<tokio::sync::MutexGuard<'_, T>> {
356
mutex.try_lock().map_err(|_| {
357
SocketError::trap(anyhow::anyhow!(
358
"concurrent access to resource not supported"
359
))
360
})
361
}
362
363