diff --git a/vfio/src/vfio_device.rs b/vfio/src/vfio_device.rs index b324da82b..6230fe0de 100644 --- a/vfio/src/vfio_device.rs +++ b/vfio/src/vfio_device.rs @@ -103,7 +103,7 @@ struct vfio_region_info_with_cap { cap_info: __IncompleteArrayField, } -struct VfioContainer { +pub struct VfioContainer { container: File, } @@ -151,7 +151,7 @@ impl VfioContainer { Ok(()) } - fn vfio_dma_map(&self, iova: u64, size: u64, user_addr: u64) -> Result<()> { + pub fn vfio_dma_map(&self, iova: u64, size: u64, user_addr: u64) -> Result<()> { let dma_map = vfio_iommu_type1_dma_map { argsz: mem::size_of::() as u32, flags: VFIO_DMA_MAP_FLAG_READ | VFIO_DMA_MAP_FLAG_WRITE, @@ -170,7 +170,7 @@ impl VfioContainer { Ok(()) } - fn vfio_dma_unmap(&self, iova: u64, size: u64) -> Result<()> { + pub fn vfio_dma_unmap(&self, iova: u64, size: u64) -> Result<()> { let mut dma_unmap = vfio_iommu_type1_dma_unmap { argsz: mem::size_of::() as u32, flags: 0, @@ -198,7 +198,7 @@ impl AsRawFd for VfioContainer { struct VfioGroup { group: File, device: Arc, - container: VfioContainer, + container: Arc, } impl VfioGroup { @@ -225,7 +225,7 @@ impl VfioGroup { return Err(VfioError::GroupViable); } - let container = VfioContainer::new()?; + let container = Arc::new(VfioContainer::new()?); if container.get_api_version() as u32 != VFIO_API_VERSION { return Err(VfioError::VfioApiVersion); } @@ -762,6 +762,10 @@ impl VfioDevice { } } + pub fn get_container(&self) -> Arc { + self.group.container.clone() + } + fn vfio_dma_map(&self, iova: u64, size: u64, user_addr: u64) -> Result<()> { self.group.container.vfio_dma_map(iova, size, user_addr) } diff --git a/vmm/src/device_manager.rs b/vmm/src/device_manager.rs index b62d36912..0d00f0386 100644 --- a/vmm/src/device_manager.rs +++ b/vmm/src/device_manager.rs @@ -1006,6 +1006,8 @@ impl DeviceManager { VfioDevice::new(&device_cfg.path, device_fd.clone(), vm_info.memory.clone()) .map_err(DeviceManagerError::VfioCreate)?; + let _vfio_container = vfio_device.get_container(); + let mut vfio_pci_device = VfioPciDevice::new(vm_info.vm_fd, allocator, vfio_device) .map_err(DeviceManagerError::VfioPciCreate)?;