vm-virtio: Define and implement Translatable trait

This new trait simplifies the address translation of a GuestAddress by
having GuestAddress implementing it.

The three crates virtio-devices, block_util and net_util have been
updated accordingly to rely on this new trait, helping with code
readability and limiting the amount of duplicated code.

Signed-off-by: Sebastien Boeuf <sebastien.boeuf@intel.com>
This commit is contained in:
Sebastien Boeuf 2022-01-26 23:44:31 +01:00 committed by Rob Bradford
parent c99d637693
commit 77df4e6773
8 changed files with 106 additions and 205 deletions

View File

@ -40,7 +40,7 @@ use vm_memory::{
bitmap::AtomicBitmap, bitmap::Bitmap, ByteValued, Bytes, GuestAddress, GuestMemory,
GuestMemoryError, GuestMemoryLoadGuard,
};
use vm_virtio::AccessPlatform;
use vm_virtio::{AccessPlatform, Translatable};
use vmm_sys_util::eventfd::EventFd;
type GuestMemoryMmap = vm_memory::GuestMemoryMmap<AtomicBitmap>;
@ -207,15 +207,9 @@ impl Request {
return Err(Error::UnexpectedWriteOnlyDescriptor);
}
let hdr_desc_addr = if let Some(access_platform) = access_platform {
GuestAddress(
access_platform
.translate(hdr_desc.addr().0, u64::from(hdr_desc.len()))
.unwrap(),
)
} else {
hdr_desc.addr()
};
let hdr_desc_addr = hdr_desc
.addr()
.translate(access_platform, hdr_desc.len() as usize);
let mut req = Request {
request_type: request_type(desc_chain.memory(), hdr_desc_addr)?,
@ -254,17 +248,10 @@ impl Request {
return Err(Error::UnexpectedReadOnlyDescriptor);
}
let desc_addr = if let Some(access_platform) = access_platform {
GuestAddress(
access_platform
.translate(desc.addr().0, u64::from(desc.len()))
.unwrap(),
)
} else {
desc.addr()
};
req.data_descriptors.push((desc_addr, desc.len()));
req.data_descriptors.push((
desc.addr().translate(access_platform, desc.len() as usize),
desc.len(),
));
desc = desc_chain
.next()
.ok_or(Error::DescriptorChainTooShort)
@ -285,15 +272,9 @@ impl Request {
return Err(Error::DescriptorLengthTooSmall);
}
req.status_addr = if let Some(access_platform) = access_platform {
GuestAddress(
access_platform
.translate(status_desc.addr().0, u64::from(status_desc.len()))
.unwrap(),
)
} else {
status_desc.addr()
};
req.status_addr = status_desc
.addr()
.translate(access_platform, status_desc.len() as usize);
Ok(req)
}

View File

@ -14,8 +14,8 @@ use virtio_bindings::bindings::virtio_net::{
VIRTIO_NET_F_GUEST_UFO, VIRTIO_NET_OK,
};
use virtio_queue::Queue;
use vm_memory::{ByteValued, Bytes, GuestAddress, GuestMemoryAtomic, GuestMemoryError};
use vm_virtio::AccessPlatform;
use vm_memory::{ByteValued, Bytes, GuestMemoryAtomic, GuestMemoryError};
use vm_virtio::{AccessPlatform, Translatable};
#[derive(Debug)]
pub enum Error {
@ -65,44 +65,22 @@ impl CtrlQueue {
for mut desc_chain in queue.iter().map_err(Error::QueueIterator)? {
let ctrl_desc = desc_chain.next().ok_or(Error::NoControlHeaderDescriptor)?;
let ctrl_desc_addr = if let Some(access_platform) = access_platform {
GuestAddress(
access_platform
.translate(ctrl_desc.addr().0, u64::from(ctrl_desc.len()))
.unwrap(),
)
} else {
ctrl_desc.addr()
};
let ctrl_hdr: ControlHeader = desc_chain
.memory()
.read_obj(ctrl_desc_addr)
.read_obj(
ctrl_desc
.addr()
.translate(access_platform, ctrl_desc.len() as usize),
)
.map_err(Error::GuestMemory)?;
let data_desc = desc_chain.next().ok_or(Error::NoDataDescriptor)?;
let data_desc_addr = if let Some(access_platform) = access_platform {
GuestAddress(
access_platform
.translate(data_desc.addr().0, u64::from(data_desc.len()))
.unwrap(),
)
} else {
data_desc.addr()
};
let data_desc_addr = data_desc
.addr()
.translate(access_platform, data_desc.len() as usize);
let status_desc = desc_chain.next().ok_or(Error::NoStatusDescriptor)?;
let status_desc_addr = if let Some(access_platform) = access_platform {
GuestAddress(
access_platform
.translate(status_desc.addr().0, u64::from(status_desc.len()))
.unwrap(),
)
} else {
status_desc.addr()
};
let ok = match u32::from(ctrl_hdr.class) {
VIRTIO_NET_CTRL_MQ => {
let queue_pairs = desc_chain
@ -154,7 +132,9 @@ impl CtrlQueue {
.memory()
.write_obj(
if ok { VIRTIO_NET_OK } else { VIRTIO_NET_ERR } as u8,
status_desc_addr,
status_desc
.addr()
.translate(access_platform, status_desc.len() as usize),
)
.map_err(Error::GuestMemory)?;
let len = ctrl_desc.len() + data_desc.len() + status_desc.len();

View File

@ -11,8 +11,8 @@ use std::os::unix::io::{AsRawFd, RawFd};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use virtio_queue::Queue;
use vm_memory::{Bytes, GuestAddress, GuestMemory, GuestMemoryAtomic};
use vm_virtio::AccessPlatform;
use vm_memory::{Bytes, GuestMemory, GuestMemoryAtomic};
use vm_virtio::{AccessPlatform, Translatable};
#[derive(Clone)]
pub struct TxVirtio {
@ -60,16 +60,7 @@ impl TxVirtio {
let mut iovecs = Vec::new();
while let Some(desc) = next_desc {
let desc_addr = if let Some(access_platform) = access_platform {
GuestAddress(
access_platform
.translate(desc.addr().0, u64::from(desc.len()))
.unwrap(),
)
} else {
desc.addr()
};
let desc_addr = desc.addr().translate(access_platform, desc.len() as usize);
if !desc.is_write_only() && desc.len() > 0 {
let buf = desc_chain
.memory()
@ -195,31 +186,18 @@ impl RxVirtio {
.next()
.ok_or(NetQueuePairError::DescriptorChainTooShort)?;
let desc_addr = if let Some(access_platform) = access_platform {
GuestAddress(
access_platform
.translate(desc.addr().0, u64::from(desc.len()))
.unwrap(),
let num_buffers_addr = desc_chain
.memory()
.checked_offset(
desc.addr().translate(access_platform, desc.len() as usize),
10,
)
} else {
desc.addr()
};
let num_buffers_addr = desc_chain.memory().checked_offset(desc_addr, 10).unwrap();
.unwrap();
let mut next_desc = Some(desc);
let mut iovecs = Vec::new();
while let Some(desc) = next_desc {
let desc_addr = if let Some(access_platform) = access_platform {
GuestAddress(
access_platform
.translate(desc.addr().0, u64::from(desc.len()))
.unwrap(),
)
} else {
desc.addr()
};
let desc_addr = desc.addr().translate(access_platform, desc.len() as usize);
if desc.is_write_only() && desc.len() > 0 {
let buf = desc_chain
.memory()

View File

@ -25,10 +25,10 @@ use std::sync::{Arc, Barrier, Mutex};
use versionize::{VersionMap, Versionize, VersionizeResult};
use versionize_derive::Versionize;
use virtio_queue::Queue;
use vm_memory::{ByteValued, Bytes, GuestAddress, GuestMemoryAtomic};
use vm_memory::{ByteValued, Bytes, GuestMemoryAtomic};
use vm_migration::VersionMapped;
use vm_migration::{Migratable, MigratableError, Pausable, Snapshot, Snapshottable, Transportable};
use vm_virtio::AccessPlatform;
use vm_virtio::{AccessPlatform, Translatable};
use vmm_sys_util::eventfd::EventFd;
const QUEUE_SIZE: u16 = 256;
@ -148,20 +148,11 @@ impl ConsoleEpollHandler {
let len = cmp::min(desc.len() as u32, in_buffer.len() as u32);
let source_slice = in_buffer.drain(..len as usize).collect::<Vec<u8>>();
let desc_addr = if let Some(access_platform) = &self.access_platform {
GuestAddress(
access_platform
.translate(desc.addr().0, u64::from(desc.len()))
.unwrap(),
)
} else {
if let Err(e) = desc_chain.memory().write_slice(
&source_slice[..],
desc.addr()
};
if let Err(e) = desc_chain
.memory()
.write_slice(&source_slice[..], desc_addr)
{
.translate(self.access_platform.as_ref(), desc.len() as usize),
) {
error!("Failed to write slice: {:?}", e);
avail_iter.go_to_previous_position();
break;
@ -197,19 +188,12 @@ impl ConsoleEpollHandler {
for mut desc_chain in trans_queue.iter().unwrap() {
let desc = desc_chain.next().unwrap();
if let Some(ref mut out) = self.endpoint.out_file() {
let desc_addr = if let Some(access_platform) = &self.access_platform {
GuestAddress(
access_platform
.translate(desc.addr().0, u64::from(desc.len()))
.unwrap(),
)
} else {
let _ = desc_chain.memory().write_to(
desc.addr()
};
let _ = desc_chain
.memory()
.write_to(desc_addr, out, desc.len() as usize);
.translate(self.access_platform.as_ref(), desc.len() as usize),
out,
desc.len() as usize,
);
let _ = out.flush();
}
used_desc_heads[used_count] = (desc_chain.head_index(), desc.len());

View File

@ -34,7 +34,7 @@ use vm_memory::{
};
use vm_migration::VersionMapped;
use vm_migration::{Migratable, MigratableError, Pausable, Snapshot, Snapshottable, Transportable};
use vm_virtio::AccessPlatform;
use vm_virtio::{AccessPlatform, Translatable};
use vmm_sys_util::eventfd::EventFd;
const QUEUE_SIZE: u16 = 256;
@ -131,19 +131,9 @@ impl Request {
return Err(Error::InvalidRequest);
}
let desc_addr = if let Some(access_platform) = access_platform {
GuestAddress(
access_platform
.translate(desc.addr().0, u64::from(desc.len()))
.unwrap(),
)
} else {
desc.addr()
};
let request: VirtioPmemReq = desc_chain
.memory()
.read_obj(desc_addr)
.read_obj(desc.addr().translate(access_platform, desc.len() as usize))
.map_err(Error::GuestMemory)?;
let request_type = match request.type_ {
@ -162,19 +152,11 @@ impl Request {
return Err(Error::BufferLengthTooSmall);
}
let status_desc_addr = if let Some(access_platform) = access_platform {
GuestAddress(
access_platform
.translate(status_desc.addr().0, u64::from(status_desc.len()))
.unwrap(),
)
} else {
status_desc.addr()
};
Ok(Request {
type_: request_type,
status_addr: status_desc_addr,
status_addr: status_desc
.addr()
.translate(access_platform, status_desc.len() as usize),
})
}
}

View File

@ -22,10 +22,10 @@ use std::sync::{Arc, Barrier};
use versionize::{VersionMap, Versionize, VersionizeResult};
use versionize_derive::Versionize;
use virtio_queue::Queue;
use vm_memory::{Bytes, GuestAddress, GuestMemoryAtomic};
use vm_memory::{Bytes, GuestMemoryAtomic};
use vm_migration::VersionMapped;
use vm_migration::{Migratable, MigratableError, Pausable, Snapshot, Snapshottable, Transportable};
use vm_virtio::AccessPlatform;
use vm_virtio::{AccessPlatform, Translatable};
use vmm_sys_util::eventfd::EventFd;
const QUEUE_SIZE: u16 = 256;
@ -56,20 +56,15 @@ impl RngEpollHandler {
// Drivers can only read from the random device.
if desc.is_write_only() {
let desc_addr = if let Some(access_platform) = &self.access_platform {
GuestAddress(
access_platform
.translate(desc.addr().0, u64::from(desc.len()))
.unwrap(),
)
} else {
desc.addr()
};
// Fill the read with data from the random device on the host.
if desc_chain
.memory()
.read_from(desc_addr, &mut self.random_file, desc.len() as usize)
.read_from(
desc.addr()
.translate(self.access_platform.as_ref(), desc.len() as usize),
&mut self.random_file,
desc.len() as usize,
)
.is_ok()
{
len = desc.len();

View File

@ -22,8 +22,8 @@ use super::defs;
use super::{Result, VsockError};
use crate::{get_host_address_range, GuestMemoryMmap};
use virtio_queue::DescriptorChain;
use vm_memory::{GuestAddress, GuestMemoryLoadGuard};
use vm_virtio::AccessPlatform;
use vm_memory::GuestMemoryLoadGuard;
use vm_virtio::{AccessPlatform, Translatable};
// The vsock packet header is defined by the C struct:
//
@ -124,19 +124,13 @@ impl VsockPacket {
return Err(VsockError::HdrDescTooSmall(head.len()));
}
let head_addr = if let Some(access_platform) = access_platform {
GuestAddress(
access_platform
.translate(head.addr().0, u64::from(head.len()))
.unwrap(),
)
} else {
head.addr()
};
let mut pkt = Self {
hdr: get_host_address_range(desc_chain.memory(), head_addr, VSOCK_PKT_HDR_SIZE)
.ok_or(VsockError::GuestMemory)? as *mut u8,
hdr: get_host_address_range(
desc_chain.memory(),
head.addr().translate(access_platform, head.len() as usize),
VSOCK_PKT_HDR_SIZE,
)
.ok_or(VsockError::GuestMemory)? as *mut u8,
buf: None,
buf_size: 0,
};
@ -166,20 +160,16 @@ impl VsockPacket {
return Err(VsockError::BufDescTooSmall);
}
let buf_desc_addr = if let Some(access_platform) = access_platform {
GuestAddress(
access_platform
.translate(buf_desc.addr().0, u64::from(buf_desc.len()))
.unwrap(),
)
} else {
buf_desc.addr()
};
pkt.buf_size = buf_desc.len() as usize;
pkt.buf = Some(
get_host_address_range(desc_chain.memory(), buf_desc_addr, pkt.buf_size)
.ok_or(VsockError::GuestMemory)? as *mut u8,
get_host_address_range(
desc_chain.memory(),
buf_desc
.addr()
.translate(access_platform, buf_desc.len() as usize),
pkt.buf_size,
)
.ok_or(VsockError::GuestMemory)? as *mut u8,
);
Ok(pkt)
@ -214,29 +204,22 @@ impl VsockPacket {
let buf_desc = desc_chain.next().ok_or(VsockError::BufDescMissing)?;
let buf_size = buf_desc.len() as usize;
let (head_addr, buf_desc_addr) = if let Some(access_platform) = access_platform {
(
GuestAddress(
access_platform
.translate(head.addr().0, u64::from(head.len()))
.unwrap(),
),
GuestAddress(
access_platform
.translate(buf_desc.addr().0, u64::from(buf_desc.len()))
.unwrap(),
),
)
} else {
(head.addr(), buf_desc.addr())
};
Ok(Self {
hdr: get_host_address_range(desc_chain.memory(), head_addr, VSOCK_PKT_HDR_SIZE)
.ok_or(VsockError::GuestMemory)? as *mut u8,
hdr: get_host_address_range(
desc_chain.memory(),
head.addr().translate(access_platform, head.len() as usize),
VSOCK_PKT_HDR_SIZE,
)
.ok_or(VsockError::GuestMemory)? as *mut u8,
buf: Some(
get_host_address_range(desc_chain.memory(), buf_desc_addr, buf_size)
.ok_or(VsockError::GuestMemory)? as *mut u8,
get_host_address_range(
desc_chain.memory(),
buf_desc
.addr()
.translate(access_platform, buf_desc.len() as usize),
buf_size,
)
.ok_or(VsockError::GuestMemory)? as *mut u8,
),
buf_size,
})

View File

@ -11,6 +11,9 @@
//! Implements virtio queues
use std::fmt::{self, Debug};
use std::sync::Arc;
use vm_memory::GuestAddress;
pub mod queue;
pub use queue::*;
@ -94,3 +97,18 @@ pub trait AccessPlatform: Send + Sync + Debug {
/// Provide a way to translate address ranges.
fn translate(&self, base: u64, size: u64) -> std::result::Result<u64, std::io::Error>;
}
pub trait Translatable {
#[must_use]
fn translate(&self, access_platform: Option<&Arc<dyn AccessPlatform>>, len: usize) -> Self;
}
impl Translatable for GuestAddress {
fn translate(&self, access_platform: Option<&Arc<dyn AccessPlatform>>, len: usize) -> Self {
if let Some(access_platform) = access_platform {
GuestAddress(access_platform.translate(self.0, len as u64).unwrap())
} else {
*self
}
}
}