diff --git a/virtio-devices/src/block.rs b/virtio-devices/src/block.rs index 70a9ad1ce..7d17cc176 100644 --- a/virtio-devices/src/block.rs +++ b/virtio-devices/src/block.rs @@ -17,7 +17,6 @@ use crate::seccomp_filters::{get_seccomp_filter, Thread}; use crate::VirtioInterrupt; use anyhow::anyhow; use block_util::{build_disk_image_id, Request, RequestType, VirtioBlockConfig}; -use libc::EFD_NONBLOCK; use seccomp::{SeccompAction, SeccompFilter}; use std::collections::HashMap; use std::io::{self, Read, Seek, SeekFrom, Write}; @@ -261,18 +260,10 @@ impl EpollHelperHandler for BlockEpollHandler { pub struct Block { common: VirtioCommon, id: String, - kill_evt: Option, disk_image: Arc>, disk_path: PathBuf, disk_nsectors: u64, config: VirtioBlockConfig, - queue_evts: Option>, - interrupt_cb: Option>, - epoll_threads: Option>>, - pause_evt: Option, - paused: Arc, - paused_sync: Arc, - queue_size: Vec, writeback: Arc, counters: BlockCounters, seccomp_action: SeccompAction, @@ -338,22 +329,17 @@ impl Block { Ok(Block { common: VirtioCommon { + device_type: VirtioDeviceType::TYPE_BLOCK as u32, avail_features, + paused_sync: Some(Arc::new(Barrier::new(num_queues + 1))), + queue_sizes: vec![queue_size; num_queues], ..Default::default() }, id, - kill_evt: None, disk_image: Arc::new(Mutex::new(disk_image)), disk_path, disk_nsectors, config, - queue_evts: None, - interrupt_cb: None, - epoll_threads: None, - pause_evt: None, - paused: Arc::new(AtomicBool::new(false)), - paused_sync: Arc::new(Barrier::new(num_queues + 1)), - queue_size: vec![queue_size; num_queues], writeback: Arc::new(AtomicBool::new(true)), counters: BlockCounters::default(), seccomp_action, @@ -401,22 +387,13 @@ impl Block { } } -impl Drop for Block { - fn drop(&mut self) { - if let Some(kill_evt) = self.kill_evt.take() { - // Ignore the result because there is nothing we can do about it. - let _ = kill_evt.write(1); - } - } -} - impl VirtioDevice for Block { fn device_type(&self) -> u32 { - VirtioDeviceType::TYPE_BLOCK as u32 + self.common.device_type } fn queue_max_sizes(&self) -> &[u16] { - self.queue_size.as_slice() + self.common.queue_sizes.as_slice() } fn features(&self) -> u64 { @@ -456,54 +433,35 @@ impl VirtioDevice for Block { mut queues: Vec, mut queue_evts: Vec, ) -> ActivateResult { - if queues.len() != self.queue_size.len() || queue_evts.len() != self.queue_size.len() { - error!( - "Cannot perform activate. Expected {} queue(s), got {}", - self.queue_size.len(), - queues.len() - ); - return Err(ActivateError::BadActivate); - } - - let (self_kill_evt, kill_evt) = EventFd::new(EFD_NONBLOCK) - .and_then(|e| Ok((e.try_clone()?, e))) - .map_err(|e| { - error!("failed creating kill EventFd pair: {}", e); - ActivateError::BadActivate - })?; - - self.kill_evt = Some(self_kill_evt); - - let (self_pause_evt, pause_evt) = EventFd::new(EFD_NONBLOCK) - .and_then(|e| Ok((e.try_clone()?, e))) - .map_err(|e| { - error!("failed creating pause EventFd pair: {}", e); - ActivateError::BadActivate - })?; - self.pause_evt = Some(self_pause_evt); - - // Save the interrupt EventFD as we need to return it on reset - // but clone it to pass into the thread. - self.interrupt_cb = Some(interrupt_cb.clone()); - - let mut tmp_queue_evts: Vec = Vec::new(); - for queue_evt in queue_evts.iter() { - // Save the queue EventFD as we need to return it on reset - // but clone it to pass into the thread. - tmp_queue_evts.push(queue_evt.try_clone().map_err(|e| { - error!("failed to clone queue EventFd: {}", e); - ActivateError::BadActivate - })?); - } - self.queue_evts = Some(tmp_queue_evts); + self.common.activate(&queues, &queue_evts, &interrupt_cb)?; let disk_image_id = build_disk_image_id(&self.disk_path); let event_idx = self.common.feature_acked(VIRTIO_RING_F_EVENT_IDX.into()); self.update_writeback(); let mut epoll_threads = Vec::new(); - for _ in 0..self.queue_size.len() { + for _ in 0..self.common.queue_sizes.len() { let queue_evt = queue_evts.remove(0); + let kill_evt = self + .common + .kill_evt + .as_ref() + .unwrap() + .try_clone() + .map_err(|e| { + error!("failed to clone kill_evt eventfd: {}", e); + ActivateError::BadActivate + })?; + let pause_evt = self + .common + .pause_evt + .as_ref() + .unwrap() + .try_clone() + .map_err(|e| { + error!("failed to clone pause_evt eventfd: {}", e); + ActivateError::BadActivate + })?; let mut handler = BlockEpollHandler { queue: queues.remove(0), mem: mem.clone(), @@ -511,8 +469,8 @@ impl VirtioDevice for Block { disk_nsectors: self.disk_nsectors, interrupt_cb: interrupt_cb.clone(), disk_image_id: disk_image_id.clone(), - kill_evt: kill_evt.try_clone().unwrap(), - pause_evt: pause_evt.try_clone().unwrap(), + kill_evt, + pause_evt, event_idx, writeback: self.writeback.clone(), counters: self.counters.clone(), @@ -521,8 +479,8 @@ impl VirtioDevice for Block { handler.queue.set_event_idx(event_idx); - let paused = self.paused.clone(); - let paused_sync = self.paused_sync.clone(); + let paused = self.common.paused.clone(); + let paused_sync = self.common.paused_sync.clone(); // Retrieve seccomp filter for virtio_blk thread let virtio_blk_seccomp_filter = @@ -534,7 +492,7 @@ impl VirtioDevice for Block { .spawn(move || { if let Err(e) = SeccompFilter::apply(virtio_blk_seccomp_filter) { error!("Error applying seccomp filter: {:?}", e); - } else if let Err(e) = handler.run(paused, paused_sync) { + } else if let Err(e) = handler.run(paused, paused_sync.unwrap()) { error!("Error running worker: {:?}", e); } }) @@ -545,27 +503,13 @@ impl VirtioDevice for Block { })?; } - self.epoll_threads = Some(epoll_threads); + self.common.epoll_threads = Some(epoll_threads); Ok(()) } fn reset(&mut self) -> Option<(Arc, Vec)> { - // We first must resume the virtio thread if it was paused. - if self.pause_evt.take().is_some() { - self.resume().ok()?; - } - - if let Some(kill_evt) = self.kill_evt.take() { - // Ignore the result because there is nothing we can do about it. - let _ = kill_evt.write(1); - } - - // Return the interrupt and queue EventFDs - Some(( - self.interrupt_cb.take().unwrap(), - self.queue_evts.take().unwrap(), - )) + self.common.reset() } fn counters(&self) -> Option>> { @@ -592,7 +536,25 @@ impl VirtioDevice for Block { } } -virtio_pausable!(Block, T: 'static + DiskFile + Send); +impl Drop for Block { + fn drop(&mut self) { + if let Some(kill_evt) = self.common.kill_evt.take() { + // Ignore the result because there is nothing we can do about it. + let _ = kill_evt.write(1); + } + } +} + +impl Pausable for Block { + fn pause(&mut self) -> result::Result<(), MigratableError> { + self.common.pause() + } + + fn resume(&mut self) -> result::Result<(), MigratableError> { + self.common.resume() + } +} + impl Snapshottable for Block { fn id(&self) -> String { self.id.clone() diff --git a/virtio-devices/src/device.rs b/virtio-devices/src/device.rs index 9e6d1f0d2..80d801f89 100644 --- a/virtio-devices/src/device.rs +++ b/virtio-devices/src/device.rs @@ -6,12 +6,18 @@ // // SPDX-License-Identifier: Apache-2.0 AND BSD-3-Clause -use crate::{ActivateResult, Error, Queue}; +use crate::{ActivateError, ActivateResult, Error, Queue}; +use libc::EFD_NONBLOCK; use std::collections::HashMap; use std::io::Write; use std::num::Wrapping; -use std::sync::Arc; +use std::sync::{ + atomic::{AtomicBool, Ordering}, + Arc, Barrier, +}; +use std::thread; use vm_memory::{GuestAddress, GuestMemoryAtomic, GuestMemoryMmap, GuestUsize}; +use vm_migration::{MigratableError, Pausable}; use vm_virtio::VirtioDeviceType; use vmm_sys_util::eventfd::EventFd; @@ -187,6 +193,15 @@ pub trait DmaRemapping: Send + Sync { pub struct VirtioCommon { pub avail_features: u64, pub acked_features: u64, + pub kill_evt: Option, + pub interrupt_cb: Option>, + pub queue_evts: Option>, + pub pause_evt: Option, + pub paused: Arc, + pub paused_sync: Option>, + pub epoll_threads: Option>>, + pub queue_sizes: Vec, + pub device_type: u32, } impl VirtioCommon { @@ -206,6 +221,107 @@ impl VirtioCommon { } self.acked_features |= v; } + + pub fn activate( + &mut self, + queues: &[Queue], + queue_evts: &[EventFd], + interrupt_cb: &Arc, + ) -> ActivateResult { + if queues.len() != self.queue_sizes.len() || queue_evts.len() != self.queue_sizes.len() { + error!( + "Cannot perform activate. Expected {} queue(s), got {}", + self.queue_sizes.len(), + queues.len() + ); + return Err(ActivateError::BadActivate); + } + + let kill_evt = EventFd::new(EFD_NONBLOCK).map_err(|e| { + error!("failed creating kill EventFd: {}", e); + ActivateError::BadActivate + })?; + self.kill_evt = Some(kill_evt); + + let pause_evt = EventFd::new(EFD_NONBLOCK).map_err(|e| { + error!("failed creating pause EventFd: {}", e); + ActivateError::BadActivate + })?; + self.pause_evt = Some(pause_evt); + + // Save the interrupt EventFD as we need to return it on reset + // but clone it to pass into the thread. + self.interrupt_cb = Some(interrupt_cb.clone()); + + let mut tmp_queue_evts: Vec = Vec::new(); + for queue_evt in queue_evts.iter() { + // Save the queue EventFD as we need to return it on reset + // but clone it to pass into the thread. + tmp_queue_evts.push(queue_evt.try_clone().map_err(|e| { + error!("failed to clone queue EventFd: {}", e); + ActivateError::BadActivate + })?); + } + self.queue_evts = Some(tmp_queue_evts); + Ok(()) + } + + pub fn reset(&mut self) -> Option<(Arc, Vec)> { + // We first must resume the virtio thread if it was paused. + if self.pause_evt.take().is_some() { + self.resume().ok()?; + } + + if let Some(kill_evt) = self.kill_evt.take() { + // Ignore the result because there is nothing we can do about it. + let _ = kill_evt.write(1); + } + + // Return the interrupt and queue EventFDs + Some(( + self.interrupt_cb.take().unwrap(), + self.queue_evts.take().unwrap(), + )) + } +} + +impl Pausable for VirtioCommon { + fn pause(&mut self) -> std::result::Result<(), MigratableError> { + debug!( + "Pausing virtio-{}", + VirtioDeviceType::from(self.device_type) + ); + self.paused.store(true, Ordering::SeqCst); + if let Some(pause_evt) = &self.pause_evt { + pause_evt + .write(1) + .map_err(|e| MigratableError::Pause(e.into()))?; + + // Wait for all threads to acknowledge the pause before going + // any further. This is exclusively performed when pause_evt + // eventfd is Some(), as this means the virtio device has been + // activated. One specific case where the device can be paused + // while it hasn't been yet activated is snapshot/restore. + self.paused_sync.as_ref().unwrap().wait(); + } + + Ok(()) + } + + fn resume(&mut self) -> std::result::Result<(), MigratableError> { + debug!( + "Resuming virtio-{}", + VirtioDeviceType::from(self.device_type) + ); + self.paused.store(false, Ordering::SeqCst); + if let Some(epoll_threads) = &self.epoll_threads { + for t in epoll_threads.iter() { + t.thread().unpark(); + } + } + + Ok(()) + } } #[macro_export]