diff --git a/pci/src/vfio.rs b/pci/src/vfio.rs index 1f1efadfc..4502d1cb8 100644 --- a/pci/src/vfio.rs +++ b/pci/src/vfio.rs @@ -419,6 +419,7 @@ pub(crate) struct VfioCommon { pub(crate) legacy_interrupt_group: Option>, pub(crate) vfio_wrapper: Arc, pub(crate) patches: HashMap, + x_nv_gpudirect_clique: Option, } impl VfioCommon { @@ -429,6 +430,7 @@ impl VfioCommon { subclass: &dyn PciSubclass, bdf: PciBdf, snapshot: Option, + x_nv_gpudirect_clique: Option, ) -> Result { let pci_configuration_state = vm_migration::versioned_state_from_id(snapshot.as_ref(), PCI_CONFIGURATION_ID) @@ -465,6 +467,7 @@ impl VfioCommon { legacy_interrupt_group, vfio_wrapper, patches: HashMap::new(), + x_nv_gpudirect_clique, }; let state: Option = snapshot @@ -859,15 +862,15 @@ impl VfioCommon { } pub(crate) fn parse_capabilities(&mut self, bdf: PciBdf) { - let mut cap_next = self + let mut cap_iter = self .vfio_wrapper .read_config_byte(PCI_CONFIG_CAPABILITY_OFFSET); let mut pci_express_cap_found = false; let mut power_management_cap_found = false; - while cap_next != 0 { - let cap_id = self.vfio_wrapper.read_config_byte(cap_next.into()); + while cap_iter != 0 { + let cap_id = self.vfio_wrapper.read_config_byte(cap_iter.into()); match PciCapabilityId::from(cap_id) { PciCapabilityId::MessageSignalledInterrupts => { @@ -875,8 +878,8 @@ impl VfioCommon { if irq_info.count > 0 { // Parse capability only if the VFIO device // supports MSI. - let msg_ctl = self.parse_msi_capabilities(cap_next); - self.initialize_msi(msg_ctl, cap_next as u32, None); + let msg_ctl = self.parse_msi_capabilities(cap_iter); + self.initialize_msi(msg_ctl, cap_iter as u32, None); } } } @@ -886,8 +889,8 @@ impl VfioCommon { if irq_info.count > 0 { // Parse capability only if the VFIO device // supports MSI-X. - let msix_cap = self.parse_msix_capabilities(cap_next); - self.initialize_msix(msix_cap, cap_next as u32, bdf, None); + let msix_cap = self.parse_msix_capabilities(cap_iter); + self.initialize_msix(msix_cap, cap_iter as u32, bdf, None); } } } @@ -896,7 +899,16 @@ impl VfioCommon { _ => {} }; - cap_next = self.vfio_wrapper.read_config_byte((cap_next + 1).into()); + let cap_next = self.vfio_wrapper.read_config_byte((cap_iter + 1).into()); + if cap_next == 0 { + break; + } + + cap_iter = cap_next; + } + + if let Some(clique_id) = self.x_nv_gpudirect_clique { + self.add_nv_gpudirect_clique_cap(cap_iter, clique_id); } if pci_express_cap_found && power_management_cap_found { @@ -904,6 +916,37 @@ impl VfioCommon { } } + fn add_nv_gpudirect_clique_cap(&mut self, cap_iter: u8, clique_id: u8) { + // Turing, Ampere, Hopper, and Lovelace GPUs have dedicated space + // at 0xD4 for this capability. + let cap_offset = 0xd4u32; + + let reg_idx = (cap_iter / 4) as usize; + self.patches.insert( + reg_idx, + ConfigPatch { + mask: 0x0000_ff00, + patch: cap_offset << 8, + }, + ); + + let reg_idx = (cap_offset / 4) as usize; + self.patches.insert( + reg_idx, + ConfigPatch { + mask: 0xffff_ffff, + patch: 0x50080009u32, + }, + ); + self.patches.insert( + reg_idx + 1, + ConfigPatch { + mask: 0xffff_ffff, + patch: u32::from(clique_id) << 19 | 0x5032, + }, + ); + } + fn parse_extended_capabilities(&mut self) { let mut current_offset = PCI_CONFIG_EXTENDED_CAPABILITY_OFFSET; @@ -1351,6 +1394,7 @@ impl VfioPciDevice { bdf: PciBdf, memory_slot: Arc u32 + Send + Sync>, snapshot: Option, + x_nv_gpudirect_clique: Option, ) -> Result { let device = Arc::new(device); device.reset(); @@ -1364,6 +1408,7 @@ impl VfioPciDevice { &PciVfioSubclass::VfioSubclass, bdf, vm_migration::snapshot_from_id(snapshot.as_ref(), VFIO_COMMON_ID), + x_nv_gpudirect_clique, )?; let vfio_pci_device = VfioPciDevice { diff --git a/pci/src/vfio_user.rs b/pci/src/vfio_user.rs index c5e64b970..da4048ac1 100644 --- a/pci/src/vfio_user.rs +++ b/pci/src/vfio_user.rs @@ -94,6 +94,7 @@ impl VfioUserPciDevice { &PciVfioUserSubclass::VfioUserSubclass, bdf, vm_migration::snapshot_from_id(snapshot.as_ref(), VFIO_COMMON_ID), + None, ) .map_err(VfioUserPciDeviceError::CreateVfioCommon)?; diff --git a/vmm/src/api/openapi/cloud-hypervisor.yaml b/vmm/src/api/openapi/cloud-hypervisor.yaml index e82697713..81d5a6a81 100644 --- a/vmm/src/api/openapi/cloud-hypervisor.yaml +++ b/vmm/src/api/openapi/cloud-hypervisor.yaml @@ -1046,7 +1046,9 @@ components: format: int16 id: type: string - + x_nv_gpudirect_clique: + type: integer + format: int8 TpmConfig: required: - socket diff --git a/vmm/src/config.rs b/vmm/src/config.rs index 4848f4a6c..c41bf2b5e 100644 --- a/vmm/src/config.rs +++ b/vmm/src/config.rs @@ -1746,7 +1746,12 @@ impl DeviceConfig { pub fn parse(device: &str) -> Result { let mut parser = OptionParser::new(); - parser.add("path").add("id").add("iommu").add("pci_segment"); + parser + .add("path") + .add("id") + .add("iommu") + .add("pci_segment") + .add("x_nv_gpudirect_clique"); parser.parse(device).map_err(Error::ParseDevice)?; let path = parser @@ -1763,12 +1768,15 @@ impl DeviceConfig { .convert::("pci_segment") .map_err(Error::ParseDevice)? .unwrap_or_default(); - + let x_nv_gpudirect_clique = parser + .convert::("x_nv_gpudirect_clique") + .map_err(Error::ParseDevice)?; Ok(DeviceConfig { path, iommu, id, pci_segment, + x_nv_gpudirect_clique, }) } @@ -3324,6 +3332,7 @@ mod tests { id: None, iommu: false, pci_segment: 0, + x_nv_gpudirect_clique: None, } } diff --git a/vmm/src/device_manager.rs b/vmm/src/device_manager.rs index 6c8e5bc94..6f6e45c74 100644 --- a/vmm/src/device_manager.rs +++ b/vmm/src/device_manager.rs @@ -3510,6 +3510,7 @@ impl DeviceManager { pci_device_bdf, Arc::new(move || memory_manager.lock().unwrap().allocate_memory_slot()), vm_migration::snapshot_from_id(self.snapshot.as_ref(), vfio_name.as_str()), + device_cfg.x_nv_gpudirect_clique, ) .map_err(DeviceManagerError::VfioPciCreate)?; diff --git a/vmm/src/vm_config.rs b/vmm/src/vm_config.rs index 5c11b3a40..27233fb05 100644 --- a/vmm/src/vm_config.rs +++ b/vmm/src/vm_config.rs @@ -434,6 +434,8 @@ pub struct DeviceConfig { pub id: Option, #[serde(default)] pub pci_segment: u16, + #[serde(default)] + pub x_nv_gpudirect_clique: Option, } #[derive(Clone, Debug, PartialEq, Eq, Deserialize, Serialize)]