diff --git a/vm-virtio/src/net_util.rs b/vm-virtio/src/net_util.rs index 17289c612..19d9a52d9 100644 --- a/vm-virtio/src/net_util.rs +++ b/vm-virtio/src/net_util.rs @@ -14,7 +14,7 @@ use std::os::unix::io::{AsRawFd, RawFd}; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use virtio_bindings::bindings::virtio_net::*; -use vm_memory::{Bytes, GuestAddress, GuestMemoryError, GuestMemoryMmap}; +use vm_memory::{ByteValued, Bytes, GuestAddress, GuestMemoryError, GuestMemoryMmap}; use vmm_sys_util::eventfd::EventFd; type Result = std::result::Result; @@ -48,6 +48,8 @@ pub enum Error { FailedProcessMQ, /// Read queue failed. GuestMemory(GuestMemoryError), + /// Invalid ctrl class + InvalidCtlClass, /// Invalid ctrl command InvalidCtlCmd, /// Invalid descriptor @@ -91,15 +93,54 @@ impl CtrlVirtio { CtrlVirtio { queue_evt, queue } } + fn process_mq(&self, mem: &GuestMemoryMmap, avail_desc: DescriptorChain) -> Result<()> { + let mq_desc = if avail_desc.has_next() { + avail_desc.next_descriptor().unwrap() + } else { + return Err(Error::NoQueuePairsNum); + }; + let queue_pairs = mem + .read_obj::(mq_desc.addr) + .map_err(Error::GuestMemory)?; + if (queue_pairs < VIRTIO_NET_CTRL_MQ_VQ_PAIRS_MIN as u16) + || (queue_pairs > VIRTIO_NET_CTRL_MQ_VQ_PAIRS_MAX as u16) + { + return Err(Error::InvalidQueuePairsNum); + } + let status_desc = if mq_desc.has_next() { + mq_desc.next_descriptor().unwrap() + } else { + return Err(Error::NoQueuePairsNum); + }; + mem.write_obj::(0, status_desc.addr) + .map_err(Error::GuestMemory)?; + + Ok(()) + } + pub fn process_cvq(&mut self, mem: &GuestMemoryMmap) -> Result<()> { let mut used_desc_heads = [(0, 0); QUEUE_SIZE]; let mut used_count = 0; if let Some(avail_desc) = self.queue.iter(&mem).next() { used_desc_heads[used_count] = (avail_desc.index, avail_desc.len); used_count += 1; - let _ = mem - .read_obj::(avail_desc.addr) + let ctrl_hdr = mem + .read_obj::(avail_desc.addr) .map_err(Error::GuestMemory)?; + let ctrl_hdr_v = ctrl_hdr.as_slice(); + let class = ctrl_hdr_v[0]; + let cmd = ctrl_hdr_v[1]; + match u32::from(class) { + VIRTIO_NET_CTRL_MQ => { + if u32::from(cmd) != VIRTIO_NET_CTRL_MQ_VQ_PAIRS_SET { + return Err(Error::InvalidCtlCmd); + } + if let Err(_e) = self.process_mq(&mem, avail_desc) { + return Err(Error::FailedProcessMQ); + } + } + _ => return Err(Error::InvalidCtlClass), + } } else { return Err(Error::InvalidDesc); }