diff --git a/pci/src/vfio.rs b/pci/src/vfio.rs index c0e064081..da43700c5 100644 --- a/pci/src/vfio.rs +++ b/pci/src/vfio.rs @@ -354,6 +354,10 @@ impl VfioPciDevice { Ok(vfio_pci_device) } + pub fn iommu_attached(&self) -> bool { + self.iommu_attached + } + fn enable_intx(&mut self) -> Result<()> { if let Some(intx) = &mut self.interrupt.intx { if !intx.enabled { diff --git a/vmm/src/device_manager.rs b/vmm/src/device_manager.rs index 65877f948..af666d182 100644 --- a/vmm/src/device_manager.rs +++ b/vmm/src/device_manager.rs @@ -407,6 +407,12 @@ pub enum DeviceManagerError { /// Failed to create FixedVhdDiskSync CreateFixedVhdDiskSync(io::Error), + + /// Failed adding DMA mapping handler to virtio-mem device. + AddDmaMappingHandlerVirtioMem(virtio_devices::mem::Error), + + /// Failed removing DMA mapping handler from virtio-mem device. + RemoveDmaMappingHandlerVirtioMem(virtio_devices::mem::Error), } pub type DeviceManagerResult = result::Result; @@ -898,6 +904,9 @@ pub struct DeviceManager { #[cfg(feature = "acpi")] acpi_address: GuestAddress, + + // Possible handle to the virtio-balloon device + virtio_mem_devices: Vec>>, } impl DeviceManager { @@ -983,6 +992,7 @@ impl DeviceManager { acpi_address, serial_pty: None, console_pty: None, + virtio_mem_devices: Vec::new(), }; let device_manager = Arc::new(Mutex::new(device_manager)); @@ -2532,6 +2542,8 @@ impl DeviceManager { .map_err(DeviceManagerError::CreateVirtioMem)?, )); + self.virtio_mem_devices.push(Arc::clone(&virtio_mem_device)); + devices.push(( Arc::clone(&virtio_mem_device) as VirtioDeviceArc, false, @@ -2702,18 +2714,25 @@ impl DeviceManager { let vfio_device = VfioDevice::new(&device_cfg.path, Arc::clone(&vfio_container)) .map_err(DeviceManagerError::VfioCreate)?; + let vfio_mapping = Arc::new(VfioDmaMapping::new( + Arc::clone(&vfio_container), + Arc::new(memory), + )); if device_cfg.iommu { if let Some(iommu) = &self.iommu_device { - let vfio_mapping = Arc::new(VfioDmaMapping::new( - Arc::clone(&vfio_container), - Arc::new(memory), - )); - iommu .lock() .unwrap() .add_external_mapping(pci_device_bdf, vfio_mapping); } + } else { + for virtio_mem_device in self.virtio_mem_devices.iter() { + virtio_mem_device + .lock() + .unwrap() + .add_dma_mapping_handler(pci_device_bdf, vfio_mapping.clone()) + .map_err(DeviceManagerError::AddDmaMappingHandlerVirtioMem)?; + } } let legacy_interrupt_group = if let (Some(irq), Some(legacy_interrupt_manager)) = ( @@ -3198,10 +3217,10 @@ impl DeviceManager { let (pci_device, bus_device, virtio_device) = if let Ok(vfio_pci_device) = any_device.clone().downcast::>() { - // Unregister DMA mapping in IOMMU. - // Do not unregister the virtio-mem region, as it is directly - // handled by the virtio-mem device. { + // Unregister DMA mapping in IOMMU. + // Do not unregister the virtio-mem region, as it is + // directly handled by the virtio-mem device. let dev = vfio_pci_device.lock().unwrap(); for (_, zone) in self.memory_manager.lock().unwrap().memory_zones().iter() { for region in zone.regions() { @@ -3209,6 +3228,18 @@ impl DeviceManager { .map_err(DeviceManagerError::VfioDmaUnmap)?; } } + + // Unregister the VFIO mapping handler from all virtio-mem + // devices. + if !dev.iommu_attached() { + for virtio_mem_device in self.virtio_mem_devices.iter() { + virtio_mem_device + .lock() + .unwrap() + .remove_dma_mapping_handler(pci_device_bdf) + .map_err(DeviceManagerError::RemoveDmaMappingHandlerVirtioMem)?; + } + } } (