diff --git a/net_util/src/queue_pair.rs b/net_util/src/queue_pair.rs index 806a9fa8a..5195b1db5 100644 --- a/net_util/src/queue_pair.rs +++ b/net_util/src/queue_pair.rs @@ -41,33 +41,16 @@ impl TxVirtio { rate_limiter: &mut Option, ) -> Result { let mut retry_write = false; + let mut rate_limit_reached = false; while let Some(avail_desc) = queue.iter(mem).next() { + if rate_limit_reached { + queue.go_to_previous_position(); + break; + } + let head_index = avail_desc.index; let mut next_desc = Some(avail_desc); - if let Some(rate_limiter) = rate_limiter { - if !rate_limiter.consume(1, TokenType::Ops) { - queue.go_to_previous_position(); - break; - } - - let mut bytes = Wrapping(0); - let mut tmp_next_desc = next_desc.clone(); - while let Some(desc) = tmp_next_desc { - if !desc.is_write_only() { - bytes += Wrapping(desc.len as u64); - } - tmp_next_desc = desc.next_descriptor(); - } - bytes -= Wrapping(vnet_hdr_len() as u64); - if !rate_limiter.consume(bytes.0, TokenType::Bytes) { - // Revert the OPS consume(). - rate_limiter.manual_replenish(1, TokenType::Ops); - queue.go_to_previous_position(); - break; - } - } - let mut iovecs = Vec::new(); while let Some(desc) = next_desc { if !desc.is_write_only() && desc.len > 0 { @@ -84,7 +67,7 @@ impl TxVirtio { next_desc = desc.next_descriptor(); } - if !iovecs.is_empty() { + let len = if !iovecs.is_empty() { let result = unsafe { libc::writev( tap.as_raw_fd() as libc::c_int, @@ -108,10 +91,22 @@ impl TxVirtio { self.counter_bytes += Wrapping(result as u64 - vnet_hdr_len() as u64); self.counter_frames += Wrapping(1); - } + + result as u32 + } else { + 0 + }; queue.add_used(mem, head_index, 0); queue.update_avail_event(mem); + + // For the sake of simplicity (similar to the RX rate limiting), we always + // let the 'last' descriptor chain go-through even if it was over the rate + // limit, and simply stop processing oncoming `avail_desc` if any. + if let Some(rate_limiter) = rate_limiter { + rate_limit_reached = !rate_limiter.consume(1, TokenType::Ops) + || !rate_limiter.consume(len as u64, TokenType::Bytes); + } } Ok(retry_write)