608 lines
18 KiB
Rust
608 lines
18 KiB
Rust
//! Rate limiting for network requests using token bucket algorithm.
|
|
//!
|
|
//! This module provides rate limiting capabilities to prevent DoS attacks
|
|
//! and ensure fair resource allocation among peers.
|
|
|
|
use hashbrown::HashMap;
|
|
use libp2p::PeerId;
|
|
use parking_lot::RwLock;
|
|
use std::time::{Duration, Instant};
|
|
use tracing::warn;
|
|
|
|
/// Configuration for rate limiting.
|
|
#[derive(Clone, Debug)]
|
|
pub struct RateLimitConfig {
|
|
/// Maximum requests allowed per second (steady-state rate).
|
|
pub requests_per_second: f64,
|
|
/// Maximum burst size (bucket capacity).
|
|
pub burst_size: u32,
|
|
/// Time window in milliseconds for rate calculation.
|
|
pub window_ms: u64,
|
|
}
|
|
|
|
impl Default for RateLimitConfig {
|
|
fn default() -> Self {
|
|
RateLimitConfig {
|
|
requests_per_second: 100.0,
|
|
burst_size: 200,
|
|
window_ms: 1000,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl RateLimitConfig {
|
|
/// Creates a new rate limit configuration.
|
|
pub fn new(requests_per_second: f64, burst_size: u32, window_ms: u64) -> Self {
|
|
RateLimitConfig {
|
|
requests_per_second,
|
|
burst_size,
|
|
window_ms,
|
|
}
|
|
}
|
|
|
|
/// Creates a strict rate limit configuration.
|
|
pub fn strict() -> Self {
|
|
RateLimitConfig {
|
|
requests_per_second: 50.0,
|
|
burst_size: 100,
|
|
window_ms: 1000,
|
|
}
|
|
}
|
|
|
|
/// Creates a relaxed rate limit configuration.
|
|
pub fn relaxed() -> Self {
|
|
RateLimitConfig {
|
|
requests_per_second: 200.0,
|
|
burst_size: 400,
|
|
window_ms: 1000,
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Token bucket state for a single peer.
|
|
#[derive(Debug)]
|
|
struct TokenBucket {
|
|
/// Current number of available tokens.
|
|
tokens: f64,
|
|
/// Last time tokens were refilled.
|
|
last_refill: Instant,
|
|
/// Number of rate limit violations.
|
|
violations: u32,
|
|
}
|
|
|
|
impl TokenBucket {
|
|
/// Creates a new token bucket with full capacity.
|
|
fn new(capacity: u32) -> Self {
|
|
TokenBucket {
|
|
tokens: capacity as f64,
|
|
last_refill: Instant::now(),
|
|
violations: 0,
|
|
}
|
|
}
|
|
|
|
/// Refills tokens based on elapsed time.
|
|
fn refill(&mut self, rate: f64) {
|
|
let now = Instant::now();
|
|
let elapsed = now.duration_since(self.last_refill);
|
|
let new_tokens = elapsed.as_secs_f64() * rate;
|
|
self.tokens = (self.tokens + new_tokens).min(rate * 2.0); // Cap at 2x rate (burst capacity)
|
|
self.last_refill = now;
|
|
}
|
|
|
|
/// Attempts to consume a token. Returns true if successful.
|
|
fn try_consume(&mut self, rate: f64, burst: u32) -> bool {
|
|
self.refill(rate);
|
|
|
|
// Cap tokens at burst size
|
|
self.tokens = self.tokens.min(burst as f64);
|
|
|
|
if self.tokens >= 1.0 {
|
|
self.tokens -= 1.0;
|
|
true
|
|
} else {
|
|
self.violations += 1;
|
|
false
|
|
}
|
|
}
|
|
|
|
/// Returns the time until the next token is available.
|
|
fn time_until_available(&self, rate: f64) -> Duration {
|
|
if self.tokens >= 1.0 {
|
|
Duration::ZERO
|
|
} else {
|
|
let tokens_needed = 1.0 - self.tokens;
|
|
let seconds_needed = tokens_needed / rate;
|
|
Duration::from_secs_f64(seconds_needed)
|
|
}
|
|
}
|
|
}
|
|
|
|
/// A rate limiter using the token bucket algorithm.
|
|
pub struct RateLimiter {
|
|
/// Configuration for rate limiting.
|
|
config: RateLimitConfig,
|
|
/// Token bucket for tracking requests.
|
|
bucket: RwLock<TokenBucket>,
|
|
}
|
|
|
|
impl RateLimiter {
|
|
/// Creates a new rate limiter with the given configuration.
|
|
pub fn new(config: RateLimitConfig) -> Self {
|
|
let bucket = TokenBucket::new(config.burst_size);
|
|
RateLimiter {
|
|
config,
|
|
bucket: RwLock::new(bucket),
|
|
}
|
|
}
|
|
|
|
/// Creates a rate limiter with default configuration.
|
|
pub fn with_defaults() -> Self {
|
|
Self::new(RateLimitConfig::default())
|
|
}
|
|
|
|
/// Checks if a request is allowed and consumes a token if so.
|
|
pub fn check_and_consume(&self) -> bool {
|
|
let mut bucket = self.bucket.write();
|
|
bucket.try_consume(self.config.requests_per_second, self.config.burst_size)
|
|
}
|
|
|
|
/// Returns the current number of available tokens.
|
|
pub fn available_tokens(&self) -> f64 {
|
|
let mut bucket = self.bucket.write();
|
|
bucket.refill(self.config.requests_per_second);
|
|
bucket.tokens
|
|
}
|
|
|
|
/// Returns the time until the next token is available.
|
|
pub fn time_until_available(&self) -> Duration {
|
|
let bucket = self.bucket.read();
|
|
bucket.time_until_available(self.config.requests_per_second)
|
|
}
|
|
|
|
/// Returns the number of rate limit violations.
|
|
pub fn violations(&self) -> u32 {
|
|
self.bucket.read().violations
|
|
}
|
|
|
|
/// Resets the rate limiter state.
|
|
pub fn reset(&self) {
|
|
let mut bucket = self.bucket.write();
|
|
bucket.tokens = self.config.burst_size as f64;
|
|
bucket.last_refill = Instant::now();
|
|
bucket.violations = 0;
|
|
}
|
|
}
|
|
|
|
/// Per-peer rate limiter that tracks rate limits for each peer.
|
|
pub struct PerPeerLimiter {
|
|
/// Configuration for rate limiting.
|
|
config: RateLimitConfig,
|
|
/// Token buckets per peer.
|
|
peers: RwLock<HashMap<PeerId, TokenBucket>>,
|
|
/// Maximum violation threshold before cooldown.
|
|
max_violations: u32,
|
|
/// Cooldown duration after exceeding max violations.
|
|
cooldown_duration: Duration,
|
|
/// Peers currently in cooldown.
|
|
cooldowns: RwLock<HashMap<PeerId, Instant>>,
|
|
}
|
|
|
|
impl PerPeerLimiter {
|
|
/// Creates a new per-peer rate limiter with the given configuration.
|
|
pub fn new(config: RateLimitConfig) -> Self {
|
|
PerPeerLimiter {
|
|
config,
|
|
peers: RwLock::new(HashMap::new()),
|
|
max_violations: 10,
|
|
cooldown_duration: Duration::from_secs(60),
|
|
cooldowns: RwLock::new(HashMap::new()),
|
|
}
|
|
}
|
|
|
|
/// Creates a per-peer rate limiter with default configuration.
|
|
pub fn with_defaults() -> Self {
|
|
Self::new(RateLimitConfig::default())
|
|
}
|
|
|
|
/// Creates a per-peer rate limiter with custom violation thresholds.
|
|
pub fn with_violation_config(
|
|
config: RateLimitConfig,
|
|
max_violations: u32,
|
|
cooldown_duration: Duration,
|
|
) -> Self {
|
|
PerPeerLimiter {
|
|
config,
|
|
peers: RwLock::new(HashMap::new()),
|
|
max_violations,
|
|
cooldown_duration,
|
|
cooldowns: RwLock::new(HashMap::new()),
|
|
}
|
|
}
|
|
|
|
/// Checks if a request from the given peer is allowed.
|
|
/// Returns true if allowed, false if rate limited.
|
|
pub fn check_rate_limit(&self, peer_id: &PeerId) -> bool {
|
|
// Check if peer is in cooldown
|
|
{
|
|
let cooldowns = self.cooldowns.read();
|
|
if let Some(cooldown_until) = cooldowns.get(peer_id) {
|
|
if Instant::now() < *cooldown_until {
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Remove expired cooldown
|
|
{
|
|
let mut cooldowns = self.cooldowns.write();
|
|
if let Some(cooldown_until) = cooldowns.get(peer_id) {
|
|
if Instant::now() >= *cooldown_until {
|
|
cooldowns.remove(peer_id);
|
|
}
|
|
}
|
|
}
|
|
|
|
let mut peers = self.peers.write();
|
|
let bucket = peers
|
|
.entry(*peer_id)
|
|
.or_insert_with(|| TokenBucket::new(self.config.burst_size));
|
|
|
|
// Refill tokens before checking
|
|
bucket.refill(self.config.requests_per_second);
|
|
|
|
// Check if tokens are available (but don't consume yet)
|
|
bucket.tokens >= 1.0
|
|
}
|
|
|
|
/// Records a request from the given peer, consuming a token.
|
|
/// Should be called after check_rate_limit returns true.
|
|
pub fn record_request(&self, peer_id: &PeerId) {
|
|
let mut peers = self.peers.write();
|
|
let bucket = peers
|
|
.entry(*peer_id)
|
|
.or_insert_with(|| TokenBucket::new(self.config.burst_size));
|
|
|
|
if !bucket.try_consume(self.config.requests_per_second, self.config.burst_size) {
|
|
// Check if we need to put peer in cooldown
|
|
if bucket.violations >= self.max_violations {
|
|
warn!(
|
|
"Peer {} exceeded max violations ({}), entering cooldown",
|
|
peer_id, bucket.violations
|
|
);
|
|
drop(peers); // Release lock before acquiring cooldowns lock
|
|
let mut cooldowns = self.cooldowns.write();
|
|
cooldowns.insert(*peer_id, Instant::now() + self.cooldown_duration);
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Checks and records a request in a single operation.
|
|
/// Returns true if the request was allowed.
|
|
pub fn check_and_record(&self, peer_id: &PeerId) -> bool {
|
|
// Check if peer is in cooldown
|
|
{
|
|
let cooldowns = self.cooldowns.read();
|
|
if let Some(cooldown_until) = cooldowns.get(peer_id) {
|
|
if Instant::now() < *cooldown_until {
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Remove expired cooldown and check/record request
|
|
{
|
|
let mut cooldowns = self.cooldowns.write();
|
|
if let Some(cooldown_until) = cooldowns.get(peer_id) {
|
|
if Instant::now() >= *cooldown_until {
|
|
cooldowns.remove(peer_id);
|
|
}
|
|
}
|
|
}
|
|
|
|
let mut peers = self.peers.write();
|
|
let bucket = peers
|
|
.entry(*peer_id)
|
|
.or_insert_with(|| TokenBucket::new(self.config.burst_size));
|
|
|
|
let allowed = bucket.try_consume(self.config.requests_per_second, self.config.burst_size);
|
|
|
|
if !allowed && bucket.violations >= self.max_violations {
|
|
warn!(
|
|
"Peer {} exceeded max violations ({}), entering cooldown",
|
|
peer_id, bucket.violations
|
|
);
|
|
drop(peers);
|
|
let mut cooldowns = self.cooldowns.write();
|
|
cooldowns.insert(*peer_id, Instant::now() + self.cooldown_duration);
|
|
}
|
|
|
|
allowed
|
|
}
|
|
|
|
/// Returns the number of available tokens for a peer.
|
|
pub fn available_tokens(&self, peer_id: &PeerId) -> f64 {
|
|
let mut peers = self.peers.write();
|
|
if let Some(bucket) = peers.get_mut(peer_id) {
|
|
bucket.refill(self.config.requests_per_second);
|
|
bucket.tokens
|
|
} else {
|
|
self.config.burst_size as f64
|
|
}
|
|
}
|
|
|
|
/// Returns the number of violations for a peer.
|
|
pub fn violations(&self, peer_id: &PeerId) -> u32 {
|
|
let peers = self.peers.read();
|
|
peers.get(peer_id).map(|b| b.violations).unwrap_or(0)
|
|
}
|
|
|
|
/// Returns true if the peer is currently in cooldown.
|
|
pub fn is_in_cooldown(&self, peer_id: &PeerId) -> bool {
|
|
let cooldowns = self.cooldowns.read();
|
|
if let Some(cooldown_until) = cooldowns.get(peer_id) {
|
|
Instant::now() < *cooldown_until
|
|
} else {
|
|
false
|
|
}
|
|
}
|
|
|
|
/// Returns the remaining cooldown duration for a peer, if any.
|
|
pub fn cooldown_remaining(&self, peer_id: &PeerId) -> Option<Duration> {
|
|
let cooldowns = self.cooldowns.read();
|
|
if let Some(cooldown_until) = cooldowns.get(peer_id) {
|
|
let now = Instant::now();
|
|
if now < *cooldown_until {
|
|
Some(*cooldown_until - now)
|
|
} else {
|
|
None
|
|
}
|
|
} else {
|
|
None
|
|
}
|
|
}
|
|
|
|
/// Resets the rate limit state for a specific peer.
|
|
pub fn reset_peer(&self, peer_id: &PeerId) {
|
|
self.peers.write().remove(peer_id);
|
|
self.cooldowns.write().remove(peer_id);
|
|
}
|
|
|
|
/// Cleans up state for disconnected peers.
|
|
pub fn cleanup(&self, connected_peers: &[PeerId]) {
|
|
let connected: std::collections::HashSet<_> = connected_peers.iter().collect();
|
|
self.peers.write().retain(|id, _| connected.contains(id));
|
|
self.cooldowns
|
|
.write()
|
|
.retain(|id, _| connected.contains(id));
|
|
}
|
|
|
|
/// Returns the number of tracked peers.
|
|
pub fn peer_count(&self) -> usize {
|
|
self.peers.read().len()
|
|
}
|
|
|
|
/// Returns the configuration.
|
|
pub fn config(&self) -> &RateLimitConfig {
|
|
&self.config
|
|
}
|
|
}
|
|
|
|
/// Error returned when a request is rate limited.
|
|
#[derive(Debug, Clone)]
|
|
pub struct RateLimitError {
|
|
/// The peer that was rate limited.
|
|
pub peer_id: PeerId,
|
|
/// Time until the rate limit resets.
|
|
pub retry_after: Duration,
|
|
/// Whether the peer is in cooldown due to excessive violations.
|
|
pub in_cooldown: bool,
|
|
}
|
|
|
|
impl std::fmt::Display for RateLimitError {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
if self.in_cooldown {
|
|
write!(
|
|
f,
|
|
"Peer {} is in cooldown, retry after {:?}",
|
|
self.peer_id, self.retry_after
|
|
)
|
|
} else {
|
|
write!(
|
|
f,
|
|
"Rate limited peer {}, retry after {:?}",
|
|
self.peer_id, self.retry_after
|
|
)
|
|
}
|
|
}
|
|
}
|
|
|
|
impl std::error::Error for RateLimitError {}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use std::thread::sleep;
|
|
|
|
fn random_peer_id() -> PeerId {
|
|
PeerId::random()
|
|
}
|
|
|
|
#[test]
|
|
fn test_rate_limiter_allows_initial_burst() {
|
|
let config = RateLimitConfig {
|
|
requests_per_second: 10.0,
|
|
burst_size: 20,
|
|
window_ms: 1000,
|
|
};
|
|
let limiter = RateLimiter::new(config);
|
|
|
|
// Should allow burst_size requests
|
|
for _ in 0..20 {
|
|
assert!(limiter.check_and_consume());
|
|
}
|
|
|
|
// 21st should be denied
|
|
assert!(!limiter.check_and_consume());
|
|
}
|
|
|
|
#[test]
|
|
fn test_rate_limiter_refills_over_time() {
|
|
let config = RateLimitConfig {
|
|
requests_per_second: 100.0,
|
|
burst_size: 10,
|
|
window_ms: 1000,
|
|
};
|
|
let limiter = RateLimiter::new(config);
|
|
|
|
// Exhaust all tokens
|
|
for _ in 0..10 {
|
|
limiter.check_and_consume();
|
|
}
|
|
assert!(!limiter.check_and_consume());
|
|
|
|
// Wait for tokens to refill (at 100 req/s, 10ms = 1 token)
|
|
sleep(Duration::from_millis(15));
|
|
|
|
// Should have at least one token now
|
|
assert!(limiter.check_and_consume());
|
|
}
|
|
|
|
#[test]
|
|
fn test_per_peer_limiter_isolates_peers() {
|
|
let config = RateLimitConfig {
|
|
requests_per_second: 10.0,
|
|
burst_size: 5,
|
|
window_ms: 1000,
|
|
};
|
|
let limiter = PerPeerLimiter::new(config);
|
|
|
|
let peer1 = random_peer_id();
|
|
let peer2 = random_peer_id();
|
|
|
|
// Exhaust peer1's tokens
|
|
for _ in 0..5 {
|
|
assert!(limiter.check_and_record(&peer1));
|
|
}
|
|
assert!(!limiter.check_and_record(&peer1));
|
|
|
|
// peer2 should still have tokens
|
|
for _ in 0..5 {
|
|
assert!(limiter.check_and_record(&peer2));
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_per_peer_limiter_check_and_record() {
|
|
let config = RateLimitConfig {
|
|
requests_per_second: 10.0,
|
|
burst_size: 3,
|
|
window_ms: 1000,
|
|
};
|
|
let limiter = PerPeerLimiter::new(config);
|
|
|
|
let peer = random_peer_id();
|
|
|
|
// check_rate_limit should return true without consuming
|
|
assert!(limiter.check_rate_limit(&peer));
|
|
assert!(limiter.check_rate_limit(&peer));
|
|
|
|
// record_request should consume tokens
|
|
limiter.record_request(&peer);
|
|
limiter.record_request(&peer);
|
|
limiter.record_request(&peer);
|
|
|
|
// Now should be rate limited
|
|
assert!(!limiter.check_rate_limit(&peer));
|
|
}
|
|
|
|
#[test]
|
|
fn test_per_peer_limiter_cooldown() {
|
|
let config = RateLimitConfig {
|
|
requests_per_second: 100.0,
|
|
burst_size: 2,
|
|
window_ms: 1000,
|
|
};
|
|
let limiter = PerPeerLimiter::with_violation_config(
|
|
config,
|
|
3, // max_violations
|
|
Duration::from_millis(100), // short cooldown for test
|
|
);
|
|
|
|
let peer = random_peer_id();
|
|
|
|
// Exhaust tokens and trigger violations
|
|
limiter.check_and_record(&peer);
|
|
limiter.check_and_record(&peer);
|
|
|
|
// These should trigger violations
|
|
limiter.check_and_record(&peer);
|
|
limiter.check_and_record(&peer);
|
|
limiter.check_and_record(&peer);
|
|
|
|
// Should be in cooldown now
|
|
assert!(limiter.is_in_cooldown(&peer));
|
|
assert!(!limiter.check_rate_limit(&peer));
|
|
|
|
// Wait for cooldown to expire
|
|
sleep(Duration::from_millis(150));
|
|
|
|
// Should be out of cooldown
|
|
assert!(!limiter.is_in_cooldown(&peer));
|
|
}
|
|
|
|
#[test]
|
|
fn test_per_peer_limiter_cleanup() {
|
|
let config = RateLimitConfig::default();
|
|
let limiter = PerPeerLimiter::new(config);
|
|
|
|
let peer1 = random_peer_id();
|
|
let peer2 = random_peer_id();
|
|
let peer3 = random_peer_id();
|
|
|
|
// Add some peers
|
|
limiter.record_request(&peer1);
|
|
limiter.record_request(&peer2);
|
|
limiter.record_request(&peer3);
|
|
|
|
assert_eq!(limiter.peer_count(), 3);
|
|
|
|
// Cleanup keeping only peer1 and peer2
|
|
limiter.cleanup(&[peer1, peer2]);
|
|
|
|
assert_eq!(limiter.peer_count(), 2);
|
|
}
|
|
|
|
#[test]
|
|
fn test_rate_limit_config_defaults() {
|
|
let config = RateLimitConfig::default();
|
|
assert_eq!(config.requests_per_second, 100.0);
|
|
assert_eq!(config.burst_size, 200);
|
|
assert_eq!(config.window_ms, 1000);
|
|
}
|
|
|
|
#[test]
|
|
fn test_violations_tracking() {
|
|
let config = RateLimitConfig {
|
|
requests_per_second: 10.0,
|
|
burst_size: 2,
|
|
window_ms: 1000,
|
|
};
|
|
let limiter = PerPeerLimiter::new(config);
|
|
|
|
let peer = random_peer_id();
|
|
|
|
// Initial violations should be 0
|
|
assert_eq!(limiter.violations(&peer), 0);
|
|
|
|
// Exhaust tokens
|
|
limiter.check_and_record(&peer);
|
|
limiter.check_and_record(&peer);
|
|
|
|
// These should cause violations
|
|
limiter.check_and_record(&peer);
|
|
limiter.check_and_record(&peer);
|
|
|
|
// Should have violations now
|
|
assert!(limiter.violations(&peer) >= 2);
|
|
}
|
|
}
|