diff --git a/virtio-devices/src/vhost_user/blk.rs b/virtio-devices/src/vhost_user/blk.rs index 231af10e4..c4d702711 100644 --- a/virtio-devices/src/vhost_user/blk.rs +++ b/virtio-devices/src/vhost_user/blk.rs @@ -4,10 +4,7 @@ use super::super::{ ActivateError, ActivateResult, Queue, VirtioCommon, VirtioDevice, VirtioDeviceType, }; -use super::vu_common_ctrl::{ - add_memory_region, connect_vhost_user, negotiate_features_vhost_user, reset_vhost_user, - setup_vhost_user, update_mem_table, VhostUserConfig, -}; +use super::vu_common_ctrl::{VhostUserConfig, VhostUserHandle}; use super::{Error, Result, DEFAULT_VIRTIO_FEATURES}; use crate::vhost_user::{Inflight, VhostUserEpollHandler}; use crate::VirtioInterrupt; @@ -23,7 +20,7 @@ use std::vec::Vec; use vhost::vhost_user::message::VhostUserConfigFlags; use vhost::vhost_user::message::VHOST_USER_CONFIG_OFFSET; use vhost::vhost_user::message::{VhostUserProtocolFeatures, VhostUserVirtioFeatures}; -use vhost::vhost_user::{Master, MasterReqHandler, VhostUserMaster, VhostUserMasterReqHandler}; +use vhost::vhost_user::{MasterReqHandler, VhostUserMaster, VhostUserMasterReqHandler}; use vhost::VhostBackend; use virtio_bindings::bindings::virtio_blk::{ VIRTIO_BLK_F_BLK_SIZE, VIRTIO_BLK_F_CONFIG_WCE, VIRTIO_BLK_F_DISCARD, VIRTIO_BLK_F_FLUSH, @@ -42,7 +39,7 @@ impl VhostUserMasterReqHandler for SlaveReqHandler {} pub struct Blk { common: VirtioCommon, id: String, - vhost_user_blk: Arc>, + vu: Arc>, config: VirtioBlockConfig, guest_memory: Option>, acked_protocol_features: u64, @@ -55,8 +52,8 @@ impl Blk { pub fn new(id: String, vu_cfg: VhostUserConfig) -> Result { let num_queues = vu_cfg.num_queues; - let mut vhost_user_blk = - connect_vhost_user(false, &vu_cfg.socket, num_queues as u64, false)?; + let mut vu = + VhostUserHandle::connect_vhost_user(false, &vu_cfg.socket, num_queues as u64, false)?; // Filling device and vring features VMM supports. let mut avail_features = 1 << VIRTIO_BLK_F_SIZE_MAX @@ -81,15 +78,12 @@ impl Blk { | VhostUserProtocolFeatures::REPLY_ACK | VhostUserProtocolFeatures::INFLIGHT_SHMFD; - let (acked_features, acked_protocol_features) = negotiate_features_vhost_user( - &mut vhost_user_blk, - avail_features, - avail_protocol_features, - )?; + let (acked_features, acked_protocol_features) = + vu.negotiate_features_vhost_user(avail_features, avail_protocol_features)?; let backend_num_queues = if acked_protocol_features & VhostUserProtocolFeatures::MQ.bits() != 0 { - vhost_user_blk + vu.socket_handle() .get_queue_num() .map_err(Error::VhostUserGetQueueMaxNum)? as usize } else { @@ -104,7 +98,8 @@ impl Blk { let config_len = mem::size_of::(); let config_space: Vec = vec![0u8; config_len as usize]; - let (_, config_space) = vhost_user_blk + let (_, config_space) = vu + .socket_handle() .get_config( VHOST_USER_CONFIG_OFFSET, config_len as u32, @@ -122,7 +117,7 @@ impl Blk { // how many virt queues to be handled, which backend required to know // at early stage. for i in 0..num_queues { - vhost_user_blk + vu.socket_handle() .set_vring_base(i, 0) .map_err(Error::VhostUserSetVringBase)?; } @@ -138,7 +133,7 @@ impl Blk { ..Default::default() }, id, - vhost_user_blk: Arc::new(Mutex::new(vhost_user_blk)), + vu: Arc::new(Mutex::new(vu)), config, guest_memory: None, acked_protocol_features, @@ -195,9 +190,10 @@ impl VirtioDevice for Blk { self.config.writeback = data[0]; if let Err(e) = self - .vhost_user_blk + .vu .lock() .unwrap() + .socket_handle() .set_config(offset as u32, VhostUserConfigFlags::WRITABLE, data) .map_err(Error::VhostUserSetConfig) { @@ -232,24 +228,26 @@ impl VirtioDevice for Blk { None }; - setup_vhost_user( - &mut self.vhost_user_blk.lock().unwrap(), - &mem.memory(), - queues.clone(), - queue_evts.iter().map(|q| q.try_clone().unwrap()).collect(), - &interrupt_cb, - backend_acked_features, - &slave_req_handler, - inflight.as_mut(), - ) - .map_err(ActivateError::VhostUserBlkSetup)?; + self.vu + .lock() + .unwrap() + .setup_vhost_user( + &mem.memory(), + queues.clone(), + queue_evts.iter().map(|q| q.try_clone().unwrap()).collect(), + &interrupt_cb, + backend_acked_features, + &slave_req_handler, + inflight.as_mut(), + ) + .map_err(ActivateError::VhostUserBlkSetup)?; // Run a dedicated thread for handling potential reconnections with // the backend. let (kill_evt, pause_evt) = self.common.dup_eventfds(); let mut handler: VhostUserEpollHandler = VhostUserEpollHandler { - vu: self.vhost_user_blk.clone(), + vu: self.vu.clone(), mem, kill_evt, pause_evt, @@ -289,10 +287,12 @@ impl VirtioDevice for Blk { self.common.resume().ok()?; } - if let Err(e) = reset_vhost_user( - &mut self.vhost_user_blk.lock().unwrap(), - self.common.queue_sizes.len(), - ) { + if let Err(e) = self + .vu + .lock() + .unwrap() + .reset_vhost_user(self.common.queue_sizes.len()) + { error!("Failed to reset vhost-user daemon: {:?}", e); return None; } @@ -309,7 +309,7 @@ impl VirtioDevice for Blk { } fn shutdown(&mut self) { - let _ = unsafe { libc::close(self.vhost_user_blk.lock().unwrap().as_raw_fd()) }; + let _ = unsafe { libc::close(self.vu.lock().unwrap().socket_handle().as_raw_fd()) }; } fn add_memory_region( @@ -318,14 +318,17 @@ impl VirtioDevice for Blk { ) -> std::result::Result<(), crate::Error> { if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits() != 0 { - add_memory_region(&mut self.vhost_user_blk.lock().unwrap(), region) + self.vu + .lock() + .unwrap() + .add_memory_region(region) .map_err(crate::Error::VhostUserAddMemoryRegion) } else if let Some(guest_memory) = &self.guest_memory { - update_mem_table( - &mut self.vhost_user_blk.lock().unwrap(), - guest_memory.memory().deref(), - ) - .map_err(crate::Error::VhostUserUpdateMemory) + self.vu + .lock() + .unwrap() + .update_mem_table(guest_memory.memory().deref()) + .map_err(crate::Error::VhostUserUpdateMemory) } else { Ok(()) } diff --git a/virtio-devices/src/vhost_user/fs.rs b/virtio-devices/src/vhost_user/fs.rs index 47b8c1f8b..6832b6857 100644 --- a/virtio-devices/src/vhost_user/fs.rs +++ b/virtio-devices/src/vhost_user/fs.rs @@ -1,10 +1,7 @@ // Copyright 2019 Intel Corporation. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -use super::vu_common_ctrl::{ - add_memory_region, connect_vhost_user, negotiate_features_vhost_user, reset_vhost_user, - setup_vhost_user, update_mem_table, -}; +use super::vu_common_ctrl::VhostUserHandle; use super::{Error, Result, DEFAULT_VIRTIO_FEATURES}; use crate::seccomp_filters::{get_seccomp_filter, Thread}; use crate::vhost_user::{Inflight, VhostUserEpollHandler}; @@ -26,7 +23,7 @@ use vhost::vhost_user::message::{ VhostUserVirtioFeatures, VHOST_USER_FS_SLAVE_ENTRIES, }; use vhost::vhost_user::{ - HandlerResult, Master, MasterReqHandler, VhostUserMaster, VhostUserMasterReqHandler, + HandlerResult, MasterReqHandler, VhostUserMaster, VhostUserMasterReqHandler, }; use vm_memory::{ Address, ByteValued, GuestAddress, GuestAddressSpace, GuestMemory, GuestMemoryAtomic, @@ -286,7 +283,7 @@ unsafe impl ByteValued for VirtioFsConfig {} pub struct Fs { common: VirtioCommon, id: String, - vu: Arc>, + vu: Arc>, config: VirtioFsConfig, // Hold ownership of the memory that is allocated for the device // which will be automatically dropped when the device is dropped @@ -316,7 +313,7 @@ impl Fs { let num_queues = NUM_QUEUE_OFFSET + req_num_queues; // Connect to the vhost-user socket. - let mut vhost_user_fs = connect_vhost_user(false, path, num_queues as u64, false)?; + let mut vu = VhostUserHandle::connect_vhost_user(false, path, num_queues as u64, false)?; // Filling device and vring features VMM supports. let avail_features = DEFAULT_VIRTIO_FEATURES; @@ -331,15 +328,12 @@ impl Fs { avail_protocol_features |= slave_protocol_features; } - let (acked_features, acked_protocol_features) = negotiate_features_vhost_user( - &mut vhost_user_fs, - avail_features, - avail_protocol_features, - )?; + let (acked_features, acked_protocol_features) = + vu.negotiate_features_vhost_user(avail_features, avail_protocol_features)?; let backend_num_queues = if acked_protocol_features & VhostUserProtocolFeatures::MQ.bits() != 0 { - vhost_user_fs + vu.socket_handle() .get_queue_num() .map_err(Error::VhostUserGetQueueMaxNum)? as usize } else { @@ -377,7 +371,7 @@ impl Fs { ..Default::default() }, id, - vu: Arc::new(Mutex::new(vhost_user_fs)), + vu: Arc::new(Mutex::new(vu)), config, cache, slave_req_support, @@ -467,17 +461,19 @@ impl VirtioDevice for Fs { None }; - setup_vhost_user( - &mut self.vu.lock().unwrap(), - &mem.memory(), - queues.clone(), - queue_evts.iter().map(|q| q.try_clone().unwrap()).collect(), - &interrupt_cb, - backend_acked_features, - &slave_req_handler, - inflight.as_mut(), - ) - .map_err(ActivateError::VhostUserFsSetup)?; + self.vu + .lock() + .unwrap() + .setup_vhost_user( + &mem.memory(), + queues.clone(), + queue_evts.iter().map(|q| q.try_clone().unwrap()).collect(), + &interrupt_cb, + backend_acked_features, + &slave_req_handler, + inflight.as_mut(), + ) + .map_err(ActivateError::VhostUserFsSetup)?; // Run a dedicated thread for handling potential reconnections with // the backend as well as requests initiated by the backend. @@ -530,8 +526,11 @@ impl VirtioDevice for Fs { self.common.resume().ok()?; } - if let Err(e) = - reset_vhost_user(&mut self.vu.lock().unwrap(), self.common.queue_sizes.len()) + if let Err(e) = self + .vu + .lock() + .unwrap() + .reset_vhost_user(self.common.queue_sizes.len()) { error!("Failed to reset vhost-user daemon: {:?}", e); return None; @@ -549,7 +548,7 @@ impl VirtioDevice for Fs { } fn shutdown(&mut self) { - let _ = unsafe { libc::close(self.vu.lock().unwrap().as_raw_fd()) }; + let _ = unsafe { libc::close(self.vu.lock().unwrap().socket_handle().as_raw_fd()) }; } fn get_shm_regions(&self) -> Option { @@ -574,10 +573,16 @@ impl VirtioDevice for Fs { ) -> std::result::Result<(), crate::Error> { if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits() != 0 { - add_memory_region(&mut self.vu.lock().unwrap(), region) + self.vu + .lock() + .unwrap() + .add_memory_region(region) .map_err(crate::Error::VhostUserAddMemoryRegion) } else if let Some(guest_memory) = &self.guest_memory { - update_mem_table(&mut self.vu.lock().unwrap(), guest_memory.memory().deref()) + self.vu + .lock() + .unwrap() + .update_mem_table(guest_memory.memory().deref()) .map_err(crate::Error::VhostUserUpdateMemory) } else { Ok(()) diff --git a/virtio-devices/src/vhost_user/mod.rs b/virtio-devices/src/vhost_user/mod.rs index a8c8a43dc..5d7dbd67a 100644 --- a/virtio-devices/src/vhost_user/mod.rs +++ b/virtio-devices/src/vhost_user/mod.rs @@ -12,12 +12,12 @@ use std::ops::Deref; use std::os::unix::io::AsRawFd; use std::sync::{atomic::AtomicBool, Arc, Barrier, Mutex}; use vhost::vhost_user::message::{VhostUserInflight, VhostUserVirtioFeatures}; -use vhost::vhost_user::{Master, MasterReqHandler, VhostUserMasterReqHandler}; +use vhost::vhost_user::{MasterReqHandler, VhostUserMasterReqHandler}; use vhost::Error as VhostError; use vm_memory::{Error as MmapError, GuestAddressSpace, GuestMemoryAtomic}; use vm_virtio::Error as VirtioError; use vmm_sys_util::eventfd::EventFd; -use vu_common_ctrl::{connect_vhost_user, reinitialize_vhost_user}; +use vu_common_ctrl::VhostUserHandle; pub mod blk; pub mod fs; @@ -138,7 +138,7 @@ pub struct Inflight { } pub struct VhostUserEpollHandler { - pub vu: Arc>, + pub vu: Arc>, pub mem: GuestMemoryAtomic, pub kill_evt: EventFd, pub pause_evt: EventFd, @@ -161,7 +161,7 @@ impl VhostUserEpollHandler { ) -> std::result::Result<(), EpollHelperError> { let mut helper = EpollHelper::new(&self.kill_evt, &self.pause_evt)?; helper.add_event_custom( - self.vu.lock().unwrap().as_raw_fd(), + self.vu.lock().unwrap().socket_handle().as_raw_fd(), HUP_CONNECTION_EVENT, epoll::Events::EPOLLHUP, )?; @@ -177,12 +177,12 @@ impl VhostUserEpollHandler { fn reconnect(&mut self, helper: &mut EpollHelper) -> std::result::Result<(), EpollHelperError> { helper.del_event_custom( - self.vu.lock().unwrap().as_raw_fd(), + self.vu.lock().unwrap().socket_handle().as_raw_fd(), HUP_CONNECTION_EVENT, epoll::Events::EPOLLHUP, )?; - let mut vhost_user = connect_vhost_user( + let mut vhost_user = VhostUserHandle::connect_vhost_user( self.server, &self.socket_path, self.queues.len() as u64, @@ -196,29 +196,29 @@ impl VhostUserEpollHandler { })?; // Initialize the backend - reinitialize_vhost_user( - &mut vhost_user, - self.mem.memory().deref(), - self.queues.clone(), - self.queue_evts - .iter() - .map(|q| q.try_clone().unwrap()) - .collect(), - &self.virtio_interrupt, - self.acked_features, - self.acked_protocol_features, - &self.slave_req_handler, - self.inflight.as_mut(), - ) - .map_err(|e| { - EpollHelperError::IoError(std::io::Error::new( - std::io::ErrorKind::Other, - format!("failed reconnecting vhost-user backend{:?}", e), - )) - })?; + vhost_user + .reinitialize_vhost_user( + self.mem.memory().deref(), + self.queues.clone(), + self.queue_evts + .iter() + .map(|q| q.try_clone().unwrap()) + .collect(), + &self.virtio_interrupt, + self.acked_features, + self.acked_protocol_features, + &self.slave_req_handler, + self.inflight.as_mut(), + ) + .map_err(|e| { + EpollHelperError::IoError(std::io::Error::new( + std::io::ErrorKind::Other, + format!("failed reconnecting vhost-user backend{:?}", e), + )) + })?; helper.add_event_custom( - vhost_user.as_raw_fd(), + vhost_user.socket_handle().as_raw_fd(), HUP_CONNECTION_EVENT, epoll::Events::EPOLLHUP, )?; diff --git a/virtio-devices/src/vhost_user/net.rs b/virtio-devices/src/vhost_user/net.rs index 2f269e3f9..6d31309c5 100644 --- a/virtio-devices/src/vhost_user/net.rs +++ b/virtio-devices/src/vhost_user/net.rs @@ -2,10 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 use crate::seccomp_filters::{get_seccomp_filter, Thread}; -use crate::vhost_user::vu_common_ctrl::{ - add_memory_region, connect_vhost_user, negotiate_features_vhost_user, reset_vhost_user, - setup_vhost_user, update_mem_table, VhostUserConfig, -}; +use crate::vhost_user::vu_common_ctrl::{VhostUserConfig, VhostUserHandle}; use crate::vhost_user::{Error, Inflight, Result, VhostUserEpollHandler}; use crate::{ ActivateError, ActivateResult, EpollHelper, EpollHelperError, EpollHelperHandler, Queue, @@ -23,7 +20,7 @@ use std::sync::{Arc, Barrier, Mutex}; use std::thread; use std::vec::Vec; use vhost::vhost_user::message::{VhostUserProtocolFeatures, VhostUserVirtioFeatures}; -use vhost::vhost_user::{Master, MasterReqHandler, VhostUserMaster, VhostUserMasterReqHandler}; +use vhost::vhost_user::{MasterReqHandler, VhostUserMaster, VhostUserMasterReqHandler}; use virtio_bindings::bindings::virtio_net::{ VIRTIO_NET_F_CSUM, VIRTIO_NET_F_CTRL_VQ, VIRTIO_NET_F_GUEST_CSUM, VIRTIO_NET_F_GUEST_ECN, VIRTIO_NET_F_GUEST_TSO4, VIRTIO_NET_F_GUEST_TSO6, VIRTIO_NET_F_GUEST_UFO, @@ -94,7 +91,7 @@ impl EpollHelperHandler for NetCtrlEpollHandler { pub struct Net { common: VirtioCommon, id: String, - vhost_user_net: Arc>, + vu: Arc>, config: VirtioNetConfig, guest_memory: Option>, acked_protocol_features: u64, @@ -136,23 +133,20 @@ impl Net { let mut config = VirtioNetConfig::default(); build_net_config_space(&mut config, mac_addr, num_queues, &mut avail_features); - let mut vhost_user_net = - connect_vhost_user(server, &vu_cfg.socket, num_queues as u64, false)?; + let mut vu = + VhostUserHandle::connect_vhost_user(server, &vu_cfg.socket, num_queues as u64, false)?; let avail_protocol_features = VhostUserProtocolFeatures::MQ | VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS | VhostUserProtocolFeatures::REPLY_ACK | VhostUserProtocolFeatures::INFLIGHT_SHMFD; - let (mut acked_features, acked_protocol_features) = negotiate_features_vhost_user( - &mut vhost_user_net, - avail_features, - avail_protocol_features, - )?; + let (mut acked_features, acked_protocol_features) = + vu.negotiate_features_vhost_user(avail_features, avail_protocol_features)?; let backend_num_queues = if acked_protocol_features & VhostUserProtocolFeatures::MQ.bits() != 0 { - vhost_user_net + vu.socket_handle() .get_queue_num() .map_err(Error::VhostUserGetQueueMaxNum)? as usize } else { @@ -186,7 +180,7 @@ impl Net { min_queues: DEFAULT_QUEUE_NUMBER as u16, ..Default::default() }, - vhost_user_net: Arc::new(Mutex::new(vhost_user_net)), + vu: Arc::new(Mutex::new(vu)), config, guest_memory: None, acked_protocol_features, @@ -301,24 +295,26 @@ impl VirtioDevice for Net { None }; - setup_vhost_user( - &mut self.vhost_user_net.lock().unwrap(), - &mem.memory(), - queues.clone(), - queue_evts.iter().map(|q| q.try_clone().unwrap()).collect(), - &interrupt_cb, - backend_acked_features, - &slave_req_handler, - inflight.as_mut(), - ) - .map_err(ActivateError::VhostUserNetSetup)?; + self.vu + .lock() + .unwrap() + .setup_vhost_user( + &mem.memory(), + queues.clone(), + queue_evts.iter().map(|q| q.try_clone().unwrap()).collect(), + &interrupt_cb, + backend_acked_features, + &slave_req_handler, + inflight.as_mut(), + ) + .map_err(ActivateError::VhostUserNetSetup)?; // Run a dedicated thread for handling potential reconnections with // the backend. let (kill_evt, pause_evt) = self.common.dup_eventfds(); let mut handler: VhostUserEpollHandler = VhostUserEpollHandler { - vu: self.vhost_user_net.clone(), + vu: self.vu.clone(), mem, kill_evt, pause_evt, @@ -358,10 +354,12 @@ impl VirtioDevice for Net { self.common.resume().ok()?; } - if let Err(e) = reset_vhost_user( - &mut self.vhost_user_net.lock().unwrap(), - self.common.queue_sizes.len(), - ) { + if let Err(e) = self + .vu + .lock() + .unwrap() + .reset_vhost_user(self.common.queue_sizes.len()) + { error!("Failed to reset vhost-user daemon: {:?}", e); return None; } @@ -378,7 +376,7 @@ impl VirtioDevice for Net { } fn shutdown(&mut self) { - let _ = unsafe { libc::close(self.vhost_user_net.lock().unwrap().as_raw_fd()) }; + let _ = unsafe { libc::close(self.vu.lock().unwrap().socket_handle().as_raw_fd()) }; // Remove socket path if needed if self.server { @@ -392,14 +390,17 @@ impl VirtioDevice for Net { ) -> std::result::Result<(), crate::Error> { if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits() != 0 { - add_memory_region(&mut self.vhost_user_net.lock().unwrap(), region) + self.vu + .lock() + .unwrap() + .add_memory_region(region) .map_err(crate::Error::VhostUserAddMemoryRegion) } else if let Some(guest_memory) = &self.guest_memory { - update_mem_table( - &mut self.vhost_user_net.lock().unwrap(), - guest_memory.memory().deref(), - ) - .map_err(crate::Error::VhostUserUpdateMemory) + self.vu + .lock() + .unwrap() + .update_mem_table(guest_memory.memory().deref()) + .map_err(crate::Error::VhostUserUpdateMemory) } else { Ok(()) } diff --git a/virtio-devices/src/vhost_user/vu_common_ctrl.rs b/virtio-devices/src/vhost_user/vu_common_ctrl.rs index b866dbe97..cb62505e1 100644 --- a/virtio-devices/src/vhost_user/vu_common_ctrl.rs +++ b/virtio-devices/src/vhost_user/vu_common_ctrl.rs @@ -28,15 +28,45 @@ pub struct VhostUserConfig { pub queue_size: u16, } -pub fn update_mem_table(vu: &mut Master, mem: &GuestMemoryMmap) -> Result<()> { - let mut regions: Vec = Vec::new(); - for region in mem.iter() { +#[derive(Clone)] +pub struct VhostUserHandle { + vu: Master, +} + +impl VhostUserHandle { + pub fn update_mem_table(&mut self, mem: &GuestMemoryMmap) -> Result<()> { + let mut regions: Vec = Vec::new(); + for region in mem.iter() { + let (mmap_handle, mmap_offset) = match region.file_offset() { + Some(_file_offset) => (_file_offset.file().as_raw_fd(), _file_offset.start()), + None => return Err(Error::VhostUserMemoryRegion(MmapError::NoMemoryRegion)), + }; + + let vhost_user_net_reg = VhostUserMemoryRegionInfo { + guest_phys_addr: region.start_addr().raw_value(), + memory_size: region.len() as u64, + userspace_addr: region.as_ptr() as u64, + mmap_offset, + mmap_handle, + }; + + regions.push(vhost_user_net_reg); + } + + self.vu + .set_mem_table(regions.as_slice()) + .map_err(Error::VhostUserSetMemTable)?; + + Ok(()) + } + + pub fn add_memory_region(&mut self, region: &Arc) -> Result<()> { let (mmap_handle, mmap_offset) = match region.file_offset() { - Some(_file_offset) => (_file_offset.file().as_raw_fd(), _file_offset.start()), - None => return Err(Error::VhostUserMemoryRegion(MmapError::NoMemoryRegion)), + Some(file_offset) => (file_offset.file().as_raw_fd(), file_offset.start()), + None => return Err(Error::MissingRegionFd), }; - let vhost_user_net_reg = VhostUserMemoryRegionInfo { + let region = VhostUserMemoryRegionInfo { guest_phys_addr: region.start_addr().raw_value(), memory_size: region.len() as u64, userspace_addr: region.as_ptr() as u64, @@ -44,254 +74,258 @@ pub fn update_mem_table(vu: &mut Master, mem: &GuestMemoryMmap) -> Result<()> { mmap_handle, }; - regions.push(vhost_user_net_reg); + self.vu + .add_mem_region(®ion) + .map_err(Error::VhostUserAddMemReg) } - vu.set_mem_table(regions.as_slice()) - .map_err(Error::VhostUserSetMemTable)?; + pub fn negotiate_features_vhost_user( + &mut self, + avail_features: u64, + avail_protocol_features: VhostUserProtocolFeatures, + ) -> Result<(u64, u64)> { + // Set vhost-user owner. + self.vu.set_owner().map_err(Error::VhostUserSetOwner)?; - Ok(()) -} + // Get features from backend, do negotiation to get a feature collection which + // both VMM and backend support. + let backend_features = self + .vu + .get_features() + .map_err(Error::VhostUserGetFeatures)?; + let acked_features = avail_features & backend_features; -pub fn add_memory_region(vu: &mut Master, region: &Arc) -> Result<()> { - let (mmap_handle, mmap_offset) = match region.file_offset() { - Some(file_offset) => (file_offset.file().as_raw_fd(), file_offset.start()), - None => return Err(Error::MissingRegionFd), - }; + let acked_protocol_features = + if acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() != 0 { + let backend_protocol_features = self + .vu + .get_protocol_features() + .map_err(Error::VhostUserGetProtocolFeatures)?; - let region = VhostUserMemoryRegionInfo { - guest_phys_addr: region.start_addr().raw_value(), - memory_size: region.len() as u64, - userspace_addr: region.as_ptr() as u64, - mmap_offset, - mmap_handle, - }; + let acked_protocol_features = avail_protocol_features & backend_protocol_features; - vu.add_mem_region(®ion) - .map_err(Error::VhostUserAddMemReg) -} + self.vu + .set_protocol_features(acked_protocol_features) + .map_err(Error::VhostUserSetProtocolFeatures)?; -pub fn negotiate_features_vhost_user( - vu: &mut Master, - avail_features: u64, - avail_protocol_features: VhostUserProtocolFeatures, -) -> Result<(u64, u64)> { - // Set vhost-user owner. - vu.set_owner().map_err(Error::VhostUserSetOwner)?; + acked_protocol_features + } else { + VhostUserProtocolFeatures::empty() + }; - // Get features from backend, do negotiation to get a feature collection which - // both VMM and backend support. - let backend_features = vu.get_features().map_err(Error::VhostUserGetFeatures)?; - let acked_features = avail_features & backend_features; + if avail_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK) + && acked_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK) + { + self.vu.set_hdr_flags(VhostUserHeaderFlag::NEED_REPLY); + } - let acked_protocol_features = - if acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() != 0 { - let backend_protocol_features = vu - .get_protocol_features() - .map_err(Error::VhostUserGetProtocolFeatures)?; + Ok((acked_features, acked_protocol_features.bits())) + } - let acked_protocol_features = avail_protocol_features & backend_protocol_features; + #[allow(clippy::too_many_arguments)] + pub fn setup_vhost_user( + &mut self, + mem: &GuestMemoryMmap, + queues: Vec, + queue_evts: Vec, + virtio_interrupt: &Arc, + acked_features: u64, + slave_req_handler: &Option>, + inflight: Option<&mut Inflight>, + ) -> Result<()> { + self.vu + .set_features(acked_features) + .map_err(Error::VhostUserSetFeatures)?; - vu.set_protocol_features(acked_protocol_features) - .map_err(Error::VhostUserSetProtocolFeatures)?; + // Let's first provide the memory table to the backend. + self.update_mem_table(mem)?; - acked_protocol_features + // Setup for inflight I/O tracking shared memory. + if let Some(inflight) = inflight { + if inflight.fd.is_none() { + let inflight_req_info = VhostUserInflight { + mmap_size: 0, + mmap_offset: 0, + num_queues: queues.len() as u16, + queue_size: queues[0].actual_size(), + }; + let (info, fd) = self + .vu + .get_inflight_fd(&inflight_req_info) + .map_err(Error::VhostUserGetInflight)?; + inflight.info = info; + inflight.fd = Some(fd); + } + // Unwrapping the inflight fd is safe here since we know it can't be None. + self.vu + .set_inflight_fd(&inflight.info, inflight.fd.as_ref().unwrap().as_raw_fd()) + .map_err(Error::VhostUserSetInflight)?; + } + + for (queue_index, queue) in queues.into_iter().enumerate() { + let actual_size: usize = queue.actual_size().try_into().unwrap(); + + self.vu + .set_vring_num(queue_index, queue.actual_size()) + .map_err(Error::VhostUserSetVringNum)?; + + let config_data = VringConfigData { + queue_max_size: queue.get_max_size(), + queue_size: queue.actual_size(), + flags: 0u32, + desc_table_addr: get_host_address_range( + mem, + queue.desc_table, + actual_size * std::mem::size_of::(), + ) + .ok_or(Error::DescriptorTableAddress)? as u64, + // The used ring is {flags: u16; idx: u16; virtq_used_elem [{id: u16, len: u16}; actual_size]}, + // i.e. 4 + (4 + 4) * actual_size. + used_ring_addr: get_host_address_range(mem, queue.used_ring, 4 + actual_size * 8) + .ok_or(Error::UsedAddress)? as u64, + // The used ring is {flags: u16; idx: u16; elem [u16; actual_size]}, + // i.e. 4 + (2) * actual_size. + avail_ring_addr: get_host_address_range(mem, queue.avail_ring, 4 + actual_size * 2) + .ok_or(Error::AvailAddress)? as u64, + log_addr: None, + }; + + self.vu + .set_vring_addr(queue_index, &config_data) + .map_err(Error::VhostUserSetVringAddr)?; + self.vu + .set_vring_base( + queue_index, + queue + .avail_index_from_memory(mem) + .map_err(Error::GetAvailableIndex)?, + ) + .map_err(Error::VhostUserSetVringBase)?; + + if let Some(eventfd) = + virtio_interrupt.notifier(&VirtioInterruptType::Queue, Some(&queue)) + { + self.vu + .set_vring_call(queue_index, &eventfd) + .map_err(Error::VhostUserSetVringCall)?; + } + + self.vu + .set_vring_kick(queue_index, &queue_evts[queue_index]) + .map_err(Error::VhostUserSetVringKick)?; + + self.vu + .set_vring_enable(queue_index, true) + .map_err(Error::VhostUserSetVringEnable)?; + } + + if let Some(slave_req_handler) = slave_req_handler { + self.vu + .set_slave_request_fd(&slave_req_handler.get_tx_raw_fd()) + .map_err(Error::VhostUserSetSlaveRequestFd) } else { - VhostUserProtocolFeatures::empty() - }; - - if avail_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK) - && acked_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK) - { - vu.set_hdr_flags(VhostUserHeaderFlag::NEED_REPLY); - } - - Ok((acked_features, acked_protocol_features.bits())) -} - -#[allow(clippy::too_many_arguments)] -pub fn setup_vhost_user( - vu: &mut Master, - mem: &GuestMemoryMmap, - queues: Vec, - queue_evts: Vec, - virtio_interrupt: &Arc, - acked_features: u64, - slave_req_handler: &Option>, - inflight: Option<&mut Inflight>, -) -> Result<()> { - vu.set_features(acked_features) - .map_err(Error::VhostUserSetFeatures)?; - - // Let's first provide the memory table to the backend. - update_mem_table(vu, mem)?; - - // Setup for inflight I/O tracking shared memory. - if let Some(inflight) = inflight { - if inflight.fd.is_none() { - let inflight_req_info = VhostUserInflight { - mmap_size: 0, - mmap_offset: 0, - num_queues: queues.len() as u16, - queue_size: queues[0].actual_size(), - }; - let (info, fd) = vu - .get_inflight_fd(&inflight_req_info) - .map_err(Error::VhostUserGetInflight)?; - inflight.info = info; - inflight.fd = Some(fd); + Ok(()) } - // Unwrapping the inflight fd is safe here since we know it can't be None. - vu.set_inflight_fd(&inflight.info, inflight.fd.as_ref().unwrap().as_raw_fd()) - .map_err(Error::VhostUserSetInflight)?; } - for (queue_index, queue) in queues.into_iter().enumerate() { - let actual_size: usize = queue.actual_size().try_into().unwrap(); + pub fn reset_vhost_user(&mut self, num_queues: usize) -> Result<()> { + for queue_index in 0..num_queues { + // Disable the vrings. + self.vu + .set_vring_enable(queue_index, false) + .map_err(Error::VhostUserSetVringEnable)?; + } - vu.set_vring_num(queue_index, queue.actual_size()) - .map_err(Error::VhostUserSetVringNum)?; + // Reset the owner. + self.vu.reset_owner().map_err(Error::VhostUserResetOwner) + } - let config_data = VringConfigData { - queue_max_size: queue.get_max_size(), - queue_size: queue.actual_size(), - flags: 0u32, - desc_table_addr: get_host_address_range( - mem, - queue.desc_table, - actual_size * std::mem::size_of::(), - ) - .ok_or(Error::DescriptorTableAddress)? as u64, - // The used ring is {flags: u16; idx: u16; virtq_used_elem [{id: u16, len: u16}; actual_size]}, - // i.e. 4 + (4 + 4) * actual_size. - used_ring_addr: get_host_address_range(mem, queue.used_ring, 4 + actual_size * 8) - .ok_or(Error::UsedAddress)? as u64, - // The used ring is {flags: u16; idx: u16; elem [u16; actual_size]}, - // i.e. 4 + (2) * actual_size. - avail_ring_addr: get_host_address_range(mem, queue.avail_ring, 4 + actual_size * 2) - .ok_or(Error::AvailAddress)? as u64, - log_addr: None, - }; + #[allow(clippy::too_many_arguments)] + pub fn reinitialize_vhost_user( + &mut self, + mem: &GuestMemoryMmap, + queues: Vec, + queue_evts: Vec, + virtio_interrupt: &Arc, + acked_features: u64, + acked_protocol_features: u64, + slave_req_handler: &Option>, + inflight: Option<&mut Inflight>, + ) -> Result<()> { + self.vu.set_owner().map_err(Error::VhostUserSetOwner)?; + self.vu + .get_features() + .map_err(Error::VhostUserGetFeatures)?; - vu.set_vring_addr(queue_index, &config_data) - .map_err(Error::VhostUserSetVringAddr)?; - vu.set_vring_base( - queue_index, - queue - .avail_index_from_memory(mem) - .map_err(Error::GetAvailableIndex)?, + if acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() != 0 { + if let Some(acked_protocol_features) = + VhostUserProtocolFeatures::from_bits(acked_protocol_features) + { + self.vu + .set_protocol_features(acked_protocol_features) + .map_err(Error::VhostUserSetProtocolFeatures)?; + + if acked_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK) { + self.vu.set_hdr_flags(VhostUserHeaderFlag::NEED_REPLY); + } + } + } + + self.setup_vhost_user( + mem, + queues, + queue_evts, + virtio_interrupt, + acked_features, + slave_req_handler, + inflight, ) - .map_err(Error::VhostUserSetVringBase)?; - - if let Some(eventfd) = virtio_interrupt.notifier(&VirtioInterruptType::Queue, Some(&queue)) - { - vu.set_vring_call(queue_index, &eventfd) - .map_err(Error::VhostUserSetVringCall)?; - } - - vu.set_vring_kick(queue_index, &queue_evts[queue_index]) - .map_err(Error::VhostUserSetVringKick)?; - - vu.set_vring_enable(queue_index, true) - .map_err(Error::VhostUserSetVringEnable)?; } - if let Some(slave_req_handler) = slave_req_handler { - vu.set_slave_request_fd(&slave_req_handler.get_tx_raw_fd()) - .map_err(Error::VhostUserSetSlaveRequestFd) - } else { - Ok(()) - } -} - -pub fn reset_vhost_user(vu: &mut Master, num_queues: usize) -> Result<()> { - for queue_index in 0..num_queues { - // Disable the vrings. - vu.set_vring_enable(queue_index, false) - .map_err(Error::VhostUserSetVringEnable)?; - } - - // Reset the owner. - vu.reset_owner().map_err(Error::VhostUserResetOwner) -} - -#[allow(clippy::too_many_arguments)] -pub fn reinitialize_vhost_user( - vu: &mut Master, - mem: &GuestMemoryMmap, - queues: Vec, - queue_evts: Vec, - virtio_interrupt: &Arc, - acked_features: u64, - acked_protocol_features: u64, - slave_req_handler: &Option>, - inflight: Option<&mut Inflight>, -) -> Result<()> { - vu.set_owner().map_err(Error::VhostUserSetOwner)?; - vu.get_features().map_err(Error::VhostUserGetFeatures)?; - - if acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() != 0 { - if let Some(acked_protocol_features) = - VhostUserProtocolFeatures::from_bits(acked_protocol_features) - { - vu.set_protocol_features(acked_protocol_features) - .map_err(Error::VhostUserSetProtocolFeatures)?; - - if acked_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK) { - vu.set_hdr_flags(VhostUserHeaderFlag::NEED_REPLY); + pub fn connect_vhost_user( + server: bool, + socket_path: &str, + num_queues: u64, + unlink_socket: bool, + ) -> Result { + if server { + if unlink_socket { + std::fs::remove_file(socket_path).map_err(Error::RemoveSocketPath)?; } - } - } - setup_vhost_user( - vu, - mem, - queues, - queue_evts, - virtio_interrupt, - acked_features, - slave_req_handler, - inflight, - ) -} + info!("Binding vhost-user listener..."); + let listener = UnixListener::bind(socket_path).map_err(Error::BindSocket)?; + info!("Waiting for incoming vhost-user connection..."); + let (stream, _) = listener.accept().map_err(Error::AcceptConnection)?; -pub fn connect_vhost_user( - server: bool, - socket_path: &str, - num_queues: u64, - unlink_socket: bool, -) -> Result { - if server { - if unlink_socket { - std::fs::remove_file(socket_path).map_err(Error::RemoveSocketPath)?; - } + Ok(VhostUserHandle { + vu: Master::from_stream(stream, num_queues), + }) + } else { + let now = Instant::now(); - info!("Binding vhost-user listener..."); - let listener = UnixListener::bind(socket_path).map_err(Error::BindSocket)?; - info!("Waiting for incoming vhost-user connection..."); - let (stream, _) = listener.accept().map_err(Error::AcceptConnection)?; + // Retry connecting for a full minute + let err = loop { + let err = match Master::connect(socket_path, num_queues) { + Ok(m) => return Ok(VhostUserHandle { vu: m }), + Err(e) => e, + }; + sleep(Duration::from_millis(100)); - Ok(Master::from_stream(stream, num_queues)) - } else { - let now = Instant::now(); - - // Retry connecting for a full minute - let err = loop { - let err = match Master::connect(socket_path, num_queues) { - Ok(m) => return Ok(m), - Err(e) => e, + if now.elapsed().as_secs() >= 60 { + break err; + } }; - sleep(Duration::from_millis(100)); - if now.elapsed().as_secs() >= 60 { - break err; - } - }; + error!( + "Failed connecting the backend after trying for 1 minute: {:?}", + err + ); + Err(Error::VhostUserConnect) + } + } - error!( - "Failed connecting the backend after trying for 1 minute: {:?}", - err - ); - Err(Error::VhostUserConnect) + pub fn socket_handle(&mut self) -> &mut Master { + &mut self.vu } }