diff --git a/vm-virtio/src/iommu.rs b/vm-virtio/src/iommu.rs index 1a5399132..a7c8ec636 100644 --- a/vm-virtio/src/iommu.rs +++ b/vm-virtio/src/iommu.rs @@ -8,6 +8,7 @@ use super::{ VirtioDeviceType, VIRTIO_F_VERSION_1, }; use crate::{DmaRemapping, VirtioInterrupt, VirtioInterruptType}; +use anyhow::anyhow; use epoll; use libc::EFD_NONBLOCK; use std::cmp; @@ -27,7 +28,10 @@ use vm_memory::{ Address, ByteValued, Bytes, GuestAddress, GuestAddressSpace, GuestMemoryAtomic, GuestMemoryError, GuestMemoryMmap, }; -use vm_migration::{Migratable, MigratableError, Pausable, Snapshottable, Transportable}; +use vm_migration::{ + Migratable, MigratableError, Pausable, Snapshot, SnapshotDataSection, Snapshottable, + Transportable, +}; use vmm_sys_util::eventfd::EventFd; /// Queues sizes @@ -757,7 +761,7 @@ impl IommuEpollHandler { } } -#[derive(Clone, Copy)] +#[derive(Clone, Copy, Serialize, Deserialize)] struct Mapping { gpa: u64, size: u64, @@ -811,6 +815,14 @@ pub struct Iommu { paused: Arc, } +#[derive(Serialize, Deserialize)] +struct IommuState { + avail_features: u64, + acked_features: u64, + endpoints: BTreeMap, + mappings: BTreeMap>, +} + impl Iommu { pub fn new(id: String) -> io::Result<(Self, Arc)> { let config = VirtioIommuConfig { @@ -846,6 +858,24 @@ impl Iommu { )) } + fn state(&self) -> IommuState { + IommuState { + avail_features: self.avail_features, + acked_features: self.acked_features, + endpoints: self.mapping.endpoints.read().unwrap().clone(), + mappings: self.mapping.mappings.read().unwrap().clone(), + } + } + + fn set_state(&mut self, state: &IommuState) -> io::Result<()> { + self.avail_features = state.avail_features; + self.acked_features = state.acked_features; + *(self.mapping.endpoints.write().unwrap()) = state.endpoints.clone(); + *(self.mapping.mappings.write().unwrap()) = state.mappings.clone(); + + Ok(()) + } + // This function lets the caller specify a list of devices attached to the // virtual IOMMU. This list is translated into a virtio-iommu configuration // topology, so that it can be understood by the guest driver. @@ -1048,6 +1078,41 @@ impl Snapshottable for Iommu { fn id(&self) -> String { self.id.clone() } + + fn snapshot(&self) -> std::result::Result { + let snapshot = + serde_json::to_vec(&self.state()).map_err(|e| MigratableError::Snapshot(e.into()))?; + + let mut iommu_snapshot = Snapshot::new(self.id.as_str()); + iommu_snapshot.add_data_section(SnapshotDataSection { + id: format!("{}-section", self.id), + snapshot, + }); + + Ok(iommu_snapshot) + } + + fn restore(&mut self, snapshot: Snapshot) -> std::result::Result<(), MigratableError> { + if let Some(iommu_section) = snapshot.snapshot_data.get(&format!("{}-section", self.id)) { + let iommu_state = match serde_json::from_slice(&iommu_section.snapshot) { + Ok(state) => state, + Err(error) => { + return Err(MigratableError::Restore(anyhow!( + "Could not deserialize IOMMU {}", + error + ))) + } + }; + + return self.set_state(&iommu_state).map_err(|e| { + MigratableError::Restore(anyhow!("Could not restore IOMMU state {:?}", e)) + }); + } + + Err(MigratableError::Restore(anyhow!( + "Could not find IOMMU snapshot section" + ))) + } } impl Transportable for Iommu {} impl Migratable for Iommu {}