diff --git a/vmm/src/interrupt.rs b/vmm/src/interrupt.rs index e2d12718c..8c76db84e 100644 --- a/vmm/src/interrupt.rs +++ b/vmm/src/interrupt.rs @@ -82,8 +82,8 @@ pub struct MsiInterruptGroup { irq_routes: HashMap, } -pub trait MsiInterruptGroupOps { - fn set_gsi_routes(&self) -> Result<()>; +pub trait MsiInterruptGroupOps { + fn set_gsi_routes(&self, routes: &HashMap>) -> Result<()>; } pub trait RoutingEntryExt { @@ -112,7 +112,7 @@ impl InterruptSourceGroup for MsiInterruptGroup where E: Send + Sync, RoutingEntry: RoutingEntryExt, - MsiInterruptGroup: MsiInterruptGroupOps, + MsiInterruptGroup: MsiInterruptGroupOps, { fn enable(&self) -> Result<()> { for (_, route) in self.irq_routes.iter() { @@ -152,12 +152,9 @@ where fn update(&self, index: InterruptIndex, config: InterruptSourceConfig) -> Result<()> { if let Some(route) = self.irq_routes.get(&index) { let entry = RoutingEntry::<_>::make_entry(&self.vm, route.gsi, &config)?; - self.gsi_msi_routes - .lock() - .unwrap() - .insert(route.gsi, *entry); - - return self.set_gsi_routes(); + let mut routes = self.gsi_msi_routes.lock().unwrap(); + routes.insert(route.gsi, *entry); + return self.set_gsi_routes(&routes); } Err(io::Error::new( @@ -168,8 +165,8 @@ where fn mask(&self, index: InterruptIndex) -> Result<()> { if let Some(route) = self.irq_routes.get(&index) { - let mut gsi_msi_routes = self.gsi_msi_routes.lock().unwrap(); - if let Some(entry) = gsi_msi_routes.get_mut(&route.gsi) { + let mut routes = self.gsi_msi_routes.lock().unwrap(); + if let Some(entry) = routes.get_mut(&route.gsi) { entry.masked = true; } else { return Err(io::Error::new( @@ -177,9 +174,7 @@ where format!("mask: No existing route for interrupt index {}", index), )); } - // Drop the guard because set_gsi_routes will try to take the lock again. - drop(gsi_msi_routes); - self.set_gsi_routes()?; + self.set_gsi_routes(&routes)?; return route.disable(&self.vm); } @@ -191,8 +186,8 @@ where fn unmask(&self, index: InterruptIndex) -> Result<()> { if let Some(route) = self.irq_routes.get(&index) { - let mut gsi_msi_routes = self.gsi_msi_routes.lock().unwrap(); - if let Some(entry) = gsi_msi_routes.get_mut(&route.gsi) { + let mut routes = self.gsi_msi_routes.lock().unwrap(); + if let Some(entry) = routes.get_mut(&route.gsi) { entry.masked = false; } else { return Err(io::Error::new( @@ -200,9 +195,7 @@ where format!("mask: No existing route for interrupt index {}", index), )); } - // Drop the guard because set_gsi_routes will try to take the lock again. - drop(gsi_msi_routes); - self.set_gsi_routes()?; + self.set_gsi_routes(&routes)?; return route.enable(&self.vm); } @@ -297,7 +290,7 @@ impl InterruptManager for MsiInterruptManager where E: Send + Sync + 'static, RoutingEntry: RoutingEntryExt, - MsiInterruptGroup: MsiInterruptGroupOps, + MsiInterruptGroup: MsiInterruptGroupOps, { type GroupConfig = MsiIrqGroupConfig; @@ -370,11 +363,13 @@ pub mod kvm { } } - impl MsiInterruptGroupOps for KvmMsiInterruptGroup { - fn set_gsi_routes(&self) -> Result<()> { - let gsi_msi_routes = self.gsi_msi_routes.lock().unwrap(); + impl MsiInterruptGroupOps for KvmMsiInterruptGroup { + fn set_gsi_routes( + &self, + routes: &HashMap>, + ) -> Result<()> { let mut entry_vec: Vec = Vec::new(); - for (_, entry) in gsi_msi_routes.iter() { + for (_, entry) in routes.iter() { if entry.masked { continue; }