Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
google
GitHub Repository: google/crosvm
Path: blob/main/base/src/sys/linux/vsock.rs
5394 views
1
// Copyright 2021 The ChromiumOS Authors
2
// Use of this source code is governed by a BSD-style license that can be
3
// found in the LICENSE file.
4
5
/// Support for virtual sockets.
6
use std::fmt;
7
use std::io;
8
use std::mem;
9
use std::mem::size_of;
10
use std::num::ParseIntError;
11
use std::os::raw::c_uchar;
12
use std::os::raw::c_uint;
13
use std::os::raw::c_ushort;
14
use std::os::unix::io::AsRawFd;
15
use std::os::unix::io::IntoRawFd;
16
use std::os::unix::io::RawFd;
17
use std::result;
18
use std::str::FromStr;
19
20
use libc::c_void;
21
use libc::sa_family_t;
22
use libc::size_t;
23
use libc::sockaddr;
24
use libc::socklen_t;
25
use libc::F_GETFL;
26
use libc::F_SETFL;
27
use libc::O_NONBLOCK;
28
use libc::VMADDR_CID_ANY;
29
use libc::VMADDR_CID_HOST;
30
use libc::VMADDR_CID_HYPERVISOR;
31
use thiserror::Error;
32
33
// The domain for vsock sockets.
34
const AF_VSOCK: sa_family_t = 40;
35
36
// Vsock loopback address.
37
const VMADDR_CID_LOCAL: c_uint = 1;
38
39
/// Vsock equivalent of binding on port 0. Binds to a random port.
40
pub const VMADDR_PORT_ANY: c_uint = c_uint::MAX;
41
42
// The number of bytes of padding to be added to the sockaddr_vm struct. Taken directly
43
// from linux/vm_sockets.h.
44
const PADDING: usize = size_of::<sockaddr>()
45
- size_of::<sa_family_t>()
46
- size_of::<c_ushort>()
47
- (2 * size_of::<c_uint>());
48
49
#[repr(C)]
50
#[derive(Default)]
51
struct sockaddr_vm {
52
svm_family: sa_family_t,
53
svm_reserved1: c_ushort,
54
svm_port: c_uint,
55
svm_cid: c_uint,
56
svm_zero: [c_uchar; PADDING],
57
}
58
59
#[derive(Error, Debug)]
60
#[error("failed to parse vsock address")]
61
pub struct AddrParseError;
62
63
/// The vsock equivalent of an IP address.
64
#[derive(Debug, Copy, Clone, Hash, Eq, PartialEq, Ord, PartialOrd)]
65
pub enum VsockCid {
66
/// Vsock equivalent of INADDR_ANY. Indicates the context id of the current endpoint.
67
Any,
68
/// An address that refers to the bare-metal machine that serves as the hypervisor.
69
Hypervisor,
70
/// The loopback address.
71
Local,
72
/// The parent machine. It may not be the hypervisor for nested VMs.
73
Host,
74
/// An assigned CID that serves as the address for VSOCK.
75
Cid(c_uint),
76
}
77
78
impl fmt::Display for VsockCid {
79
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
80
match &self {
81
VsockCid::Any => write!(fmt, "Any"),
82
VsockCid::Hypervisor => write!(fmt, "Hypervisor"),
83
VsockCid::Local => write!(fmt, "Local"),
84
VsockCid::Host => write!(fmt, "Host"),
85
VsockCid::Cid(c) => write!(fmt, "'{c}'"),
86
}
87
}
88
}
89
90
impl From<c_uint> for VsockCid {
91
fn from(c: c_uint) -> Self {
92
match c {
93
VMADDR_CID_ANY => VsockCid::Any,
94
VMADDR_CID_HYPERVISOR => VsockCid::Hypervisor,
95
VMADDR_CID_LOCAL => VsockCid::Local,
96
VMADDR_CID_HOST => VsockCid::Host,
97
_ => VsockCid::Cid(c),
98
}
99
}
100
}
101
102
impl FromStr for VsockCid {
103
type Err = ParseIntError;
104
105
fn from_str(s: &str) -> Result<Self, Self::Err> {
106
let c: c_uint = s.parse()?;
107
Ok(c.into())
108
}
109
}
110
111
impl From<VsockCid> for c_uint {
112
fn from(cid: VsockCid) -> c_uint {
113
match cid {
114
VsockCid::Any => VMADDR_CID_ANY,
115
VsockCid::Hypervisor => VMADDR_CID_HYPERVISOR,
116
VsockCid::Local => VMADDR_CID_LOCAL,
117
VsockCid::Host => VMADDR_CID_HOST,
118
VsockCid::Cid(c) => c,
119
}
120
}
121
}
122
123
/// An address associated with a virtual socket.
124
#[derive(Debug, Copy, Clone, Hash, Eq, PartialEq, Ord, PartialOrd)]
125
pub struct SocketAddr {
126
pub cid: VsockCid,
127
pub port: c_uint,
128
}
129
130
pub trait ToSocketAddr {
131
fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError>;
132
}
133
134
impl ToSocketAddr for SocketAddr {
135
fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
136
Ok(*self)
137
}
138
}
139
140
impl ToSocketAddr for str {
141
fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
142
self.parse()
143
}
144
}
145
146
impl ToSocketAddr for (VsockCid, c_uint) {
147
fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
148
let (cid, port) = *self;
149
Ok(SocketAddr { cid, port })
150
}
151
}
152
153
impl<T: ToSocketAddr + ?Sized> ToSocketAddr for &T {
154
fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
155
(**self).to_socket_addr()
156
}
157
}
158
159
impl FromStr for SocketAddr {
160
type Err = AddrParseError;
161
162
/// Parse a vsock SocketAddr from a string. vsock socket addresses are of the form
163
/// "vsock:cid:port".
164
fn from_str(s: &str) -> Result<SocketAddr, AddrParseError> {
165
let components: Vec<&str> = s.split(':').collect();
166
if components.len() != 3 || components[0] != "vsock" {
167
return Err(AddrParseError);
168
}
169
170
Ok(SocketAddr {
171
cid: components[1].parse().map_err(|_| AddrParseError)?,
172
port: components[2].parse().map_err(|_| AddrParseError)?,
173
})
174
}
175
}
176
177
impl fmt::Display for SocketAddr {
178
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
179
write!(fmt, "{}:{}", self.cid, self.port)
180
}
181
}
182
183
/// Sets `fd` to be blocking or nonblocking. `fd` must be a valid fd of a type that accepts the
184
/// `O_NONBLOCK` flag. This includes regular files, pipes, and sockets.
185
unsafe fn set_nonblocking(fd: RawFd, nonblocking: bool) -> io::Result<()> {
186
let flags = libc::fcntl(fd, F_GETFL, 0);
187
if flags < 0 {
188
return Err(io::Error::last_os_error());
189
}
190
191
let flags = if nonblocking {
192
flags | O_NONBLOCK
193
} else {
194
flags & !O_NONBLOCK
195
};
196
197
let ret = libc::fcntl(fd, F_SETFL, flags);
198
if ret < 0 {
199
return Err(io::Error::last_os_error());
200
}
201
202
Ok(())
203
}
204
205
/// A virtual socket.
206
///
207
/// Do not use this class unless you need to change socket options or query the
208
/// state of the socket prior to calling listen or connect. Instead use either VsockStream or
209
/// VsockListener.
210
#[derive(Debug)]
211
pub struct VsockSocket {
212
fd: RawFd,
213
}
214
215
impl VsockSocket {
216
pub fn new() -> io::Result<Self> {
217
// SAFETY: trivially safe
218
let fd = unsafe { libc::socket(libc::AF_VSOCK, libc::SOCK_STREAM | libc::SOCK_CLOEXEC, 0) };
219
if fd < 0 {
220
Err(io::Error::last_os_error())
221
} else {
222
Ok(VsockSocket { fd })
223
}
224
}
225
226
pub fn bind<A: ToSocketAddr>(&mut self, addr: A) -> io::Result<()> {
227
let sockaddr = addr
228
.to_socket_addr()
229
.map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?;
230
231
// The compiler should optimize this out since these are both compile-time constants.
232
assert_eq!(size_of::<sockaddr_vm>(), size_of::<sockaddr>());
233
234
let svm = sockaddr_vm {
235
svm_family: AF_VSOCK,
236
svm_cid: sockaddr.cid.into(),
237
svm_port: sockaddr.port,
238
..Default::default()
239
};
240
241
// SAFETY:
242
// Safe because this doesn't modify any memory and we check the return value.
243
let ret = unsafe {
244
libc::bind(
245
self.fd,
246
&svm as *const sockaddr_vm as *const sockaddr,
247
size_of::<sockaddr_vm>() as socklen_t,
248
)
249
};
250
if ret < 0 {
251
let bind_err = io::Error::last_os_error();
252
Err(bind_err)
253
} else {
254
Ok(())
255
}
256
}
257
258
pub fn connect<A: ToSocketAddr>(self, addr: A) -> io::Result<VsockStream> {
259
let sockaddr = addr
260
.to_socket_addr()
261
.map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?;
262
263
let svm = sockaddr_vm {
264
svm_family: AF_VSOCK,
265
svm_cid: sockaddr.cid.into(),
266
svm_port: sockaddr.port,
267
..Default::default()
268
};
269
270
// SAFETY:
271
// Safe because this just connects a vsock socket, and the return value is checked.
272
let ret = unsafe {
273
libc::connect(
274
self.fd,
275
&svm as *const sockaddr_vm as *const sockaddr,
276
size_of::<sockaddr_vm>() as socklen_t,
277
)
278
};
279
if ret < 0 {
280
let connect_err = io::Error::last_os_error();
281
Err(connect_err)
282
} else {
283
Ok(VsockStream { sock: self })
284
}
285
}
286
287
pub fn listen(self) -> io::Result<VsockListener> {
288
// SAFETY:
289
// Safe because this doesn't modify any memory and we check the return value.
290
let ret = unsafe { libc::listen(self.fd, 1) };
291
if ret < 0 {
292
let listen_err = io::Error::last_os_error();
293
return Err(listen_err);
294
}
295
Ok(VsockListener { sock: self })
296
}
297
298
/// Returns the port that this socket is bound to. This can only succeed after bind is called.
299
pub fn local_port(&self) -> io::Result<u32> {
300
let mut svm: sockaddr_vm = Default::default();
301
302
let mut addrlen = size_of::<sockaddr_vm>() as socklen_t;
303
// SAFETY:
304
// Safe because we give a valid pointer for addrlen and check the length.
305
let ret = unsafe {
306
// Get the socket address that was actually bound.
307
libc::getsockname(
308
self.fd,
309
&mut svm as *mut sockaddr_vm as *mut sockaddr,
310
&mut addrlen as *mut socklen_t,
311
)
312
};
313
if ret < 0 {
314
let getsockname_err = io::Error::last_os_error();
315
Err(getsockname_err)
316
} else {
317
// If this doesn't match, it's not safe to get the port out of the sockaddr.
318
assert_eq!(addrlen as usize, size_of::<sockaddr_vm>());
319
320
Ok(svm.svm_port)
321
}
322
}
323
324
pub fn try_clone(&self) -> io::Result<Self> {
325
// SAFETY:
326
// Safe because this doesn't modify any memory and we check the return value.
327
let dup_fd = unsafe { libc::fcntl(self.fd, libc::F_DUPFD_CLOEXEC, 0) };
328
if dup_fd < 0 {
329
Err(io::Error::last_os_error())
330
} else {
331
Ok(Self { fd: dup_fd })
332
}
333
}
334
335
pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> {
336
// SAFETY:
337
// Safe because the fd is valid and owned by this stream.
338
unsafe { set_nonblocking(self.fd, nonblocking) }
339
}
340
}
341
342
impl IntoRawFd for VsockSocket {
343
fn into_raw_fd(self) -> RawFd {
344
let fd = self.fd;
345
mem::forget(self);
346
fd
347
}
348
}
349
350
impl AsRawFd for VsockSocket {
351
fn as_raw_fd(&self) -> RawFd {
352
self.fd
353
}
354
}
355
356
impl Drop for VsockSocket {
357
fn drop(&mut self) {
358
// SAFETY:
359
// Safe because this doesn't modify any memory and we are the only
360
// owner of the file descriptor.
361
unsafe { libc::close(self.fd) };
362
}
363
}
364
365
/// A virtual stream socket.
366
#[derive(Debug)]
367
pub struct VsockStream {
368
sock: VsockSocket,
369
}
370
371
impl VsockStream {
372
pub fn connect<A: ToSocketAddr>(addr: A) -> io::Result<VsockStream> {
373
let sock = VsockSocket::new()?;
374
sock.connect(addr)
375
}
376
377
/// Returns the port that this stream is bound to.
378
pub fn local_port(&self) -> io::Result<u32> {
379
self.sock.local_port()
380
}
381
382
pub fn try_clone(&self) -> io::Result<VsockStream> {
383
self.sock.try_clone().map(|f| VsockStream { sock: f })
384
}
385
386
pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> {
387
self.sock.set_nonblocking(nonblocking)
388
}
389
}
390
391
impl io::Read for VsockStream {
392
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
393
// SAFETY:
394
// Safe because this will only modify the contents of |buf| and we check the return value.
395
let ret = unsafe {
396
libc::read(
397
self.sock.as_raw_fd(),
398
buf as *mut [u8] as *mut c_void,
399
buf.len() as size_t,
400
)
401
};
402
if ret < 0 {
403
return Err(io::Error::last_os_error());
404
}
405
406
Ok(ret as usize)
407
}
408
}
409
410
impl io::Write for VsockStream {
411
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
412
// SAFETY:
413
// Safe because this doesn't modify any memory and we check the return value.
414
let ret = unsafe {
415
libc::write(
416
self.sock.as_raw_fd(),
417
buf as *const [u8] as *const c_void,
418
buf.len() as size_t,
419
)
420
};
421
if ret < 0 {
422
return Err(io::Error::last_os_error());
423
}
424
425
Ok(ret as usize)
426
}
427
428
fn flush(&mut self) -> io::Result<()> {
429
// No buffered data so nothing to do.
430
Ok(())
431
}
432
}
433
434
impl AsRawFd for VsockStream {
435
fn as_raw_fd(&self) -> RawFd {
436
self.sock.as_raw_fd()
437
}
438
}
439
440
impl IntoRawFd for VsockStream {
441
fn into_raw_fd(self) -> RawFd {
442
self.sock.into_raw_fd()
443
}
444
}
445
446
/// Represents a virtual socket server.
447
#[derive(Debug)]
448
pub struct VsockListener {
449
sock: VsockSocket,
450
}
451
452
impl VsockListener {
453
/// Creates a new `VsockListener` bound to the specified port on the current virtual socket
454
/// endpoint.
455
pub fn bind<A: ToSocketAddr>(addr: A) -> io::Result<VsockListener> {
456
let mut sock = VsockSocket::new()?;
457
sock.bind(addr)?;
458
sock.listen()
459
}
460
461
/// Returns the port that this listener is bound to.
462
pub fn local_port(&self) -> io::Result<u32> {
463
self.sock.local_port()
464
}
465
466
/// Accepts a new incoming connection on this listener. Blocks the calling thread until a
467
/// new connection is established. When established, returns the corresponding `VsockStream`
468
/// and the remote peer's address.
469
pub fn accept(&self) -> io::Result<(VsockStream, SocketAddr)> {
470
let mut svm: sockaddr_vm = Default::default();
471
472
let mut socklen: socklen_t = size_of::<sockaddr_vm>() as socklen_t;
473
// SAFETY:
474
// Safe because this will only modify |svm| and we check the return value.
475
let fd = unsafe {
476
libc::accept4(
477
self.sock.as_raw_fd(),
478
&mut svm as *mut sockaddr_vm as *mut sockaddr,
479
&mut socklen as *mut socklen_t,
480
libc::SOCK_CLOEXEC,
481
)
482
};
483
if fd < 0 {
484
return Err(io::Error::last_os_error());
485
}
486
487
if svm.svm_family != AF_VSOCK {
488
return Err(io::Error::new(
489
io::ErrorKind::InvalidData,
490
format!("unexpected address family: {}", svm.svm_family),
491
));
492
}
493
494
Ok((
495
VsockStream {
496
sock: VsockSocket { fd },
497
},
498
SocketAddr {
499
cid: svm.svm_cid.into(),
500
port: svm.svm_port,
501
},
502
))
503
}
504
505
pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> {
506
self.sock.set_nonblocking(nonblocking)
507
}
508
}
509
510
impl AsRawFd for VsockListener {
511
fn as_raw_fd(&self) -> RawFd {
512
self.sock.as_raw_fd()
513
}
514
}
515
516