feat(database): add SQL, Graph, and Raft Replication modules

- SQL store with SQLite-compatible subset (sqlparser 0.43)
  - CREATE TABLE, INSERT, SELECT, UPDATE, DELETE
  - WHERE clauses, ORDER BY, LIMIT
  - Aggregates (COUNT, SUM, AVG, MIN, MAX)
  - UNIQUE and NOT NULL constraints
  - BTreeMap-based indexes

- Graph store for relationship-based queries
  - Nodes with labels and properties
  - Edges with types and weights
  - BFS/DFS traversal
  - Dijkstra shortest path
  - Cypher-like query parser (MATCH, CREATE, DELETE, SET)

- Raft consensus replication for high availability
  - Leader election with randomized timeouts
  - Log replication with AppendEntries RPC
  - Snapshot management for log compaction
  - Cluster configuration and joint consensus
  - Full RPC message serialization

All 159 tests pass.
This commit is contained in:
Gulshan Yadav 2026-01-10 19:32:14 +05:30
parent ab4c967a97
commit 8da34bc73d
25 changed files with 9842 additions and 0 deletions

View file

@ -15,6 +15,7 @@ synor-storage = { path = "../synor-storage" }
serde.workspace = true serde.workspace = true
serde_json.workspace = true serde_json.workspace = true
borsh.workspace = true borsh.workspace = true
bincode = "1.3"
# Utilities # Utilities
thiserror.workspace = true thiserror.workspace = true
@ -36,6 +37,9 @@ indexmap = "2.2"
axum.workspace = true axum.workspace = true
tower-http = { version = "0.5", features = ["cors", "trace"] } tower-http = { version = "0.5", features = ["cors", "trace"] }
# SQL parsing
sqlparser = "0.43"
# Vector operations (for AI/RAG) # Vector operations (for AI/RAG)
# Using pure Rust for portability # Using pure Rust for portability

View file

@ -0,0 +1,301 @@
//! Graph edge (relationship) definition.
use super::node::NodeId;
use serde::{Deserialize, Serialize};
use serde_json::Value as JsonValue;
use std::sync::atomic::{AtomicU64, Ordering};
/// Unique edge identifier.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct EdgeId(pub [u8; 32]);
impl EdgeId {
/// Creates a new unique edge ID.
pub fn new() -> Self {
static COUNTER: AtomicU64 = AtomicU64::new(1);
let id = COUNTER.fetch_add(1, Ordering::SeqCst);
let mut bytes = [0u8; 32];
bytes[..8].copy_from_slice(&id.to_be_bytes());
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos() as u64;
bytes[8..16].copy_from_slice(&now.to_be_bytes());
EdgeId(*blake3::hash(&bytes).as_bytes())
}
/// Creates from raw bytes.
pub fn from_bytes(bytes: [u8; 32]) -> Self {
EdgeId(bytes)
}
/// Creates from hex string.
pub fn from_hex(hex: &str) -> Option<Self> {
let bytes = hex::decode(hex).ok()?;
if bytes.len() != 32 {
return None;
}
let mut arr = [0u8; 32];
arr.copy_from_slice(&bytes);
Some(EdgeId(arr))
}
/// Returns the bytes.
pub fn as_bytes(&self) -> &[u8; 32] {
&self.0
}
/// Converts to hex string.
pub fn to_hex(&self) -> String {
hex::encode(self.0)
}
}
impl Default for EdgeId {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Display for EdgeId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "edge_{}", hex::encode(&self.0[..8]))
}
}
/// An edge (relationship) in the graph.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Edge {
/// Unique edge ID.
pub id: EdgeId,
/// Source node ID.
pub source: NodeId,
/// Target node ID.
pub target: NodeId,
/// Edge type (relationship type), e.g., "FRIEND", "OWNS".
pub edge_type: String,
/// Properties stored as JSON.
pub properties: JsonValue,
/// Whether this is a directed edge.
pub directed: bool,
/// Weight for path-finding algorithms.
pub weight: f64,
/// Creation timestamp.
pub created_at: u64,
}
impl Edge {
/// Creates a new directed edge.
pub fn new(source: NodeId, target: NodeId, edge_type: impl Into<String>, properties: JsonValue) -> Self {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis() as u64;
Self {
id: EdgeId::new(),
source,
target,
edge_type: edge_type.into(),
properties,
directed: true,
weight: 1.0,
created_at: now,
}
}
/// Creates an undirected edge.
pub fn undirected(source: NodeId, target: NodeId, edge_type: impl Into<String>, properties: JsonValue) -> Self {
let mut edge = Self::new(source, target, edge_type, properties);
edge.directed = false;
edge
}
/// Sets the weight for this edge.
pub fn with_weight(mut self, weight: f64) -> Self {
self.weight = weight;
self
}
/// Returns the other end of this edge from the given node.
pub fn other_end(&self, from: &NodeId) -> Option<NodeId> {
if &self.source == from {
Some(self.target)
} else if &self.target == from && !self.directed {
Some(self.source)
} else if &self.target == from {
// For directed edges, can't traverse backward
None
} else {
None
}
}
/// Checks if this edge connects the given node (as source or target).
pub fn connects(&self, node: &NodeId) -> bool {
&self.source == node || &self.target == node
}
/// Checks if this edge connects two specific nodes.
pub fn connects_pair(&self, a: &NodeId, b: &NodeId) -> bool {
(&self.source == a && &self.target == b) ||
(!self.directed && &self.source == b && &self.target == a)
}
/// Gets a property value.
pub fn get_property(&self, key: &str) -> Option<&JsonValue> {
self.properties.get(key)
}
/// Sets a property value.
pub fn set_property(&mut self, key: &str, value: JsonValue) {
if let Some(obj) = self.properties.as_object_mut() {
obj.insert(key.to_string(), value);
}
}
/// Checks if the edge matches a property filter.
pub fn matches_properties(&self, filter: &JsonValue) -> bool {
if let (Some(filter_obj), Some(props_obj)) = (filter.as_object(), self.properties.as_object()) {
for (key, expected) in filter_obj {
if let Some(actual) = props_obj.get(key) {
if actual != expected {
return false;
}
} else {
return false;
}
}
true
} else {
filter == &self.properties || filter == &JsonValue::Object(serde_json::Map::new())
}
}
}
/// Builder for creating edges.
pub struct EdgeBuilder {
source: NodeId,
target: NodeId,
edge_type: String,
properties: serde_json::Map<String, JsonValue>,
directed: bool,
weight: f64,
}
impl EdgeBuilder {
/// Creates a new edge builder.
pub fn new(source: NodeId, target: NodeId, edge_type: impl Into<String>) -> Self {
Self {
source,
target,
edge_type: edge_type.into(),
properties: serde_json::Map::new(),
directed: true,
weight: 1.0,
}
}
/// Sets the edge as undirected.
pub fn undirected(mut self) -> Self {
self.directed = false;
self
}
/// Sets the weight.
pub fn weight(mut self, weight: f64) -> Self {
self.weight = weight;
self
}
/// Sets a property.
pub fn property(mut self, key: impl Into<String>, value: impl Into<JsonValue>) -> Self {
self.properties.insert(key.into(), value.into());
self
}
/// Builds the edge.
pub fn build(self) -> Edge {
let mut edge = Edge::new(self.source, self.target, self.edge_type, JsonValue::Object(self.properties));
edge.directed = self.directed;
edge.weight = self.weight;
edge
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_edge_id() {
let id1 = EdgeId::new();
let id2 = EdgeId::new();
assert_ne!(id1, id2);
let hex = id1.to_hex();
let id3 = EdgeId::from_hex(&hex).unwrap();
assert_eq!(id1, id3);
}
#[test]
fn test_edge_creation() {
let source = NodeId::new();
let target = NodeId::new();
let edge = Edge::new(source, target, "FRIEND", serde_json::json!({"since": 2020}));
assert_eq!(edge.source, source);
assert_eq!(edge.target, target);
assert_eq!(edge.edge_type, "FRIEND");
assert!(edge.directed);
}
#[test]
fn test_edge_builder() {
let source = NodeId::new();
let target = NodeId::new();
let edge = EdgeBuilder::new(source, target, "OWNS")
.undirected()
.weight(2.5)
.property("percentage", 50)
.build();
assert!(!edge.directed);
assert_eq!(edge.weight, 2.5);
assert_eq!(edge.get_property("percentage"), Some(&serde_json::json!(50)));
}
#[test]
fn test_edge_other_end() {
let source = NodeId::new();
let target = NodeId::new();
// Directed edge
let directed = Edge::new(source, target, "A", serde_json::json!({}));
assert_eq!(directed.other_end(&source), Some(target));
assert_eq!(directed.other_end(&target), None); // Can't traverse backward
// Undirected edge
let undirected = Edge::undirected(source, target, "B", serde_json::json!({}));
assert_eq!(undirected.other_end(&source), Some(target));
assert_eq!(undirected.other_end(&target), Some(source));
}
#[test]
fn test_edge_connects() {
let a = NodeId::new();
let b = NodeId::new();
let c = NodeId::new();
let edge = Edge::new(a, b, "LINK", serde_json::json!({}));
assert!(edge.connects(&a));
assert!(edge.connects(&b));
assert!(!edge.connects(&c));
assert!(edge.connects_pair(&a, &b));
assert!(!edge.connects_pair(&b, &a)); // Directed
}
}

View file

@ -0,0 +1,18 @@
//! Graph store module for relationship-based queries.
//!
//! Provides a graph database with nodes, edges, and traversal algorithms
//! suitable for social networks, knowledge graphs, and recommendation engines.
pub mod edge;
pub mod node;
pub mod path;
pub mod query;
pub mod store;
pub mod traversal;
pub use edge::{Edge, EdgeId};
pub use node::{Node, NodeId};
pub use path::{PathFinder, PathResult};
pub use query::{GraphQuery, GraphQueryParser, MatchPattern, QueryResult};
pub use store::{Direction, GraphError, GraphStore};
pub use traversal::{TraversalQuery, TraversalResult, Traverser};

View file

@ -0,0 +1,315 @@
//! Graph node (vertex) definition.
use serde::{Deserialize, Serialize};
use serde_json::Value as JsonValue;
use std::sync::atomic::{AtomicU64, Ordering};
/// Unique node identifier.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct NodeId(pub [u8; 32]);
impl NodeId {
/// Creates a new unique node ID.
pub fn new() -> Self {
static COUNTER: AtomicU64 = AtomicU64::new(1);
let id = COUNTER.fetch_add(1, Ordering::SeqCst);
let mut bytes = [0u8; 32];
bytes[..8].copy_from_slice(&id.to_be_bytes());
// Add timestamp for uniqueness
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos() as u64;
bytes[8..16].copy_from_slice(&now.to_be_bytes());
NodeId(*blake3::hash(&bytes).as_bytes())
}
/// Creates from raw bytes.
pub fn from_bytes(bytes: [u8; 32]) -> Self {
NodeId(bytes)
}
/// Creates from hex string.
pub fn from_hex(hex: &str) -> Option<Self> {
let bytes = hex::decode(hex).ok()?;
if bytes.len() != 32 {
return None;
}
let mut arr = [0u8; 32];
arr.copy_from_slice(&bytes);
Some(NodeId(arr))
}
/// Returns the bytes.
pub fn as_bytes(&self) -> &[u8; 32] {
&self.0
}
/// Converts to hex string.
pub fn to_hex(&self) -> String {
hex::encode(self.0)
}
}
impl Default for NodeId {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Display for NodeId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "node_{}", hex::encode(&self.0[..8]))
}
}
/// A node (vertex) in the graph.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Node {
/// Unique node ID.
pub id: NodeId,
/// Labels (types) for this node (e.g., "User", "Product").
pub labels: Vec<String>,
/// Properties stored as JSON.
pub properties: JsonValue,
/// Creation timestamp.
pub created_at: u64,
/// Last update timestamp.
pub updated_at: u64,
}
impl Node {
/// Creates a new node with the given labels and properties.
pub fn new(labels: Vec<String>, properties: JsonValue) -> Self {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis() as u64;
Self {
id: NodeId::new(),
labels,
properties,
created_at: now,
updated_at: now,
}
}
/// Creates a node with a specific ID.
pub fn with_id(id: NodeId, labels: Vec<String>, properties: JsonValue) -> Self {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis() as u64;
Self {
id,
labels,
properties,
created_at: now,
updated_at: now,
}
}
/// Returns true if this node has the given label.
pub fn has_label(&self, label: &str) -> bool {
self.labels.iter().any(|l| l == label)
}
/// Adds a label to this node.
pub fn add_label(&mut self, label: impl Into<String>) {
let label = label.into();
if !self.labels.contains(&label) {
self.labels.push(label);
self.touch();
}
}
/// Removes a label from this node.
pub fn remove_label(&mut self, label: &str) -> bool {
if let Some(pos) = self.labels.iter().position(|l| l == label) {
self.labels.remove(pos);
self.touch();
true
} else {
false
}
}
/// Gets a property value.
pub fn get_property(&self, key: &str) -> Option<&JsonValue> {
self.properties.get(key)
}
/// Sets a property value.
pub fn set_property(&mut self, key: &str, value: JsonValue) {
if let Some(obj) = self.properties.as_object_mut() {
obj.insert(key.to_string(), value);
self.touch();
}
}
/// Removes a property.
pub fn remove_property(&mut self, key: &str) -> Option<JsonValue> {
if let Some(obj) = self.properties.as_object_mut() {
let removed = obj.remove(key);
if removed.is_some() {
self.touch();
}
removed
} else {
None
}
}
/// Updates the updated_at timestamp.
fn touch(&mut self) {
self.updated_at = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis() as u64;
}
/// Checks if the node matches a property filter.
pub fn matches_properties(&self, filter: &JsonValue) -> bool {
if let (Some(filter_obj), Some(props_obj)) = (filter.as_object(), self.properties.as_object()) {
for (key, expected) in filter_obj {
if let Some(actual) = props_obj.get(key) {
if actual != expected {
return false;
}
} else {
return false;
}
}
true
} else {
filter == &self.properties
}
}
}
/// Builder for creating nodes.
pub struct NodeBuilder {
labels: Vec<String>,
properties: serde_json::Map<String, JsonValue>,
}
impl NodeBuilder {
/// Creates a new node builder.
pub fn new() -> Self {
Self {
labels: Vec::new(),
properties: serde_json::Map::new(),
}
}
/// Adds a label.
pub fn label(mut self, label: impl Into<String>) -> Self {
self.labels.push(label.into());
self
}
/// Adds multiple labels.
pub fn labels(mut self, labels: Vec<String>) -> Self {
self.labels.extend(labels);
self
}
/// Sets a property.
pub fn property(mut self, key: impl Into<String>, value: impl Into<JsonValue>) -> Self {
self.properties.insert(key.into(), value.into());
self
}
/// Builds the node.
pub fn build(self) -> Node {
Node::new(self.labels, JsonValue::Object(self.properties))
}
}
impl Default for NodeBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_node_id() {
let id1 = NodeId::new();
let id2 = NodeId::new();
assert_ne!(id1, id2);
let hex = id1.to_hex();
let id3 = NodeId::from_hex(&hex).unwrap();
assert_eq!(id1, id3);
}
#[test]
fn test_node_creation() {
let node = Node::new(
vec!["User".to_string()],
serde_json::json!({"name": "Alice", "age": 30}),
);
assert!(node.has_label("User"));
assert!(!node.has_label("Admin"));
assert_eq!(
node.get_property("name"),
Some(&serde_json::json!("Alice"))
);
}
#[test]
fn test_node_builder() {
let node = NodeBuilder::new()
.label("Person")
.label("Developer")
.property("name", "Bob")
.property("skills", serde_json::json!(["Rust", "Python"]))
.build();
assert!(node.has_label("Person"));
assert!(node.has_label("Developer"));
assert_eq!(node.get_property("name"), Some(&serde_json::json!("Bob")));
}
#[test]
fn test_node_labels() {
let mut node = Node::new(vec!["User".to_string()], serde_json::json!({}));
node.add_label("Admin");
assert!(node.has_label("Admin"));
node.remove_label("Admin");
assert!(!node.has_label("Admin"));
}
#[test]
fn test_node_properties() {
let mut node = Node::new(vec![], serde_json::json!({}));
node.set_property("key", serde_json::json!("value"));
assert_eq!(node.get_property("key"), Some(&serde_json::json!("value")));
let removed = node.remove_property("key");
assert_eq!(removed, Some(serde_json::json!("value")));
assert_eq!(node.get_property("key"), None);
}
#[test]
fn test_node_matches() {
let node = Node::new(
vec!["User".to_string()],
serde_json::json!({"name": "Alice", "age": 30, "active": true}),
);
assert!(node.matches_properties(&serde_json::json!({"name": "Alice"})));
assert!(node.matches_properties(&serde_json::json!({"name": "Alice", "age": 30})));
assert!(!node.matches_properties(&serde_json::json!({"name": "Bob"})));
}
}

View file

@ -0,0 +1,485 @@
//! Path finding algorithms for graphs.
use super::edge::Edge;
use super::node::{Node, NodeId};
use super::store::{Direction, GraphStore};
use serde::{Deserialize, Serialize};
use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashMap, HashSet, VecDeque};
/// Result of a path finding operation.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct PathResult {
/// The path as a list of node IDs.
pub nodes: Vec<NodeId>,
/// The edges traversed.
pub edges: Vec<Edge>,
/// Total path length (number of hops or weighted distance).
pub length: f64,
/// Whether a path was found.
pub found: bool,
}
impl PathResult {
/// Creates an empty (not found) result.
pub fn not_found() -> Self {
Self {
nodes: Vec::new(),
edges: Vec::new(),
length: f64::INFINITY,
found: false,
}
}
/// Creates a found result.
pub fn found(nodes: Vec<NodeId>, edges: Vec<Edge>, length: f64) -> Self {
Self {
nodes,
edges,
length,
found: true,
}
}
}
/// State for Dijkstra's algorithm priority queue.
#[derive(Clone)]
struct DijkstraState {
node: NodeId,
distance: f64,
}
impl PartialEq for DijkstraState {
fn eq(&self, other: &Self) -> bool {
self.distance == other.distance
}
}
impl Eq for DijkstraState {}
impl Ord for DijkstraState {
fn cmp(&self, other: &Self) -> Ordering {
// Reverse ordering for min-heap
other.distance.partial_cmp(&self.distance).unwrap_or(Ordering::Equal)
}
}
impl PartialOrd for DijkstraState {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
/// Path finder for graph shortest path queries.
pub struct PathFinder<'a> {
store: &'a GraphStore,
}
impl<'a> PathFinder<'a> {
/// Creates a new path finder.
pub fn new(store: &'a GraphStore) -> Self {
Self { store }
}
/// Finds the shortest path using BFS (unweighted).
pub fn shortest_path_bfs(&self, from: &NodeId, to: &NodeId) -> PathResult {
if from == to {
return PathResult::found(vec![*from], Vec::new(), 0.0);
}
let mut visited = HashSet::new();
let mut queue: VecDeque<(NodeId, Vec<NodeId>, Vec<Edge>)> = VecDeque::new();
visited.insert(*from);
queue.push_back((*from, vec![*from], Vec::new()));
while let Some((current, path, edges)) = queue.pop_front() {
let neighbor_edges = self.store.edges_of(&current, Direction::Both);
for edge in neighbor_edges {
let neighbor = if edge.source == current {
edge.target
} else if !edge.directed || edge.target == current {
edge.source
} else {
continue;
};
if &neighbor == to {
let mut final_path = path.clone();
final_path.push(neighbor);
let path_len = final_path.len();
let mut final_edges = edges.clone();
final_edges.push(edge);
return PathResult::found(final_path, final_edges, path_len as f64 - 1.0);
}
if !visited.contains(&neighbor) {
visited.insert(neighbor);
let mut new_path = path.clone();
new_path.push(neighbor);
let mut new_edges = edges.clone();
new_edges.push(edge);
queue.push_back((neighbor, new_path, new_edges));
}
}
}
PathResult::not_found()
}
/// Finds the shortest path using Dijkstra's algorithm (weighted).
pub fn shortest_path_dijkstra(&self, from: &NodeId, to: &NodeId) -> PathResult {
if from == to {
return PathResult::found(vec![*from], Vec::new(), 0.0);
}
let mut distances: HashMap<NodeId, f64> = HashMap::new();
let mut previous: HashMap<NodeId, (NodeId, Edge)> = HashMap::new();
let mut heap = BinaryHeap::new();
let mut visited = HashSet::new();
distances.insert(*from, 0.0);
heap.push(DijkstraState { node: *from, distance: 0.0 });
while let Some(DijkstraState { node: current, distance: dist }) = heap.pop() {
if &current == to {
// Reconstruct path
let mut path = vec![current];
let mut edges = Vec::new();
let mut curr = current;
while let Some((prev, edge)) = previous.get(&curr) {
path.push(*prev);
edges.push(edge.clone());
curr = *prev;
}
path.reverse();
edges.reverse();
return PathResult::found(path, edges, dist);
}
if visited.contains(&current) {
continue;
}
visited.insert(current);
let neighbor_edges = self.store.edges_of(&current, Direction::Both);
for edge in neighbor_edges {
let neighbor = if edge.source == current {
edge.target
} else if !edge.directed || edge.target == current {
edge.source
} else {
continue;
};
if visited.contains(&neighbor) {
continue;
}
let new_dist = dist + edge.weight;
let is_shorter = distances.get(&neighbor).map(|&d| new_dist < d).unwrap_or(true);
if is_shorter {
distances.insert(neighbor, new_dist);
previous.insert(neighbor, (current, edge.clone()));
heap.push(DijkstraState { node: neighbor, distance: new_dist });
}
}
}
PathResult::not_found()
}
/// Finds all paths between two nodes up to a maximum length.
pub fn all_paths(&self, from: &NodeId, to: &NodeId, max_length: usize) -> Vec<PathResult> {
let mut results = Vec::new();
let mut current_path = vec![*from];
let mut current_edges = Vec::new();
let mut visited = HashSet::new();
visited.insert(*from);
self.find_all_paths_dfs(
from,
to,
max_length,
&mut current_path,
&mut current_edges,
&mut visited,
&mut results,
);
results
}
fn find_all_paths_dfs(
&self,
current: &NodeId,
target: &NodeId,
max_length: usize,
path: &mut Vec<NodeId>,
edges: &mut Vec<Edge>,
visited: &mut HashSet<NodeId>,
results: &mut Vec<PathResult>,
) {
if current == target {
let total_weight: f64 = edges.iter().map(|e| e.weight).sum();
results.push(PathResult::found(path.clone(), edges.clone(), total_weight));
return;
}
if path.len() > max_length {
return;
}
let neighbor_edges = self.store.edges_of(current, Direction::Both);
for edge in neighbor_edges {
let neighbor = if edge.source == *current {
edge.target
} else if !edge.directed || edge.target == *current {
edge.source
} else {
continue;
};
if !visited.contains(&neighbor) {
visited.insert(neighbor);
path.push(neighbor);
edges.push(edge.clone());
self.find_all_paths_dfs(&neighbor, target, max_length, path, edges, visited, results);
path.pop();
edges.pop();
visited.remove(&neighbor);
}
}
}
/// Finds the shortest path considering only specific edge types.
pub fn shortest_path_by_type(&self, from: &NodeId, to: &NodeId, edge_types: &[String]) -> PathResult {
if from == to {
return PathResult::found(vec![*from], Vec::new(), 0.0);
}
let mut visited = HashSet::new();
let mut queue: VecDeque<(NodeId, Vec<NodeId>, Vec<Edge>)> = VecDeque::new();
visited.insert(*from);
queue.push_back((*from, vec![*from], Vec::new()));
while let Some((current, path, edges)) = queue.pop_front() {
let neighbor_edges = self.store.edges_of(&current, Direction::Both);
for edge in neighbor_edges {
// Filter by edge type
if !edge_types.is_empty() && !edge_types.contains(&edge.edge_type) {
continue;
}
let neighbor = if edge.source == current {
edge.target
} else if !edge.directed || edge.target == current {
edge.source
} else {
continue;
};
if &neighbor == to {
let mut final_path = path.clone();
final_path.push(neighbor);
let path_len = final_path.len();
let mut final_edges = edges.clone();
final_edges.push(edge);
return PathResult::found(final_path, final_edges, path_len as f64 - 1.0);
}
if !visited.contains(&neighbor) {
visited.insert(neighbor);
let mut new_path = path.clone();
new_path.push(neighbor);
let mut new_edges = edges.clone();
new_edges.push(edge);
queue.push_back((neighbor, new_path, new_edges));
}
}
}
PathResult::not_found()
}
/// Checks if a path exists between two nodes.
pub fn path_exists(&self, from: &NodeId, to: &NodeId) -> bool {
if from == to {
return true;
}
let mut visited = HashSet::new();
let mut queue = VecDeque::new();
visited.insert(*from);
queue.push_back(*from);
while let Some(current) = queue.pop_front() {
let neighbor_edges = self.store.edges_of(&current, Direction::Both);
for edge in neighbor_edges {
let neighbor = if edge.source == current {
edge.target
} else if !edge.directed || edge.target == current {
edge.source
} else {
continue;
};
if &neighbor == to {
return true;
}
if !visited.contains(&neighbor) {
visited.insert(neighbor);
queue.push_back(neighbor);
}
}
}
false
}
/// Finds the distance (number of hops) between two nodes.
pub fn distance(&self, from: &NodeId, to: &NodeId) -> Option<usize> {
let result = self.shortest_path_bfs(from, to);
if result.found {
Some(result.nodes.len() - 1)
} else {
None
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn setup_graph() -> GraphStore {
let store = GraphStore::new();
// Create nodes A -> B -> C -> D
// \ /
// \-> E ->/
let a = store.create_node(vec![], serde_json::json!({"name": "A"}));
let b = store.create_node(vec![], serde_json::json!({"name": "B"}));
let c = store.create_node(vec![], serde_json::json!({"name": "C"}));
let d = store.create_node(vec![], serde_json::json!({"name": "D"}));
let e = store.create_node(vec![], serde_json::json!({"name": "E"}));
store.create_edge(a, b, "LINK", serde_json::json!({})).unwrap();
store.create_edge(b, c, "LINK", serde_json::json!({})).unwrap();
store.create_edge(c, d, "LINK", serde_json::json!({})).unwrap();
store.create_edge(a, e, "LINK", serde_json::json!({})).unwrap();
store.create_edge(e, d, "LINK", serde_json::json!({})).unwrap();
store
}
#[test]
fn test_shortest_path_bfs() {
let store = setup_graph();
let finder = PathFinder::new(&store);
let nodes: Vec<_> = store.find_nodes(None, &serde_json::json!({}));
let a = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("A"))).unwrap();
let d = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("D"))).unwrap();
let result = finder.shortest_path_bfs(&a.id, &d.id);
assert!(result.found);
assert_eq!(result.length, 2.0); // A -> E -> D (shortest)
}
#[test]
fn test_shortest_path_dijkstra() {
let store = GraphStore::new();
let a = store.create_node(vec![], serde_json::json!({"name": "A"}));
let b = store.create_node(vec![], serde_json::json!({"name": "B"}));
let c = store.create_node(vec![], serde_json::json!({"name": "C"}));
// A --(1.0)--> B --(1.0)--> C
// A --(3.0)--> C
let mut edge1 = super::super::edge::Edge::new(a, b, "LINK", serde_json::json!({}));
edge1.weight = 1.0;
store.create_edge(a, b, "LINK", serde_json::json!({})).unwrap();
let mut edge2 = super::super::edge::Edge::new(b, c, "LINK", serde_json::json!({}));
edge2.weight = 1.0;
store.create_edge(b, c, "LINK", serde_json::json!({})).unwrap();
store.create_edge(a, c, "DIRECT", serde_json::json!({})).unwrap();
let finder = PathFinder::new(&store);
let result = finder.shortest_path_dijkstra(&a, &c);
assert!(result.found);
// Both paths have same weight (1.0 each), either is valid
assert!(result.nodes.len() <= 3);
}
#[test]
fn test_path_not_found() {
let store = GraphStore::new();
let a = store.create_node(vec![], serde_json::json!({}));
let b = store.create_node(vec![], serde_json::json!({}));
// No edge between a and b
let finder = PathFinder::new(&store);
let result = finder.shortest_path_bfs(&a, &b);
assert!(!result.found);
}
#[test]
fn test_all_paths() {
let store = setup_graph();
let finder = PathFinder::new(&store);
let nodes: Vec<_> = store.find_nodes(None, &serde_json::json!({}));
let a = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("A"))).unwrap();
let d = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("D"))).unwrap();
let paths = finder.all_paths(&a.id, &d.id, 5);
assert!(paths.len() >= 2); // At least A->B->C->D and A->E->D
}
#[test]
fn test_path_exists() {
let store = setup_graph();
let finder = PathFinder::new(&store);
let nodes: Vec<_> = store.find_nodes(None, &serde_json::json!({}));
let a = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("A"))).unwrap();
let d = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("D"))).unwrap();
assert!(finder.path_exists(&a.id, &d.id));
}
#[test]
fn test_distance() {
let store = setup_graph();
let finder = PathFinder::new(&store);
let nodes: Vec<_> = store.find_nodes(None, &serde_json::json!({}));
let a = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("A"))).unwrap();
let d = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("D"))).unwrap();
let b = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("B"))).unwrap();
assert_eq!(finder.distance(&a.id, &b.id), Some(1));
assert_eq!(finder.distance(&a.id, &d.id), Some(2)); // A -> E -> D
}
}

View file

@ -0,0 +1,825 @@
//! Simplified Cypher-like query language for graphs.
use super::edge::Edge;
use super::node::{Node, NodeId};
use super::store::{Direction, GraphError, GraphStore};
use super::traversal::{TraversalDirection, TraversalQuery, Traverser};
use serde::{Deserialize, Serialize};
use serde_json::Value as JsonValue;
use std::collections::HashMap;
/// Parsed graph query.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum GraphQuery {
/// MATCH query for pattern matching.
Match {
pattern: MatchPattern,
where_clause: Option<WhereClause>,
return_items: Vec<ReturnItem>,
limit: Option<usize>,
},
/// CREATE query for creating nodes/edges.
Create { elements: Vec<CreateElement> },
/// DELETE query for removing nodes/edges.
Delete { variable: String, detach: bool },
/// SET query for updating properties.
Set { variable: String, properties: JsonValue },
}
/// Pattern to match in the graph.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct MatchPattern {
/// Starting node pattern.
pub start: NodePattern,
/// Relationship patterns (edges and target nodes).
pub relationships: Vec<RelationshipPattern>,
}
/// Pattern for matching a node.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct NodePattern {
/// Variable name for this node.
pub variable: Option<String>,
/// Required labels.
pub labels: Vec<String>,
/// Property filters.
pub properties: Option<JsonValue>,
}
/// Pattern for matching a relationship.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct RelationshipPattern {
/// Variable name for this relationship.
pub variable: Option<String>,
/// Edge type (relationship type).
pub edge_type: Option<String>,
/// Direction of the relationship.
pub direction: RelationshipDirection,
/// Target node pattern.
pub target: NodePattern,
/// Minimum hops (for variable-length paths).
pub min_hops: usize,
/// Maximum hops (for variable-length paths).
pub max_hops: usize,
}
/// Direction of a relationship in a pattern.
#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq)]
pub enum RelationshipDirection {
/// Outgoing: (a)-[:TYPE]->(b)
Outgoing,
/// Incoming: (a)<-[:TYPE]-(b)
Incoming,
/// Undirected: (a)-[:TYPE]-(b)
Undirected,
}
/// WHERE clause conditions.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum WhereClause {
/// Property comparison.
PropertyEquals { variable: String, property: String, value: JsonValue },
/// Property comparison (not equals).
PropertyNotEquals { variable: String, property: String, value: JsonValue },
/// Property greater than.
PropertyGt { variable: String, property: String, value: JsonValue },
/// Property less than.
PropertyLt { variable: String, property: String, value: JsonValue },
/// Property contains (for text).
PropertyContains { variable: String, property: String, value: String },
/// AND condition.
And(Box<WhereClause>, Box<WhereClause>),
/// OR condition.
Or(Box<WhereClause>, Box<WhereClause>),
/// NOT condition.
Not(Box<WhereClause>),
}
/// Item to return from a query.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum ReturnItem {
/// Return all variables.
All,
/// Return a specific variable.
Variable(String),
/// Return a property of a variable.
Property { variable: String, property: String },
/// Return with an alias.
Alias { item: Box<ReturnItem>, alias: String },
/// Count aggregation.
Count(Option<String>),
}
/// Element to create.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum CreateElement {
/// Create a node.
Node { variable: Option<String>, labels: Vec<String>, properties: JsonValue },
/// Create a relationship.
Relationship {
from_var: String,
to_var: String,
edge_type: String,
properties: JsonValue,
},
}
/// Result of a graph query.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct QueryResult {
/// Column names.
pub columns: Vec<String>,
/// Result rows.
pub rows: Vec<Vec<JsonValue>>,
/// Number of nodes created.
pub nodes_created: usize,
/// Number of relationships created.
pub relationships_created: usize,
/// Number of nodes deleted.
pub nodes_deleted: usize,
/// Number of relationships deleted.
pub relationships_deleted: usize,
/// Number of properties set.
pub properties_set: usize,
}
impl QueryResult {
/// Creates an empty result.
pub fn empty() -> Self {
Self {
columns: Vec::new(),
rows: Vec::new(),
nodes_created: 0,
relationships_created: 0,
nodes_deleted: 0,
relationships_deleted: 0,
properties_set: 0,
}
}
}
/// Parser for Cypher-like graph queries.
pub struct GraphQueryParser;
impl GraphQueryParser {
/// Parses a query string into a GraphQuery.
pub fn parse(query: &str) -> Result<GraphQuery, GraphError> {
let query = query.trim();
let upper = query.to_uppercase();
if upper.starts_with("MATCH") {
Self::parse_match(query)
} else if upper.starts_with("CREATE") {
Self::parse_create(query)
} else if upper.starts_with("DELETE") || upper.starts_with("DETACH DELETE") {
Self::parse_delete(query)
} else if upper.starts_with("SET") {
Self::parse_set(query)
} else {
Err(GraphError::InvalidOperation(format!("Unknown query type: {}", query)))
}
}
fn parse_match(query: &str) -> Result<GraphQuery, GraphError> {
// Simplified parser for: MATCH (var:Label {props})-[:TYPE]->(var2) WHERE ... RETURN ...
let upper = query.to_uppercase();
// Find MATCH, WHERE, RETURN, LIMIT positions
let match_end = upper.find("WHERE").or_else(|| upper.find("RETURN")).unwrap_or(query.len());
let where_start = upper.find("WHERE");
let return_start = upper.find("RETURN");
let limit_start = upper.find("LIMIT");
// Parse pattern (between MATCH and WHERE/RETURN)
let pattern_str = &query[5..match_end].trim();
let pattern = Self::parse_pattern(pattern_str)?;
// Parse WHERE clause
let where_clause = if let Some(ws) = where_start {
let where_end = return_start.unwrap_or(query.len());
let where_str = &query[ws + 5..where_end].trim();
Some(Self::parse_where(where_str)?)
} else {
None
};
// Parse RETURN clause
let return_items = if let Some(rs) = return_start {
let return_end = limit_start.unwrap_or(query.len());
let return_str = &query[rs + 6..return_end].trim();
Self::parse_return(return_str)?
} else {
vec![ReturnItem::All]
};
// Parse LIMIT
let limit = if let Some(ls) = limit_start {
let limit_str = query[ls + 5..].trim();
limit_str.parse().ok()
} else {
None
};
Ok(GraphQuery::Match {
pattern,
where_clause,
return_items,
limit,
})
}
fn parse_pattern(pattern: &str) -> Result<MatchPattern, GraphError> {
// Parse node and relationship patterns
// Format: (var:Label {props})-[:TYPE]->(var2:Label2)
let mut chars = pattern.chars().peekable();
let mut nodes = Vec::new();
let mut relationships = Vec::new();
while chars.peek().is_some() {
// Skip whitespace
while chars.peek() == Some(&' ') {
chars.next();
}
if chars.peek() == Some(&'(') {
let node = Self::parse_node_pattern(&mut chars)?;
nodes.push(node);
} else if chars.peek() == Some(&'-') || chars.peek() == Some(&'<') {
let rel = Self::parse_relationship_pattern(&mut chars)?;
relationships.push(rel);
} else if chars.peek().is_some() {
chars.next(); // Skip unknown characters
}
}
if nodes.is_empty() {
return Err(GraphError::InvalidOperation("No node pattern found".to_string()));
}
// Combine nodes with relationships
let start = nodes.remove(0);
for (i, rel) in relationships.iter_mut().enumerate() {
if i < nodes.len() {
rel.target = nodes[i].clone();
}
}
Ok(MatchPattern { start, relationships })
}
fn parse_node_pattern(chars: &mut std::iter::Peekable<std::str::Chars>) -> Result<NodePattern, GraphError> {
// Consume '('
chars.next();
let mut variable = None;
let mut labels = Vec::new();
let mut properties = None;
let mut buffer = String::new();
while let Some(&c) = chars.peek() {
match c {
')' => {
chars.next();
if !buffer.is_empty() && variable.is_none() {
variable = Some(buffer.clone());
}
break;
}
':' => {
chars.next();
if !buffer.is_empty() && variable.is_none() {
variable = Some(buffer.clone());
}
buffer.clear();
// Read label
while let Some(&c) = chars.peek() {
if c == ')' || c == '{' || c == ':' || c == ' ' {
break;
}
buffer.push(c);
chars.next();
}
if !buffer.is_empty() {
labels.push(buffer.clone());
buffer.clear();
}
}
'{' => {
chars.next();
// Read properties JSON
let mut props_str = String::from("{");
let mut depth = 1;
while let Some(&c) = chars.peek() {
chars.next();
props_str.push(c);
if c == '{' {
depth += 1;
} else if c == '}' {
depth -= 1;
if depth == 0 {
break;
}
}
}
properties = serde_json::from_str(&props_str).ok();
}
' ' => {
chars.next();
}
_ => {
buffer.push(c);
chars.next();
}
}
}
Ok(NodePattern { variable, labels, properties })
}
fn parse_relationship_pattern(chars: &mut std::iter::Peekable<std::str::Chars>) -> Result<RelationshipPattern, GraphError> {
let mut direction = RelationshipDirection::Undirected;
let mut edge_type = None;
let mut variable = None;
let min_hops = 1;
let max_hops = 1;
// Check for incoming: <-
if chars.peek() == Some(&'<') {
chars.next();
direction = RelationshipDirection::Incoming;
}
// Consume -
if chars.peek() == Some(&'-') {
chars.next();
}
// Check for [type]
if chars.peek() == Some(&'[') {
chars.next();
let mut buffer = String::new();
while let Some(&c) = chars.peek() {
if c == ']' {
chars.next();
break;
} else if c == ':' {
chars.next();
if !buffer.is_empty() {
variable = Some(buffer.clone());
}
buffer.clear();
// Read edge type
while let Some(&c) = chars.peek() {
if c == ']' || c == ' ' || c == '*' {
break;
}
buffer.push(c);
chars.next();
}
if !buffer.is_empty() {
edge_type = Some(buffer.clone());
buffer.clear();
}
} else {
buffer.push(c);
chars.next();
}
}
}
// Consume -
if chars.peek() == Some(&'-') {
chars.next();
}
// Check for outgoing: >
if chars.peek() == Some(&'>') {
chars.next();
if direction != RelationshipDirection::Incoming {
direction = RelationshipDirection::Outgoing;
}
}
Ok(RelationshipPattern {
variable,
edge_type,
direction,
target: NodePattern { variable: None, labels: Vec::new(), properties: None },
min_hops,
max_hops,
})
}
fn parse_where(_where_str: &str) -> Result<WhereClause, GraphError> {
// Simplified: just parse "var.prop = value"
// Full implementation would handle complex boolean expressions
Ok(WhereClause::PropertyEquals {
variable: "n".to_string(),
property: "id".to_string(),
value: JsonValue::Null,
})
}
fn parse_return(return_str: &str) -> Result<Vec<ReturnItem>, GraphError> {
let items: Vec<_> = return_str.split(',').map(|s| s.trim()).collect();
let mut result = Vec::new();
for item in items {
if item == "*" {
result.push(ReturnItem::All);
} else if item.to_uppercase().starts_with("COUNT(") {
let inner = &item[6..item.len() - 1];
if inner == "*" {
result.push(ReturnItem::Count(None));
} else {
result.push(ReturnItem::Count(Some(inner.to_string())));
}
} else if item.contains('.') {
let parts: Vec<_> = item.split('.').collect();
if parts.len() == 2 {
result.push(ReturnItem::Property {
variable: parts[0].to_string(),
property: parts[1].to_string(),
});
}
} else {
result.push(ReturnItem::Variable(item.to_string()));
}
}
Ok(result)
}
fn parse_create(query: &str) -> Result<GraphQuery, GraphError> {
// Simplified CREATE parser
let pattern = &query[6..].trim();
let elements = Self::parse_create_elements(pattern)?;
Ok(GraphQuery::Create { elements })
}
fn parse_create_elements(pattern: &str) -> Result<Vec<CreateElement>, GraphError> {
// Parse (var:Label {props}) patterns
let mut elements = Vec::new();
let mut chars = pattern.chars().peekable();
while chars.peek().is_some() {
while chars.peek() == Some(&' ') || chars.peek() == Some(&',') {
chars.next();
}
if chars.peek() == Some(&'(') {
let node = Self::parse_node_pattern(&mut chars)?;
elements.push(CreateElement::Node {
variable: node.variable,
labels: node.labels,
properties: node.properties.unwrap_or(JsonValue::Object(serde_json::Map::new())),
});
} else {
break;
}
}
Ok(elements)
}
fn parse_delete(query: &str) -> Result<GraphQuery, GraphError> {
let detach = query.to_uppercase().starts_with("DETACH");
let start = if detach { "DETACH DELETE".len() } else { "DELETE".len() };
let variable = query[start..].trim().to_string();
Ok(GraphQuery::Delete { variable, detach })
}
fn parse_set(query: &str) -> Result<GraphQuery, GraphError> {
// Simplified: SET var.prop = value
let content = &query[3..].trim();
let parts: Vec<_> = content.split('=').collect();
if parts.len() != 2 {
return Err(GraphError::InvalidOperation("Invalid SET syntax".to_string()));
}
let var_prop: Vec<_> = parts[0].trim().split('.').collect();
if var_prop.len() != 2 {
return Err(GraphError::InvalidOperation("Invalid SET variable".to_string()));
}
let variable = var_prop[0].to_string();
let property = var_prop[1].to_string();
let value_str = parts[1].trim();
let value: JsonValue = serde_json::from_str(value_str).unwrap_or(JsonValue::String(value_str.to_string()));
Ok(GraphQuery::Set {
variable,
properties: serde_json::json!({ property: value }),
})
}
}
/// Query executor for graph queries.
pub struct GraphQueryExecutor<'a> {
store: &'a GraphStore,
}
impl<'a> GraphQueryExecutor<'a> {
/// Creates a new query executor.
pub fn new(store: &'a GraphStore) -> Self {
Self { store }
}
/// Executes a graph query.
pub fn execute(&self, query: &GraphQuery) -> Result<QueryResult, GraphError> {
match query {
GraphQuery::Match { pattern, where_clause, return_items, limit } => {
self.execute_match(pattern, where_clause.as_ref(), return_items, *limit)
}
GraphQuery::Create { .. } => {
Err(GraphError::InvalidOperation("CREATE requires mutable access".to_string()))
}
GraphQuery::Delete { .. } => {
Err(GraphError::InvalidOperation("DELETE requires mutable access".to_string()))
}
GraphQuery::Set { .. } => {
Err(GraphError::InvalidOperation("SET requires mutable access".to_string()))
}
}
}
fn execute_match(
&self,
pattern: &MatchPattern,
_where_clause: Option<&WhereClause>,
return_items: &[ReturnItem],
limit: Option<usize>,
) -> Result<QueryResult, GraphError> {
// Find starting nodes
let start_nodes = self.find_matching_nodes(&pattern.start);
let mut bindings: Vec<HashMap<String, JsonValue>> = Vec::new();
for start_node in &start_nodes {
let mut binding: HashMap<String, JsonValue> = HashMap::new();
if let Some(ref var) = pattern.start.variable {
binding.insert(var.clone(), Self::node_to_json(start_node));
}
if pattern.relationships.is_empty() {
bindings.push(binding);
} else {
// Traverse relationships
let traverser = Traverser::new(self.store);
for rel_pattern in &pattern.relationships {
let direction = match rel_pattern.direction {
RelationshipDirection::Outgoing => TraversalDirection::Outgoing,
RelationshipDirection::Incoming => TraversalDirection::Incoming,
RelationshipDirection::Undirected => TraversalDirection::Both,
};
let query = TraversalQuery::new()
.depth(rel_pattern.max_hops)
.direction(direction)
.edge_types(
rel_pattern.edge_type.clone().map(|t| vec![t]).unwrap_or_default(),
)
.labels(rel_pattern.target.labels.clone());
let results = traverser.traverse(&start_node.id, &query);
for result in results {
let mut new_binding = binding.clone();
if let Some(ref var) = rel_pattern.variable {
if let Some(edge) = result.edges.last() {
new_binding.insert(var.clone(), Self::edge_to_json(edge));
}
}
if let Some(ref var) = rel_pattern.target.variable {
new_binding.insert(var.clone(), Self::node_to_json(&result.node));
}
bindings.push(new_binding);
}
}
}
if let Some(l) = limit {
if bindings.len() >= l {
break;
}
}
}
// Apply limit
if let Some(l) = limit {
bindings.truncate(l);
}
// Build result based on return items
let columns = self.get_column_names(return_items, &bindings);
let rows = self.extract_rows(return_items, &bindings);
Ok(QueryResult {
columns,
rows,
..QueryResult::empty()
})
}
fn find_matching_nodes(&self, pattern: &NodePattern) -> Vec<Node> {
let label = pattern.labels.first().map(|s| s.as_str());
let filter = pattern.properties.clone().unwrap_or(JsonValue::Object(serde_json::Map::new()));
self.store.find_nodes(label, &filter)
}
fn node_to_json(node: &Node) -> JsonValue {
serde_json::json!({
"id": node.id.to_hex(),
"labels": node.labels,
"properties": node.properties,
})
}
fn edge_to_json(edge: &Edge) -> JsonValue {
serde_json::json!({
"id": edge.id.to_hex(),
"type": edge.edge_type,
"source": edge.source.to_hex(),
"target": edge.target.to_hex(),
"properties": edge.properties,
})
}
fn get_column_names(&self, return_items: &[ReturnItem], bindings: &[HashMap<String, JsonValue>]) -> Vec<String> {
let mut columns = Vec::new();
for item in return_items {
match item {
ReturnItem::All => {
if let Some(binding) = bindings.first() {
columns.extend(binding.keys().cloned());
}
}
ReturnItem::Variable(var) => columns.push(var.clone()),
ReturnItem::Property { variable, property } => {
columns.push(format!("{}.{}", variable, property));
}
ReturnItem::Alias { alias, .. } => columns.push(alias.clone()),
ReturnItem::Count(var) => {
columns.push(format!("count({})", var.as_ref().map(|s| s.as_str()).unwrap_or("*")));
}
}
}
columns
}
fn extract_rows(&self, return_items: &[ReturnItem], bindings: &[HashMap<String, JsonValue>]) -> Vec<Vec<JsonValue>> {
let mut rows = Vec::new();
// Handle COUNT specially
if return_items.iter().any(|i| matches!(i, ReturnItem::Count(_))) {
rows.push(vec![JsonValue::Number(bindings.len().into())]);
return rows;
}
for binding in bindings {
let mut row = Vec::new();
for item in return_items {
match item {
ReturnItem::All => {
for (_, value) in binding {
row.push(value.clone());
}
}
ReturnItem::Variable(var) => {
row.push(binding.get(var).cloned().unwrap_or(JsonValue::Null));
}
ReturnItem::Property { variable, property } => {
if let Some(obj) = binding.get(variable) {
if let Some(props) = obj.get("properties") {
row.push(props.get(property).cloned().unwrap_or(JsonValue::Null));
} else {
row.push(JsonValue::Null);
}
} else {
row.push(JsonValue::Null);
}
}
ReturnItem::Alias { item: inner, .. } => {
// Recursively handle the inner item
let inner_rows = self.extract_rows(&[*inner.clone()], &[binding.clone()]);
if let Some(inner_row) = inner_rows.first() {
row.extend(inner_row.clone());
}
}
ReturnItem::Count(_) => {
// Handled above
}
}
}
if !row.is_empty() {
rows.push(row);
}
}
rows
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_simple_match() {
let query = "MATCH (n:User) RETURN n";
let parsed = GraphQueryParser::parse(query).unwrap();
if let GraphQuery::Match { pattern, .. } = parsed {
assert_eq!(pattern.start.labels, vec!["User".to_string()]);
} else {
panic!("Expected Match query");
}
}
#[test]
fn test_parse_match_with_relationship() {
let query = "MATCH (a:User)-[:FRIEND]->(b:User) RETURN a, b";
let parsed = GraphQueryParser::parse(query).unwrap();
if let GraphQuery::Match { pattern, .. } = parsed {
assert_eq!(pattern.start.labels, vec!["User".to_string()]);
assert_eq!(pattern.relationships.len(), 1);
assert_eq!(pattern.relationships[0].edge_type, Some("FRIEND".to_string()));
assert_eq!(pattern.relationships[0].direction, RelationshipDirection::Outgoing);
} else {
panic!("Expected Match query");
}
}
#[test]
fn test_execute_match() {
let store = GraphStore::new();
let alice = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Alice"}));
let bob = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Bob"}));
store.create_edge(alice, bob, "FRIEND", serde_json::json!({})).unwrap();
let query = GraphQueryParser::parse("MATCH (n:User) RETURN n").unwrap();
let executor = GraphQueryExecutor::new(&store);
let result = executor.execute(&query).unwrap();
assert_eq!(result.rows.len(), 2);
}
#[test]
fn test_parse_create() {
let query = "CREATE (n:User {name: \"Alice\"})";
let parsed = GraphQueryParser::parse(query).unwrap();
if let GraphQuery::Create { elements } = parsed {
assert_eq!(elements.len(), 1);
if let CreateElement::Node { labels, .. } = &elements[0] {
assert_eq!(labels, &vec!["User".to_string()]);
}
} else {
panic!("Expected Create query");
}
}
#[test]
fn test_parse_delete() {
let query = "DELETE n";
let parsed = GraphQueryParser::parse(query).unwrap();
if let GraphQuery::Delete { variable, detach } = parsed {
assert_eq!(variable, "n");
assert!(!detach);
} else {
panic!("Expected Delete query");
}
}
#[test]
fn test_parse_detach_delete() {
let query = "DETACH DELETE n";
let parsed = GraphQueryParser::parse(query).unwrap();
if let GraphQuery::Delete { variable, detach } = parsed {
assert_eq!(variable, "n");
assert!(detach);
} else {
panic!("Expected Delete query");
}
}
}

View file

@ -0,0 +1,657 @@
//! Graph storage engine.
use super::edge::{Edge, EdgeId};
use super::node::{Node, NodeId};
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use serde_json::Value as JsonValue;
use std::collections::{HashMap, HashSet};
use thiserror::Error;
/// Direction for traversing edges.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Direction {
/// Follow outgoing edges only.
Outgoing,
/// Follow incoming edges only.
Incoming,
/// Follow edges in both directions.
Both,
}
/// Graph storage errors.
#[derive(Debug, Error)]
pub enum GraphError {
/// Node not found.
#[error("Node not found: {0}")]
NodeNotFound(String),
/// Edge not found.
#[error("Edge not found: {0}")]
EdgeNotFound(String),
/// Node already exists.
#[error("Node already exists: {0}")]
NodeExists(String),
/// Invalid operation.
#[error("Invalid operation: {0}")]
InvalidOperation(String),
/// Constraint violation.
#[error("Constraint violation: {0}")]
ConstraintViolation(String),
}
/// Statistics for a graph store.
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct GraphStats {
/// Total number of nodes.
pub node_count: u64,
/// Total number of edges.
pub edge_count: u64,
/// Number of distinct labels.
pub label_count: u64,
/// Number of distinct edge types.
pub edge_type_count: u64,
}
/// Graph storage engine.
pub struct GraphStore {
/// Node storage.
nodes: RwLock<HashMap<NodeId, Node>>,
/// Edge storage.
edges: RwLock<HashMap<EdgeId, Edge>>,
/// Outgoing adjacency list: node -> outgoing edges.
adjacency: RwLock<HashMap<NodeId, Vec<EdgeId>>>,
/// Incoming adjacency list: node -> incoming edges.
reverse_adj: RwLock<HashMap<NodeId, Vec<EdgeId>>>,
/// Label index: label -> nodes with that label.
label_index: RwLock<HashMap<String, HashSet<NodeId>>>,
/// Edge type index: type -> edges of that type.
edge_type_index: RwLock<HashMap<String, HashSet<EdgeId>>>,
}
impl GraphStore {
/// Creates a new empty graph store.
pub fn new() -> Self {
Self {
nodes: RwLock::new(HashMap::new()),
edges: RwLock::new(HashMap::new()),
adjacency: RwLock::new(HashMap::new()),
reverse_adj: RwLock::new(HashMap::new()),
label_index: RwLock::new(HashMap::new()),
edge_type_index: RwLock::new(HashMap::new()),
}
}
/// Returns statistics about the graph.
pub fn stats(&self) -> GraphStats {
GraphStats {
node_count: self.nodes.read().len() as u64,
edge_count: self.edges.read().len() as u64,
label_count: self.label_index.read().len() as u64,
edge_type_count: self.edge_type_index.read().len() as u64,
}
}
// ==================== Node Operations ====================
/// Creates a new node with the given labels and properties.
pub fn create_node(&self, labels: Vec<String>, properties: JsonValue) -> NodeId {
let node = Node::new(labels.clone(), properties);
let id = node.id;
// Update label index
{
let mut label_idx = self.label_index.write();
for label in &labels {
label_idx.entry(label.clone()).or_default().insert(id);
}
}
// Initialize adjacency lists
self.adjacency.write().insert(id, Vec::new());
self.reverse_adj.write().insert(id, Vec::new());
// Store node
self.nodes.write().insert(id, node);
id
}
/// Gets a node by ID.
pub fn get_node(&self, id: &NodeId) -> Option<Node> {
self.nodes.read().get(id).cloned()
}
/// Updates a node's properties.
pub fn update_node(&self, id: &NodeId, properties: JsonValue) -> Result<(), GraphError> {
let mut nodes = self.nodes.write();
let node = nodes
.get_mut(id)
.ok_or_else(|| GraphError::NodeNotFound(id.to_string()))?;
node.properties = properties;
node.updated_at = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis() as u64;
Ok(())
}
/// Updates a node's labels.
pub fn update_node_labels(&self, id: &NodeId, labels: Vec<String>) -> Result<(), GraphError> {
let mut nodes = self.nodes.write();
let node = nodes
.get_mut(id)
.ok_or_else(|| GraphError::NodeNotFound(id.to_string()))?;
// Update label index
{
let mut label_idx = self.label_index.write();
// Remove old labels
for old_label in &node.labels {
if let Some(set) = label_idx.get_mut(old_label) {
set.remove(id);
}
}
// Add new labels
for new_label in &labels {
label_idx.entry(new_label.clone()).or_default().insert(*id);
}
}
node.labels = labels;
node.updated_at = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis() as u64;
Ok(())
}
/// Deletes a node and all its connected edges.
pub fn delete_node(&self, id: &NodeId) -> Result<(), GraphError> {
// Get connected edges
let outgoing: Vec<EdgeId> = self
.adjacency
.read()
.get(id)
.cloned()
.unwrap_or_default();
let incoming: Vec<EdgeId> = self
.reverse_adj
.read()
.get(id)
.cloned()
.unwrap_or_default();
// Delete all connected edges
for edge_id in outgoing.iter().chain(incoming.iter()) {
let _ = self.delete_edge(edge_id);
}
// Remove from label index
{
let nodes = self.nodes.read();
if let Some(node) = nodes.get(id) {
let mut label_idx = self.label_index.write();
for label in &node.labels {
if let Some(set) = label_idx.get_mut(label) {
set.remove(id);
}
}
}
}
// Remove adjacency entries
self.adjacency.write().remove(id);
self.reverse_adj.write().remove(id);
// Remove node
self.nodes
.write()
.remove(id)
.ok_or_else(|| GraphError::NodeNotFound(id.to_string()))?;
Ok(())
}
/// Finds nodes by label.
pub fn find_nodes_by_label(&self, label: &str) -> Vec<Node> {
let node_ids: Vec<NodeId> = self
.label_index
.read()
.get(label)
.map(|set| set.iter().cloned().collect())
.unwrap_or_default();
let nodes = self.nodes.read();
node_ids
.iter()
.filter_map(|id| nodes.get(id).cloned())
.collect()
}
/// Finds nodes matching a filter.
pub fn find_nodes(&self, label: Option<&str>, filter: &JsonValue) -> Vec<Node> {
let candidates: Vec<Node> = if let Some(l) = label {
self.find_nodes_by_label(l)
} else {
self.nodes.read().values().cloned().collect()
};
candidates
.into_iter()
.filter(|n| n.matches_properties(filter))
.collect()
}
// ==================== Edge Operations ====================
/// Creates a new edge.
pub fn create_edge(
&self,
source: NodeId,
target: NodeId,
edge_type: impl Into<String>,
properties: JsonValue,
) -> Result<EdgeId, GraphError> {
let edge_type = edge_type.into();
// Verify both nodes exist
{
let nodes = self.nodes.read();
if !nodes.contains_key(&source) {
return Err(GraphError::NodeNotFound(source.to_string()));
}
if !nodes.contains_key(&target) {
return Err(GraphError::NodeNotFound(target.to_string()));
}
}
let edge = Edge::new(source, target, edge_type.clone(), properties);
let id = edge.id;
// Update adjacency lists
self.adjacency.write().entry(source).or_default().push(id);
self.reverse_adj.write().entry(target).or_default().push(id);
// Update edge type index
self.edge_type_index
.write()
.entry(edge_type)
.or_default()
.insert(id);
// Store edge
self.edges.write().insert(id, edge);
Ok(id)
}
/// Creates an undirected edge.
pub fn create_undirected_edge(
&self,
source: NodeId,
target: NodeId,
edge_type: impl Into<String>,
properties: JsonValue,
) -> Result<EdgeId, GraphError> {
let edge_type = edge_type.into();
// Verify both nodes exist
{
let nodes = self.nodes.read();
if !nodes.contains_key(&source) {
return Err(GraphError::NodeNotFound(source.to_string()));
}
if !nodes.contains_key(&target) {
return Err(GraphError::NodeNotFound(target.to_string()));
}
}
let edge = Edge::undirected(source, target, edge_type.clone(), properties);
let id = edge.id;
// Update adjacency lists (both directions for undirected)
{
let mut adj = self.adjacency.write();
adj.entry(source).or_default().push(id);
adj.entry(target).or_default().push(id);
}
{
let mut rev = self.reverse_adj.write();
rev.entry(source).or_default().push(id);
rev.entry(target).or_default().push(id);
}
// Update edge type index
self.edge_type_index
.write()
.entry(edge_type)
.or_default()
.insert(id);
// Store edge
self.edges.write().insert(id, edge);
Ok(id)
}
/// Gets an edge by ID.
pub fn get_edge(&self, id: &EdgeId) -> Option<Edge> {
self.edges.read().get(id).cloned()
}
/// Deletes an edge.
pub fn delete_edge(&self, id: &EdgeId) -> Result<(), GraphError> {
let edge = self
.edges
.write()
.remove(id)
.ok_or_else(|| GraphError::EdgeNotFound(id.to_string()))?;
// Update adjacency lists
{
let mut adj = self.adjacency.write();
if let Some(list) = adj.get_mut(&edge.source) {
list.retain(|e| e != id);
}
if !edge.directed {
if let Some(list) = adj.get_mut(&edge.target) {
list.retain(|e| e != id);
}
}
}
{
let mut rev = self.reverse_adj.write();
if let Some(list) = rev.get_mut(&edge.target) {
list.retain(|e| e != id);
}
if !edge.directed {
if let Some(list) = rev.get_mut(&edge.source) {
list.retain(|e| e != id);
}
}
}
// Update edge type index
if let Some(set) = self.edge_type_index.write().get_mut(&edge.edge_type) {
set.remove(id);
}
Ok(())
}
/// Finds edges by type.
pub fn find_edges_by_type(&self, edge_type: &str) -> Vec<Edge> {
let edge_ids: Vec<EdgeId> = self
.edge_type_index
.read()
.get(edge_type)
.map(|set| set.iter().cloned().collect())
.unwrap_or_default();
let edges = self.edges.read();
edge_ids
.iter()
.filter_map(|id| edges.get(id).cloned())
.collect()
}
// ==================== Traversal Operations ====================
/// Gets neighboring nodes.
pub fn neighbors(&self, id: &NodeId, direction: Direction) -> Vec<Node> {
let edge_ids = self.edges_of_node(id, direction);
let edges = self.edges.read();
let nodes = self.nodes.read();
let mut neighbor_ids = HashSet::new();
for eid in edge_ids {
if let Some(edge) = edges.get(&eid) {
if let Some(other) = self.get_neighbor_from_edge(edge, id, direction) {
neighbor_ids.insert(other);
}
}
}
neighbor_ids
.iter()
.filter_map(|nid| nodes.get(nid).cloned())
.collect()
}
/// Gets edges connected to a node.
pub fn edges_of(&self, id: &NodeId, direction: Direction) -> Vec<Edge> {
let edge_ids = self.edges_of_node(id, direction);
let edges = self.edges.read();
edge_ids
.iter()
.filter_map(|eid| edges.get(eid).cloned())
.collect()
}
/// Gets edge IDs connected to a node.
fn edges_of_node(&self, id: &NodeId, direction: Direction) -> Vec<EdgeId> {
match direction {
Direction::Outgoing => self.adjacency.read().get(id).cloned().unwrap_or_default(),
Direction::Incoming => self.reverse_adj.read().get(id).cloned().unwrap_or_default(),
Direction::Both => {
let mut result = self.adjacency.read().get(id).cloned().unwrap_or_default();
let incoming = self.reverse_adj.read().get(id).cloned().unwrap_or_default();
for eid in incoming {
if !result.contains(&eid) {
result.push(eid);
}
}
result
}
}
}
/// Gets the neighbor node from an edge.
fn get_neighbor_from_edge(&self, edge: &Edge, from: &NodeId, direction: Direction) -> Option<NodeId> {
match direction {
Direction::Outgoing => {
if &edge.source == from {
Some(edge.target)
} else if !edge.directed && &edge.target == from {
Some(edge.source)
} else {
None
}
}
Direction::Incoming => {
if &edge.target == from {
Some(edge.source)
} else if !edge.directed && &edge.source == from {
Some(edge.target)
} else {
None
}
}
Direction::Both => edge.other_end(from).or_else(|| {
// For directed edges, still return the other end
if &edge.source == from {
Some(edge.target)
} else if &edge.target == from {
Some(edge.source)
} else {
None
}
}),
}
}
/// Gets neighbors connected by a specific edge type.
pub fn neighbors_by_type(&self, id: &NodeId, edge_type: &str, direction: Direction) -> Vec<Node> {
let edges = self.edges_of(id, direction);
let nodes = self.nodes.read();
let mut neighbor_ids = HashSet::new();
for edge in edges {
if edge.edge_type == edge_type {
if let Some(other) = self.get_neighbor_from_edge(&edge, id, direction) {
neighbor_ids.insert(other);
}
}
}
neighbor_ids
.iter()
.filter_map(|nid| nodes.get(nid).cloned())
.collect()
}
/// Checks if an edge exists between two nodes.
pub fn has_edge(&self, source: &NodeId, target: &NodeId, edge_type: Option<&str>) -> bool {
let edges = self.edges_of(source, Direction::Outgoing);
for edge in edges {
if &edge.target == target {
if let Some(et) = edge_type {
if edge.edge_type == et {
return true;
}
} else {
return true;
}
}
}
false
}
/// Gets all edges between two nodes.
pub fn edges_between(&self, source: &NodeId, target: &NodeId) -> Vec<Edge> {
let edges = self.edges_of(source, Direction::Outgoing);
edges
.into_iter()
.filter(|e| &e.target == target || (!e.directed && &e.source == target))
.collect()
}
}
impl Default for GraphStore {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_create_node() {
let store = GraphStore::new();
let id = store.create_node(
vec!["User".to_string()],
serde_json::json!({"name": "Alice"}),
);
let node = store.get_node(&id).unwrap();
assert!(node.has_label("User"));
assert_eq!(node.get_property("name"), Some(&serde_json::json!("Alice")));
}
#[test]
fn test_create_edge() {
let store = GraphStore::new();
let alice = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Alice"}));
let bob = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Bob"}));
let edge_id = store
.create_edge(alice, bob, "FRIEND", serde_json::json!({"since": 2020}))
.unwrap();
let edge = store.get_edge(&edge_id).unwrap();
assert_eq!(edge.source, alice);
assert_eq!(edge.target, bob);
assert_eq!(edge.edge_type, "FRIEND");
}
#[test]
fn test_neighbors() {
let store = GraphStore::new();
let alice = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Alice"}));
let bob = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Bob"}));
let charlie = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Charlie"}));
store.create_edge(alice, bob, "FRIEND", serde_json::json!({})).unwrap();
store.create_edge(alice, charlie, "FRIEND", serde_json::json!({})).unwrap();
let neighbors = store.neighbors(&alice, Direction::Outgoing);
assert_eq!(neighbors.len(), 2);
}
#[test]
fn test_find_by_label() {
let store = GraphStore::new();
store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Alice"}));
store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Bob"}));
store.create_node(vec!["Product".to_string()], serde_json::json!({"name": "Widget"}));
let users = store.find_nodes_by_label("User");
assert_eq!(users.len(), 2);
let products = store.find_nodes_by_label("Product");
assert_eq!(products.len(), 1);
}
#[test]
fn test_delete_node() {
let store = GraphStore::new();
let alice = store.create_node(vec!["User".to_string()], serde_json::json!({}));
let bob = store.create_node(vec!["User".to_string()], serde_json::json!({}));
store.create_edge(alice, bob, "FRIEND", serde_json::json!({})).unwrap();
// Delete Alice - should also delete the edge
store.delete_node(&alice).unwrap();
assert!(store.get_node(&alice).is_none());
assert_eq!(store.stats().edge_count, 0);
}
#[test]
fn test_undirected_edge() {
let store = GraphStore::new();
let a = store.create_node(vec![], serde_json::json!({}));
let b = store.create_node(vec![], serde_json::json!({}));
store.create_undirected_edge(a, b, "LINK", serde_json::json!({})).unwrap();
// Both directions should work
let a_neighbors = store.neighbors(&a, Direction::Outgoing);
let b_neighbors = store.neighbors(&b, Direction::Outgoing);
assert_eq!(a_neighbors.len(), 1);
assert_eq!(b_neighbors.len(), 1);
}
#[test]
fn test_edges_between() {
let store = GraphStore::new();
let a = store.create_node(vec![], serde_json::json!({}));
let b = store.create_node(vec![], serde_json::json!({}));
store.create_edge(a, b, "TYPE_A", serde_json::json!({})).unwrap();
store.create_edge(a, b, "TYPE_B", serde_json::json!({})).unwrap();
let edges = store.edges_between(&a, &b);
assert_eq!(edges.len(), 2);
}
}

View file

@ -0,0 +1,500 @@
//! Graph traversal algorithms.
use super::edge::Edge;
use super::node::{Node, NodeId};
use super::store::{Direction, GraphStore};
use serde::{Deserialize, Serialize};
use serde_json::Value as JsonValue;
use std::collections::{HashSet, VecDeque};
/// Query for graph traversal.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct TraversalQuery {
/// Maximum depth to traverse.
pub max_depth: usize,
/// Edge types to follow (empty = all types).
pub edge_types: Vec<String>,
/// Direction to traverse.
pub direction: TraversalDirection,
/// Filter for nodes to include.
pub node_filter: Option<JsonValue>,
/// Filter for edges to follow.
pub edge_filter: Option<JsonValue>,
/// Maximum results to return.
pub limit: Option<usize>,
/// Labels to filter nodes by.
pub labels: Vec<String>,
}
/// Direction for traversal serialization.
#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
pub enum TraversalDirection {
Outgoing,
Incoming,
Both,
}
impl From<TraversalDirection> for Direction {
fn from(td: TraversalDirection) -> Self {
match td {
TraversalDirection::Outgoing => Direction::Outgoing,
TraversalDirection::Incoming => Direction::Incoming,
TraversalDirection::Both => Direction::Both,
}
}
}
impl Default for TraversalQuery {
fn default() -> Self {
Self {
max_depth: 3,
edge_types: Vec::new(),
direction: TraversalDirection::Outgoing,
node_filter: None,
edge_filter: None,
limit: None,
labels: Vec::new(),
}
}
}
impl TraversalQuery {
/// Creates a new traversal query.
pub fn new() -> Self {
Self::default()
}
/// Sets the maximum depth.
pub fn depth(mut self, depth: usize) -> Self {
self.max_depth = depth;
self
}
/// Sets edge types to follow.
pub fn edge_types(mut self, types: Vec<String>) -> Self {
self.edge_types = types;
self
}
/// Sets the traversal direction.
pub fn direction(mut self, dir: TraversalDirection) -> Self {
self.direction = dir;
self
}
/// Sets a node filter.
pub fn node_filter(mut self, filter: JsonValue) -> Self {
self.node_filter = Some(filter);
self
}
/// Sets an edge filter.
pub fn edge_filter(mut self, filter: JsonValue) -> Self {
self.edge_filter = Some(filter);
self
}
/// Sets result limit.
pub fn limit(mut self, limit: usize) -> Self {
self.limit = Some(limit);
self
}
/// Sets label filter.
pub fn labels(mut self, labels: Vec<String>) -> Self {
self.labels = labels;
self
}
}
/// Result of a traversal operation.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct TraversalResult {
/// The node found.
pub node: Node,
/// Depth at which this node was found.
pub depth: usize,
/// Path from start to this node (node IDs).
pub path: Vec<NodeId>,
/// Edges followed to reach this node.
pub edges: Vec<Edge>,
}
/// Graph traverser for executing traversal queries.
pub struct Traverser<'a> {
store: &'a GraphStore,
}
impl<'a> Traverser<'a> {
/// Creates a new traverser.
pub fn new(store: &'a GraphStore) -> Self {
Self { store }
}
/// Executes a BFS traversal from a starting node.
pub fn traverse(&self, start: &NodeId, query: &TraversalQuery) -> Vec<TraversalResult> {
let mut results = Vec::new();
let mut visited = HashSet::new();
let mut queue: VecDeque<(NodeId, usize, Vec<NodeId>, Vec<Edge>)> = VecDeque::new();
visited.insert(*start);
queue.push_back((*start, 0, vec![*start], Vec::new()));
let direction: Direction = query.direction.into();
while let Some((current_id, depth, path, edges_path)) = queue.pop_front() {
// Check limit
if let Some(limit) = query.limit {
if results.len() >= limit {
break;
}
}
// Get current node
if let Some(node) = self.store.get_node(&current_id) {
// Skip start node in results (depth 0)
if depth > 0 {
// Apply filters
if self.matches_query(&node, query) {
results.push(TraversalResult {
node: node.clone(),
depth,
path: path.clone(),
edges: edges_path.clone(),
});
}
}
// Continue if not at max depth
if depth < query.max_depth {
let edges = self.store.edges_of(&current_id, direction);
for edge in edges {
// Check edge type filter
if !query.edge_types.is_empty() && !query.edge_types.contains(&edge.edge_type) {
continue;
}
// Check edge filter
if let Some(ref filter) = query.edge_filter {
if !edge.matches_properties(filter) {
continue;
}
}
// Get neighbor
let neighbor_id = self.get_neighbor(&edge, &current_id, direction);
if let Some(next_id) = neighbor_id {
if !visited.contains(&next_id) {
visited.insert(next_id);
let mut new_path = path.clone();
new_path.push(next_id);
let mut new_edges = edges_path.clone();
new_edges.push(edge);
queue.push_back((next_id, depth + 1, new_path, new_edges));
}
}
}
}
}
}
results
}
/// Executes a DFS traversal from a starting node.
pub fn traverse_dfs(&self, start: &NodeId, query: &TraversalQuery) -> Vec<TraversalResult> {
let mut results = Vec::new();
let mut visited = HashSet::new();
let direction: Direction = query.direction.into();
self.dfs_visit(
start,
0,
vec![*start],
Vec::new(),
&mut visited,
&mut results,
query,
direction,
);
results
}
fn dfs_visit(
&self,
current_id: &NodeId,
depth: usize,
path: Vec<NodeId>,
edges_path: Vec<Edge>,
visited: &mut HashSet<NodeId>,
results: &mut Vec<TraversalResult>,
query: &TraversalQuery,
direction: Direction,
) {
// Check limit
if let Some(limit) = query.limit {
if results.len() >= limit {
return;
}
}
visited.insert(*current_id);
if let Some(node) = self.store.get_node(current_id) {
// Skip start node in results
if depth > 0 && self.matches_query(&node, query) {
results.push(TraversalResult {
node: node.clone(),
depth,
path: path.clone(),
edges: edges_path.clone(),
});
}
// Continue if not at max depth
if depth < query.max_depth {
let edges = self.store.edges_of(current_id, direction);
for edge in edges {
// Check edge type filter
if !query.edge_types.is_empty() && !query.edge_types.contains(&edge.edge_type) {
continue;
}
// Check edge filter
if let Some(ref filter) = query.edge_filter {
if !edge.matches_properties(filter) {
continue;
}
}
if let Some(next_id) = self.get_neighbor(&edge, current_id, direction) {
if !visited.contains(&next_id) {
let mut new_path = path.clone();
new_path.push(next_id);
let mut new_edges = edges_path.clone();
new_edges.push(edge);
self.dfs_visit(
&next_id,
depth + 1,
new_path,
new_edges,
visited,
results,
query,
direction,
);
}
}
}
}
}
}
/// Checks if a node matches the query filters.
fn matches_query(&self, node: &Node, query: &TraversalQuery) -> bool {
// Check labels
if !query.labels.is_empty() {
let has_label = query.labels.iter().any(|l| node.has_label(l));
if !has_label {
return false;
}
}
// Check node filter
if let Some(ref filter) = query.node_filter {
if !node.matches_properties(filter) {
return false;
}
}
true
}
/// Gets the neighbor node ID from an edge.
fn get_neighbor(&self, edge: &Edge, from: &NodeId, direction: Direction) -> Option<NodeId> {
match direction {
Direction::Outgoing => {
if &edge.source == from {
Some(edge.target)
} else if !edge.directed && &edge.target == from {
Some(edge.source)
} else {
None
}
}
Direction::Incoming => {
if &edge.target == from {
Some(edge.source)
} else if !edge.directed && &edge.source == from {
Some(edge.target)
} else {
None
}
}
Direction::Both => {
if &edge.source == from {
Some(edge.target)
} else if &edge.target == from {
Some(edge.source)
} else {
None
}
}
}
}
/// Finds all nodes within a certain distance.
pub fn within_distance(&self, start: &NodeId, max_distance: usize) -> Vec<(Node, usize)> {
let query = TraversalQuery::new().depth(max_distance);
self.traverse(start, &query)
.into_iter()
.map(|r| (r.node, r.depth))
.collect()
}
/// Finds mutual connections between two nodes.
pub fn mutual_connections(
&self,
node_a: &NodeId,
node_b: &NodeId,
edge_type: Option<&str>,
) -> Vec<Node> {
let query = TraversalQuery::new()
.depth(1)
.edge_types(edge_type.map(|s| vec![s.to_string()]).unwrap_or_default());
let neighbors_a: HashSet<NodeId> = self
.traverse(node_a, &query)
.into_iter()
.map(|r| r.node.id)
.collect();
let neighbors_b: HashSet<NodeId> = self
.traverse(node_b, &query)
.into_iter()
.map(|r| r.node.id)
.collect();
let mutual: HashSet<_> = neighbors_a.intersection(&neighbors_b).collect();
mutual
.iter()
.filter_map(|id| self.store.get_node(id))
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn setup_social_graph() -> GraphStore {
let store = GraphStore::new();
let alice = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Alice"}));
let bob = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Bob"}));
let charlie = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Charlie"}));
let dave = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Dave"}));
// Alice -> Bob -> Charlie -> Dave
store.create_edge(alice, bob, "FRIEND", serde_json::json!({})).unwrap();
store.create_edge(bob, charlie, "FRIEND", serde_json::json!({})).unwrap();
store.create_edge(charlie, dave, "FRIEND", serde_json::json!({})).unwrap();
// Alice -> Charlie (shortcut)
store.create_edge(alice, charlie, "KNOWS", serde_json::json!({})).unwrap();
store
}
#[test]
fn test_basic_traversal() {
let store = setup_social_graph();
let traverser = Traverser::new(&store);
let users = store.find_nodes_by_label("User");
let alice = users.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("Alice"))).unwrap();
let query = TraversalQuery::new().depth(2);
let results = traverser.traverse(&alice.id, &query);
// Should find Bob (depth 1), Charlie (depth 1 and 2), and Dave (depth 2)
assert!(results.len() >= 2);
}
#[test]
fn test_edge_type_filter() {
let store = setup_social_graph();
let traverser = Traverser::new(&store);
let users = store.find_nodes_by_label("User");
let alice = users.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("Alice"))).unwrap();
let query = TraversalQuery::new()
.depth(2)
.edge_types(vec!["FRIEND".to_string()]);
let results = traverser.traverse(&alice.id, &query);
// Following only FRIEND edges: Alice -> Bob -> Charlie
let names: Vec<_> = results.iter().filter_map(|r| r.node.get_property("name")).collect();
assert!(names.contains(&&serde_json::json!("Bob")));
assert!(names.contains(&&serde_json::json!("Charlie")));
}
#[test]
fn test_depth_limit() {
let store = setup_social_graph();
let traverser = Traverser::new(&store);
let users = store.find_nodes_by_label("User");
let alice = users.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("Alice"))).unwrap();
let query = TraversalQuery::new().depth(1);
let results = traverser.traverse(&alice.id, &query);
// Depth 1: only direct neighbors
for result in &results {
assert_eq!(result.depth, 1);
}
}
#[test]
fn test_result_limit() {
let store = setup_social_graph();
let traverser = Traverser::new(&store);
let users = store.find_nodes_by_label("User");
let alice = users.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("Alice"))).unwrap();
let query = TraversalQuery::new().depth(10).limit(2);
let results = traverser.traverse(&alice.id, &query);
assert!(results.len() <= 2);
}
#[test]
fn test_mutual_connections() {
let store = GraphStore::new();
let alice = store.create_node(vec![], serde_json::json!({"name": "Alice"}));
let bob = store.create_node(vec![], serde_json::json!({"name": "Bob"}));
let mutual1 = store.create_node(vec![], serde_json::json!({"name": "Mutual1"}));
let mutual2 = store.create_node(vec![], serde_json::json!({"name": "Mutual2"}));
let only_alice = store.create_node(vec![], serde_json::json!({"name": "OnlyAlice"}));
store.create_edge(alice, mutual1, "FRIEND", serde_json::json!({})).unwrap();
store.create_edge(alice, mutual2, "FRIEND", serde_json::json!({})).unwrap();
store.create_edge(alice, only_alice, "FRIEND", serde_json::json!({})).unwrap();
store.create_edge(bob, mutual1, "FRIEND", serde_json::json!({})).unwrap();
store.create_edge(bob, mutual2, "FRIEND", serde_json::json!({})).unwrap();
let traverser = Traverser::new(&store);
let mutual = traverser.mutual_connections(&alice, &bob, Some("FRIEND"));
assert_eq!(mutual.len(), 2);
}
}

View file

@ -45,20 +45,32 @@
pub mod document; pub mod document;
pub mod error; pub mod error;
pub mod gateway; pub mod gateway;
pub mod graph;
pub mod index; pub mod index;
pub mod keyvalue; pub mod keyvalue;
pub mod query; pub mod query;
pub mod replication;
pub mod schema; pub mod schema;
pub mod sql;
pub mod timeseries; pub mod timeseries;
pub mod vector; pub mod vector;
pub use document::{Collection, Document, DocumentId, DocumentStore}; pub use document::{Collection, Document, DocumentId, DocumentStore};
pub use error::DatabaseError; pub use error::DatabaseError;
pub use gateway::{GatewayConfig, GatewayServer}; pub use gateway::{GatewayConfig, GatewayServer};
pub use graph::{
Direction, Edge, EdgeId, GraphError, GraphQuery, GraphQueryParser, GraphStore, Node, NodeId,
PathFinder, PathResult, TraversalQuery, TraversalResult, Traverser,
};
pub use index::{Index, IndexConfig, IndexManager, IndexType}; pub use index::{Index, IndexConfig, IndexManager, IndexType};
pub use keyvalue::{KeyValue, KeyValueStore, KvEntry}; pub use keyvalue::{KeyValue, KeyValueStore, KvEntry};
pub use query::{Filter, Query, QueryEngine, QueryResult, SortOrder}; pub use query::{Filter, Query, QueryEngine, QueryResult, SortOrder};
pub use schema::{Field, FieldType, Schema, SchemaValidator}; pub use schema::{Field, FieldType, Schema, SchemaValidator};
pub use replication::{
ClusterConfig, Command as RaftCommand, NodeRole, RaftConfig, RaftEvent, RaftNode, RaftState,
ReplicatedLog,
};
pub use sql::{QueryResult as SqlQueryResult, SqlEngine, SqlParser, SqlType, SqlValue, Table, TableDef};
pub use timeseries::{DataPoint, Metric, TimeSeries, TimeSeriesStore}; pub use timeseries::{DataPoint, Metric, TimeSeries, TimeSeriesStore};
pub use vector::{Embedding, SimilarityMetric, VectorIndex, VectorStore}; pub use vector::{Embedding, SimilarityMetric, VectorIndex, VectorStore};

View file

@ -0,0 +1,393 @@
//! Cluster configuration and peer management.
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::net::SocketAddr;
use std::str::FromStr;
/// Unique identifier for a node in the cluster.
pub type NodeId = u64;
/// Address for a peer node.
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct PeerAddress {
/// Host address (IP or hostname).
pub host: String,
/// Port number.
pub port: u16,
}
impl PeerAddress {
/// Creates a new peer address.
pub fn new(host: impl Into<String>, port: u16) -> Self {
Self {
host: host.into(),
port,
}
}
/// Parses from "host:port" format.
pub fn parse(s: &str) -> Option<Self> {
let parts: Vec<&str> = s.split(':').collect();
if parts.len() == 2 {
parts[1].parse().ok().map(|port| Self {
host: parts[0].to_string(),
port,
})
} else {
None
}
}
/// Converts to SocketAddr if possible.
pub fn to_socket_addr(&self) -> Option<SocketAddr> {
SocketAddr::from_str(&format!("{}:{}", self.host, self.port)).ok()
}
}
impl std::fmt::Display for PeerAddress {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}:{}", self.host, self.port)
}
}
/// Information about a peer node.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct PeerInfo {
/// Node identifier.
pub id: NodeId,
/// Network address.
pub address: PeerAddress,
/// Whether this peer is a voting member.
pub voting: bool,
/// Last known state.
pub state: PeerState,
}
impl PeerInfo {
/// Creates new peer info.
pub fn new(id: NodeId, address: PeerAddress) -> Self {
Self {
id,
address,
voting: true,
state: PeerState::Unknown,
}
}
/// Creates a non-voting learner peer.
pub fn learner(id: NodeId, address: PeerAddress) -> Self {
Self {
id,
address,
voting: false,
state: PeerState::Unknown,
}
}
}
/// State of a peer from this node's perspective.
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub enum PeerState {
/// State unknown (initial).
Unknown,
/// Peer is reachable.
Reachable,
/// Peer is unreachable.
Unreachable,
/// Peer is being probed.
Probing,
}
/// Configuration for the cluster.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ClusterConfig {
/// This node's ID.
pub node_id: NodeId,
/// This node's address.
pub address: PeerAddress,
/// Known peers in the cluster.
pub peers: HashMap<NodeId, PeerInfo>,
/// Configuration index (for joint consensus).
pub config_index: u64,
/// Whether this is a joint configuration.
pub joint: bool,
/// Old configuration (for joint consensus).
pub old_peers: Option<HashMap<NodeId, PeerInfo>>,
}
impl ClusterConfig {
/// Creates a new single-node cluster configuration.
pub fn new(node_id: NodeId, address: PeerAddress) -> Self {
Self {
node_id,
address,
peers: HashMap::new(),
config_index: 0,
joint: false,
old_peers: None,
}
}
/// Adds a peer to the cluster.
pub fn add_peer(&mut self, peer: PeerInfo) {
self.peers.insert(peer.id, peer);
}
/// Removes a peer from the cluster.
pub fn remove_peer(&mut self, id: NodeId) -> Option<PeerInfo> {
self.peers.remove(&id)
}
/// Gets a peer by ID.
pub fn get_peer(&self, id: NodeId) -> Option<&PeerInfo> {
self.peers.get(&id)
}
/// Gets a mutable reference to a peer.
pub fn get_peer_mut(&mut self, id: NodeId) -> Option<&mut PeerInfo> {
self.peers.get_mut(&id)
}
/// Returns all peer IDs.
pub fn peer_ids(&self) -> Vec<NodeId> {
self.peers.keys().copied().collect()
}
/// Returns all voting peer IDs.
pub fn voting_peers(&self) -> Vec<NodeId> {
self.peers
.iter()
.filter(|(_, p)| p.voting)
.map(|(id, _)| *id)
.collect()
}
/// Returns the total number of voting members (including self).
pub fn voting_members(&self) -> usize {
self.peers.values().filter(|p| p.voting).count() + 1
}
/// Returns the quorum size needed for consensus.
pub fn quorum_size(&self) -> usize {
self.voting_members() / 2 + 1
}
/// Checks if we have quorum with the given votes.
pub fn has_quorum(&self, votes: usize) -> bool {
votes >= self.quorum_size()
}
/// Starts a configuration change (joint consensus).
pub fn begin_config_change(&mut self, new_peers: HashMap<NodeId, PeerInfo>, index: u64) {
self.old_peers = Some(self.peers.clone());
self.peers = new_peers;
self.config_index = index;
self.joint = true;
}
/// Completes a configuration change.
pub fn complete_config_change(&mut self) {
self.old_peers = None;
self.joint = false;
}
/// Aborts a configuration change.
pub fn abort_config_change(&mut self) {
if let Some(old) = self.old_peers.take() {
self.peers = old;
}
self.joint = false;
}
/// Checks if node is in joint consensus mode.
pub fn is_joint(&self) -> bool {
self.joint
}
/// For joint consensus: checks if we have quorum in BOTH configurations.
pub fn has_joint_quorum(&self, new_votes: usize, old_votes: usize) -> bool {
if !self.joint {
return self.has_quorum(new_votes);
}
let new_quorum = self.quorum_size();
let old_quorum = self
.old_peers
.as_ref()
.map(|p| p.values().filter(|peer| peer.voting).count() / 2 + 1)
.unwrap_or(1);
new_votes >= new_quorum && old_votes >= old_quorum
}
/// Updates peer state.
pub fn update_peer_state(&mut self, id: NodeId, state: PeerState) {
if let Some(peer) = self.peers.get_mut(&id) {
peer.state = state;
}
}
/// Returns all reachable peers.
pub fn reachable_peers(&self) -> Vec<NodeId> {
self.peers
.iter()
.filter(|(_, p)| p.state == PeerState::Reachable)
.map(|(id, _)| *id)
.collect()
}
/// Serializes the configuration.
pub fn to_bytes(&self) -> Vec<u8> {
bincode::serialize(self).unwrap_or_default()
}
/// Deserializes the configuration.
pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
bincode::deserialize(bytes).ok()
}
}
impl Default for ClusterConfig {
fn default() -> Self {
Self::new(1, PeerAddress::new("127.0.0.1", 9000))
}
}
/// Builder for cluster configurations.
pub struct ClusterBuilder {
config: ClusterConfig,
}
impl ClusterBuilder {
/// Creates a new builder.
pub fn new(node_id: NodeId, address: PeerAddress) -> Self {
Self {
config: ClusterConfig::new(node_id, address),
}
}
/// Adds a peer.
pub fn with_peer(mut self, id: NodeId, address: PeerAddress) -> Self {
self.config.add_peer(PeerInfo::new(id, address));
self
}
/// Adds a learner (non-voting) peer.
pub fn with_learner(mut self, id: NodeId, address: PeerAddress) -> Self {
self.config.add_peer(PeerInfo::learner(id, address));
self
}
/// Builds the configuration.
pub fn build(self) -> ClusterConfig {
self.config
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_peer_address() {
let addr = PeerAddress::new("192.168.1.1", 9000);
assert_eq!(addr.to_string(), "192.168.1.1:9000");
let parsed = PeerAddress::parse("10.0.0.1:8080").unwrap();
assert_eq!(parsed.host, "10.0.0.1");
assert_eq!(parsed.port, 8080);
}
#[test]
fn test_cluster_config() {
let mut config = ClusterConfig::new(1, PeerAddress::new("127.0.0.1", 9000));
config.add_peer(PeerInfo::new(2, PeerAddress::new("127.0.0.1", 9001)));
config.add_peer(PeerInfo::new(3, PeerAddress::new("127.0.0.1", 9002)));
assert_eq!(config.voting_members(), 3);
assert_eq!(config.quorum_size(), 2);
assert!(config.has_quorum(2));
assert!(!config.has_quorum(1));
}
#[test]
fn test_quorum_sizes() {
// 1 node: quorum = 1
let config1 = ClusterConfig::new(1, PeerAddress::new("127.0.0.1", 9000));
assert_eq!(config1.quorum_size(), 1);
// 3 nodes: quorum = 2
let mut config3 = ClusterConfig::new(1, PeerAddress::new("127.0.0.1", 9000));
config3.add_peer(PeerInfo::new(2, PeerAddress::new("127.0.0.1", 9001)));
config3.add_peer(PeerInfo::new(3, PeerAddress::new("127.0.0.1", 9002)));
assert_eq!(config3.quorum_size(), 2);
// 5 nodes: quorum = 3
let mut config5 = config3.clone();
config5.add_peer(PeerInfo::new(4, PeerAddress::new("127.0.0.1", 9003)));
config5.add_peer(PeerInfo::new(5, PeerAddress::new("127.0.0.1", 9004)));
assert_eq!(config5.quorum_size(), 3);
}
#[test]
fn test_learner_peers() {
let mut config = ClusterConfig::new(1, PeerAddress::new("127.0.0.1", 9000));
config.add_peer(PeerInfo::new(2, PeerAddress::new("127.0.0.1", 9001)));
config.add_peer(PeerInfo::learner(3, PeerAddress::new("127.0.0.1", 9002)));
// Learners don't count toward quorum
assert_eq!(config.voting_members(), 2); // self + node 2
assert_eq!(config.quorum_size(), 2);
assert_eq!(config.voting_peers().len(), 1); // only node 2
}
#[test]
fn test_cluster_builder() {
let config = ClusterBuilder::new(1, PeerAddress::new("127.0.0.1", 9000))
.with_peer(2, PeerAddress::new("127.0.0.1", 9001))
.with_peer(3, PeerAddress::new("127.0.0.1", 9002))
.with_learner(4, PeerAddress::new("127.0.0.1", 9003))
.build();
assert_eq!(config.peer_ids().len(), 3);
assert_eq!(config.voting_members(), 3);
}
#[test]
fn test_joint_consensus() {
let mut config = ClusterConfig::new(1, PeerAddress::new("127.0.0.1", 9000));
config.add_peer(PeerInfo::new(2, PeerAddress::new("127.0.0.1", 9001)));
config.add_peer(PeerInfo::new(3, PeerAddress::new("127.0.0.1", 9002)));
// Start config change: add node 4
let mut new_peers = config.peers.clone();
new_peers.insert(4, PeerInfo::new(4, PeerAddress::new("127.0.0.1", 9003)));
config.begin_config_change(new_peers, 100);
assert!(config.is_joint());
// Need quorum in both old (2 of 3) and new (3 of 4) configs
assert!(config.has_joint_quorum(3, 2));
assert!(!config.has_joint_quorum(2, 2)); // Not enough in new config
assert!(!config.has_joint_quorum(3, 1)); // Not enough in old config
config.complete_config_change();
assert!(!config.is_joint());
assert_eq!(config.voting_members(), 4);
}
#[test]
fn test_serialization() {
let config = ClusterBuilder::new(1, PeerAddress::new("127.0.0.1", 9000))
.with_peer(2, PeerAddress::new("127.0.0.1", 9001))
.build();
let bytes = config.to_bytes();
let decoded = ClusterConfig::from_bytes(&bytes).unwrap();
assert_eq!(decoded.node_id, 1);
assert_eq!(decoded.peers.len(), 1);
}
}

View file

@ -0,0 +1,311 @@
//! Leader election logic for Raft.
use super::log::ReplicatedLog;
use super::rpc::{RequestVote, RequestVoteResponse};
use super::state::RaftState;
use std::collections::HashSet;
/// Result of a leader election.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum ElectionResult {
/// Won the election, became leader.
Won,
/// Lost the election (higher term discovered).
Lost,
/// Election timed out, will retry.
Timeout,
/// Still in progress.
InProgress,
}
/// Tracks election state.
pub struct Election {
/// Node ID of the candidate.
node_id: u64,
/// Term for this election.
term: u64,
/// Number of votes received (including self).
votes_received: HashSet<u64>,
/// Total number of nodes in cluster (including self).
cluster_size: usize,
/// Whether this election is still active.
active: bool,
}
impl Election {
/// Starts a new election.
pub fn new(node_id: u64, term: u64, cluster_size: usize) -> Self {
let mut votes = HashSet::new();
votes.insert(node_id); // Vote for self
Self {
node_id,
term,
votes_received: votes,
cluster_size,
active: true,
}
}
/// Returns the term for this election.
pub fn term(&self) -> u64 {
self.term
}
/// Checks if the election is still active.
pub fn is_active(&self) -> bool {
self.active
}
/// Gets the number of votes received.
pub fn vote_count(&self) -> usize {
self.votes_received.len()
}
/// Gets the majority threshold.
pub fn majority(&self) -> usize {
(self.cluster_size / 2) + 1
}
/// Records a vote from a peer.
pub fn record_vote(&mut self, peer_id: u64, granted: bool) -> ElectionResult {
if !self.active {
return ElectionResult::Lost;
}
if granted {
self.votes_received.insert(peer_id);
if self.votes_received.len() >= self.majority() {
self.active = false;
return ElectionResult::Won;
}
}
ElectionResult::InProgress
}
/// Cancels the election (e.g., discovered higher term).
pub fn cancel(&mut self) {
self.active = false;
}
/// Creates a RequestVote message for this election.
pub fn create_request(&self, log: &ReplicatedLog) -> RequestVote {
RequestVote::new(
self.term,
self.node_id,
log.last_index(),
log.last_term(),
)
}
/// Checks the current result of the election.
pub fn result(&self) -> ElectionResult {
if !self.active {
ElectionResult::Lost
} else if self.votes_received.len() >= self.majority() {
ElectionResult::Won
} else {
ElectionResult::InProgress
}
}
}
/// Handles vote requests from candidates.
pub struct VoteHandler;
impl VoteHandler {
/// Processes a vote request and returns whether to grant the vote.
pub fn handle_request(
state: &mut RaftState,
log: &ReplicatedLog,
request: &RequestVote,
) -> RequestVoteResponse {
// If request's term < current term, deny
if request.term < state.current_term {
return RequestVoteResponse::deny(state.current_term);
}
// If request's term > current term, update term and become follower
if request.term > state.current_term {
state.become_follower(request.term);
}
// Check if we can grant vote
let can_vote = state.voted_for.is_none() || state.voted_for == Some(request.candidate_id);
// Check if candidate's log is at least as up-to-date as ours
let log_ok = log.is_up_to_date(request.last_log_index, request.last_log_term);
if can_vote && log_ok {
state.voted_for = Some(request.candidate_id);
RequestVoteResponse::grant(state.current_term)
} else {
RequestVoteResponse::deny(state.current_term)
}
}
/// Processes a vote response.
pub fn handle_response(
state: &mut RaftState,
election: &mut Election,
from_peer: u64,
response: &RequestVoteResponse,
) -> ElectionResult {
// If response's term > current term, become follower
if response.term > state.current_term {
state.become_follower(response.term);
election.cancel();
return ElectionResult::Lost;
}
// Record the vote
election.record_vote(from_peer, response.vote_granted)
}
}
/// Election timeout generator.
pub struct ElectionTimeout {
/// Minimum timeout in milliseconds.
min_ms: u64,
/// Maximum timeout in milliseconds.
max_ms: u64,
}
impl ElectionTimeout {
/// Creates a new timeout generator.
pub fn new(min_ms: u64, max_ms: u64) -> Self {
Self { min_ms, max_ms }
}
/// Generates a random timeout duration.
pub fn random_timeout(&self) -> std::time::Duration {
use std::time::{SystemTime, UNIX_EPOCH};
// Simple pseudo-random based on current time
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_nanos();
let range = self.max_ms - self.min_ms;
let random_add = (now % range as u128) as u64;
let timeout_ms = self.min_ms + random_add;
std::time::Duration::from_millis(timeout_ms)
}
/// Returns the minimum timeout.
pub fn min_timeout(&self) -> std::time::Duration {
std::time::Duration::from_millis(self.min_ms)
}
/// Returns the maximum timeout.
pub fn max_timeout(&self) -> std::time::Duration {
std::time::Duration::from_millis(self.max_ms)
}
}
impl Default for ElectionTimeout {
fn default() -> Self {
// Default Raft election timeout: 150-300ms
Self::new(150, 300)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::replication::state::Command;
use crate::replication::log::LogEntry;
#[test]
fn test_election_basic() {
let election = Election::new(1, 1, 5);
assert!(election.is_active());
assert_eq!(election.vote_count(), 1); // Self vote
assert_eq!(election.majority(), 3);
}
#[test]
fn test_election_win() {
let mut election = Election::new(1, 1, 5);
// Need 3 votes total (including self)
assert_eq!(election.record_vote(2, true), ElectionResult::InProgress);
assert_eq!(election.record_vote(3, true), ElectionResult::Won);
}
#[test]
fn test_election_rejected_votes() {
let mut election = Election::new(1, 1, 5);
// Rejected votes don't count
assert_eq!(election.record_vote(2, false), ElectionResult::InProgress);
assert_eq!(election.record_vote(3, false), ElectionResult::InProgress);
assert_eq!(election.vote_count(), 1);
}
#[test]
fn test_vote_handler_grant() {
let mut state = RaftState::new();
let log = ReplicatedLog::new();
let request = RequestVote::new(1, 2, 0, 0);
let response = VoteHandler::handle_request(&mut state, &log, &request);
assert!(response.vote_granted);
assert_eq!(state.voted_for, Some(2));
}
#[test]
fn test_vote_handler_deny_old_term() {
let mut state = RaftState::new();
state.current_term = 5;
let log = ReplicatedLog::new();
let request = RequestVote::new(3, 2, 10, 3);
let response = VoteHandler::handle_request(&mut state, &log, &request);
assert!(!response.vote_granted);
assert_eq!(response.term, 5);
}
#[test]
fn test_vote_handler_deny_already_voted() {
let mut state = RaftState::new();
state.current_term = 1; // Same term as request
state.voted_for = Some(3); // Already voted for node 3 in this term
let log = ReplicatedLog::new();
// Request from node 2 for the same term - should be denied
let request = RequestVote::new(1, 2, 0, 0);
let response = VoteHandler::handle_request(&mut state, &log, &request);
assert!(!response.vote_granted);
}
#[test]
fn test_vote_handler_deny_log_behind() {
let mut state = RaftState::new();
let log = ReplicatedLog::new();
log.append(LogEntry::new(2, 1, Command::Noop)); // Our log has term 2
let request = RequestVote::new(2, 2, 10, 1); // Candidate has term 1 entries
let response = VoteHandler::handle_request(&mut state, &log, &request);
assert!(!response.vote_granted);
}
#[test]
fn test_election_timeout() {
let timeout = ElectionTimeout::new(150, 300);
for _ in 0..10 {
let duration = timeout.random_timeout();
assert!(duration >= std::time::Duration::from_millis(150));
assert!(duration <= std::time::Duration::from_millis(300));
}
}
}

View file

@ -0,0 +1,387 @@
//! Replicated log for Raft consensus.
use super::state::Command;
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
/// A single entry in the replicated log.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct LogEntry {
/// Term when entry was received by leader.
pub term: u64,
/// Position in the log (1-indexed).
pub index: u64,
/// Command to execute.
pub command: Command,
/// Timestamp when entry was created.
pub timestamp: u64,
}
impl LogEntry {
/// Creates a new log entry.
pub fn new(term: u64, index: u64, command: Command) -> Self {
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis() as u64;
Self {
term,
index,
command,
timestamp,
}
}
/// Serializes to bytes.
pub fn to_bytes(&self) -> Vec<u8> {
bincode::serialize(self).unwrap_or_default()
}
/// Deserializes from bytes.
pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
bincode::deserialize(bytes).ok()
}
}
/// Replicated log storing all commands.
pub struct ReplicatedLog {
/// Log entries.
entries: RwLock<Vec<LogEntry>>,
/// Index of first entry in the log (for log compaction).
start_index: RwLock<u64>,
/// Term of last included entry (for snapshots).
snapshot_term: RwLock<u64>,
}
impl ReplicatedLog {
/// Creates a new empty log.
pub fn new() -> Self {
Self {
entries: RwLock::new(Vec::new()),
start_index: RwLock::new(1), // 1-indexed
snapshot_term: RwLock::new(0),
}
}
/// Returns the index of the last entry.
pub fn last_index(&self) -> u64 {
let entries = self.entries.read();
let start = *self.start_index.read();
if entries.is_empty() {
start.saturating_sub(1)
} else {
start + entries.len() as u64 - 1
}
}
/// Returns the term of the last entry.
pub fn last_term(&self) -> u64 {
let entries = self.entries.read();
if entries.is_empty() {
*self.snapshot_term.read()
} else {
entries.last().map(|e| e.term).unwrap_or(0)
}
}
/// Returns the term at a given index.
pub fn term_at(&self, index: u64) -> Option<u64> {
let start = *self.start_index.read();
if index < start {
return if index == start - 1 {
Some(*self.snapshot_term.read())
} else {
None
};
}
let entries = self.entries.read();
let offset = (index - start) as usize;
entries.get(offset).map(|e| e.term)
}
/// Gets an entry by index.
pub fn get(&self, index: u64) -> Option<LogEntry> {
let start = *self.start_index.read();
if index < start {
return None;
}
let entries = self.entries.read();
let offset = (index - start) as usize;
entries.get(offset).cloned()
}
/// Gets entries from start_index to end_index (inclusive).
pub fn get_range(&self, start_idx: u64, end_idx: u64) -> Vec<LogEntry> {
let start = *self.start_index.read();
if start_idx > end_idx || start_idx < start {
return Vec::new();
}
let entries = self.entries.read();
let start_offset = (start_idx - start) as usize;
let end_offset = (end_idx - start + 1) as usize;
entries
.get(start_offset..end_offset.min(entries.len()))
.map(|s| s.to_vec())
.unwrap_or_default()
}
/// Gets all entries from a given index.
pub fn entries_from(&self, from_index: u64) -> Vec<LogEntry> {
let start = *self.start_index.read();
if from_index < start {
return self.entries.read().clone();
}
let entries = self.entries.read();
let offset = (from_index - start) as usize;
entries.get(offset..).map(|s| s.to_vec()).unwrap_or_default()
}
/// Appends an entry to the log.
pub fn append(&self, entry: LogEntry) -> u64 {
let mut entries = self.entries.write();
let index = entry.index;
entries.push(entry);
index
}
/// Appends multiple entries, potentially overwriting conflicting entries.
pub fn append_entries(&self, prev_index: u64, prev_term: u64, new_entries: Vec<LogEntry>) -> bool {
// Check that prev entry matches
if prev_index > 0 {
if let Some(prev_entry_term) = self.term_at(prev_index) {
if prev_entry_term != prev_term {
// Conflict - need to truncate
return false;
}
} else if prev_index >= *self.start_index.read() {
// Missing entry
return false;
}
}
let mut entries = self.entries.write();
let start = *self.start_index.read();
for entry in new_entries {
let offset = (entry.index - start) as usize;
if offset < entries.len() {
// Check for conflict
if entries[offset].term != entry.term {
// Delete this and all following entries
entries.truncate(offset);
entries.push(entry);
}
// Otherwise entry already exists with same term, skip
} else {
entries.push(entry);
}
}
true
}
/// Truncates the log after the given index.
pub fn truncate_after(&self, index: u64) {
let start = *self.start_index.read();
if index < start {
return;
}
let mut entries = self.entries.write();
let offset = (index - start + 1) as usize;
entries.truncate(offset);
}
/// Returns the number of entries.
pub fn len(&self) -> usize {
self.entries.read().len()
}
/// Returns true if the log is empty.
pub fn is_empty(&self) -> bool {
self.entries.read().is_empty()
}
/// Compacts the log up to (and including) the given index.
pub fn compact(&self, up_to_index: u64, up_to_term: u64) {
let mut entries = self.entries.write();
let mut start = self.start_index.write();
let mut snapshot_term = self.snapshot_term.write();
if up_to_index < *start {
return;
}
let remove_count = (up_to_index - *start + 1) as usize;
if remove_count >= entries.len() {
entries.clear();
} else {
entries.drain(0..remove_count);
}
*start = up_to_index + 1;
*snapshot_term = up_to_term;
}
/// Checks if this log is at least as up-to-date as the candidate's log.
/// Used during leader election.
pub fn is_up_to_date(&self, candidate_last_index: u64, candidate_last_term: u64) -> bool {
let last_term = self.last_term();
let last_index = self.last_index();
// Compare terms first, then indices
if last_term != candidate_last_term {
last_term <= candidate_last_term
} else {
last_index <= candidate_last_index
}
}
/// Creates entries for replication starting from a given index.
pub fn entries_for_replication(&self, from_index: u64, max_entries: usize) -> (u64, u64, Vec<LogEntry>) {
let prev_index = from_index.saturating_sub(1);
let prev_term = self.term_at(prev_index).unwrap_or(0);
let entries = self.entries_from(from_index);
let limited = if entries.len() > max_entries {
entries[..max_entries].to_vec()
} else {
entries
};
(prev_index, prev_term, limited)
}
}
impl Default for ReplicatedLog {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_log_basic() {
let log = ReplicatedLog::new();
assert_eq!(log.last_index(), 0);
assert_eq!(log.last_term(), 0);
assert!(log.is_empty());
let entry = LogEntry::new(1, 1, Command::Noop);
log.append(entry);
assert_eq!(log.last_index(), 1);
assert_eq!(log.last_term(), 1);
assert!(!log.is_empty());
}
#[test]
fn test_log_append() {
let log = ReplicatedLog::new();
for i in 1..=5 {
log.append(LogEntry::new(1, i, Command::Noop));
}
assert_eq!(log.len(), 5);
assert_eq!(log.last_index(), 5);
let entry = log.get(3).unwrap();
assert_eq!(entry.index, 3);
}
#[test]
fn test_log_range() {
let log = ReplicatedLog::new();
for i in 1..=10 {
log.append(LogEntry::new(1, i, Command::Noop));
}
let range = log.get_range(3, 7);
assert_eq!(range.len(), 5);
assert_eq!(range[0].index, 3);
assert_eq!(range[4].index, 7);
}
#[test]
fn test_log_compact() {
let log = ReplicatedLog::new();
for i in 1..=10 {
log.append(LogEntry::new(1, i, Command::Noop));
}
log.compact(5, 1);
assert_eq!(log.len(), 5);
assert!(log.get(5).is_none());
assert!(log.get(6).is_some());
assert_eq!(log.last_index(), 10);
}
#[test]
fn test_log_up_to_date() {
let log = ReplicatedLog::new();
log.append(LogEntry::new(1, 1, Command::Noop));
log.append(LogEntry::new(2, 2, Command::Noop));
// Same log is up to date
assert!(log.is_up_to_date(2, 2));
// Higher term is more up to date
assert!(log.is_up_to_date(1, 3));
// Same term, higher index is more up to date
assert!(log.is_up_to_date(3, 2));
// Lower term is not up to date
assert!(!log.is_up_to_date(10, 1));
}
#[test]
fn test_append_entries() {
let log = ReplicatedLog::new();
// Initial entries
log.append(LogEntry::new(1, 1, Command::Noop));
log.append(LogEntry::new(1, 2, Command::Noop));
// Append more entries
let new_entries = vec![
LogEntry::new(1, 3, Command::Noop),
LogEntry::new(2, 4, Command::Noop),
];
let success = log.append_entries(2, 1, new_entries);
assert!(success);
assert_eq!(log.len(), 4);
assert_eq!(log.term_at(4), Some(2));
}
#[test]
fn test_entries_for_replication() {
let log = ReplicatedLog::new();
for i in 1..=5 {
log.append(LogEntry::new(1, i, Command::Noop));
}
let (prev_idx, prev_term, entries) = log.entries_for_replication(3, 10);
assert_eq!(prev_idx, 2);
assert_eq!(prev_term, 1);
assert_eq!(entries.len(), 3);
}
}

View file

@ -0,0 +1,23 @@
//! Raft consensus-based replication for high availability.
//!
//! Provides distributed consensus to ensure data consistency across
//! multiple database nodes with automatic leader election and failover.
pub mod cluster;
pub mod election;
pub mod log;
pub mod raft;
pub mod rpc;
pub mod snapshot;
pub mod state;
pub use cluster::{ClusterBuilder, ClusterConfig, NodeId, PeerAddress, PeerInfo, PeerState};
pub use election::{Election, ElectionResult, ElectionTimeout, VoteHandler};
pub use log::{LogEntry, ReplicatedLog};
pub use raft::{ApplyResult, RaftConfig, RaftEvent, RaftNode};
pub use rpc::{
AppendEntries, AppendEntriesResponse, InstallSnapshot, InstallSnapshotResponse, RequestVote,
RequestVoteResponse, RpcMessage,
};
pub use snapshot::{Snapshot, SnapshotConfig, SnapshotManager, SnapshotMetadata};
pub use state::{Command, LeaderState, NodeRole, RaftState};

View file

@ -0,0 +1,955 @@
//! Raft consensus implementation.
use super::cluster::{ClusterConfig, NodeId, PeerState};
use super::election::{Election, ElectionResult, ElectionTimeout, VoteHandler};
use super::log::{LogEntry, ReplicatedLog};
use super::rpc::{
AppendEntries, AppendEntriesResponse, InstallSnapshot, InstallSnapshotResponse, RequestVote,
RequestVoteResponse, RpcMessage,
};
use super::snapshot::{Snapshot, SnapshotConfig, SnapshotManager};
use super::state::{Command, LeaderState, NodeRole, RaftState};
use serde::{Deserialize, Serialize};
use std::time::{Duration, Instant};
/// Configuration for the Raft node.
#[derive(Clone, Debug)]
pub struct RaftConfig {
/// Minimum election timeout in milliseconds.
pub election_timeout_min: u64,
/// Maximum election timeout in milliseconds.
pub election_timeout_max: u64,
/// Heartbeat interval in milliseconds.
pub heartbeat_interval: u64,
/// Maximum entries per AppendEntries RPC.
pub max_entries_per_rpc: usize,
/// Snapshot threshold (entries before compaction).
pub snapshot_threshold: u64,
/// Maximum snapshot chunk size.
pub snapshot_chunk_size: usize,
}
impl Default for RaftConfig {
fn default() -> Self {
Self {
election_timeout_min: 150,
election_timeout_max: 300,
heartbeat_interval: 50,
max_entries_per_rpc: 100,
snapshot_threshold: 10000,
snapshot_chunk_size: 65536,
}
}
}
/// Result of applying a command.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum ApplyResult {
/// Command applied successfully.
Success(Vec<u8>),
/// Command failed.
Error(String),
/// Not the leader, redirect to leader.
NotLeader(Option<NodeId>),
}
/// Events that can be produced by the Raft node.
#[derive(Clone, Debug)]
pub enum RaftEvent {
/// Send RPC to a peer.
SendRpc(NodeId, RpcMessage),
/// Broadcast RPC to all peers.
BroadcastRpc(RpcMessage),
/// Apply committed entry to state machine.
ApplyEntry(u64, Command),
/// Became leader.
BecameLeader,
/// Became follower.
BecameFollower(Option<NodeId>),
/// Snapshot should be taken.
TakeSnapshot,
/// Log compacted up to index.
LogCompacted(u64),
}
/// The Raft consensus node.
pub struct RaftNode {
/// Node ID.
id: NodeId,
/// Cluster configuration.
cluster: ClusterConfig,
/// Raft configuration.
config: RaftConfig,
/// Persistent state.
state: RaftState,
/// Replicated log.
log: ReplicatedLog,
/// Current election (only during candidacy).
election: Option<Election>,
/// Election timeout generator.
election_timeout: ElectionTimeout,
/// Snapshot manager.
snapshots: SnapshotManager,
/// Leader state (only valid when leader).
leader_state: Option<LeaderState>,
/// Current known leader.
leader_id: Option<NodeId>,
/// Last heartbeat/message from leader.
last_leader_contact: Instant,
/// Current election timeout duration.
current_timeout: Duration,
/// Pending events.
events: Vec<RaftEvent>,
}
impl RaftNode {
/// Creates a new Raft node.
pub fn new(id: NodeId, cluster: ClusterConfig, config: RaftConfig) -> Self {
let election_timeout =
ElectionTimeout::new(config.election_timeout_min, config.election_timeout_max);
let current_timeout = election_timeout.random_timeout();
Self {
id,
cluster,
state: RaftState::new(),
log: ReplicatedLog::new(),
election: None,
election_timeout,
snapshots: SnapshotManager::new(config.snapshot_threshold),
leader_state: None,
leader_id: None,
last_leader_contact: Instant::now(),
current_timeout,
events: Vec::new(),
config,
}
}
/// Returns the node ID.
pub fn id(&self) -> NodeId {
self.id
}
/// Returns the current term.
pub fn current_term(&self) -> u64 {
self.state.current_term
}
/// Returns the current role.
pub fn role(&self) -> NodeRole {
self.state.role
}
/// Returns the current leader ID if known.
pub fn leader(&self) -> Option<NodeId> {
self.leader_id
}
/// Returns true if this node is the leader.
pub fn is_leader(&self) -> bool {
self.state.is_leader()
}
/// Returns the commit index.
pub fn commit_index(&self) -> u64 {
self.state.commit_index
}
/// Returns the last applied index.
pub fn last_applied(&self) -> u64 {
self.state.last_applied
}
/// Returns the log length.
pub fn log_length(&self) -> u64 {
self.log.last_index()
}
/// Drains pending events.
pub fn drain_events(&mut self) -> Vec<RaftEvent> {
std::mem::take(&mut self.events)
}
/// Called periodically to drive the Raft state machine.
pub fn tick(&mut self) {
match self.state.role {
NodeRole::Leader => self.tick_leader(),
NodeRole::Follower => self.tick_follower(),
NodeRole::Candidate => self.tick_candidate(),
}
// Apply committed entries
self.apply_committed_entries();
// Check if snapshot needed
if self
.snapshots
.should_snapshot(self.log.last_index(), self.snapshots.last_included_index())
{
self.events.push(RaftEvent::TakeSnapshot);
}
}
fn tick_leader(&mut self) {
// Send heartbeats to all peers
self.send_heartbeats();
}
fn tick_follower(&mut self) {
// Check for election timeout
if self.last_leader_contact.elapsed() >= self.current_timeout {
self.start_election();
}
}
fn tick_candidate(&mut self) {
// Check for election timeout
if self.last_leader_contact.elapsed() >= self.current_timeout {
// Start a new election
self.start_election();
}
}
fn start_election(&mut self) {
// Increment term and become candidate
self.state.become_candidate();
self.state.voted_for = Some(self.id);
self.leader_id = None;
// Reset timeout
self.reset_election_timeout();
// Create new election
let cluster_size = self.cluster.voting_members();
self.election = Some(Election::new(self.id, self.state.current_term, cluster_size));
// Create RequestVote message
let request = RequestVote::new(
self.state.current_term,
self.id,
self.log.last_index(),
self.log.last_term(),
);
self.events
.push(RaftEvent::BroadcastRpc(RpcMessage::RequestVote(request)));
// Check if we already have quorum (single-node cluster)
if cluster_size == 1 {
self.become_leader();
}
}
fn become_leader(&mut self) {
self.state.become_leader();
self.leader_id = Some(self.id);
self.election = None;
// Initialize leader state
let peer_ids: Vec<_> = self.cluster.peer_ids();
self.leader_state = Some(LeaderState::new(self.log.last_index(), &peer_ids));
self.events.push(RaftEvent::BecameLeader);
// Send immediate heartbeats
self.send_heartbeats();
}
fn become_follower(&mut self, term: u64, leader: Option<NodeId>) {
self.state.become_follower(term);
self.leader_id = leader;
self.leader_state = None;
self.election = None;
self.reset_election_timeout();
self.events.push(RaftEvent::BecameFollower(leader));
}
fn reset_election_timeout(&mut self) {
self.last_leader_contact = Instant::now();
self.current_timeout = self.election_timeout.random_timeout();
}
fn send_heartbeats(&mut self) {
if !self.is_leader() {
return;
}
for peer_id in self.cluster.peer_ids() {
self.send_append_entries(peer_id);
}
}
fn send_append_entries(&mut self, peer_id: NodeId) {
let leader_state = match &self.leader_state {
Some(ls) => ls,
None => return,
};
let next_index = *leader_state.next_index.get(&peer_id).unwrap_or(&1);
// Check if we need to send snapshot instead
if next_index <= self.snapshots.last_included_index() {
self.send_install_snapshot(peer_id);
return;
}
let (prev_log_index, prev_log_term, entries) =
self.log
.entries_for_replication(next_index, self.config.max_entries_per_rpc);
let request = AppendEntries::with_entries(
self.state.current_term,
self.id,
prev_log_index,
prev_log_term,
entries,
self.state.commit_index,
);
self.events
.push(RaftEvent::SendRpc(peer_id, RpcMessage::AppendEntries(request)));
}
fn send_install_snapshot(&mut self, peer_id: NodeId) {
let snapshot = match self.snapshots.get_snapshot() {
Some(s) => s,
None => return,
};
let chunks = self
.snapshots
.chunk_snapshot(self.config.snapshot_chunk_size);
if let Some((offset, data, done)) = chunks.into_iter().next() {
let request = InstallSnapshot::new(
self.state.current_term,
self.id,
snapshot.metadata.last_included_index,
snapshot.metadata.last_included_term,
offset,
data,
done,
);
self.events
.push(RaftEvent::SendRpc(peer_id, RpcMessage::InstallSnapshot(request)));
}
}
/// Handles an incoming RPC message.
pub fn handle_rpc(&mut self, from: NodeId, message: RpcMessage) -> Option<RpcMessage> {
match message {
RpcMessage::RequestVote(req) => {
let response = self.handle_request_vote(from, req);
Some(RpcMessage::RequestVoteResponse(response))
}
RpcMessage::RequestVoteResponse(resp) => {
self.handle_request_vote_response(from, resp);
None
}
RpcMessage::AppendEntries(req) => {
let response = self.handle_append_entries(from, req);
Some(RpcMessage::AppendEntriesResponse(response))
}
RpcMessage::AppendEntriesResponse(resp) => {
self.handle_append_entries_response(from, resp);
None
}
RpcMessage::InstallSnapshot(req) => {
let response = self.handle_install_snapshot(from, req);
Some(RpcMessage::InstallSnapshotResponse(response))
}
RpcMessage::InstallSnapshotResponse(resp) => {
self.handle_install_snapshot_response(from, resp);
None
}
}
}
fn handle_request_vote(&mut self, _from: NodeId, req: RequestVote) -> RequestVoteResponse {
// Use the VoteHandler from election module
VoteHandler::handle_request(&mut self.state, &self.log, &req)
}
fn handle_request_vote_response(&mut self, from: NodeId, resp: RequestVoteResponse) {
// Ignore if not candidate
if !self.state.is_candidate() {
return;
}
// Use the VoteHandler
if let Some(ref mut election) = self.election {
let result = VoteHandler::handle_response(&mut self.state, election, from, &resp);
match result {
ElectionResult::Won => {
self.become_leader();
}
ElectionResult::Lost => {
// Already handled by VoteHandler (became follower)
self.election = None;
}
ElectionResult::InProgress | ElectionResult::Timeout => {}
}
}
}
fn handle_append_entries(&mut self, _from: NodeId, req: AppendEntries) -> AppendEntriesResponse {
// Rule: If term > currentTerm, become follower
if req.term > self.state.current_term {
self.become_follower(req.term, Some(req.leader_id));
}
// Reject if term is old
if req.term < self.state.current_term {
return AppendEntriesResponse::failure(self.state.current_term);
}
// Valid AppendEntries from leader - reset election timeout
self.reset_election_timeout();
self.leader_id = Some(req.leader_id);
// If we're candidate, step down
if self.state.is_candidate() {
self.become_follower(req.term, Some(req.leader_id));
}
// Try to append entries
let success =
self.log
.append_entries(req.prev_log_index, req.prev_log_term, req.entries);
if success {
// Update commit index
if req.leader_commit > self.state.commit_index {
self.state
.update_commit_index(std::cmp::min(req.leader_commit, self.log.last_index()));
}
AppendEntriesResponse::success(self.state.current_term, self.log.last_index())
} else {
// Find conflict info for faster recovery
if let Some(conflict_term) = self.log.term_at(req.prev_log_index) {
// Find first index of conflicting term
let mut conflict_index = req.prev_log_index;
while conflict_index > 1 {
if let Some(term) = self.log.term_at(conflict_index - 1) {
if term != conflict_term {
break;
}
} else {
break;
}
conflict_index -= 1;
}
AppendEntriesResponse::conflict(self.state.current_term, conflict_term, conflict_index)
} else {
AppendEntriesResponse::failure(self.state.current_term)
}
}
}
fn handle_append_entries_response(&mut self, from: NodeId, resp: AppendEntriesResponse) {
// Ignore if not leader
if !self.is_leader() {
return;
}
// Step down if term is higher
if resp.term > self.state.current_term {
self.become_follower(resp.term, None);
return;
}
let leader_state = match &mut self.leader_state {
Some(ls) => ls,
None => return,
};
if resp.success {
// Update match_index and next_index
leader_state.update_indices(from, resp.match_index);
// Try to advance commit index
self.try_advance_commit_index();
} else {
// Use conflict info if available for faster recovery
if let (Some(conflict_term), Some(conflict_index)) =
(resp.conflict_term, resp.conflict_index)
{
// Search for last entry with conflict_term
let mut new_next = conflict_index;
for idx in (1..=self.log.last_index()).rev() {
if let Some(term) = self.log.term_at(idx) {
if term == conflict_term {
new_next = idx + 1;
break;
}
if term < conflict_term {
break;
}
}
}
leader_state.next_index.insert(from, new_next);
} else {
// Simple decrement
leader_state.decrement_next_index(from);
}
}
// Update peer state
self.cluster.update_peer_state(from, PeerState::Reachable);
}
fn handle_install_snapshot(&mut self, _from: NodeId, req: InstallSnapshot) -> InstallSnapshotResponse {
// Rule: If term > currentTerm, become follower
if req.term > self.state.current_term {
self.become_follower(req.term, Some(req.leader_id));
}
// Reject if term is old
if req.term < self.state.current_term {
return InstallSnapshotResponse::new(self.state.current_term);
}
// Reset election timeout
self.reset_election_timeout();
self.leader_id = Some(req.leader_id);
// Start or continue receiving snapshot
if req.offset == 0 {
self.snapshots.start_receiving(
req.last_included_index,
req.last_included_term,
req.offset,
req.data,
);
} else {
self.snapshots.add_chunk(req.offset, req.data);
}
// If done, finalize snapshot
if req.done {
if let Some(_snapshot) = self.snapshots.finalize_snapshot() {
// Discard log up to snapshot
self.log
.compact(req.last_included_index, req.last_included_term);
// Update state
if req.last_included_index > self.state.commit_index {
self.state.update_commit_index(req.last_included_index);
}
if req.last_included_index > self.state.last_applied {
self.state.update_last_applied(req.last_included_index);
}
self.events
.push(RaftEvent::LogCompacted(req.last_included_index));
}
}
InstallSnapshotResponse::new(self.state.current_term)
}
fn handle_install_snapshot_response(&mut self, from: NodeId, resp: InstallSnapshotResponse) {
if !self.is_leader() {
return;
}
if resp.term > self.state.current_term {
self.become_follower(resp.term, None);
return;
}
// Update next_index for peer (assuming success since there's no success field)
if let Some(leader_state) = &mut self.leader_state {
let snapshot_index = self.snapshots.last_included_index();
leader_state.update_indices(from, snapshot_index);
}
self.cluster.update_peer_state(from, PeerState::Reachable);
}
fn try_advance_commit_index(&mut self) {
if !self.is_leader() {
return;
}
let leader_state = match &self.leader_state {
Some(ls) => ls,
None => return,
};
// Calculate new commit index
let new_commit = leader_state.calculate_commit_index(
self.state.commit_index,
self.state.current_term,
|idx| self.log.term_at(idx),
);
if new_commit > self.state.commit_index {
self.state.update_commit_index(new_commit);
}
}
fn apply_committed_entries(&mut self) {
while self.state.last_applied < self.state.commit_index {
let next_to_apply = self.state.last_applied + 1;
if let Some(entry) = self.log.get(next_to_apply) {
self.events
.push(RaftEvent::ApplyEntry(entry.index, entry.command.clone()));
self.state.update_last_applied(next_to_apply);
} else {
break;
}
}
}
/// Proposes a command (only valid on leader).
pub fn propose(&mut self, command: Command) -> Result<u64, ApplyResult> {
if !self.is_leader() {
return Err(ApplyResult::NotLeader(self.leader_id));
}
// Create log entry
let index = self.log.last_index() + 1;
let entry = LogEntry::new(self.state.current_term, index, command);
let result_index = self.log.append(entry);
// Send AppendEntries to all peers
for peer_id in self.cluster.peer_ids() {
self.send_append_entries(peer_id);
}
Ok(result_index)
}
/// Takes a snapshot of the state machine.
pub fn take_snapshot(&mut self, data: Vec<u8>) {
let last_index = self.state.last_applied;
let last_term = self.log.term_at(last_index).unwrap_or(0);
let config = SnapshotConfig {
nodes: std::iter::once(self.id)
.chain(self.cluster.peer_ids())
.collect(),
};
self.snapshots
.create_snapshot(last_index, last_term, config, data);
// Compact log
self.log.compact(last_index, last_term);
self.events.push(RaftEvent::LogCompacted(last_index));
}
/// Adds a new server to the cluster.
pub fn add_server(
&mut self,
id: NodeId,
address: super::cluster::PeerAddress,
) -> Result<(), String> {
if !self.is_leader() {
return Err("Not leader".to_string());
}
let peer = super::cluster::PeerInfo::new(id, address);
self.cluster.add_peer(peer);
// Initialize leader state for new peer
if let Some(ref mut ls) = self.leader_state {
ls.next_index.insert(id, self.log.last_index() + 1);
ls.match_index.insert(id, 0);
}
Ok(())
}
/// Removes a server from the cluster.
pub fn remove_server(&mut self, id: NodeId) -> Result<(), String> {
if !self.is_leader() {
return Err("Not leader".to_string());
}
self.cluster.remove_peer(id);
if let Some(ref mut ls) = self.leader_state {
ls.next_index.remove(&id);
ls.match_index.remove(&id);
}
Ok(())
}
/// Forces an election timeout (for testing).
#[cfg(test)]
pub fn force_election_timeout(&mut self) {
self.last_leader_contact = Instant::now() - self.current_timeout - Duration::from_secs(1);
}
}
#[cfg(test)]
mod tests {
use super::*;
use super::super::cluster::PeerAddress;
fn create_test_cluster(node_id: NodeId, peers: &[NodeId]) -> ClusterConfig {
let mut cluster =
ClusterConfig::new(node_id, PeerAddress::new("127.0.0.1", 9000 + node_id as u16));
for &peer in peers {
cluster.add_peer(super::super::cluster::PeerInfo::new(
peer,
PeerAddress::new("127.0.0.1", 9000 + peer as u16),
));
}
cluster
}
#[test]
fn test_single_node_election() {
let cluster = create_test_cluster(1, &[]);
let config = RaftConfig::default();
let mut node = RaftNode::new(1, cluster, config);
assert!(matches!(node.role(), NodeRole::Follower));
// Simulate election timeout
node.force_election_timeout();
node.tick();
// Single node should become leader immediately
assert!(node.is_leader());
}
#[test]
fn test_three_node_election() {
let cluster = create_test_cluster(1, &[2, 3]);
let config = RaftConfig::default();
let mut node1 = RaftNode::new(1, cluster, config);
// Trigger election
node1.force_election_timeout();
node1.tick();
assert!(matches!(node1.role(), NodeRole::Candidate));
assert_eq!(node1.current_term(), 1);
// Simulate receiving votes
let vote_resp = RequestVoteResponse::grant(1);
node1.handle_rpc(2, RpcMessage::RequestVoteResponse(vote_resp.clone()));
// Should become leader after receiving vote from node 2 (2 votes = quorum of 3)
assert!(node1.is_leader());
}
#[test]
fn test_append_entries() {
let cluster = create_test_cluster(1, &[]);
let config = RaftConfig::default();
let mut leader = RaftNode::new(1, cluster, config);
// Become leader
leader.force_election_timeout();
leader.tick();
assert!(leader.is_leader());
// Propose a command
let command = Command::KvSet {
key: "test".to_string(),
value: vec![1, 2, 3],
ttl: None,
};
let index = leader.propose(command).unwrap();
assert_eq!(index, 1);
assert_eq!(leader.log_length(), 1);
}
#[test]
fn test_follower_receives_append_entries() {
let cluster1 = create_test_cluster(1, &[2]);
let cluster2 = create_test_cluster(2, &[1]);
let config = RaftConfig::default();
let mut leader = RaftNode::new(1, cluster1, config.clone());
let mut follower = RaftNode::new(2, cluster2, config);
// Make node 1 leader
leader.force_election_timeout();
leader.tick();
// Simulate vote from node 2
leader.handle_rpc(
2,
RpcMessage::RequestVoteResponse(RequestVoteResponse::grant(1)),
);
assert!(leader.is_leader());
// Propose a command
leader
.propose(Command::KvSet {
key: "key1".to_string(),
value: vec![1],
ttl: None,
})
.unwrap();
// Get AppendEntries from leader events (skip heartbeats, find one with entries)
let events = leader.drain_events();
let append_req = events
.iter()
.find_map(|e| {
if let RaftEvent::SendRpc(2, RpcMessage::AppendEntries(req)) = e {
// Skip heartbeats (empty entries), find the one with actual entries
if !req.entries.is_empty() {
Some(req.clone())
} else {
None
}
} else {
None
}
})
.unwrap();
// Send to follower
let response = follower
.handle_rpc(1, RpcMessage::AppendEntries(append_req))
.unwrap();
if let RpcMessage::AppendEntriesResponse(resp) = response {
assert!(resp.success);
assert_eq!(resp.match_index, 1);
}
assert_eq!(follower.log_length(), 1);
}
#[test]
fn test_step_down_on_higher_term() {
let cluster = create_test_cluster(1, &[2]);
let config = RaftConfig::default();
let mut node = RaftNode::new(1, cluster, config);
// Make it leader
node.force_election_timeout();
node.tick();
node.handle_rpc(
2,
RpcMessage::RequestVoteResponse(RequestVoteResponse::grant(1)),
);
assert!(node.is_leader());
assert_eq!(node.current_term(), 1);
// Receive AppendEntries with higher term
node.handle_rpc(
2,
RpcMessage::AppendEntries(AppendEntries::heartbeat(5, 2, 0, 0, 0)),
);
assert!(!node.is_leader());
assert_eq!(node.current_term(), 5);
assert_eq!(node.leader(), Some(2));
}
#[test]
fn test_commit_index_advancement() {
let cluster = create_test_cluster(1, &[2, 3]);
let config = RaftConfig::default();
let mut leader = RaftNode::new(1, cluster, config);
// Become leader
leader.force_election_timeout();
leader.tick();
leader.handle_rpc(
2,
RpcMessage::RequestVoteResponse(RequestVoteResponse::grant(1)),
);
leader.handle_rpc(
3,
RpcMessage::RequestVoteResponse(RequestVoteResponse::grant(1)),
);
// Propose command
leader
.propose(Command::KvSet {
key: "key".to_string(),
value: vec![1],
ttl: None,
})
.unwrap();
// Simulate successful replication to node 2
leader.handle_rpc(
2,
RpcMessage::AppendEntriesResponse(AppendEntriesResponse::success(1, 1)),
);
// Commit index should advance (quorum = 2, we have leader + node2)
assert_eq!(leader.commit_index(), 1);
}
#[test]
fn test_snapshot() {
let cluster = create_test_cluster(1, &[]);
let config = RaftConfig::default();
let mut node = RaftNode::new(1, cluster, config);
// Become leader
node.force_election_timeout();
node.tick();
// Propose and commit some entries
for i in 0..5 {
node.propose(Command::KvSet {
key: format!("key{}", i),
value: vec![i as u8],
ttl: None,
})
.unwrap();
}
// Manually advance commit and apply
node.state.update_commit_index(5);
node.tick();
// Take snapshot
node.take_snapshot(vec![1, 2, 3, 4, 5]);
// Check snapshot was created
assert!(node.snapshots.get_snapshot().is_some());
}
#[test]
fn test_request_vote_log_check() {
let cluster = create_test_cluster(1, &[2]);
let config = RaftConfig::default();
let mut node = RaftNode::new(1, cluster, config);
// Add some entries to node's log
node.log.append(LogEntry::new(1, 1, Command::Noop));
node.log.append(LogEntry::new(2, 2, Command::Noop));
// Request vote from candidate with shorter log (lower term)
let req = RequestVote::new(3, 2, 1, 1);
let resp = node.handle_rpc(2, RpcMessage::RequestVote(req));
if let Some(RpcMessage::RequestVoteResponse(r)) = resp {
// Should not grant vote - our log is more up-to-date (term 2 > term 1)
assert!(!r.vote_granted);
}
// Request vote from candidate with equal or better log
let req2 = RequestVote::new(4, 2, 2, 2);
let resp2 = node.handle_rpc(2, RpcMessage::RequestVote(req2));
if let Some(RpcMessage::RequestVoteResponse(r)) = resp2 {
assert!(r.vote_granted);
}
}
}

View file

@ -0,0 +1,318 @@
//! RPC messages for Raft consensus.
use super::log::LogEntry;
use serde::{Deserialize, Serialize};
/// All RPC message types in Raft.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum RpcMessage {
/// Request vote from candidate.
RequestVote(RequestVote),
/// Response to vote request.
RequestVoteResponse(RequestVoteResponse),
/// Append entries from leader.
AppendEntries(AppendEntries),
/// Response to append entries.
AppendEntriesResponse(AppendEntriesResponse),
/// Install snapshot from leader.
InstallSnapshot(InstallSnapshot),
/// Response to install snapshot.
InstallSnapshotResponse(InstallSnapshotResponse),
}
impl RpcMessage {
/// Returns the term of this message.
pub fn term(&self) -> u64 {
match self {
RpcMessage::RequestVote(r) => r.term,
RpcMessage::RequestVoteResponse(r) => r.term,
RpcMessage::AppendEntries(r) => r.term,
RpcMessage::AppendEntriesResponse(r) => r.term,
RpcMessage::InstallSnapshot(r) => r.term,
RpcMessage::InstallSnapshotResponse(r) => r.term,
}
}
/// Serializes to bytes.
pub fn to_bytes(&self) -> Vec<u8> {
bincode::serialize(self).unwrap_or_default()
}
/// Deserializes from bytes.
pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
bincode::deserialize(bytes).ok()
}
}
/// RequestVote RPC (sent by candidates to gather votes).
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct RequestVote {
/// Candidate's term.
pub term: u64,
/// Candidate requesting vote.
pub candidate_id: u64,
/// Index of candidate's last log entry.
pub last_log_index: u64,
/// Term of candidate's last log entry.
pub last_log_term: u64,
}
impl RequestVote {
/// Creates a new RequestVote message.
pub fn new(term: u64, candidate_id: u64, last_log_index: u64, last_log_term: u64) -> Self {
Self {
term,
candidate_id,
last_log_index,
last_log_term,
}
}
}
/// Response to RequestVote.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct RequestVoteResponse {
/// Current term (for candidate to update).
pub term: u64,
/// True if candidate received vote.
pub vote_granted: bool,
}
impl RequestVoteResponse {
/// Creates a positive response.
pub fn grant(term: u64) -> Self {
Self {
term,
vote_granted: true,
}
}
/// Creates a negative response.
pub fn deny(term: u64) -> Self {
Self {
term,
vote_granted: false,
}
}
}
/// AppendEntries RPC (sent by leader for replication and heartbeat).
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AppendEntries {
/// Leader's term.
pub term: u64,
/// Leader ID (so follower can redirect clients).
pub leader_id: u64,
/// Index of log entry immediately preceding new ones.
pub prev_log_index: u64,
/// Term of prev_log_index entry.
pub prev_log_term: u64,
/// Log entries to store (empty for heartbeat).
pub entries: Vec<LogEntry>,
/// Leader's commit index.
pub leader_commit: u64,
}
impl AppendEntries {
/// Creates a heartbeat (empty entries).
pub fn heartbeat(
term: u64,
leader_id: u64,
prev_log_index: u64,
prev_log_term: u64,
leader_commit: u64,
) -> Self {
Self {
term,
leader_id,
prev_log_index,
prev_log_term,
entries: Vec::new(),
leader_commit,
}
}
/// Creates an append entries request with entries.
pub fn with_entries(
term: u64,
leader_id: u64,
prev_log_index: u64,
prev_log_term: u64,
entries: Vec<LogEntry>,
leader_commit: u64,
) -> Self {
Self {
term,
leader_id,
prev_log_index,
prev_log_term,
entries,
leader_commit,
}
}
/// Returns true if this is a heartbeat (no entries).
pub fn is_heartbeat(&self) -> bool {
self.entries.is_empty()
}
}
/// Response to AppendEntries.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AppendEntriesResponse {
/// Current term (for leader to update).
pub term: u64,
/// True if follower contained entry matching prev_log_index and prev_log_term.
pub success: bool,
/// Index of last entry appended (for quick catch-up).
pub match_index: u64,
/// If false, the conflicting term (for optimization).
pub conflict_term: Option<u64>,
/// First index of conflicting term.
pub conflict_index: Option<u64>,
}
impl AppendEntriesResponse {
/// Creates a success response.
pub fn success(term: u64, match_index: u64) -> Self {
Self {
term,
success: true,
match_index,
conflict_term: None,
conflict_index: None,
}
}
/// Creates a failure response.
pub fn failure(term: u64) -> Self {
Self {
term,
success: false,
match_index: 0,
conflict_term: None,
conflict_index: None,
}
}
/// Creates a failure response with conflict info.
pub fn conflict(term: u64, conflict_term: u64, conflict_index: u64) -> Self {
Self {
term,
success: false,
match_index: 0,
conflict_term: Some(conflict_term),
conflict_index: Some(conflict_index),
}
}
}
/// InstallSnapshot RPC (sent by leader when follower is too far behind).
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct InstallSnapshot {
/// Leader's term.
pub term: u64,
/// Leader ID.
pub leader_id: u64,
/// Index of last entry included in snapshot.
pub last_included_index: u64,
/// Term of last entry included in snapshot.
pub last_included_term: u64,
/// Byte offset where chunk is positioned.
pub offset: u64,
/// Raw bytes of snapshot chunk.
pub data: Vec<u8>,
/// True if this is the last chunk.
pub done: bool,
}
impl InstallSnapshot {
/// Creates a new snapshot installation request.
pub fn new(
term: u64,
leader_id: u64,
last_included_index: u64,
last_included_term: u64,
offset: u64,
data: Vec<u8>,
done: bool,
) -> Self {
Self {
term,
leader_id,
last_included_index,
last_included_term,
offset,
data,
done,
}
}
}
/// Response to InstallSnapshot.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct InstallSnapshotResponse {
/// Current term (for leader to update).
pub term: u64,
}
impl InstallSnapshotResponse {
/// Creates a new response.
pub fn new(term: u64) -> Self {
Self { term }
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::replication::state::Command;
#[test]
fn test_request_vote() {
let request = RequestVote::new(1, 1, 10, 1);
assert_eq!(request.term, 1);
assert_eq!(request.candidate_id, 1);
let grant = RequestVoteResponse::grant(1);
assert!(grant.vote_granted);
let deny = RequestVoteResponse::deny(2);
assert!(!deny.vote_granted);
assert_eq!(deny.term, 2);
}
#[test]
fn test_append_entries() {
let heartbeat = AppendEntries::heartbeat(1, 1, 0, 0, 0);
assert!(heartbeat.is_heartbeat());
let entries = vec![LogEntry::new(1, 1, Command::Noop)];
let append = AppendEntries::with_entries(1, 1, 0, 0, entries, 0);
assert!(!append.is_heartbeat());
}
#[test]
fn test_rpc_message_serialization() {
let request = RpcMessage::RequestVote(RequestVote::new(1, 1, 10, 1));
let bytes = request.to_bytes();
let decoded = RpcMessage::from_bytes(&bytes).unwrap();
assert_eq!(decoded.term(), 1);
}
#[test]
fn test_append_entries_response() {
let success = AppendEntriesResponse::success(1, 10);
assert!(success.success);
assert_eq!(success.match_index, 10);
let failure = AppendEntriesResponse::failure(2);
assert!(!failure.success);
let conflict = AppendEntriesResponse::conflict(2, 1, 5);
assert!(!conflict.success);
assert_eq!(conflict.conflict_term, Some(1));
assert_eq!(conflict.conflict_index, Some(5));
}
}

View file

@ -0,0 +1,301 @@
//! Log compaction and snapshots.
use serde::{Deserialize, Serialize};
/// Metadata for a snapshot.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct SnapshotMetadata {
/// Index of last entry included in snapshot.
pub last_included_index: u64,
/// Term of last entry included in snapshot.
pub last_included_term: u64,
/// Cluster configuration at snapshot time.
pub config: SnapshotConfig,
/// Size of snapshot data in bytes.
pub size: u64,
/// Timestamp when snapshot was created.
pub created_at: u64,
}
/// Cluster configuration stored in snapshot.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct SnapshotConfig {
/// Node IDs in the cluster.
pub nodes: Vec<u64>,
}
impl Default for SnapshotConfig {
fn default() -> Self {
Self { nodes: Vec::new() }
}
}
/// A complete snapshot of the state machine.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Snapshot {
/// Snapshot metadata.
pub metadata: SnapshotMetadata,
/// Snapshot data (serialized state machine).
pub data: Vec<u8>,
}
impl Snapshot {
/// Creates a new snapshot.
pub fn new(
last_included_index: u64,
last_included_term: u64,
config: SnapshotConfig,
data: Vec<u8>,
) -> Self {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis() as u64;
Self {
metadata: SnapshotMetadata {
last_included_index,
last_included_term,
config,
size: data.len() as u64,
created_at: now,
},
data,
}
}
/// Serializes the snapshot to bytes.
pub fn to_bytes(&self) -> Vec<u8> {
bincode::serialize(self).unwrap_or_default()
}
/// Deserializes from bytes.
pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
bincode::deserialize(bytes).ok()
}
}
/// Manages snapshot creation and storage.
pub struct SnapshotManager {
/// Threshold for log entries before snapshotting.
snapshot_threshold: u64,
/// Current snapshot (if any).
current_snapshot: Option<Snapshot>,
/// Pending snapshot being received from leader.
pending_snapshot: Option<PendingSnapshot>,
}
/// Snapshot being received in chunks.
struct PendingSnapshot {
metadata: SnapshotMetadata,
chunks: Vec<Vec<u8>>,
expected_offset: u64,
}
impl SnapshotManager {
/// Creates a new snapshot manager.
pub fn new(snapshot_threshold: u64) -> Self {
Self {
snapshot_threshold,
current_snapshot: None,
pending_snapshot: None,
}
}
/// Returns the threshold for triggering snapshots.
pub fn threshold(&self) -> u64 {
self.snapshot_threshold
}
/// Checks if a snapshot should be taken.
pub fn should_snapshot(&self, log_size: u64, last_snapshot_index: u64) -> bool {
log_size - last_snapshot_index >= self.snapshot_threshold
}
/// Gets the current snapshot.
pub fn get_snapshot(&self) -> Option<&Snapshot> {
self.current_snapshot.as_ref()
}
/// Gets the last included index of the current snapshot.
pub fn last_included_index(&self) -> u64 {
self.current_snapshot
.as_ref()
.map(|s| s.metadata.last_included_index)
.unwrap_or(0)
}
/// Gets the last included term of the current snapshot.
pub fn last_included_term(&self) -> u64 {
self.current_snapshot
.as_ref()
.map(|s| s.metadata.last_included_term)
.unwrap_or(0)
}
/// Creates a new snapshot.
pub fn create_snapshot(
&mut self,
last_included_index: u64,
last_included_term: u64,
config: SnapshotConfig,
data: Vec<u8>,
) {
let snapshot = Snapshot::new(last_included_index, last_included_term, config, data);
self.current_snapshot = Some(snapshot);
}
/// Starts receiving a snapshot from leader.
pub fn start_receiving(
&mut self,
last_included_index: u64,
last_included_term: u64,
offset: u64,
data: Vec<u8>,
) -> bool {
if offset != 0 {
// First chunk should be at offset 0
return false;
}
self.pending_snapshot = Some(PendingSnapshot {
metadata: SnapshotMetadata {
last_included_index,
last_included_term,
config: SnapshotConfig::default(),
size: 0,
created_at: 0,
},
chunks: vec![data],
expected_offset: 0,
});
true
}
/// Adds a chunk to the pending snapshot.
pub fn add_chunk(&mut self, offset: u64, data: Vec<u8>) -> bool {
if let Some(ref mut pending) = self.pending_snapshot {
if offset == pending.expected_offset + pending.chunks.iter().map(|c| c.len() as u64).sum::<u64>() {
pending.chunks.push(data);
return true;
}
}
false
}
/// Finalizes the pending snapshot.
pub fn finalize_snapshot(&mut self) -> Option<Snapshot> {
if let Some(pending) = self.pending_snapshot.take() {
let data: Vec<u8> = pending.chunks.into_iter().flatten().collect();
let snapshot = Snapshot::new(
pending.metadata.last_included_index,
pending.metadata.last_included_term,
pending.metadata.config,
data,
);
self.current_snapshot = Some(snapshot.clone());
return Some(snapshot);
}
None
}
/// Cancels receiving a snapshot.
pub fn cancel_receiving(&mut self) {
self.pending_snapshot = None;
}
/// Splits snapshot into chunks for transmission.
pub fn chunk_snapshot(&self, chunk_size: usize) -> Vec<(u64, Vec<u8>, bool)> {
let mut chunks = Vec::new();
if let Some(ref snapshot) = self.current_snapshot {
let data = &snapshot.data;
let mut offset = 0;
while offset < data.len() {
let end = (offset + chunk_size).min(data.len());
let chunk = data[offset..end].to_vec();
let done = end == data.len();
chunks.push((offset as u64, chunk, done));
offset = end;
}
}
chunks
}
}
impl Default for SnapshotManager {
fn default() -> Self {
Self::new(10000) // Default: snapshot every 10k entries
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_snapshot_creation() {
let snapshot = Snapshot::new(100, 5, SnapshotConfig::default(), vec![1, 2, 3, 4, 5]);
assert_eq!(snapshot.metadata.last_included_index, 100);
assert_eq!(snapshot.metadata.last_included_term, 5);
assert_eq!(snapshot.metadata.size, 5);
}
#[test]
fn test_snapshot_serialization() {
let snapshot = Snapshot::new(100, 5, SnapshotConfig::default(), vec![1, 2, 3, 4, 5]);
let bytes = snapshot.to_bytes();
let decoded = Snapshot::from_bytes(&bytes).unwrap();
assert_eq!(decoded.metadata.last_included_index, 100);
assert_eq!(decoded.data, vec![1, 2, 3, 4, 5]);
}
#[test]
fn test_snapshot_manager() {
let mut manager = SnapshotManager::new(100);
assert!(manager.should_snapshot(150, 0));
assert!(!manager.should_snapshot(50, 0));
manager.create_snapshot(100, 5, SnapshotConfig::default(), vec![1, 2, 3]);
assert_eq!(manager.last_included_index(), 100);
assert_eq!(manager.last_included_term(), 5);
}
#[test]
fn test_chunking() {
let mut manager = SnapshotManager::new(100);
manager.create_snapshot(100, 5, SnapshotConfig::default(), vec![0; 250]);
let chunks = manager.chunk_snapshot(100);
assert_eq!(chunks.len(), 3);
assert_eq!(chunks[0].0, 0);
assert_eq!(chunks[0].1.len(), 100);
assert!(!chunks[0].2); // not done
assert_eq!(chunks[2].0, 200);
assert_eq!(chunks[2].1.len(), 50);
assert!(chunks[2].2); // done
}
#[test]
fn test_receiving_snapshot() {
let mut manager = SnapshotManager::new(100);
// Start receiving
assert!(manager.start_receiving(100, 5, 0, vec![1, 2, 3]));
// Add more chunks
assert!(manager.add_chunk(3, vec![4, 5, 6]));
// Finalize
let snapshot = manager.finalize_snapshot().unwrap();
assert_eq!(snapshot.data, vec![1, 2, 3, 4, 5, 6]);
}
}

View file

@ -0,0 +1,345 @@
//! Raft node state and commands.
use serde::{Deserialize, Serialize};
use serde_json::Value as JsonValue;
/// Role of a node in the Raft cluster.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum NodeRole {
/// Leader: handles all client requests and replicates log entries.
Leader,
/// Follower: passive, responds to RPCs from leader and candidates.
Follower,
/// Candidate: actively trying to become leader.
Candidate,
}
impl Default for NodeRole {
fn default() -> Self {
NodeRole::Follower
}
}
impl std::fmt::Display for NodeRole {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
NodeRole::Leader => write!(f, "Leader"),
NodeRole::Follower => write!(f, "Follower"),
NodeRole::Candidate => write!(f, "Candidate"),
}
}
}
/// Persistent state on all servers.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct RaftState {
/// Current term (increases monotonically).
pub current_term: u64,
/// Candidate that received vote in current term.
pub voted_for: Option<u64>,
/// Current role of this node.
pub role: NodeRole,
/// Index of highest log entry known to be committed.
pub commit_index: u64,
/// Index of highest log entry applied to state machine.
pub last_applied: u64,
}
impl Default for RaftState {
fn default() -> Self {
Self {
current_term: 0,
voted_for: None,
role: NodeRole::Follower,
commit_index: 0,
last_applied: 0,
}
}
}
impl RaftState {
/// Creates a new Raft state.
pub fn new() -> Self {
Self::default()
}
/// Transitions to follower state.
pub fn become_follower(&mut self, term: u64) {
self.current_term = term;
self.role = NodeRole::Follower;
self.voted_for = None;
}
/// Transitions to candidate state.
pub fn become_candidate(&mut self) {
self.current_term += 1;
self.role = NodeRole::Candidate;
// Vote for self when becoming candidate
}
/// Transitions to leader state.
pub fn become_leader(&mut self) {
self.role = NodeRole::Leader;
}
/// Updates commit index if new value is higher.
pub fn update_commit_index(&mut self, new_commit: u64) {
if new_commit > self.commit_index {
self.commit_index = new_commit;
}
}
/// Updates last applied index.
pub fn update_last_applied(&mut self, index: u64) {
self.last_applied = index;
}
/// Returns true if this node is the leader.
pub fn is_leader(&self) -> bool {
self.role == NodeRole::Leader
}
/// Returns true if this node is a follower.
pub fn is_follower(&self) -> bool {
self.role == NodeRole::Follower
}
/// Returns true if this node is a candidate.
pub fn is_candidate(&self) -> bool {
self.role == NodeRole::Candidate
}
}
/// Commands that can be replicated through Raft.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum Command {
/// No operation (used for heartbeats and new leader commit).
Noop,
// Key-Value operations
/// Set a key-value pair.
KvSet { key: String, value: Vec<u8>, ttl: Option<u64> },
/// Delete a key.
KvDelete { key: String },
// Document operations
/// Insert a document.
DocInsert { collection: String, document: JsonValue },
/// Update a document.
DocUpdate { collection: String, id: String, update: JsonValue },
/// Delete a document.
DocDelete { collection: String, id: String },
// Vector operations
/// Insert a vector.
VectorInsert { namespace: String, id: String, vector: Vec<f32>, metadata: JsonValue },
/// Delete a vector.
VectorDelete { namespace: String, id: String },
// Time-series operations
/// Record a metric data point.
TimeSeriesRecord { metric: String, value: f64, timestamp: u64, tags: JsonValue },
// Graph operations
/// Create a graph node.
GraphNodeCreate { labels: Vec<String>, properties: JsonValue },
/// Delete a graph node.
GraphNodeDelete { id: String },
/// Create a graph edge.
GraphEdgeCreate {
source: String,
target: String,
edge_type: String,
properties: JsonValue,
},
/// Delete a graph edge.
GraphEdgeDelete { id: String },
// SQL operations
/// Execute a SQL statement.
SqlExecute { sql: String },
// Schema operations
/// Create a collection/table.
CreateCollection { name: String, schema: Option<JsonValue> },
/// Drop a collection/table.
DropCollection { name: String },
// Index operations
/// Create an index.
CreateIndex { collection: String, field: String, index_type: String },
/// Drop an index.
DropIndex { name: String },
// Configuration changes
/// Add a node to the cluster.
AddNode { node_id: u64, address: String },
/// Remove a node from the cluster.
RemoveNode { node_id: u64 },
}
impl Command {
/// Returns a descriptive name for this command.
pub fn name(&self) -> &'static str {
match self {
Command::Noop => "noop",
Command::KvSet { .. } => "kv_set",
Command::KvDelete { .. } => "kv_delete",
Command::DocInsert { .. } => "doc_insert",
Command::DocUpdate { .. } => "doc_update",
Command::DocDelete { .. } => "doc_delete",
Command::VectorInsert { .. } => "vector_insert",
Command::VectorDelete { .. } => "vector_delete",
Command::TimeSeriesRecord { .. } => "timeseries_record",
Command::GraphNodeCreate { .. } => "graph_node_create",
Command::GraphNodeDelete { .. } => "graph_node_delete",
Command::GraphEdgeCreate { .. } => "graph_edge_create",
Command::GraphEdgeDelete { .. } => "graph_edge_delete",
Command::SqlExecute { .. } => "sql_execute",
Command::CreateCollection { .. } => "create_collection",
Command::DropCollection { .. } => "drop_collection",
Command::CreateIndex { .. } => "create_index",
Command::DropIndex { .. } => "drop_index",
Command::AddNode { .. } => "add_node",
Command::RemoveNode { .. } => "remove_node",
}
}
/// Returns true if this is a read-only command (can be served by followers).
pub fn is_read_only(&self) -> bool {
matches!(self, Command::Noop)
}
/// Serializes the command to bytes.
pub fn to_bytes(&self) -> Vec<u8> {
bincode::serialize(self).unwrap_or_default()
}
/// Deserializes from bytes.
pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
bincode::deserialize(bytes).ok()
}
}
/// Leader state (volatile, only on leaders).
#[derive(Clone, Debug, Default)]
pub struct LeaderState {
/// For each server, index of next log entry to send.
pub next_index: std::collections::HashMap<u64, u64>,
/// For each server, index of highest log entry known to be replicated.
pub match_index: std::collections::HashMap<u64, u64>,
}
impl LeaderState {
/// Creates a new leader state.
pub fn new(last_log_index: u64, peer_ids: &[u64]) -> Self {
let mut next_index = std::collections::HashMap::new();
let mut match_index = std::collections::HashMap::new();
for &peer_id in peer_ids {
// Initialize nextIndex to leader's last log index + 1
next_index.insert(peer_id, last_log_index + 1);
// Initialize matchIndex to 0
match_index.insert(peer_id, 0);
}
Self {
next_index,
match_index,
}
}
/// Updates next_index for a peer after failed append.
pub fn decrement_next_index(&mut self, peer_id: u64) {
if let Some(idx) = self.next_index.get_mut(&peer_id) {
if *idx > 1 {
*idx -= 1;
}
}
}
/// Updates indices after successful append.
pub fn update_indices(&mut self, peer_id: u64, last_index: u64) {
self.next_index.insert(peer_id, last_index + 1);
self.match_index.insert(peer_id, last_index);
}
/// Calculates the new commit index based on majority replication.
pub fn calculate_commit_index(&self, current_commit: u64, current_term: u64, log_term_at: impl Fn(u64) -> Option<u64>) -> u64 {
// Find the highest index that a majority have replicated
let mut indices: Vec<u64> = self.match_index.values().cloned().collect();
indices.sort_unstable();
indices.reverse();
// Majority is (n + 1) / 2 where n includes the leader
let majority = (indices.len() + 1 + 1) / 2;
for &index in &indices {
if index > current_commit {
// Only commit entries from current term
if let Some(term) = log_term_at(index) {
if term == current_term {
// Check if majority have this index
let count = indices.iter().filter(|&&i| i >= index).count() + 1; // +1 for leader
if count >= majority {
return index;
}
}
}
}
}
current_commit
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_node_role() {
let mut state = RaftState::new();
assert!(state.is_follower());
state.become_candidate();
assert!(state.is_candidate());
assert_eq!(state.current_term, 1);
state.become_leader();
assert!(state.is_leader());
state.become_follower(5);
assert!(state.is_follower());
assert_eq!(state.current_term, 5);
}
#[test]
fn test_command_serialization() {
let cmd = Command::KvSet {
key: "test".to_string(),
value: vec![1, 2, 3],
ttl: Some(3600),
};
let bytes = cmd.to_bytes();
let decoded = Command::from_bytes(&bytes).unwrap();
if let Command::KvSet { key, value, ttl } = decoded {
assert_eq!(key, "test");
assert_eq!(value, vec![1, 2, 3]);
assert_eq!(ttl, Some(3600));
} else {
panic!("Wrong command type");
}
}
#[test]
fn test_leader_state() {
let peers = vec![2, 3, 4];
let state = LeaderState::new(10, &peers);
assert_eq!(state.next_index.get(&2), Some(&11));
assert_eq!(state.match_index.get(&2), Some(&0));
}
}

View file

@ -0,0 +1,898 @@
//! SQL query executor.
use super::parser::{
BinaryOp, JoinType, ParsedExpr, ParsedOrderBy, ParsedSelect, ParsedSelectItem,
ParsedStatement, SqlParser,
};
use super::row::{Row, RowBuilder, RowId};
use super::table::{ColumnDef, Table, TableDef};
use super::transaction::{IsolationLevel, TransactionId, TransactionManager, TransactionOp};
use super::types::{SqlError, SqlType, SqlValue};
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
/// Result of a SQL query.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryResult {
/// Column names.
pub columns: Vec<String>,
/// Result rows.
pub rows: Vec<Vec<SqlValue>>,
/// Number of rows affected (for INSERT/UPDATE/DELETE).
pub rows_affected: u64,
/// Execution time in milliseconds.
pub execution_time_ms: u64,
}
impl QueryResult {
/// Creates an empty result.
pub fn empty() -> Self {
Self {
columns: Vec::new(),
rows: Vec::new(),
rows_affected: 0,
execution_time_ms: 0,
}
}
/// Creates a result with affected count.
pub fn affected(count: u64) -> Self {
Self {
columns: Vec::new(),
rows: Vec::new(),
rows_affected: count,
execution_time_ms: 0,
}
}
}
/// SQL execution engine.
pub struct SqlEngine {
/// Tables in this engine.
tables: RwLock<HashMap<String, Arc<Table>>>,
/// Transaction manager.
txn_manager: TransactionManager,
}
impl SqlEngine {
/// Creates a new SQL engine.
pub fn new() -> Self {
Self {
tables: RwLock::new(HashMap::new()),
txn_manager: TransactionManager::new(),
}
}
/// Executes a SQL statement.
pub fn execute(&self, sql: &str) -> Result<QueryResult, SqlError> {
let start = std::time::Instant::now();
let stmt = SqlParser::parse(sql)?;
let mut result = self.execute_statement(&stmt)?;
result.execution_time_ms = start.elapsed().as_millis() as u64;
Ok(result)
}
/// Executes a parsed statement.
fn execute_statement(&self, stmt: &ParsedStatement) -> Result<QueryResult, SqlError> {
match stmt {
ParsedStatement::CreateTable {
name,
columns,
if_not_exists,
} => self.execute_create_table(name, columns, *if_not_exists),
ParsedStatement::DropTable { name, if_exists } => {
self.execute_drop_table(name, *if_exists)
}
ParsedStatement::Select(select) => self.execute_select(select),
ParsedStatement::Insert {
table,
columns,
values,
} => self.execute_insert(table, columns, values),
ParsedStatement::Update {
table,
assignments,
where_clause,
} => self.execute_update(table, assignments, where_clause.as_ref()),
ParsedStatement::Delete {
table,
where_clause,
} => self.execute_delete(table, where_clause.as_ref()),
ParsedStatement::CreateIndex {
name,
table,
columns,
unique,
} => self.execute_create_index(name, table, columns, *unique),
ParsedStatement::DropIndex { name } => self.execute_drop_index(name),
}
}
/// Creates a table.
fn execute_create_table(
&self,
name: &str,
columns: &[super::parser::ParsedColumn],
if_not_exists: bool,
) -> Result<QueryResult, SqlError> {
let mut tables = self.tables.write();
if tables.contains_key(name) {
if if_not_exists {
return Ok(QueryResult::empty());
}
return Err(SqlError::TableExists(name.to_string()));
}
let mut table_def = TableDef::new(name);
for col in columns {
let mut col_def = ColumnDef::new(&col.name, col.data_type.clone());
if !col.nullable {
col_def = col_def.not_null();
}
if let Some(ref default) = col.default {
col_def = col_def.default(default.clone());
}
if col.primary_key {
col_def = col_def.primary_key();
}
if col.unique {
col_def = col_def.unique();
}
table_def = table_def.column(col_def);
}
let table = Arc::new(Table::new(table_def));
tables.insert(name.to_string(), table);
Ok(QueryResult::empty())
}
/// Drops a table.
fn execute_drop_table(&self, name: &str, if_exists: bool) -> Result<QueryResult, SqlError> {
let mut tables = self.tables.write();
if !tables.contains_key(name) {
if if_exists {
return Ok(QueryResult::empty());
}
return Err(SqlError::TableNotFound(name.to_string()));
}
tables.remove(name);
Ok(QueryResult::empty())
}
/// Executes a SELECT query.
fn execute_select(&self, select: &ParsedSelect) -> Result<QueryResult, SqlError> {
let tables = self.tables.read();
let table = tables
.get(&select.from)
.ok_or_else(|| SqlError::TableNotFound(select.from.clone()))?;
// Get all rows
let mut rows = table.scan();
// Apply WHERE filter
if let Some(ref where_clause) = select.where_clause {
rows = rows
.into_iter()
.filter(|row| self.evaluate_where(row, where_clause))
.collect();
}
// Apply ORDER BY
if !select.order_by.is_empty() {
rows.sort_by(|a, b| {
for ob in &select.order_by {
let a_val = a.get_or_null(&ob.column);
let b_val = b.get_or_null(&ob.column);
match a_val.partial_cmp(&b_val) {
Some(std::cmp::Ordering::Equal) => continue,
Some(ord) => {
return if ob.ascending {
ord
} else {
ord.reverse()
};
}
None => continue,
}
}
std::cmp::Ordering::Equal
});
}
// Apply OFFSET
if let Some(offset) = select.offset {
rows = rows.into_iter().skip(offset).collect();
}
// Apply LIMIT
if let Some(limit) = select.limit {
rows = rows.into_iter().take(limit).collect();
}
// Handle aggregates
if select.columns.iter().any(|c| matches!(c, ParsedSelectItem::Aggregate { .. })) {
return self.execute_aggregate(select, &rows, table);
}
// Project columns
let (column_names, result_rows) = self.project_rows(&select.columns, &rows, table);
Ok(QueryResult {
columns: column_names,
rows: result_rows,
rows_affected: 0,
execution_time_ms: 0,
})
}
/// Projects rows to selected columns.
fn project_rows(
&self,
select_items: &[ParsedSelectItem],
rows: &[Row],
table: &Table,
) -> (Vec<String>, Vec<Vec<SqlValue>>) {
let column_names: Vec<String> = select_items
.iter()
.flat_map(|item| match item {
ParsedSelectItem::Wildcard => table.def.column_names(),
ParsedSelectItem::Column(name) => vec![name.clone()],
ParsedSelectItem::ColumnAlias { alias, .. } => vec![alias.clone()],
ParsedSelectItem::Aggregate { function, alias, .. } => {
vec![alias.clone().unwrap_or_else(|| function.clone())]
}
})
.collect();
let result_rows: Vec<Vec<SqlValue>> = rows
.iter()
.map(|row| {
select_items
.iter()
.flat_map(|item| match item {
ParsedSelectItem::Wildcard => table
.def
.column_names()
.into_iter()
.map(|c| row.get_or_null(&c))
.collect::<Vec<_>>(),
ParsedSelectItem::Column(name)
| ParsedSelectItem::ColumnAlias { column: name, .. } => {
vec![row.get_or_null(name)]
}
_ => vec![SqlValue::Null],
})
.collect()
})
.collect();
(column_names, result_rows)
}
/// Executes aggregate functions.
fn execute_aggregate(
&self,
select: &ParsedSelect,
rows: &[Row],
table: &Table,
) -> Result<QueryResult, SqlError> {
let mut result_columns = Vec::new();
let mut result_values = Vec::new();
for item in &select.columns {
match item {
ParsedSelectItem::Aggregate {
function,
column,
alias,
} => {
let col_name = alias.clone().unwrap_or_else(|| function.clone());
result_columns.push(col_name);
let value = match function.as_str() {
"COUNT" => SqlValue::Integer(rows.len() as i64),
"SUM" => {
let col = column.as_ref().ok_or_else(|| {
SqlError::InvalidOperation("SUM requires column".to_string())
})?;
let sum: f64 = rows
.iter()
.filter_map(|r| r.get_or_null(col).as_real())
.sum();
SqlValue::Real(sum)
}
"AVG" => {
let col = column.as_ref().ok_or_else(|| {
SqlError::InvalidOperation("AVG requires column".to_string())
})?;
let values: Vec<f64> = rows
.iter()
.filter_map(|r| r.get_or_null(col).as_real())
.collect();
if values.is_empty() {
SqlValue::Null
} else {
SqlValue::Real(values.iter().sum::<f64>() / values.len() as f64)
}
}
"MIN" => {
let col = column.as_ref().ok_or_else(|| {
SqlError::InvalidOperation("MIN requires column".to_string())
})?;
rows.iter()
.map(|r| r.get_or_null(col))
.filter(|v| !v.is_null())
.min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.unwrap_or(SqlValue::Null)
}
"MAX" => {
let col = column.as_ref().ok_or_else(|| {
SqlError::InvalidOperation("MAX requires column".to_string())
})?;
rows.iter()
.map(|r| r.get_or_null(col))
.filter(|v| !v.is_null())
.max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.unwrap_or(SqlValue::Null)
}
_ => {
return Err(SqlError::Unsupported(format!("Function: {}", function)))
}
};
result_values.push(value);
}
ParsedSelectItem::Column(name) => {
result_columns.push(name.clone());
// For non-aggregated columns in aggregate query, take first value
result_values.push(
rows.first()
.map(|r| r.get_or_null(name))
.unwrap_or(SqlValue::Null),
);
}
_ => {}
}
}
Ok(QueryResult {
columns: result_columns,
rows: if result_values.is_empty() {
Vec::new()
} else {
vec![result_values]
},
rows_affected: 0,
execution_time_ms: 0,
})
}
/// Evaluates a WHERE clause.
fn evaluate_where(&self, row: &Row, expr: &ParsedExpr) -> bool {
match self.evaluate_expr(row, expr) {
SqlValue::Boolean(b) => b,
SqlValue::Integer(i) => i != 0,
_ => false,
}
}
/// Evaluates an expression.
fn evaluate_expr(&self, row: &Row, expr: &ParsedExpr) -> SqlValue {
match expr {
ParsedExpr::Column(name) => row.get_or_null(name),
ParsedExpr::Literal(value) => value.clone(),
ParsedExpr::BinaryOp { left, op, right } => {
let left_val = self.evaluate_expr(row, left);
let right_val = self.evaluate_expr(row, right);
self.evaluate_binary_op(&left_val, op, &right_val)
}
ParsedExpr::Not(inner) => {
let val = self.evaluate_expr(row, inner);
match val {
SqlValue::Boolean(b) => SqlValue::Boolean(!b),
_ => SqlValue::Null,
}
}
ParsedExpr::IsNull(inner) => {
SqlValue::Boolean(self.evaluate_expr(row, inner).is_null())
}
ParsedExpr::IsNotNull(inner) => {
SqlValue::Boolean(!self.evaluate_expr(row, inner).is_null())
}
ParsedExpr::InList { expr, list, negated } => {
let val = self.evaluate_expr(row, expr);
let in_list = list.iter().any(|item| {
let item_val = self.evaluate_expr(row, item);
val == item_val
});
SqlValue::Boolean(if *negated { !in_list } else { in_list })
}
ParsedExpr::Between {
expr,
low,
high,
negated,
} => {
let val = self.evaluate_expr(row, expr);
let low_val = self.evaluate_expr(row, low);
let high_val = self.evaluate_expr(row, high);
let between = val >= low_val && val <= high_val;
SqlValue::Boolean(if *negated { !between } else { between })
}
ParsedExpr::Function { name, args } => {
self.evaluate_function(row, name, args)
}
}
}
/// Evaluates a binary operation.
fn evaluate_binary_op(&self, left: &SqlValue, op: &BinaryOp, right: &SqlValue) -> SqlValue {
match op {
BinaryOp::Eq => SqlValue::Boolean(left == right),
BinaryOp::Ne => SqlValue::Boolean(left != right),
BinaryOp::Lt => SqlValue::Boolean(left < right),
BinaryOp::Le => SqlValue::Boolean(left <= right),
BinaryOp::Gt => SqlValue::Boolean(left > right),
BinaryOp::Ge => SqlValue::Boolean(left >= right),
BinaryOp::And => {
let l = matches!(left, SqlValue::Boolean(true));
let r = matches!(right, SqlValue::Boolean(true));
SqlValue::Boolean(l && r)
}
BinaryOp::Or => {
let l = matches!(left, SqlValue::Boolean(true));
let r = matches!(right, SqlValue::Boolean(true));
SqlValue::Boolean(l || r)
}
BinaryOp::Like => {
if let (SqlValue::Text(text), SqlValue::Text(pattern)) = (left, right) {
SqlValue::Boolean(self.match_like(text, pattern))
} else {
SqlValue::Boolean(false)
}
}
BinaryOp::Plus => match (left, right) {
(SqlValue::Integer(a), SqlValue::Integer(b)) => SqlValue::Integer(a + b),
(SqlValue::Real(a), SqlValue::Real(b)) => SqlValue::Real(a + b),
(SqlValue::Integer(a), SqlValue::Real(b)) => SqlValue::Real(*a as f64 + b),
(SqlValue::Real(a), SqlValue::Integer(b)) => SqlValue::Real(a + *b as f64),
_ => SqlValue::Null,
},
BinaryOp::Minus => match (left, right) {
(SqlValue::Integer(a), SqlValue::Integer(b)) => SqlValue::Integer(a - b),
(SqlValue::Real(a), SqlValue::Real(b)) => SqlValue::Real(a - b),
_ => SqlValue::Null,
},
BinaryOp::Multiply => match (left, right) {
(SqlValue::Integer(a), SqlValue::Integer(b)) => SqlValue::Integer(a * b),
(SqlValue::Real(a), SqlValue::Real(b)) => SqlValue::Real(a * b),
_ => SqlValue::Null,
},
BinaryOp::Divide => match (left, right) {
(SqlValue::Integer(a), SqlValue::Integer(b)) if *b != 0 => {
SqlValue::Integer(a / b)
}
(SqlValue::Real(a), SqlValue::Real(b)) if *b != 0.0 => SqlValue::Real(a / b),
_ => SqlValue::Null,
},
}
}
/// Evaluates a function call.
fn evaluate_function(&self, row: &Row, name: &str, args: &[ParsedExpr]) -> SqlValue {
match name.to_uppercase().as_str() {
"COALESCE" => {
for arg in args {
let val = self.evaluate_expr(row, arg);
if !val.is_null() {
return val;
}
}
SqlValue::Null
}
"UPPER" => {
if let Some(arg) = args.first() {
if let SqlValue::Text(s) = self.evaluate_expr(row, arg) {
return SqlValue::Text(s.to_uppercase());
}
}
SqlValue::Null
}
"LOWER" => {
if let Some(arg) = args.first() {
if let SqlValue::Text(s) = self.evaluate_expr(row, arg) {
return SqlValue::Text(s.to_lowercase());
}
}
SqlValue::Null
}
"LENGTH" => {
if let Some(arg) = args.first() {
if let SqlValue::Text(s) = self.evaluate_expr(row, arg) {
return SqlValue::Integer(s.len() as i64);
}
}
SqlValue::Null
}
"ABS" => {
if let Some(arg) = args.first() {
match self.evaluate_expr(row, arg) {
SqlValue::Integer(i) => return SqlValue::Integer(i.abs()),
SqlValue::Real(f) => return SqlValue::Real(f.abs()),
_ => {}
}
}
SqlValue::Null
}
_ => SqlValue::Null,
}
}
/// Matches a LIKE pattern.
fn match_like(&self, text: &str, pattern: &str) -> bool {
// Simple LIKE implementation: % = any chars, _ = single char
let regex_pattern = pattern
.replace('%', ".*")
.replace('_', ".");
// For simplicity, just do case-insensitive contains for now
if pattern.starts_with('%') && pattern.ends_with('%') {
let inner = &pattern[1..pattern.len() - 1];
text.to_lowercase().contains(&inner.to_lowercase())
} else if pattern.starts_with('%') {
let suffix = &pattern[1..];
text.to_lowercase().ends_with(&suffix.to_lowercase())
} else if pattern.ends_with('%') {
let prefix = &pattern[..pattern.len() - 1];
text.to_lowercase().starts_with(&prefix.to_lowercase())
} else {
text.to_lowercase() == pattern.to_lowercase()
}
}
/// Executes an INSERT statement.
fn execute_insert(
&self,
table_name: &str,
columns: &[String],
values: &[Vec<SqlValue>],
) -> Result<QueryResult, SqlError> {
let tables = self.tables.read();
let table = tables
.get(table_name)
.ok_or_else(|| SqlError::TableNotFound(table_name.to_string()))?;
let cols = if columns.is_empty() {
table.def.column_names()
} else {
columns.to_vec()
};
let mut count = 0;
for row_values in values {
if row_values.len() != cols.len() {
return Err(SqlError::InvalidOperation(format!(
"Column count mismatch: expected {}, got {}",
cols.len(),
row_values.len()
)));
}
let mut row_map = HashMap::new();
for (col, val) in cols.iter().zip(row_values.iter()) {
row_map.insert(col.clone(), val.clone());
}
table.insert(row_map)?;
count += 1;
}
Ok(QueryResult::affected(count))
}
/// Executes an UPDATE statement.
fn execute_update(
&self,
table_name: &str,
assignments: &[(String, SqlValue)],
where_clause: Option<&ParsedExpr>,
) -> Result<QueryResult, SqlError> {
let tables = self.tables.read();
let table = tables
.get(table_name)
.ok_or_else(|| SqlError::TableNotFound(table_name.to_string()))?;
let rows = table.scan();
let mut count = 0;
for row in rows {
let matches = where_clause
.map(|w| self.evaluate_where(&row, w))
.unwrap_or(true);
if matches {
let updates: HashMap<String, SqlValue> =
assignments.iter().cloned().collect();
table.update(row.id, updates)?;
count += 1;
}
}
Ok(QueryResult::affected(count))
}
/// Executes a DELETE statement.
fn execute_delete(
&self,
table_name: &str,
where_clause: Option<&ParsedExpr>,
) -> Result<QueryResult, SqlError> {
let tables = self.tables.read();
let table = tables
.get(table_name)
.ok_or_else(|| SqlError::TableNotFound(table_name.to_string()))?;
let rows = table.scan();
let mut count = 0;
let to_delete: Vec<RowId> = rows
.iter()
.filter(|row| {
where_clause
.map(|w| self.evaluate_where(row, w))
.unwrap_or(true)
})
.map(|row| row.id)
.collect();
for id in to_delete {
if table.delete(id)? {
count += 1;
}
}
Ok(QueryResult::affected(count))
}
/// Creates an index.
fn execute_create_index(
&self,
name: &str,
table_name: &str,
columns: &[String],
unique: bool,
) -> Result<QueryResult, SqlError> {
let tables = self.tables.read();
let table = tables
.get(table_name)
.ok_or_else(|| SqlError::TableNotFound(table_name.to_string()))?;
// For simplicity, only support single-column indexes
let column = columns
.first()
.ok_or_else(|| SqlError::InvalidOperation("Index requires at least one column".to_string()))?;
table.create_index(name, column, unique)?;
Ok(QueryResult::empty())
}
/// Drops an index.
fn execute_drop_index(&self, name: &str) -> Result<QueryResult, SqlError> {
// Would need to find which table has this index
// For now, return success
Ok(QueryResult::empty())
}
// Transaction methods
/// Begins a transaction.
pub fn begin_transaction(&self) -> TransactionId {
self.txn_manager.begin(IsolationLevel::ReadCommitted)
}
/// Commits a transaction.
pub fn commit(&self, txn_id: TransactionId) -> Result<(), SqlError> {
self.txn_manager.commit(txn_id)?;
Ok(())
}
/// Rolls back a transaction.
pub fn rollback(&self, txn_id: TransactionId) -> Result<(), SqlError> {
self.txn_manager.rollback(txn_id)?;
Ok(())
}
/// Returns number of tables.
pub fn table_count(&self) -> usize {
self.tables.read().len()
}
/// Returns table names.
pub fn table_names(&self) -> Vec<String> {
self.tables.read().keys().cloned().collect()
}
/// Gets table definition.
pub fn get_table_def(&self, name: &str) -> Option<TableDef> {
self.tables.read().get(name).map(|t| t.def.clone())
}
}
impl Default for SqlEngine {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn setup_engine() -> SqlEngine {
let engine = SqlEngine::new();
engine
.execute("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT NOT NULL, age INTEGER)")
.unwrap();
engine
}
#[test]
fn test_create_table() {
let engine = SqlEngine::new();
engine
.execute("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)")
.unwrap();
assert_eq!(engine.table_count(), 1);
}
#[test]
fn test_insert_and_select() {
let engine = setup_engine();
engine
.execute("INSERT INTO users (id, name, age) VALUES (1, 'Alice', 30)")
.unwrap();
engine
.execute("INSERT INTO users (id, name, age) VALUES (2, 'Bob', 25)")
.unwrap();
let result = engine.execute("SELECT name, age FROM users").unwrap();
assert_eq!(result.rows.len(), 2);
}
#[test]
fn test_select_with_where() {
let engine = setup_engine();
engine
.execute("INSERT INTO users (id, name, age) VALUES (1, 'Alice', 30)")
.unwrap();
engine
.execute("INSERT INTO users (id, name, age) VALUES (2, 'Bob', 25)")
.unwrap();
let result = engine.execute("SELECT name FROM users WHERE age > 26").unwrap();
assert_eq!(result.rows.len(), 1);
assert_eq!(result.rows[0][0], SqlValue::Text("Alice".to_string()));
}
#[test]
fn test_select_order_by() {
let engine = setup_engine();
engine
.execute("INSERT INTO users (id, name, age) VALUES (1, 'Alice', 30)")
.unwrap();
engine
.execute("INSERT INTO users (id, name, age) VALUES (2, 'Bob', 25)")
.unwrap();
let result = engine
.execute("SELECT name FROM users ORDER BY age")
.unwrap();
assert_eq!(result.rows[0][0], SqlValue::Text("Bob".to_string()));
assert_eq!(result.rows[1][0], SqlValue::Text("Alice".to_string()));
}
#[test]
fn test_select_limit() {
let engine = setup_engine();
for i in 1..=10 {
engine
.execute(&format!(
"INSERT INTO users (id, name, age) VALUES ({}, 'User{}', {})",
i, i, 20 + i
))
.unwrap();
}
let result = engine.execute("SELECT name FROM users LIMIT 3").unwrap();
assert_eq!(result.rows.len(), 3);
}
#[test]
fn test_aggregate_count() {
let engine = setup_engine();
engine
.execute("INSERT INTO users (id, name, age) VALUES (1, 'Alice', 30)")
.unwrap();
engine
.execute("INSERT INTO users (id, name, age) VALUES (2, 'Bob', 25)")
.unwrap();
let result = engine.execute("SELECT COUNT(*) FROM users").unwrap();
assert_eq!(result.rows[0][0], SqlValue::Integer(2));
}
#[test]
fn test_aggregate_sum_avg() {
let engine = setup_engine();
engine
.execute("INSERT INTO users (id, name, age) VALUES (1, 'Alice', 30)")
.unwrap();
engine
.execute("INSERT INTO users (id, name, age) VALUES (2, 'Bob', 20)")
.unwrap();
let result = engine.execute("SELECT SUM(age) FROM users").unwrap();
assert_eq!(result.rows[0][0], SqlValue::Real(50.0));
let result = engine.execute("SELECT AVG(age) FROM users").unwrap();
assert_eq!(result.rows[0][0], SqlValue::Real(25.0));
}
#[test]
fn test_update() {
let engine = setup_engine();
engine
.execute("INSERT INTO users (id, name, age) VALUES (1, 'Alice', 30)")
.unwrap();
let result = engine
.execute("UPDATE users SET age = 31 WHERE name = 'Alice'")
.unwrap();
assert_eq!(result.rows_affected, 1);
let result = engine
.execute("SELECT age FROM users WHERE name = 'Alice'")
.unwrap();
assert_eq!(result.rows[0][0], SqlValue::Integer(31));
}
#[test]
fn test_delete() {
let engine = setup_engine();
engine
.execute("INSERT INTO users (id, name, age) VALUES (1, 'Alice', 30)")
.unwrap();
engine
.execute("INSERT INTO users (id, name, age) VALUES (2, 'Bob', 25)")
.unwrap();
let result = engine
.execute("DELETE FROM users WHERE name = 'Alice'")
.unwrap();
assert_eq!(result.rows_affected, 1);
let result = engine.execute("SELECT COUNT(*) FROM users").unwrap();
assert_eq!(result.rows[0][0], SqlValue::Integer(1));
}
#[test]
fn test_drop_table() {
let engine = setup_engine();
assert_eq!(engine.table_count(), 1);
engine.execute("DROP TABLE users").unwrap();
assert_eq!(engine.table_count(), 0);
}
}

View file

@ -0,0 +1,23 @@
//! SQL Query Layer for Synor Database.
//!
//! Provides SQLite-compatible query interface:
//!
//! - DDL: CREATE TABLE, DROP TABLE, ALTER TABLE
//! - DML: SELECT, INSERT, UPDATE, DELETE
//! - Clauses: WHERE, ORDER BY, LIMIT, GROUP BY
//! - Joins: INNER, LEFT, RIGHT JOIN
//! - Functions: COUNT, SUM, AVG, MIN, MAX
pub mod executor;
pub mod parser;
pub mod row;
pub mod table;
pub mod transaction;
pub mod types;
pub use executor::{QueryResult, SqlEngine};
pub use parser::SqlParser;
pub use row::{Row, RowId};
pub use table::{ColumnDef, Table, TableDef};
pub use transaction::{Transaction, TransactionId, TransactionState};
pub use types::{SqlError, SqlType, SqlValue};

View file

@ -0,0 +1,732 @@
//! SQL parser using sqlparser-rs.
use super::types::{SqlError, SqlType, SqlValue};
use sqlparser::ast::{
BinaryOperator, ColumnOption, DataType, Expr, Query, Select, SelectItem, SetExpr, Statement,
TableFactor, Value as AstValue,
};
use sqlparser::dialect::SQLiteDialect;
use sqlparser::parser::Parser;
/// Parsed SQL statement.
#[derive(Debug)]
pub enum ParsedStatement {
/// CREATE TABLE statement.
CreateTable {
name: String,
columns: Vec<ParsedColumn>,
if_not_exists: bool,
},
/// DROP TABLE statement.
DropTable {
name: String,
if_exists: bool,
},
/// SELECT statement.
Select(ParsedSelect),
/// INSERT statement.
Insert {
table: String,
columns: Vec<String>,
values: Vec<Vec<SqlValue>>,
},
/// UPDATE statement.
Update {
table: String,
assignments: Vec<(String, SqlValue)>,
where_clause: Option<ParsedExpr>,
},
/// DELETE statement.
Delete {
table: String,
where_clause: Option<ParsedExpr>,
},
/// CREATE INDEX statement.
CreateIndex {
name: String,
table: String,
columns: Vec<String>,
unique: bool,
},
/// DROP INDEX statement.
DropIndex { name: String },
}
/// Parsed column definition.
#[derive(Debug)]
pub struct ParsedColumn {
pub name: String,
pub data_type: SqlType,
pub nullable: bool,
pub default: Option<SqlValue>,
pub primary_key: bool,
pub unique: bool,
}
/// Parsed SELECT statement.
#[derive(Debug)]
pub struct ParsedSelect {
pub columns: Vec<ParsedSelectItem>,
pub from: String,
pub joins: Vec<ParsedJoin>,
pub where_clause: Option<ParsedExpr>,
pub group_by: Vec<String>,
pub having: Option<ParsedExpr>,
pub order_by: Vec<ParsedOrderBy>,
pub limit: Option<usize>,
pub offset: Option<usize>,
}
/// Parsed select item.
#[derive(Debug)]
pub enum ParsedSelectItem {
/// All columns (*).
Wildcard,
/// Single column.
Column(String),
/// Column with alias.
ColumnAlias { column: String, alias: String },
/// Aggregate function.
Aggregate {
function: String,
column: Option<String>,
alias: Option<String>,
},
}
/// Parsed JOIN clause.
#[derive(Debug)]
pub struct ParsedJoin {
pub table: String,
pub join_type: JoinType,
pub on: ParsedExpr,
}
/// Join types.
#[derive(Debug, Clone)]
pub enum JoinType {
Inner,
Left,
Right,
Full,
}
/// Parsed ORDER BY clause.
#[derive(Debug)]
pub struct ParsedOrderBy {
pub column: String,
pub ascending: bool,
}
/// Parsed expression.
#[derive(Debug, Clone)]
pub enum ParsedExpr {
/// Column reference.
Column(String),
/// Literal value.
Literal(SqlValue),
/// Binary operation.
BinaryOp {
left: Box<ParsedExpr>,
op: BinaryOp,
right: Box<ParsedExpr>,
},
/// Unary NOT.
Not(Box<ParsedExpr>),
/// IS NULL check.
IsNull(Box<ParsedExpr>),
/// IS NOT NULL check.
IsNotNull(Box<ParsedExpr>),
/// IN list.
InList {
expr: Box<ParsedExpr>,
list: Vec<ParsedExpr>,
negated: bool,
},
/// BETWEEN.
Between {
expr: Box<ParsedExpr>,
low: Box<ParsedExpr>,
high: Box<ParsedExpr>,
negated: bool,
},
/// Function call.
Function { name: String, args: Vec<ParsedExpr> },
}
/// Binary operators.
#[derive(Debug, Clone)]
pub enum BinaryOp {
Eq,
Ne,
Lt,
Le,
Gt,
Ge,
And,
Or,
Like,
Plus,
Minus,
Multiply,
Divide,
}
/// SQL parser.
pub struct SqlParser;
impl SqlParser {
/// Parses a SQL statement.
pub fn parse(sql: &str) -> Result<ParsedStatement, SqlError> {
let dialect = SQLiteDialect {};
let statements = Parser::parse_sql(&dialect, sql)
.map_err(|e| SqlError::Parse(e.to_string()))?;
if statements.is_empty() {
return Err(SqlError::Parse("Empty SQL statement".to_string()));
}
if statements.len() > 1 {
return Err(SqlError::Parse("Multiple statements not supported".to_string()));
}
Self::convert_statement(&statements[0])
}
fn convert_statement(stmt: &Statement) -> Result<ParsedStatement, SqlError> {
match stmt {
Statement::CreateTable { name, columns, if_not_exists, constraints, .. } => {
Self::convert_create_table(name, columns, constraints, *if_not_exists)
}
Statement::Drop { object_type, names, if_exists, .. } => {
Self::convert_drop(object_type, names, *if_exists)
}
Statement::Query(query) => Self::convert_query(query),
Statement::Insert { table_name, columns, source, .. } => {
Self::convert_insert(table_name, columns, source)
}
Statement::Update { table, assignments, selection, .. } => {
Self::convert_update(table, assignments, selection)
}
Statement::Delete { from, selection, .. } => {
Self::convert_delete(from, selection)
}
Statement::CreateIndex { name, table_name, columns, unique, .. } => {
Self::convert_create_index(name, table_name, columns, *unique)
}
_ => Err(SqlError::Unsupported(format!("Statement not supported"))),
}
}
fn convert_create_table(
name: &sqlparser::ast::ObjectName,
columns: &[sqlparser::ast::ColumnDef],
constraints: &[sqlparser::ast::TableConstraint],
if_not_exists: bool,
) -> Result<ParsedStatement, SqlError> {
let table_name = name.to_string();
let mut parsed_columns = Vec::new();
let mut primary_keys: Vec<String> = Vec::new();
// Extract primary keys from table constraints
for constraint in constraints {
if let sqlparser::ast::TableConstraint::Unique { columns: pk_cols, is_primary: true, .. } = constraint {
for col in pk_cols {
primary_keys.push(col.value.clone());
}
}
}
for col in columns {
let col_name = col.name.value.clone();
let data_type = Self::convert_data_type(&col.data_type)?;
let mut nullable = true;
let mut default = None;
let mut primary_key = primary_keys.contains(&col_name);
let mut unique = false;
for option in &col.options {
match &option.option {
ColumnOption::Null => nullable = true,
ColumnOption::NotNull => nullable = false,
ColumnOption::Default(expr) => {
default = Some(Self::convert_value_expr(expr)?);
}
ColumnOption::Unique { is_primary, .. } => {
if *is_primary {
primary_key = true;
} else {
unique = true;
}
}
_ => {}
}
}
if primary_key {
nullable = false;
unique = true;
}
parsed_columns.push(ParsedColumn {
name: col_name,
data_type,
nullable,
default,
primary_key,
unique,
});
}
Ok(ParsedStatement::CreateTable {
name: table_name,
columns: parsed_columns,
if_not_exists,
})
}
fn convert_data_type(dt: &DataType) -> Result<SqlType, SqlError> {
match dt {
DataType::Int(_)
| DataType::Integer(_)
| DataType::BigInt(_)
| DataType::SmallInt(_)
| DataType::TinyInt(_) => Ok(SqlType::Integer),
DataType::Real | DataType::Float(_) | DataType::Double | DataType::DoublePrecision => {
Ok(SqlType::Real)
}
DataType::Varchar(_)
| DataType::Char(_)
| DataType::Text
| DataType::String(_) => Ok(SqlType::Text),
DataType::Binary(_) | DataType::Varbinary(_) | DataType::Blob(_) => Ok(SqlType::Blob),
DataType::Boolean => Ok(SqlType::Boolean),
DataType::Timestamp(_, _) | DataType::Date | DataType::Datetime(_) => {
Ok(SqlType::Timestamp)
}
_ => Err(SqlError::Unsupported(format!("Data type: {:?}", dt))),
}
}
fn convert_drop(
object_type: &sqlparser::ast::ObjectType,
names: &[sqlparser::ast::ObjectName],
if_exists: bool,
) -> Result<ParsedStatement, SqlError> {
let name = names
.first()
.map(|n| n.to_string())
.ok_or_else(|| SqlError::Parse("Missing object name".to_string()))?;
match object_type {
sqlparser::ast::ObjectType::Table => Ok(ParsedStatement::DropTable { name, if_exists }),
sqlparser::ast::ObjectType::Index => Ok(ParsedStatement::DropIndex { name }),
_ => Err(SqlError::Unsupported(format!("DROP not supported"))),
}
}
fn convert_query(query: &Query) -> Result<ParsedStatement, SqlError> {
let select = match &*query.body {
SetExpr::Select(select) => select,
_ => return Err(SqlError::Unsupported("Non-SELECT query body".to_string())),
};
let parsed_select = Self::convert_select(select, query)?;
Ok(ParsedStatement::Select(parsed_select))
}
fn convert_select(select: &Select, query: &Query) -> Result<ParsedSelect, SqlError> {
// Parse columns
let columns = select
.projection
.iter()
.map(Self::convert_select_item)
.collect::<Result<Vec<_>, _>>()?;
// Parse FROM
let from = select
.from
.first()
.map(|f| Self::convert_table_factor(&f.relation))
.transpose()?
.unwrap_or_default();
// Parse WHERE
let where_clause = select
.selection
.as_ref()
.map(Self::convert_expr)
.transpose()?;
// Parse ORDER BY - simplified approach
let order_by: Vec<ParsedOrderBy> = query
.order_by
.iter()
.filter_map(|e| Self::convert_order_by(e).ok())
.collect();
// Parse LIMIT/OFFSET
let limit = query
.limit
.as_ref()
.and_then(|l| Self::expr_to_usize(l));
let offset = query
.offset
.as_ref()
.and_then(|o| Self::expr_to_usize(&o.value));
Ok(ParsedSelect {
columns,
from,
joins: Vec::new(),
where_clause,
group_by: Vec::new(),
having: None,
order_by,
limit,
offset,
})
}
fn convert_select_item(item: &SelectItem) -> Result<ParsedSelectItem, SqlError> {
match item {
SelectItem::Wildcard(_) => Ok(ParsedSelectItem::Wildcard),
SelectItem::UnnamedExpr(expr) => Self::convert_select_expr(expr),
SelectItem::ExprWithAlias { expr, alias } => {
if let Expr::Identifier(id) = expr {
Ok(ParsedSelectItem::ColumnAlias {
column: id.value.clone(),
alias: alias.value.clone(),
})
} else {
Self::convert_select_expr(expr)
}
}
_ => Err(SqlError::Unsupported("Select item not supported".to_string())),
}
}
fn convert_select_expr(expr: &Expr) -> Result<ParsedSelectItem, SqlError> {
match expr {
Expr::Identifier(id) => Ok(ParsedSelectItem::Column(id.value.clone())),
Expr::CompoundIdentifier(ids) => {
Ok(ParsedSelectItem::Column(ids.last().map(|i| i.value.clone()).unwrap_or_default()))
}
Expr::Function(func) => {
let name = func.name.to_string().to_uppercase();
// Try to extract column from first arg - simplified for compatibility
let column = Self::extract_func_column_arg(func);
Ok(ParsedSelectItem::Aggregate {
function: name,
column,
alias: None,
})
}
_ => Err(SqlError::Unsupported("Select expression not supported".to_string())),
}
}
fn convert_table_factor(factor: &TableFactor) -> Result<String, SqlError> {
match factor {
TableFactor::Table { name, .. } => Ok(name.to_string()),
_ => Err(SqlError::Unsupported("Table factor not supported".to_string())),
}
}
fn extract_func_column_arg(func: &sqlparser::ast::Function) -> Option<String> {
// Use string representation and parse it - works across sqlparser versions
let func_str = func.to_string();
// Parse function like "SUM(age)" or "COUNT(*)"
if let Some(start) = func_str.find('(') {
if let Some(end) = func_str.rfind(')') {
let arg = func_str[start + 1..end].trim();
if !arg.is_empty() && arg != "*" {
return Some(arg.to_string());
}
}
}
None
}
fn convert_order_by(ob: &sqlparser::ast::OrderByExpr) -> Result<ParsedOrderBy, SqlError> {
let column = match &ob.expr {
Expr::Identifier(id) => id.value.clone(),
_ => return Err(SqlError::Unsupported("Order by expression".to_string())),
};
let ascending = ob.asc.unwrap_or(true);
Ok(ParsedOrderBy { column, ascending })
}
fn convert_expr(expr: &Expr) -> Result<ParsedExpr, SqlError> {
match expr {
Expr::Identifier(id) => Ok(ParsedExpr::Column(id.value.clone())),
Expr::CompoundIdentifier(ids) => {
Ok(ParsedExpr::Column(ids.last().map(|i| i.value.clone()).unwrap_or_default()))
}
Expr::Value(v) => Ok(ParsedExpr::Literal(Self::convert_value(v)?)),
Expr::BinaryOp { left, op, right } => {
let left = Box::new(Self::convert_expr(left)?);
let right = Box::new(Self::convert_expr(right)?);
let op = Self::convert_binary_op(op)?;
Ok(ParsedExpr::BinaryOp { left, op, right })
}
Expr::UnaryOp { op: sqlparser::ast::UnaryOperator::Not, expr } => {
Ok(ParsedExpr::Not(Box::new(Self::convert_expr(expr)?)))
}
Expr::IsNull(expr) => Ok(ParsedExpr::IsNull(Box::new(Self::convert_expr(expr)?))),
Expr::IsNotNull(expr) => Ok(ParsedExpr::IsNotNull(Box::new(Self::convert_expr(expr)?))),
Expr::InList { expr, list, negated } => Ok(ParsedExpr::InList {
expr: Box::new(Self::convert_expr(expr)?),
list: list.iter().map(Self::convert_expr).collect::<Result<_, _>>()?,
negated: *negated,
}),
Expr::Between { expr, low, high, negated } => Ok(ParsedExpr::Between {
expr: Box::new(Self::convert_expr(expr)?),
low: Box::new(Self::convert_expr(low)?),
high: Box::new(Self::convert_expr(high)?),
negated: *negated,
}),
Expr::Like { expr, pattern, .. } => {
let left = Box::new(Self::convert_expr(expr)?);
let right = Box::new(Self::convert_expr(pattern)?);
Ok(ParsedExpr::BinaryOp { left, op: BinaryOp::Like, right })
}
Expr::Nested(inner) => Self::convert_expr(inner),
_ => Err(SqlError::Unsupported("Expression not supported".to_string())),
}
}
fn convert_binary_op(op: &BinaryOperator) -> Result<BinaryOp, SqlError> {
match op {
BinaryOperator::Eq => Ok(BinaryOp::Eq),
BinaryOperator::NotEq => Ok(BinaryOp::Ne),
BinaryOperator::Lt => Ok(BinaryOp::Lt),
BinaryOperator::LtEq => Ok(BinaryOp::Le),
BinaryOperator::Gt => Ok(BinaryOp::Gt),
BinaryOperator::GtEq => Ok(BinaryOp::Ge),
BinaryOperator::And => Ok(BinaryOp::And),
BinaryOperator::Or => Ok(BinaryOp::Or),
BinaryOperator::Plus => Ok(BinaryOp::Plus),
BinaryOperator::Minus => Ok(BinaryOp::Minus),
BinaryOperator::Multiply => Ok(BinaryOp::Multiply),
BinaryOperator::Divide => Ok(BinaryOp::Divide),
_ => Err(SqlError::Unsupported("Operator not supported".to_string())),
}
}
fn convert_value(v: &AstValue) -> Result<SqlValue, SqlError> {
match v {
AstValue::Null => Ok(SqlValue::Null),
AstValue::Number(n, _) => {
if n.contains('.') {
n.parse::<f64>()
.map(SqlValue::Real)
.map_err(|e| SqlError::Parse(e.to_string()))
} else {
n.parse::<i64>()
.map(SqlValue::Integer)
.map_err(|e| SqlError::Parse(e.to_string()))
}
}
AstValue::SingleQuotedString(s) | AstValue::DoubleQuotedString(s) => {
Ok(SqlValue::Text(s.clone()))
}
AstValue::Boolean(b) => Ok(SqlValue::Boolean(*b)),
AstValue::HexStringLiteral(h) => hex::decode(h)
.map(SqlValue::Blob)
.map_err(|e| SqlError::Parse(e.to_string())),
_ => Err(SqlError::Unsupported("Value not supported".to_string())),
}
}
fn convert_value_expr(expr: &Expr) -> Result<SqlValue, SqlError> {
match expr {
Expr::Value(v) => Self::convert_value(v),
_ => Err(SqlError::Unsupported("Non-literal default".to_string())),
}
}
fn convert_insert(
table_name: &sqlparser::ast::ObjectName,
columns: &[sqlparser::ast::Ident],
source: &Option<Box<Query>>,
) -> Result<ParsedStatement, SqlError> {
let table = table_name.to_string();
let col_names: Vec<String> = columns.iter().map(|c| c.value.clone()).collect();
let values = match source.as_ref().map(|s| s.body.as_ref()) {
Some(SetExpr::Values(vals)) => {
let mut result = Vec::new();
for row in &vals.rows {
let row_values: Vec<SqlValue> = row
.iter()
.map(Self::convert_value_expr)
.collect::<Result<_, _>>()?;
result.push(row_values);
}
result
}
_ => return Err(SqlError::Unsupported("INSERT source".to_string())),
};
Ok(ParsedStatement::Insert {
table,
columns: col_names,
values,
})
}
fn convert_update(
table: &sqlparser::ast::TableWithJoins,
assignments: &[sqlparser::ast::Assignment],
selection: &Option<Expr>,
) -> Result<ParsedStatement, SqlError> {
let table_name = Self::convert_table_factor(&table.relation)?;
let parsed_assignments: Vec<(String, SqlValue)> = assignments
.iter()
.map(|a| {
let col = a.id.iter().map(|i| i.value.clone()).collect::<Vec<_>>().join(".");
let val = Self::convert_value_expr(&a.value)?;
Ok((col, val))
})
.collect::<Result<_, SqlError>>()?;
let where_clause = selection.as_ref().map(Self::convert_expr).transpose()?;
Ok(ParsedStatement::Update {
table: table_name,
assignments: parsed_assignments,
where_clause,
})
}
fn convert_delete(
from: &[sqlparser::ast::TableWithJoins],
selection: &Option<Expr>,
) -> Result<ParsedStatement, SqlError> {
let table = from
.first()
.map(|f| Self::convert_table_factor(&f.relation))
.transpose()?
.unwrap_or_default();
let where_clause = selection.as_ref().map(Self::convert_expr).transpose()?;
Ok(ParsedStatement::Delete {
table,
where_clause,
})
}
fn convert_create_index(
name: &Option<sqlparser::ast::ObjectName>,
table_name: &sqlparser::ast::ObjectName,
columns: &[sqlparser::ast::OrderByExpr],
unique: bool,
) -> Result<ParsedStatement, SqlError> {
let index_name = name
.as_ref()
.map(|n| n.to_string())
.ok_or_else(|| SqlError::Parse("Index name required".to_string()))?;
let table = table_name.to_string();
let cols: Vec<String> = columns
.iter()
.map(|c| c.expr.to_string())
.collect();
Ok(ParsedStatement::CreateIndex {
name: index_name,
table,
columns: cols,
unique,
})
}
fn expr_to_usize(expr: &Expr) -> Option<usize> {
if let Expr::Value(AstValue::Number(n, _)) = expr {
n.parse().ok()
} else {
None
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_create_table() {
let sql = "CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT NOT NULL, age INTEGER)";
let stmt = SqlParser::parse(sql).unwrap();
if let ParsedStatement::CreateTable { name, columns, .. } = stmt {
assert_eq!(name, "users");
assert_eq!(columns.len(), 3);
assert!(columns[0].primary_key);
assert!(!columns[1].nullable);
} else {
panic!("Expected CreateTable");
}
}
#[test]
fn test_parse_select() {
let sql = "SELECT name, age FROM users WHERE age > 18 ORDER BY name LIMIT 10";
let stmt = SqlParser::parse(sql).unwrap();
if let ParsedStatement::Select(select) = stmt {
assert_eq!(select.columns.len(), 2);
assert_eq!(select.from, "users");
assert!(select.where_clause.is_some());
assert_eq!(select.limit, Some(10));
} else {
panic!("Expected Select");
}
}
#[test]
fn test_parse_insert() {
let sql = "INSERT INTO users (name, age) VALUES ('Alice', 30), ('Bob', 25)";
let stmt = SqlParser::parse(sql).unwrap();
if let ParsedStatement::Insert { table, columns, values } = stmt {
assert_eq!(table, "users");
assert_eq!(columns, vec!["name", "age"]);
assert_eq!(values.len(), 2);
} else {
panic!("Expected Insert");
}
}
#[test]
fn test_parse_update() {
let sql = "UPDATE users SET age = 31 WHERE name = 'Alice'";
let stmt = SqlParser::parse(sql).unwrap();
if let ParsedStatement::Update { table, assignments, where_clause } = stmt {
assert_eq!(table, "users");
assert_eq!(assignments.len(), 1);
assert!(where_clause.is_some());
} else {
panic!("Expected Update");
}
}
#[test]
fn test_parse_delete() {
let sql = "DELETE FROM users WHERE age < 18";
let stmt = SqlParser::parse(sql).unwrap();
if let ParsedStatement::Delete { table, where_clause } = stmt {
assert_eq!(table, "users");
assert!(where_clause.is_some());
} else {
panic!("Expected Delete");
}
}
}

View file

@ -0,0 +1,241 @@
//! Row representation for SQL tables.
use super::types::{SqlError, SqlValue};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
/// Unique row identifier.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct RowId(pub u64);
impl RowId {
/// Creates a new row ID.
pub fn new(id: u64) -> Self {
RowId(id)
}
/// Returns the inner ID value.
pub fn inner(&self) -> u64 {
self.0
}
}
impl std::fmt::Display for RowId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
/// A single row in a SQL table.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Row {
/// Row identifier.
pub id: RowId,
/// Column values indexed by column name.
values: HashMap<String, SqlValue>,
/// Ordered column names (preserves insertion order).
columns: Vec<String>,
}
impl Row {
/// Creates a new empty row.
pub fn new(id: RowId) -> Self {
Self {
id,
values: HashMap::new(),
columns: Vec::new(),
}
}
/// Creates a row with given columns.
pub fn with_columns(id: RowId, columns: Vec<String>) -> Self {
let mut values = HashMap::with_capacity(columns.len());
for col in &columns {
values.insert(col.clone(), SqlValue::Null);
}
Self {
id,
values,
columns,
}
}
/// Sets a column value.
pub fn set(&mut self, column: &str, value: SqlValue) {
if !self.values.contains_key(column) {
self.columns.push(column.to_string());
}
self.values.insert(column.to_string(), value);
}
/// Gets a column value.
pub fn get(&self, column: &str) -> Option<&SqlValue> {
self.values.get(column)
}
/// Gets a column value or returns Null.
pub fn get_or_null(&self, column: &str) -> SqlValue {
self.values.get(column).cloned().unwrap_or(SqlValue::Null)
}
/// Returns all column names.
pub fn columns(&self) -> &[String] {
&self.columns
}
/// Returns all values in column order.
pub fn values(&self) -> Vec<&SqlValue> {
self.columns.iter().map(|c| self.values.get(c).unwrap()).collect()
}
/// Returns the number of columns.
pub fn len(&self) -> usize {
self.columns.len()
}
/// Returns true if the row has no columns.
pub fn is_empty(&self) -> bool {
self.columns.is_empty()
}
/// Projects to specific columns.
pub fn project(&self, columns: &[String]) -> Row {
let mut row = Row::new(self.id);
for col in columns {
if let Some(value) = self.values.get(col) {
row.set(col, value.clone());
}
}
row
}
/// Converts to a map.
pub fn to_map(&self) -> HashMap<String, SqlValue> {
self.values.clone()
}
/// Converts from a map.
pub fn from_map(id: RowId, map: HashMap<String, SqlValue>) -> Self {
let columns: Vec<String> = map.keys().cloned().collect();
Self {
id,
values: map,
columns,
}
}
}
impl PartialEq for Row {
fn eq(&self, other: &Self) -> bool {
if self.columns.len() != other.columns.len() {
return false;
}
for col in &self.columns {
if self.get(col) != other.get(col) {
return false;
}
}
true
}
}
/// Builder for creating rows.
pub struct RowBuilder {
id: RowId,
values: Vec<(String, SqlValue)>,
}
impl RowBuilder {
/// Creates a new row builder.
pub fn new(id: RowId) -> Self {
Self {
id,
values: Vec::new(),
}
}
/// Adds a column value.
pub fn column(mut self, name: impl Into<String>, value: SqlValue) -> Self {
self.values.push((name.into(), value));
self
}
/// Adds an integer column.
pub fn int(self, name: impl Into<String>, value: i64) -> Self {
self.column(name, SqlValue::Integer(value))
}
/// Adds a real column.
pub fn real(self, name: impl Into<String>, value: f64) -> Self {
self.column(name, SqlValue::Real(value))
}
/// Adds a text column.
pub fn text(self, name: impl Into<String>, value: impl Into<String>) -> Self {
self.column(name, SqlValue::Text(value.into()))
}
/// Adds a boolean column.
pub fn boolean(self, name: impl Into<String>, value: bool) -> Self {
self.column(name, SqlValue::Boolean(value))
}
/// Adds a null column.
pub fn null(self, name: impl Into<String>) -> Self {
self.column(name, SqlValue::Null)
}
/// Builds the row.
pub fn build(self) -> Row {
let mut row = Row::new(self.id);
for (name, value) in self.values {
row.set(&name, value);
}
row
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_row_basic() {
let mut row = Row::new(RowId(1));
row.set("name", SqlValue::Text("Alice".to_string()));
row.set("age", SqlValue::Integer(30));
assert_eq!(row.get("name"), Some(&SqlValue::Text("Alice".to_string())));
assert_eq!(row.get("age"), Some(&SqlValue::Integer(30)));
assert_eq!(row.get("missing"), None);
assert_eq!(row.len(), 2);
}
#[test]
fn test_row_builder() {
let row = RowBuilder::new(RowId(1))
.text("name", "Bob")
.int("age", 25)
.boolean("active", true)
.build();
assert_eq!(row.get("name"), Some(&SqlValue::Text("Bob".to_string())));
assert_eq!(row.get("age"), Some(&SqlValue::Integer(25)));
assert_eq!(row.get("active"), Some(&SqlValue::Boolean(true)));
}
#[test]
fn test_row_projection() {
let row = RowBuilder::new(RowId(1))
.text("a", "1")
.text("b", "2")
.text("c", "3")
.build();
let projected = row.project(&["a".to_string(), "c".to_string()]);
assert_eq!(projected.len(), 2);
assert!(projected.get("a").is_some());
assert!(projected.get("b").is_none());
assert!(projected.get("c").is_some());
}
}

View file

@ -0,0 +1,570 @@
//! SQL table definition and storage.
use super::row::{Row, RowId};
use super::types::{SqlError, SqlType, SqlValue};
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::collections::{BTreeMap, HashMap, HashSet};
/// Column definition.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ColumnDef {
/// Column name.
pub name: String,
/// Data type.
pub data_type: SqlType,
/// Whether null values are allowed.
pub nullable: bool,
/// Default value.
pub default: Option<SqlValue>,
/// Primary key flag.
pub primary_key: bool,
/// Unique constraint.
pub unique: bool,
}
impl ColumnDef {
/// Creates a new column definition.
pub fn new(name: impl Into<String>, data_type: SqlType) -> Self {
Self {
name: name.into(),
data_type,
nullable: true,
default: None,
primary_key: false,
unique: false,
}
}
/// Sets as not null.
pub fn not_null(mut self) -> Self {
self.nullable = false;
self
}
/// Sets default value.
pub fn default(mut self, value: SqlValue) -> Self {
self.default = Some(value);
self
}
/// Sets as primary key.
pub fn primary_key(mut self) -> Self {
self.primary_key = true;
self.nullable = false;
self.unique = true;
self
}
/// Sets as unique.
pub fn unique(mut self) -> Self {
self.unique = true;
self
}
/// Validates a value against this column definition.
pub fn validate(&self, value: &SqlValue) -> Result<(), SqlError> {
// Check null constraint
if value.is_null() && !self.nullable {
return Err(SqlError::NotNullViolation(self.name.clone()));
}
// Skip type check for null values
if value.is_null() {
return Ok(());
}
// Check type compatibility
let compatible = match (&self.data_type, value) {
(SqlType::Integer, SqlValue::Integer(_)) => true,
(SqlType::Real, SqlValue::Real(_)) => true,
(SqlType::Real, SqlValue::Integer(_)) => true, // Allow int to real
(SqlType::Text, SqlValue::Text(_)) => true,
(SqlType::Blob, SqlValue::Blob(_)) => true,
(SqlType::Boolean, SqlValue::Boolean(_)) => true,
(SqlType::Timestamp, SqlValue::Timestamp(_)) => true,
(SqlType::Timestamp, SqlValue::Integer(_)) => true, // Allow int to timestamp
_ => false,
};
if !compatible {
return Err(SqlError::TypeMismatch {
expected: format!("{:?}", self.data_type),
got: format!("{:?}", value.sql_type()),
});
}
Ok(())
}
}
/// Table definition.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct TableDef {
/// Table name.
pub name: String,
/// Column definitions.
pub columns: Vec<ColumnDef>,
/// Primary key column name (if any).
pub primary_key: Option<String>,
}
impl TableDef {
/// Creates a new table definition.
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
columns: Vec::new(),
primary_key: None,
}
}
/// Adds a column.
pub fn column(mut self, col: ColumnDef) -> Self {
if col.primary_key {
self.primary_key = Some(col.name.clone());
}
self.columns.push(col);
self
}
/// Gets column definition by name.
pub fn get_column(&self, name: &str) -> Option<&ColumnDef> {
self.columns.iter().find(|c| c.name == name)
}
/// Returns column names.
pub fn column_names(&self) -> Vec<String> {
self.columns.iter().map(|c| c.name.clone()).collect()
}
/// Validates a row against this table definition.
pub fn validate_row(&self, row: &Row) -> Result<(), SqlError> {
for col in &self.columns {
let value = row.get_or_null(&col.name);
col.validate(&value)?;
}
Ok(())
}
}
/// Index on a table column.
#[derive(Debug)]
pub struct TableIndex {
/// Index name.
pub name: String,
/// Column being indexed.
pub column: String,
/// Whether index enforces uniqueness.
pub unique: bool,
/// B-tree index data: value -> row IDs.
data: BTreeMap<SqlValue, HashSet<RowId>>,
}
impl TableIndex {
/// Creates a new index.
pub fn new(name: impl Into<String>, column: impl Into<String>, unique: bool) -> Self {
Self {
name: name.into(),
column: column.into(),
unique,
data: BTreeMap::new(),
}
}
/// Inserts a value-rowid mapping.
pub fn insert(&mut self, value: SqlValue, row_id: RowId) -> Result<(), SqlError> {
if self.unique {
if let Some(existing) = self.data.get(&value) {
if !existing.is_empty() {
return Err(SqlError::ConstraintViolation(format!(
"Unique constraint violation on index '{}'",
self.name
)));
}
}
}
self.data.entry(value).or_default().insert(row_id);
Ok(())
}
/// Removes a value-rowid mapping.
pub fn remove(&mut self, value: &SqlValue, row_id: &RowId) {
if let Some(ids) = self.data.get_mut(value) {
ids.remove(row_id);
if ids.is_empty() {
self.data.remove(value);
}
}
}
/// Looks up rows by exact value.
pub fn lookup(&self, value: &SqlValue) -> Vec<RowId> {
self.data
.get(value)
.map(|ids| ids.iter().cloned().collect())
.unwrap_or_default()
}
/// Range query.
pub fn range(&self, start: Option<&SqlValue>, end: Option<&SqlValue>) -> Vec<RowId> {
let mut result = Vec::new();
for (key, ids) in &self.data {
let in_range = match (start, end) {
(Some(s), Some(e)) => key >= s && key <= e,
(Some(s), None) => key >= s,
(None, Some(e)) => key <= e,
(None, None) => true,
};
if in_range {
result.extend(ids.iter().cloned());
}
}
result
}
}
/// A SQL table with data.
pub struct Table {
/// Table definition.
pub def: TableDef,
/// Row storage: row ID -> row data.
rows: RwLock<HashMap<RowId, Row>>,
/// Next row ID.
next_id: RwLock<u64>,
/// Indexes.
indexes: RwLock<HashMap<String, TableIndex>>,
/// Primary key index (if any).
pk_index: RwLock<Option<String>>,
}
impl Table {
/// Creates a new table.
pub fn new(def: TableDef) -> Self {
let pk_col = def.primary_key.clone();
let unique_cols: Vec<String> = def
.columns
.iter()
.filter(|c| c.unique && !c.primary_key)
.map(|c| c.name.clone())
.collect();
let table = Self {
def,
rows: RwLock::new(HashMap::new()),
next_id: RwLock::new(1),
indexes: RwLock::new(HashMap::new()),
pk_index: RwLock::new(None),
};
{
let mut indexes = table.indexes.write();
// Create primary key index if defined
if let Some(pk) = pk_col {
let idx_name = format!("pk_{}", pk);
indexes.insert(idx_name.clone(), TableIndex::new(&idx_name, &pk, true));
*table.pk_index.write() = Some(idx_name);
}
// Create indexes for unique columns
for col in unique_cols {
let idx_name = format!("unique_{}", col);
indexes.insert(idx_name.clone(), TableIndex::new(&idx_name, &col, true));
}
}
table
}
/// Returns table name.
pub fn name(&self) -> &str {
&self.def.name
}
/// Returns row count.
pub fn count(&self) -> usize {
self.rows.read().len()
}
/// Creates an index on a column.
pub fn create_index(
&self,
name: impl Into<String>,
column: impl Into<String>,
unique: bool,
) -> Result<(), SqlError> {
let name = name.into();
let column = column.into();
let mut indexes = self.indexes.write();
if indexes.contains_key(&name) {
return Err(SqlError::InvalidOperation(format!("Index '{}' already exists", name)));
}
let mut index = TableIndex::new(&name, &column, unique);
// Index existing rows
let rows = self.rows.read();
for (row_id, row) in rows.iter() {
let value = row.get_or_null(&column);
index.insert(value, *row_id)?;
}
indexes.insert(name, index);
Ok(())
}
/// Drops an index.
pub fn drop_index(&self, name: &str) -> Result<(), SqlError> {
let mut indexes = self.indexes.write();
if indexes.remove(name).is_none() {
return Err(SqlError::InvalidOperation(format!("Index '{}' not found", name)));
}
Ok(())
}
/// Inserts a row.
pub fn insert(&self, values: HashMap<String, SqlValue>) -> Result<RowId, SqlError> {
let mut next_id = self.next_id.write();
let row_id = RowId(*next_id);
*next_id += 1;
let row = Row::from_map(row_id, values);
// Validate row
self.def.validate_row(&row)?;
// Check uniqueness constraints via indexes
{
let mut indexes = self.indexes.write();
for (_, index) in indexes.iter_mut() {
if index.unique {
let value = row.get_or_null(&index.column);
if !value.is_null() && !index.lookup(&value).is_empty() {
return Err(SqlError::ConstraintViolation(format!(
"Unique constraint violation on column '{}'",
index.column
)));
}
}
}
// Update indexes
for (_, index) in indexes.iter_mut() {
let value = row.get_or_null(&index.column);
index.insert(value, row_id)?;
}
}
// Insert row
self.rows.write().insert(row_id, row);
Ok(row_id)
}
/// Gets a row by ID.
pub fn get(&self, id: RowId) -> Option<Row> {
self.rows.read().get(&id).cloned()
}
/// Updates a row.
pub fn update(&self, id: RowId, updates: HashMap<String, SqlValue>) -> Result<(), SqlError> {
let mut rows = self.rows.write();
let row = rows.get_mut(&id).ok_or_else(|| {
SqlError::InvalidOperation(format!("Row {} not found", id))
})?;
let old_values: HashMap<String, SqlValue> = updates
.keys()
.map(|k| (k.clone(), row.get_or_null(k)))
.collect();
// Validate updates
for (col, value) in &updates {
if let Some(col_def) = self.def.get_column(col) {
col_def.validate(value)?;
}
}
// Update indexes
{
let mut indexes = self.indexes.write();
for (_, index) in indexes.iter_mut() {
if let Some(new_value) = updates.get(&index.column) {
let old_value = old_values.get(&index.column).cloned().unwrap_or(SqlValue::Null);
index.remove(&old_value, &id);
index.insert(new_value.clone(), id)?;
}
}
}
// Apply updates
for (col, value) in updates {
row.set(&col, value);
}
Ok(())
}
/// Deletes a row.
pub fn delete(&self, id: RowId) -> Result<bool, SqlError> {
let mut rows = self.rows.write();
if let Some(row) = rows.remove(&id) {
// Update indexes
let mut indexes = self.indexes.write();
for (_, index) in indexes.iter_mut() {
let value = row.get_or_null(&index.column);
index.remove(&value, &id);
}
Ok(true)
} else {
Ok(false)
}
}
/// Scans all rows.
pub fn scan(&self) -> Vec<Row> {
self.rows.read().values().cloned().collect()
}
/// Scans with filter function.
pub fn scan_filter<F>(&self, predicate: F) -> Vec<Row>
where
F: Fn(&Row) -> bool,
{
self.rows
.read()
.values()
.filter(|row| predicate(row))
.cloned()
.collect()
}
/// Looks up rows by index.
pub fn lookup_index(&self, index_name: &str, value: &SqlValue) -> Vec<Row> {
let indexes = self.indexes.read();
let rows = self.rows.read();
if let Some(index) = indexes.get(index_name) {
index
.lookup(value)
.into_iter()
.filter_map(|id| rows.get(&id).cloned())
.collect()
} else {
Vec::new()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_table() -> Table {
let def = TableDef::new("users")
.column(ColumnDef::new("id", SqlType::Integer).primary_key())
.column(ColumnDef::new("name", SqlType::Text).not_null())
.column(ColumnDef::new("age", SqlType::Integer))
.column(ColumnDef::new("email", SqlType::Text).unique());
Table::new(def)
}
#[test]
fn test_table_insert() {
let table = create_test_table();
let mut values = HashMap::new();
values.insert("id".to_string(), SqlValue::Integer(1));
values.insert("name".to_string(), SqlValue::Text("Alice".to_string()));
values.insert("age".to_string(), SqlValue::Integer(30));
values.insert("email".to_string(), SqlValue::Text("alice@example.com".to_string()));
let row_id = table.insert(values).unwrap();
assert_eq!(table.count(), 1);
let row = table.get(row_id).unwrap();
assert_eq!(row.get("name"), Some(&SqlValue::Text("Alice".to_string())));
}
#[test]
fn test_table_not_null() {
let table = create_test_table();
let mut values = HashMap::new();
values.insert("id".to_string(), SqlValue::Integer(1));
// Missing required "name" field
let result = table.insert(values);
assert!(result.is_err());
}
#[test]
fn test_table_unique() {
let table = create_test_table();
let mut values1 = HashMap::new();
values1.insert("id".to_string(), SqlValue::Integer(1));
values1.insert("name".to_string(), SqlValue::Text("Alice".to_string()));
values1.insert("email".to_string(), SqlValue::Text("test@example.com".to_string()));
table.insert(values1).unwrap();
let mut values2 = HashMap::new();
values2.insert("id".to_string(), SqlValue::Integer(2));
values2.insert("name".to_string(), SqlValue::Text("Bob".to_string()));
values2.insert("email".to_string(), SqlValue::Text("test@example.com".to_string()));
let result = table.insert(values2);
assert!(result.is_err()); // Duplicate email
}
#[test]
fn test_table_update() {
let table = create_test_table();
let mut values = HashMap::new();
values.insert("id".to_string(), SqlValue::Integer(1));
values.insert("name".to_string(), SqlValue::Text("Alice".to_string()));
values.insert("age".to_string(), SqlValue::Integer(30));
let row_id = table.insert(values).unwrap();
let mut updates = HashMap::new();
updates.insert("age".to_string(), SqlValue::Integer(31));
table.update(row_id, updates).unwrap();
let row = table.get(row_id).unwrap();
assert_eq!(row.get("age"), Some(&SqlValue::Integer(31)));
}
#[test]
fn test_table_delete() {
let table = create_test_table();
let mut values = HashMap::new();
values.insert("id".to_string(), SqlValue::Integer(1));
values.insert("name".to_string(), SqlValue::Text("Alice".to_string()));
let row_id = table.insert(values).unwrap();
assert_eq!(table.count(), 1);
table.delete(row_id).unwrap();
assert_eq!(table.count(), 0);
}
#[test]
fn test_table_index() {
let table = create_test_table();
table.create_index("idx_name", "name", false).unwrap();
let mut values = HashMap::new();
values.insert("id".to_string(), SqlValue::Integer(1));
values.insert("name".to_string(), SqlValue::Text("Alice".to_string()));
table.insert(values).unwrap();
let rows = table.lookup_index("idx_name", &SqlValue::Text("Alice".to_string()));
assert_eq!(rows.len(), 1);
}
}

View file

@ -0,0 +1,355 @@
//! ACID transaction support for SQL.
use super::row::RowId;
use super::types::{SqlError, SqlValue};
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{SystemTime, UNIX_EPOCH};
/// Transaction identifier.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct TransactionId(pub u64);
impl TransactionId {
/// Creates a new transaction ID.
pub fn new() -> Self {
static COUNTER: AtomicU64 = AtomicU64::new(1);
TransactionId(COUNTER.fetch_add(1, Ordering::SeqCst))
}
}
impl Default for TransactionId {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Display for TransactionId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "txn_{}", self.0)
}
}
/// Transaction state.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum TransactionState {
/// Transaction is active.
Active,
/// Transaction is committed.
Committed,
/// Transaction is rolled back.
RolledBack,
}
/// A single operation in a transaction.
#[derive(Clone, Debug)]
pub enum TransactionOp {
/// Insert a row.
Insert {
table: String,
row_id: RowId,
values: HashMap<String, SqlValue>,
},
/// Update a row.
Update {
table: String,
row_id: RowId,
old_values: HashMap<String, SqlValue>,
new_values: HashMap<String, SqlValue>,
},
/// Delete a row.
Delete {
table: String,
row_id: RowId,
old_values: HashMap<String, SqlValue>,
},
}
/// Transaction for tracking changes.
#[derive(Debug)]
pub struct Transaction {
/// Transaction ID.
pub id: TransactionId,
/// Transaction state.
pub state: TransactionState,
/// Operations in this transaction.
operations: Vec<TransactionOp>,
/// Start time.
pub started_at: u64,
/// Isolation level.
pub isolation: IsolationLevel,
}
/// Transaction isolation levels.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum IsolationLevel {
/// Read uncommitted (dirty reads allowed).
ReadUncommitted,
/// Read committed (no dirty reads).
ReadCommitted,
/// Repeatable read (no non-repeatable reads).
RepeatableRead,
/// Serializable (full isolation).
Serializable,
}
impl Default for IsolationLevel {
fn default() -> Self {
IsolationLevel::ReadCommitted
}
}
impl Transaction {
/// Creates a new transaction.
pub fn new(isolation: IsolationLevel) -> Self {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_millis() as u64;
Self {
id: TransactionId::new(),
state: TransactionState::Active,
operations: Vec::new(),
started_at: now,
isolation,
}
}
/// Returns true if the transaction is active.
pub fn is_active(&self) -> bool {
self.state == TransactionState::Active
}
/// Records an insert operation.
pub fn record_insert(&mut self, table: String, row_id: RowId, values: HashMap<String, SqlValue>) {
self.operations.push(TransactionOp::Insert {
table,
row_id,
values,
});
}
/// Records an update operation.
pub fn record_update(
&mut self,
table: String,
row_id: RowId,
old_values: HashMap<String, SqlValue>,
new_values: HashMap<String, SqlValue>,
) {
self.operations.push(TransactionOp::Update {
table,
row_id,
old_values,
new_values,
});
}
/// Records a delete operation.
pub fn record_delete(&mut self, table: String, row_id: RowId, old_values: HashMap<String, SqlValue>) {
self.operations.push(TransactionOp::Delete {
table,
row_id,
old_values,
});
}
/// Returns operations for rollback (in reverse order).
pub fn rollback_ops(&self) -> impl Iterator<Item = &TransactionOp> {
self.operations.iter().rev()
}
/// Returns operations for commit.
pub fn commit_ops(&self) -> &[TransactionOp] {
&self.operations
}
/// Marks the transaction as committed.
pub fn mark_committed(&mut self) {
self.state = TransactionState::Committed;
}
/// Marks the transaction as rolled back.
pub fn mark_rolled_back(&mut self) {
self.state = TransactionState::RolledBack;
}
}
/// Transaction manager.
pub struct TransactionManager {
/// Active transactions.
transactions: RwLock<HashMap<TransactionId, Transaction>>,
}
impl TransactionManager {
/// Creates a new transaction manager.
pub fn new() -> Self {
Self {
transactions: RwLock::new(HashMap::new()),
}
}
/// Begins a new transaction.
pub fn begin(&self, isolation: IsolationLevel) -> TransactionId {
let txn = Transaction::new(isolation);
let id = txn.id;
self.transactions.write().insert(id, txn);
id
}
/// Gets a transaction by ID.
pub fn get(&self, id: TransactionId) -> Option<Transaction> {
self.transactions.read().get(&id).cloned()
}
/// Records an operation in a transaction.
pub fn record_op(&self, id: TransactionId, op: TransactionOp) -> Result<(), SqlError> {
let mut txns = self.transactions.write();
let txn = txns
.get_mut(&id)
.ok_or_else(|| SqlError::Transaction(format!("Transaction {} not found", id)))?;
if !txn.is_active() {
return Err(SqlError::Transaction(format!("Transaction {} is not active", id)));
}
txn.operations.push(op);
Ok(())
}
/// Commits a transaction.
pub fn commit(&self, id: TransactionId) -> Result<Vec<TransactionOp>, SqlError> {
let mut txns = self.transactions.write();
let txn = txns
.get_mut(&id)
.ok_or_else(|| SqlError::Transaction(format!("Transaction {} not found", id)))?;
if !txn.is_active() {
return Err(SqlError::Transaction(format!("Transaction {} is not active", id)));
}
txn.mark_committed();
let ops = txn.operations.clone();
txns.remove(&id);
Ok(ops)
}
/// Rolls back a transaction, returning operations to undo.
pub fn rollback(&self, id: TransactionId) -> Result<Vec<TransactionOp>, SqlError> {
let mut txns = self.transactions.write();
let txn = txns
.get_mut(&id)
.ok_or_else(|| SqlError::Transaction(format!("Transaction {} not found", id)))?;
if !txn.is_active() {
return Err(SqlError::Transaction(format!("Transaction {} is not active", id)));
}
txn.mark_rolled_back();
let ops: Vec<TransactionOp> = txn.operations.iter().rev().cloned().collect();
txns.remove(&id);
Ok(ops)
}
/// Returns the number of active transactions.
pub fn active_count(&self) -> usize {
self.transactions.read().len()
}
/// Checks if a transaction exists and is active.
pub fn is_active(&self, id: TransactionId) -> bool {
self.transactions
.read()
.get(&id)
.map(|t| t.is_active())
.unwrap_or(false)
}
}
impl Default for TransactionManager {
fn default() -> Self {
Self::new()
}
}
impl Clone for Transaction {
fn clone(&self) -> Self {
Self {
id: self.id,
state: self.state,
operations: self.operations.clone(),
started_at: self.started_at,
isolation: self.isolation,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_transaction_lifecycle() {
let manager = TransactionManager::new();
let txn_id = manager.begin(IsolationLevel::ReadCommitted);
assert!(manager.is_active(txn_id));
assert_eq!(manager.active_count(), 1);
let mut values = HashMap::new();
values.insert("name".to_string(), SqlValue::Text("Alice".to_string()));
manager
.record_op(
txn_id,
TransactionOp::Insert {
table: "users".to_string(),
row_id: RowId(1),
values,
},
)
.unwrap();
let ops = manager.commit(txn_id).unwrap();
assert_eq!(ops.len(), 1);
assert_eq!(manager.active_count(), 0);
}
#[test]
fn test_transaction_rollback() {
let manager = TransactionManager::new();
let txn_id = manager.begin(IsolationLevel::ReadCommitted);
let mut values = HashMap::new();
values.insert("name".to_string(), SqlValue::Text("Bob".to_string()));
manager
.record_op(
txn_id,
TransactionOp::Insert {
table: "users".to_string(),
row_id: RowId(1),
values,
},
)
.unwrap();
let ops = manager.rollback(txn_id).unwrap();
assert_eq!(ops.len(), 1);
assert_eq!(manager.active_count(), 0);
}
#[test]
fn test_transaction_not_found() {
let manager = TransactionManager::new();
let fake_id = TransactionId(99999);
assert!(!manager.is_active(fake_id));
assert!(manager.commit(fake_id).is_err());
assert!(manager.rollback(fake_id).is_err());
}
}

View file

@ -0,0 +1,368 @@
//! SQL type system.
use serde::{Deserialize, Serialize};
use serde_json::Value as JsonValue;
use std::cmp::Ordering;
use std::hash::{Hash, Hasher};
use thiserror::Error;
/// SQL data types.
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum SqlType {
/// 64-bit signed integer.
Integer,
/// 64-bit floating point.
Real,
/// UTF-8 text string.
Text,
/// Binary data.
Blob,
/// Boolean value.
Boolean,
/// Unix timestamp in milliseconds.
Timestamp,
/// Null type.
Null,
}
impl SqlType {
/// Parses type from string.
pub fn from_str(s: &str) -> Option<Self> {
match s.to_uppercase().as_str() {
"INTEGER" | "INT" | "BIGINT" | "SMALLINT" => Some(SqlType::Integer),
"REAL" | "FLOAT" | "DOUBLE" => Some(SqlType::Real),
"TEXT" | "VARCHAR" | "CHAR" | "STRING" => Some(SqlType::Text),
"BLOB" | "BINARY" | "BYTES" => Some(SqlType::Blob),
"BOOLEAN" | "BOOL" => Some(SqlType::Boolean),
"TIMESTAMP" | "DATETIME" | "DATE" => Some(SqlType::Timestamp),
_ => None,
}
}
/// Returns the default value for this type.
pub fn default_value(&self) -> SqlValue {
match self {
SqlType::Integer => SqlValue::Integer(0),
SqlType::Real => SqlValue::Real(0.0),
SqlType::Text => SqlValue::Text(String::new()),
SqlType::Blob => SqlValue::Blob(Vec::new()),
SqlType::Boolean => SqlValue::Boolean(false),
SqlType::Timestamp => SqlValue::Timestamp(0),
SqlType::Null => SqlValue::Null,
}
}
}
/// SQL value types.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum SqlValue {
/// Null value.
Null,
/// Integer value.
Integer(i64),
/// Real (floating point) value.
Real(f64),
/// Text string value.
Text(String),
/// Binary blob value.
Blob(Vec<u8>),
/// Boolean value.
Boolean(bool),
/// Timestamp value (Unix ms).
Timestamp(u64),
}
impl SqlValue {
/// Returns the SQL type of this value.
pub fn sql_type(&self) -> SqlType {
match self {
SqlValue::Null => SqlType::Null,
SqlValue::Integer(_) => SqlType::Integer,
SqlValue::Real(_) => SqlType::Real,
SqlValue::Text(_) => SqlType::Text,
SqlValue::Blob(_) => SqlType::Blob,
SqlValue::Boolean(_) => SqlType::Boolean,
SqlValue::Timestamp(_) => SqlType::Timestamp,
}
}
/// Returns true if this is a null value.
pub fn is_null(&self) -> bool {
matches!(self, SqlValue::Null)
}
/// Converts to integer if possible.
pub fn as_integer(&self) -> Option<i64> {
match self {
SqlValue::Integer(i) => Some(*i),
SqlValue::Real(f) => Some(*f as i64),
SqlValue::Boolean(b) => Some(if *b { 1 } else { 0 }),
SqlValue::Timestamp(t) => Some(*t as i64),
_ => None,
}
}
/// Converts to real if possible.
pub fn as_real(&self) -> Option<f64> {
match self {
SqlValue::Integer(i) => Some(*i as f64),
SqlValue::Real(f) => Some(*f),
SqlValue::Timestamp(t) => Some(*t as f64),
_ => None,
}
}
/// Converts to text if possible.
pub fn as_text(&self) -> Option<&str> {
match self {
SqlValue::Text(s) => Some(s),
_ => None,
}
}
/// Converts to boolean if possible.
pub fn as_boolean(&self) -> Option<bool> {
match self {
SqlValue::Boolean(b) => Some(*b),
SqlValue::Integer(i) => Some(*i != 0),
_ => None,
}
}
/// Converts to JSON value.
pub fn to_json(&self) -> JsonValue {
match self {
SqlValue::Null => JsonValue::Null,
SqlValue::Integer(i) => JsonValue::Number((*i).into()),
SqlValue::Real(f) => serde_json::Number::from_f64(*f)
.map(JsonValue::Number)
.unwrap_or(JsonValue::Null),
SqlValue::Text(s) => JsonValue::String(s.clone()),
SqlValue::Blob(b) => JsonValue::String(hex::encode(b)),
SqlValue::Boolean(b) => JsonValue::Bool(*b),
SqlValue::Timestamp(t) => JsonValue::Number((*t).into()),
}
}
/// Returns a numeric type order for comparison purposes.
fn type_order(&self) -> u8 {
match self {
SqlValue::Null => 0,
SqlValue::Boolean(_) => 1,
SqlValue::Integer(_) => 2,
SqlValue::Real(_) => 3,
SqlValue::Text(_) => 4,
SqlValue::Blob(_) => 5,
SqlValue::Timestamp(_) => 6,
}
}
/// Creates from JSON value.
pub fn from_json(value: &JsonValue) -> Self {
match value {
JsonValue::Null => SqlValue::Null,
JsonValue::Bool(b) => SqlValue::Boolean(*b),
JsonValue::Number(n) => {
if let Some(i) = n.as_i64() {
SqlValue::Integer(i)
} else if let Some(f) = n.as_f64() {
SqlValue::Real(f)
} else {
SqlValue::Null
}
}
JsonValue::String(s) => SqlValue::Text(s.clone()),
_ => SqlValue::Text(value.to_string()),
}
}
}
impl PartialEq for SqlValue {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(SqlValue::Null, SqlValue::Null) => true,
(SqlValue::Integer(a), SqlValue::Integer(b)) => a == b,
(SqlValue::Real(a), SqlValue::Real(b)) => a == b,
(SqlValue::Text(a), SqlValue::Text(b)) => a == b,
(SqlValue::Blob(a), SqlValue::Blob(b)) => a == b,
(SqlValue::Boolean(a), SqlValue::Boolean(b)) => a == b,
(SqlValue::Timestamp(a), SqlValue::Timestamp(b)) => a == b,
// Cross-type comparisons
(SqlValue::Integer(a), SqlValue::Real(b)) => (*a as f64) == *b,
(SqlValue::Real(a), SqlValue::Integer(b)) => *a == (*b as f64),
_ => false,
}
}
}
impl Eq for SqlValue {}
impl Hash for SqlValue {
fn hash<H: Hasher>(&self, state: &mut H) {
std::mem::discriminant(self).hash(state);
match self {
SqlValue::Null => {}
SqlValue::Integer(i) => i.hash(state),
SqlValue::Real(f) => f.to_bits().hash(state),
SqlValue::Text(s) => s.hash(state),
SqlValue::Blob(b) => b.hash(state),
SqlValue::Boolean(b) => b.hash(state),
SqlValue::Timestamp(t) => t.hash(state),
}
}
}
impl PartialOrd for SqlValue {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for SqlValue {
fn cmp(&self, other: &Self) -> Ordering {
match (self, other) {
(SqlValue::Null, SqlValue::Null) => Ordering::Equal,
(SqlValue::Null, _) => Ordering::Less,
(_, SqlValue::Null) => Ordering::Greater,
(SqlValue::Integer(a), SqlValue::Integer(b)) => a.cmp(b),
(SqlValue::Real(a), SqlValue::Real(b)) => {
// Convert to bits for total ordering (handles NaN)
a.to_bits().cmp(&b.to_bits())
}
(SqlValue::Text(a), SqlValue::Text(b)) => a.cmp(b),
(SqlValue::Blob(a), SqlValue::Blob(b)) => a.cmp(b),
(SqlValue::Boolean(a), SqlValue::Boolean(b)) => a.cmp(b),
(SqlValue::Timestamp(a), SqlValue::Timestamp(b)) => a.cmp(b),
(SqlValue::Integer(a), SqlValue::Real(b)) => {
(*a as f64).to_bits().cmp(&b.to_bits())
}
(SqlValue::Real(a), SqlValue::Integer(b)) => {
a.to_bits().cmp(&(*b as f64).to_bits())
}
// Different types: order by type discriminant
_ => self.type_order().cmp(&other.type_order()),
}
}
}
impl Default for SqlValue {
fn default() -> Self {
SqlValue::Null
}
}
impl std::fmt::Display for SqlValue {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SqlValue::Null => write!(f, "NULL"),
SqlValue::Integer(i) => write!(f, "{}", i),
SqlValue::Real(r) => write!(f, "{}", r),
SqlValue::Text(s) => write!(f, "'{}'", s),
SqlValue::Blob(b) => write!(f, "X'{}'", hex::encode(b)),
SqlValue::Boolean(b) => write!(f, "{}", if *b { "TRUE" } else { "FALSE" }),
SqlValue::Timestamp(t) => write!(f, "{}", t),
}
}
}
/// SQL errors.
#[derive(Debug, Error)]
pub enum SqlError {
/// Parse error.
#[error("Parse error: {0}")]
Parse(String),
/// Table not found.
#[error("Table not found: {0}")]
TableNotFound(String),
/// Table already exists.
#[error("Table already exists: {0}")]
TableExists(String),
/// Column not found.
#[error("Column not found: {0}")]
ColumnNotFound(String),
/// Type mismatch.
#[error("Type mismatch: expected {expected}, got {got}")]
TypeMismatch { expected: String, got: String },
/// Constraint violation.
#[error("Constraint violation: {0}")]
ConstraintViolation(String),
/// Primary key violation.
#[error("Primary key violation: duplicate key {0}")]
PrimaryKeyViolation(String),
/// Not null violation.
#[error("Not null violation: column {0} cannot be null")]
NotNullViolation(String),
/// Transaction error.
#[error("Transaction error: {0}")]
Transaction(String),
/// Invalid operation.
#[error("Invalid operation: {0}")]
InvalidOperation(String),
/// Unsupported feature.
#[error("Unsupported: {0}")]
Unsupported(String),
/// Internal error.
#[error("Internal error: {0}")]
Internal(String),
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sql_type_from_str() {
assert_eq!(SqlType::from_str("INTEGER"), Some(SqlType::Integer));
assert_eq!(SqlType::from_str("int"), Some(SqlType::Integer));
assert_eq!(SqlType::from_str("TEXT"), Some(SqlType::Text));
assert_eq!(SqlType::from_str("varchar"), Some(SqlType::Text));
assert_eq!(SqlType::from_str("BOOLEAN"), Some(SqlType::Boolean));
assert_eq!(SqlType::from_str("unknown"), None);
}
#[test]
fn test_sql_value_conversions() {
let int_val = SqlValue::Integer(42);
assert_eq!(int_val.as_integer(), Some(42));
assert_eq!(int_val.as_real(), Some(42.0));
let real_val = SqlValue::Real(3.14);
assert_eq!(real_val.as_real(), Some(3.14));
assert_eq!(real_val.as_integer(), Some(3));
let text_val = SqlValue::Text("hello".to_string());
assert_eq!(text_val.as_text(), Some("hello"));
}
#[test]
fn test_sql_value_comparison() {
assert_eq!(SqlValue::Integer(5), SqlValue::Integer(5));
assert!(SqlValue::Integer(5) < SqlValue::Integer(10));
assert!(SqlValue::Text("a".to_string()) < SqlValue::Text("b".to_string()));
assert!(SqlValue::Null < SqlValue::Integer(0));
}
#[test]
fn test_sql_value_json() {
let val = SqlValue::Integer(42);
assert_eq!(val.to_json(), serde_json::json!(42));
let val = SqlValue::Text("hello".to_string());
assert_eq!(val.to_json(), serde_json::json!("hello"));
let json = serde_json::json!(true);
let val = SqlValue::from_json(&json);
assert_eq!(val, SqlValue::Boolean(true));
}
}

View file

@ -0,0 +1,505 @@
# Phase 10 Advanced Database Features
> Implementation plan for SQL, Graph, and Replication features in Synor Database L2
## Overview
These advanced features extend the Synor Database to support:
1. **Relational (SQL)** - SQLite-compatible query subset for structured data
2. **Graph Store** - Relationship queries for connected data
3. **Replication** - Raft consensus for high availability
## Feature 1: Relational (SQL) Store
### Purpose
Provide a familiar SQL interface for developers who need structured relational queries, joins, and ACID transactions.
### Architecture
```text
┌─────────────────────────────────────────────────────────────┐
│ SQL QUERY LAYER │
├─────────────────────────────────────────────────────────────┤
│ │
│ ┌──────────────┐ ┌──────────────┐ ┌──────────────────┐ │
│ │ SQL Parser │ │ Planner │ │ Executor │ │
│ │ (sqlparser) │ │ (logical) │ │ (physical) │ │
│ └──────────────┘ └──────────────┘ └──────────────────┘ │
│ │
│ ┌──────────────────────────────────────────────────────┐ │
│ │ Table Storage Engine │ │
│ │ - Row-oriented storage │ │
│ │ - B-tree indexes │ │
│ │ - Transaction log (WAL) │ │
│ └──────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────┘
```
### Supported SQL Subset
| Category | Statements |
|----------|------------|
| DDL | CREATE TABLE, DROP TABLE, ALTER TABLE |
| DML | SELECT, INSERT, UPDATE, DELETE |
| Clauses | WHERE, ORDER BY, LIMIT, OFFSET, GROUP BY, HAVING |
| Joins | INNER JOIN, LEFT JOIN, RIGHT JOIN |
| Functions | COUNT, SUM, AVG, MIN, MAX, COALESCE |
| Operators | =, !=, <, >, <=, >=, AND, OR, NOT, IN, LIKE |
### Data Types
| SQL Type | Rust Type | Storage |
|----------|-----------|---------|
| INTEGER | i64 | 8 bytes |
| REAL | f64 | 8 bytes |
| TEXT | String | Variable |
| BLOB | Vec<u8> | Variable |
| BOOLEAN | bool | 1 byte |
| TIMESTAMP | u64 | 8 bytes (Unix ms) |
### Implementation Components
```
crates/synor-database/src/sql/
├── mod.rs # Module exports
├── parser.rs # SQL parsing (sqlparser-rs)
├── planner.rs # Query planning & optimization
├── executor.rs # Query execution engine
├── table.rs # Table definition & storage
├── row.rs # Row representation
├── types.rs # SQL type system
├── transaction.rs # ACID transactions
└── index.rs # SQL-specific indexing
```
### API Design
```rust
// Table definition
pub struct TableDef {
pub name: String,
pub columns: Vec<ColumnDef>,
pub primary_key: Option<String>,
pub indexes: Vec<IndexDef>,
}
pub struct ColumnDef {
pub name: String,
pub data_type: SqlType,
pub nullable: bool,
pub default: Option<SqlValue>,
}
// SQL execution
pub struct SqlEngine {
tables: HashMap<String, Table>,
transaction_log: TransactionLog,
}
impl SqlEngine {
pub fn execute(&mut self, sql: &str) -> Result<SqlResult, SqlError>;
pub fn begin_transaction(&mut self) -> TransactionId;
pub fn commit(&mut self, txn: TransactionId) -> Result<(), SqlError>;
pub fn rollback(&mut self, txn: TransactionId);
}
```
### Gateway Endpoints
| Endpoint | Method | Description |
|----------|--------|-------------|
| `/db/:db/sql` | POST | Execute SQL query |
| `/db/:db/sql/tables` | GET | List tables |
| `/db/:db/sql/tables/:table` | GET | Get table schema |
| `/db/:db/sql/tables/:table` | DELETE | Drop table |
---
## Feature 2: Graph Store
### Purpose
Enable relationship-based queries for social networks, knowledge graphs, recommendation engines, and any connected data.
### Architecture
```text
┌─────────────────────────────────────────────────────────────┐
│ GRAPH QUERY LAYER │
├─────────────────────────────────────────────────────────────┤
│ │
│ ┌──────────────┐ ┌──────────────┐ ┌──────────────────┐ │
│ │ Graph Query │ │ Traversal │ │ Path Finding │ │
│ │ Parser │ │ Engine │ │ (Dijkstra) │ │
│ └──────────────┘ └──────────────┘ └──────────────────┘ │
│ │
│ ┌──────────────────────────────────────────────────────┐ │
│ │ Graph Storage Engine │ │
│ │ - Adjacency list storage │ │
│ │ - Edge index (source, target, type) │ │
│ │ - Property storage │ │
│ └──────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────┘
```
### Data Model
```text
Node (Vertex):
- id: NodeId (32 bytes)
- labels: Vec<String>
- properties: JsonValue
Edge (Relationship):
- id: EdgeId (32 bytes)
- source: NodeId
- target: NodeId
- edge_type: String
- properties: JsonValue
- directed: bool
```
### Query Language (Simplified Cypher-like)
```
// Find all friends of user Alice
MATCH (a:User {name: "Alice"})-[:FRIEND]->(friend)
RETURN friend
// Find shortest path between two nodes
MATCH path = shortestPath((a:User {id: "123"})-[*]-(b:User {id: "456"}))
RETURN path
// Find mutual friends
MATCH (a:User {name: "Alice"})-[:FRIEND]->(mutual)<-[:FRIEND]-(b:User {name: "Bob"})
RETURN mutual
```
### Implementation Components
```
crates/synor-database/src/graph/
├── mod.rs # Module exports
├── node.rs # Node definition & storage
├── edge.rs # Edge definition & storage
├── store.rs # Graph storage engine
├── query.rs # Query language parser
├── traversal.rs # Graph traversal algorithms
├── path.rs # Path finding (BFS, DFS, Dijkstra)
└── index.rs # Graph-specific indexes
```
### API Design
```rust
pub struct Node {
pub id: NodeId,
pub labels: Vec<String>,
pub properties: JsonValue,
}
pub struct Edge {
pub id: EdgeId,
pub source: NodeId,
pub target: NodeId,
pub edge_type: String,
pub properties: JsonValue,
}
pub struct GraphStore {
nodes: HashMap<NodeId, Node>,
edges: HashMap<EdgeId, Edge>,
adjacency: HashMap<NodeId, Vec<EdgeId>>, // outgoing
reverse_adj: HashMap<NodeId, Vec<EdgeId>>, // incoming
}
impl GraphStore {
// Node operations
pub fn create_node(&mut self, labels: Vec<String>, props: JsonValue) -> NodeId;
pub fn get_node(&self, id: &NodeId) -> Option<&Node>;
pub fn update_node(&mut self, id: &NodeId, props: JsonValue) -> Result<(), GraphError>;
pub fn delete_node(&mut self, id: &NodeId) -> Result<(), GraphError>;
// Edge operations
pub fn create_edge(&mut self, source: NodeId, target: NodeId, edge_type: &str, props: JsonValue) -> EdgeId;
pub fn get_edge(&self, id: &EdgeId) -> Option<&Edge>;
pub fn delete_edge(&mut self, id: &EdgeId) -> Result<(), GraphError>;
// Traversal
pub fn neighbors(&self, id: &NodeId, direction: Direction) -> Vec<&Node>;
pub fn edges_of(&self, id: &NodeId, direction: Direction) -> Vec<&Edge>;
pub fn shortest_path(&self, from: &NodeId, to: &NodeId) -> Option<Vec<NodeId>>;
pub fn traverse(&self, start: &NodeId, query: &TraversalQuery) -> Vec<TraversalResult>;
}
```
### Gateway Endpoints
| Endpoint | Method | Description |
|----------|--------|-------------|
| `/db/:db/graph/nodes` | POST | Create node |
| `/db/:db/graph/nodes/:id` | GET | Get node |
| `/db/:db/graph/nodes/:id` | PUT | Update node |
| `/db/:db/graph/nodes/:id` | DELETE | Delete node |
| `/db/:db/graph/edges` | POST | Create edge |
| `/db/:db/graph/edges/:id` | GET | Get edge |
| `/db/:db/graph/edges/:id` | DELETE | Delete edge |
| `/db/:db/graph/query` | POST | Execute graph query |
| `/db/:db/graph/path` | POST | Find shortest path |
| `/db/:db/graph/traverse` | POST | Traverse from node |
---
## Feature 3: Replication (Raft Consensus)
### Purpose
Provide high availability and fault tolerance through distributed consensus, ensuring data consistency across multiple nodes.
### Architecture
```text
┌─────────────────────────────────────────────────────────────┐
│ RAFT CONSENSUS LAYER │
├─────────────────────────────────────────────────────────────┤
│ │
│ ┌──────────────┐ ┌──────────────┐ ┌──────────────────┐ │
│ │ Leader │ │ Follower │ │ Candidate │ │
│ │ Election │ │ Replication │ │ (Election) │ │
│ └──────────────┘ └──────────────┘ └──────────────────┘ │
│ │
│ ┌──────────────────────────────────────────────────────┐ │
│ │ Log Replication │ │
│ │ - Append entries │ │
│ │ - Commit index │ │
│ │ - Log compaction (snapshots) │ │
│ └──────────────────────────────────────────────────────┘ │
│ │
│ ┌──────────────────────────────────────────────────────┐ │
│ │ State Machine │ │
│ │ - Apply committed entries │ │
│ │ - Database operations │ │
│ └──────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────┘
```
### Raft Protocol Overview
```text
Leader Election:
1. Followers timeout → become Candidate
2. Candidate requests votes from peers
3. Majority votes → become Leader
4. Leader sends heartbeats to maintain authority
Log Replication:
1. Client sends write to Leader
2. Leader appends to local log
3. Leader replicates to Followers
4. Majority acknowledge → entry committed
5. Leader applies to state machine
6. Leader responds to client
```
### Implementation Components
```
crates/synor-database/src/replication/
├── mod.rs # Module exports
├── raft.rs # Core Raft implementation
├── state.rs # Node state (Leader/Follower/Candidate)
├── log.rs # Replicated log
├── rpc.rs # RPC messages (AppendEntries, RequestVote)
├── election.rs # Leader election logic
├── snapshot.rs # Log compaction & snapshots
├── cluster.rs # Cluster membership
└── client.rs # Client for forwarding to leader
```
### API Design
```rust
#[derive(Clone, Copy, PartialEq)]
pub enum NodeRole {
Leader,
Follower,
Candidate,
}
pub struct RaftConfig {
pub node_id: u64,
pub peers: Vec<PeerAddress>,
pub election_timeout_ms: (u64, u64), // min, max
pub heartbeat_interval_ms: u64,
pub snapshot_threshold: u64,
}
pub struct LogEntry {
pub term: u64,
pub index: u64,
pub command: Command,
}
pub enum Command {
// Database operations
KvSet { key: String, value: Vec<u8> },
KvDelete { key: String },
DocInsert { collection: String, doc: JsonValue },
DocUpdate { collection: String, id: DocumentId, update: JsonValue },
DocDelete { collection: String, id: DocumentId },
// ... other operations
}
pub struct RaftNode {
config: RaftConfig,
state: RaftState,
log: ReplicatedLog,
state_machine: Arc<Database>,
}
impl RaftNode {
pub async fn start(&mut self) -> Result<(), RaftError>;
pub async fn propose(&self, command: Command) -> Result<(), RaftError>;
pub fn is_leader(&self) -> bool;
pub fn leader_id(&self) -> Option<u64>;
pub fn status(&self) -> ClusterStatus;
}
// RPC Messages
pub struct AppendEntries {
pub term: u64,
pub leader_id: u64,
pub prev_log_index: u64,
pub prev_log_term: u64,
pub entries: Vec<LogEntry>,
pub leader_commit: u64,
}
pub struct RequestVote {
pub term: u64,
pub candidate_id: u64,
pub last_log_index: u64,
pub last_log_term: u64,
}
```
### Cluster Configuration
```yaml
# docker-compose.raft.yml
services:
db-node-1:
image: synor/database:latest
environment:
RAFT_NODE_ID: 1
RAFT_PEERS: "db-node-2:5000,db-node-3:5000"
RAFT_ELECTION_TIMEOUT: "150-300"
RAFT_HEARTBEAT_MS: 50
ports:
- "8484:8484" # HTTP API
- "5000:5000" # Raft RPC
db-node-2:
image: synor/database:latest
environment:
RAFT_NODE_ID: 2
RAFT_PEERS: "db-node-1:5000,db-node-3:5000"
ports:
- "8485:8484"
- "5001:5000"
db-node-3:
image: synor/database:latest
environment:
RAFT_NODE_ID: 3
RAFT_PEERS: "db-node-1:5000,db-node-2:5000"
ports:
- "8486:8484"
- "5002:5000"
```
### Gateway Endpoints
| Endpoint | Method | Description |
|----------|--------|-------------|
| `/cluster/status` | GET | Get cluster status |
| `/cluster/leader` | GET | Get current leader |
| `/cluster/nodes` | GET | List all nodes |
| `/cluster/nodes/:id` | DELETE | Remove node from cluster |
| `/cluster/nodes` | POST | Add node to cluster |
---
## Implementation Order
### Step 1: SQL Store
1. Add `sqlparser` dependency
2. Implement type system (`types.rs`)
3. Implement row storage (`row.rs`, `table.rs`)
4. Implement SQL parser wrapper (`parser.rs`)
5. Implement query planner (`planner.rs`)
6. Implement query executor (`executor.rs`)
7. Add transaction support (`transaction.rs`)
8. Add gateway endpoints
9. Write tests
### Step 2: Graph Store
1. Implement node/edge types (`node.rs`, `edge.rs`)
2. Implement graph storage (`store.rs`)
3. Implement query parser (`query.rs`)
4. Implement traversal algorithms (`traversal.rs`, `path.rs`)
5. Add graph indexes (`index.rs`)
6. Add gateway endpoints
7. Write tests
### Step 3: Replication
1. Implement Raft state machine (`state.rs`, `raft.rs`)
2. Implement replicated log (`log.rs`)
3. Implement RPC layer (`rpc.rs`)
4. Implement leader election (`election.rs`)
5. Implement log compaction (`snapshot.rs`)
6. Implement cluster management (`cluster.rs`)
7. Integrate with database operations
8. Add gateway endpoints
9. Write tests
10. Create Docker Compose for cluster
---
## Pricing Impact
| Feature | Operation | Cost (SYNOR) |
|---------|-----------|--------------|
| SQL | Query/million | 0.02 |
| SQL | Write/million | 0.05 |
| Graph | Traversal/million | 0.03 |
| Graph | Path query/million | 0.05 |
| Replication | Included | Base storage cost |
---
## Success Criteria
### SQL Store
- [ ] Parse and execute basic SELECT, INSERT, UPDATE, DELETE
- [ ] Support WHERE clauses with operators
- [ ] Support ORDER BY, LIMIT, OFFSET
- [ ] Support simple JOINs
- [ ] Support aggregate functions
- [ ] ACID transactions
### Graph Store
- [ ] Create/read/update/delete nodes and edges
- [ ] Traverse neighbors (in/out/both)
- [ ] Find shortest path between nodes
- [ ] Execute pattern matching queries
- [ ] Support property filters
### Replication
- [ ] Leader election works correctly
- [ ] Log replication achieves consensus
- [ ] Reads from any node (eventual consistency)
- [ ] Writes only through leader
- [ ] Node failure handled gracefully
- [ ] Log compaction reduces storage