diff --git a/virtio-devices/src/vhost_user/mod.rs b/virtio-devices/src/vhost_user/mod.rs index 5e65766ba..2cecb3e6f 100644 --- a/virtio-devices/src/vhost_user/mod.rs +++ b/virtio-devices/src/vhost_user/mod.rs @@ -41,6 +41,8 @@ pub enum Error { FailedSignalingUsedQueue(io::Error), /// Failed to read vhost eventfd. MemoryRegions(MmapError), + /// Failed removing socket path + RemoveSocketPath(io::Error), /// Failed to create master. VhostUserCreateMaster(VhostError), /// Failed to open vhost device. diff --git a/virtio-devices/src/vhost_user/net.rs b/virtio-devices/src/vhost_user/net.rs index 39575011f..ae91f50a3 100644 --- a/virtio-devices/src/vhost_user/net.rs +++ b/virtio-devices/src/vhost_user/net.rs @@ -6,8 +6,8 @@ use super::super::{ VirtioCommon, VirtioDevice, VirtioDeviceType, EPOLL_HELPER_EVENT_LAST, }; use super::vu_common_ctrl::{ - add_memory_region, negotiate_features_vhost_user, reinitialize_vhost_user, reset_vhost_user, - setup_vhost_user, update_mem_table, VhostUserConfig, + add_memory_region, connect_vhost_user, negotiate_features_vhost_user, reinitialize_vhost_user, + reset_vhost_user, setup_vhost_user, update_mem_table, VhostUserConfig, }; use super::{Error, Result}; use crate::seccomp_filters::{get_seccomp_filter, Thread}; @@ -16,7 +16,6 @@ use net_util::{build_net_config_space, CtrlQueue, MacAddr, VirtioNetConfig}; use seccomp::{SeccompAction, SeccompFilter}; use std::ops::Deref; use std::os::unix::io::AsRawFd; -use std::os::unix::net::UnixListener; use std::result; use std::sync::atomic::AtomicBool; use std::sync::{Arc, Barrier, Mutex}; @@ -135,25 +134,18 @@ impl ReconnectEpollHandler { epoll::Events::EPOLLHUP, )?; - let num_queues = self.queues.len() as u64; - - let mut vhost_user_net = if self.server { - std::fs::remove_file(&self.socket_path).map_err(EpollHelperError::IoError)?; - info!("Binding vhost-user-net listener..."); - let listener = - UnixListener::bind(&self.socket_path).map_err(EpollHelperError::IoError)?; - info!("Waiting for incoming vhost-user-net connection..."); - let (stream, _) = listener.accept().map_err(EpollHelperError::IoError)?; - - Master::from_stream(stream, num_queues) - } else { - Master::connect(&self.socket_path, num_queues).map_err(|e| { - EpollHelperError::IoError(std::io::Error::new( - std::io::ErrorKind::Other, - format!("failed connecting vhost-user backend{:?}", e), - )) - })? - }; + let mut vhost_user_net = connect_vhost_user( + self.server, + &self.socket_path, + self.queues.len() as u64, + true, + ) + .map_err(|e| { + EpollHelperError::IoError(std::io::Error::new( + std::io::ErrorKind::Other, + format!("failed connecting vhost-user backend{:?}", e), + )) + })?; // Initialize the backend reinitialize_vhost_user( @@ -254,17 +246,8 @@ impl Net { let mut config = VirtioNetConfig::default(); build_net_config_space(&mut config, mac_addr, num_queues, &mut avail_features); - let mut vhost_user_net = if server { - info!("Binding vhost-user-net listener..."); - let listener = UnixListener::bind(&vu_cfg.socket).map_err(Error::BindSocket)?; - info!("Waiting for incoming vhost-user-net connection..."); - let (stream, _) = listener.accept().map_err(Error::AcceptConnection)?; - - Master::from_stream(stream, num_queues as u64) - } else { - Master::connect(&vu_cfg.socket, num_queues as u64) - .map_err(Error::VhostUserCreateMaster)? - }; + let mut vhost_user_net = + connect_vhost_user(server, &vu_cfg.socket, num_queues as u64, false)?; let avail_protocol_features = VhostUserProtocolFeatures::MQ | VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS diff --git a/virtio-devices/src/vhost_user/vu_common_ctrl.rs b/virtio-devices/src/vhost_user/vu_common_ctrl.rs index a013ae9a7..e9cbed435 100644 --- a/virtio-devices/src/vhost_user/vu_common_ctrl.rs +++ b/virtio-devices/src/vhost_user/vu_common_ctrl.rs @@ -6,6 +6,7 @@ use super::{Error, Result}; use crate::{get_host_address_range, VirtioInterrupt, VirtioInterruptType}; use std::convert::TryInto; use std::os::unix::io::AsRawFd; +use std::os::unix::net::UnixListener; use std::sync::Arc; use std::vec::Vec; use vhost::vhost_user::message::{VhostUserProtocolFeatures, VhostUserVirtioFeatures}; @@ -204,3 +205,25 @@ pub fn reinitialize_vhost_user( acked_features, ) } + +pub fn connect_vhost_user( + server: bool, + socket_path: &str, + num_queues: u64, + unlink_socket: bool, +) -> Result { + Ok(if server { + if unlink_socket { + std::fs::remove_file(socket_path).map_err(Error::RemoveSocketPath)?; + } + + info!("Binding vhost-user listener..."); + let listener = UnixListener::bind(socket_path).map_err(Error::BindSocket)?; + info!("Waiting for incoming vhost-user connection..."); + let (stream, _) = listener.accept().map_err(Error::AcceptConnection)?; + + Master::from_stream(stream, num_queues) + } else { + Master::connect(socket_path, num_queues).map_err(Error::VhostUserConnect)? + }) +}