diff --git a/vm-virtio/src/device.rs b/vm-virtio/src/device.rs index 8b3b6042a..552f3cab0 100644 --- a/vm-virtio/src/device.rs +++ b/vm-virtio/src/device.rs @@ -22,6 +22,9 @@ pub type VirtioInterrupt = Box< + Sync, >; +pub type VirtioIommuRemapping = + Box std::result::Result + Send + Sync>; + #[derive(Clone)] pub struct VirtioSharedMemory { pub offset: u64, @@ -83,4 +86,18 @@ pub trait VirtioDevice: Send { fn get_shm_regions(&self) -> Option { None } + + fn iommu_translate(&self, addr: u64) -> u64 { + addr + } +} + +/// Trait providing address translation the same way a physical DMA remapping +/// table would provide translation between an IOVA and a physical address. +/// The goal of this trait is to be used by virtio devices to perform the +/// address translation before they try to read from the guest physical address. +/// On the other side, the implementation itself should be provided by the code +/// emulating the IOMMU for the guest. +pub trait DmaRemapping: Send + Sync { + fn translate(&self, id: u32, addr: u64) -> std::result::Result; } diff --git a/vm-virtio/src/queue.rs b/vm-virtio/src/queue.rs index ffe122c2d..7cee6a37d 100644 --- a/vm-virtio/src/queue.rs +++ b/vm-virtio/src/queue.rs @@ -11,7 +11,9 @@ use std::cmp::min; use std::num::Wrapping; use std::sync::atomic::{fence, Ordering}; +use std::sync::Arc; +use crate::device::VirtioIommuRemapping; use vm_memory::{ Address, ByteValued, Bytes, GuestAddress, GuestMemory, GuestMemoryMmap, GuestUsize, }; @@ -44,6 +46,7 @@ pub struct DescriptorChain<'a> { desc_table: GuestAddress, queue_size: u16, ttl: u16, // used to prevent infinite chain cycles + iommu_mapping_cb: Option>, /// Reference to guest memory pub mem: &'a GuestMemoryMmap, @@ -71,6 +74,7 @@ impl<'a> DescriptorChain<'a> { desc_table: GuestAddress, queue_size: u16, index: u16, + iommu_mapping_cb: Option>, ) -> Option { if index >= queue_size { return None; @@ -91,16 +95,25 @@ impl<'a> DescriptorChain<'a> { return None; } }; + + // Translate address if necessary + let desc_addr = if let Some(iommu_mapping_cb) = &iommu_mapping_cb { + (iommu_mapping_cb)(desc.addr).unwrap() + } else { + desc.addr + }; + let chain = DescriptorChain { mem, desc_table, queue_size, ttl: queue_size, index, - addr: GuestAddress(desc.addr), + addr: GuestAddress(desc_addr), len: desc.len, flags: desc.flags, next: desc.next, + iommu_mapping_cb, }; if chain.is_valid() { @@ -137,12 +150,17 @@ impl<'a> DescriptorChain<'a> { /// the head of the next _available_ descriptor chain. pub fn next_descriptor(&self) -> Option> { if self.has_next() { - DescriptorChain::checked_new(self.mem, self.desc_table, self.queue_size, self.next).map( - |mut c| { - c.ttl = self.ttl - 1; - c - }, + DescriptorChain::checked_new( + self.mem, + self.desc_table, + self.queue_size, + self.next, + self.iommu_mapping_cb.clone(), ) + .map(|mut c| { + c.ttl = self.ttl - 1; + c + }) } else { None } @@ -158,6 +176,7 @@ pub struct AvailIter<'a, 'b> { last_index: Wrapping, queue_size: u16, next_avail: &'b mut Wrapping, + iommu_mapping_cb: Option>, } impl<'a, 'b> AvailIter<'a, 'b> { @@ -170,6 +189,7 @@ impl<'a, 'b> AvailIter<'a, 'b> { last_index: Wrapping(0), queue_size: 0, next_avail: q_next_avail, + iommu_mapping_cb: None, } } } @@ -199,8 +219,13 @@ impl<'a, 'b> Iterator for AvailIter<'a, 'b> { self.next_index += Wrapping(1); - let ret = - DescriptorChain::checked_new(self.mem, self.desc_table, self.queue_size, desc_index); + let ret = DescriptorChain::checked_new( + self.mem, + self.desc_table, + self.queue_size, + desc_index, + self.iommu_mapping_cb.clone(), + ); if ret.is_some() { *self.next_avail += Wrapping(1); } @@ -234,6 +259,8 @@ pub struct Queue { pub next_avail: Wrapping, pub next_used: Wrapping, + + pub iommu_mapping_cb: Option>, } impl Queue { @@ -249,6 +276,7 @@ impl Queue { used_ring: GuestAddress(0), next_avail: Wrapping(0), next_used: Wrapping(0), + iommu_mapping_cb: None, } } @@ -256,6 +284,26 @@ impl Queue { self.max_size } + pub fn enable(&mut self, set: bool) { + self.ready = set; + + if set { + // Translate address of descriptor table and vrings. + if let Some(iommu_mapping_cb) = &self.iommu_mapping_cb { + self.desc_table = + GuestAddress((iommu_mapping_cb)(self.desc_table.raw_value()).unwrap()); + self.avail_ring = + GuestAddress((iommu_mapping_cb)(self.avail_ring.raw_value()).unwrap()); + self.used_ring = + GuestAddress((iommu_mapping_cb)(self.used_ring.raw_value()).unwrap()); + } + } else { + self.desc_table = GuestAddress(0); + self.avail_ring = GuestAddress(0); + self.used_ring = GuestAddress(0); + } + } + /// Return the actual size of the queue, as the driver may not set up a /// queue as big as the device allows. pub fn actual_size(&self) -> u16 { @@ -354,6 +402,7 @@ impl Queue { last_index: Wrapping(last_index), queue_size, next_avail: &mut self.next_avail, + iommu_mapping_cb: self.iommu_mapping_cb.clone(), } } @@ -647,14 +696,16 @@ pub(crate) mod tests { assert!(vq.end().0 < 0x1000); // index >= queue_size - assert!(DescriptorChain::checked_new(m, vq.start(), 16, 16).is_none()); + assert!(DescriptorChain::checked_new(m, vq.start(), 16, 16, None).is_none()); // desc_table address is way off - assert!(DescriptorChain::checked_new(m, GuestAddress(0x00ff_ffff_ffff), 16, 0).is_none()); + assert!( + DescriptorChain::checked_new(m, GuestAddress(0x00ff_ffff_ffff), 16, 0, None).is_none() + ); // the addr field of the descriptor is way off vq.dtable[0].addr.set(0x0fff_ffff_ffff); - assert!(DescriptorChain::checked_new(m, vq.start(), 16, 0).is_none()); + assert!(DescriptorChain::checked_new(m, vq.start(), 16, 0, None).is_none()); // let's create some invalid chains @@ -663,7 +714,7 @@ pub(crate) mod tests { vq.dtable[0].addr.set(0x1000); // ...but the length is too large vq.dtable[0].len.set(0xffff_ffff); - assert!(DescriptorChain::checked_new(m, vq.start(), 16, 0).is_none()); + assert!(DescriptorChain::checked_new(m, vq.start(), 16, 0, None).is_none()); } { @@ -673,7 +724,7 @@ pub(crate) mod tests { //..but the the index of the next descriptor is too large vq.dtable[0].next.set(16); - assert!(DescriptorChain::checked_new(m, vq.start(), 16, 0).is_none()); + assert!(DescriptorChain::checked_new(m, vq.start(), 16, 0, None).is_none()); } // finally, let's test an ok chain @@ -682,7 +733,7 @@ pub(crate) mod tests { vq.dtable[0].next.set(1); vq.dtable[1].set(0x2000, 0x1000, 0, 0); - let c = DescriptorChain::checked_new(m, vq.start(), 16, 0).unwrap(); + let c = DescriptorChain::checked_new(m, vq.start(), 16, 0, None).unwrap(); assert_eq!(c.mem as *const GuestMemoryMmap, m as *const GuestMemoryMmap); assert_eq!(c.desc_table, vq.start()); diff --git a/vm-virtio/src/transport/pci_common_config.rs b/vm-virtio/src/transport/pci_common_config.rs index e2258701d..835b1190d 100644 --- a/vm-virtio/src/transport/pci_common_config.rs +++ b/vm-virtio/src/transport/pci_common_config.rs @@ -149,7 +149,7 @@ impl VirtioPciCommonConfig { 0x16 => self.queue_select = value, 0x18 => self.with_queue_mut(queues, |q| q.size = value), 0x1a => self.with_queue_mut(queues, |q| q.vector = value), - 0x1c => self.with_queue_mut(queues, |q| q.ready = value == 1), + 0x1c => self.with_queue_mut(queues, |q| q.enable(value == 1)), _ => { warn!("invalid virtio register word write: 0x{:x}", offset); } diff --git a/vm-virtio/src/transport/pci_device.rs b/vm-virtio/src/transport/pci_device.rs index 83440e55d..d25fc6f85 100755 --- a/vm-virtio/src/transport/pci_device.rs +++ b/vm-virtio/src/transport/pci_device.rs @@ -31,8 +31,8 @@ use vmm_sys_util::{errno::Result, eventfd::EventFd}; use super::VirtioPciCommonConfig; use crate::{ Queue, VirtioDevice, VirtioDeviceType, VirtioInterrupt, VirtioInterruptType, - DEVICE_ACKNOWLEDGE, DEVICE_DRIVER, DEVICE_DRIVER_OK, DEVICE_FAILED, DEVICE_FEATURES_OK, - DEVICE_INIT, INTERRUPT_STATUS_CONFIG_CHANGED, INTERRUPT_STATUS_USED_RING, + VirtioIommuRemapping, DEVICE_ACKNOWLEDGE, DEVICE_DRIVER, DEVICE_DRIVER_OK, DEVICE_FAILED, + DEVICE_FEATURES_OK, DEVICE_INIT, INTERRUPT_STATUS_CONFIG_CHANGED, INTERRUPT_STATUS_USED_RING, }; #[allow(clippy::enum_variant_names)] @@ -250,6 +250,7 @@ impl VirtioPciDevice { memory: Arc>, device: Box, msix_num: u16, + iommu_mapping_cb: Option>, ) -> Result { let mut queue_evts = Vec::new(); for _ in device.queue_max_sizes().iter() { @@ -258,7 +259,11 @@ impl VirtioPciDevice { let queues = device .queue_max_sizes() .iter() - .map(|&s| Queue::new(s)) + .map(|&s| { + let mut queue = Queue::new(s); + queue.iommu_mapping_cb = iommu_mapping_cb.clone(); + queue + }) .collect(); let pci_device_id = VIRTIO_PCI_DEVICE_ID_BASE + device.device_type() as u16; diff --git a/vmm/src/device_manager.rs b/vmm/src/device_manager.rs index 3f53ee65e..5d5c7e9ca 100644 --- a/vmm/src/device_manager.rs +++ b/vmm/src/device_manager.rs @@ -943,8 +943,9 @@ impl DeviceManager { 0 }; - let mut virtio_pci_device = VirtioPciDevice::new(memory.clone(), virtio_device, msix_num) - .map_err(DeviceManagerError::VirtioDevice)?; + let mut virtio_pci_device = + VirtioPciDevice::new(memory.clone(), virtio_device, msix_num, None) + .map_err(DeviceManagerError::VirtioDevice)?; let bars = virtio_pci_device .allocate_bars(allocator)