rate-limiter: Allow RateLimiter to be shared between threads

Signed-off-by: Thomas Barrett <tbarrett@crusoeenergy.com>
This commit is contained in:
Thomas Barrett 2023-09-04 14:10:45 -07:00 committed by Bo Chen
parent e880aeed8d
commit 8d2e590886

View File

@ -46,9 +46,11 @@
#[macro_use]
extern crate log;
use std::io;
use std::os::unix::io::{AsRawFd, RawFd};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Mutex;
use std::time::{Duration, Instant};
use std::{fmt, io};
use vmm_sys_util::timerfd::TimerFd;
#[derive(Debug)]
@ -271,27 +273,27 @@ pub enum BucketUpdate {
/// implementation. These events are meant to be consumed by the user of this struct.
/// On each such event, the user must call the `event_handler()` method.
pub struct RateLimiter {
inner: Mutex<RateLimiterInner>,
// Internal flag that quickly determines timer state.
timer_active: AtomicBool,
}
struct RateLimiterInner {
bandwidth: Option<TokenBucket>,
ops: Option<TokenBucket>,
timer_fd: TimerFd,
// Internal flag that quickly determines timer state.
timer_active: bool,
}
impl PartialEq for RateLimiter {
fn eq(&self, other: &RateLimiter) -> bool {
self.bandwidth == other.bandwidth && self.ops == other.ops
}
}
impl fmt::Debug for RateLimiter {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"RateLimiter {{ bandwidth: {:?}, ops: {:?} }}",
self.bandwidth, self.ops
)
impl RateLimiterInner {
// Arm the timer of the rate limiter with the provided `Duration` (which will fire only once).
fn activate_timer(&mut self, dur: Duration, flag: &AtomicBool) {
// Panic when failing to arm the timer (same handling in crate TimerFd::set_state())
self.timer_fd
.reset(dur, None)
.expect("Can't arm the timer (unexpected 'timerfd_settime' failure).");
flag.store(true, Ordering::Relaxed)
}
}
@ -355,35 +357,28 @@ impl RateLimiter {
}
Ok(RateLimiter {
bandwidth: bytes_token_bucket,
ops: ops_token_bucket,
timer_fd,
timer_active: false,
inner: Mutex::new(RateLimiterInner {
bandwidth: bytes_token_bucket,
ops: ops_token_bucket,
timer_fd,
}),
timer_active: AtomicBool::new(false),
})
}
// Arm the timer of the rate limiter with the provided `Duration` (which will fire only once).
fn activate_timer(&mut self, dur: Duration) {
// Panic when failing to arm the timer (same handling in crate TimerFd::set_state())
self.timer_fd
.reset(dur, None)
.expect("Can't arm the timer (unexpected 'timerfd_settime' failure).");
self.timer_active = true;
}
/// Attempts to consume tokens and returns whether that is possible.
///
/// If rate limiting is disabled on provided `token_type`, this function will always succeed.
pub fn consume(&mut self, tokens: u64, token_type: TokenType) -> bool {
pub fn consume(&self, tokens: u64, token_type: TokenType) -> bool {
// If the timer is active, we can't consume tokens from any bucket and the function fails.
if self.timer_active {
if self.is_blocked() {
return false;
}
let mut guard = self.inner.lock().unwrap();
// Identify the required token bucket.
let token_bucket = match token_type {
TokenType::Bytes => self.bandwidth.as_mut(),
TokenType::Ops => self.ops.as_mut(),
TokenType::Bytes => guard.bandwidth.as_mut(),
TokenType::Ops => guard.ops.as_mut(),
};
// Try to consume from the token bucket.
if let Some(bucket) = token_bucket {
@ -393,8 +388,8 @@ impl RateLimiter {
// register a timer to replenish the bucket and resume processing;
// make sure there is only one running timer for this limiter.
BucketReduction::Failure => {
if !self.timer_active {
self.activate_timer(TIMER_REFILL_DUR);
if !self.is_blocked() {
guard.activate_timer(TIMER_REFILL_DUR, &self.timer_active);
}
false
}
@ -409,7 +404,10 @@ impl RateLimiter {
// order to enforce the bandwidth limit we need to prevent
// further calls to the rate limiter for
// `ratio * refill_time` milliseconds.
self.activate_timer(Duration::from_millis((ratio * refill_time as f64) as u64));
guard.activate_timer(
Duration::from_millis((ratio * refill_time as f64) as u64),
&self.timer_active,
);
true
}
}
@ -424,11 +422,12 @@ impl RateLimiter {
///
/// Can be used to *manually* add tokens to a bucket. Useful for reverting a
/// `consume()` if needed.
pub fn manual_replenish(&mut self, tokens: u64, token_type: TokenType) {
pub fn manual_replenish(&self, tokens: u64, token_type: TokenType) {
let mut guard = self.inner.lock().unwrap();
// Identify the required token bucket.
let token_bucket = match token_type {
TokenType::Bytes => self.bandwidth.as_mut(),
TokenType::Ops => self.ops.as_mut(),
TokenType::Bytes => guard.bandwidth.as_mut(),
TokenType::Ops => guard.ops.as_mut(),
};
// Add tokens to the token bucket.
if let Some(bucket) = token_bucket {
@ -442,7 +441,7 @@ impl RateLimiter {
/// budget for it.
/// An event will be generated on the exported FD when the limiter 'unblocks'.
pub fn is_blocked(&self) -> bool {
self.timer_active
self.timer_active.load(Ordering::Relaxed)
}
/// This function needs to be called every time there is an event on the
@ -451,11 +450,12 @@ impl RateLimiter {
/// # Errors
///
/// If the rate limiter is disabled or is not blocked, an error is returned.
pub fn event_handler(&mut self) -> Result<(), Error> {
pub fn event_handler(&self) -> Result<(), Error> {
let mut guard = self.inner.lock().unwrap();
loop {
// Note: As we manually added the `O_NONBLOCK` flag to the FD, the following
// `timer_fd::wait()` won't block (which is different from its default behavior.)
match self.timer_fd.wait() {
match guard.timer_fd.wait() {
Err(e) => {
let err: std::io::Error = e.into();
match err.kind() {
@ -469,7 +469,7 @@ impl RateLimiter {
}
}
_ => {
self.timer_active = false;
self.timer_active.store(false, Ordering::Relaxed);
return Ok(());
}
}
@ -479,27 +479,18 @@ impl RateLimiter {
/// Updates the parameters of the token buckets associated with this RateLimiter.
// TODO: Please note that, right now, the buckets become full after being updated.
pub fn update_buckets(&mut self, bytes: BucketUpdate, ops: BucketUpdate) {
let mut guard = self.inner.lock().unwrap();
match bytes {
BucketUpdate::Disabled => self.bandwidth = None,
BucketUpdate::Update(tb) => self.bandwidth = Some(tb),
BucketUpdate::Disabled => guard.bandwidth = None,
BucketUpdate::Update(tb) => guard.bandwidth = Some(tb),
BucketUpdate::None => (),
};
match ops {
BucketUpdate::Disabled => self.ops = None,
BucketUpdate::Update(tb) => self.ops = Some(tb),
BucketUpdate::Disabled => guard.ops = None,
BucketUpdate::Update(tb) => guard.ops = Some(tb),
BucketUpdate::None => (),
};
}
/// Returns an immutable view of the inner bandwidth token bucket.
pub fn bandwidth(&self) -> Option<&TokenBucket> {
self.bandwidth.as_ref()
}
/// Returns an immutable view of the inner ops token bucket.
pub fn ops(&self) -> Option<&TokenBucket> {
self.ops.as_ref()
}
}
impl AsRawFd for RateLimiter {
@ -510,7 +501,8 @@ impl AsRawFd for RateLimiter {
/// Will return a negative value if rate limiting is disabled on both
/// token types.
fn as_raw_fd(&self) -> RawFd {
self.timer_fd.as_raw_fd()
let guard = self.inner.lock().unwrap();
guard.timer_fd.as_raw_fd()
}
}
@ -525,6 +517,7 @@ impl Default for RateLimiter {
#[cfg(test)]
pub(crate) mod tests {
use super::*;
use std::fmt;
use std::thread;
use std::time::Duration;
@ -557,11 +550,33 @@ pub(crate) mod tests {
}
impl RateLimiter {
fn get_token_bucket(&self, token_type: TokenType) -> Option<&TokenBucket> {
match token_type {
TokenType::Bytes => self.bandwidth.as_ref(),
TokenType::Ops => self.ops.as_ref(),
}
pub fn bandwidth(&self) -> Option<TokenBucket> {
let guard = self.inner.lock().unwrap();
guard.bandwidth.clone()
}
pub fn ops(&self) -> Option<TokenBucket> {
let guard = self.inner.lock().unwrap();
guard.ops.clone()
}
}
impl PartialEq for RateLimiter {
fn eq(&self, other: &RateLimiter) -> bool {
let self_guard = self.inner.lock().unwrap();
let other_guard = other.inner.lock().unwrap();
self_guard.bandwidth == other_guard.bandwidth && self_guard.ops == other_guard.ops
}
}
impl fmt::Debug for RateLimiter {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let guard = self.inner.lock().unwrap();
write!(
f,
"RateLimiter {{ bandwidth: {:?}, ops: {:?} }}",
guard.bandwidth, guard.ops
)
}
}
@ -640,7 +655,7 @@ pub(crate) mod tests {
#[test]
fn test_rate_limiter_default() {
let mut l = RateLimiter::default();
let l = RateLimiter::default();
// limiter should not be blocked
assert!(!l.is_blocked());
@ -659,14 +674,13 @@ pub(crate) mod tests {
#[test]
fn test_rate_limiter_new() {
let l = RateLimiter::new(1000, 1001, 1002, 1003, 1004, 1005).unwrap();
let bw = l.bandwidth.unwrap();
let bw = l.bandwidth().unwrap();
assert_eq!(bw.capacity(), 1000);
assert_eq!(bw.one_time_burst(), 1001);
assert_eq!(bw.refill_time_ms(), 1002);
assert_eq!(bw.budget(), 1000);
let ops = l.ops.unwrap();
let ops = l.ops().unwrap();
assert_eq!(ops.capacity(), 1003);
assert_eq!(ops.one_time_burst(), 1004);
assert_eq!(ops.refill_time_ms(), 1005);
@ -676,20 +690,20 @@ pub(crate) mod tests {
#[test]
fn test_rate_limiter_manual_replenish() {
// rate limiter with limit of 1000 bytes/s and 1000 ops/s
let mut l = RateLimiter::new(1000, 0, 1000, 1000, 0, 1000).unwrap();
let l = RateLimiter::new(1000, 0, 1000, 1000, 0, 1000).unwrap();
// consume 123 bytes
assert!(l.consume(123, TokenType::Bytes));
l.manual_replenish(23, TokenType::Bytes);
{
let bytes_tb = l.get_token_bucket(TokenType::Bytes).unwrap();
let bytes_tb = l.bandwidth().unwrap();
assert_eq!(bytes_tb.budget(), 900);
}
// consume 123 ops
assert!(l.consume(123, TokenType::Ops));
l.manual_replenish(23, TokenType::Ops);
{
let bytes_tb = l.get_token_bucket(TokenType::Ops).unwrap();
let bytes_tb = l.ops().unwrap();
assert_eq!(bytes_tb.budget(), 900);
}
}
@ -697,7 +711,7 @@ pub(crate) mod tests {
#[test]
fn test_rate_limiter_bandwidth() {
// rate limiter with limit of 1000 bytes/s
let mut l = RateLimiter::new(1000, 0, 1000, 0, 0, 0).unwrap();
let l = RateLimiter::new(1000, 0, 1000, 0, 0, 0).unwrap();
// limiter should not be blocked
assert!(!l.is_blocked());
@ -730,7 +744,7 @@ pub(crate) mod tests {
#[test]
fn test_rate_limiter_ops() {
// rate limiter with limit of 1000 ops/s
let mut l = RateLimiter::new(0, 0, 0, 1000, 0, 1000).unwrap();
let l = RateLimiter::new(0, 0, 0, 1000, 0, 1000).unwrap();
// limiter should not be blocked
assert!(!l.is_blocked());
@ -763,7 +777,7 @@ pub(crate) mod tests {
#[test]
fn test_rate_limiter_full() {
// rate limiter with limit of 1000 bytes/s and 1000 ops/s
let mut l = RateLimiter::new(1000, 0, 1000, 1000, 0, 1000).unwrap();
let l = RateLimiter::new(1000, 0, 1000, 1000, 0, 1000).unwrap();
// limiter should not be blocked
assert!(!l.is_blocked());
@ -799,7 +813,7 @@ pub(crate) mod tests {
#[test]
fn test_rate_limiter_overconsumption() {
// initialize the rate limiter
let mut l = RateLimiter::new(1000, 0, 1000, 1000, 0, 1000).unwrap();
let l = RateLimiter::new(1000, 0, 1000, 1000, 0, 1000).unwrap();
// try to consume 2.5x the bucket size
// we are "borrowing" 1.5x the bucket size in tokens since
// the bucket is full
@ -818,7 +832,7 @@ pub(crate) mod tests {
assert!(!l.is_blocked());
// reset the rate limiter
let mut l = RateLimiter::new(1000, 0, 1000, 1000, 0, 1000).unwrap();
let l = RateLimiter::new(1000, 0, 1000, 1000, 0, 1000).unwrap();
// try to consume 1.5x the bucket size
// we are "borrowing" 1.5x the bucket size in tokens since
// the bucket is full, should arm the timer to 0.5x replenish
@ -857,12 +871,12 @@ pub(crate) mod tests {
fn test_update_buckets() {
let mut x = RateLimiter::new(1000, 2000, 1000, 10, 20, 1000).unwrap();
let initial_bw = x.bandwidth.clone();
let initial_ops = x.ops.clone();
let initial_bw = x.bandwidth();
let initial_ops = x.ops();
x.update_buckets(BucketUpdate::None, BucketUpdate::None);
assert_eq!(x.bandwidth, initial_bw);
assert_eq!(x.ops, initial_ops);
assert_eq!(x.bandwidth(), initial_bw);
assert_eq!(x.ops(), initial_ops);
let new_bw = TokenBucket::new(123, 0, 57).unwrap();
let new_ops = TokenBucket::new(321, 12346, 89).unwrap();
@ -871,18 +885,21 @@ pub(crate) mod tests {
BucketUpdate::Update(new_ops.clone()),
);
// We have manually adjust the last_update field, because it changes when update_buckets()
// constructs new buckets (and thus gets a different value for last_update). We do this so
// it makes sense to test the following assertions.
x.bandwidth.as_mut().unwrap().last_update = new_bw.last_update;
x.ops.as_mut().unwrap().last_update = new_ops.last_update;
{
let mut guard = x.inner.lock().unwrap();
// We have manually adjust the last_update field, because it changes when update_buckets()
// constructs new buckets (and thus gets a different value for last_update). We do this so
// it makes sense to test the following assertions.
guard.bandwidth.as_mut().unwrap().last_update = new_bw.last_update;
guard.ops.as_mut().unwrap().last_update = new_ops.last_update;
}
assert_eq!(x.bandwidth, Some(new_bw));
assert_eq!(x.ops, Some(new_ops));
assert_eq!(x.bandwidth(), Some(new_bw));
assert_eq!(x.ops(), Some(new_ops));
x.update_buckets(BucketUpdate::Disabled, BucketUpdate::Disabled);
assert_eq!(x.bandwidth, None);
assert_eq!(x.ops, None);
assert_eq!(x.bandwidth(), None);
assert_eq!(x.ops(), None);
}
#[test]