synor/crates/synor-network/src/rate_limit.rs
2026-01-08 05:22:24 +05:30

608 lines
18 KiB
Rust

//! Rate limiting for network requests using token bucket algorithm.
//!
//! This module provides rate limiting capabilities to prevent DoS attacks
//! and ensure fair resource allocation among peers.
use hashbrown::HashMap;
use libp2p::PeerId;
use parking_lot::RwLock;
use std::time::{Duration, Instant};
use tracing::warn;
/// Configuration for rate limiting.
#[derive(Clone, Debug)]
pub struct RateLimitConfig {
/// Maximum requests allowed per second (steady-state rate).
pub requests_per_second: f64,
/// Maximum burst size (bucket capacity).
pub burst_size: u32,
/// Time window in milliseconds for rate calculation.
pub window_ms: u64,
}
impl Default for RateLimitConfig {
fn default() -> Self {
RateLimitConfig {
requests_per_second: 100.0,
burst_size: 200,
window_ms: 1000,
}
}
}
impl RateLimitConfig {
/// Creates a new rate limit configuration.
pub fn new(requests_per_second: f64, burst_size: u32, window_ms: u64) -> Self {
RateLimitConfig {
requests_per_second,
burst_size,
window_ms,
}
}
/// Creates a strict rate limit configuration.
pub fn strict() -> Self {
RateLimitConfig {
requests_per_second: 50.0,
burst_size: 100,
window_ms: 1000,
}
}
/// Creates a relaxed rate limit configuration.
pub fn relaxed() -> Self {
RateLimitConfig {
requests_per_second: 200.0,
burst_size: 400,
window_ms: 1000,
}
}
}
/// Token bucket state for a single peer.
#[derive(Debug)]
struct TokenBucket {
/// Current number of available tokens.
tokens: f64,
/// Last time tokens were refilled.
last_refill: Instant,
/// Number of rate limit violations.
violations: u32,
}
impl TokenBucket {
/// Creates a new token bucket with full capacity.
fn new(capacity: u32) -> Self {
TokenBucket {
tokens: capacity as f64,
last_refill: Instant::now(),
violations: 0,
}
}
/// Refills tokens based on elapsed time.
fn refill(&mut self, rate: f64) {
let now = Instant::now();
let elapsed = now.duration_since(self.last_refill);
let new_tokens = elapsed.as_secs_f64() * rate;
self.tokens = (self.tokens + new_tokens).min(rate * 2.0); // Cap at 2x rate (burst capacity)
self.last_refill = now;
}
/// Attempts to consume a token. Returns true if successful.
fn try_consume(&mut self, rate: f64, burst: u32) -> bool {
self.refill(rate);
// Cap tokens at burst size
self.tokens = self.tokens.min(burst as f64);
if self.tokens >= 1.0 {
self.tokens -= 1.0;
true
} else {
self.violations += 1;
false
}
}
/// Returns the time until the next token is available.
fn time_until_available(&self, rate: f64) -> Duration {
if self.tokens >= 1.0 {
Duration::ZERO
} else {
let tokens_needed = 1.0 - self.tokens;
let seconds_needed = tokens_needed / rate;
Duration::from_secs_f64(seconds_needed)
}
}
}
/// A rate limiter using the token bucket algorithm.
pub struct RateLimiter {
/// Configuration for rate limiting.
config: RateLimitConfig,
/// Token bucket for tracking requests.
bucket: RwLock<TokenBucket>,
}
impl RateLimiter {
/// Creates a new rate limiter with the given configuration.
pub fn new(config: RateLimitConfig) -> Self {
let bucket = TokenBucket::new(config.burst_size);
RateLimiter {
config,
bucket: RwLock::new(bucket),
}
}
/// Creates a rate limiter with default configuration.
pub fn with_defaults() -> Self {
Self::new(RateLimitConfig::default())
}
/// Checks if a request is allowed and consumes a token if so.
pub fn check_and_consume(&self) -> bool {
let mut bucket = self.bucket.write();
bucket.try_consume(self.config.requests_per_second, self.config.burst_size)
}
/// Returns the current number of available tokens.
pub fn available_tokens(&self) -> f64 {
let mut bucket = self.bucket.write();
bucket.refill(self.config.requests_per_second);
bucket.tokens
}
/// Returns the time until the next token is available.
pub fn time_until_available(&self) -> Duration {
let bucket = self.bucket.read();
bucket.time_until_available(self.config.requests_per_second)
}
/// Returns the number of rate limit violations.
pub fn violations(&self) -> u32 {
self.bucket.read().violations
}
/// Resets the rate limiter state.
pub fn reset(&self) {
let mut bucket = self.bucket.write();
bucket.tokens = self.config.burst_size as f64;
bucket.last_refill = Instant::now();
bucket.violations = 0;
}
}
/// Per-peer rate limiter that tracks rate limits for each peer.
pub struct PerPeerLimiter {
/// Configuration for rate limiting.
config: RateLimitConfig,
/// Token buckets per peer.
peers: RwLock<HashMap<PeerId, TokenBucket>>,
/// Maximum violation threshold before cooldown.
max_violations: u32,
/// Cooldown duration after exceeding max violations.
cooldown_duration: Duration,
/// Peers currently in cooldown.
cooldowns: RwLock<HashMap<PeerId, Instant>>,
}
impl PerPeerLimiter {
/// Creates a new per-peer rate limiter with the given configuration.
pub fn new(config: RateLimitConfig) -> Self {
PerPeerLimiter {
config,
peers: RwLock::new(HashMap::new()),
max_violations: 10,
cooldown_duration: Duration::from_secs(60),
cooldowns: RwLock::new(HashMap::new()),
}
}
/// Creates a per-peer rate limiter with default configuration.
pub fn with_defaults() -> Self {
Self::new(RateLimitConfig::default())
}
/// Creates a per-peer rate limiter with custom violation thresholds.
pub fn with_violation_config(
config: RateLimitConfig,
max_violations: u32,
cooldown_duration: Duration,
) -> Self {
PerPeerLimiter {
config,
peers: RwLock::new(HashMap::new()),
max_violations,
cooldown_duration,
cooldowns: RwLock::new(HashMap::new()),
}
}
/// Checks if a request from the given peer is allowed.
/// Returns true if allowed, false if rate limited.
pub fn check_rate_limit(&self, peer_id: &PeerId) -> bool {
// Check if peer is in cooldown
{
let cooldowns = self.cooldowns.read();
if let Some(cooldown_until) = cooldowns.get(peer_id) {
if Instant::now() < *cooldown_until {
return false;
}
}
}
// Remove expired cooldown
{
let mut cooldowns = self.cooldowns.write();
if let Some(cooldown_until) = cooldowns.get(peer_id) {
if Instant::now() >= *cooldown_until {
cooldowns.remove(peer_id);
}
}
}
let mut peers = self.peers.write();
let bucket = peers
.entry(*peer_id)
.or_insert_with(|| TokenBucket::new(self.config.burst_size));
// Refill tokens before checking
bucket.refill(self.config.requests_per_second);
// Check if tokens are available (but don't consume yet)
bucket.tokens >= 1.0
}
/// Records a request from the given peer, consuming a token.
/// Should be called after check_rate_limit returns true.
pub fn record_request(&self, peer_id: &PeerId) {
let mut peers = self.peers.write();
let bucket = peers
.entry(*peer_id)
.or_insert_with(|| TokenBucket::new(self.config.burst_size));
if !bucket.try_consume(self.config.requests_per_second, self.config.burst_size) {
// Check if we need to put peer in cooldown
if bucket.violations >= self.max_violations {
warn!(
"Peer {} exceeded max violations ({}), entering cooldown",
peer_id, bucket.violations
);
drop(peers); // Release lock before acquiring cooldowns lock
let mut cooldowns = self.cooldowns.write();
cooldowns.insert(*peer_id, Instant::now() + self.cooldown_duration);
}
}
}
/// Checks and records a request in a single operation.
/// Returns true if the request was allowed.
pub fn check_and_record(&self, peer_id: &PeerId) -> bool {
// Check if peer is in cooldown
{
let cooldowns = self.cooldowns.read();
if let Some(cooldown_until) = cooldowns.get(peer_id) {
if Instant::now() < *cooldown_until {
return false;
}
}
}
// Remove expired cooldown and check/record request
{
let mut cooldowns = self.cooldowns.write();
if let Some(cooldown_until) = cooldowns.get(peer_id) {
if Instant::now() >= *cooldown_until {
cooldowns.remove(peer_id);
}
}
}
let mut peers = self.peers.write();
let bucket = peers
.entry(*peer_id)
.or_insert_with(|| TokenBucket::new(self.config.burst_size));
let allowed = bucket.try_consume(self.config.requests_per_second, self.config.burst_size);
if !allowed && bucket.violations >= self.max_violations {
warn!(
"Peer {} exceeded max violations ({}), entering cooldown",
peer_id, bucket.violations
);
drop(peers);
let mut cooldowns = self.cooldowns.write();
cooldowns.insert(*peer_id, Instant::now() + self.cooldown_duration);
}
allowed
}
/// Returns the number of available tokens for a peer.
pub fn available_tokens(&self, peer_id: &PeerId) -> f64 {
let mut peers = self.peers.write();
if let Some(bucket) = peers.get_mut(peer_id) {
bucket.refill(self.config.requests_per_second);
bucket.tokens
} else {
self.config.burst_size as f64
}
}
/// Returns the number of violations for a peer.
pub fn violations(&self, peer_id: &PeerId) -> u32 {
let peers = self.peers.read();
peers.get(peer_id).map(|b| b.violations).unwrap_or(0)
}
/// Returns true if the peer is currently in cooldown.
pub fn is_in_cooldown(&self, peer_id: &PeerId) -> bool {
let cooldowns = self.cooldowns.read();
if let Some(cooldown_until) = cooldowns.get(peer_id) {
Instant::now() < *cooldown_until
} else {
false
}
}
/// Returns the remaining cooldown duration for a peer, if any.
pub fn cooldown_remaining(&self, peer_id: &PeerId) -> Option<Duration> {
let cooldowns = self.cooldowns.read();
if let Some(cooldown_until) = cooldowns.get(peer_id) {
let now = Instant::now();
if now < *cooldown_until {
Some(*cooldown_until - now)
} else {
None
}
} else {
None
}
}
/// Resets the rate limit state for a specific peer.
pub fn reset_peer(&self, peer_id: &PeerId) {
self.peers.write().remove(peer_id);
self.cooldowns.write().remove(peer_id);
}
/// Cleans up state for disconnected peers.
pub fn cleanup(&self, connected_peers: &[PeerId]) {
let connected: std::collections::HashSet<_> = connected_peers.iter().collect();
self.peers.write().retain(|id, _| connected.contains(id));
self.cooldowns
.write()
.retain(|id, _| connected.contains(id));
}
/// Returns the number of tracked peers.
pub fn peer_count(&self) -> usize {
self.peers.read().len()
}
/// Returns the configuration.
pub fn config(&self) -> &RateLimitConfig {
&self.config
}
}
/// Error returned when a request is rate limited.
#[derive(Debug, Clone)]
pub struct RateLimitError {
/// The peer that was rate limited.
pub peer_id: PeerId,
/// Time until the rate limit resets.
pub retry_after: Duration,
/// Whether the peer is in cooldown due to excessive violations.
pub in_cooldown: bool,
}
impl std::fmt::Display for RateLimitError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.in_cooldown {
write!(
f,
"Peer {} is in cooldown, retry after {:?}",
self.peer_id, self.retry_after
)
} else {
write!(
f,
"Rate limited peer {}, retry after {:?}",
self.peer_id, self.retry_after
)
}
}
}
impl std::error::Error for RateLimitError {}
#[cfg(test)]
mod tests {
use super::*;
use std::thread::sleep;
fn random_peer_id() -> PeerId {
PeerId::random()
}
#[test]
fn test_rate_limiter_allows_initial_burst() {
let config = RateLimitConfig {
requests_per_second: 10.0,
burst_size: 20,
window_ms: 1000,
};
let limiter = RateLimiter::new(config);
// Should allow burst_size requests
for _ in 0..20 {
assert!(limiter.check_and_consume());
}
// 21st should be denied
assert!(!limiter.check_and_consume());
}
#[test]
fn test_rate_limiter_refills_over_time() {
let config = RateLimitConfig {
requests_per_second: 100.0,
burst_size: 10,
window_ms: 1000,
};
let limiter = RateLimiter::new(config);
// Exhaust all tokens
for _ in 0..10 {
limiter.check_and_consume();
}
assert!(!limiter.check_and_consume());
// Wait for tokens to refill (at 100 req/s, 10ms = 1 token)
sleep(Duration::from_millis(15));
// Should have at least one token now
assert!(limiter.check_and_consume());
}
#[test]
fn test_per_peer_limiter_isolates_peers() {
let config = RateLimitConfig {
requests_per_second: 10.0,
burst_size: 5,
window_ms: 1000,
};
let limiter = PerPeerLimiter::new(config);
let peer1 = random_peer_id();
let peer2 = random_peer_id();
// Exhaust peer1's tokens
for _ in 0..5 {
assert!(limiter.check_and_record(&peer1));
}
assert!(!limiter.check_and_record(&peer1));
// peer2 should still have tokens
for _ in 0..5 {
assert!(limiter.check_and_record(&peer2));
}
}
#[test]
fn test_per_peer_limiter_check_and_record() {
let config = RateLimitConfig {
requests_per_second: 10.0,
burst_size: 3,
window_ms: 1000,
};
let limiter = PerPeerLimiter::new(config);
let peer = random_peer_id();
// check_rate_limit should return true without consuming
assert!(limiter.check_rate_limit(&peer));
assert!(limiter.check_rate_limit(&peer));
// record_request should consume tokens
limiter.record_request(&peer);
limiter.record_request(&peer);
limiter.record_request(&peer);
// Now should be rate limited
assert!(!limiter.check_rate_limit(&peer));
}
#[test]
fn test_per_peer_limiter_cooldown() {
let config = RateLimitConfig {
requests_per_second: 100.0,
burst_size: 2,
window_ms: 1000,
};
let limiter = PerPeerLimiter::with_violation_config(
config,
3, // max_violations
Duration::from_millis(100), // short cooldown for test
);
let peer = random_peer_id();
// Exhaust tokens and trigger violations
limiter.check_and_record(&peer);
limiter.check_and_record(&peer);
// These should trigger violations
limiter.check_and_record(&peer);
limiter.check_and_record(&peer);
limiter.check_and_record(&peer);
// Should be in cooldown now
assert!(limiter.is_in_cooldown(&peer));
assert!(!limiter.check_rate_limit(&peer));
// Wait for cooldown to expire
sleep(Duration::from_millis(150));
// Should be out of cooldown
assert!(!limiter.is_in_cooldown(&peer));
}
#[test]
fn test_per_peer_limiter_cleanup() {
let config = RateLimitConfig::default();
let limiter = PerPeerLimiter::new(config);
let peer1 = random_peer_id();
let peer2 = random_peer_id();
let peer3 = random_peer_id();
// Add some peers
limiter.record_request(&peer1);
limiter.record_request(&peer2);
limiter.record_request(&peer3);
assert_eq!(limiter.peer_count(), 3);
// Cleanup keeping only peer1 and peer2
limiter.cleanup(&[peer1, peer2]);
assert_eq!(limiter.peer_count(), 2);
}
#[test]
fn test_rate_limit_config_defaults() {
let config = RateLimitConfig::default();
assert_eq!(config.requests_per_second, 100.0);
assert_eq!(config.burst_size, 200);
assert_eq!(config.window_ms, 1000);
}
#[test]
fn test_violations_tracking() {
let config = RateLimitConfig {
requests_per_second: 10.0,
burst_size: 2,
window_ms: 1000,
};
let limiter = PerPeerLimiter::new(config);
let peer = random_peer_id();
// Initial violations should be 0
assert_eq!(limiter.violations(&peer), 0);
// Exhaust tokens
limiter.check_and_record(&peer);
limiter.check_and_record(&peer);
// These should cause violations
limiter.check_and_record(&peer);
limiter.check_and_record(&peer);
// Should have violations now
assert!(limiter.violations(&peer) >= 2);
}
}