//! Cross-shard messaging protocol. //! //! Enables atomic operations across shards via receipt-based messaging. use std::collections::{HashMap, VecDeque}; use serde::{Deserialize, Serialize}; use synor_types::Hash256; use crate::{ShardError, ShardId, ShardResult, Slot}; /// Message status. #[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] pub enum MessageStatus { /// Message is pending delivery. Pending, /// Message has been delivered. Delivered, /// Message delivery confirmed with receipt. Confirmed, /// Message timed out. TimedOut, /// Message processing failed. Failed, } /// Cross-shard message. #[derive(Clone, Debug, Serialize, Deserialize)] pub struct CrossShardMessage { /// Unique message ID. pub id: Hash256, /// Source shard. pub from_shard: ShardId, /// Destination shard. pub to_shard: ShardId, /// Message payload. pub payload: Vec, /// Slot when message was sent. pub sent_slot: Slot, /// Message status. pub status: MessageStatus, /// Optional transaction hash that triggered this message. pub tx_hash: Option, } impl CrossShardMessage { /// Creates a new cross-shard message. pub fn new( from_shard: ShardId, to_shard: ShardId, payload: Vec, sent_slot: Slot, ) -> Self { // Generate unique ID let mut hasher = blake3::Hasher::new(); hasher.update(&from_shard.to_le_bytes()); hasher.update(&to_shard.to_le_bytes()); hasher.update(&sent_slot.to_le_bytes()); hasher.update(&payload); let hash = hasher.finalize(); Self { id: Hash256::from_bytes(*hash.as_bytes()), from_shard, to_shard, payload, sent_slot, status: MessageStatus::Pending, tx_hash: None, } } /// Sets the transaction hash. pub fn with_tx_hash(mut self, tx_hash: Hash256) -> Self { self.tx_hash = Some(tx_hash); self } } /// Receipt for delivered cross-shard message. #[derive(Clone, Debug, Serialize, Deserialize)] pub struct MessageReceipt { /// Message ID. pub message_id: Hash256, /// Receiving shard. pub receiver_shard: ShardId, /// Slot when processed. pub processed_slot: Slot, /// Success flag. pub success: bool, /// Result data (e.g., return value or error). pub result: Vec, /// State root after processing. pub post_state_root: Hash256, } /// Routes messages between shards. pub struct MessageRouter { /// Number of shards. num_shards: u16, /// Message timeout in slots. timeout_slots: u64, /// Outbound message queues per shard. outbound: HashMap>, /// Inbound message queues per shard. inbound: HashMap>, /// Pending receipts. pending_receipts: HashMap, /// Confirmed receipts. receipts: HashMap, } impl MessageRouter { /// Creates a new message router. pub fn new(num_shards: u16, timeout_slots: u64) -> Self { let mut outbound = HashMap::new(); let mut inbound = HashMap::new(); for i in 0..num_shards { outbound.insert(i, VecDeque::new()); inbound.insert(i, VecDeque::new()); } Self { num_shards, timeout_slots, outbound, inbound, pending_receipts: HashMap::new(), receipts: HashMap::new(), } } /// Sends a cross-shard message. pub fn send( &mut self, from: ShardId, to: ShardId, payload: Vec, current_slot: Slot, ) -> ShardResult { if from >= self.num_shards || to >= self.num_shards { return Err(ShardError::InvalidShardId(from.max(to), self.num_shards - 1)); } if from == to { return Err(ShardError::MessageFailed( "Cannot send message to same shard".into(), )); } let message = CrossShardMessage::new(from, to, payload, current_slot); let message_id = message.id; // Add to outbound queue if let Some(queue) = self.outbound.get_mut(&from) { queue.push_back(message.clone()); } // Add to inbound queue of destination if let Some(queue) = self.inbound.get_mut(&to) { queue.push_back(message.clone()); } // Track for receipt self.pending_receipts.insert(message_id, message); Ok(message_id) } /// Receives pending messages for a shard. pub fn receive(&mut self, shard_id: ShardId, current_slot: Slot) -> Vec { let mut messages = Vec::new(); let mut timed_out = Vec::new(); let timeout_threshold = self.timeout_slots; if let Some(queue) = self.inbound.get_mut(&shard_id) { // Take all pending messages while let Some(mut msg) = queue.pop_front() { // Check for timeout if current_slot - msg.sent_slot > timeout_threshold { msg.status = MessageStatus::TimedOut; timed_out.push(msg); } else { msg.status = MessageStatus::Delivered; messages.push(msg); } } } // Process timeouts after releasing the borrow for msg in timed_out { self.handle_timeout(&msg); } messages } /// Confirms message processing with receipt. pub fn confirm( &mut self, message_id: Hash256, receiver_shard: ShardId, processed_slot: Slot, success: bool, result: Vec, post_state_root: Hash256, ) -> ShardResult<()> { if let Some(mut message) = self.pending_receipts.remove(&message_id) { message.status = if success { MessageStatus::Confirmed } else { MessageStatus::Failed }; let receipt = MessageReceipt { message_id, receiver_shard, processed_slot, success, result, post_state_root, }; self.receipts.insert(message_id, receipt); Ok(()) } else { Err(ShardError::MessageFailed(format!( "Message {} not found in pending", message_id ))) } } /// Gets a receipt for a message. pub fn get_receipt(&self, message_id: &Hash256) -> Option<&MessageReceipt> { self.receipts.get(message_id) } /// Gets pending message count for a shard. pub fn pending_count(&self, shard_id: ShardId) -> usize { self.inbound .get(&shard_id) .map(|q| q.len()) .unwrap_or(0) } /// Handles message timeout. fn handle_timeout(&mut self, message: &CrossShardMessage) { // Remove from pending and mark as timed out self.pending_receipts.remove(&message.id); // Create timeout receipt let receipt = MessageReceipt { message_id: message.id, receiver_shard: message.to_shard, processed_slot: message.sent_slot + self.timeout_slots, success: false, result: b"TIMEOUT".to_vec(), post_state_root: Hash256::from_bytes([0u8; 32]), }; self.receipts.insert(message.id, receipt); } /// Cleans up old receipts. pub fn cleanup_old_receipts(&mut self, min_slot: Slot) { self.receipts.retain(|_, receipt| receipt.processed_slot >= min_slot); } /// Gets statistics about message routing. pub fn stats(&self) -> MessageStats { let total_pending: usize = self.inbound.values().map(|q| q.len()).sum(); let total_receipts = self.receipts.len(); let successful = self.receipts.values().filter(|r| r.success).count(); MessageStats { pending_messages: total_pending, total_receipts, successful_deliveries: successful, failed_deliveries: total_receipts - successful, } } } /// Message routing statistics. #[derive(Clone, Debug)] pub struct MessageStats { /// Pending messages across all shards. pub pending_messages: usize, /// Total receipts. pub total_receipts: usize, /// Successful deliveries. pub successful_deliveries: usize, /// Failed deliveries. pub failed_deliveries: usize, } #[cfg(test)] mod tests { use super::*; #[test] fn test_message_creation() { let msg = CrossShardMessage::new(0, 1, b"hello".to_vec(), 100); assert_eq!(msg.from_shard, 0); assert_eq!(msg.to_shard, 1); assert_eq!(msg.sent_slot, 100); assert_eq!(msg.status, MessageStatus::Pending); } #[test] fn test_send_receive() { let mut router = MessageRouter::new(4, 64); // Send message from shard 0 to shard 1 let msg_id = router.send(0, 1, b"test payload".to_vec(), 10).unwrap(); // Receive on shard 1 let messages = router.receive(1, 15); assert_eq!(messages.len(), 1); assert_eq!(messages[0].id, msg_id); assert_eq!(messages[0].status, MessageStatus::Delivered); } #[test] fn test_confirm_receipt() { let mut router = MessageRouter::new(4, 64); let msg_id = router.send(0, 1, b"test".to_vec(), 10).unwrap(); let _ = router.receive(1, 15); // Confirm processing router .confirm(msg_id, 1, 20, true, b"ok".to_vec(), Hash256::from_bytes([1u8; 32])) .unwrap(); let receipt = router.get_receipt(&msg_id).unwrap(); assert!(receipt.success); assert_eq!(receipt.result, b"ok"); } #[test] fn test_message_timeout() { let mut router = MessageRouter::new(4, 10); let msg_id = router.send(0, 1, b"test".to_vec(), 10).unwrap(); // Receive after timeout let messages = router.receive(1, 100); assert_eq!(messages.len(), 0); // Timed out messages not returned // Check timeout receipt let receipt = router.get_receipt(&msg_id).unwrap(); assert!(!receipt.success); } #[test] fn test_same_shard_error() { let mut router = MessageRouter::new(4, 64); let result = router.send(0, 0, b"test".to_vec(), 10); assert!(result.is_err()); } #[test] fn test_stats() { let mut router = MessageRouter::new(4, 64); router.send(0, 1, b"msg1".to_vec(), 10).unwrap(); router.send(0, 2, b"msg2".to_vec(), 10).unwrap(); let stats = router.stats(); assert_eq!(stats.pending_messages, 2); } }