diff --git a/Cargo.lock b/Cargo.lock index b08dab092..e484d1dcd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1780,8 +1780,10 @@ checksum = "8edc89eaa583cf6bc4c6ef16a219f0a60d342ca3bf0eae793560038ac8af1795" name = "rate_limiter" version = "0.1.0" dependencies = [ + "epoll", "libc", "log", + "thiserror", "vmm-sys-util", ] diff --git a/rate_limiter/Cargo.toml b/rate_limiter/Cargo.toml index 6d9bfc9bc..51ec74dd8 100644 --- a/rate_limiter/Cargo.toml +++ b/rate_limiter/Cargo.toml @@ -4,6 +4,8 @@ version = "0.1.0" edition = "2021" [dependencies] +epoll = "4.3.3" libc = "0.2.147" log = "0.4.20" +thiserror = "1.0.40" vmm-sys-util = "0.11.0" diff --git a/rate_limiter/src/group.rs b/rate_limiter/src/group.rs new file mode 100644 index 000000000..24c775141 --- /dev/null +++ b/rate_limiter/src/group.rs @@ -0,0 +1,526 @@ +// Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2023 Crusoe Energy Systems LLC +// SPDX-License-Identifier: Apache-2.0 + +use crate::{RateLimiter, TokenType}; +use core::panic::AssertUnwindSafe; +use std::fs::File; +use std::io; +use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; +use std::result; +use std::sync::{Arc, Mutex}; +use std::thread; +use thiserror::Error; +use vmm_sys_util::eventfd::EventFd; + +/// Errors associated with rate-limiter group. +#[derive(Debug, Error)] +pub enum Error { + /// Cannot create thread + #[error("Error spawning rate-limiter thread {0}")] + ThreadSpawn(#[source] io::Error), + + /// Cannot create epoll context. + #[error("Error creating epoll context: {0}")] + Epoll(#[source] io::Error), + + /// Cannot create EventFd. + #[error("Error creating EventFd: {0}")] + EventFd(#[source] io::Error), + + /// Cannot create RateLimiter. + #[error("Error creating RateLimiter: {0}")] + RateLimiter(#[source] io::Error), + + /// Cannot read from EventFd. + #[error("Error reading from EventFd: {0}")] + EventFdRead(#[source] io::Error), + + /// Cannot write to EventFd. + #[error("Error writing to EventFd: {0}")] + EventFdWrite(#[source] io::Error), +} + +/// The RateLimiterGroupHandle is a handle to a RateLimiterGroup that may be +/// used in exactly the same way as the RateLimiter type. When the RateLimiter +/// within a RateLimiterGroup is unblocked, each RateLimiterGroupHandle will +/// be notified. +pub struct RateLimiterGroupHandle { + eventfd: Arc, + inner: Arc, +} + +impl RateLimiterGroupHandle { + fn new(inner: Arc) -> result::Result { + let eventfd = Arc::new(EventFd::new(0).map_err(Error::EventFd)?); + inner.handles.lock().unwrap().push(eventfd.clone()); + Ok(Self { eventfd, inner }) + } + + /// Attempts to consume tokens and returns whether that is possible. + /// + /// If rate limiting is disabled on provided `token_type`, this function will always succeed. + pub fn consume(&self, tokens: u64, token_type: TokenType) -> bool { + self.inner.rate_limiter.consume(tokens, token_type) + } + + /// Adds tokens of `token_type` to their respective bucket. + /// + /// Can be used to *manually* add tokens to a bucket. Useful for reverting a + /// `consume()` if needed. + pub fn manual_replenish(&self, tokens: u64, token_type: TokenType) { + self.inner.rate_limiter.manual_replenish(tokens, token_type) + } + + /// This function needs to be called every time there is an event on the + /// FD provided by this object's `AsRawFd` trait implementation. + /// + /// # Errors + /// + /// If the rate limiter is disabled or is not blocked, an error is returned. + pub fn event_handler(&self) -> Result<(), Error> { + self.eventfd.read().map_err(Error::EventFdRead).map(|_| ()) + } + + /// Returns whether this rate limiter is blocked. + /// + /// The limiter 'blocks' when a `consume()` operation fails because there was not enough + /// budget for it. + /// An event will be generated on the exported FD when the limiter 'unblocks'. + pub fn is_blocked(&self) -> bool { + self.inner.rate_limiter.is_blocked() + } +} + +impl Clone for RateLimiterGroupHandle { + fn clone(&self) -> Self { + RateLimiterGroupHandle::new(self.inner.clone()).unwrap() + } +} + +impl AsRawFd for RateLimiterGroupHandle { + fn as_raw_fd(&self) -> RawFd { + self.eventfd.as_raw_fd() + } +} + +impl Drop for RateLimiterGroupHandle { + fn drop(&mut self) { + let mut handles = self.inner.handles.lock().unwrap(); + let index = handles + .iter() + .position(|handle| handle.as_raw_fd() == self.eventfd.as_raw_fd()) + .expect("RateLimiterGroupHandle must be subscribed to RateLimiterGroup"); + handles.remove(index); + } +} + +struct RateLimiterGroupInner { + id: String, + rate_limiter: RateLimiter, + handles: Mutex>>, +} + +/// A RateLimiterGroup is an extension of RateLimiter that enables rate-limiting +/// the aggregate io consumption of multiple consumers. +pub struct RateLimiterGroup { + inner: Arc, + epoll_file: File, + kill_evt: EventFd, + epoll_thread: Option>, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u64)] +enum EpollDispatch { + Kill = 1, + Unblocked = 2, + Unknown, +} + +impl From for EpollDispatch { + fn from(v: u64) -> Self { + use EpollDispatch::*; + match v { + 1 => Kill, + 2 => Unblocked, + _ => Unknown, + } + } +} + +impl RateLimiterGroup { + /// Create a new RateLimiterGroup. + pub fn new( + id: &str, + bytes_total_capacity: u64, + bytes_one_time_burst: u64, + bytes_complete_refill_time_ms: u64, + ops_total_capacity: u64, + ops_one_time_burst: u64, + ops_complete_refill_time_ms: u64, + ) -> result::Result { + let rate_limiter = RateLimiter::new( + bytes_total_capacity, + bytes_one_time_burst, + bytes_complete_refill_time_ms, + ops_total_capacity, + ops_one_time_burst, + ops_complete_refill_time_ms, + ) + .map_err(Error::RateLimiter)?; + + let epoll_fd = epoll::create(true).map_err(Error::Epoll)?; + let kill_evt = EventFd::new(0).map_err(Error::EventFd)?; + + epoll::ctl( + epoll_fd, + epoll::ControlOptions::EPOLL_CTL_ADD, + kill_evt.as_raw_fd(), + epoll::Event::new(epoll::Events::EPOLLIN, EpollDispatch::Kill as u64), + ) + .map_err(Error::Epoll)?; + + epoll::ctl( + epoll_fd, + epoll::ControlOptions::EPOLL_CTL_ADD, + rate_limiter.as_raw_fd(), + epoll::Event::new(epoll::Events::EPOLLIN, EpollDispatch::Unblocked as u64), + ) + .map_err(Error::Epoll)?; + + // Use 'File' to enforce closing on 'epoll_fd' + // SAFETY: epoll_fd is valid + let epoll_file = unsafe { File::from_raw_fd(epoll_fd) }; + + Ok(Self { + inner: Arc::new(RateLimiterGroupInner { + id: id.to_string(), + rate_limiter, + handles: Mutex::new(Vec::new()), + }), + epoll_file, + kill_evt, + epoll_thread: None, + }) + } + + /// Create a new RateLimiterGroupHandle. + pub fn new_handle(&self) -> result::Result { + RateLimiterGroupHandle::new(self.inner.clone()) + } + + /// Start a worker thread to broadcast an event to each RateLimiterGroupHandle + /// when the RateLimiter becomes unblocked. + pub fn start_thread(&mut self, exit_evt: EventFd) -> result::Result<(), Error> { + let inner = self.inner.clone(); + let epoll_fd = self.epoll_file.as_raw_fd(); + thread::Builder::new() + .name(format!("rate-limit-group-{}", inner.id)) + .spawn(move || { + let res = std::panic::catch_unwind(AssertUnwindSafe(move || { + const EPOLL_EVENTS_LEN: usize = 2; + + let mut events = + [epoll::Event::new(epoll::Events::empty(), 0); EPOLL_EVENTS_LEN]; + + loop { + let num_events = match epoll::wait(epoll_fd, -1, &mut events[..]) { + Ok(res) => res, + Err(e) => { + if e.kind() == io::ErrorKind::Interrupted { + continue; + } else { + return Err(Error::Epoll(e)); + } + } + }; + + for event in events.iter().take(num_events) { + let dispatch_event: EpollDispatch = event.data.into(); + match dispatch_event { + EpollDispatch::Unknown => { + let event = event.data; + warn!("Unknown rate-limiter loop event: {}", event); + } + EpollDispatch::Unblocked => { + inner.rate_limiter.event_handler().unwrap(); + let handles = inner.handles.lock().unwrap(); + for handle in handles.iter() { + handle.write(1).map_err(Error::EventFdWrite)? + } + } + EpollDispatch::Kill => { + info!( + "KILL_EVENT received, stopping rate-limit-group epoll loop" + ); + return Ok(()); + } + } + } + } + })); + + match res { + Ok(res) => { + if let Err(e) = res { + error!("Error running rate-limit-group worker: {:?}", e); + exit_evt.write(1).unwrap(); + } + } + Err(_) => { + error!("rate-limit-group worker panicked"); + exit_evt.write(1).unwrap(); + } + }; + }) + .map(|thread| self.epoll_thread.insert(thread)) + .map_err(Error::ThreadSpawn)?; + + Ok(()) + } +} + +impl Drop for RateLimiterGroup { + fn drop(&mut self) { + self.kill_evt.write(1).unwrap(); + + if let Some(t) = self.epoll_thread.take() { + if let Err(e) = t.join() { + error!("Error joining thread: {:?}", e); + } + } + } +} + +#[cfg(test)] +pub(crate) mod tests { + use super::RateLimiterGroupHandle; + use crate::{group::RateLimiterGroup, TokenBucket, TokenType, REFILL_TIMER_INTERVAL_MS}; + use std::{os::fd::AsRawFd, thread, time::Duration}; + use vmm_sys_util::eventfd::EventFd; + + impl RateLimiterGroupHandle { + pub fn bandwidth(&self) -> Option { + let guard = self.inner.rate_limiter.inner.lock().unwrap(); + guard.bandwidth.clone() + } + + pub fn ops(&self) -> Option { + let guard = self.inner.rate_limiter.inner.lock().unwrap(); + guard.ops.clone() + } + } + + #[test] + fn test_rate_limiter_group_new() { + let l = RateLimiterGroup::new("test", 1000, 1001, 1002, 1003, 1004, 1005).unwrap(); + let h = l.new_handle().unwrap(); + let bw = h.bandwidth().unwrap(); + assert_eq!(bw.capacity(), 1000); + assert_eq!(bw.one_time_burst(), 1001); + assert_eq!(bw.refill_time_ms(), 1002); + assert_eq!(bw.budget(), 1000); + + let ops = h.ops().unwrap(); + assert_eq!(ops.capacity(), 1003); + assert_eq!(ops.one_time_burst(), 1004); + assert_eq!(ops.refill_time_ms(), 1005); + assert_eq!(ops.budget(), 1003); + } + + #[test] + fn test_rate_limiter_group_manual_replenish() { + // rate limiter with limit of 1000 bytes/s and 1000 ops/s + let l = RateLimiterGroup::new("test", 1000, 0, 1000, 1000, 0, 1000).unwrap(); + let h = l.new_handle().unwrap(); + + // consume 123 bytes + assert!(h.consume(123, TokenType::Bytes)); + h.manual_replenish(23, TokenType::Bytes); + { + let bytes_tb = h.bandwidth().unwrap(); + assert_eq!(bytes_tb.budget(), 900); + } + // consume 123 ops + assert!(h.consume(123, TokenType::Ops)); + h.manual_replenish(23, TokenType::Ops); + { + let bytes_tb = h.ops().unwrap(); + assert_eq!(bytes_tb.budget(), 900); + } + } + + #[test] + fn test_rate_limiter_group_bandwidth() { + // rate limiter with limit of 1000 bytes/s + let mut l = RateLimiterGroup::new("test", 1000, 0, 1000, 0, 0, 0).unwrap(); + l.start_thread(EventFd::new(0).unwrap()).unwrap(); + + let h = l.new_handle().unwrap(); + + // limiter should not be blocked + assert!(!h.is_blocked()); + // raw FD for this disabled should be valid + assert!(h.as_raw_fd() > 0); + + // ops/s limiter should be disabled so consume(whatever) should work + assert!(h.consume(u64::max_value(), TokenType::Ops)); + + // do full 1000 bytes + assert!(h.consume(1000, TokenType::Bytes)); + // try and fail on another 100 + assert!(!h.consume(100, TokenType::Bytes)); + // since consume failed, limiter should be blocked now + assert!(h.is_blocked()); + // wait half the timer period + thread::sleep(Duration::from_millis(REFILL_TIMER_INTERVAL_MS / 2)); + // limiter should still be blocked + assert!(h.is_blocked()); + // wait the other half of the timer period + thread::sleep(Duration::from_millis(REFILL_TIMER_INTERVAL_MS / 2)); + // the timer_fd should have an event on it by now + assert!(h.event_handler().is_ok()); + // limiter should now be unblocked + assert!(!h.is_blocked()); + // try and succeed on another 100 bytes this time + assert!(h.consume(100, TokenType::Bytes)); + } + + #[test] + fn test_rate_limiter_group_ops() { + // rate limiter with limit of 1000 ops/s + let mut l = RateLimiterGroup::new("test", 0, 0, 0, 1000, 0, 1000).unwrap(); + l.start_thread(EventFd::new(0).unwrap()).unwrap(); + + let h = l.new_handle().unwrap(); + + // limiter should not be blocked + assert!(!h.is_blocked()); + // raw FD for this disabled should be valid + assert!(h.as_raw_fd() > 0); + + // bytes/s limiter should be disabled so consume(whatever) should work + assert!(h.consume(u64::max_value(), TokenType::Bytes)); + + // do full 1000 ops + assert!(h.consume(1000, TokenType::Ops)); + // try and fail on another 100 + assert!(!h.consume(100, TokenType::Ops)); + // since consume failed, limiter should be blocked now + assert!(h.is_blocked()); + // wait half the timer period + thread::sleep(Duration::from_millis(REFILL_TIMER_INTERVAL_MS / 2)); + // limiter should still be blocked + assert!(h.is_blocked()); + // wait the other half of the timer period + thread::sleep(Duration::from_millis(REFILL_TIMER_INTERVAL_MS / 2)); + // the timer_fd should have an event on it by now + assert!(h.event_handler().is_ok()); + // limiter should now be unblocked + assert!(!h.is_blocked()); + // try and succeed on another 100 ops this time + assert!(h.consume(100, TokenType::Ops)); + } + + #[test] + fn test_rate_limiter_group_full() { + // rate limiter with limit of 1000 bytes/s and 1000 ops/s + let mut l = RateLimiterGroup::new("test", 1000, 0, 1000, 1000, 0, 1000).unwrap(); + l.start_thread(EventFd::new(0).unwrap()).unwrap(); + + let h = l.new_handle().unwrap(); + + // limiter should not be blocked + assert!(!h.is_blocked()); + // raw FD for this disabled should be valid + assert!(h.as_raw_fd() > 0); + + // do full 1000 bytes + assert!(h.consume(1000, TokenType::Ops)); + // do full 1000 bytes + assert!(h.consume(1000, TokenType::Bytes)); + // try and fail on another 100 ops + assert!(!h.consume(100, TokenType::Ops)); + // try and fail on another 100 bytes + assert!(!h.consume(100, TokenType::Bytes)); + // since consume failed, limiter should be blocked now + assert!(h.is_blocked()); + // wait half the timer period + thread::sleep(Duration::from_millis(REFILL_TIMER_INTERVAL_MS / 2)); + // limiter should still be blocked + assert!(h.is_blocked()); + // wait the other half of the timer period + thread::sleep(Duration::from_millis(REFILL_TIMER_INTERVAL_MS / 2)); + // the timer_fd should have an event on it by now + assert!(h.event_handler().is_ok()); + // limiter should now be unblocked + assert!(!h.is_blocked()); + // try and succeed on another 100 ops this time + assert!(h.consume(100, TokenType::Ops)); + // try and succeed on another 100 bytes this time + assert!(h.consume(100, TokenType::Bytes)); + } + + #[test] + fn test_rate_limiter_group_overconsumption() { + // initialize the rate limiter + let mut l = RateLimiterGroup::new("test", 1000, 0, 1000, 1000, 0, 1000).unwrap(); + l.start_thread(EventFd::new(0).unwrap()).unwrap(); + + let h = l.new_handle().unwrap(); + + // try to consume 2.5x the bucket size + // we are "borrowing" 1.5x the bucket size in tokens since + // the bucket is full + assert!(h.consume(2500, TokenType::Bytes)); + + // check that even after a whole second passes, the rate limiter + // is still blocked + thread::sleep(Duration::from_millis(1000)); + assert!(h.is_blocked()); + + // after 1.5x the replenish time has passed, the rate limiter + // is available again + thread::sleep(Duration::from_millis(500)); + assert!(h.event_handler().is_ok()); + assert!(!h.is_blocked()); + + // reset the rate limiter + let mut l = RateLimiterGroup::new("test", 1000, 0, 1000, 1000, 0, 1000).unwrap(); + l.start_thread(EventFd::new(0).unwrap()).unwrap(); + + let h = l.new_handle().unwrap(); + // try to consume 1.5x the bucket size + // we are "borrowing" 1.5x the bucket size in tokens since + // the bucket is full, should arm the timer to 0.5x replenish + // time, which is 500 ms + assert!(h.consume(1500, TokenType::Bytes)); + + // check that after more than the minimum refill time, + // the rate limiter is still blocked + thread::sleep(Duration::from_millis(200)); + assert!(h.is_blocked()); + + // try to consume some tokens, which should fail as the timer + // is still active + assert!(!h.consume(100, TokenType::Bytes)); + assert!(h.is_blocked()); + + // check that after the minimum refill time, the timer was not + // overwritten and the rate limiter is still blocked from the + // borrowing we performed earlier + thread::sleep(Duration::from_millis(100)); + assert!(h.is_blocked()); + assert!(!h.consume(100, TokenType::Bytes)); + + // after waiting out the full duration, rate limiter should be + // available again + thread::sleep(Duration::from_millis(200)); + assert!(h.event_handler().is_ok()); + assert!(!h.is_blocked()); + assert!(h.consume(100, TokenType::Bytes)); + } +} diff --git a/rate_limiter/src/lib.rs b/rate_limiter/src/lib.rs index 287fa3113..c844f9290 100644 --- a/rate_limiter/src/lib.rs +++ b/rate_limiter/src/lib.rs @@ -53,6 +53,9 @@ use std::sync::Mutex; use std::time::{Duration, Instant}; use vmm_sys_util::timerfd::TimerFd; +/// Module for group rate limiting. +pub mod group; + #[derive(Debug)] /// Describes the errors that may occur while handling rate limiter events. pub enum Error {