diff --git a/block_util/src/lib.rs b/block_util/src/lib.rs index bfac7287d..273cc48df 100644 --- a/block_util/src/lib.rs +++ b/block_util/src/lib.rs @@ -30,11 +30,12 @@ use std::io::{self, IoSlice, IoSliceMut, Read, Seek, SeekFrom, Write}; use std::os::linux::fs::MetadataExt; use std::path::Path; use std::result; +use std::sync::Arc; use std::sync::MutexGuard; use versionize::{VersionMap, Versionize, VersionizeResult}; use versionize_derive::Versionize; use virtio_bindings::bindings::virtio_blk::*; -use virtio_queue::DescriptorChain; +use virtio_queue::{AccessPlatform, DescriptorChain}; use vm_memory::{ bitmap::AtomicBitmap, bitmap::Bitmap, ByteValued, Bytes, GuestAddress, GuestMemory, GuestMemoryError, GuestMemoryLoadGuard, @@ -190,6 +191,7 @@ pub struct Request { impl Request { pub fn parse( desc_chain: &mut DescriptorChain>, + access_platform: Option<&Arc>, ) -> result::Result { let hdr_desc = desc_chain .next() @@ -204,9 +206,19 @@ impl Request { return Err(Error::UnexpectedWriteOnlyDescriptor); } + let hdr_desc_addr = if let Some(access_platform) = access_platform { + GuestAddress( + access_platform + .translate(hdr_desc.addr().0, u64::from(hdr_desc.len())) + .unwrap(), + ) + } else { + hdr_desc.addr() + }; + let mut req = Request { - request_type: request_type(desc_chain.memory(), hdr_desc.addr())?, - sector: sector(desc_chain.memory(), hdr_desc.addr())?, + request_type: request_type(desc_chain.memory(), hdr_desc_addr)?, + sector: sector(desc_chain.memory(), hdr_desc_addr)?, data_descriptors: Vec::new(), status_addr: GuestAddress(0), writeback: true, @@ -240,7 +252,18 @@ impl Request { if !desc.is_write_only() && req.request_type == RequestType::GetDeviceId { return Err(Error::UnexpectedReadOnlyDescriptor); } - req.data_descriptors.push((desc.addr(), desc.len())); + + let desc_addr = if let Some(access_platform) = access_platform { + GuestAddress( + access_platform + .translate(desc.addr().0, u64::from(desc.len())) + .unwrap(), + ) + } else { + desc.addr() + }; + + req.data_descriptors.push((desc_addr, desc.len())); desc = desc_chain .next() .ok_or(Error::DescriptorChainTooShort) @@ -261,7 +284,15 @@ impl Request { return Err(Error::DescriptorLengthTooSmall); } - req.status_addr = status_desc.addr(); + req.status_addr = if let Some(access_platform) = access_platform { + GuestAddress( + access_platform + .translate(status_desc.addr().0, u64::from(status_desc.len())) + .unwrap(), + ) + } else { + status_desc.addr() + }; Ok(req) } diff --git a/vhost_user_block/src/lib.rs b/vhost_user_block/src/lib.rs index d821d22da..eb23893b9 100644 --- a/vhost_user_block/src/lib.rs +++ b/vhost_user_block/src/lib.rs @@ -117,7 +117,7 @@ impl VhostUserBlkThread { while let Some(mut desc_chain) = vring.mut_queue().iter().unwrap().next() { debug!("got an element in the queue"); let len; - match Request::parse(&mut desc_chain) { + match Request::parse(&mut desc_chain, None) { Ok(mut request) => { debug!("element is a valid request"); request.set_writeback(self.writeback.load(Ordering::Acquire)); diff --git a/virtio-devices/src/block.rs b/virtio-devices/src/block.rs index 9d7de4f62..e5c17a6ff 100644 --- a/virtio-devices/src/block.rs +++ b/virtio-devices/src/block.rs @@ -35,7 +35,7 @@ use std::{collections::HashMap, convert::TryInto}; use versionize::{VersionMap, Versionize, VersionizeResult}; use versionize_derive::Versionize; use virtio_bindings::bindings::virtio_blk::*; -use virtio_queue::Queue; +use virtio_queue::{AccessPlatform, Queue}; use vm_memory::{ByteValued, Bytes, GuestAddressSpace, GuestMemoryAtomic}; use vm_migration::VersionMapped; use vm_migration::{Migratable, MigratableError, Pausable, Snapshot, Snapshottable, Transportable}; @@ -96,6 +96,7 @@ struct BlockEpollHandler { queue_evt: EventFd, request_list: HashMap, rate_limiter: Option, + access_platform: Option>, } impl BlockEpollHandler { @@ -107,7 +108,8 @@ impl BlockEpollHandler { let mut avail_iter = queue.iter().map_err(Error::QueueIterator)?; for mut desc_chain in &mut avail_iter { - let mut request = Request::parse(&mut desc_chain).map_err(Error::RequestParsing)?; + let mut request = Request::parse(&mut desc_chain, self.access_platform.as_ref()) + .map_err(Error::RequestParsing)?; if let Some(rate_limiter) = &mut self.rate_limiter { // If limiter.consume() fails it means there is no more TokenType::Ops @@ -626,6 +628,7 @@ impl VirtioDevice for Block { queue_evt, request_list: HashMap::with_capacity(queue_size.into()), rate_limiter, + access_platform: self.common.access_platform.clone(), }; let paused = self.common.paused.clone(); @@ -679,6 +682,10 @@ impl VirtioDevice for Block { Some(counters) } + + fn set_access_platform(&mut self, access_platform: Arc) { + self.common.set_access_platform(access_platform) + } } impl Pausable for Block {