Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
google
GitHub Repository: google/crosvm
Path: blob/main/devices/src/virtio/vhost_user_backend/handler.rs
5394 views
1
// Copyright 2021 The ChromiumOS Authors
2
// Use of this source code is governed by a BSD-style license that can be
3
// found in the LICENSE file.
4
5
//! Library for implementing vhost-user device executables.
6
//!
7
//! This crate provides
8
//! * `VhostUserDevice` trait, which is a collection of methods to handle vhost-user requests, and
9
//! * `DeviceRequestHandler` struct, which makes a connection to a VMM and starts an event loop.
10
//!
11
//! They are expected to be used as follows:
12
//!
13
//! 1. Define a struct and implement `VhostUserDevice` for it.
14
//! 2. Create a `DeviceRequestHandler` with the backend struct.
15
//! 3. Drive the `DeviceRequestHandler::run` async fn with an executor.
16
//!
17
//! ```ignore
18
//! struct MyBackend {
19
//! /* fields */
20
//! }
21
//!
22
//! impl VhostUserDevice for MyBackend {
23
//! /* implement methods */
24
//! }
25
//!
26
//! fn main() -> Result<(), Box<dyn Error>> {
27
//! let backend = MyBackend { /* initialize fields */ };
28
//! let handler = DeviceRequestHandler::new(backend);
29
//! let socket = std::path::Path("/path/to/socket");
30
//! let ex = cros_async::Executor::new()?;
31
//!
32
//! if let Err(e) = ex.run_until(handler.run(socket, &ex)) {
33
//! eprintln!("error happened: {}", e);
34
//! }
35
//! Ok(())
36
//! }
37
//! ```
38
// Implementation note:
39
// This code lets us take advantage of the vmm_vhost low level implementation of the vhost user
40
// protocol. DeviceRequestHandler implements the Backend trait from vmm_vhost, and includes some
41
// common code for setting up guest memory and managing partially configured vrings.
42
// DeviceRequestHandler::run watches the vhost-user socket and then calls handle_request() when it
43
// becomes readable. handle_request() reads and parses the message and then calls one of the
44
// Backend trait methods. These dispatch back to the supplied VhostUserDevice implementation (this
45
// is what our devices implement).
46
47
pub(super) mod sys;
48
49
use std::collections::BTreeMap;
50
use std::convert::From;
51
use std::fs::File;
52
use std::io::BufReader;
53
use std::io::Write;
54
use std::num::Wrapping;
55
#[cfg(any(target_os = "android", target_os = "linux"))]
56
use std::os::unix::io::AsRawFd;
57
use std::sync::Arc;
58
59
use anyhow::bail;
60
use anyhow::Context;
61
#[cfg(any(target_os = "android", target_os = "linux"))]
62
use base::clear_fd_flags;
63
use base::error;
64
use base::trace;
65
use base::warn;
66
use base::Event;
67
use base::Protection;
68
use base::SafeDescriptor;
69
use base::SharedMemory;
70
use base::WorkerThread;
71
use cros_async::TaskHandle;
72
use hypervisor::MemCacheType;
73
use serde::Deserialize;
74
use serde::Serialize;
75
use snapshot::AnySnapshot;
76
use sync::Mutex;
77
use thiserror::Error as ThisError;
78
use vm_control::VmMemorySource;
79
use vm_memory::GuestAddress;
80
use vm_memory::GuestMemory;
81
use vm_memory::MemoryRegion;
82
use vmm_vhost::message::VhostUserConfigFlags;
83
use vmm_vhost::message::VhostUserExternalMapMsg;
84
use vmm_vhost::message::VhostUserGpuMapMsg;
85
use vmm_vhost::message::VhostUserInflight;
86
use vmm_vhost::message::VhostUserMMap;
87
use vmm_vhost::message::VhostUserMMapFlags;
88
use vmm_vhost::message::VhostUserMemoryRegion;
89
use vmm_vhost::message::VhostUserMigrationPhase;
90
use vmm_vhost::message::VhostUserProtocolFeatures;
91
use vmm_vhost::message::VhostUserSingleMemoryRegion;
92
use vmm_vhost::message::VhostUserTransferDirection;
93
use vmm_vhost::message::VhostUserVringAddrFlags;
94
use vmm_vhost::message::VhostUserVringState;
95
use vmm_vhost::Connection;
96
use vmm_vhost::Error as VhostError;
97
use vmm_vhost::Frontend;
98
use vmm_vhost::FrontendClient;
99
use vmm_vhost::Result as VhostResult;
100
use vmm_vhost::VHOST_USER_F_PROTOCOL_FEATURES;
101
102
use crate::virtio::Interrupt;
103
use crate::virtio::Queue;
104
use crate::virtio::QueueConfig;
105
use crate::virtio::SharedMemoryMapper;
106
use crate::virtio::SharedMemoryRegion;
107
108
/// Keeps a mapping from the vmm's virtual addresses to guest addresses.
109
/// used to translate messages from the vmm to guest offsets.
110
#[derive(Default)]
111
pub struct MappingInfo {
112
pub vmm_addr: u64,
113
pub guest_phys: u64,
114
pub size: u64,
115
}
116
117
pub fn vmm_va_to_gpa(maps: &[MappingInfo], vmm_va: u64) -> VhostResult<GuestAddress> {
118
for map in maps {
119
if vmm_va >= map.vmm_addr && vmm_va < map.vmm_addr + map.size {
120
return Ok(GuestAddress(vmm_va - map.vmm_addr + map.guest_phys));
121
}
122
}
123
Err(VhostError::InvalidMessage)
124
}
125
126
/// Trait for vhost-user devices. Analogous to the `VirtioDevice` trait.
127
///
128
/// In contrast with [[vmm_vhost::Backend]], which closely matches the vhost-user spec, this trait
129
/// is designed to follow crosvm conventions for implementing devices.
130
pub trait VhostUserDevice {
131
/// The maximum number of queues that this backend can manage.
132
fn max_queue_num(&self) -> usize;
133
134
/// The set of feature bits that this backend supports.
135
fn features(&self) -> u64;
136
137
/// Acknowledges that this set of features should be enabled.
138
///
139
/// Implementations only need to handle device-specific feature bits; the `DeviceRequestHandler`
140
/// framework will manage generic vhost and vring features.
141
///
142
/// `DeviceRequestHandler` checks for valid features before calling this function, so the
143
/// features in `value` will always be a subset of those advertised by `features()`.
144
fn ack_features(&mut self, _value: u64) -> anyhow::Result<()> {
145
Ok(())
146
}
147
148
/// The set of protocol feature bits that this backend supports.
149
fn protocol_features(&self) -> VhostUserProtocolFeatures;
150
151
/// Reads this device configuration space at `offset`.
152
fn read_config(&self, offset: u64, dst: &mut [u8]);
153
154
/// writes `data` to this device's configuration space at `offset`.
155
fn write_config(&self, _offset: u64, _data: &[u8]) {}
156
157
/// Indicates that the backend should start processing requests for virtio queue number `idx`.
158
/// This method must not block the current thread so device backends should either spawn an
159
/// async task or another thread to handle messages from the Queue.
160
fn start_queue(&mut self, idx: usize, queue: Queue, mem: GuestMemory) -> anyhow::Result<()>;
161
162
/// Indicates that the backend should stop processing requests for virtio queue number `idx`.
163
/// This method should return the queue passed to `start_queue` for the corresponding `idx`.
164
/// This method will only be called for queues that were previously started by `start_queue`.
165
fn stop_queue(&mut self, idx: usize) -> anyhow::Result<Queue>;
166
167
/// Resets the vhost-user backend.
168
fn reset(&mut self);
169
170
/// Returns the device's shared memory region if present.
171
fn get_shared_memory_region(&self) -> Option<SharedMemoryRegion> {
172
None
173
}
174
175
/// Accepts `VhostBackendReqConnection` to conduct Vhost backend to frontend message
176
/// handling.
177
///
178
/// This method will be called when `VhostUserProtocolFeatures::BACKEND_REQ` is
179
/// negotiated.
180
fn set_backend_req_connection(&mut self, _conn: VhostBackendReqConnection) {}
181
182
/// Enter the "suspended device state" described in the vhost-user spec. See the spec for
183
/// requirements.
184
///
185
/// One reasonably foolproof way to satisfy the requirements is to stop all worker threads.
186
///
187
/// Called after a `stop_queue` call if there are no running queues left. Also called soon
188
/// after device creation to ensure the device is acting suspended immediately on construction.
189
///
190
/// The next `start_queue` call implicitly exits the "suspend device state".
191
///
192
/// * Ok(()) => device successfully suspended
193
/// * Err(_) => unrecoverable error
194
fn enter_suspended_state(&mut self) -> anyhow::Result<()>;
195
196
/// Snapshot device and return serialized state.
197
fn snapshot(&mut self) -> anyhow::Result<AnySnapshot>;
198
199
/// Restore device state from a snapshot.
200
fn restore(&mut self, data: AnySnapshot) -> anyhow::Result<()>;
201
202
/// Whether guest memory should be unmapped in forked processes.
203
///
204
/// This is intended for use in combination with --protected-vm, where the guest memory can be
205
/// dangerous to access. Some systems, e.g. Android, have tools that fork processes and examine
206
/// their memory. This flag effectively hides the guest memory from those tools.
207
///
208
/// Not compatible with sandboxing.
209
fn unmap_guest_memory_on_fork(&self) -> bool {
210
false
211
}
212
}
213
214
/// A virtio ring entry.
215
struct Vring {
216
// The queue config. This doesn't get mutated by the queue workers.
217
queue: QueueConfig,
218
doorbell: Option<Interrupt>,
219
enabled: bool,
220
}
221
222
impl Vring {
223
fn new(max_size: u16, features: u64) -> Self {
224
Self {
225
queue: QueueConfig::new(max_size, features),
226
doorbell: None,
227
enabled: false,
228
}
229
}
230
231
fn reset(&mut self) {
232
self.queue.reset();
233
self.doorbell = None;
234
self.enabled = false;
235
}
236
}
237
238
/// Ops for running vhost-user over a stream (i.e. regular protocol).
239
pub(super) struct VhostUserRegularOps;
240
241
impl VhostUserRegularOps {
242
pub fn set_mem_table(
243
contexts: &[VhostUserMemoryRegion],
244
files: Vec<File>,
245
) -> VhostResult<(GuestMemory, Vec<MappingInfo>)> {
246
if files.len() != contexts.len() {
247
return Err(VhostError::InvalidParam(
248
"number of files & contexts was not equal",
249
));
250
}
251
252
let mut regions = Vec::with_capacity(files.len());
253
for (region, file) in contexts.iter().zip(files.into_iter()) {
254
let region = MemoryRegion::new_from_shm(
255
region.memory_size,
256
GuestAddress(region.guest_phys_addr),
257
region.mmap_offset,
258
Arc::new(
259
SharedMemory::from_safe_descriptor(
260
SafeDescriptor::from(file),
261
region.memory_size,
262
)
263
.unwrap(),
264
),
265
)
266
.map_err(|e| {
267
error!("failed to create a memory region: {}", e);
268
VhostError::InvalidOperation
269
})?;
270
regions.push(region);
271
}
272
let guest_mem = GuestMemory::from_regions(regions).map_err(|e| {
273
error!("failed to create guest memory: {}", e);
274
VhostError::InvalidOperation
275
})?;
276
277
let vmm_maps = contexts
278
.iter()
279
.map(|region| MappingInfo {
280
vmm_addr: region.user_addr,
281
guest_phys: region.guest_phys_addr,
282
size: region.memory_size,
283
})
284
.collect();
285
Ok((guest_mem, vmm_maps))
286
}
287
}
288
289
/// An adapter that implements `vmm_vhost::Backend` for any type implementing `VhostUserDevice`.
290
pub struct DeviceRequestHandler<T: VhostUserDevice> {
291
vrings: Vec<Vring>,
292
owned: bool,
293
vmm_maps: Option<Vec<MappingInfo>>,
294
mem: Option<GuestMemory>,
295
acked_features: u64,
296
acked_protocol_features: VhostUserProtocolFeatures,
297
backend: T,
298
backend_req_connection: Option<VhostBackendReqConnection>,
299
// Thread processing active device state FD.
300
device_state_thread: Option<DeviceStateThread>,
301
}
302
303
enum DeviceStateThread {
304
Save(WorkerThread<Result<(), ciborium::ser::Error<std::io::Error>>>),
305
Load(WorkerThread<Result<DeviceRequestHandlerSnapshot, ciborium::de::Error<std::io::Error>>>),
306
}
307
308
#[derive(Serialize, Deserialize)]
309
pub struct DeviceRequestHandlerSnapshot {
310
acked_features: u64,
311
acked_protocol_features: u64,
312
backend: AnySnapshot,
313
}
314
315
impl<T: VhostUserDevice> DeviceRequestHandler<T> {
316
/// Creates a vhost-user handler instance for `backend`.
317
pub(crate) fn new(mut backend: T) -> Self {
318
let mut vrings = Vec::with_capacity(backend.max_queue_num());
319
for _ in 0..backend.max_queue_num() {
320
vrings.push(Vring::new(Queue::MAX_SIZE, backend.features()));
321
}
322
323
// VhostUserDevice implementations must support `enter_suspended_state()`.
324
// Call it on startup to ensure it works and to initialize the device in a suspended state.
325
backend
326
.enter_suspended_state()
327
.expect("enter_suspended_state failed on device init");
328
329
DeviceRequestHandler {
330
vrings,
331
owned: false,
332
vmm_maps: None,
333
mem: None,
334
acked_features: 0,
335
acked_protocol_features: VhostUserProtocolFeatures::empty(),
336
backend,
337
backend_req_connection: None,
338
device_state_thread: None,
339
}
340
}
341
342
/// Check if all queues are stopped.
343
///
344
/// The device can be suspended with `enter_suspended_state()` only when all queues are stopped.
345
fn all_queues_stopped(&self) -> bool {
346
self.vrings.iter().all(|vring| !vring.queue.ready())
347
}
348
}
349
350
impl<T: VhostUserDevice> Drop for DeviceRequestHandler<T> {
351
fn drop(&mut self) {
352
for (index, vring) in self.vrings.iter().enumerate() {
353
if vring.queue.ready() {
354
if let Err(e) = self.backend.stop_queue(index) {
355
error!("Failed to stop queue {} during drop: {:#}", index, e);
356
}
357
}
358
}
359
}
360
}
361
362
impl<T: VhostUserDevice> AsRef<T> for DeviceRequestHandler<T> {
363
fn as_ref(&self) -> &T {
364
&self.backend
365
}
366
}
367
368
impl<T: VhostUserDevice> AsMut<T> for DeviceRequestHandler<T> {
369
fn as_mut(&mut self) -> &mut T {
370
&mut self.backend
371
}
372
}
373
374
impl<T: VhostUserDevice> vmm_vhost::Backend for DeviceRequestHandler<T> {
375
fn set_owner(&mut self) -> VhostResult<()> {
376
if self.owned {
377
return Err(VhostError::InvalidOperation);
378
}
379
self.owned = true;
380
Ok(())
381
}
382
383
fn reset_owner(&mut self) -> VhostResult<()> {
384
self.owned = false;
385
self.acked_features = 0;
386
self.backend.reset();
387
Ok(())
388
}
389
390
fn get_features(&mut self) -> VhostResult<u64> {
391
let features = self.backend.features();
392
Ok(features)
393
}
394
395
fn set_features(&mut self, features: u64) -> VhostResult<()> {
396
if !self.owned {
397
return Err(VhostError::InvalidOperation);
398
}
399
400
let unexpected_features = features & !self.backend.features();
401
if unexpected_features != 0 {
402
error!("unexpected set_features {:#x}", unexpected_features);
403
return Err(VhostError::InvalidParam("unexpected set_features"));
404
}
405
406
if let Err(e) = self.backend.ack_features(features) {
407
error!("failed to acknowledge features 0x{:x}: {}", features, e);
408
return Err(VhostError::InvalidOperation);
409
}
410
411
self.acked_features |= features;
412
413
// If VHOST_USER_F_PROTOCOL_FEATURES has not been negotiated, the ring is initialized in an
414
// enabled state.
415
// If VHOST_USER_F_PROTOCOL_FEATURES has been negotiated, the ring is initialized in a
416
// disabled state.
417
// Client must not pass data to/from the backend until ring is enabled by
418
// VHOST_USER_SET_VRING_ENABLE with parameter 1, or after it has been disabled by
419
// VHOST_USER_SET_VRING_ENABLE with parameter 0.
420
let vring_enabled = self.acked_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES != 0;
421
for v in &mut self.vrings {
422
v.enabled = vring_enabled;
423
}
424
425
Ok(())
426
}
427
428
fn get_protocol_features(&mut self) -> VhostResult<VhostUserProtocolFeatures> {
429
Ok(self.backend.protocol_features() | VhostUserProtocolFeatures::REPLY_ACK)
430
}
431
432
fn set_protocol_features(&mut self, features: u64) -> VhostResult<()> {
433
let features = match VhostUserProtocolFeatures::from_bits(features) {
434
Some(proto_features) => proto_features,
435
None => {
436
error!(
437
"unsupported bits in VHOST_USER_SET_PROTOCOL_FEATURES: {:#x}",
438
features
439
);
440
return Err(VhostError::InvalidOperation);
441
}
442
};
443
let supported = self.get_protocol_features()?;
444
self.acked_protocol_features = features & supported;
445
Ok(())
446
}
447
448
fn set_mem_table(
449
&mut self,
450
contexts: &[VhostUserMemoryRegion],
451
files: Vec<File>,
452
) -> VhostResult<()> {
453
let (guest_mem, vmm_maps) = VhostUserRegularOps::set_mem_table(contexts, files)?;
454
if self.backend.unmap_guest_memory_on_fork() {
455
#[cfg(any(target_os = "android", target_os = "linux"))]
456
if let Err(e) = guest_mem.use_dontfork() {
457
error!("failed to set MADV_DONTFORK on guest memory: {e:#}");
458
}
459
#[cfg(not(any(target_os = "android", target_os = "linux")))]
460
error!("unmap_guest_memory_on_fork unsupported; skipping");
461
}
462
self.mem = Some(guest_mem);
463
self.vmm_maps = Some(vmm_maps);
464
Ok(())
465
}
466
467
fn get_queue_num(&mut self) -> VhostResult<u64> {
468
Ok(self.vrings.len() as u64)
469
}
470
471
fn set_vring_num(&mut self, index: u32, num: u32) -> VhostResult<()> {
472
if index as usize >= self.vrings.len() || num == 0 || num > Queue::MAX_SIZE.into() {
473
return Err(VhostError::InvalidParam(
474
"set_vring_num: invalid index or num",
475
));
476
}
477
self.vrings[index as usize].queue.set_size(num as u16);
478
479
Ok(())
480
}
481
482
fn set_vring_addr(
483
&mut self,
484
index: u32,
485
_flags: VhostUserVringAddrFlags,
486
descriptor: u64,
487
used: u64,
488
available: u64,
489
_log: u64,
490
) -> VhostResult<()> {
491
if index as usize >= self.vrings.len() {
492
return Err(VhostError::InvalidParam(
493
"set_vring_addr: index out of range",
494
));
495
}
496
497
let vmm_maps = self
498
.vmm_maps
499
.as_ref()
500
.ok_or(VhostError::InvalidParam("set_vring_addr: missing vmm_maps"))?;
501
let vring = &mut self.vrings[index as usize];
502
vring
503
.queue
504
.set_desc_table(vmm_va_to_gpa(vmm_maps, descriptor)?);
505
vring
506
.queue
507
.set_avail_ring(vmm_va_to_gpa(vmm_maps, available)?);
508
vring.queue.set_used_ring(vmm_va_to_gpa(vmm_maps, used)?);
509
510
Ok(())
511
}
512
513
fn set_vring_base(&mut self, index: u32, base: u32) -> VhostResult<()> {
514
if index as usize >= self.vrings.len() {
515
return Err(VhostError::InvalidParam(
516
"set_vring_base: index out of range",
517
));
518
}
519
520
let vring = &mut self.vrings[index as usize];
521
vring.queue.set_next_avail(Wrapping(base as u16));
522
vring.queue.set_next_used(Wrapping(base as u16));
523
524
Ok(())
525
}
526
527
fn get_vring_base(&mut self, index: u32) -> VhostResult<VhostUserVringState> {
528
let vring = self
529
.vrings
530
.get_mut(index as usize)
531
.ok_or(VhostError::InvalidParam(
532
"get_vring_base: index out of range",
533
))?;
534
535
// Quotation from vhost-user spec:
536
// "The back-end must [...] stop ring upon receiving VHOST_USER_GET_VRING_BASE."
537
// We only call `queue.set_ready()` when starting the queue, so if the queue is ready, that
538
// means it is started and should be stopped.
539
let vring_base = if vring.queue.ready() {
540
let queue = match self.backend.stop_queue(index as usize) {
541
Ok(q) => q,
542
Err(e) => {
543
error!("Failed to stop queue in get_vring_base: {:#}", e);
544
return Err(VhostError::BackendInternalError);
545
}
546
};
547
548
trace!("stopped queue {index}");
549
vring.reset();
550
551
if self.all_queues_stopped() {
552
trace!("all queues stopped; entering suspended state");
553
self.backend
554
.enter_suspended_state()
555
.map_err(VhostError::EnterSuspendedState)?;
556
}
557
558
queue.next_avail_to_process()
559
} else {
560
0
561
};
562
563
Ok(VhostUserVringState::new(index, vring_base.into()))
564
}
565
566
fn set_vring_kick(&mut self, index: u8, file: Option<File>) -> VhostResult<()> {
567
if index as usize >= self.vrings.len() {
568
return Err(VhostError::InvalidParam(
569
"set_vring_kick: index out of range",
570
));
571
}
572
573
let vring = &mut self.vrings[index as usize];
574
if vring.queue.ready() {
575
error!("kick fd cannot replaced after queue is started");
576
return Err(VhostError::InvalidOperation);
577
}
578
579
let file = file.ok_or(VhostError::InvalidParam("missing file for set_vring_kick"))?;
580
581
// Remove O_NONBLOCK from kick_fd. Otherwise, uring_executor will fails when we read
582
// values via `next_val()` later.
583
// This is only required (and can only be done) on Unix platforms.
584
#[cfg(any(target_os = "android", target_os = "linux"))]
585
if let Err(e) = clear_fd_flags(file.as_raw_fd(), libc::O_NONBLOCK) {
586
error!("failed to remove O_NONBLOCK for kick fd: {}", e);
587
return Err(VhostError::InvalidParam(
588
"could not remove O_NONBLOCK from vring_kick",
589
));
590
}
591
592
let kick_evt = Event::from(SafeDescriptor::from(file));
593
594
// Enable any virtqueue features that were negotiated (like VIRTIO_RING_F_EVENT_IDX).
595
vring.queue.ack_features(self.acked_features);
596
vring.queue.set_ready(true);
597
598
let mem = self
599
.mem
600
.as_ref()
601
.cloned()
602
.ok_or(VhostError::InvalidOperation)?;
603
604
let doorbell = vring.doorbell.clone().ok_or(VhostError::InvalidOperation)?;
605
606
let queue = match vring.queue.activate(&mem, kick_evt, doorbell) {
607
Ok(queue) => queue,
608
Err(e) => {
609
error!("failed to activate vring: {:#}", e);
610
return Err(VhostError::BackendInternalError);
611
}
612
};
613
614
if let Err(e) = self.backend.start_queue(index as usize, queue, mem) {
615
error!("Failed to start queue {}: {}", index, e);
616
return Err(VhostError::BackendInternalError);
617
}
618
trace!("started queue {index}");
619
620
Ok(())
621
}
622
623
fn set_vring_call(&mut self, index: u8, file: Option<File>) -> VhostResult<()> {
624
if index as usize >= self.vrings.len() {
625
return Err(VhostError::InvalidParam(
626
"set_vring_call: index out of range",
627
));
628
}
629
630
let backend_req_conn = self.backend_req_connection.clone();
631
let signal_config_change_fn = Box::new(move || {
632
if let Some(frontend) = backend_req_conn.as_ref() {
633
if let Err(e) = frontend.send_config_changed() {
634
error!("Failed to notify config change: {:#}", e);
635
}
636
} else {
637
error!("No Backend request connection found");
638
}
639
});
640
641
let file = file.ok_or(VhostError::InvalidParam("missing file for set_vring_call"))?;
642
self.vrings[index as usize].doorbell = Some(Interrupt::new_vhost_user(
643
Event::from(SafeDescriptor::from(file)),
644
signal_config_change_fn,
645
));
646
Ok(())
647
}
648
649
fn set_vring_err(&mut self, _index: u8, _fd: Option<File>) -> VhostResult<()> {
650
// TODO
651
Ok(())
652
}
653
654
fn set_vring_enable(&mut self, index: u32, enable: bool) -> VhostResult<()> {
655
if index as usize >= self.vrings.len() {
656
return Err(VhostError::InvalidParam(
657
"set_vring_enable: index out of range",
658
));
659
}
660
661
// This request should be handled only when VHOST_USER_F_PROTOCOL_FEATURES
662
// has been negotiated.
663
if self.acked_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES == 0 {
664
return Err(VhostError::InvalidOperation);
665
}
666
667
// Backend must not pass data to/from the ring until ring is enabled by
668
// VHOST_USER_SET_VRING_ENABLE with parameter 1, or after it has been disabled by
669
// VHOST_USER_SET_VRING_ENABLE with parameter 0.
670
self.vrings[index as usize].enabled = enable;
671
672
Ok(())
673
}
674
675
fn get_config(
676
&mut self,
677
offset: u32,
678
size: u32,
679
_flags: VhostUserConfigFlags,
680
) -> VhostResult<Vec<u8>> {
681
let mut data = vec![0; size as usize];
682
self.backend.read_config(u64::from(offset), &mut data);
683
Ok(data)
684
}
685
686
fn set_config(
687
&mut self,
688
offset: u32,
689
buf: &[u8],
690
_flags: VhostUserConfigFlags,
691
) -> VhostResult<()> {
692
self.backend.write_config(u64::from(offset), buf);
693
Ok(())
694
}
695
696
fn set_backend_req_fd(&mut self, ep: Connection) {
697
let conn = VhostBackendReqConnection::new(
698
FrontendClient::new(
699
ep,
700
self.acked_protocol_features
701
.contains(VhostUserProtocolFeatures::REPLY_ACK),
702
),
703
self.backend.get_shared_memory_region().map(|r| r.id),
704
);
705
706
if self.backend_req_connection.is_some() {
707
warn!("Backend Request Connection already established. Overwriting");
708
}
709
self.backend_req_connection = Some(conn.clone());
710
711
self.backend.set_backend_req_connection(conn);
712
}
713
714
fn get_inflight_fd(
715
&mut self,
716
_inflight: &VhostUserInflight,
717
) -> VhostResult<(VhostUserInflight, File)> {
718
unimplemented!("get_inflight_fd");
719
}
720
721
fn set_inflight_fd(&mut self, _inflight: &VhostUserInflight, _file: File) -> VhostResult<()> {
722
unimplemented!("set_inflight_fd");
723
}
724
725
fn get_max_mem_slots(&mut self) -> VhostResult<u64> {
726
//TODO
727
Ok(0)
728
}
729
730
fn add_mem_region(
731
&mut self,
732
_region: &VhostUserSingleMemoryRegion,
733
_fd: File,
734
) -> VhostResult<()> {
735
//TODO
736
Ok(())
737
}
738
739
fn remove_mem_region(&mut self, _region: &VhostUserSingleMemoryRegion) -> VhostResult<()> {
740
//TODO
741
Ok(())
742
}
743
744
fn set_device_state_fd(
745
&mut self,
746
transfer_direction: VhostUserTransferDirection,
747
migration_phase: VhostUserMigrationPhase,
748
fd: File,
749
) -> VhostResult<Option<File>> {
750
if migration_phase != VhostUserMigrationPhase::Stopped {
751
return Err(VhostError::InvalidOperation);
752
}
753
if !self.all_queues_stopped() {
754
return Err(VhostError::InvalidOperation);
755
}
756
if self.device_state_thread.is_some() {
757
error!("must call check_device_state before starting new state transfer");
758
return Err(VhostError::InvalidOperation);
759
}
760
// `set_device_state_fd` is designed to allow snapshot/restore concurrently with other
761
// methods, but, for simplicitly, we do those operations inline and only spawn a thread to
762
// handle the serialization and data transfer (the latter which seems necessary to
763
// implement the API correctly without, e.g., deadlocking because a pipe is full).
764
match transfer_direction {
765
VhostUserTransferDirection::Save => {
766
// Snapshot the state.
767
let snapshot = DeviceRequestHandlerSnapshot {
768
acked_features: self.acked_features,
769
acked_protocol_features: self.acked_protocol_features.bits(),
770
backend: self.backend.snapshot().map_err(VhostError::SnapshotError)?,
771
};
772
// Spawn thread to write the serialized bytes.
773
self.device_state_thread = Some(DeviceStateThread::Save(WorkerThread::start(
774
"device_state_save",
775
move |_kill_event| -> Result<(), ciborium::ser::Error<std::io::Error>> {
776
let mut w = std::io::BufWriter::new(fd);
777
ciborium::into_writer(&snapshot, &mut w)?;
778
w.flush()?;
779
Ok(())
780
},
781
)));
782
Ok(None)
783
}
784
VhostUserTransferDirection::Load => {
785
// Spawn a thread to read the bytes and deserialize. Restore will happen in
786
// `check_device_state`.
787
self.device_state_thread = Some(DeviceStateThread::Load(WorkerThread::start(
788
"device_state_load",
789
move |_kill_event| ciborium::from_reader(&mut BufReader::new(fd)),
790
)));
791
Ok(None)
792
}
793
}
794
}
795
796
fn check_device_state(&mut self) -> VhostResult<()> {
797
let Some(thread) = self.device_state_thread.take() else {
798
error!("check_device_state: no active state transfer");
799
return Err(VhostError::InvalidOperation);
800
};
801
match thread {
802
DeviceStateThread::Save(worker) => {
803
worker.stop().map_err(|e| {
804
error!("device state save thread failed: {:#}", e);
805
VhostError::BackendInternalError
806
})?;
807
Ok(())
808
}
809
DeviceStateThread::Load(worker) => {
810
let snapshot = worker.stop().map_err(|e| {
811
error!("device state load thread failed: {:#}", e);
812
VhostError::BackendInternalError
813
})?;
814
self.acked_features = snapshot.acked_features;
815
self.acked_protocol_features =
816
VhostUserProtocolFeatures::from_bits(snapshot.acked_protocol_features)
817
.with_context(|| {
818
format!(
819
"unsupported bits in acked_protocol_features: {:#x}",
820
snapshot.acked_protocol_features
821
)
822
})
823
.map_err(VhostError::RestoreError)?;
824
self.backend
825
.restore(snapshot.backend)
826
.map_err(VhostError::RestoreError)?;
827
Ok(())
828
}
829
}
830
}
831
832
fn get_shmem_config(&mut self) -> VhostResult<Vec<SharedMemoryRegion>> {
833
Ok(self
834
.backend
835
.get_shared_memory_region()
836
.into_iter()
837
.collect())
838
}
839
}
840
841
/// Keeps track of Vhost user backend request connection.
842
#[derive(Clone)]
843
pub struct VhostBackendReqConnection {
844
shared: Arc<Mutex<VhostBackendReqConnectionShared>>,
845
shmid: Option<u8>,
846
}
847
848
struct VhostBackendReqConnectionShared {
849
conn: FrontendClient,
850
mapped_regions: BTreeMap<u64 /* offset */, u64 /* size */>,
851
}
852
853
impl VhostBackendReqConnection {
854
fn new(conn: FrontendClient, shmid: Option<u8>) -> Self {
855
Self {
856
shared: Arc::new(Mutex::new(VhostBackendReqConnectionShared {
857
conn,
858
mapped_regions: BTreeMap::new(),
859
})),
860
shmid,
861
}
862
}
863
864
/// Send `VHOST_USER_CONFIG_CHANGE_MSG` to the frontend
865
fn send_config_changed(&self) -> anyhow::Result<()> {
866
let mut shared = self.shared.lock();
867
shared
868
.conn
869
.handle_config_change()
870
.context("Could not send config change message")?;
871
Ok(())
872
}
873
874
/// Create a SharedMemoryMapper trait object using this backend request connection.
875
pub fn shmem_mapper(&self) -> Option<Box<dyn SharedMemoryMapper>> {
876
if let Some(shmid) = self.shmid {
877
Some(Box::new(VhostShmemMapper {
878
shared: self.shared.clone(),
879
shmid,
880
}))
881
} else {
882
None
883
}
884
}
885
}
886
887
#[derive(Clone)]
888
struct VhostShmemMapper {
889
shared: Arc<Mutex<VhostBackendReqConnectionShared>>,
890
shmid: u8,
891
}
892
893
impl SharedMemoryMapper for VhostShmemMapper {
894
fn add_mapping(
895
&mut self,
896
source: VmMemorySource,
897
offset: u64,
898
prot: Protection,
899
_cache: MemCacheType,
900
) -> anyhow::Result<()> {
901
let mut shared = self.shared.lock();
902
let size = match source {
903
VmMemorySource::Vulkan {
904
descriptor,
905
handle_type,
906
memory_idx,
907
device_uuid,
908
driver_uuid,
909
size,
910
} => {
911
let msg = VhostUserGpuMapMsg::new(
912
self.shmid,
913
offset,
914
size,
915
memory_idx,
916
handle_type,
917
device_uuid,
918
driver_uuid,
919
);
920
shared
921
.conn
922
.gpu_map(&msg, &descriptor)
923
.context("map GPU memory")?;
924
size
925
}
926
VmMemorySource::ExternalMapping { ptr, size } => {
927
let msg = VhostUserExternalMapMsg::new(self.shmid, offset, size, ptr);
928
shared
929
.conn
930
.external_map(&msg)
931
.context("create external mapping")?;
932
size
933
}
934
source => {
935
// The last two sources use the same VhostUserMMap, continue matching here
936
// on the aliased `source` above.
937
let (descriptor, fd_offset, size) = match source {
938
VmMemorySource::Descriptor {
939
descriptor,
940
offset,
941
size,
942
} => (descriptor, offset, size),
943
VmMemorySource::SharedMemory(shmem) => {
944
let size = shmem.size();
945
let descriptor = SafeDescriptor::from(shmem);
946
(descriptor, 0, size)
947
}
948
_ => bail!("unsupported source"),
949
};
950
let mut flags = VhostUserMMapFlags::empty();
951
anyhow::ensure!(prot.allows(&Protection::read()), "mapping must be readable");
952
if prot.allows(&Protection::write()) {
953
flags |= VhostUserMMapFlags::MAP_RW;
954
}
955
let msg = VhostUserMMap {
956
shmid: self.shmid,
957
padding: Default::default(),
958
fd_offset,
959
shm_offset: offset,
960
len: size,
961
flags,
962
};
963
shared
964
.conn
965
.shmem_map(&msg, &descriptor)
966
.context("map shmem")?;
967
size
968
}
969
};
970
971
shared.mapped_regions.insert(offset, size);
972
Ok(())
973
}
974
975
fn remove_mapping(&mut self, offset: u64) -> anyhow::Result<()> {
976
let mut shared = self.shared.lock();
977
let size = shared
978
.mapped_regions
979
.remove(&offset)
980
.context("unknown offset")?;
981
let msg = VhostUserMMap {
982
shmid: self.shmid,
983
padding: Default::default(),
984
fd_offset: 0,
985
shm_offset: offset,
986
len: size,
987
flags: VhostUserMMapFlags::empty(),
988
};
989
shared
990
.conn
991
.shmem_unmap(&msg)
992
.context("unmap shmem")
993
.map(|_| ())
994
}
995
}
996
997
pub(crate) struct WorkerState<T, U> {
998
pub(crate) queue_task: TaskHandle<U>,
999
pub(crate) queue: T,
1000
}
1001
1002
/// Errors for device operations
1003
#[derive(Debug, ThisError)]
1004
pub enum Error {
1005
#[error("worker not found when stopping queue")]
1006
WorkerNotFound,
1007
}
1008
1009
#[cfg(test)]
1010
mod tests {
1011
use std::sync::mpsc::channel;
1012
1013
use anyhow::bail;
1014
use base::Event;
1015
use virtio_sys::virtio_ring::VIRTIO_RING_F_EVENT_IDX;
1016
use vmm_vhost::BackendServer;
1017
use vmm_vhost::FrontendReq;
1018
use zerocopy::FromBytes;
1019
use zerocopy::FromZeros;
1020
use zerocopy::Immutable;
1021
use zerocopy::IntoBytes;
1022
use zerocopy::KnownLayout;
1023
1024
use super::*;
1025
use crate::virtio::vhost_user_frontend::VhostUserFrontend;
1026
use crate::virtio::DeviceType;
1027
use crate::virtio::VirtioDevice;
1028
1029
#[derive(Clone, Copy, Debug, PartialEq, Eq, FromBytes, Immutable, IntoBytes, KnownLayout)]
1030
#[repr(C, packed(4))]
1031
struct FakeConfig {
1032
x: u32,
1033
y: u64,
1034
}
1035
1036
const FAKE_CONFIG_DATA: FakeConfig = FakeConfig { x: 1, y: 2 };
1037
1038
pub(super) struct FakeBackend {
1039
avail_features: u64,
1040
acked_features: u64,
1041
active_queues: Vec<Option<Queue>>,
1042
allow_backend_req: bool,
1043
backend_conn: Option<VhostBackendReqConnection>,
1044
}
1045
1046
#[derive(Deserialize, Serialize)]
1047
struct FakeBackendSnapshot {
1048
data: Vec<u8>,
1049
}
1050
1051
impl FakeBackend {
1052
const MAX_QUEUE_NUM: usize = 16;
1053
1054
pub(super) fn new() -> Self {
1055
let mut active_queues = Vec::new();
1056
active_queues.resize_with(Self::MAX_QUEUE_NUM, Default::default);
1057
Self {
1058
avail_features: 1 << VHOST_USER_F_PROTOCOL_FEATURES | 1 << VIRTIO_RING_F_EVENT_IDX,
1059
acked_features: 0,
1060
active_queues,
1061
allow_backend_req: false,
1062
backend_conn: None,
1063
}
1064
}
1065
}
1066
1067
impl VhostUserDevice for FakeBackend {
1068
fn max_queue_num(&self) -> usize {
1069
Self::MAX_QUEUE_NUM
1070
}
1071
1072
fn features(&self) -> u64 {
1073
self.avail_features
1074
}
1075
1076
fn ack_features(&mut self, value: u64) -> anyhow::Result<()> {
1077
let unrequested_features = value & !self.avail_features;
1078
if unrequested_features != 0 {
1079
bail!(
1080
"invalid protocol features are given: 0x{:x}",
1081
unrequested_features
1082
);
1083
}
1084
self.acked_features |= value;
1085
Ok(())
1086
}
1087
1088
fn protocol_features(&self) -> VhostUserProtocolFeatures {
1089
let mut features =
1090
VhostUserProtocolFeatures::CONFIG | VhostUserProtocolFeatures::DEVICE_STATE;
1091
if self.allow_backend_req {
1092
features |= VhostUserProtocolFeatures::BACKEND_REQ;
1093
}
1094
features
1095
}
1096
1097
fn read_config(&self, offset: u64, dst: &mut [u8]) {
1098
dst.copy_from_slice(&FAKE_CONFIG_DATA.as_bytes()[offset as usize..]);
1099
}
1100
1101
fn reset(&mut self) {}
1102
1103
fn start_queue(
1104
&mut self,
1105
idx: usize,
1106
queue: Queue,
1107
_mem: GuestMemory,
1108
) -> anyhow::Result<()> {
1109
self.active_queues[idx] = Some(queue);
1110
Ok(())
1111
}
1112
1113
fn stop_queue(&mut self, idx: usize) -> anyhow::Result<Queue> {
1114
Ok(self.active_queues[idx]
1115
.take()
1116
.ok_or(Error::WorkerNotFound)?)
1117
}
1118
1119
fn set_backend_req_connection(&mut self, conn: VhostBackendReqConnection) {
1120
self.backend_conn = Some(conn);
1121
}
1122
1123
fn enter_suspended_state(&mut self) -> anyhow::Result<()> {
1124
Ok(())
1125
}
1126
1127
fn snapshot(&mut self) -> anyhow::Result<AnySnapshot> {
1128
AnySnapshot::to_any(FakeBackendSnapshot {
1129
data: vec![1, 2, 3],
1130
})
1131
.context("failed to serialize snapshot")
1132
}
1133
1134
fn restore(&mut self, data: AnySnapshot) -> anyhow::Result<()> {
1135
let snapshot: FakeBackendSnapshot =
1136
AnySnapshot::from_any(data).context("failed to deserialize snapshot")?;
1137
assert_eq!(snapshot.data, vec![1, 2, 3], "bad snapshot data");
1138
Ok(())
1139
}
1140
}
1141
1142
fn create_queues(
1143
num: usize,
1144
mem: &GuestMemory,
1145
interrupt: &Interrupt,
1146
) -> BTreeMap<usize, Queue> {
1147
let mut queues = BTreeMap::new();
1148
for idx in 0..num {
1149
let mut queue = QueueConfig::new(0x10, 0);
1150
queue.set_ready(true);
1151
let queue = queue
1152
.activate(mem, Event::new().unwrap(), interrupt.clone())
1153
.expect("QueueConfig::activate");
1154
queues.insert(idx, queue);
1155
}
1156
queues
1157
}
1158
1159
#[test]
1160
fn test_vhost_user_lifecycle() {
1161
test_vhost_user_lifecycle_parameterized(false);
1162
}
1163
1164
#[test]
1165
#[cfg(not(windows))] // Windows requries more complex connection setup.
1166
fn test_vhost_user_lifecycle_with_backend_req() {
1167
test_vhost_user_lifecycle_parameterized(true);
1168
}
1169
1170
fn test_vhost_user_lifecycle_parameterized(allow_backend_req: bool) {
1171
const QUEUES_NUM: usize = 2;
1172
const BASE_FEATURES: u64 = 1 << VIRTIO_RING_F_EVENT_IDX;
1173
const EXPECTED_FEATURES: u64 =
1174
1 << VHOST_USER_F_PROTOCOL_FEATURES | 1 << VIRTIO_RING_F_EVENT_IDX;
1175
1176
// First phase: Test normal usage, then take a snapshot and shutdown.
1177
let snapshot = {
1178
let (client_connection, server_connection) = vmm_vhost::Connection::pair().unwrap();
1179
let (shutdown_tx, shutdown_rx) = channel();
1180
let (vm_evt_wrtube, _vm_evt_rdtube) = base::Tube::directional_pair().unwrap();
1181
let vmm_thread = std::thread::spawn(move || {
1182
// VMM side
1183
let mut vmm_device = VhostUserFrontend::new(
1184
DeviceType::Console,
1185
BASE_FEATURES,
1186
client_connection,
1187
vm_evt_wrtube,
1188
None,
1189
None,
1190
)
1191
.unwrap();
1192
1193
vmm_device.ack_features(BASE_FEATURES);
1194
1195
let mem = GuestMemory::new(&[(GuestAddress(0x0), 0x10000)]).unwrap();
1196
let interrupt = Interrupt::new_for_test_with_msix();
1197
1198
println!("read_config");
1199
let mut config = FakeConfig::new_zeroed();
1200
vmm_device.read_config(0, config.as_mut_bytes());
1201
// Check if the obtained config data is correct.
1202
assert_eq!(config, FAKE_CONFIG_DATA);
1203
1204
println!("activate");
1205
vmm_device
1206
.activate(
1207
mem.clone(),
1208
interrupt.clone(),
1209
create_queues(QUEUES_NUM, &mem, &interrupt),
1210
)
1211
.unwrap();
1212
1213
println!("reset");
1214
let reset_result = vmm_device.reset();
1215
assert!(
1216
reset_result.is_ok(),
1217
"reset failed: {:#}",
1218
reset_result.unwrap_err()
1219
);
1220
1221
println!("activate");
1222
vmm_device
1223
.activate(
1224
mem.clone(),
1225
interrupt.clone(),
1226
create_queues(QUEUES_NUM, &mem, &interrupt),
1227
)
1228
.unwrap();
1229
1230
println!("virtio_sleep");
1231
let queues = vmm_device
1232
.virtio_sleep()
1233
.unwrap()
1234
.expect("virtio_sleep unexpectedly returned None");
1235
1236
println!("virtio_snapshot");
1237
let snapshot = vmm_device
1238
.virtio_snapshot()
1239
.expect("virtio_snapshot failed");
1240
1241
println!("virtio_wake");
1242
vmm_device
1243
.virtio_wake(Some((mem.clone(), interrupt.clone(), queues)))
1244
.unwrap();
1245
1246
println!("wait for shutdown signal");
1247
shutdown_rx.recv().unwrap();
1248
1249
// The VMM side is supposed to stop before the device side.
1250
println!("drop");
1251
1252
snapshot
1253
});
1254
1255
// Device side
1256
let mut handler = DeviceRequestHandler::new(FakeBackend::new());
1257
handler.as_mut().allow_backend_req = allow_backend_req;
1258
1259
let mut req_handler = BackendServer::new(server_connection, handler);
1260
1261
// VhostUserFrontend::new()
1262
handle_request(&mut req_handler, FrontendReq::SET_OWNER).unwrap();
1263
handle_request(&mut req_handler, FrontendReq::GET_FEATURES).unwrap();
1264
handle_request(&mut req_handler, FrontendReq::GET_PROTOCOL_FEATURES).unwrap();
1265
handle_request(&mut req_handler, FrontendReq::SET_PROTOCOL_FEATURES).unwrap();
1266
if allow_backend_req {
1267
handle_request(&mut req_handler, FrontendReq::SET_BACKEND_REQ_FD).unwrap();
1268
}
1269
1270
// VhostUserFrontend::read_config()
1271
handle_request(&mut req_handler, FrontendReq::GET_CONFIG).unwrap();
1272
1273
// VhostUserFrontend::activate()
1274
handle_request(&mut req_handler, FrontendReq::SET_FEATURES).unwrap();
1275
assert_eq!(req_handler.as_ref().acked_features, EXPECTED_FEATURES);
1276
handle_request(&mut req_handler, FrontendReq::SET_MEM_TABLE).unwrap();
1277
for _ in 0..QUEUES_NUM {
1278
handle_request(&mut req_handler, FrontendReq::SET_VRING_NUM).unwrap();
1279
handle_request(&mut req_handler, FrontendReq::SET_VRING_ADDR).unwrap();
1280
handle_request(&mut req_handler, FrontendReq::SET_VRING_BASE).unwrap();
1281
handle_request(&mut req_handler, FrontendReq::SET_VRING_CALL).unwrap();
1282
handle_request(&mut req_handler, FrontendReq::SET_VRING_KICK).unwrap();
1283
handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1284
}
1285
1286
// VhostUserFrontend::reset()
1287
for _ in 0..QUEUES_NUM {
1288
handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1289
handle_request(&mut req_handler, FrontendReq::GET_VRING_BASE).unwrap();
1290
}
1291
1292
// VhostUserFrontend::activate()
1293
handle_request(&mut req_handler, FrontendReq::SET_MEM_TABLE).unwrap();
1294
for _ in 0..QUEUES_NUM {
1295
handle_request(&mut req_handler, FrontendReq::SET_VRING_NUM).unwrap();
1296
handle_request(&mut req_handler, FrontendReq::SET_VRING_ADDR).unwrap();
1297
handle_request(&mut req_handler, FrontendReq::SET_VRING_BASE).unwrap();
1298
handle_request(&mut req_handler, FrontendReq::SET_VRING_CALL).unwrap();
1299
handle_request(&mut req_handler, FrontendReq::SET_VRING_KICK).unwrap();
1300
handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1301
}
1302
1303
if allow_backend_req {
1304
// Make sure the connection still works even after reset/reactivate.
1305
req_handler
1306
.as_ref()
1307
.as_ref()
1308
.backend_conn
1309
.as_ref()
1310
.expect("backend_conn missing")
1311
.send_config_changed()
1312
.expect("send_config_changed failed");
1313
}
1314
1315
// VhostUserFrontend::virtio_sleep()
1316
for _ in 0..QUEUES_NUM {
1317
handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1318
handle_request(&mut req_handler, FrontendReq::GET_VRING_BASE).unwrap();
1319
}
1320
1321
// VhostUserFrontend::virtio_snapshot()
1322
handle_request(&mut req_handler, FrontendReq::SET_DEVICE_STATE_FD).unwrap();
1323
handle_request(&mut req_handler, FrontendReq::CHECK_DEVICE_STATE).unwrap();
1324
1325
// VhostUserFrontend::virtio_wake()
1326
handle_request(&mut req_handler, FrontendReq::SET_MEM_TABLE).unwrap();
1327
for _ in 0..QUEUES_NUM {
1328
handle_request(&mut req_handler, FrontendReq::SET_VRING_NUM).unwrap();
1329
handle_request(&mut req_handler, FrontendReq::SET_VRING_ADDR).unwrap();
1330
handle_request(&mut req_handler, FrontendReq::SET_VRING_BASE).unwrap();
1331
handle_request(&mut req_handler, FrontendReq::SET_VRING_CALL).unwrap();
1332
handle_request(&mut req_handler, FrontendReq::SET_VRING_KICK).unwrap();
1333
handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1334
}
1335
1336
if allow_backend_req {
1337
// Make sure the connection still works even after sleep/wake.
1338
req_handler
1339
.as_ref()
1340
.as_ref()
1341
.backend_conn
1342
.as_ref()
1343
.expect("backend_conn missing")
1344
.send_config_changed()
1345
.expect("send_config_changed failed");
1346
}
1347
1348
// Ask the client to shutdown, then wait to it to finish.
1349
shutdown_tx.send(()).unwrap();
1350
1351
// Verify recv_header fails with `ClientExit` after the client has disconnected.
1352
match req_handler.recv_header() {
1353
Err(VhostError::ClientExit) => (),
1354
r => panic!("expected Err(ClientExit) but got {r:?}"),
1355
}
1356
1357
vmm_thread.join().unwrap()
1358
};
1359
1360
// Second phase: Restore the snapshot.
1361
{
1362
let (client_connection, server_connection) = vmm_vhost::Connection::pair().unwrap();
1363
let (shutdown_tx, shutdown_rx) = channel();
1364
let (vm_evt_wrtube, _vm_evt_rdtube) = base::Tube::directional_pair().unwrap();
1365
let vmm_thread = std::thread::spawn(move || {
1366
// VMM side
1367
let mut vmm_device = VhostUserFrontend::new(
1368
DeviceType::Console,
1369
BASE_FEATURES,
1370
client_connection,
1371
vm_evt_wrtube,
1372
None,
1373
None,
1374
)
1375
.unwrap();
1376
1377
let mem = GuestMemory::new(&[(GuestAddress(0x0), 0x10000)]).unwrap();
1378
let interrupt = Interrupt::new_for_test_with_msix();
1379
1380
println!("virtio_sleep");
1381
assert!(vmm_device.virtio_sleep().unwrap().is_none());
1382
1383
println!("virtio_restore");
1384
vmm_device
1385
.virtio_restore(snapshot)
1386
.expect("virtio_restore failed");
1387
1388
println!("virtio_wake");
1389
vmm_device
1390
.virtio_wake(Some((
1391
mem.clone(),
1392
interrupt.clone(),
1393
create_queues(QUEUES_NUM, &mem, &interrupt),
1394
)))
1395
.unwrap();
1396
1397
println!("wait for shutdown signal");
1398
shutdown_rx.recv().unwrap();
1399
1400
// The VMM side is supposed to stop before the device side.
1401
println!("drop");
1402
});
1403
1404
// Device side
1405
let mut handler = DeviceRequestHandler::new(FakeBackend::new());
1406
handler.as_mut().allow_backend_req = allow_backend_req;
1407
1408
let mut req_handler = BackendServer::new(server_connection, handler);
1409
1410
// VhostUserFrontend::new()
1411
handle_request(&mut req_handler, FrontendReq::SET_OWNER).unwrap();
1412
handle_request(&mut req_handler, FrontendReq::GET_FEATURES).unwrap();
1413
handle_request(&mut req_handler, FrontendReq::GET_PROTOCOL_FEATURES).unwrap();
1414
handle_request(&mut req_handler, FrontendReq::SET_PROTOCOL_FEATURES).unwrap();
1415
if allow_backend_req {
1416
handle_request(&mut req_handler, FrontendReq::SET_BACKEND_REQ_FD).unwrap();
1417
}
1418
1419
// VhostUserFrontend::virtio_sleep()
1420
// (no-op)
1421
1422
// VhostUserFrontend::virtio_restore()
1423
handle_request(&mut req_handler, FrontendReq::SET_FEATURES).unwrap();
1424
assert_eq!(req_handler.as_ref().acked_features, EXPECTED_FEATURES);
1425
handle_request(&mut req_handler, FrontendReq::SET_DEVICE_STATE_FD).unwrap();
1426
handle_request(&mut req_handler, FrontendReq::CHECK_DEVICE_STATE).unwrap();
1427
1428
// VhostUserFrontend::virtio_wake()
1429
handle_request(&mut req_handler, FrontendReq::SET_MEM_TABLE).unwrap();
1430
for _ in 0..QUEUES_NUM {
1431
handle_request(&mut req_handler, FrontendReq::SET_VRING_NUM).unwrap();
1432
handle_request(&mut req_handler, FrontendReq::SET_VRING_ADDR).unwrap();
1433
handle_request(&mut req_handler, FrontendReq::SET_VRING_BASE).unwrap();
1434
handle_request(&mut req_handler, FrontendReq::SET_VRING_CALL).unwrap();
1435
handle_request(&mut req_handler, FrontendReq::SET_VRING_KICK).unwrap();
1436
handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1437
}
1438
1439
if allow_backend_req {
1440
// Make sure the connection still works even after restore.
1441
req_handler
1442
.as_ref()
1443
.as_ref()
1444
.backend_conn
1445
.as_ref()
1446
.expect("backend_conn missing")
1447
.send_config_changed()
1448
.expect("send_config_changed failed");
1449
}
1450
1451
// Ask the client to shutdown, then wait to it to finish.
1452
shutdown_tx.send(()).unwrap();
1453
// Verify recv_header fails with `ClientExit` after the client has disconnected.
1454
match req_handler.recv_header() {
1455
Err(VhostError::ClientExit) => (),
1456
r => panic!("expected Err(ClientExit) but got {r:?}"),
1457
}
1458
vmm_thread.join().unwrap();
1459
}
1460
}
1461
1462
#[track_caller]
1463
fn handle_request<S: vmm_vhost::Backend>(
1464
handler: &mut BackendServer<S>,
1465
expected_message_type: FrontendReq,
1466
) -> Result<(), VhostError> {
1467
let (hdr, files) = handler.recv_header()?;
1468
assert_eq!(hdr.get_code(), Ok(expected_message_type));
1469
handler.process_message(hdr, files)
1470
}
1471
}
1472
1473