//! Gateway server implementation. //! //! This module provides the main Gateway server that runs the REST API, //! WebSocket, and metrics endpoints. use crate::{ auth::AuthService, config::GatewayConfig, middleware::{ auth_middleware, build_cors_layer, rate_limit_middleware, request_id_middleware, security_headers_middleware, timing_middleware, version_middleware, RateLimiterState, }, routes::{self, AppState}, }; use axum::{ middleware::{from_fn, from_fn_with_state}, Router, }; use std::{net::SocketAddr, sync::Arc}; use tokio::{ net::TcpListener, signal, sync::oneshot, }; use tower_http::{ compression::CompressionLayer, limit::RequestBodyLimitLayer, timeout::TimeoutLayer, trace::TraceLayer, }; use tracing::info; /// Synor API Gateway server. pub struct Gateway { config: GatewayConfig, auth_service: Arc, rate_limiter: Arc, } impl Gateway { /// Create a new gateway instance. pub fn new(config: GatewayConfig) -> anyhow::Result { let auth_service = Arc::new(AuthService::from_config(config.auth.clone())); let rate_limiter = Arc::new(RateLimiterState::new(config.rate_limit.clone())); Ok(Self { config, auth_service, rate_limiter, }) } /// Create gateway from environment configuration. pub fn from_env() -> anyhow::Result { let config = GatewayConfig::from_env()?; Self::new(config) } /// Create gateway from configuration file. pub fn from_file(path: &str) -> anyhow::Result { let config = GatewayConfig::from_file(path)?; Self::new(config) } /// Build the router with all middleware and routes. fn build_router(&self) -> Router { let app_state = AppState::new(self.config.clone()); // Build base router with all service routes let router = routes::build_router(app_state.clone()); // Apply middleware stack (order matters - applied bottom to top) let router = router // Innermost: service routes (already applied) // Security headers .layer(from_fn(security_headers_middleware)) // Version checking .layer(from_fn(version_middleware)) // Authentication context injection .layer(from_fn_with_state( self.auth_service.clone(), auth_middleware, )) // Rate limiting .layer(from_fn_with_state( self.rate_limiter.clone(), |state, connect_info, req, next| async move { rate_limit_middleware(state, connect_info, req, next).await }, )) // Request timing .layer(from_fn(timing_middleware)) // Request ID .layer(from_fn(request_id_middleware)); // Optional layers based on configuration let router = if self.config.cors.enabled { router.layer(build_cors_layer(&self.config.cors)) } else { router }; let router = if self.config.server.compression { router.layer(CompressionLayer::new()) } else { router }; // Request body limit let router = router.layer(RequestBodyLimitLayer::new(self.config.server.max_body_size)); // Request timeout let router = router.layer(TimeoutLayer::new(self.config.server.request_timeout)); // Tracing let router = router.layer(TraceLayer::new_for_http()); router } /// Start the gateway server. pub async fn serve(self) -> anyhow::Result<()> { let listen_addr = self.config.server.listen_addr; let shutdown_timeout = self.config.server.shutdown_timeout; info!( listen_addr = %listen_addr, "Starting Synor API Gateway" ); let router = self.build_router(); // Create TCP listener let listener = TcpListener::bind(listen_addr).await?; info!( listen_addr = %listen_addr, "Gateway listening for connections" ); // Graceful shutdown handling let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); // Spawn shutdown signal handler tokio::spawn(async move { shutdown_signal().await; let _ = shutdown_tx.send(()); }); // Serve with graceful shutdown axum::serve( listener, router.into_make_service_with_connect_info::(), ) .with_graceful_shutdown(async { let _ = shutdown_rx.await; info!("Shutdown signal received, initiating graceful shutdown"); }) .await?; info!("Gateway shutdown complete"); Ok(()) } /// Start the gateway server and return a handle for programmatic shutdown. pub async fn serve_with_shutdown(self) -> anyhow::Result { let listen_addr = self.config.server.listen_addr; info!( listen_addr = %listen_addr, "Starting Synor API Gateway" ); let router = self.build_router(); // Create TCP listener let listener = TcpListener::bind(listen_addr).await?; let local_addr = listener.local_addr()?; info!( listen_addr = %local_addr, "Gateway listening for connections" ); // Create shutdown channel let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); // Spawn server in background let server_handle = tokio::spawn(async move { axum::serve( listener, router.into_make_service_with_connect_info::(), ) .with_graceful_shutdown(async { let _ = shutdown_rx.await; info!("Shutdown signal received"); }) .await }); Ok(GatewayHandle { shutdown_tx: Some(shutdown_tx), server_handle, local_addr, }) } /// Get the configured listen address. pub fn listen_addr(&self) -> SocketAddr { self.config.server.listen_addr } /// Get the configured WebSocket address. pub fn ws_addr(&self) -> SocketAddr { self.config.server.ws_addr } } /// Handle for controlling a running gateway. pub struct GatewayHandle { shutdown_tx: Option>, server_handle: tokio::task::JoinHandle>, local_addr: SocketAddr, } impl GatewayHandle { /// Get the local address the server is bound to. pub fn local_addr(&self) -> SocketAddr { self.local_addr } /// Trigger graceful shutdown. pub fn shutdown(&mut self) { if let Some(tx) = self.shutdown_tx.take() { let _ = tx.send(()); } } /// Wait for the server to finish. pub async fn wait(self) -> anyhow::Result<()> { self.server_handle.await??; Ok(()) } /// Shutdown and wait for completion. pub async fn shutdown_and_wait(mut self) -> anyhow::Result<()> { self.shutdown(); self.wait().await } } /// Wait for shutdown signal (Ctrl+C or SIGTERM). async fn shutdown_signal() { let ctrl_c = async { signal::ctrl_c() .await .expect("Failed to install Ctrl+C handler"); }; #[cfg(unix)] let terminate = async { signal::unix::signal(signal::unix::SignalKind::terminate()) .expect("Failed to install signal handler") .recv() .await; }; #[cfg(not(unix))] let terminate = std::future::pending::<()>(); tokio::select! { _ = ctrl_c => { info!("Received Ctrl+C signal"); } _ = terminate => { info!("Received SIGTERM signal"); } } } /// Builder for Gateway configuration. pub struct GatewayBuilder { config: GatewayConfig, } impl GatewayBuilder { /// Create a new gateway builder with default configuration. pub fn new() -> Self { Self { config: GatewayConfig::default(), } } /// Set the listen address. pub fn listen_addr(mut self, addr: SocketAddr) -> Self { self.config.server.listen_addr = addr; self } /// Set the WebSocket address. pub fn ws_addr(mut self, addr: SocketAddr) -> Self { self.config.server.ws_addr = addr; self } /// Set the JWT secret. pub fn jwt_secret(mut self, secret: impl Into) -> Self { self.config.auth.jwt_secret = secret.into(); self } /// Disable authentication (for development). pub fn disable_auth(mut self) -> Self { self.config.auth.enabled = false; self } /// Disable rate limiting. pub fn disable_rate_limit(mut self) -> Self { self.config.rate_limit.enabled = false; self } /// Set maximum request body size. pub fn max_body_size(mut self, size: usize) -> Self { self.config.server.max_body_size = size; self } /// Enable or disable compression. pub fn compression(mut self, enabled: bool) -> Self { self.config.server.compression = enabled; self } /// Build the Gateway. pub fn build(self) -> anyhow::Result { Gateway::new(self.config) } } impl Default for GatewayBuilder { fn default() -> Self { Self::new() } } #[cfg(test)] mod tests { use super::*; #[tokio::test] async fn test_gateway_builder() { let gateway = GatewayBuilder::new() .listen_addr("127.0.0.1:0".parse().unwrap()) .disable_auth() .disable_rate_limit() .build() .unwrap(); assert!(!gateway.config.auth.enabled); assert!(!gateway.config.rate_limit.enabled); } #[tokio::test] async fn test_gateway_start_stop() { let gateway = GatewayBuilder::new() .listen_addr("127.0.0.1:0".parse().unwrap()) .disable_auth() .disable_rate_limit() .build() .unwrap(); let mut handle = gateway.serve_with_shutdown().await.unwrap(); // Server should be running let addr = handle.local_addr(); assert!(addr.port() > 0); // Trigger shutdown handle.shutdown(); // Wait for server to stop handle.wait().await.unwrap(); } }