diff --git a/virtio-devices/src/mem.rs b/virtio-devices/src/mem.rs index 679ead7e4..7d15a598d 100644 --- a/virtio-devices/src/mem.rs +++ b/virtio-devices/src/mem.rs @@ -24,6 +24,7 @@ use crate::{VirtioInterrupt, VirtioInterruptType}; use anyhow::anyhow; use libc::EFD_NONBLOCK; use seccomp::{SeccompAction, SeccompFilter}; +use std::collections::BTreeMap; use std::io; use std::mem::size_of; use std::os::unix::io::{AsRawFd, RawFd}; @@ -32,6 +33,7 @@ use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; use std::sync::mpsc; use std::sync::{Arc, Barrier, Mutex}; use std::thread; +use vm_device::dma_mapping::ExternalDmaMapping; use vm_memory::{ Address, ByteValued, Bytes, GuestAddress, GuestAddressSpace, GuestMemoryAtomic, GuestMemoryError, GuestMemoryMmap, GuestMemoryRegion, GuestRegionMmap, @@ -124,6 +126,12 @@ pub enum Error { ValidateError(anyhow::Error), // Failed discarding memory range DiscardMemoryRange(std::io::Error), + // Failed DMA mapping. + DmaMap(std::io::Error), + // Failed DMA unmapping. + DmaUnmap(std::io::Error), + // Invalid DMA mapping handler + InvalidDmaMappingHandler, } #[repr(C)] @@ -401,12 +409,16 @@ impl BlocksState { *state = plug; } } + + fn inner(&self) -> &Vec { + &self.0 + } } struct MemEpollHandler { host_addr: u64, host_fd: Option, - blocks_state: BlocksState, + blocks_state: Arc>, config: Arc>, resize: ResizeSender, queue: Queue, @@ -416,6 +428,7 @@ struct MemEpollHandler { kill_evt: EventFd, pause_evt: EventFd, hugepages: bool, + dma_mapping_handlers: Arc>>>, } impl MemEpollHandler { @@ -473,6 +486,8 @@ impl MemEpollHandler { let first_block_index = (offset / config.block_size) as usize; if !self .blocks_state + .lock() + .unwrap() .is_range_state(first_block_index, nb_blocks, !plug) { return VIRTIO_MEM_RESP_ERROR; @@ -486,11 +501,40 @@ impl MemEpollHandler { } self.blocks_state + .lock() + .unwrap() .set_range(first_block_index, nb_blocks, plug); if plug { + let handlers = self.dma_mapping_handlers.lock().unwrap(); + let mut gpa = addr; + for _ in 0..nb_blocks { + for (_, handler) in handlers.iter() { + if let Err(e) = handler.map(gpa, gpa, config.block_size) { + error!( + "failed DMA mapping addr 0x{:x} size 0x{:x}: {}", + gpa, config.block_size, e + ); + return VIRTIO_MEM_RESP_ERROR; + } + } + + gpa += config.block_size; + } + config.plugged_size += size; } else { + let handlers = self.dma_mapping_handlers.lock().unwrap(); + for (_, handler) in handlers.iter() { + if let Err(e) = handler.unmap(addr, size) { + error!( + "failed DMA unmapping addr 0x{:x} size 0x{:x}: {}", + addr, size, e + ); + return VIRTIO_MEM_RESP_ERROR; + } + } + config.plugged_size -= size; } @@ -504,8 +548,30 @@ impl MemEpollHandler { return VIRTIO_MEM_RESP_ERROR; } - self.blocks_state - .set_range(0, (config.region_size / config.block_size) as u16, false); + // Remaining plugged blocks are unmapped. + if config.plugged_size > 0 { + let handlers = self.dma_mapping_handlers.lock().unwrap(); + for (idx, plugged) in self.blocks_state.lock().unwrap().inner().iter().enumerate() { + if *plugged { + let gpa = config.addr + (idx as u64 * config.block_size); + for (_, handler) in handlers.iter() { + if let Err(e) = handler.unmap(gpa, config.block_size) { + error!( + "failed DMA unmapping addr 0x{:x} size 0x{:x}: {}", + gpa, config.block_size, e + ); + return VIRTIO_MEM_RESP_ERROR; + } + } + } + } + } + + self.blocks_state.lock().unwrap().set_range( + 0, + (config.region_size / config.block_size) as u16, + false, + ); config.plugged_size = 0; @@ -524,19 +590,23 @@ impl MemEpollHandler { let offset = addr - config.addr; let first_block_index = (offset / config.block_size) as usize; - let resp_state = if self - .blocks_state - .is_range_state(first_block_index, nb_blocks, true) - { - VIRTIO_MEM_STATE_PLUGGED - } else if self - .blocks_state - .is_range_state(first_block_index, nb_blocks, false) - { - VIRTIO_MEM_STATE_UNPLUGGED - } else { - VIRTIO_MEM_STATE_MIXED - }; + let resp_state = + if self + .blocks_state + .lock() + .unwrap() + .is_range_state(first_block_index, nb_blocks, true) + { + VIRTIO_MEM_STATE_PLUGGED + } else if self.blocks_state.lock().unwrap().is_range_state( + first_block_index, + nb_blocks, + false, + ) { + VIRTIO_MEM_STATE_UNPLUGGED + } else { + VIRTIO_MEM_STATE_MIXED + }; (resp_type, resp_state) } @@ -675,6 +745,8 @@ pub struct Mem { config: Arc>, seccomp_action: SeccompAction, hugepages: bool, + dma_mapping_handlers: Arc>>>, + blocks_state: Arc>, } impl Mem { @@ -758,8 +830,64 @@ impl Mem { config: Arc::new(Mutex::new(config)), seccomp_action, hugepages, + dma_mapping_handlers: Arc::new(Mutex::new(BTreeMap::new())), + blocks_state: Arc::new(Mutex::new(BlocksState(vec![ + false; + (config.region_size / config.block_size) + as usize + ]))), }) } + + pub fn add_dma_mapping_handler( + &mut self, + device_id: u32, + handler: Arc, + ) -> result::Result<(), Error> { + let config = self.config.lock().unwrap(); + + if config.plugged_size > 0 { + for (idx, plugged) in self.blocks_state.lock().unwrap().inner().iter().enumerate() { + if *plugged { + let gpa = config.addr + (idx as u64 * config.block_size); + handler + .map(gpa, gpa, config.block_size) + .map_err(Error::DmaMap)?; + } + } + } + + self.dma_mapping_handlers + .lock() + .unwrap() + .insert(device_id, handler); + + Ok(()) + } + + pub fn remove_dma_mapping_handler(&mut self, device_id: u32) -> result::Result<(), Error> { + let handler = self + .dma_mapping_handlers + .lock() + .unwrap() + .remove(&device_id) + .ok_or(Error::InvalidDmaMappingHandler)?; + + let config = self.config.lock().unwrap(); + + if config.plugged_size > 0 { + for (idx, plugged) in self.blocks_state.lock().unwrap().inner().iter().enumerate() { + if *plugged { + let gpa = config.addr + (idx as u64 * config.block_size); + handler + .unmap(gpa, config.block_size) + .map_err(Error::DmaUnmap)?; + } + } + } + + Ok(()) + } } impl Drop for Mem { @@ -825,10 +953,7 @@ impl VirtioDevice for Mem { let mut handler = MemEpollHandler { host_addr: self.host_addr, host_fd: self.host_fd, - blocks_state: BlocksState(vec![ - false; - (config.region_size / config.block_size) as usize - ]), + blocks_state: Arc::clone(&self.blocks_state), config: self.config.clone(), resize: self.resize.clone(), queue: queues.remove(0), @@ -838,6 +963,7 @@ impl VirtioDevice for Mem { kill_evt, pause_evt, hugepages: self.hugepages, + dma_mapping_handlers: Arc::clone(&self.dma_mapping_handlers), }; handler diff --git a/virtio-devices/src/seccomp_filters.rs b/virtio-devices/src/seccomp_filters.rs index a4a591c51..93c1ded54 100644 --- a/virtio-devices/src/seccomp_filters.rs +++ b/virtio-devices/src/seccomp_filters.rs @@ -55,6 +55,24 @@ const SYS_IO_URING_ENTER: i64 = 426; // See include/uapi/asm-generic/ioctls.h in the kernel code. const FIONBIO: u64 = 0x5421; +// See include/uapi/linux/vfio.h in the kernel code. +const VFIO_IOMMU_MAP_DMA: u64 = 0x3b71; +const VFIO_IOMMU_UNMAP_DMA: u64 = 0x3b72; + +fn create_virtio_iommu_ioctl_seccomp_rule() -> Vec { + or![ + and![Cond::new(1, ArgLen::DWORD, Eq, VFIO_IOMMU_MAP_DMA).unwrap()], + and![Cond::new(1, ArgLen::DWORD, Eq, VFIO_IOMMU_UNMAP_DMA).unwrap()], + ] +} + +fn create_virtio_mem_ioctl_seccomp_rule() -> Vec { + or![ + and![Cond::new(1, ArgLen::DWORD, Eq, VFIO_IOMMU_MAP_DMA).unwrap()], + and![Cond::new(1, ArgLen::DWORD, Eq, VFIO_IOMMU_UNMAP_DMA).unwrap()], + ] +} + fn virtio_balloon_thread_rules() -> Vec { vec![ allow_syscall(libc::SYS_brk), @@ -156,6 +174,7 @@ fn virtio_iommu_thread_rules() -> Vec { allow_syscall(libc::SYS_epoll_wait), allow_syscall(libc::SYS_exit), allow_syscall(libc::SYS_futex), + allow_syscall_if(libc::SYS_ioctl, create_virtio_iommu_ioctl_seccomp_rule()), allow_syscall(libc::SYS_madvise), allow_syscall(libc::SYS_mmap), allow_syscall(libc::SYS_mprotect), @@ -179,6 +198,7 @@ fn virtio_mem_thread_rules() -> Vec { allow_syscall(libc::SYS_exit), allow_syscall(libc::SYS_fallocate), allow_syscall(libc::SYS_futex), + allow_syscall_if(libc::SYS_ioctl, create_virtio_mem_ioctl_seccomp_rule()), allow_syscall(libc::SYS_madvise), allow_syscall(libc::SYS_munmap), allow_syscall(libc::SYS_read),