diff --git a/virtio-devices/src/epoll_helper.rs b/virtio-devices/src/epoll_helper.rs index a0eb04519..21216c283 100644 --- a/virtio-devices/src/epoll_helper.rs +++ b/virtio-devices/src/epoll_helper.rs @@ -24,6 +24,7 @@ pub struct EpollHelper { pub enum EpollHelperError { CreateFd(std::io::Error), Ctl(std::io::Error), + IoError(std::io::Error), Wait(std::io::Error), } @@ -57,11 +58,35 @@ impl EpollHelper { } pub fn add_event(&mut self, fd: RawFd, id: u16) -> std::result::Result<(), EpollHelperError> { + self.add_event_custom(fd, id, epoll::Events::EPOLLIN) + } + + pub fn add_event_custom( + &mut self, + fd: RawFd, + id: u16, + evts: epoll::Events, + ) -> std::result::Result<(), EpollHelperError> { epoll::ctl( self.epoll_file.as_raw_fd(), epoll::ControlOptions::EPOLL_CTL_ADD, fd, - epoll::Event::new(epoll::Events::EPOLLIN, id.into()), + epoll::Event::new(evts, id.into()), + ) + .map_err(EpollHelperError::Ctl) + } + + pub fn del_event_custom( + &mut self, + fd: RawFd, + id: u16, + evts: epoll::Events, + ) -> std::result::Result<(), EpollHelperError> { + epoll::ctl( + self.epoll_file.as_raw_fd(), + epoll::ControlOptions::EPOLL_CTL_DEL, + fd, + epoll::Event::new(evts, id.into()), ) .map_err(EpollHelperError::Ctl) } diff --git a/virtio-devices/src/vhost_user/net.rs b/virtio-devices/src/vhost_user/net.rs index 0e2820dc2..39575011f 100644 --- a/virtio-devices/src/vhost_user/net.rs +++ b/virtio-devices/src/vhost_user/net.rs @@ -5,7 +5,10 @@ use super::super::{ ActivateError, ActivateResult, EpollHelper, EpollHelperError, EpollHelperHandler, Queue, VirtioCommon, VirtioDevice, VirtioDeviceType, EPOLL_HELPER_EVENT_LAST, }; -use super::vu_common_ctrl::*; +use super::vu_common_ctrl::{ + add_memory_region, negotiate_features_vhost_user, reinitialize_vhost_user, reset_vhost_user, + setup_vhost_user, update_mem_table, VhostUserConfig, +}; use super::{Error, Result}; use crate::seccomp_filters::{get_seccomp_filter, Thread}; use crate::{VirtioInterrupt, VIRTIO_F_RING_EVENT_IDX, VIRTIO_F_VERSION_1}; @@ -16,7 +19,7 @@ use std::os::unix::io::AsRawFd; use std::os::unix::net::UnixListener; use std::result; use std::sync::atomic::AtomicBool; -use std::sync::{Arc, Barrier}; +use std::sync::{Arc, Barrier, Mutex}; use std::thread; use std::vec::Vec; use vhost::vhost_user::message::{VhostUserProtocolFeatures, VhostUserVirtioFeatures}; @@ -90,15 +93,133 @@ impl EpollHelperHandler for NetCtrlEpollHandler { } } +/// Reconnection thread +// Event meaning the connection was closed. +const HUP_CONNECTION_EVENT: u16 = EPOLL_HELPER_EVENT_LAST + 1; + +pub struct ReconnectEpollHandler { + pub vu: Arc>, + pub mem: GuestMemoryAtomic, + pub kill_evt: EventFd, + pub pause_evt: EventFd, + pub queues: Vec, + pub queue_evts: Vec, + pub virtio_interrupt: Arc, + pub acked_features: u64, + pub acked_protocol_features: u64, + pub socket_path: String, + pub server: bool, +} + +impl ReconnectEpollHandler { + pub fn run( + &mut self, + paused: Arc, + paused_sync: Arc, + ) -> std::result::Result<(), EpollHelperError> { + let mut helper = EpollHelper::new(&self.kill_evt, &self.pause_evt)?; + helper.add_event_custom( + self.vu.lock().unwrap().as_raw_fd(), + HUP_CONNECTION_EVENT, + epoll::Events::EPOLLHUP, + )?; + helper.run(paused, paused_sync, self)?; + + Ok(()) + } + + fn reconnect(&mut self, helper: &mut EpollHelper) -> std::result::Result<(), EpollHelperError> { + helper.del_event_custom( + self.vu.lock().unwrap().as_raw_fd(), + HUP_CONNECTION_EVENT, + epoll::Events::EPOLLHUP, + )?; + + let num_queues = self.queues.len() as u64; + + let mut vhost_user_net = if self.server { + std::fs::remove_file(&self.socket_path).map_err(EpollHelperError::IoError)?; + info!("Binding vhost-user-net listener..."); + let listener = + UnixListener::bind(&self.socket_path).map_err(EpollHelperError::IoError)?; + info!("Waiting for incoming vhost-user-net connection..."); + let (stream, _) = listener.accept().map_err(EpollHelperError::IoError)?; + + Master::from_stream(stream, num_queues) + } else { + Master::connect(&self.socket_path, num_queues).map_err(|e| { + EpollHelperError::IoError(std::io::Error::new( + std::io::ErrorKind::Other, + format!("failed connecting vhost-user backend{:?}", e), + )) + })? + }; + + // Initialize the backend + reinitialize_vhost_user( + &mut vhost_user_net, + self.mem.memory().deref(), + self.queues.clone(), + self.queue_evts + .iter() + .map(|q| q.try_clone().unwrap()) + .collect(), + &self.virtio_interrupt, + self.acked_features, + self.acked_protocol_features, + ) + .map_err(|e| { + EpollHelperError::IoError(std::io::Error::new( + std::io::ErrorKind::Other, + format!("failed reconnecting vhost-user backend{:?}", e), + )) + })?; + + helper.add_event_custom( + vhost_user_net.as_raw_fd(), + HUP_CONNECTION_EVENT, + epoll::Events::EPOLLHUP, + )?; + + // Update vhost-user reference + let mut vu = self.vu.lock().unwrap(); + *vu = vhost_user_net; + + Ok(()) + } +} + +impl EpollHelperHandler for ReconnectEpollHandler { + fn handle_event(&mut self, helper: &mut EpollHelper, event: &epoll::Event) -> bool { + let ev_type = event.data as u16; + match ev_type { + HUP_CONNECTION_EVENT => { + if let Err(e) = self.reconnect(helper) { + error!("failed to reconnect vhost-user-net backend: {:?}", e); + return true; + } + } + _ => { + error!("Unknown event for vhost-user-net reconnection thread"); + return true; + } + } + + false + } +} + pub struct Net { common: VirtioCommon, id: String, - vhost_user_net: Master, + vhost_user_net: Arc>, config: VirtioNetConfig, guest_memory: Option>, acked_protocol_features: u64, - socket_path: Option, + socket_path: String, + server: bool, ctrl_queue_epoll_thread: Option>, + reconnect_epoll_thread: Option>, seccomp_action: SeccompAction, } @@ -111,8 +232,6 @@ impl Net { server: bool, seccomp_action: SeccompAction, ) -> Result { - let mut socket_path: Option = None; - let mut num_queues = vu_cfg.num_queues; // Filling device and vring features VMM supports. @@ -141,8 +260,6 @@ impl Net { info!("Waiting for incoming vhost-user-net connection..."); let (stream, _) = listener.accept().map_err(Error::AcceptConnection)?; - socket_path = Some(vu_cfg.socket.clone()); - Master::from_stream(stream, num_queues as u64) } else { Master::connect(&vu_cfg.socket, num_queues as u64) @@ -191,16 +308,18 @@ impl Net { queue_sizes: vec![vu_cfg.queue_size; num_queues], avail_features: acked_features, acked_features: 0, - paused_sync: Some(Arc::new(Barrier::new(1))), + paused_sync: Some(Arc::new(Barrier::new(2))), min_queues: DEFAULT_QUEUE_NUMBER as u16, ..Default::default() }, - vhost_user_net, + vhost_user_net: Arc::new(Mutex::new(vhost_user_net)), config, guest_memory: None, acked_protocol_features, - socket_path, + socket_path: vu_cfg.socket, + server, ctrl_queue_epoll_thread: None, + reconnect_epoll_thread: None, seccomp_action, }) } @@ -285,8 +404,9 @@ impl VirtioDevice for Net { let paused = self.common.paused.clone(); // Let's update the barrier as we need 1 for the control queue + 1 - // for the main thread signalling the pause. - self.common.paused_sync = Some(Arc::new(Barrier::new(2))); + // for the reconnect thread + 1 for the main thread signalling the + // pause. + self.common.paused_sync = Some(Arc::new(Barrier::new(3))); let paused_sync = self.common.paused_sync.clone(); // Retrieve seccomp filter for virtio_net_ctl thread @@ -317,15 +437,68 @@ impl VirtioDevice for Net { | (self.common.avail_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits()); setup_vhost_user( - &mut self.vhost_user_net, + &mut self.vhost_user_net.lock().unwrap(), &mem.memory(), - queues, - queue_evts, + queues.clone(), + queue_evts.iter().map(|q| q.try_clone().unwrap()).collect(), &interrupt_cb, backend_acked_features, ) .map_err(ActivateError::VhostUserNetSetup)?; + // Run a dedicated thread for handling potential reconnections with + // the backend. + let kill_evt = self + .common + .kill_evt + .as_ref() + .unwrap() + .try_clone() + .map_err(|e| { + error!("failed to clone kill_evt eventfd: {}", e); + ActivateError::BadActivate + })?; + let pause_evt = self + .common + .pause_evt + .as_ref() + .unwrap() + .try_clone() + .map_err(|e| { + error!("failed to clone pause_evt eventfd: {}", e); + ActivateError::BadActivate + })?; + + let mut reconnect_handler = ReconnectEpollHandler { + vu: self.vhost_user_net.clone(), + mem, + kill_evt, + pause_evt, + queues, + queue_evts, + virtio_interrupt: interrupt_cb, + acked_features: backend_acked_features, + acked_protocol_features: self.acked_protocol_features, + socket_path: self.socket_path.clone(), + server: self.server, + }; + + let paused = self.common.paused.clone(); + let paused_sync = self.common.paused_sync.clone(); + + thread::Builder::new() + .name(format!("{}_reconnect", self.id)) + .spawn(move || { + if let Err(e) = reconnect_handler.run(paused, paused_sync.unwrap()) { + error!("Error running reconnection worker: {:?}", e); + } + }) + .map(|thread| self.reconnect_epoll_thread = Some(thread)) + .map_err(|e| { + error!("failed to clone queue EventFd: {}", e); + ActivateError::BadActivate + })?; + Ok(()) } @@ -335,7 +508,10 @@ impl VirtioDevice for Net { self.common.resume().ok()?; } - if let Err(e) = reset_vhost_user(&mut self.vhost_user_net, self.common.queue_sizes.len()) { + if let Err(e) = reset_vhost_user( + &mut self.vhost_user_net.lock().unwrap(), + self.common.queue_sizes.len(), + ) { error!("Failed to reset vhost-user daemon: {:?}", e); return None; } @@ -352,11 +528,11 @@ impl VirtioDevice for Net { } fn shutdown(&mut self) { - let _ = unsafe { libc::close(self.vhost_user_net.as_raw_fd()) }; + let _ = unsafe { libc::close(self.vhost_user_net.lock().unwrap().as_raw_fd()) }; // Remove socket path if needed - if let Some(socket_path) = &self.socket_path { - let _ = std::fs::remove_file(socket_path); + if self.server { + let _ = std::fs::remove_file(&self.socket_path); } } @@ -366,11 +542,14 @@ impl VirtioDevice for Net { ) -> std::result::Result<(), crate::Error> { if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits() != 0 { - add_memory_region(&mut self.vhost_user_net, region) + add_memory_region(&mut self.vhost_user_net.lock().unwrap(), region) .map_err(crate::Error::VhostUserAddMemoryRegion) } else if let Some(guest_memory) = &self.guest_memory { - update_mem_table(&mut self.vhost_user_net, guest_memory.memory().deref()) - .map_err(crate::Error::VhostUserUpdateMemory) + update_mem_table( + &mut self.vhost_user_net.lock().unwrap(), + guest_memory.memory().deref(), + ) + .map_err(crate::Error::VhostUserUpdateMemory) } else { Ok(()) } @@ -388,6 +567,11 @@ impl Pausable for Net { if let Some(ctrl_queue_epoll_thread) = &self.ctrl_queue_epoll_thread { ctrl_queue_epoll_thread.thread().unpark(); } + + if let Some(reconnect_epoll_thread) = &self.reconnect_epoll_thread { + reconnect_epoll_thread.thread().unpark(); + } + Ok(()) } } diff --git a/virtio-devices/src/vhost_user/vu_common_ctrl.rs b/virtio-devices/src/vhost_user/vu_common_ctrl.rs index e97907039..a013ae9a7 100644 --- a/virtio-devices/src/vhost_user/vu_common_ctrl.rs +++ b/virtio-devices/src/vhost_user/vu_common_ctrl.rs @@ -173,3 +173,34 @@ pub fn reset_vhost_user(vu: &mut Master, num_queues: usize) -> Result<()> { // Reset the owner. vu.reset_owner().map_err(Error::VhostUserResetOwner) } + +pub fn reinitialize_vhost_user( + vu: &mut Master, + mem: &GuestMemoryMmap, + queues: Vec, + queue_evts: Vec, + virtio_interrupt: &Arc, + acked_features: u64, + acked_protocol_features: u64, +) -> Result<()> { + vu.set_owner().map_err(Error::VhostUserSetOwner)?; + vu.get_features().map_err(Error::VhostUserGetFeatures)?; + + if acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() != 0 { + if let Some(acked_protocol_features) = + VhostUserProtocolFeatures::from_bits(acked_protocol_features) + { + vu.set_protocol_features(acked_protocol_features) + .map_err(Error::VhostUserSetProtocolFeatures)?; + } + } + + setup_vhost_user( + vu, + mem, + queues, + queue_evts, + virtio_interrupt, + acked_features, + ) +}