vm-virtio: Define and implement Translatable trait

This new trait simplifies the address translation of a GuestAddress by
having GuestAddress implementing it.

The three crates virtio-devices, block_util and net_util have been
updated accordingly to rely on this new trait, helping with code
readability and limiting the amount of duplicated code.

Signed-off-by: Sebastien Boeuf <sebastien.boeuf@intel.com>
This commit is contained in:
Sebastien Boeuf 2022-01-26 23:44:31 +01:00 committed by Rob Bradford
parent c99d637693
commit 77df4e6773
8 changed files with 106 additions and 205 deletions

View File

@ -40,7 +40,7 @@ use vm_memory::{
bitmap::AtomicBitmap, bitmap::Bitmap, ByteValued, Bytes, GuestAddress, GuestMemory, bitmap::AtomicBitmap, bitmap::Bitmap, ByteValued, Bytes, GuestAddress, GuestMemory,
GuestMemoryError, GuestMemoryLoadGuard, GuestMemoryError, GuestMemoryLoadGuard,
}; };
use vm_virtio::AccessPlatform; use vm_virtio::{AccessPlatform, Translatable};
use vmm_sys_util::eventfd::EventFd; use vmm_sys_util::eventfd::EventFd;
type GuestMemoryMmap = vm_memory::GuestMemoryMmap<AtomicBitmap>; type GuestMemoryMmap = vm_memory::GuestMemoryMmap<AtomicBitmap>;
@ -207,15 +207,9 @@ impl Request {
return Err(Error::UnexpectedWriteOnlyDescriptor); return Err(Error::UnexpectedWriteOnlyDescriptor);
} }
let hdr_desc_addr = if let Some(access_platform) = access_platform { let hdr_desc_addr = hdr_desc
GuestAddress( .addr()
access_platform .translate(access_platform, hdr_desc.len() as usize);
.translate(hdr_desc.addr().0, u64::from(hdr_desc.len()))
.unwrap(),
)
} else {
hdr_desc.addr()
};
let mut req = Request { let mut req = Request {
request_type: request_type(desc_chain.memory(), hdr_desc_addr)?, request_type: request_type(desc_chain.memory(), hdr_desc_addr)?,
@ -254,17 +248,10 @@ impl Request {
return Err(Error::UnexpectedReadOnlyDescriptor); return Err(Error::UnexpectedReadOnlyDescriptor);
} }
let desc_addr = if let Some(access_platform) = access_platform { req.data_descriptors.push((
GuestAddress( desc.addr().translate(access_platform, desc.len() as usize),
access_platform desc.len(),
.translate(desc.addr().0, u64::from(desc.len())) ));
.unwrap(),
)
} else {
desc.addr()
};
req.data_descriptors.push((desc_addr, desc.len()));
desc = desc_chain desc = desc_chain
.next() .next()
.ok_or(Error::DescriptorChainTooShort) .ok_or(Error::DescriptorChainTooShort)
@ -285,15 +272,9 @@ impl Request {
return Err(Error::DescriptorLengthTooSmall); return Err(Error::DescriptorLengthTooSmall);
} }
req.status_addr = if let Some(access_platform) = access_platform { req.status_addr = status_desc
GuestAddress( .addr()
access_platform .translate(access_platform, status_desc.len() as usize);
.translate(status_desc.addr().0, u64::from(status_desc.len()))
.unwrap(),
)
} else {
status_desc.addr()
};
Ok(req) Ok(req)
} }

View File

@ -14,8 +14,8 @@ use virtio_bindings::bindings::virtio_net::{
VIRTIO_NET_F_GUEST_UFO, VIRTIO_NET_OK, VIRTIO_NET_F_GUEST_UFO, VIRTIO_NET_OK,
}; };
use virtio_queue::Queue; use virtio_queue::Queue;
use vm_memory::{ByteValued, Bytes, GuestAddress, GuestMemoryAtomic, GuestMemoryError}; use vm_memory::{ByteValued, Bytes, GuestMemoryAtomic, GuestMemoryError};
use vm_virtio::AccessPlatform; use vm_virtio::{AccessPlatform, Translatable};
#[derive(Debug)] #[derive(Debug)]
pub enum Error { pub enum Error {
@ -65,44 +65,22 @@ impl CtrlQueue {
for mut desc_chain in queue.iter().map_err(Error::QueueIterator)? { for mut desc_chain in queue.iter().map_err(Error::QueueIterator)? {
let ctrl_desc = desc_chain.next().ok_or(Error::NoControlHeaderDescriptor)?; let ctrl_desc = desc_chain.next().ok_or(Error::NoControlHeaderDescriptor)?;
let ctrl_desc_addr = if let Some(access_platform) = access_platform {
GuestAddress(
access_platform
.translate(ctrl_desc.addr().0, u64::from(ctrl_desc.len()))
.unwrap(),
)
} else {
ctrl_desc.addr()
};
let ctrl_hdr: ControlHeader = desc_chain let ctrl_hdr: ControlHeader = desc_chain
.memory() .memory()
.read_obj(ctrl_desc_addr) .read_obj(
ctrl_desc
.addr()
.translate(access_platform, ctrl_desc.len() as usize),
)
.map_err(Error::GuestMemory)?; .map_err(Error::GuestMemory)?;
let data_desc = desc_chain.next().ok_or(Error::NoDataDescriptor)?; let data_desc = desc_chain.next().ok_or(Error::NoDataDescriptor)?;
let data_desc_addr = if let Some(access_platform) = access_platform { let data_desc_addr = data_desc
GuestAddress( .addr()
access_platform .translate(access_platform, data_desc.len() as usize);
.translate(data_desc.addr().0, u64::from(data_desc.len()))
.unwrap(),
)
} else {
data_desc.addr()
};
let status_desc = desc_chain.next().ok_or(Error::NoStatusDescriptor)?; let status_desc = desc_chain.next().ok_or(Error::NoStatusDescriptor)?;
let status_desc_addr = if let Some(access_platform) = access_platform {
GuestAddress(
access_platform
.translate(status_desc.addr().0, u64::from(status_desc.len()))
.unwrap(),
)
} else {
status_desc.addr()
};
let ok = match u32::from(ctrl_hdr.class) { let ok = match u32::from(ctrl_hdr.class) {
VIRTIO_NET_CTRL_MQ => { VIRTIO_NET_CTRL_MQ => {
let queue_pairs = desc_chain let queue_pairs = desc_chain
@ -154,7 +132,9 @@ impl CtrlQueue {
.memory() .memory()
.write_obj( .write_obj(
if ok { VIRTIO_NET_OK } else { VIRTIO_NET_ERR } as u8, if ok { VIRTIO_NET_OK } else { VIRTIO_NET_ERR } as u8,
status_desc_addr, status_desc
.addr()
.translate(access_platform, status_desc.len() as usize),
) )
.map_err(Error::GuestMemory)?; .map_err(Error::GuestMemory)?;
let len = ctrl_desc.len() + data_desc.len() + status_desc.len(); let len = ctrl_desc.len() + data_desc.len() + status_desc.len();

View File

@ -11,8 +11,8 @@ use std::os::unix::io::{AsRawFd, RawFd};
use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc; use std::sync::Arc;
use virtio_queue::Queue; use virtio_queue::Queue;
use vm_memory::{Bytes, GuestAddress, GuestMemory, GuestMemoryAtomic}; use vm_memory::{Bytes, GuestMemory, GuestMemoryAtomic};
use vm_virtio::AccessPlatform; use vm_virtio::{AccessPlatform, Translatable};
#[derive(Clone)] #[derive(Clone)]
pub struct TxVirtio { pub struct TxVirtio {
@ -60,16 +60,7 @@ impl TxVirtio {
let mut iovecs = Vec::new(); let mut iovecs = Vec::new();
while let Some(desc) = next_desc { while let Some(desc) = next_desc {
let desc_addr = if let Some(access_platform) = access_platform { let desc_addr = desc.addr().translate(access_platform, desc.len() as usize);
GuestAddress(
access_platform
.translate(desc.addr().0, u64::from(desc.len()))
.unwrap(),
)
} else {
desc.addr()
};
if !desc.is_write_only() && desc.len() > 0 { if !desc.is_write_only() && desc.len() > 0 {
let buf = desc_chain let buf = desc_chain
.memory() .memory()
@ -195,31 +186,18 @@ impl RxVirtio {
.next() .next()
.ok_or(NetQueuePairError::DescriptorChainTooShort)?; .ok_or(NetQueuePairError::DescriptorChainTooShort)?;
let desc_addr = if let Some(access_platform) = access_platform { let num_buffers_addr = desc_chain
GuestAddress( .memory()
access_platform .checked_offset(
.translate(desc.addr().0, u64::from(desc.len())) desc.addr().translate(access_platform, desc.len() as usize),
.unwrap(), 10,
) )
} else { .unwrap();
desc.addr()
};
let num_buffers_addr = desc_chain.memory().checked_offset(desc_addr, 10).unwrap();
let mut next_desc = Some(desc); let mut next_desc = Some(desc);
let mut iovecs = Vec::new(); let mut iovecs = Vec::new();
while let Some(desc) = next_desc { while let Some(desc) = next_desc {
let desc_addr = if let Some(access_platform) = access_platform { let desc_addr = desc.addr().translate(access_platform, desc.len() as usize);
GuestAddress(
access_platform
.translate(desc.addr().0, u64::from(desc.len()))
.unwrap(),
)
} else {
desc.addr()
};
if desc.is_write_only() && desc.len() > 0 { if desc.is_write_only() && desc.len() > 0 {
let buf = desc_chain let buf = desc_chain
.memory() .memory()

View File

@ -25,10 +25,10 @@ use std::sync::{Arc, Barrier, Mutex};
use versionize::{VersionMap, Versionize, VersionizeResult}; use versionize::{VersionMap, Versionize, VersionizeResult};
use versionize_derive::Versionize; use versionize_derive::Versionize;
use virtio_queue::Queue; use virtio_queue::Queue;
use vm_memory::{ByteValued, Bytes, GuestAddress, GuestMemoryAtomic}; use vm_memory::{ByteValued, Bytes, GuestMemoryAtomic};
use vm_migration::VersionMapped; use vm_migration::VersionMapped;
use vm_migration::{Migratable, MigratableError, Pausable, Snapshot, Snapshottable, Transportable}; use vm_migration::{Migratable, MigratableError, Pausable, Snapshot, Snapshottable, Transportable};
use vm_virtio::AccessPlatform; use vm_virtio::{AccessPlatform, Translatable};
use vmm_sys_util::eventfd::EventFd; use vmm_sys_util::eventfd::EventFd;
const QUEUE_SIZE: u16 = 256; const QUEUE_SIZE: u16 = 256;
@ -148,20 +148,11 @@ impl ConsoleEpollHandler {
let len = cmp::min(desc.len() as u32, in_buffer.len() as u32); let len = cmp::min(desc.len() as u32, in_buffer.len() as u32);
let source_slice = in_buffer.drain(..len as usize).collect::<Vec<u8>>(); let source_slice = in_buffer.drain(..len as usize).collect::<Vec<u8>>();
let desc_addr = if let Some(access_platform) = &self.access_platform { if let Err(e) = desc_chain.memory().write_slice(
GuestAddress( &source_slice[..],
access_platform
.translate(desc.addr().0, u64::from(desc.len()))
.unwrap(),
)
} else {
desc.addr() desc.addr()
}; .translate(self.access_platform.as_ref(), desc.len() as usize),
) {
if let Err(e) = desc_chain
.memory()
.write_slice(&source_slice[..], desc_addr)
{
error!("Failed to write slice: {:?}", e); error!("Failed to write slice: {:?}", e);
avail_iter.go_to_previous_position(); avail_iter.go_to_previous_position();
break; break;
@ -197,19 +188,12 @@ impl ConsoleEpollHandler {
for mut desc_chain in trans_queue.iter().unwrap() { for mut desc_chain in trans_queue.iter().unwrap() {
let desc = desc_chain.next().unwrap(); let desc = desc_chain.next().unwrap();
if let Some(ref mut out) = self.endpoint.out_file() { if let Some(ref mut out) = self.endpoint.out_file() {
let desc_addr = if let Some(access_platform) = &self.access_platform { let _ = desc_chain.memory().write_to(
GuestAddress(
access_platform
.translate(desc.addr().0, u64::from(desc.len()))
.unwrap(),
)
} else {
desc.addr() desc.addr()
}; .translate(self.access_platform.as_ref(), desc.len() as usize),
out,
let _ = desc_chain desc.len() as usize,
.memory() );
.write_to(desc_addr, out, desc.len() as usize);
let _ = out.flush(); let _ = out.flush();
} }
used_desc_heads[used_count] = (desc_chain.head_index(), desc.len()); used_desc_heads[used_count] = (desc_chain.head_index(), desc.len());

View File

@ -34,7 +34,7 @@ use vm_memory::{
}; };
use vm_migration::VersionMapped; use vm_migration::VersionMapped;
use vm_migration::{Migratable, MigratableError, Pausable, Snapshot, Snapshottable, Transportable}; use vm_migration::{Migratable, MigratableError, Pausable, Snapshot, Snapshottable, Transportable};
use vm_virtio::AccessPlatform; use vm_virtio::{AccessPlatform, Translatable};
use vmm_sys_util::eventfd::EventFd; use vmm_sys_util::eventfd::EventFd;
const QUEUE_SIZE: u16 = 256; const QUEUE_SIZE: u16 = 256;
@ -131,19 +131,9 @@ impl Request {
return Err(Error::InvalidRequest); return Err(Error::InvalidRequest);
} }
let desc_addr = if let Some(access_platform) = access_platform {
GuestAddress(
access_platform
.translate(desc.addr().0, u64::from(desc.len()))
.unwrap(),
)
} else {
desc.addr()
};
let request: VirtioPmemReq = desc_chain let request: VirtioPmemReq = desc_chain
.memory() .memory()
.read_obj(desc_addr) .read_obj(desc.addr().translate(access_platform, desc.len() as usize))
.map_err(Error::GuestMemory)?; .map_err(Error::GuestMemory)?;
let request_type = match request.type_ { let request_type = match request.type_ {
@ -162,19 +152,11 @@ impl Request {
return Err(Error::BufferLengthTooSmall); return Err(Error::BufferLengthTooSmall);
} }
let status_desc_addr = if let Some(access_platform) = access_platform {
GuestAddress(
access_platform
.translate(status_desc.addr().0, u64::from(status_desc.len()))
.unwrap(),
)
} else {
status_desc.addr()
};
Ok(Request { Ok(Request {
type_: request_type, type_: request_type,
status_addr: status_desc_addr, status_addr: status_desc
.addr()
.translate(access_platform, status_desc.len() as usize),
}) })
} }
} }

View File

@ -22,10 +22,10 @@ use std::sync::{Arc, Barrier};
use versionize::{VersionMap, Versionize, VersionizeResult}; use versionize::{VersionMap, Versionize, VersionizeResult};
use versionize_derive::Versionize; use versionize_derive::Versionize;
use virtio_queue::Queue; use virtio_queue::Queue;
use vm_memory::{Bytes, GuestAddress, GuestMemoryAtomic}; use vm_memory::{Bytes, GuestMemoryAtomic};
use vm_migration::VersionMapped; use vm_migration::VersionMapped;
use vm_migration::{Migratable, MigratableError, Pausable, Snapshot, Snapshottable, Transportable}; use vm_migration::{Migratable, MigratableError, Pausable, Snapshot, Snapshottable, Transportable};
use vm_virtio::AccessPlatform; use vm_virtio::{AccessPlatform, Translatable};
use vmm_sys_util::eventfd::EventFd; use vmm_sys_util::eventfd::EventFd;
const QUEUE_SIZE: u16 = 256; const QUEUE_SIZE: u16 = 256;
@ -56,20 +56,15 @@ impl RngEpollHandler {
// Drivers can only read from the random device. // Drivers can only read from the random device.
if desc.is_write_only() { if desc.is_write_only() {
let desc_addr = if let Some(access_platform) = &self.access_platform {
GuestAddress(
access_platform
.translate(desc.addr().0, u64::from(desc.len()))
.unwrap(),
)
} else {
desc.addr()
};
// Fill the read with data from the random device on the host. // Fill the read with data from the random device on the host.
if desc_chain if desc_chain
.memory() .memory()
.read_from(desc_addr, &mut self.random_file, desc.len() as usize) .read_from(
desc.addr()
.translate(self.access_platform.as_ref(), desc.len() as usize),
&mut self.random_file,
desc.len() as usize,
)
.is_ok() .is_ok()
{ {
len = desc.len(); len = desc.len();

View File

@ -22,8 +22,8 @@ use super::defs;
use super::{Result, VsockError}; use super::{Result, VsockError};
use crate::{get_host_address_range, GuestMemoryMmap}; use crate::{get_host_address_range, GuestMemoryMmap};
use virtio_queue::DescriptorChain; use virtio_queue::DescriptorChain;
use vm_memory::{GuestAddress, GuestMemoryLoadGuard}; use vm_memory::GuestMemoryLoadGuard;
use vm_virtio::AccessPlatform; use vm_virtio::{AccessPlatform, Translatable};
// The vsock packet header is defined by the C struct: // The vsock packet header is defined by the C struct:
// //
@ -124,19 +124,13 @@ impl VsockPacket {
return Err(VsockError::HdrDescTooSmall(head.len())); return Err(VsockError::HdrDescTooSmall(head.len()));
} }
let head_addr = if let Some(access_platform) = access_platform {
GuestAddress(
access_platform
.translate(head.addr().0, u64::from(head.len()))
.unwrap(),
)
} else {
head.addr()
};
let mut pkt = Self { let mut pkt = Self {
hdr: get_host_address_range(desc_chain.memory(), head_addr, VSOCK_PKT_HDR_SIZE) hdr: get_host_address_range(
.ok_or(VsockError::GuestMemory)? as *mut u8, desc_chain.memory(),
head.addr().translate(access_platform, head.len() as usize),
VSOCK_PKT_HDR_SIZE,
)
.ok_or(VsockError::GuestMemory)? as *mut u8,
buf: None, buf: None,
buf_size: 0, buf_size: 0,
}; };
@ -166,20 +160,16 @@ impl VsockPacket {
return Err(VsockError::BufDescTooSmall); return Err(VsockError::BufDescTooSmall);
} }
let buf_desc_addr = if let Some(access_platform) = access_platform {
GuestAddress(
access_platform
.translate(buf_desc.addr().0, u64::from(buf_desc.len()))
.unwrap(),
)
} else {
buf_desc.addr()
};
pkt.buf_size = buf_desc.len() as usize; pkt.buf_size = buf_desc.len() as usize;
pkt.buf = Some( pkt.buf = Some(
get_host_address_range(desc_chain.memory(), buf_desc_addr, pkt.buf_size) get_host_address_range(
.ok_or(VsockError::GuestMemory)? as *mut u8, desc_chain.memory(),
buf_desc
.addr()
.translate(access_platform, buf_desc.len() as usize),
pkt.buf_size,
)
.ok_or(VsockError::GuestMemory)? as *mut u8,
); );
Ok(pkt) Ok(pkt)
@ -214,29 +204,22 @@ impl VsockPacket {
let buf_desc = desc_chain.next().ok_or(VsockError::BufDescMissing)?; let buf_desc = desc_chain.next().ok_or(VsockError::BufDescMissing)?;
let buf_size = buf_desc.len() as usize; let buf_size = buf_desc.len() as usize;
let (head_addr, buf_desc_addr) = if let Some(access_platform) = access_platform {
(
GuestAddress(
access_platform
.translate(head.addr().0, u64::from(head.len()))
.unwrap(),
),
GuestAddress(
access_platform
.translate(buf_desc.addr().0, u64::from(buf_desc.len()))
.unwrap(),
),
)
} else {
(head.addr(), buf_desc.addr())
};
Ok(Self { Ok(Self {
hdr: get_host_address_range(desc_chain.memory(), head_addr, VSOCK_PKT_HDR_SIZE) hdr: get_host_address_range(
.ok_or(VsockError::GuestMemory)? as *mut u8, desc_chain.memory(),
head.addr().translate(access_platform, head.len() as usize),
VSOCK_PKT_HDR_SIZE,
)
.ok_or(VsockError::GuestMemory)? as *mut u8,
buf: Some( buf: Some(
get_host_address_range(desc_chain.memory(), buf_desc_addr, buf_size) get_host_address_range(
.ok_or(VsockError::GuestMemory)? as *mut u8, desc_chain.memory(),
buf_desc
.addr()
.translate(access_platform, buf_desc.len() as usize),
buf_size,
)
.ok_or(VsockError::GuestMemory)? as *mut u8,
), ),
buf_size, buf_size,
}) })

View File

@ -11,6 +11,9 @@
//! Implements virtio queues //! Implements virtio queues
use std::fmt::{self, Debug}; use std::fmt::{self, Debug};
use std::sync::Arc;
use vm_memory::GuestAddress;
pub mod queue; pub mod queue;
pub use queue::*; pub use queue::*;
@ -94,3 +97,18 @@ pub trait AccessPlatform: Send + Sync + Debug {
/// Provide a way to translate address ranges. /// Provide a way to translate address ranges.
fn translate(&self, base: u64, size: u64) -> std::result::Result<u64, std::io::Error>; fn translate(&self, base: u64, size: u64) -> std::result::Result<u64, std::io::Error>;
} }
pub trait Translatable {
#[must_use]
fn translate(&self, access_platform: Option<&Arc<dyn AccessPlatform>>, len: usize) -> Self;
}
impl Translatable for GuestAddress {
fn translate(&self, access_platform: Option<&Arc<dyn AccessPlatform>>, len: usize) -> Self {
if let Some(access_platform) = access_platform {
GuestAddress(access_platform.translate(self.0, len as u64).unwrap())
} else {
*self
}
}
}