//! Rate limiting for network requests using token bucket algorithm. //! //! This module provides rate limiting capabilities to prevent DoS attacks //! and ensure fair resource allocation among peers. use hashbrown::HashMap; use libp2p::PeerId; use parking_lot::RwLock; use std::time::{Duration, Instant}; use tracing::warn; /// Configuration for rate limiting. #[derive(Clone, Debug)] pub struct RateLimitConfig { /// Maximum requests allowed per second (steady-state rate). pub requests_per_second: f64, /// Maximum burst size (bucket capacity). pub burst_size: u32, /// Time window in milliseconds for rate calculation. pub window_ms: u64, } impl Default for RateLimitConfig { fn default() -> Self { RateLimitConfig { requests_per_second: 100.0, burst_size: 200, window_ms: 1000, } } } impl RateLimitConfig { /// Creates a new rate limit configuration. pub fn new(requests_per_second: f64, burst_size: u32, window_ms: u64) -> Self { RateLimitConfig { requests_per_second, burst_size, window_ms, } } /// Creates a strict rate limit configuration. pub fn strict() -> Self { RateLimitConfig { requests_per_second: 50.0, burst_size: 100, window_ms: 1000, } } /// Creates a relaxed rate limit configuration. pub fn relaxed() -> Self { RateLimitConfig { requests_per_second: 200.0, burst_size: 400, window_ms: 1000, } } } /// Token bucket state for a single peer. #[derive(Debug)] struct TokenBucket { /// Current number of available tokens. tokens: f64, /// Last time tokens were refilled. last_refill: Instant, /// Number of rate limit violations. violations: u32, } impl TokenBucket { /// Creates a new token bucket with full capacity. fn new(capacity: u32) -> Self { TokenBucket { tokens: capacity as f64, last_refill: Instant::now(), violations: 0, } } /// Refills tokens based on elapsed time. fn refill(&mut self, rate: f64) { let now = Instant::now(); let elapsed = now.duration_since(self.last_refill); let new_tokens = elapsed.as_secs_f64() * rate; self.tokens = (self.tokens + new_tokens).min(rate * 2.0); // Cap at 2x rate (burst capacity) self.last_refill = now; } /// Attempts to consume a token. Returns true if successful. fn try_consume(&mut self, rate: f64, burst: u32) -> bool { self.refill(rate); // Cap tokens at burst size self.tokens = self.tokens.min(burst as f64); if self.tokens >= 1.0 { self.tokens -= 1.0; true } else { self.violations += 1; false } } /// Returns the time until the next token is available. fn time_until_available(&self, rate: f64) -> Duration { if self.tokens >= 1.0 { Duration::ZERO } else { let tokens_needed = 1.0 - self.tokens; let seconds_needed = tokens_needed / rate; Duration::from_secs_f64(seconds_needed) } } } /// A rate limiter using the token bucket algorithm. pub struct RateLimiter { /// Configuration for rate limiting. config: RateLimitConfig, /// Token bucket for tracking requests. bucket: RwLock, } impl RateLimiter { /// Creates a new rate limiter with the given configuration. pub fn new(config: RateLimitConfig) -> Self { let bucket = TokenBucket::new(config.burst_size); RateLimiter { config, bucket: RwLock::new(bucket), } } /// Creates a rate limiter with default configuration. pub fn with_defaults() -> Self { Self::new(RateLimitConfig::default()) } /// Checks if a request is allowed and consumes a token if so. pub fn check_and_consume(&self) -> bool { let mut bucket = self.bucket.write(); bucket.try_consume(self.config.requests_per_second, self.config.burst_size) } /// Returns the current number of available tokens. pub fn available_tokens(&self) -> f64 { let mut bucket = self.bucket.write(); bucket.refill(self.config.requests_per_second); bucket.tokens } /// Returns the time until the next token is available. pub fn time_until_available(&self) -> Duration { let bucket = self.bucket.read(); bucket.time_until_available(self.config.requests_per_second) } /// Returns the number of rate limit violations. pub fn violations(&self) -> u32 { self.bucket.read().violations } /// Resets the rate limiter state. pub fn reset(&self) { let mut bucket = self.bucket.write(); bucket.tokens = self.config.burst_size as f64; bucket.last_refill = Instant::now(); bucket.violations = 0; } } /// Per-peer rate limiter that tracks rate limits for each peer. pub struct PerPeerLimiter { /// Configuration for rate limiting. config: RateLimitConfig, /// Token buckets per peer. peers: RwLock>, /// Maximum violation threshold before cooldown. max_violations: u32, /// Cooldown duration after exceeding max violations. cooldown_duration: Duration, /// Peers currently in cooldown. cooldowns: RwLock>, } impl PerPeerLimiter { /// Creates a new per-peer rate limiter with the given configuration. pub fn new(config: RateLimitConfig) -> Self { PerPeerLimiter { config, peers: RwLock::new(HashMap::new()), max_violations: 10, cooldown_duration: Duration::from_secs(60), cooldowns: RwLock::new(HashMap::new()), } } /// Creates a per-peer rate limiter with default configuration. pub fn with_defaults() -> Self { Self::new(RateLimitConfig::default()) } /// Creates a per-peer rate limiter with custom violation thresholds. pub fn with_violation_config( config: RateLimitConfig, max_violations: u32, cooldown_duration: Duration, ) -> Self { PerPeerLimiter { config, peers: RwLock::new(HashMap::new()), max_violations, cooldown_duration, cooldowns: RwLock::new(HashMap::new()), } } /// Checks if a request from the given peer is allowed. /// Returns true if allowed, false if rate limited. pub fn check_rate_limit(&self, peer_id: &PeerId) -> bool { // Check if peer is in cooldown { let cooldowns = self.cooldowns.read(); if let Some(cooldown_until) = cooldowns.get(peer_id) { if Instant::now() < *cooldown_until { return false; } } } // Remove expired cooldown { let mut cooldowns = self.cooldowns.write(); if let Some(cooldown_until) = cooldowns.get(peer_id) { if Instant::now() >= *cooldown_until { cooldowns.remove(peer_id); } } } let mut peers = self.peers.write(); let bucket = peers .entry(*peer_id) .or_insert_with(|| TokenBucket::new(self.config.burst_size)); // Refill tokens before checking bucket.refill(self.config.requests_per_second); // Check if tokens are available (but don't consume yet) bucket.tokens >= 1.0 } /// Records a request from the given peer, consuming a token. /// Should be called after check_rate_limit returns true. pub fn record_request(&self, peer_id: &PeerId) { let mut peers = self.peers.write(); let bucket = peers .entry(*peer_id) .or_insert_with(|| TokenBucket::new(self.config.burst_size)); if !bucket.try_consume(self.config.requests_per_second, self.config.burst_size) { // Check if we need to put peer in cooldown if bucket.violations >= self.max_violations { warn!( "Peer {} exceeded max violations ({}), entering cooldown", peer_id, bucket.violations ); drop(peers); // Release lock before acquiring cooldowns lock let mut cooldowns = self.cooldowns.write(); cooldowns.insert(*peer_id, Instant::now() + self.cooldown_duration); } } } /// Checks and records a request in a single operation. /// Returns true if the request was allowed. pub fn check_and_record(&self, peer_id: &PeerId) -> bool { // Check if peer is in cooldown { let cooldowns = self.cooldowns.read(); if let Some(cooldown_until) = cooldowns.get(peer_id) { if Instant::now() < *cooldown_until { return false; } } } // Remove expired cooldown and check/record request { let mut cooldowns = self.cooldowns.write(); if let Some(cooldown_until) = cooldowns.get(peer_id) { if Instant::now() >= *cooldown_until { cooldowns.remove(peer_id); } } } let mut peers = self.peers.write(); let bucket = peers .entry(*peer_id) .or_insert_with(|| TokenBucket::new(self.config.burst_size)); let allowed = bucket.try_consume(self.config.requests_per_second, self.config.burst_size); if !allowed && bucket.violations >= self.max_violations { warn!( "Peer {} exceeded max violations ({}), entering cooldown", peer_id, bucket.violations ); drop(peers); let mut cooldowns = self.cooldowns.write(); cooldowns.insert(*peer_id, Instant::now() + self.cooldown_duration); } allowed } /// Returns the number of available tokens for a peer. pub fn available_tokens(&self, peer_id: &PeerId) -> f64 { let mut peers = self.peers.write(); if let Some(bucket) = peers.get_mut(peer_id) { bucket.refill(self.config.requests_per_second); bucket.tokens } else { self.config.burst_size as f64 } } /// Returns the number of violations for a peer. pub fn violations(&self, peer_id: &PeerId) -> u32 { let peers = self.peers.read(); peers.get(peer_id).map(|b| b.violations).unwrap_or(0) } /// Returns true if the peer is currently in cooldown. pub fn is_in_cooldown(&self, peer_id: &PeerId) -> bool { let cooldowns = self.cooldowns.read(); if let Some(cooldown_until) = cooldowns.get(peer_id) { Instant::now() < *cooldown_until } else { false } } /// Returns the remaining cooldown duration for a peer, if any. pub fn cooldown_remaining(&self, peer_id: &PeerId) -> Option { let cooldowns = self.cooldowns.read(); if let Some(cooldown_until) = cooldowns.get(peer_id) { let now = Instant::now(); if now < *cooldown_until { Some(*cooldown_until - now) } else { None } } else { None } } /// Resets the rate limit state for a specific peer. pub fn reset_peer(&self, peer_id: &PeerId) { self.peers.write().remove(peer_id); self.cooldowns.write().remove(peer_id); } /// Cleans up state for disconnected peers. pub fn cleanup(&self, connected_peers: &[PeerId]) { let connected: std::collections::HashSet<_> = connected_peers.iter().collect(); self.peers.write().retain(|id, _| connected.contains(id)); self.cooldowns .write() .retain(|id, _| connected.contains(id)); } /// Returns the number of tracked peers. pub fn peer_count(&self) -> usize { self.peers.read().len() } /// Returns the configuration. pub fn config(&self) -> &RateLimitConfig { &self.config } } /// Error returned when a request is rate limited. #[derive(Debug, Clone)] pub struct RateLimitError { /// The peer that was rate limited. pub peer_id: PeerId, /// Time until the rate limit resets. pub retry_after: Duration, /// Whether the peer is in cooldown due to excessive violations. pub in_cooldown: bool, } impl std::fmt::Display for RateLimitError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { if self.in_cooldown { write!( f, "Peer {} is in cooldown, retry after {:?}", self.peer_id, self.retry_after ) } else { write!( f, "Rate limited peer {}, retry after {:?}", self.peer_id, self.retry_after ) } } } impl std::error::Error for RateLimitError {} #[cfg(test)] mod tests { use super::*; use std::thread::sleep; fn random_peer_id() -> PeerId { PeerId::random() } #[test] fn test_rate_limiter_allows_initial_burst() { let config = RateLimitConfig { requests_per_second: 10.0, burst_size: 20, window_ms: 1000, }; let limiter = RateLimiter::new(config); // Should allow burst_size requests for _ in 0..20 { assert!(limiter.check_and_consume()); } // 21st should be denied assert!(!limiter.check_and_consume()); } #[test] fn test_rate_limiter_refills_over_time() { let config = RateLimitConfig { requests_per_second: 100.0, burst_size: 10, window_ms: 1000, }; let limiter = RateLimiter::new(config); // Exhaust all tokens for _ in 0..10 { limiter.check_and_consume(); } assert!(!limiter.check_and_consume()); // Wait for tokens to refill (at 100 req/s, 10ms = 1 token) sleep(Duration::from_millis(15)); // Should have at least one token now assert!(limiter.check_and_consume()); } #[test] fn test_per_peer_limiter_isolates_peers() { let config = RateLimitConfig { requests_per_second: 10.0, burst_size: 5, window_ms: 1000, }; let limiter = PerPeerLimiter::new(config); let peer1 = random_peer_id(); let peer2 = random_peer_id(); // Exhaust peer1's tokens for _ in 0..5 { assert!(limiter.check_and_record(&peer1)); } assert!(!limiter.check_and_record(&peer1)); // peer2 should still have tokens for _ in 0..5 { assert!(limiter.check_and_record(&peer2)); } } #[test] fn test_per_peer_limiter_check_and_record() { let config = RateLimitConfig { requests_per_second: 10.0, burst_size: 3, window_ms: 1000, }; let limiter = PerPeerLimiter::new(config); let peer = random_peer_id(); // check_rate_limit should return true without consuming assert!(limiter.check_rate_limit(&peer)); assert!(limiter.check_rate_limit(&peer)); // record_request should consume tokens limiter.record_request(&peer); limiter.record_request(&peer); limiter.record_request(&peer); // Now should be rate limited assert!(!limiter.check_rate_limit(&peer)); } #[test] fn test_per_peer_limiter_cooldown() { let config = RateLimitConfig { requests_per_second: 100.0, burst_size: 2, window_ms: 1000, }; let limiter = PerPeerLimiter::with_violation_config( config, 3, // max_violations Duration::from_millis(100), // short cooldown for test ); let peer = random_peer_id(); // Exhaust tokens and trigger violations limiter.check_and_record(&peer); limiter.check_and_record(&peer); // These should trigger violations limiter.check_and_record(&peer); limiter.check_and_record(&peer); limiter.check_and_record(&peer); // Should be in cooldown now assert!(limiter.is_in_cooldown(&peer)); assert!(!limiter.check_rate_limit(&peer)); // Wait for cooldown to expire sleep(Duration::from_millis(150)); // Should be out of cooldown assert!(!limiter.is_in_cooldown(&peer)); } #[test] fn test_per_peer_limiter_cleanup() { let config = RateLimitConfig::default(); let limiter = PerPeerLimiter::new(config); let peer1 = random_peer_id(); let peer2 = random_peer_id(); let peer3 = random_peer_id(); // Add some peers limiter.record_request(&peer1); limiter.record_request(&peer2); limiter.record_request(&peer3); assert_eq!(limiter.peer_count(), 3); // Cleanup keeping only peer1 and peer2 limiter.cleanup(&[peer1, peer2]); assert_eq!(limiter.peer_count(), 2); } #[test] fn test_rate_limit_config_defaults() { let config = RateLimitConfig::default(); assert_eq!(config.requests_per_second, 100.0); assert_eq!(config.burst_size, 200); assert_eq!(config.window_ms, 1000); } #[test] fn test_violations_tracking() { let config = RateLimitConfig { requests_per_second: 10.0, burst_size: 2, window_ms: 1000, }; let limiter = PerPeerLimiter::new(config); let peer = random_peer_id(); // Initial violations should be 0 assert_eq!(limiter.violations(&peer), 0); // Exhaust tokens limiter.check_and_record(&peer); limiter.check_and_record(&peer); // These should cause violations limiter.check_and_record(&peer); limiter.check_and_record(&peer); // Should have violations now assert!(limiter.violations(&peer) >= 2); } }