diff --git a/pci/src/bus.rs b/pci/src/bus.rs index ff033ddab..2d65b2bf5 100644 --- a/pci/src/bus.rs +++ b/pci/src/bus.rs @@ -62,6 +62,7 @@ impl PciRoot { 0, 0, None, + None, ), } } diff --git a/pci/src/configuration.rs b/pci/src/configuration.rs index 2542e7072..95379d9a6 100644 --- a/pci/src/configuration.rs +++ b/pci/src/configuration.rs @@ -32,6 +32,8 @@ const CAPABILITY_MAX_OFFSET: usize = 192; const INTERRUPT_LINE_PIN_REG: usize = 15; +pub const PCI_CONFIGURATION_ID: &str = "pci_configuration"; + /// Represents the types of PCI headers allowed in the configuration registers. #[derive(Copy, Clone)] pub enum PciHeaderType { @@ -394,7 +396,7 @@ fn decode_64_bits_bar_size(bar_size_hi: u32, bar_size_lo: u32) -> Option { None } -#[derive(Default, Clone, Copy, Versionize)] +#[derive(Debug, Default, Clone, Copy, Versionize)] struct PciBar { addr: u32, size: u32, @@ -403,7 +405,7 @@ struct PciBar { } #[derive(Versionize)] -struct PciConfigurationState { +pub struct PciConfigurationState { registers: Vec, writable_bits: Vec, bars: Vec, @@ -553,46 +555,78 @@ impl PciConfiguration { subsystem_vendor_id: u16, subsystem_id: u16, msix_config: Option>>, + state: Option, ) -> Self { - let mut registers = [0u32; NUM_CONFIGURATION_REGISTERS]; - let mut writable_bits = [0u32; NUM_CONFIGURATION_REGISTERS]; - registers[0] = u32::from(device_id) << 16 | u32::from(vendor_id); - // TODO(dverkamp): Status should be write-1-to-clear - writable_bits[1] = 0x0000_ffff; // Status (r/o), command (r/w) - let pi = if let Some(pi) = programming_interface { - pi.get_register_value() + let ( + registers, + writable_bits, + bars, + rom_bar_addr, + rom_bar_size, + rom_bar_used, + last_capability, + msix_cap_reg_idx, + ) = if let Some(state) = state { + ( + state.registers.try_into().unwrap(), + state.writable_bits.try_into().unwrap(), + state.bars.try_into().unwrap(), + state.rom_bar_addr, + state.rom_bar_size, + state.rom_bar_used, + state.last_capability, + state.msix_cap_reg_idx, + ) } else { - 0 - }; - registers[2] = u32::from(class_code.get_register_value()) << 24 - | u32::from(subclass.get_register_value()) << 16 - | u32::from(pi) << 8 - | u32::from(revision_id); - writable_bits[3] = 0x0000_00ff; // Cacheline size (r/w) - match header_type { - PciHeaderType::Device => { - registers[3] = 0x0000_0000; // Header type 0 (device) - writable_bits[15] = 0x0000_00ff; // Interrupt line (r/w) - } - PciHeaderType::Bridge => { - registers[3] = 0x0001_0000; // Header type 1 (bridge) - writable_bits[9] = 0xfff0_fff0; // Memory base and limit - writable_bits[15] = 0xffff_00ff; // Bridge control (r/w), interrupt line (r/w) - } - }; - registers[11] = u32::from(subsystem_id) << 16 | u32::from(subsystem_vendor_id); + let mut registers = [0u32; NUM_CONFIGURATION_REGISTERS]; + let mut writable_bits = [0u32; NUM_CONFIGURATION_REGISTERS]; + registers[0] = u32::from(device_id) << 16 | u32::from(vendor_id); + // TODO(dverkamp): Status should be write-1-to-clear + writable_bits[1] = 0x0000_ffff; // Status (r/o), command (r/w) + let pi = if let Some(pi) = programming_interface { + pi.get_register_value() + } else { + 0 + }; + registers[2] = u32::from(class_code.get_register_value()) << 24 + | u32::from(subclass.get_register_value()) << 16 + | u32::from(pi) << 8 + | u32::from(revision_id); + writable_bits[3] = 0x0000_00ff; // Cacheline size (r/w) + match header_type { + PciHeaderType::Device => { + registers[3] = 0x0000_0000; // Header type 0 (device) + writable_bits[15] = 0x0000_00ff; // Interrupt line (r/w) + } + PciHeaderType::Bridge => { + registers[3] = 0x0001_0000; // Header type 1 (bridge) + writable_bits[9] = 0xfff0_fff0; // Memory base and limit + writable_bits[15] = 0xffff_00ff; // Bridge control (r/w), interrupt line (r/w) + } + }; + registers[11] = u32::from(subsystem_id) << 16 | u32::from(subsystem_vendor_id); - let bars = [PciBar::default(); NUM_BAR_REGS]; + ( + registers, + writable_bits, + [PciBar::default(); NUM_BAR_REGS], + 0, + 0, + false, + None, + None, + ) + }; PciConfiguration { registers, writable_bits, bars, - rom_bar_addr: 0, - rom_bar_size: 0, - rom_bar_used: false, - last_capability: None, - msix_cap_reg_idx: None, + rom_bar_addr, + rom_bar_size, + rom_bar_used, + last_capability, + msix_cap_reg_idx, msix_config, } } @@ -1046,7 +1080,7 @@ impl Pausable for PciConfiguration {} impl Snapshottable for PciConfiguration { fn id(&self) -> String { - String::from("pci_configuration") + String::from(PCI_CONFIGURATION_ID) } fn snapshot(&mut self) -> std::result::Result { @@ -1178,6 +1212,7 @@ mod tests { 0xABCD, 0x2468, None, + None, ); // Add two capabilities with different contents. @@ -1234,6 +1269,7 @@ mod tests { 0xABCD, 0x2468, None, + None, ); let class_reg = cfg.read_reg(2); diff --git a/pci/src/lib.rs b/pci/src/lib.rs index 0758d17df..4b24fcd03 100644 --- a/pci/src/lib.rs +++ b/pci/src/lib.rs @@ -19,12 +19,13 @@ pub use self::configuration::{ PciBarConfiguration, PciBarPrefetchable, PciBarRegionType, PciCapability, PciCapabilityId, PciClassCode, PciConfiguration, PciExpressCapabilityId, PciHeaderType, PciMassStorageSubclass, PciNetworkControllerSubclass, PciProgrammingInterface, PciSerialBusSubClass, PciSubclass, + PCI_CONFIGURATION_ID, }; pub use self::device::{ BarReprogrammingParams, DeviceRelocation, Error as PciDeviceError, PciDevice, }; pub use self::msi::{msi_num_enabled_vectors, MsiCap, MsiConfig}; -pub use self::msix::{MsixCap, MsixConfig, MsixTableEntry, MSIX_TABLE_ENTRY_SIZE}; +pub use self::msix::{MsixCap, MsixConfig, MsixTableEntry, MSIX_CONFIG_ID, MSIX_TABLE_ENTRY_SIZE}; pub use self::vfio::{VfioPciDevice, VfioPciError}; pub use self::vfio_user::{VfioUserDmaMapping, VfioUserPciDevice, VfioUserPciDeviceError}; use serde::de::Visitor; diff --git a/pci/src/msix.rs b/pci/src/msix.rs index d66d41272..fe5b95d11 100644 --- a/pci/src/msix.rs +++ b/pci/src/msix.rs @@ -26,9 +26,10 @@ const MSIX_ENABLE_BIT: u8 = 15; const FUNCTION_MASK_MASK: u16 = (1 << FUNCTION_MASK_BIT) as u16; const MSIX_ENABLE_MASK: u16 = (1 << MSIX_ENABLE_BIT) as u16; pub const MSIX_TABLE_ENTRY_SIZE: usize = 16; +pub const MSIX_CONFIG_ID: &str = "msix_config"; #[derive(Debug)] -enum Error { +pub enum Error { /// Failed enabling the interrupt route. EnableInterruptRoute(io::Error), /// Failed updating the interrupt route. @@ -61,7 +62,7 @@ impl Default for MsixTableEntry { } #[derive(Versionize)] -struct MsixConfigState { +pub struct MsixConfigState { table_entries: Vec, pba_entries: Vec, masked: bool, @@ -84,23 +85,62 @@ impl MsixConfig { msix_vectors: u16, interrupt_source_group: Arc, devid: u32, - ) -> Self { + state: Option, + ) -> result::Result { assert!(msix_vectors <= MAX_MSIX_VECTORS_PER_DEVICE); - let mut table_entries: Vec = Vec::new(); - table_entries.resize_with(msix_vectors as usize, Default::default); - let mut pba_entries: Vec = Vec::new(); - let num_pba_entries: usize = ((msix_vectors as usize) / BITS_PER_PBA_ENTRY) + 1; - pba_entries.resize_with(num_pba_entries, Default::default); + let (table_entries, pba_entries, masked, enabled) = if let Some(state) = state { + if state.enabled && !state.masked { + for (idx, table_entry) in state.table_entries.iter().enumerate() { + if table_entry.masked() { + continue; + } - MsixConfig { + let config = MsiIrqSourceConfig { + high_addr: table_entry.msg_addr_hi, + low_addr: table_entry.msg_addr_lo, + data: table_entry.msg_data, + devid, + }; + + interrupt_source_group + .update( + idx as InterruptIndex, + InterruptSourceConfig::MsiIrq(config), + state.masked, + ) + .map_err(Error::UpdateInterruptRoute)?; + + interrupt_source_group + .enable() + .map_err(Error::EnableInterruptRoute)?; + } + } + + ( + state.table_entries, + state.pba_entries, + state.masked, + state.enabled, + ) + } else { + let mut table_entries: Vec = Vec::new(); + table_entries.resize_with(msix_vectors as usize, Default::default); + let mut pba_entries: Vec = Vec::new(); + let num_pba_entries: usize = ((msix_vectors as usize) / BITS_PER_PBA_ENTRY) + 1; + pba_entries.resize_with(num_pba_entries, Default::default); + + (table_entries, pba_entries, true, false) + }; + + Ok(MsixConfig { table_entries, pba_entries, devid, interrupt_source_group, - masked: true, - enabled: false, - } + masked, + enabled, + }) } fn state(&self) -> MsixConfigState { @@ -426,7 +466,7 @@ impl Pausable for MsixConfig {} impl Snapshottable for MsixConfig { fn id(&self) -> String { - String::from("msix_config") + String::from(MSIX_CONFIG_ID) } fn snapshot(&mut self) -> std::result::Result { diff --git a/pci/src/vfio.rs b/pci/src/vfio.rs index 5c2a7327e..7c59ef6e5 100644 --- a/pci/src/vfio.rs +++ b/pci/src/vfio.rs @@ -663,7 +663,9 @@ impl VfioCommon { msix_cap.table_size(), interrupt_source_group.clone(), bdf.into(), - ); + None, + ) + .unwrap(); self.interrupt.msix = Some(VfioMsix { bar: msix_config, @@ -1235,6 +1237,7 @@ impl VfioPciDevice { 0, 0, None, + None, ); let vfio_wrapper = VfioDeviceWrapper::new(Arc::clone(&device)); diff --git a/pci/src/vfio_user.rs b/pci/src/vfio_user.rs index 6715c03f5..473476dc0 100644 --- a/pci/src/vfio_user.rs +++ b/pci/src/vfio_user.rs @@ -88,6 +88,7 @@ impl VfioUserPciDevice { 0, 0, None, + None, ); let resettable = client.lock().unwrap().resettable(); if resettable { diff --git a/virtio-devices/src/transport/mod.rs b/virtio-devices/src/transport/mod.rs index cb38c6fb6..038d4cb98 100644 --- a/virtio-devices/src/transport/mod.rs +++ b/virtio-devices/src/transport/mod.rs @@ -5,7 +5,7 @@ use vmm_sys_util::eventfd::EventFd; mod pci_common_config; mod pci_device; -pub use pci_common_config::VirtioPciCommonConfig; +pub use pci_common_config::{VirtioPciCommonConfig, VIRTIO_PCI_COMMON_CONFIG_ID}; pub use pci_device::{VirtioPciDevice, VirtioPciDeviceActivator}; pub trait VirtioTransport { diff --git a/virtio-devices/src/transport/pci_common_config.rs b/virtio-devices/src/transport/pci_common_config.rs index 399fdadbf..6cbb992ff 100644 --- a/virtio-devices/src/transport/pci_common_config.rs +++ b/virtio-devices/src/transport/pci_common_config.rs @@ -16,6 +16,8 @@ use virtio_queue::{Queue, QueueT}; use vm_migration::{MigratableError, Pausable, Snapshot, Snapshottable, VersionMapped}; use vm_virtio::AccessPlatform; +pub const VIRTIO_PCI_COMMON_CONFIG_ID: &str = "virtio_pci_common_config"; + #[derive(Clone, Versionize)] pub struct VirtioPciCommonConfigState { pub driver_status: u8, @@ -63,6 +65,22 @@ pub struct VirtioPciCommonConfig { } impl VirtioPciCommonConfig { + pub fn new( + state: VirtioPciCommonConfigState, + access_platform: Option>, + ) -> Self { + VirtioPciCommonConfig { + access_platform, + driver_status: state.driver_status, + config_generation: state.config_generation, + device_feature_select: state.device_feature_select, + driver_feature_select: state.driver_feature_select, + queue_select: state.queue_select, + msix_config: Arc::new(AtomicU16::new(state.msix_config)), + msix_queues: Arc::new(Mutex::new(state.msix_queues)), + } + } + fn state(&self) -> VirtioPciCommonConfigState { VirtioPciCommonConfigState { driver_status: self.driver_status, @@ -75,16 +93,6 @@ impl VirtioPciCommonConfig { } } - fn set_state(&mut self, state: &VirtioPciCommonConfigState) { - self.driver_status = state.driver_status; - self.config_generation = state.config_generation; - self.device_feature_select = state.device_feature_select; - self.driver_feature_select = state.driver_feature_select; - self.queue_select = state.queue_select; - self.msix_config.store(state.msix_config, Ordering::Release); - *(self.msix_queues.lock().unwrap()) = state.msix_queues.clone(); - } - pub fn read( &mut self, offset: u64, @@ -309,17 +317,12 @@ impl Pausable for VirtioPciCommonConfig {} impl Snapshottable for VirtioPciCommonConfig { fn id(&self) -> String { - String::from("virtio_pci_common_config") + String::from(VIRTIO_PCI_COMMON_CONFIG_ID) } fn snapshot(&mut self) -> std::result::Result { Snapshot::new_from_versioned_state(&self.id(), &self.state()) } - - fn restore(&mut self, snapshot: Snapshot) -> std::result::Result<(), MigratableError> { - self.set_state(&snapshot.to_versioned_state(&self.id())?); - Ok(()) - } } #[cfg(test)] diff --git a/virtio-devices/src/transport/pci_device.rs b/virtio-devices/src/transport/pci_device.rs index c08019c4e..e92876b9d 100644 --- a/virtio-devices/src/transport/pci_device.rs +++ b/virtio-devices/src/transport/pci_device.rs @@ -6,15 +6,13 @@ // // SPDX-License-Identifier: Apache-2.0 AND BSD-3-Clause -use super::VirtioPciCommonConfig; -use crate::transport::VirtioTransport; +use crate::transport::{VirtioPciCommonConfig, VirtioTransport, VIRTIO_PCI_COMMON_CONFIG_ID}; use crate::GuestMemoryMmap; use crate::{ ActivateResult, VirtioDevice, VirtioDeviceType, VirtioInterrupt, VirtioInterruptType, DEVICE_ACKNOWLEDGE, DEVICE_DRIVER, DEVICE_DRIVER_OK, DEVICE_FAILED, DEVICE_FEATURES_OK, DEVICE_INIT, }; -use anyhow::anyhow; use libc::EFD_NONBLOCK; use pci::{ BarReprogrammingParams, MsixCap, MsixConfig, PciBarConfiguration, PciBarRegionType, @@ -30,7 +28,7 @@ use std::sync::atomic::{AtomicBool, AtomicU16, AtomicUsize, Ordering}; use std::sync::{Arc, Barrier, Mutex}; use versionize::{VersionMap, Versionize, VersionizeResult}; use versionize_derive::Versionize; -use virtio_queue::{Error as QueueError, Queue, QueueT}; +use virtio_queue::{Queue, QueueT}; use vm_allocator::{AddressAllocator, SystemAllocator}; use vm_device::dma_mapping::ExternalDmaMapping; use vm_device::interrupt::{ @@ -44,15 +42,11 @@ use vm_migration::{ use vm_virtio::AccessPlatform; use vmm_sys_util::{errno::Result, eventfd::EventFd}; +use super::pci_common_config::VirtioPciCommonConfigState; + /// Vector value used to disable MSI for a queue. const VIRTQ_MSI_NO_VECTOR: u16 = 0xffff; -#[derive(Debug)] -enum Error { - /// Failed to retrieve queue ring's index. - QueueRingIndex(QueueError), -} - #[allow(clippy::enum_variant_names)] enum PciCapabilityType { CommonConfig = 1, @@ -280,7 +274,7 @@ struct QueueState { } #[derive(Versionize)] -struct VirtioPciDeviceState { +pub struct VirtioPciDeviceState { device_activated: bool, queues: Vec, interrupt_status: usize, @@ -390,9 +384,9 @@ impl VirtioPciDevice { use_64bit_bar: bool, dma_handler: Option>, pending_activations: Arc>>, + snapshot: Option, ) -> Result { - let device_clone = device.clone(); - let mut locked_device = device_clone.lock().unwrap(); + let mut locked_device = device.lock().unwrap(); let mut queue_evts = Vec::new(); for _ in locked_device.queue_max_sizes().iter() { queue_evts.push(EventFd::new(EFD_NONBLOCK)?) @@ -403,7 +397,7 @@ impl VirtioPciDevice { locked_device.set_access_platform(access_platform.clone()); } - let queues = locked_device + let mut queues: Vec = locked_device .queue_max_sizes() .iter() .map(|&s| Queue::new(s).unwrap()) @@ -416,12 +410,26 @@ impl VirtioPciDevice { count: msix_num as InterruptIndex, })?; + let msix_state = + vm_migration::versioned_state_from_id(snapshot.as_ref(), pci::MSIX_CONFIG_ID).map_err( + |e| { + std::io::Error::new( + std::io::ErrorKind::Other, + format!("Failed to get MsixConfigState from Snapshot: {}", e), + ) + }, + )?; + let (msix_config, msix_config_clone) = if msix_num > 0 { - let msix_config = Arc::new(Mutex::new(MsixConfig::new( - msix_num, - interrupt_source_group.clone(), - pci_device_bdf, - ))); + let msix_config = Arc::new(Mutex::new( + MsixConfig::new( + msix_num, + interrupt_source_group.clone(), + pci_device_bdf, + msix_state, + ) + .unwrap(), + )); let msix_config_clone = msix_config.clone(); (Some(msix_config), Some(msix_config_clone)) } else { @@ -443,6 +451,15 @@ impl VirtioPciDevice { ), }; + let pci_configuration_state = + vm_migration::versioned_state_from_id(snapshot.as_ref(), pci::PCI_CONFIGURATION_ID) + .map_err(|e| { + std::io::Error::new( + std::io::ErrorKind::Other, + format!("Failed to get PciConfigurationState from Snapshot: {}", e), + ) + })?; + let configuration = PciConfiguration::new( VIRTIO_PCI_VENDOR_ID, pci_device_id, @@ -454,26 +471,97 @@ impl VirtioPciDevice { VIRTIO_PCI_VENDOR_ID, pci_device_id, msix_config_clone, + pci_configuration_state, ); + let common_config_state = + vm_migration::versioned_state_from_id(snapshot.as_ref(), VIRTIO_PCI_COMMON_CONFIG_ID) + .map_err(|e| { + std::io::Error::new( + std::io::ErrorKind::Other, + format!( + "Failed to get VirtioPciCommonConfigState from Snapshot: {}", + e + ), + ) + })?; + + let common_config = if let Some(common_config_state) = common_config_state { + VirtioPciCommonConfig::new(common_config_state, access_platform) + } else { + VirtioPciCommonConfig::new( + VirtioPciCommonConfigState { + driver_status: 0, + config_generation: 0, + device_feature_select: 0, + driver_feature_select: 0, + queue_select: 0, + msix_config: VIRTQ_MSI_NO_VECTOR, + msix_queues: vec![VIRTQ_MSI_NO_VECTOR; num_queues], + }, + access_platform, + ) + }; + + let state: Option = snapshot + .as_ref() + .map(|s| s.to_versioned_state(&id)) + .transpose() + .map_err(|e| { + std::io::Error::new( + std::io::ErrorKind::Other, + format!("Failed to get VirtioPciDeviceState from Snapshot: {}", e), + ) + })?; + + let (device_activated, interrupt_status) = if let Some(state) = state { + // Update virtqueues indexes for both available and used rings. + for (i, queue) in queues.iter_mut().enumerate() { + queue.set_size(state.queues[i].size); + queue.set_ready(state.queues[i].ready); + queue + .try_set_desc_table_address(GuestAddress(state.queues[i].desc_table)) + .unwrap(); + queue + .try_set_avail_ring_address(GuestAddress(state.queues[i].avail_ring)) + .unwrap(); + queue + .try_set_used_ring_address(GuestAddress(state.queues[i].used_ring)) + .unwrap(); + queue.set_next_avail( + queue + .used_idx(memory.memory().deref(), Ordering::Acquire) + .unwrap() + .0, + ); + queue.set_next_used( + queue + .used_idx(memory.memory().deref(), Ordering::Acquire) + .unwrap() + .0, + ); + } + + (state.device_activated, state.interrupt_status) + } else { + (false, 0) + }; + + // Dropping the MutexGuard to unlock the VirtioDevice. This is required + // in the context of a restore given the device might require some + // activation, meaning it will require locking. Dropping the lock + // prevents from a subtle deadlock. + std::mem::drop(locked_device); + let mut virtio_pci_device = VirtioPciDevice { id, configuration, - common_config: VirtioPciCommonConfig { - access_platform, - driver_status: 0, - config_generation: 0, - device_feature_select: 0, - driver_feature_select: 0, - queue_select: 0, - msix_config: Arc::new(AtomicU16::new(VIRTQ_MSI_NO_VECTOR)), - msix_queues: Arc::new(Mutex::new(vec![VIRTQ_MSI_NO_VECTOR; num_queues])), - }, + common_config, msix_config, msix_num, device, - device_activated: Arc::new(AtomicBool::new(false)), - interrupt_status: Arc::new(AtomicUsize::new(0)), + device_activated: Arc::new(AtomicBool::new(device_activated)), + interrupt_status: Arc::new(AtomicUsize::new(interrupt_status)), virtio_interrupt: None, queues, queue_evts, @@ -497,6 +585,20 @@ impl VirtioPciDevice { ))); } + // In case of a restore, we can activate the device, as we know at + // this point the virtqueues are in the right state and the device is + // ready to be activated, which will spawn each virtio worker thread. + if virtio_pci_device.device_activated.load(Ordering::SeqCst) + && virtio_pci_device.is_driver_ready() + { + virtio_pci_device.activate().map_err(|e| { + std::io::Error::new( + std::io::ErrorKind::Other, + format!("Failed activating the device: {}", e), + ) + })?; + } + Ok(virtio_pci_device) } @@ -519,42 +621,6 @@ impl VirtioPciDevice { } } - fn set_state(&mut self, state: &VirtioPciDeviceState) -> std::result::Result<(), Error> { - self.device_activated - .store(state.device_activated, Ordering::Release); - self.interrupt_status - .store(state.interrupt_status, Ordering::Release); - - // Update virtqueues indexes for both available and used rings. - for (i, queue) in self.queues.iter_mut().enumerate() { - queue.set_size(state.queues[i].size); - queue.set_ready(state.queues[i].ready); - queue - .try_set_desc_table_address(GuestAddress(state.queues[i].desc_table)) - .unwrap(); - queue - .try_set_avail_ring_address(GuestAddress(state.queues[i].avail_ring)) - .unwrap(); - queue - .try_set_used_ring_address(GuestAddress(state.queues[i].used_ring)) - .unwrap(); - queue.set_next_avail( - queue - .used_idx(self.memory.memory().deref(), Ordering::Acquire) - .map_err(Error::QueueRingIndex)? - .0, - ); - queue.set_next_used( - queue - .used_idx(self.memory.memory().deref(), Ordering::Acquire) - .map_err(Error::QueueRingIndex)? - .0, - ); - } - - Ok(()) - } - /// Gets the list of queue events that must be triggered whenever the VM writes to /// `virtio::NOTIFY_REG_OFFSET` past the MMIO base. Each event must be triggered when the /// value being written equals the index of the event in this list. @@ -883,6 +949,7 @@ impl PciDevice for VirtioPciDevice { let mut settings_bar_addr = None; let mut use_64bit_bar = self.use_64bit_bar; + let restoring = resources.is_some(); if let Some(resources) = resources { for resource in resources { if let Resource::PciBar { @@ -939,39 +1006,53 @@ impl PciDevice for VirtioPciDevice { .set_address(virtio_pci_bar_addr.raw_value()) .set_size(CAPABILITY_BAR_SIZE) .set_region_type(region_type); - self.configuration.add_pci_bar(&bar).map_err(|e| { - PciDeviceError::IoRegistrationFailed(virtio_pci_bar_addr.raw_value(), e) - })?; + + // The creation of the PCI BAR and its associated capabilities must + // happen only during the creation of a brand new VM. When a VM is + // restored from a known state, the BARs are already created with the + // right content, therefore we don't need to go through this codepath. + if !restoring { + self.configuration.add_pci_bar(&bar).map_err(|e| { + PciDeviceError::IoRegistrationFailed(virtio_pci_bar_addr.raw_value(), e) + })?; + + // Once the BARs are allocated, the capabilities can be added to the PCI configuration. + self.add_pci_capabilities(VIRTIO_COMMON_BAR_INDEX as u8)?; + } bars.push(bar); - // Once the BARs are allocated, the capabilities can be added to the PCI configuration. - self.add_pci_capabilities(VIRTIO_COMMON_BAR_INDEX as u8)?; - // Allocate a dedicated BAR if there are some shared memory regions. if let Some(shm_list) = device.get_shm_regions() { let bar = PciBarConfiguration::default() .set_index(VIRTIO_SHM_BAR_INDEX) .set_address(shm_list.addr.raw_value()) .set_size(shm_list.len); - self.configuration - .add_pci_bar(&bar) - .map_err(|e| PciDeviceError::IoRegistrationFailed(shm_list.addr.raw_value(), e))?; + + // The creation of the PCI BAR and its associated capabilities must + // happen only during the creation of a brand new VM. When a VM is + // restored from a known state, the BARs are already created with the + // right content, therefore we don't need to go through this codepath. + if !restoring { + self.configuration.add_pci_bar(&bar).map_err(|e| { + PciDeviceError::IoRegistrationFailed(shm_list.addr.raw_value(), e) + })?; + + for (idx, shm) in shm_list.region_list.iter().enumerate() { + let shm_cap = VirtioPciCap64::new( + PciCapabilityType::SharedMemoryConfig, + VIRTIO_SHM_BAR_INDEX as u8, + idx as u8, + shm.offset, + shm.len, + ); + self.configuration + .add_capability(&shm_cap) + .map_err(PciDeviceError::CapabilitiesSetup)?; + } + } bars.push(bar); - - for (idx, shm) in shm_list.region_list.iter().enumerate() { - let shm_cap = VirtioPciCap64::new( - PciCapabilityType::SharedMemoryConfig, - VIRTIO_SHM_BAR_INDEX as u8, - idx as u8, - shm.offset, - shm.len, - ); - self.configuration - .add_capability(&shm_cap) - .map_err(PciDeviceError::CapabilitiesSetup)?; - } } self.bar_regions = bars.clone(); @@ -1186,58 +1267,6 @@ impl Snapshottable for VirtioPciDevice { Ok(virtio_pci_dev_snapshot) } - - fn restore(&mut self, snapshot: Snapshot) -> std::result::Result<(), MigratableError> { - if let Some(virtio_pci_dev_section) = - snapshot.snapshot_data.get(&format!("{}-section", self.id)) - { - // Restore MSI-X - if let Some(msix_config) = &self.msix_config { - let id = msix_config.lock().unwrap().id(); - if let Some(msix_snapshot) = snapshot.snapshots.get(&id) { - msix_config - .lock() - .unwrap() - .restore(*msix_snapshot.clone())?; - } - } - - // Restore VirtioPciCommonConfig - if let Some(virtio_config_snapshot) = snapshot.snapshots.get(&self.common_config.id()) { - self.common_config - .restore(*virtio_config_snapshot.clone())?; - } - - // Restore PciConfiguration - if let Some(pci_config_snapshot) = snapshot.snapshots.get(&self.configuration.id()) { - self.configuration.restore(*pci_config_snapshot.clone())?; - } - - // First restore the status of the virtqueues. - self.set_state(&virtio_pci_dev_section.to_versioned_state()?) - .map_err(|e| { - MigratableError::Restore(anyhow!( - "Could not restore VIRTIO_PCI_DEVICE state {:?}", - e - )) - })?; - - // Then we can activate the device, as we know at this point that - // the virtqueues are in the right state and the device is ready - // to be activated, which will spawn each virtio worker thread. - if self.device_activated.load(Ordering::SeqCst) && self.is_driver_ready() { - self.activate().map_err(|e| { - MigratableError::Restore(anyhow!("Failed activating the device: {:?}", e)) - })?; - } - - return Ok(()); - } - - Err(MigratableError::Restore(anyhow!( - "Could not find VIRTIO_PCI_DEVICE snapshot section" - ))) - } } impl Transportable for VirtioPciDevice {} impl Migratable for VirtioPciDevice {} diff --git a/vmm/src/device_manager.rs b/vmm/src/device_manager.rs index 45640b82a..dcedabbea 100644 --- a/vmm/src/device_manager.rs +++ b/vmm/src/device_manager.rs @@ -3515,6 +3515,7 @@ impl DeviceManager { pci_segment_id > 0 || device_type != VirtioDeviceType::Block as u32, dma_handler, self.pending_activations.clone(), + vm_migration::snapshot_from_id(self.snapshot.as_ref(), id.as_str()), ) .map_err(DeviceManagerError::VirtioDevice)?, ));