diff --git a/vm-virtio/src/transport/pci_device.rs b/vm-virtio/src/transport/pci_device.rs index ad323122e..8ae396ff5 100755 --- a/vm-virtio/src/transport/pci_device.rs +++ b/vm-virtio/src/transport/pci_device.rs @@ -20,6 +20,7 @@ use crate::{ VirtioIommuRemapping, DEVICE_ACKNOWLEDGE, DEVICE_DRIVER, DEVICE_DRIVER_OK, DEVICE_FAILED, DEVICE_FEATURES_OK, DEVICE_INIT, VIRTIO_MSI_NO_VECTOR, }; +use anyhow::anyhow; use devices::BusDevice; use libc::EFD_NONBLOCK; use pci::{ @@ -30,6 +31,7 @@ use pci::{ use std::any::Any; use std::cmp; use std::io::Write; +use std::num::Wrapping; use std::result; use std::sync::atomic::{AtomicU16, AtomicUsize, Ordering}; use std::sync::{Arc, Mutex}; @@ -41,9 +43,18 @@ use vm_memory::{ Address, ByteValued, GuestAddress, GuestAddressSpace, GuestMemoryAtomic, GuestMemoryMmap, GuestUsize, Le32, }; -use vm_migration::{Migratable, MigratableError, Pausable, Snapshottable, Transportable}; +use vm_migration::{ + Migratable, MigratableError, Pausable, Snapshot, SnapshotDataSection, Snapshottable, + Transportable, +}; use vmm_sys_util::{errno::Result, eventfd::EventFd}; +#[derive(Debug)] +enum Error { + /// Failed to retrieve queue ring's index. + QueueRingIndex(crate::queue::Error), +} + #[allow(clippy::enum_variant_names)] enum PciCapabilityType { CommonConfig = 1, @@ -258,6 +269,13 @@ const NOTIFY_OFF_MULTIPLIER: u32 = 4; // A dword per notification address. const VIRTIO_PCI_VENDOR_ID: u16 = 0x1af4; const VIRTIO_PCI_DEVICE_ID_BASE: u16 = 0x1040; // Add to device type to get device ID. +#[derive(Serialize, Deserialize)] +struct VirtioPciDeviceState { + device_activated: bool, + queues: Vec, + interrupt_status: usize, +} + pub struct VirtioPciDevice { id: String, @@ -424,6 +442,40 @@ impl VirtioPciDevice { Ok(virtio_pci_device) } + fn state(&self) -> VirtioPciDeviceState { + VirtioPciDeviceState { + device_activated: self.device_activated, + interrupt_status: self.interrupt_status.load(Ordering::SeqCst), + queues: self.queues.clone(), + } + } + + fn set_state(&mut self, state: &VirtioPciDeviceState) -> std::result::Result<(), Error> { + self.device_activated = state.device_activated; + self.interrupt_status + .store(state.interrupt_status, Ordering::SeqCst); + self.queues = state.queues.clone(); + + // Update virtqueues indexes for both available and used rings. + if let Some(mem) = self.memory.as_ref() { + let mem = mem.memory(); + for queue in self.queues.iter_mut() { + queue.next_avail = Wrapping( + queue + .used_index_from_memory(&mem) + .map_err(Error::QueueRingIndex)?, + ); + queue.next_used = Wrapping( + queue + .used_index_from_memory(&mem) + .map_err(Error::QueueRingIndex)?, + ); + } + } + + Ok(()) + } + /// Gets the list of queue events that must be triggered whenever the VM writes to /// `virtio::NOTIFY_REG_OFFSET` past the MMIO base. Each event must be triggered when the /// value being written equals the index of the event in this list. @@ -970,6 +1022,108 @@ impl Snapshottable for VirtioPciDevice { fn id(&self) -> String { self.id.clone() } + + fn snapshot(&self) -> std::result::Result { + let snapshot = + serde_json::to_vec(&self.state()).map_err(|e| MigratableError::Snapshot(e.into()))?; + + let mut virtio_pci_dev_snapshot = Snapshot::new(self.id.as_str()); + virtio_pci_dev_snapshot.add_data_section(SnapshotDataSection { + id: format!("{}-section", self.id), + snapshot, + }); + + // Snapshot PciConfiguration + virtio_pci_dev_snapshot.add_snapshot(self.configuration.snapshot()?); + + // Snapshot VirtioPciCommonConfig + virtio_pci_dev_snapshot.add_snapshot(self.common_config.snapshot()?); + + // Snapshot MSI-X + if let Some(msix_config) = &self.msix_config { + virtio_pci_dev_snapshot.add_snapshot(msix_config.lock().unwrap().snapshot()?); + } + + Ok(virtio_pci_dev_snapshot) + } + + fn restore(&mut self, snapshot: Snapshot) -> std::result::Result<(), MigratableError> { + if let Some(virtio_pci_dev_section) = + snapshot.snapshot_data.get(&format!("{}-section", self.id)) + { + // Restore MSI-X + if let Some(msix_config) = &self.msix_config { + let id = msix_config.lock().unwrap().id(); + if let Some(msix_snapshot) = snapshot.snapshots.get(&id) { + msix_config + .lock() + .unwrap() + .restore(*msix_snapshot.clone())?; + } + } + + // Restore VirtioPciCommonConfig + if let Some(virtio_config_snapshot) = snapshot.snapshots.get(&self.common_config.id()) { + self.common_config + .restore(*virtio_config_snapshot.clone())?; + } + + // Restore PciConfiguration + if let Some(pci_config_snapshot) = snapshot.snapshots.get(&self.configuration.id()) { + self.configuration.restore(*pci_config_snapshot.clone())?; + } + + let virtio_pci_dev_state = + match serde_json::from_slice(&virtio_pci_dev_section.snapshot) { + Ok(state) => state, + Err(error) => { + return Err(MigratableError::Restore(anyhow!( + "Could not deserialize VIRTIO_PCI_DEVICE {}", + error + ))) + } + }; + + // First restore the status of the virtqueues. + self.set_state(&virtio_pci_dev_state).map_err(|e| { + MigratableError::Restore(anyhow!( + "Could not restore VIRTIO_PCI_DEVICE state {:?}", + e + )) + })?; + + // Then we can activate the device, as we know at this point that + // the virtqueues are in the right state and the device is ready + // to be activated, which will spawn each virtio worker thread. + if self.device_activated && self.is_driver_ready() && self.are_queues_valid() { + if let Some(virtio_interrupt) = self.virtio_interrupt.take() { + if self.memory.is_some() { + let mem = self.memory.as_ref().unwrap().clone(); + let mut device = self.device.lock().unwrap(); + device + .activate( + mem, + virtio_interrupt, + self.queues.clone(), + self.queue_evts.split_off(0), + ) + .map_err(|e| { + MigratableError::Restore(anyhow!( + "Failed activating the device: {:?}", + e + )) + })?; + } + } + } + + return Ok(()); + } + + Err(MigratableError::Restore(anyhow!( + "Could not find VIRTIO_PCI_DEVICE snapshot section" + ))) + } } impl Transportable for VirtioPciDevice {} impl Migratable for VirtioPciDevice {}