diff --git a/virtio-devices/src/pmem.rs b/virtio-devices/src/pmem.rs index dab7a2ddb..1227e93d2 100644 --- a/virtio-devices/src/pmem.rs +++ b/virtio-devices/src/pmem.rs @@ -27,7 +27,7 @@ use std::sync::atomic::AtomicBool; use std::sync::{Arc, Barrier}; use versionize::{VersionMap, Versionize, VersionizeResult}; use versionize_derive::Versionize; -use virtio_queue::{DescriptorChain, Queue}; +use virtio_queue::{AccessPlatform, DescriptorChain, Queue}; use vm_memory::{ Address, ByteValued, Bytes, GuestAddress, GuestMemoryAtomic, GuestMemoryError, GuestMemoryLoadGuard, @@ -118,6 +118,7 @@ struct Request { impl Request { fn parse( desc_chain: &mut DescriptorChain>, + access_platform: Option<&Arc>, ) -> result::Result { let desc = desc_chain.next().ok_or(Error::DescriptorChainTooShort)?; // The descriptor contains the request type which MUST be readable. @@ -129,9 +130,19 @@ impl Request { return Err(Error::InvalidRequest); } + let desc_addr = if let Some(access_platform) = access_platform { + GuestAddress( + access_platform + .translate(desc.addr().0, u64::from(desc.len())) + .unwrap(), + ) + } else { + desc.addr() + }; + let request: VirtioPmemReq = desc_chain .memory() - .read_obj(desc.addr()) + .read_obj(desc_addr) .map_err(Error::GuestMemory)?; let request_type = match request.type_ { @@ -150,9 +161,19 @@ impl Request { return Err(Error::BufferLengthTooSmall); } + let status_desc_addr = if let Some(access_platform) = access_platform { + GuestAddress( + access_platform + .translate(status_desc.addr().0, u64::from(status_desc.len())) + .unwrap(), + ) + } else { + status_desc.addr() + }; + Ok(Request { type_: request_type, - status_addr: status_desc.addr(), + status_addr: status_desc_addr, }) } } @@ -164,6 +185,7 @@ struct PmemEpollHandler { queue_evt: EventFd, kill_evt: EventFd, pause_evt: EventFd, + access_platform: Option>, } impl PmemEpollHandler { @@ -171,7 +193,7 @@ impl PmemEpollHandler { let mut used_desc_heads = [(0, 0); QUEUE_SIZE as usize]; let mut used_count = 0; for mut desc_chain in self.queue.iter().unwrap() { - let len = match Request::parse(&mut desc_chain) { + let len = match Request::parse(&mut desc_chain, self.access_platform.as_ref()) { Ok(ref req) if (req.type_ == RequestType::Flush) => { let status_code = match self.disk.sync_all() { Ok(()) => VIRTIO_PMEM_RESP_TYPE_OK, @@ -388,6 +410,7 @@ impl VirtioDevice for Pmem { queue_evt: queue_evts.remove(0), kill_evt, pause_evt, + access_platform: self.common.access_platform.clone(), }; let paused = self.common.paused.clone(); @@ -424,6 +447,10 @@ impl VirtioDevice for Pmem { fn userspace_mappings(&self) -> Vec { vec![self.mapping.clone()] } + + fn set_access_platform(&mut self, access_platform: Arc) { + self.common.set_access_platform(access_platform) + } } impl Pausable for Pmem {