vhost_user_backend: Replace Mutex with RwLock when possible

Signed-off-by: Sebastien Boeuf <sebastien.boeuf@intel.com>
This commit is contained in:
Sebastien Boeuf 2019-09-16 11:02:30 -07:00
parent 2e2cad91ae
commit 4ed81894aa

View File

@ -9,7 +9,7 @@ use std::io;
use std::num::Wrapping; use std::num::Wrapping;
use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
use std::result; use std::result;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex, RwLock};
use std::thread; use std::thread;
use vhost_rs::vhost_user::message::{ use vhost_rs::vhost_user::message::{
VhostUserConfigFlags, VhostUserMemoryRegion, VhostUserProtocolFeatures, VhostUserConfigFlags, VhostUserMemoryRegion, VhostUserProtocolFeatures,
@ -90,7 +90,7 @@ pub struct VhostUserDaemon<S: VhostUserBackend> {
name: String, name: String,
sock_path: String, sock_path: String,
handler: Arc<Mutex<VhostUserHandler<S>>>, handler: Arc<Mutex<VhostUserHandler<S>>>,
vring_handler: Arc<Mutex<VringEpollHandler<S>>>, vring_handler: Arc<RwLock<VringEpollHandler<S>>>,
main_thread: Option<thread::JoinHandle<Result<()>>>, main_thread: Option<thread::JoinHandle<Result<()>>>,
} }
@ -158,7 +158,7 @@ impl<S: VhostUserBackend> VhostUserDaemon<S> {
data: u64, data: u64,
) -> result::Result<(), io::Error> { ) -> result::Result<(), io::Error> {
self.vring_handler self.vring_handler
.lock() .read()
.unwrap() .unwrap()
.register_listener(fd, ev_type, data) .register_listener(fd, ev_type, data)
} }
@ -173,7 +173,7 @@ impl<S: VhostUserBackend> VhostUserDaemon<S> {
data: u64, data: u64,
) -> result::Result<(), io::Error> { ) -> result::Result<(), io::Error> {
self.vring_handler self.vring_handler
.lock() .read()
.unwrap() .unwrap()
.unregister_listener(fd, ev_type, data) .unregister_listener(fd, ev_type, data)
} }
@ -189,7 +189,7 @@ impl<S: VhostUserBackend> VhostUserDaemon<S> {
/// of `process_queue`. With this twisted trick, all common parts related /// of `process_queue`. With this twisted trick, all common parts related
/// to the virtqueues can remain part of the library. /// to the virtqueues can remain part of the library.
pub fn process_queue(&self, q_idx: u16) -> Result<()> { pub fn process_queue(&self, q_idx: u16) -> Result<()> {
self.vring_handler.lock().unwrap().process_queue(q_idx) self.vring_handler.read().unwrap().process_queue(q_idx)
} }
} }
@ -257,7 +257,7 @@ impl Vring {
struct VringEpollHandler<S: VhostUserBackend> { struct VringEpollHandler<S: VhostUserBackend> {
backend: Arc<S>, backend: Arc<S>,
vrings: Arc<Mutex<Vec<Vring>>>, vrings: Arc<RwLock<Vec<Vring>>>,
mem: Option<GuestMemoryMmap>, mem: Option<GuestMemoryMmap>,
epoll_fd: RawFd, epoll_fd: RawFd,
} }
@ -268,7 +268,7 @@ impl<S: VhostUserBackend> VringEpollHandler<S> {
} }
fn process_queue(&self, q_idx: u16) -> Result<()> { fn process_queue(&self, q_idx: u16) -> Result<()> {
let vring = &mut self.vrings.lock().unwrap()[q_idx as usize]; let vring = &mut self.vrings.write().unwrap()[q_idx as usize];
let mut used_desc_heads = vec![(0, 0); vring.queue.size as usize]; let mut used_desc_heads = vec![(0, 0); vring.queue.size as usize];
let mut used_count = 0; let mut used_count = 0;
if let Some(mem) = &self.mem { if let Some(mem) = &self.mem {
@ -296,11 +296,11 @@ impl<S: VhostUserBackend> VringEpollHandler<S> {
Ok(()) Ok(())
} }
fn handle_event(&mut self, device_event: u16, evset: epoll::Events) -> Result<bool> { fn handle_event(&self, device_event: u16, evset: epoll::Events) -> Result<bool> {
let num_queues = self.vrings.lock().unwrap().len(); let num_queues = self.vrings.read().unwrap().len();
match device_event as usize { match device_event as usize {
x if x < num_queues => { x if x < num_queues => {
if let Some(kick) = &self.vrings.lock().unwrap()[device_event as usize].kick { if let Some(kick) = &self.vrings.read().unwrap()[device_event as usize].kick {
kick.read().unwrap(); kick.read().unwrap();
} }
@ -313,7 +313,7 @@ impl<S: VhostUserBackend> VringEpollHandler<S> {
} }
fn register_vring_listener(&self, q_idx: usize) -> result::Result<(), io::Error> { fn register_vring_listener(&self, q_idx: usize) -> result::Result<(), io::Error> {
if let Some(fd) = &self.vrings.lock().unwrap()[q_idx].kick { if let Some(fd) = &self.vrings.read().unwrap()[q_idx].kick {
self.register_listener(fd.as_raw_fd(), epoll::Events::EPOLLIN, q_idx as u64) self.register_listener(fd.as_raw_fd(), epoll::Events::EPOLLIN, q_idx as u64)
} else { } else {
Ok(()) Ok(())
@ -321,7 +321,7 @@ impl<S: VhostUserBackend> VringEpollHandler<S> {
} }
fn unregister_vring_listener(&self, q_idx: usize) -> result::Result<(), io::Error> { fn unregister_vring_listener(&self, q_idx: usize) -> result::Result<(), io::Error> {
if let Some(fd) = &self.vrings.lock().unwrap()[q_idx].kick { if let Some(fd) = &self.vrings.read().unwrap()[q_idx].kick {
self.unregister_listener(fd.as_raw_fd(), epoll::Events::EPOLLIN, q_idx as u64) self.unregister_listener(fd.as_raw_fd(), epoll::Events::EPOLLIN, q_idx as u64)
} else { } else {
Ok(()) Ok(())
@ -358,7 +358,7 @@ impl<S: VhostUserBackend> VringEpollHandler<S> {
} }
struct VringWorker<S: VhostUserBackend> { struct VringWorker<S: VhostUserBackend> {
handler: Arc<Mutex<VringEpollHandler<S>>>, handler: Arc<RwLock<VringEpollHandler<S>>>,
} }
impl<S: VhostUserBackend> VringWorker<S> { impl<S: VhostUserBackend> VringWorker<S> {
@ -396,7 +396,7 @@ impl<S: VhostUserBackend> VringWorker<S> {
let ev_type = event.data as u16; let ev_type = event.data as u16;
if self.handler.lock().unwrap().handle_event(ev_type, evset)? { if self.handler.read().unwrap().handle_event(ev_type, evset)? {
break 'epoll; break 'epoll;
} }
} }
@ -408,7 +408,7 @@ impl<S: VhostUserBackend> VringWorker<S> {
struct VhostUserHandler<S: VhostUserBackend> { struct VhostUserHandler<S: VhostUserBackend> {
backend: Arc<S>, backend: Arc<S>,
vring_handler: Arc<Mutex<VringEpollHandler<S>>>, vring_handler: Arc<RwLock<VringEpollHandler<S>>>,
owned: bool, owned: bool,
features_acked: bool, features_acked: bool,
acked_features: u64, acked_features: u64,
@ -416,7 +416,7 @@ struct VhostUserHandler<S: VhostUserBackend> {
num_queues: usize, num_queues: usize,
max_queue_size: usize, max_queue_size: usize,
memory: Option<Memory>, memory: Option<Memory>,
vrings: Arc<Mutex<Vec<Vring>>>, vrings: Arc<RwLock<Vec<Vring>>>,
} }
impl<S: VhostUserBackend> VhostUserHandler<S> { impl<S: VhostUserBackend> VhostUserHandler<S> {
@ -425,14 +425,14 @@ impl<S: VhostUserBackend> VhostUserHandler<S> {
let max_queue_size = backend.max_queue_size(); let max_queue_size = backend.max_queue_size();
let arc_backend = Arc::new(backend); let arc_backend = Arc::new(backend);
let vrings = Arc::new(Mutex::new(vec![ let vrings = Arc::new(RwLock::new(vec![
Vring::new(max_queue_size as u16); Vring::new(max_queue_size as u16);
num_queues num_queues
])); ]));
// Create the epoll file descriptor // Create the epoll file descriptor
let epoll_fd = epoll::create(true).unwrap(); let epoll_fd = epoll::create(true).unwrap();
let vring_handler = Arc::new(Mutex::new(VringEpollHandler { let vring_handler = Arc::new(RwLock::new(VringEpollHandler {
backend: arc_backend.clone(), backend: arc_backend.clone(),
vrings: vrings.clone(), vrings: vrings.clone(),
mem: None, mem: None,
@ -461,7 +461,7 @@ impl<S: VhostUserBackend> VhostUserHandler<S> {
} }
} }
fn get_vring_handler(&self) -> Arc<Mutex<VringEpollHandler<S>>> { fn get_vring_handler(&self) -> Arc<RwLock<VringEpollHandler<S>>> {
self.vring_handler.clone() self.vring_handler.clone()
} }
@ -518,7 +518,7 @@ impl<S: VhostUserBackend> VhostUserSlaveReqHandler for VhostUserHandler<S> {
// been disabled by VHOST_USER_SET_VRING_ENABLE with parameter 0. // been disabled by VHOST_USER_SET_VRING_ENABLE with parameter 0.
let vring_enabled = let vring_enabled =
self.acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() == 0; self.acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() == 0;
for vring in self.vrings.lock().unwrap().iter_mut() { for vring in self.vrings.write().unwrap().iter_mut() {
vring.enabled = vring_enabled; vring.enabled = vring_enabled;
} }
@ -561,7 +561,7 @@ impl<S: VhostUserBackend> VhostUserSlaveReqHandler for VhostUserHandler<S> {
} }
let mem = GuestMemoryMmap::with_files(regions).unwrap(); let mem = GuestMemoryMmap::with_files(regions).unwrap();
self.vring_handler.lock().unwrap().update_memory(Some(mem)); self.vring_handler.write().unwrap().update_memory(Some(mem));
self.memory = Some(Memory { mappings }); self.memory = Some(Memory { mappings });
Ok(()) Ok(())
@ -575,7 +575,7 @@ impl<S: VhostUserBackend> VhostUserSlaveReqHandler for VhostUserHandler<S> {
if index as usize >= self.num_queues || num == 0 || num as usize > self.max_queue_size { if index as usize >= self.num_queues || num == 0 || num as usize > self.max_queue_size {
return Err(VhostUserError::InvalidParam); return Err(VhostUserError::InvalidParam);
} }
self.vrings.lock().unwrap()[index as usize].queue.size = num as u16; self.vrings.write().unwrap()[index as usize].queue.size = num as u16;
Ok(()) Ok(())
} }
@ -596,9 +596,13 @@ impl<S: VhostUserBackend> VhostUserSlaveReqHandler for VhostUserHandler<S> {
let desc_table = self.vmm_va_to_gpa(descriptor).unwrap(); let desc_table = self.vmm_va_to_gpa(descriptor).unwrap();
let avail_ring = self.vmm_va_to_gpa(available).unwrap(); let avail_ring = self.vmm_va_to_gpa(available).unwrap();
let used_ring = self.vmm_va_to_gpa(used).unwrap(); let used_ring = self.vmm_va_to_gpa(used).unwrap();
self.vrings.lock().unwrap()[index as usize].queue.desc_table = GuestAddress(desc_table); self.vrings.write().unwrap()[index as usize]
self.vrings.lock().unwrap()[index as usize].queue.avail_ring = GuestAddress(avail_ring); .queue
self.vrings.lock().unwrap()[index as usize].queue.used_ring = GuestAddress(used_ring); .desc_table = GuestAddress(desc_table);
self.vrings.write().unwrap()[index as usize]
.queue
.avail_ring = GuestAddress(avail_ring);
self.vrings.write().unwrap()[index as usize].queue.used_ring = GuestAddress(used_ring);
Ok(()) Ok(())
} else { } else {
Err(VhostUserError::InvalidParam) Err(VhostUserError::InvalidParam)
@ -606,8 +610,10 @@ impl<S: VhostUserBackend> VhostUserSlaveReqHandler for VhostUserHandler<S> {
} }
fn set_vring_base(&mut self, index: u32, base: u32) -> VhostUserResult<()> { fn set_vring_base(&mut self, index: u32, base: u32) -> VhostUserResult<()> {
self.vrings.lock().unwrap()[index as usize].queue.next_avail = Wrapping(base as u16); self.vrings.write().unwrap()[index as usize]
self.vrings.lock().unwrap()[index as usize].queue.next_used = Wrapping(base as u16); .queue
.next_avail = Wrapping(base as u16);
self.vrings.write().unwrap()[index as usize].queue.next_used = Wrapping(base as u16);
Ok(()) Ok(())
} }
@ -620,14 +626,14 @@ impl<S: VhostUserBackend> VhostUserSlaveReqHandler for VhostUserHandler<S> {
// that file descriptor is readable) on the descriptor specified by // that file descriptor is readable) on the descriptor specified by
// VHOST_USER_SET_VRING_KICK, and stop ring upon receiving // VHOST_USER_SET_VRING_KICK, and stop ring upon receiving
// VHOST_USER_GET_VRING_BASE. // VHOST_USER_GET_VRING_BASE.
self.vrings.lock().unwrap()[index as usize].started = false; self.vrings.write().unwrap()[index as usize].started = false;
self.vring_handler self.vring_handler
.lock() .read()
.unwrap() .unwrap()
.unregister_vring_listener(index as usize) .unregister_vring_listener(index as usize)
.unwrap(); .unwrap();
let next_avail = self.vrings.lock().unwrap()[index as usize] let next_avail = self.vrings.read().unwrap()[index as usize]
.queue .queue
.next_avail .next_avail
.0 as u16; .0 as u16;
@ -640,11 +646,11 @@ impl<S: VhostUserBackend> VhostUserSlaveReqHandler for VhostUserHandler<S> {
return Err(VhostUserError::InvalidParam); return Err(VhostUserError::InvalidParam);
} }
if self.vrings.lock().unwrap()[index as usize].kick.is_some() { if self.vrings.read().unwrap()[index as usize].kick.is_some() {
// Close file descriptor set by previous operations. // Close file descriptor set by previous operations.
let _ = unsafe { let _ = unsafe {
libc::close( libc::close(
self.vrings.lock().unwrap()[index as usize] self.vrings.write().unwrap()[index as usize]
.kick .kick
.take() .take()
.unwrap() .unwrap()
@ -652,7 +658,7 @@ impl<S: VhostUserBackend> VhostUserSlaveReqHandler for VhostUserHandler<S> {
) )
}; };
} }
self.vrings.lock().unwrap()[index as usize].kick = self.vrings.write().unwrap()[index as usize].kick =
Some(unsafe { EventFd::from_raw_fd(fd.unwrap()) });; Some(unsafe { EventFd::from_raw_fd(fd.unwrap()) });;
// Quotation from vhost-user spec: // Quotation from vhost-user spec:
@ -662,9 +668,9 @@ impl<S: VhostUserBackend> VhostUserSlaveReqHandler for VhostUserHandler<S> {
// VHOST_USER_GET_VRING_BASE. // VHOST_USER_GET_VRING_BASE.
// //
// So we should add fd to event monitor(select, poll, epoll) here. // So we should add fd to event monitor(select, poll, epoll) here.
self.vrings.lock().unwrap()[index as usize].started = true; self.vrings.write().unwrap()[index as usize].started = true;
self.vring_handler self.vring_handler
.lock() .read()
.unwrap() .unwrap()
.register_vring_listener(index as usize) .register_vring_listener(index as usize)
.unwrap(); .unwrap();
@ -677,11 +683,11 @@ impl<S: VhostUserBackend> VhostUserSlaveReqHandler for VhostUserHandler<S> {
return Err(VhostUserError::InvalidParam); return Err(VhostUserError::InvalidParam);
} }
if self.vrings.lock().unwrap()[index as usize].call.is_some() { if self.vrings.write().unwrap()[index as usize].call.is_some() {
// Close file descriptor set by previous operations. // Close file descriptor set by previous operations.
let _ = unsafe { let _ = unsafe {
libc::close( libc::close(
self.vrings.lock().unwrap()[index as usize] self.vrings.write().unwrap()[index as usize]
.call .call
.take() .take()
.unwrap() .unwrap()
@ -689,7 +695,7 @@ impl<S: VhostUserBackend> VhostUserSlaveReqHandler for VhostUserHandler<S> {
) )
}; };
} }
self.vrings.lock().unwrap()[index as usize].call = self.vrings.write().unwrap()[index as usize].call =
Some(unsafe { EventFd::from_raw_fd(fd.unwrap()) }); Some(unsafe { EventFd::from_raw_fd(fd.unwrap()) });
Ok(()) Ok(())
@ -700,11 +706,11 @@ impl<S: VhostUserBackend> VhostUserSlaveReqHandler for VhostUserHandler<S> {
return Err(VhostUserError::InvalidParam); return Err(VhostUserError::InvalidParam);
} }
if self.vrings.lock().unwrap()[index as usize].err.is_some() { if self.vrings.read().unwrap()[index as usize].err.is_some() {
// Close file descriptor set by previous operations. // Close file descriptor set by previous operations.
let _ = unsafe { let _ = unsafe {
libc::close( libc::close(
self.vrings.lock().unwrap()[index as usize] self.vrings.write().unwrap()[index as usize]
.err .err
.take() .take()
.unwrap() .unwrap()
@ -712,7 +718,7 @@ impl<S: VhostUserBackend> VhostUserSlaveReqHandler for VhostUserHandler<S> {
) )
}; };
} }
self.vrings.lock().unwrap()[index as usize].err = self.vrings.write().unwrap()[index as usize].err =
Some(unsafe { EventFd::from_raw_fd(fd.unwrap()) }); Some(unsafe { EventFd::from_raw_fd(fd.unwrap()) });
Ok(()) Ok(())
@ -731,7 +737,7 @@ impl<S: VhostUserBackend> VhostUserSlaveReqHandler for VhostUserHandler<S> {
// enabled by VHOST_USER_SET_VRING_ENABLE with parameter 1, // enabled by VHOST_USER_SET_VRING_ENABLE with parameter 1,
// or after it has been disabled by VHOST_USER_SET_VRING_ENABLE // or after it has been disabled by VHOST_USER_SET_VRING_ENABLE
// with parameter 0. // with parameter 0.
self.vrings.lock().unwrap()[index as usize].enabled = enable; self.vrings.write().unwrap()[index as usize].enabled = enable;
Ok(()) Ok(())
} }