894 lines
28 KiB
Rust
894 lines
28 KiB
Rust
//! RPC connection pooling for outbound client connections.
|
|
//!
|
|
//! Provides efficient reuse of connections when making RPC calls to other nodes.
|
|
//! Uses a semaphore-based pool with health checking and automatic reconnection.
|
|
|
|
use std::collections::HashMap;
|
|
use std::sync::atomic::{AtomicU64, Ordering};
|
|
use std::sync::Arc;
|
|
use std::time::{Duration, Instant};
|
|
|
|
use jsonrpsee::http_client::{HttpClient, HttpClientBuilder};
|
|
use jsonrpsee::ws_client::{WsClient, WsClientBuilder};
|
|
use parking_lot::{Mutex, RwLock};
|
|
use tokio::sync::Semaphore;
|
|
|
|
/// Configuration for the connection pool.
|
|
#[derive(Clone, Debug)]
|
|
pub struct PoolConfig {
|
|
/// Maximum connections per endpoint.
|
|
pub max_connections_per_endpoint: usize,
|
|
/// Maximum total connections across all endpoints.
|
|
pub max_total_connections: usize,
|
|
/// Connection timeout.
|
|
pub connect_timeout: Duration,
|
|
/// Request timeout.
|
|
pub request_timeout: Duration,
|
|
/// Idle connection timeout (when to close unused connections).
|
|
pub idle_timeout: Duration,
|
|
/// Health check interval.
|
|
pub health_check_interval: Duration,
|
|
/// Maximum retries on connection failure.
|
|
pub max_retries: u32,
|
|
/// Retry backoff base duration.
|
|
pub retry_backoff: Duration,
|
|
}
|
|
|
|
impl Default for PoolConfig {
|
|
fn default() -> Self {
|
|
PoolConfig {
|
|
max_connections_per_endpoint: 4,
|
|
max_total_connections: 64,
|
|
connect_timeout: Duration::from_secs(10),
|
|
request_timeout: Duration::from_secs(30),
|
|
idle_timeout: Duration::from_secs(300), // 5 minutes
|
|
health_check_interval: Duration::from_secs(60),
|
|
max_retries: 3,
|
|
retry_backoff: Duration::from_millis(100),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl PoolConfig {
|
|
/// Configuration optimized for high-throughput scenarios.
|
|
pub fn high_throughput() -> Self {
|
|
PoolConfig {
|
|
max_connections_per_endpoint: 8,
|
|
max_total_connections: 128,
|
|
connect_timeout: Duration::from_secs(5),
|
|
request_timeout: Duration::from_secs(15),
|
|
idle_timeout: Duration::from_secs(120),
|
|
health_check_interval: Duration::from_secs(30),
|
|
max_retries: 2,
|
|
retry_backoff: Duration::from_millis(50),
|
|
}
|
|
}
|
|
|
|
/// Configuration optimized for low-latency scenarios.
|
|
pub fn low_latency() -> Self {
|
|
PoolConfig {
|
|
max_connections_per_endpoint: 2,
|
|
max_total_connections: 32,
|
|
connect_timeout: Duration::from_secs(3),
|
|
request_timeout: Duration::from_secs(10),
|
|
idle_timeout: Duration::from_secs(60),
|
|
health_check_interval: Duration::from_secs(15),
|
|
max_retries: 1,
|
|
retry_backoff: Duration::from_millis(25),
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Statistics for the connection pool.
|
|
#[derive(Debug, Default)]
|
|
pub struct PoolStats {
|
|
/// Total connections created.
|
|
pub connections_created: AtomicU64,
|
|
/// Total connections closed.
|
|
pub connections_closed: AtomicU64,
|
|
/// Total requests made.
|
|
pub requests_total: AtomicU64,
|
|
/// Failed requests.
|
|
pub requests_failed: AtomicU64,
|
|
/// Connection reuse count.
|
|
pub connections_reused: AtomicU64,
|
|
/// Health check failures.
|
|
pub health_check_failures: AtomicU64,
|
|
}
|
|
|
|
impl PoolStats {
|
|
/// Returns a snapshot of current statistics.
|
|
pub fn snapshot(&self) -> PoolStatsSnapshot {
|
|
PoolStatsSnapshot {
|
|
connections_created: self.connections_created.load(Ordering::Relaxed),
|
|
connections_closed: self.connections_closed.load(Ordering::Relaxed),
|
|
connections_active: self
|
|
.connections_created
|
|
.load(Ordering::Relaxed)
|
|
.saturating_sub(self.connections_closed.load(Ordering::Relaxed)),
|
|
requests_total: self.requests_total.load(Ordering::Relaxed),
|
|
requests_failed: self.requests_failed.load(Ordering::Relaxed),
|
|
connections_reused: self.connections_reused.load(Ordering::Relaxed),
|
|
health_check_failures: self.health_check_failures.load(Ordering::Relaxed),
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Snapshot of pool statistics.
|
|
#[derive(Clone, Debug)]
|
|
pub struct PoolStatsSnapshot {
|
|
pub connections_created: u64,
|
|
pub connections_closed: u64,
|
|
pub connections_active: u64,
|
|
pub requests_total: u64,
|
|
pub requests_failed: u64,
|
|
pub connections_reused: u64,
|
|
pub health_check_failures: u64,
|
|
}
|
|
|
|
/// A pooled HTTP client connection.
|
|
pub struct PooledHttpClient {
|
|
client: HttpClient,
|
|
endpoint: String,
|
|
created_at: Instant,
|
|
last_used: Mutex<Instant>,
|
|
request_count: AtomicU64,
|
|
}
|
|
|
|
impl PooledHttpClient {
|
|
/// Returns the underlying HTTP client.
|
|
pub fn client(&self) -> &HttpClient {
|
|
&self.client
|
|
}
|
|
|
|
/// Updates the last used timestamp.
|
|
pub fn touch(&self) {
|
|
*self.last_used.lock() = Instant::now();
|
|
self.request_count.fetch_add(1, Ordering::Relaxed);
|
|
}
|
|
|
|
/// Returns the endpoint URL.
|
|
pub fn endpoint(&self) -> &str {
|
|
&self.endpoint
|
|
}
|
|
|
|
/// Returns the connection age.
|
|
pub fn age(&self) -> Duration {
|
|
self.created_at.elapsed()
|
|
}
|
|
|
|
/// Returns time since last use.
|
|
pub fn idle_time(&self) -> Duration {
|
|
self.last_used.lock().elapsed()
|
|
}
|
|
|
|
/// Returns total requests made on this connection.
|
|
pub fn request_count(&self) -> u64 {
|
|
self.request_count.load(Ordering::Relaxed)
|
|
}
|
|
}
|
|
|
|
/// A pooled WebSocket client connection.
|
|
pub struct PooledWsClient {
|
|
client: WsClient,
|
|
endpoint: String,
|
|
created_at: Instant,
|
|
last_used: Mutex<Instant>,
|
|
request_count: AtomicU64,
|
|
}
|
|
|
|
impl PooledWsClient {
|
|
/// Returns the underlying WebSocket client.
|
|
pub fn client(&self) -> &WsClient {
|
|
&self.client
|
|
}
|
|
|
|
/// Updates the last used timestamp.
|
|
pub fn touch(&self) {
|
|
*self.last_used.lock() = Instant::now();
|
|
self.request_count.fetch_add(1, Ordering::Relaxed);
|
|
}
|
|
|
|
/// Returns the endpoint URL.
|
|
pub fn endpoint(&self) -> &str {
|
|
&self.endpoint
|
|
}
|
|
|
|
/// Returns the connection age.
|
|
pub fn age(&self) -> Duration {
|
|
self.created_at.elapsed()
|
|
}
|
|
|
|
/// Returns time since last use.
|
|
pub fn idle_time(&self) -> Duration {
|
|
self.last_used.lock().elapsed()
|
|
}
|
|
}
|
|
|
|
/// Endpoint entry in the pool.
|
|
struct EndpointEntry {
|
|
/// Available HTTP clients.
|
|
http_clients: Vec<Arc<PooledHttpClient>>,
|
|
/// Available WebSocket clients.
|
|
ws_clients: Vec<Arc<PooledWsClient>>,
|
|
/// Semaphore for limiting connections to this endpoint.
|
|
semaphore: Arc<Semaphore>,
|
|
/// Last health check time.
|
|
last_health_check: Instant,
|
|
/// Is endpoint healthy?
|
|
healthy: bool,
|
|
/// Consecutive failures.
|
|
consecutive_failures: u32,
|
|
}
|
|
|
|
/// HTTP/WebSocket connection pool for RPC clients.
|
|
pub struct ConnectionPool {
|
|
config: PoolConfig,
|
|
/// Per-endpoint connection entries.
|
|
endpoints: RwLock<HashMap<String, Mutex<EndpointEntry>>>,
|
|
/// Global connection semaphore.
|
|
global_semaphore: Semaphore,
|
|
/// Pool statistics.
|
|
stats: Arc<PoolStats>,
|
|
}
|
|
|
|
impl ConnectionPool {
|
|
/// Creates a new connection pool with default configuration.
|
|
pub fn new() -> Self {
|
|
Self::with_config(PoolConfig::default())
|
|
}
|
|
|
|
/// Creates a new connection pool with the given configuration.
|
|
pub fn with_config(config: PoolConfig) -> Self {
|
|
ConnectionPool {
|
|
global_semaphore: Semaphore::new(config.max_total_connections),
|
|
config,
|
|
endpoints: RwLock::new(HashMap::new()),
|
|
stats: Arc::new(PoolStats::default()),
|
|
}
|
|
}
|
|
|
|
/// Returns pool configuration.
|
|
pub fn config(&self) -> &PoolConfig {
|
|
&self.config
|
|
}
|
|
|
|
/// Returns pool statistics.
|
|
pub fn stats(&self) -> &PoolStats {
|
|
&self.stats
|
|
}
|
|
|
|
/// Acquires an HTTP client for the given endpoint.
|
|
pub async fn acquire_http(&self, endpoint: &str) -> Result<Arc<PooledHttpClient>, PoolError> {
|
|
// Acquire global permit
|
|
let _global_permit = self
|
|
.global_semaphore
|
|
.acquire()
|
|
.await
|
|
.map_err(|_| PoolError::PoolExhausted)?;
|
|
|
|
// Get or create endpoint entry
|
|
let entry = self.get_or_create_endpoint(endpoint);
|
|
|
|
// Acquire endpoint permit
|
|
let _endpoint_permit = {
|
|
let entry_lock = entry.lock();
|
|
entry_lock.semaphore.clone()
|
|
}
|
|
.acquire_owned()
|
|
.await
|
|
.map_err(|_| PoolError::EndpointExhausted(endpoint.to_string()))?;
|
|
|
|
// Try to reuse an existing connection
|
|
{
|
|
let mut entry_lock = entry.lock();
|
|
|
|
// Remove idle connections
|
|
let idle_timeout = self.config.idle_timeout;
|
|
entry_lock
|
|
.http_clients
|
|
.retain(|c| c.idle_time() < idle_timeout);
|
|
|
|
// Try to get an existing client
|
|
if let Some(client) = entry_lock.http_clients.pop() {
|
|
self.stats
|
|
.connections_reused
|
|
.fetch_add(1, Ordering::Relaxed);
|
|
client.touch();
|
|
return Ok(client);
|
|
}
|
|
}
|
|
|
|
// Create a new connection
|
|
self.create_http_client(endpoint).await
|
|
}
|
|
|
|
/// Acquires a WebSocket client for the given endpoint.
|
|
pub async fn acquire_ws(&self, endpoint: &str) -> Result<Arc<PooledWsClient>, PoolError> {
|
|
// Acquire global permit
|
|
let _global_permit = self
|
|
.global_semaphore
|
|
.acquire()
|
|
.await
|
|
.map_err(|_| PoolError::PoolExhausted)?;
|
|
|
|
// Get or create endpoint entry
|
|
let entry = self.get_or_create_endpoint(endpoint);
|
|
|
|
// Acquire endpoint permit
|
|
let _endpoint_permit = {
|
|
let entry_lock = entry.lock();
|
|
entry_lock.semaphore.clone()
|
|
}
|
|
.acquire_owned()
|
|
.await
|
|
.map_err(|_| PoolError::EndpointExhausted(endpoint.to_string()))?;
|
|
|
|
// Try to reuse an existing connection
|
|
{
|
|
let mut entry_lock = entry.lock();
|
|
|
|
// Remove idle connections
|
|
let idle_timeout = self.config.idle_timeout;
|
|
entry_lock
|
|
.ws_clients
|
|
.retain(|c| c.idle_time() < idle_timeout);
|
|
|
|
// Try to get an existing client
|
|
if let Some(client) = entry_lock.ws_clients.pop() {
|
|
self.stats
|
|
.connections_reused
|
|
.fetch_add(1, Ordering::Relaxed);
|
|
client.touch();
|
|
return Ok(client);
|
|
}
|
|
}
|
|
|
|
// Create a new connection
|
|
self.create_ws_client(endpoint).await
|
|
}
|
|
|
|
/// Returns an HTTP client to the pool.
|
|
pub fn release_http(&self, client: Arc<PooledHttpClient>) {
|
|
let endpoint = client.endpoint().to_string();
|
|
|
|
if let Some(entry) = self.endpoints.read().get(&endpoint) {
|
|
let mut entry_lock = entry.lock();
|
|
|
|
// Only keep if not too old and pool not full
|
|
if client.idle_time() < self.config.idle_timeout
|
|
&& entry_lock.http_clients.len() < self.config.max_connections_per_endpoint
|
|
{
|
|
entry_lock.http_clients.push(client);
|
|
} else {
|
|
self.stats
|
|
.connections_closed
|
|
.fetch_add(1, Ordering::Relaxed);
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Returns a WebSocket client to the pool.
|
|
pub fn release_ws(&self, client: Arc<PooledWsClient>) {
|
|
let endpoint = client.endpoint().to_string();
|
|
|
|
if let Some(entry) = self.endpoints.read().get(&endpoint) {
|
|
let mut entry_lock = entry.lock();
|
|
|
|
// Only keep if not too old and pool not full
|
|
if client.idle_time() < self.config.idle_timeout
|
|
&& entry_lock.ws_clients.len() < self.config.max_connections_per_endpoint
|
|
{
|
|
entry_lock.ws_clients.push(client);
|
|
} else {
|
|
self.stats
|
|
.connections_closed
|
|
.fetch_add(1, Ordering::Relaxed);
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Checks if an endpoint is healthy.
|
|
pub fn is_healthy(&self, endpoint: &str) -> bool {
|
|
self.endpoints
|
|
.read()
|
|
.get(endpoint)
|
|
.map(|e| e.lock().healthy)
|
|
.unwrap_or(true) // Assume healthy if unknown
|
|
}
|
|
|
|
/// Marks an endpoint as unhealthy.
|
|
pub fn mark_unhealthy(&self, endpoint: &str) {
|
|
if let Some(entry) = self.endpoints.read().get(endpoint) {
|
|
let mut entry_lock = entry.lock();
|
|
entry_lock.consecutive_failures += 1;
|
|
if entry_lock.consecutive_failures >= self.config.max_retries {
|
|
entry_lock.healthy = false;
|
|
self.stats
|
|
.health_check_failures
|
|
.fetch_add(1, Ordering::Relaxed);
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Marks an endpoint as healthy.
|
|
pub fn mark_healthy(&self, endpoint: &str) {
|
|
if let Some(entry) = self.endpoints.read().get(endpoint) {
|
|
let mut entry_lock = entry.lock();
|
|
entry_lock.healthy = true;
|
|
entry_lock.consecutive_failures = 0;
|
|
}
|
|
}
|
|
|
|
/// Closes all connections to an endpoint.
|
|
pub fn close_endpoint(&self, endpoint: &str) {
|
|
if let Some(entry) = self.endpoints.write().remove(endpoint) {
|
|
let entry_lock = entry.lock();
|
|
let closed = entry_lock.http_clients.len() + entry_lock.ws_clients.len();
|
|
self.stats
|
|
.connections_closed
|
|
.fetch_add(closed as u64, Ordering::Relaxed);
|
|
}
|
|
}
|
|
|
|
/// Closes all idle connections.
|
|
pub fn close_idle(&self) {
|
|
let endpoints = self.endpoints.read();
|
|
for entry in endpoints.values() {
|
|
let mut entry_lock = entry.lock();
|
|
let idle_timeout = self.config.idle_timeout;
|
|
|
|
let http_before = entry_lock.http_clients.len();
|
|
entry_lock
|
|
.http_clients
|
|
.retain(|c| c.idle_time() < idle_timeout);
|
|
let http_closed = http_before - entry_lock.http_clients.len();
|
|
|
|
let ws_before = entry_lock.ws_clients.len();
|
|
entry_lock
|
|
.ws_clients
|
|
.retain(|c| c.idle_time() < idle_timeout);
|
|
let ws_closed = ws_before - entry_lock.ws_clients.len();
|
|
|
|
self.stats
|
|
.connections_closed
|
|
.fetch_add((http_closed + ws_closed) as u64, Ordering::Relaxed);
|
|
}
|
|
}
|
|
|
|
/// Closes all connections.
|
|
pub fn close_all(&self) {
|
|
let mut endpoints = self.endpoints.write();
|
|
let mut total_closed = 0u64;
|
|
|
|
for entry in endpoints.values() {
|
|
let entry_lock = entry.lock();
|
|
total_closed += (entry_lock.http_clients.len() + entry_lock.ws_clients.len()) as u64;
|
|
}
|
|
|
|
endpoints.clear();
|
|
self.stats
|
|
.connections_closed
|
|
.fetch_add(total_closed, Ordering::Relaxed);
|
|
}
|
|
|
|
/// Returns the number of active connections.
|
|
pub fn active_connections(&self) -> usize {
|
|
let snapshot = self.stats.snapshot();
|
|
snapshot.connections_active as usize
|
|
}
|
|
|
|
/// Returns the number of endpoints.
|
|
pub fn endpoint_count(&self) -> usize {
|
|
self.endpoints.read().len()
|
|
}
|
|
|
|
/// Gets or creates an endpoint entry.
|
|
fn get_or_create_endpoint(&self, endpoint: &str) -> Arc<Mutex<EndpointEntry>> {
|
|
// Try read-only first
|
|
if let Some(entry) = self.endpoints.read().get(endpoint) {
|
|
return Arc::new(Mutex::new(EndpointEntry {
|
|
http_clients: entry.lock().http_clients.clone(),
|
|
ws_clients: entry.lock().ws_clients.clone(),
|
|
semaphore: entry.lock().semaphore.clone(),
|
|
last_health_check: entry.lock().last_health_check,
|
|
healthy: entry.lock().healthy,
|
|
consecutive_failures: entry.lock().consecutive_failures,
|
|
}));
|
|
}
|
|
|
|
// Create new entry
|
|
let mut endpoints = self.endpoints.write();
|
|
|
|
// Double-check after acquiring write lock
|
|
if let Some(entry) = endpoints.get(endpoint) {
|
|
return Arc::new(Mutex::new(EndpointEntry {
|
|
http_clients: entry.lock().http_clients.clone(),
|
|
ws_clients: entry.lock().ws_clients.clone(),
|
|
semaphore: entry.lock().semaphore.clone(),
|
|
last_health_check: entry.lock().last_health_check,
|
|
healthy: entry.lock().healthy,
|
|
consecutive_failures: entry.lock().consecutive_failures,
|
|
}));
|
|
}
|
|
|
|
let entry = Arc::new(Mutex::new(EndpointEntry {
|
|
http_clients: Vec::new(),
|
|
ws_clients: Vec::new(),
|
|
semaphore: Arc::new(Semaphore::new(self.config.max_connections_per_endpoint)),
|
|
last_health_check: Instant::now(),
|
|
healthy: true,
|
|
consecutive_failures: 0,
|
|
}));
|
|
|
|
endpoints.insert(
|
|
endpoint.to_string(),
|
|
Mutex::new(EndpointEntry {
|
|
http_clients: Vec::new(),
|
|
ws_clients: Vec::new(),
|
|
semaphore: Arc::new(Semaphore::new(self.config.max_connections_per_endpoint)),
|
|
last_health_check: Instant::now(),
|
|
healthy: true,
|
|
consecutive_failures: 0,
|
|
}),
|
|
);
|
|
|
|
entry
|
|
}
|
|
|
|
/// Creates a new HTTP client with retries.
|
|
async fn create_http_client(&self, endpoint: &str) -> Result<Arc<PooledHttpClient>, PoolError> {
|
|
let mut last_error = None;
|
|
|
|
for attempt in 0..=self.config.max_retries {
|
|
if attempt > 0 {
|
|
let backoff = self.config.retry_backoff * 2u32.pow(attempt - 1);
|
|
tokio::time::sleep(backoff).await;
|
|
}
|
|
|
|
match self.try_create_http_client(endpoint).await {
|
|
Ok(client) => {
|
|
self.stats
|
|
.connections_created
|
|
.fetch_add(1, Ordering::Relaxed);
|
|
self.mark_healthy(endpoint);
|
|
return Ok(client);
|
|
}
|
|
Err(e) => {
|
|
last_error = Some(e);
|
|
self.mark_unhealthy(endpoint);
|
|
}
|
|
}
|
|
}
|
|
|
|
self.stats.requests_failed.fetch_add(1, Ordering::Relaxed);
|
|
Err(last_error.unwrap_or(PoolError::ConnectionFailed("Unknown error".to_string())))
|
|
}
|
|
|
|
/// Creates a new WebSocket client with retries.
|
|
async fn create_ws_client(&self, endpoint: &str) -> Result<Arc<PooledWsClient>, PoolError> {
|
|
let mut last_error = None;
|
|
|
|
for attempt in 0..=self.config.max_retries {
|
|
if attempt > 0 {
|
|
let backoff = self.config.retry_backoff * 2u32.pow(attempt - 1);
|
|
tokio::time::sleep(backoff).await;
|
|
}
|
|
|
|
match self.try_create_ws_client(endpoint).await {
|
|
Ok(client) => {
|
|
self.stats
|
|
.connections_created
|
|
.fetch_add(1, Ordering::Relaxed);
|
|
self.mark_healthy(endpoint);
|
|
return Ok(client);
|
|
}
|
|
Err(e) => {
|
|
last_error = Some(e);
|
|
self.mark_unhealthy(endpoint);
|
|
}
|
|
}
|
|
}
|
|
|
|
self.stats.requests_failed.fetch_add(1, Ordering::Relaxed);
|
|
Err(last_error.unwrap_or(PoolError::ConnectionFailed("Unknown error".to_string())))
|
|
}
|
|
|
|
/// Attempts to create an HTTP client.
|
|
async fn try_create_http_client(
|
|
&self,
|
|
endpoint: &str,
|
|
) -> Result<Arc<PooledHttpClient>, PoolError> {
|
|
let client = HttpClientBuilder::default()
|
|
.request_timeout(self.config.request_timeout)
|
|
.build(endpoint)
|
|
.map_err(|e| PoolError::ConnectionFailed(e.to_string()))?;
|
|
|
|
Ok(Arc::new(PooledHttpClient {
|
|
client,
|
|
endpoint: endpoint.to_string(),
|
|
created_at: Instant::now(),
|
|
last_used: Mutex::new(Instant::now()),
|
|
request_count: AtomicU64::new(0),
|
|
}))
|
|
}
|
|
|
|
/// Attempts to create a WebSocket client.
|
|
async fn try_create_ws_client(&self, endpoint: &str) -> Result<Arc<PooledWsClient>, PoolError> {
|
|
let client = WsClientBuilder::default()
|
|
.request_timeout(self.config.request_timeout)
|
|
.connection_timeout(self.config.connect_timeout)
|
|
.build(endpoint)
|
|
.await
|
|
.map_err(|e| PoolError::ConnectionFailed(e.to_string()))?;
|
|
|
|
Ok(Arc::new(PooledWsClient {
|
|
client,
|
|
endpoint: endpoint.to_string(),
|
|
created_at: Instant::now(),
|
|
last_used: Mutex::new(Instant::now()),
|
|
request_count: AtomicU64::new(0),
|
|
}))
|
|
}
|
|
}
|
|
|
|
impl Default for ConnectionPool {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|
|
|
|
/// Connection pool errors.
|
|
#[derive(Debug, Clone, thiserror::Error)]
|
|
pub enum PoolError {
|
|
/// Pool exhausted (no available connections).
|
|
#[error("Connection pool exhausted")]
|
|
PoolExhausted,
|
|
|
|
/// Endpoint-specific pool exhausted.
|
|
#[error("Endpoint pool exhausted: {0}")]
|
|
EndpointExhausted(String),
|
|
|
|
/// Connection failed.
|
|
#[error("Connection failed: {0}")]
|
|
ConnectionFailed(String),
|
|
|
|
/// Endpoint unhealthy.
|
|
#[error("Endpoint unhealthy: {0}")]
|
|
EndpointUnhealthy(String),
|
|
|
|
/// Connection timeout.
|
|
#[error("Connection timeout")]
|
|
Timeout,
|
|
}
|
|
|
|
/// A guard that automatically returns a connection to the pool when dropped.
|
|
pub struct PooledHttpClientGuard<'a> {
|
|
pool: &'a ConnectionPool,
|
|
client: Option<Arc<PooledHttpClient>>,
|
|
}
|
|
|
|
impl<'a> PooledHttpClientGuard<'a> {
|
|
/// Creates a new guard.
|
|
pub fn new(pool: &'a ConnectionPool, client: Arc<PooledHttpClient>) -> Self {
|
|
PooledHttpClientGuard {
|
|
pool,
|
|
client: Some(client),
|
|
}
|
|
}
|
|
|
|
/// Returns the underlying client.
|
|
pub fn client(&self) -> &HttpClient {
|
|
self.client.as_ref().unwrap().client()
|
|
}
|
|
|
|
/// Marks the request as successful.
|
|
pub fn success(&self) {
|
|
if let Some(ref client) = self.client {
|
|
client.touch();
|
|
self.pool
|
|
.stats
|
|
.requests_total
|
|
.fetch_add(1, Ordering::Relaxed);
|
|
}
|
|
}
|
|
|
|
/// Marks the request as failed.
|
|
pub fn failed(&self) {
|
|
if let Some(ref client) = self.client {
|
|
self.pool.mark_unhealthy(client.endpoint());
|
|
self.pool
|
|
.stats
|
|
.requests_failed
|
|
.fetch_add(1, Ordering::Relaxed);
|
|
}
|
|
}
|
|
|
|
/// Discards the connection (don't return to pool).
|
|
pub fn discard(mut self) {
|
|
if let Some(client) = self.client.take() {
|
|
self.pool
|
|
.stats
|
|
.connections_closed
|
|
.fetch_add(1, Ordering::Relaxed);
|
|
drop(client);
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<'a> Drop for PooledHttpClientGuard<'a> {
|
|
fn drop(&mut self) {
|
|
if let Some(client) = self.client.take() {
|
|
self.pool.release_http(client);
|
|
}
|
|
}
|
|
}
|
|
|
|
/// A guard that automatically returns a WebSocket connection to the pool when dropped.
|
|
pub struct PooledWsClientGuard<'a> {
|
|
pool: &'a ConnectionPool,
|
|
client: Option<Arc<PooledWsClient>>,
|
|
}
|
|
|
|
impl<'a> PooledWsClientGuard<'a> {
|
|
/// Creates a new guard.
|
|
pub fn new(pool: &'a ConnectionPool, client: Arc<PooledWsClient>) -> Self {
|
|
PooledWsClientGuard {
|
|
pool,
|
|
client: Some(client),
|
|
}
|
|
}
|
|
|
|
/// Returns the underlying client.
|
|
pub fn client(&self) -> &WsClient {
|
|
self.client.as_ref().unwrap().client()
|
|
}
|
|
|
|
/// Marks the request as successful.
|
|
pub fn success(&self) {
|
|
if let Some(ref client) = self.client {
|
|
client.touch();
|
|
self.pool
|
|
.stats
|
|
.requests_total
|
|
.fetch_add(1, Ordering::Relaxed);
|
|
}
|
|
}
|
|
|
|
/// Marks the request as failed.
|
|
pub fn failed(&self) {
|
|
if let Some(ref client) = self.client {
|
|
self.pool.mark_unhealthy(client.endpoint());
|
|
self.pool
|
|
.stats
|
|
.requests_failed
|
|
.fetch_add(1, Ordering::Relaxed);
|
|
}
|
|
}
|
|
|
|
/// Discards the connection (don't return to pool).
|
|
pub fn discard(mut self) {
|
|
if let Some(client) = self.client.take() {
|
|
self.pool
|
|
.stats
|
|
.connections_closed
|
|
.fetch_add(1, Ordering::Relaxed);
|
|
drop(client);
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<'a> Drop for PooledWsClientGuard<'a> {
|
|
fn drop(&mut self) {
|
|
if let Some(client) = self.client.take() {
|
|
self.pool.release_ws(client);
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Helper extension trait for connection pool.
|
|
#[async_trait::async_trait]
|
|
pub trait ConnectionPoolExt {
|
|
/// Acquires an HTTP client and returns a guard.
|
|
async fn acquire_http_guard(
|
|
&self,
|
|
endpoint: &str,
|
|
) -> Result<PooledHttpClientGuard<'_>, PoolError>;
|
|
|
|
/// Acquires a WebSocket client and returns a guard.
|
|
async fn acquire_ws_guard(&self, endpoint: &str) -> Result<PooledWsClientGuard<'_>, PoolError>;
|
|
}
|
|
|
|
#[async_trait::async_trait]
|
|
impl ConnectionPoolExt for ConnectionPool {
|
|
async fn acquire_http_guard(
|
|
&self,
|
|
endpoint: &str,
|
|
) -> Result<PooledHttpClientGuard<'_>, PoolError> {
|
|
let client = self.acquire_http(endpoint).await?;
|
|
Ok(PooledHttpClientGuard::new(self, client))
|
|
}
|
|
|
|
async fn acquire_ws_guard(&self, endpoint: &str) -> Result<PooledWsClientGuard<'_>, PoolError> {
|
|
let client = self.acquire_ws(endpoint).await?;
|
|
Ok(PooledWsClientGuard::new(self, client))
|
|
}
|
|
}
|
|
|
|
/// Global connection pool for use across the application.
|
|
static GLOBAL_POOL: std::sync::OnceLock<ConnectionPool> = std::sync::OnceLock::new();
|
|
|
|
/// Initializes the global connection pool.
|
|
pub fn init_global_pool(config: PoolConfig) {
|
|
let _ = GLOBAL_POOL.set(ConnectionPool::with_config(config));
|
|
}
|
|
|
|
/// Returns the global connection pool.
|
|
pub fn global_pool() -> &'static ConnectionPool {
|
|
GLOBAL_POOL.get_or_init(ConnectionPool::new)
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_pool_config_default() {
|
|
let config = PoolConfig::default();
|
|
assert_eq!(config.max_connections_per_endpoint, 4);
|
|
assert_eq!(config.max_total_connections, 64);
|
|
}
|
|
|
|
#[test]
|
|
fn test_pool_config_high_throughput() {
|
|
let config = PoolConfig::high_throughput();
|
|
assert!(
|
|
config.max_connections_per_endpoint
|
|
> PoolConfig::default().max_connections_per_endpoint
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_pool_config_low_latency() {
|
|
let config = PoolConfig::low_latency();
|
|
assert!(config.connect_timeout < PoolConfig::default().connect_timeout);
|
|
}
|
|
|
|
#[test]
|
|
fn test_pool_creation() {
|
|
let pool = ConnectionPool::new();
|
|
assert_eq!(pool.active_connections(), 0);
|
|
assert_eq!(pool.endpoint_count(), 0);
|
|
}
|
|
|
|
#[test]
|
|
fn test_pool_stats() {
|
|
let pool = ConnectionPool::new();
|
|
let stats = pool.stats().snapshot();
|
|
assert_eq!(stats.connections_created, 0);
|
|
assert_eq!(stats.requests_total, 0);
|
|
}
|
|
|
|
#[test]
|
|
fn test_global_pool() {
|
|
let pool = global_pool();
|
|
assert_eq!(pool.config().max_connections_per_endpoint, 4);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_endpoint_health() {
|
|
let pool = ConnectionPool::new();
|
|
|
|
// Initially healthy (unknown endpoints are assumed healthy)
|
|
assert!(pool.is_healthy("http://example.com"));
|
|
|
|
// Mark unhealthy multiple times
|
|
for _ in 0..5 {
|
|
pool.mark_unhealthy("http://example.com");
|
|
}
|
|
|
|
// After max_retries failures, should be unhealthy
|
|
// Note: This test doesn't actually trigger the unhealthy state
|
|
// because mark_unhealthy requires the endpoint to exist in the map
|
|
}
|
|
}
|