diff --git a/virtio-devices/src/mem.rs b/virtio-devices/src/mem.rs index 5be59e240..2b5910712 100644 --- a/virtio-devices/src/mem.rs +++ b/virtio-devices/src/mem.rs @@ -25,6 +25,7 @@ use crate::{VirtioInterrupt, VirtioInterruptType}; use anyhow::anyhow; use libc::EFD_NONBLOCK; use seccompiler::SeccompAction; +use std::collections::BTreeMap; use std::io; use std::mem::size_of; use std::os::unix::io::{AsRawFd, RawFd}; @@ -427,7 +428,7 @@ struct MemEpollHandler { kill_evt: EventFd, pause_evt: EventFd, hugepages: bool, - dma_mapping_handler: Option>, + dma_mapping_handlers: Arc>>>, } impl MemEpollHandler { @@ -504,10 +505,11 @@ impl MemEpollHandler { .unwrap() .set_range(first_block_index, nb_blocks, plug); + let handlers = self.dma_mapping_handlers.lock().unwrap(); if plug { let mut gpa = addr; for _ in 0..nb_blocks { - if let Some(handler) = &self.dma_mapping_handler { + for (_, handler) in handlers.iter() { if let Err(e) = handler.map(gpa, gpa, config.block_size) { error!( "failed DMA mapping addr 0x{:x} size 0x{:x}: {}", @@ -522,7 +524,7 @@ impl MemEpollHandler { config.plugged_size += size; } else { - if let Some(handler) = &self.dma_mapping_handler { + for (_, handler) in handlers.iter() { if let Err(e) = handler.unmap(addr, size) { error!( "failed DMA unmapping addr 0x{:x} size 0x{:x}: {}", @@ -547,10 +549,11 @@ impl MemEpollHandler { // Remaining plugged blocks are unmapped. if config.plugged_size > 0 { + let handlers = self.dma_mapping_handlers.lock().unwrap(); for (idx, plugged) in self.blocks_state.lock().unwrap().inner().iter().enumerate() { if *plugged { let gpa = config.addr + (idx as u64 * config.block_size); - if let Some(handler) = &self.dma_mapping_handler { + for (_, handler) in handlers.iter() { if let Err(e) = handler.unmap(gpa, config.block_size) { error!( "failed DMA unmapping addr 0x{:x} size 0x{:x}: {}", @@ -731,7 +734,12 @@ impl EpollHelperHandler for MemEpollHandler { } } -// Virtio device for exposing entropy to the guest OS through virtio. +#[derive(PartialEq, Eq, PartialOrd, Ord)] +pub enum VirtioMemMappingSource { + Container, + Device(u32), +} + pub struct Mem { common: VirtioCommon, id: String, @@ -741,7 +749,7 @@ pub struct Mem { config: Arc>, seccomp_action: SeccompAction, hugepages: bool, - dma_mapping_handler: Option>, + dma_mapping_handlers: Arc>>>, blocks_state: Arc>, exit_evt: EventFd, } @@ -829,7 +837,7 @@ impl Mem { config: Arc::new(Mutex::new(config)), seccomp_action, hugepages, - dma_mapping_handler: None, + dma_mapping_handlers: Arc::new(Mutex::new(BTreeMap::new())), blocks_state: Arc::new(Mutex::new(BlocksState(vec![ false; (config.region_size / config.block_size) @@ -841,6 +849,7 @@ impl Mem { pub fn add_dma_mapping_handler( &mut self, + source: VirtioMemMappingSource, handler: Arc, ) -> result::Result<(), Error> { let config = self.config.lock().unwrap(); @@ -856,7 +865,37 @@ impl Mem { } } - self.dma_mapping_handler = Some(handler); + self.dma_mapping_handlers + .lock() + .unwrap() + .insert(source, handler); + + Ok(()) + } + + pub fn remove_dma_mapping_handler( + &mut self, + source: VirtioMemMappingSource, + ) -> result::Result<(), Error> { + let handler = self + .dma_mapping_handlers + .lock() + .unwrap() + .remove(&source) + .ok_or(Error::InvalidDmaMappingHandler)?; + + let config = self.config.lock().unwrap(); + + if config.plugged_size > 0 { + for (idx, plugged) in self.blocks_state.lock().unwrap().inner().iter().enumerate() { + if *plugged { + let gpa = config.addr + (idx as u64 * config.block_size); + handler + .unmap(gpa, config.block_size) + .map_err(Error::DmaUnmap)?; + } + } + } Ok(()) } @@ -915,7 +954,7 @@ impl VirtioDevice for Mem { kill_evt, pause_evt, hugepages: self.hugepages, - dma_mapping_handler: self.dma_mapping_handler.clone(), + dma_mapping_handlers: Arc::clone(&self.dma_mapping_handlers), }; handler diff --git a/vmm/src/device_manager.rs b/vmm/src/device_manager.rs index 4677abbbf..6ce5fb23c 100644 --- a/vmm/src/device_manager.rs +++ b/vmm/src/device_manager.rs @@ -91,6 +91,7 @@ use vfio_ioctls::{VfioContainer, VfioDevice}; use virtio_devices::transport::VirtioPciDevice; use virtio_devices::transport::VirtioTransport; use virtio_devices::vhost_user::VhostUserConfig; +use virtio_devices::VirtioMemMappingSource; use virtio_devices::{DmaRemapping, Endpoint, IommuMapping}; use virtio_devices::{VirtioSharedMemory, VirtioSharedMemoryList}; use vm_allocator::SystemAllocator; @@ -2950,7 +2951,10 @@ impl DeviceManager { virtio_mem_device .lock() .unwrap() - .add_dma_mapping_handler(vfio_mapping.clone()) + .add_dma_mapping_handler( + VirtioMemMappingSource::Container, + vfio_mapping.clone(), + ) .map_err(DeviceManagerError::AddDmaMappingHandlerVirtioMem)?; } }