Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
google
GitHub Repository: google/crosvm
Path: blob/main/base/src/sys/unix/stream_channel.rs
5394 views
1
// Copyright 2022 The ChromiumOS Authors
2
// Use of this source code is governed by a BSD-style license that can be
3
// found in the LICENSE file.
4
5
use std::io;
6
use std::io::Read;
7
use std::os::unix::io::AsRawFd;
8
use std::os::unix::io::RawFd;
9
use std::os::unix::net::UnixStream;
10
use std::time::Duration;
11
12
use libc::c_void;
13
use serde::Deserialize;
14
use serde::Serialize;
15
16
use super::super::net::UnixSeqpacket;
17
use crate::descriptor::AsRawDescriptor;
18
use crate::IntoRawDescriptor;
19
use crate::RawDescriptor;
20
use crate::ReadNotifier;
21
use crate::Result;
22
23
#[derive(Copy, Clone)]
24
pub enum FramingMode {
25
Message,
26
Byte,
27
}
28
29
#[derive(Copy, Clone, PartialEq, Eq)]
30
pub enum BlockingMode {
31
Blocking,
32
Nonblocking,
33
}
34
35
impl io::Read for StreamChannel {
36
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
37
self.inner_read(buf)
38
}
39
}
40
41
impl io::Read for &StreamChannel {
42
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
43
self.inner_read(buf)
44
}
45
}
46
47
impl AsRawDescriptor for StreamChannel {
48
fn as_raw_descriptor(&self) -> RawDescriptor {
49
(&self).as_raw_descriptor()
50
}
51
}
52
53
#[derive(Debug, Deserialize, Serialize)]
54
enum SocketType {
55
Message(UnixSeqpacket),
56
#[serde(with = "crate::with_as_descriptor")]
57
Byte(UnixStream),
58
}
59
60
/// An abstraction over named pipes and unix socketpairs. This abstraction can be used in a blocking
61
/// and non blocking mode.
62
///
63
/// WARNING: partial reads of messages behave differently depending on the platform.
64
/// See sys::unix::StreamChannel::inner_read for details.
65
#[derive(Debug, Deserialize, Serialize)]
66
pub struct StreamChannel {
67
stream: SocketType,
68
}
69
70
impl StreamChannel {
71
pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> {
72
match &mut self.stream {
73
SocketType::Byte(sock) => sock.set_nonblocking(nonblocking),
74
SocketType::Message(sock) => sock.set_nonblocking(nonblocking),
75
}
76
}
77
78
pub fn get_framing_mode(&self) -> FramingMode {
79
match &self.stream {
80
SocketType::Message(_) => FramingMode::Message,
81
SocketType::Byte(_) => FramingMode::Byte,
82
}
83
}
84
85
pub(super) fn inner_read(&self, buf: &mut [u8]) -> io::Result<usize> {
86
match &self.stream {
87
SocketType::Byte(sock) => (&mut &*sock).read(buf),
88
89
// On Windows, reading from SOCK_SEQPACKET with a buffer that is too small is an error,
90
// and the extra data will be preserved inside the named pipe.
91
//
92
// Linux though, will silently truncate unless MSG_TRUNC is passed. So we pass it, but
93
// even in that case, Linux will still throw away the extra data. This means there is a
94
// slight behavior difference between platforms from the consumer's perspective.
95
// In practice on Linux, intentional partial reads of messages are usually accomplished
96
// by also passing MSG_PEEK. While we could do this, and hide this rough edge from
97
// consumers, it would add complexity & turn every read into two read syscalls.
98
//
99
// So the compromise is this:
100
// * On Linux: a partial read of a message is an Err and loses data.
101
// * On Windows: a partial read of a message is Ok and does not lose data.
102
SocketType::Message(sock) => {
103
// SAFETY:
104
// Safe because buf is valid, we pass buf's size to recv to bound the return
105
// length, and we check the return code.
106
let retval = unsafe {
107
// TODO(nkgold|b/152067913): Move this into the UnixSeqpacket struct as a
108
// recv_with_flags method once that struct's tests are working.
109
libc::recv(
110
sock.as_raw_descriptor(),
111
buf.as_mut_ptr() as *mut c_void,
112
buf.len(),
113
libc::MSG_TRUNC,
114
)
115
};
116
let receive_len = if retval < 0 {
117
Err(std::io::Error::last_os_error())
118
} else {
119
Ok(retval)
120
}? as usize;
121
122
if receive_len > buf.len() {
123
Err(std::io::Error::other(format!(
124
"packet size {:?} encountered, but buffer was only of size {:?}",
125
receive_len,
126
buf.len()
127
)))
128
} else {
129
Ok(receive_len)
130
}
131
}
132
}
133
}
134
135
/// Creates a cross platform stream pair.
136
pub fn pair(
137
blocking_mode: BlockingMode,
138
framing_mode: FramingMode,
139
) -> Result<(StreamChannel, StreamChannel)> {
140
let (pipe_a, pipe_b) = match framing_mode {
141
FramingMode::Byte => {
142
let (pipe_a, pipe_b) = UnixStream::pair()?;
143
(SocketType::Byte(pipe_a), SocketType::Byte(pipe_b))
144
}
145
FramingMode::Message => {
146
let (pipe_a, pipe_b) = UnixSeqpacket::pair()?;
147
(SocketType::Message(pipe_a), SocketType::Message(pipe_b))
148
}
149
};
150
let mut stream_a = StreamChannel { stream: pipe_a };
151
let mut stream_b = StreamChannel { stream: pipe_b };
152
let is_non_blocking = blocking_mode == BlockingMode::Nonblocking;
153
stream_a.set_nonblocking(is_non_blocking)?;
154
stream_b.set_nonblocking(is_non_blocking)?;
155
Ok((stream_a, stream_b))
156
}
157
158
pub fn set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
159
match &self.stream {
160
SocketType::Byte(sock) => sock.set_read_timeout(timeout),
161
SocketType::Message(sock) => sock.set_read_timeout(timeout),
162
}
163
}
164
165
pub fn set_write_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
166
match &self.stream {
167
SocketType::Byte(sock) => sock.set_write_timeout(timeout),
168
SocketType::Message(sock) => sock.set_write_timeout(timeout),
169
}
170
}
171
172
// WARNING: Generally, multiple StreamChannel ends are not wanted. StreamChannel behavior with
173
// > 1 reader per end is not defined.
174
pub fn try_clone(&self) -> io::Result<Self> {
175
Ok(StreamChannel {
176
stream: match &self.stream {
177
SocketType::Byte(sock) => SocketType::Byte(sock.try_clone()?),
178
SocketType::Message(sock) => SocketType::Message(sock.try_clone()?),
179
},
180
})
181
}
182
}
183
184
impl io::Write for StreamChannel {
185
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
186
match &mut self.stream {
187
SocketType::Byte(sock) => sock.write(buf),
188
SocketType::Message(sock) => sock.send(buf),
189
}
190
}
191
fn flush(&mut self) -> io::Result<()> {
192
match &mut self.stream {
193
SocketType::Byte(sock) => sock.flush(),
194
SocketType::Message(_) => Ok(()),
195
}
196
}
197
}
198
199
impl io::Write for &StreamChannel {
200
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
201
match &self.stream {
202
SocketType::Byte(sock) => (&mut &*sock).write(buf),
203
SocketType::Message(sock) => sock.send(buf),
204
}
205
}
206
fn flush(&mut self) -> io::Result<()> {
207
match &self.stream {
208
SocketType::Byte(sock) => (&mut &*sock).flush(),
209
SocketType::Message(_) => Ok(()),
210
}
211
}
212
}
213
214
impl AsRawFd for StreamChannel {
215
fn as_raw_fd(&self) -> RawFd {
216
match &self.stream {
217
SocketType::Byte(sock) => sock.as_raw_descriptor(),
218
SocketType::Message(sock) => sock.as_raw_descriptor(),
219
}
220
}
221
}
222
223
impl AsRawFd for &StreamChannel {
224
fn as_raw_fd(&self) -> RawFd {
225
self.as_raw_descriptor()
226
}
227
}
228
229
impl AsRawDescriptor for &StreamChannel {
230
fn as_raw_descriptor(&self) -> RawDescriptor {
231
match &self.stream {
232
SocketType::Byte(sock) => sock.as_raw_descriptor(),
233
SocketType::Message(sock) => sock.as_raw_descriptor(),
234
}
235
}
236
}
237
238
impl IntoRawDescriptor for StreamChannel {
239
fn into_raw_descriptor(self) -> RawFd {
240
match self.stream {
241
SocketType::Byte(sock) => sock.into_raw_descriptor(),
242
SocketType::Message(sock) => sock.into_raw_descriptor(),
243
}
244
}
245
}
246
247
impl ReadNotifier for StreamChannel {
248
/// Returns a RawDescriptor that can be polled for reads using PollContext.
249
fn get_read_notifier(&self) -> &dyn AsRawDescriptor {
250
self
251
}
252
}
253
254
#[cfg(test)]
255
mod test {
256
use std::io::Read;
257
use std::io::Write;
258
259
use super::*;
260
use crate::EventContext;
261
use crate::EventToken;
262
use crate::ReadNotifier;
263
264
#[derive(EventToken, Debug, Eq, PartialEq, Copy, Clone)]
265
enum Token {
266
ReceivedData,
267
}
268
269
#[test]
270
fn test_non_blocking_pair_byte() {
271
let (mut sender, mut receiver) =
272
StreamChannel::pair(BlockingMode::Nonblocking, FramingMode::Byte).unwrap();
273
274
sender.write_all(&[75, 77, 54, 82, 76, 65]).unwrap();
275
276
// Wait for the data to arrive.
277
let event_ctx: EventContext<Token> =
278
EventContext::build_with(&[(receiver.get_read_notifier(), Token::ReceivedData)])
279
.unwrap();
280
let events = event_ctx.wait().unwrap();
281
let tokens: Vec<Token> = events
282
.iter()
283
.filter(|e| e.is_readable)
284
.map(|e| e.token)
285
.collect();
286
assert_eq!(tokens, vec! {Token::ReceivedData});
287
288
// Smaller than what we sent so we get multiple chunks
289
let mut recv_buffer: [u8; 4] = [0; 4];
290
291
let mut size = receiver.read(&mut recv_buffer).unwrap();
292
assert_eq!(size, 4);
293
assert_eq!(recv_buffer, [75, 77, 54, 82]);
294
295
size = receiver.read(&mut recv_buffer).unwrap();
296
assert_eq!(size, 2);
297
assert_eq!(recv_buffer[0..2], [76, 65]);
298
299
// Now that we've polled for & received all data, polling again should show no events.
300
assert_eq!(
301
event_ctx
302
.wait_timeout(std::time::Duration::new(0, 0))
303
.unwrap()
304
.len(),
305
0
306
);
307
}
308
309
#[test]
310
fn test_non_blocking_pair_message() {
311
let (mut sender, mut receiver) =
312
StreamChannel::pair(BlockingMode::Nonblocking, FramingMode::Message).unwrap();
313
314
sender.write_all(&[75, 77, 54, 82, 76, 65]).unwrap();
315
316
// Wait for the data to arrive.
317
let event_ctx: EventContext<Token> =
318
EventContext::build_with(&[(receiver.get_read_notifier(), Token::ReceivedData)])
319
.unwrap();
320
let events = event_ctx.wait().unwrap();
321
let tokens: Vec<Token> = events
322
.iter()
323
.filter(|e| e.is_readable)
324
.map(|e| e.token)
325
.collect();
326
assert_eq!(tokens, vec! {Token::ReceivedData});
327
328
// Unlike Byte format, Message mode panics if the buffer is smaller than the packet size;
329
// make the buffer the right size.
330
let mut recv_buffer: [u8; 6] = [0; 6];
331
332
let size = receiver.read(&mut recv_buffer).unwrap();
333
assert_eq!(size, 6);
334
assert_eq!(recv_buffer, [75, 77, 54, 82, 76, 65]);
335
336
// Now that we've polled for & received all data, polling again should show no events.
337
assert_eq!(
338
event_ctx
339
.wait_timeout(std::time::Duration::new(0, 0))
340
.unwrap()
341
.len(),
342
0
343
);
344
}
345
346
#[test]
347
fn test_non_blocking_pair_error_no_data() {
348
let (mut sender, mut receiver) =
349
StreamChannel::pair(BlockingMode::Nonblocking, FramingMode::Byte).unwrap();
350
receiver
351
.set_nonblocking(true)
352
.expect("Failed to set receiver to nonblocking mode.");
353
354
sender.write_all(&[75, 77]).unwrap();
355
356
// Wait for the data to arrive.
357
let event_ctx: EventContext<Token> =
358
EventContext::build_with(&[(receiver.get_read_notifier(), Token::ReceivedData)])
359
.unwrap();
360
let events = event_ctx.wait().unwrap();
361
let tokens: Vec<Token> = events
362
.iter()
363
.filter(|e| e.is_readable)
364
.map(|e| e.token)
365
.collect();
366
assert_eq!(tokens, vec! {Token::ReceivedData});
367
368
// We only read 2 bytes, even though we requested 4 bytes.
369
let mut recv_buffer: [u8; 4] = [0; 4];
370
let size = receiver.read(&mut recv_buffer).unwrap();
371
assert_eq!(size, 2);
372
assert_eq!(recv_buffer, [75, 77, 00, 00]);
373
374
// Further reads should encounter an error since there is no available data and this is a
375
// non blocking pipe.
376
assert!(receiver.read(&mut recv_buffer).is_err());
377
}
378
}
379
380