diff --git a/vm-virtio/src/transport/pci_common_config.rs b/vm-virtio/src/transport/pci_common_config.rs index e411f9950..6d8504057 100644 --- a/vm-virtio/src/transport/pci_common_config.rs +++ b/vm-virtio/src/transport/pci_common_config.rs @@ -8,10 +8,22 @@ extern crate byteorder; use crate::{Queue, VirtioDevice}; +use anyhow::anyhow; use byteorder::{ByteOrder, LittleEndian}; use std::sync::atomic::{AtomicU16, Ordering}; use std::sync::{Arc, Mutex}; use vm_memory::GuestAddress; +use vm_migration::{MigratableError, Pausable, Snapshot, SnapshotDataSection, Snapshottable}; + +#[derive(Clone, Serialize, Deserialize)] +pub struct VirtioPciCommonConfigState { + pub driver_status: u8, + pub config_generation: u8, + pub device_feature_select: u32, + pub driver_feature_select: u32, + pub queue_select: u16, + pub msix_config: u16, +} /// Contains the data for reading and writing the common configuration structure of a virtio PCI /// device. @@ -45,6 +57,26 @@ pub struct VirtioPciCommonConfig { } impl VirtioPciCommonConfig { + fn state(&self) -> VirtioPciCommonConfigState { + VirtioPciCommonConfigState { + driver_status: self.driver_status, + config_generation: self.config_generation, + device_feature_select: self.device_feature_select, + driver_feature_select: self.driver_feature_select, + queue_select: self.queue_select, + msix_config: self.msix_config.load(Ordering::SeqCst), + } + } + + fn set_state(&mut self, state: &VirtioPciCommonConfigState) { + self.driver_status = state.driver_status; + self.config_generation = state.config_generation; + self.device_feature_select = state.device_feature_select; + self.driver_feature_select = state.driver_feature_select; + self.queue_select = state.queue_select; + self.msix_config.store(state.msix_config, Ordering::SeqCst); + } + pub fn read( &mut self, offset: u64, @@ -251,6 +283,54 @@ impl VirtioPciCommonConfig { } } +impl Pausable for VirtioPciCommonConfig {} + +impl Snapshottable for VirtioPciCommonConfig { + fn id(&self) -> String { + String::from("virtio_pci_common_config") + } + + fn snapshot(&self) -> std::result::Result { + let snapshot = + serde_json::to_vec(&self.state()).map_err(|e| MigratableError::Snapshot(e.into()))?; + + let mut config_snapshot = Snapshot::new(self.id().as_str()); + config_snapshot.add_data_section(SnapshotDataSection { + id: format!("{}-section", self.id()), + snapshot, + }); + + Ok(config_snapshot) + } + + fn restore(&mut self, snapshot: Snapshot) -> std::result::Result<(), MigratableError> { + if let Some(config_section) = snapshot + .snapshot_data + .get(&format!("{}-section", self.id())) + { + let config_state = match serde_json::from_slice(&config_section.snapshot) { + Ok(state) => state, + Err(error) => { + return Err(MigratableError::Restore(anyhow!( + "Could not deserialize {}: {}", + self.id(), + error + ))) + } + }; + + self.set_state(&config_state); + + return Ok(()); + } + + Err(MigratableError::Restore(anyhow!( + "Could not find {} snapshot section", + self.id() + ))) + } +} + #[cfg(test)] mod tests { use super::*;