diff --git a/vm-virtio/src/vhost_user/fs.rs b/vm-virtio/src/vhost_user/fs.rs index e8debdf5e..6dbde8afa 100644 --- a/vm-virtio/src/vhost_user/fs.rs +++ b/vm-virtio/src/vhost_user/fs.rs @@ -27,8 +27,8 @@ use vhost_rs::vhost_user::{ }; use vhost_rs::VhostBackend; use vm_memory::{ - Address, ByteValued, GuestAddress, GuestAddressSpace, GuestMemoryAtomic, GuestMemoryMmap, - MmapRegion, + Address, ByteValued, GuestAddress, GuestAddressSpace, GuestMemory, GuestMemoryAtomic, + GuestMemoryMmap, MmapRegion, }; use vm_migration::{Migratable, MigratableError, Pausable, Snapshottable, Transportable}; use vmm_sys_util::eventfd::EventFd; @@ -39,6 +39,7 @@ struct SlaveReqHandler { cache_offset: GuestAddress, cache_size: u64, mmap_cache_addr: u64, + mem: GuestMemoryAtomic, } impl SlaveReqHandler { @@ -183,18 +184,32 @@ impl VhostUserMasterReqHandler for SlaveReqHandler { let cache_end = self.cache_offset.raw_value() + self.cache_size; let efault = libc::EFAULT; - let offset = gpa - .checked_sub(self.cache_offset.raw_value()) - .ok_or_else(|| io::Error::from_raw_os_error(efault))?; - let end = gpa - .checked_add(fs.len[i]) - .ok_or_else(|| io::Error::from_raw_os_error(efault))?; + let mut ptr = if gpa >= self.cache_offset.raw_value() && gpa < cache_end { + let offset = gpa + .checked_sub(self.cache_offset.raw_value()) + .ok_or_else(|| io::Error::from_raw_os_error(efault))?; + let end = gpa + .checked_add(fs.len[i]) + .ok_or_else(|| io::Error::from_raw_os_error(efault))?; - if gpa < self.cache_offset.raw_value() || gpa >= cache_end || end >= cache_end { - return Err(io::Error::from_raw_os_error(efault)); - } + if end >= cache_end { + return Err(io::Error::from_raw_os_error(efault)); + } + + self.mmap_cache_addr + offset + } else { + self.mem + .memory() + .get_host_address(GuestAddress(gpa)) + .map_err(|e| { + error!( + "Failed to find RAM region associated with guest physical address 0x{:x}: {:?}", + gpa, e + ); + io::Error::from_raw_os_error(efault) + })? as u64 + }; - let mut ptr = self.mmap_cache_addr + offset; while len > 0 { let ret = if (fs.flags[i] & VhostUserFSSlaveMsgFlags::MAP_W) == VhostUserFSSlaveMsgFlags::MAP_W @@ -479,6 +494,7 @@ impl VirtioDevice for Fs { cache_offset: cache.0.addr, cache_size: cache.0.len, mmap_cache_addr: cache.0.host_addr, + mem, })); let req_handler = MasterReqHandler::new(vu_master_req_handler).map_err(|e| {