// Copyright (c) 2020 Intel Corporation. All rights reserved. // // SPDX-License-Identifier: Apache-2.0 AND BSD-3-Clause use super::{register_listener, unregister_listener, vnet_hdr_len, Tap}; use crate::GuestMemoryMmap; use rate_limiter::{RateLimiter, TokenType}; use std::io; use std::num::Wrapping; use std::os::unix::io::{AsRawFd, RawFd}; use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; use virtio_queue::Queue; use vm_memory::{Bytes, GuestMemory, GuestMemoryAtomic}; use vm_virtio::{AccessPlatform, Translatable}; #[derive(Clone)] pub struct TxVirtio { pub counter_bytes: Wrapping, pub counter_frames: Wrapping, } impl Default for TxVirtio { fn default() -> Self { Self::new() } } impl TxVirtio { pub fn new() -> Self { TxVirtio { counter_bytes: Wrapping(0), counter_frames: Wrapping(0), } } pub fn process_desc_chain( &mut self, tap: &mut Tap, queue: &mut Queue>, rate_limiter: &mut Option, access_platform: Option<&Arc>, ) -> Result { let mut retry_write = false; let mut rate_limit_reached = false; loop { let used_desc_head: (u16, u32); let mut avail_iter = queue .iter() .map_err(NetQueuePairError::QueueIteratorFailed)?; if let Some(mut desc_chain) = avail_iter.next() { if rate_limit_reached { avail_iter.go_to_previous_position(); break; } let mut next_desc = desc_chain.next(); let mut iovecs = Vec::new(); while let Some(desc) = next_desc { let desc_addr = desc.addr().translate(access_platform, desc.len() as usize); if !desc.is_write_only() && desc.len() > 0 { let buf = desc_chain .memory() .get_slice(desc_addr, desc.len() as usize) .map_err(NetQueuePairError::GuestMemory)? .as_ptr(); let iovec = libc::iovec { iov_base: buf as *mut libc::c_void, iov_len: desc.len() as libc::size_t, }; iovecs.push(iovec); } else { error!( "Invalid descriptor chain: address = 0x{:x} length = {} write_only = {}", desc_addr.0, desc.len(), desc.is_write_only() ); return Err(NetQueuePairError::DescriptorChainInvalid); } next_desc = desc_chain.next(); } let len = if !iovecs.is_empty() { let result = unsafe { libc::writev( tap.as_raw_fd() as libc::c_int, iovecs.as_ptr() as *const libc::iovec, iovecs.len() as libc::c_int, ) }; if result < 0 { let e = std::io::Error::last_os_error(); /* EAGAIN */ if e.kind() == std::io::ErrorKind::WouldBlock { avail_iter.go_to_previous_position(); retry_write = true; break; } error!("net: tx: failed writing to tap: {}", e); return Err(NetQueuePairError::WriteTap(e)); } self.counter_bytes += Wrapping(result as u64 - vnet_hdr_len() as u64); self.counter_frames += Wrapping(1); result as u32 } else { 0 }; used_desc_head = (desc_chain.head_index(), len); // For the sake of simplicity (similar to the RX rate limiting), we always // let the 'last' descriptor chain go-through even if it was over the rate // limit, and simply stop processing oncoming `avail_desc` if any. if let Some(rate_limiter) = rate_limiter { rate_limit_reached = !rate_limiter.consume(1, TokenType::Ops) || !rate_limiter.consume(len as u64, TokenType::Bytes); } } else { break; } queue .add_used(used_desc_head.0, used_desc_head.1) .map_err(NetQueuePairError::QueueAddUsed)?; queue .enable_notification() .map_err(NetQueuePairError::QueueEnableNotification)?; } Ok(retry_write) } } #[derive(Clone)] pub struct RxVirtio { pub counter_bytes: Wrapping, pub counter_frames: Wrapping, } impl Default for RxVirtio { fn default() -> Self { Self::new() } } impl RxVirtio { pub fn new() -> Self { RxVirtio { counter_bytes: Wrapping(0), counter_frames: Wrapping(0), } } pub fn process_desc_chain( &mut self, tap: &mut Tap, queue: &mut Queue>, rate_limiter: &mut Option, access_platform: Option<&Arc>, ) -> Result { let mut exhausted_descs = true; let mut rate_limit_reached = false; loop { let used_desc_head: (u16, u32); let mut avail_iter = queue .iter() .map_err(NetQueuePairError::QueueIteratorFailed)?; if let Some(mut desc_chain) = avail_iter.next() { if rate_limit_reached { exhausted_descs = false; avail_iter.go_to_previous_position(); break; } let desc = desc_chain .next() .ok_or(NetQueuePairError::DescriptorChainTooShort)?; let num_buffers_addr = desc_chain .memory() .checked_offset( desc.addr().translate(access_platform, desc.len() as usize), 10, ) .unwrap(); let mut next_desc = Some(desc); let mut iovecs = Vec::new(); while let Some(desc) = next_desc { let desc_addr = desc.addr().translate(access_platform, desc.len() as usize); if desc.is_write_only() && desc.len() > 0 { let buf = desc_chain .memory() .get_slice(desc_addr, desc.len() as usize) .map_err(NetQueuePairError::GuestMemory)? .as_ptr(); let iovec = libc::iovec { iov_base: buf as *mut libc::c_void, iov_len: desc.len() as libc::size_t, }; iovecs.push(iovec); } else { error!( "Invalid descriptor chain: address = 0x{:x} length = {} write_only = {}", desc_addr.0, desc.len(), desc.is_write_only() ); return Err(NetQueuePairError::DescriptorChainInvalid); } next_desc = desc_chain.next(); } let len = if !iovecs.is_empty() { let result = unsafe { libc::readv( tap.as_raw_fd() as libc::c_int, iovecs.as_ptr() as *const libc::iovec, iovecs.len() as libc::c_int, ) }; if result < 0 { let e = std::io::Error::last_os_error(); exhausted_descs = false; avail_iter.go_to_previous_position(); /* EAGAIN */ if e.kind() == std::io::ErrorKind::WouldBlock { break; } error!("net: rx: failed reading from tap: {}", e); return Err(NetQueuePairError::ReadTap(e)); } // Write num_buffers to guest memory. We simply write 1 as we // never spread the frame over more than one descriptor chain. desc_chain .memory() .write_obj(1u16, num_buffers_addr) .map_err(NetQueuePairError::GuestMemory)?; self.counter_bytes += Wrapping(result as u64 - vnet_hdr_len() as u64); self.counter_frames += Wrapping(1); result as u32 } else { 0 }; used_desc_head = (desc_chain.head_index(), len); // For the sake of simplicity (keeping the handling of RX_QUEUE_EVENT and // RX_TAP_EVENT totally asynchronous), we always let the 'last' descriptor // chain go-through even if it was over the rate limit, and simply stop // processing oncoming `avail_desc` if any. if let Some(rate_limiter) = rate_limiter { rate_limit_reached = !rate_limiter.consume(1, TokenType::Ops) || !rate_limiter.consume(len as u64, TokenType::Bytes); } } else { break; } queue .add_used(used_desc_head.0, used_desc_head.1) .map_err(NetQueuePairError::QueueAddUsed)?; queue .enable_notification() .map_err(NetQueuePairError::QueueEnableNotification)?; } Ok(exhausted_descs) } } #[derive(Default, Clone)] pub struct NetCounters { pub tx_bytes: Arc, pub tx_frames: Arc, pub rx_bytes: Arc, pub rx_frames: Arc, } #[derive(Debug)] pub enum NetQueuePairError { /// No memory configured NoMemoryConfigured, /// Error registering listener RegisterListener(io::Error), /// Error unregistering listener UnregisterListener(io::Error), /// Error writing to the TAP device WriteTap(io::Error), /// Error reading from the TAP device ReadTap(io::Error), /// Error related to guest memory GuestMemory(vm_memory::GuestMemoryError), /// Returned an error while iterating through the queue QueueIteratorFailed(virtio_queue::Error), /// Descriptor chain is too short DescriptorChainTooShort, /// Descriptor chain does not contain valid descriptors DescriptorChainInvalid, /// Failed to determine if queue needed notification QueueNeedsNotification(virtio_queue::Error), /// Failed to enable notification on the queue QueueEnableNotification(virtio_queue::Error), /// Failed to add used index to the queue QueueAddUsed(virtio_queue::Error), } pub struct NetQueuePair { pub tap: Tap, // With epoll each FD must be unique. So in order to filter the // events we need to get a second FD responding to the original // device so that we can send EPOLLOUT and EPOLLIN to separate // events. pub tap_for_write_epoll: Tap, pub rx: RxVirtio, pub tx: TxVirtio, pub epoll_fd: Option, pub rx_tap_listening: bool, pub tx_tap_listening: bool, pub counters: NetCounters, pub tap_rx_event_id: u16, pub tap_tx_event_id: u16, pub rx_desc_avail: bool, pub rx_rate_limiter: Option, pub tx_rate_limiter: Option, pub access_platform: Option>, } impl NetQueuePair { pub fn process_tx( &mut self, queue: &mut Queue>, ) -> Result { let tx_tap_retry = self.tx.process_desc_chain( &mut self.tap, queue, &mut self.tx_rate_limiter, self.access_platform.as_ref(), )?; // We got told to try again when writing to the tap. Wait for the TAP to be writable if tx_tap_retry && !self.tx_tap_listening { register_listener( self.epoll_fd.unwrap(), self.tap_for_write_epoll.as_raw_fd(), epoll::Events::EPOLLOUT, u64::from(self.tap_tx_event_id), ) .map_err(NetQueuePairError::RegisterListener)?; self.tx_tap_listening = true; info!("Writing to TAP returned EAGAIN. Listening for TAP to become writable."); } else if !tx_tap_retry && self.tx_tap_listening { unregister_listener( self.epoll_fd.unwrap(), self.tap_for_write_epoll.as_raw_fd(), epoll::Events::EPOLLOUT, u64::from(self.tap_tx_event_id), ) .map_err(NetQueuePairError::UnregisterListener)?; self.tx_tap_listening = false; info!("Writing to TAP succeeded. No longer listening for TAP to become writable."); } self.counters .tx_bytes .fetch_add(self.tx.counter_bytes.0, Ordering::AcqRel); self.counters .tx_frames .fetch_add(self.tx.counter_frames.0, Ordering::AcqRel); self.tx.counter_bytes = Wrapping(0); self.tx.counter_frames = Wrapping(0); queue .needs_notification() .map_err(NetQueuePairError::QueueNeedsNotification) } pub fn process_rx( &mut self, queue: &mut Queue>, ) -> Result { self.rx_desc_avail = !self.rx.process_desc_chain( &mut self.tap, queue, &mut self.rx_rate_limiter, self.access_platform.as_ref(), )?; let rate_limit_reached = self .rx_rate_limiter .as_ref() .map_or(false, |r| r.is_blocked()); // Stop listening on the `RX_TAP_EVENT` when: // 1) there is no available describles, or // 2) the RX rate limit is reached. if self.rx_tap_listening && (!self.rx_desc_avail || rate_limit_reached) { unregister_listener( self.epoll_fd.unwrap(), self.tap.as_raw_fd(), epoll::Events::EPOLLIN, u64::from(self.tap_rx_event_id), ) .map_err(NetQueuePairError::UnregisterListener)?; self.rx_tap_listening = false; } self.counters .rx_bytes .fetch_add(self.rx.counter_bytes.0, Ordering::AcqRel); self.counters .rx_frames .fetch_add(self.rx.counter_frames.0, Ordering::AcqRel); self.rx.counter_bytes = Wrapping(0); self.rx.counter_frames = Wrapping(0); queue .needs_notification() .map_err(NetQueuePairError::QueueNeedsNotification) } }