virtio-devices: vsock: Handle descriptor address translation

Since we're trying to move away from the translation happening in the
virtio-queue crate, the device itself is performing the address
translation when needed.

Signed-off-by: Sebastien Boeuf <sebastien.boeuf@intel.com>
This commit is contained in:
Sebastien Boeuf 2022-01-26 17:15:50 +01:00 committed by Rob Bradford
parent 09f5b82fd7
commit e2225bb4b0
5 changed files with 70 additions and 9 deletions

View File

@ -823,6 +823,7 @@ mod tests {
.unwrap()
.next()
.unwrap(),
None,
)
.unwrap();
let conn = match conn_state {

View File

@ -47,7 +47,7 @@ use std::sync::atomic::AtomicBool;
use std::sync::{Arc, Barrier, RwLock};
use versionize::{VersionMap, Versionize, VersionizeResult};
use versionize_derive::Versionize;
use virtio_queue::Queue;
use virtio_queue::{AccessPlatform, Queue};
use vm_memory::GuestMemoryAtomic;
use vm_migration::{
Migratable, MigratableError, Pausable, Snapshot, Snapshottable, Transportable, VersionMapped,
@ -93,6 +93,7 @@ pub struct VsockEpollHandler<B: VsockBackend> {
pub pause_evt: EventFd,
pub interrupt_cb: Arc<dyn VirtioInterrupt>,
pub backend: Arc<RwLock<B>>,
pub access_platform: Option<Arc<dyn AccessPlatform>>,
}
impl<B> VsockEpollHandler<B>
@ -124,7 +125,10 @@ where
let mut avail_iter = self.queues[0].iter().map_err(DeviceError::QueueIterator)?;
for mut desc_chain in &mut avail_iter {
let used_len = match VsockPacket::from_rx_virtq_head(&mut desc_chain) {
let used_len = match VsockPacket::from_rx_virtq_head(
&mut desc_chain,
self.access_platform.as_ref(),
) {
Ok(mut pkt) => {
if self.backend.write().unwrap().recv_pkt(&mut pkt).is_ok() {
pkt.hdr().len() as u32 + pkt.len()
@ -169,7 +173,10 @@ where
let mut avail_iter = self.queues[1].iter().map_err(DeviceError::QueueIterator)?;
for mut desc_chain in &mut avail_iter {
let pkt = match VsockPacket::from_tx_virtq_head(&mut desc_chain) {
let pkt = match VsockPacket::from_tx_virtq_head(
&mut desc_chain,
self.access_platform.as_ref(),
) {
Ok(pkt) => pkt,
Err(e) => {
error!("vsock: error reading TX packet: {:?}", e);
@ -438,6 +445,7 @@ where
pause_evt,
interrupt_cb,
backend: self.backend.clone(),
access_platform: self.common.access_platform.clone(),
};
let paused = self.common.paused.clone();
@ -472,6 +480,10 @@ where
fn shutdown(&mut self) {
std::fs::remove_file(&self.path).ok();
}
fn set_access_platform(&mut self, access_platform: Arc<dyn AccessPlatform>) {
self.common.set_access_platform(access_platform)
}
}
impl<B> Pausable for Vsock<B>

View File

@ -328,6 +328,7 @@ mod tests {
pause_evt: EventFd::new(EFD_NONBLOCK).unwrap(),
interrupt_cb,
backend: Arc::new(RwLock::new(TestBackend::new())),
access_platform: None,
},
}
}

View File

@ -16,12 +16,13 @@
/// to temporary buffers, before passing it on to the vsock backend.
///
use byteorder::{ByteOrder, LittleEndian};
use std::sync::Arc;
use super::defs;
use super::{Result, VsockError};
use crate::{get_host_address_range, GuestMemoryMmap};
use virtio_queue::DescriptorChain;
use vm_memory::GuestMemoryLoadGuard;
use virtio_queue::{AccessPlatform, DescriptorChain};
use vm_memory::{GuestAddress, GuestMemoryLoadGuard};
// The vsock packet header is defined by the C struct:
//
@ -107,6 +108,7 @@ impl VsockPacket {
///
pub fn from_tx_virtq_head(
desc_chain: &mut DescriptorChain<GuestMemoryLoadGuard<GuestMemoryMmap>>,
access_platform: Option<&Arc<dyn AccessPlatform>>,
) -> Result<Self> {
let head = desc_chain.next().ok_or(VsockError::HdrDescMissing)?;
@ -121,8 +123,18 @@ impl VsockPacket {
return Err(VsockError::HdrDescTooSmall(head.len()));
}
let head_addr = if let Some(access_platform) = access_platform {
GuestAddress(
access_platform
.translate(head.addr().0, u64::from(head.len()))
.unwrap(),
)
} else {
head.addr()
};
let mut pkt = Self {
hdr: get_host_address_range(desc_chain.memory(), head.addr(), VSOCK_PKT_HDR_SIZE)
hdr: get_host_address_range(desc_chain.memory(), head_addr, VSOCK_PKT_HDR_SIZE)
.ok_or(VsockError::GuestMemory)? as *mut u8,
buf: None,
buf_size: 0,
@ -153,9 +165,19 @@ impl VsockPacket {
return Err(VsockError::BufDescTooSmall);
}
let buf_desc_addr = if let Some(access_platform) = access_platform {
GuestAddress(
access_platform
.translate(buf_desc.addr().0, u64::from(buf_desc.len()))
.unwrap(),
)
} else {
buf_desc.addr()
};
pkt.buf_size = buf_desc.len() as usize;
pkt.buf = Some(
get_host_address_range(desc_chain.memory(), buf_desc.addr(), pkt.buf_size)
get_host_address_range(desc_chain.memory(), buf_desc_addr, pkt.buf_size)
.ok_or(VsockError::GuestMemory)? as *mut u8,
);
@ -169,6 +191,7 @@ impl VsockPacket {
///
pub fn from_rx_virtq_head(
desc_chain: &mut DescriptorChain<GuestMemoryLoadGuard<GuestMemoryMmap>>,
access_platform: Option<&Arc<dyn AccessPlatform>>,
) -> Result<Self> {
let head = desc_chain.next().ok_or(VsockError::HdrDescMissing)?;
@ -190,11 +213,28 @@ impl VsockPacket {
let buf_desc = desc_chain.next().ok_or(VsockError::BufDescMissing)?;
let buf_size = buf_desc.len() as usize;
let (head_addr, buf_desc_addr) = if let Some(access_platform) = access_platform {
(
GuestAddress(
access_platform
.translate(head.addr().0, u64::from(head.len()))
.unwrap(),
),
GuestAddress(
access_platform
.translate(buf_desc.addr().0, u64::from(buf_desc.len()))
.unwrap(),
),
)
} else {
(head.addr(), buf_desc.addr())
};
Ok(Self {
hdr: get_host_address_range(desc_chain.memory(), head.addr(), VSOCK_PKT_HDR_SIZE)
hdr: get_host_address_range(desc_chain.memory(), head_addr, VSOCK_PKT_HDR_SIZE)
.ok_or(VsockError::GuestMemory)? as *mut u8,
buf: Some(
get_host_address_range(desc_chain.memory(), buf_desc.addr(), buf_size)
get_host_address_range(desc_chain.memory(), buf_desc_addr, buf_size)
.ok_or(VsockError::GuestMemory)? as *mut u8,
),
buf_size,
@ -380,6 +420,7 @@ mod tests {
.unwrap()
.next()
.unwrap(),
None,
) {
Err($err) => (),
Ok(_) => panic!("Packet assembly should've failed!"),
@ -410,6 +451,7 @@ mod tests {
.unwrap()
.next()
.unwrap(),
None,
)
.unwrap();
assert_eq!(pkt.hdr().len(), VSOCK_PKT_HDR_SIZE);
@ -447,6 +489,7 @@ mod tests {
.unwrap()
.next()
.unwrap(),
None,
)
.unwrap();
assert!(pkt.buf().is_none());
@ -504,6 +547,7 @@ mod tests {
.unwrap()
.next()
.unwrap(),
None,
)
.unwrap();
assert_eq!(pkt.hdr().len(), VSOCK_PKT_HDR_SIZE);
@ -560,6 +604,7 @@ mod tests {
.unwrap()
.next()
.unwrap(),
None,
)
.unwrap();
@ -650,6 +695,7 @@ mod tests {
.unwrap()
.next()
.unwrap(),
None,
)
.unwrap();

View File

@ -846,6 +846,7 @@ mod tests {
.unwrap()
.next()
.unwrap(),
None,
)
.unwrap();
let uds_path = format!("test_vsock_{}.sock", name);