diff --git a/arch/src/x86_64/mod.rs b/arch/src/x86_64/mod.rs index f0cee7526..6827ec34d 100644 --- a/arch/src/x86_64/mod.rs +++ b/arch/src/x86_64/mod.rs @@ -37,6 +37,9 @@ const TSC_DEADLINE_TIMER_ECX_BIT: u8 = 24; // tsc deadline timer ecx bit. const HYPERVISOR_ECX_BIT: u8 = 31; // Hypervisor ecx bit. const MTRR_EDX_BIT: u8 = 12; // Hypervisor ecx bit. const INVARIANT_TSC_EDX_BIT: u8 = 8; // Invariant TSC bit on 0x8000_0007 EDX +const AMX_BF16: u8 = 22; // AMX tile computation on bfloat16 numbers +const AMX_TILE: u8 = 24; // AMX tile load/store instructions +const AMX_INT8: u8 = 25; // AMX tile computation on 8-bit integers // KVM feature bits #[cfg(feature = "tdx")] @@ -151,6 +154,7 @@ pub struct CpuidConfig { pub kvm_hyperv: bool, #[cfg(feature = "tdx")] pub tdx: bool, + pub amx: bool, } #[derive(Debug)] @@ -640,6 +644,12 @@ pub fn generate_common_cpuid( // Update some existing CPUID for entry in cpuid.as_mut_slice().iter_mut() { match entry.function { + // Clear AMX related bits if the AMX feature is not enabled + 0x7 => { + if !config.amx && entry.index == 0 { + entry.edx &= !(1 << AMX_BF16 | 1 << AMX_TILE | 1 << AMX_INT8) + } + } 0xd => { #[cfg(feature = "tdx")] diff --git a/vmm/src/cpu.rs b/vmm/src/cpu.rs index f0cbfe653..ba8f24667 100644 --- a/vmm/src/cpu.rs +++ b/vmm/src/cpu.rs @@ -745,6 +745,7 @@ impl CpuManager { kvm_hyperv: self.config.kvm_hyperv, #[cfg(feature = "tdx")] tdx, + amx: self.config.features.amx, }, ) .map_err(Error::CommonCpuId)? diff --git a/vmm/src/lib.rs b/vmm/src/lib.rs index c7e61f636..96425b599 100644 --- a/vmm/src/lib.rs +++ b/vmm/src/lib.rs @@ -1665,6 +1665,7 @@ impl Vmm { let common_cpuid = { #[cfg(feature = "tdx")] let tdx = vm_config.lock().unwrap().is_tdx_enabled(); + let amx = vm_config.lock().unwrap().cpus.features.amx; let phys_bits = vm::physical_bits(&hypervisor, vm_config.lock().unwrap().cpus.max_phys_bits); arch::generate_common_cpuid( @@ -1675,6 +1676,7 @@ impl Vmm { kvm_hyperv: vm_config.lock().unwrap().cpus.kvm_hyperv, #[cfg(feature = "tdx")] tdx, + amx, }, ) .map_err(|e| { @@ -1866,6 +1868,7 @@ impl Vmm { kvm_hyperv: vm_config.cpus.kvm_hyperv, #[cfg(feature = "tdx")] tdx: vm_config.is_tdx_enabled(), + amx: vm_config.cpus.features.amx, }, ) .map_err(|e| { diff --git a/vmm/src/vm.rs b/vmm/src/vm.rs index 6f474779a..74f7f9dd0 100644 --- a/vmm/src/vm.rs +++ b/vmm/src/vm.rs @@ -2409,6 +2409,7 @@ impl Snapshottable for Vm { #[cfg(all(feature = "kvm", target_arch = "x86_64"))] let common_cpuid = { + let amx = self.config.lock().unwrap().cpus.features.amx; let phys_bits = physical_bits( &self.hypervisor, self.config.lock().unwrap().cpus.max_phys_bits, @@ -2421,6 +2422,7 @@ impl Snapshottable for Vm { kvm_hyperv: self.config.lock().unwrap().cpus.kvm_hyperv, #[cfg(feature = "tdx")] tdx: tdx_enabled, + amx, }, ) .map_err(|e| {