vhost_user_backend: Replace Mutex with RwLock when possible

Signed-off-by: Sebastien Boeuf <sebastien.boeuf@intel.com>
This commit is contained in:
Sebastien Boeuf 2019-09-16 11:02:30 -07:00
parent 2e2cad91ae
commit 4ed81894aa

View File

@ -9,7 +9,7 @@ use std::io;
use std::num::Wrapping;
use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
use std::result;
use std::sync::{Arc, Mutex};
use std::sync::{Arc, Mutex, RwLock};
use std::thread;
use vhost_rs::vhost_user::message::{
VhostUserConfigFlags, VhostUserMemoryRegion, VhostUserProtocolFeatures,
@ -90,7 +90,7 @@ pub struct VhostUserDaemon<S: VhostUserBackend> {
name: String,
sock_path: String,
handler: Arc<Mutex<VhostUserHandler<S>>>,
vring_handler: Arc<Mutex<VringEpollHandler<S>>>,
vring_handler: Arc<RwLock<VringEpollHandler<S>>>,
main_thread: Option<thread::JoinHandle<Result<()>>>,
}
@ -158,7 +158,7 @@ impl<S: VhostUserBackend> VhostUserDaemon<S> {
data: u64,
) -> result::Result<(), io::Error> {
self.vring_handler
.lock()
.read()
.unwrap()
.register_listener(fd, ev_type, data)
}
@ -173,7 +173,7 @@ impl<S: VhostUserBackend> VhostUserDaemon<S> {
data: u64,
) -> result::Result<(), io::Error> {
self.vring_handler
.lock()
.read()
.unwrap()
.unregister_listener(fd, ev_type, data)
}
@ -189,7 +189,7 @@ impl<S: VhostUserBackend> VhostUserDaemon<S> {
/// of `process_queue`. With this twisted trick, all common parts related
/// to the virtqueues can remain part of the library.
pub fn process_queue(&self, q_idx: u16) -> Result<()> {
self.vring_handler.lock().unwrap().process_queue(q_idx)
self.vring_handler.read().unwrap().process_queue(q_idx)
}
}
@ -257,7 +257,7 @@ impl Vring {
struct VringEpollHandler<S: VhostUserBackend> {
backend: Arc<S>,
vrings: Arc<Mutex<Vec<Vring>>>,
vrings: Arc<RwLock<Vec<Vring>>>,
mem: Option<GuestMemoryMmap>,
epoll_fd: RawFd,
}
@ -268,7 +268,7 @@ impl<S: VhostUserBackend> VringEpollHandler<S> {
}
fn process_queue(&self, q_idx: u16) -> Result<()> {
let vring = &mut self.vrings.lock().unwrap()[q_idx as usize];
let vring = &mut self.vrings.write().unwrap()[q_idx as usize];
let mut used_desc_heads = vec![(0, 0); vring.queue.size as usize];
let mut used_count = 0;
if let Some(mem) = &self.mem {
@ -296,11 +296,11 @@ impl<S: VhostUserBackend> VringEpollHandler<S> {
Ok(())
}
fn handle_event(&mut self, device_event: u16, evset: epoll::Events) -> Result<bool> {
let num_queues = self.vrings.lock().unwrap().len();
fn handle_event(&self, device_event: u16, evset: epoll::Events) -> Result<bool> {
let num_queues = self.vrings.read().unwrap().len();
match device_event as usize {
x if x < num_queues => {
if let Some(kick) = &self.vrings.lock().unwrap()[device_event as usize].kick {
if let Some(kick) = &self.vrings.read().unwrap()[device_event as usize].kick {
kick.read().unwrap();
}
@ -313,7 +313,7 @@ impl<S: VhostUserBackend> VringEpollHandler<S> {
}
fn register_vring_listener(&self, q_idx: usize) -> result::Result<(), io::Error> {
if let Some(fd) = &self.vrings.lock().unwrap()[q_idx].kick {
if let Some(fd) = &self.vrings.read().unwrap()[q_idx].kick {
self.register_listener(fd.as_raw_fd(), epoll::Events::EPOLLIN, q_idx as u64)
} else {
Ok(())
@ -321,7 +321,7 @@ impl<S: VhostUserBackend> VringEpollHandler<S> {
}
fn unregister_vring_listener(&self, q_idx: usize) -> result::Result<(), io::Error> {
if let Some(fd) = &self.vrings.lock().unwrap()[q_idx].kick {
if let Some(fd) = &self.vrings.read().unwrap()[q_idx].kick {
self.unregister_listener(fd.as_raw_fd(), epoll::Events::EPOLLIN, q_idx as u64)
} else {
Ok(())
@ -358,7 +358,7 @@ impl<S: VhostUserBackend> VringEpollHandler<S> {
}
struct VringWorker<S: VhostUserBackend> {
handler: Arc<Mutex<VringEpollHandler<S>>>,
handler: Arc<RwLock<VringEpollHandler<S>>>,
}
impl<S: VhostUserBackend> VringWorker<S> {
@ -396,7 +396,7 @@ impl<S: VhostUserBackend> VringWorker<S> {
let ev_type = event.data as u16;
if self.handler.lock().unwrap().handle_event(ev_type, evset)? {
if self.handler.read().unwrap().handle_event(ev_type, evset)? {
break 'epoll;
}
}
@ -408,7 +408,7 @@ impl<S: VhostUserBackend> VringWorker<S> {
struct VhostUserHandler<S: VhostUserBackend> {
backend: Arc<S>,
vring_handler: Arc<Mutex<VringEpollHandler<S>>>,
vring_handler: Arc<RwLock<VringEpollHandler<S>>>,
owned: bool,
features_acked: bool,
acked_features: u64,
@ -416,7 +416,7 @@ struct VhostUserHandler<S: VhostUserBackend> {
num_queues: usize,
max_queue_size: usize,
memory: Option<Memory>,
vrings: Arc<Mutex<Vec<Vring>>>,
vrings: Arc<RwLock<Vec<Vring>>>,
}
impl<S: VhostUserBackend> VhostUserHandler<S> {
@ -425,14 +425,14 @@ impl<S: VhostUserBackend> VhostUserHandler<S> {
let max_queue_size = backend.max_queue_size();
let arc_backend = Arc::new(backend);
let vrings = Arc::new(Mutex::new(vec![
let vrings = Arc::new(RwLock::new(vec![
Vring::new(max_queue_size as u16);
num_queues
]));
// Create the epoll file descriptor
let epoll_fd = epoll::create(true).unwrap();
let vring_handler = Arc::new(Mutex::new(VringEpollHandler {
let vring_handler = Arc::new(RwLock::new(VringEpollHandler {
backend: arc_backend.clone(),
vrings: vrings.clone(),
mem: None,
@ -461,7 +461,7 @@ impl<S: VhostUserBackend> VhostUserHandler<S> {
}
}
fn get_vring_handler(&self) -> Arc<Mutex<VringEpollHandler<S>>> {
fn get_vring_handler(&self) -> Arc<RwLock<VringEpollHandler<S>>> {
self.vring_handler.clone()
}
@ -518,7 +518,7 @@ impl<S: VhostUserBackend> VhostUserSlaveReqHandler for VhostUserHandler<S> {
// been disabled by VHOST_USER_SET_VRING_ENABLE with parameter 0.
let vring_enabled =
self.acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() == 0;
for vring in self.vrings.lock().unwrap().iter_mut() {
for vring in self.vrings.write().unwrap().iter_mut() {
vring.enabled = vring_enabled;
}
@ -561,7 +561,7 @@ impl<S: VhostUserBackend> VhostUserSlaveReqHandler for VhostUserHandler<S> {
}
let mem = GuestMemoryMmap::with_files(regions).unwrap();
self.vring_handler.lock().unwrap().update_memory(Some(mem));
self.vring_handler.write().unwrap().update_memory(Some(mem));
self.memory = Some(Memory { mappings });
Ok(())
@ -575,7 +575,7 @@ impl<S: VhostUserBackend> VhostUserSlaveReqHandler for VhostUserHandler<S> {
if index as usize >= self.num_queues || num == 0 || num as usize > self.max_queue_size {
return Err(VhostUserError::InvalidParam);
}
self.vrings.lock().unwrap()[index as usize].queue.size = num as u16;
self.vrings.write().unwrap()[index as usize].queue.size = num as u16;
Ok(())
}
@ -596,9 +596,13 @@ impl<S: VhostUserBackend> VhostUserSlaveReqHandler for VhostUserHandler<S> {
let desc_table = self.vmm_va_to_gpa(descriptor).unwrap();
let avail_ring = self.vmm_va_to_gpa(available).unwrap();
let used_ring = self.vmm_va_to_gpa(used).unwrap();
self.vrings.lock().unwrap()[index as usize].queue.desc_table = GuestAddress(desc_table);
self.vrings.lock().unwrap()[index as usize].queue.avail_ring = GuestAddress(avail_ring);
self.vrings.lock().unwrap()[index as usize].queue.used_ring = GuestAddress(used_ring);
self.vrings.write().unwrap()[index as usize]
.queue
.desc_table = GuestAddress(desc_table);
self.vrings.write().unwrap()[index as usize]
.queue
.avail_ring = GuestAddress(avail_ring);
self.vrings.write().unwrap()[index as usize].queue.used_ring = GuestAddress(used_ring);
Ok(())
} else {
Err(VhostUserError::InvalidParam)
@ -606,8 +610,10 @@ impl<S: VhostUserBackend> VhostUserSlaveReqHandler for VhostUserHandler<S> {
}
fn set_vring_base(&mut self, index: u32, base: u32) -> VhostUserResult<()> {
self.vrings.lock().unwrap()[index as usize].queue.next_avail = Wrapping(base as u16);
self.vrings.lock().unwrap()[index as usize].queue.next_used = Wrapping(base as u16);
self.vrings.write().unwrap()[index as usize]
.queue
.next_avail = Wrapping(base as u16);
self.vrings.write().unwrap()[index as usize].queue.next_used = Wrapping(base as u16);
Ok(())
}
@ -620,14 +626,14 @@ impl<S: VhostUserBackend> VhostUserSlaveReqHandler for VhostUserHandler<S> {
// that file descriptor is readable) on the descriptor specified by
// VHOST_USER_SET_VRING_KICK, and stop ring upon receiving
// VHOST_USER_GET_VRING_BASE.
self.vrings.lock().unwrap()[index as usize].started = false;
self.vrings.write().unwrap()[index as usize].started = false;
self.vring_handler
.lock()
.read()
.unwrap()
.unregister_vring_listener(index as usize)
.unwrap();
let next_avail = self.vrings.lock().unwrap()[index as usize]
let next_avail = self.vrings.read().unwrap()[index as usize]
.queue
.next_avail
.0 as u16;
@ -640,11 +646,11 @@ impl<S: VhostUserBackend> VhostUserSlaveReqHandler for VhostUserHandler<S> {
return Err(VhostUserError::InvalidParam);
}
if self.vrings.lock().unwrap()[index as usize].kick.is_some() {
if self.vrings.read().unwrap()[index as usize].kick.is_some() {
// Close file descriptor set by previous operations.
let _ = unsafe {
libc::close(
self.vrings.lock().unwrap()[index as usize]
self.vrings.write().unwrap()[index as usize]
.kick
.take()
.unwrap()
@ -652,7 +658,7 @@ impl<S: VhostUserBackend> VhostUserSlaveReqHandler for VhostUserHandler<S> {
)
};
}
self.vrings.lock().unwrap()[index as usize].kick =
self.vrings.write().unwrap()[index as usize].kick =
Some(unsafe { EventFd::from_raw_fd(fd.unwrap()) });;
// Quotation from vhost-user spec:
@ -662,9 +668,9 @@ impl<S: VhostUserBackend> VhostUserSlaveReqHandler for VhostUserHandler<S> {
// VHOST_USER_GET_VRING_BASE.
//
// So we should add fd to event monitor(select, poll, epoll) here.
self.vrings.lock().unwrap()[index as usize].started = true;
self.vrings.write().unwrap()[index as usize].started = true;
self.vring_handler
.lock()
.read()
.unwrap()
.register_vring_listener(index as usize)
.unwrap();
@ -677,11 +683,11 @@ impl<S: VhostUserBackend> VhostUserSlaveReqHandler for VhostUserHandler<S> {
return Err(VhostUserError::InvalidParam);
}
if self.vrings.lock().unwrap()[index as usize].call.is_some() {
if self.vrings.write().unwrap()[index as usize].call.is_some() {
// Close file descriptor set by previous operations.
let _ = unsafe {
libc::close(
self.vrings.lock().unwrap()[index as usize]
self.vrings.write().unwrap()[index as usize]
.call
.take()
.unwrap()
@ -689,7 +695,7 @@ impl<S: VhostUserBackend> VhostUserSlaveReqHandler for VhostUserHandler<S> {
)
};
}
self.vrings.lock().unwrap()[index as usize].call =
self.vrings.write().unwrap()[index as usize].call =
Some(unsafe { EventFd::from_raw_fd(fd.unwrap()) });
Ok(())
@ -700,11 +706,11 @@ impl<S: VhostUserBackend> VhostUserSlaveReqHandler for VhostUserHandler<S> {
return Err(VhostUserError::InvalidParam);
}
if self.vrings.lock().unwrap()[index as usize].err.is_some() {
if self.vrings.read().unwrap()[index as usize].err.is_some() {
// Close file descriptor set by previous operations.
let _ = unsafe {
libc::close(
self.vrings.lock().unwrap()[index as usize]
self.vrings.write().unwrap()[index as usize]
.err
.take()
.unwrap()
@ -712,7 +718,7 @@ impl<S: VhostUserBackend> VhostUserSlaveReqHandler for VhostUserHandler<S> {
)
};
}
self.vrings.lock().unwrap()[index as usize].err =
self.vrings.write().unwrap()[index as usize].err =
Some(unsafe { EventFd::from_raw_fd(fd.unwrap()) });
Ok(())
@ -731,7 +737,7 @@ impl<S: VhostUserBackend> VhostUserSlaveReqHandler for VhostUserHandler<S> {
// enabled by VHOST_USER_SET_VRING_ENABLE with parameter 1,
// or after it has been disabled by VHOST_USER_SET_VRING_ENABLE
// with parameter 0.
self.vrings.lock().unwrap()[index as usize].enabled = enable;
self.vrings.write().unwrap()[index as usize].enabled = enable;
Ok(())
}