Path: blob/main/devices/src/virtio/vhost_user_backend/handler.rs
5394 views
// Copyright 2021 The ChromiumOS Authors1// Use of this source code is governed by a BSD-style license that can be2// found in the LICENSE file.34//! Library for implementing vhost-user device executables.5//!6//! This crate provides7//! * `VhostUserDevice` trait, which is a collection of methods to handle vhost-user requests, and8//! * `DeviceRequestHandler` struct, which makes a connection to a VMM and starts an event loop.9//!10//! They are expected to be used as follows:11//!12//! 1. Define a struct and implement `VhostUserDevice` for it.13//! 2. Create a `DeviceRequestHandler` with the backend struct.14//! 3. Drive the `DeviceRequestHandler::run` async fn with an executor.15//!16//! ```ignore17//! struct MyBackend {18//! /* fields */19//! }20//!21//! impl VhostUserDevice for MyBackend {22//! /* implement methods */23//! }24//!25//! fn main() -> Result<(), Box<dyn Error>> {26//! let backend = MyBackend { /* initialize fields */ };27//! let handler = DeviceRequestHandler::new(backend);28//! let socket = std::path::Path("/path/to/socket");29//! let ex = cros_async::Executor::new()?;30//!31//! if let Err(e) = ex.run_until(handler.run(socket, &ex)) {32//! eprintln!("error happened: {}", e);33//! }34//! Ok(())35//! }36//! ```37// Implementation note:38// This code lets us take advantage of the vmm_vhost low level implementation of the vhost user39// protocol. DeviceRequestHandler implements the Backend trait from vmm_vhost, and includes some40// common code for setting up guest memory and managing partially configured vrings.41// DeviceRequestHandler::run watches the vhost-user socket and then calls handle_request() when it42// becomes readable. handle_request() reads and parses the message and then calls one of the43// Backend trait methods. These dispatch back to the supplied VhostUserDevice implementation (this44// is what our devices implement).4546pub(super) mod sys;4748use std::collections::BTreeMap;49use std::convert::From;50use std::fs::File;51use std::io::BufReader;52use std::io::Write;53use std::num::Wrapping;54#[cfg(any(target_os = "android", target_os = "linux"))]55use std::os::unix::io::AsRawFd;56use std::sync::Arc;5758use anyhow::bail;59use anyhow::Context;60#[cfg(any(target_os = "android", target_os = "linux"))]61use base::clear_fd_flags;62use base::error;63use base::trace;64use base::warn;65use base::Event;66use base::Protection;67use base::SafeDescriptor;68use base::SharedMemory;69use base::WorkerThread;70use cros_async::TaskHandle;71use hypervisor::MemCacheType;72use serde::Deserialize;73use serde::Serialize;74use snapshot::AnySnapshot;75use sync::Mutex;76use thiserror::Error as ThisError;77use vm_control::VmMemorySource;78use vm_memory::GuestAddress;79use vm_memory::GuestMemory;80use vm_memory::MemoryRegion;81use vmm_vhost::message::VhostUserConfigFlags;82use vmm_vhost::message::VhostUserExternalMapMsg;83use vmm_vhost::message::VhostUserGpuMapMsg;84use vmm_vhost::message::VhostUserInflight;85use vmm_vhost::message::VhostUserMMap;86use vmm_vhost::message::VhostUserMMapFlags;87use vmm_vhost::message::VhostUserMemoryRegion;88use vmm_vhost::message::VhostUserMigrationPhase;89use vmm_vhost::message::VhostUserProtocolFeatures;90use vmm_vhost::message::VhostUserSingleMemoryRegion;91use vmm_vhost::message::VhostUserTransferDirection;92use vmm_vhost::message::VhostUserVringAddrFlags;93use vmm_vhost::message::VhostUserVringState;94use vmm_vhost::Connection;95use vmm_vhost::Error as VhostError;96use vmm_vhost::Frontend;97use vmm_vhost::FrontendClient;98use vmm_vhost::Result as VhostResult;99use vmm_vhost::VHOST_USER_F_PROTOCOL_FEATURES;100101use crate::virtio::Interrupt;102use crate::virtio::Queue;103use crate::virtio::QueueConfig;104use crate::virtio::SharedMemoryMapper;105use crate::virtio::SharedMemoryRegion;106107/// Keeps a mapping from the vmm's virtual addresses to guest addresses.108/// used to translate messages from the vmm to guest offsets.109#[derive(Default)]110pub struct MappingInfo {111pub vmm_addr: u64,112pub guest_phys: u64,113pub size: u64,114}115116pub fn vmm_va_to_gpa(maps: &[MappingInfo], vmm_va: u64) -> VhostResult<GuestAddress> {117for map in maps {118if vmm_va >= map.vmm_addr && vmm_va < map.vmm_addr + map.size {119return Ok(GuestAddress(vmm_va - map.vmm_addr + map.guest_phys));120}121}122Err(VhostError::InvalidMessage)123}124125/// Trait for vhost-user devices. Analogous to the `VirtioDevice` trait.126///127/// In contrast with [[vmm_vhost::Backend]], which closely matches the vhost-user spec, this trait128/// is designed to follow crosvm conventions for implementing devices.129pub trait VhostUserDevice {130/// The maximum number of queues that this backend can manage.131fn max_queue_num(&self) -> usize;132133/// The set of feature bits that this backend supports.134fn features(&self) -> u64;135136/// Acknowledges that this set of features should be enabled.137///138/// Implementations only need to handle device-specific feature bits; the `DeviceRequestHandler`139/// framework will manage generic vhost and vring features.140///141/// `DeviceRequestHandler` checks for valid features before calling this function, so the142/// features in `value` will always be a subset of those advertised by `features()`.143fn ack_features(&mut self, _value: u64) -> anyhow::Result<()> {144Ok(())145}146147/// The set of protocol feature bits that this backend supports.148fn protocol_features(&self) -> VhostUserProtocolFeatures;149150/// Reads this device configuration space at `offset`.151fn read_config(&self, offset: u64, dst: &mut [u8]);152153/// writes `data` to this device's configuration space at `offset`.154fn write_config(&self, _offset: u64, _data: &[u8]) {}155156/// Indicates that the backend should start processing requests for virtio queue number `idx`.157/// This method must not block the current thread so device backends should either spawn an158/// async task or another thread to handle messages from the Queue.159fn start_queue(&mut self, idx: usize, queue: Queue, mem: GuestMemory) -> anyhow::Result<()>;160161/// Indicates that the backend should stop processing requests for virtio queue number `idx`.162/// This method should return the queue passed to `start_queue` for the corresponding `idx`.163/// This method will only be called for queues that were previously started by `start_queue`.164fn stop_queue(&mut self, idx: usize) -> anyhow::Result<Queue>;165166/// Resets the vhost-user backend.167fn reset(&mut self);168169/// Returns the device's shared memory region if present.170fn get_shared_memory_region(&self) -> Option<SharedMemoryRegion> {171None172}173174/// Accepts `VhostBackendReqConnection` to conduct Vhost backend to frontend message175/// handling.176///177/// This method will be called when `VhostUserProtocolFeatures::BACKEND_REQ` is178/// negotiated.179fn set_backend_req_connection(&mut self, _conn: VhostBackendReqConnection) {}180181/// Enter the "suspended device state" described in the vhost-user spec. See the spec for182/// requirements.183///184/// One reasonably foolproof way to satisfy the requirements is to stop all worker threads.185///186/// Called after a `stop_queue` call if there are no running queues left. Also called soon187/// after device creation to ensure the device is acting suspended immediately on construction.188///189/// The next `start_queue` call implicitly exits the "suspend device state".190///191/// * Ok(()) => device successfully suspended192/// * Err(_) => unrecoverable error193fn enter_suspended_state(&mut self) -> anyhow::Result<()>;194195/// Snapshot device and return serialized state.196fn snapshot(&mut self) -> anyhow::Result<AnySnapshot>;197198/// Restore device state from a snapshot.199fn restore(&mut self, data: AnySnapshot) -> anyhow::Result<()>;200201/// Whether guest memory should be unmapped in forked processes.202///203/// This is intended for use in combination with --protected-vm, where the guest memory can be204/// dangerous to access. Some systems, e.g. Android, have tools that fork processes and examine205/// their memory. This flag effectively hides the guest memory from those tools.206///207/// Not compatible with sandboxing.208fn unmap_guest_memory_on_fork(&self) -> bool {209false210}211}212213/// A virtio ring entry.214struct Vring {215// The queue config. This doesn't get mutated by the queue workers.216queue: QueueConfig,217doorbell: Option<Interrupt>,218enabled: bool,219}220221impl Vring {222fn new(max_size: u16, features: u64) -> Self {223Self {224queue: QueueConfig::new(max_size, features),225doorbell: None,226enabled: false,227}228}229230fn reset(&mut self) {231self.queue.reset();232self.doorbell = None;233self.enabled = false;234}235}236237/// Ops for running vhost-user over a stream (i.e. regular protocol).238pub(super) struct VhostUserRegularOps;239240impl VhostUserRegularOps {241pub fn set_mem_table(242contexts: &[VhostUserMemoryRegion],243files: Vec<File>,244) -> VhostResult<(GuestMemory, Vec<MappingInfo>)> {245if files.len() != contexts.len() {246return Err(VhostError::InvalidParam(247"number of files & contexts was not equal",248));249}250251let mut regions = Vec::with_capacity(files.len());252for (region, file) in contexts.iter().zip(files.into_iter()) {253let region = MemoryRegion::new_from_shm(254region.memory_size,255GuestAddress(region.guest_phys_addr),256region.mmap_offset,257Arc::new(258SharedMemory::from_safe_descriptor(259SafeDescriptor::from(file),260region.memory_size,261)262.unwrap(),263),264)265.map_err(|e| {266error!("failed to create a memory region: {}", e);267VhostError::InvalidOperation268})?;269regions.push(region);270}271let guest_mem = GuestMemory::from_regions(regions).map_err(|e| {272error!("failed to create guest memory: {}", e);273VhostError::InvalidOperation274})?;275276let vmm_maps = contexts277.iter()278.map(|region| MappingInfo {279vmm_addr: region.user_addr,280guest_phys: region.guest_phys_addr,281size: region.memory_size,282})283.collect();284Ok((guest_mem, vmm_maps))285}286}287288/// An adapter that implements `vmm_vhost::Backend` for any type implementing `VhostUserDevice`.289pub struct DeviceRequestHandler<T: VhostUserDevice> {290vrings: Vec<Vring>,291owned: bool,292vmm_maps: Option<Vec<MappingInfo>>,293mem: Option<GuestMemory>,294acked_features: u64,295acked_protocol_features: VhostUserProtocolFeatures,296backend: T,297backend_req_connection: Option<VhostBackendReqConnection>,298// Thread processing active device state FD.299device_state_thread: Option<DeviceStateThread>,300}301302enum DeviceStateThread {303Save(WorkerThread<Result<(), ciborium::ser::Error<std::io::Error>>>),304Load(WorkerThread<Result<DeviceRequestHandlerSnapshot, ciborium::de::Error<std::io::Error>>>),305}306307#[derive(Serialize, Deserialize)]308pub struct DeviceRequestHandlerSnapshot {309acked_features: u64,310acked_protocol_features: u64,311backend: AnySnapshot,312}313314impl<T: VhostUserDevice> DeviceRequestHandler<T> {315/// Creates a vhost-user handler instance for `backend`.316pub(crate) fn new(mut backend: T) -> Self {317let mut vrings = Vec::with_capacity(backend.max_queue_num());318for _ in 0..backend.max_queue_num() {319vrings.push(Vring::new(Queue::MAX_SIZE, backend.features()));320}321322// VhostUserDevice implementations must support `enter_suspended_state()`.323// Call it on startup to ensure it works and to initialize the device in a suspended state.324backend325.enter_suspended_state()326.expect("enter_suspended_state failed on device init");327328DeviceRequestHandler {329vrings,330owned: false,331vmm_maps: None,332mem: None,333acked_features: 0,334acked_protocol_features: VhostUserProtocolFeatures::empty(),335backend,336backend_req_connection: None,337device_state_thread: None,338}339}340341/// Check if all queues are stopped.342///343/// The device can be suspended with `enter_suspended_state()` only when all queues are stopped.344fn all_queues_stopped(&self) -> bool {345self.vrings.iter().all(|vring| !vring.queue.ready())346}347}348349impl<T: VhostUserDevice> Drop for DeviceRequestHandler<T> {350fn drop(&mut self) {351for (index, vring) in self.vrings.iter().enumerate() {352if vring.queue.ready() {353if let Err(e) = self.backend.stop_queue(index) {354error!("Failed to stop queue {} during drop: {:#}", index, e);355}356}357}358}359}360361impl<T: VhostUserDevice> AsRef<T> for DeviceRequestHandler<T> {362fn as_ref(&self) -> &T {363&self.backend364}365}366367impl<T: VhostUserDevice> AsMut<T> for DeviceRequestHandler<T> {368fn as_mut(&mut self) -> &mut T {369&mut self.backend370}371}372373impl<T: VhostUserDevice> vmm_vhost::Backend for DeviceRequestHandler<T> {374fn set_owner(&mut self) -> VhostResult<()> {375if self.owned {376return Err(VhostError::InvalidOperation);377}378self.owned = true;379Ok(())380}381382fn reset_owner(&mut self) -> VhostResult<()> {383self.owned = false;384self.acked_features = 0;385self.backend.reset();386Ok(())387}388389fn get_features(&mut self) -> VhostResult<u64> {390let features = self.backend.features();391Ok(features)392}393394fn set_features(&mut self, features: u64) -> VhostResult<()> {395if !self.owned {396return Err(VhostError::InvalidOperation);397}398399let unexpected_features = features & !self.backend.features();400if unexpected_features != 0 {401error!("unexpected set_features {:#x}", unexpected_features);402return Err(VhostError::InvalidParam("unexpected set_features"));403}404405if let Err(e) = self.backend.ack_features(features) {406error!("failed to acknowledge features 0x{:x}: {}", features, e);407return Err(VhostError::InvalidOperation);408}409410self.acked_features |= features;411412// If VHOST_USER_F_PROTOCOL_FEATURES has not been negotiated, the ring is initialized in an413// enabled state.414// If VHOST_USER_F_PROTOCOL_FEATURES has been negotiated, the ring is initialized in a415// disabled state.416// Client must not pass data to/from the backend until ring is enabled by417// VHOST_USER_SET_VRING_ENABLE with parameter 1, or after it has been disabled by418// VHOST_USER_SET_VRING_ENABLE with parameter 0.419let vring_enabled = self.acked_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES != 0;420for v in &mut self.vrings {421v.enabled = vring_enabled;422}423424Ok(())425}426427fn get_protocol_features(&mut self) -> VhostResult<VhostUserProtocolFeatures> {428Ok(self.backend.protocol_features() | VhostUserProtocolFeatures::REPLY_ACK)429}430431fn set_protocol_features(&mut self, features: u64) -> VhostResult<()> {432let features = match VhostUserProtocolFeatures::from_bits(features) {433Some(proto_features) => proto_features,434None => {435error!(436"unsupported bits in VHOST_USER_SET_PROTOCOL_FEATURES: {:#x}",437features438);439return Err(VhostError::InvalidOperation);440}441};442let supported = self.get_protocol_features()?;443self.acked_protocol_features = features & supported;444Ok(())445}446447fn set_mem_table(448&mut self,449contexts: &[VhostUserMemoryRegion],450files: Vec<File>,451) -> VhostResult<()> {452let (guest_mem, vmm_maps) = VhostUserRegularOps::set_mem_table(contexts, files)?;453if self.backend.unmap_guest_memory_on_fork() {454#[cfg(any(target_os = "android", target_os = "linux"))]455if let Err(e) = guest_mem.use_dontfork() {456error!("failed to set MADV_DONTFORK on guest memory: {e:#}");457}458#[cfg(not(any(target_os = "android", target_os = "linux")))]459error!("unmap_guest_memory_on_fork unsupported; skipping");460}461self.mem = Some(guest_mem);462self.vmm_maps = Some(vmm_maps);463Ok(())464}465466fn get_queue_num(&mut self) -> VhostResult<u64> {467Ok(self.vrings.len() as u64)468}469470fn set_vring_num(&mut self, index: u32, num: u32) -> VhostResult<()> {471if index as usize >= self.vrings.len() || num == 0 || num > Queue::MAX_SIZE.into() {472return Err(VhostError::InvalidParam(473"set_vring_num: invalid index or num",474));475}476self.vrings[index as usize].queue.set_size(num as u16);477478Ok(())479}480481fn set_vring_addr(482&mut self,483index: u32,484_flags: VhostUserVringAddrFlags,485descriptor: u64,486used: u64,487available: u64,488_log: u64,489) -> VhostResult<()> {490if index as usize >= self.vrings.len() {491return Err(VhostError::InvalidParam(492"set_vring_addr: index out of range",493));494}495496let vmm_maps = self497.vmm_maps498.as_ref()499.ok_or(VhostError::InvalidParam("set_vring_addr: missing vmm_maps"))?;500let vring = &mut self.vrings[index as usize];501vring502.queue503.set_desc_table(vmm_va_to_gpa(vmm_maps, descriptor)?);504vring505.queue506.set_avail_ring(vmm_va_to_gpa(vmm_maps, available)?);507vring.queue.set_used_ring(vmm_va_to_gpa(vmm_maps, used)?);508509Ok(())510}511512fn set_vring_base(&mut self, index: u32, base: u32) -> VhostResult<()> {513if index as usize >= self.vrings.len() {514return Err(VhostError::InvalidParam(515"set_vring_base: index out of range",516));517}518519let vring = &mut self.vrings[index as usize];520vring.queue.set_next_avail(Wrapping(base as u16));521vring.queue.set_next_used(Wrapping(base as u16));522523Ok(())524}525526fn get_vring_base(&mut self, index: u32) -> VhostResult<VhostUserVringState> {527let vring = self528.vrings529.get_mut(index as usize)530.ok_or(VhostError::InvalidParam(531"get_vring_base: index out of range",532))?;533534// Quotation from vhost-user spec:535// "The back-end must [...] stop ring upon receiving VHOST_USER_GET_VRING_BASE."536// We only call `queue.set_ready()` when starting the queue, so if the queue is ready, that537// means it is started and should be stopped.538let vring_base = if vring.queue.ready() {539let queue = match self.backend.stop_queue(index as usize) {540Ok(q) => q,541Err(e) => {542error!("Failed to stop queue in get_vring_base: {:#}", e);543return Err(VhostError::BackendInternalError);544}545};546547trace!("stopped queue {index}");548vring.reset();549550if self.all_queues_stopped() {551trace!("all queues stopped; entering suspended state");552self.backend553.enter_suspended_state()554.map_err(VhostError::EnterSuspendedState)?;555}556557queue.next_avail_to_process()558} else {5590560};561562Ok(VhostUserVringState::new(index, vring_base.into()))563}564565fn set_vring_kick(&mut self, index: u8, file: Option<File>) -> VhostResult<()> {566if index as usize >= self.vrings.len() {567return Err(VhostError::InvalidParam(568"set_vring_kick: index out of range",569));570}571572let vring = &mut self.vrings[index as usize];573if vring.queue.ready() {574error!("kick fd cannot replaced after queue is started");575return Err(VhostError::InvalidOperation);576}577578let file = file.ok_or(VhostError::InvalidParam("missing file for set_vring_kick"))?;579580// Remove O_NONBLOCK from kick_fd. Otherwise, uring_executor will fails when we read581// values via `next_val()` later.582// This is only required (and can only be done) on Unix platforms.583#[cfg(any(target_os = "android", target_os = "linux"))]584if let Err(e) = clear_fd_flags(file.as_raw_fd(), libc::O_NONBLOCK) {585error!("failed to remove O_NONBLOCK for kick fd: {}", e);586return Err(VhostError::InvalidParam(587"could not remove O_NONBLOCK from vring_kick",588));589}590591let kick_evt = Event::from(SafeDescriptor::from(file));592593// Enable any virtqueue features that were negotiated (like VIRTIO_RING_F_EVENT_IDX).594vring.queue.ack_features(self.acked_features);595vring.queue.set_ready(true);596597let mem = self598.mem599.as_ref()600.cloned()601.ok_or(VhostError::InvalidOperation)?;602603let doorbell = vring.doorbell.clone().ok_or(VhostError::InvalidOperation)?;604605let queue = match vring.queue.activate(&mem, kick_evt, doorbell) {606Ok(queue) => queue,607Err(e) => {608error!("failed to activate vring: {:#}", e);609return Err(VhostError::BackendInternalError);610}611};612613if let Err(e) = self.backend.start_queue(index as usize, queue, mem) {614error!("Failed to start queue {}: {}", index, e);615return Err(VhostError::BackendInternalError);616}617trace!("started queue {index}");618619Ok(())620}621622fn set_vring_call(&mut self, index: u8, file: Option<File>) -> VhostResult<()> {623if index as usize >= self.vrings.len() {624return Err(VhostError::InvalidParam(625"set_vring_call: index out of range",626));627}628629let backend_req_conn = self.backend_req_connection.clone();630let signal_config_change_fn = Box::new(move || {631if let Some(frontend) = backend_req_conn.as_ref() {632if let Err(e) = frontend.send_config_changed() {633error!("Failed to notify config change: {:#}", e);634}635} else {636error!("No Backend request connection found");637}638});639640let file = file.ok_or(VhostError::InvalidParam("missing file for set_vring_call"))?;641self.vrings[index as usize].doorbell = Some(Interrupt::new_vhost_user(642Event::from(SafeDescriptor::from(file)),643signal_config_change_fn,644));645Ok(())646}647648fn set_vring_err(&mut self, _index: u8, _fd: Option<File>) -> VhostResult<()> {649// TODO650Ok(())651}652653fn set_vring_enable(&mut self, index: u32, enable: bool) -> VhostResult<()> {654if index as usize >= self.vrings.len() {655return Err(VhostError::InvalidParam(656"set_vring_enable: index out of range",657));658}659660// This request should be handled only when VHOST_USER_F_PROTOCOL_FEATURES661// has been negotiated.662if self.acked_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES == 0 {663return Err(VhostError::InvalidOperation);664}665666// Backend must not pass data to/from the ring until ring is enabled by667// VHOST_USER_SET_VRING_ENABLE with parameter 1, or after it has been disabled by668// VHOST_USER_SET_VRING_ENABLE with parameter 0.669self.vrings[index as usize].enabled = enable;670671Ok(())672}673674fn get_config(675&mut self,676offset: u32,677size: u32,678_flags: VhostUserConfigFlags,679) -> VhostResult<Vec<u8>> {680let mut data = vec![0; size as usize];681self.backend.read_config(u64::from(offset), &mut data);682Ok(data)683}684685fn set_config(686&mut self,687offset: u32,688buf: &[u8],689_flags: VhostUserConfigFlags,690) -> VhostResult<()> {691self.backend.write_config(u64::from(offset), buf);692Ok(())693}694695fn set_backend_req_fd(&mut self, ep: Connection) {696let conn = VhostBackendReqConnection::new(697FrontendClient::new(698ep,699self.acked_protocol_features700.contains(VhostUserProtocolFeatures::REPLY_ACK),701),702self.backend.get_shared_memory_region().map(|r| r.id),703);704705if self.backend_req_connection.is_some() {706warn!("Backend Request Connection already established. Overwriting");707}708self.backend_req_connection = Some(conn.clone());709710self.backend.set_backend_req_connection(conn);711}712713fn get_inflight_fd(714&mut self,715_inflight: &VhostUserInflight,716) -> VhostResult<(VhostUserInflight, File)> {717unimplemented!("get_inflight_fd");718}719720fn set_inflight_fd(&mut self, _inflight: &VhostUserInflight, _file: File) -> VhostResult<()> {721unimplemented!("set_inflight_fd");722}723724fn get_max_mem_slots(&mut self) -> VhostResult<u64> {725//TODO726Ok(0)727}728729fn add_mem_region(730&mut self,731_region: &VhostUserSingleMemoryRegion,732_fd: File,733) -> VhostResult<()> {734//TODO735Ok(())736}737738fn remove_mem_region(&mut self, _region: &VhostUserSingleMemoryRegion) -> VhostResult<()> {739//TODO740Ok(())741}742743fn set_device_state_fd(744&mut self,745transfer_direction: VhostUserTransferDirection,746migration_phase: VhostUserMigrationPhase,747fd: File,748) -> VhostResult<Option<File>> {749if migration_phase != VhostUserMigrationPhase::Stopped {750return Err(VhostError::InvalidOperation);751}752if !self.all_queues_stopped() {753return Err(VhostError::InvalidOperation);754}755if self.device_state_thread.is_some() {756error!("must call check_device_state before starting new state transfer");757return Err(VhostError::InvalidOperation);758}759// `set_device_state_fd` is designed to allow snapshot/restore concurrently with other760// methods, but, for simplicitly, we do those operations inline and only spawn a thread to761// handle the serialization and data transfer (the latter which seems necessary to762// implement the API correctly without, e.g., deadlocking because a pipe is full).763match transfer_direction {764VhostUserTransferDirection::Save => {765// Snapshot the state.766let snapshot = DeviceRequestHandlerSnapshot {767acked_features: self.acked_features,768acked_protocol_features: self.acked_protocol_features.bits(),769backend: self.backend.snapshot().map_err(VhostError::SnapshotError)?,770};771// Spawn thread to write the serialized bytes.772self.device_state_thread = Some(DeviceStateThread::Save(WorkerThread::start(773"device_state_save",774move |_kill_event| -> Result<(), ciborium::ser::Error<std::io::Error>> {775let mut w = std::io::BufWriter::new(fd);776ciborium::into_writer(&snapshot, &mut w)?;777w.flush()?;778Ok(())779},780)));781Ok(None)782}783VhostUserTransferDirection::Load => {784// Spawn a thread to read the bytes and deserialize. Restore will happen in785// `check_device_state`.786self.device_state_thread = Some(DeviceStateThread::Load(WorkerThread::start(787"device_state_load",788move |_kill_event| ciborium::from_reader(&mut BufReader::new(fd)),789)));790Ok(None)791}792}793}794795fn check_device_state(&mut self) -> VhostResult<()> {796let Some(thread) = self.device_state_thread.take() else {797error!("check_device_state: no active state transfer");798return Err(VhostError::InvalidOperation);799};800match thread {801DeviceStateThread::Save(worker) => {802worker.stop().map_err(|e| {803error!("device state save thread failed: {:#}", e);804VhostError::BackendInternalError805})?;806Ok(())807}808DeviceStateThread::Load(worker) => {809let snapshot = worker.stop().map_err(|e| {810error!("device state load thread failed: {:#}", e);811VhostError::BackendInternalError812})?;813self.acked_features = snapshot.acked_features;814self.acked_protocol_features =815VhostUserProtocolFeatures::from_bits(snapshot.acked_protocol_features)816.with_context(|| {817format!(818"unsupported bits in acked_protocol_features: {:#x}",819snapshot.acked_protocol_features820)821})822.map_err(VhostError::RestoreError)?;823self.backend824.restore(snapshot.backend)825.map_err(VhostError::RestoreError)?;826Ok(())827}828}829}830831fn get_shmem_config(&mut self) -> VhostResult<Vec<SharedMemoryRegion>> {832Ok(self833.backend834.get_shared_memory_region()835.into_iter()836.collect())837}838}839840/// Keeps track of Vhost user backend request connection.841#[derive(Clone)]842pub struct VhostBackendReqConnection {843shared: Arc<Mutex<VhostBackendReqConnectionShared>>,844shmid: Option<u8>,845}846847struct VhostBackendReqConnectionShared {848conn: FrontendClient,849mapped_regions: BTreeMap<u64 /* offset */, u64 /* size */>,850}851852impl VhostBackendReqConnection {853fn new(conn: FrontendClient, shmid: Option<u8>) -> Self {854Self {855shared: Arc::new(Mutex::new(VhostBackendReqConnectionShared {856conn,857mapped_regions: BTreeMap::new(),858})),859shmid,860}861}862863/// Send `VHOST_USER_CONFIG_CHANGE_MSG` to the frontend864fn send_config_changed(&self) -> anyhow::Result<()> {865let mut shared = self.shared.lock();866shared867.conn868.handle_config_change()869.context("Could not send config change message")?;870Ok(())871}872873/// Create a SharedMemoryMapper trait object using this backend request connection.874pub fn shmem_mapper(&self) -> Option<Box<dyn SharedMemoryMapper>> {875if let Some(shmid) = self.shmid {876Some(Box::new(VhostShmemMapper {877shared: self.shared.clone(),878shmid,879}))880} else {881None882}883}884}885886#[derive(Clone)]887struct VhostShmemMapper {888shared: Arc<Mutex<VhostBackendReqConnectionShared>>,889shmid: u8,890}891892impl SharedMemoryMapper for VhostShmemMapper {893fn add_mapping(894&mut self,895source: VmMemorySource,896offset: u64,897prot: Protection,898_cache: MemCacheType,899) -> anyhow::Result<()> {900let mut shared = self.shared.lock();901let size = match source {902VmMemorySource::Vulkan {903descriptor,904handle_type,905memory_idx,906device_uuid,907driver_uuid,908size,909} => {910let msg = VhostUserGpuMapMsg::new(911self.shmid,912offset,913size,914memory_idx,915handle_type,916device_uuid,917driver_uuid,918);919shared920.conn921.gpu_map(&msg, &descriptor)922.context("map GPU memory")?;923size924}925VmMemorySource::ExternalMapping { ptr, size } => {926let msg = VhostUserExternalMapMsg::new(self.shmid, offset, size, ptr);927shared928.conn929.external_map(&msg)930.context("create external mapping")?;931size932}933source => {934// The last two sources use the same VhostUserMMap, continue matching here935// on the aliased `source` above.936let (descriptor, fd_offset, size) = match source {937VmMemorySource::Descriptor {938descriptor,939offset,940size,941} => (descriptor, offset, size),942VmMemorySource::SharedMemory(shmem) => {943let size = shmem.size();944let descriptor = SafeDescriptor::from(shmem);945(descriptor, 0, size)946}947_ => bail!("unsupported source"),948};949let mut flags = VhostUserMMapFlags::empty();950anyhow::ensure!(prot.allows(&Protection::read()), "mapping must be readable");951if prot.allows(&Protection::write()) {952flags |= VhostUserMMapFlags::MAP_RW;953}954let msg = VhostUserMMap {955shmid: self.shmid,956padding: Default::default(),957fd_offset,958shm_offset: offset,959len: size,960flags,961};962shared963.conn964.shmem_map(&msg, &descriptor)965.context("map shmem")?;966size967}968};969970shared.mapped_regions.insert(offset, size);971Ok(())972}973974fn remove_mapping(&mut self, offset: u64) -> anyhow::Result<()> {975let mut shared = self.shared.lock();976let size = shared977.mapped_regions978.remove(&offset)979.context("unknown offset")?;980let msg = VhostUserMMap {981shmid: self.shmid,982padding: Default::default(),983fd_offset: 0,984shm_offset: offset,985len: size,986flags: VhostUserMMapFlags::empty(),987};988shared989.conn990.shmem_unmap(&msg)991.context("unmap shmem")992.map(|_| ())993}994}995996pub(crate) struct WorkerState<T, U> {997pub(crate) queue_task: TaskHandle<U>,998pub(crate) queue: T,999}10001001/// Errors for device operations1002#[derive(Debug, ThisError)]1003pub enum Error {1004#[error("worker not found when stopping queue")]1005WorkerNotFound,1006}10071008#[cfg(test)]1009mod tests {1010use std::sync::mpsc::channel;10111012use anyhow::bail;1013use base::Event;1014use virtio_sys::virtio_ring::VIRTIO_RING_F_EVENT_IDX;1015use vmm_vhost::BackendServer;1016use vmm_vhost::FrontendReq;1017use zerocopy::FromBytes;1018use zerocopy::FromZeros;1019use zerocopy::Immutable;1020use zerocopy::IntoBytes;1021use zerocopy::KnownLayout;10221023use super::*;1024use crate::virtio::vhost_user_frontend::VhostUserFrontend;1025use crate::virtio::DeviceType;1026use crate::virtio::VirtioDevice;10271028#[derive(Clone, Copy, Debug, PartialEq, Eq, FromBytes, Immutable, IntoBytes, KnownLayout)]1029#[repr(C, packed(4))]1030struct FakeConfig {1031x: u32,1032y: u64,1033}10341035const FAKE_CONFIG_DATA: FakeConfig = FakeConfig { x: 1, y: 2 };10361037pub(super) struct FakeBackend {1038avail_features: u64,1039acked_features: u64,1040active_queues: Vec<Option<Queue>>,1041allow_backend_req: bool,1042backend_conn: Option<VhostBackendReqConnection>,1043}10441045#[derive(Deserialize, Serialize)]1046struct FakeBackendSnapshot {1047data: Vec<u8>,1048}10491050impl FakeBackend {1051const MAX_QUEUE_NUM: usize = 16;10521053pub(super) fn new() -> Self {1054let mut active_queues = Vec::new();1055active_queues.resize_with(Self::MAX_QUEUE_NUM, Default::default);1056Self {1057avail_features: 1 << VHOST_USER_F_PROTOCOL_FEATURES | 1 << VIRTIO_RING_F_EVENT_IDX,1058acked_features: 0,1059active_queues,1060allow_backend_req: false,1061backend_conn: None,1062}1063}1064}10651066impl VhostUserDevice for FakeBackend {1067fn max_queue_num(&self) -> usize {1068Self::MAX_QUEUE_NUM1069}10701071fn features(&self) -> u64 {1072self.avail_features1073}10741075fn ack_features(&mut self, value: u64) -> anyhow::Result<()> {1076let unrequested_features = value & !self.avail_features;1077if unrequested_features != 0 {1078bail!(1079"invalid protocol features are given: 0x{:x}",1080unrequested_features1081);1082}1083self.acked_features |= value;1084Ok(())1085}10861087fn protocol_features(&self) -> VhostUserProtocolFeatures {1088let mut features =1089VhostUserProtocolFeatures::CONFIG | VhostUserProtocolFeatures::DEVICE_STATE;1090if self.allow_backend_req {1091features |= VhostUserProtocolFeatures::BACKEND_REQ;1092}1093features1094}10951096fn read_config(&self, offset: u64, dst: &mut [u8]) {1097dst.copy_from_slice(&FAKE_CONFIG_DATA.as_bytes()[offset as usize..]);1098}10991100fn reset(&mut self) {}11011102fn start_queue(1103&mut self,1104idx: usize,1105queue: Queue,1106_mem: GuestMemory,1107) -> anyhow::Result<()> {1108self.active_queues[idx] = Some(queue);1109Ok(())1110}11111112fn stop_queue(&mut self, idx: usize) -> anyhow::Result<Queue> {1113Ok(self.active_queues[idx]1114.take()1115.ok_or(Error::WorkerNotFound)?)1116}11171118fn set_backend_req_connection(&mut self, conn: VhostBackendReqConnection) {1119self.backend_conn = Some(conn);1120}11211122fn enter_suspended_state(&mut self) -> anyhow::Result<()> {1123Ok(())1124}11251126fn snapshot(&mut self) -> anyhow::Result<AnySnapshot> {1127AnySnapshot::to_any(FakeBackendSnapshot {1128data: vec![1, 2, 3],1129})1130.context("failed to serialize snapshot")1131}11321133fn restore(&mut self, data: AnySnapshot) -> anyhow::Result<()> {1134let snapshot: FakeBackendSnapshot =1135AnySnapshot::from_any(data).context("failed to deserialize snapshot")?;1136assert_eq!(snapshot.data, vec![1, 2, 3], "bad snapshot data");1137Ok(())1138}1139}11401141fn create_queues(1142num: usize,1143mem: &GuestMemory,1144interrupt: &Interrupt,1145) -> BTreeMap<usize, Queue> {1146let mut queues = BTreeMap::new();1147for idx in 0..num {1148let mut queue = QueueConfig::new(0x10, 0);1149queue.set_ready(true);1150let queue = queue1151.activate(mem, Event::new().unwrap(), interrupt.clone())1152.expect("QueueConfig::activate");1153queues.insert(idx, queue);1154}1155queues1156}11571158#[test]1159fn test_vhost_user_lifecycle() {1160test_vhost_user_lifecycle_parameterized(false);1161}11621163#[test]1164#[cfg(not(windows))] // Windows requries more complex connection setup.1165fn test_vhost_user_lifecycle_with_backend_req() {1166test_vhost_user_lifecycle_parameterized(true);1167}11681169fn test_vhost_user_lifecycle_parameterized(allow_backend_req: bool) {1170const QUEUES_NUM: usize = 2;1171const BASE_FEATURES: u64 = 1 << VIRTIO_RING_F_EVENT_IDX;1172const EXPECTED_FEATURES: u64 =11731 << VHOST_USER_F_PROTOCOL_FEATURES | 1 << VIRTIO_RING_F_EVENT_IDX;11741175// First phase: Test normal usage, then take a snapshot and shutdown.1176let snapshot = {1177let (client_connection, server_connection) = vmm_vhost::Connection::pair().unwrap();1178let (shutdown_tx, shutdown_rx) = channel();1179let (vm_evt_wrtube, _vm_evt_rdtube) = base::Tube::directional_pair().unwrap();1180let vmm_thread = std::thread::spawn(move || {1181// VMM side1182let mut vmm_device = VhostUserFrontend::new(1183DeviceType::Console,1184BASE_FEATURES,1185client_connection,1186vm_evt_wrtube,1187None,1188None,1189)1190.unwrap();11911192vmm_device.ack_features(BASE_FEATURES);11931194let mem = GuestMemory::new(&[(GuestAddress(0x0), 0x10000)]).unwrap();1195let interrupt = Interrupt::new_for_test_with_msix();11961197println!("read_config");1198let mut config = FakeConfig::new_zeroed();1199vmm_device.read_config(0, config.as_mut_bytes());1200// Check if the obtained config data is correct.1201assert_eq!(config, FAKE_CONFIG_DATA);12021203println!("activate");1204vmm_device1205.activate(1206mem.clone(),1207interrupt.clone(),1208create_queues(QUEUES_NUM, &mem, &interrupt),1209)1210.unwrap();12111212println!("reset");1213let reset_result = vmm_device.reset();1214assert!(1215reset_result.is_ok(),1216"reset failed: {:#}",1217reset_result.unwrap_err()1218);12191220println!("activate");1221vmm_device1222.activate(1223mem.clone(),1224interrupt.clone(),1225create_queues(QUEUES_NUM, &mem, &interrupt),1226)1227.unwrap();12281229println!("virtio_sleep");1230let queues = vmm_device1231.virtio_sleep()1232.unwrap()1233.expect("virtio_sleep unexpectedly returned None");12341235println!("virtio_snapshot");1236let snapshot = vmm_device1237.virtio_snapshot()1238.expect("virtio_snapshot failed");12391240println!("virtio_wake");1241vmm_device1242.virtio_wake(Some((mem.clone(), interrupt.clone(), queues)))1243.unwrap();12441245println!("wait for shutdown signal");1246shutdown_rx.recv().unwrap();12471248// The VMM side is supposed to stop before the device side.1249println!("drop");12501251snapshot1252});12531254// Device side1255let mut handler = DeviceRequestHandler::new(FakeBackend::new());1256handler.as_mut().allow_backend_req = allow_backend_req;12571258let mut req_handler = BackendServer::new(server_connection, handler);12591260// VhostUserFrontend::new()1261handle_request(&mut req_handler, FrontendReq::SET_OWNER).unwrap();1262handle_request(&mut req_handler, FrontendReq::GET_FEATURES).unwrap();1263handle_request(&mut req_handler, FrontendReq::GET_PROTOCOL_FEATURES).unwrap();1264handle_request(&mut req_handler, FrontendReq::SET_PROTOCOL_FEATURES).unwrap();1265if allow_backend_req {1266handle_request(&mut req_handler, FrontendReq::SET_BACKEND_REQ_FD).unwrap();1267}12681269// VhostUserFrontend::read_config()1270handle_request(&mut req_handler, FrontendReq::GET_CONFIG).unwrap();12711272// VhostUserFrontend::activate()1273handle_request(&mut req_handler, FrontendReq::SET_FEATURES).unwrap();1274assert_eq!(req_handler.as_ref().acked_features, EXPECTED_FEATURES);1275handle_request(&mut req_handler, FrontendReq::SET_MEM_TABLE).unwrap();1276for _ in 0..QUEUES_NUM {1277handle_request(&mut req_handler, FrontendReq::SET_VRING_NUM).unwrap();1278handle_request(&mut req_handler, FrontendReq::SET_VRING_ADDR).unwrap();1279handle_request(&mut req_handler, FrontendReq::SET_VRING_BASE).unwrap();1280handle_request(&mut req_handler, FrontendReq::SET_VRING_CALL).unwrap();1281handle_request(&mut req_handler, FrontendReq::SET_VRING_KICK).unwrap();1282handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();1283}12841285// VhostUserFrontend::reset()1286for _ in 0..QUEUES_NUM {1287handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();1288handle_request(&mut req_handler, FrontendReq::GET_VRING_BASE).unwrap();1289}12901291// VhostUserFrontend::activate()1292handle_request(&mut req_handler, FrontendReq::SET_MEM_TABLE).unwrap();1293for _ in 0..QUEUES_NUM {1294handle_request(&mut req_handler, FrontendReq::SET_VRING_NUM).unwrap();1295handle_request(&mut req_handler, FrontendReq::SET_VRING_ADDR).unwrap();1296handle_request(&mut req_handler, FrontendReq::SET_VRING_BASE).unwrap();1297handle_request(&mut req_handler, FrontendReq::SET_VRING_CALL).unwrap();1298handle_request(&mut req_handler, FrontendReq::SET_VRING_KICK).unwrap();1299handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();1300}13011302if allow_backend_req {1303// Make sure the connection still works even after reset/reactivate.1304req_handler1305.as_ref()1306.as_ref()1307.backend_conn1308.as_ref()1309.expect("backend_conn missing")1310.send_config_changed()1311.expect("send_config_changed failed");1312}13131314// VhostUserFrontend::virtio_sleep()1315for _ in 0..QUEUES_NUM {1316handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();1317handle_request(&mut req_handler, FrontendReq::GET_VRING_BASE).unwrap();1318}13191320// VhostUserFrontend::virtio_snapshot()1321handle_request(&mut req_handler, FrontendReq::SET_DEVICE_STATE_FD).unwrap();1322handle_request(&mut req_handler, FrontendReq::CHECK_DEVICE_STATE).unwrap();13231324// VhostUserFrontend::virtio_wake()1325handle_request(&mut req_handler, FrontendReq::SET_MEM_TABLE).unwrap();1326for _ in 0..QUEUES_NUM {1327handle_request(&mut req_handler, FrontendReq::SET_VRING_NUM).unwrap();1328handle_request(&mut req_handler, FrontendReq::SET_VRING_ADDR).unwrap();1329handle_request(&mut req_handler, FrontendReq::SET_VRING_BASE).unwrap();1330handle_request(&mut req_handler, FrontendReq::SET_VRING_CALL).unwrap();1331handle_request(&mut req_handler, FrontendReq::SET_VRING_KICK).unwrap();1332handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();1333}13341335if allow_backend_req {1336// Make sure the connection still works even after sleep/wake.1337req_handler1338.as_ref()1339.as_ref()1340.backend_conn1341.as_ref()1342.expect("backend_conn missing")1343.send_config_changed()1344.expect("send_config_changed failed");1345}13461347// Ask the client to shutdown, then wait to it to finish.1348shutdown_tx.send(()).unwrap();13491350// Verify recv_header fails with `ClientExit` after the client has disconnected.1351match req_handler.recv_header() {1352Err(VhostError::ClientExit) => (),1353r => panic!("expected Err(ClientExit) but got {r:?}"),1354}13551356vmm_thread.join().unwrap()1357};13581359// Second phase: Restore the snapshot.1360{1361let (client_connection, server_connection) = vmm_vhost::Connection::pair().unwrap();1362let (shutdown_tx, shutdown_rx) = channel();1363let (vm_evt_wrtube, _vm_evt_rdtube) = base::Tube::directional_pair().unwrap();1364let vmm_thread = std::thread::spawn(move || {1365// VMM side1366let mut vmm_device = VhostUserFrontend::new(1367DeviceType::Console,1368BASE_FEATURES,1369client_connection,1370vm_evt_wrtube,1371None,1372None,1373)1374.unwrap();13751376let mem = GuestMemory::new(&[(GuestAddress(0x0), 0x10000)]).unwrap();1377let interrupt = Interrupt::new_for_test_with_msix();13781379println!("virtio_sleep");1380assert!(vmm_device.virtio_sleep().unwrap().is_none());13811382println!("virtio_restore");1383vmm_device1384.virtio_restore(snapshot)1385.expect("virtio_restore failed");13861387println!("virtio_wake");1388vmm_device1389.virtio_wake(Some((1390mem.clone(),1391interrupt.clone(),1392create_queues(QUEUES_NUM, &mem, &interrupt),1393)))1394.unwrap();13951396println!("wait for shutdown signal");1397shutdown_rx.recv().unwrap();13981399// The VMM side is supposed to stop before the device side.1400println!("drop");1401});14021403// Device side1404let mut handler = DeviceRequestHandler::new(FakeBackend::new());1405handler.as_mut().allow_backend_req = allow_backend_req;14061407let mut req_handler = BackendServer::new(server_connection, handler);14081409// VhostUserFrontend::new()1410handle_request(&mut req_handler, FrontendReq::SET_OWNER).unwrap();1411handle_request(&mut req_handler, FrontendReq::GET_FEATURES).unwrap();1412handle_request(&mut req_handler, FrontendReq::GET_PROTOCOL_FEATURES).unwrap();1413handle_request(&mut req_handler, FrontendReq::SET_PROTOCOL_FEATURES).unwrap();1414if allow_backend_req {1415handle_request(&mut req_handler, FrontendReq::SET_BACKEND_REQ_FD).unwrap();1416}14171418// VhostUserFrontend::virtio_sleep()1419// (no-op)14201421// VhostUserFrontend::virtio_restore()1422handle_request(&mut req_handler, FrontendReq::SET_FEATURES).unwrap();1423assert_eq!(req_handler.as_ref().acked_features, EXPECTED_FEATURES);1424handle_request(&mut req_handler, FrontendReq::SET_DEVICE_STATE_FD).unwrap();1425handle_request(&mut req_handler, FrontendReq::CHECK_DEVICE_STATE).unwrap();14261427// VhostUserFrontend::virtio_wake()1428handle_request(&mut req_handler, FrontendReq::SET_MEM_TABLE).unwrap();1429for _ in 0..QUEUES_NUM {1430handle_request(&mut req_handler, FrontendReq::SET_VRING_NUM).unwrap();1431handle_request(&mut req_handler, FrontendReq::SET_VRING_ADDR).unwrap();1432handle_request(&mut req_handler, FrontendReq::SET_VRING_BASE).unwrap();1433handle_request(&mut req_handler, FrontendReq::SET_VRING_CALL).unwrap();1434handle_request(&mut req_handler, FrontendReq::SET_VRING_KICK).unwrap();1435handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();1436}14371438if allow_backend_req {1439// Make sure the connection still works even after restore.1440req_handler1441.as_ref()1442.as_ref()1443.backend_conn1444.as_ref()1445.expect("backend_conn missing")1446.send_config_changed()1447.expect("send_config_changed failed");1448}14491450// Ask the client to shutdown, then wait to it to finish.1451shutdown_tx.send(()).unwrap();1452// Verify recv_header fails with `ClientExit` after the client has disconnected.1453match req_handler.recv_header() {1454Err(VhostError::ClientExit) => (),1455r => panic!("expected Err(ClientExit) but got {r:?}"),1456}1457vmm_thread.join().unwrap();1458}1459}14601461#[track_caller]1462fn handle_request<S: vmm_vhost::Backend>(1463handler: &mut BackendServer<S>,1464expected_message_type: FrontendReq,1465) -> Result<(), VhostError> {1466let (hdr, files) = handler.recv_header()?;1467assert_eq!(hdr.get_code(), Ok(expected_message_type));1468handler.process_message(hdr, files)1469}1470}147114721473