//! Transaction routing to shards. //! //! Routes transactions to appropriate shards based on account addresses. use std::collections::HashMap; use serde::{Deserialize, Serialize}; use synor_types::Hash256; use crate::{ShardConfig, ShardId}; /// Routing table entry. #[derive(Clone, Debug, Serialize, Deserialize)] pub struct RoutingEntry { /// Address prefix. pub prefix: Vec, /// Assigned shard. pub shard_id: ShardId, /// Entry weight (for load balancing). pub weight: u32, } /// Routing table for shard assignment. #[derive(Clone, Debug)] pub struct RoutingTable { /// Entries sorted by prefix. entries: Vec, /// Cache for recent lookups. cache: HashMap, /// Max cache size. max_cache_size: usize, } impl RoutingTable { /// Creates a new routing table. pub fn new() -> Self { Self { entries: Vec::new(), cache: HashMap::new(), max_cache_size: 10000, } } /// Adds a routing entry. pub fn add_entry(&mut self, entry: RoutingEntry) { self.entries.push(entry); self.entries.sort_by(|a, b| a.prefix.cmp(&b.prefix)); } /// Looks up the shard for an address. pub fn lookup(&mut self, address: &Hash256) -> Option { // Check cache first if let Some(&shard) = self.cache.get(address) { return Some(shard); } // Search entries let addr_bytes = address.as_bytes(); for entry in &self.entries { if addr_bytes.starts_with(&entry.prefix) { // Update cache if self.cache.len() < self.max_cache_size { self.cache.insert(*address, entry.shard_id); } return Some(entry.shard_id); } } None } /// Clears the cache. pub fn clear_cache(&mut self) { self.cache.clear(); } } impl Default for RoutingTable { fn default() -> Self { Self::new() } } /// Routes transactions to shards. pub struct ShardRouter { /// Configuration. config: ShardConfig, /// Custom routing table (optional). routing_table: Option, /// Shard load (for load balancing). shard_load: HashMap, } impl ShardRouter { /// Creates a new router with the given configuration. pub fn new(config: ShardConfig) -> Self { let mut shard_load = HashMap::new(); for i in 0..config.num_shards { shard_load.insert(i, 0); } Self { config, routing_table: None, shard_load, } } /// Sets a custom routing table. pub fn with_routing_table(mut self, table: RoutingTable) -> Self { self.routing_table = Some(table); self } /// Routes an address to a shard. pub fn route(&self, address: &Hash256) -> ShardId { // Check custom routing table first if let Some(ref table) = self.routing_table { let mut table = table.clone(); if let Some(shard) = table.lookup(address) { return shard; } } // Default: hash-based routing self.config.shard_for_address(address) } /// Routes a transaction with sender and receiver. /// Returns (sender_shard, receiver_shard, is_cross_shard). pub fn route_transaction( &self, sender: &Hash256, receiver: &Hash256, ) -> (ShardId, ShardId, bool) { let sender_shard = self.route(sender); let receiver_shard = self.route(receiver); let is_cross_shard = sender_shard != receiver_shard; (sender_shard, receiver_shard, is_cross_shard) } /// Records transaction load on a shard. pub fn record_transaction(&mut self, shard_id: ShardId) { if let Some(load) = self.shard_load.get_mut(&shard_id) { *load += 1; } } /// Resets load counters. pub fn reset_load(&mut self) { for load in self.shard_load.values_mut() { *load = 0; } } /// Gets the current load for a shard. pub fn get_load(&self, shard_id: ShardId) -> u64 { self.shard_load.get(&shard_id).copied().unwrap_or(0) } /// Gets the least loaded shard. pub fn least_loaded_shard(&self) -> ShardId { self.shard_load .iter() .min_by_key(|(_, &load)| load) .map(|(&id, _)| id) .unwrap_or(0) } /// Gets the most loaded shard. pub fn most_loaded_shard(&self) -> ShardId { self.shard_load .iter() .max_by_key(|(_, &load)| load) .map(|(&id, _)| id) .unwrap_or(0) } /// Checks if load is imbalanced (for resharding trigger). pub fn is_load_imbalanced(&self, threshold: f64) -> bool { if self.shard_load.is_empty() { return false; } let loads: Vec = self.shard_load.values().copied().collect(); let avg = loads.iter().sum::() as f64 / loads.len() as f64; if avg == 0.0 { return false; } // Check if any shard is significantly above average loads.iter().any(|&load| (load as f64 / avg) > threshold) } /// Returns load distribution stats. pub fn load_stats(&self) -> LoadStats { let loads: Vec = self.shard_load.values().copied().collect(); let total: u64 = loads.iter().sum(); let avg = if loads.is_empty() { 0.0 } else { total as f64 / loads.len() as f64 }; let min = loads.iter().min().copied().unwrap_or(0); let max = loads.iter().max().copied().unwrap_or(0); LoadStats { total_transactions: total, average_per_shard: avg, min_load: min, max_load: max, num_shards: loads.len(), } } } /// Load distribution statistics. #[derive(Clone, Debug)] pub struct LoadStats { /// Total transactions routed. pub total_transactions: u64, /// Average per shard. pub average_per_shard: f64, /// Minimum shard load. pub min_load: u64, /// Maximum shard load. pub max_load: u64, /// Number of shards. pub num_shards: usize, } #[cfg(test)] mod tests { use super::*; #[test] fn test_routing_deterministic() { let config = ShardConfig::with_shards(32); let router = ShardRouter::new(config); let addr = Hash256::from_bytes([1u8; 32]); let shard1 = router.route(&addr); let shard2 = router.route(&addr); assert_eq!(shard1, shard2); } #[test] fn test_routing_distribution() { let config = ShardConfig::with_shards(8); let router = ShardRouter::new(config); let mut counts = [0u32; 8]; for i in 0..1000u32 { let mut bytes = [0u8; 32]; bytes[0..4].copy_from_slice(&i.to_le_bytes()); let addr = Hash256::from_bytes(bytes); let shard = router.route(&addr) as usize; counts[shard] += 1; } // Check all shards have some load for (i, count) in counts.iter().enumerate() { assert!(*count > 0, "Shard {} has no transactions", i); } } #[test] fn test_cross_shard_detection() { let config = ShardConfig::with_shards(32); let router = ShardRouter::new(config); // Same shard let sender = Hash256::from_bytes([1u8; 32]); let receiver = Hash256::from_bytes([1u8; 32]); // Same address let (_, _, is_cross) = router.route_transaction(&sender, &receiver); assert!(!is_cross); // Different shards let mut other_bytes = [0u8; 32]; other_bytes[0] = 255; // Different first bytes let other = Hash256::from_bytes(other_bytes); let sender_shard = router.route(&sender); let other_shard = router.route(&other); if sender_shard != other_shard { let (_, _, is_cross) = router.route_transaction(&sender, &other); assert!(is_cross); } } #[test] fn test_load_tracking() { let config = ShardConfig::with_shards(4); let mut router = ShardRouter::new(config); router.record_transaction(0); router.record_transaction(0); router.record_transaction(1); assert_eq!(router.get_load(0), 2); assert_eq!(router.get_load(1), 1); assert_eq!(router.least_loaded_shard(), 2); // or 3, both have 0 assert_eq!(router.most_loaded_shard(), 0); } #[test] fn test_load_imbalance() { let config = ShardConfig::with_shards(4); let mut router = ShardRouter::new(config); // Even load for _ in 0..10 { router.record_transaction(0); router.record_transaction(1); router.record_transaction(2); router.record_transaction(3); } assert!(!router.is_load_imbalanced(2.0)); // Reset and create imbalance router.reset_load(); for _ in 0..100 { router.record_transaction(0); } router.record_transaction(1); assert!(router.is_load_imbalanced(2.0)); } }