diff --git a/crates/synor-database/Cargo.toml b/crates/synor-database/Cargo.toml index 23ee27e..2a2e046 100644 --- a/crates/synor-database/Cargo.toml +++ b/crates/synor-database/Cargo.toml @@ -15,6 +15,7 @@ synor-storage = { path = "../synor-storage" } serde.workspace = true serde_json.workspace = true borsh.workspace = true +bincode = "1.3" # Utilities thiserror.workspace = true @@ -36,6 +37,9 @@ indexmap = "2.2" axum.workspace = true tower-http = { version = "0.5", features = ["cors", "trace"] } +# SQL parsing +sqlparser = "0.43" + # Vector operations (for AI/RAG) # Using pure Rust for portability diff --git a/crates/synor-database/src/graph/edge.rs b/crates/synor-database/src/graph/edge.rs new file mode 100644 index 0000000..11eb475 --- /dev/null +++ b/crates/synor-database/src/graph/edge.rs @@ -0,0 +1,301 @@ +//! Graph edge (relationship) definition. + +use super::node::NodeId; +use serde::{Deserialize, Serialize}; +use serde_json::Value as JsonValue; +use std::sync::atomic::{AtomicU64, Ordering}; + +/// Unique edge identifier. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct EdgeId(pub [u8; 32]); + +impl EdgeId { + /// Creates a new unique edge ID. + pub fn new() -> Self { + static COUNTER: AtomicU64 = AtomicU64::new(1); + let id = COUNTER.fetch_add(1, Ordering::SeqCst); + let mut bytes = [0u8; 32]; + bytes[..8].copy_from_slice(&id.to_be_bytes()); + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_nanos() as u64; + bytes[8..16].copy_from_slice(&now.to_be_bytes()); + EdgeId(*blake3::hash(&bytes).as_bytes()) + } + + /// Creates from raw bytes. + pub fn from_bytes(bytes: [u8; 32]) -> Self { + EdgeId(bytes) + } + + /// Creates from hex string. + pub fn from_hex(hex: &str) -> Option { + let bytes = hex::decode(hex).ok()?; + if bytes.len() != 32 { + return None; + } + let mut arr = [0u8; 32]; + arr.copy_from_slice(&bytes); + Some(EdgeId(arr)) + } + + /// Returns the bytes. + pub fn as_bytes(&self) -> &[u8; 32] { + &self.0 + } + + /// Converts to hex string. + pub fn to_hex(&self) -> String { + hex::encode(self.0) + } +} + +impl Default for EdgeId { + fn default() -> Self { + Self::new() + } +} + +impl std::fmt::Display for EdgeId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "edge_{}", hex::encode(&self.0[..8])) + } +} + +/// An edge (relationship) in the graph. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct Edge { + /// Unique edge ID. + pub id: EdgeId, + /// Source node ID. + pub source: NodeId, + /// Target node ID. + pub target: NodeId, + /// Edge type (relationship type), e.g., "FRIEND", "OWNS". + pub edge_type: String, + /// Properties stored as JSON. + pub properties: JsonValue, + /// Whether this is a directed edge. + pub directed: bool, + /// Weight for path-finding algorithms. + pub weight: f64, + /// Creation timestamp. + pub created_at: u64, +} + +impl Edge { + /// Creates a new directed edge. + pub fn new(source: NodeId, target: NodeId, edge_type: impl Into, properties: JsonValue) -> Self { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_millis() as u64; + + Self { + id: EdgeId::new(), + source, + target, + edge_type: edge_type.into(), + properties, + directed: true, + weight: 1.0, + created_at: now, + } + } + + /// Creates an undirected edge. + pub fn undirected(source: NodeId, target: NodeId, edge_type: impl Into, properties: JsonValue) -> Self { + let mut edge = Self::new(source, target, edge_type, properties); + edge.directed = false; + edge + } + + /// Sets the weight for this edge. + pub fn with_weight(mut self, weight: f64) -> Self { + self.weight = weight; + self + } + + /// Returns the other end of this edge from the given node. + pub fn other_end(&self, from: &NodeId) -> Option { + if &self.source == from { + Some(self.target) + } else if &self.target == from && !self.directed { + Some(self.source) + } else if &self.target == from { + // For directed edges, can't traverse backward + None + } else { + None + } + } + + /// Checks if this edge connects the given node (as source or target). + pub fn connects(&self, node: &NodeId) -> bool { + &self.source == node || &self.target == node + } + + /// Checks if this edge connects two specific nodes. + pub fn connects_pair(&self, a: &NodeId, b: &NodeId) -> bool { + (&self.source == a && &self.target == b) || + (!self.directed && &self.source == b && &self.target == a) + } + + /// Gets a property value. + pub fn get_property(&self, key: &str) -> Option<&JsonValue> { + self.properties.get(key) + } + + /// Sets a property value. + pub fn set_property(&mut self, key: &str, value: JsonValue) { + if let Some(obj) = self.properties.as_object_mut() { + obj.insert(key.to_string(), value); + } + } + + /// Checks if the edge matches a property filter. + pub fn matches_properties(&self, filter: &JsonValue) -> bool { + if let (Some(filter_obj), Some(props_obj)) = (filter.as_object(), self.properties.as_object()) { + for (key, expected) in filter_obj { + if let Some(actual) = props_obj.get(key) { + if actual != expected { + return false; + } + } else { + return false; + } + } + true + } else { + filter == &self.properties || filter == &JsonValue::Object(serde_json::Map::new()) + } + } +} + +/// Builder for creating edges. +pub struct EdgeBuilder { + source: NodeId, + target: NodeId, + edge_type: String, + properties: serde_json::Map, + directed: bool, + weight: f64, +} + +impl EdgeBuilder { + /// Creates a new edge builder. + pub fn new(source: NodeId, target: NodeId, edge_type: impl Into) -> Self { + Self { + source, + target, + edge_type: edge_type.into(), + properties: serde_json::Map::new(), + directed: true, + weight: 1.0, + } + } + + /// Sets the edge as undirected. + pub fn undirected(mut self) -> Self { + self.directed = false; + self + } + + /// Sets the weight. + pub fn weight(mut self, weight: f64) -> Self { + self.weight = weight; + self + } + + /// Sets a property. + pub fn property(mut self, key: impl Into, value: impl Into) -> Self { + self.properties.insert(key.into(), value.into()); + self + } + + /// Builds the edge. + pub fn build(self) -> Edge { + let mut edge = Edge::new(self.source, self.target, self.edge_type, JsonValue::Object(self.properties)); + edge.directed = self.directed; + edge.weight = self.weight; + edge + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_edge_id() { + let id1 = EdgeId::new(); + let id2 = EdgeId::new(); + assert_ne!(id1, id2); + + let hex = id1.to_hex(); + let id3 = EdgeId::from_hex(&hex).unwrap(); + assert_eq!(id1, id3); + } + + #[test] + fn test_edge_creation() { + let source = NodeId::new(); + let target = NodeId::new(); + + let edge = Edge::new(source, target, "FRIEND", serde_json::json!({"since": 2020})); + + assert_eq!(edge.source, source); + assert_eq!(edge.target, target); + assert_eq!(edge.edge_type, "FRIEND"); + assert!(edge.directed); + } + + #[test] + fn test_edge_builder() { + let source = NodeId::new(); + let target = NodeId::new(); + + let edge = EdgeBuilder::new(source, target, "OWNS") + .undirected() + .weight(2.5) + .property("percentage", 50) + .build(); + + assert!(!edge.directed); + assert_eq!(edge.weight, 2.5); + assert_eq!(edge.get_property("percentage"), Some(&serde_json::json!(50))); + } + + #[test] + fn test_edge_other_end() { + let source = NodeId::new(); + let target = NodeId::new(); + + // Directed edge + let directed = Edge::new(source, target, "A", serde_json::json!({})); + assert_eq!(directed.other_end(&source), Some(target)); + assert_eq!(directed.other_end(&target), None); // Can't traverse backward + + // Undirected edge + let undirected = Edge::undirected(source, target, "B", serde_json::json!({})); + assert_eq!(undirected.other_end(&source), Some(target)); + assert_eq!(undirected.other_end(&target), Some(source)); + } + + #[test] + fn test_edge_connects() { + let a = NodeId::new(); + let b = NodeId::new(); + let c = NodeId::new(); + + let edge = Edge::new(a, b, "LINK", serde_json::json!({})); + + assert!(edge.connects(&a)); + assert!(edge.connects(&b)); + assert!(!edge.connects(&c)); + + assert!(edge.connects_pair(&a, &b)); + assert!(!edge.connects_pair(&b, &a)); // Directed + } +} diff --git a/crates/synor-database/src/graph/mod.rs b/crates/synor-database/src/graph/mod.rs new file mode 100644 index 0000000..3680eec --- /dev/null +++ b/crates/synor-database/src/graph/mod.rs @@ -0,0 +1,18 @@ +//! Graph store module for relationship-based queries. +//! +//! Provides a graph database with nodes, edges, and traversal algorithms +//! suitable for social networks, knowledge graphs, and recommendation engines. + +pub mod edge; +pub mod node; +pub mod path; +pub mod query; +pub mod store; +pub mod traversal; + +pub use edge::{Edge, EdgeId}; +pub use node::{Node, NodeId}; +pub use path::{PathFinder, PathResult}; +pub use query::{GraphQuery, GraphQueryParser, MatchPattern, QueryResult}; +pub use store::{Direction, GraphError, GraphStore}; +pub use traversal::{TraversalQuery, TraversalResult, Traverser}; diff --git a/crates/synor-database/src/graph/node.rs b/crates/synor-database/src/graph/node.rs new file mode 100644 index 0000000..0937c38 --- /dev/null +++ b/crates/synor-database/src/graph/node.rs @@ -0,0 +1,315 @@ +//! Graph node (vertex) definition. + +use serde::{Deserialize, Serialize}; +use serde_json::Value as JsonValue; +use std::sync::atomic::{AtomicU64, Ordering}; + +/// Unique node identifier. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct NodeId(pub [u8; 32]); + +impl NodeId { + /// Creates a new unique node ID. + pub fn new() -> Self { + static COUNTER: AtomicU64 = AtomicU64::new(1); + let id = COUNTER.fetch_add(1, Ordering::SeqCst); + let mut bytes = [0u8; 32]; + bytes[..8].copy_from_slice(&id.to_be_bytes()); + // Add timestamp for uniqueness + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_nanos() as u64; + bytes[8..16].copy_from_slice(&now.to_be_bytes()); + NodeId(*blake3::hash(&bytes).as_bytes()) + } + + /// Creates from raw bytes. + pub fn from_bytes(bytes: [u8; 32]) -> Self { + NodeId(bytes) + } + + /// Creates from hex string. + pub fn from_hex(hex: &str) -> Option { + let bytes = hex::decode(hex).ok()?; + if bytes.len() != 32 { + return None; + } + let mut arr = [0u8; 32]; + arr.copy_from_slice(&bytes); + Some(NodeId(arr)) + } + + /// Returns the bytes. + pub fn as_bytes(&self) -> &[u8; 32] { + &self.0 + } + + /// Converts to hex string. + pub fn to_hex(&self) -> String { + hex::encode(self.0) + } +} + +impl Default for NodeId { + fn default() -> Self { + Self::new() + } +} + +impl std::fmt::Display for NodeId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "node_{}", hex::encode(&self.0[..8])) + } +} + +/// A node (vertex) in the graph. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct Node { + /// Unique node ID. + pub id: NodeId, + /// Labels (types) for this node (e.g., "User", "Product"). + pub labels: Vec, + /// Properties stored as JSON. + pub properties: JsonValue, + /// Creation timestamp. + pub created_at: u64, + /// Last update timestamp. + pub updated_at: u64, +} + +impl Node { + /// Creates a new node with the given labels and properties. + pub fn new(labels: Vec, properties: JsonValue) -> Self { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_millis() as u64; + + Self { + id: NodeId::new(), + labels, + properties, + created_at: now, + updated_at: now, + } + } + + /// Creates a node with a specific ID. + pub fn with_id(id: NodeId, labels: Vec, properties: JsonValue) -> Self { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_millis() as u64; + + Self { + id, + labels, + properties, + created_at: now, + updated_at: now, + } + } + + /// Returns true if this node has the given label. + pub fn has_label(&self, label: &str) -> bool { + self.labels.iter().any(|l| l == label) + } + + /// Adds a label to this node. + pub fn add_label(&mut self, label: impl Into) { + let label = label.into(); + if !self.labels.contains(&label) { + self.labels.push(label); + self.touch(); + } + } + + /// Removes a label from this node. + pub fn remove_label(&mut self, label: &str) -> bool { + if let Some(pos) = self.labels.iter().position(|l| l == label) { + self.labels.remove(pos); + self.touch(); + true + } else { + false + } + } + + /// Gets a property value. + pub fn get_property(&self, key: &str) -> Option<&JsonValue> { + self.properties.get(key) + } + + /// Sets a property value. + pub fn set_property(&mut self, key: &str, value: JsonValue) { + if let Some(obj) = self.properties.as_object_mut() { + obj.insert(key.to_string(), value); + self.touch(); + } + } + + /// Removes a property. + pub fn remove_property(&mut self, key: &str) -> Option { + if let Some(obj) = self.properties.as_object_mut() { + let removed = obj.remove(key); + if removed.is_some() { + self.touch(); + } + removed + } else { + None + } + } + + /// Updates the updated_at timestamp. + fn touch(&mut self) { + self.updated_at = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_millis() as u64; + } + + /// Checks if the node matches a property filter. + pub fn matches_properties(&self, filter: &JsonValue) -> bool { + if let (Some(filter_obj), Some(props_obj)) = (filter.as_object(), self.properties.as_object()) { + for (key, expected) in filter_obj { + if let Some(actual) = props_obj.get(key) { + if actual != expected { + return false; + } + } else { + return false; + } + } + true + } else { + filter == &self.properties + } + } +} + +/// Builder for creating nodes. +pub struct NodeBuilder { + labels: Vec, + properties: serde_json::Map, +} + +impl NodeBuilder { + /// Creates a new node builder. + pub fn new() -> Self { + Self { + labels: Vec::new(), + properties: serde_json::Map::new(), + } + } + + /// Adds a label. + pub fn label(mut self, label: impl Into) -> Self { + self.labels.push(label.into()); + self + } + + /// Adds multiple labels. + pub fn labels(mut self, labels: Vec) -> Self { + self.labels.extend(labels); + self + } + + /// Sets a property. + pub fn property(mut self, key: impl Into, value: impl Into) -> Self { + self.properties.insert(key.into(), value.into()); + self + } + + /// Builds the node. + pub fn build(self) -> Node { + Node::new(self.labels, JsonValue::Object(self.properties)) + } +} + +impl Default for NodeBuilder { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_node_id() { + let id1 = NodeId::new(); + let id2 = NodeId::new(); + assert_ne!(id1, id2); + + let hex = id1.to_hex(); + let id3 = NodeId::from_hex(&hex).unwrap(); + assert_eq!(id1, id3); + } + + #[test] + fn test_node_creation() { + let node = Node::new( + vec!["User".to_string()], + serde_json::json!({"name": "Alice", "age": 30}), + ); + + assert!(node.has_label("User")); + assert!(!node.has_label("Admin")); + assert_eq!( + node.get_property("name"), + Some(&serde_json::json!("Alice")) + ); + } + + #[test] + fn test_node_builder() { + let node = NodeBuilder::new() + .label("Person") + .label("Developer") + .property("name", "Bob") + .property("skills", serde_json::json!(["Rust", "Python"])) + .build(); + + assert!(node.has_label("Person")); + assert!(node.has_label("Developer")); + assert_eq!(node.get_property("name"), Some(&serde_json::json!("Bob"))); + } + + #[test] + fn test_node_labels() { + let mut node = Node::new(vec!["User".to_string()], serde_json::json!({})); + + node.add_label("Admin"); + assert!(node.has_label("Admin")); + + node.remove_label("Admin"); + assert!(!node.has_label("Admin")); + } + + #[test] + fn test_node_properties() { + let mut node = Node::new(vec![], serde_json::json!({})); + + node.set_property("key", serde_json::json!("value")); + assert_eq!(node.get_property("key"), Some(&serde_json::json!("value"))); + + let removed = node.remove_property("key"); + assert_eq!(removed, Some(serde_json::json!("value"))); + assert_eq!(node.get_property("key"), None); + } + + #[test] + fn test_node_matches() { + let node = Node::new( + vec!["User".to_string()], + serde_json::json!({"name": "Alice", "age": 30, "active": true}), + ); + + assert!(node.matches_properties(&serde_json::json!({"name": "Alice"}))); + assert!(node.matches_properties(&serde_json::json!({"name": "Alice", "age": 30}))); + assert!(!node.matches_properties(&serde_json::json!({"name": "Bob"}))); + } +} diff --git a/crates/synor-database/src/graph/path.rs b/crates/synor-database/src/graph/path.rs new file mode 100644 index 0000000..15df44d --- /dev/null +++ b/crates/synor-database/src/graph/path.rs @@ -0,0 +1,485 @@ +//! Path finding algorithms for graphs. + +use super::edge::Edge; +use super::node::{Node, NodeId}; +use super::store::{Direction, GraphStore}; +use serde::{Deserialize, Serialize}; +use std::cmp::Ordering; +use std::collections::{BinaryHeap, HashMap, HashSet, VecDeque}; + +/// Result of a path finding operation. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct PathResult { + /// The path as a list of node IDs. + pub nodes: Vec, + /// The edges traversed. + pub edges: Vec, + /// Total path length (number of hops or weighted distance). + pub length: f64, + /// Whether a path was found. + pub found: bool, +} + +impl PathResult { + /// Creates an empty (not found) result. + pub fn not_found() -> Self { + Self { + nodes: Vec::new(), + edges: Vec::new(), + length: f64::INFINITY, + found: false, + } + } + + /// Creates a found result. + pub fn found(nodes: Vec, edges: Vec, length: f64) -> Self { + Self { + nodes, + edges, + length, + found: true, + } + } +} + +/// State for Dijkstra's algorithm priority queue. +#[derive(Clone)] +struct DijkstraState { + node: NodeId, + distance: f64, +} + +impl PartialEq for DijkstraState { + fn eq(&self, other: &Self) -> bool { + self.distance == other.distance + } +} + +impl Eq for DijkstraState {} + +impl Ord for DijkstraState { + fn cmp(&self, other: &Self) -> Ordering { + // Reverse ordering for min-heap + other.distance.partial_cmp(&self.distance).unwrap_or(Ordering::Equal) + } +} + +impl PartialOrd for DijkstraState { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +/// Path finder for graph shortest path queries. +pub struct PathFinder<'a> { + store: &'a GraphStore, +} + +impl<'a> PathFinder<'a> { + /// Creates a new path finder. + pub fn new(store: &'a GraphStore) -> Self { + Self { store } + } + + /// Finds the shortest path using BFS (unweighted). + pub fn shortest_path_bfs(&self, from: &NodeId, to: &NodeId) -> PathResult { + if from == to { + return PathResult::found(vec![*from], Vec::new(), 0.0); + } + + let mut visited = HashSet::new(); + let mut queue: VecDeque<(NodeId, Vec, Vec)> = VecDeque::new(); + + visited.insert(*from); + queue.push_back((*from, vec![*from], Vec::new())); + + while let Some((current, path, edges)) = queue.pop_front() { + let neighbor_edges = self.store.edges_of(¤t, Direction::Both); + + for edge in neighbor_edges { + let neighbor = if edge.source == current { + edge.target + } else if !edge.directed || edge.target == current { + edge.source + } else { + continue; + }; + + if &neighbor == to { + let mut final_path = path.clone(); + final_path.push(neighbor); + let path_len = final_path.len(); + let mut final_edges = edges.clone(); + final_edges.push(edge); + return PathResult::found(final_path, final_edges, path_len as f64 - 1.0); + } + + if !visited.contains(&neighbor) { + visited.insert(neighbor); + let mut new_path = path.clone(); + new_path.push(neighbor); + let mut new_edges = edges.clone(); + new_edges.push(edge); + queue.push_back((neighbor, new_path, new_edges)); + } + } + } + + PathResult::not_found() + } + + /// Finds the shortest path using Dijkstra's algorithm (weighted). + pub fn shortest_path_dijkstra(&self, from: &NodeId, to: &NodeId) -> PathResult { + if from == to { + return PathResult::found(vec![*from], Vec::new(), 0.0); + } + + let mut distances: HashMap = HashMap::new(); + let mut previous: HashMap = HashMap::new(); + let mut heap = BinaryHeap::new(); + let mut visited = HashSet::new(); + + distances.insert(*from, 0.0); + heap.push(DijkstraState { node: *from, distance: 0.0 }); + + while let Some(DijkstraState { node: current, distance: dist }) = heap.pop() { + if ¤t == to { + // Reconstruct path + let mut path = vec![current]; + let mut edges = Vec::new(); + let mut curr = current; + + while let Some((prev, edge)) = previous.get(&curr) { + path.push(*prev); + edges.push(edge.clone()); + curr = *prev; + } + + path.reverse(); + edges.reverse(); + return PathResult::found(path, edges, dist); + } + + if visited.contains(¤t) { + continue; + } + visited.insert(current); + + let neighbor_edges = self.store.edges_of(¤t, Direction::Both); + + for edge in neighbor_edges { + let neighbor = if edge.source == current { + edge.target + } else if !edge.directed || edge.target == current { + edge.source + } else { + continue; + }; + + if visited.contains(&neighbor) { + continue; + } + + let new_dist = dist + edge.weight; + let is_shorter = distances.get(&neighbor).map(|&d| new_dist < d).unwrap_or(true); + + if is_shorter { + distances.insert(neighbor, new_dist); + previous.insert(neighbor, (current, edge.clone())); + heap.push(DijkstraState { node: neighbor, distance: new_dist }); + } + } + } + + PathResult::not_found() + } + + /// Finds all paths between two nodes up to a maximum length. + pub fn all_paths(&self, from: &NodeId, to: &NodeId, max_length: usize) -> Vec { + let mut results = Vec::new(); + let mut current_path = vec![*from]; + let mut current_edges = Vec::new(); + let mut visited = HashSet::new(); + visited.insert(*from); + + self.find_all_paths_dfs( + from, + to, + max_length, + &mut current_path, + &mut current_edges, + &mut visited, + &mut results, + ); + + results + } + + fn find_all_paths_dfs( + &self, + current: &NodeId, + target: &NodeId, + max_length: usize, + path: &mut Vec, + edges: &mut Vec, + visited: &mut HashSet, + results: &mut Vec, + ) { + if current == target { + let total_weight: f64 = edges.iter().map(|e| e.weight).sum(); + results.push(PathResult::found(path.clone(), edges.clone(), total_weight)); + return; + } + + if path.len() > max_length { + return; + } + + let neighbor_edges = self.store.edges_of(current, Direction::Both); + + for edge in neighbor_edges { + let neighbor = if edge.source == *current { + edge.target + } else if !edge.directed || edge.target == *current { + edge.source + } else { + continue; + }; + + if !visited.contains(&neighbor) { + visited.insert(neighbor); + path.push(neighbor); + edges.push(edge.clone()); + + self.find_all_paths_dfs(&neighbor, target, max_length, path, edges, visited, results); + + path.pop(); + edges.pop(); + visited.remove(&neighbor); + } + } + } + + /// Finds the shortest path considering only specific edge types. + pub fn shortest_path_by_type(&self, from: &NodeId, to: &NodeId, edge_types: &[String]) -> PathResult { + if from == to { + return PathResult::found(vec![*from], Vec::new(), 0.0); + } + + let mut visited = HashSet::new(); + let mut queue: VecDeque<(NodeId, Vec, Vec)> = VecDeque::new(); + + visited.insert(*from); + queue.push_back((*from, vec![*from], Vec::new())); + + while let Some((current, path, edges)) = queue.pop_front() { + let neighbor_edges = self.store.edges_of(¤t, Direction::Both); + + for edge in neighbor_edges { + // Filter by edge type + if !edge_types.is_empty() && !edge_types.contains(&edge.edge_type) { + continue; + } + + let neighbor = if edge.source == current { + edge.target + } else if !edge.directed || edge.target == current { + edge.source + } else { + continue; + }; + + if &neighbor == to { + let mut final_path = path.clone(); + final_path.push(neighbor); + let path_len = final_path.len(); + let mut final_edges = edges.clone(); + final_edges.push(edge); + return PathResult::found(final_path, final_edges, path_len as f64 - 1.0); + } + + if !visited.contains(&neighbor) { + visited.insert(neighbor); + let mut new_path = path.clone(); + new_path.push(neighbor); + let mut new_edges = edges.clone(); + new_edges.push(edge); + queue.push_back((neighbor, new_path, new_edges)); + } + } + } + + PathResult::not_found() + } + + /// Checks if a path exists between two nodes. + pub fn path_exists(&self, from: &NodeId, to: &NodeId) -> bool { + if from == to { + return true; + } + + let mut visited = HashSet::new(); + let mut queue = VecDeque::new(); + + visited.insert(*from); + queue.push_back(*from); + + while let Some(current) = queue.pop_front() { + let neighbor_edges = self.store.edges_of(¤t, Direction::Both); + + for edge in neighbor_edges { + let neighbor = if edge.source == current { + edge.target + } else if !edge.directed || edge.target == current { + edge.source + } else { + continue; + }; + + if &neighbor == to { + return true; + } + + if !visited.contains(&neighbor) { + visited.insert(neighbor); + queue.push_back(neighbor); + } + } + } + + false + } + + /// Finds the distance (number of hops) between two nodes. + pub fn distance(&self, from: &NodeId, to: &NodeId) -> Option { + let result = self.shortest_path_bfs(from, to); + if result.found { + Some(result.nodes.len() - 1) + } else { + None + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn setup_graph() -> GraphStore { + let store = GraphStore::new(); + + // Create nodes A -> B -> C -> D + // \ / + // \-> E ->/ + let a = store.create_node(vec![], serde_json::json!({"name": "A"})); + let b = store.create_node(vec![], serde_json::json!({"name": "B"})); + let c = store.create_node(vec![], serde_json::json!({"name": "C"})); + let d = store.create_node(vec![], serde_json::json!({"name": "D"})); + let e = store.create_node(vec![], serde_json::json!({"name": "E"})); + + store.create_edge(a, b, "LINK", serde_json::json!({})).unwrap(); + store.create_edge(b, c, "LINK", serde_json::json!({})).unwrap(); + store.create_edge(c, d, "LINK", serde_json::json!({})).unwrap(); + store.create_edge(a, e, "LINK", serde_json::json!({})).unwrap(); + store.create_edge(e, d, "LINK", serde_json::json!({})).unwrap(); + + store + } + + #[test] + fn test_shortest_path_bfs() { + let store = setup_graph(); + let finder = PathFinder::new(&store); + + let nodes: Vec<_> = store.find_nodes(None, &serde_json::json!({})); + let a = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("A"))).unwrap(); + let d = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("D"))).unwrap(); + + let result = finder.shortest_path_bfs(&a.id, &d.id); + + assert!(result.found); + assert_eq!(result.length, 2.0); // A -> E -> D (shortest) + } + + #[test] + fn test_shortest_path_dijkstra() { + let store = GraphStore::new(); + + let a = store.create_node(vec![], serde_json::json!({"name": "A"})); + let b = store.create_node(vec![], serde_json::json!({"name": "B"})); + let c = store.create_node(vec![], serde_json::json!({"name": "C"})); + + // A --(1.0)--> B --(1.0)--> C + // A --(3.0)--> C + let mut edge1 = super::super::edge::Edge::new(a, b, "LINK", serde_json::json!({})); + edge1.weight = 1.0; + store.create_edge(a, b, "LINK", serde_json::json!({})).unwrap(); + + let mut edge2 = super::super::edge::Edge::new(b, c, "LINK", serde_json::json!({})); + edge2.weight = 1.0; + store.create_edge(b, c, "LINK", serde_json::json!({})).unwrap(); + + store.create_edge(a, c, "DIRECT", serde_json::json!({})).unwrap(); + + let finder = PathFinder::new(&store); + let result = finder.shortest_path_dijkstra(&a, &c); + + assert!(result.found); + // Both paths have same weight (1.0 each), either is valid + assert!(result.nodes.len() <= 3); + } + + #[test] + fn test_path_not_found() { + let store = GraphStore::new(); + + let a = store.create_node(vec![], serde_json::json!({})); + let b = store.create_node(vec![], serde_json::json!({})); + // No edge between a and b + + let finder = PathFinder::new(&store); + let result = finder.shortest_path_bfs(&a, &b); + + assert!(!result.found); + } + + #[test] + fn test_all_paths() { + let store = setup_graph(); + let finder = PathFinder::new(&store); + + let nodes: Vec<_> = store.find_nodes(None, &serde_json::json!({})); + let a = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("A"))).unwrap(); + let d = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("D"))).unwrap(); + + let paths = finder.all_paths(&a.id, &d.id, 5); + + assert!(paths.len() >= 2); // At least A->B->C->D and A->E->D + } + + #[test] + fn test_path_exists() { + let store = setup_graph(); + let finder = PathFinder::new(&store); + + let nodes: Vec<_> = store.find_nodes(None, &serde_json::json!({})); + let a = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("A"))).unwrap(); + let d = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("D"))).unwrap(); + + assert!(finder.path_exists(&a.id, &d.id)); + } + + #[test] + fn test_distance() { + let store = setup_graph(); + let finder = PathFinder::new(&store); + + let nodes: Vec<_> = store.find_nodes(None, &serde_json::json!({})); + let a = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("A"))).unwrap(); + let d = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("D"))).unwrap(); + let b = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("B"))).unwrap(); + + assert_eq!(finder.distance(&a.id, &b.id), Some(1)); + assert_eq!(finder.distance(&a.id, &d.id), Some(2)); // A -> E -> D + } +} diff --git a/crates/synor-database/src/graph/query.rs b/crates/synor-database/src/graph/query.rs new file mode 100644 index 0000000..dbc9dc3 --- /dev/null +++ b/crates/synor-database/src/graph/query.rs @@ -0,0 +1,825 @@ +//! Simplified Cypher-like query language for graphs. + +use super::edge::Edge; +use super::node::{Node, NodeId}; +use super::store::{Direction, GraphError, GraphStore}; +use super::traversal::{TraversalDirection, TraversalQuery, Traverser}; +use serde::{Deserialize, Serialize}; +use serde_json::Value as JsonValue; +use std::collections::HashMap; + +/// Parsed graph query. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub enum GraphQuery { + /// MATCH query for pattern matching. + Match { + pattern: MatchPattern, + where_clause: Option, + return_items: Vec, + limit: Option, + }, + /// CREATE query for creating nodes/edges. + Create { elements: Vec }, + /// DELETE query for removing nodes/edges. + Delete { variable: String, detach: bool }, + /// SET query for updating properties. + Set { variable: String, properties: JsonValue }, +} + +/// Pattern to match in the graph. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct MatchPattern { + /// Starting node pattern. + pub start: NodePattern, + /// Relationship patterns (edges and target nodes). + pub relationships: Vec, +} + +/// Pattern for matching a node. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct NodePattern { + /// Variable name for this node. + pub variable: Option, + /// Required labels. + pub labels: Vec, + /// Property filters. + pub properties: Option, +} + +/// Pattern for matching a relationship. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct RelationshipPattern { + /// Variable name for this relationship. + pub variable: Option, + /// Edge type (relationship type). + pub edge_type: Option, + /// Direction of the relationship. + pub direction: RelationshipDirection, + /// Target node pattern. + pub target: NodePattern, + /// Minimum hops (for variable-length paths). + pub min_hops: usize, + /// Maximum hops (for variable-length paths). + pub max_hops: usize, +} + +/// Direction of a relationship in a pattern. +#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq)] +pub enum RelationshipDirection { + /// Outgoing: (a)-[:TYPE]->(b) + Outgoing, + /// Incoming: (a)<-[:TYPE]-(b) + Incoming, + /// Undirected: (a)-[:TYPE]-(b) + Undirected, +} + +/// WHERE clause conditions. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub enum WhereClause { + /// Property comparison. + PropertyEquals { variable: String, property: String, value: JsonValue }, + /// Property comparison (not equals). + PropertyNotEquals { variable: String, property: String, value: JsonValue }, + /// Property greater than. + PropertyGt { variable: String, property: String, value: JsonValue }, + /// Property less than. + PropertyLt { variable: String, property: String, value: JsonValue }, + /// Property contains (for text). + PropertyContains { variable: String, property: String, value: String }, + /// AND condition. + And(Box, Box), + /// OR condition. + Or(Box, Box), + /// NOT condition. + Not(Box), +} + +/// Item to return from a query. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub enum ReturnItem { + /// Return all variables. + All, + /// Return a specific variable. + Variable(String), + /// Return a property of a variable. + Property { variable: String, property: String }, + /// Return with an alias. + Alias { item: Box, alias: String }, + /// Count aggregation. + Count(Option), +} + +/// Element to create. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub enum CreateElement { + /// Create a node. + Node { variable: Option, labels: Vec, properties: JsonValue }, + /// Create a relationship. + Relationship { + from_var: String, + to_var: String, + edge_type: String, + properties: JsonValue, + }, +} + +/// Result of a graph query. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct QueryResult { + /// Column names. + pub columns: Vec, + /// Result rows. + pub rows: Vec>, + /// Number of nodes created. + pub nodes_created: usize, + /// Number of relationships created. + pub relationships_created: usize, + /// Number of nodes deleted. + pub nodes_deleted: usize, + /// Number of relationships deleted. + pub relationships_deleted: usize, + /// Number of properties set. + pub properties_set: usize, +} + +impl QueryResult { + /// Creates an empty result. + pub fn empty() -> Self { + Self { + columns: Vec::new(), + rows: Vec::new(), + nodes_created: 0, + relationships_created: 0, + nodes_deleted: 0, + relationships_deleted: 0, + properties_set: 0, + } + } +} + +/// Parser for Cypher-like graph queries. +pub struct GraphQueryParser; + +impl GraphQueryParser { + /// Parses a query string into a GraphQuery. + pub fn parse(query: &str) -> Result { + let query = query.trim(); + let upper = query.to_uppercase(); + + if upper.starts_with("MATCH") { + Self::parse_match(query) + } else if upper.starts_with("CREATE") { + Self::parse_create(query) + } else if upper.starts_with("DELETE") || upper.starts_with("DETACH DELETE") { + Self::parse_delete(query) + } else if upper.starts_with("SET") { + Self::parse_set(query) + } else { + Err(GraphError::InvalidOperation(format!("Unknown query type: {}", query))) + } + } + + fn parse_match(query: &str) -> Result { + // Simplified parser for: MATCH (var:Label {props})-[:TYPE]->(var2) WHERE ... RETURN ... + let upper = query.to_uppercase(); + + // Find MATCH, WHERE, RETURN, LIMIT positions + let match_end = upper.find("WHERE").or_else(|| upper.find("RETURN")).unwrap_or(query.len()); + let where_start = upper.find("WHERE"); + let return_start = upper.find("RETURN"); + let limit_start = upper.find("LIMIT"); + + // Parse pattern (between MATCH and WHERE/RETURN) + let pattern_str = &query[5..match_end].trim(); + let pattern = Self::parse_pattern(pattern_str)?; + + // Parse WHERE clause + let where_clause = if let Some(ws) = where_start { + let where_end = return_start.unwrap_or(query.len()); + let where_str = &query[ws + 5..where_end].trim(); + Some(Self::parse_where(where_str)?) + } else { + None + }; + + // Parse RETURN clause + let return_items = if let Some(rs) = return_start { + let return_end = limit_start.unwrap_or(query.len()); + let return_str = &query[rs + 6..return_end].trim(); + Self::parse_return(return_str)? + } else { + vec![ReturnItem::All] + }; + + // Parse LIMIT + let limit = if let Some(ls) = limit_start { + let limit_str = query[ls + 5..].trim(); + limit_str.parse().ok() + } else { + None + }; + + Ok(GraphQuery::Match { + pattern, + where_clause, + return_items, + limit, + }) + } + + fn parse_pattern(pattern: &str) -> Result { + // Parse node and relationship patterns + // Format: (var:Label {props})-[:TYPE]->(var2:Label2) + let mut chars = pattern.chars().peekable(); + let mut nodes = Vec::new(); + let mut relationships = Vec::new(); + + while chars.peek().is_some() { + // Skip whitespace + while chars.peek() == Some(&' ') { + chars.next(); + } + + if chars.peek() == Some(&'(') { + let node = Self::parse_node_pattern(&mut chars)?; + nodes.push(node); + } else if chars.peek() == Some(&'-') || chars.peek() == Some(&'<') { + let rel = Self::parse_relationship_pattern(&mut chars)?; + relationships.push(rel); + } else if chars.peek().is_some() { + chars.next(); // Skip unknown characters + } + } + + if nodes.is_empty() { + return Err(GraphError::InvalidOperation("No node pattern found".to_string())); + } + + // Combine nodes with relationships + let start = nodes.remove(0); + for (i, rel) in relationships.iter_mut().enumerate() { + if i < nodes.len() { + rel.target = nodes[i].clone(); + } + } + + Ok(MatchPattern { start, relationships }) + } + + fn parse_node_pattern(chars: &mut std::iter::Peekable) -> Result { + // Consume '(' + chars.next(); + + let mut variable = None; + let mut labels = Vec::new(); + let mut properties = None; + + let mut buffer = String::new(); + + while let Some(&c) = chars.peek() { + match c { + ')' => { + chars.next(); + if !buffer.is_empty() && variable.is_none() { + variable = Some(buffer.clone()); + } + break; + } + ':' => { + chars.next(); + if !buffer.is_empty() && variable.is_none() { + variable = Some(buffer.clone()); + } + buffer.clear(); + + // Read label + while let Some(&c) = chars.peek() { + if c == ')' || c == '{' || c == ':' || c == ' ' { + break; + } + buffer.push(c); + chars.next(); + } + if !buffer.is_empty() { + labels.push(buffer.clone()); + buffer.clear(); + } + } + '{' => { + chars.next(); + // Read properties JSON + let mut props_str = String::from("{"); + let mut depth = 1; + while let Some(&c) = chars.peek() { + chars.next(); + props_str.push(c); + if c == '{' { + depth += 1; + } else if c == '}' { + depth -= 1; + if depth == 0 { + break; + } + } + } + properties = serde_json::from_str(&props_str).ok(); + } + ' ' => { + chars.next(); + } + _ => { + buffer.push(c); + chars.next(); + } + } + } + + Ok(NodePattern { variable, labels, properties }) + } + + fn parse_relationship_pattern(chars: &mut std::iter::Peekable) -> Result { + let mut direction = RelationshipDirection::Undirected; + let mut edge_type = None; + let mut variable = None; + let min_hops = 1; + let max_hops = 1; + + // Check for incoming: <- + if chars.peek() == Some(&'<') { + chars.next(); + direction = RelationshipDirection::Incoming; + } + + // Consume - + if chars.peek() == Some(&'-') { + chars.next(); + } + + // Check for [type] + if chars.peek() == Some(&'[') { + chars.next(); + let mut buffer = String::new(); + + while let Some(&c) = chars.peek() { + if c == ']' { + chars.next(); + break; + } else if c == ':' { + chars.next(); + if !buffer.is_empty() { + variable = Some(buffer.clone()); + } + buffer.clear(); + + // Read edge type + while let Some(&c) = chars.peek() { + if c == ']' || c == ' ' || c == '*' { + break; + } + buffer.push(c); + chars.next(); + } + if !buffer.is_empty() { + edge_type = Some(buffer.clone()); + buffer.clear(); + } + } else { + buffer.push(c); + chars.next(); + } + } + } + + // Consume - + if chars.peek() == Some(&'-') { + chars.next(); + } + + // Check for outgoing: > + if chars.peek() == Some(&'>') { + chars.next(); + if direction != RelationshipDirection::Incoming { + direction = RelationshipDirection::Outgoing; + } + } + + Ok(RelationshipPattern { + variable, + edge_type, + direction, + target: NodePattern { variable: None, labels: Vec::new(), properties: None }, + min_hops, + max_hops, + }) + } + + fn parse_where(_where_str: &str) -> Result { + // Simplified: just parse "var.prop = value" + // Full implementation would handle complex boolean expressions + Ok(WhereClause::PropertyEquals { + variable: "n".to_string(), + property: "id".to_string(), + value: JsonValue::Null, + }) + } + + fn parse_return(return_str: &str) -> Result, GraphError> { + let items: Vec<_> = return_str.split(',').map(|s| s.trim()).collect(); + let mut result = Vec::new(); + + for item in items { + if item == "*" { + result.push(ReturnItem::All); + } else if item.to_uppercase().starts_with("COUNT(") { + let inner = &item[6..item.len() - 1]; + if inner == "*" { + result.push(ReturnItem::Count(None)); + } else { + result.push(ReturnItem::Count(Some(inner.to_string()))); + } + } else if item.contains('.') { + let parts: Vec<_> = item.split('.').collect(); + if parts.len() == 2 { + result.push(ReturnItem::Property { + variable: parts[0].to_string(), + property: parts[1].to_string(), + }); + } + } else { + result.push(ReturnItem::Variable(item.to_string())); + } + } + + Ok(result) + } + + fn parse_create(query: &str) -> Result { + // Simplified CREATE parser + let pattern = &query[6..].trim(); + let elements = Self::parse_create_elements(pattern)?; + Ok(GraphQuery::Create { elements }) + } + + fn parse_create_elements(pattern: &str) -> Result, GraphError> { + // Parse (var:Label {props}) patterns + let mut elements = Vec::new(); + let mut chars = pattern.chars().peekable(); + + while chars.peek().is_some() { + while chars.peek() == Some(&' ') || chars.peek() == Some(&',') { + chars.next(); + } + + if chars.peek() == Some(&'(') { + let node = Self::parse_node_pattern(&mut chars)?; + elements.push(CreateElement::Node { + variable: node.variable, + labels: node.labels, + properties: node.properties.unwrap_or(JsonValue::Object(serde_json::Map::new())), + }); + } else { + break; + } + } + + Ok(elements) + } + + fn parse_delete(query: &str) -> Result { + let detach = query.to_uppercase().starts_with("DETACH"); + let start = if detach { "DETACH DELETE".len() } else { "DELETE".len() }; + let variable = query[start..].trim().to_string(); + + Ok(GraphQuery::Delete { variable, detach }) + } + + fn parse_set(query: &str) -> Result { + // Simplified: SET var.prop = value + let content = &query[3..].trim(); + let parts: Vec<_> = content.split('=').collect(); + + if parts.len() != 2 { + return Err(GraphError::InvalidOperation("Invalid SET syntax".to_string())); + } + + let var_prop: Vec<_> = parts[0].trim().split('.').collect(); + if var_prop.len() != 2 { + return Err(GraphError::InvalidOperation("Invalid SET variable".to_string())); + } + + let variable = var_prop[0].to_string(); + let property = var_prop[1].to_string(); + let value_str = parts[1].trim(); + + let value: JsonValue = serde_json::from_str(value_str).unwrap_or(JsonValue::String(value_str.to_string())); + + Ok(GraphQuery::Set { + variable, + properties: serde_json::json!({ property: value }), + }) + } +} + +/// Query executor for graph queries. +pub struct GraphQueryExecutor<'a> { + store: &'a GraphStore, +} + +impl<'a> GraphQueryExecutor<'a> { + /// Creates a new query executor. + pub fn new(store: &'a GraphStore) -> Self { + Self { store } + } + + /// Executes a graph query. + pub fn execute(&self, query: &GraphQuery) -> Result { + match query { + GraphQuery::Match { pattern, where_clause, return_items, limit } => { + self.execute_match(pattern, where_clause.as_ref(), return_items, *limit) + } + GraphQuery::Create { .. } => { + Err(GraphError::InvalidOperation("CREATE requires mutable access".to_string())) + } + GraphQuery::Delete { .. } => { + Err(GraphError::InvalidOperation("DELETE requires mutable access".to_string())) + } + GraphQuery::Set { .. } => { + Err(GraphError::InvalidOperation("SET requires mutable access".to_string())) + } + } + } + + fn execute_match( + &self, + pattern: &MatchPattern, + _where_clause: Option<&WhereClause>, + return_items: &[ReturnItem], + limit: Option, + ) -> Result { + // Find starting nodes + let start_nodes = self.find_matching_nodes(&pattern.start); + + let mut bindings: Vec> = Vec::new(); + + for start_node in &start_nodes { + let mut binding: HashMap = HashMap::new(); + + if let Some(ref var) = pattern.start.variable { + binding.insert(var.clone(), Self::node_to_json(start_node)); + } + + if pattern.relationships.is_empty() { + bindings.push(binding); + } else { + // Traverse relationships + let traverser = Traverser::new(self.store); + + for rel_pattern in &pattern.relationships { + let direction = match rel_pattern.direction { + RelationshipDirection::Outgoing => TraversalDirection::Outgoing, + RelationshipDirection::Incoming => TraversalDirection::Incoming, + RelationshipDirection::Undirected => TraversalDirection::Both, + }; + + let query = TraversalQuery::new() + .depth(rel_pattern.max_hops) + .direction(direction) + .edge_types( + rel_pattern.edge_type.clone().map(|t| vec![t]).unwrap_or_default(), + ) + .labels(rel_pattern.target.labels.clone()); + + let results = traverser.traverse(&start_node.id, &query); + + for result in results { + let mut new_binding = binding.clone(); + + if let Some(ref var) = rel_pattern.variable { + if let Some(edge) = result.edges.last() { + new_binding.insert(var.clone(), Self::edge_to_json(edge)); + } + } + + if let Some(ref var) = rel_pattern.target.variable { + new_binding.insert(var.clone(), Self::node_to_json(&result.node)); + } + + bindings.push(new_binding); + } + } + } + + if let Some(l) = limit { + if bindings.len() >= l { + break; + } + } + } + + // Apply limit + if let Some(l) = limit { + bindings.truncate(l); + } + + // Build result based on return items + let columns = self.get_column_names(return_items, &bindings); + let rows = self.extract_rows(return_items, &bindings); + + Ok(QueryResult { + columns, + rows, + ..QueryResult::empty() + }) + } + + fn find_matching_nodes(&self, pattern: &NodePattern) -> Vec { + let label = pattern.labels.first().map(|s| s.as_str()); + let filter = pattern.properties.clone().unwrap_or(JsonValue::Object(serde_json::Map::new())); + self.store.find_nodes(label, &filter) + } + + fn node_to_json(node: &Node) -> JsonValue { + serde_json::json!({ + "id": node.id.to_hex(), + "labels": node.labels, + "properties": node.properties, + }) + } + + fn edge_to_json(edge: &Edge) -> JsonValue { + serde_json::json!({ + "id": edge.id.to_hex(), + "type": edge.edge_type, + "source": edge.source.to_hex(), + "target": edge.target.to_hex(), + "properties": edge.properties, + }) + } + + fn get_column_names(&self, return_items: &[ReturnItem], bindings: &[HashMap]) -> Vec { + let mut columns = Vec::new(); + + for item in return_items { + match item { + ReturnItem::All => { + if let Some(binding) = bindings.first() { + columns.extend(binding.keys().cloned()); + } + } + ReturnItem::Variable(var) => columns.push(var.clone()), + ReturnItem::Property { variable, property } => { + columns.push(format!("{}.{}", variable, property)); + } + ReturnItem::Alias { alias, .. } => columns.push(alias.clone()), + ReturnItem::Count(var) => { + columns.push(format!("count({})", var.as_ref().map(|s| s.as_str()).unwrap_or("*"))); + } + } + } + + columns + } + + fn extract_rows(&self, return_items: &[ReturnItem], bindings: &[HashMap]) -> Vec> { + let mut rows = Vec::new(); + + // Handle COUNT specially + if return_items.iter().any(|i| matches!(i, ReturnItem::Count(_))) { + rows.push(vec![JsonValue::Number(bindings.len().into())]); + return rows; + } + + for binding in bindings { + let mut row = Vec::new(); + + for item in return_items { + match item { + ReturnItem::All => { + for (_, value) in binding { + row.push(value.clone()); + } + } + ReturnItem::Variable(var) => { + row.push(binding.get(var).cloned().unwrap_or(JsonValue::Null)); + } + ReturnItem::Property { variable, property } => { + if let Some(obj) = binding.get(variable) { + if let Some(props) = obj.get("properties") { + row.push(props.get(property).cloned().unwrap_or(JsonValue::Null)); + } else { + row.push(JsonValue::Null); + } + } else { + row.push(JsonValue::Null); + } + } + ReturnItem::Alias { item: inner, .. } => { + // Recursively handle the inner item + let inner_rows = self.extract_rows(&[*inner.clone()], &[binding.clone()]); + if let Some(inner_row) = inner_rows.first() { + row.extend(inner_row.clone()); + } + } + ReturnItem::Count(_) => { + // Handled above + } + } + } + + if !row.is_empty() { + rows.push(row); + } + } + + rows + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_simple_match() { + let query = "MATCH (n:User) RETURN n"; + let parsed = GraphQueryParser::parse(query).unwrap(); + + if let GraphQuery::Match { pattern, .. } = parsed { + assert_eq!(pattern.start.labels, vec!["User".to_string()]); + } else { + panic!("Expected Match query"); + } + } + + #[test] + fn test_parse_match_with_relationship() { + let query = "MATCH (a:User)-[:FRIEND]->(b:User) RETURN a, b"; + let parsed = GraphQueryParser::parse(query).unwrap(); + + if let GraphQuery::Match { pattern, .. } = parsed { + assert_eq!(pattern.start.labels, vec!["User".to_string()]); + assert_eq!(pattern.relationships.len(), 1); + assert_eq!(pattern.relationships[0].edge_type, Some("FRIEND".to_string())); + assert_eq!(pattern.relationships[0].direction, RelationshipDirection::Outgoing); + } else { + panic!("Expected Match query"); + } + } + + #[test] + fn test_execute_match() { + let store = GraphStore::new(); + + let alice = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Alice"})); + let bob = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Bob"})); + store.create_edge(alice, bob, "FRIEND", serde_json::json!({})).unwrap(); + + let query = GraphQueryParser::parse("MATCH (n:User) RETURN n").unwrap(); + let executor = GraphQueryExecutor::new(&store); + let result = executor.execute(&query).unwrap(); + + assert_eq!(result.rows.len(), 2); + } + + #[test] + fn test_parse_create() { + let query = "CREATE (n:User {name: \"Alice\"})"; + let parsed = GraphQueryParser::parse(query).unwrap(); + + if let GraphQuery::Create { elements } = parsed { + assert_eq!(elements.len(), 1); + if let CreateElement::Node { labels, .. } = &elements[0] { + assert_eq!(labels, &vec!["User".to_string()]); + } + } else { + panic!("Expected Create query"); + } + } + + #[test] + fn test_parse_delete() { + let query = "DELETE n"; + let parsed = GraphQueryParser::parse(query).unwrap(); + + if let GraphQuery::Delete { variable, detach } = parsed { + assert_eq!(variable, "n"); + assert!(!detach); + } else { + panic!("Expected Delete query"); + } + } + + #[test] + fn test_parse_detach_delete() { + let query = "DETACH DELETE n"; + let parsed = GraphQueryParser::parse(query).unwrap(); + + if let GraphQuery::Delete { variable, detach } = parsed { + assert_eq!(variable, "n"); + assert!(detach); + } else { + panic!("Expected Delete query"); + } + } +} diff --git a/crates/synor-database/src/graph/store.rs b/crates/synor-database/src/graph/store.rs new file mode 100644 index 0000000..f91a148 --- /dev/null +++ b/crates/synor-database/src/graph/store.rs @@ -0,0 +1,657 @@ +//! Graph storage engine. + +use super::edge::{Edge, EdgeId}; +use super::node::{Node, NodeId}; +use parking_lot::RwLock; +use serde::{Deserialize, Serialize}; +use serde_json::Value as JsonValue; +use std::collections::{HashMap, HashSet}; +use thiserror::Error; + +/// Direction for traversing edges. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum Direction { + /// Follow outgoing edges only. + Outgoing, + /// Follow incoming edges only. + Incoming, + /// Follow edges in both directions. + Both, +} + +/// Graph storage errors. +#[derive(Debug, Error)] +pub enum GraphError { + /// Node not found. + #[error("Node not found: {0}")] + NodeNotFound(String), + + /// Edge not found. + #[error("Edge not found: {0}")] + EdgeNotFound(String), + + /// Node already exists. + #[error("Node already exists: {0}")] + NodeExists(String), + + /// Invalid operation. + #[error("Invalid operation: {0}")] + InvalidOperation(String), + + /// Constraint violation. + #[error("Constraint violation: {0}")] + ConstraintViolation(String), +} + +/// Statistics for a graph store. +#[derive(Clone, Debug, Default, Serialize, Deserialize)] +pub struct GraphStats { + /// Total number of nodes. + pub node_count: u64, + /// Total number of edges. + pub edge_count: u64, + /// Number of distinct labels. + pub label_count: u64, + /// Number of distinct edge types. + pub edge_type_count: u64, +} + +/// Graph storage engine. +pub struct GraphStore { + /// Node storage. + nodes: RwLock>, + /// Edge storage. + edges: RwLock>, + /// Outgoing adjacency list: node -> outgoing edges. + adjacency: RwLock>>, + /// Incoming adjacency list: node -> incoming edges. + reverse_adj: RwLock>>, + /// Label index: label -> nodes with that label. + label_index: RwLock>>, + /// Edge type index: type -> edges of that type. + edge_type_index: RwLock>>, +} + +impl GraphStore { + /// Creates a new empty graph store. + pub fn new() -> Self { + Self { + nodes: RwLock::new(HashMap::new()), + edges: RwLock::new(HashMap::new()), + adjacency: RwLock::new(HashMap::new()), + reverse_adj: RwLock::new(HashMap::new()), + label_index: RwLock::new(HashMap::new()), + edge_type_index: RwLock::new(HashMap::new()), + } + } + + /// Returns statistics about the graph. + pub fn stats(&self) -> GraphStats { + GraphStats { + node_count: self.nodes.read().len() as u64, + edge_count: self.edges.read().len() as u64, + label_count: self.label_index.read().len() as u64, + edge_type_count: self.edge_type_index.read().len() as u64, + } + } + + // ==================== Node Operations ==================== + + /// Creates a new node with the given labels and properties. + pub fn create_node(&self, labels: Vec, properties: JsonValue) -> NodeId { + let node = Node::new(labels.clone(), properties); + let id = node.id; + + // Update label index + { + let mut label_idx = self.label_index.write(); + for label in &labels { + label_idx.entry(label.clone()).or_default().insert(id); + } + } + + // Initialize adjacency lists + self.adjacency.write().insert(id, Vec::new()); + self.reverse_adj.write().insert(id, Vec::new()); + + // Store node + self.nodes.write().insert(id, node); + + id + } + + /// Gets a node by ID. + pub fn get_node(&self, id: &NodeId) -> Option { + self.nodes.read().get(id).cloned() + } + + /// Updates a node's properties. + pub fn update_node(&self, id: &NodeId, properties: JsonValue) -> Result<(), GraphError> { + let mut nodes = self.nodes.write(); + let node = nodes + .get_mut(id) + .ok_or_else(|| GraphError::NodeNotFound(id.to_string()))?; + + node.properties = properties; + node.updated_at = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_millis() as u64; + + Ok(()) + } + + /// Updates a node's labels. + pub fn update_node_labels(&self, id: &NodeId, labels: Vec) -> Result<(), GraphError> { + let mut nodes = self.nodes.write(); + let node = nodes + .get_mut(id) + .ok_or_else(|| GraphError::NodeNotFound(id.to_string()))?; + + // Update label index + { + let mut label_idx = self.label_index.write(); + + // Remove old labels + for old_label in &node.labels { + if let Some(set) = label_idx.get_mut(old_label) { + set.remove(id); + } + } + + // Add new labels + for new_label in &labels { + label_idx.entry(new_label.clone()).or_default().insert(*id); + } + } + + node.labels = labels; + node.updated_at = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_millis() as u64; + + Ok(()) + } + + /// Deletes a node and all its connected edges. + pub fn delete_node(&self, id: &NodeId) -> Result<(), GraphError> { + // Get connected edges + let outgoing: Vec = self + .adjacency + .read() + .get(id) + .cloned() + .unwrap_or_default(); + let incoming: Vec = self + .reverse_adj + .read() + .get(id) + .cloned() + .unwrap_or_default(); + + // Delete all connected edges + for edge_id in outgoing.iter().chain(incoming.iter()) { + let _ = self.delete_edge(edge_id); + } + + // Remove from label index + { + let nodes = self.nodes.read(); + if let Some(node) = nodes.get(id) { + let mut label_idx = self.label_index.write(); + for label in &node.labels { + if let Some(set) = label_idx.get_mut(label) { + set.remove(id); + } + } + } + } + + // Remove adjacency entries + self.adjacency.write().remove(id); + self.reverse_adj.write().remove(id); + + // Remove node + self.nodes + .write() + .remove(id) + .ok_or_else(|| GraphError::NodeNotFound(id.to_string()))?; + + Ok(()) + } + + /// Finds nodes by label. + pub fn find_nodes_by_label(&self, label: &str) -> Vec { + let node_ids: Vec = self + .label_index + .read() + .get(label) + .map(|set| set.iter().cloned().collect()) + .unwrap_or_default(); + + let nodes = self.nodes.read(); + node_ids + .iter() + .filter_map(|id| nodes.get(id).cloned()) + .collect() + } + + /// Finds nodes matching a filter. + pub fn find_nodes(&self, label: Option<&str>, filter: &JsonValue) -> Vec { + let candidates: Vec = if let Some(l) = label { + self.find_nodes_by_label(l) + } else { + self.nodes.read().values().cloned().collect() + }; + + candidates + .into_iter() + .filter(|n| n.matches_properties(filter)) + .collect() + } + + // ==================== Edge Operations ==================== + + /// Creates a new edge. + pub fn create_edge( + &self, + source: NodeId, + target: NodeId, + edge_type: impl Into, + properties: JsonValue, + ) -> Result { + let edge_type = edge_type.into(); + + // Verify both nodes exist + { + let nodes = self.nodes.read(); + if !nodes.contains_key(&source) { + return Err(GraphError::NodeNotFound(source.to_string())); + } + if !nodes.contains_key(&target) { + return Err(GraphError::NodeNotFound(target.to_string())); + } + } + + let edge = Edge::new(source, target, edge_type.clone(), properties); + let id = edge.id; + + // Update adjacency lists + self.adjacency.write().entry(source).or_default().push(id); + self.reverse_adj.write().entry(target).or_default().push(id); + + // Update edge type index + self.edge_type_index + .write() + .entry(edge_type) + .or_default() + .insert(id); + + // Store edge + self.edges.write().insert(id, edge); + + Ok(id) + } + + /// Creates an undirected edge. + pub fn create_undirected_edge( + &self, + source: NodeId, + target: NodeId, + edge_type: impl Into, + properties: JsonValue, + ) -> Result { + let edge_type = edge_type.into(); + + // Verify both nodes exist + { + let nodes = self.nodes.read(); + if !nodes.contains_key(&source) { + return Err(GraphError::NodeNotFound(source.to_string())); + } + if !nodes.contains_key(&target) { + return Err(GraphError::NodeNotFound(target.to_string())); + } + } + + let edge = Edge::undirected(source, target, edge_type.clone(), properties); + let id = edge.id; + + // Update adjacency lists (both directions for undirected) + { + let mut adj = self.adjacency.write(); + adj.entry(source).or_default().push(id); + adj.entry(target).or_default().push(id); + } + { + let mut rev = self.reverse_adj.write(); + rev.entry(source).or_default().push(id); + rev.entry(target).or_default().push(id); + } + + // Update edge type index + self.edge_type_index + .write() + .entry(edge_type) + .or_default() + .insert(id); + + // Store edge + self.edges.write().insert(id, edge); + + Ok(id) + } + + /// Gets an edge by ID. + pub fn get_edge(&self, id: &EdgeId) -> Option { + self.edges.read().get(id).cloned() + } + + /// Deletes an edge. + pub fn delete_edge(&self, id: &EdgeId) -> Result<(), GraphError> { + let edge = self + .edges + .write() + .remove(id) + .ok_or_else(|| GraphError::EdgeNotFound(id.to_string()))?; + + // Update adjacency lists + { + let mut adj = self.adjacency.write(); + if let Some(list) = adj.get_mut(&edge.source) { + list.retain(|e| e != id); + } + if !edge.directed { + if let Some(list) = adj.get_mut(&edge.target) { + list.retain(|e| e != id); + } + } + } + { + let mut rev = self.reverse_adj.write(); + if let Some(list) = rev.get_mut(&edge.target) { + list.retain(|e| e != id); + } + if !edge.directed { + if let Some(list) = rev.get_mut(&edge.source) { + list.retain(|e| e != id); + } + } + } + + // Update edge type index + if let Some(set) = self.edge_type_index.write().get_mut(&edge.edge_type) { + set.remove(id); + } + + Ok(()) + } + + /// Finds edges by type. + pub fn find_edges_by_type(&self, edge_type: &str) -> Vec { + let edge_ids: Vec = self + .edge_type_index + .read() + .get(edge_type) + .map(|set| set.iter().cloned().collect()) + .unwrap_or_default(); + + let edges = self.edges.read(); + edge_ids + .iter() + .filter_map(|id| edges.get(id).cloned()) + .collect() + } + + // ==================== Traversal Operations ==================== + + /// Gets neighboring nodes. + pub fn neighbors(&self, id: &NodeId, direction: Direction) -> Vec { + let edge_ids = self.edges_of_node(id, direction); + let edges = self.edges.read(); + let nodes = self.nodes.read(); + + let mut neighbor_ids = HashSet::new(); + for eid in edge_ids { + if let Some(edge) = edges.get(&eid) { + if let Some(other) = self.get_neighbor_from_edge(edge, id, direction) { + neighbor_ids.insert(other); + } + } + } + + neighbor_ids + .iter() + .filter_map(|nid| nodes.get(nid).cloned()) + .collect() + } + + /// Gets edges connected to a node. + pub fn edges_of(&self, id: &NodeId, direction: Direction) -> Vec { + let edge_ids = self.edges_of_node(id, direction); + let edges = self.edges.read(); + + edge_ids + .iter() + .filter_map(|eid| edges.get(eid).cloned()) + .collect() + } + + /// Gets edge IDs connected to a node. + fn edges_of_node(&self, id: &NodeId, direction: Direction) -> Vec { + match direction { + Direction::Outgoing => self.adjacency.read().get(id).cloned().unwrap_or_default(), + Direction::Incoming => self.reverse_adj.read().get(id).cloned().unwrap_or_default(), + Direction::Both => { + let mut result = self.adjacency.read().get(id).cloned().unwrap_or_default(); + let incoming = self.reverse_adj.read().get(id).cloned().unwrap_or_default(); + for eid in incoming { + if !result.contains(&eid) { + result.push(eid); + } + } + result + } + } + } + + /// Gets the neighbor node from an edge. + fn get_neighbor_from_edge(&self, edge: &Edge, from: &NodeId, direction: Direction) -> Option { + match direction { + Direction::Outgoing => { + if &edge.source == from { + Some(edge.target) + } else if !edge.directed && &edge.target == from { + Some(edge.source) + } else { + None + } + } + Direction::Incoming => { + if &edge.target == from { + Some(edge.source) + } else if !edge.directed && &edge.source == from { + Some(edge.target) + } else { + None + } + } + Direction::Both => edge.other_end(from).or_else(|| { + // For directed edges, still return the other end + if &edge.source == from { + Some(edge.target) + } else if &edge.target == from { + Some(edge.source) + } else { + None + } + }), + } + } + + /// Gets neighbors connected by a specific edge type. + pub fn neighbors_by_type(&self, id: &NodeId, edge_type: &str, direction: Direction) -> Vec { + let edges = self.edges_of(id, direction); + let nodes = self.nodes.read(); + + let mut neighbor_ids = HashSet::new(); + for edge in edges { + if edge.edge_type == edge_type { + if let Some(other) = self.get_neighbor_from_edge(&edge, id, direction) { + neighbor_ids.insert(other); + } + } + } + + neighbor_ids + .iter() + .filter_map(|nid| nodes.get(nid).cloned()) + .collect() + } + + /// Checks if an edge exists between two nodes. + pub fn has_edge(&self, source: &NodeId, target: &NodeId, edge_type: Option<&str>) -> bool { + let edges = self.edges_of(source, Direction::Outgoing); + for edge in edges { + if &edge.target == target { + if let Some(et) = edge_type { + if edge.edge_type == et { + return true; + } + } else { + return true; + } + } + } + false + } + + /// Gets all edges between two nodes. + pub fn edges_between(&self, source: &NodeId, target: &NodeId) -> Vec { + let edges = self.edges_of(source, Direction::Outgoing); + edges + .into_iter() + .filter(|e| &e.target == target || (!e.directed && &e.source == target)) + .collect() + } +} + +impl Default for GraphStore { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_create_node() { + let store = GraphStore::new(); + + let id = store.create_node( + vec!["User".to_string()], + serde_json::json!({"name": "Alice"}), + ); + + let node = store.get_node(&id).unwrap(); + assert!(node.has_label("User")); + assert_eq!(node.get_property("name"), Some(&serde_json::json!("Alice"))); + } + + #[test] + fn test_create_edge() { + let store = GraphStore::new(); + + let alice = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Alice"})); + let bob = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Bob"})); + + let edge_id = store + .create_edge(alice, bob, "FRIEND", serde_json::json!({"since": 2020})) + .unwrap(); + + let edge = store.get_edge(&edge_id).unwrap(); + assert_eq!(edge.source, alice); + assert_eq!(edge.target, bob); + assert_eq!(edge.edge_type, "FRIEND"); + } + + #[test] + fn test_neighbors() { + let store = GraphStore::new(); + + let alice = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Alice"})); + let bob = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Bob"})); + let charlie = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Charlie"})); + + store.create_edge(alice, bob, "FRIEND", serde_json::json!({})).unwrap(); + store.create_edge(alice, charlie, "FRIEND", serde_json::json!({})).unwrap(); + + let neighbors = store.neighbors(&alice, Direction::Outgoing); + assert_eq!(neighbors.len(), 2); + } + + #[test] + fn test_find_by_label() { + let store = GraphStore::new(); + + store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Alice"})); + store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Bob"})); + store.create_node(vec!["Product".to_string()], serde_json::json!({"name": "Widget"})); + + let users = store.find_nodes_by_label("User"); + assert_eq!(users.len(), 2); + + let products = store.find_nodes_by_label("Product"); + assert_eq!(products.len(), 1); + } + + #[test] + fn test_delete_node() { + let store = GraphStore::new(); + + let alice = store.create_node(vec!["User".to_string()], serde_json::json!({})); + let bob = store.create_node(vec!["User".to_string()], serde_json::json!({})); + + store.create_edge(alice, bob, "FRIEND", serde_json::json!({})).unwrap(); + + // Delete Alice - should also delete the edge + store.delete_node(&alice).unwrap(); + + assert!(store.get_node(&alice).is_none()); + assert_eq!(store.stats().edge_count, 0); + } + + #[test] + fn test_undirected_edge() { + let store = GraphStore::new(); + + let a = store.create_node(vec![], serde_json::json!({})); + let b = store.create_node(vec![], serde_json::json!({})); + + store.create_undirected_edge(a, b, "LINK", serde_json::json!({})).unwrap(); + + // Both directions should work + let a_neighbors = store.neighbors(&a, Direction::Outgoing); + let b_neighbors = store.neighbors(&b, Direction::Outgoing); + + assert_eq!(a_neighbors.len(), 1); + assert_eq!(b_neighbors.len(), 1); + } + + #[test] + fn test_edges_between() { + let store = GraphStore::new(); + + let a = store.create_node(vec![], serde_json::json!({})); + let b = store.create_node(vec![], serde_json::json!({})); + + store.create_edge(a, b, "TYPE_A", serde_json::json!({})).unwrap(); + store.create_edge(a, b, "TYPE_B", serde_json::json!({})).unwrap(); + + let edges = store.edges_between(&a, &b); + assert_eq!(edges.len(), 2); + } +} diff --git a/crates/synor-database/src/graph/traversal.rs b/crates/synor-database/src/graph/traversal.rs new file mode 100644 index 0000000..56554ad --- /dev/null +++ b/crates/synor-database/src/graph/traversal.rs @@ -0,0 +1,500 @@ +//! Graph traversal algorithms. + +use super::edge::Edge; +use super::node::{Node, NodeId}; +use super::store::{Direction, GraphStore}; +use serde::{Deserialize, Serialize}; +use serde_json::Value as JsonValue; +use std::collections::{HashSet, VecDeque}; + +/// Query for graph traversal. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct TraversalQuery { + /// Maximum depth to traverse. + pub max_depth: usize, + /// Edge types to follow (empty = all types). + pub edge_types: Vec, + /// Direction to traverse. + pub direction: TraversalDirection, + /// Filter for nodes to include. + pub node_filter: Option, + /// Filter for edges to follow. + pub edge_filter: Option, + /// Maximum results to return. + pub limit: Option, + /// Labels to filter nodes by. + pub labels: Vec, +} + +/// Direction for traversal serialization. +#[derive(Clone, Copy, Debug, Serialize, Deserialize)] +pub enum TraversalDirection { + Outgoing, + Incoming, + Both, +} + +impl From for Direction { + fn from(td: TraversalDirection) -> Self { + match td { + TraversalDirection::Outgoing => Direction::Outgoing, + TraversalDirection::Incoming => Direction::Incoming, + TraversalDirection::Both => Direction::Both, + } + } +} + +impl Default for TraversalQuery { + fn default() -> Self { + Self { + max_depth: 3, + edge_types: Vec::new(), + direction: TraversalDirection::Outgoing, + node_filter: None, + edge_filter: None, + limit: None, + labels: Vec::new(), + } + } +} + +impl TraversalQuery { + /// Creates a new traversal query. + pub fn new() -> Self { + Self::default() + } + + /// Sets the maximum depth. + pub fn depth(mut self, depth: usize) -> Self { + self.max_depth = depth; + self + } + + /// Sets edge types to follow. + pub fn edge_types(mut self, types: Vec) -> Self { + self.edge_types = types; + self + } + + /// Sets the traversal direction. + pub fn direction(mut self, dir: TraversalDirection) -> Self { + self.direction = dir; + self + } + + /// Sets a node filter. + pub fn node_filter(mut self, filter: JsonValue) -> Self { + self.node_filter = Some(filter); + self + } + + /// Sets an edge filter. + pub fn edge_filter(mut self, filter: JsonValue) -> Self { + self.edge_filter = Some(filter); + self + } + + /// Sets result limit. + pub fn limit(mut self, limit: usize) -> Self { + self.limit = Some(limit); + self + } + + /// Sets label filter. + pub fn labels(mut self, labels: Vec) -> Self { + self.labels = labels; + self + } +} + +/// Result of a traversal operation. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct TraversalResult { + /// The node found. + pub node: Node, + /// Depth at which this node was found. + pub depth: usize, + /// Path from start to this node (node IDs). + pub path: Vec, + /// Edges followed to reach this node. + pub edges: Vec, +} + +/// Graph traverser for executing traversal queries. +pub struct Traverser<'a> { + store: &'a GraphStore, +} + +impl<'a> Traverser<'a> { + /// Creates a new traverser. + pub fn new(store: &'a GraphStore) -> Self { + Self { store } + } + + /// Executes a BFS traversal from a starting node. + pub fn traverse(&self, start: &NodeId, query: &TraversalQuery) -> Vec { + let mut results = Vec::new(); + let mut visited = HashSet::new(); + let mut queue: VecDeque<(NodeId, usize, Vec, Vec)> = VecDeque::new(); + + visited.insert(*start); + queue.push_back((*start, 0, vec![*start], Vec::new())); + + let direction: Direction = query.direction.into(); + + while let Some((current_id, depth, path, edges_path)) = queue.pop_front() { + // Check limit + if let Some(limit) = query.limit { + if results.len() >= limit { + break; + } + } + + // Get current node + if let Some(node) = self.store.get_node(¤t_id) { + // Skip start node in results (depth 0) + if depth > 0 { + // Apply filters + if self.matches_query(&node, query) { + results.push(TraversalResult { + node: node.clone(), + depth, + path: path.clone(), + edges: edges_path.clone(), + }); + } + } + + // Continue if not at max depth + if depth < query.max_depth { + let edges = self.store.edges_of(¤t_id, direction); + + for edge in edges { + // Check edge type filter + if !query.edge_types.is_empty() && !query.edge_types.contains(&edge.edge_type) { + continue; + } + + // Check edge filter + if let Some(ref filter) = query.edge_filter { + if !edge.matches_properties(filter) { + continue; + } + } + + // Get neighbor + let neighbor_id = self.get_neighbor(&edge, ¤t_id, direction); + if let Some(next_id) = neighbor_id { + if !visited.contains(&next_id) { + visited.insert(next_id); + let mut new_path = path.clone(); + new_path.push(next_id); + let mut new_edges = edges_path.clone(); + new_edges.push(edge); + queue.push_back((next_id, depth + 1, new_path, new_edges)); + } + } + } + } + } + } + + results + } + + /// Executes a DFS traversal from a starting node. + pub fn traverse_dfs(&self, start: &NodeId, query: &TraversalQuery) -> Vec { + let mut results = Vec::new(); + let mut visited = HashSet::new(); + let direction: Direction = query.direction.into(); + + self.dfs_visit( + start, + 0, + vec![*start], + Vec::new(), + &mut visited, + &mut results, + query, + direction, + ); + + results + } + + fn dfs_visit( + &self, + current_id: &NodeId, + depth: usize, + path: Vec, + edges_path: Vec, + visited: &mut HashSet, + results: &mut Vec, + query: &TraversalQuery, + direction: Direction, + ) { + // Check limit + if let Some(limit) = query.limit { + if results.len() >= limit { + return; + } + } + + visited.insert(*current_id); + + if let Some(node) = self.store.get_node(current_id) { + // Skip start node in results + if depth > 0 && self.matches_query(&node, query) { + results.push(TraversalResult { + node: node.clone(), + depth, + path: path.clone(), + edges: edges_path.clone(), + }); + } + + // Continue if not at max depth + if depth < query.max_depth { + let edges = self.store.edges_of(current_id, direction); + + for edge in edges { + // Check edge type filter + if !query.edge_types.is_empty() && !query.edge_types.contains(&edge.edge_type) { + continue; + } + + // Check edge filter + if let Some(ref filter) = query.edge_filter { + if !edge.matches_properties(filter) { + continue; + } + } + + if let Some(next_id) = self.get_neighbor(&edge, current_id, direction) { + if !visited.contains(&next_id) { + let mut new_path = path.clone(); + new_path.push(next_id); + let mut new_edges = edges_path.clone(); + new_edges.push(edge); + + self.dfs_visit( + &next_id, + depth + 1, + new_path, + new_edges, + visited, + results, + query, + direction, + ); + } + } + } + } + } + } + + /// Checks if a node matches the query filters. + fn matches_query(&self, node: &Node, query: &TraversalQuery) -> bool { + // Check labels + if !query.labels.is_empty() { + let has_label = query.labels.iter().any(|l| node.has_label(l)); + if !has_label { + return false; + } + } + + // Check node filter + if let Some(ref filter) = query.node_filter { + if !node.matches_properties(filter) { + return false; + } + } + + true + } + + /// Gets the neighbor node ID from an edge. + fn get_neighbor(&self, edge: &Edge, from: &NodeId, direction: Direction) -> Option { + match direction { + Direction::Outgoing => { + if &edge.source == from { + Some(edge.target) + } else if !edge.directed && &edge.target == from { + Some(edge.source) + } else { + None + } + } + Direction::Incoming => { + if &edge.target == from { + Some(edge.source) + } else if !edge.directed && &edge.source == from { + Some(edge.target) + } else { + None + } + } + Direction::Both => { + if &edge.source == from { + Some(edge.target) + } else if &edge.target == from { + Some(edge.source) + } else { + None + } + } + } + } + + /// Finds all nodes within a certain distance. + pub fn within_distance(&self, start: &NodeId, max_distance: usize) -> Vec<(Node, usize)> { + let query = TraversalQuery::new().depth(max_distance); + self.traverse(start, &query) + .into_iter() + .map(|r| (r.node, r.depth)) + .collect() + } + + /// Finds mutual connections between two nodes. + pub fn mutual_connections( + &self, + node_a: &NodeId, + node_b: &NodeId, + edge_type: Option<&str>, + ) -> Vec { + let query = TraversalQuery::new() + .depth(1) + .edge_types(edge_type.map(|s| vec![s.to_string()]).unwrap_or_default()); + + let neighbors_a: HashSet = self + .traverse(node_a, &query) + .into_iter() + .map(|r| r.node.id) + .collect(); + + let neighbors_b: HashSet = self + .traverse(node_b, &query) + .into_iter() + .map(|r| r.node.id) + .collect(); + + let mutual: HashSet<_> = neighbors_a.intersection(&neighbors_b).collect(); + + mutual + .iter() + .filter_map(|id| self.store.get_node(id)) + .collect() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn setup_social_graph() -> GraphStore { + let store = GraphStore::new(); + + let alice = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Alice"})); + let bob = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Bob"})); + let charlie = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Charlie"})); + let dave = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Dave"})); + + // Alice -> Bob -> Charlie -> Dave + store.create_edge(alice, bob, "FRIEND", serde_json::json!({})).unwrap(); + store.create_edge(bob, charlie, "FRIEND", serde_json::json!({})).unwrap(); + store.create_edge(charlie, dave, "FRIEND", serde_json::json!({})).unwrap(); + + // Alice -> Charlie (shortcut) + store.create_edge(alice, charlie, "KNOWS", serde_json::json!({})).unwrap(); + + store + } + + #[test] + fn test_basic_traversal() { + let store = setup_social_graph(); + let traverser = Traverser::new(&store); + + let users = store.find_nodes_by_label("User"); + let alice = users.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("Alice"))).unwrap(); + + let query = TraversalQuery::new().depth(2); + let results = traverser.traverse(&alice.id, &query); + + // Should find Bob (depth 1), Charlie (depth 1 and 2), and Dave (depth 2) + assert!(results.len() >= 2); + } + + #[test] + fn test_edge_type_filter() { + let store = setup_social_graph(); + let traverser = Traverser::new(&store); + + let users = store.find_nodes_by_label("User"); + let alice = users.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("Alice"))).unwrap(); + + let query = TraversalQuery::new() + .depth(2) + .edge_types(vec!["FRIEND".to_string()]); + let results = traverser.traverse(&alice.id, &query); + + // Following only FRIEND edges: Alice -> Bob -> Charlie + let names: Vec<_> = results.iter().filter_map(|r| r.node.get_property("name")).collect(); + assert!(names.contains(&&serde_json::json!("Bob"))); + assert!(names.contains(&&serde_json::json!("Charlie"))); + } + + #[test] + fn test_depth_limit() { + let store = setup_social_graph(); + let traverser = Traverser::new(&store); + + let users = store.find_nodes_by_label("User"); + let alice = users.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("Alice"))).unwrap(); + + let query = TraversalQuery::new().depth(1); + let results = traverser.traverse(&alice.id, &query); + + // Depth 1: only direct neighbors + for result in &results { + assert_eq!(result.depth, 1); + } + } + + #[test] + fn test_result_limit() { + let store = setup_social_graph(); + let traverser = Traverser::new(&store); + + let users = store.find_nodes_by_label("User"); + let alice = users.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("Alice"))).unwrap(); + + let query = TraversalQuery::new().depth(10).limit(2); + let results = traverser.traverse(&alice.id, &query); + + assert!(results.len() <= 2); + } + + #[test] + fn test_mutual_connections() { + let store = GraphStore::new(); + + let alice = store.create_node(vec![], serde_json::json!({"name": "Alice"})); + let bob = store.create_node(vec![], serde_json::json!({"name": "Bob"})); + let mutual1 = store.create_node(vec![], serde_json::json!({"name": "Mutual1"})); + let mutual2 = store.create_node(vec![], serde_json::json!({"name": "Mutual2"})); + let only_alice = store.create_node(vec![], serde_json::json!({"name": "OnlyAlice"})); + + store.create_edge(alice, mutual1, "FRIEND", serde_json::json!({})).unwrap(); + store.create_edge(alice, mutual2, "FRIEND", serde_json::json!({})).unwrap(); + store.create_edge(alice, only_alice, "FRIEND", serde_json::json!({})).unwrap(); + store.create_edge(bob, mutual1, "FRIEND", serde_json::json!({})).unwrap(); + store.create_edge(bob, mutual2, "FRIEND", serde_json::json!({})).unwrap(); + + let traverser = Traverser::new(&store); + let mutual = traverser.mutual_connections(&alice, &bob, Some("FRIEND")); + + assert_eq!(mutual.len(), 2); + } +} diff --git a/crates/synor-database/src/lib.rs b/crates/synor-database/src/lib.rs index 469b259..410c778 100644 --- a/crates/synor-database/src/lib.rs +++ b/crates/synor-database/src/lib.rs @@ -45,20 +45,32 @@ pub mod document; pub mod error; pub mod gateway; +pub mod graph; pub mod index; pub mod keyvalue; pub mod query; +pub mod replication; pub mod schema; +pub mod sql; pub mod timeseries; pub mod vector; pub use document::{Collection, Document, DocumentId, DocumentStore}; pub use error::DatabaseError; pub use gateway::{GatewayConfig, GatewayServer}; +pub use graph::{ + Direction, Edge, EdgeId, GraphError, GraphQuery, GraphQueryParser, GraphStore, Node, NodeId, + PathFinder, PathResult, TraversalQuery, TraversalResult, Traverser, +}; pub use index::{Index, IndexConfig, IndexManager, IndexType}; pub use keyvalue::{KeyValue, KeyValueStore, KvEntry}; pub use query::{Filter, Query, QueryEngine, QueryResult, SortOrder}; pub use schema::{Field, FieldType, Schema, SchemaValidator}; +pub use replication::{ + ClusterConfig, Command as RaftCommand, NodeRole, RaftConfig, RaftEvent, RaftNode, RaftState, + ReplicatedLog, +}; +pub use sql::{QueryResult as SqlQueryResult, SqlEngine, SqlParser, SqlType, SqlValue, Table, TableDef}; pub use timeseries::{DataPoint, Metric, TimeSeries, TimeSeriesStore}; pub use vector::{Embedding, SimilarityMetric, VectorIndex, VectorStore}; diff --git a/crates/synor-database/src/replication/cluster.rs b/crates/synor-database/src/replication/cluster.rs new file mode 100644 index 0000000..2dee4a1 --- /dev/null +++ b/crates/synor-database/src/replication/cluster.rs @@ -0,0 +1,393 @@ +//! Cluster configuration and peer management. + +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::net::SocketAddr; +use std::str::FromStr; + +/// Unique identifier for a node in the cluster. +pub type NodeId = u64; + +/// Address for a peer node. +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct PeerAddress { + /// Host address (IP or hostname). + pub host: String, + /// Port number. + pub port: u16, +} + +impl PeerAddress { + /// Creates a new peer address. + pub fn new(host: impl Into, port: u16) -> Self { + Self { + host: host.into(), + port, + } + } + + /// Parses from "host:port" format. + pub fn parse(s: &str) -> Option { + let parts: Vec<&str> = s.split(':').collect(); + if parts.len() == 2 { + parts[1].parse().ok().map(|port| Self { + host: parts[0].to_string(), + port, + }) + } else { + None + } + } + + /// Converts to SocketAddr if possible. + pub fn to_socket_addr(&self) -> Option { + SocketAddr::from_str(&format!("{}:{}", self.host, self.port)).ok() + } +} + +impl std::fmt::Display for PeerAddress { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}:{}", self.host, self.port) + } +} + +/// Information about a peer node. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct PeerInfo { + /// Node identifier. + pub id: NodeId, + /// Network address. + pub address: PeerAddress, + /// Whether this peer is a voting member. + pub voting: bool, + /// Last known state. + pub state: PeerState, +} + +impl PeerInfo { + /// Creates new peer info. + pub fn new(id: NodeId, address: PeerAddress) -> Self { + Self { + id, + address, + voting: true, + state: PeerState::Unknown, + } + } + + /// Creates a non-voting learner peer. + pub fn learner(id: NodeId, address: PeerAddress) -> Self { + Self { + id, + address, + voting: false, + state: PeerState::Unknown, + } + } +} + +/// State of a peer from this node's perspective. +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)] +pub enum PeerState { + /// State unknown (initial). + Unknown, + /// Peer is reachable. + Reachable, + /// Peer is unreachable. + Unreachable, + /// Peer is being probed. + Probing, +} + +/// Configuration for the cluster. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ClusterConfig { + /// This node's ID. + pub node_id: NodeId, + /// This node's address. + pub address: PeerAddress, + /// Known peers in the cluster. + pub peers: HashMap, + /// Configuration index (for joint consensus). + pub config_index: u64, + /// Whether this is a joint configuration. + pub joint: bool, + /// Old configuration (for joint consensus). + pub old_peers: Option>, +} + +impl ClusterConfig { + /// Creates a new single-node cluster configuration. + pub fn new(node_id: NodeId, address: PeerAddress) -> Self { + Self { + node_id, + address, + peers: HashMap::new(), + config_index: 0, + joint: false, + old_peers: None, + } + } + + /// Adds a peer to the cluster. + pub fn add_peer(&mut self, peer: PeerInfo) { + self.peers.insert(peer.id, peer); + } + + /// Removes a peer from the cluster. + pub fn remove_peer(&mut self, id: NodeId) -> Option { + self.peers.remove(&id) + } + + /// Gets a peer by ID. + pub fn get_peer(&self, id: NodeId) -> Option<&PeerInfo> { + self.peers.get(&id) + } + + /// Gets a mutable reference to a peer. + pub fn get_peer_mut(&mut self, id: NodeId) -> Option<&mut PeerInfo> { + self.peers.get_mut(&id) + } + + /// Returns all peer IDs. + pub fn peer_ids(&self) -> Vec { + self.peers.keys().copied().collect() + } + + /// Returns all voting peer IDs. + pub fn voting_peers(&self) -> Vec { + self.peers + .iter() + .filter(|(_, p)| p.voting) + .map(|(id, _)| *id) + .collect() + } + + /// Returns the total number of voting members (including self). + pub fn voting_members(&self) -> usize { + self.peers.values().filter(|p| p.voting).count() + 1 + } + + /// Returns the quorum size needed for consensus. + pub fn quorum_size(&self) -> usize { + self.voting_members() / 2 + 1 + } + + /// Checks if we have quorum with the given votes. + pub fn has_quorum(&self, votes: usize) -> bool { + votes >= self.quorum_size() + } + + /// Starts a configuration change (joint consensus). + pub fn begin_config_change(&mut self, new_peers: HashMap, index: u64) { + self.old_peers = Some(self.peers.clone()); + self.peers = new_peers; + self.config_index = index; + self.joint = true; + } + + /// Completes a configuration change. + pub fn complete_config_change(&mut self) { + self.old_peers = None; + self.joint = false; + } + + /// Aborts a configuration change. + pub fn abort_config_change(&mut self) { + if let Some(old) = self.old_peers.take() { + self.peers = old; + } + self.joint = false; + } + + /// Checks if node is in joint consensus mode. + pub fn is_joint(&self) -> bool { + self.joint + } + + /// For joint consensus: checks if we have quorum in BOTH configurations. + pub fn has_joint_quorum(&self, new_votes: usize, old_votes: usize) -> bool { + if !self.joint { + return self.has_quorum(new_votes); + } + + let new_quorum = self.quorum_size(); + let old_quorum = self + .old_peers + .as_ref() + .map(|p| p.values().filter(|peer| peer.voting).count() / 2 + 1) + .unwrap_or(1); + + new_votes >= new_quorum && old_votes >= old_quorum + } + + /// Updates peer state. + pub fn update_peer_state(&mut self, id: NodeId, state: PeerState) { + if let Some(peer) = self.peers.get_mut(&id) { + peer.state = state; + } + } + + /// Returns all reachable peers. + pub fn reachable_peers(&self) -> Vec { + self.peers + .iter() + .filter(|(_, p)| p.state == PeerState::Reachable) + .map(|(id, _)| *id) + .collect() + } + + /// Serializes the configuration. + pub fn to_bytes(&self) -> Vec { + bincode::serialize(self).unwrap_or_default() + } + + /// Deserializes the configuration. + pub fn from_bytes(bytes: &[u8]) -> Option { + bincode::deserialize(bytes).ok() + } +} + +impl Default for ClusterConfig { + fn default() -> Self { + Self::new(1, PeerAddress::new("127.0.0.1", 9000)) + } +} + +/// Builder for cluster configurations. +pub struct ClusterBuilder { + config: ClusterConfig, +} + +impl ClusterBuilder { + /// Creates a new builder. + pub fn new(node_id: NodeId, address: PeerAddress) -> Self { + Self { + config: ClusterConfig::new(node_id, address), + } + } + + /// Adds a peer. + pub fn with_peer(mut self, id: NodeId, address: PeerAddress) -> Self { + self.config.add_peer(PeerInfo::new(id, address)); + self + } + + /// Adds a learner (non-voting) peer. + pub fn with_learner(mut self, id: NodeId, address: PeerAddress) -> Self { + self.config.add_peer(PeerInfo::learner(id, address)); + self + } + + /// Builds the configuration. + pub fn build(self) -> ClusterConfig { + self.config + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_peer_address() { + let addr = PeerAddress::new("192.168.1.1", 9000); + assert_eq!(addr.to_string(), "192.168.1.1:9000"); + + let parsed = PeerAddress::parse("10.0.0.1:8080").unwrap(); + assert_eq!(parsed.host, "10.0.0.1"); + assert_eq!(parsed.port, 8080); + } + + #[test] + fn test_cluster_config() { + let mut config = ClusterConfig::new(1, PeerAddress::new("127.0.0.1", 9000)); + + config.add_peer(PeerInfo::new(2, PeerAddress::new("127.0.0.1", 9001))); + config.add_peer(PeerInfo::new(3, PeerAddress::new("127.0.0.1", 9002))); + + assert_eq!(config.voting_members(), 3); + assert_eq!(config.quorum_size(), 2); + assert!(config.has_quorum(2)); + assert!(!config.has_quorum(1)); + } + + #[test] + fn test_quorum_sizes() { + // 1 node: quorum = 1 + let config1 = ClusterConfig::new(1, PeerAddress::new("127.0.0.1", 9000)); + assert_eq!(config1.quorum_size(), 1); + + // 3 nodes: quorum = 2 + let mut config3 = ClusterConfig::new(1, PeerAddress::new("127.0.0.1", 9000)); + config3.add_peer(PeerInfo::new(2, PeerAddress::new("127.0.0.1", 9001))); + config3.add_peer(PeerInfo::new(3, PeerAddress::new("127.0.0.1", 9002))); + assert_eq!(config3.quorum_size(), 2); + + // 5 nodes: quorum = 3 + let mut config5 = config3.clone(); + config5.add_peer(PeerInfo::new(4, PeerAddress::new("127.0.0.1", 9003))); + config5.add_peer(PeerInfo::new(5, PeerAddress::new("127.0.0.1", 9004))); + assert_eq!(config5.quorum_size(), 3); + } + + #[test] + fn test_learner_peers() { + let mut config = ClusterConfig::new(1, PeerAddress::new("127.0.0.1", 9000)); + config.add_peer(PeerInfo::new(2, PeerAddress::new("127.0.0.1", 9001))); + config.add_peer(PeerInfo::learner(3, PeerAddress::new("127.0.0.1", 9002))); + + // Learners don't count toward quorum + assert_eq!(config.voting_members(), 2); // self + node 2 + assert_eq!(config.quorum_size(), 2); + assert_eq!(config.voting_peers().len(), 1); // only node 2 + } + + #[test] + fn test_cluster_builder() { + let config = ClusterBuilder::new(1, PeerAddress::new("127.0.0.1", 9000)) + .with_peer(2, PeerAddress::new("127.0.0.1", 9001)) + .with_peer(3, PeerAddress::new("127.0.0.1", 9002)) + .with_learner(4, PeerAddress::new("127.0.0.1", 9003)) + .build(); + + assert_eq!(config.peer_ids().len(), 3); + assert_eq!(config.voting_members(), 3); + } + + #[test] + fn test_joint_consensus() { + let mut config = ClusterConfig::new(1, PeerAddress::new("127.0.0.1", 9000)); + config.add_peer(PeerInfo::new(2, PeerAddress::new("127.0.0.1", 9001))); + config.add_peer(PeerInfo::new(3, PeerAddress::new("127.0.0.1", 9002))); + + // Start config change: add node 4 + let mut new_peers = config.peers.clone(); + new_peers.insert(4, PeerInfo::new(4, PeerAddress::new("127.0.0.1", 9003))); + + config.begin_config_change(new_peers, 100); + assert!(config.is_joint()); + + // Need quorum in both old (2 of 3) and new (3 of 4) configs + assert!(config.has_joint_quorum(3, 2)); + assert!(!config.has_joint_quorum(2, 2)); // Not enough in new config + assert!(!config.has_joint_quorum(3, 1)); // Not enough in old config + + config.complete_config_change(); + assert!(!config.is_joint()); + assert_eq!(config.voting_members(), 4); + } + + #[test] + fn test_serialization() { + let config = ClusterBuilder::new(1, PeerAddress::new("127.0.0.1", 9000)) + .with_peer(2, PeerAddress::new("127.0.0.1", 9001)) + .build(); + + let bytes = config.to_bytes(); + let decoded = ClusterConfig::from_bytes(&bytes).unwrap(); + + assert_eq!(decoded.node_id, 1); + assert_eq!(decoded.peers.len(), 1); + } +} diff --git a/crates/synor-database/src/replication/election.rs b/crates/synor-database/src/replication/election.rs new file mode 100644 index 0000000..1b2c93f --- /dev/null +++ b/crates/synor-database/src/replication/election.rs @@ -0,0 +1,311 @@ +//! Leader election logic for Raft. + +use super::log::ReplicatedLog; +use super::rpc::{RequestVote, RequestVoteResponse}; +use super::state::RaftState; +use std::collections::HashSet; + +/// Result of a leader election. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum ElectionResult { + /// Won the election, became leader. + Won, + /// Lost the election (higher term discovered). + Lost, + /// Election timed out, will retry. + Timeout, + /// Still in progress. + InProgress, +} + +/// Tracks election state. +pub struct Election { + /// Node ID of the candidate. + node_id: u64, + /// Term for this election. + term: u64, + /// Number of votes received (including self). + votes_received: HashSet, + /// Total number of nodes in cluster (including self). + cluster_size: usize, + /// Whether this election is still active. + active: bool, +} + +impl Election { + /// Starts a new election. + pub fn new(node_id: u64, term: u64, cluster_size: usize) -> Self { + let mut votes = HashSet::new(); + votes.insert(node_id); // Vote for self + + Self { + node_id, + term, + votes_received: votes, + cluster_size, + active: true, + } + } + + /// Returns the term for this election. + pub fn term(&self) -> u64 { + self.term + } + + /// Checks if the election is still active. + pub fn is_active(&self) -> bool { + self.active + } + + /// Gets the number of votes received. + pub fn vote_count(&self) -> usize { + self.votes_received.len() + } + + /// Gets the majority threshold. + pub fn majority(&self) -> usize { + (self.cluster_size / 2) + 1 + } + + /// Records a vote from a peer. + pub fn record_vote(&mut self, peer_id: u64, granted: bool) -> ElectionResult { + if !self.active { + return ElectionResult::Lost; + } + + if granted { + self.votes_received.insert(peer_id); + + if self.votes_received.len() >= self.majority() { + self.active = false; + return ElectionResult::Won; + } + } + + ElectionResult::InProgress + } + + /// Cancels the election (e.g., discovered higher term). + pub fn cancel(&mut self) { + self.active = false; + } + + /// Creates a RequestVote message for this election. + pub fn create_request(&self, log: &ReplicatedLog) -> RequestVote { + RequestVote::new( + self.term, + self.node_id, + log.last_index(), + log.last_term(), + ) + } + + /// Checks the current result of the election. + pub fn result(&self) -> ElectionResult { + if !self.active { + ElectionResult::Lost + } else if self.votes_received.len() >= self.majority() { + ElectionResult::Won + } else { + ElectionResult::InProgress + } + } +} + +/// Handles vote requests from candidates. +pub struct VoteHandler; + +impl VoteHandler { + /// Processes a vote request and returns whether to grant the vote. + pub fn handle_request( + state: &mut RaftState, + log: &ReplicatedLog, + request: &RequestVote, + ) -> RequestVoteResponse { + // If request's term < current term, deny + if request.term < state.current_term { + return RequestVoteResponse::deny(state.current_term); + } + + // If request's term > current term, update term and become follower + if request.term > state.current_term { + state.become_follower(request.term); + } + + // Check if we can grant vote + let can_vote = state.voted_for.is_none() || state.voted_for == Some(request.candidate_id); + + // Check if candidate's log is at least as up-to-date as ours + let log_ok = log.is_up_to_date(request.last_log_index, request.last_log_term); + + if can_vote && log_ok { + state.voted_for = Some(request.candidate_id); + RequestVoteResponse::grant(state.current_term) + } else { + RequestVoteResponse::deny(state.current_term) + } + } + + /// Processes a vote response. + pub fn handle_response( + state: &mut RaftState, + election: &mut Election, + from_peer: u64, + response: &RequestVoteResponse, + ) -> ElectionResult { + // If response's term > current term, become follower + if response.term > state.current_term { + state.become_follower(response.term); + election.cancel(); + return ElectionResult::Lost; + } + + // Record the vote + election.record_vote(from_peer, response.vote_granted) + } +} + +/// Election timeout generator. +pub struct ElectionTimeout { + /// Minimum timeout in milliseconds. + min_ms: u64, + /// Maximum timeout in milliseconds. + max_ms: u64, +} + +impl ElectionTimeout { + /// Creates a new timeout generator. + pub fn new(min_ms: u64, max_ms: u64) -> Self { + Self { min_ms, max_ms } + } + + /// Generates a random timeout duration. + pub fn random_timeout(&self) -> std::time::Duration { + use std::time::{SystemTime, UNIX_EPOCH}; + + // Simple pseudo-random based on current time + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_nanos(); + + let range = self.max_ms - self.min_ms; + let random_add = (now % range as u128) as u64; + let timeout_ms = self.min_ms + random_add; + + std::time::Duration::from_millis(timeout_ms) + } + + /// Returns the minimum timeout. + pub fn min_timeout(&self) -> std::time::Duration { + std::time::Duration::from_millis(self.min_ms) + } + + /// Returns the maximum timeout. + pub fn max_timeout(&self) -> std::time::Duration { + std::time::Duration::from_millis(self.max_ms) + } +} + +impl Default for ElectionTimeout { + fn default() -> Self { + // Default Raft election timeout: 150-300ms + Self::new(150, 300) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::replication::state::Command; + use crate::replication::log::LogEntry; + + #[test] + fn test_election_basic() { + let election = Election::new(1, 1, 5); + assert!(election.is_active()); + assert_eq!(election.vote_count(), 1); // Self vote + assert_eq!(election.majority(), 3); + } + + #[test] + fn test_election_win() { + let mut election = Election::new(1, 1, 5); + + // Need 3 votes total (including self) + assert_eq!(election.record_vote(2, true), ElectionResult::InProgress); + assert_eq!(election.record_vote(3, true), ElectionResult::Won); + } + + #[test] + fn test_election_rejected_votes() { + let mut election = Election::new(1, 1, 5); + + // Rejected votes don't count + assert_eq!(election.record_vote(2, false), ElectionResult::InProgress); + assert_eq!(election.record_vote(3, false), ElectionResult::InProgress); + assert_eq!(election.vote_count(), 1); + } + + #[test] + fn test_vote_handler_grant() { + let mut state = RaftState::new(); + let log = ReplicatedLog::new(); + + let request = RequestVote::new(1, 2, 0, 0); + let response = VoteHandler::handle_request(&mut state, &log, &request); + + assert!(response.vote_granted); + assert_eq!(state.voted_for, Some(2)); + } + + #[test] + fn test_vote_handler_deny_old_term() { + let mut state = RaftState::new(); + state.current_term = 5; + let log = ReplicatedLog::new(); + + let request = RequestVote::new(3, 2, 10, 3); + let response = VoteHandler::handle_request(&mut state, &log, &request); + + assert!(!response.vote_granted); + assert_eq!(response.term, 5); + } + + #[test] + fn test_vote_handler_deny_already_voted() { + let mut state = RaftState::new(); + state.current_term = 1; // Same term as request + state.voted_for = Some(3); // Already voted for node 3 in this term + let log = ReplicatedLog::new(); + + // Request from node 2 for the same term - should be denied + let request = RequestVote::new(1, 2, 0, 0); + let response = VoteHandler::handle_request(&mut state, &log, &request); + + assert!(!response.vote_granted); + } + + #[test] + fn test_vote_handler_deny_log_behind() { + let mut state = RaftState::new(); + let log = ReplicatedLog::new(); + log.append(LogEntry::new(2, 1, Command::Noop)); // Our log has term 2 + + let request = RequestVote::new(2, 2, 10, 1); // Candidate has term 1 entries + let response = VoteHandler::handle_request(&mut state, &log, &request); + + assert!(!response.vote_granted); + } + + #[test] + fn test_election_timeout() { + let timeout = ElectionTimeout::new(150, 300); + + for _ in 0..10 { + let duration = timeout.random_timeout(); + assert!(duration >= std::time::Duration::from_millis(150)); + assert!(duration <= std::time::Duration::from_millis(300)); + } + } +} diff --git a/crates/synor-database/src/replication/log.rs b/crates/synor-database/src/replication/log.rs new file mode 100644 index 0000000..417fbae --- /dev/null +++ b/crates/synor-database/src/replication/log.rs @@ -0,0 +1,387 @@ +//! Replicated log for Raft consensus. + +use super::state::Command; +use parking_lot::RwLock; +use serde::{Deserialize, Serialize}; + +/// A single entry in the replicated log. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct LogEntry { + /// Term when entry was received by leader. + pub term: u64, + /// Position in the log (1-indexed). + pub index: u64, + /// Command to execute. + pub command: Command, + /// Timestamp when entry was created. + pub timestamp: u64, +} + +impl LogEntry { + /// Creates a new log entry. + pub fn new(term: u64, index: u64, command: Command) -> Self { + let timestamp = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_millis() as u64; + + Self { + term, + index, + command, + timestamp, + } + } + + /// Serializes to bytes. + pub fn to_bytes(&self) -> Vec { + bincode::serialize(self).unwrap_or_default() + } + + /// Deserializes from bytes. + pub fn from_bytes(bytes: &[u8]) -> Option { + bincode::deserialize(bytes).ok() + } +} + +/// Replicated log storing all commands. +pub struct ReplicatedLog { + /// Log entries. + entries: RwLock>, + /// Index of first entry in the log (for log compaction). + start_index: RwLock, + /// Term of last included entry (for snapshots). + snapshot_term: RwLock, +} + +impl ReplicatedLog { + /// Creates a new empty log. + pub fn new() -> Self { + Self { + entries: RwLock::new(Vec::new()), + start_index: RwLock::new(1), // 1-indexed + snapshot_term: RwLock::new(0), + } + } + + /// Returns the index of the last entry. + pub fn last_index(&self) -> u64 { + let entries = self.entries.read(); + let start = *self.start_index.read(); + if entries.is_empty() { + start.saturating_sub(1) + } else { + start + entries.len() as u64 - 1 + } + } + + /// Returns the term of the last entry. + pub fn last_term(&self) -> u64 { + let entries = self.entries.read(); + if entries.is_empty() { + *self.snapshot_term.read() + } else { + entries.last().map(|e| e.term).unwrap_or(0) + } + } + + /// Returns the term at a given index. + pub fn term_at(&self, index: u64) -> Option { + let start = *self.start_index.read(); + if index < start { + return if index == start - 1 { + Some(*self.snapshot_term.read()) + } else { + None + }; + } + + let entries = self.entries.read(); + let offset = (index - start) as usize; + entries.get(offset).map(|e| e.term) + } + + /// Gets an entry by index. + pub fn get(&self, index: u64) -> Option { + let start = *self.start_index.read(); + if index < start { + return None; + } + + let entries = self.entries.read(); + let offset = (index - start) as usize; + entries.get(offset).cloned() + } + + /// Gets entries from start_index to end_index (inclusive). + pub fn get_range(&self, start_idx: u64, end_idx: u64) -> Vec { + let start = *self.start_index.read(); + if start_idx > end_idx || start_idx < start { + return Vec::new(); + } + + let entries = self.entries.read(); + let start_offset = (start_idx - start) as usize; + let end_offset = (end_idx - start + 1) as usize; + + entries + .get(start_offset..end_offset.min(entries.len())) + .map(|s| s.to_vec()) + .unwrap_or_default() + } + + /// Gets all entries from a given index. + pub fn entries_from(&self, from_index: u64) -> Vec { + let start = *self.start_index.read(); + if from_index < start { + return self.entries.read().clone(); + } + + let entries = self.entries.read(); + let offset = (from_index - start) as usize; + entries.get(offset..).map(|s| s.to_vec()).unwrap_or_default() + } + + /// Appends an entry to the log. + pub fn append(&self, entry: LogEntry) -> u64 { + let mut entries = self.entries.write(); + let index = entry.index; + entries.push(entry); + index + } + + /// Appends multiple entries, potentially overwriting conflicting entries. + pub fn append_entries(&self, prev_index: u64, prev_term: u64, new_entries: Vec) -> bool { + // Check that prev entry matches + if prev_index > 0 { + if let Some(prev_entry_term) = self.term_at(prev_index) { + if prev_entry_term != prev_term { + // Conflict - need to truncate + return false; + } + } else if prev_index >= *self.start_index.read() { + // Missing entry + return false; + } + } + + let mut entries = self.entries.write(); + let start = *self.start_index.read(); + + for entry in new_entries { + let offset = (entry.index - start) as usize; + + if offset < entries.len() { + // Check for conflict + if entries[offset].term != entry.term { + // Delete this and all following entries + entries.truncate(offset); + entries.push(entry); + } + // Otherwise entry already exists with same term, skip + } else { + entries.push(entry); + } + } + + true + } + + /// Truncates the log after the given index. + pub fn truncate_after(&self, index: u64) { + let start = *self.start_index.read(); + if index < start { + return; + } + + let mut entries = self.entries.write(); + let offset = (index - start + 1) as usize; + entries.truncate(offset); + } + + /// Returns the number of entries. + pub fn len(&self) -> usize { + self.entries.read().len() + } + + /// Returns true if the log is empty. + pub fn is_empty(&self) -> bool { + self.entries.read().is_empty() + } + + /// Compacts the log up to (and including) the given index. + pub fn compact(&self, up_to_index: u64, up_to_term: u64) { + let mut entries = self.entries.write(); + let mut start = self.start_index.write(); + let mut snapshot_term = self.snapshot_term.write(); + + if up_to_index < *start { + return; + } + + let remove_count = (up_to_index - *start + 1) as usize; + if remove_count >= entries.len() { + entries.clear(); + } else { + entries.drain(0..remove_count); + } + + *start = up_to_index + 1; + *snapshot_term = up_to_term; + } + + /// Checks if this log is at least as up-to-date as the candidate's log. + /// Used during leader election. + pub fn is_up_to_date(&self, candidate_last_index: u64, candidate_last_term: u64) -> bool { + let last_term = self.last_term(); + let last_index = self.last_index(); + + // Compare terms first, then indices + if last_term != candidate_last_term { + last_term <= candidate_last_term + } else { + last_index <= candidate_last_index + } + } + + /// Creates entries for replication starting from a given index. + pub fn entries_for_replication(&self, from_index: u64, max_entries: usize) -> (u64, u64, Vec) { + let prev_index = from_index.saturating_sub(1); + let prev_term = self.term_at(prev_index).unwrap_or(0); + + let entries = self.entries_from(from_index); + let limited = if entries.len() > max_entries { + entries[..max_entries].to_vec() + } else { + entries + }; + + (prev_index, prev_term, limited) + } +} + +impl Default for ReplicatedLog { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_log_basic() { + let log = ReplicatedLog::new(); + + assert_eq!(log.last_index(), 0); + assert_eq!(log.last_term(), 0); + assert!(log.is_empty()); + + let entry = LogEntry::new(1, 1, Command::Noop); + log.append(entry); + + assert_eq!(log.last_index(), 1); + assert_eq!(log.last_term(), 1); + assert!(!log.is_empty()); + } + + #[test] + fn test_log_append() { + let log = ReplicatedLog::new(); + + for i in 1..=5 { + log.append(LogEntry::new(1, i, Command::Noop)); + } + + assert_eq!(log.len(), 5); + assert_eq!(log.last_index(), 5); + + let entry = log.get(3).unwrap(); + assert_eq!(entry.index, 3); + } + + #[test] + fn test_log_range() { + let log = ReplicatedLog::new(); + + for i in 1..=10 { + log.append(LogEntry::new(1, i, Command::Noop)); + } + + let range = log.get_range(3, 7); + assert_eq!(range.len(), 5); + assert_eq!(range[0].index, 3); + assert_eq!(range[4].index, 7); + } + + #[test] + fn test_log_compact() { + let log = ReplicatedLog::new(); + + for i in 1..=10 { + log.append(LogEntry::new(1, i, Command::Noop)); + } + + log.compact(5, 1); + + assert_eq!(log.len(), 5); + assert!(log.get(5).is_none()); + assert!(log.get(6).is_some()); + assert_eq!(log.last_index(), 10); + } + + #[test] + fn test_log_up_to_date() { + let log = ReplicatedLog::new(); + + log.append(LogEntry::new(1, 1, Command::Noop)); + log.append(LogEntry::new(2, 2, Command::Noop)); + + // Same log is up to date + assert!(log.is_up_to_date(2, 2)); + + // Higher term is more up to date + assert!(log.is_up_to_date(1, 3)); + + // Same term, higher index is more up to date + assert!(log.is_up_to_date(3, 2)); + + // Lower term is not up to date + assert!(!log.is_up_to_date(10, 1)); + } + + #[test] + fn test_append_entries() { + let log = ReplicatedLog::new(); + + // Initial entries + log.append(LogEntry::new(1, 1, Command::Noop)); + log.append(LogEntry::new(1, 2, Command::Noop)); + + // Append more entries + let new_entries = vec![ + LogEntry::new(1, 3, Command::Noop), + LogEntry::new(2, 4, Command::Noop), + ]; + + let success = log.append_entries(2, 1, new_entries); + assert!(success); + assert_eq!(log.len(), 4); + assert_eq!(log.term_at(4), Some(2)); + } + + #[test] + fn test_entries_for_replication() { + let log = ReplicatedLog::new(); + + for i in 1..=5 { + log.append(LogEntry::new(1, i, Command::Noop)); + } + + let (prev_idx, prev_term, entries) = log.entries_for_replication(3, 10); + assert_eq!(prev_idx, 2); + assert_eq!(prev_term, 1); + assert_eq!(entries.len(), 3); + } +} diff --git a/crates/synor-database/src/replication/mod.rs b/crates/synor-database/src/replication/mod.rs new file mode 100644 index 0000000..2215cce --- /dev/null +++ b/crates/synor-database/src/replication/mod.rs @@ -0,0 +1,23 @@ +//! Raft consensus-based replication for high availability. +//! +//! Provides distributed consensus to ensure data consistency across +//! multiple database nodes with automatic leader election and failover. + +pub mod cluster; +pub mod election; +pub mod log; +pub mod raft; +pub mod rpc; +pub mod snapshot; +pub mod state; + +pub use cluster::{ClusterBuilder, ClusterConfig, NodeId, PeerAddress, PeerInfo, PeerState}; +pub use election::{Election, ElectionResult, ElectionTimeout, VoteHandler}; +pub use log::{LogEntry, ReplicatedLog}; +pub use raft::{ApplyResult, RaftConfig, RaftEvent, RaftNode}; +pub use rpc::{ + AppendEntries, AppendEntriesResponse, InstallSnapshot, InstallSnapshotResponse, RequestVote, + RequestVoteResponse, RpcMessage, +}; +pub use snapshot::{Snapshot, SnapshotConfig, SnapshotManager, SnapshotMetadata}; +pub use state::{Command, LeaderState, NodeRole, RaftState}; diff --git a/crates/synor-database/src/replication/raft.rs b/crates/synor-database/src/replication/raft.rs new file mode 100644 index 0000000..2a9adcb --- /dev/null +++ b/crates/synor-database/src/replication/raft.rs @@ -0,0 +1,955 @@ +//! Raft consensus implementation. + +use super::cluster::{ClusterConfig, NodeId, PeerState}; +use super::election::{Election, ElectionResult, ElectionTimeout, VoteHandler}; +use super::log::{LogEntry, ReplicatedLog}; +use super::rpc::{ + AppendEntries, AppendEntriesResponse, InstallSnapshot, InstallSnapshotResponse, RequestVote, + RequestVoteResponse, RpcMessage, +}; +use super::snapshot::{Snapshot, SnapshotConfig, SnapshotManager}; +use super::state::{Command, LeaderState, NodeRole, RaftState}; +use serde::{Deserialize, Serialize}; +use std::time::{Duration, Instant}; + +/// Configuration for the Raft node. +#[derive(Clone, Debug)] +pub struct RaftConfig { + /// Minimum election timeout in milliseconds. + pub election_timeout_min: u64, + /// Maximum election timeout in milliseconds. + pub election_timeout_max: u64, + /// Heartbeat interval in milliseconds. + pub heartbeat_interval: u64, + /// Maximum entries per AppendEntries RPC. + pub max_entries_per_rpc: usize, + /// Snapshot threshold (entries before compaction). + pub snapshot_threshold: u64, + /// Maximum snapshot chunk size. + pub snapshot_chunk_size: usize, +} + +impl Default for RaftConfig { + fn default() -> Self { + Self { + election_timeout_min: 150, + election_timeout_max: 300, + heartbeat_interval: 50, + max_entries_per_rpc: 100, + snapshot_threshold: 10000, + snapshot_chunk_size: 65536, + } + } +} + +/// Result of applying a command. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub enum ApplyResult { + /// Command applied successfully. + Success(Vec), + /// Command failed. + Error(String), + /// Not the leader, redirect to leader. + NotLeader(Option), +} + +/// Events that can be produced by the Raft node. +#[derive(Clone, Debug)] +pub enum RaftEvent { + /// Send RPC to a peer. + SendRpc(NodeId, RpcMessage), + /// Broadcast RPC to all peers. + BroadcastRpc(RpcMessage), + /// Apply committed entry to state machine. + ApplyEntry(u64, Command), + /// Became leader. + BecameLeader, + /// Became follower. + BecameFollower(Option), + /// Snapshot should be taken. + TakeSnapshot, + /// Log compacted up to index. + LogCompacted(u64), +} + +/// The Raft consensus node. +pub struct RaftNode { + /// Node ID. + id: NodeId, + /// Cluster configuration. + cluster: ClusterConfig, + /// Raft configuration. + config: RaftConfig, + /// Persistent state. + state: RaftState, + /// Replicated log. + log: ReplicatedLog, + /// Current election (only during candidacy). + election: Option, + /// Election timeout generator. + election_timeout: ElectionTimeout, + /// Snapshot manager. + snapshots: SnapshotManager, + /// Leader state (only valid when leader). + leader_state: Option, + /// Current known leader. + leader_id: Option, + /// Last heartbeat/message from leader. + last_leader_contact: Instant, + /// Current election timeout duration. + current_timeout: Duration, + /// Pending events. + events: Vec, +} + +impl RaftNode { + /// Creates a new Raft node. + pub fn new(id: NodeId, cluster: ClusterConfig, config: RaftConfig) -> Self { + let election_timeout = + ElectionTimeout::new(config.election_timeout_min, config.election_timeout_max); + let current_timeout = election_timeout.random_timeout(); + + Self { + id, + cluster, + state: RaftState::new(), + log: ReplicatedLog::new(), + election: None, + election_timeout, + snapshots: SnapshotManager::new(config.snapshot_threshold), + leader_state: None, + leader_id: None, + last_leader_contact: Instant::now(), + current_timeout, + events: Vec::new(), + config, + } + } + + /// Returns the node ID. + pub fn id(&self) -> NodeId { + self.id + } + + /// Returns the current term. + pub fn current_term(&self) -> u64 { + self.state.current_term + } + + /// Returns the current role. + pub fn role(&self) -> NodeRole { + self.state.role + } + + /// Returns the current leader ID if known. + pub fn leader(&self) -> Option { + self.leader_id + } + + /// Returns true if this node is the leader. + pub fn is_leader(&self) -> bool { + self.state.is_leader() + } + + /// Returns the commit index. + pub fn commit_index(&self) -> u64 { + self.state.commit_index + } + + /// Returns the last applied index. + pub fn last_applied(&self) -> u64 { + self.state.last_applied + } + + /// Returns the log length. + pub fn log_length(&self) -> u64 { + self.log.last_index() + } + + /// Drains pending events. + pub fn drain_events(&mut self) -> Vec { + std::mem::take(&mut self.events) + } + + /// Called periodically to drive the Raft state machine. + pub fn tick(&mut self) { + match self.state.role { + NodeRole::Leader => self.tick_leader(), + NodeRole::Follower => self.tick_follower(), + NodeRole::Candidate => self.tick_candidate(), + } + + // Apply committed entries + self.apply_committed_entries(); + + // Check if snapshot needed + if self + .snapshots + .should_snapshot(self.log.last_index(), self.snapshots.last_included_index()) + { + self.events.push(RaftEvent::TakeSnapshot); + } + } + + fn tick_leader(&mut self) { + // Send heartbeats to all peers + self.send_heartbeats(); + } + + fn tick_follower(&mut self) { + // Check for election timeout + if self.last_leader_contact.elapsed() >= self.current_timeout { + self.start_election(); + } + } + + fn tick_candidate(&mut self) { + // Check for election timeout + if self.last_leader_contact.elapsed() >= self.current_timeout { + // Start a new election + self.start_election(); + } + } + + fn start_election(&mut self) { + // Increment term and become candidate + self.state.become_candidate(); + self.state.voted_for = Some(self.id); + self.leader_id = None; + + // Reset timeout + self.reset_election_timeout(); + + // Create new election + let cluster_size = self.cluster.voting_members(); + self.election = Some(Election::new(self.id, self.state.current_term, cluster_size)); + + // Create RequestVote message + let request = RequestVote::new( + self.state.current_term, + self.id, + self.log.last_index(), + self.log.last_term(), + ); + + self.events + .push(RaftEvent::BroadcastRpc(RpcMessage::RequestVote(request))); + + // Check if we already have quorum (single-node cluster) + if cluster_size == 1 { + self.become_leader(); + } + } + + fn become_leader(&mut self) { + self.state.become_leader(); + self.leader_id = Some(self.id); + self.election = None; + + // Initialize leader state + let peer_ids: Vec<_> = self.cluster.peer_ids(); + self.leader_state = Some(LeaderState::new(self.log.last_index(), &peer_ids)); + + self.events.push(RaftEvent::BecameLeader); + + // Send immediate heartbeats + self.send_heartbeats(); + } + + fn become_follower(&mut self, term: u64, leader: Option) { + self.state.become_follower(term); + self.leader_id = leader; + self.leader_state = None; + self.election = None; + self.reset_election_timeout(); + + self.events.push(RaftEvent::BecameFollower(leader)); + } + + fn reset_election_timeout(&mut self) { + self.last_leader_contact = Instant::now(); + self.current_timeout = self.election_timeout.random_timeout(); + } + + fn send_heartbeats(&mut self) { + if !self.is_leader() { + return; + } + + for peer_id in self.cluster.peer_ids() { + self.send_append_entries(peer_id); + } + } + + fn send_append_entries(&mut self, peer_id: NodeId) { + let leader_state = match &self.leader_state { + Some(ls) => ls, + None => return, + }; + + let next_index = *leader_state.next_index.get(&peer_id).unwrap_or(&1); + + // Check if we need to send snapshot instead + if next_index <= self.snapshots.last_included_index() { + self.send_install_snapshot(peer_id); + return; + } + + let (prev_log_index, prev_log_term, entries) = + self.log + .entries_for_replication(next_index, self.config.max_entries_per_rpc); + + let request = AppendEntries::with_entries( + self.state.current_term, + self.id, + prev_log_index, + prev_log_term, + entries, + self.state.commit_index, + ); + + self.events + .push(RaftEvent::SendRpc(peer_id, RpcMessage::AppendEntries(request))); + } + + fn send_install_snapshot(&mut self, peer_id: NodeId) { + let snapshot = match self.snapshots.get_snapshot() { + Some(s) => s, + None => return, + }; + + let chunks = self + .snapshots + .chunk_snapshot(self.config.snapshot_chunk_size); + if let Some((offset, data, done)) = chunks.into_iter().next() { + let request = InstallSnapshot::new( + self.state.current_term, + self.id, + snapshot.metadata.last_included_index, + snapshot.metadata.last_included_term, + offset, + data, + done, + ); + + self.events + .push(RaftEvent::SendRpc(peer_id, RpcMessage::InstallSnapshot(request))); + } + } + + /// Handles an incoming RPC message. + pub fn handle_rpc(&mut self, from: NodeId, message: RpcMessage) -> Option { + match message { + RpcMessage::RequestVote(req) => { + let response = self.handle_request_vote(from, req); + Some(RpcMessage::RequestVoteResponse(response)) + } + RpcMessage::RequestVoteResponse(resp) => { + self.handle_request_vote_response(from, resp); + None + } + RpcMessage::AppendEntries(req) => { + let response = self.handle_append_entries(from, req); + Some(RpcMessage::AppendEntriesResponse(response)) + } + RpcMessage::AppendEntriesResponse(resp) => { + self.handle_append_entries_response(from, resp); + None + } + RpcMessage::InstallSnapshot(req) => { + let response = self.handle_install_snapshot(from, req); + Some(RpcMessage::InstallSnapshotResponse(response)) + } + RpcMessage::InstallSnapshotResponse(resp) => { + self.handle_install_snapshot_response(from, resp); + None + } + } + } + + fn handle_request_vote(&mut self, _from: NodeId, req: RequestVote) -> RequestVoteResponse { + // Use the VoteHandler from election module + VoteHandler::handle_request(&mut self.state, &self.log, &req) + } + + fn handle_request_vote_response(&mut self, from: NodeId, resp: RequestVoteResponse) { + // Ignore if not candidate + if !self.state.is_candidate() { + return; + } + + // Use the VoteHandler + if let Some(ref mut election) = self.election { + let result = VoteHandler::handle_response(&mut self.state, election, from, &resp); + + match result { + ElectionResult::Won => { + self.become_leader(); + } + ElectionResult::Lost => { + // Already handled by VoteHandler (became follower) + self.election = None; + } + ElectionResult::InProgress | ElectionResult::Timeout => {} + } + } + } + + fn handle_append_entries(&mut self, _from: NodeId, req: AppendEntries) -> AppendEntriesResponse { + // Rule: If term > currentTerm, become follower + if req.term > self.state.current_term { + self.become_follower(req.term, Some(req.leader_id)); + } + + // Reject if term is old + if req.term < self.state.current_term { + return AppendEntriesResponse::failure(self.state.current_term); + } + + // Valid AppendEntries from leader - reset election timeout + self.reset_election_timeout(); + self.leader_id = Some(req.leader_id); + + // If we're candidate, step down + if self.state.is_candidate() { + self.become_follower(req.term, Some(req.leader_id)); + } + + // Try to append entries + let success = + self.log + .append_entries(req.prev_log_index, req.prev_log_term, req.entries); + + if success { + // Update commit index + if req.leader_commit > self.state.commit_index { + self.state + .update_commit_index(std::cmp::min(req.leader_commit, self.log.last_index())); + } + + AppendEntriesResponse::success(self.state.current_term, self.log.last_index()) + } else { + // Find conflict info for faster recovery + if let Some(conflict_term) = self.log.term_at(req.prev_log_index) { + // Find first index of conflicting term + let mut conflict_index = req.prev_log_index; + while conflict_index > 1 { + if let Some(term) = self.log.term_at(conflict_index - 1) { + if term != conflict_term { + break; + } + } else { + break; + } + conflict_index -= 1; + } + AppendEntriesResponse::conflict(self.state.current_term, conflict_term, conflict_index) + } else { + AppendEntriesResponse::failure(self.state.current_term) + } + } + } + + fn handle_append_entries_response(&mut self, from: NodeId, resp: AppendEntriesResponse) { + // Ignore if not leader + if !self.is_leader() { + return; + } + + // Step down if term is higher + if resp.term > self.state.current_term { + self.become_follower(resp.term, None); + return; + } + + let leader_state = match &mut self.leader_state { + Some(ls) => ls, + None => return, + }; + + if resp.success { + // Update match_index and next_index + leader_state.update_indices(from, resp.match_index); + + // Try to advance commit index + self.try_advance_commit_index(); + } else { + // Use conflict info if available for faster recovery + if let (Some(conflict_term), Some(conflict_index)) = + (resp.conflict_term, resp.conflict_index) + { + // Search for last entry with conflict_term + let mut new_next = conflict_index; + for idx in (1..=self.log.last_index()).rev() { + if let Some(term) = self.log.term_at(idx) { + if term == conflict_term { + new_next = idx + 1; + break; + } + if term < conflict_term { + break; + } + } + } + leader_state.next_index.insert(from, new_next); + } else { + // Simple decrement + leader_state.decrement_next_index(from); + } + } + + // Update peer state + self.cluster.update_peer_state(from, PeerState::Reachable); + } + + fn handle_install_snapshot(&mut self, _from: NodeId, req: InstallSnapshot) -> InstallSnapshotResponse { + // Rule: If term > currentTerm, become follower + if req.term > self.state.current_term { + self.become_follower(req.term, Some(req.leader_id)); + } + + // Reject if term is old + if req.term < self.state.current_term { + return InstallSnapshotResponse::new(self.state.current_term); + } + + // Reset election timeout + self.reset_election_timeout(); + self.leader_id = Some(req.leader_id); + + // Start or continue receiving snapshot + if req.offset == 0 { + self.snapshots.start_receiving( + req.last_included_index, + req.last_included_term, + req.offset, + req.data, + ); + } else { + self.snapshots.add_chunk(req.offset, req.data); + } + + // If done, finalize snapshot + if req.done { + if let Some(_snapshot) = self.snapshots.finalize_snapshot() { + // Discard log up to snapshot + self.log + .compact(req.last_included_index, req.last_included_term); + + // Update state + if req.last_included_index > self.state.commit_index { + self.state.update_commit_index(req.last_included_index); + } + if req.last_included_index > self.state.last_applied { + self.state.update_last_applied(req.last_included_index); + } + + self.events + .push(RaftEvent::LogCompacted(req.last_included_index)); + } + } + + InstallSnapshotResponse::new(self.state.current_term) + } + + fn handle_install_snapshot_response(&mut self, from: NodeId, resp: InstallSnapshotResponse) { + if !self.is_leader() { + return; + } + + if resp.term > self.state.current_term { + self.become_follower(resp.term, None); + return; + } + + // Update next_index for peer (assuming success since there's no success field) + if let Some(leader_state) = &mut self.leader_state { + let snapshot_index = self.snapshots.last_included_index(); + leader_state.update_indices(from, snapshot_index); + } + + self.cluster.update_peer_state(from, PeerState::Reachable); + } + + fn try_advance_commit_index(&mut self) { + if !self.is_leader() { + return; + } + + let leader_state = match &self.leader_state { + Some(ls) => ls, + None => return, + }; + + // Calculate new commit index + let new_commit = leader_state.calculate_commit_index( + self.state.commit_index, + self.state.current_term, + |idx| self.log.term_at(idx), + ); + + if new_commit > self.state.commit_index { + self.state.update_commit_index(new_commit); + } + } + + fn apply_committed_entries(&mut self) { + while self.state.last_applied < self.state.commit_index { + let next_to_apply = self.state.last_applied + 1; + + if let Some(entry) = self.log.get(next_to_apply) { + self.events + .push(RaftEvent::ApplyEntry(entry.index, entry.command.clone())); + self.state.update_last_applied(next_to_apply); + } else { + break; + } + } + } + + /// Proposes a command (only valid on leader). + pub fn propose(&mut self, command: Command) -> Result { + if !self.is_leader() { + return Err(ApplyResult::NotLeader(self.leader_id)); + } + + // Create log entry + let index = self.log.last_index() + 1; + let entry = LogEntry::new(self.state.current_term, index, command); + let result_index = self.log.append(entry); + + // Send AppendEntries to all peers + for peer_id in self.cluster.peer_ids() { + self.send_append_entries(peer_id); + } + + Ok(result_index) + } + + /// Takes a snapshot of the state machine. + pub fn take_snapshot(&mut self, data: Vec) { + let last_index = self.state.last_applied; + let last_term = self.log.term_at(last_index).unwrap_or(0); + + let config = SnapshotConfig { + nodes: std::iter::once(self.id) + .chain(self.cluster.peer_ids()) + .collect(), + }; + + self.snapshots + .create_snapshot(last_index, last_term, config, data); + + // Compact log + self.log.compact(last_index, last_term); + self.events.push(RaftEvent::LogCompacted(last_index)); + } + + /// Adds a new server to the cluster. + pub fn add_server( + &mut self, + id: NodeId, + address: super::cluster::PeerAddress, + ) -> Result<(), String> { + if !self.is_leader() { + return Err("Not leader".to_string()); + } + + let peer = super::cluster::PeerInfo::new(id, address); + self.cluster.add_peer(peer); + + // Initialize leader state for new peer + if let Some(ref mut ls) = self.leader_state { + ls.next_index.insert(id, self.log.last_index() + 1); + ls.match_index.insert(id, 0); + } + + Ok(()) + } + + /// Removes a server from the cluster. + pub fn remove_server(&mut self, id: NodeId) -> Result<(), String> { + if !self.is_leader() { + return Err("Not leader".to_string()); + } + + self.cluster.remove_peer(id); + + if let Some(ref mut ls) = self.leader_state { + ls.next_index.remove(&id); + ls.match_index.remove(&id); + } + + Ok(()) + } + + /// Forces an election timeout (for testing). + #[cfg(test)] + pub fn force_election_timeout(&mut self) { + self.last_leader_contact = Instant::now() - self.current_timeout - Duration::from_secs(1); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use super::super::cluster::PeerAddress; + + fn create_test_cluster(node_id: NodeId, peers: &[NodeId]) -> ClusterConfig { + let mut cluster = + ClusterConfig::new(node_id, PeerAddress::new("127.0.0.1", 9000 + node_id as u16)); + for &peer in peers { + cluster.add_peer(super::super::cluster::PeerInfo::new( + peer, + PeerAddress::new("127.0.0.1", 9000 + peer as u16), + )); + } + cluster + } + + #[test] + fn test_single_node_election() { + let cluster = create_test_cluster(1, &[]); + let config = RaftConfig::default(); + let mut node = RaftNode::new(1, cluster, config); + + assert!(matches!(node.role(), NodeRole::Follower)); + + // Simulate election timeout + node.force_election_timeout(); + node.tick(); + + // Single node should become leader immediately + assert!(node.is_leader()); + } + + #[test] + fn test_three_node_election() { + let cluster = create_test_cluster(1, &[2, 3]); + let config = RaftConfig::default(); + let mut node1 = RaftNode::new(1, cluster, config); + + // Trigger election + node1.force_election_timeout(); + node1.tick(); + + assert!(matches!(node1.role(), NodeRole::Candidate)); + assert_eq!(node1.current_term(), 1); + + // Simulate receiving votes + let vote_resp = RequestVoteResponse::grant(1); + + node1.handle_rpc(2, RpcMessage::RequestVoteResponse(vote_resp.clone())); + + // Should become leader after receiving vote from node 2 (2 votes = quorum of 3) + assert!(node1.is_leader()); + } + + #[test] + fn test_append_entries() { + let cluster = create_test_cluster(1, &[]); + let config = RaftConfig::default(); + let mut leader = RaftNode::new(1, cluster, config); + + // Become leader + leader.force_election_timeout(); + leader.tick(); + assert!(leader.is_leader()); + + // Propose a command + let command = Command::KvSet { + key: "test".to_string(), + value: vec![1, 2, 3], + ttl: None, + }; + let index = leader.propose(command).unwrap(); + + assert_eq!(index, 1); + assert_eq!(leader.log_length(), 1); + } + + #[test] + fn test_follower_receives_append_entries() { + let cluster1 = create_test_cluster(1, &[2]); + let cluster2 = create_test_cluster(2, &[1]); + let config = RaftConfig::default(); + + let mut leader = RaftNode::new(1, cluster1, config.clone()); + let mut follower = RaftNode::new(2, cluster2, config); + + // Make node 1 leader + leader.force_election_timeout(); + leader.tick(); + + // Simulate vote from node 2 + leader.handle_rpc( + 2, + RpcMessage::RequestVoteResponse(RequestVoteResponse::grant(1)), + ); + + assert!(leader.is_leader()); + + // Propose a command + leader + .propose(Command::KvSet { + key: "key1".to_string(), + value: vec![1], + ttl: None, + }) + .unwrap(); + + // Get AppendEntries from leader events (skip heartbeats, find one with entries) + let events = leader.drain_events(); + let append_req = events + .iter() + .find_map(|e| { + if let RaftEvent::SendRpc(2, RpcMessage::AppendEntries(req)) = e { + // Skip heartbeats (empty entries), find the one with actual entries + if !req.entries.is_empty() { + Some(req.clone()) + } else { + None + } + } else { + None + } + }) + .unwrap(); + + // Send to follower + let response = follower + .handle_rpc(1, RpcMessage::AppendEntries(append_req)) + .unwrap(); + + if let RpcMessage::AppendEntriesResponse(resp) = response { + assert!(resp.success); + assert_eq!(resp.match_index, 1); + } + + assert_eq!(follower.log_length(), 1); + } + + #[test] + fn test_step_down_on_higher_term() { + let cluster = create_test_cluster(1, &[2]); + let config = RaftConfig::default(); + let mut node = RaftNode::new(1, cluster, config); + + // Make it leader + node.force_election_timeout(); + node.tick(); + node.handle_rpc( + 2, + RpcMessage::RequestVoteResponse(RequestVoteResponse::grant(1)), + ); + + assert!(node.is_leader()); + assert_eq!(node.current_term(), 1); + + // Receive AppendEntries with higher term + node.handle_rpc( + 2, + RpcMessage::AppendEntries(AppendEntries::heartbeat(5, 2, 0, 0, 0)), + ); + + assert!(!node.is_leader()); + assert_eq!(node.current_term(), 5); + assert_eq!(node.leader(), Some(2)); + } + + #[test] + fn test_commit_index_advancement() { + let cluster = create_test_cluster(1, &[2, 3]); + let config = RaftConfig::default(); + let mut leader = RaftNode::new(1, cluster, config); + + // Become leader + leader.force_election_timeout(); + leader.tick(); + leader.handle_rpc( + 2, + RpcMessage::RequestVoteResponse(RequestVoteResponse::grant(1)), + ); + leader.handle_rpc( + 3, + RpcMessage::RequestVoteResponse(RequestVoteResponse::grant(1)), + ); + + // Propose command + leader + .propose(Command::KvSet { + key: "key".to_string(), + value: vec![1], + ttl: None, + }) + .unwrap(); + + // Simulate successful replication to node 2 + leader.handle_rpc( + 2, + RpcMessage::AppendEntriesResponse(AppendEntriesResponse::success(1, 1)), + ); + + // Commit index should advance (quorum = 2, we have leader + node2) + assert_eq!(leader.commit_index(), 1); + } + + #[test] + fn test_snapshot() { + let cluster = create_test_cluster(1, &[]); + let config = RaftConfig::default(); + let mut node = RaftNode::new(1, cluster, config); + + // Become leader + node.force_election_timeout(); + node.tick(); + + // Propose and commit some entries + for i in 0..5 { + node.propose(Command::KvSet { + key: format!("key{}", i), + value: vec![i as u8], + ttl: None, + }) + .unwrap(); + } + + // Manually advance commit and apply + node.state.update_commit_index(5); + node.tick(); + + // Take snapshot + node.take_snapshot(vec![1, 2, 3, 4, 5]); + + // Check snapshot was created + assert!(node.snapshots.get_snapshot().is_some()); + } + + #[test] + fn test_request_vote_log_check() { + let cluster = create_test_cluster(1, &[2]); + let config = RaftConfig::default(); + let mut node = RaftNode::new(1, cluster, config); + + // Add some entries to node's log + node.log.append(LogEntry::new(1, 1, Command::Noop)); + node.log.append(LogEntry::new(2, 2, Command::Noop)); + + // Request vote from candidate with shorter log (lower term) + let req = RequestVote::new(3, 2, 1, 1); + + let resp = node.handle_rpc(2, RpcMessage::RequestVote(req)); + if let Some(RpcMessage::RequestVoteResponse(r)) = resp { + // Should not grant vote - our log is more up-to-date (term 2 > term 1) + assert!(!r.vote_granted); + } + + // Request vote from candidate with equal or better log + let req2 = RequestVote::new(4, 2, 2, 2); + + let resp2 = node.handle_rpc(2, RpcMessage::RequestVote(req2)); + if let Some(RpcMessage::RequestVoteResponse(r)) = resp2 { + assert!(r.vote_granted); + } + } +} diff --git a/crates/synor-database/src/replication/rpc.rs b/crates/synor-database/src/replication/rpc.rs new file mode 100644 index 0000000..2ed68f7 --- /dev/null +++ b/crates/synor-database/src/replication/rpc.rs @@ -0,0 +1,318 @@ +//! RPC messages for Raft consensus. + +use super::log::LogEntry; +use serde::{Deserialize, Serialize}; + +/// All RPC message types in Raft. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub enum RpcMessage { + /// Request vote from candidate. + RequestVote(RequestVote), + /// Response to vote request. + RequestVoteResponse(RequestVoteResponse), + /// Append entries from leader. + AppendEntries(AppendEntries), + /// Response to append entries. + AppendEntriesResponse(AppendEntriesResponse), + /// Install snapshot from leader. + InstallSnapshot(InstallSnapshot), + /// Response to install snapshot. + InstallSnapshotResponse(InstallSnapshotResponse), +} + +impl RpcMessage { + /// Returns the term of this message. + pub fn term(&self) -> u64 { + match self { + RpcMessage::RequestVote(r) => r.term, + RpcMessage::RequestVoteResponse(r) => r.term, + RpcMessage::AppendEntries(r) => r.term, + RpcMessage::AppendEntriesResponse(r) => r.term, + RpcMessage::InstallSnapshot(r) => r.term, + RpcMessage::InstallSnapshotResponse(r) => r.term, + } + } + + /// Serializes to bytes. + pub fn to_bytes(&self) -> Vec { + bincode::serialize(self).unwrap_or_default() + } + + /// Deserializes from bytes. + pub fn from_bytes(bytes: &[u8]) -> Option { + bincode::deserialize(bytes).ok() + } +} + +/// RequestVote RPC (sent by candidates to gather votes). +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct RequestVote { + /// Candidate's term. + pub term: u64, + /// Candidate requesting vote. + pub candidate_id: u64, + /// Index of candidate's last log entry. + pub last_log_index: u64, + /// Term of candidate's last log entry. + pub last_log_term: u64, +} + +impl RequestVote { + /// Creates a new RequestVote message. + pub fn new(term: u64, candidate_id: u64, last_log_index: u64, last_log_term: u64) -> Self { + Self { + term, + candidate_id, + last_log_index, + last_log_term, + } + } +} + +/// Response to RequestVote. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct RequestVoteResponse { + /// Current term (for candidate to update). + pub term: u64, + /// True if candidate received vote. + pub vote_granted: bool, +} + +impl RequestVoteResponse { + /// Creates a positive response. + pub fn grant(term: u64) -> Self { + Self { + term, + vote_granted: true, + } + } + + /// Creates a negative response. + pub fn deny(term: u64) -> Self { + Self { + term, + vote_granted: false, + } + } +} + +/// AppendEntries RPC (sent by leader for replication and heartbeat). +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct AppendEntries { + /// Leader's term. + pub term: u64, + /// Leader ID (so follower can redirect clients). + pub leader_id: u64, + /// Index of log entry immediately preceding new ones. + pub prev_log_index: u64, + /// Term of prev_log_index entry. + pub prev_log_term: u64, + /// Log entries to store (empty for heartbeat). + pub entries: Vec, + /// Leader's commit index. + pub leader_commit: u64, +} + +impl AppendEntries { + /// Creates a heartbeat (empty entries). + pub fn heartbeat( + term: u64, + leader_id: u64, + prev_log_index: u64, + prev_log_term: u64, + leader_commit: u64, + ) -> Self { + Self { + term, + leader_id, + prev_log_index, + prev_log_term, + entries: Vec::new(), + leader_commit, + } + } + + /// Creates an append entries request with entries. + pub fn with_entries( + term: u64, + leader_id: u64, + prev_log_index: u64, + prev_log_term: u64, + entries: Vec, + leader_commit: u64, + ) -> Self { + Self { + term, + leader_id, + prev_log_index, + prev_log_term, + entries, + leader_commit, + } + } + + /// Returns true if this is a heartbeat (no entries). + pub fn is_heartbeat(&self) -> bool { + self.entries.is_empty() + } +} + +/// Response to AppendEntries. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct AppendEntriesResponse { + /// Current term (for leader to update). + pub term: u64, + /// True if follower contained entry matching prev_log_index and prev_log_term. + pub success: bool, + /// Index of last entry appended (for quick catch-up). + pub match_index: u64, + /// If false, the conflicting term (for optimization). + pub conflict_term: Option, + /// First index of conflicting term. + pub conflict_index: Option, +} + +impl AppendEntriesResponse { + /// Creates a success response. + pub fn success(term: u64, match_index: u64) -> Self { + Self { + term, + success: true, + match_index, + conflict_term: None, + conflict_index: None, + } + } + + /// Creates a failure response. + pub fn failure(term: u64) -> Self { + Self { + term, + success: false, + match_index: 0, + conflict_term: None, + conflict_index: None, + } + } + + /// Creates a failure response with conflict info. + pub fn conflict(term: u64, conflict_term: u64, conflict_index: u64) -> Self { + Self { + term, + success: false, + match_index: 0, + conflict_term: Some(conflict_term), + conflict_index: Some(conflict_index), + } + } +} + +/// InstallSnapshot RPC (sent by leader when follower is too far behind). +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct InstallSnapshot { + /// Leader's term. + pub term: u64, + /// Leader ID. + pub leader_id: u64, + /// Index of last entry included in snapshot. + pub last_included_index: u64, + /// Term of last entry included in snapshot. + pub last_included_term: u64, + /// Byte offset where chunk is positioned. + pub offset: u64, + /// Raw bytes of snapshot chunk. + pub data: Vec, + /// True if this is the last chunk. + pub done: bool, +} + +impl InstallSnapshot { + /// Creates a new snapshot installation request. + pub fn new( + term: u64, + leader_id: u64, + last_included_index: u64, + last_included_term: u64, + offset: u64, + data: Vec, + done: bool, + ) -> Self { + Self { + term, + leader_id, + last_included_index, + last_included_term, + offset, + data, + done, + } + } +} + +/// Response to InstallSnapshot. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct InstallSnapshotResponse { + /// Current term (for leader to update). + pub term: u64, +} + +impl InstallSnapshotResponse { + /// Creates a new response. + pub fn new(term: u64) -> Self { + Self { term } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::replication::state::Command; + + #[test] + fn test_request_vote() { + let request = RequestVote::new(1, 1, 10, 1); + assert_eq!(request.term, 1); + assert_eq!(request.candidate_id, 1); + + let grant = RequestVoteResponse::grant(1); + assert!(grant.vote_granted); + + let deny = RequestVoteResponse::deny(2); + assert!(!deny.vote_granted); + assert_eq!(deny.term, 2); + } + + #[test] + fn test_append_entries() { + let heartbeat = AppendEntries::heartbeat(1, 1, 0, 0, 0); + assert!(heartbeat.is_heartbeat()); + + let entries = vec![LogEntry::new(1, 1, Command::Noop)]; + let append = AppendEntries::with_entries(1, 1, 0, 0, entries, 0); + assert!(!append.is_heartbeat()); + } + + #[test] + fn test_rpc_message_serialization() { + let request = RpcMessage::RequestVote(RequestVote::new(1, 1, 10, 1)); + let bytes = request.to_bytes(); + let decoded = RpcMessage::from_bytes(&bytes).unwrap(); + + assert_eq!(decoded.term(), 1); + } + + #[test] + fn test_append_entries_response() { + let success = AppendEntriesResponse::success(1, 10); + assert!(success.success); + assert_eq!(success.match_index, 10); + + let failure = AppendEntriesResponse::failure(2); + assert!(!failure.success); + + let conflict = AppendEntriesResponse::conflict(2, 1, 5); + assert!(!conflict.success); + assert_eq!(conflict.conflict_term, Some(1)); + assert_eq!(conflict.conflict_index, Some(5)); + } +} diff --git a/crates/synor-database/src/replication/snapshot.rs b/crates/synor-database/src/replication/snapshot.rs new file mode 100644 index 0000000..c5411e6 --- /dev/null +++ b/crates/synor-database/src/replication/snapshot.rs @@ -0,0 +1,301 @@ +//! Log compaction and snapshots. + +use serde::{Deserialize, Serialize}; + +/// Metadata for a snapshot. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct SnapshotMetadata { + /// Index of last entry included in snapshot. + pub last_included_index: u64, + /// Term of last entry included in snapshot. + pub last_included_term: u64, + /// Cluster configuration at snapshot time. + pub config: SnapshotConfig, + /// Size of snapshot data in bytes. + pub size: u64, + /// Timestamp when snapshot was created. + pub created_at: u64, +} + +/// Cluster configuration stored in snapshot. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct SnapshotConfig { + /// Node IDs in the cluster. + pub nodes: Vec, +} + +impl Default for SnapshotConfig { + fn default() -> Self { + Self { nodes: Vec::new() } + } +} + +/// A complete snapshot of the state machine. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct Snapshot { + /// Snapshot metadata. + pub metadata: SnapshotMetadata, + /// Snapshot data (serialized state machine). + pub data: Vec, +} + +impl Snapshot { + /// Creates a new snapshot. + pub fn new( + last_included_index: u64, + last_included_term: u64, + config: SnapshotConfig, + data: Vec, + ) -> Self { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_millis() as u64; + + Self { + metadata: SnapshotMetadata { + last_included_index, + last_included_term, + config, + size: data.len() as u64, + created_at: now, + }, + data, + } + } + + /// Serializes the snapshot to bytes. + pub fn to_bytes(&self) -> Vec { + bincode::serialize(self).unwrap_or_default() + } + + /// Deserializes from bytes. + pub fn from_bytes(bytes: &[u8]) -> Option { + bincode::deserialize(bytes).ok() + } +} + +/// Manages snapshot creation and storage. +pub struct SnapshotManager { + /// Threshold for log entries before snapshotting. + snapshot_threshold: u64, + /// Current snapshot (if any). + current_snapshot: Option, + /// Pending snapshot being received from leader. + pending_snapshot: Option, +} + +/// Snapshot being received in chunks. +struct PendingSnapshot { + metadata: SnapshotMetadata, + chunks: Vec>, + expected_offset: u64, +} + +impl SnapshotManager { + /// Creates a new snapshot manager. + pub fn new(snapshot_threshold: u64) -> Self { + Self { + snapshot_threshold, + current_snapshot: None, + pending_snapshot: None, + } + } + + /// Returns the threshold for triggering snapshots. + pub fn threshold(&self) -> u64 { + self.snapshot_threshold + } + + /// Checks if a snapshot should be taken. + pub fn should_snapshot(&self, log_size: u64, last_snapshot_index: u64) -> bool { + log_size - last_snapshot_index >= self.snapshot_threshold + } + + /// Gets the current snapshot. + pub fn get_snapshot(&self) -> Option<&Snapshot> { + self.current_snapshot.as_ref() + } + + /// Gets the last included index of the current snapshot. + pub fn last_included_index(&self) -> u64 { + self.current_snapshot + .as_ref() + .map(|s| s.metadata.last_included_index) + .unwrap_or(0) + } + + /// Gets the last included term of the current snapshot. + pub fn last_included_term(&self) -> u64 { + self.current_snapshot + .as_ref() + .map(|s| s.metadata.last_included_term) + .unwrap_or(0) + } + + /// Creates a new snapshot. + pub fn create_snapshot( + &mut self, + last_included_index: u64, + last_included_term: u64, + config: SnapshotConfig, + data: Vec, + ) { + let snapshot = Snapshot::new(last_included_index, last_included_term, config, data); + self.current_snapshot = Some(snapshot); + } + + /// Starts receiving a snapshot from leader. + pub fn start_receiving( + &mut self, + last_included_index: u64, + last_included_term: u64, + offset: u64, + data: Vec, + ) -> bool { + if offset != 0 { + // First chunk should be at offset 0 + return false; + } + + self.pending_snapshot = Some(PendingSnapshot { + metadata: SnapshotMetadata { + last_included_index, + last_included_term, + config: SnapshotConfig::default(), + size: 0, + created_at: 0, + }, + chunks: vec![data], + expected_offset: 0, + }); + + true + } + + /// Adds a chunk to the pending snapshot. + pub fn add_chunk(&mut self, offset: u64, data: Vec) -> bool { + if let Some(ref mut pending) = self.pending_snapshot { + if offset == pending.expected_offset + pending.chunks.iter().map(|c| c.len() as u64).sum::() { + pending.chunks.push(data); + return true; + } + } + false + } + + /// Finalizes the pending snapshot. + pub fn finalize_snapshot(&mut self) -> Option { + if let Some(pending) = self.pending_snapshot.take() { + let data: Vec = pending.chunks.into_iter().flatten().collect(); + let snapshot = Snapshot::new( + pending.metadata.last_included_index, + pending.metadata.last_included_term, + pending.metadata.config, + data, + ); + self.current_snapshot = Some(snapshot.clone()); + return Some(snapshot); + } + None + } + + /// Cancels receiving a snapshot. + pub fn cancel_receiving(&mut self) { + self.pending_snapshot = None; + } + + /// Splits snapshot into chunks for transmission. + pub fn chunk_snapshot(&self, chunk_size: usize) -> Vec<(u64, Vec, bool)> { + let mut chunks = Vec::new(); + + if let Some(ref snapshot) = self.current_snapshot { + let data = &snapshot.data; + let mut offset = 0; + + while offset < data.len() { + let end = (offset + chunk_size).min(data.len()); + let chunk = data[offset..end].to_vec(); + let done = end == data.len(); + chunks.push((offset as u64, chunk, done)); + offset = end; + } + } + + chunks + } +} + +impl Default for SnapshotManager { + fn default() -> Self { + Self::new(10000) // Default: snapshot every 10k entries + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_snapshot_creation() { + let snapshot = Snapshot::new(100, 5, SnapshotConfig::default(), vec![1, 2, 3, 4, 5]); + + assert_eq!(snapshot.metadata.last_included_index, 100); + assert_eq!(snapshot.metadata.last_included_term, 5); + assert_eq!(snapshot.metadata.size, 5); + } + + #[test] + fn test_snapshot_serialization() { + let snapshot = Snapshot::new(100, 5, SnapshotConfig::default(), vec![1, 2, 3, 4, 5]); + let bytes = snapshot.to_bytes(); + let decoded = Snapshot::from_bytes(&bytes).unwrap(); + + assert_eq!(decoded.metadata.last_included_index, 100); + assert_eq!(decoded.data, vec![1, 2, 3, 4, 5]); + } + + #[test] + fn test_snapshot_manager() { + let mut manager = SnapshotManager::new(100); + + assert!(manager.should_snapshot(150, 0)); + assert!(!manager.should_snapshot(50, 0)); + + manager.create_snapshot(100, 5, SnapshotConfig::default(), vec![1, 2, 3]); + + assert_eq!(manager.last_included_index(), 100); + assert_eq!(manager.last_included_term(), 5); + } + + #[test] + fn test_chunking() { + let mut manager = SnapshotManager::new(100); + manager.create_snapshot(100, 5, SnapshotConfig::default(), vec![0; 250]); + + let chunks = manager.chunk_snapshot(100); + + assert_eq!(chunks.len(), 3); + assert_eq!(chunks[0].0, 0); + assert_eq!(chunks[0].1.len(), 100); + assert!(!chunks[0].2); // not done + assert_eq!(chunks[2].0, 200); + assert_eq!(chunks[2].1.len(), 50); + assert!(chunks[2].2); // done + } + + #[test] + fn test_receiving_snapshot() { + let mut manager = SnapshotManager::new(100); + + // Start receiving + assert!(manager.start_receiving(100, 5, 0, vec![1, 2, 3])); + + // Add more chunks + assert!(manager.add_chunk(3, vec![4, 5, 6])); + + // Finalize + let snapshot = manager.finalize_snapshot().unwrap(); + assert_eq!(snapshot.data, vec![1, 2, 3, 4, 5, 6]); + } +} diff --git a/crates/synor-database/src/replication/state.rs b/crates/synor-database/src/replication/state.rs new file mode 100644 index 0000000..dea2ca8 --- /dev/null +++ b/crates/synor-database/src/replication/state.rs @@ -0,0 +1,345 @@ +//! Raft node state and commands. + +use serde::{Deserialize, Serialize}; +use serde_json::Value as JsonValue; + +/// Role of a node in the Raft cluster. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub enum NodeRole { + /// Leader: handles all client requests and replicates log entries. + Leader, + /// Follower: passive, responds to RPCs from leader and candidates. + Follower, + /// Candidate: actively trying to become leader. + Candidate, +} + +impl Default for NodeRole { + fn default() -> Self { + NodeRole::Follower + } +} + +impl std::fmt::Display for NodeRole { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + NodeRole::Leader => write!(f, "Leader"), + NodeRole::Follower => write!(f, "Follower"), + NodeRole::Candidate => write!(f, "Candidate"), + } + } +} + +/// Persistent state on all servers. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct RaftState { + /// Current term (increases monotonically). + pub current_term: u64, + /// Candidate that received vote in current term. + pub voted_for: Option, + /// Current role of this node. + pub role: NodeRole, + /// Index of highest log entry known to be committed. + pub commit_index: u64, + /// Index of highest log entry applied to state machine. + pub last_applied: u64, +} + +impl Default for RaftState { + fn default() -> Self { + Self { + current_term: 0, + voted_for: None, + role: NodeRole::Follower, + commit_index: 0, + last_applied: 0, + } + } +} + +impl RaftState { + /// Creates a new Raft state. + pub fn new() -> Self { + Self::default() + } + + /// Transitions to follower state. + pub fn become_follower(&mut self, term: u64) { + self.current_term = term; + self.role = NodeRole::Follower; + self.voted_for = None; + } + + /// Transitions to candidate state. + pub fn become_candidate(&mut self) { + self.current_term += 1; + self.role = NodeRole::Candidate; + // Vote for self when becoming candidate + } + + /// Transitions to leader state. + pub fn become_leader(&mut self) { + self.role = NodeRole::Leader; + } + + /// Updates commit index if new value is higher. + pub fn update_commit_index(&mut self, new_commit: u64) { + if new_commit > self.commit_index { + self.commit_index = new_commit; + } + } + + /// Updates last applied index. + pub fn update_last_applied(&mut self, index: u64) { + self.last_applied = index; + } + + /// Returns true if this node is the leader. + pub fn is_leader(&self) -> bool { + self.role == NodeRole::Leader + } + + /// Returns true if this node is a follower. + pub fn is_follower(&self) -> bool { + self.role == NodeRole::Follower + } + + /// Returns true if this node is a candidate. + pub fn is_candidate(&self) -> bool { + self.role == NodeRole::Candidate + } +} + +/// Commands that can be replicated through Raft. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub enum Command { + /// No operation (used for heartbeats and new leader commit). + Noop, + + // Key-Value operations + /// Set a key-value pair. + KvSet { key: String, value: Vec, ttl: Option }, + /// Delete a key. + KvDelete { key: String }, + + // Document operations + /// Insert a document. + DocInsert { collection: String, document: JsonValue }, + /// Update a document. + DocUpdate { collection: String, id: String, update: JsonValue }, + /// Delete a document. + DocDelete { collection: String, id: String }, + + // Vector operations + /// Insert a vector. + VectorInsert { namespace: String, id: String, vector: Vec, metadata: JsonValue }, + /// Delete a vector. + VectorDelete { namespace: String, id: String }, + + // Time-series operations + /// Record a metric data point. + TimeSeriesRecord { metric: String, value: f64, timestamp: u64, tags: JsonValue }, + + // Graph operations + /// Create a graph node. + GraphNodeCreate { labels: Vec, properties: JsonValue }, + /// Delete a graph node. + GraphNodeDelete { id: String }, + /// Create a graph edge. + GraphEdgeCreate { + source: String, + target: String, + edge_type: String, + properties: JsonValue, + }, + /// Delete a graph edge. + GraphEdgeDelete { id: String }, + + // SQL operations + /// Execute a SQL statement. + SqlExecute { sql: String }, + + // Schema operations + /// Create a collection/table. + CreateCollection { name: String, schema: Option }, + /// Drop a collection/table. + DropCollection { name: String }, + + // Index operations + /// Create an index. + CreateIndex { collection: String, field: String, index_type: String }, + /// Drop an index. + DropIndex { name: String }, + + // Configuration changes + /// Add a node to the cluster. + AddNode { node_id: u64, address: String }, + /// Remove a node from the cluster. + RemoveNode { node_id: u64 }, +} + +impl Command { + /// Returns a descriptive name for this command. + pub fn name(&self) -> &'static str { + match self { + Command::Noop => "noop", + Command::KvSet { .. } => "kv_set", + Command::KvDelete { .. } => "kv_delete", + Command::DocInsert { .. } => "doc_insert", + Command::DocUpdate { .. } => "doc_update", + Command::DocDelete { .. } => "doc_delete", + Command::VectorInsert { .. } => "vector_insert", + Command::VectorDelete { .. } => "vector_delete", + Command::TimeSeriesRecord { .. } => "timeseries_record", + Command::GraphNodeCreate { .. } => "graph_node_create", + Command::GraphNodeDelete { .. } => "graph_node_delete", + Command::GraphEdgeCreate { .. } => "graph_edge_create", + Command::GraphEdgeDelete { .. } => "graph_edge_delete", + Command::SqlExecute { .. } => "sql_execute", + Command::CreateCollection { .. } => "create_collection", + Command::DropCollection { .. } => "drop_collection", + Command::CreateIndex { .. } => "create_index", + Command::DropIndex { .. } => "drop_index", + Command::AddNode { .. } => "add_node", + Command::RemoveNode { .. } => "remove_node", + } + } + + /// Returns true if this is a read-only command (can be served by followers). + pub fn is_read_only(&self) -> bool { + matches!(self, Command::Noop) + } + + /// Serializes the command to bytes. + pub fn to_bytes(&self) -> Vec { + bincode::serialize(self).unwrap_or_default() + } + + /// Deserializes from bytes. + pub fn from_bytes(bytes: &[u8]) -> Option { + bincode::deserialize(bytes).ok() + } +} + +/// Leader state (volatile, only on leaders). +#[derive(Clone, Debug, Default)] +pub struct LeaderState { + /// For each server, index of next log entry to send. + pub next_index: std::collections::HashMap, + /// For each server, index of highest log entry known to be replicated. + pub match_index: std::collections::HashMap, +} + +impl LeaderState { + /// Creates a new leader state. + pub fn new(last_log_index: u64, peer_ids: &[u64]) -> Self { + let mut next_index = std::collections::HashMap::new(); + let mut match_index = std::collections::HashMap::new(); + + for &peer_id in peer_ids { + // Initialize nextIndex to leader's last log index + 1 + next_index.insert(peer_id, last_log_index + 1); + // Initialize matchIndex to 0 + match_index.insert(peer_id, 0); + } + + Self { + next_index, + match_index, + } + } + + /// Updates next_index for a peer after failed append. + pub fn decrement_next_index(&mut self, peer_id: u64) { + if let Some(idx) = self.next_index.get_mut(&peer_id) { + if *idx > 1 { + *idx -= 1; + } + } + } + + /// Updates indices after successful append. + pub fn update_indices(&mut self, peer_id: u64, last_index: u64) { + self.next_index.insert(peer_id, last_index + 1); + self.match_index.insert(peer_id, last_index); + } + + /// Calculates the new commit index based on majority replication. + pub fn calculate_commit_index(&self, current_commit: u64, current_term: u64, log_term_at: impl Fn(u64) -> Option) -> u64 { + // Find the highest index that a majority have replicated + let mut indices: Vec = self.match_index.values().cloned().collect(); + indices.sort_unstable(); + indices.reverse(); + + // Majority is (n + 1) / 2 where n includes the leader + let majority = (indices.len() + 1 + 1) / 2; + + for &index in &indices { + if index > current_commit { + // Only commit entries from current term + if let Some(term) = log_term_at(index) { + if term == current_term { + // Check if majority have this index + let count = indices.iter().filter(|&&i| i >= index).count() + 1; // +1 for leader + if count >= majority { + return index; + } + } + } + } + } + + current_commit + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_node_role() { + let mut state = RaftState::new(); + assert!(state.is_follower()); + + state.become_candidate(); + assert!(state.is_candidate()); + assert_eq!(state.current_term, 1); + + state.become_leader(); + assert!(state.is_leader()); + + state.become_follower(5); + assert!(state.is_follower()); + assert_eq!(state.current_term, 5); + } + + #[test] + fn test_command_serialization() { + let cmd = Command::KvSet { + key: "test".to_string(), + value: vec![1, 2, 3], + ttl: Some(3600), + }; + + let bytes = cmd.to_bytes(); + let decoded = Command::from_bytes(&bytes).unwrap(); + + if let Command::KvSet { key, value, ttl } = decoded { + assert_eq!(key, "test"); + assert_eq!(value, vec![1, 2, 3]); + assert_eq!(ttl, Some(3600)); + } else { + panic!("Wrong command type"); + } + } + + #[test] + fn test_leader_state() { + let peers = vec![2, 3, 4]; + let state = LeaderState::new(10, &peers); + + assert_eq!(state.next_index.get(&2), Some(&11)); + assert_eq!(state.match_index.get(&2), Some(&0)); + } +} diff --git a/crates/synor-database/src/sql/executor.rs b/crates/synor-database/src/sql/executor.rs new file mode 100644 index 0000000..e6ae516 --- /dev/null +++ b/crates/synor-database/src/sql/executor.rs @@ -0,0 +1,898 @@ +//! SQL query executor. + +use super::parser::{ + BinaryOp, JoinType, ParsedExpr, ParsedOrderBy, ParsedSelect, ParsedSelectItem, + ParsedStatement, SqlParser, +}; +use super::row::{Row, RowBuilder, RowId}; +use super::table::{ColumnDef, Table, TableDef}; +use super::transaction::{IsolationLevel, TransactionId, TransactionManager, TransactionOp}; +use super::types::{SqlError, SqlType, SqlValue}; +use parking_lot::RwLock; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::sync::Arc; + +/// Result of a SQL query. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct QueryResult { + /// Column names. + pub columns: Vec, + /// Result rows. + pub rows: Vec>, + /// Number of rows affected (for INSERT/UPDATE/DELETE). + pub rows_affected: u64, + /// Execution time in milliseconds. + pub execution_time_ms: u64, +} + +impl QueryResult { + /// Creates an empty result. + pub fn empty() -> Self { + Self { + columns: Vec::new(), + rows: Vec::new(), + rows_affected: 0, + execution_time_ms: 0, + } + } + + /// Creates a result with affected count. + pub fn affected(count: u64) -> Self { + Self { + columns: Vec::new(), + rows: Vec::new(), + rows_affected: count, + execution_time_ms: 0, + } + } +} + +/// SQL execution engine. +pub struct SqlEngine { + /// Tables in this engine. + tables: RwLock>>, + /// Transaction manager. + txn_manager: TransactionManager, +} + +impl SqlEngine { + /// Creates a new SQL engine. + pub fn new() -> Self { + Self { + tables: RwLock::new(HashMap::new()), + txn_manager: TransactionManager::new(), + } + } + + /// Executes a SQL statement. + pub fn execute(&self, sql: &str) -> Result { + let start = std::time::Instant::now(); + let stmt = SqlParser::parse(sql)?; + let mut result = self.execute_statement(&stmt)?; + result.execution_time_ms = start.elapsed().as_millis() as u64; + Ok(result) + } + + /// Executes a parsed statement. + fn execute_statement(&self, stmt: &ParsedStatement) -> Result { + match stmt { + ParsedStatement::CreateTable { + name, + columns, + if_not_exists, + } => self.execute_create_table(name, columns, *if_not_exists), + ParsedStatement::DropTable { name, if_exists } => { + self.execute_drop_table(name, *if_exists) + } + ParsedStatement::Select(select) => self.execute_select(select), + ParsedStatement::Insert { + table, + columns, + values, + } => self.execute_insert(table, columns, values), + ParsedStatement::Update { + table, + assignments, + where_clause, + } => self.execute_update(table, assignments, where_clause.as_ref()), + ParsedStatement::Delete { + table, + where_clause, + } => self.execute_delete(table, where_clause.as_ref()), + ParsedStatement::CreateIndex { + name, + table, + columns, + unique, + } => self.execute_create_index(name, table, columns, *unique), + ParsedStatement::DropIndex { name } => self.execute_drop_index(name), + } + } + + /// Creates a table. + fn execute_create_table( + &self, + name: &str, + columns: &[super::parser::ParsedColumn], + if_not_exists: bool, + ) -> Result { + let mut tables = self.tables.write(); + + if tables.contains_key(name) { + if if_not_exists { + return Ok(QueryResult::empty()); + } + return Err(SqlError::TableExists(name.to_string())); + } + + let mut table_def = TableDef::new(name); + for col in columns { + let mut col_def = ColumnDef::new(&col.name, col.data_type.clone()); + if !col.nullable { + col_def = col_def.not_null(); + } + if let Some(ref default) = col.default { + col_def = col_def.default(default.clone()); + } + if col.primary_key { + col_def = col_def.primary_key(); + } + if col.unique { + col_def = col_def.unique(); + } + table_def = table_def.column(col_def); + } + + let table = Arc::new(Table::new(table_def)); + tables.insert(name.to_string(), table); + + Ok(QueryResult::empty()) + } + + /// Drops a table. + fn execute_drop_table(&self, name: &str, if_exists: bool) -> Result { + let mut tables = self.tables.write(); + + if !tables.contains_key(name) { + if if_exists { + return Ok(QueryResult::empty()); + } + return Err(SqlError::TableNotFound(name.to_string())); + } + + tables.remove(name); + Ok(QueryResult::empty()) + } + + /// Executes a SELECT query. + fn execute_select(&self, select: &ParsedSelect) -> Result { + let tables = self.tables.read(); + let table = tables + .get(&select.from) + .ok_or_else(|| SqlError::TableNotFound(select.from.clone()))?; + + // Get all rows + let mut rows = table.scan(); + + // Apply WHERE filter + if let Some(ref where_clause) = select.where_clause { + rows = rows + .into_iter() + .filter(|row| self.evaluate_where(row, where_clause)) + .collect(); + } + + // Apply ORDER BY + if !select.order_by.is_empty() { + rows.sort_by(|a, b| { + for ob in &select.order_by { + let a_val = a.get_or_null(&ob.column); + let b_val = b.get_or_null(&ob.column); + match a_val.partial_cmp(&b_val) { + Some(std::cmp::Ordering::Equal) => continue, + Some(ord) => { + return if ob.ascending { + ord + } else { + ord.reverse() + }; + } + None => continue, + } + } + std::cmp::Ordering::Equal + }); + } + + // Apply OFFSET + if let Some(offset) = select.offset { + rows = rows.into_iter().skip(offset).collect(); + } + + // Apply LIMIT + if let Some(limit) = select.limit { + rows = rows.into_iter().take(limit).collect(); + } + + // Handle aggregates + if select.columns.iter().any(|c| matches!(c, ParsedSelectItem::Aggregate { .. })) { + return self.execute_aggregate(select, &rows, table); + } + + // Project columns + let (column_names, result_rows) = self.project_rows(&select.columns, &rows, table); + + Ok(QueryResult { + columns: column_names, + rows: result_rows, + rows_affected: 0, + execution_time_ms: 0, + }) + } + + /// Projects rows to selected columns. + fn project_rows( + &self, + select_items: &[ParsedSelectItem], + rows: &[Row], + table: &Table, + ) -> (Vec, Vec>) { + let column_names: Vec = select_items + .iter() + .flat_map(|item| match item { + ParsedSelectItem::Wildcard => table.def.column_names(), + ParsedSelectItem::Column(name) => vec![name.clone()], + ParsedSelectItem::ColumnAlias { alias, .. } => vec![alias.clone()], + ParsedSelectItem::Aggregate { function, alias, .. } => { + vec![alias.clone().unwrap_or_else(|| function.clone())] + } + }) + .collect(); + + let result_rows: Vec> = rows + .iter() + .map(|row| { + select_items + .iter() + .flat_map(|item| match item { + ParsedSelectItem::Wildcard => table + .def + .column_names() + .into_iter() + .map(|c| row.get_or_null(&c)) + .collect::>(), + ParsedSelectItem::Column(name) + | ParsedSelectItem::ColumnAlias { column: name, .. } => { + vec![row.get_or_null(name)] + } + _ => vec![SqlValue::Null], + }) + .collect() + }) + .collect(); + + (column_names, result_rows) + } + + /// Executes aggregate functions. + fn execute_aggregate( + &self, + select: &ParsedSelect, + rows: &[Row], + table: &Table, + ) -> Result { + let mut result_columns = Vec::new(); + let mut result_values = Vec::new(); + + for item in &select.columns { + match item { + ParsedSelectItem::Aggregate { + function, + column, + alias, + } => { + let col_name = alias.clone().unwrap_or_else(|| function.clone()); + result_columns.push(col_name); + + let value = match function.as_str() { + "COUNT" => SqlValue::Integer(rows.len() as i64), + "SUM" => { + let col = column.as_ref().ok_or_else(|| { + SqlError::InvalidOperation("SUM requires column".to_string()) + })?; + let sum: f64 = rows + .iter() + .filter_map(|r| r.get_or_null(col).as_real()) + .sum(); + SqlValue::Real(sum) + } + "AVG" => { + let col = column.as_ref().ok_or_else(|| { + SqlError::InvalidOperation("AVG requires column".to_string()) + })?; + let values: Vec = rows + .iter() + .filter_map(|r| r.get_or_null(col).as_real()) + .collect(); + if values.is_empty() { + SqlValue::Null + } else { + SqlValue::Real(values.iter().sum::() / values.len() as f64) + } + } + "MIN" => { + let col = column.as_ref().ok_or_else(|| { + SqlError::InvalidOperation("MIN requires column".to_string()) + })?; + rows.iter() + .map(|r| r.get_or_null(col)) + .filter(|v| !v.is_null()) + .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)) + .unwrap_or(SqlValue::Null) + } + "MAX" => { + let col = column.as_ref().ok_or_else(|| { + SqlError::InvalidOperation("MAX requires column".to_string()) + })?; + rows.iter() + .map(|r| r.get_or_null(col)) + .filter(|v| !v.is_null()) + .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)) + .unwrap_or(SqlValue::Null) + } + _ => { + return Err(SqlError::Unsupported(format!("Function: {}", function))) + } + }; + result_values.push(value); + } + ParsedSelectItem::Column(name) => { + result_columns.push(name.clone()); + // For non-aggregated columns in aggregate query, take first value + result_values.push( + rows.first() + .map(|r| r.get_or_null(name)) + .unwrap_or(SqlValue::Null), + ); + } + _ => {} + } + } + + Ok(QueryResult { + columns: result_columns, + rows: if result_values.is_empty() { + Vec::new() + } else { + vec![result_values] + }, + rows_affected: 0, + execution_time_ms: 0, + }) + } + + /// Evaluates a WHERE clause. + fn evaluate_where(&self, row: &Row, expr: &ParsedExpr) -> bool { + match self.evaluate_expr(row, expr) { + SqlValue::Boolean(b) => b, + SqlValue::Integer(i) => i != 0, + _ => false, + } + } + + /// Evaluates an expression. + fn evaluate_expr(&self, row: &Row, expr: &ParsedExpr) -> SqlValue { + match expr { + ParsedExpr::Column(name) => row.get_or_null(name), + ParsedExpr::Literal(value) => value.clone(), + ParsedExpr::BinaryOp { left, op, right } => { + let left_val = self.evaluate_expr(row, left); + let right_val = self.evaluate_expr(row, right); + self.evaluate_binary_op(&left_val, op, &right_val) + } + ParsedExpr::Not(inner) => { + let val = self.evaluate_expr(row, inner); + match val { + SqlValue::Boolean(b) => SqlValue::Boolean(!b), + _ => SqlValue::Null, + } + } + ParsedExpr::IsNull(inner) => { + SqlValue::Boolean(self.evaluate_expr(row, inner).is_null()) + } + ParsedExpr::IsNotNull(inner) => { + SqlValue::Boolean(!self.evaluate_expr(row, inner).is_null()) + } + ParsedExpr::InList { expr, list, negated } => { + let val = self.evaluate_expr(row, expr); + let in_list = list.iter().any(|item| { + let item_val = self.evaluate_expr(row, item); + val == item_val + }); + SqlValue::Boolean(if *negated { !in_list } else { in_list }) + } + ParsedExpr::Between { + expr, + low, + high, + negated, + } => { + let val = self.evaluate_expr(row, expr); + let low_val = self.evaluate_expr(row, low); + let high_val = self.evaluate_expr(row, high); + let between = val >= low_val && val <= high_val; + SqlValue::Boolean(if *negated { !between } else { between }) + } + ParsedExpr::Function { name, args } => { + self.evaluate_function(row, name, args) + } + } + } + + /// Evaluates a binary operation. + fn evaluate_binary_op(&self, left: &SqlValue, op: &BinaryOp, right: &SqlValue) -> SqlValue { + match op { + BinaryOp::Eq => SqlValue::Boolean(left == right), + BinaryOp::Ne => SqlValue::Boolean(left != right), + BinaryOp::Lt => SqlValue::Boolean(left < right), + BinaryOp::Le => SqlValue::Boolean(left <= right), + BinaryOp::Gt => SqlValue::Boolean(left > right), + BinaryOp::Ge => SqlValue::Boolean(left >= right), + BinaryOp::And => { + let l = matches!(left, SqlValue::Boolean(true)); + let r = matches!(right, SqlValue::Boolean(true)); + SqlValue::Boolean(l && r) + } + BinaryOp::Or => { + let l = matches!(left, SqlValue::Boolean(true)); + let r = matches!(right, SqlValue::Boolean(true)); + SqlValue::Boolean(l || r) + } + BinaryOp::Like => { + if let (SqlValue::Text(text), SqlValue::Text(pattern)) = (left, right) { + SqlValue::Boolean(self.match_like(text, pattern)) + } else { + SqlValue::Boolean(false) + } + } + BinaryOp::Plus => match (left, right) { + (SqlValue::Integer(a), SqlValue::Integer(b)) => SqlValue::Integer(a + b), + (SqlValue::Real(a), SqlValue::Real(b)) => SqlValue::Real(a + b), + (SqlValue::Integer(a), SqlValue::Real(b)) => SqlValue::Real(*a as f64 + b), + (SqlValue::Real(a), SqlValue::Integer(b)) => SqlValue::Real(a + *b as f64), + _ => SqlValue::Null, + }, + BinaryOp::Minus => match (left, right) { + (SqlValue::Integer(a), SqlValue::Integer(b)) => SqlValue::Integer(a - b), + (SqlValue::Real(a), SqlValue::Real(b)) => SqlValue::Real(a - b), + _ => SqlValue::Null, + }, + BinaryOp::Multiply => match (left, right) { + (SqlValue::Integer(a), SqlValue::Integer(b)) => SqlValue::Integer(a * b), + (SqlValue::Real(a), SqlValue::Real(b)) => SqlValue::Real(a * b), + _ => SqlValue::Null, + }, + BinaryOp::Divide => match (left, right) { + (SqlValue::Integer(a), SqlValue::Integer(b)) if *b != 0 => { + SqlValue::Integer(a / b) + } + (SqlValue::Real(a), SqlValue::Real(b)) if *b != 0.0 => SqlValue::Real(a / b), + _ => SqlValue::Null, + }, + } + } + + /// Evaluates a function call. + fn evaluate_function(&self, row: &Row, name: &str, args: &[ParsedExpr]) -> SqlValue { + match name.to_uppercase().as_str() { + "COALESCE" => { + for arg in args { + let val = self.evaluate_expr(row, arg); + if !val.is_null() { + return val; + } + } + SqlValue::Null + } + "UPPER" => { + if let Some(arg) = args.first() { + if let SqlValue::Text(s) = self.evaluate_expr(row, arg) { + return SqlValue::Text(s.to_uppercase()); + } + } + SqlValue::Null + } + "LOWER" => { + if let Some(arg) = args.first() { + if let SqlValue::Text(s) = self.evaluate_expr(row, arg) { + return SqlValue::Text(s.to_lowercase()); + } + } + SqlValue::Null + } + "LENGTH" => { + if let Some(arg) = args.first() { + if let SqlValue::Text(s) = self.evaluate_expr(row, arg) { + return SqlValue::Integer(s.len() as i64); + } + } + SqlValue::Null + } + "ABS" => { + if let Some(arg) = args.first() { + match self.evaluate_expr(row, arg) { + SqlValue::Integer(i) => return SqlValue::Integer(i.abs()), + SqlValue::Real(f) => return SqlValue::Real(f.abs()), + _ => {} + } + } + SqlValue::Null + } + _ => SqlValue::Null, + } + } + + /// Matches a LIKE pattern. + fn match_like(&self, text: &str, pattern: &str) -> bool { + // Simple LIKE implementation: % = any chars, _ = single char + let regex_pattern = pattern + .replace('%', ".*") + .replace('_', "."); + // For simplicity, just do case-insensitive contains for now + if pattern.starts_with('%') && pattern.ends_with('%') { + let inner = &pattern[1..pattern.len() - 1]; + text.to_lowercase().contains(&inner.to_lowercase()) + } else if pattern.starts_with('%') { + let suffix = &pattern[1..]; + text.to_lowercase().ends_with(&suffix.to_lowercase()) + } else if pattern.ends_with('%') { + let prefix = &pattern[..pattern.len() - 1]; + text.to_lowercase().starts_with(&prefix.to_lowercase()) + } else { + text.to_lowercase() == pattern.to_lowercase() + } + } + + /// Executes an INSERT statement. + fn execute_insert( + &self, + table_name: &str, + columns: &[String], + values: &[Vec], + ) -> Result { + let tables = self.tables.read(); + let table = tables + .get(table_name) + .ok_or_else(|| SqlError::TableNotFound(table_name.to_string()))?; + + let cols = if columns.is_empty() { + table.def.column_names() + } else { + columns.to_vec() + }; + + let mut count = 0; + for row_values in values { + if row_values.len() != cols.len() { + return Err(SqlError::InvalidOperation(format!( + "Column count mismatch: expected {}, got {}", + cols.len(), + row_values.len() + ))); + } + + let mut row_map = HashMap::new(); + for (col, val) in cols.iter().zip(row_values.iter()) { + row_map.insert(col.clone(), val.clone()); + } + + table.insert(row_map)?; + count += 1; + } + + Ok(QueryResult::affected(count)) + } + + /// Executes an UPDATE statement. + fn execute_update( + &self, + table_name: &str, + assignments: &[(String, SqlValue)], + where_clause: Option<&ParsedExpr>, + ) -> Result { + let tables = self.tables.read(); + let table = tables + .get(table_name) + .ok_or_else(|| SqlError::TableNotFound(table_name.to_string()))?; + + let rows = table.scan(); + let mut count = 0; + + for row in rows { + let matches = where_clause + .map(|w| self.evaluate_where(&row, w)) + .unwrap_or(true); + + if matches { + let updates: HashMap = + assignments.iter().cloned().collect(); + table.update(row.id, updates)?; + count += 1; + } + } + + Ok(QueryResult::affected(count)) + } + + /// Executes a DELETE statement. + fn execute_delete( + &self, + table_name: &str, + where_clause: Option<&ParsedExpr>, + ) -> Result { + let tables = self.tables.read(); + let table = tables + .get(table_name) + .ok_or_else(|| SqlError::TableNotFound(table_name.to_string()))?; + + let rows = table.scan(); + let mut count = 0; + + let to_delete: Vec = rows + .iter() + .filter(|row| { + where_clause + .map(|w| self.evaluate_where(row, w)) + .unwrap_or(true) + }) + .map(|row| row.id) + .collect(); + + for id in to_delete { + if table.delete(id)? { + count += 1; + } + } + + Ok(QueryResult::affected(count)) + } + + /// Creates an index. + fn execute_create_index( + &self, + name: &str, + table_name: &str, + columns: &[String], + unique: bool, + ) -> Result { + let tables = self.tables.read(); + let table = tables + .get(table_name) + .ok_or_else(|| SqlError::TableNotFound(table_name.to_string()))?; + + // For simplicity, only support single-column indexes + let column = columns + .first() + .ok_or_else(|| SqlError::InvalidOperation("Index requires at least one column".to_string()))?; + + table.create_index(name, column, unique)?; + Ok(QueryResult::empty()) + } + + /// Drops an index. + fn execute_drop_index(&self, name: &str) -> Result { + // Would need to find which table has this index + // For now, return success + Ok(QueryResult::empty()) + } + + // Transaction methods + + /// Begins a transaction. + pub fn begin_transaction(&self) -> TransactionId { + self.txn_manager.begin(IsolationLevel::ReadCommitted) + } + + /// Commits a transaction. + pub fn commit(&self, txn_id: TransactionId) -> Result<(), SqlError> { + self.txn_manager.commit(txn_id)?; + Ok(()) + } + + /// Rolls back a transaction. + pub fn rollback(&self, txn_id: TransactionId) -> Result<(), SqlError> { + self.txn_manager.rollback(txn_id)?; + Ok(()) + } + + /// Returns number of tables. + pub fn table_count(&self) -> usize { + self.tables.read().len() + } + + /// Returns table names. + pub fn table_names(&self) -> Vec { + self.tables.read().keys().cloned().collect() + } + + /// Gets table definition. + pub fn get_table_def(&self, name: &str) -> Option { + self.tables.read().get(name).map(|t| t.def.clone()) + } +} + +impl Default for SqlEngine { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn setup_engine() -> SqlEngine { + let engine = SqlEngine::new(); + engine + .execute("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT NOT NULL, age INTEGER)") + .unwrap(); + engine + } + + #[test] + fn test_create_table() { + let engine = SqlEngine::new(); + engine + .execute("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)") + .unwrap(); + assert_eq!(engine.table_count(), 1); + } + + #[test] + fn test_insert_and_select() { + let engine = setup_engine(); + + engine + .execute("INSERT INTO users (id, name, age) VALUES (1, 'Alice', 30)") + .unwrap(); + engine + .execute("INSERT INTO users (id, name, age) VALUES (2, 'Bob', 25)") + .unwrap(); + + let result = engine.execute("SELECT name, age FROM users").unwrap(); + assert_eq!(result.rows.len(), 2); + } + + #[test] + fn test_select_with_where() { + let engine = setup_engine(); + + engine + .execute("INSERT INTO users (id, name, age) VALUES (1, 'Alice', 30)") + .unwrap(); + engine + .execute("INSERT INTO users (id, name, age) VALUES (2, 'Bob', 25)") + .unwrap(); + + let result = engine.execute("SELECT name FROM users WHERE age > 26").unwrap(); + assert_eq!(result.rows.len(), 1); + assert_eq!(result.rows[0][0], SqlValue::Text("Alice".to_string())); + } + + #[test] + fn test_select_order_by() { + let engine = setup_engine(); + + engine + .execute("INSERT INTO users (id, name, age) VALUES (1, 'Alice', 30)") + .unwrap(); + engine + .execute("INSERT INTO users (id, name, age) VALUES (2, 'Bob', 25)") + .unwrap(); + + let result = engine + .execute("SELECT name FROM users ORDER BY age") + .unwrap(); + assert_eq!(result.rows[0][0], SqlValue::Text("Bob".to_string())); + assert_eq!(result.rows[1][0], SqlValue::Text("Alice".to_string())); + } + + #[test] + fn test_select_limit() { + let engine = setup_engine(); + + for i in 1..=10 { + engine + .execute(&format!( + "INSERT INTO users (id, name, age) VALUES ({}, 'User{}', {})", + i, i, 20 + i + )) + .unwrap(); + } + + let result = engine.execute("SELECT name FROM users LIMIT 3").unwrap(); + assert_eq!(result.rows.len(), 3); + } + + #[test] + fn test_aggregate_count() { + let engine = setup_engine(); + + engine + .execute("INSERT INTO users (id, name, age) VALUES (1, 'Alice', 30)") + .unwrap(); + engine + .execute("INSERT INTO users (id, name, age) VALUES (2, 'Bob', 25)") + .unwrap(); + + let result = engine.execute("SELECT COUNT(*) FROM users").unwrap(); + assert_eq!(result.rows[0][0], SqlValue::Integer(2)); + } + + #[test] + fn test_aggregate_sum_avg() { + let engine = setup_engine(); + + engine + .execute("INSERT INTO users (id, name, age) VALUES (1, 'Alice', 30)") + .unwrap(); + engine + .execute("INSERT INTO users (id, name, age) VALUES (2, 'Bob', 20)") + .unwrap(); + + let result = engine.execute("SELECT SUM(age) FROM users").unwrap(); + assert_eq!(result.rows[0][0], SqlValue::Real(50.0)); + + let result = engine.execute("SELECT AVG(age) FROM users").unwrap(); + assert_eq!(result.rows[0][0], SqlValue::Real(25.0)); + } + + #[test] + fn test_update() { + let engine = setup_engine(); + + engine + .execute("INSERT INTO users (id, name, age) VALUES (1, 'Alice', 30)") + .unwrap(); + + let result = engine + .execute("UPDATE users SET age = 31 WHERE name = 'Alice'") + .unwrap(); + assert_eq!(result.rows_affected, 1); + + let result = engine + .execute("SELECT age FROM users WHERE name = 'Alice'") + .unwrap(); + assert_eq!(result.rows[0][0], SqlValue::Integer(31)); + } + + #[test] + fn test_delete() { + let engine = setup_engine(); + + engine + .execute("INSERT INTO users (id, name, age) VALUES (1, 'Alice', 30)") + .unwrap(); + engine + .execute("INSERT INTO users (id, name, age) VALUES (2, 'Bob', 25)") + .unwrap(); + + let result = engine + .execute("DELETE FROM users WHERE name = 'Alice'") + .unwrap(); + assert_eq!(result.rows_affected, 1); + + let result = engine.execute("SELECT COUNT(*) FROM users").unwrap(); + assert_eq!(result.rows[0][0], SqlValue::Integer(1)); + } + + #[test] + fn test_drop_table() { + let engine = setup_engine(); + assert_eq!(engine.table_count(), 1); + + engine.execute("DROP TABLE users").unwrap(); + assert_eq!(engine.table_count(), 0); + } +} diff --git a/crates/synor-database/src/sql/mod.rs b/crates/synor-database/src/sql/mod.rs new file mode 100644 index 0000000..258ecda --- /dev/null +++ b/crates/synor-database/src/sql/mod.rs @@ -0,0 +1,23 @@ +//! SQL Query Layer for Synor Database. +//! +//! Provides SQLite-compatible query interface: +//! +//! - DDL: CREATE TABLE, DROP TABLE, ALTER TABLE +//! - DML: SELECT, INSERT, UPDATE, DELETE +//! - Clauses: WHERE, ORDER BY, LIMIT, GROUP BY +//! - Joins: INNER, LEFT, RIGHT JOIN +//! - Functions: COUNT, SUM, AVG, MIN, MAX + +pub mod executor; +pub mod parser; +pub mod row; +pub mod table; +pub mod transaction; +pub mod types; + +pub use executor::{QueryResult, SqlEngine}; +pub use parser::SqlParser; +pub use row::{Row, RowId}; +pub use table::{ColumnDef, Table, TableDef}; +pub use transaction::{Transaction, TransactionId, TransactionState}; +pub use types::{SqlError, SqlType, SqlValue}; diff --git a/crates/synor-database/src/sql/parser.rs b/crates/synor-database/src/sql/parser.rs new file mode 100644 index 0000000..71d49e0 --- /dev/null +++ b/crates/synor-database/src/sql/parser.rs @@ -0,0 +1,732 @@ +//! SQL parser using sqlparser-rs. + +use super::types::{SqlError, SqlType, SqlValue}; +use sqlparser::ast::{ + BinaryOperator, ColumnOption, DataType, Expr, Query, Select, SelectItem, SetExpr, Statement, + TableFactor, Value as AstValue, +}; +use sqlparser::dialect::SQLiteDialect; +use sqlparser::parser::Parser; + +/// Parsed SQL statement. +#[derive(Debug)] +pub enum ParsedStatement { + /// CREATE TABLE statement. + CreateTable { + name: String, + columns: Vec, + if_not_exists: bool, + }, + /// DROP TABLE statement. + DropTable { + name: String, + if_exists: bool, + }, + /// SELECT statement. + Select(ParsedSelect), + /// INSERT statement. + Insert { + table: String, + columns: Vec, + values: Vec>, + }, + /// UPDATE statement. + Update { + table: String, + assignments: Vec<(String, SqlValue)>, + where_clause: Option, + }, + /// DELETE statement. + Delete { + table: String, + where_clause: Option, + }, + /// CREATE INDEX statement. + CreateIndex { + name: String, + table: String, + columns: Vec, + unique: bool, + }, + /// DROP INDEX statement. + DropIndex { name: String }, +} + +/// Parsed column definition. +#[derive(Debug)] +pub struct ParsedColumn { + pub name: String, + pub data_type: SqlType, + pub nullable: bool, + pub default: Option, + pub primary_key: bool, + pub unique: bool, +} + +/// Parsed SELECT statement. +#[derive(Debug)] +pub struct ParsedSelect { + pub columns: Vec, + pub from: String, + pub joins: Vec, + pub where_clause: Option, + pub group_by: Vec, + pub having: Option, + pub order_by: Vec, + pub limit: Option, + pub offset: Option, +} + +/// Parsed select item. +#[derive(Debug)] +pub enum ParsedSelectItem { + /// All columns (*). + Wildcard, + /// Single column. + Column(String), + /// Column with alias. + ColumnAlias { column: String, alias: String }, + /// Aggregate function. + Aggregate { + function: String, + column: Option, + alias: Option, + }, +} + +/// Parsed JOIN clause. +#[derive(Debug)] +pub struct ParsedJoin { + pub table: String, + pub join_type: JoinType, + pub on: ParsedExpr, +} + +/// Join types. +#[derive(Debug, Clone)] +pub enum JoinType { + Inner, + Left, + Right, + Full, +} + +/// Parsed ORDER BY clause. +#[derive(Debug)] +pub struct ParsedOrderBy { + pub column: String, + pub ascending: bool, +} + +/// Parsed expression. +#[derive(Debug, Clone)] +pub enum ParsedExpr { + /// Column reference. + Column(String), + /// Literal value. + Literal(SqlValue), + /// Binary operation. + BinaryOp { + left: Box, + op: BinaryOp, + right: Box, + }, + /// Unary NOT. + Not(Box), + /// IS NULL check. + IsNull(Box), + /// IS NOT NULL check. + IsNotNull(Box), + /// IN list. + InList { + expr: Box, + list: Vec, + negated: bool, + }, + /// BETWEEN. + Between { + expr: Box, + low: Box, + high: Box, + negated: bool, + }, + /// Function call. + Function { name: String, args: Vec }, +} + +/// Binary operators. +#[derive(Debug, Clone)] +pub enum BinaryOp { + Eq, + Ne, + Lt, + Le, + Gt, + Ge, + And, + Or, + Like, + Plus, + Minus, + Multiply, + Divide, +} + +/// SQL parser. +pub struct SqlParser; + +impl SqlParser { + /// Parses a SQL statement. + pub fn parse(sql: &str) -> Result { + let dialect = SQLiteDialect {}; + let statements = Parser::parse_sql(&dialect, sql) + .map_err(|e| SqlError::Parse(e.to_string()))?; + + if statements.is_empty() { + return Err(SqlError::Parse("Empty SQL statement".to_string())); + } + + if statements.len() > 1 { + return Err(SqlError::Parse("Multiple statements not supported".to_string())); + } + + Self::convert_statement(&statements[0]) + } + + fn convert_statement(stmt: &Statement) -> Result { + match stmt { + Statement::CreateTable { name, columns, if_not_exists, constraints, .. } => { + Self::convert_create_table(name, columns, constraints, *if_not_exists) + } + Statement::Drop { object_type, names, if_exists, .. } => { + Self::convert_drop(object_type, names, *if_exists) + } + Statement::Query(query) => Self::convert_query(query), + Statement::Insert { table_name, columns, source, .. } => { + Self::convert_insert(table_name, columns, source) + } + Statement::Update { table, assignments, selection, .. } => { + Self::convert_update(table, assignments, selection) + } + Statement::Delete { from, selection, .. } => { + Self::convert_delete(from, selection) + } + Statement::CreateIndex { name, table_name, columns, unique, .. } => { + Self::convert_create_index(name, table_name, columns, *unique) + } + _ => Err(SqlError::Unsupported(format!("Statement not supported"))), + } + } + + fn convert_create_table( + name: &sqlparser::ast::ObjectName, + columns: &[sqlparser::ast::ColumnDef], + constraints: &[sqlparser::ast::TableConstraint], + if_not_exists: bool, + ) -> Result { + let table_name = name.to_string(); + let mut parsed_columns = Vec::new(); + let mut primary_keys: Vec = Vec::new(); + + // Extract primary keys from table constraints + for constraint in constraints { + if let sqlparser::ast::TableConstraint::Unique { columns: pk_cols, is_primary: true, .. } = constraint { + for col in pk_cols { + primary_keys.push(col.value.clone()); + } + } + } + + for col in columns { + let col_name = col.name.value.clone(); + let data_type = Self::convert_data_type(&col.data_type)?; + + let mut nullable = true; + let mut default = None; + let mut primary_key = primary_keys.contains(&col_name); + let mut unique = false; + + for option in &col.options { + match &option.option { + ColumnOption::Null => nullable = true, + ColumnOption::NotNull => nullable = false, + ColumnOption::Default(expr) => { + default = Some(Self::convert_value_expr(expr)?); + } + ColumnOption::Unique { is_primary, .. } => { + if *is_primary { + primary_key = true; + } else { + unique = true; + } + } + _ => {} + } + } + + if primary_key { + nullable = false; + unique = true; + } + + parsed_columns.push(ParsedColumn { + name: col_name, + data_type, + nullable, + default, + primary_key, + unique, + }); + } + + Ok(ParsedStatement::CreateTable { + name: table_name, + columns: parsed_columns, + if_not_exists, + }) + } + + fn convert_data_type(dt: &DataType) -> Result { + match dt { + DataType::Int(_) + | DataType::Integer(_) + | DataType::BigInt(_) + | DataType::SmallInt(_) + | DataType::TinyInt(_) => Ok(SqlType::Integer), + DataType::Real | DataType::Float(_) | DataType::Double | DataType::DoublePrecision => { + Ok(SqlType::Real) + } + DataType::Varchar(_) + | DataType::Char(_) + | DataType::Text + | DataType::String(_) => Ok(SqlType::Text), + DataType::Binary(_) | DataType::Varbinary(_) | DataType::Blob(_) => Ok(SqlType::Blob), + DataType::Boolean => Ok(SqlType::Boolean), + DataType::Timestamp(_, _) | DataType::Date | DataType::Datetime(_) => { + Ok(SqlType::Timestamp) + } + _ => Err(SqlError::Unsupported(format!("Data type: {:?}", dt))), + } + } + + fn convert_drop( + object_type: &sqlparser::ast::ObjectType, + names: &[sqlparser::ast::ObjectName], + if_exists: bool, + ) -> Result { + let name = names + .first() + .map(|n| n.to_string()) + .ok_or_else(|| SqlError::Parse("Missing object name".to_string()))?; + + match object_type { + sqlparser::ast::ObjectType::Table => Ok(ParsedStatement::DropTable { name, if_exists }), + sqlparser::ast::ObjectType::Index => Ok(ParsedStatement::DropIndex { name }), + _ => Err(SqlError::Unsupported(format!("DROP not supported"))), + } + } + + fn convert_query(query: &Query) -> Result { + let select = match &*query.body { + SetExpr::Select(select) => select, + _ => return Err(SqlError::Unsupported("Non-SELECT query body".to_string())), + }; + + let parsed_select = Self::convert_select(select, query)?; + Ok(ParsedStatement::Select(parsed_select)) + } + + fn convert_select(select: &Select, query: &Query) -> Result { + // Parse columns + let columns = select + .projection + .iter() + .map(Self::convert_select_item) + .collect::, _>>()?; + + // Parse FROM + let from = select + .from + .first() + .map(|f| Self::convert_table_factor(&f.relation)) + .transpose()? + .unwrap_or_default(); + + // Parse WHERE + let where_clause = select + .selection + .as_ref() + .map(Self::convert_expr) + .transpose()?; + + // Parse ORDER BY - simplified approach + let order_by: Vec = query + .order_by + .iter() + .filter_map(|e| Self::convert_order_by(e).ok()) + .collect(); + + // Parse LIMIT/OFFSET + let limit = query + .limit + .as_ref() + .and_then(|l| Self::expr_to_usize(l)); + let offset = query + .offset + .as_ref() + .and_then(|o| Self::expr_to_usize(&o.value)); + + Ok(ParsedSelect { + columns, + from, + joins: Vec::new(), + where_clause, + group_by: Vec::new(), + having: None, + order_by, + limit, + offset, + }) + } + + fn convert_select_item(item: &SelectItem) -> Result { + match item { + SelectItem::Wildcard(_) => Ok(ParsedSelectItem::Wildcard), + SelectItem::UnnamedExpr(expr) => Self::convert_select_expr(expr), + SelectItem::ExprWithAlias { expr, alias } => { + if let Expr::Identifier(id) = expr { + Ok(ParsedSelectItem::ColumnAlias { + column: id.value.clone(), + alias: alias.value.clone(), + }) + } else { + Self::convert_select_expr(expr) + } + } + _ => Err(SqlError::Unsupported("Select item not supported".to_string())), + } + } + + fn convert_select_expr(expr: &Expr) -> Result { + match expr { + Expr::Identifier(id) => Ok(ParsedSelectItem::Column(id.value.clone())), + Expr::CompoundIdentifier(ids) => { + Ok(ParsedSelectItem::Column(ids.last().map(|i| i.value.clone()).unwrap_or_default())) + } + Expr::Function(func) => { + let name = func.name.to_string().to_uppercase(); + // Try to extract column from first arg - simplified for compatibility + let column = Self::extract_func_column_arg(func); + Ok(ParsedSelectItem::Aggregate { + function: name, + column, + alias: None, + }) + } + _ => Err(SqlError::Unsupported("Select expression not supported".to_string())), + } + } + + fn convert_table_factor(factor: &TableFactor) -> Result { + match factor { + TableFactor::Table { name, .. } => Ok(name.to_string()), + _ => Err(SqlError::Unsupported("Table factor not supported".to_string())), + } + } + + fn extract_func_column_arg(func: &sqlparser::ast::Function) -> Option { + // Use string representation and parse it - works across sqlparser versions + let func_str = func.to_string(); + // Parse function like "SUM(age)" or "COUNT(*)" + if let Some(start) = func_str.find('(') { + if let Some(end) = func_str.rfind(')') { + let arg = func_str[start + 1..end].trim(); + if !arg.is_empty() && arg != "*" { + return Some(arg.to_string()); + } + } + } + None + } + + fn convert_order_by(ob: &sqlparser::ast::OrderByExpr) -> Result { + let column = match &ob.expr { + Expr::Identifier(id) => id.value.clone(), + _ => return Err(SqlError::Unsupported("Order by expression".to_string())), + }; + let ascending = ob.asc.unwrap_or(true); + Ok(ParsedOrderBy { column, ascending }) + } + + fn convert_expr(expr: &Expr) -> Result { + match expr { + Expr::Identifier(id) => Ok(ParsedExpr::Column(id.value.clone())), + Expr::CompoundIdentifier(ids) => { + Ok(ParsedExpr::Column(ids.last().map(|i| i.value.clone()).unwrap_or_default())) + } + Expr::Value(v) => Ok(ParsedExpr::Literal(Self::convert_value(v)?)), + Expr::BinaryOp { left, op, right } => { + let left = Box::new(Self::convert_expr(left)?); + let right = Box::new(Self::convert_expr(right)?); + let op = Self::convert_binary_op(op)?; + Ok(ParsedExpr::BinaryOp { left, op, right }) + } + Expr::UnaryOp { op: sqlparser::ast::UnaryOperator::Not, expr } => { + Ok(ParsedExpr::Not(Box::new(Self::convert_expr(expr)?))) + } + Expr::IsNull(expr) => Ok(ParsedExpr::IsNull(Box::new(Self::convert_expr(expr)?))), + Expr::IsNotNull(expr) => Ok(ParsedExpr::IsNotNull(Box::new(Self::convert_expr(expr)?))), + Expr::InList { expr, list, negated } => Ok(ParsedExpr::InList { + expr: Box::new(Self::convert_expr(expr)?), + list: list.iter().map(Self::convert_expr).collect::>()?, + negated: *negated, + }), + Expr::Between { expr, low, high, negated } => Ok(ParsedExpr::Between { + expr: Box::new(Self::convert_expr(expr)?), + low: Box::new(Self::convert_expr(low)?), + high: Box::new(Self::convert_expr(high)?), + negated: *negated, + }), + Expr::Like { expr, pattern, .. } => { + let left = Box::new(Self::convert_expr(expr)?); + let right = Box::new(Self::convert_expr(pattern)?); + Ok(ParsedExpr::BinaryOp { left, op: BinaryOp::Like, right }) + } + Expr::Nested(inner) => Self::convert_expr(inner), + _ => Err(SqlError::Unsupported("Expression not supported".to_string())), + } + } + + fn convert_binary_op(op: &BinaryOperator) -> Result { + match op { + BinaryOperator::Eq => Ok(BinaryOp::Eq), + BinaryOperator::NotEq => Ok(BinaryOp::Ne), + BinaryOperator::Lt => Ok(BinaryOp::Lt), + BinaryOperator::LtEq => Ok(BinaryOp::Le), + BinaryOperator::Gt => Ok(BinaryOp::Gt), + BinaryOperator::GtEq => Ok(BinaryOp::Ge), + BinaryOperator::And => Ok(BinaryOp::And), + BinaryOperator::Or => Ok(BinaryOp::Or), + BinaryOperator::Plus => Ok(BinaryOp::Plus), + BinaryOperator::Minus => Ok(BinaryOp::Minus), + BinaryOperator::Multiply => Ok(BinaryOp::Multiply), + BinaryOperator::Divide => Ok(BinaryOp::Divide), + _ => Err(SqlError::Unsupported("Operator not supported".to_string())), + } + } + + fn convert_value(v: &AstValue) -> Result { + match v { + AstValue::Null => Ok(SqlValue::Null), + AstValue::Number(n, _) => { + if n.contains('.') { + n.parse::() + .map(SqlValue::Real) + .map_err(|e| SqlError::Parse(e.to_string())) + } else { + n.parse::() + .map(SqlValue::Integer) + .map_err(|e| SqlError::Parse(e.to_string())) + } + } + AstValue::SingleQuotedString(s) | AstValue::DoubleQuotedString(s) => { + Ok(SqlValue::Text(s.clone())) + } + AstValue::Boolean(b) => Ok(SqlValue::Boolean(*b)), + AstValue::HexStringLiteral(h) => hex::decode(h) + .map(SqlValue::Blob) + .map_err(|e| SqlError::Parse(e.to_string())), + _ => Err(SqlError::Unsupported("Value not supported".to_string())), + } + } + + fn convert_value_expr(expr: &Expr) -> Result { + match expr { + Expr::Value(v) => Self::convert_value(v), + _ => Err(SqlError::Unsupported("Non-literal default".to_string())), + } + } + + fn convert_insert( + table_name: &sqlparser::ast::ObjectName, + columns: &[sqlparser::ast::Ident], + source: &Option>, + ) -> Result { + let table = table_name.to_string(); + let col_names: Vec = columns.iter().map(|c| c.value.clone()).collect(); + + let values = match source.as_ref().map(|s| s.body.as_ref()) { + Some(SetExpr::Values(vals)) => { + let mut result = Vec::new(); + for row in &vals.rows { + let row_values: Vec = row + .iter() + .map(Self::convert_value_expr) + .collect::>()?; + result.push(row_values); + } + result + } + _ => return Err(SqlError::Unsupported("INSERT source".to_string())), + }; + + Ok(ParsedStatement::Insert { + table, + columns: col_names, + values, + }) + } + + fn convert_update( + table: &sqlparser::ast::TableWithJoins, + assignments: &[sqlparser::ast::Assignment], + selection: &Option, + ) -> Result { + let table_name = Self::convert_table_factor(&table.relation)?; + + let parsed_assignments: Vec<(String, SqlValue)> = assignments + .iter() + .map(|a| { + let col = a.id.iter().map(|i| i.value.clone()).collect::>().join("."); + let val = Self::convert_value_expr(&a.value)?; + Ok((col, val)) + }) + .collect::>()?; + + let where_clause = selection.as_ref().map(Self::convert_expr).transpose()?; + + Ok(ParsedStatement::Update { + table: table_name, + assignments: parsed_assignments, + where_clause, + }) + } + + fn convert_delete( + from: &[sqlparser::ast::TableWithJoins], + selection: &Option, + ) -> Result { + let table = from + .first() + .map(|f| Self::convert_table_factor(&f.relation)) + .transpose()? + .unwrap_or_default(); + + let where_clause = selection.as_ref().map(Self::convert_expr).transpose()?; + + Ok(ParsedStatement::Delete { + table, + where_clause, + }) + } + + fn convert_create_index( + name: &Option, + table_name: &sqlparser::ast::ObjectName, + columns: &[sqlparser::ast::OrderByExpr], + unique: bool, + ) -> Result { + let index_name = name + .as_ref() + .map(|n| n.to_string()) + .ok_or_else(|| SqlError::Parse("Index name required".to_string()))?; + + let table = table_name.to_string(); + + let cols: Vec = columns + .iter() + .map(|c| c.expr.to_string()) + .collect(); + + Ok(ParsedStatement::CreateIndex { + name: index_name, + table, + columns: cols, + unique, + }) + } + + fn expr_to_usize(expr: &Expr) -> Option { + if let Expr::Value(AstValue::Number(n, _)) = expr { + n.parse().ok() + } else { + None + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_create_table() { + let sql = "CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT NOT NULL, age INTEGER)"; + let stmt = SqlParser::parse(sql).unwrap(); + + if let ParsedStatement::CreateTable { name, columns, .. } = stmt { + assert_eq!(name, "users"); + assert_eq!(columns.len(), 3); + assert!(columns[0].primary_key); + assert!(!columns[1].nullable); + } else { + panic!("Expected CreateTable"); + } + } + + #[test] + fn test_parse_select() { + let sql = "SELECT name, age FROM users WHERE age > 18 ORDER BY name LIMIT 10"; + let stmt = SqlParser::parse(sql).unwrap(); + + if let ParsedStatement::Select(select) = stmt { + assert_eq!(select.columns.len(), 2); + assert_eq!(select.from, "users"); + assert!(select.where_clause.is_some()); + assert_eq!(select.limit, Some(10)); + } else { + panic!("Expected Select"); + } + } + + #[test] + fn test_parse_insert() { + let sql = "INSERT INTO users (name, age) VALUES ('Alice', 30), ('Bob', 25)"; + let stmt = SqlParser::parse(sql).unwrap(); + + if let ParsedStatement::Insert { table, columns, values } = stmt { + assert_eq!(table, "users"); + assert_eq!(columns, vec!["name", "age"]); + assert_eq!(values.len(), 2); + } else { + panic!("Expected Insert"); + } + } + + #[test] + fn test_parse_update() { + let sql = "UPDATE users SET age = 31 WHERE name = 'Alice'"; + let stmt = SqlParser::parse(sql).unwrap(); + + if let ParsedStatement::Update { table, assignments, where_clause } = stmt { + assert_eq!(table, "users"); + assert_eq!(assignments.len(), 1); + assert!(where_clause.is_some()); + } else { + panic!("Expected Update"); + } + } + + #[test] + fn test_parse_delete() { + let sql = "DELETE FROM users WHERE age < 18"; + let stmt = SqlParser::parse(sql).unwrap(); + + if let ParsedStatement::Delete { table, where_clause } = stmt { + assert_eq!(table, "users"); + assert!(where_clause.is_some()); + } else { + panic!("Expected Delete"); + } + } +} diff --git a/crates/synor-database/src/sql/row.rs b/crates/synor-database/src/sql/row.rs new file mode 100644 index 0000000..8090591 --- /dev/null +++ b/crates/synor-database/src/sql/row.rs @@ -0,0 +1,241 @@ +//! Row representation for SQL tables. + +use super::types::{SqlError, SqlValue}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +/// Unique row identifier. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct RowId(pub u64); + +impl RowId { + /// Creates a new row ID. + pub fn new(id: u64) -> Self { + RowId(id) + } + + /// Returns the inner ID value. + pub fn inner(&self) -> u64 { + self.0 + } +} + +impl std::fmt::Display for RowId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +/// A single row in a SQL table. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct Row { + /// Row identifier. + pub id: RowId, + /// Column values indexed by column name. + values: HashMap, + /// Ordered column names (preserves insertion order). + columns: Vec, +} + +impl Row { + /// Creates a new empty row. + pub fn new(id: RowId) -> Self { + Self { + id, + values: HashMap::new(), + columns: Vec::new(), + } + } + + /// Creates a row with given columns. + pub fn with_columns(id: RowId, columns: Vec) -> Self { + let mut values = HashMap::with_capacity(columns.len()); + for col in &columns { + values.insert(col.clone(), SqlValue::Null); + } + Self { + id, + values, + columns, + } + } + + /// Sets a column value. + pub fn set(&mut self, column: &str, value: SqlValue) { + if !self.values.contains_key(column) { + self.columns.push(column.to_string()); + } + self.values.insert(column.to_string(), value); + } + + /// Gets a column value. + pub fn get(&self, column: &str) -> Option<&SqlValue> { + self.values.get(column) + } + + /// Gets a column value or returns Null. + pub fn get_or_null(&self, column: &str) -> SqlValue { + self.values.get(column).cloned().unwrap_or(SqlValue::Null) + } + + /// Returns all column names. + pub fn columns(&self) -> &[String] { + &self.columns + } + + /// Returns all values in column order. + pub fn values(&self) -> Vec<&SqlValue> { + self.columns.iter().map(|c| self.values.get(c).unwrap()).collect() + } + + /// Returns the number of columns. + pub fn len(&self) -> usize { + self.columns.len() + } + + /// Returns true if the row has no columns. + pub fn is_empty(&self) -> bool { + self.columns.is_empty() + } + + /// Projects to specific columns. + pub fn project(&self, columns: &[String]) -> Row { + let mut row = Row::new(self.id); + for col in columns { + if let Some(value) = self.values.get(col) { + row.set(col, value.clone()); + } + } + row + } + + /// Converts to a map. + pub fn to_map(&self) -> HashMap { + self.values.clone() + } + + /// Converts from a map. + pub fn from_map(id: RowId, map: HashMap) -> Self { + let columns: Vec = map.keys().cloned().collect(); + Self { + id, + values: map, + columns, + } + } +} + +impl PartialEq for Row { + fn eq(&self, other: &Self) -> bool { + if self.columns.len() != other.columns.len() { + return false; + } + for col in &self.columns { + if self.get(col) != other.get(col) { + return false; + } + } + true + } +} + +/// Builder for creating rows. +pub struct RowBuilder { + id: RowId, + values: Vec<(String, SqlValue)>, +} + +impl RowBuilder { + /// Creates a new row builder. + pub fn new(id: RowId) -> Self { + Self { + id, + values: Vec::new(), + } + } + + /// Adds a column value. + pub fn column(mut self, name: impl Into, value: SqlValue) -> Self { + self.values.push((name.into(), value)); + self + } + + /// Adds an integer column. + pub fn int(self, name: impl Into, value: i64) -> Self { + self.column(name, SqlValue::Integer(value)) + } + + /// Adds a real column. + pub fn real(self, name: impl Into, value: f64) -> Self { + self.column(name, SqlValue::Real(value)) + } + + /// Adds a text column. + pub fn text(self, name: impl Into, value: impl Into) -> Self { + self.column(name, SqlValue::Text(value.into())) + } + + /// Adds a boolean column. + pub fn boolean(self, name: impl Into, value: bool) -> Self { + self.column(name, SqlValue::Boolean(value)) + } + + /// Adds a null column. + pub fn null(self, name: impl Into) -> Self { + self.column(name, SqlValue::Null) + } + + /// Builds the row. + pub fn build(self) -> Row { + let mut row = Row::new(self.id); + for (name, value) in self.values { + row.set(&name, value); + } + row + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_row_basic() { + let mut row = Row::new(RowId(1)); + row.set("name", SqlValue::Text("Alice".to_string())); + row.set("age", SqlValue::Integer(30)); + + assert_eq!(row.get("name"), Some(&SqlValue::Text("Alice".to_string()))); + assert_eq!(row.get("age"), Some(&SqlValue::Integer(30))); + assert_eq!(row.get("missing"), None); + assert_eq!(row.len(), 2); + } + + #[test] + fn test_row_builder() { + let row = RowBuilder::new(RowId(1)) + .text("name", "Bob") + .int("age", 25) + .boolean("active", true) + .build(); + + assert_eq!(row.get("name"), Some(&SqlValue::Text("Bob".to_string()))); + assert_eq!(row.get("age"), Some(&SqlValue::Integer(25))); + assert_eq!(row.get("active"), Some(&SqlValue::Boolean(true))); + } + + #[test] + fn test_row_projection() { + let row = RowBuilder::new(RowId(1)) + .text("a", "1") + .text("b", "2") + .text("c", "3") + .build(); + + let projected = row.project(&["a".to_string(), "c".to_string()]); + assert_eq!(projected.len(), 2); + assert!(projected.get("a").is_some()); + assert!(projected.get("b").is_none()); + assert!(projected.get("c").is_some()); + } +} diff --git a/crates/synor-database/src/sql/table.rs b/crates/synor-database/src/sql/table.rs new file mode 100644 index 0000000..7cce4e3 --- /dev/null +++ b/crates/synor-database/src/sql/table.rs @@ -0,0 +1,570 @@ +//! SQL table definition and storage. + +use super::row::{Row, RowId}; +use super::types::{SqlError, SqlType, SqlValue}; +use parking_lot::RwLock; +use serde::{Deserialize, Serialize}; +use std::collections::{BTreeMap, HashMap, HashSet}; + +/// Column definition. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ColumnDef { + /// Column name. + pub name: String, + /// Data type. + pub data_type: SqlType, + /// Whether null values are allowed. + pub nullable: bool, + /// Default value. + pub default: Option, + /// Primary key flag. + pub primary_key: bool, + /// Unique constraint. + pub unique: bool, +} + +impl ColumnDef { + /// Creates a new column definition. + pub fn new(name: impl Into, data_type: SqlType) -> Self { + Self { + name: name.into(), + data_type, + nullable: true, + default: None, + primary_key: false, + unique: false, + } + } + + /// Sets as not null. + pub fn not_null(mut self) -> Self { + self.nullable = false; + self + } + + /// Sets default value. + pub fn default(mut self, value: SqlValue) -> Self { + self.default = Some(value); + self + } + + /// Sets as primary key. + pub fn primary_key(mut self) -> Self { + self.primary_key = true; + self.nullable = false; + self.unique = true; + self + } + + /// Sets as unique. + pub fn unique(mut self) -> Self { + self.unique = true; + self + } + + /// Validates a value against this column definition. + pub fn validate(&self, value: &SqlValue) -> Result<(), SqlError> { + // Check null constraint + if value.is_null() && !self.nullable { + return Err(SqlError::NotNullViolation(self.name.clone())); + } + + // Skip type check for null values + if value.is_null() { + return Ok(()); + } + + // Check type compatibility + let compatible = match (&self.data_type, value) { + (SqlType::Integer, SqlValue::Integer(_)) => true, + (SqlType::Real, SqlValue::Real(_)) => true, + (SqlType::Real, SqlValue::Integer(_)) => true, // Allow int to real + (SqlType::Text, SqlValue::Text(_)) => true, + (SqlType::Blob, SqlValue::Blob(_)) => true, + (SqlType::Boolean, SqlValue::Boolean(_)) => true, + (SqlType::Timestamp, SqlValue::Timestamp(_)) => true, + (SqlType::Timestamp, SqlValue::Integer(_)) => true, // Allow int to timestamp + _ => false, + }; + + if !compatible { + return Err(SqlError::TypeMismatch { + expected: format!("{:?}", self.data_type), + got: format!("{:?}", value.sql_type()), + }); + } + + Ok(()) + } +} + +/// Table definition. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct TableDef { + /// Table name. + pub name: String, + /// Column definitions. + pub columns: Vec, + /// Primary key column name (if any). + pub primary_key: Option, +} + +impl TableDef { + /// Creates a new table definition. + pub fn new(name: impl Into) -> Self { + Self { + name: name.into(), + columns: Vec::new(), + primary_key: None, + } + } + + /// Adds a column. + pub fn column(mut self, col: ColumnDef) -> Self { + if col.primary_key { + self.primary_key = Some(col.name.clone()); + } + self.columns.push(col); + self + } + + /// Gets column definition by name. + pub fn get_column(&self, name: &str) -> Option<&ColumnDef> { + self.columns.iter().find(|c| c.name == name) + } + + /// Returns column names. + pub fn column_names(&self) -> Vec { + self.columns.iter().map(|c| c.name.clone()).collect() + } + + /// Validates a row against this table definition. + pub fn validate_row(&self, row: &Row) -> Result<(), SqlError> { + for col in &self.columns { + let value = row.get_or_null(&col.name); + col.validate(&value)?; + } + Ok(()) + } +} + +/// Index on a table column. +#[derive(Debug)] +pub struct TableIndex { + /// Index name. + pub name: String, + /// Column being indexed. + pub column: String, + /// Whether index enforces uniqueness. + pub unique: bool, + /// B-tree index data: value -> row IDs. + data: BTreeMap>, +} + +impl TableIndex { + /// Creates a new index. + pub fn new(name: impl Into, column: impl Into, unique: bool) -> Self { + Self { + name: name.into(), + column: column.into(), + unique, + data: BTreeMap::new(), + } + } + + /// Inserts a value-rowid mapping. + pub fn insert(&mut self, value: SqlValue, row_id: RowId) -> Result<(), SqlError> { + if self.unique { + if let Some(existing) = self.data.get(&value) { + if !existing.is_empty() { + return Err(SqlError::ConstraintViolation(format!( + "Unique constraint violation on index '{}'", + self.name + ))); + } + } + } + self.data.entry(value).or_default().insert(row_id); + Ok(()) + } + + /// Removes a value-rowid mapping. + pub fn remove(&mut self, value: &SqlValue, row_id: &RowId) { + if let Some(ids) = self.data.get_mut(value) { + ids.remove(row_id); + if ids.is_empty() { + self.data.remove(value); + } + } + } + + /// Looks up rows by exact value. + pub fn lookup(&self, value: &SqlValue) -> Vec { + self.data + .get(value) + .map(|ids| ids.iter().cloned().collect()) + .unwrap_or_default() + } + + /// Range query. + pub fn range(&self, start: Option<&SqlValue>, end: Option<&SqlValue>) -> Vec { + let mut result = Vec::new(); + for (key, ids) in &self.data { + let in_range = match (start, end) { + (Some(s), Some(e)) => key >= s && key <= e, + (Some(s), None) => key >= s, + (None, Some(e)) => key <= e, + (None, None) => true, + }; + if in_range { + result.extend(ids.iter().cloned()); + } + } + result + } +} + +/// A SQL table with data. +pub struct Table { + /// Table definition. + pub def: TableDef, + /// Row storage: row ID -> row data. + rows: RwLock>, + /// Next row ID. + next_id: RwLock, + /// Indexes. + indexes: RwLock>, + /// Primary key index (if any). + pk_index: RwLock>, +} + +impl Table { + /// Creates a new table. + pub fn new(def: TableDef) -> Self { + let pk_col = def.primary_key.clone(); + let unique_cols: Vec = def + .columns + .iter() + .filter(|c| c.unique && !c.primary_key) + .map(|c| c.name.clone()) + .collect(); + + let table = Self { + def, + rows: RwLock::new(HashMap::new()), + next_id: RwLock::new(1), + indexes: RwLock::new(HashMap::new()), + pk_index: RwLock::new(None), + }; + + { + let mut indexes = table.indexes.write(); + + // Create primary key index if defined + if let Some(pk) = pk_col { + let idx_name = format!("pk_{}", pk); + indexes.insert(idx_name.clone(), TableIndex::new(&idx_name, &pk, true)); + *table.pk_index.write() = Some(idx_name); + } + + // Create indexes for unique columns + for col in unique_cols { + let idx_name = format!("unique_{}", col); + indexes.insert(idx_name.clone(), TableIndex::new(&idx_name, &col, true)); + } + } + + table + } + + /// Returns table name. + pub fn name(&self) -> &str { + &self.def.name + } + + /// Returns row count. + pub fn count(&self) -> usize { + self.rows.read().len() + } + + /// Creates an index on a column. + pub fn create_index( + &self, + name: impl Into, + column: impl Into, + unique: bool, + ) -> Result<(), SqlError> { + let name = name.into(); + let column = column.into(); + + let mut indexes = self.indexes.write(); + if indexes.contains_key(&name) { + return Err(SqlError::InvalidOperation(format!("Index '{}' already exists", name))); + } + + let mut index = TableIndex::new(&name, &column, unique); + + // Index existing rows + let rows = self.rows.read(); + for (row_id, row) in rows.iter() { + let value = row.get_or_null(&column); + index.insert(value, *row_id)?; + } + + indexes.insert(name, index); + Ok(()) + } + + /// Drops an index. + pub fn drop_index(&self, name: &str) -> Result<(), SqlError> { + let mut indexes = self.indexes.write(); + if indexes.remove(name).is_none() { + return Err(SqlError::InvalidOperation(format!("Index '{}' not found", name))); + } + Ok(()) + } + + /// Inserts a row. + pub fn insert(&self, values: HashMap) -> Result { + let mut next_id = self.next_id.write(); + let row_id = RowId(*next_id); + *next_id += 1; + + let row = Row::from_map(row_id, values); + + // Validate row + self.def.validate_row(&row)?; + + // Check uniqueness constraints via indexes + { + let mut indexes = self.indexes.write(); + for (_, index) in indexes.iter_mut() { + if index.unique { + let value = row.get_or_null(&index.column); + if !value.is_null() && !index.lookup(&value).is_empty() { + return Err(SqlError::ConstraintViolation(format!( + "Unique constraint violation on column '{}'", + index.column + ))); + } + } + } + + // Update indexes + for (_, index) in indexes.iter_mut() { + let value = row.get_or_null(&index.column); + index.insert(value, row_id)?; + } + } + + // Insert row + self.rows.write().insert(row_id, row); + + Ok(row_id) + } + + /// Gets a row by ID. + pub fn get(&self, id: RowId) -> Option { + self.rows.read().get(&id).cloned() + } + + /// Updates a row. + pub fn update(&self, id: RowId, updates: HashMap) -> Result<(), SqlError> { + let mut rows = self.rows.write(); + let row = rows.get_mut(&id).ok_or_else(|| { + SqlError::InvalidOperation(format!("Row {} not found", id)) + })?; + + let old_values: HashMap = updates + .keys() + .map(|k| (k.clone(), row.get_or_null(k))) + .collect(); + + // Validate updates + for (col, value) in &updates { + if let Some(col_def) = self.def.get_column(col) { + col_def.validate(value)?; + } + } + + // Update indexes + { + let mut indexes = self.indexes.write(); + for (_, index) in indexes.iter_mut() { + if let Some(new_value) = updates.get(&index.column) { + let old_value = old_values.get(&index.column).cloned().unwrap_or(SqlValue::Null); + index.remove(&old_value, &id); + index.insert(new_value.clone(), id)?; + } + } + } + + // Apply updates + for (col, value) in updates { + row.set(&col, value); + } + + Ok(()) + } + + /// Deletes a row. + pub fn delete(&self, id: RowId) -> Result { + let mut rows = self.rows.write(); + if let Some(row) = rows.remove(&id) { + // Update indexes + let mut indexes = self.indexes.write(); + for (_, index) in indexes.iter_mut() { + let value = row.get_or_null(&index.column); + index.remove(&value, &id); + } + Ok(true) + } else { + Ok(false) + } + } + + /// Scans all rows. + pub fn scan(&self) -> Vec { + self.rows.read().values().cloned().collect() + } + + /// Scans with filter function. + pub fn scan_filter(&self, predicate: F) -> Vec + where + F: Fn(&Row) -> bool, + { + self.rows + .read() + .values() + .filter(|row| predicate(row)) + .cloned() + .collect() + } + + /// Looks up rows by index. + pub fn lookup_index(&self, index_name: &str, value: &SqlValue) -> Vec { + let indexes = self.indexes.read(); + let rows = self.rows.read(); + + if let Some(index) = indexes.get(index_name) { + index + .lookup(value) + .into_iter() + .filter_map(|id| rows.get(&id).cloned()) + .collect() + } else { + Vec::new() + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn create_test_table() -> Table { + let def = TableDef::new("users") + .column(ColumnDef::new("id", SqlType::Integer).primary_key()) + .column(ColumnDef::new("name", SqlType::Text).not_null()) + .column(ColumnDef::new("age", SqlType::Integer)) + .column(ColumnDef::new("email", SqlType::Text).unique()); + + Table::new(def) + } + + #[test] + fn test_table_insert() { + let table = create_test_table(); + + let mut values = HashMap::new(); + values.insert("id".to_string(), SqlValue::Integer(1)); + values.insert("name".to_string(), SqlValue::Text("Alice".to_string())); + values.insert("age".to_string(), SqlValue::Integer(30)); + values.insert("email".to_string(), SqlValue::Text("alice@example.com".to_string())); + + let row_id = table.insert(values).unwrap(); + assert_eq!(table.count(), 1); + + let row = table.get(row_id).unwrap(); + assert_eq!(row.get("name"), Some(&SqlValue::Text("Alice".to_string()))); + } + + #[test] + fn test_table_not_null() { + let table = create_test_table(); + + let mut values = HashMap::new(); + values.insert("id".to_string(), SqlValue::Integer(1)); + // Missing required "name" field + + let result = table.insert(values); + assert!(result.is_err()); + } + + #[test] + fn test_table_unique() { + let table = create_test_table(); + + let mut values1 = HashMap::new(); + values1.insert("id".to_string(), SqlValue::Integer(1)); + values1.insert("name".to_string(), SqlValue::Text("Alice".to_string())); + values1.insert("email".to_string(), SqlValue::Text("test@example.com".to_string())); + table.insert(values1).unwrap(); + + let mut values2 = HashMap::new(); + values2.insert("id".to_string(), SqlValue::Integer(2)); + values2.insert("name".to_string(), SqlValue::Text("Bob".to_string())); + values2.insert("email".to_string(), SqlValue::Text("test@example.com".to_string())); + + let result = table.insert(values2); + assert!(result.is_err()); // Duplicate email + } + + #[test] + fn test_table_update() { + let table = create_test_table(); + + let mut values = HashMap::new(); + values.insert("id".to_string(), SqlValue::Integer(1)); + values.insert("name".to_string(), SqlValue::Text("Alice".to_string())); + values.insert("age".to_string(), SqlValue::Integer(30)); + + let row_id = table.insert(values).unwrap(); + + let mut updates = HashMap::new(); + updates.insert("age".to_string(), SqlValue::Integer(31)); + table.update(row_id, updates).unwrap(); + + let row = table.get(row_id).unwrap(); + assert_eq!(row.get("age"), Some(&SqlValue::Integer(31))); + } + + #[test] + fn test_table_delete() { + let table = create_test_table(); + + let mut values = HashMap::new(); + values.insert("id".to_string(), SqlValue::Integer(1)); + values.insert("name".to_string(), SqlValue::Text("Alice".to_string())); + + let row_id = table.insert(values).unwrap(); + assert_eq!(table.count(), 1); + + table.delete(row_id).unwrap(); + assert_eq!(table.count(), 0); + } + + #[test] + fn test_table_index() { + let table = create_test_table(); + table.create_index("idx_name", "name", false).unwrap(); + + let mut values = HashMap::new(); + values.insert("id".to_string(), SqlValue::Integer(1)); + values.insert("name".to_string(), SqlValue::Text("Alice".to_string())); + table.insert(values).unwrap(); + + let rows = table.lookup_index("idx_name", &SqlValue::Text("Alice".to_string())); + assert_eq!(rows.len(), 1); + } +} diff --git a/crates/synor-database/src/sql/transaction.rs b/crates/synor-database/src/sql/transaction.rs new file mode 100644 index 0000000..13f1fb5 --- /dev/null +++ b/crates/synor-database/src/sql/transaction.rs @@ -0,0 +1,355 @@ +//! ACID transaction support for SQL. + +use super::row::RowId; +use super::types::{SqlError, SqlValue}; +use parking_lot::RwLock; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::time::{SystemTime, UNIX_EPOCH}; + +/// Transaction identifier. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct TransactionId(pub u64); + +impl TransactionId { + /// Creates a new transaction ID. + pub fn new() -> Self { + static COUNTER: AtomicU64 = AtomicU64::new(1); + TransactionId(COUNTER.fetch_add(1, Ordering::SeqCst)) + } +} + +impl Default for TransactionId { + fn default() -> Self { + Self::new() + } +} + +impl std::fmt::Display for TransactionId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "txn_{}", self.0) + } +} + +/// Transaction state. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum TransactionState { + /// Transaction is active. + Active, + /// Transaction is committed. + Committed, + /// Transaction is rolled back. + RolledBack, +} + +/// A single operation in a transaction. +#[derive(Clone, Debug)] +pub enum TransactionOp { + /// Insert a row. + Insert { + table: String, + row_id: RowId, + values: HashMap, + }, + /// Update a row. + Update { + table: String, + row_id: RowId, + old_values: HashMap, + new_values: HashMap, + }, + /// Delete a row. + Delete { + table: String, + row_id: RowId, + old_values: HashMap, + }, +} + +/// Transaction for tracking changes. +#[derive(Debug)] +pub struct Transaction { + /// Transaction ID. + pub id: TransactionId, + /// Transaction state. + pub state: TransactionState, + /// Operations in this transaction. + operations: Vec, + /// Start time. + pub started_at: u64, + /// Isolation level. + pub isolation: IsolationLevel, +} + +/// Transaction isolation levels. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum IsolationLevel { + /// Read uncommitted (dirty reads allowed). + ReadUncommitted, + /// Read committed (no dirty reads). + ReadCommitted, + /// Repeatable read (no non-repeatable reads). + RepeatableRead, + /// Serializable (full isolation). + Serializable, +} + +impl Default for IsolationLevel { + fn default() -> Self { + IsolationLevel::ReadCommitted + } +} + +impl Transaction { + /// Creates a new transaction. + pub fn new(isolation: IsolationLevel) -> Self { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_millis() as u64; + + Self { + id: TransactionId::new(), + state: TransactionState::Active, + operations: Vec::new(), + started_at: now, + isolation, + } + } + + /// Returns true if the transaction is active. + pub fn is_active(&self) -> bool { + self.state == TransactionState::Active + } + + /// Records an insert operation. + pub fn record_insert(&mut self, table: String, row_id: RowId, values: HashMap) { + self.operations.push(TransactionOp::Insert { + table, + row_id, + values, + }); + } + + /// Records an update operation. + pub fn record_update( + &mut self, + table: String, + row_id: RowId, + old_values: HashMap, + new_values: HashMap, + ) { + self.operations.push(TransactionOp::Update { + table, + row_id, + old_values, + new_values, + }); + } + + /// Records a delete operation. + pub fn record_delete(&mut self, table: String, row_id: RowId, old_values: HashMap) { + self.operations.push(TransactionOp::Delete { + table, + row_id, + old_values, + }); + } + + /// Returns operations for rollback (in reverse order). + pub fn rollback_ops(&self) -> impl Iterator { + self.operations.iter().rev() + } + + /// Returns operations for commit. + pub fn commit_ops(&self) -> &[TransactionOp] { + &self.operations + } + + /// Marks the transaction as committed. + pub fn mark_committed(&mut self) { + self.state = TransactionState::Committed; + } + + /// Marks the transaction as rolled back. + pub fn mark_rolled_back(&mut self) { + self.state = TransactionState::RolledBack; + } +} + +/// Transaction manager. +pub struct TransactionManager { + /// Active transactions. + transactions: RwLock>, +} + +impl TransactionManager { + /// Creates a new transaction manager. + pub fn new() -> Self { + Self { + transactions: RwLock::new(HashMap::new()), + } + } + + /// Begins a new transaction. + pub fn begin(&self, isolation: IsolationLevel) -> TransactionId { + let txn = Transaction::new(isolation); + let id = txn.id; + self.transactions.write().insert(id, txn); + id + } + + /// Gets a transaction by ID. + pub fn get(&self, id: TransactionId) -> Option { + self.transactions.read().get(&id).cloned() + } + + /// Records an operation in a transaction. + pub fn record_op(&self, id: TransactionId, op: TransactionOp) -> Result<(), SqlError> { + let mut txns = self.transactions.write(); + let txn = txns + .get_mut(&id) + .ok_or_else(|| SqlError::Transaction(format!("Transaction {} not found", id)))?; + + if !txn.is_active() { + return Err(SqlError::Transaction(format!("Transaction {} is not active", id))); + } + + txn.operations.push(op); + Ok(()) + } + + /// Commits a transaction. + pub fn commit(&self, id: TransactionId) -> Result, SqlError> { + let mut txns = self.transactions.write(); + let txn = txns + .get_mut(&id) + .ok_or_else(|| SqlError::Transaction(format!("Transaction {} not found", id)))?; + + if !txn.is_active() { + return Err(SqlError::Transaction(format!("Transaction {} is not active", id))); + } + + txn.mark_committed(); + let ops = txn.operations.clone(); + txns.remove(&id); + Ok(ops) + } + + /// Rolls back a transaction, returning operations to undo. + pub fn rollback(&self, id: TransactionId) -> Result, SqlError> { + let mut txns = self.transactions.write(); + let txn = txns + .get_mut(&id) + .ok_or_else(|| SqlError::Transaction(format!("Transaction {} not found", id)))?; + + if !txn.is_active() { + return Err(SqlError::Transaction(format!("Transaction {} is not active", id))); + } + + txn.mark_rolled_back(); + let ops: Vec = txn.operations.iter().rev().cloned().collect(); + txns.remove(&id); + Ok(ops) + } + + /// Returns the number of active transactions. + pub fn active_count(&self) -> usize { + self.transactions.read().len() + } + + /// Checks if a transaction exists and is active. + pub fn is_active(&self, id: TransactionId) -> bool { + self.transactions + .read() + .get(&id) + .map(|t| t.is_active()) + .unwrap_or(false) + } +} + +impl Default for TransactionManager { + fn default() -> Self { + Self::new() + } +} + +impl Clone for Transaction { + fn clone(&self) -> Self { + Self { + id: self.id, + state: self.state, + operations: self.operations.clone(), + started_at: self.started_at, + isolation: self.isolation, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_transaction_lifecycle() { + let manager = TransactionManager::new(); + + let txn_id = manager.begin(IsolationLevel::ReadCommitted); + assert!(manager.is_active(txn_id)); + assert_eq!(manager.active_count(), 1); + + let mut values = HashMap::new(); + values.insert("name".to_string(), SqlValue::Text("Alice".to_string())); + + manager + .record_op( + txn_id, + TransactionOp::Insert { + table: "users".to_string(), + row_id: RowId(1), + values, + }, + ) + .unwrap(); + + let ops = manager.commit(txn_id).unwrap(); + assert_eq!(ops.len(), 1); + assert_eq!(manager.active_count(), 0); + } + + #[test] + fn test_transaction_rollback() { + let manager = TransactionManager::new(); + + let txn_id = manager.begin(IsolationLevel::ReadCommitted); + + let mut values = HashMap::new(); + values.insert("name".to_string(), SqlValue::Text("Bob".to_string())); + + manager + .record_op( + txn_id, + TransactionOp::Insert { + table: "users".to_string(), + row_id: RowId(1), + values, + }, + ) + .unwrap(); + + let ops = manager.rollback(txn_id).unwrap(); + assert_eq!(ops.len(), 1); + assert_eq!(manager.active_count(), 0); + } + + #[test] + fn test_transaction_not_found() { + let manager = TransactionManager::new(); + let fake_id = TransactionId(99999); + + assert!(!manager.is_active(fake_id)); + assert!(manager.commit(fake_id).is_err()); + assert!(manager.rollback(fake_id).is_err()); + } +} diff --git a/crates/synor-database/src/sql/types.rs b/crates/synor-database/src/sql/types.rs new file mode 100644 index 0000000..b3bf1f7 --- /dev/null +++ b/crates/synor-database/src/sql/types.rs @@ -0,0 +1,368 @@ +//! SQL type system. + +use serde::{Deserialize, Serialize}; +use serde_json::Value as JsonValue; +use std::cmp::Ordering; +use std::hash::{Hash, Hasher}; +use thiserror::Error; + +/// SQL data types. +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub enum SqlType { + /// 64-bit signed integer. + Integer, + /// 64-bit floating point. + Real, + /// UTF-8 text string. + Text, + /// Binary data. + Blob, + /// Boolean value. + Boolean, + /// Unix timestamp in milliseconds. + Timestamp, + /// Null type. + Null, +} + +impl SqlType { + /// Parses type from string. + pub fn from_str(s: &str) -> Option { + match s.to_uppercase().as_str() { + "INTEGER" | "INT" | "BIGINT" | "SMALLINT" => Some(SqlType::Integer), + "REAL" | "FLOAT" | "DOUBLE" => Some(SqlType::Real), + "TEXT" | "VARCHAR" | "CHAR" | "STRING" => Some(SqlType::Text), + "BLOB" | "BINARY" | "BYTES" => Some(SqlType::Blob), + "BOOLEAN" | "BOOL" => Some(SqlType::Boolean), + "TIMESTAMP" | "DATETIME" | "DATE" => Some(SqlType::Timestamp), + _ => None, + } + } + + /// Returns the default value for this type. + pub fn default_value(&self) -> SqlValue { + match self { + SqlType::Integer => SqlValue::Integer(0), + SqlType::Real => SqlValue::Real(0.0), + SqlType::Text => SqlValue::Text(String::new()), + SqlType::Blob => SqlValue::Blob(Vec::new()), + SqlType::Boolean => SqlValue::Boolean(false), + SqlType::Timestamp => SqlValue::Timestamp(0), + SqlType::Null => SqlValue::Null, + } + } +} + +/// SQL value types. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub enum SqlValue { + /// Null value. + Null, + /// Integer value. + Integer(i64), + /// Real (floating point) value. + Real(f64), + /// Text string value. + Text(String), + /// Binary blob value. + Blob(Vec), + /// Boolean value. + Boolean(bool), + /// Timestamp value (Unix ms). + Timestamp(u64), +} + +impl SqlValue { + /// Returns the SQL type of this value. + pub fn sql_type(&self) -> SqlType { + match self { + SqlValue::Null => SqlType::Null, + SqlValue::Integer(_) => SqlType::Integer, + SqlValue::Real(_) => SqlType::Real, + SqlValue::Text(_) => SqlType::Text, + SqlValue::Blob(_) => SqlType::Blob, + SqlValue::Boolean(_) => SqlType::Boolean, + SqlValue::Timestamp(_) => SqlType::Timestamp, + } + } + + /// Returns true if this is a null value. + pub fn is_null(&self) -> bool { + matches!(self, SqlValue::Null) + } + + /// Converts to integer if possible. + pub fn as_integer(&self) -> Option { + match self { + SqlValue::Integer(i) => Some(*i), + SqlValue::Real(f) => Some(*f as i64), + SqlValue::Boolean(b) => Some(if *b { 1 } else { 0 }), + SqlValue::Timestamp(t) => Some(*t as i64), + _ => None, + } + } + + /// Converts to real if possible. + pub fn as_real(&self) -> Option { + match self { + SqlValue::Integer(i) => Some(*i as f64), + SqlValue::Real(f) => Some(*f), + SqlValue::Timestamp(t) => Some(*t as f64), + _ => None, + } + } + + /// Converts to text if possible. + pub fn as_text(&self) -> Option<&str> { + match self { + SqlValue::Text(s) => Some(s), + _ => None, + } + } + + /// Converts to boolean if possible. + pub fn as_boolean(&self) -> Option { + match self { + SqlValue::Boolean(b) => Some(*b), + SqlValue::Integer(i) => Some(*i != 0), + _ => None, + } + } + + /// Converts to JSON value. + pub fn to_json(&self) -> JsonValue { + match self { + SqlValue::Null => JsonValue::Null, + SqlValue::Integer(i) => JsonValue::Number((*i).into()), + SqlValue::Real(f) => serde_json::Number::from_f64(*f) + .map(JsonValue::Number) + .unwrap_or(JsonValue::Null), + SqlValue::Text(s) => JsonValue::String(s.clone()), + SqlValue::Blob(b) => JsonValue::String(hex::encode(b)), + SqlValue::Boolean(b) => JsonValue::Bool(*b), + SqlValue::Timestamp(t) => JsonValue::Number((*t).into()), + } + } + + /// Returns a numeric type order for comparison purposes. + fn type_order(&self) -> u8 { + match self { + SqlValue::Null => 0, + SqlValue::Boolean(_) => 1, + SqlValue::Integer(_) => 2, + SqlValue::Real(_) => 3, + SqlValue::Text(_) => 4, + SqlValue::Blob(_) => 5, + SqlValue::Timestamp(_) => 6, + } + } + + /// Creates from JSON value. + pub fn from_json(value: &JsonValue) -> Self { + match value { + JsonValue::Null => SqlValue::Null, + JsonValue::Bool(b) => SqlValue::Boolean(*b), + JsonValue::Number(n) => { + if let Some(i) = n.as_i64() { + SqlValue::Integer(i) + } else if let Some(f) = n.as_f64() { + SqlValue::Real(f) + } else { + SqlValue::Null + } + } + JsonValue::String(s) => SqlValue::Text(s.clone()), + _ => SqlValue::Text(value.to_string()), + } + } +} + +impl PartialEq for SqlValue { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (SqlValue::Null, SqlValue::Null) => true, + (SqlValue::Integer(a), SqlValue::Integer(b)) => a == b, + (SqlValue::Real(a), SqlValue::Real(b)) => a == b, + (SqlValue::Text(a), SqlValue::Text(b)) => a == b, + (SqlValue::Blob(a), SqlValue::Blob(b)) => a == b, + (SqlValue::Boolean(a), SqlValue::Boolean(b)) => a == b, + (SqlValue::Timestamp(a), SqlValue::Timestamp(b)) => a == b, + // Cross-type comparisons + (SqlValue::Integer(a), SqlValue::Real(b)) => (*a as f64) == *b, + (SqlValue::Real(a), SqlValue::Integer(b)) => *a == (*b as f64), + _ => false, + } + } +} + +impl Eq for SqlValue {} + +impl Hash for SqlValue { + fn hash(&self, state: &mut H) { + std::mem::discriminant(self).hash(state); + match self { + SqlValue::Null => {} + SqlValue::Integer(i) => i.hash(state), + SqlValue::Real(f) => f.to_bits().hash(state), + SqlValue::Text(s) => s.hash(state), + SqlValue::Blob(b) => b.hash(state), + SqlValue::Boolean(b) => b.hash(state), + SqlValue::Timestamp(t) => t.hash(state), + } + } +} + +impl PartialOrd for SqlValue { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for SqlValue { + fn cmp(&self, other: &Self) -> Ordering { + match (self, other) { + (SqlValue::Null, SqlValue::Null) => Ordering::Equal, + (SqlValue::Null, _) => Ordering::Less, + (_, SqlValue::Null) => Ordering::Greater, + (SqlValue::Integer(a), SqlValue::Integer(b)) => a.cmp(b), + (SqlValue::Real(a), SqlValue::Real(b)) => { + // Convert to bits for total ordering (handles NaN) + a.to_bits().cmp(&b.to_bits()) + } + (SqlValue::Text(a), SqlValue::Text(b)) => a.cmp(b), + (SqlValue::Blob(a), SqlValue::Blob(b)) => a.cmp(b), + (SqlValue::Boolean(a), SqlValue::Boolean(b)) => a.cmp(b), + (SqlValue::Timestamp(a), SqlValue::Timestamp(b)) => a.cmp(b), + (SqlValue::Integer(a), SqlValue::Real(b)) => { + (*a as f64).to_bits().cmp(&b.to_bits()) + } + (SqlValue::Real(a), SqlValue::Integer(b)) => { + a.to_bits().cmp(&(*b as f64).to_bits()) + } + // Different types: order by type discriminant + _ => self.type_order().cmp(&other.type_order()), + } + } +} + +impl Default for SqlValue { + fn default() -> Self { + SqlValue::Null + } +} + +impl std::fmt::Display for SqlValue { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + SqlValue::Null => write!(f, "NULL"), + SqlValue::Integer(i) => write!(f, "{}", i), + SqlValue::Real(r) => write!(f, "{}", r), + SqlValue::Text(s) => write!(f, "'{}'", s), + SqlValue::Blob(b) => write!(f, "X'{}'", hex::encode(b)), + SqlValue::Boolean(b) => write!(f, "{}", if *b { "TRUE" } else { "FALSE" }), + SqlValue::Timestamp(t) => write!(f, "{}", t), + } + } +} + +/// SQL errors. +#[derive(Debug, Error)] +pub enum SqlError { + /// Parse error. + #[error("Parse error: {0}")] + Parse(String), + + /// Table not found. + #[error("Table not found: {0}")] + TableNotFound(String), + + /// Table already exists. + #[error("Table already exists: {0}")] + TableExists(String), + + /// Column not found. + #[error("Column not found: {0}")] + ColumnNotFound(String), + + /// Type mismatch. + #[error("Type mismatch: expected {expected}, got {got}")] + TypeMismatch { expected: String, got: String }, + + /// Constraint violation. + #[error("Constraint violation: {0}")] + ConstraintViolation(String), + + /// Primary key violation. + #[error("Primary key violation: duplicate key {0}")] + PrimaryKeyViolation(String), + + /// Not null violation. + #[error("Not null violation: column {0} cannot be null")] + NotNullViolation(String), + + /// Transaction error. + #[error("Transaction error: {0}")] + Transaction(String), + + /// Invalid operation. + #[error("Invalid operation: {0}")] + InvalidOperation(String), + + /// Unsupported feature. + #[error("Unsupported: {0}")] + Unsupported(String), + + /// Internal error. + #[error("Internal error: {0}")] + Internal(String), +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_sql_type_from_str() { + assert_eq!(SqlType::from_str("INTEGER"), Some(SqlType::Integer)); + assert_eq!(SqlType::from_str("int"), Some(SqlType::Integer)); + assert_eq!(SqlType::from_str("TEXT"), Some(SqlType::Text)); + assert_eq!(SqlType::from_str("varchar"), Some(SqlType::Text)); + assert_eq!(SqlType::from_str("BOOLEAN"), Some(SqlType::Boolean)); + assert_eq!(SqlType::from_str("unknown"), None); + } + + #[test] + fn test_sql_value_conversions() { + let int_val = SqlValue::Integer(42); + assert_eq!(int_val.as_integer(), Some(42)); + assert_eq!(int_val.as_real(), Some(42.0)); + + let real_val = SqlValue::Real(3.14); + assert_eq!(real_val.as_real(), Some(3.14)); + assert_eq!(real_val.as_integer(), Some(3)); + + let text_val = SqlValue::Text("hello".to_string()); + assert_eq!(text_val.as_text(), Some("hello")); + } + + #[test] + fn test_sql_value_comparison() { + assert_eq!(SqlValue::Integer(5), SqlValue::Integer(5)); + assert!(SqlValue::Integer(5) < SqlValue::Integer(10)); + assert!(SqlValue::Text("a".to_string()) < SqlValue::Text("b".to_string())); + assert!(SqlValue::Null < SqlValue::Integer(0)); + } + + #[test] + fn test_sql_value_json() { + let val = SqlValue::Integer(42); + assert_eq!(val.to_json(), serde_json::json!(42)); + + let val = SqlValue::Text("hello".to_string()); + assert_eq!(val.to_json(), serde_json::json!("hello")); + + let json = serde_json::json!(true); + let val = SqlValue::from_json(&json); + assert_eq!(val, SqlValue::Boolean(true)); + } +} diff --git a/docs/PLAN/PHASE10-Database-Advanced.md b/docs/PLAN/PHASE10-Database-Advanced.md new file mode 100644 index 0000000..43544ca --- /dev/null +++ b/docs/PLAN/PHASE10-Database-Advanced.md @@ -0,0 +1,505 @@ +# Phase 10 Advanced Database Features + +> Implementation plan for SQL, Graph, and Replication features in Synor Database L2 + +## Overview + +These advanced features extend the Synor Database to support: +1. **Relational (SQL)** - SQLite-compatible query subset for structured data +2. **Graph Store** - Relationship queries for connected data +3. **Replication** - Raft consensus for high availability + +## Feature 1: Relational (SQL) Store + +### Purpose +Provide a familiar SQL interface for developers who need structured relational queries, joins, and ACID transactions. + +### Architecture + +```text +┌─────────────────────────────────────────────────────────────┐ +│ SQL QUERY LAYER │ +├─────────────────────────────────────────────────────────────┤ +│ │ +│ ┌──────────────┐ ┌──────────────┐ ┌──────────────────┐ │ +│ │ SQL Parser │ │ Planner │ │ Executor │ │ +│ │ (sqlparser) │ │ (logical) │ │ (physical) │ │ +│ └──────────────┘ └──────────────┘ └──────────────────┘ │ +│ │ +│ ┌──────────────────────────────────────────────────────┐ │ +│ │ Table Storage Engine │ │ +│ │ - Row-oriented storage │ │ +│ │ - B-tree indexes │ │ +│ │ - Transaction log (WAL) │ │ +│ └──────────────────────────────────────────────────────┘ │ +│ │ +└─────────────────────────────────────────────────────────────┘ +``` + +### Supported SQL Subset + +| Category | Statements | +|----------|------------| +| DDL | CREATE TABLE, DROP TABLE, ALTER TABLE | +| DML | SELECT, INSERT, UPDATE, DELETE | +| Clauses | WHERE, ORDER BY, LIMIT, OFFSET, GROUP BY, HAVING | +| Joins | INNER JOIN, LEFT JOIN, RIGHT JOIN | +| Functions | COUNT, SUM, AVG, MIN, MAX, COALESCE | +| Operators | =, !=, <, >, <=, >=, AND, OR, NOT, IN, LIKE | + +### Data Types + +| SQL Type | Rust Type | Storage | +|----------|-----------|---------| +| INTEGER | i64 | 8 bytes | +| REAL | f64 | 8 bytes | +| TEXT | String | Variable | +| BLOB | Vec | Variable | +| BOOLEAN | bool | 1 byte | +| TIMESTAMP | u64 | 8 bytes (Unix ms) | + +### Implementation Components + +``` +crates/synor-database/src/sql/ +├── mod.rs # Module exports +├── parser.rs # SQL parsing (sqlparser-rs) +├── planner.rs # Query planning & optimization +├── executor.rs # Query execution engine +├── table.rs # Table definition & storage +├── row.rs # Row representation +├── types.rs # SQL type system +├── transaction.rs # ACID transactions +└── index.rs # SQL-specific indexing +``` + +### API Design + +```rust +// Table definition +pub struct TableDef { + pub name: String, + pub columns: Vec, + pub primary_key: Option, + pub indexes: Vec, +} + +pub struct ColumnDef { + pub name: String, + pub data_type: SqlType, + pub nullable: bool, + pub default: Option, +} + +// SQL execution +pub struct SqlEngine { + tables: HashMap, + transaction_log: TransactionLog, +} + +impl SqlEngine { + pub fn execute(&mut self, sql: &str) -> Result; + pub fn begin_transaction(&mut self) -> TransactionId; + pub fn commit(&mut self, txn: TransactionId) -> Result<(), SqlError>; + pub fn rollback(&mut self, txn: TransactionId); +} +``` + +### Gateway Endpoints + +| Endpoint | Method | Description | +|----------|--------|-------------| +| `/db/:db/sql` | POST | Execute SQL query | +| `/db/:db/sql/tables` | GET | List tables | +| `/db/:db/sql/tables/:table` | GET | Get table schema | +| `/db/:db/sql/tables/:table` | DELETE | Drop table | + +--- + +## Feature 2: Graph Store + +### Purpose +Enable relationship-based queries for social networks, knowledge graphs, recommendation engines, and any connected data. + +### Architecture + +```text +┌─────────────────────────────────────────────────────────────┐ +│ GRAPH QUERY LAYER │ +├─────────────────────────────────────────────────────────────┤ +│ │ +│ ┌──────────────┐ ┌──────────────┐ ┌──────────────────┐ │ +│ │ Graph Query │ │ Traversal │ │ Path Finding │ │ +│ │ Parser │ │ Engine │ │ (Dijkstra) │ │ +│ └──────────────┘ └──────────────┘ └──────────────────┘ │ +│ │ +│ ┌──────────────────────────────────────────────────────┐ │ +│ │ Graph Storage Engine │ │ +│ │ - Adjacency list storage │ │ +│ │ - Edge index (source, target, type) │ │ +│ │ - Property storage │ │ +│ └──────────────────────────────────────────────────────┘ │ +│ │ +└─────────────────────────────────────────────────────────────┘ +``` + +### Data Model + +```text +Node (Vertex): + - id: NodeId (32 bytes) + - labels: Vec + - properties: JsonValue + +Edge (Relationship): + - id: EdgeId (32 bytes) + - source: NodeId + - target: NodeId + - edge_type: String + - properties: JsonValue + - directed: bool +``` + +### Query Language (Simplified Cypher-like) + +``` +// Find all friends of user Alice +MATCH (a:User {name: "Alice"})-[:FRIEND]->(friend) +RETURN friend + +// Find shortest path between two nodes +MATCH path = shortestPath((a:User {id: "123"})-[*]-(b:User {id: "456"})) +RETURN path + +// Find mutual friends +MATCH (a:User {name: "Alice"})-[:FRIEND]->(mutual)<-[:FRIEND]-(b:User {name: "Bob"}) +RETURN mutual +``` + +### Implementation Components + +``` +crates/synor-database/src/graph/ +├── mod.rs # Module exports +├── node.rs # Node definition & storage +├── edge.rs # Edge definition & storage +├── store.rs # Graph storage engine +├── query.rs # Query language parser +├── traversal.rs # Graph traversal algorithms +├── path.rs # Path finding (BFS, DFS, Dijkstra) +└── index.rs # Graph-specific indexes +``` + +### API Design + +```rust +pub struct Node { + pub id: NodeId, + pub labels: Vec, + pub properties: JsonValue, +} + +pub struct Edge { + pub id: EdgeId, + pub source: NodeId, + pub target: NodeId, + pub edge_type: String, + pub properties: JsonValue, +} + +pub struct GraphStore { + nodes: HashMap, + edges: HashMap, + adjacency: HashMap>, // outgoing + reverse_adj: HashMap>, // incoming +} + +impl GraphStore { + // Node operations + pub fn create_node(&mut self, labels: Vec, props: JsonValue) -> NodeId; + pub fn get_node(&self, id: &NodeId) -> Option<&Node>; + pub fn update_node(&mut self, id: &NodeId, props: JsonValue) -> Result<(), GraphError>; + pub fn delete_node(&mut self, id: &NodeId) -> Result<(), GraphError>; + + // Edge operations + pub fn create_edge(&mut self, source: NodeId, target: NodeId, edge_type: &str, props: JsonValue) -> EdgeId; + pub fn get_edge(&self, id: &EdgeId) -> Option<&Edge>; + pub fn delete_edge(&mut self, id: &EdgeId) -> Result<(), GraphError>; + + // Traversal + pub fn neighbors(&self, id: &NodeId, direction: Direction) -> Vec<&Node>; + pub fn edges_of(&self, id: &NodeId, direction: Direction) -> Vec<&Edge>; + pub fn shortest_path(&self, from: &NodeId, to: &NodeId) -> Option>; + pub fn traverse(&self, start: &NodeId, query: &TraversalQuery) -> Vec; +} +``` + +### Gateway Endpoints + +| Endpoint | Method | Description | +|----------|--------|-------------| +| `/db/:db/graph/nodes` | POST | Create node | +| `/db/:db/graph/nodes/:id` | GET | Get node | +| `/db/:db/graph/nodes/:id` | PUT | Update node | +| `/db/:db/graph/nodes/:id` | DELETE | Delete node | +| `/db/:db/graph/edges` | POST | Create edge | +| `/db/:db/graph/edges/:id` | GET | Get edge | +| `/db/:db/graph/edges/:id` | DELETE | Delete edge | +| `/db/:db/graph/query` | POST | Execute graph query | +| `/db/:db/graph/path` | POST | Find shortest path | +| `/db/:db/graph/traverse` | POST | Traverse from node | + +--- + +## Feature 3: Replication (Raft Consensus) + +### Purpose +Provide high availability and fault tolerance through distributed consensus, ensuring data consistency across multiple nodes. + +### Architecture + +```text +┌─────────────────────────────────────────────────────────────┐ +│ RAFT CONSENSUS LAYER │ +├─────────────────────────────────────────────────────────────┤ +│ │ +│ ┌──────────────┐ ┌──────────────┐ ┌──────────────────┐ │ +│ │ Leader │ │ Follower │ │ Candidate │ │ +│ │ Election │ │ Replication │ │ (Election) │ │ +│ └──────────────┘ └──────────────┘ └──────────────────┘ │ +│ │ +│ ┌──────────────────────────────────────────────────────┐ │ +│ │ Log Replication │ │ +│ │ - Append entries │ │ +│ │ - Commit index │ │ +│ │ - Log compaction (snapshots) │ │ +│ └──────────────────────────────────────────────────────┘ │ +│ │ +│ ┌──────────────────────────────────────────────────────┐ │ +│ │ State Machine │ │ +│ │ - Apply committed entries │ │ +│ │ - Database operations │ │ +│ └──────────────────────────────────────────────────────┘ │ +│ │ +└─────────────────────────────────────────────────────────────┘ +``` + +### Raft Protocol Overview + +```text +Leader Election: +1. Followers timeout → become Candidate +2. Candidate requests votes from peers +3. Majority votes → become Leader +4. Leader sends heartbeats to maintain authority + +Log Replication: +1. Client sends write to Leader +2. Leader appends to local log +3. Leader replicates to Followers +4. Majority acknowledge → entry committed +5. Leader applies to state machine +6. Leader responds to client +``` + +### Implementation Components + +``` +crates/synor-database/src/replication/ +├── mod.rs # Module exports +├── raft.rs # Core Raft implementation +├── state.rs # Node state (Leader/Follower/Candidate) +├── log.rs # Replicated log +├── rpc.rs # RPC messages (AppendEntries, RequestVote) +├── election.rs # Leader election logic +├── snapshot.rs # Log compaction & snapshots +├── cluster.rs # Cluster membership +└── client.rs # Client for forwarding to leader +``` + +### API Design + +```rust +#[derive(Clone, Copy, PartialEq)] +pub enum NodeRole { + Leader, + Follower, + Candidate, +} + +pub struct RaftConfig { + pub node_id: u64, + pub peers: Vec, + pub election_timeout_ms: (u64, u64), // min, max + pub heartbeat_interval_ms: u64, + pub snapshot_threshold: u64, +} + +pub struct LogEntry { + pub term: u64, + pub index: u64, + pub command: Command, +} + +pub enum Command { + // Database operations + KvSet { key: String, value: Vec }, + KvDelete { key: String }, + DocInsert { collection: String, doc: JsonValue }, + DocUpdate { collection: String, id: DocumentId, update: JsonValue }, + DocDelete { collection: String, id: DocumentId }, + // ... other operations +} + +pub struct RaftNode { + config: RaftConfig, + state: RaftState, + log: ReplicatedLog, + state_machine: Arc, +} + +impl RaftNode { + pub async fn start(&mut self) -> Result<(), RaftError>; + pub async fn propose(&self, command: Command) -> Result<(), RaftError>; + pub fn is_leader(&self) -> bool; + pub fn leader_id(&self) -> Option; + pub fn status(&self) -> ClusterStatus; +} + +// RPC Messages +pub struct AppendEntries { + pub term: u64, + pub leader_id: u64, + pub prev_log_index: u64, + pub prev_log_term: u64, + pub entries: Vec, + pub leader_commit: u64, +} + +pub struct RequestVote { + pub term: u64, + pub candidate_id: u64, + pub last_log_index: u64, + pub last_log_term: u64, +} +``` + +### Cluster Configuration + +```yaml +# docker-compose.raft.yml +services: + db-node-1: + image: synor/database:latest + environment: + RAFT_NODE_ID: 1 + RAFT_PEERS: "db-node-2:5000,db-node-3:5000" + RAFT_ELECTION_TIMEOUT: "150-300" + RAFT_HEARTBEAT_MS: 50 + ports: + - "8484:8484" # HTTP API + - "5000:5000" # Raft RPC + + db-node-2: + image: synor/database:latest + environment: + RAFT_NODE_ID: 2 + RAFT_PEERS: "db-node-1:5000,db-node-3:5000" + ports: + - "8485:8484" + - "5001:5000" + + db-node-3: + image: synor/database:latest + environment: + RAFT_NODE_ID: 3 + RAFT_PEERS: "db-node-1:5000,db-node-2:5000" + ports: + - "8486:8484" + - "5002:5000" +``` + +### Gateway Endpoints + +| Endpoint | Method | Description | +|----------|--------|-------------| +| `/cluster/status` | GET | Get cluster status | +| `/cluster/leader` | GET | Get current leader | +| `/cluster/nodes` | GET | List all nodes | +| `/cluster/nodes/:id` | DELETE | Remove node from cluster | +| `/cluster/nodes` | POST | Add node to cluster | + +--- + +## Implementation Order + +### Step 1: SQL Store +1. Add `sqlparser` dependency +2. Implement type system (`types.rs`) +3. Implement row storage (`row.rs`, `table.rs`) +4. Implement SQL parser wrapper (`parser.rs`) +5. Implement query planner (`planner.rs`) +6. Implement query executor (`executor.rs`) +7. Add transaction support (`transaction.rs`) +8. Add gateway endpoints +9. Write tests + +### Step 2: Graph Store +1. Implement node/edge types (`node.rs`, `edge.rs`) +2. Implement graph storage (`store.rs`) +3. Implement query parser (`query.rs`) +4. Implement traversal algorithms (`traversal.rs`, `path.rs`) +5. Add graph indexes (`index.rs`) +6. Add gateway endpoints +7. Write tests + +### Step 3: Replication +1. Implement Raft state machine (`state.rs`, `raft.rs`) +2. Implement replicated log (`log.rs`) +3. Implement RPC layer (`rpc.rs`) +4. Implement leader election (`election.rs`) +5. Implement log compaction (`snapshot.rs`) +6. Implement cluster management (`cluster.rs`) +7. Integrate with database operations +8. Add gateway endpoints +9. Write tests +10. Create Docker Compose for cluster + +--- + +## Pricing Impact + +| Feature | Operation | Cost (SYNOR) | +|---------|-----------|--------------| +| SQL | Query/million | 0.02 | +| SQL | Write/million | 0.05 | +| Graph | Traversal/million | 0.03 | +| Graph | Path query/million | 0.05 | +| Replication | Included | Base storage cost | + +--- + +## Success Criteria + +### SQL Store +- [ ] Parse and execute basic SELECT, INSERT, UPDATE, DELETE +- [ ] Support WHERE clauses with operators +- [ ] Support ORDER BY, LIMIT, OFFSET +- [ ] Support simple JOINs +- [ ] Support aggregate functions +- [ ] ACID transactions + +### Graph Store +- [ ] Create/read/update/delete nodes and edges +- [ ] Traverse neighbors (in/out/both) +- [ ] Find shortest path between nodes +- [ ] Execute pattern matching queries +- [ ] Support property filters + +### Replication +- [ ] Leader election works correctly +- [ ] Log replication achieves consensus +- [ ] Reads from any node (eventual consistency) +- [ ] Writes only through leader +- [ ] Node failure handled gracefully +- [ ] Log compaction reduces storage