//! Raft consensus implementation. use super::cluster::{ClusterConfig, NodeId, PeerState}; use super::election::{Election, ElectionResult, ElectionTimeout, VoteHandler}; use super::log::{LogEntry, ReplicatedLog}; use super::rpc::{ AppendEntries, AppendEntriesResponse, InstallSnapshot, InstallSnapshotResponse, RequestVote, RequestVoteResponse, RpcMessage, }; use super::snapshot::{SnapshotConfig, SnapshotManager}; use super::state::{Command, LeaderState, NodeRole, RaftState}; use serde::{Deserialize, Serialize}; use std::time::{Duration, Instant}; /// Configuration for the Raft node. #[derive(Clone, Debug)] pub struct RaftConfig { /// Minimum election timeout in milliseconds. pub election_timeout_min: u64, /// Maximum election timeout in milliseconds. pub election_timeout_max: u64, /// Heartbeat interval in milliseconds. pub heartbeat_interval: u64, /// Maximum entries per AppendEntries RPC. pub max_entries_per_rpc: usize, /// Snapshot threshold (entries before compaction). pub snapshot_threshold: u64, /// Maximum snapshot chunk size. pub snapshot_chunk_size: usize, } impl Default for RaftConfig { fn default() -> Self { Self { election_timeout_min: 150, election_timeout_max: 300, heartbeat_interval: 50, max_entries_per_rpc: 100, snapshot_threshold: 10000, snapshot_chunk_size: 65536, } } } /// Result of applying a command. #[derive(Clone, Debug, Serialize, Deserialize)] pub enum ApplyResult { /// Command applied successfully. Success(Vec), /// Command failed. Error(String), /// Not the leader, redirect to leader. NotLeader(Option), } /// Events that can be produced by the Raft node. #[derive(Clone, Debug)] pub enum RaftEvent { /// Send RPC to a peer. SendRpc(NodeId, RpcMessage), /// Broadcast RPC to all peers. BroadcastRpc(RpcMessage), /// Apply committed entry to state machine. ApplyEntry(u64, Command), /// Became leader. BecameLeader, /// Became follower. BecameFollower(Option), /// Snapshot should be taken. TakeSnapshot, /// Log compacted up to index. LogCompacted(u64), } /// The Raft consensus node. pub struct RaftNode { /// Node ID. id: NodeId, /// Cluster configuration. cluster: ClusterConfig, /// Raft configuration. config: RaftConfig, /// Persistent state. state: RaftState, /// Replicated log. log: ReplicatedLog, /// Current election (only during candidacy). election: Option, /// Election timeout generator. election_timeout: ElectionTimeout, /// Snapshot manager. snapshots: SnapshotManager, /// Leader state (only valid when leader). leader_state: Option, /// Current known leader. leader_id: Option, /// Last heartbeat/message from leader. last_leader_contact: Instant, /// Current election timeout duration. current_timeout: Duration, /// Pending events. events: Vec, } impl RaftNode { /// Creates a new Raft node. pub fn new(id: NodeId, cluster: ClusterConfig, config: RaftConfig) -> Self { let election_timeout = ElectionTimeout::new(config.election_timeout_min, config.election_timeout_max); let current_timeout = election_timeout.random_timeout(); Self { id, cluster, state: RaftState::new(), log: ReplicatedLog::new(), election: None, election_timeout, snapshots: SnapshotManager::new(config.snapshot_threshold), leader_state: None, leader_id: None, last_leader_contact: Instant::now(), current_timeout, events: Vec::new(), config, } } /// Returns the node ID. pub fn id(&self) -> NodeId { self.id } /// Returns the current term. pub fn current_term(&self) -> u64 { self.state.current_term } /// Returns the current role. pub fn role(&self) -> NodeRole { self.state.role } /// Returns the current leader ID if known. pub fn leader(&self) -> Option { self.leader_id } /// Returns true if this node is the leader. pub fn is_leader(&self) -> bool { self.state.is_leader() } /// Returns the commit index. pub fn commit_index(&self) -> u64 { self.state.commit_index } /// Returns the last applied index. pub fn last_applied(&self) -> u64 { self.state.last_applied } /// Returns the log length. pub fn log_length(&self) -> u64 { self.log.last_index() } /// Drains pending events. pub fn drain_events(&mut self) -> Vec { std::mem::take(&mut self.events) } /// Called periodically to drive the Raft state machine. pub fn tick(&mut self) { match self.state.role { NodeRole::Leader => self.tick_leader(), NodeRole::Follower => self.tick_follower(), NodeRole::Candidate => self.tick_candidate(), } // Apply committed entries self.apply_committed_entries(); // Check if snapshot needed if self .snapshots .should_snapshot(self.log.last_index(), self.snapshots.last_included_index()) { self.events.push(RaftEvent::TakeSnapshot); } } fn tick_leader(&mut self) { // Send heartbeats to all peers self.send_heartbeats(); } fn tick_follower(&mut self) { // Check for election timeout if self.last_leader_contact.elapsed() >= self.current_timeout { self.start_election(); } } fn tick_candidate(&mut self) { // Check for election timeout if self.last_leader_contact.elapsed() >= self.current_timeout { // Start a new election self.start_election(); } } fn start_election(&mut self) { // Increment term and become candidate self.state.become_candidate(); self.state.voted_for = Some(self.id); self.leader_id = None; // Reset timeout self.reset_election_timeout(); // Create new election let cluster_size = self.cluster.voting_members(); self.election = Some(Election::new(self.id, self.state.current_term, cluster_size)); // Create RequestVote message let request = RequestVote::new( self.state.current_term, self.id, self.log.last_index(), self.log.last_term(), ); self.events .push(RaftEvent::BroadcastRpc(RpcMessage::RequestVote(request))); // Check if we already have quorum (single-node cluster) if cluster_size == 1 { self.become_leader(); } } fn become_leader(&mut self) { self.state.become_leader(); self.leader_id = Some(self.id); self.election = None; // Initialize leader state let peer_ids: Vec<_> = self.cluster.peer_ids(); self.leader_state = Some(LeaderState::new(self.log.last_index(), &peer_ids)); self.events.push(RaftEvent::BecameLeader); // Send immediate heartbeats self.send_heartbeats(); } fn become_follower(&mut self, term: u64, leader: Option) { self.state.become_follower(term); self.leader_id = leader; self.leader_state = None; self.election = None; self.reset_election_timeout(); self.events.push(RaftEvent::BecameFollower(leader)); } fn reset_election_timeout(&mut self) { self.last_leader_contact = Instant::now(); self.current_timeout = self.election_timeout.random_timeout(); } fn send_heartbeats(&mut self) { if !self.is_leader() { return; } for peer_id in self.cluster.peer_ids() { self.send_append_entries(peer_id); } } fn send_append_entries(&mut self, peer_id: NodeId) { let leader_state = match &self.leader_state { Some(ls) => ls, None => return, }; let next_index = *leader_state.next_index.get(&peer_id).unwrap_or(&1); // Check if we need to send snapshot instead if next_index <= self.snapshots.last_included_index() { self.send_install_snapshot(peer_id); return; } let (prev_log_index, prev_log_term, entries) = self.log .entries_for_replication(next_index, self.config.max_entries_per_rpc); let request = AppendEntries::with_entries( self.state.current_term, self.id, prev_log_index, prev_log_term, entries, self.state.commit_index, ); self.events .push(RaftEvent::SendRpc(peer_id, RpcMessage::AppendEntries(request))); } fn send_install_snapshot(&mut self, peer_id: NodeId) { let snapshot = match self.snapshots.get_snapshot() { Some(s) => s, None => return, }; let chunks = self .snapshots .chunk_snapshot(self.config.snapshot_chunk_size); if let Some((offset, data, done)) = chunks.into_iter().next() { let request = InstallSnapshot::new( self.state.current_term, self.id, snapshot.metadata.last_included_index, snapshot.metadata.last_included_term, offset, data, done, ); self.events .push(RaftEvent::SendRpc(peer_id, RpcMessage::InstallSnapshot(request))); } } /// Handles an incoming RPC message. pub fn handle_rpc(&mut self, from: NodeId, message: RpcMessage) -> Option { match message { RpcMessage::RequestVote(req) => { let response = self.handle_request_vote(from, req); Some(RpcMessage::RequestVoteResponse(response)) } RpcMessage::RequestVoteResponse(resp) => { self.handle_request_vote_response(from, resp); None } RpcMessage::AppendEntries(req) => { let response = self.handle_append_entries(from, req); Some(RpcMessage::AppendEntriesResponse(response)) } RpcMessage::AppendEntriesResponse(resp) => { self.handle_append_entries_response(from, resp); None } RpcMessage::InstallSnapshot(req) => { let response = self.handle_install_snapshot(from, req); Some(RpcMessage::InstallSnapshotResponse(response)) } RpcMessage::InstallSnapshotResponse(resp) => { self.handle_install_snapshot_response(from, resp); None } } } fn handle_request_vote(&mut self, _from: NodeId, req: RequestVote) -> RequestVoteResponse { // Use the VoteHandler from election module VoteHandler::handle_request(&mut self.state, &self.log, &req) } fn handle_request_vote_response(&mut self, from: NodeId, resp: RequestVoteResponse) { // Ignore if not candidate if !self.state.is_candidate() { return; } // Use the VoteHandler if let Some(ref mut election) = self.election { let result = VoteHandler::handle_response(&mut self.state, election, from, &resp); match result { ElectionResult::Won => { self.become_leader(); } ElectionResult::Lost => { // Already handled by VoteHandler (became follower) self.election = None; } ElectionResult::InProgress | ElectionResult::Timeout => {} } } } fn handle_append_entries(&mut self, _from: NodeId, req: AppendEntries) -> AppendEntriesResponse { // Rule: If term > currentTerm, become follower if req.term > self.state.current_term { self.become_follower(req.term, Some(req.leader_id)); } // Reject if term is old if req.term < self.state.current_term { return AppendEntriesResponse::failure(self.state.current_term); } // Valid AppendEntries from leader - reset election timeout self.reset_election_timeout(); self.leader_id = Some(req.leader_id); // If we're candidate, step down if self.state.is_candidate() { self.become_follower(req.term, Some(req.leader_id)); } // Try to append entries let success = self.log .append_entries(req.prev_log_index, req.prev_log_term, req.entries); if success { // Update commit index if req.leader_commit > self.state.commit_index { self.state .update_commit_index(std::cmp::min(req.leader_commit, self.log.last_index())); } AppendEntriesResponse::success(self.state.current_term, self.log.last_index()) } else { // Find conflict info for faster recovery if let Some(conflict_term) = self.log.term_at(req.prev_log_index) { // Find first index of conflicting term let mut conflict_index = req.prev_log_index; while conflict_index > 1 { if let Some(term) = self.log.term_at(conflict_index - 1) { if term != conflict_term { break; } } else { break; } conflict_index -= 1; } AppendEntriesResponse::conflict(self.state.current_term, conflict_term, conflict_index) } else { AppendEntriesResponse::failure(self.state.current_term) } } } fn handle_append_entries_response(&mut self, from: NodeId, resp: AppendEntriesResponse) { // Ignore if not leader if !self.is_leader() { return; } // Step down if term is higher if resp.term > self.state.current_term { self.become_follower(resp.term, None); return; } let leader_state = match &mut self.leader_state { Some(ls) => ls, None => return, }; if resp.success { // Update match_index and next_index leader_state.update_indices(from, resp.match_index); // Try to advance commit index self.try_advance_commit_index(); } else { // Use conflict info if available for faster recovery if let (Some(conflict_term), Some(conflict_index)) = (resp.conflict_term, resp.conflict_index) { // Search for last entry with conflict_term let mut new_next = conflict_index; for idx in (1..=self.log.last_index()).rev() { if let Some(term) = self.log.term_at(idx) { if term == conflict_term { new_next = idx + 1; break; } if term < conflict_term { break; } } } leader_state.next_index.insert(from, new_next); } else { // Simple decrement leader_state.decrement_next_index(from); } } // Update peer state self.cluster.update_peer_state(from, PeerState::Reachable); } fn handle_install_snapshot(&mut self, _from: NodeId, req: InstallSnapshot) -> InstallSnapshotResponse { // Rule: If term > currentTerm, become follower if req.term > self.state.current_term { self.become_follower(req.term, Some(req.leader_id)); } // Reject if term is old if req.term < self.state.current_term { return InstallSnapshotResponse::new(self.state.current_term); } // Reset election timeout self.reset_election_timeout(); self.leader_id = Some(req.leader_id); // Start or continue receiving snapshot if req.offset == 0 { self.snapshots.start_receiving( req.last_included_index, req.last_included_term, req.offset, req.data, ); } else { self.snapshots.add_chunk(req.offset, req.data); } // If done, finalize snapshot if req.done { if let Some(_snapshot) = self.snapshots.finalize_snapshot() { // Discard log up to snapshot self.log .compact(req.last_included_index, req.last_included_term); // Update state if req.last_included_index > self.state.commit_index { self.state.update_commit_index(req.last_included_index); } if req.last_included_index > self.state.last_applied { self.state.update_last_applied(req.last_included_index); } self.events .push(RaftEvent::LogCompacted(req.last_included_index)); } } InstallSnapshotResponse::new(self.state.current_term) } fn handle_install_snapshot_response(&mut self, from: NodeId, resp: InstallSnapshotResponse) { if !self.is_leader() { return; } if resp.term > self.state.current_term { self.become_follower(resp.term, None); return; } // Update next_index for peer (assuming success since there's no success field) if let Some(leader_state) = &mut self.leader_state { let snapshot_index = self.snapshots.last_included_index(); leader_state.update_indices(from, snapshot_index); } self.cluster.update_peer_state(from, PeerState::Reachable); } fn try_advance_commit_index(&mut self) { if !self.is_leader() { return; } let leader_state = match &self.leader_state { Some(ls) => ls, None => return, }; // Calculate new commit index let new_commit = leader_state.calculate_commit_index( self.state.commit_index, self.state.current_term, |idx| self.log.term_at(idx), ); if new_commit > self.state.commit_index { self.state.update_commit_index(new_commit); } } fn apply_committed_entries(&mut self) { while self.state.last_applied < self.state.commit_index { let next_to_apply = self.state.last_applied + 1; if let Some(entry) = self.log.get(next_to_apply) { self.events .push(RaftEvent::ApplyEntry(entry.index, entry.command.clone())); self.state.update_last_applied(next_to_apply); } else { break; } } } /// Proposes a command (only valid on leader). pub fn propose(&mut self, command: Command) -> Result { if !self.is_leader() { return Err(ApplyResult::NotLeader(self.leader_id)); } // Create log entry let index = self.log.last_index() + 1; let entry = LogEntry::new(self.state.current_term, index, command); let result_index = self.log.append(entry); // Send AppendEntries to all peers for peer_id in self.cluster.peer_ids() { self.send_append_entries(peer_id); } Ok(result_index) } /// Takes a snapshot of the state machine. pub fn take_snapshot(&mut self, data: Vec) { let last_index = self.state.last_applied; let last_term = self.log.term_at(last_index).unwrap_or(0); let config = SnapshotConfig { nodes: std::iter::once(self.id) .chain(self.cluster.peer_ids()) .collect(), }; self.snapshots .create_snapshot(last_index, last_term, config, data); // Compact log self.log.compact(last_index, last_term); self.events.push(RaftEvent::LogCompacted(last_index)); } /// Adds a new server to the cluster. pub fn add_server( &mut self, id: NodeId, address: super::cluster::PeerAddress, ) -> Result<(), String> { if !self.is_leader() { return Err("Not leader".to_string()); } let peer = super::cluster::PeerInfo::new(id, address); self.cluster.add_peer(peer); // Initialize leader state for new peer if let Some(ref mut ls) = self.leader_state { ls.next_index.insert(id, self.log.last_index() + 1); ls.match_index.insert(id, 0); } Ok(()) } /// Removes a server from the cluster. pub fn remove_server(&mut self, id: NodeId) -> Result<(), String> { if !self.is_leader() { return Err("Not leader".to_string()); } self.cluster.remove_peer(id); if let Some(ref mut ls) = self.leader_state { ls.next_index.remove(&id); ls.match_index.remove(&id); } Ok(()) } /// Forces an election timeout (for testing). #[cfg(test)] pub fn force_election_timeout(&mut self) { self.last_leader_contact = Instant::now() - self.current_timeout - Duration::from_secs(1); } } #[cfg(test)] mod tests { use super::*; use super::super::cluster::PeerAddress; fn create_test_cluster(node_id: NodeId, peers: &[NodeId]) -> ClusterConfig { let mut cluster = ClusterConfig::new(node_id, PeerAddress::new("127.0.0.1", 9000 + node_id as u16)); for &peer in peers { cluster.add_peer(super::super::cluster::PeerInfo::new( peer, PeerAddress::new("127.0.0.1", 9000 + peer as u16), )); } cluster } #[test] fn test_single_node_election() { let cluster = create_test_cluster(1, &[]); let config = RaftConfig::default(); let mut node = RaftNode::new(1, cluster, config); assert!(matches!(node.role(), NodeRole::Follower)); // Simulate election timeout node.force_election_timeout(); node.tick(); // Single node should become leader immediately assert!(node.is_leader()); } #[test] fn test_three_node_election() { let cluster = create_test_cluster(1, &[2, 3]); let config = RaftConfig::default(); let mut node1 = RaftNode::new(1, cluster, config); // Trigger election node1.force_election_timeout(); node1.tick(); assert!(matches!(node1.role(), NodeRole::Candidate)); assert_eq!(node1.current_term(), 1); // Simulate receiving votes let vote_resp = RequestVoteResponse::grant(1); node1.handle_rpc(2, RpcMessage::RequestVoteResponse(vote_resp.clone())); // Should become leader after receiving vote from node 2 (2 votes = quorum of 3) assert!(node1.is_leader()); } #[test] fn test_append_entries() { let cluster = create_test_cluster(1, &[]); let config = RaftConfig::default(); let mut leader = RaftNode::new(1, cluster, config); // Become leader leader.force_election_timeout(); leader.tick(); assert!(leader.is_leader()); // Propose a command let command = Command::KvSet { key: "test".to_string(), value: vec![1, 2, 3], ttl: None, }; let index = leader.propose(command).unwrap(); assert_eq!(index, 1); assert_eq!(leader.log_length(), 1); } #[test] fn test_follower_receives_append_entries() { let cluster1 = create_test_cluster(1, &[2]); let cluster2 = create_test_cluster(2, &[1]); let config = RaftConfig::default(); let mut leader = RaftNode::new(1, cluster1, config.clone()); let mut follower = RaftNode::new(2, cluster2, config); // Make node 1 leader leader.force_election_timeout(); leader.tick(); // Simulate vote from node 2 leader.handle_rpc( 2, RpcMessage::RequestVoteResponse(RequestVoteResponse::grant(1)), ); assert!(leader.is_leader()); // Propose a command leader .propose(Command::KvSet { key: "key1".to_string(), value: vec![1], ttl: None, }) .unwrap(); // Get AppendEntries from leader events (skip heartbeats, find one with entries) let events = leader.drain_events(); let append_req = events .iter() .find_map(|e| { if let RaftEvent::SendRpc(2, RpcMessage::AppendEntries(req)) = e { // Skip heartbeats (empty entries), find the one with actual entries if !req.entries.is_empty() { Some(req.clone()) } else { None } } else { None } }) .unwrap(); // Send to follower let response = follower .handle_rpc(1, RpcMessage::AppendEntries(append_req)) .unwrap(); if let RpcMessage::AppendEntriesResponse(resp) = response { assert!(resp.success); assert_eq!(resp.match_index, 1); } assert_eq!(follower.log_length(), 1); } #[test] fn test_step_down_on_higher_term() { let cluster = create_test_cluster(1, &[2]); let config = RaftConfig::default(); let mut node = RaftNode::new(1, cluster, config); // Make it leader node.force_election_timeout(); node.tick(); node.handle_rpc( 2, RpcMessage::RequestVoteResponse(RequestVoteResponse::grant(1)), ); assert!(node.is_leader()); assert_eq!(node.current_term(), 1); // Receive AppendEntries with higher term node.handle_rpc( 2, RpcMessage::AppendEntries(AppendEntries::heartbeat(5, 2, 0, 0, 0)), ); assert!(!node.is_leader()); assert_eq!(node.current_term(), 5); assert_eq!(node.leader(), Some(2)); } #[test] fn test_commit_index_advancement() { let cluster = create_test_cluster(1, &[2, 3]); let config = RaftConfig::default(); let mut leader = RaftNode::new(1, cluster, config); // Become leader leader.force_election_timeout(); leader.tick(); leader.handle_rpc( 2, RpcMessage::RequestVoteResponse(RequestVoteResponse::grant(1)), ); leader.handle_rpc( 3, RpcMessage::RequestVoteResponse(RequestVoteResponse::grant(1)), ); // Propose command leader .propose(Command::KvSet { key: "key".to_string(), value: vec![1], ttl: None, }) .unwrap(); // Simulate successful replication to node 2 leader.handle_rpc( 2, RpcMessage::AppendEntriesResponse(AppendEntriesResponse::success(1, 1)), ); // Commit index should advance (quorum = 2, we have leader + node2) assert_eq!(leader.commit_index(), 1); } #[test] fn test_snapshot() { let cluster = create_test_cluster(1, &[]); let config = RaftConfig::default(); let mut node = RaftNode::new(1, cluster, config); // Become leader node.force_election_timeout(); node.tick(); // Propose and commit some entries for i in 0..5 { node.propose(Command::KvSet { key: format!("key{}", i), value: vec![i as u8], ttl: None, }) .unwrap(); } // Manually advance commit and apply node.state.update_commit_index(5); node.tick(); // Take snapshot node.take_snapshot(vec![1, 2, 3, 4, 5]); // Check snapshot was created assert!(node.snapshots.get_snapshot().is_some()); } #[test] fn test_request_vote_log_check() { let cluster = create_test_cluster(1, &[2]); let config = RaftConfig::default(); let mut node = RaftNode::new(1, cluster, config); // Add some entries to node's log node.log.append(LogEntry::new(1, 1, Command::Noop)); node.log.append(LogEntry::new(2, 2, Command::Noop)); // Request vote from candidate with shorter log (lower term) let req = RequestVote::new(3, 2, 1, 1); let resp = node.handle_rpc(2, RpcMessage::RequestVote(req)); if let Some(RpcMessage::RequestVoteResponse(r)) = resp { // Should not grant vote - our log is more up-to-date (term 2 > term 1) assert!(!r.vote_granted); } // Request vote from candidate with equal or better log let req2 = RequestVote::new(4, 2, 2, 2); let resp2 = node.handle_rpc(2, RpcMessage::RequestVote(req2)); if let Some(RpcMessage::RequestVoteResponse(r)) = resp2 { assert!(r.vote_granted); } } }