diff --git a/block_util/src/lib.rs b/block_util/src/lib.rs index 8f3db17e2..3247ff124 100644 --- a/block_util/src/lib.rs +++ b/block_util/src/lib.rs @@ -40,7 +40,7 @@ use vm_memory::{ bitmap::AtomicBitmap, bitmap::Bitmap, ByteValued, Bytes, GuestAddress, GuestMemory, GuestMemoryError, GuestMemoryLoadGuard, }; -use vm_virtio::AccessPlatform; +use vm_virtio::{AccessPlatform, Translatable}; use vmm_sys_util::eventfd::EventFd; type GuestMemoryMmap = vm_memory::GuestMemoryMmap; @@ -207,15 +207,9 @@ impl Request { return Err(Error::UnexpectedWriteOnlyDescriptor); } - let hdr_desc_addr = if let Some(access_platform) = access_platform { - GuestAddress( - access_platform - .translate(hdr_desc.addr().0, u64::from(hdr_desc.len())) - .unwrap(), - ) - } else { - hdr_desc.addr() - }; + let hdr_desc_addr = hdr_desc + .addr() + .translate(access_platform, hdr_desc.len() as usize); let mut req = Request { request_type: request_type(desc_chain.memory(), hdr_desc_addr)?, @@ -254,17 +248,10 @@ impl Request { return Err(Error::UnexpectedReadOnlyDescriptor); } - let desc_addr = if let Some(access_platform) = access_platform { - GuestAddress( - access_platform - .translate(desc.addr().0, u64::from(desc.len())) - .unwrap(), - ) - } else { - desc.addr() - }; - - req.data_descriptors.push((desc_addr, desc.len())); + req.data_descriptors.push(( + desc.addr().translate(access_platform, desc.len() as usize), + desc.len(), + )); desc = desc_chain .next() .ok_or(Error::DescriptorChainTooShort) @@ -285,15 +272,9 @@ impl Request { return Err(Error::DescriptorLengthTooSmall); } - req.status_addr = if let Some(access_platform) = access_platform { - GuestAddress( - access_platform - .translate(status_desc.addr().0, u64::from(status_desc.len())) - .unwrap(), - ) - } else { - status_desc.addr() - }; + req.status_addr = status_desc + .addr() + .translate(access_platform, status_desc.len() as usize); Ok(req) } diff --git a/net_util/src/ctrl_queue.rs b/net_util/src/ctrl_queue.rs index 7ccd6b9c7..647c2f39a 100644 --- a/net_util/src/ctrl_queue.rs +++ b/net_util/src/ctrl_queue.rs @@ -14,8 +14,8 @@ use virtio_bindings::bindings::virtio_net::{ VIRTIO_NET_F_GUEST_UFO, VIRTIO_NET_OK, }; use virtio_queue::Queue; -use vm_memory::{ByteValued, Bytes, GuestAddress, GuestMemoryAtomic, GuestMemoryError}; -use vm_virtio::AccessPlatform; +use vm_memory::{ByteValued, Bytes, GuestMemoryAtomic, GuestMemoryError}; +use vm_virtio::{AccessPlatform, Translatable}; #[derive(Debug)] pub enum Error { @@ -65,44 +65,22 @@ impl CtrlQueue { for mut desc_chain in queue.iter().map_err(Error::QueueIterator)? { let ctrl_desc = desc_chain.next().ok_or(Error::NoControlHeaderDescriptor)?; - let ctrl_desc_addr = if let Some(access_platform) = access_platform { - GuestAddress( - access_platform - .translate(ctrl_desc.addr().0, u64::from(ctrl_desc.len())) - .unwrap(), - ) - } else { - ctrl_desc.addr() - }; - let ctrl_hdr: ControlHeader = desc_chain .memory() - .read_obj(ctrl_desc_addr) + .read_obj( + ctrl_desc + .addr() + .translate(access_platform, ctrl_desc.len() as usize), + ) .map_err(Error::GuestMemory)?; let data_desc = desc_chain.next().ok_or(Error::NoDataDescriptor)?; - let data_desc_addr = if let Some(access_platform) = access_platform { - GuestAddress( - access_platform - .translate(data_desc.addr().0, u64::from(data_desc.len())) - .unwrap(), - ) - } else { - data_desc.addr() - }; + let data_desc_addr = data_desc + .addr() + .translate(access_platform, data_desc.len() as usize); let status_desc = desc_chain.next().ok_or(Error::NoStatusDescriptor)?; - let status_desc_addr = if let Some(access_platform) = access_platform { - GuestAddress( - access_platform - .translate(status_desc.addr().0, u64::from(status_desc.len())) - .unwrap(), - ) - } else { - status_desc.addr() - }; - let ok = match u32::from(ctrl_hdr.class) { VIRTIO_NET_CTRL_MQ => { let queue_pairs = desc_chain @@ -154,7 +132,9 @@ impl CtrlQueue { .memory() .write_obj( if ok { VIRTIO_NET_OK } else { VIRTIO_NET_ERR } as u8, - status_desc_addr, + status_desc + .addr() + .translate(access_platform, status_desc.len() as usize), ) .map_err(Error::GuestMemory)?; let len = ctrl_desc.len() + data_desc.len() + status_desc.len(); diff --git a/net_util/src/queue_pair.rs b/net_util/src/queue_pair.rs index cb904e014..b51761dab 100644 --- a/net_util/src/queue_pair.rs +++ b/net_util/src/queue_pair.rs @@ -11,8 +11,8 @@ use std::os::unix::io::{AsRawFd, RawFd}; use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; use virtio_queue::Queue; -use vm_memory::{Bytes, GuestAddress, GuestMemory, GuestMemoryAtomic}; -use vm_virtio::AccessPlatform; +use vm_memory::{Bytes, GuestMemory, GuestMemoryAtomic}; +use vm_virtio::{AccessPlatform, Translatable}; #[derive(Clone)] pub struct TxVirtio { @@ -60,16 +60,7 @@ impl TxVirtio { let mut iovecs = Vec::new(); while let Some(desc) = next_desc { - let desc_addr = if let Some(access_platform) = access_platform { - GuestAddress( - access_platform - .translate(desc.addr().0, u64::from(desc.len())) - .unwrap(), - ) - } else { - desc.addr() - }; - + let desc_addr = desc.addr().translate(access_platform, desc.len() as usize); if !desc.is_write_only() && desc.len() > 0 { let buf = desc_chain .memory() @@ -195,31 +186,18 @@ impl RxVirtio { .next() .ok_or(NetQueuePairError::DescriptorChainTooShort)?; - let desc_addr = if let Some(access_platform) = access_platform { - GuestAddress( - access_platform - .translate(desc.addr().0, u64::from(desc.len())) - .unwrap(), + let num_buffers_addr = desc_chain + .memory() + .checked_offset( + desc.addr().translate(access_platform, desc.len() as usize), + 10, ) - } else { - desc.addr() - }; - - let num_buffers_addr = desc_chain.memory().checked_offset(desc_addr, 10).unwrap(); + .unwrap(); let mut next_desc = Some(desc); let mut iovecs = Vec::new(); while let Some(desc) = next_desc { - let desc_addr = if let Some(access_platform) = access_platform { - GuestAddress( - access_platform - .translate(desc.addr().0, u64::from(desc.len())) - .unwrap(), - ) - } else { - desc.addr() - }; - + let desc_addr = desc.addr().translate(access_platform, desc.len() as usize); if desc.is_write_only() && desc.len() > 0 { let buf = desc_chain .memory() diff --git a/virtio-devices/src/console.rs b/virtio-devices/src/console.rs index 0a259f1e2..b45f64e7f 100644 --- a/virtio-devices/src/console.rs +++ b/virtio-devices/src/console.rs @@ -25,10 +25,10 @@ use std::sync::{Arc, Barrier, Mutex}; use versionize::{VersionMap, Versionize, VersionizeResult}; use versionize_derive::Versionize; use virtio_queue::Queue; -use vm_memory::{ByteValued, Bytes, GuestAddress, GuestMemoryAtomic}; +use vm_memory::{ByteValued, Bytes, GuestMemoryAtomic}; use vm_migration::VersionMapped; use vm_migration::{Migratable, MigratableError, Pausable, Snapshot, Snapshottable, Transportable}; -use vm_virtio::AccessPlatform; +use vm_virtio::{AccessPlatform, Translatable}; use vmm_sys_util::eventfd::EventFd; const QUEUE_SIZE: u16 = 256; @@ -148,20 +148,11 @@ impl ConsoleEpollHandler { let len = cmp::min(desc.len() as u32, in_buffer.len() as u32); let source_slice = in_buffer.drain(..len as usize).collect::>(); - let desc_addr = if let Some(access_platform) = &self.access_platform { - GuestAddress( - access_platform - .translate(desc.addr().0, u64::from(desc.len())) - .unwrap(), - ) - } else { + if let Err(e) = desc_chain.memory().write_slice( + &source_slice[..], desc.addr() - }; - - if let Err(e) = desc_chain - .memory() - .write_slice(&source_slice[..], desc_addr) - { + .translate(self.access_platform.as_ref(), desc.len() as usize), + ) { error!("Failed to write slice: {:?}", e); avail_iter.go_to_previous_position(); break; @@ -197,19 +188,12 @@ impl ConsoleEpollHandler { for mut desc_chain in trans_queue.iter().unwrap() { let desc = desc_chain.next().unwrap(); if let Some(ref mut out) = self.endpoint.out_file() { - let desc_addr = if let Some(access_platform) = &self.access_platform { - GuestAddress( - access_platform - .translate(desc.addr().0, u64::from(desc.len())) - .unwrap(), - ) - } else { + let _ = desc_chain.memory().write_to( desc.addr() - }; - - let _ = desc_chain - .memory() - .write_to(desc_addr, out, desc.len() as usize); + .translate(self.access_platform.as_ref(), desc.len() as usize), + out, + desc.len() as usize, + ); let _ = out.flush(); } used_desc_heads[used_count] = (desc_chain.head_index(), desc.len()); diff --git a/virtio-devices/src/pmem.rs b/virtio-devices/src/pmem.rs index 2e205c0dd..b8c598794 100644 --- a/virtio-devices/src/pmem.rs +++ b/virtio-devices/src/pmem.rs @@ -34,7 +34,7 @@ use vm_memory::{ }; use vm_migration::VersionMapped; use vm_migration::{Migratable, MigratableError, Pausable, Snapshot, Snapshottable, Transportable}; -use vm_virtio::AccessPlatform; +use vm_virtio::{AccessPlatform, Translatable}; use vmm_sys_util::eventfd::EventFd; const QUEUE_SIZE: u16 = 256; @@ -131,19 +131,9 @@ impl Request { return Err(Error::InvalidRequest); } - let desc_addr = if let Some(access_platform) = access_platform { - GuestAddress( - access_platform - .translate(desc.addr().0, u64::from(desc.len())) - .unwrap(), - ) - } else { - desc.addr() - }; - let request: VirtioPmemReq = desc_chain .memory() - .read_obj(desc_addr) + .read_obj(desc.addr().translate(access_platform, desc.len() as usize)) .map_err(Error::GuestMemory)?; let request_type = match request.type_ { @@ -162,19 +152,11 @@ impl Request { return Err(Error::BufferLengthTooSmall); } - let status_desc_addr = if let Some(access_platform) = access_platform { - GuestAddress( - access_platform - .translate(status_desc.addr().0, u64::from(status_desc.len())) - .unwrap(), - ) - } else { - status_desc.addr() - }; - Ok(Request { type_: request_type, - status_addr: status_desc_addr, + status_addr: status_desc + .addr() + .translate(access_platform, status_desc.len() as usize), }) } } diff --git a/virtio-devices/src/rng.rs b/virtio-devices/src/rng.rs index 2370005ce..4bb9ebb1f 100644 --- a/virtio-devices/src/rng.rs +++ b/virtio-devices/src/rng.rs @@ -22,10 +22,10 @@ use std::sync::{Arc, Barrier}; use versionize::{VersionMap, Versionize, VersionizeResult}; use versionize_derive::Versionize; use virtio_queue::Queue; -use vm_memory::{Bytes, GuestAddress, GuestMemoryAtomic}; +use vm_memory::{Bytes, GuestMemoryAtomic}; use vm_migration::VersionMapped; use vm_migration::{Migratable, MigratableError, Pausable, Snapshot, Snapshottable, Transportable}; -use vm_virtio::AccessPlatform; +use vm_virtio::{AccessPlatform, Translatable}; use vmm_sys_util::eventfd::EventFd; const QUEUE_SIZE: u16 = 256; @@ -56,20 +56,15 @@ impl RngEpollHandler { // Drivers can only read from the random device. if desc.is_write_only() { - let desc_addr = if let Some(access_platform) = &self.access_platform { - GuestAddress( - access_platform - .translate(desc.addr().0, u64::from(desc.len())) - .unwrap(), - ) - } else { - desc.addr() - }; - // Fill the read with data from the random device on the host. if desc_chain .memory() - .read_from(desc_addr, &mut self.random_file, desc.len() as usize) + .read_from( + desc.addr() + .translate(self.access_platform.as_ref(), desc.len() as usize), + &mut self.random_file, + desc.len() as usize, + ) .is_ok() { len = desc.len(); diff --git a/virtio-devices/src/vsock/packet.rs b/virtio-devices/src/vsock/packet.rs index 6e0b8b177..c15cc4f18 100644 --- a/virtio-devices/src/vsock/packet.rs +++ b/virtio-devices/src/vsock/packet.rs @@ -22,8 +22,8 @@ use super::defs; use super::{Result, VsockError}; use crate::{get_host_address_range, GuestMemoryMmap}; use virtio_queue::DescriptorChain; -use vm_memory::{GuestAddress, GuestMemoryLoadGuard}; -use vm_virtio::AccessPlatform; +use vm_memory::GuestMemoryLoadGuard; +use vm_virtio::{AccessPlatform, Translatable}; // The vsock packet header is defined by the C struct: // @@ -124,19 +124,13 @@ impl VsockPacket { return Err(VsockError::HdrDescTooSmall(head.len())); } - let head_addr = if let Some(access_platform) = access_platform { - GuestAddress( - access_platform - .translate(head.addr().0, u64::from(head.len())) - .unwrap(), - ) - } else { - head.addr() - }; - let mut pkt = Self { - hdr: get_host_address_range(desc_chain.memory(), head_addr, VSOCK_PKT_HDR_SIZE) - .ok_or(VsockError::GuestMemory)? as *mut u8, + hdr: get_host_address_range( + desc_chain.memory(), + head.addr().translate(access_platform, head.len() as usize), + VSOCK_PKT_HDR_SIZE, + ) + .ok_or(VsockError::GuestMemory)? as *mut u8, buf: None, buf_size: 0, }; @@ -166,20 +160,16 @@ impl VsockPacket { return Err(VsockError::BufDescTooSmall); } - let buf_desc_addr = if let Some(access_platform) = access_platform { - GuestAddress( - access_platform - .translate(buf_desc.addr().0, u64::from(buf_desc.len())) - .unwrap(), - ) - } else { - buf_desc.addr() - }; - pkt.buf_size = buf_desc.len() as usize; pkt.buf = Some( - get_host_address_range(desc_chain.memory(), buf_desc_addr, pkt.buf_size) - .ok_or(VsockError::GuestMemory)? as *mut u8, + get_host_address_range( + desc_chain.memory(), + buf_desc + .addr() + .translate(access_platform, buf_desc.len() as usize), + pkt.buf_size, + ) + .ok_or(VsockError::GuestMemory)? as *mut u8, ); Ok(pkt) @@ -214,29 +204,22 @@ impl VsockPacket { let buf_desc = desc_chain.next().ok_or(VsockError::BufDescMissing)?; let buf_size = buf_desc.len() as usize; - let (head_addr, buf_desc_addr) = if let Some(access_platform) = access_platform { - ( - GuestAddress( - access_platform - .translate(head.addr().0, u64::from(head.len())) - .unwrap(), - ), - GuestAddress( - access_platform - .translate(buf_desc.addr().0, u64::from(buf_desc.len())) - .unwrap(), - ), - ) - } else { - (head.addr(), buf_desc.addr()) - }; - Ok(Self { - hdr: get_host_address_range(desc_chain.memory(), head_addr, VSOCK_PKT_HDR_SIZE) - .ok_or(VsockError::GuestMemory)? as *mut u8, + hdr: get_host_address_range( + desc_chain.memory(), + head.addr().translate(access_platform, head.len() as usize), + VSOCK_PKT_HDR_SIZE, + ) + .ok_or(VsockError::GuestMemory)? as *mut u8, buf: Some( - get_host_address_range(desc_chain.memory(), buf_desc_addr, buf_size) - .ok_or(VsockError::GuestMemory)? as *mut u8, + get_host_address_range( + desc_chain.memory(), + buf_desc + .addr() + .translate(access_platform, buf_desc.len() as usize), + buf_size, + ) + .ok_or(VsockError::GuestMemory)? as *mut u8, ), buf_size, }) diff --git a/vm-virtio/src/lib.rs b/vm-virtio/src/lib.rs index 5cebb3e80..74f1a0c45 100644 --- a/vm-virtio/src/lib.rs +++ b/vm-virtio/src/lib.rs @@ -11,6 +11,9 @@ //! Implements virtio queues use std::fmt::{self, Debug}; +use std::sync::Arc; + +use vm_memory::GuestAddress; pub mod queue; pub use queue::*; @@ -94,3 +97,18 @@ pub trait AccessPlatform: Send + Sync + Debug { /// Provide a way to translate address ranges. fn translate(&self, base: u64, size: u64) -> std::result::Result; } + +pub trait Translatable { + #[must_use] + fn translate(&self, access_platform: Option<&Arc>, len: usize) -> Self; +} + +impl Translatable for GuestAddress { + fn translate(&self, access_platform: Option<&Arc>, len: usize) -> Self { + if let Some(access_platform) = access_platform { + GuestAddress(access_platform.translate(self.0, len as u64).unwrap()) + } else { + *self + } + } +}