diff --git a/pci/src/lib.rs b/pci/src/lib.rs index 3d4d7c40c..1e852f91b 100644 --- a/pci/src/lib.rs +++ b/pci/src/lib.rs @@ -26,7 +26,7 @@ pub use self::device::{ }; pub use self::msi::{msi_num_enabled_vectors, MsiCap, MsiConfig}; pub use self::msix::{MsixCap, MsixConfig, MsixTableEntry, MSIX_CONFIG_ID, MSIX_TABLE_ENTRY_SIZE}; -pub use self::vfio::{VfioDmaMapping, VfioPciDevice, VfioPciError}; +pub use self::vfio::{MmioRegion, VfioDmaMapping, VfioPciDevice, VfioPciError}; pub use self::vfio_user::{VfioUserDmaMapping, VfioUserPciDevice, VfioUserPciDeviceError}; use serde::de::Visitor; use std::fmt::{self, Display}; diff --git a/pci/src/vfio.rs b/pci/src/vfio.rs index 8c4c0d9b1..ba3350c0a 100644 --- a/pci/src/vfio.rs +++ b/pci/src/vfio.rs @@ -275,6 +275,48 @@ pub struct MmioRegion { pub(crate) index: u32, pub(crate) user_memory_regions: Vec, } + +trait MmioRegionRange { + fn check_range(&self, guest_addr: u64, size: u64) -> bool; + fn find_user_address(&self, guest_addr: u64) -> Result; +} + +impl MmioRegionRange for Vec { + // Check if a guest address is within the range of mmio regions + fn check_range(&self, guest_addr: u64, size: u64) -> bool { + for region in self.iter() { + let Some(guest_addr_end) = guest_addr.checked_add(size) else { + return false; + }; + let Some(region_end) = region.start.raw_value().checked_add(region.length) else { + return false; + }; + if guest_addr >= region.start.raw_value() && guest_addr_end <= region_end { + return true; + } + } + false + } + + // Locate the user region address for a guest address within all mmio regions + fn find_user_address(&self, guest_addr: u64) -> Result { + for region in self.iter() { + for user_region in region.user_memory_regions.iter() { + if guest_addr >= user_region.start + && guest_addr < user_region.start + user_region.size + { + return Ok(user_region.host_addr + (guest_addr - user_region.start)); + } + } + } + + Err(io::Error::new( + io::ErrorKind::Other, + format!("unable to find user address: 0x{guest_addr:x}"), + )) + } +} + #[derive(Debug, Error)] pub enum VfioError { #[error("Kernel VFIO error: {0}")] @@ -1893,16 +1935,25 @@ impl Migratable for VfioPciDevice {} pub struct VfioDmaMapping { container: Arc, memory: Arc, + mmio_regions: Arc>>, } impl VfioDmaMapping { /// Create a DmaMapping object. - /// /// # Parameters /// * `container`: VFIO container object. - /// * `memory·: guest memory to mmap. - pub fn new(container: Arc, memory: Arc) -> Self { - VfioDmaMapping { container, memory } + /// * `memory`: guest memory to mmap. + /// * `mmio_regions`: mmio_regions to mmap. + pub fn new( + container: Arc, + memory: Arc, + mmio_regions: Arc>>, + ) -> Self { + VfioDmaMapping { + container, + memory, + mmio_regions, + } } } @@ -1911,14 +1962,21 @@ impl ExternalDmaMapping for VfioDmaMapping t as u64, + Err(e) => { + return Err(io::Error::new( + io::ErrorKind::Other, + format!("unable to retrieve user address for gpa 0x{gpa:x} from guest memory region: {e}") + )); + } + } + } else if self.mmio_regions.lock().unwrap().check_range(gpa, size) { + self.mmio_regions.lock().unwrap().find_user_address(gpa)? } else { return Err(io::Error::new( io::ErrorKind::Other, - format!( - "failed to convert guest address 0x{gpa:x} into \ - host user virtual address" - ), + format!("failed to locate guest address 0x{gpa:x} in guest memory"), )); }; diff --git a/vmm/src/device_manager.rs b/vmm/src/device_manager.rs index bf31dc840..61ff76115 100644 --- a/vmm/src/device_manager.rs +++ b/vmm/src/device_manager.rs @@ -58,7 +58,7 @@ use libc::{ O_TMPFILE, PROT_READ, PROT_WRITE, TCSANOW, }; use pci::{ - DeviceRelocation, PciBarRegionType, PciBdf, PciDevice, VfioDmaMapping, + DeviceRelocation, MmioRegion, PciBarRegionType, PciBdf, PciDevice, VfioDmaMapping, VfioPciDevice, VfioUserDmaMapping, VfioUserPciDevice, VfioUserPciDeviceError, }; use rate_limiter::group::RateLimiterGroup; @@ -965,6 +965,8 @@ pub struct DeviceManager { snapshot: Option, rate_limit_groups: HashMap>, + + mmio_regions: Arc>>, } impl DeviceManager { @@ -1195,6 +1197,7 @@ impl DeviceManager { acpi_platform_addresses: AcpiPlatformAddresses::default(), snapshot, rate_limit_groups, + mmio_regions: Arc::new(Mutex::new(Vec::new())), }; let device_manager = Arc::new(Mutex::new(device_manager)); @@ -3423,6 +3426,7 @@ impl DeviceManager { let vfio_mapping = Arc::new(VfioDmaMapping::new( Arc::clone(&vfio_container), Arc::new(self.memory_manager.lock().unwrap().guest_memory()), + Arc::clone(&self.mmio_regions), )); if let Some(iommu) = &self.iommu_device { @@ -3467,6 +3471,7 @@ impl DeviceManager { let vfio_mapping = Arc::new(VfioDmaMapping::new( Arc::clone(&vfio_container), Arc::new(self.memory_manager.lock().unwrap().guest_memory()), + Arc::clone(&self.mmio_regions), )); for virtio_mem_device in self.virtio_mem_devices.iter() { @@ -3529,6 +3534,10 @@ impl DeviceManager { .map_mmio_regions() .map_err(DeviceManagerError::VfioMapRegion)?; + for mmio_region in vfio_pci_device.lock().unwrap().mmio_regions() { + self.mmio_regions.lock().unwrap().push(mmio_region); + } + let mut node = device_node!(vfio_name, vfio_pci_device); // Update the device tree with correct resource information. @@ -4200,12 +4209,21 @@ impl DeviceManager { let (pci_device, bus_device, virtio_device, remove_dma_handler) = match pci_device_handle { // No need to remove any virtio-mem mapping here as the container outlives all devices - PciDeviceHandle::Vfio(vfio_pci_device) => ( - Arc::clone(&vfio_pci_device) as Arc>, - Arc::clone(&vfio_pci_device) as Arc>, - None as Option>>, - false, - ), + PciDeviceHandle::Vfio(vfio_pci_device) => { + for mmio_region in vfio_pci_device.lock().unwrap().mmio_regions() { + self.mmio_regions + .lock() + .unwrap() + .retain(|x| x.start != mmio_region.start) + } + + ( + Arc::clone(&vfio_pci_device) as Arc>, + Arc::clone(&vfio_pci_device) as Arc>, + None as Option>>, + false, + ) + } PciDeviceHandle::Virtio(virtio_pci_device) => { let dev = virtio_pci_device.lock().unwrap(); let bar_addr = dev.config_bar_addr();