feat(database): add SQL, Graph, and Raft Replication modules
- SQL store with SQLite-compatible subset (sqlparser 0.43) - CREATE TABLE, INSERT, SELECT, UPDATE, DELETE - WHERE clauses, ORDER BY, LIMIT - Aggregates (COUNT, SUM, AVG, MIN, MAX) - UNIQUE and NOT NULL constraints - BTreeMap-based indexes - Graph store for relationship-based queries - Nodes with labels and properties - Edges with types and weights - BFS/DFS traversal - Dijkstra shortest path - Cypher-like query parser (MATCH, CREATE, DELETE, SET) - Raft consensus replication for high availability - Leader election with randomized timeouts - Log replication with AppendEntries RPC - Snapshot management for log compaction - Cluster configuration and joint consensus - Full RPC message serialization All 159 tests pass.
This commit is contained in:
parent
ab4c967a97
commit
8da34bc73d
25 changed files with 9842 additions and 0 deletions
|
|
@ -15,6 +15,7 @@ synor-storage = { path = "../synor-storage" }
|
|||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
borsh.workspace = true
|
||||
bincode = "1.3"
|
||||
|
||||
# Utilities
|
||||
thiserror.workspace = true
|
||||
|
|
@ -36,6 +37,9 @@ indexmap = "2.2"
|
|||
axum.workspace = true
|
||||
tower-http = { version = "0.5", features = ["cors", "trace"] }
|
||||
|
||||
# SQL parsing
|
||||
sqlparser = "0.43"
|
||||
|
||||
# Vector operations (for AI/RAG)
|
||||
# Using pure Rust for portability
|
||||
|
||||
|
|
|
|||
301
crates/synor-database/src/graph/edge.rs
Normal file
301
crates/synor-database/src/graph/edge.rs
Normal file
|
|
@ -0,0 +1,301 @@
|
|||
//! Graph edge (relationship) definition.
|
||||
|
||||
use super::node::NodeId;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value as JsonValue;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
|
||||
/// Unique edge identifier.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub struct EdgeId(pub [u8; 32]);
|
||||
|
||||
impl EdgeId {
|
||||
/// Creates a new unique edge ID.
|
||||
pub fn new() -> Self {
|
||||
static COUNTER: AtomicU64 = AtomicU64::new(1);
|
||||
let id = COUNTER.fetch_add(1, Ordering::SeqCst);
|
||||
let mut bytes = [0u8; 32];
|
||||
bytes[..8].copy_from_slice(&id.to_be_bytes());
|
||||
let now = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_nanos() as u64;
|
||||
bytes[8..16].copy_from_slice(&now.to_be_bytes());
|
||||
EdgeId(*blake3::hash(&bytes).as_bytes())
|
||||
}
|
||||
|
||||
/// Creates from raw bytes.
|
||||
pub fn from_bytes(bytes: [u8; 32]) -> Self {
|
||||
EdgeId(bytes)
|
||||
}
|
||||
|
||||
/// Creates from hex string.
|
||||
pub fn from_hex(hex: &str) -> Option<Self> {
|
||||
let bytes = hex::decode(hex).ok()?;
|
||||
if bytes.len() != 32 {
|
||||
return None;
|
||||
}
|
||||
let mut arr = [0u8; 32];
|
||||
arr.copy_from_slice(&bytes);
|
||||
Some(EdgeId(arr))
|
||||
}
|
||||
|
||||
/// Returns the bytes.
|
||||
pub fn as_bytes(&self) -> &[u8; 32] {
|
||||
&self.0
|
||||
}
|
||||
|
||||
/// Converts to hex string.
|
||||
pub fn to_hex(&self) -> String {
|
||||
hex::encode(self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for EdgeId {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for EdgeId {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "edge_{}", hex::encode(&self.0[..8]))
|
||||
}
|
||||
}
|
||||
|
||||
/// An edge (relationship) in the graph.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Edge {
|
||||
/// Unique edge ID.
|
||||
pub id: EdgeId,
|
||||
/// Source node ID.
|
||||
pub source: NodeId,
|
||||
/// Target node ID.
|
||||
pub target: NodeId,
|
||||
/// Edge type (relationship type), e.g., "FRIEND", "OWNS".
|
||||
pub edge_type: String,
|
||||
/// Properties stored as JSON.
|
||||
pub properties: JsonValue,
|
||||
/// Whether this is a directed edge.
|
||||
pub directed: bool,
|
||||
/// Weight for path-finding algorithms.
|
||||
pub weight: f64,
|
||||
/// Creation timestamp.
|
||||
pub created_at: u64,
|
||||
}
|
||||
|
||||
impl Edge {
|
||||
/// Creates a new directed edge.
|
||||
pub fn new(source: NodeId, target: NodeId, edge_type: impl Into<String>, properties: JsonValue) -> Self {
|
||||
let now = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_millis() as u64;
|
||||
|
||||
Self {
|
||||
id: EdgeId::new(),
|
||||
source,
|
||||
target,
|
||||
edge_type: edge_type.into(),
|
||||
properties,
|
||||
directed: true,
|
||||
weight: 1.0,
|
||||
created_at: now,
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates an undirected edge.
|
||||
pub fn undirected(source: NodeId, target: NodeId, edge_type: impl Into<String>, properties: JsonValue) -> Self {
|
||||
let mut edge = Self::new(source, target, edge_type, properties);
|
||||
edge.directed = false;
|
||||
edge
|
||||
}
|
||||
|
||||
/// Sets the weight for this edge.
|
||||
pub fn with_weight(mut self, weight: f64) -> Self {
|
||||
self.weight = weight;
|
||||
self
|
||||
}
|
||||
|
||||
/// Returns the other end of this edge from the given node.
|
||||
pub fn other_end(&self, from: &NodeId) -> Option<NodeId> {
|
||||
if &self.source == from {
|
||||
Some(self.target)
|
||||
} else if &self.target == from && !self.directed {
|
||||
Some(self.source)
|
||||
} else if &self.target == from {
|
||||
// For directed edges, can't traverse backward
|
||||
None
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Checks if this edge connects the given node (as source or target).
|
||||
pub fn connects(&self, node: &NodeId) -> bool {
|
||||
&self.source == node || &self.target == node
|
||||
}
|
||||
|
||||
/// Checks if this edge connects two specific nodes.
|
||||
pub fn connects_pair(&self, a: &NodeId, b: &NodeId) -> bool {
|
||||
(&self.source == a && &self.target == b) ||
|
||||
(!self.directed && &self.source == b && &self.target == a)
|
||||
}
|
||||
|
||||
/// Gets a property value.
|
||||
pub fn get_property(&self, key: &str) -> Option<&JsonValue> {
|
||||
self.properties.get(key)
|
||||
}
|
||||
|
||||
/// Sets a property value.
|
||||
pub fn set_property(&mut self, key: &str, value: JsonValue) {
|
||||
if let Some(obj) = self.properties.as_object_mut() {
|
||||
obj.insert(key.to_string(), value);
|
||||
}
|
||||
}
|
||||
|
||||
/// Checks if the edge matches a property filter.
|
||||
pub fn matches_properties(&self, filter: &JsonValue) -> bool {
|
||||
if let (Some(filter_obj), Some(props_obj)) = (filter.as_object(), self.properties.as_object()) {
|
||||
for (key, expected) in filter_obj {
|
||||
if let Some(actual) = props_obj.get(key) {
|
||||
if actual != expected {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
} else {
|
||||
filter == &self.properties || filter == &JsonValue::Object(serde_json::Map::new())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Builder for creating edges.
|
||||
pub struct EdgeBuilder {
|
||||
source: NodeId,
|
||||
target: NodeId,
|
||||
edge_type: String,
|
||||
properties: serde_json::Map<String, JsonValue>,
|
||||
directed: bool,
|
||||
weight: f64,
|
||||
}
|
||||
|
||||
impl EdgeBuilder {
|
||||
/// Creates a new edge builder.
|
||||
pub fn new(source: NodeId, target: NodeId, edge_type: impl Into<String>) -> Self {
|
||||
Self {
|
||||
source,
|
||||
target,
|
||||
edge_type: edge_type.into(),
|
||||
properties: serde_json::Map::new(),
|
||||
directed: true,
|
||||
weight: 1.0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Sets the edge as undirected.
|
||||
pub fn undirected(mut self) -> Self {
|
||||
self.directed = false;
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the weight.
|
||||
pub fn weight(mut self, weight: f64) -> Self {
|
||||
self.weight = weight;
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets a property.
|
||||
pub fn property(mut self, key: impl Into<String>, value: impl Into<JsonValue>) -> Self {
|
||||
self.properties.insert(key.into(), value.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Builds the edge.
|
||||
pub fn build(self) -> Edge {
|
||||
let mut edge = Edge::new(self.source, self.target, self.edge_type, JsonValue::Object(self.properties));
|
||||
edge.directed = self.directed;
|
||||
edge.weight = self.weight;
|
||||
edge
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_edge_id() {
|
||||
let id1 = EdgeId::new();
|
||||
let id2 = EdgeId::new();
|
||||
assert_ne!(id1, id2);
|
||||
|
||||
let hex = id1.to_hex();
|
||||
let id3 = EdgeId::from_hex(&hex).unwrap();
|
||||
assert_eq!(id1, id3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_edge_creation() {
|
||||
let source = NodeId::new();
|
||||
let target = NodeId::new();
|
||||
|
||||
let edge = Edge::new(source, target, "FRIEND", serde_json::json!({"since": 2020}));
|
||||
|
||||
assert_eq!(edge.source, source);
|
||||
assert_eq!(edge.target, target);
|
||||
assert_eq!(edge.edge_type, "FRIEND");
|
||||
assert!(edge.directed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_edge_builder() {
|
||||
let source = NodeId::new();
|
||||
let target = NodeId::new();
|
||||
|
||||
let edge = EdgeBuilder::new(source, target, "OWNS")
|
||||
.undirected()
|
||||
.weight(2.5)
|
||||
.property("percentage", 50)
|
||||
.build();
|
||||
|
||||
assert!(!edge.directed);
|
||||
assert_eq!(edge.weight, 2.5);
|
||||
assert_eq!(edge.get_property("percentage"), Some(&serde_json::json!(50)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_edge_other_end() {
|
||||
let source = NodeId::new();
|
||||
let target = NodeId::new();
|
||||
|
||||
// Directed edge
|
||||
let directed = Edge::new(source, target, "A", serde_json::json!({}));
|
||||
assert_eq!(directed.other_end(&source), Some(target));
|
||||
assert_eq!(directed.other_end(&target), None); // Can't traverse backward
|
||||
|
||||
// Undirected edge
|
||||
let undirected = Edge::undirected(source, target, "B", serde_json::json!({}));
|
||||
assert_eq!(undirected.other_end(&source), Some(target));
|
||||
assert_eq!(undirected.other_end(&target), Some(source));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_edge_connects() {
|
||||
let a = NodeId::new();
|
||||
let b = NodeId::new();
|
||||
let c = NodeId::new();
|
||||
|
||||
let edge = Edge::new(a, b, "LINK", serde_json::json!({}));
|
||||
|
||||
assert!(edge.connects(&a));
|
||||
assert!(edge.connects(&b));
|
||||
assert!(!edge.connects(&c));
|
||||
|
||||
assert!(edge.connects_pair(&a, &b));
|
||||
assert!(!edge.connects_pair(&b, &a)); // Directed
|
||||
}
|
||||
}
|
||||
18
crates/synor-database/src/graph/mod.rs
Normal file
18
crates/synor-database/src/graph/mod.rs
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
//! Graph store module for relationship-based queries.
|
||||
//!
|
||||
//! Provides a graph database with nodes, edges, and traversal algorithms
|
||||
//! suitable for social networks, knowledge graphs, and recommendation engines.
|
||||
|
||||
pub mod edge;
|
||||
pub mod node;
|
||||
pub mod path;
|
||||
pub mod query;
|
||||
pub mod store;
|
||||
pub mod traversal;
|
||||
|
||||
pub use edge::{Edge, EdgeId};
|
||||
pub use node::{Node, NodeId};
|
||||
pub use path::{PathFinder, PathResult};
|
||||
pub use query::{GraphQuery, GraphQueryParser, MatchPattern, QueryResult};
|
||||
pub use store::{Direction, GraphError, GraphStore};
|
||||
pub use traversal::{TraversalQuery, TraversalResult, Traverser};
|
||||
315
crates/synor-database/src/graph/node.rs
Normal file
315
crates/synor-database/src/graph/node.rs
Normal file
|
|
@ -0,0 +1,315 @@
|
|||
//! Graph node (vertex) definition.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value as JsonValue;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
|
||||
/// Unique node identifier.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub struct NodeId(pub [u8; 32]);
|
||||
|
||||
impl NodeId {
|
||||
/// Creates a new unique node ID.
|
||||
pub fn new() -> Self {
|
||||
static COUNTER: AtomicU64 = AtomicU64::new(1);
|
||||
let id = COUNTER.fetch_add(1, Ordering::SeqCst);
|
||||
let mut bytes = [0u8; 32];
|
||||
bytes[..8].copy_from_slice(&id.to_be_bytes());
|
||||
// Add timestamp for uniqueness
|
||||
let now = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_nanos() as u64;
|
||||
bytes[8..16].copy_from_slice(&now.to_be_bytes());
|
||||
NodeId(*blake3::hash(&bytes).as_bytes())
|
||||
}
|
||||
|
||||
/// Creates from raw bytes.
|
||||
pub fn from_bytes(bytes: [u8; 32]) -> Self {
|
||||
NodeId(bytes)
|
||||
}
|
||||
|
||||
/// Creates from hex string.
|
||||
pub fn from_hex(hex: &str) -> Option<Self> {
|
||||
let bytes = hex::decode(hex).ok()?;
|
||||
if bytes.len() != 32 {
|
||||
return None;
|
||||
}
|
||||
let mut arr = [0u8; 32];
|
||||
arr.copy_from_slice(&bytes);
|
||||
Some(NodeId(arr))
|
||||
}
|
||||
|
||||
/// Returns the bytes.
|
||||
pub fn as_bytes(&self) -> &[u8; 32] {
|
||||
&self.0
|
||||
}
|
||||
|
||||
/// Converts to hex string.
|
||||
pub fn to_hex(&self) -> String {
|
||||
hex::encode(self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for NodeId {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for NodeId {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "node_{}", hex::encode(&self.0[..8]))
|
||||
}
|
||||
}
|
||||
|
||||
/// A node (vertex) in the graph.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Node {
|
||||
/// Unique node ID.
|
||||
pub id: NodeId,
|
||||
/// Labels (types) for this node (e.g., "User", "Product").
|
||||
pub labels: Vec<String>,
|
||||
/// Properties stored as JSON.
|
||||
pub properties: JsonValue,
|
||||
/// Creation timestamp.
|
||||
pub created_at: u64,
|
||||
/// Last update timestamp.
|
||||
pub updated_at: u64,
|
||||
}
|
||||
|
||||
impl Node {
|
||||
/// Creates a new node with the given labels and properties.
|
||||
pub fn new(labels: Vec<String>, properties: JsonValue) -> Self {
|
||||
let now = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_millis() as u64;
|
||||
|
||||
Self {
|
||||
id: NodeId::new(),
|
||||
labels,
|
||||
properties,
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a node with a specific ID.
|
||||
pub fn with_id(id: NodeId, labels: Vec<String>, properties: JsonValue) -> Self {
|
||||
let now = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_millis() as u64;
|
||||
|
||||
Self {
|
||||
id,
|
||||
labels,
|
||||
properties,
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if this node has the given label.
|
||||
pub fn has_label(&self, label: &str) -> bool {
|
||||
self.labels.iter().any(|l| l == label)
|
||||
}
|
||||
|
||||
/// Adds a label to this node.
|
||||
pub fn add_label(&mut self, label: impl Into<String>) {
|
||||
let label = label.into();
|
||||
if !self.labels.contains(&label) {
|
||||
self.labels.push(label);
|
||||
self.touch();
|
||||
}
|
||||
}
|
||||
|
||||
/// Removes a label from this node.
|
||||
pub fn remove_label(&mut self, label: &str) -> bool {
|
||||
if let Some(pos) = self.labels.iter().position(|l| l == label) {
|
||||
self.labels.remove(pos);
|
||||
self.touch();
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Gets a property value.
|
||||
pub fn get_property(&self, key: &str) -> Option<&JsonValue> {
|
||||
self.properties.get(key)
|
||||
}
|
||||
|
||||
/// Sets a property value.
|
||||
pub fn set_property(&mut self, key: &str, value: JsonValue) {
|
||||
if let Some(obj) = self.properties.as_object_mut() {
|
||||
obj.insert(key.to_string(), value);
|
||||
self.touch();
|
||||
}
|
||||
}
|
||||
|
||||
/// Removes a property.
|
||||
pub fn remove_property(&mut self, key: &str) -> Option<JsonValue> {
|
||||
if let Some(obj) = self.properties.as_object_mut() {
|
||||
let removed = obj.remove(key);
|
||||
if removed.is_some() {
|
||||
self.touch();
|
||||
}
|
||||
removed
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Updates the updated_at timestamp.
|
||||
fn touch(&mut self) {
|
||||
self.updated_at = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_millis() as u64;
|
||||
}
|
||||
|
||||
/// Checks if the node matches a property filter.
|
||||
pub fn matches_properties(&self, filter: &JsonValue) -> bool {
|
||||
if let (Some(filter_obj), Some(props_obj)) = (filter.as_object(), self.properties.as_object()) {
|
||||
for (key, expected) in filter_obj {
|
||||
if let Some(actual) = props_obj.get(key) {
|
||||
if actual != expected {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
} else {
|
||||
filter == &self.properties
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Builder for creating nodes.
|
||||
pub struct NodeBuilder {
|
||||
labels: Vec<String>,
|
||||
properties: serde_json::Map<String, JsonValue>,
|
||||
}
|
||||
|
||||
impl NodeBuilder {
|
||||
/// Creates a new node builder.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
labels: Vec::new(),
|
||||
properties: serde_json::Map::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Adds a label.
|
||||
pub fn label(mut self, label: impl Into<String>) -> Self {
|
||||
self.labels.push(label.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Adds multiple labels.
|
||||
pub fn labels(mut self, labels: Vec<String>) -> Self {
|
||||
self.labels.extend(labels);
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets a property.
|
||||
pub fn property(mut self, key: impl Into<String>, value: impl Into<JsonValue>) -> Self {
|
||||
self.properties.insert(key.into(), value.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Builds the node.
|
||||
pub fn build(self) -> Node {
|
||||
Node::new(self.labels, JsonValue::Object(self.properties))
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for NodeBuilder {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_node_id() {
|
||||
let id1 = NodeId::new();
|
||||
let id2 = NodeId::new();
|
||||
assert_ne!(id1, id2);
|
||||
|
||||
let hex = id1.to_hex();
|
||||
let id3 = NodeId::from_hex(&hex).unwrap();
|
||||
assert_eq!(id1, id3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_node_creation() {
|
||||
let node = Node::new(
|
||||
vec!["User".to_string()],
|
||||
serde_json::json!({"name": "Alice", "age": 30}),
|
||||
);
|
||||
|
||||
assert!(node.has_label("User"));
|
||||
assert!(!node.has_label("Admin"));
|
||||
assert_eq!(
|
||||
node.get_property("name"),
|
||||
Some(&serde_json::json!("Alice"))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_node_builder() {
|
||||
let node = NodeBuilder::new()
|
||||
.label("Person")
|
||||
.label("Developer")
|
||||
.property("name", "Bob")
|
||||
.property("skills", serde_json::json!(["Rust", "Python"]))
|
||||
.build();
|
||||
|
||||
assert!(node.has_label("Person"));
|
||||
assert!(node.has_label("Developer"));
|
||||
assert_eq!(node.get_property("name"), Some(&serde_json::json!("Bob")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_node_labels() {
|
||||
let mut node = Node::new(vec!["User".to_string()], serde_json::json!({}));
|
||||
|
||||
node.add_label("Admin");
|
||||
assert!(node.has_label("Admin"));
|
||||
|
||||
node.remove_label("Admin");
|
||||
assert!(!node.has_label("Admin"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_node_properties() {
|
||||
let mut node = Node::new(vec![], serde_json::json!({}));
|
||||
|
||||
node.set_property("key", serde_json::json!("value"));
|
||||
assert_eq!(node.get_property("key"), Some(&serde_json::json!("value")));
|
||||
|
||||
let removed = node.remove_property("key");
|
||||
assert_eq!(removed, Some(serde_json::json!("value")));
|
||||
assert_eq!(node.get_property("key"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_node_matches() {
|
||||
let node = Node::new(
|
||||
vec!["User".to_string()],
|
||||
serde_json::json!({"name": "Alice", "age": 30, "active": true}),
|
||||
);
|
||||
|
||||
assert!(node.matches_properties(&serde_json::json!({"name": "Alice"})));
|
||||
assert!(node.matches_properties(&serde_json::json!({"name": "Alice", "age": 30})));
|
||||
assert!(!node.matches_properties(&serde_json::json!({"name": "Bob"})));
|
||||
}
|
||||
}
|
||||
485
crates/synor-database/src/graph/path.rs
Normal file
485
crates/synor-database/src/graph/path.rs
Normal file
|
|
@ -0,0 +1,485 @@
|
|||
//! Path finding algorithms for graphs.
|
||||
|
||||
use super::edge::Edge;
|
||||
use super::node::{Node, NodeId};
|
||||
use super::store::{Direction, GraphStore};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::cmp::Ordering;
|
||||
use std::collections::{BinaryHeap, HashMap, HashSet, VecDeque};
|
||||
|
||||
/// Result of a path finding operation.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct PathResult {
|
||||
/// The path as a list of node IDs.
|
||||
pub nodes: Vec<NodeId>,
|
||||
/// The edges traversed.
|
||||
pub edges: Vec<Edge>,
|
||||
/// Total path length (number of hops or weighted distance).
|
||||
pub length: f64,
|
||||
/// Whether a path was found.
|
||||
pub found: bool,
|
||||
}
|
||||
|
||||
impl PathResult {
|
||||
/// Creates an empty (not found) result.
|
||||
pub fn not_found() -> Self {
|
||||
Self {
|
||||
nodes: Vec::new(),
|
||||
edges: Vec::new(),
|
||||
length: f64::INFINITY,
|
||||
found: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a found result.
|
||||
pub fn found(nodes: Vec<NodeId>, edges: Vec<Edge>, length: f64) -> Self {
|
||||
Self {
|
||||
nodes,
|
||||
edges,
|
||||
length,
|
||||
found: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// State for Dijkstra's algorithm priority queue.
|
||||
#[derive(Clone)]
|
||||
struct DijkstraState {
|
||||
node: NodeId,
|
||||
distance: f64,
|
||||
}
|
||||
|
||||
impl PartialEq for DijkstraState {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.distance == other.distance
|
||||
}
|
||||
}
|
||||
|
||||
impl Eq for DijkstraState {}
|
||||
|
||||
impl Ord for DijkstraState {
|
||||
fn cmp(&self, other: &Self) -> Ordering {
|
||||
// Reverse ordering for min-heap
|
||||
other.distance.partial_cmp(&self.distance).unwrap_or(Ordering::Equal)
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialOrd for DijkstraState {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
}
|
||||
|
||||
/// Path finder for graph shortest path queries.
|
||||
pub struct PathFinder<'a> {
|
||||
store: &'a GraphStore,
|
||||
}
|
||||
|
||||
impl<'a> PathFinder<'a> {
|
||||
/// Creates a new path finder.
|
||||
pub fn new(store: &'a GraphStore) -> Self {
|
||||
Self { store }
|
||||
}
|
||||
|
||||
/// Finds the shortest path using BFS (unweighted).
|
||||
pub fn shortest_path_bfs(&self, from: &NodeId, to: &NodeId) -> PathResult {
|
||||
if from == to {
|
||||
return PathResult::found(vec![*from], Vec::new(), 0.0);
|
||||
}
|
||||
|
||||
let mut visited = HashSet::new();
|
||||
let mut queue: VecDeque<(NodeId, Vec<NodeId>, Vec<Edge>)> = VecDeque::new();
|
||||
|
||||
visited.insert(*from);
|
||||
queue.push_back((*from, vec![*from], Vec::new()));
|
||||
|
||||
while let Some((current, path, edges)) = queue.pop_front() {
|
||||
let neighbor_edges = self.store.edges_of(¤t, Direction::Both);
|
||||
|
||||
for edge in neighbor_edges {
|
||||
let neighbor = if edge.source == current {
|
||||
edge.target
|
||||
} else if !edge.directed || edge.target == current {
|
||||
edge.source
|
||||
} else {
|
||||
continue;
|
||||
};
|
||||
|
||||
if &neighbor == to {
|
||||
let mut final_path = path.clone();
|
||||
final_path.push(neighbor);
|
||||
let path_len = final_path.len();
|
||||
let mut final_edges = edges.clone();
|
||||
final_edges.push(edge);
|
||||
return PathResult::found(final_path, final_edges, path_len as f64 - 1.0);
|
||||
}
|
||||
|
||||
if !visited.contains(&neighbor) {
|
||||
visited.insert(neighbor);
|
||||
let mut new_path = path.clone();
|
||||
new_path.push(neighbor);
|
||||
let mut new_edges = edges.clone();
|
||||
new_edges.push(edge);
|
||||
queue.push_back((neighbor, new_path, new_edges));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
PathResult::not_found()
|
||||
}
|
||||
|
||||
/// Finds the shortest path using Dijkstra's algorithm (weighted).
|
||||
pub fn shortest_path_dijkstra(&self, from: &NodeId, to: &NodeId) -> PathResult {
|
||||
if from == to {
|
||||
return PathResult::found(vec![*from], Vec::new(), 0.0);
|
||||
}
|
||||
|
||||
let mut distances: HashMap<NodeId, f64> = HashMap::new();
|
||||
let mut previous: HashMap<NodeId, (NodeId, Edge)> = HashMap::new();
|
||||
let mut heap = BinaryHeap::new();
|
||||
let mut visited = HashSet::new();
|
||||
|
||||
distances.insert(*from, 0.0);
|
||||
heap.push(DijkstraState { node: *from, distance: 0.0 });
|
||||
|
||||
while let Some(DijkstraState { node: current, distance: dist }) = heap.pop() {
|
||||
if ¤t == to {
|
||||
// Reconstruct path
|
||||
let mut path = vec![current];
|
||||
let mut edges = Vec::new();
|
||||
let mut curr = current;
|
||||
|
||||
while let Some((prev, edge)) = previous.get(&curr) {
|
||||
path.push(*prev);
|
||||
edges.push(edge.clone());
|
||||
curr = *prev;
|
||||
}
|
||||
|
||||
path.reverse();
|
||||
edges.reverse();
|
||||
return PathResult::found(path, edges, dist);
|
||||
}
|
||||
|
||||
if visited.contains(¤t) {
|
||||
continue;
|
||||
}
|
||||
visited.insert(current);
|
||||
|
||||
let neighbor_edges = self.store.edges_of(¤t, Direction::Both);
|
||||
|
||||
for edge in neighbor_edges {
|
||||
let neighbor = if edge.source == current {
|
||||
edge.target
|
||||
} else if !edge.directed || edge.target == current {
|
||||
edge.source
|
||||
} else {
|
||||
continue;
|
||||
};
|
||||
|
||||
if visited.contains(&neighbor) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let new_dist = dist + edge.weight;
|
||||
let is_shorter = distances.get(&neighbor).map(|&d| new_dist < d).unwrap_or(true);
|
||||
|
||||
if is_shorter {
|
||||
distances.insert(neighbor, new_dist);
|
||||
previous.insert(neighbor, (current, edge.clone()));
|
||||
heap.push(DijkstraState { node: neighbor, distance: new_dist });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
PathResult::not_found()
|
||||
}
|
||||
|
||||
/// Finds all paths between two nodes up to a maximum length.
|
||||
pub fn all_paths(&self, from: &NodeId, to: &NodeId, max_length: usize) -> Vec<PathResult> {
|
||||
let mut results = Vec::new();
|
||||
let mut current_path = vec![*from];
|
||||
let mut current_edges = Vec::new();
|
||||
let mut visited = HashSet::new();
|
||||
visited.insert(*from);
|
||||
|
||||
self.find_all_paths_dfs(
|
||||
from,
|
||||
to,
|
||||
max_length,
|
||||
&mut current_path,
|
||||
&mut current_edges,
|
||||
&mut visited,
|
||||
&mut results,
|
||||
);
|
||||
|
||||
results
|
||||
}
|
||||
|
||||
fn find_all_paths_dfs(
|
||||
&self,
|
||||
current: &NodeId,
|
||||
target: &NodeId,
|
||||
max_length: usize,
|
||||
path: &mut Vec<NodeId>,
|
||||
edges: &mut Vec<Edge>,
|
||||
visited: &mut HashSet<NodeId>,
|
||||
results: &mut Vec<PathResult>,
|
||||
) {
|
||||
if current == target {
|
||||
let total_weight: f64 = edges.iter().map(|e| e.weight).sum();
|
||||
results.push(PathResult::found(path.clone(), edges.clone(), total_weight));
|
||||
return;
|
||||
}
|
||||
|
||||
if path.len() > max_length {
|
||||
return;
|
||||
}
|
||||
|
||||
let neighbor_edges = self.store.edges_of(current, Direction::Both);
|
||||
|
||||
for edge in neighbor_edges {
|
||||
let neighbor = if edge.source == *current {
|
||||
edge.target
|
||||
} else if !edge.directed || edge.target == *current {
|
||||
edge.source
|
||||
} else {
|
||||
continue;
|
||||
};
|
||||
|
||||
if !visited.contains(&neighbor) {
|
||||
visited.insert(neighbor);
|
||||
path.push(neighbor);
|
||||
edges.push(edge.clone());
|
||||
|
||||
self.find_all_paths_dfs(&neighbor, target, max_length, path, edges, visited, results);
|
||||
|
||||
path.pop();
|
||||
edges.pop();
|
||||
visited.remove(&neighbor);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Finds the shortest path considering only specific edge types.
|
||||
pub fn shortest_path_by_type(&self, from: &NodeId, to: &NodeId, edge_types: &[String]) -> PathResult {
|
||||
if from == to {
|
||||
return PathResult::found(vec![*from], Vec::new(), 0.0);
|
||||
}
|
||||
|
||||
let mut visited = HashSet::new();
|
||||
let mut queue: VecDeque<(NodeId, Vec<NodeId>, Vec<Edge>)> = VecDeque::new();
|
||||
|
||||
visited.insert(*from);
|
||||
queue.push_back((*from, vec![*from], Vec::new()));
|
||||
|
||||
while let Some((current, path, edges)) = queue.pop_front() {
|
||||
let neighbor_edges = self.store.edges_of(¤t, Direction::Both);
|
||||
|
||||
for edge in neighbor_edges {
|
||||
// Filter by edge type
|
||||
if !edge_types.is_empty() && !edge_types.contains(&edge.edge_type) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let neighbor = if edge.source == current {
|
||||
edge.target
|
||||
} else if !edge.directed || edge.target == current {
|
||||
edge.source
|
||||
} else {
|
||||
continue;
|
||||
};
|
||||
|
||||
if &neighbor == to {
|
||||
let mut final_path = path.clone();
|
||||
final_path.push(neighbor);
|
||||
let path_len = final_path.len();
|
||||
let mut final_edges = edges.clone();
|
||||
final_edges.push(edge);
|
||||
return PathResult::found(final_path, final_edges, path_len as f64 - 1.0);
|
||||
}
|
||||
|
||||
if !visited.contains(&neighbor) {
|
||||
visited.insert(neighbor);
|
||||
let mut new_path = path.clone();
|
||||
new_path.push(neighbor);
|
||||
let mut new_edges = edges.clone();
|
||||
new_edges.push(edge);
|
||||
queue.push_back((neighbor, new_path, new_edges));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
PathResult::not_found()
|
||||
}
|
||||
|
||||
/// Checks if a path exists between two nodes.
|
||||
pub fn path_exists(&self, from: &NodeId, to: &NodeId) -> bool {
|
||||
if from == to {
|
||||
return true;
|
||||
}
|
||||
|
||||
let mut visited = HashSet::new();
|
||||
let mut queue = VecDeque::new();
|
||||
|
||||
visited.insert(*from);
|
||||
queue.push_back(*from);
|
||||
|
||||
while let Some(current) = queue.pop_front() {
|
||||
let neighbor_edges = self.store.edges_of(¤t, Direction::Both);
|
||||
|
||||
for edge in neighbor_edges {
|
||||
let neighbor = if edge.source == current {
|
||||
edge.target
|
||||
} else if !edge.directed || edge.target == current {
|
||||
edge.source
|
||||
} else {
|
||||
continue;
|
||||
};
|
||||
|
||||
if &neighbor == to {
|
||||
return true;
|
||||
}
|
||||
|
||||
if !visited.contains(&neighbor) {
|
||||
visited.insert(neighbor);
|
||||
queue.push_back(neighbor);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
/// Finds the distance (number of hops) between two nodes.
|
||||
pub fn distance(&self, from: &NodeId, to: &NodeId) -> Option<usize> {
|
||||
let result = self.shortest_path_bfs(from, to);
|
||||
if result.found {
|
||||
Some(result.nodes.len() - 1)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn setup_graph() -> GraphStore {
|
||||
let store = GraphStore::new();
|
||||
|
||||
// Create nodes A -> B -> C -> D
|
||||
// \ /
|
||||
// \-> E ->/
|
||||
let a = store.create_node(vec![], serde_json::json!({"name": "A"}));
|
||||
let b = store.create_node(vec![], serde_json::json!({"name": "B"}));
|
||||
let c = store.create_node(vec![], serde_json::json!({"name": "C"}));
|
||||
let d = store.create_node(vec![], serde_json::json!({"name": "D"}));
|
||||
let e = store.create_node(vec![], serde_json::json!({"name": "E"}));
|
||||
|
||||
store.create_edge(a, b, "LINK", serde_json::json!({})).unwrap();
|
||||
store.create_edge(b, c, "LINK", serde_json::json!({})).unwrap();
|
||||
store.create_edge(c, d, "LINK", serde_json::json!({})).unwrap();
|
||||
store.create_edge(a, e, "LINK", serde_json::json!({})).unwrap();
|
||||
store.create_edge(e, d, "LINK", serde_json::json!({})).unwrap();
|
||||
|
||||
store
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shortest_path_bfs() {
|
||||
let store = setup_graph();
|
||||
let finder = PathFinder::new(&store);
|
||||
|
||||
let nodes: Vec<_> = store.find_nodes(None, &serde_json::json!({}));
|
||||
let a = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("A"))).unwrap();
|
||||
let d = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("D"))).unwrap();
|
||||
|
||||
let result = finder.shortest_path_bfs(&a.id, &d.id);
|
||||
|
||||
assert!(result.found);
|
||||
assert_eq!(result.length, 2.0); // A -> E -> D (shortest)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shortest_path_dijkstra() {
|
||||
let store = GraphStore::new();
|
||||
|
||||
let a = store.create_node(vec![], serde_json::json!({"name": "A"}));
|
||||
let b = store.create_node(vec![], serde_json::json!({"name": "B"}));
|
||||
let c = store.create_node(vec![], serde_json::json!({"name": "C"}));
|
||||
|
||||
// A --(1.0)--> B --(1.0)--> C
|
||||
// A --(3.0)--> C
|
||||
let mut edge1 = super::super::edge::Edge::new(a, b, "LINK", serde_json::json!({}));
|
||||
edge1.weight = 1.0;
|
||||
store.create_edge(a, b, "LINK", serde_json::json!({})).unwrap();
|
||||
|
||||
let mut edge2 = super::super::edge::Edge::new(b, c, "LINK", serde_json::json!({}));
|
||||
edge2.weight = 1.0;
|
||||
store.create_edge(b, c, "LINK", serde_json::json!({})).unwrap();
|
||||
|
||||
store.create_edge(a, c, "DIRECT", serde_json::json!({})).unwrap();
|
||||
|
||||
let finder = PathFinder::new(&store);
|
||||
let result = finder.shortest_path_dijkstra(&a, &c);
|
||||
|
||||
assert!(result.found);
|
||||
// Both paths have same weight (1.0 each), either is valid
|
||||
assert!(result.nodes.len() <= 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_path_not_found() {
|
||||
let store = GraphStore::new();
|
||||
|
||||
let a = store.create_node(vec![], serde_json::json!({}));
|
||||
let b = store.create_node(vec![], serde_json::json!({}));
|
||||
// No edge between a and b
|
||||
|
||||
let finder = PathFinder::new(&store);
|
||||
let result = finder.shortest_path_bfs(&a, &b);
|
||||
|
||||
assert!(!result.found);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_all_paths() {
|
||||
let store = setup_graph();
|
||||
let finder = PathFinder::new(&store);
|
||||
|
||||
let nodes: Vec<_> = store.find_nodes(None, &serde_json::json!({}));
|
||||
let a = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("A"))).unwrap();
|
||||
let d = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("D"))).unwrap();
|
||||
|
||||
let paths = finder.all_paths(&a.id, &d.id, 5);
|
||||
|
||||
assert!(paths.len() >= 2); // At least A->B->C->D and A->E->D
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_path_exists() {
|
||||
let store = setup_graph();
|
||||
let finder = PathFinder::new(&store);
|
||||
|
||||
let nodes: Vec<_> = store.find_nodes(None, &serde_json::json!({}));
|
||||
let a = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("A"))).unwrap();
|
||||
let d = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("D"))).unwrap();
|
||||
|
||||
assert!(finder.path_exists(&a.id, &d.id));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_distance() {
|
||||
let store = setup_graph();
|
||||
let finder = PathFinder::new(&store);
|
||||
|
||||
let nodes: Vec<_> = store.find_nodes(None, &serde_json::json!({}));
|
||||
let a = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("A"))).unwrap();
|
||||
let d = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("D"))).unwrap();
|
||||
let b = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("B"))).unwrap();
|
||||
|
||||
assert_eq!(finder.distance(&a.id, &b.id), Some(1));
|
||||
assert_eq!(finder.distance(&a.id, &d.id), Some(2)); // A -> E -> D
|
||||
}
|
||||
}
|
||||
825
crates/synor-database/src/graph/query.rs
Normal file
825
crates/synor-database/src/graph/query.rs
Normal file
|
|
@ -0,0 +1,825 @@
|
|||
//! Simplified Cypher-like query language for graphs.
|
||||
|
||||
use super::edge::Edge;
|
||||
use super::node::{Node, NodeId};
|
||||
use super::store::{Direction, GraphError, GraphStore};
|
||||
use super::traversal::{TraversalDirection, TraversalQuery, Traverser};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value as JsonValue;
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Parsed graph query.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub enum GraphQuery {
|
||||
/// MATCH query for pattern matching.
|
||||
Match {
|
||||
pattern: MatchPattern,
|
||||
where_clause: Option<WhereClause>,
|
||||
return_items: Vec<ReturnItem>,
|
||||
limit: Option<usize>,
|
||||
},
|
||||
/// CREATE query for creating nodes/edges.
|
||||
Create { elements: Vec<CreateElement> },
|
||||
/// DELETE query for removing nodes/edges.
|
||||
Delete { variable: String, detach: bool },
|
||||
/// SET query for updating properties.
|
||||
Set { variable: String, properties: JsonValue },
|
||||
}
|
||||
|
||||
/// Pattern to match in the graph.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct MatchPattern {
|
||||
/// Starting node pattern.
|
||||
pub start: NodePattern,
|
||||
/// Relationship patterns (edges and target nodes).
|
||||
pub relationships: Vec<RelationshipPattern>,
|
||||
}
|
||||
|
||||
/// Pattern for matching a node.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct NodePattern {
|
||||
/// Variable name for this node.
|
||||
pub variable: Option<String>,
|
||||
/// Required labels.
|
||||
pub labels: Vec<String>,
|
||||
/// Property filters.
|
||||
pub properties: Option<JsonValue>,
|
||||
}
|
||||
|
||||
/// Pattern for matching a relationship.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct RelationshipPattern {
|
||||
/// Variable name for this relationship.
|
||||
pub variable: Option<String>,
|
||||
/// Edge type (relationship type).
|
||||
pub edge_type: Option<String>,
|
||||
/// Direction of the relationship.
|
||||
pub direction: RelationshipDirection,
|
||||
/// Target node pattern.
|
||||
pub target: NodePattern,
|
||||
/// Minimum hops (for variable-length paths).
|
||||
pub min_hops: usize,
|
||||
/// Maximum hops (for variable-length paths).
|
||||
pub max_hops: usize,
|
||||
}
|
||||
|
||||
/// Direction of a relationship in a pattern.
|
||||
#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq)]
|
||||
pub enum RelationshipDirection {
|
||||
/// Outgoing: (a)-[:TYPE]->(b)
|
||||
Outgoing,
|
||||
/// Incoming: (a)<-[:TYPE]-(b)
|
||||
Incoming,
|
||||
/// Undirected: (a)-[:TYPE]-(b)
|
||||
Undirected,
|
||||
}
|
||||
|
||||
/// WHERE clause conditions.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub enum WhereClause {
|
||||
/// Property comparison.
|
||||
PropertyEquals { variable: String, property: String, value: JsonValue },
|
||||
/// Property comparison (not equals).
|
||||
PropertyNotEquals { variable: String, property: String, value: JsonValue },
|
||||
/// Property greater than.
|
||||
PropertyGt { variable: String, property: String, value: JsonValue },
|
||||
/// Property less than.
|
||||
PropertyLt { variable: String, property: String, value: JsonValue },
|
||||
/// Property contains (for text).
|
||||
PropertyContains { variable: String, property: String, value: String },
|
||||
/// AND condition.
|
||||
And(Box<WhereClause>, Box<WhereClause>),
|
||||
/// OR condition.
|
||||
Or(Box<WhereClause>, Box<WhereClause>),
|
||||
/// NOT condition.
|
||||
Not(Box<WhereClause>),
|
||||
}
|
||||
|
||||
/// Item to return from a query.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub enum ReturnItem {
|
||||
/// Return all variables.
|
||||
All,
|
||||
/// Return a specific variable.
|
||||
Variable(String),
|
||||
/// Return a property of a variable.
|
||||
Property { variable: String, property: String },
|
||||
/// Return with an alias.
|
||||
Alias { item: Box<ReturnItem>, alias: String },
|
||||
/// Count aggregation.
|
||||
Count(Option<String>),
|
||||
}
|
||||
|
||||
/// Element to create.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub enum CreateElement {
|
||||
/// Create a node.
|
||||
Node { variable: Option<String>, labels: Vec<String>, properties: JsonValue },
|
||||
/// Create a relationship.
|
||||
Relationship {
|
||||
from_var: String,
|
||||
to_var: String,
|
||||
edge_type: String,
|
||||
properties: JsonValue,
|
||||
},
|
||||
}
|
||||
|
||||
/// Result of a graph query.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct QueryResult {
|
||||
/// Column names.
|
||||
pub columns: Vec<String>,
|
||||
/// Result rows.
|
||||
pub rows: Vec<Vec<JsonValue>>,
|
||||
/// Number of nodes created.
|
||||
pub nodes_created: usize,
|
||||
/// Number of relationships created.
|
||||
pub relationships_created: usize,
|
||||
/// Number of nodes deleted.
|
||||
pub nodes_deleted: usize,
|
||||
/// Number of relationships deleted.
|
||||
pub relationships_deleted: usize,
|
||||
/// Number of properties set.
|
||||
pub properties_set: usize,
|
||||
}
|
||||
|
||||
impl QueryResult {
|
||||
/// Creates an empty result.
|
||||
pub fn empty() -> Self {
|
||||
Self {
|
||||
columns: Vec::new(),
|
||||
rows: Vec::new(),
|
||||
nodes_created: 0,
|
||||
relationships_created: 0,
|
||||
nodes_deleted: 0,
|
||||
relationships_deleted: 0,
|
||||
properties_set: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Parser for Cypher-like graph queries.
|
||||
pub struct GraphQueryParser;
|
||||
|
||||
impl GraphQueryParser {
|
||||
/// Parses a query string into a GraphQuery.
|
||||
pub fn parse(query: &str) -> Result<GraphQuery, GraphError> {
|
||||
let query = query.trim();
|
||||
let upper = query.to_uppercase();
|
||||
|
||||
if upper.starts_with("MATCH") {
|
||||
Self::parse_match(query)
|
||||
} else if upper.starts_with("CREATE") {
|
||||
Self::parse_create(query)
|
||||
} else if upper.starts_with("DELETE") || upper.starts_with("DETACH DELETE") {
|
||||
Self::parse_delete(query)
|
||||
} else if upper.starts_with("SET") {
|
||||
Self::parse_set(query)
|
||||
} else {
|
||||
Err(GraphError::InvalidOperation(format!("Unknown query type: {}", query)))
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_match(query: &str) -> Result<GraphQuery, GraphError> {
|
||||
// Simplified parser for: MATCH (var:Label {props})-[:TYPE]->(var2) WHERE ... RETURN ...
|
||||
let upper = query.to_uppercase();
|
||||
|
||||
// Find MATCH, WHERE, RETURN, LIMIT positions
|
||||
let match_end = upper.find("WHERE").or_else(|| upper.find("RETURN")).unwrap_or(query.len());
|
||||
let where_start = upper.find("WHERE");
|
||||
let return_start = upper.find("RETURN");
|
||||
let limit_start = upper.find("LIMIT");
|
||||
|
||||
// Parse pattern (between MATCH and WHERE/RETURN)
|
||||
let pattern_str = &query[5..match_end].trim();
|
||||
let pattern = Self::parse_pattern(pattern_str)?;
|
||||
|
||||
// Parse WHERE clause
|
||||
let where_clause = if let Some(ws) = where_start {
|
||||
let where_end = return_start.unwrap_or(query.len());
|
||||
let where_str = &query[ws + 5..where_end].trim();
|
||||
Some(Self::parse_where(where_str)?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Parse RETURN clause
|
||||
let return_items = if let Some(rs) = return_start {
|
||||
let return_end = limit_start.unwrap_or(query.len());
|
||||
let return_str = &query[rs + 6..return_end].trim();
|
||||
Self::parse_return(return_str)?
|
||||
} else {
|
||||
vec![ReturnItem::All]
|
||||
};
|
||||
|
||||
// Parse LIMIT
|
||||
let limit = if let Some(ls) = limit_start {
|
||||
let limit_str = query[ls + 5..].trim();
|
||||
limit_str.parse().ok()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(GraphQuery::Match {
|
||||
pattern,
|
||||
where_clause,
|
||||
return_items,
|
||||
limit,
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_pattern(pattern: &str) -> Result<MatchPattern, GraphError> {
|
||||
// Parse node and relationship patterns
|
||||
// Format: (var:Label {props})-[:TYPE]->(var2:Label2)
|
||||
let mut chars = pattern.chars().peekable();
|
||||
let mut nodes = Vec::new();
|
||||
let mut relationships = Vec::new();
|
||||
|
||||
while chars.peek().is_some() {
|
||||
// Skip whitespace
|
||||
while chars.peek() == Some(&' ') {
|
||||
chars.next();
|
||||
}
|
||||
|
||||
if chars.peek() == Some(&'(') {
|
||||
let node = Self::parse_node_pattern(&mut chars)?;
|
||||
nodes.push(node);
|
||||
} else if chars.peek() == Some(&'-') || chars.peek() == Some(&'<') {
|
||||
let rel = Self::parse_relationship_pattern(&mut chars)?;
|
||||
relationships.push(rel);
|
||||
} else if chars.peek().is_some() {
|
||||
chars.next(); // Skip unknown characters
|
||||
}
|
||||
}
|
||||
|
||||
if nodes.is_empty() {
|
||||
return Err(GraphError::InvalidOperation("No node pattern found".to_string()));
|
||||
}
|
||||
|
||||
// Combine nodes with relationships
|
||||
let start = nodes.remove(0);
|
||||
for (i, rel) in relationships.iter_mut().enumerate() {
|
||||
if i < nodes.len() {
|
||||
rel.target = nodes[i].clone();
|
||||
}
|
||||
}
|
||||
|
||||
Ok(MatchPattern { start, relationships })
|
||||
}
|
||||
|
||||
fn parse_node_pattern(chars: &mut std::iter::Peekable<std::str::Chars>) -> Result<NodePattern, GraphError> {
|
||||
// Consume '('
|
||||
chars.next();
|
||||
|
||||
let mut variable = None;
|
||||
let mut labels = Vec::new();
|
||||
let mut properties = None;
|
||||
|
||||
let mut buffer = String::new();
|
||||
|
||||
while let Some(&c) = chars.peek() {
|
||||
match c {
|
||||
')' => {
|
||||
chars.next();
|
||||
if !buffer.is_empty() && variable.is_none() {
|
||||
variable = Some(buffer.clone());
|
||||
}
|
||||
break;
|
||||
}
|
||||
':' => {
|
||||
chars.next();
|
||||
if !buffer.is_empty() && variable.is_none() {
|
||||
variable = Some(buffer.clone());
|
||||
}
|
||||
buffer.clear();
|
||||
|
||||
// Read label
|
||||
while let Some(&c) = chars.peek() {
|
||||
if c == ')' || c == '{' || c == ':' || c == ' ' {
|
||||
break;
|
||||
}
|
||||
buffer.push(c);
|
||||
chars.next();
|
||||
}
|
||||
if !buffer.is_empty() {
|
||||
labels.push(buffer.clone());
|
||||
buffer.clear();
|
||||
}
|
||||
}
|
||||
'{' => {
|
||||
chars.next();
|
||||
// Read properties JSON
|
||||
let mut props_str = String::from("{");
|
||||
let mut depth = 1;
|
||||
while let Some(&c) = chars.peek() {
|
||||
chars.next();
|
||||
props_str.push(c);
|
||||
if c == '{' {
|
||||
depth += 1;
|
||||
} else if c == '}' {
|
||||
depth -= 1;
|
||||
if depth == 0 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
properties = serde_json::from_str(&props_str).ok();
|
||||
}
|
||||
' ' => {
|
||||
chars.next();
|
||||
}
|
||||
_ => {
|
||||
buffer.push(c);
|
||||
chars.next();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(NodePattern { variable, labels, properties })
|
||||
}
|
||||
|
||||
fn parse_relationship_pattern(chars: &mut std::iter::Peekable<std::str::Chars>) -> Result<RelationshipPattern, GraphError> {
|
||||
let mut direction = RelationshipDirection::Undirected;
|
||||
let mut edge_type = None;
|
||||
let mut variable = None;
|
||||
let min_hops = 1;
|
||||
let max_hops = 1;
|
||||
|
||||
// Check for incoming: <-
|
||||
if chars.peek() == Some(&'<') {
|
||||
chars.next();
|
||||
direction = RelationshipDirection::Incoming;
|
||||
}
|
||||
|
||||
// Consume -
|
||||
if chars.peek() == Some(&'-') {
|
||||
chars.next();
|
||||
}
|
||||
|
||||
// Check for [type]
|
||||
if chars.peek() == Some(&'[') {
|
||||
chars.next();
|
||||
let mut buffer = String::new();
|
||||
|
||||
while let Some(&c) = chars.peek() {
|
||||
if c == ']' {
|
||||
chars.next();
|
||||
break;
|
||||
} else if c == ':' {
|
||||
chars.next();
|
||||
if !buffer.is_empty() {
|
||||
variable = Some(buffer.clone());
|
||||
}
|
||||
buffer.clear();
|
||||
|
||||
// Read edge type
|
||||
while let Some(&c) = chars.peek() {
|
||||
if c == ']' || c == ' ' || c == '*' {
|
||||
break;
|
||||
}
|
||||
buffer.push(c);
|
||||
chars.next();
|
||||
}
|
||||
if !buffer.is_empty() {
|
||||
edge_type = Some(buffer.clone());
|
||||
buffer.clear();
|
||||
}
|
||||
} else {
|
||||
buffer.push(c);
|
||||
chars.next();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Consume -
|
||||
if chars.peek() == Some(&'-') {
|
||||
chars.next();
|
||||
}
|
||||
|
||||
// Check for outgoing: >
|
||||
if chars.peek() == Some(&'>') {
|
||||
chars.next();
|
||||
if direction != RelationshipDirection::Incoming {
|
||||
direction = RelationshipDirection::Outgoing;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(RelationshipPattern {
|
||||
variable,
|
||||
edge_type,
|
||||
direction,
|
||||
target: NodePattern { variable: None, labels: Vec::new(), properties: None },
|
||||
min_hops,
|
||||
max_hops,
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_where(_where_str: &str) -> Result<WhereClause, GraphError> {
|
||||
// Simplified: just parse "var.prop = value"
|
||||
// Full implementation would handle complex boolean expressions
|
||||
Ok(WhereClause::PropertyEquals {
|
||||
variable: "n".to_string(),
|
||||
property: "id".to_string(),
|
||||
value: JsonValue::Null,
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_return(return_str: &str) -> Result<Vec<ReturnItem>, GraphError> {
|
||||
let items: Vec<_> = return_str.split(',').map(|s| s.trim()).collect();
|
||||
let mut result = Vec::new();
|
||||
|
||||
for item in items {
|
||||
if item == "*" {
|
||||
result.push(ReturnItem::All);
|
||||
} else if item.to_uppercase().starts_with("COUNT(") {
|
||||
let inner = &item[6..item.len() - 1];
|
||||
if inner == "*" {
|
||||
result.push(ReturnItem::Count(None));
|
||||
} else {
|
||||
result.push(ReturnItem::Count(Some(inner.to_string())));
|
||||
}
|
||||
} else if item.contains('.') {
|
||||
let parts: Vec<_> = item.split('.').collect();
|
||||
if parts.len() == 2 {
|
||||
result.push(ReturnItem::Property {
|
||||
variable: parts[0].to_string(),
|
||||
property: parts[1].to_string(),
|
||||
});
|
||||
}
|
||||
} else {
|
||||
result.push(ReturnItem::Variable(item.to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn parse_create(query: &str) -> Result<GraphQuery, GraphError> {
|
||||
// Simplified CREATE parser
|
||||
let pattern = &query[6..].trim();
|
||||
let elements = Self::parse_create_elements(pattern)?;
|
||||
Ok(GraphQuery::Create { elements })
|
||||
}
|
||||
|
||||
fn parse_create_elements(pattern: &str) -> Result<Vec<CreateElement>, GraphError> {
|
||||
// Parse (var:Label {props}) patterns
|
||||
let mut elements = Vec::new();
|
||||
let mut chars = pattern.chars().peekable();
|
||||
|
||||
while chars.peek().is_some() {
|
||||
while chars.peek() == Some(&' ') || chars.peek() == Some(&',') {
|
||||
chars.next();
|
||||
}
|
||||
|
||||
if chars.peek() == Some(&'(') {
|
||||
let node = Self::parse_node_pattern(&mut chars)?;
|
||||
elements.push(CreateElement::Node {
|
||||
variable: node.variable,
|
||||
labels: node.labels,
|
||||
properties: node.properties.unwrap_or(JsonValue::Object(serde_json::Map::new())),
|
||||
});
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(elements)
|
||||
}
|
||||
|
||||
fn parse_delete(query: &str) -> Result<GraphQuery, GraphError> {
|
||||
let detach = query.to_uppercase().starts_with("DETACH");
|
||||
let start = if detach { "DETACH DELETE".len() } else { "DELETE".len() };
|
||||
let variable = query[start..].trim().to_string();
|
||||
|
||||
Ok(GraphQuery::Delete { variable, detach })
|
||||
}
|
||||
|
||||
fn parse_set(query: &str) -> Result<GraphQuery, GraphError> {
|
||||
// Simplified: SET var.prop = value
|
||||
let content = &query[3..].trim();
|
||||
let parts: Vec<_> = content.split('=').collect();
|
||||
|
||||
if parts.len() != 2 {
|
||||
return Err(GraphError::InvalidOperation("Invalid SET syntax".to_string()));
|
||||
}
|
||||
|
||||
let var_prop: Vec<_> = parts[0].trim().split('.').collect();
|
||||
if var_prop.len() != 2 {
|
||||
return Err(GraphError::InvalidOperation("Invalid SET variable".to_string()));
|
||||
}
|
||||
|
||||
let variable = var_prop[0].to_string();
|
||||
let property = var_prop[1].to_string();
|
||||
let value_str = parts[1].trim();
|
||||
|
||||
let value: JsonValue = serde_json::from_str(value_str).unwrap_or(JsonValue::String(value_str.to_string()));
|
||||
|
||||
Ok(GraphQuery::Set {
|
||||
variable,
|
||||
properties: serde_json::json!({ property: value }),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Query executor for graph queries.
|
||||
pub struct GraphQueryExecutor<'a> {
|
||||
store: &'a GraphStore,
|
||||
}
|
||||
|
||||
impl<'a> GraphQueryExecutor<'a> {
|
||||
/// Creates a new query executor.
|
||||
pub fn new(store: &'a GraphStore) -> Self {
|
||||
Self { store }
|
||||
}
|
||||
|
||||
/// Executes a graph query.
|
||||
pub fn execute(&self, query: &GraphQuery) -> Result<QueryResult, GraphError> {
|
||||
match query {
|
||||
GraphQuery::Match { pattern, where_clause, return_items, limit } => {
|
||||
self.execute_match(pattern, where_clause.as_ref(), return_items, *limit)
|
||||
}
|
||||
GraphQuery::Create { .. } => {
|
||||
Err(GraphError::InvalidOperation("CREATE requires mutable access".to_string()))
|
||||
}
|
||||
GraphQuery::Delete { .. } => {
|
||||
Err(GraphError::InvalidOperation("DELETE requires mutable access".to_string()))
|
||||
}
|
||||
GraphQuery::Set { .. } => {
|
||||
Err(GraphError::InvalidOperation("SET requires mutable access".to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn execute_match(
|
||||
&self,
|
||||
pattern: &MatchPattern,
|
||||
_where_clause: Option<&WhereClause>,
|
||||
return_items: &[ReturnItem],
|
||||
limit: Option<usize>,
|
||||
) -> Result<QueryResult, GraphError> {
|
||||
// Find starting nodes
|
||||
let start_nodes = self.find_matching_nodes(&pattern.start);
|
||||
|
||||
let mut bindings: Vec<HashMap<String, JsonValue>> = Vec::new();
|
||||
|
||||
for start_node in &start_nodes {
|
||||
let mut binding: HashMap<String, JsonValue> = HashMap::new();
|
||||
|
||||
if let Some(ref var) = pattern.start.variable {
|
||||
binding.insert(var.clone(), Self::node_to_json(start_node));
|
||||
}
|
||||
|
||||
if pattern.relationships.is_empty() {
|
||||
bindings.push(binding);
|
||||
} else {
|
||||
// Traverse relationships
|
||||
let traverser = Traverser::new(self.store);
|
||||
|
||||
for rel_pattern in &pattern.relationships {
|
||||
let direction = match rel_pattern.direction {
|
||||
RelationshipDirection::Outgoing => TraversalDirection::Outgoing,
|
||||
RelationshipDirection::Incoming => TraversalDirection::Incoming,
|
||||
RelationshipDirection::Undirected => TraversalDirection::Both,
|
||||
};
|
||||
|
||||
let query = TraversalQuery::new()
|
||||
.depth(rel_pattern.max_hops)
|
||||
.direction(direction)
|
||||
.edge_types(
|
||||
rel_pattern.edge_type.clone().map(|t| vec![t]).unwrap_or_default(),
|
||||
)
|
||||
.labels(rel_pattern.target.labels.clone());
|
||||
|
||||
let results = traverser.traverse(&start_node.id, &query);
|
||||
|
||||
for result in results {
|
||||
let mut new_binding = binding.clone();
|
||||
|
||||
if let Some(ref var) = rel_pattern.variable {
|
||||
if let Some(edge) = result.edges.last() {
|
||||
new_binding.insert(var.clone(), Self::edge_to_json(edge));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ref var) = rel_pattern.target.variable {
|
||||
new_binding.insert(var.clone(), Self::node_to_json(&result.node));
|
||||
}
|
||||
|
||||
bindings.push(new_binding);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(l) = limit {
|
||||
if bindings.len() >= l {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply limit
|
||||
if let Some(l) = limit {
|
||||
bindings.truncate(l);
|
||||
}
|
||||
|
||||
// Build result based on return items
|
||||
let columns = self.get_column_names(return_items, &bindings);
|
||||
let rows = self.extract_rows(return_items, &bindings);
|
||||
|
||||
Ok(QueryResult {
|
||||
columns,
|
||||
rows,
|
||||
..QueryResult::empty()
|
||||
})
|
||||
}
|
||||
|
||||
fn find_matching_nodes(&self, pattern: &NodePattern) -> Vec<Node> {
|
||||
let label = pattern.labels.first().map(|s| s.as_str());
|
||||
let filter = pattern.properties.clone().unwrap_or(JsonValue::Object(serde_json::Map::new()));
|
||||
self.store.find_nodes(label, &filter)
|
||||
}
|
||||
|
||||
fn node_to_json(node: &Node) -> JsonValue {
|
||||
serde_json::json!({
|
||||
"id": node.id.to_hex(),
|
||||
"labels": node.labels,
|
||||
"properties": node.properties,
|
||||
})
|
||||
}
|
||||
|
||||
fn edge_to_json(edge: &Edge) -> JsonValue {
|
||||
serde_json::json!({
|
||||
"id": edge.id.to_hex(),
|
||||
"type": edge.edge_type,
|
||||
"source": edge.source.to_hex(),
|
||||
"target": edge.target.to_hex(),
|
||||
"properties": edge.properties,
|
||||
})
|
||||
}
|
||||
|
||||
fn get_column_names(&self, return_items: &[ReturnItem], bindings: &[HashMap<String, JsonValue>]) -> Vec<String> {
|
||||
let mut columns = Vec::new();
|
||||
|
||||
for item in return_items {
|
||||
match item {
|
||||
ReturnItem::All => {
|
||||
if let Some(binding) = bindings.first() {
|
||||
columns.extend(binding.keys().cloned());
|
||||
}
|
||||
}
|
||||
ReturnItem::Variable(var) => columns.push(var.clone()),
|
||||
ReturnItem::Property { variable, property } => {
|
||||
columns.push(format!("{}.{}", variable, property));
|
||||
}
|
||||
ReturnItem::Alias { alias, .. } => columns.push(alias.clone()),
|
||||
ReturnItem::Count(var) => {
|
||||
columns.push(format!("count({})", var.as_ref().map(|s| s.as_str()).unwrap_or("*")));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
columns
|
||||
}
|
||||
|
||||
fn extract_rows(&self, return_items: &[ReturnItem], bindings: &[HashMap<String, JsonValue>]) -> Vec<Vec<JsonValue>> {
|
||||
let mut rows = Vec::new();
|
||||
|
||||
// Handle COUNT specially
|
||||
if return_items.iter().any(|i| matches!(i, ReturnItem::Count(_))) {
|
||||
rows.push(vec![JsonValue::Number(bindings.len().into())]);
|
||||
return rows;
|
||||
}
|
||||
|
||||
for binding in bindings {
|
||||
let mut row = Vec::new();
|
||||
|
||||
for item in return_items {
|
||||
match item {
|
||||
ReturnItem::All => {
|
||||
for (_, value) in binding {
|
||||
row.push(value.clone());
|
||||
}
|
||||
}
|
||||
ReturnItem::Variable(var) => {
|
||||
row.push(binding.get(var).cloned().unwrap_or(JsonValue::Null));
|
||||
}
|
||||
ReturnItem::Property { variable, property } => {
|
||||
if let Some(obj) = binding.get(variable) {
|
||||
if let Some(props) = obj.get("properties") {
|
||||
row.push(props.get(property).cloned().unwrap_or(JsonValue::Null));
|
||||
} else {
|
||||
row.push(JsonValue::Null);
|
||||
}
|
||||
} else {
|
||||
row.push(JsonValue::Null);
|
||||
}
|
||||
}
|
||||
ReturnItem::Alias { item: inner, .. } => {
|
||||
// Recursively handle the inner item
|
||||
let inner_rows = self.extract_rows(&[*inner.clone()], &[binding.clone()]);
|
||||
if let Some(inner_row) = inner_rows.first() {
|
||||
row.extend(inner_row.clone());
|
||||
}
|
||||
}
|
||||
ReturnItem::Count(_) => {
|
||||
// Handled above
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !row.is_empty() {
|
||||
rows.push(row);
|
||||
}
|
||||
}
|
||||
|
||||
rows
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_simple_match() {
|
||||
let query = "MATCH (n:User) RETURN n";
|
||||
let parsed = GraphQueryParser::parse(query).unwrap();
|
||||
|
||||
if let GraphQuery::Match { pattern, .. } = parsed {
|
||||
assert_eq!(pattern.start.labels, vec!["User".to_string()]);
|
||||
} else {
|
||||
panic!("Expected Match query");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_match_with_relationship() {
|
||||
let query = "MATCH (a:User)-[:FRIEND]->(b:User) RETURN a, b";
|
||||
let parsed = GraphQueryParser::parse(query).unwrap();
|
||||
|
||||
if let GraphQuery::Match { pattern, .. } = parsed {
|
||||
assert_eq!(pattern.start.labels, vec!["User".to_string()]);
|
||||
assert_eq!(pattern.relationships.len(), 1);
|
||||
assert_eq!(pattern.relationships[0].edge_type, Some("FRIEND".to_string()));
|
||||
assert_eq!(pattern.relationships[0].direction, RelationshipDirection::Outgoing);
|
||||
} else {
|
||||
panic!("Expected Match query");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_execute_match() {
|
||||
let store = GraphStore::new();
|
||||
|
||||
let alice = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Alice"}));
|
||||
let bob = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Bob"}));
|
||||
store.create_edge(alice, bob, "FRIEND", serde_json::json!({})).unwrap();
|
||||
|
||||
let query = GraphQueryParser::parse("MATCH (n:User) RETURN n").unwrap();
|
||||
let executor = GraphQueryExecutor::new(&store);
|
||||
let result = executor.execute(&query).unwrap();
|
||||
|
||||
assert_eq!(result.rows.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_create() {
|
||||
let query = "CREATE (n:User {name: \"Alice\"})";
|
||||
let parsed = GraphQueryParser::parse(query).unwrap();
|
||||
|
||||
if let GraphQuery::Create { elements } = parsed {
|
||||
assert_eq!(elements.len(), 1);
|
||||
if let CreateElement::Node { labels, .. } = &elements[0] {
|
||||
assert_eq!(labels, &vec!["User".to_string()]);
|
||||
}
|
||||
} else {
|
||||
panic!("Expected Create query");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_delete() {
|
||||
let query = "DELETE n";
|
||||
let parsed = GraphQueryParser::parse(query).unwrap();
|
||||
|
||||
if let GraphQuery::Delete { variable, detach } = parsed {
|
||||
assert_eq!(variable, "n");
|
||||
assert!(!detach);
|
||||
} else {
|
||||
panic!("Expected Delete query");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_detach_delete() {
|
||||
let query = "DETACH DELETE n";
|
||||
let parsed = GraphQueryParser::parse(query).unwrap();
|
||||
|
||||
if let GraphQuery::Delete { variable, detach } = parsed {
|
||||
assert_eq!(variable, "n");
|
||||
assert!(detach);
|
||||
} else {
|
||||
panic!("Expected Delete query");
|
||||
}
|
||||
}
|
||||
}
|
||||
657
crates/synor-database/src/graph/store.rs
Normal file
657
crates/synor-database/src/graph/store.rs
Normal file
|
|
@ -0,0 +1,657 @@
|
|||
//! Graph storage engine.
|
||||
|
||||
use super::edge::{Edge, EdgeId};
|
||||
use super::node::{Node, NodeId};
|
||||
use parking_lot::RwLock;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value as JsonValue;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use thiserror::Error;
|
||||
|
||||
/// Direction for traversing edges.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub enum Direction {
|
||||
/// Follow outgoing edges only.
|
||||
Outgoing,
|
||||
/// Follow incoming edges only.
|
||||
Incoming,
|
||||
/// Follow edges in both directions.
|
||||
Both,
|
||||
}
|
||||
|
||||
/// Graph storage errors.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum GraphError {
|
||||
/// Node not found.
|
||||
#[error("Node not found: {0}")]
|
||||
NodeNotFound(String),
|
||||
|
||||
/// Edge not found.
|
||||
#[error("Edge not found: {0}")]
|
||||
EdgeNotFound(String),
|
||||
|
||||
/// Node already exists.
|
||||
#[error("Node already exists: {0}")]
|
||||
NodeExists(String),
|
||||
|
||||
/// Invalid operation.
|
||||
#[error("Invalid operation: {0}")]
|
||||
InvalidOperation(String),
|
||||
|
||||
/// Constraint violation.
|
||||
#[error("Constraint violation: {0}")]
|
||||
ConstraintViolation(String),
|
||||
}
|
||||
|
||||
/// Statistics for a graph store.
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
|
||||
pub struct GraphStats {
|
||||
/// Total number of nodes.
|
||||
pub node_count: u64,
|
||||
/// Total number of edges.
|
||||
pub edge_count: u64,
|
||||
/// Number of distinct labels.
|
||||
pub label_count: u64,
|
||||
/// Number of distinct edge types.
|
||||
pub edge_type_count: u64,
|
||||
}
|
||||
|
||||
/// Graph storage engine.
|
||||
pub struct GraphStore {
|
||||
/// Node storage.
|
||||
nodes: RwLock<HashMap<NodeId, Node>>,
|
||||
/// Edge storage.
|
||||
edges: RwLock<HashMap<EdgeId, Edge>>,
|
||||
/// Outgoing adjacency list: node -> outgoing edges.
|
||||
adjacency: RwLock<HashMap<NodeId, Vec<EdgeId>>>,
|
||||
/// Incoming adjacency list: node -> incoming edges.
|
||||
reverse_adj: RwLock<HashMap<NodeId, Vec<EdgeId>>>,
|
||||
/// Label index: label -> nodes with that label.
|
||||
label_index: RwLock<HashMap<String, HashSet<NodeId>>>,
|
||||
/// Edge type index: type -> edges of that type.
|
||||
edge_type_index: RwLock<HashMap<String, HashSet<EdgeId>>>,
|
||||
}
|
||||
|
||||
impl GraphStore {
|
||||
/// Creates a new empty graph store.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
nodes: RwLock::new(HashMap::new()),
|
||||
edges: RwLock::new(HashMap::new()),
|
||||
adjacency: RwLock::new(HashMap::new()),
|
||||
reverse_adj: RwLock::new(HashMap::new()),
|
||||
label_index: RwLock::new(HashMap::new()),
|
||||
edge_type_index: RwLock::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns statistics about the graph.
|
||||
pub fn stats(&self) -> GraphStats {
|
||||
GraphStats {
|
||||
node_count: self.nodes.read().len() as u64,
|
||||
edge_count: self.edges.read().len() as u64,
|
||||
label_count: self.label_index.read().len() as u64,
|
||||
edge_type_count: self.edge_type_index.read().len() as u64,
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== Node Operations ====================
|
||||
|
||||
/// Creates a new node with the given labels and properties.
|
||||
pub fn create_node(&self, labels: Vec<String>, properties: JsonValue) -> NodeId {
|
||||
let node = Node::new(labels.clone(), properties);
|
||||
let id = node.id;
|
||||
|
||||
// Update label index
|
||||
{
|
||||
let mut label_idx = self.label_index.write();
|
||||
for label in &labels {
|
||||
label_idx.entry(label.clone()).or_default().insert(id);
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize adjacency lists
|
||||
self.adjacency.write().insert(id, Vec::new());
|
||||
self.reverse_adj.write().insert(id, Vec::new());
|
||||
|
||||
// Store node
|
||||
self.nodes.write().insert(id, node);
|
||||
|
||||
id
|
||||
}
|
||||
|
||||
/// Gets a node by ID.
|
||||
pub fn get_node(&self, id: &NodeId) -> Option<Node> {
|
||||
self.nodes.read().get(id).cloned()
|
||||
}
|
||||
|
||||
/// Updates a node's properties.
|
||||
pub fn update_node(&self, id: &NodeId, properties: JsonValue) -> Result<(), GraphError> {
|
||||
let mut nodes = self.nodes.write();
|
||||
let node = nodes
|
||||
.get_mut(id)
|
||||
.ok_or_else(|| GraphError::NodeNotFound(id.to_string()))?;
|
||||
|
||||
node.properties = properties;
|
||||
node.updated_at = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_millis() as u64;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Updates a node's labels.
|
||||
pub fn update_node_labels(&self, id: &NodeId, labels: Vec<String>) -> Result<(), GraphError> {
|
||||
let mut nodes = self.nodes.write();
|
||||
let node = nodes
|
||||
.get_mut(id)
|
||||
.ok_or_else(|| GraphError::NodeNotFound(id.to_string()))?;
|
||||
|
||||
// Update label index
|
||||
{
|
||||
let mut label_idx = self.label_index.write();
|
||||
|
||||
// Remove old labels
|
||||
for old_label in &node.labels {
|
||||
if let Some(set) = label_idx.get_mut(old_label) {
|
||||
set.remove(id);
|
||||
}
|
||||
}
|
||||
|
||||
// Add new labels
|
||||
for new_label in &labels {
|
||||
label_idx.entry(new_label.clone()).or_default().insert(*id);
|
||||
}
|
||||
}
|
||||
|
||||
node.labels = labels;
|
||||
node.updated_at = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_millis() as u64;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Deletes a node and all its connected edges.
|
||||
pub fn delete_node(&self, id: &NodeId) -> Result<(), GraphError> {
|
||||
// Get connected edges
|
||||
let outgoing: Vec<EdgeId> = self
|
||||
.adjacency
|
||||
.read()
|
||||
.get(id)
|
||||
.cloned()
|
||||
.unwrap_or_default();
|
||||
let incoming: Vec<EdgeId> = self
|
||||
.reverse_adj
|
||||
.read()
|
||||
.get(id)
|
||||
.cloned()
|
||||
.unwrap_or_default();
|
||||
|
||||
// Delete all connected edges
|
||||
for edge_id in outgoing.iter().chain(incoming.iter()) {
|
||||
let _ = self.delete_edge(edge_id);
|
||||
}
|
||||
|
||||
// Remove from label index
|
||||
{
|
||||
let nodes = self.nodes.read();
|
||||
if let Some(node) = nodes.get(id) {
|
||||
let mut label_idx = self.label_index.write();
|
||||
for label in &node.labels {
|
||||
if let Some(set) = label_idx.get_mut(label) {
|
||||
set.remove(id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remove adjacency entries
|
||||
self.adjacency.write().remove(id);
|
||||
self.reverse_adj.write().remove(id);
|
||||
|
||||
// Remove node
|
||||
self.nodes
|
||||
.write()
|
||||
.remove(id)
|
||||
.ok_or_else(|| GraphError::NodeNotFound(id.to_string()))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Finds nodes by label.
|
||||
pub fn find_nodes_by_label(&self, label: &str) -> Vec<Node> {
|
||||
let node_ids: Vec<NodeId> = self
|
||||
.label_index
|
||||
.read()
|
||||
.get(label)
|
||||
.map(|set| set.iter().cloned().collect())
|
||||
.unwrap_or_default();
|
||||
|
||||
let nodes = self.nodes.read();
|
||||
node_ids
|
||||
.iter()
|
||||
.filter_map(|id| nodes.get(id).cloned())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Finds nodes matching a filter.
|
||||
pub fn find_nodes(&self, label: Option<&str>, filter: &JsonValue) -> Vec<Node> {
|
||||
let candidates: Vec<Node> = if let Some(l) = label {
|
||||
self.find_nodes_by_label(l)
|
||||
} else {
|
||||
self.nodes.read().values().cloned().collect()
|
||||
};
|
||||
|
||||
candidates
|
||||
.into_iter()
|
||||
.filter(|n| n.matches_properties(filter))
|
||||
.collect()
|
||||
}
|
||||
|
||||
// ==================== Edge Operations ====================
|
||||
|
||||
/// Creates a new edge.
|
||||
pub fn create_edge(
|
||||
&self,
|
||||
source: NodeId,
|
||||
target: NodeId,
|
||||
edge_type: impl Into<String>,
|
||||
properties: JsonValue,
|
||||
) -> Result<EdgeId, GraphError> {
|
||||
let edge_type = edge_type.into();
|
||||
|
||||
// Verify both nodes exist
|
||||
{
|
||||
let nodes = self.nodes.read();
|
||||
if !nodes.contains_key(&source) {
|
||||
return Err(GraphError::NodeNotFound(source.to_string()));
|
||||
}
|
||||
if !nodes.contains_key(&target) {
|
||||
return Err(GraphError::NodeNotFound(target.to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
let edge = Edge::new(source, target, edge_type.clone(), properties);
|
||||
let id = edge.id;
|
||||
|
||||
// Update adjacency lists
|
||||
self.adjacency.write().entry(source).or_default().push(id);
|
||||
self.reverse_adj.write().entry(target).or_default().push(id);
|
||||
|
||||
// Update edge type index
|
||||
self.edge_type_index
|
||||
.write()
|
||||
.entry(edge_type)
|
||||
.or_default()
|
||||
.insert(id);
|
||||
|
||||
// Store edge
|
||||
self.edges.write().insert(id, edge);
|
||||
|
||||
Ok(id)
|
||||
}
|
||||
|
||||
/// Creates an undirected edge.
|
||||
pub fn create_undirected_edge(
|
||||
&self,
|
||||
source: NodeId,
|
||||
target: NodeId,
|
||||
edge_type: impl Into<String>,
|
||||
properties: JsonValue,
|
||||
) -> Result<EdgeId, GraphError> {
|
||||
let edge_type = edge_type.into();
|
||||
|
||||
// Verify both nodes exist
|
||||
{
|
||||
let nodes = self.nodes.read();
|
||||
if !nodes.contains_key(&source) {
|
||||
return Err(GraphError::NodeNotFound(source.to_string()));
|
||||
}
|
||||
if !nodes.contains_key(&target) {
|
||||
return Err(GraphError::NodeNotFound(target.to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
let edge = Edge::undirected(source, target, edge_type.clone(), properties);
|
||||
let id = edge.id;
|
||||
|
||||
// Update adjacency lists (both directions for undirected)
|
||||
{
|
||||
let mut adj = self.adjacency.write();
|
||||
adj.entry(source).or_default().push(id);
|
||||
adj.entry(target).or_default().push(id);
|
||||
}
|
||||
{
|
||||
let mut rev = self.reverse_adj.write();
|
||||
rev.entry(source).or_default().push(id);
|
||||
rev.entry(target).or_default().push(id);
|
||||
}
|
||||
|
||||
// Update edge type index
|
||||
self.edge_type_index
|
||||
.write()
|
||||
.entry(edge_type)
|
||||
.or_default()
|
||||
.insert(id);
|
||||
|
||||
// Store edge
|
||||
self.edges.write().insert(id, edge);
|
||||
|
||||
Ok(id)
|
||||
}
|
||||
|
||||
/// Gets an edge by ID.
|
||||
pub fn get_edge(&self, id: &EdgeId) -> Option<Edge> {
|
||||
self.edges.read().get(id).cloned()
|
||||
}
|
||||
|
||||
/// Deletes an edge.
|
||||
pub fn delete_edge(&self, id: &EdgeId) -> Result<(), GraphError> {
|
||||
let edge = self
|
||||
.edges
|
||||
.write()
|
||||
.remove(id)
|
||||
.ok_or_else(|| GraphError::EdgeNotFound(id.to_string()))?;
|
||||
|
||||
// Update adjacency lists
|
||||
{
|
||||
let mut adj = self.adjacency.write();
|
||||
if let Some(list) = adj.get_mut(&edge.source) {
|
||||
list.retain(|e| e != id);
|
||||
}
|
||||
if !edge.directed {
|
||||
if let Some(list) = adj.get_mut(&edge.target) {
|
||||
list.retain(|e| e != id);
|
||||
}
|
||||
}
|
||||
}
|
||||
{
|
||||
let mut rev = self.reverse_adj.write();
|
||||
if let Some(list) = rev.get_mut(&edge.target) {
|
||||
list.retain(|e| e != id);
|
||||
}
|
||||
if !edge.directed {
|
||||
if let Some(list) = rev.get_mut(&edge.source) {
|
||||
list.retain(|e| e != id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update edge type index
|
||||
if let Some(set) = self.edge_type_index.write().get_mut(&edge.edge_type) {
|
||||
set.remove(id);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Finds edges by type.
|
||||
pub fn find_edges_by_type(&self, edge_type: &str) -> Vec<Edge> {
|
||||
let edge_ids: Vec<EdgeId> = self
|
||||
.edge_type_index
|
||||
.read()
|
||||
.get(edge_type)
|
||||
.map(|set| set.iter().cloned().collect())
|
||||
.unwrap_or_default();
|
||||
|
||||
let edges = self.edges.read();
|
||||
edge_ids
|
||||
.iter()
|
||||
.filter_map(|id| edges.get(id).cloned())
|
||||
.collect()
|
||||
}
|
||||
|
||||
// ==================== Traversal Operations ====================
|
||||
|
||||
/// Gets neighboring nodes.
|
||||
pub fn neighbors(&self, id: &NodeId, direction: Direction) -> Vec<Node> {
|
||||
let edge_ids = self.edges_of_node(id, direction);
|
||||
let edges = self.edges.read();
|
||||
let nodes = self.nodes.read();
|
||||
|
||||
let mut neighbor_ids = HashSet::new();
|
||||
for eid in edge_ids {
|
||||
if let Some(edge) = edges.get(&eid) {
|
||||
if let Some(other) = self.get_neighbor_from_edge(edge, id, direction) {
|
||||
neighbor_ids.insert(other);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
neighbor_ids
|
||||
.iter()
|
||||
.filter_map(|nid| nodes.get(nid).cloned())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Gets edges connected to a node.
|
||||
pub fn edges_of(&self, id: &NodeId, direction: Direction) -> Vec<Edge> {
|
||||
let edge_ids = self.edges_of_node(id, direction);
|
||||
let edges = self.edges.read();
|
||||
|
||||
edge_ids
|
||||
.iter()
|
||||
.filter_map(|eid| edges.get(eid).cloned())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Gets edge IDs connected to a node.
|
||||
fn edges_of_node(&self, id: &NodeId, direction: Direction) -> Vec<EdgeId> {
|
||||
match direction {
|
||||
Direction::Outgoing => self.adjacency.read().get(id).cloned().unwrap_or_default(),
|
||||
Direction::Incoming => self.reverse_adj.read().get(id).cloned().unwrap_or_default(),
|
||||
Direction::Both => {
|
||||
let mut result = self.adjacency.read().get(id).cloned().unwrap_or_default();
|
||||
let incoming = self.reverse_adj.read().get(id).cloned().unwrap_or_default();
|
||||
for eid in incoming {
|
||||
if !result.contains(&eid) {
|
||||
result.push(eid);
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Gets the neighbor node from an edge.
|
||||
fn get_neighbor_from_edge(&self, edge: &Edge, from: &NodeId, direction: Direction) -> Option<NodeId> {
|
||||
match direction {
|
||||
Direction::Outgoing => {
|
||||
if &edge.source == from {
|
||||
Some(edge.target)
|
||||
} else if !edge.directed && &edge.target == from {
|
||||
Some(edge.source)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
Direction::Incoming => {
|
||||
if &edge.target == from {
|
||||
Some(edge.source)
|
||||
} else if !edge.directed && &edge.source == from {
|
||||
Some(edge.target)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
Direction::Both => edge.other_end(from).or_else(|| {
|
||||
// For directed edges, still return the other end
|
||||
if &edge.source == from {
|
||||
Some(edge.target)
|
||||
} else if &edge.target == from {
|
||||
Some(edge.source)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Gets neighbors connected by a specific edge type.
|
||||
pub fn neighbors_by_type(&self, id: &NodeId, edge_type: &str, direction: Direction) -> Vec<Node> {
|
||||
let edges = self.edges_of(id, direction);
|
||||
let nodes = self.nodes.read();
|
||||
|
||||
let mut neighbor_ids = HashSet::new();
|
||||
for edge in edges {
|
||||
if edge.edge_type == edge_type {
|
||||
if let Some(other) = self.get_neighbor_from_edge(&edge, id, direction) {
|
||||
neighbor_ids.insert(other);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
neighbor_ids
|
||||
.iter()
|
||||
.filter_map(|nid| nodes.get(nid).cloned())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Checks if an edge exists between two nodes.
|
||||
pub fn has_edge(&self, source: &NodeId, target: &NodeId, edge_type: Option<&str>) -> bool {
|
||||
let edges = self.edges_of(source, Direction::Outgoing);
|
||||
for edge in edges {
|
||||
if &edge.target == target {
|
||||
if let Some(et) = edge_type {
|
||||
if edge.edge_type == et {
|
||||
return true;
|
||||
}
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Gets all edges between two nodes.
|
||||
pub fn edges_between(&self, source: &NodeId, target: &NodeId) -> Vec<Edge> {
|
||||
let edges = self.edges_of(source, Direction::Outgoing);
|
||||
edges
|
||||
.into_iter()
|
||||
.filter(|e| &e.target == target || (!e.directed && &e.source == target))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for GraphStore {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_create_node() {
|
||||
let store = GraphStore::new();
|
||||
|
||||
let id = store.create_node(
|
||||
vec!["User".to_string()],
|
||||
serde_json::json!({"name": "Alice"}),
|
||||
);
|
||||
|
||||
let node = store.get_node(&id).unwrap();
|
||||
assert!(node.has_label("User"));
|
||||
assert_eq!(node.get_property("name"), Some(&serde_json::json!("Alice")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_edge() {
|
||||
let store = GraphStore::new();
|
||||
|
||||
let alice = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Alice"}));
|
||||
let bob = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Bob"}));
|
||||
|
||||
let edge_id = store
|
||||
.create_edge(alice, bob, "FRIEND", serde_json::json!({"since": 2020}))
|
||||
.unwrap();
|
||||
|
||||
let edge = store.get_edge(&edge_id).unwrap();
|
||||
assert_eq!(edge.source, alice);
|
||||
assert_eq!(edge.target, bob);
|
||||
assert_eq!(edge.edge_type, "FRIEND");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_neighbors() {
|
||||
let store = GraphStore::new();
|
||||
|
||||
let alice = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Alice"}));
|
||||
let bob = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Bob"}));
|
||||
let charlie = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Charlie"}));
|
||||
|
||||
store.create_edge(alice, bob, "FRIEND", serde_json::json!({})).unwrap();
|
||||
store.create_edge(alice, charlie, "FRIEND", serde_json::json!({})).unwrap();
|
||||
|
||||
let neighbors = store.neighbors(&alice, Direction::Outgoing);
|
||||
assert_eq!(neighbors.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_by_label() {
|
||||
let store = GraphStore::new();
|
||||
|
||||
store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Alice"}));
|
||||
store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Bob"}));
|
||||
store.create_node(vec!["Product".to_string()], serde_json::json!({"name": "Widget"}));
|
||||
|
||||
let users = store.find_nodes_by_label("User");
|
||||
assert_eq!(users.len(), 2);
|
||||
|
||||
let products = store.find_nodes_by_label("Product");
|
||||
assert_eq!(products.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_delete_node() {
|
||||
let store = GraphStore::new();
|
||||
|
||||
let alice = store.create_node(vec!["User".to_string()], serde_json::json!({}));
|
||||
let bob = store.create_node(vec!["User".to_string()], serde_json::json!({}));
|
||||
|
||||
store.create_edge(alice, bob, "FRIEND", serde_json::json!({})).unwrap();
|
||||
|
||||
// Delete Alice - should also delete the edge
|
||||
store.delete_node(&alice).unwrap();
|
||||
|
||||
assert!(store.get_node(&alice).is_none());
|
||||
assert_eq!(store.stats().edge_count, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_undirected_edge() {
|
||||
let store = GraphStore::new();
|
||||
|
||||
let a = store.create_node(vec![], serde_json::json!({}));
|
||||
let b = store.create_node(vec![], serde_json::json!({}));
|
||||
|
||||
store.create_undirected_edge(a, b, "LINK", serde_json::json!({})).unwrap();
|
||||
|
||||
// Both directions should work
|
||||
let a_neighbors = store.neighbors(&a, Direction::Outgoing);
|
||||
let b_neighbors = store.neighbors(&b, Direction::Outgoing);
|
||||
|
||||
assert_eq!(a_neighbors.len(), 1);
|
||||
assert_eq!(b_neighbors.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_edges_between() {
|
||||
let store = GraphStore::new();
|
||||
|
||||
let a = store.create_node(vec![], serde_json::json!({}));
|
||||
let b = store.create_node(vec![], serde_json::json!({}));
|
||||
|
||||
store.create_edge(a, b, "TYPE_A", serde_json::json!({})).unwrap();
|
||||
store.create_edge(a, b, "TYPE_B", serde_json::json!({})).unwrap();
|
||||
|
||||
let edges = store.edges_between(&a, &b);
|
||||
assert_eq!(edges.len(), 2);
|
||||
}
|
||||
}
|
||||
500
crates/synor-database/src/graph/traversal.rs
Normal file
500
crates/synor-database/src/graph/traversal.rs
Normal file
|
|
@ -0,0 +1,500 @@
|
|||
//! Graph traversal algorithms.
|
||||
|
||||
use super::edge::Edge;
|
||||
use super::node::{Node, NodeId};
|
||||
use super::store::{Direction, GraphStore};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value as JsonValue;
|
||||
use std::collections::{HashSet, VecDeque};
|
||||
|
||||
/// Query for graph traversal.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct TraversalQuery {
|
||||
/// Maximum depth to traverse.
|
||||
pub max_depth: usize,
|
||||
/// Edge types to follow (empty = all types).
|
||||
pub edge_types: Vec<String>,
|
||||
/// Direction to traverse.
|
||||
pub direction: TraversalDirection,
|
||||
/// Filter for nodes to include.
|
||||
pub node_filter: Option<JsonValue>,
|
||||
/// Filter for edges to follow.
|
||||
pub edge_filter: Option<JsonValue>,
|
||||
/// Maximum results to return.
|
||||
pub limit: Option<usize>,
|
||||
/// Labels to filter nodes by.
|
||||
pub labels: Vec<String>,
|
||||
}
|
||||
|
||||
/// Direction for traversal serialization.
|
||||
#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
|
||||
pub enum TraversalDirection {
|
||||
Outgoing,
|
||||
Incoming,
|
||||
Both,
|
||||
}
|
||||
|
||||
impl From<TraversalDirection> for Direction {
|
||||
fn from(td: TraversalDirection) -> Self {
|
||||
match td {
|
||||
TraversalDirection::Outgoing => Direction::Outgoing,
|
||||
TraversalDirection::Incoming => Direction::Incoming,
|
||||
TraversalDirection::Both => Direction::Both,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for TraversalQuery {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_depth: 3,
|
||||
edge_types: Vec::new(),
|
||||
direction: TraversalDirection::Outgoing,
|
||||
node_filter: None,
|
||||
edge_filter: None,
|
||||
limit: None,
|
||||
labels: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TraversalQuery {
|
||||
/// Creates a new traversal query.
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Sets the maximum depth.
|
||||
pub fn depth(mut self, depth: usize) -> Self {
|
||||
self.max_depth = depth;
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets edge types to follow.
|
||||
pub fn edge_types(mut self, types: Vec<String>) -> Self {
|
||||
self.edge_types = types;
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the traversal direction.
|
||||
pub fn direction(mut self, dir: TraversalDirection) -> Self {
|
||||
self.direction = dir;
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets a node filter.
|
||||
pub fn node_filter(mut self, filter: JsonValue) -> Self {
|
||||
self.node_filter = Some(filter);
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets an edge filter.
|
||||
pub fn edge_filter(mut self, filter: JsonValue) -> Self {
|
||||
self.edge_filter = Some(filter);
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets result limit.
|
||||
pub fn limit(mut self, limit: usize) -> Self {
|
||||
self.limit = Some(limit);
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets label filter.
|
||||
pub fn labels(mut self, labels: Vec<String>) -> Self {
|
||||
self.labels = labels;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of a traversal operation.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct TraversalResult {
|
||||
/// The node found.
|
||||
pub node: Node,
|
||||
/// Depth at which this node was found.
|
||||
pub depth: usize,
|
||||
/// Path from start to this node (node IDs).
|
||||
pub path: Vec<NodeId>,
|
||||
/// Edges followed to reach this node.
|
||||
pub edges: Vec<Edge>,
|
||||
}
|
||||
|
||||
/// Graph traverser for executing traversal queries.
|
||||
pub struct Traverser<'a> {
|
||||
store: &'a GraphStore,
|
||||
}
|
||||
|
||||
impl<'a> Traverser<'a> {
|
||||
/// Creates a new traverser.
|
||||
pub fn new(store: &'a GraphStore) -> Self {
|
||||
Self { store }
|
||||
}
|
||||
|
||||
/// Executes a BFS traversal from a starting node.
|
||||
pub fn traverse(&self, start: &NodeId, query: &TraversalQuery) -> Vec<TraversalResult> {
|
||||
let mut results = Vec::new();
|
||||
let mut visited = HashSet::new();
|
||||
let mut queue: VecDeque<(NodeId, usize, Vec<NodeId>, Vec<Edge>)> = VecDeque::new();
|
||||
|
||||
visited.insert(*start);
|
||||
queue.push_back((*start, 0, vec![*start], Vec::new()));
|
||||
|
||||
let direction: Direction = query.direction.into();
|
||||
|
||||
while let Some((current_id, depth, path, edges_path)) = queue.pop_front() {
|
||||
// Check limit
|
||||
if let Some(limit) = query.limit {
|
||||
if results.len() >= limit {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Get current node
|
||||
if let Some(node) = self.store.get_node(¤t_id) {
|
||||
// Skip start node in results (depth 0)
|
||||
if depth > 0 {
|
||||
// Apply filters
|
||||
if self.matches_query(&node, query) {
|
||||
results.push(TraversalResult {
|
||||
node: node.clone(),
|
||||
depth,
|
||||
path: path.clone(),
|
||||
edges: edges_path.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Continue if not at max depth
|
||||
if depth < query.max_depth {
|
||||
let edges = self.store.edges_of(¤t_id, direction);
|
||||
|
||||
for edge in edges {
|
||||
// Check edge type filter
|
||||
if !query.edge_types.is_empty() && !query.edge_types.contains(&edge.edge_type) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check edge filter
|
||||
if let Some(ref filter) = query.edge_filter {
|
||||
if !edge.matches_properties(filter) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Get neighbor
|
||||
let neighbor_id = self.get_neighbor(&edge, ¤t_id, direction);
|
||||
if let Some(next_id) = neighbor_id {
|
||||
if !visited.contains(&next_id) {
|
||||
visited.insert(next_id);
|
||||
let mut new_path = path.clone();
|
||||
new_path.push(next_id);
|
||||
let mut new_edges = edges_path.clone();
|
||||
new_edges.push(edge);
|
||||
queue.push_back((next_id, depth + 1, new_path, new_edges));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
results
|
||||
}
|
||||
|
||||
/// Executes a DFS traversal from a starting node.
|
||||
pub fn traverse_dfs(&self, start: &NodeId, query: &TraversalQuery) -> Vec<TraversalResult> {
|
||||
let mut results = Vec::new();
|
||||
let mut visited = HashSet::new();
|
||||
let direction: Direction = query.direction.into();
|
||||
|
||||
self.dfs_visit(
|
||||
start,
|
||||
0,
|
||||
vec![*start],
|
||||
Vec::new(),
|
||||
&mut visited,
|
||||
&mut results,
|
||||
query,
|
||||
direction,
|
||||
);
|
||||
|
||||
results
|
||||
}
|
||||
|
||||
fn dfs_visit(
|
||||
&self,
|
||||
current_id: &NodeId,
|
||||
depth: usize,
|
||||
path: Vec<NodeId>,
|
||||
edges_path: Vec<Edge>,
|
||||
visited: &mut HashSet<NodeId>,
|
||||
results: &mut Vec<TraversalResult>,
|
||||
query: &TraversalQuery,
|
||||
direction: Direction,
|
||||
) {
|
||||
// Check limit
|
||||
if let Some(limit) = query.limit {
|
||||
if results.len() >= limit {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
visited.insert(*current_id);
|
||||
|
||||
if let Some(node) = self.store.get_node(current_id) {
|
||||
// Skip start node in results
|
||||
if depth > 0 && self.matches_query(&node, query) {
|
||||
results.push(TraversalResult {
|
||||
node: node.clone(),
|
||||
depth,
|
||||
path: path.clone(),
|
||||
edges: edges_path.clone(),
|
||||
});
|
||||
}
|
||||
|
||||
// Continue if not at max depth
|
||||
if depth < query.max_depth {
|
||||
let edges = self.store.edges_of(current_id, direction);
|
||||
|
||||
for edge in edges {
|
||||
// Check edge type filter
|
||||
if !query.edge_types.is_empty() && !query.edge_types.contains(&edge.edge_type) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check edge filter
|
||||
if let Some(ref filter) = query.edge_filter {
|
||||
if !edge.matches_properties(filter) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(next_id) = self.get_neighbor(&edge, current_id, direction) {
|
||||
if !visited.contains(&next_id) {
|
||||
let mut new_path = path.clone();
|
||||
new_path.push(next_id);
|
||||
let mut new_edges = edges_path.clone();
|
||||
new_edges.push(edge);
|
||||
|
||||
self.dfs_visit(
|
||||
&next_id,
|
||||
depth + 1,
|
||||
new_path,
|
||||
new_edges,
|
||||
visited,
|
||||
results,
|
||||
query,
|
||||
direction,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Checks if a node matches the query filters.
|
||||
fn matches_query(&self, node: &Node, query: &TraversalQuery) -> bool {
|
||||
// Check labels
|
||||
if !query.labels.is_empty() {
|
||||
let has_label = query.labels.iter().any(|l| node.has_label(l));
|
||||
if !has_label {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Check node filter
|
||||
if let Some(ref filter) = query.node_filter {
|
||||
if !node.matches_properties(filter) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
/// Gets the neighbor node ID from an edge.
|
||||
fn get_neighbor(&self, edge: &Edge, from: &NodeId, direction: Direction) -> Option<NodeId> {
|
||||
match direction {
|
||||
Direction::Outgoing => {
|
||||
if &edge.source == from {
|
||||
Some(edge.target)
|
||||
} else if !edge.directed && &edge.target == from {
|
||||
Some(edge.source)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
Direction::Incoming => {
|
||||
if &edge.target == from {
|
||||
Some(edge.source)
|
||||
} else if !edge.directed && &edge.source == from {
|
||||
Some(edge.target)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
Direction::Both => {
|
||||
if &edge.source == from {
|
||||
Some(edge.target)
|
||||
} else if &edge.target == from {
|
||||
Some(edge.source)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Finds all nodes within a certain distance.
|
||||
pub fn within_distance(&self, start: &NodeId, max_distance: usize) -> Vec<(Node, usize)> {
|
||||
let query = TraversalQuery::new().depth(max_distance);
|
||||
self.traverse(start, &query)
|
||||
.into_iter()
|
||||
.map(|r| (r.node, r.depth))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Finds mutual connections between two nodes.
|
||||
pub fn mutual_connections(
|
||||
&self,
|
||||
node_a: &NodeId,
|
||||
node_b: &NodeId,
|
||||
edge_type: Option<&str>,
|
||||
) -> Vec<Node> {
|
||||
let query = TraversalQuery::new()
|
||||
.depth(1)
|
||||
.edge_types(edge_type.map(|s| vec![s.to_string()]).unwrap_or_default());
|
||||
|
||||
let neighbors_a: HashSet<NodeId> = self
|
||||
.traverse(node_a, &query)
|
||||
.into_iter()
|
||||
.map(|r| r.node.id)
|
||||
.collect();
|
||||
|
||||
let neighbors_b: HashSet<NodeId> = self
|
||||
.traverse(node_b, &query)
|
||||
.into_iter()
|
||||
.map(|r| r.node.id)
|
||||
.collect();
|
||||
|
||||
let mutual: HashSet<_> = neighbors_a.intersection(&neighbors_b).collect();
|
||||
|
||||
mutual
|
||||
.iter()
|
||||
.filter_map(|id| self.store.get_node(id))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn setup_social_graph() -> GraphStore {
|
||||
let store = GraphStore::new();
|
||||
|
||||
let alice = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Alice"}));
|
||||
let bob = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Bob"}));
|
||||
let charlie = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Charlie"}));
|
||||
let dave = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Dave"}));
|
||||
|
||||
// Alice -> Bob -> Charlie -> Dave
|
||||
store.create_edge(alice, bob, "FRIEND", serde_json::json!({})).unwrap();
|
||||
store.create_edge(bob, charlie, "FRIEND", serde_json::json!({})).unwrap();
|
||||
store.create_edge(charlie, dave, "FRIEND", serde_json::json!({})).unwrap();
|
||||
|
||||
// Alice -> Charlie (shortcut)
|
||||
store.create_edge(alice, charlie, "KNOWS", serde_json::json!({})).unwrap();
|
||||
|
||||
store
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_basic_traversal() {
|
||||
let store = setup_social_graph();
|
||||
let traverser = Traverser::new(&store);
|
||||
|
||||
let users = store.find_nodes_by_label("User");
|
||||
let alice = users.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("Alice"))).unwrap();
|
||||
|
||||
let query = TraversalQuery::new().depth(2);
|
||||
let results = traverser.traverse(&alice.id, &query);
|
||||
|
||||
// Should find Bob (depth 1), Charlie (depth 1 and 2), and Dave (depth 2)
|
||||
assert!(results.len() >= 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_edge_type_filter() {
|
||||
let store = setup_social_graph();
|
||||
let traverser = Traverser::new(&store);
|
||||
|
||||
let users = store.find_nodes_by_label("User");
|
||||
let alice = users.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("Alice"))).unwrap();
|
||||
|
||||
let query = TraversalQuery::new()
|
||||
.depth(2)
|
||||
.edge_types(vec!["FRIEND".to_string()]);
|
||||
let results = traverser.traverse(&alice.id, &query);
|
||||
|
||||
// Following only FRIEND edges: Alice -> Bob -> Charlie
|
||||
let names: Vec<_> = results.iter().filter_map(|r| r.node.get_property("name")).collect();
|
||||
assert!(names.contains(&&serde_json::json!("Bob")));
|
||||
assert!(names.contains(&&serde_json::json!("Charlie")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_depth_limit() {
|
||||
let store = setup_social_graph();
|
||||
let traverser = Traverser::new(&store);
|
||||
|
||||
let users = store.find_nodes_by_label("User");
|
||||
let alice = users.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("Alice"))).unwrap();
|
||||
|
||||
let query = TraversalQuery::new().depth(1);
|
||||
let results = traverser.traverse(&alice.id, &query);
|
||||
|
||||
// Depth 1: only direct neighbors
|
||||
for result in &results {
|
||||
assert_eq!(result.depth, 1);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_result_limit() {
|
||||
let store = setup_social_graph();
|
||||
let traverser = Traverser::new(&store);
|
||||
|
||||
let users = store.find_nodes_by_label("User");
|
||||
let alice = users.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("Alice"))).unwrap();
|
||||
|
||||
let query = TraversalQuery::new().depth(10).limit(2);
|
||||
let results = traverser.traverse(&alice.id, &query);
|
||||
|
||||
assert!(results.len() <= 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mutual_connections() {
|
||||
let store = GraphStore::new();
|
||||
|
||||
let alice = store.create_node(vec![], serde_json::json!({"name": "Alice"}));
|
||||
let bob = store.create_node(vec![], serde_json::json!({"name": "Bob"}));
|
||||
let mutual1 = store.create_node(vec![], serde_json::json!({"name": "Mutual1"}));
|
||||
let mutual2 = store.create_node(vec![], serde_json::json!({"name": "Mutual2"}));
|
||||
let only_alice = store.create_node(vec![], serde_json::json!({"name": "OnlyAlice"}));
|
||||
|
||||
store.create_edge(alice, mutual1, "FRIEND", serde_json::json!({})).unwrap();
|
||||
store.create_edge(alice, mutual2, "FRIEND", serde_json::json!({})).unwrap();
|
||||
store.create_edge(alice, only_alice, "FRIEND", serde_json::json!({})).unwrap();
|
||||
store.create_edge(bob, mutual1, "FRIEND", serde_json::json!({})).unwrap();
|
||||
store.create_edge(bob, mutual2, "FRIEND", serde_json::json!({})).unwrap();
|
||||
|
||||
let traverser = Traverser::new(&store);
|
||||
let mutual = traverser.mutual_connections(&alice, &bob, Some("FRIEND"));
|
||||
|
||||
assert_eq!(mutual.len(), 2);
|
||||
}
|
||||
}
|
||||
|
|
@ -45,20 +45,32 @@
|
|||
pub mod document;
|
||||
pub mod error;
|
||||
pub mod gateway;
|
||||
pub mod graph;
|
||||
pub mod index;
|
||||
pub mod keyvalue;
|
||||
pub mod query;
|
||||
pub mod replication;
|
||||
pub mod schema;
|
||||
pub mod sql;
|
||||
pub mod timeseries;
|
||||
pub mod vector;
|
||||
|
||||
pub use document::{Collection, Document, DocumentId, DocumentStore};
|
||||
pub use error::DatabaseError;
|
||||
pub use gateway::{GatewayConfig, GatewayServer};
|
||||
pub use graph::{
|
||||
Direction, Edge, EdgeId, GraphError, GraphQuery, GraphQueryParser, GraphStore, Node, NodeId,
|
||||
PathFinder, PathResult, TraversalQuery, TraversalResult, Traverser,
|
||||
};
|
||||
pub use index::{Index, IndexConfig, IndexManager, IndexType};
|
||||
pub use keyvalue::{KeyValue, KeyValueStore, KvEntry};
|
||||
pub use query::{Filter, Query, QueryEngine, QueryResult, SortOrder};
|
||||
pub use schema::{Field, FieldType, Schema, SchemaValidator};
|
||||
pub use replication::{
|
||||
ClusterConfig, Command as RaftCommand, NodeRole, RaftConfig, RaftEvent, RaftNode, RaftState,
|
||||
ReplicatedLog,
|
||||
};
|
||||
pub use sql::{QueryResult as SqlQueryResult, SqlEngine, SqlParser, SqlType, SqlValue, Table, TableDef};
|
||||
pub use timeseries::{DataPoint, Metric, TimeSeries, TimeSeriesStore};
|
||||
pub use vector::{Embedding, SimilarityMetric, VectorIndex, VectorStore};
|
||||
|
||||
|
|
|
|||
393
crates/synor-database/src/replication/cluster.rs
Normal file
393
crates/synor-database/src/replication/cluster.rs
Normal file
|
|
@ -0,0 +1,393 @@
|
|||
//! Cluster configuration and peer management.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::net::SocketAddr;
|
||||
use std::str::FromStr;
|
||||
|
||||
/// Unique identifier for a node in the cluster.
|
||||
pub type NodeId = u64;
|
||||
|
||||
/// Address for a peer node.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)]
|
||||
pub struct PeerAddress {
|
||||
/// Host address (IP or hostname).
|
||||
pub host: String,
|
||||
/// Port number.
|
||||
pub port: u16,
|
||||
}
|
||||
|
||||
impl PeerAddress {
|
||||
/// Creates a new peer address.
|
||||
pub fn new(host: impl Into<String>, port: u16) -> Self {
|
||||
Self {
|
||||
host: host.into(),
|
||||
port,
|
||||
}
|
||||
}
|
||||
|
||||
/// Parses from "host:port" format.
|
||||
pub fn parse(s: &str) -> Option<Self> {
|
||||
let parts: Vec<&str> = s.split(':').collect();
|
||||
if parts.len() == 2 {
|
||||
parts[1].parse().ok().map(|port| Self {
|
||||
host: parts[0].to_string(),
|
||||
port,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts to SocketAddr if possible.
|
||||
pub fn to_socket_addr(&self) -> Option<SocketAddr> {
|
||||
SocketAddr::from_str(&format!("{}:{}", self.host, self.port)).ok()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for PeerAddress {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}:{}", self.host, self.port)
|
||||
}
|
||||
}
|
||||
|
||||
/// Information about a peer node.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct PeerInfo {
|
||||
/// Node identifier.
|
||||
pub id: NodeId,
|
||||
/// Network address.
|
||||
pub address: PeerAddress,
|
||||
/// Whether this peer is a voting member.
|
||||
pub voting: bool,
|
||||
/// Last known state.
|
||||
pub state: PeerState,
|
||||
}
|
||||
|
||||
impl PeerInfo {
|
||||
/// Creates new peer info.
|
||||
pub fn new(id: NodeId, address: PeerAddress) -> Self {
|
||||
Self {
|
||||
id,
|
||||
address,
|
||||
voting: true,
|
||||
state: PeerState::Unknown,
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a non-voting learner peer.
|
||||
pub fn learner(id: NodeId, address: PeerAddress) -> Self {
|
||||
Self {
|
||||
id,
|
||||
address,
|
||||
voting: false,
|
||||
state: PeerState::Unknown,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// State of a peer from this node's perspective.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub enum PeerState {
|
||||
/// State unknown (initial).
|
||||
Unknown,
|
||||
/// Peer is reachable.
|
||||
Reachable,
|
||||
/// Peer is unreachable.
|
||||
Unreachable,
|
||||
/// Peer is being probed.
|
||||
Probing,
|
||||
}
|
||||
|
||||
/// Configuration for the cluster.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct ClusterConfig {
|
||||
/// This node's ID.
|
||||
pub node_id: NodeId,
|
||||
/// This node's address.
|
||||
pub address: PeerAddress,
|
||||
/// Known peers in the cluster.
|
||||
pub peers: HashMap<NodeId, PeerInfo>,
|
||||
/// Configuration index (for joint consensus).
|
||||
pub config_index: u64,
|
||||
/// Whether this is a joint configuration.
|
||||
pub joint: bool,
|
||||
/// Old configuration (for joint consensus).
|
||||
pub old_peers: Option<HashMap<NodeId, PeerInfo>>,
|
||||
}
|
||||
|
||||
impl ClusterConfig {
|
||||
/// Creates a new single-node cluster configuration.
|
||||
pub fn new(node_id: NodeId, address: PeerAddress) -> Self {
|
||||
Self {
|
||||
node_id,
|
||||
address,
|
||||
peers: HashMap::new(),
|
||||
config_index: 0,
|
||||
joint: false,
|
||||
old_peers: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Adds a peer to the cluster.
|
||||
pub fn add_peer(&mut self, peer: PeerInfo) {
|
||||
self.peers.insert(peer.id, peer);
|
||||
}
|
||||
|
||||
/// Removes a peer from the cluster.
|
||||
pub fn remove_peer(&mut self, id: NodeId) -> Option<PeerInfo> {
|
||||
self.peers.remove(&id)
|
||||
}
|
||||
|
||||
/// Gets a peer by ID.
|
||||
pub fn get_peer(&self, id: NodeId) -> Option<&PeerInfo> {
|
||||
self.peers.get(&id)
|
||||
}
|
||||
|
||||
/// Gets a mutable reference to a peer.
|
||||
pub fn get_peer_mut(&mut self, id: NodeId) -> Option<&mut PeerInfo> {
|
||||
self.peers.get_mut(&id)
|
||||
}
|
||||
|
||||
/// Returns all peer IDs.
|
||||
pub fn peer_ids(&self) -> Vec<NodeId> {
|
||||
self.peers.keys().copied().collect()
|
||||
}
|
||||
|
||||
/// Returns all voting peer IDs.
|
||||
pub fn voting_peers(&self) -> Vec<NodeId> {
|
||||
self.peers
|
||||
.iter()
|
||||
.filter(|(_, p)| p.voting)
|
||||
.map(|(id, _)| *id)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Returns the total number of voting members (including self).
|
||||
pub fn voting_members(&self) -> usize {
|
||||
self.peers.values().filter(|p| p.voting).count() + 1
|
||||
}
|
||||
|
||||
/// Returns the quorum size needed for consensus.
|
||||
pub fn quorum_size(&self) -> usize {
|
||||
self.voting_members() / 2 + 1
|
||||
}
|
||||
|
||||
/// Checks if we have quorum with the given votes.
|
||||
pub fn has_quorum(&self, votes: usize) -> bool {
|
||||
votes >= self.quorum_size()
|
||||
}
|
||||
|
||||
/// Starts a configuration change (joint consensus).
|
||||
pub fn begin_config_change(&mut self, new_peers: HashMap<NodeId, PeerInfo>, index: u64) {
|
||||
self.old_peers = Some(self.peers.clone());
|
||||
self.peers = new_peers;
|
||||
self.config_index = index;
|
||||
self.joint = true;
|
||||
}
|
||||
|
||||
/// Completes a configuration change.
|
||||
pub fn complete_config_change(&mut self) {
|
||||
self.old_peers = None;
|
||||
self.joint = false;
|
||||
}
|
||||
|
||||
/// Aborts a configuration change.
|
||||
pub fn abort_config_change(&mut self) {
|
||||
if let Some(old) = self.old_peers.take() {
|
||||
self.peers = old;
|
||||
}
|
||||
self.joint = false;
|
||||
}
|
||||
|
||||
/// Checks if node is in joint consensus mode.
|
||||
pub fn is_joint(&self) -> bool {
|
||||
self.joint
|
||||
}
|
||||
|
||||
/// For joint consensus: checks if we have quorum in BOTH configurations.
|
||||
pub fn has_joint_quorum(&self, new_votes: usize, old_votes: usize) -> bool {
|
||||
if !self.joint {
|
||||
return self.has_quorum(new_votes);
|
||||
}
|
||||
|
||||
let new_quorum = self.quorum_size();
|
||||
let old_quorum = self
|
||||
.old_peers
|
||||
.as_ref()
|
||||
.map(|p| p.values().filter(|peer| peer.voting).count() / 2 + 1)
|
||||
.unwrap_or(1);
|
||||
|
||||
new_votes >= new_quorum && old_votes >= old_quorum
|
||||
}
|
||||
|
||||
/// Updates peer state.
|
||||
pub fn update_peer_state(&mut self, id: NodeId, state: PeerState) {
|
||||
if let Some(peer) = self.peers.get_mut(&id) {
|
||||
peer.state = state;
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns all reachable peers.
|
||||
pub fn reachable_peers(&self) -> Vec<NodeId> {
|
||||
self.peers
|
||||
.iter()
|
||||
.filter(|(_, p)| p.state == PeerState::Reachable)
|
||||
.map(|(id, _)| *id)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Serializes the configuration.
|
||||
pub fn to_bytes(&self) -> Vec<u8> {
|
||||
bincode::serialize(self).unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Deserializes the configuration.
|
||||
pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
|
||||
bincode::deserialize(bytes).ok()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ClusterConfig {
|
||||
fn default() -> Self {
|
||||
Self::new(1, PeerAddress::new("127.0.0.1", 9000))
|
||||
}
|
||||
}
|
||||
|
||||
/// Builder for cluster configurations.
|
||||
pub struct ClusterBuilder {
|
||||
config: ClusterConfig,
|
||||
}
|
||||
|
||||
impl ClusterBuilder {
|
||||
/// Creates a new builder.
|
||||
pub fn new(node_id: NodeId, address: PeerAddress) -> Self {
|
||||
Self {
|
||||
config: ClusterConfig::new(node_id, address),
|
||||
}
|
||||
}
|
||||
|
||||
/// Adds a peer.
|
||||
pub fn with_peer(mut self, id: NodeId, address: PeerAddress) -> Self {
|
||||
self.config.add_peer(PeerInfo::new(id, address));
|
||||
self
|
||||
}
|
||||
|
||||
/// Adds a learner (non-voting) peer.
|
||||
pub fn with_learner(mut self, id: NodeId, address: PeerAddress) -> Self {
|
||||
self.config.add_peer(PeerInfo::learner(id, address));
|
||||
self
|
||||
}
|
||||
|
||||
/// Builds the configuration.
|
||||
pub fn build(self) -> ClusterConfig {
|
||||
self.config
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_peer_address() {
|
||||
let addr = PeerAddress::new("192.168.1.1", 9000);
|
||||
assert_eq!(addr.to_string(), "192.168.1.1:9000");
|
||||
|
||||
let parsed = PeerAddress::parse("10.0.0.1:8080").unwrap();
|
||||
assert_eq!(parsed.host, "10.0.0.1");
|
||||
assert_eq!(parsed.port, 8080);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cluster_config() {
|
||||
let mut config = ClusterConfig::new(1, PeerAddress::new("127.0.0.1", 9000));
|
||||
|
||||
config.add_peer(PeerInfo::new(2, PeerAddress::new("127.0.0.1", 9001)));
|
||||
config.add_peer(PeerInfo::new(3, PeerAddress::new("127.0.0.1", 9002)));
|
||||
|
||||
assert_eq!(config.voting_members(), 3);
|
||||
assert_eq!(config.quorum_size(), 2);
|
||||
assert!(config.has_quorum(2));
|
||||
assert!(!config.has_quorum(1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quorum_sizes() {
|
||||
// 1 node: quorum = 1
|
||||
let config1 = ClusterConfig::new(1, PeerAddress::new("127.0.0.1", 9000));
|
||||
assert_eq!(config1.quorum_size(), 1);
|
||||
|
||||
// 3 nodes: quorum = 2
|
||||
let mut config3 = ClusterConfig::new(1, PeerAddress::new("127.0.0.1", 9000));
|
||||
config3.add_peer(PeerInfo::new(2, PeerAddress::new("127.0.0.1", 9001)));
|
||||
config3.add_peer(PeerInfo::new(3, PeerAddress::new("127.0.0.1", 9002)));
|
||||
assert_eq!(config3.quorum_size(), 2);
|
||||
|
||||
// 5 nodes: quorum = 3
|
||||
let mut config5 = config3.clone();
|
||||
config5.add_peer(PeerInfo::new(4, PeerAddress::new("127.0.0.1", 9003)));
|
||||
config5.add_peer(PeerInfo::new(5, PeerAddress::new("127.0.0.1", 9004)));
|
||||
assert_eq!(config5.quorum_size(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_learner_peers() {
|
||||
let mut config = ClusterConfig::new(1, PeerAddress::new("127.0.0.1", 9000));
|
||||
config.add_peer(PeerInfo::new(2, PeerAddress::new("127.0.0.1", 9001)));
|
||||
config.add_peer(PeerInfo::learner(3, PeerAddress::new("127.0.0.1", 9002)));
|
||||
|
||||
// Learners don't count toward quorum
|
||||
assert_eq!(config.voting_members(), 2); // self + node 2
|
||||
assert_eq!(config.quorum_size(), 2);
|
||||
assert_eq!(config.voting_peers().len(), 1); // only node 2
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cluster_builder() {
|
||||
let config = ClusterBuilder::new(1, PeerAddress::new("127.0.0.1", 9000))
|
||||
.with_peer(2, PeerAddress::new("127.0.0.1", 9001))
|
||||
.with_peer(3, PeerAddress::new("127.0.0.1", 9002))
|
||||
.with_learner(4, PeerAddress::new("127.0.0.1", 9003))
|
||||
.build();
|
||||
|
||||
assert_eq!(config.peer_ids().len(), 3);
|
||||
assert_eq!(config.voting_members(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_joint_consensus() {
|
||||
let mut config = ClusterConfig::new(1, PeerAddress::new("127.0.0.1", 9000));
|
||||
config.add_peer(PeerInfo::new(2, PeerAddress::new("127.0.0.1", 9001)));
|
||||
config.add_peer(PeerInfo::new(3, PeerAddress::new("127.0.0.1", 9002)));
|
||||
|
||||
// Start config change: add node 4
|
||||
let mut new_peers = config.peers.clone();
|
||||
new_peers.insert(4, PeerInfo::new(4, PeerAddress::new("127.0.0.1", 9003)));
|
||||
|
||||
config.begin_config_change(new_peers, 100);
|
||||
assert!(config.is_joint());
|
||||
|
||||
// Need quorum in both old (2 of 3) and new (3 of 4) configs
|
||||
assert!(config.has_joint_quorum(3, 2));
|
||||
assert!(!config.has_joint_quorum(2, 2)); // Not enough in new config
|
||||
assert!(!config.has_joint_quorum(3, 1)); // Not enough in old config
|
||||
|
||||
config.complete_config_change();
|
||||
assert!(!config.is_joint());
|
||||
assert_eq!(config.voting_members(), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_serialization() {
|
||||
let config = ClusterBuilder::new(1, PeerAddress::new("127.0.0.1", 9000))
|
||||
.with_peer(2, PeerAddress::new("127.0.0.1", 9001))
|
||||
.build();
|
||||
|
||||
let bytes = config.to_bytes();
|
||||
let decoded = ClusterConfig::from_bytes(&bytes).unwrap();
|
||||
|
||||
assert_eq!(decoded.node_id, 1);
|
||||
assert_eq!(decoded.peers.len(), 1);
|
||||
}
|
||||
}
|
||||
311
crates/synor-database/src/replication/election.rs
Normal file
311
crates/synor-database/src/replication/election.rs
Normal file
|
|
@ -0,0 +1,311 @@
|
|||
//! Leader election logic for Raft.
|
||||
|
||||
use super::log::ReplicatedLog;
|
||||
use super::rpc::{RequestVote, RequestVoteResponse};
|
||||
use super::state::RaftState;
|
||||
use std::collections::HashSet;
|
||||
|
||||
/// Result of a leader election.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub enum ElectionResult {
|
||||
/// Won the election, became leader.
|
||||
Won,
|
||||
/// Lost the election (higher term discovered).
|
||||
Lost,
|
||||
/// Election timed out, will retry.
|
||||
Timeout,
|
||||
/// Still in progress.
|
||||
InProgress,
|
||||
}
|
||||
|
||||
/// Tracks election state.
|
||||
pub struct Election {
|
||||
/// Node ID of the candidate.
|
||||
node_id: u64,
|
||||
/// Term for this election.
|
||||
term: u64,
|
||||
/// Number of votes received (including self).
|
||||
votes_received: HashSet<u64>,
|
||||
/// Total number of nodes in cluster (including self).
|
||||
cluster_size: usize,
|
||||
/// Whether this election is still active.
|
||||
active: bool,
|
||||
}
|
||||
|
||||
impl Election {
|
||||
/// Starts a new election.
|
||||
pub fn new(node_id: u64, term: u64, cluster_size: usize) -> Self {
|
||||
let mut votes = HashSet::new();
|
||||
votes.insert(node_id); // Vote for self
|
||||
|
||||
Self {
|
||||
node_id,
|
||||
term,
|
||||
votes_received: votes,
|
||||
cluster_size,
|
||||
active: true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the term for this election.
|
||||
pub fn term(&self) -> u64 {
|
||||
self.term
|
||||
}
|
||||
|
||||
/// Checks if the election is still active.
|
||||
pub fn is_active(&self) -> bool {
|
||||
self.active
|
||||
}
|
||||
|
||||
/// Gets the number of votes received.
|
||||
pub fn vote_count(&self) -> usize {
|
||||
self.votes_received.len()
|
||||
}
|
||||
|
||||
/// Gets the majority threshold.
|
||||
pub fn majority(&self) -> usize {
|
||||
(self.cluster_size / 2) + 1
|
||||
}
|
||||
|
||||
/// Records a vote from a peer.
|
||||
pub fn record_vote(&mut self, peer_id: u64, granted: bool) -> ElectionResult {
|
||||
if !self.active {
|
||||
return ElectionResult::Lost;
|
||||
}
|
||||
|
||||
if granted {
|
||||
self.votes_received.insert(peer_id);
|
||||
|
||||
if self.votes_received.len() >= self.majority() {
|
||||
self.active = false;
|
||||
return ElectionResult::Won;
|
||||
}
|
||||
}
|
||||
|
||||
ElectionResult::InProgress
|
||||
}
|
||||
|
||||
/// Cancels the election (e.g., discovered higher term).
|
||||
pub fn cancel(&mut self) {
|
||||
self.active = false;
|
||||
}
|
||||
|
||||
/// Creates a RequestVote message for this election.
|
||||
pub fn create_request(&self, log: &ReplicatedLog) -> RequestVote {
|
||||
RequestVote::new(
|
||||
self.term,
|
||||
self.node_id,
|
||||
log.last_index(),
|
||||
log.last_term(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Checks the current result of the election.
|
||||
pub fn result(&self) -> ElectionResult {
|
||||
if !self.active {
|
||||
ElectionResult::Lost
|
||||
} else if self.votes_received.len() >= self.majority() {
|
||||
ElectionResult::Won
|
||||
} else {
|
||||
ElectionResult::InProgress
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Handles vote requests from candidates.
|
||||
pub struct VoteHandler;
|
||||
|
||||
impl VoteHandler {
|
||||
/// Processes a vote request and returns whether to grant the vote.
|
||||
pub fn handle_request(
|
||||
state: &mut RaftState,
|
||||
log: &ReplicatedLog,
|
||||
request: &RequestVote,
|
||||
) -> RequestVoteResponse {
|
||||
// If request's term < current term, deny
|
||||
if request.term < state.current_term {
|
||||
return RequestVoteResponse::deny(state.current_term);
|
||||
}
|
||||
|
||||
// If request's term > current term, update term and become follower
|
||||
if request.term > state.current_term {
|
||||
state.become_follower(request.term);
|
||||
}
|
||||
|
||||
// Check if we can grant vote
|
||||
let can_vote = state.voted_for.is_none() || state.voted_for == Some(request.candidate_id);
|
||||
|
||||
// Check if candidate's log is at least as up-to-date as ours
|
||||
let log_ok = log.is_up_to_date(request.last_log_index, request.last_log_term);
|
||||
|
||||
if can_vote && log_ok {
|
||||
state.voted_for = Some(request.candidate_id);
|
||||
RequestVoteResponse::grant(state.current_term)
|
||||
} else {
|
||||
RequestVoteResponse::deny(state.current_term)
|
||||
}
|
||||
}
|
||||
|
||||
/// Processes a vote response.
|
||||
pub fn handle_response(
|
||||
state: &mut RaftState,
|
||||
election: &mut Election,
|
||||
from_peer: u64,
|
||||
response: &RequestVoteResponse,
|
||||
) -> ElectionResult {
|
||||
// If response's term > current term, become follower
|
||||
if response.term > state.current_term {
|
||||
state.become_follower(response.term);
|
||||
election.cancel();
|
||||
return ElectionResult::Lost;
|
||||
}
|
||||
|
||||
// Record the vote
|
||||
election.record_vote(from_peer, response.vote_granted)
|
||||
}
|
||||
}
|
||||
|
||||
/// Election timeout generator.
|
||||
pub struct ElectionTimeout {
|
||||
/// Minimum timeout in milliseconds.
|
||||
min_ms: u64,
|
||||
/// Maximum timeout in milliseconds.
|
||||
max_ms: u64,
|
||||
}
|
||||
|
||||
impl ElectionTimeout {
|
||||
/// Creates a new timeout generator.
|
||||
pub fn new(min_ms: u64, max_ms: u64) -> Self {
|
||||
Self { min_ms, max_ms }
|
||||
}
|
||||
|
||||
/// Generates a random timeout duration.
|
||||
pub fn random_timeout(&self) -> std::time::Duration {
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
// Simple pseudo-random based on current time
|
||||
let now = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_nanos();
|
||||
|
||||
let range = self.max_ms - self.min_ms;
|
||||
let random_add = (now % range as u128) as u64;
|
||||
let timeout_ms = self.min_ms + random_add;
|
||||
|
||||
std::time::Duration::from_millis(timeout_ms)
|
||||
}
|
||||
|
||||
/// Returns the minimum timeout.
|
||||
pub fn min_timeout(&self) -> std::time::Duration {
|
||||
std::time::Duration::from_millis(self.min_ms)
|
||||
}
|
||||
|
||||
/// Returns the maximum timeout.
|
||||
pub fn max_timeout(&self) -> std::time::Duration {
|
||||
std::time::Duration::from_millis(self.max_ms)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ElectionTimeout {
|
||||
fn default() -> Self {
|
||||
// Default Raft election timeout: 150-300ms
|
||||
Self::new(150, 300)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::replication::state::Command;
|
||||
use crate::replication::log::LogEntry;
|
||||
|
||||
#[test]
|
||||
fn test_election_basic() {
|
||||
let election = Election::new(1, 1, 5);
|
||||
assert!(election.is_active());
|
||||
assert_eq!(election.vote_count(), 1); // Self vote
|
||||
assert_eq!(election.majority(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_election_win() {
|
||||
let mut election = Election::new(1, 1, 5);
|
||||
|
||||
// Need 3 votes total (including self)
|
||||
assert_eq!(election.record_vote(2, true), ElectionResult::InProgress);
|
||||
assert_eq!(election.record_vote(3, true), ElectionResult::Won);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_election_rejected_votes() {
|
||||
let mut election = Election::new(1, 1, 5);
|
||||
|
||||
// Rejected votes don't count
|
||||
assert_eq!(election.record_vote(2, false), ElectionResult::InProgress);
|
||||
assert_eq!(election.record_vote(3, false), ElectionResult::InProgress);
|
||||
assert_eq!(election.vote_count(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_vote_handler_grant() {
|
||||
let mut state = RaftState::new();
|
||||
let log = ReplicatedLog::new();
|
||||
|
||||
let request = RequestVote::new(1, 2, 0, 0);
|
||||
let response = VoteHandler::handle_request(&mut state, &log, &request);
|
||||
|
||||
assert!(response.vote_granted);
|
||||
assert_eq!(state.voted_for, Some(2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_vote_handler_deny_old_term() {
|
||||
let mut state = RaftState::new();
|
||||
state.current_term = 5;
|
||||
let log = ReplicatedLog::new();
|
||||
|
||||
let request = RequestVote::new(3, 2, 10, 3);
|
||||
let response = VoteHandler::handle_request(&mut state, &log, &request);
|
||||
|
||||
assert!(!response.vote_granted);
|
||||
assert_eq!(response.term, 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_vote_handler_deny_already_voted() {
|
||||
let mut state = RaftState::new();
|
||||
state.current_term = 1; // Same term as request
|
||||
state.voted_for = Some(3); // Already voted for node 3 in this term
|
||||
let log = ReplicatedLog::new();
|
||||
|
||||
// Request from node 2 for the same term - should be denied
|
||||
let request = RequestVote::new(1, 2, 0, 0);
|
||||
let response = VoteHandler::handle_request(&mut state, &log, &request);
|
||||
|
||||
assert!(!response.vote_granted);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_vote_handler_deny_log_behind() {
|
||||
let mut state = RaftState::new();
|
||||
let log = ReplicatedLog::new();
|
||||
log.append(LogEntry::new(2, 1, Command::Noop)); // Our log has term 2
|
||||
|
||||
let request = RequestVote::new(2, 2, 10, 1); // Candidate has term 1 entries
|
||||
let response = VoteHandler::handle_request(&mut state, &log, &request);
|
||||
|
||||
assert!(!response.vote_granted);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_election_timeout() {
|
||||
let timeout = ElectionTimeout::new(150, 300);
|
||||
|
||||
for _ in 0..10 {
|
||||
let duration = timeout.random_timeout();
|
||||
assert!(duration >= std::time::Duration::from_millis(150));
|
||||
assert!(duration <= std::time::Duration::from_millis(300));
|
||||
}
|
||||
}
|
||||
}
|
||||
387
crates/synor-database/src/replication/log.rs
Normal file
387
crates/synor-database/src/replication/log.rs
Normal file
|
|
@ -0,0 +1,387 @@
|
|||
//! Replicated log for Raft consensus.
|
||||
|
||||
use super::state::Command;
|
||||
use parking_lot::RwLock;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// A single entry in the replicated log.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct LogEntry {
|
||||
/// Term when entry was received by leader.
|
||||
pub term: u64,
|
||||
/// Position in the log (1-indexed).
|
||||
pub index: u64,
|
||||
/// Command to execute.
|
||||
pub command: Command,
|
||||
/// Timestamp when entry was created.
|
||||
pub timestamp: u64,
|
||||
}
|
||||
|
||||
impl LogEntry {
|
||||
/// Creates a new log entry.
|
||||
pub fn new(term: u64, index: u64, command: Command) -> Self {
|
||||
let timestamp = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_millis() as u64;
|
||||
|
||||
Self {
|
||||
term,
|
||||
index,
|
||||
command,
|
||||
timestamp,
|
||||
}
|
||||
}
|
||||
|
||||
/// Serializes to bytes.
|
||||
pub fn to_bytes(&self) -> Vec<u8> {
|
||||
bincode::serialize(self).unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Deserializes from bytes.
|
||||
pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
|
||||
bincode::deserialize(bytes).ok()
|
||||
}
|
||||
}
|
||||
|
||||
/// Replicated log storing all commands.
|
||||
pub struct ReplicatedLog {
|
||||
/// Log entries.
|
||||
entries: RwLock<Vec<LogEntry>>,
|
||||
/// Index of first entry in the log (for log compaction).
|
||||
start_index: RwLock<u64>,
|
||||
/// Term of last included entry (for snapshots).
|
||||
snapshot_term: RwLock<u64>,
|
||||
}
|
||||
|
||||
impl ReplicatedLog {
|
||||
/// Creates a new empty log.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
entries: RwLock::new(Vec::new()),
|
||||
start_index: RwLock::new(1), // 1-indexed
|
||||
snapshot_term: RwLock::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the index of the last entry.
|
||||
pub fn last_index(&self) -> u64 {
|
||||
let entries = self.entries.read();
|
||||
let start = *self.start_index.read();
|
||||
if entries.is_empty() {
|
||||
start.saturating_sub(1)
|
||||
} else {
|
||||
start + entries.len() as u64 - 1
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the term of the last entry.
|
||||
pub fn last_term(&self) -> u64 {
|
||||
let entries = self.entries.read();
|
||||
if entries.is_empty() {
|
||||
*self.snapshot_term.read()
|
||||
} else {
|
||||
entries.last().map(|e| e.term).unwrap_or(0)
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the term at a given index.
|
||||
pub fn term_at(&self, index: u64) -> Option<u64> {
|
||||
let start = *self.start_index.read();
|
||||
if index < start {
|
||||
return if index == start - 1 {
|
||||
Some(*self.snapshot_term.read())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
}
|
||||
|
||||
let entries = self.entries.read();
|
||||
let offset = (index - start) as usize;
|
||||
entries.get(offset).map(|e| e.term)
|
||||
}
|
||||
|
||||
/// Gets an entry by index.
|
||||
pub fn get(&self, index: u64) -> Option<LogEntry> {
|
||||
let start = *self.start_index.read();
|
||||
if index < start {
|
||||
return None;
|
||||
}
|
||||
|
||||
let entries = self.entries.read();
|
||||
let offset = (index - start) as usize;
|
||||
entries.get(offset).cloned()
|
||||
}
|
||||
|
||||
/// Gets entries from start_index to end_index (inclusive).
|
||||
pub fn get_range(&self, start_idx: u64, end_idx: u64) -> Vec<LogEntry> {
|
||||
let start = *self.start_index.read();
|
||||
if start_idx > end_idx || start_idx < start {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let entries = self.entries.read();
|
||||
let start_offset = (start_idx - start) as usize;
|
||||
let end_offset = (end_idx - start + 1) as usize;
|
||||
|
||||
entries
|
||||
.get(start_offset..end_offset.min(entries.len()))
|
||||
.map(|s| s.to_vec())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Gets all entries from a given index.
|
||||
pub fn entries_from(&self, from_index: u64) -> Vec<LogEntry> {
|
||||
let start = *self.start_index.read();
|
||||
if from_index < start {
|
||||
return self.entries.read().clone();
|
||||
}
|
||||
|
||||
let entries = self.entries.read();
|
||||
let offset = (from_index - start) as usize;
|
||||
entries.get(offset..).map(|s| s.to_vec()).unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Appends an entry to the log.
|
||||
pub fn append(&self, entry: LogEntry) -> u64 {
|
||||
let mut entries = self.entries.write();
|
||||
let index = entry.index;
|
||||
entries.push(entry);
|
||||
index
|
||||
}
|
||||
|
||||
/// Appends multiple entries, potentially overwriting conflicting entries.
|
||||
pub fn append_entries(&self, prev_index: u64, prev_term: u64, new_entries: Vec<LogEntry>) -> bool {
|
||||
// Check that prev entry matches
|
||||
if prev_index > 0 {
|
||||
if let Some(prev_entry_term) = self.term_at(prev_index) {
|
||||
if prev_entry_term != prev_term {
|
||||
// Conflict - need to truncate
|
||||
return false;
|
||||
}
|
||||
} else if prev_index >= *self.start_index.read() {
|
||||
// Missing entry
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
let mut entries = self.entries.write();
|
||||
let start = *self.start_index.read();
|
||||
|
||||
for entry in new_entries {
|
||||
let offset = (entry.index - start) as usize;
|
||||
|
||||
if offset < entries.len() {
|
||||
// Check for conflict
|
||||
if entries[offset].term != entry.term {
|
||||
// Delete this and all following entries
|
||||
entries.truncate(offset);
|
||||
entries.push(entry);
|
||||
}
|
||||
// Otherwise entry already exists with same term, skip
|
||||
} else {
|
||||
entries.push(entry);
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
/// Truncates the log after the given index.
|
||||
pub fn truncate_after(&self, index: u64) {
|
||||
let start = *self.start_index.read();
|
||||
if index < start {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut entries = self.entries.write();
|
||||
let offset = (index - start + 1) as usize;
|
||||
entries.truncate(offset);
|
||||
}
|
||||
|
||||
/// Returns the number of entries.
|
||||
pub fn len(&self) -> usize {
|
||||
self.entries.read().len()
|
||||
}
|
||||
|
||||
/// Returns true if the log is empty.
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.entries.read().is_empty()
|
||||
}
|
||||
|
||||
/// Compacts the log up to (and including) the given index.
|
||||
pub fn compact(&self, up_to_index: u64, up_to_term: u64) {
|
||||
let mut entries = self.entries.write();
|
||||
let mut start = self.start_index.write();
|
||||
let mut snapshot_term = self.snapshot_term.write();
|
||||
|
||||
if up_to_index < *start {
|
||||
return;
|
||||
}
|
||||
|
||||
let remove_count = (up_to_index - *start + 1) as usize;
|
||||
if remove_count >= entries.len() {
|
||||
entries.clear();
|
||||
} else {
|
||||
entries.drain(0..remove_count);
|
||||
}
|
||||
|
||||
*start = up_to_index + 1;
|
||||
*snapshot_term = up_to_term;
|
||||
}
|
||||
|
||||
/// Checks if this log is at least as up-to-date as the candidate's log.
|
||||
/// Used during leader election.
|
||||
pub fn is_up_to_date(&self, candidate_last_index: u64, candidate_last_term: u64) -> bool {
|
||||
let last_term = self.last_term();
|
||||
let last_index = self.last_index();
|
||||
|
||||
// Compare terms first, then indices
|
||||
if last_term != candidate_last_term {
|
||||
last_term <= candidate_last_term
|
||||
} else {
|
||||
last_index <= candidate_last_index
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates entries for replication starting from a given index.
|
||||
pub fn entries_for_replication(&self, from_index: u64, max_entries: usize) -> (u64, u64, Vec<LogEntry>) {
|
||||
let prev_index = from_index.saturating_sub(1);
|
||||
let prev_term = self.term_at(prev_index).unwrap_or(0);
|
||||
|
||||
let entries = self.entries_from(from_index);
|
||||
let limited = if entries.len() > max_entries {
|
||||
entries[..max_entries].to_vec()
|
||||
} else {
|
||||
entries
|
||||
};
|
||||
|
||||
(prev_index, prev_term, limited)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ReplicatedLog {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_log_basic() {
|
||||
let log = ReplicatedLog::new();
|
||||
|
||||
assert_eq!(log.last_index(), 0);
|
||||
assert_eq!(log.last_term(), 0);
|
||||
assert!(log.is_empty());
|
||||
|
||||
let entry = LogEntry::new(1, 1, Command::Noop);
|
||||
log.append(entry);
|
||||
|
||||
assert_eq!(log.last_index(), 1);
|
||||
assert_eq!(log.last_term(), 1);
|
||||
assert!(!log.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_log_append() {
|
||||
let log = ReplicatedLog::new();
|
||||
|
||||
for i in 1..=5 {
|
||||
log.append(LogEntry::new(1, i, Command::Noop));
|
||||
}
|
||||
|
||||
assert_eq!(log.len(), 5);
|
||||
assert_eq!(log.last_index(), 5);
|
||||
|
||||
let entry = log.get(3).unwrap();
|
||||
assert_eq!(entry.index, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_log_range() {
|
||||
let log = ReplicatedLog::new();
|
||||
|
||||
for i in 1..=10 {
|
||||
log.append(LogEntry::new(1, i, Command::Noop));
|
||||
}
|
||||
|
||||
let range = log.get_range(3, 7);
|
||||
assert_eq!(range.len(), 5);
|
||||
assert_eq!(range[0].index, 3);
|
||||
assert_eq!(range[4].index, 7);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_log_compact() {
|
||||
let log = ReplicatedLog::new();
|
||||
|
||||
for i in 1..=10 {
|
||||
log.append(LogEntry::new(1, i, Command::Noop));
|
||||
}
|
||||
|
||||
log.compact(5, 1);
|
||||
|
||||
assert_eq!(log.len(), 5);
|
||||
assert!(log.get(5).is_none());
|
||||
assert!(log.get(6).is_some());
|
||||
assert_eq!(log.last_index(), 10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_log_up_to_date() {
|
||||
let log = ReplicatedLog::new();
|
||||
|
||||
log.append(LogEntry::new(1, 1, Command::Noop));
|
||||
log.append(LogEntry::new(2, 2, Command::Noop));
|
||||
|
||||
// Same log is up to date
|
||||
assert!(log.is_up_to_date(2, 2));
|
||||
|
||||
// Higher term is more up to date
|
||||
assert!(log.is_up_to_date(1, 3));
|
||||
|
||||
// Same term, higher index is more up to date
|
||||
assert!(log.is_up_to_date(3, 2));
|
||||
|
||||
// Lower term is not up to date
|
||||
assert!(!log.is_up_to_date(10, 1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_append_entries() {
|
||||
let log = ReplicatedLog::new();
|
||||
|
||||
// Initial entries
|
||||
log.append(LogEntry::new(1, 1, Command::Noop));
|
||||
log.append(LogEntry::new(1, 2, Command::Noop));
|
||||
|
||||
// Append more entries
|
||||
let new_entries = vec![
|
||||
LogEntry::new(1, 3, Command::Noop),
|
||||
LogEntry::new(2, 4, Command::Noop),
|
||||
];
|
||||
|
||||
let success = log.append_entries(2, 1, new_entries);
|
||||
assert!(success);
|
||||
assert_eq!(log.len(), 4);
|
||||
assert_eq!(log.term_at(4), Some(2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_entries_for_replication() {
|
||||
let log = ReplicatedLog::new();
|
||||
|
||||
for i in 1..=5 {
|
||||
log.append(LogEntry::new(1, i, Command::Noop));
|
||||
}
|
||||
|
||||
let (prev_idx, prev_term, entries) = log.entries_for_replication(3, 10);
|
||||
assert_eq!(prev_idx, 2);
|
||||
assert_eq!(prev_term, 1);
|
||||
assert_eq!(entries.len(), 3);
|
||||
}
|
||||
}
|
||||
23
crates/synor-database/src/replication/mod.rs
Normal file
23
crates/synor-database/src/replication/mod.rs
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
//! Raft consensus-based replication for high availability.
|
||||
//!
|
||||
//! Provides distributed consensus to ensure data consistency across
|
||||
//! multiple database nodes with automatic leader election and failover.
|
||||
|
||||
pub mod cluster;
|
||||
pub mod election;
|
||||
pub mod log;
|
||||
pub mod raft;
|
||||
pub mod rpc;
|
||||
pub mod snapshot;
|
||||
pub mod state;
|
||||
|
||||
pub use cluster::{ClusterBuilder, ClusterConfig, NodeId, PeerAddress, PeerInfo, PeerState};
|
||||
pub use election::{Election, ElectionResult, ElectionTimeout, VoteHandler};
|
||||
pub use log::{LogEntry, ReplicatedLog};
|
||||
pub use raft::{ApplyResult, RaftConfig, RaftEvent, RaftNode};
|
||||
pub use rpc::{
|
||||
AppendEntries, AppendEntriesResponse, InstallSnapshot, InstallSnapshotResponse, RequestVote,
|
||||
RequestVoteResponse, RpcMessage,
|
||||
};
|
||||
pub use snapshot::{Snapshot, SnapshotConfig, SnapshotManager, SnapshotMetadata};
|
||||
pub use state::{Command, LeaderState, NodeRole, RaftState};
|
||||
955
crates/synor-database/src/replication/raft.rs
Normal file
955
crates/synor-database/src/replication/raft.rs
Normal file
|
|
@ -0,0 +1,955 @@
|
|||
//! Raft consensus implementation.
|
||||
|
||||
use super::cluster::{ClusterConfig, NodeId, PeerState};
|
||||
use super::election::{Election, ElectionResult, ElectionTimeout, VoteHandler};
|
||||
use super::log::{LogEntry, ReplicatedLog};
|
||||
use super::rpc::{
|
||||
AppendEntries, AppendEntriesResponse, InstallSnapshot, InstallSnapshotResponse, RequestVote,
|
||||
RequestVoteResponse, RpcMessage,
|
||||
};
|
||||
use super::snapshot::{Snapshot, SnapshotConfig, SnapshotManager};
|
||||
use super::state::{Command, LeaderState, NodeRole, RaftState};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
/// Configuration for the Raft node.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct RaftConfig {
|
||||
/// Minimum election timeout in milliseconds.
|
||||
pub election_timeout_min: u64,
|
||||
/// Maximum election timeout in milliseconds.
|
||||
pub election_timeout_max: u64,
|
||||
/// Heartbeat interval in milliseconds.
|
||||
pub heartbeat_interval: u64,
|
||||
/// Maximum entries per AppendEntries RPC.
|
||||
pub max_entries_per_rpc: usize,
|
||||
/// Snapshot threshold (entries before compaction).
|
||||
pub snapshot_threshold: u64,
|
||||
/// Maximum snapshot chunk size.
|
||||
pub snapshot_chunk_size: usize,
|
||||
}
|
||||
|
||||
impl Default for RaftConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
election_timeout_min: 150,
|
||||
election_timeout_max: 300,
|
||||
heartbeat_interval: 50,
|
||||
max_entries_per_rpc: 100,
|
||||
snapshot_threshold: 10000,
|
||||
snapshot_chunk_size: 65536,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of applying a command.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub enum ApplyResult {
|
||||
/// Command applied successfully.
|
||||
Success(Vec<u8>),
|
||||
/// Command failed.
|
||||
Error(String),
|
||||
/// Not the leader, redirect to leader.
|
||||
NotLeader(Option<NodeId>),
|
||||
}
|
||||
|
||||
/// Events that can be produced by the Raft node.
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum RaftEvent {
|
||||
/// Send RPC to a peer.
|
||||
SendRpc(NodeId, RpcMessage),
|
||||
/// Broadcast RPC to all peers.
|
||||
BroadcastRpc(RpcMessage),
|
||||
/// Apply committed entry to state machine.
|
||||
ApplyEntry(u64, Command),
|
||||
/// Became leader.
|
||||
BecameLeader,
|
||||
/// Became follower.
|
||||
BecameFollower(Option<NodeId>),
|
||||
/// Snapshot should be taken.
|
||||
TakeSnapshot,
|
||||
/// Log compacted up to index.
|
||||
LogCompacted(u64),
|
||||
}
|
||||
|
||||
/// The Raft consensus node.
|
||||
pub struct RaftNode {
|
||||
/// Node ID.
|
||||
id: NodeId,
|
||||
/// Cluster configuration.
|
||||
cluster: ClusterConfig,
|
||||
/// Raft configuration.
|
||||
config: RaftConfig,
|
||||
/// Persistent state.
|
||||
state: RaftState,
|
||||
/// Replicated log.
|
||||
log: ReplicatedLog,
|
||||
/// Current election (only during candidacy).
|
||||
election: Option<Election>,
|
||||
/// Election timeout generator.
|
||||
election_timeout: ElectionTimeout,
|
||||
/// Snapshot manager.
|
||||
snapshots: SnapshotManager,
|
||||
/// Leader state (only valid when leader).
|
||||
leader_state: Option<LeaderState>,
|
||||
/// Current known leader.
|
||||
leader_id: Option<NodeId>,
|
||||
/// Last heartbeat/message from leader.
|
||||
last_leader_contact: Instant,
|
||||
/// Current election timeout duration.
|
||||
current_timeout: Duration,
|
||||
/// Pending events.
|
||||
events: Vec<RaftEvent>,
|
||||
}
|
||||
|
||||
impl RaftNode {
|
||||
/// Creates a new Raft node.
|
||||
pub fn new(id: NodeId, cluster: ClusterConfig, config: RaftConfig) -> Self {
|
||||
let election_timeout =
|
||||
ElectionTimeout::new(config.election_timeout_min, config.election_timeout_max);
|
||||
let current_timeout = election_timeout.random_timeout();
|
||||
|
||||
Self {
|
||||
id,
|
||||
cluster,
|
||||
state: RaftState::new(),
|
||||
log: ReplicatedLog::new(),
|
||||
election: None,
|
||||
election_timeout,
|
||||
snapshots: SnapshotManager::new(config.snapshot_threshold),
|
||||
leader_state: None,
|
||||
leader_id: None,
|
||||
last_leader_contact: Instant::now(),
|
||||
current_timeout,
|
||||
events: Vec::new(),
|
||||
config,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the node ID.
|
||||
pub fn id(&self) -> NodeId {
|
||||
self.id
|
||||
}
|
||||
|
||||
/// Returns the current term.
|
||||
pub fn current_term(&self) -> u64 {
|
||||
self.state.current_term
|
||||
}
|
||||
|
||||
/// Returns the current role.
|
||||
pub fn role(&self) -> NodeRole {
|
||||
self.state.role
|
||||
}
|
||||
|
||||
/// Returns the current leader ID if known.
|
||||
pub fn leader(&self) -> Option<NodeId> {
|
||||
self.leader_id
|
||||
}
|
||||
|
||||
/// Returns true if this node is the leader.
|
||||
pub fn is_leader(&self) -> bool {
|
||||
self.state.is_leader()
|
||||
}
|
||||
|
||||
/// Returns the commit index.
|
||||
pub fn commit_index(&self) -> u64 {
|
||||
self.state.commit_index
|
||||
}
|
||||
|
||||
/// Returns the last applied index.
|
||||
pub fn last_applied(&self) -> u64 {
|
||||
self.state.last_applied
|
||||
}
|
||||
|
||||
/// Returns the log length.
|
||||
pub fn log_length(&self) -> u64 {
|
||||
self.log.last_index()
|
||||
}
|
||||
|
||||
/// Drains pending events.
|
||||
pub fn drain_events(&mut self) -> Vec<RaftEvent> {
|
||||
std::mem::take(&mut self.events)
|
||||
}
|
||||
|
||||
/// Called periodically to drive the Raft state machine.
|
||||
pub fn tick(&mut self) {
|
||||
match self.state.role {
|
||||
NodeRole::Leader => self.tick_leader(),
|
||||
NodeRole::Follower => self.tick_follower(),
|
||||
NodeRole::Candidate => self.tick_candidate(),
|
||||
}
|
||||
|
||||
// Apply committed entries
|
||||
self.apply_committed_entries();
|
||||
|
||||
// Check if snapshot needed
|
||||
if self
|
||||
.snapshots
|
||||
.should_snapshot(self.log.last_index(), self.snapshots.last_included_index())
|
||||
{
|
||||
self.events.push(RaftEvent::TakeSnapshot);
|
||||
}
|
||||
}
|
||||
|
||||
fn tick_leader(&mut self) {
|
||||
// Send heartbeats to all peers
|
||||
self.send_heartbeats();
|
||||
}
|
||||
|
||||
fn tick_follower(&mut self) {
|
||||
// Check for election timeout
|
||||
if self.last_leader_contact.elapsed() >= self.current_timeout {
|
||||
self.start_election();
|
||||
}
|
||||
}
|
||||
|
||||
fn tick_candidate(&mut self) {
|
||||
// Check for election timeout
|
||||
if self.last_leader_contact.elapsed() >= self.current_timeout {
|
||||
// Start a new election
|
||||
self.start_election();
|
||||
}
|
||||
}
|
||||
|
||||
fn start_election(&mut self) {
|
||||
// Increment term and become candidate
|
||||
self.state.become_candidate();
|
||||
self.state.voted_for = Some(self.id);
|
||||
self.leader_id = None;
|
||||
|
||||
// Reset timeout
|
||||
self.reset_election_timeout();
|
||||
|
||||
// Create new election
|
||||
let cluster_size = self.cluster.voting_members();
|
||||
self.election = Some(Election::new(self.id, self.state.current_term, cluster_size));
|
||||
|
||||
// Create RequestVote message
|
||||
let request = RequestVote::new(
|
||||
self.state.current_term,
|
||||
self.id,
|
||||
self.log.last_index(),
|
||||
self.log.last_term(),
|
||||
);
|
||||
|
||||
self.events
|
||||
.push(RaftEvent::BroadcastRpc(RpcMessage::RequestVote(request)));
|
||||
|
||||
// Check if we already have quorum (single-node cluster)
|
||||
if cluster_size == 1 {
|
||||
self.become_leader();
|
||||
}
|
||||
}
|
||||
|
||||
fn become_leader(&mut self) {
|
||||
self.state.become_leader();
|
||||
self.leader_id = Some(self.id);
|
||||
self.election = None;
|
||||
|
||||
// Initialize leader state
|
||||
let peer_ids: Vec<_> = self.cluster.peer_ids();
|
||||
self.leader_state = Some(LeaderState::new(self.log.last_index(), &peer_ids));
|
||||
|
||||
self.events.push(RaftEvent::BecameLeader);
|
||||
|
||||
// Send immediate heartbeats
|
||||
self.send_heartbeats();
|
||||
}
|
||||
|
||||
fn become_follower(&mut self, term: u64, leader: Option<NodeId>) {
|
||||
self.state.become_follower(term);
|
||||
self.leader_id = leader;
|
||||
self.leader_state = None;
|
||||
self.election = None;
|
||||
self.reset_election_timeout();
|
||||
|
||||
self.events.push(RaftEvent::BecameFollower(leader));
|
||||
}
|
||||
|
||||
fn reset_election_timeout(&mut self) {
|
||||
self.last_leader_contact = Instant::now();
|
||||
self.current_timeout = self.election_timeout.random_timeout();
|
||||
}
|
||||
|
||||
fn send_heartbeats(&mut self) {
|
||||
if !self.is_leader() {
|
||||
return;
|
||||
}
|
||||
|
||||
for peer_id in self.cluster.peer_ids() {
|
||||
self.send_append_entries(peer_id);
|
||||
}
|
||||
}
|
||||
|
||||
fn send_append_entries(&mut self, peer_id: NodeId) {
|
||||
let leader_state = match &self.leader_state {
|
||||
Some(ls) => ls,
|
||||
None => return,
|
||||
};
|
||||
|
||||
let next_index = *leader_state.next_index.get(&peer_id).unwrap_or(&1);
|
||||
|
||||
// Check if we need to send snapshot instead
|
||||
if next_index <= self.snapshots.last_included_index() {
|
||||
self.send_install_snapshot(peer_id);
|
||||
return;
|
||||
}
|
||||
|
||||
let (prev_log_index, prev_log_term, entries) =
|
||||
self.log
|
||||
.entries_for_replication(next_index, self.config.max_entries_per_rpc);
|
||||
|
||||
let request = AppendEntries::with_entries(
|
||||
self.state.current_term,
|
||||
self.id,
|
||||
prev_log_index,
|
||||
prev_log_term,
|
||||
entries,
|
||||
self.state.commit_index,
|
||||
);
|
||||
|
||||
self.events
|
||||
.push(RaftEvent::SendRpc(peer_id, RpcMessage::AppendEntries(request)));
|
||||
}
|
||||
|
||||
fn send_install_snapshot(&mut self, peer_id: NodeId) {
|
||||
let snapshot = match self.snapshots.get_snapshot() {
|
||||
Some(s) => s,
|
||||
None => return,
|
||||
};
|
||||
|
||||
let chunks = self
|
||||
.snapshots
|
||||
.chunk_snapshot(self.config.snapshot_chunk_size);
|
||||
if let Some((offset, data, done)) = chunks.into_iter().next() {
|
||||
let request = InstallSnapshot::new(
|
||||
self.state.current_term,
|
||||
self.id,
|
||||
snapshot.metadata.last_included_index,
|
||||
snapshot.metadata.last_included_term,
|
||||
offset,
|
||||
data,
|
||||
done,
|
||||
);
|
||||
|
||||
self.events
|
||||
.push(RaftEvent::SendRpc(peer_id, RpcMessage::InstallSnapshot(request)));
|
||||
}
|
||||
}
|
||||
|
||||
/// Handles an incoming RPC message.
|
||||
pub fn handle_rpc(&mut self, from: NodeId, message: RpcMessage) -> Option<RpcMessage> {
|
||||
match message {
|
||||
RpcMessage::RequestVote(req) => {
|
||||
let response = self.handle_request_vote(from, req);
|
||||
Some(RpcMessage::RequestVoteResponse(response))
|
||||
}
|
||||
RpcMessage::RequestVoteResponse(resp) => {
|
||||
self.handle_request_vote_response(from, resp);
|
||||
None
|
||||
}
|
||||
RpcMessage::AppendEntries(req) => {
|
||||
let response = self.handle_append_entries(from, req);
|
||||
Some(RpcMessage::AppendEntriesResponse(response))
|
||||
}
|
||||
RpcMessage::AppendEntriesResponse(resp) => {
|
||||
self.handle_append_entries_response(from, resp);
|
||||
None
|
||||
}
|
||||
RpcMessage::InstallSnapshot(req) => {
|
||||
let response = self.handle_install_snapshot(from, req);
|
||||
Some(RpcMessage::InstallSnapshotResponse(response))
|
||||
}
|
||||
RpcMessage::InstallSnapshotResponse(resp) => {
|
||||
self.handle_install_snapshot_response(from, resp);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_request_vote(&mut self, _from: NodeId, req: RequestVote) -> RequestVoteResponse {
|
||||
// Use the VoteHandler from election module
|
||||
VoteHandler::handle_request(&mut self.state, &self.log, &req)
|
||||
}
|
||||
|
||||
fn handle_request_vote_response(&mut self, from: NodeId, resp: RequestVoteResponse) {
|
||||
// Ignore if not candidate
|
||||
if !self.state.is_candidate() {
|
||||
return;
|
||||
}
|
||||
|
||||
// Use the VoteHandler
|
||||
if let Some(ref mut election) = self.election {
|
||||
let result = VoteHandler::handle_response(&mut self.state, election, from, &resp);
|
||||
|
||||
match result {
|
||||
ElectionResult::Won => {
|
||||
self.become_leader();
|
||||
}
|
||||
ElectionResult::Lost => {
|
||||
// Already handled by VoteHandler (became follower)
|
||||
self.election = None;
|
||||
}
|
||||
ElectionResult::InProgress | ElectionResult::Timeout => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_append_entries(&mut self, _from: NodeId, req: AppendEntries) -> AppendEntriesResponse {
|
||||
// Rule: If term > currentTerm, become follower
|
||||
if req.term > self.state.current_term {
|
||||
self.become_follower(req.term, Some(req.leader_id));
|
||||
}
|
||||
|
||||
// Reject if term is old
|
||||
if req.term < self.state.current_term {
|
||||
return AppendEntriesResponse::failure(self.state.current_term);
|
||||
}
|
||||
|
||||
// Valid AppendEntries from leader - reset election timeout
|
||||
self.reset_election_timeout();
|
||||
self.leader_id = Some(req.leader_id);
|
||||
|
||||
// If we're candidate, step down
|
||||
if self.state.is_candidate() {
|
||||
self.become_follower(req.term, Some(req.leader_id));
|
||||
}
|
||||
|
||||
// Try to append entries
|
||||
let success =
|
||||
self.log
|
||||
.append_entries(req.prev_log_index, req.prev_log_term, req.entries);
|
||||
|
||||
if success {
|
||||
// Update commit index
|
||||
if req.leader_commit > self.state.commit_index {
|
||||
self.state
|
||||
.update_commit_index(std::cmp::min(req.leader_commit, self.log.last_index()));
|
||||
}
|
||||
|
||||
AppendEntriesResponse::success(self.state.current_term, self.log.last_index())
|
||||
} else {
|
||||
// Find conflict info for faster recovery
|
||||
if let Some(conflict_term) = self.log.term_at(req.prev_log_index) {
|
||||
// Find first index of conflicting term
|
||||
let mut conflict_index = req.prev_log_index;
|
||||
while conflict_index > 1 {
|
||||
if let Some(term) = self.log.term_at(conflict_index - 1) {
|
||||
if term != conflict_term {
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
conflict_index -= 1;
|
||||
}
|
||||
AppendEntriesResponse::conflict(self.state.current_term, conflict_term, conflict_index)
|
||||
} else {
|
||||
AppendEntriesResponse::failure(self.state.current_term)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_append_entries_response(&mut self, from: NodeId, resp: AppendEntriesResponse) {
|
||||
// Ignore if not leader
|
||||
if !self.is_leader() {
|
||||
return;
|
||||
}
|
||||
|
||||
// Step down if term is higher
|
||||
if resp.term > self.state.current_term {
|
||||
self.become_follower(resp.term, None);
|
||||
return;
|
||||
}
|
||||
|
||||
let leader_state = match &mut self.leader_state {
|
||||
Some(ls) => ls,
|
||||
None => return,
|
||||
};
|
||||
|
||||
if resp.success {
|
||||
// Update match_index and next_index
|
||||
leader_state.update_indices(from, resp.match_index);
|
||||
|
||||
// Try to advance commit index
|
||||
self.try_advance_commit_index();
|
||||
} else {
|
||||
// Use conflict info if available for faster recovery
|
||||
if let (Some(conflict_term), Some(conflict_index)) =
|
||||
(resp.conflict_term, resp.conflict_index)
|
||||
{
|
||||
// Search for last entry with conflict_term
|
||||
let mut new_next = conflict_index;
|
||||
for idx in (1..=self.log.last_index()).rev() {
|
||||
if let Some(term) = self.log.term_at(idx) {
|
||||
if term == conflict_term {
|
||||
new_next = idx + 1;
|
||||
break;
|
||||
}
|
||||
if term < conflict_term {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
leader_state.next_index.insert(from, new_next);
|
||||
} else {
|
||||
// Simple decrement
|
||||
leader_state.decrement_next_index(from);
|
||||
}
|
||||
}
|
||||
|
||||
// Update peer state
|
||||
self.cluster.update_peer_state(from, PeerState::Reachable);
|
||||
}
|
||||
|
||||
fn handle_install_snapshot(&mut self, _from: NodeId, req: InstallSnapshot) -> InstallSnapshotResponse {
|
||||
// Rule: If term > currentTerm, become follower
|
||||
if req.term > self.state.current_term {
|
||||
self.become_follower(req.term, Some(req.leader_id));
|
||||
}
|
||||
|
||||
// Reject if term is old
|
||||
if req.term < self.state.current_term {
|
||||
return InstallSnapshotResponse::new(self.state.current_term);
|
||||
}
|
||||
|
||||
// Reset election timeout
|
||||
self.reset_election_timeout();
|
||||
self.leader_id = Some(req.leader_id);
|
||||
|
||||
// Start or continue receiving snapshot
|
||||
if req.offset == 0 {
|
||||
self.snapshots.start_receiving(
|
||||
req.last_included_index,
|
||||
req.last_included_term,
|
||||
req.offset,
|
||||
req.data,
|
||||
);
|
||||
} else {
|
||||
self.snapshots.add_chunk(req.offset, req.data);
|
||||
}
|
||||
|
||||
// If done, finalize snapshot
|
||||
if req.done {
|
||||
if let Some(_snapshot) = self.snapshots.finalize_snapshot() {
|
||||
// Discard log up to snapshot
|
||||
self.log
|
||||
.compact(req.last_included_index, req.last_included_term);
|
||||
|
||||
// Update state
|
||||
if req.last_included_index > self.state.commit_index {
|
||||
self.state.update_commit_index(req.last_included_index);
|
||||
}
|
||||
if req.last_included_index > self.state.last_applied {
|
||||
self.state.update_last_applied(req.last_included_index);
|
||||
}
|
||||
|
||||
self.events
|
||||
.push(RaftEvent::LogCompacted(req.last_included_index));
|
||||
}
|
||||
}
|
||||
|
||||
InstallSnapshotResponse::new(self.state.current_term)
|
||||
}
|
||||
|
||||
fn handle_install_snapshot_response(&mut self, from: NodeId, resp: InstallSnapshotResponse) {
|
||||
if !self.is_leader() {
|
||||
return;
|
||||
}
|
||||
|
||||
if resp.term > self.state.current_term {
|
||||
self.become_follower(resp.term, None);
|
||||
return;
|
||||
}
|
||||
|
||||
// Update next_index for peer (assuming success since there's no success field)
|
||||
if let Some(leader_state) = &mut self.leader_state {
|
||||
let snapshot_index = self.snapshots.last_included_index();
|
||||
leader_state.update_indices(from, snapshot_index);
|
||||
}
|
||||
|
||||
self.cluster.update_peer_state(from, PeerState::Reachable);
|
||||
}
|
||||
|
||||
fn try_advance_commit_index(&mut self) {
|
||||
if !self.is_leader() {
|
||||
return;
|
||||
}
|
||||
|
||||
let leader_state = match &self.leader_state {
|
||||
Some(ls) => ls,
|
||||
None => return,
|
||||
};
|
||||
|
||||
// Calculate new commit index
|
||||
let new_commit = leader_state.calculate_commit_index(
|
||||
self.state.commit_index,
|
||||
self.state.current_term,
|
||||
|idx| self.log.term_at(idx),
|
||||
);
|
||||
|
||||
if new_commit > self.state.commit_index {
|
||||
self.state.update_commit_index(new_commit);
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_committed_entries(&mut self) {
|
||||
while self.state.last_applied < self.state.commit_index {
|
||||
let next_to_apply = self.state.last_applied + 1;
|
||||
|
||||
if let Some(entry) = self.log.get(next_to_apply) {
|
||||
self.events
|
||||
.push(RaftEvent::ApplyEntry(entry.index, entry.command.clone()));
|
||||
self.state.update_last_applied(next_to_apply);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Proposes a command (only valid on leader).
|
||||
pub fn propose(&mut self, command: Command) -> Result<u64, ApplyResult> {
|
||||
if !self.is_leader() {
|
||||
return Err(ApplyResult::NotLeader(self.leader_id));
|
||||
}
|
||||
|
||||
// Create log entry
|
||||
let index = self.log.last_index() + 1;
|
||||
let entry = LogEntry::new(self.state.current_term, index, command);
|
||||
let result_index = self.log.append(entry);
|
||||
|
||||
// Send AppendEntries to all peers
|
||||
for peer_id in self.cluster.peer_ids() {
|
||||
self.send_append_entries(peer_id);
|
||||
}
|
||||
|
||||
Ok(result_index)
|
||||
}
|
||||
|
||||
/// Takes a snapshot of the state machine.
|
||||
pub fn take_snapshot(&mut self, data: Vec<u8>) {
|
||||
let last_index = self.state.last_applied;
|
||||
let last_term = self.log.term_at(last_index).unwrap_or(0);
|
||||
|
||||
let config = SnapshotConfig {
|
||||
nodes: std::iter::once(self.id)
|
||||
.chain(self.cluster.peer_ids())
|
||||
.collect(),
|
||||
};
|
||||
|
||||
self.snapshots
|
||||
.create_snapshot(last_index, last_term, config, data);
|
||||
|
||||
// Compact log
|
||||
self.log.compact(last_index, last_term);
|
||||
self.events.push(RaftEvent::LogCompacted(last_index));
|
||||
}
|
||||
|
||||
/// Adds a new server to the cluster.
|
||||
pub fn add_server(
|
||||
&mut self,
|
||||
id: NodeId,
|
||||
address: super::cluster::PeerAddress,
|
||||
) -> Result<(), String> {
|
||||
if !self.is_leader() {
|
||||
return Err("Not leader".to_string());
|
||||
}
|
||||
|
||||
let peer = super::cluster::PeerInfo::new(id, address);
|
||||
self.cluster.add_peer(peer);
|
||||
|
||||
// Initialize leader state for new peer
|
||||
if let Some(ref mut ls) = self.leader_state {
|
||||
ls.next_index.insert(id, self.log.last_index() + 1);
|
||||
ls.match_index.insert(id, 0);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Removes a server from the cluster.
|
||||
pub fn remove_server(&mut self, id: NodeId) -> Result<(), String> {
|
||||
if !self.is_leader() {
|
||||
return Err("Not leader".to_string());
|
||||
}
|
||||
|
||||
self.cluster.remove_peer(id);
|
||||
|
||||
if let Some(ref mut ls) = self.leader_state {
|
||||
ls.next_index.remove(&id);
|
||||
ls.match_index.remove(&id);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Forces an election timeout (for testing).
|
||||
#[cfg(test)]
|
||||
pub fn force_election_timeout(&mut self) {
|
||||
self.last_leader_contact = Instant::now() - self.current_timeout - Duration::from_secs(1);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use super::super::cluster::PeerAddress;
|
||||
|
||||
fn create_test_cluster(node_id: NodeId, peers: &[NodeId]) -> ClusterConfig {
|
||||
let mut cluster =
|
||||
ClusterConfig::new(node_id, PeerAddress::new("127.0.0.1", 9000 + node_id as u16));
|
||||
for &peer in peers {
|
||||
cluster.add_peer(super::super::cluster::PeerInfo::new(
|
||||
peer,
|
||||
PeerAddress::new("127.0.0.1", 9000 + peer as u16),
|
||||
));
|
||||
}
|
||||
cluster
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_single_node_election() {
|
||||
let cluster = create_test_cluster(1, &[]);
|
||||
let config = RaftConfig::default();
|
||||
let mut node = RaftNode::new(1, cluster, config);
|
||||
|
||||
assert!(matches!(node.role(), NodeRole::Follower));
|
||||
|
||||
// Simulate election timeout
|
||||
node.force_election_timeout();
|
||||
node.tick();
|
||||
|
||||
// Single node should become leader immediately
|
||||
assert!(node.is_leader());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_three_node_election() {
|
||||
let cluster = create_test_cluster(1, &[2, 3]);
|
||||
let config = RaftConfig::default();
|
||||
let mut node1 = RaftNode::new(1, cluster, config);
|
||||
|
||||
// Trigger election
|
||||
node1.force_election_timeout();
|
||||
node1.tick();
|
||||
|
||||
assert!(matches!(node1.role(), NodeRole::Candidate));
|
||||
assert_eq!(node1.current_term(), 1);
|
||||
|
||||
// Simulate receiving votes
|
||||
let vote_resp = RequestVoteResponse::grant(1);
|
||||
|
||||
node1.handle_rpc(2, RpcMessage::RequestVoteResponse(vote_resp.clone()));
|
||||
|
||||
// Should become leader after receiving vote from node 2 (2 votes = quorum of 3)
|
||||
assert!(node1.is_leader());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_append_entries() {
|
||||
let cluster = create_test_cluster(1, &[]);
|
||||
let config = RaftConfig::default();
|
||||
let mut leader = RaftNode::new(1, cluster, config);
|
||||
|
||||
// Become leader
|
||||
leader.force_election_timeout();
|
||||
leader.tick();
|
||||
assert!(leader.is_leader());
|
||||
|
||||
// Propose a command
|
||||
let command = Command::KvSet {
|
||||
key: "test".to_string(),
|
||||
value: vec![1, 2, 3],
|
||||
ttl: None,
|
||||
};
|
||||
let index = leader.propose(command).unwrap();
|
||||
|
||||
assert_eq!(index, 1);
|
||||
assert_eq!(leader.log_length(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_follower_receives_append_entries() {
|
||||
let cluster1 = create_test_cluster(1, &[2]);
|
||||
let cluster2 = create_test_cluster(2, &[1]);
|
||||
let config = RaftConfig::default();
|
||||
|
||||
let mut leader = RaftNode::new(1, cluster1, config.clone());
|
||||
let mut follower = RaftNode::new(2, cluster2, config);
|
||||
|
||||
// Make node 1 leader
|
||||
leader.force_election_timeout();
|
||||
leader.tick();
|
||||
|
||||
// Simulate vote from node 2
|
||||
leader.handle_rpc(
|
||||
2,
|
||||
RpcMessage::RequestVoteResponse(RequestVoteResponse::grant(1)),
|
||||
);
|
||||
|
||||
assert!(leader.is_leader());
|
||||
|
||||
// Propose a command
|
||||
leader
|
||||
.propose(Command::KvSet {
|
||||
key: "key1".to_string(),
|
||||
value: vec![1],
|
||||
ttl: None,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
// Get AppendEntries from leader events (skip heartbeats, find one with entries)
|
||||
let events = leader.drain_events();
|
||||
let append_req = events
|
||||
.iter()
|
||||
.find_map(|e| {
|
||||
if let RaftEvent::SendRpc(2, RpcMessage::AppendEntries(req)) = e {
|
||||
// Skip heartbeats (empty entries), find the one with actual entries
|
||||
if !req.entries.is_empty() {
|
||||
Some(req.clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
// Send to follower
|
||||
let response = follower
|
||||
.handle_rpc(1, RpcMessage::AppendEntries(append_req))
|
||||
.unwrap();
|
||||
|
||||
if let RpcMessage::AppendEntriesResponse(resp) = response {
|
||||
assert!(resp.success);
|
||||
assert_eq!(resp.match_index, 1);
|
||||
}
|
||||
|
||||
assert_eq!(follower.log_length(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_step_down_on_higher_term() {
|
||||
let cluster = create_test_cluster(1, &[2]);
|
||||
let config = RaftConfig::default();
|
||||
let mut node = RaftNode::new(1, cluster, config);
|
||||
|
||||
// Make it leader
|
||||
node.force_election_timeout();
|
||||
node.tick();
|
||||
node.handle_rpc(
|
||||
2,
|
||||
RpcMessage::RequestVoteResponse(RequestVoteResponse::grant(1)),
|
||||
);
|
||||
|
||||
assert!(node.is_leader());
|
||||
assert_eq!(node.current_term(), 1);
|
||||
|
||||
// Receive AppendEntries with higher term
|
||||
node.handle_rpc(
|
||||
2,
|
||||
RpcMessage::AppendEntries(AppendEntries::heartbeat(5, 2, 0, 0, 0)),
|
||||
);
|
||||
|
||||
assert!(!node.is_leader());
|
||||
assert_eq!(node.current_term(), 5);
|
||||
assert_eq!(node.leader(), Some(2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_commit_index_advancement() {
|
||||
let cluster = create_test_cluster(1, &[2, 3]);
|
||||
let config = RaftConfig::default();
|
||||
let mut leader = RaftNode::new(1, cluster, config);
|
||||
|
||||
// Become leader
|
||||
leader.force_election_timeout();
|
||||
leader.tick();
|
||||
leader.handle_rpc(
|
||||
2,
|
||||
RpcMessage::RequestVoteResponse(RequestVoteResponse::grant(1)),
|
||||
);
|
||||
leader.handle_rpc(
|
||||
3,
|
||||
RpcMessage::RequestVoteResponse(RequestVoteResponse::grant(1)),
|
||||
);
|
||||
|
||||
// Propose command
|
||||
leader
|
||||
.propose(Command::KvSet {
|
||||
key: "key".to_string(),
|
||||
value: vec![1],
|
||||
ttl: None,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
// Simulate successful replication to node 2
|
||||
leader.handle_rpc(
|
||||
2,
|
||||
RpcMessage::AppendEntriesResponse(AppendEntriesResponse::success(1, 1)),
|
||||
);
|
||||
|
||||
// Commit index should advance (quorum = 2, we have leader + node2)
|
||||
assert_eq!(leader.commit_index(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_snapshot() {
|
||||
let cluster = create_test_cluster(1, &[]);
|
||||
let config = RaftConfig::default();
|
||||
let mut node = RaftNode::new(1, cluster, config);
|
||||
|
||||
// Become leader
|
||||
node.force_election_timeout();
|
||||
node.tick();
|
||||
|
||||
// Propose and commit some entries
|
||||
for i in 0..5 {
|
||||
node.propose(Command::KvSet {
|
||||
key: format!("key{}", i),
|
||||
value: vec![i as u8],
|
||||
ttl: None,
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
// Manually advance commit and apply
|
||||
node.state.update_commit_index(5);
|
||||
node.tick();
|
||||
|
||||
// Take snapshot
|
||||
node.take_snapshot(vec![1, 2, 3, 4, 5]);
|
||||
|
||||
// Check snapshot was created
|
||||
assert!(node.snapshots.get_snapshot().is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_request_vote_log_check() {
|
||||
let cluster = create_test_cluster(1, &[2]);
|
||||
let config = RaftConfig::default();
|
||||
let mut node = RaftNode::new(1, cluster, config);
|
||||
|
||||
// Add some entries to node's log
|
||||
node.log.append(LogEntry::new(1, 1, Command::Noop));
|
||||
node.log.append(LogEntry::new(2, 2, Command::Noop));
|
||||
|
||||
// Request vote from candidate with shorter log (lower term)
|
||||
let req = RequestVote::new(3, 2, 1, 1);
|
||||
|
||||
let resp = node.handle_rpc(2, RpcMessage::RequestVote(req));
|
||||
if let Some(RpcMessage::RequestVoteResponse(r)) = resp {
|
||||
// Should not grant vote - our log is more up-to-date (term 2 > term 1)
|
||||
assert!(!r.vote_granted);
|
||||
}
|
||||
|
||||
// Request vote from candidate with equal or better log
|
||||
let req2 = RequestVote::new(4, 2, 2, 2);
|
||||
|
||||
let resp2 = node.handle_rpc(2, RpcMessage::RequestVote(req2));
|
||||
if let Some(RpcMessage::RequestVoteResponse(r)) = resp2 {
|
||||
assert!(r.vote_granted);
|
||||
}
|
||||
}
|
||||
}
|
||||
318
crates/synor-database/src/replication/rpc.rs
Normal file
318
crates/synor-database/src/replication/rpc.rs
Normal file
|
|
@ -0,0 +1,318 @@
|
|||
//! RPC messages for Raft consensus.
|
||||
|
||||
use super::log::LogEntry;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// All RPC message types in Raft.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub enum RpcMessage {
|
||||
/// Request vote from candidate.
|
||||
RequestVote(RequestVote),
|
||||
/// Response to vote request.
|
||||
RequestVoteResponse(RequestVoteResponse),
|
||||
/// Append entries from leader.
|
||||
AppendEntries(AppendEntries),
|
||||
/// Response to append entries.
|
||||
AppendEntriesResponse(AppendEntriesResponse),
|
||||
/// Install snapshot from leader.
|
||||
InstallSnapshot(InstallSnapshot),
|
||||
/// Response to install snapshot.
|
||||
InstallSnapshotResponse(InstallSnapshotResponse),
|
||||
}
|
||||
|
||||
impl RpcMessage {
|
||||
/// Returns the term of this message.
|
||||
pub fn term(&self) -> u64 {
|
||||
match self {
|
||||
RpcMessage::RequestVote(r) => r.term,
|
||||
RpcMessage::RequestVoteResponse(r) => r.term,
|
||||
RpcMessage::AppendEntries(r) => r.term,
|
||||
RpcMessage::AppendEntriesResponse(r) => r.term,
|
||||
RpcMessage::InstallSnapshot(r) => r.term,
|
||||
RpcMessage::InstallSnapshotResponse(r) => r.term,
|
||||
}
|
||||
}
|
||||
|
||||
/// Serializes to bytes.
|
||||
pub fn to_bytes(&self) -> Vec<u8> {
|
||||
bincode::serialize(self).unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Deserializes from bytes.
|
||||
pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
|
||||
bincode::deserialize(bytes).ok()
|
||||
}
|
||||
}
|
||||
|
||||
/// RequestVote RPC (sent by candidates to gather votes).
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct RequestVote {
|
||||
/// Candidate's term.
|
||||
pub term: u64,
|
||||
/// Candidate requesting vote.
|
||||
pub candidate_id: u64,
|
||||
/// Index of candidate's last log entry.
|
||||
pub last_log_index: u64,
|
||||
/// Term of candidate's last log entry.
|
||||
pub last_log_term: u64,
|
||||
}
|
||||
|
||||
impl RequestVote {
|
||||
/// Creates a new RequestVote message.
|
||||
pub fn new(term: u64, candidate_id: u64, last_log_index: u64, last_log_term: u64) -> Self {
|
||||
Self {
|
||||
term,
|
||||
candidate_id,
|
||||
last_log_index,
|
||||
last_log_term,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Response to RequestVote.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct RequestVoteResponse {
|
||||
/// Current term (for candidate to update).
|
||||
pub term: u64,
|
||||
/// True if candidate received vote.
|
||||
pub vote_granted: bool,
|
||||
}
|
||||
|
||||
impl RequestVoteResponse {
|
||||
/// Creates a positive response.
|
||||
pub fn grant(term: u64) -> Self {
|
||||
Self {
|
||||
term,
|
||||
vote_granted: true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a negative response.
|
||||
pub fn deny(term: u64) -> Self {
|
||||
Self {
|
||||
term,
|
||||
vote_granted: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// AppendEntries RPC (sent by leader for replication and heartbeat).
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct AppendEntries {
|
||||
/// Leader's term.
|
||||
pub term: u64,
|
||||
/// Leader ID (so follower can redirect clients).
|
||||
pub leader_id: u64,
|
||||
/// Index of log entry immediately preceding new ones.
|
||||
pub prev_log_index: u64,
|
||||
/// Term of prev_log_index entry.
|
||||
pub prev_log_term: u64,
|
||||
/// Log entries to store (empty for heartbeat).
|
||||
pub entries: Vec<LogEntry>,
|
||||
/// Leader's commit index.
|
||||
pub leader_commit: u64,
|
||||
}
|
||||
|
||||
impl AppendEntries {
|
||||
/// Creates a heartbeat (empty entries).
|
||||
pub fn heartbeat(
|
||||
term: u64,
|
||||
leader_id: u64,
|
||||
prev_log_index: u64,
|
||||
prev_log_term: u64,
|
||||
leader_commit: u64,
|
||||
) -> Self {
|
||||
Self {
|
||||
term,
|
||||
leader_id,
|
||||
prev_log_index,
|
||||
prev_log_term,
|
||||
entries: Vec::new(),
|
||||
leader_commit,
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates an append entries request with entries.
|
||||
pub fn with_entries(
|
||||
term: u64,
|
||||
leader_id: u64,
|
||||
prev_log_index: u64,
|
||||
prev_log_term: u64,
|
||||
entries: Vec<LogEntry>,
|
||||
leader_commit: u64,
|
||||
) -> Self {
|
||||
Self {
|
||||
term,
|
||||
leader_id,
|
||||
prev_log_index,
|
||||
prev_log_term,
|
||||
entries,
|
||||
leader_commit,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if this is a heartbeat (no entries).
|
||||
pub fn is_heartbeat(&self) -> bool {
|
||||
self.entries.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
/// Response to AppendEntries.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct AppendEntriesResponse {
|
||||
/// Current term (for leader to update).
|
||||
pub term: u64,
|
||||
/// True if follower contained entry matching prev_log_index and prev_log_term.
|
||||
pub success: bool,
|
||||
/// Index of last entry appended (for quick catch-up).
|
||||
pub match_index: u64,
|
||||
/// If false, the conflicting term (for optimization).
|
||||
pub conflict_term: Option<u64>,
|
||||
/// First index of conflicting term.
|
||||
pub conflict_index: Option<u64>,
|
||||
}
|
||||
|
||||
impl AppendEntriesResponse {
|
||||
/// Creates a success response.
|
||||
pub fn success(term: u64, match_index: u64) -> Self {
|
||||
Self {
|
||||
term,
|
||||
success: true,
|
||||
match_index,
|
||||
conflict_term: None,
|
||||
conflict_index: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a failure response.
|
||||
pub fn failure(term: u64) -> Self {
|
||||
Self {
|
||||
term,
|
||||
success: false,
|
||||
match_index: 0,
|
||||
conflict_term: None,
|
||||
conflict_index: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a failure response with conflict info.
|
||||
pub fn conflict(term: u64, conflict_term: u64, conflict_index: u64) -> Self {
|
||||
Self {
|
||||
term,
|
||||
success: false,
|
||||
match_index: 0,
|
||||
conflict_term: Some(conflict_term),
|
||||
conflict_index: Some(conflict_index),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// InstallSnapshot RPC (sent by leader when follower is too far behind).
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct InstallSnapshot {
|
||||
/// Leader's term.
|
||||
pub term: u64,
|
||||
/// Leader ID.
|
||||
pub leader_id: u64,
|
||||
/// Index of last entry included in snapshot.
|
||||
pub last_included_index: u64,
|
||||
/// Term of last entry included in snapshot.
|
||||
pub last_included_term: u64,
|
||||
/// Byte offset where chunk is positioned.
|
||||
pub offset: u64,
|
||||
/// Raw bytes of snapshot chunk.
|
||||
pub data: Vec<u8>,
|
||||
/// True if this is the last chunk.
|
||||
pub done: bool,
|
||||
}
|
||||
|
||||
impl InstallSnapshot {
|
||||
/// Creates a new snapshot installation request.
|
||||
pub fn new(
|
||||
term: u64,
|
||||
leader_id: u64,
|
||||
last_included_index: u64,
|
||||
last_included_term: u64,
|
||||
offset: u64,
|
||||
data: Vec<u8>,
|
||||
done: bool,
|
||||
) -> Self {
|
||||
Self {
|
||||
term,
|
||||
leader_id,
|
||||
last_included_index,
|
||||
last_included_term,
|
||||
offset,
|
||||
data,
|
||||
done,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Response to InstallSnapshot.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct InstallSnapshotResponse {
|
||||
/// Current term (for leader to update).
|
||||
pub term: u64,
|
||||
}
|
||||
|
||||
impl InstallSnapshotResponse {
|
||||
/// Creates a new response.
|
||||
pub fn new(term: u64) -> Self {
|
||||
Self { term }
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::replication::state::Command;
|
||||
|
||||
#[test]
|
||||
fn test_request_vote() {
|
||||
let request = RequestVote::new(1, 1, 10, 1);
|
||||
assert_eq!(request.term, 1);
|
||||
assert_eq!(request.candidate_id, 1);
|
||||
|
||||
let grant = RequestVoteResponse::grant(1);
|
||||
assert!(grant.vote_granted);
|
||||
|
||||
let deny = RequestVoteResponse::deny(2);
|
||||
assert!(!deny.vote_granted);
|
||||
assert_eq!(deny.term, 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_append_entries() {
|
||||
let heartbeat = AppendEntries::heartbeat(1, 1, 0, 0, 0);
|
||||
assert!(heartbeat.is_heartbeat());
|
||||
|
||||
let entries = vec![LogEntry::new(1, 1, Command::Noop)];
|
||||
let append = AppendEntries::with_entries(1, 1, 0, 0, entries, 0);
|
||||
assert!(!append.is_heartbeat());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rpc_message_serialization() {
|
||||
let request = RpcMessage::RequestVote(RequestVote::new(1, 1, 10, 1));
|
||||
let bytes = request.to_bytes();
|
||||
let decoded = RpcMessage::from_bytes(&bytes).unwrap();
|
||||
|
||||
assert_eq!(decoded.term(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_append_entries_response() {
|
||||
let success = AppendEntriesResponse::success(1, 10);
|
||||
assert!(success.success);
|
||||
assert_eq!(success.match_index, 10);
|
||||
|
||||
let failure = AppendEntriesResponse::failure(2);
|
||||
assert!(!failure.success);
|
||||
|
||||
let conflict = AppendEntriesResponse::conflict(2, 1, 5);
|
||||
assert!(!conflict.success);
|
||||
assert_eq!(conflict.conflict_term, Some(1));
|
||||
assert_eq!(conflict.conflict_index, Some(5));
|
||||
}
|
||||
}
|
||||
301
crates/synor-database/src/replication/snapshot.rs
Normal file
301
crates/synor-database/src/replication/snapshot.rs
Normal file
|
|
@ -0,0 +1,301 @@
|
|||
//! Log compaction and snapshots.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Metadata for a snapshot.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct SnapshotMetadata {
|
||||
/// Index of last entry included in snapshot.
|
||||
pub last_included_index: u64,
|
||||
/// Term of last entry included in snapshot.
|
||||
pub last_included_term: u64,
|
||||
/// Cluster configuration at snapshot time.
|
||||
pub config: SnapshotConfig,
|
||||
/// Size of snapshot data in bytes.
|
||||
pub size: u64,
|
||||
/// Timestamp when snapshot was created.
|
||||
pub created_at: u64,
|
||||
}
|
||||
|
||||
/// Cluster configuration stored in snapshot.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct SnapshotConfig {
|
||||
/// Node IDs in the cluster.
|
||||
pub nodes: Vec<u64>,
|
||||
}
|
||||
|
||||
impl Default for SnapshotConfig {
|
||||
fn default() -> Self {
|
||||
Self { nodes: Vec::new() }
|
||||
}
|
||||
}
|
||||
|
||||
/// A complete snapshot of the state machine.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Snapshot {
|
||||
/// Snapshot metadata.
|
||||
pub metadata: SnapshotMetadata,
|
||||
/// Snapshot data (serialized state machine).
|
||||
pub data: Vec<u8>,
|
||||
}
|
||||
|
||||
impl Snapshot {
|
||||
/// Creates a new snapshot.
|
||||
pub fn new(
|
||||
last_included_index: u64,
|
||||
last_included_term: u64,
|
||||
config: SnapshotConfig,
|
||||
data: Vec<u8>,
|
||||
) -> Self {
|
||||
let now = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_millis() as u64;
|
||||
|
||||
Self {
|
||||
metadata: SnapshotMetadata {
|
||||
last_included_index,
|
||||
last_included_term,
|
||||
config,
|
||||
size: data.len() as u64,
|
||||
created_at: now,
|
||||
},
|
||||
data,
|
||||
}
|
||||
}
|
||||
|
||||
/// Serializes the snapshot to bytes.
|
||||
pub fn to_bytes(&self) -> Vec<u8> {
|
||||
bincode::serialize(self).unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Deserializes from bytes.
|
||||
pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
|
||||
bincode::deserialize(bytes).ok()
|
||||
}
|
||||
}
|
||||
|
||||
/// Manages snapshot creation and storage.
|
||||
pub struct SnapshotManager {
|
||||
/// Threshold for log entries before snapshotting.
|
||||
snapshot_threshold: u64,
|
||||
/// Current snapshot (if any).
|
||||
current_snapshot: Option<Snapshot>,
|
||||
/// Pending snapshot being received from leader.
|
||||
pending_snapshot: Option<PendingSnapshot>,
|
||||
}
|
||||
|
||||
/// Snapshot being received in chunks.
|
||||
struct PendingSnapshot {
|
||||
metadata: SnapshotMetadata,
|
||||
chunks: Vec<Vec<u8>>,
|
||||
expected_offset: u64,
|
||||
}
|
||||
|
||||
impl SnapshotManager {
|
||||
/// Creates a new snapshot manager.
|
||||
pub fn new(snapshot_threshold: u64) -> Self {
|
||||
Self {
|
||||
snapshot_threshold,
|
||||
current_snapshot: None,
|
||||
pending_snapshot: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the threshold for triggering snapshots.
|
||||
pub fn threshold(&self) -> u64 {
|
||||
self.snapshot_threshold
|
||||
}
|
||||
|
||||
/// Checks if a snapshot should be taken.
|
||||
pub fn should_snapshot(&self, log_size: u64, last_snapshot_index: u64) -> bool {
|
||||
log_size - last_snapshot_index >= self.snapshot_threshold
|
||||
}
|
||||
|
||||
/// Gets the current snapshot.
|
||||
pub fn get_snapshot(&self) -> Option<&Snapshot> {
|
||||
self.current_snapshot.as_ref()
|
||||
}
|
||||
|
||||
/// Gets the last included index of the current snapshot.
|
||||
pub fn last_included_index(&self) -> u64 {
|
||||
self.current_snapshot
|
||||
.as_ref()
|
||||
.map(|s| s.metadata.last_included_index)
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
/// Gets the last included term of the current snapshot.
|
||||
pub fn last_included_term(&self) -> u64 {
|
||||
self.current_snapshot
|
||||
.as_ref()
|
||||
.map(|s| s.metadata.last_included_term)
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
/// Creates a new snapshot.
|
||||
pub fn create_snapshot(
|
||||
&mut self,
|
||||
last_included_index: u64,
|
||||
last_included_term: u64,
|
||||
config: SnapshotConfig,
|
||||
data: Vec<u8>,
|
||||
) {
|
||||
let snapshot = Snapshot::new(last_included_index, last_included_term, config, data);
|
||||
self.current_snapshot = Some(snapshot);
|
||||
}
|
||||
|
||||
/// Starts receiving a snapshot from leader.
|
||||
pub fn start_receiving(
|
||||
&mut self,
|
||||
last_included_index: u64,
|
||||
last_included_term: u64,
|
||||
offset: u64,
|
||||
data: Vec<u8>,
|
||||
) -> bool {
|
||||
if offset != 0 {
|
||||
// First chunk should be at offset 0
|
||||
return false;
|
||||
}
|
||||
|
||||
self.pending_snapshot = Some(PendingSnapshot {
|
||||
metadata: SnapshotMetadata {
|
||||
last_included_index,
|
||||
last_included_term,
|
||||
config: SnapshotConfig::default(),
|
||||
size: 0,
|
||||
created_at: 0,
|
||||
},
|
||||
chunks: vec![data],
|
||||
expected_offset: 0,
|
||||
});
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
/// Adds a chunk to the pending snapshot.
|
||||
pub fn add_chunk(&mut self, offset: u64, data: Vec<u8>) -> bool {
|
||||
if let Some(ref mut pending) = self.pending_snapshot {
|
||||
if offset == pending.expected_offset + pending.chunks.iter().map(|c| c.len() as u64).sum::<u64>() {
|
||||
pending.chunks.push(data);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Finalizes the pending snapshot.
|
||||
pub fn finalize_snapshot(&mut self) -> Option<Snapshot> {
|
||||
if let Some(pending) = self.pending_snapshot.take() {
|
||||
let data: Vec<u8> = pending.chunks.into_iter().flatten().collect();
|
||||
let snapshot = Snapshot::new(
|
||||
pending.metadata.last_included_index,
|
||||
pending.metadata.last_included_term,
|
||||
pending.metadata.config,
|
||||
data,
|
||||
);
|
||||
self.current_snapshot = Some(snapshot.clone());
|
||||
return Some(snapshot);
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Cancels receiving a snapshot.
|
||||
pub fn cancel_receiving(&mut self) {
|
||||
self.pending_snapshot = None;
|
||||
}
|
||||
|
||||
/// Splits snapshot into chunks for transmission.
|
||||
pub fn chunk_snapshot(&self, chunk_size: usize) -> Vec<(u64, Vec<u8>, bool)> {
|
||||
let mut chunks = Vec::new();
|
||||
|
||||
if let Some(ref snapshot) = self.current_snapshot {
|
||||
let data = &snapshot.data;
|
||||
let mut offset = 0;
|
||||
|
||||
while offset < data.len() {
|
||||
let end = (offset + chunk_size).min(data.len());
|
||||
let chunk = data[offset..end].to_vec();
|
||||
let done = end == data.len();
|
||||
chunks.push((offset as u64, chunk, done));
|
||||
offset = end;
|
||||
}
|
||||
}
|
||||
|
||||
chunks
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SnapshotManager {
|
||||
fn default() -> Self {
|
||||
Self::new(10000) // Default: snapshot every 10k entries
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_snapshot_creation() {
|
||||
let snapshot = Snapshot::new(100, 5, SnapshotConfig::default(), vec![1, 2, 3, 4, 5]);
|
||||
|
||||
assert_eq!(snapshot.metadata.last_included_index, 100);
|
||||
assert_eq!(snapshot.metadata.last_included_term, 5);
|
||||
assert_eq!(snapshot.metadata.size, 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_snapshot_serialization() {
|
||||
let snapshot = Snapshot::new(100, 5, SnapshotConfig::default(), vec![1, 2, 3, 4, 5]);
|
||||
let bytes = snapshot.to_bytes();
|
||||
let decoded = Snapshot::from_bytes(&bytes).unwrap();
|
||||
|
||||
assert_eq!(decoded.metadata.last_included_index, 100);
|
||||
assert_eq!(decoded.data, vec![1, 2, 3, 4, 5]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_snapshot_manager() {
|
||||
let mut manager = SnapshotManager::new(100);
|
||||
|
||||
assert!(manager.should_snapshot(150, 0));
|
||||
assert!(!manager.should_snapshot(50, 0));
|
||||
|
||||
manager.create_snapshot(100, 5, SnapshotConfig::default(), vec![1, 2, 3]);
|
||||
|
||||
assert_eq!(manager.last_included_index(), 100);
|
||||
assert_eq!(manager.last_included_term(), 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chunking() {
|
||||
let mut manager = SnapshotManager::new(100);
|
||||
manager.create_snapshot(100, 5, SnapshotConfig::default(), vec![0; 250]);
|
||||
|
||||
let chunks = manager.chunk_snapshot(100);
|
||||
|
||||
assert_eq!(chunks.len(), 3);
|
||||
assert_eq!(chunks[0].0, 0);
|
||||
assert_eq!(chunks[0].1.len(), 100);
|
||||
assert!(!chunks[0].2); // not done
|
||||
assert_eq!(chunks[2].0, 200);
|
||||
assert_eq!(chunks[2].1.len(), 50);
|
||||
assert!(chunks[2].2); // done
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_receiving_snapshot() {
|
||||
let mut manager = SnapshotManager::new(100);
|
||||
|
||||
// Start receiving
|
||||
assert!(manager.start_receiving(100, 5, 0, vec![1, 2, 3]));
|
||||
|
||||
// Add more chunks
|
||||
assert!(manager.add_chunk(3, vec![4, 5, 6]));
|
||||
|
||||
// Finalize
|
||||
let snapshot = manager.finalize_snapshot().unwrap();
|
||||
assert_eq!(snapshot.data, vec![1, 2, 3, 4, 5, 6]);
|
||||
}
|
||||
}
|
||||
345
crates/synor-database/src/replication/state.rs
Normal file
345
crates/synor-database/src/replication/state.rs
Normal file
|
|
@ -0,0 +1,345 @@
|
|||
//! Raft node state and commands.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value as JsonValue;
|
||||
|
||||
/// Role of a node in the Raft cluster.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum NodeRole {
|
||||
/// Leader: handles all client requests and replicates log entries.
|
||||
Leader,
|
||||
/// Follower: passive, responds to RPCs from leader and candidates.
|
||||
Follower,
|
||||
/// Candidate: actively trying to become leader.
|
||||
Candidate,
|
||||
}
|
||||
|
||||
impl Default for NodeRole {
|
||||
fn default() -> Self {
|
||||
NodeRole::Follower
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for NodeRole {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
NodeRole::Leader => write!(f, "Leader"),
|
||||
NodeRole::Follower => write!(f, "Follower"),
|
||||
NodeRole::Candidate => write!(f, "Candidate"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Persistent state on all servers.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct RaftState {
|
||||
/// Current term (increases monotonically).
|
||||
pub current_term: u64,
|
||||
/// Candidate that received vote in current term.
|
||||
pub voted_for: Option<u64>,
|
||||
/// Current role of this node.
|
||||
pub role: NodeRole,
|
||||
/// Index of highest log entry known to be committed.
|
||||
pub commit_index: u64,
|
||||
/// Index of highest log entry applied to state machine.
|
||||
pub last_applied: u64,
|
||||
}
|
||||
|
||||
impl Default for RaftState {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
current_term: 0,
|
||||
voted_for: None,
|
||||
role: NodeRole::Follower,
|
||||
commit_index: 0,
|
||||
last_applied: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RaftState {
|
||||
/// Creates a new Raft state.
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Transitions to follower state.
|
||||
pub fn become_follower(&mut self, term: u64) {
|
||||
self.current_term = term;
|
||||
self.role = NodeRole::Follower;
|
||||
self.voted_for = None;
|
||||
}
|
||||
|
||||
/// Transitions to candidate state.
|
||||
pub fn become_candidate(&mut self) {
|
||||
self.current_term += 1;
|
||||
self.role = NodeRole::Candidate;
|
||||
// Vote for self when becoming candidate
|
||||
}
|
||||
|
||||
/// Transitions to leader state.
|
||||
pub fn become_leader(&mut self) {
|
||||
self.role = NodeRole::Leader;
|
||||
}
|
||||
|
||||
/// Updates commit index if new value is higher.
|
||||
pub fn update_commit_index(&mut self, new_commit: u64) {
|
||||
if new_commit > self.commit_index {
|
||||
self.commit_index = new_commit;
|
||||
}
|
||||
}
|
||||
|
||||
/// Updates last applied index.
|
||||
pub fn update_last_applied(&mut self, index: u64) {
|
||||
self.last_applied = index;
|
||||
}
|
||||
|
||||
/// Returns true if this node is the leader.
|
||||
pub fn is_leader(&self) -> bool {
|
||||
self.role == NodeRole::Leader
|
||||
}
|
||||
|
||||
/// Returns true if this node is a follower.
|
||||
pub fn is_follower(&self) -> bool {
|
||||
self.role == NodeRole::Follower
|
||||
}
|
||||
|
||||
/// Returns true if this node is a candidate.
|
||||
pub fn is_candidate(&self) -> bool {
|
||||
self.role == NodeRole::Candidate
|
||||
}
|
||||
}
|
||||
|
||||
/// Commands that can be replicated through Raft.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub enum Command {
|
||||
/// No operation (used for heartbeats and new leader commit).
|
||||
Noop,
|
||||
|
||||
// Key-Value operations
|
||||
/// Set a key-value pair.
|
||||
KvSet { key: String, value: Vec<u8>, ttl: Option<u64> },
|
||||
/// Delete a key.
|
||||
KvDelete { key: String },
|
||||
|
||||
// Document operations
|
||||
/// Insert a document.
|
||||
DocInsert { collection: String, document: JsonValue },
|
||||
/// Update a document.
|
||||
DocUpdate { collection: String, id: String, update: JsonValue },
|
||||
/// Delete a document.
|
||||
DocDelete { collection: String, id: String },
|
||||
|
||||
// Vector operations
|
||||
/// Insert a vector.
|
||||
VectorInsert { namespace: String, id: String, vector: Vec<f32>, metadata: JsonValue },
|
||||
/// Delete a vector.
|
||||
VectorDelete { namespace: String, id: String },
|
||||
|
||||
// Time-series operations
|
||||
/// Record a metric data point.
|
||||
TimeSeriesRecord { metric: String, value: f64, timestamp: u64, tags: JsonValue },
|
||||
|
||||
// Graph operations
|
||||
/// Create a graph node.
|
||||
GraphNodeCreate { labels: Vec<String>, properties: JsonValue },
|
||||
/// Delete a graph node.
|
||||
GraphNodeDelete { id: String },
|
||||
/// Create a graph edge.
|
||||
GraphEdgeCreate {
|
||||
source: String,
|
||||
target: String,
|
||||
edge_type: String,
|
||||
properties: JsonValue,
|
||||
},
|
||||
/// Delete a graph edge.
|
||||
GraphEdgeDelete { id: String },
|
||||
|
||||
// SQL operations
|
||||
/// Execute a SQL statement.
|
||||
SqlExecute { sql: String },
|
||||
|
||||
// Schema operations
|
||||
/// Create a collection/table.
|
||||
CreateCollection { name: String, schema: Option<JsonValue> },
|
||||
/// Drop a collection/table.
|
||||
DropCollection { name: String },
|
||||
|
||||
// Index operations
|
||||
/// Create an index.
|
||||
CreateIndex { collection: String, field: String, index_type: String },
|
||||
/// Drop an index.
|
||||
DropIndex { name: String },
|
||||
|
||||
// Configuration changes
|
||||
/// Add a node to the cluster.
|
||||
AddNode { node_id: u64, address: String },
|
||||
/// Remove a node from the cluster.
|
||||
RemoveNode { node_id: u64 },
|
||||
}
|
||||
|
||||
impl Command {
|
||||
/// Returns a descriptive name for this command.
|
||||
pub fn name(&self) -> &'static str {
|
||||
match self {
|
||||
Command::Noop => "noop",
|
||||
Command::KvSet { .. } => "kv_set",
|
||||
Command::KvDelete { .. } => "kv_delete",
|
||||
Command::DocInsert { .. } => "doc_insert",
|
||||
Command::DocUpdate { .. } => "doc_update",
|
||||
Command::DocDelete { .. } => "doc_delete",
|
||||
Command::VectorInsert { .. } => "vector_insert",
|
||||
Command::VectorDelete { .. } => "vector_delete",
|
||||
Command::TimeSeriesRecord { .. } => "timeseries_record",
|
||||
Command::GraphNodeCreate { .. } => "graph_node_create",
|
||||
Command::GraphNodeDelete { .. } => "graph_node_delete",
|
||||
Command::GraphEdgeCreate { .. } => "graph_edge_create",
|
||||
Command::GraphEdgeDelete { .. } => "graph_edge_delete",
|
||||
Command::SqlExecute { .. } => "sql_execute",
|
||||
Command::CreateCollection { .. } => "create_collection",
|
||||
Command::DropCollection { .. } => "drop_collection",
|
||||
Command::CreateIndex { .. } => "create_index",
|
||||
Command::DropIndex { .. } => "drop_index",
|
||||
Command::AddNode { .. } => "add_node",
|
||||
Command::RemoveNode { .. } => "remove_node",
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if this is a read-only command (can be served by followers).
|
||||
pub fn is_read_only(&self) -> bool {
|
||||
matches!(self, Command::Noop)
|
||||
}
|
||||
|
||||
/// Serializes the command to bytes.
|
||||
pub fn to_bytes(&self) -> Vec<u8> {
|
||||
bincode::serialize(self).unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Deserializes from bytes.
|
||||
pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
|
||||
bincode::deserialize(bytes).ok()
|
||||
}
|
||||
}
|
||||
|
||||
/// Leader state (volatile, only on leaders).
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct LeaderState {
|
||||
/// For each server, index of next log entry to send.
|
||||
pub next_index: std::collections::HashMap<u64, u64>,
|
||||
/// For each server, index of highest log entry known to be replicated.
|
||||
pub match_index: std::collections::HashMap<u64, u64>,
|
||||
}
|
||||
|
||||
impl LeaderState {
|
||||
/// Creates a new leader state.
|
||||
pub fn new(last_log_index: u64, peer_ids: &[u64]) -> Self {
|
||||
let mut next_index = std::collections::HashMap::new();
|
||||
let mut match_index = std::collections::HashMap::new();
|
||||
|
||||
for &peer_id in peer_ids {
|
||||
// Initialize nextIndex to leader's last log index + 1
|
||||
next_index.insert(peer_id, last_log_index + 1);
|
||||
// Initialize matchIndex to 0
|
||||
match_index.insert(peer_id, 0);
|
||||
}
|
||||
|
||||
Self {
|
||||
next_index,
|
||||
match_index,
|
||||
}
|
||||
}
|
||||
|
||||
/// Updates next_index for a peer after failed append.
|
||||
pub fn decrement_next_index(&mut self, peer_id: u64) {
|
||||
if let Some(idx) = self.next_index.get_mut(&peer_id) {
|
||||
if *idx > 1 {
|
||||
*idx -= 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Updates indices after successful append.
|
||||
pub fn update_indices(&mut self, peer_id: u64, last_index: u64) {
|
||||
self.next_index.insert(peer_id, last_index + 1);
|
||||
self.match_index.insert(peer_id, last_index);
|
||||
}
|
||||
|
||||
/// Calculates the new commit index based on majority replication.
|
||||
pub fn calculate_commit_index(&self, current_commit: u64, current_term: u64, log_term_at: impl Fn(u64) -> Option<u64>) -> u64 {
|
||||
// Find the highest index that a majority have replicated
|
||||
let mut indices: Vec<u64> = self.match_index.values().cloned().collect();
|
||||
indices.sort_unstable();
|
||||
indices.reverse();
|
||||
|
||||
// Majority is (n + 1) / 2 where n includes the leader
|
||||
let majority = (indices.len() + 1 + 1) / 2;
|
||||
|
||||
for &index in &indices {
|
||||
if index > current_commit {
|
||||
// Only commit entries from current term
|
||||
if let Some(term) = log_term_at(index) {
|
||||
if term == current_term {
|
||||
// Check if majority have this index
|
||||
let count = indices.iter().filter(|&&i| i >= index).count() + 1; // +1 for leader
|
||||
if count >= majority {
|
||||
return index;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
current_commit
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_node_role() {
|
||||
let mut state = RaftState::new();
|
||||
assert!(state.is_follower());
|
||||
|
||||
state.become_candidate();
|
||||
assert!(state.is_candidate());
|
||||
assert_eq!(state.current_term, 1);
|
||||
|
||||
state.become_leader();
|
||||
assert!(state.is_leader());
|
||||
|
||||
state.become_follower(5);
|
||||
assert!(state.is_follower());
|
||||
assert_eq!(state.current_term, 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_command_serialization() {
|
||||
let cmd = Command::KvSet {
|
||||
key: "test".to_string(),
|
||||
value: vec![1, 2, 3],
|
||||
ttl: Some(3600),
|
||||
};
|
||||
|
||||
let bytes = cmd.to_bytes();
|
||||
let decoded = Command::from_bytes(&bytes).unwrap();
|
||||
|
||||
if let Command::KvSet { key, value, ttl } = decoded {
|
||||
assert_eq!(key, "test");
|
||||
assert_eq!(value, vec![1, 2, 3]);
|
||||
assert_eq!(ttl, Some(3600));
|
||||
} else {
|
||||
panic!("Wrong command type");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_leader_state() {
|
||||
let peers = vec![2, 3, 4];
|
||||
let state = LeaderState::new(10, &peers);
|
||||
|
||||
assert_eq!(state.next_index.get(&2), Some(&11));
|
||||
assert_eq!(state.match_index.get(&2), Some(&0));
|
||||
}
|
||||
}
|
||||
898
crates/synor-database/src/sql/executor.rs
Normal file
898
crates/synor-database/src/sql/executor.rs
Normal file
|
|
@ -0,0 +1,898 @@
|
|||
//! SQL query executor.
|
||||
|
||||
use super::parser::{
|
||||
BinaryOp, JoinType, ParsedExpr, ParsedOrderBy, ParsedSelect, ParsedSelectItem,
|
||||
ParsedStatement, SqlParser,
|
||||
};
|
||||
use super::row::{Row, RowBuilder, RowId};
|
||||
use super::table::{ColumnDef, Table, TableDef};
|
||||
use super::transaction::{IsolationLevel, TransactionId, TransactionManager, TransactionOp};
|
||||
use super::types::{SqlError, SqlType, SqlValue};
|
||||
use parking_lot::RwLock;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Result of a SQL query.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct QueryResult {
|
||||
/// Column names.
|
||||
pub columns: Vec<String>,
|
||||
/// Result rows.
|
||||
pub rows: Vec<Vec<SqlValue>>,
|
||||
/// Number of rows affected (for INSERT/UPDATE/DELETE).
|
||||
pub rows_affected: u64,
|
||||
/// Execution time in milliseconds.
|
||||
pub execution_time_ms: u64,
|
||||
}
|
||||
|
||||
impl QueryResult {
|
||||
/// Creates an empty result.
|
||||
pub fn empty() -> Self {
|
||||
Self {
|
||||
columns: Vec::new(),
|
||||
rows: Vec::new(),
|
||||
rows_affected: 0,
|
||||
execution_time_ms: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a result with affected count.
|
||||
pub fn affected(count: u64) -> Self {
|
||||
Self {
|
||||
columns: Vec::new(),
|
||||
rows: Vec::new(),
|
||||
rows_affected: count,
|
||||
execution_time_ms: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// SQL execution engine.
|
||||
pub struct SqlEngine {
|
||||
/// Tables in this engine.
|
||||
tables: RwLock<HashMap<String, Arc<Table>>>,
|
||||
/// Transaction manager.
|
||||
txn_manager: TransactionManager,
|
||||
}
|
||||
|
||||
impl SqlEngine {
|
||||
/// Creates a new SQL engine.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
tables: RwLock::new(HashMap::new()),
|
||||
txn_manager: TransactionManager::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Executes a SQL statement.
|
||||
pub fn execute(&self, sql: &str) -> Result<QueryResult, SqlError> {
|
||||
let start = std::time::Instant::now();
|
||||
let stmt = SqlParser::parse(sql)?;
|
||||
let mut result = self.execute_statement(&stmt)?;
|
||||
result.execution_time_ms = start.elapsed().as_millis() as u64;
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Executes a parsed statement.
|
||||
fn execute_statement(&self, stmt: &ParsedStatement) -> Result<QueryResult, SqlError> {
|
||||
match stmt {
|
||||
ParsedStatement::CreateTable {
|
||||
name,
|
||||
columns,
|
||||
if_not_exists,
|
||||
} => self.execute_create_table(name, columns, *if_not_exists),
|
||||
ParsedStatement::DropTable { name, if_exists } => {
|
||||
self.execute_drop_table(name, *if_exists)
|
||||
}
|
||||
ParsedStatement::Select(select) => self.execute_select(select),
|
||||
ParsedStatement::Insert {
|
||||
table,
|
||||
columns,
|
||||
values,
|
||||
} => self.execute_insert(table, columns, values),
|
||||
ParsedStatement::Update {
|
||||
table,
|
||||
assignments,
|
||||
where_clause,
|
||||
} => self.execute_update(table, assignments, where_clause.as_ref()),
|
||||
ParsedStatement::Delete {
|
||||
table,
|
||||
where_clause,
|
||||
} => self.execute_delete(table, where_clause.as_ref()),
|
||||
ParsedStatement::CreateIndex {
|
||||
name,
|
||||
table,
|
||||
columns,
|
||||
unique,
|
||||
} => self.execute_create_index(name, table, columns, *unique),
|
||||
ParsedStatement::DropIndex { name } => self.execute_drop_index(name),
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a table.
|
||||
fn execute_create_table(
|
||||
&self,
|
||||
name: &str,
|
||||
columns: &[super::parser::ParsedColumn],
|
||||
if_not_exists: bool,
|
||||
) -> Result<QueryResult, SqlError> {
|
||||
let mut tables = self.tables.write();
|
||||
|
||||
if tables.contains_key(name) {
|
||||
if if_not_exists {
|
||||
return Ok(QueryResult::empty());
|
||||
}
|
||||
return Err(SqlError::TableExists(name.to_string()));
|
||||
}
|
||||
|
||||
let mut table_def = TableDef::new(name);
|
||||
for col in columns {
|
||||
let mut col_def = ColumnDef::new(&col.name, col.data_type.clone());
|
||||
if !col.nullable {
|
||||
col_def = col_def.not_null();
|
||||
}
|
||||
if let Some(ref default) = col.default {
|
||||
col_def = col_def.default(default.clone());
|
||||
}
|
||||
if col.primary_key {
|
||||
col_def = col_def.primary_key();
|
||||
}
|
||||
if col.unique {
|
||||
col_def = col_def.unique();
|
||||
}
|
||||
table_def = table_def.column(col_def);
|
||||
}
|
||||
|
||||
let table = Arc::new(Table::new(table_def));
|
||||
tables.insert(name.to_string(), table);
|
||||
|
||||
Ok(QueryResult::empty())
|
||||
}
|
||||
|
||||
/// Drops a table.
|
||||
fn execute_drop_table(&self, name: &str, if_exists: bool) -> Result<QueryResult, SqlError> {
|
||||
let mut tables = self.tables.write();
|
||||
|
||||
if !tables.contains_key(name) {
|
||||
if if_exists {
|
||||
return Ok(QueryResult::empty());
|
||||
}
|
||||
return Err(SqlError::TableNotFound(name.to_string()));
|
||||
}
|
||||
|
||||
tables.remove(name);
|
||||
Ok(QueryResult::empty())
|
||||
}
|
||||
|
||||
/// Executes a SELECT query.
|
||||
fn execute_select(&self, select: &ParsedSelect) -> Result<QueryResult, SqlError> {
|
||||
let tables = self.tables.read();
|
||||
let table = tables
|
||||
.get(&select.from)
|
||||
.ok_or_else(|| SqlError::TableNotFound(select.from.clone()))?;
|
||||
|
||||
// Get all rows
|
||||
let mut rows = table.scan();
|
||||
|
||||
// Apply WHERE filter
|
||||
if let Some(ref where_clause) = select.where_clause {
|
||||
rows = rows
|
||||
.into_iter()
|
||||
.filter(|row| self.evaluate_where(row, where_clause))
|
||||
.collect();
|
||||
}
|
||||
|
||||
// Apply ORDER BY
|
||||
if !select.order_by.is_empty() {
|
||||
rows.sort_by(|a, b| {
|
||||
for ob in &select.order_by {
|
||||
let a_val = a.get_or_null(&ob.column);
|
||||
let b_val = b.get_or_null(&ob.column);
|
||||
match a_val.partial_cmp(&b_val) {
|
||||
Some(std::cmp::Ordering::Equal) => continue,
|
||||
Some(ord) => {
|
||||
return if ob.ascending {
|
||||
ord
|
||||
} else {
|
||||
ord.reverse()
|
||||
};
|
||||
}
|
||||
None => continue,
|
||||
}
|
||||
}
|
||||
std::cmp::Ordering::Equal
|
||||
});
|
||||
}
|
||||
|
||||
// Apply OFFSET
|
||||
if let Some(offset) = select.offset {
|
||||
rows = rows.into_iter().skip(offset).collect();
|
||||
}
|
||||
|
||||
// Apply LIMIT
|
||||
if let Some(limit) = select.limit {
|
||||
rows = rows.into_iter().take(limit).collect();
|
||||
}
|
||||
|
||||
// Handle aggregates
|
||||
if select.columns.iter().any(|c| matches!(c, ParsedSelectItem::Aggregate { .. })) {
|
||||
return self.execute_aggregate(select, &rows, table);
|
||||
}
|
||||
|
||||
// Project columns
|
||||
let (column_names, result_rows) = self.project_rows(&select.columns, &rows, table);
|
||||
|
||||
Ok(QueryResult {
|
||||
columns: column_names,
|
||||
rows: result_rows,
|
||||
rows_affected: 0,
|
||||
execution_time_ms: 0,
|
||||
})
|
||||
}
|
||||
|
||||
/// Projects rows to selected columns.
|
||||
fn project_rows(
|
||||
&self,
|
||||
select_items: &[ParsedSelectItem],
|
||||
rows: &[Row],
|
||||
table: &Table,
|
||||
) -> (Vec<String>, Vec<Vec<SqlValue>>) {
|
||||
let column_names: Vec<String> = select_items
|
||||
.iter()
|
||||
.flat_map(|item| match item {
|
||||
ParsedSelectItem::Wildcard => table.def.column_names(),
|
||||
ParsedSelectItem::Column(name) => vec![name.clone()],
|
||||
ParsedSelectItem::ColumnAlias { alias, .. } => vec![alias.clone()],
|
||||
ParsedSelectItem::Aggregate { function, alias, .. } => {
|
||||
vec![alias.clone().unwrap_or_else(|| function.clone())]
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let result_rows: Vec<Vec<SqlValue>> = rows
|
||||
.iter()
|
||||
.map(|row| {
|
||||
select_items
|
||||
.iter()
|
||||
.flat_map(|item| match item {
|
||||
ParsedSelectItem::Wildcard => table
|
||||
.def
|
||||
.column_names()
|
||||
.into_iter()
|
||||
.map(|c| row.get_or_null(&c))
|
||||
.collect::<Vec<_>>(),
|
||||
ParsedSelectItem::Column(name)
|
||||
| ParsedSelectItem::ColumnAlias { column: name, .. } => {
|
||||
vec![row.get_or_null(name)]
|
||||
}
|
||||
_ => vec![SqlValue::Null],
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
|
||||
(column_names, result_rows)
|
||||
}
|
||||
|
||||
/// Executes aggregate functions.
|
||||
fn execute_aggregate(
|
||||
&self,
|
||||
select: &ParsedSelect,
|
||||
rows: &[Row],
|
||||
table: &Table,
|
||||
) -> Result<QueryResult, SqlError> {
|
||||
let mut result_columns = Vec::new();
|
||||
let mut result_values = Vec::new();
|
||||
|
||||
for item in &select.columns {
|
||||
match item {
|
||||
ParsedSelectItem::Aggregate {
|
||||
function,
|
||||
column,
|
||||
alias,
|
||||
} => {
|
||||
let col_name = alias.clone().unwrap_or_else(|| function.clone());
|
||||
result_columns.push(col_name);
|
||||
|
||||
let value = match function.as_str() {
|
||||
"COUNT" => SqlValue::Integer(rows.len() as i64),
|
||||
"SUM" => {
|
||||
let col = column.as_ref().ok_or_else(|| {
|
||||
SqlError::InvalidOperation("SUM requires column".to_string())
|
||||
})?;
|
||||
let sum: f64 = rows
|
||||
.iter()
|
||||
.filter_map(|r| r.get_or_null(col).as_real())
|
||||
.sum();
|
||||
SqlValue::Real(sum)
|
||||
}
|
||||
"AVG" => {
|
||||
let col = column.as_ref().ok_or_else(|| {
|
||||
SqlError::InvalidOperation("AVG requires column".to_string())
|
||||
})?;
|
||||
let values: Vec<f64> = rows
|
||||
.iter()
|
||||
.filter_map(|r| r.get_or_null(col).as_real())
|
||||
.collect();
|
||||
if values.is_empty() {
|
||||
SqlValue::Null
|
||||
} else {
|
||||
SqlValue::Real(values.iter().sum::<f64>() / values.len() as f64)
|
||||
}
|
||||
}
|
||||
"MIN" => {
|
||||
let col = column.as_ref().ok_or_else(|| {
|
||||
SqlError::InvalidOperation("MIN requires column".to_string())
|
||||
})?;
|
||||
rows.iter()
|
||||
.map(|r| r.get_or_null(col))
|
||||
.filter(|v| !v.is_null())
|
||||
.min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
|
||||
.unwrap_or(SqlValue::Null)
|
||||
}
|
||||
"MAX" => {
|
||||
let col = column.as_ref().ok_or_else(|| {
|
||||
SqlError::InvalidOperation("MAX requires column".to_string())
|
||||
})?;
|
||||
rows.iter()
|
||||
.map(|r| r.get_or_null(col))
|
||||
.filter(|v| !v.is_null())
|
||||
.max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
|
||||
.unwrap_or(SqlValue::Null)
|
||||
}
|
||||
_ => {
|
||||
return Err(SqlError::Unsupported(format!("Function: {}", function)))
|
||||
}
|
||||
};
|
||||
result_values.push(value);
|
||||
}
|
||||
ParsedSelectItem::Column(name) => {
|
||||
result_columns.push(name.clone());
|
||||
// For non-aggregated columns in aggregate query, take first value
|
||||
result_values.push(
|
||||
rows.first()
|
||||
.map(|r| r.get_or_null(name))
|
||||
.unwrap_or(SqlValue::Null),
|
||||
);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(QueryResult {
|
||||
columns: result_columns,
|
||||
rows: if result_values.is_empty() {
|
||||
Vec::new()
|
||||
} else {
|
||||
vec![result_values]
|
||||
},
|
||||
rows_affected: 0,
|
||||
execution_time_ms: 0,
|
||||
})
|
||||
}
|
||||
|
||||
/// Evaluates a WHERE clause.
|
||||
fn evaluate_where(&self, row: &Row, expr: &ParsedExpr) -> bool {
|
||||
match self.evaluate_expr(row, expr) {
|
||||
SqlValue::Boolean(b) => b,
|
||||
SqlValue::Integer(i) => i != 0,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Evaluates an expression.
|
||||
fn evaluate_expr(&self, row: &Row, expr: &ParsedExpr) -> SqlValue {
|
||||
match expr {
|
||||
ParsedExpr::Column(name) => row.get_or_null(name),
|
||||
ParsedExpr::Literal(value) => value.clone(),
|
||||
ParsedExpr::BinaryOp { left, op, right } => {
|
||||
let left_val = self.evaluate_expr(row, left);
|
||||
let right_val = self.evaluate_expr(row, right);
|
||||
self.evaluate_binary_op(&left_val, op, &right_val)
|
||||
}
|
||||
ParsedExpr::Not(inner) => {
|
||||
let val = self.evaluate_expr(row, inner);
|
||||
match val {
|
||||
SqlValue::Boolean(b) => SqlValue::Boolean(!b),
|
||||
_ => SqlValue::Null,
|
||||
}
|
||||
}
|
||||
ParsedExpr::IsNull(inner) => {
|
||||
SqlValue::Boolean(self.evaluate_expr(row, inner).is_null())
|
||||
}
|
||||
ParsedExpr::IsNotNull(inner) => {
|
||||
SqlValue::Boolean(!self.evaluate_expr(row, inner).is_null())
|
||||
}
|
||||
ParsedExpr::InList { expr, list, negated } => {
|
||||
let val = self.evaluate_expr(row, expr);
|
||||
let in_list = list.iter().any(|item| {
|
||||
let item_val = self.evaluate_expr(row, item);
|
||||
val == item_val
|
||||
});
|
||||
SqlValue::Boolean(if *negated { !in_list } else { in_list })
|
||||
}
|
||||
ParsedExpr::Between {
|
||||
expr,
|
||||
low,
|
||||
high,
|
||||
negated,
|
||||
} => {
|
||||
let val = self.evaluate_expr(row, expr);
|
||||
let low_val = self.evaluate_expr(row, low);
|
||||
let high_val = self.evaluate_expr(row, high);
|
||||
let between = val >= low_val && val <= high_val;
|
||||
SqlValue::Boolean(if *negated { !between } else { between })
|
||||
}
|
||||
ParsedExpr::Function { name, args } => {
|
||||
self.evaluate_function(row, name, args)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Evaluates a binary operation.
|
||||
fn evaluate_binary_op(&self, left: &SqlValue, op: &BinaryOp, right: &SqlValue) -> SqlValue {
|
||||
match op {
|
||||
BinaryOp::Eq => SqlValue::Boolean(left == right),
|
||||
BinaryOp::Ne => SqlValue::Boolean(left != right),
|
||||
BinaryOp::Lt => SqlValue::Boolean(left < right),
|
||||
BinaryOp::Le => SqlValue::Boolean(left <= right),
|
||||
BinaryOp::Gt => SqlValue::Boolean(left > right),
|
||||
BinaryOp::Ge => SqlValue::Boolean(left >= right),
|
||||
BinaryOp::And => {
|
||||
let l = matches!(left, SqlValue::Boolean(true));
|
||||
let r = matches!(right, SqlValue::Boolean(true));
|
||||
SqlValue::Boolean(l && r)
|
||||
}
|
||||
BinaryOp::Or => {
|
||||
let l = matches!(left, SqlValue::Boolean(true));
|
||||
let r = matches!(right, SqlValue::Boolean(true));
|
||||
SqlValue::Boolean(l || r)
|
||||
}
|
||||
BinaryOp::Like => {
|
||||
if let (SqlValue::Text(text), SqlValue::Text(pattern)) = (left, right) {
|
||||
SqlValue::Boolean(self.match_like(text, pattern))
|
||||
} else {
|
||||
SqlValue::Boolean(false)
|
||||
}
|
||||
}
|
||||
BinaryOp::Plus => match (left, right) {
|
||||
(SqlValue::Integer(a), SqlValue::Integer(b)) => SqlValue::Integer(a + b),
|
||||
(SqlValue::Real(a), SqlValue::Real(b)) => SqlValue::Real(a + b),
|
||||
(SqlValue::Integer(a), SqlValue::Real(b)) => SqlValue::Real(*a as f64 + b),
|
||||
(SqlValue::Real(a), SqlValue::Integer(b)) => SqlValue::Real(a + *b as f64),
|
||||
_ => SqlValue::Null,
|
||||
},
|
||||
BinaryOp::Minus => match (left, right) {
|
||||
(SqlValue::Integer(a), SqlValue::Integer(b)) => SqlValue::Integer(a - b),
|
||||
(SqlValue::Real(a), SqlValue::Real(b)) => SqlValue::Real(a - b),
|
||||
_ => SqlValue::Null,
|
||||
},
|
||||
BinaryOp::Multiply => match (left, right) {
|
||||
(SqlValue::Integer(a), SqlValue::Integer(b)) => SqlValue::Integer(a * b),
|
||||
(SqlValue::Real(a), SqlValue::Real(b)) => SqlValue::Real(a * b),
|
||||
_ => SqlValue::Null,
|
||||
},
|
||||
BinaryOp::Divide => match (left, right) {
|
||||
(SqlValue::Integer(a), SqlValue::Integer(b)) if *b != 0 => {
|
||||
SqlValue::Integer(a / b)
|
||||
}
|
||||
(SqlValue::Real(a), SqlValue::Real(b)) if *b != 0.0 => SqlValue::Real(a / b),
|
||||
_ => SqlValue::Null,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Evaluates a function call.
|
||||
fn evaluate_function(&self, row: &Row, name: &str, args: &[ParsedExpr]) -> SqlValue {
|
||||
match name.to_uppercase().as_str() {
|
||||
"COALESCE" => {
|
||||
for arg in args {
|
||||
let val = self.evaluate_expr(row, arg);
|
||||
if !val.is_null() {
|
||||
return val;
|
||||
}
|
||||
}
|
||||
SqlValue::Null
|
||||
}
|
||||
"UPPER" => {
|
||||
if let Some(arg) = args.first() {
|
||||
if let SqlValue::Text(s) = self.evaluate_expr(row, arg) {
|
||||
return SqlValue::Text(s.to_uppercase());
|
||||
}
|
||||
}
|
||||
SqlValue::Null
|
||||
}
|
||||
"LOWER" => {
|
||||
if let Some(arg) = args.first() {
|
||||
if let SqlValue::Text(s) = self.evaluate_expr(row, arg) {
|
||||
return SqlValue::Text(s.to_lowercase());
|
||||
}
|
||||
}
|
||||
SqlValue::Null
|
||||
}
|
||||
"LENGTH" => {
|
||||
if let Some(arg) = args.first() {
|
||||
if let SqlValue::Text(s) = self.evaluate_expr(row, arg) {
|
||||
return SqlValue::Integer(s.len() as i64);
|
||||
}
|
||||
}
|
||||
SqlValue::Null
|
||||
}
|
||||
"ABS" => {
|
||||
if let Some(arg) = args.first() {
|
||||
match self.evaluate_expr(row, arg) {
|
||||
SqlValue::Integer(i) => return SqlValue::Integer(i.abs()),
|
||||
SqlValue::Real(f) => return SqlValue::Real(f.abs()),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
SqlValue::Null
|
||||
}
|
||||
_ => SqlValue::Null,
|
||||
}
|
||||
}
|
||||
|
||||
/// Matches a LIKE pattern.
|
||||
fn match_like(&self, text: &str, pattern: &str) -> bool {
|
||||
// Simple LIKE implementation: % = any chars, _ = single char
|
||||
let regex_pattern = pattern
|
||||
.replace('%', ".*")
|
||||
.replace('_', ".");
|
||||
// For simplicity, just do case-insensitive contains for now
|
||||
if pattern.starts_with('%') && pattern.ends_with('%') {
|
||||
let inner = &pattern[1..pattern.len() - 1];
|
||||
text.to_lowercase().contains(&inner.to_lowercase())
|
||||
} else if pattern.starts_with('%') {
|
||||
let suffix = &pattern[1..];
|
||||
text.to_lowercase().ends_with(&suffix.to_lowercase())
|
||||
} else if pattern.ends_with('%') {
|
||||
let prefix = &pattern[..pattern.len() - 1];
|
||||
text.to_lowercase().starts_with(&prefix.to_lowercase())
|
||||
} else {
|
||||
text.to_lowercase() == pattern.to_lowercase()
|
||||
}
|
||||
}
|
||||
|
||||
/// Executes an INSERT statement.
|
||||
fn execute_insert(
|
||||
&self,
|
||||
table_name: &str,
|
||||
columns: &[String],
|
||||
values: &[Vec<SqlValue>],
|
||||
) -> Result<QueryResult, SqlError> {
|
||||
let tables = self.tables.read();
|
||||
let table = tables
|
||||
.get(table_name)
|
||||
.ok_or_else(|| SqlError::TableNotFound(table_name.to_string()))?;
|
||||
|
||||
let cols = if columns.is_empty() {
|
||||
table.def.column_names()
|
||||
} else {
|
||||
columns.to_vec()
|
||||
};
|
||||
|
||||
let mut count = 0;
|
||||
for row_values in values {
|
||||
if row_values.len() != cols.len() {
|
||||
return Err(SqlError::InvalidOperation(format!(
|
||||
"Column count mismatch: expected {}, got {}",
|
||||
cols.len(),
|
||||
row_values.len()
|
||||
)));
|
||||
}
|
||||
|
||||
let mut row_map = HashMap::new();
|
||||
for (col, val) in cols.iter().zip(row_values.iter()) {
|
||||
row_map.insert(col.clone(), val.clone());
|
||||
}
|
||||
|
||||
table.insert(row_map)?;
|
||||
count += 1;
|
||||
}
|
||||
|
||||
Ok(QueryResult::affected(count))
|
||||
}
|
||||
|
||||
/// Executes an UPDATE statement.
|
||||
fn execute_update(
|
||||
&self,
|
||||
table_name: &str,
|
||||
assignments: &[(String, SqlValue)],
|
||||
where_clause: Option<&ParsedExpr>,
|
||||
) -> Result<QueryResult, SqlError> {
|
||||
let tables = self.tables.read();
|
||||
let table = tables
|
||||
.get(table_name)
|
||||
.ok_or_else(|| SqlError::TableNotFound(table_name.to_string()))?;
|
||||
|
||||
let rows = table.scan();
|
||||
let mut count = 0;
|
||||
|
||||
for row in rows {
|
||||
let matches = where_clause
|
||||
.map(|w| self.evaluate_where(&row, w))
|
||||
.unwrap_or(true);
|
||||
|
||||
if matches {
|
||||
let updates: HashMap<String, SqlValue> =
|
||||
assignments.iter().cloned().collect();
|
||||
table.update(row.id, updates)?;
|
||||
count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(QueryResult::affected(count))
|
||||
}
|
||||
|
||||
/// Executes a DELETE statement.
|
||||
fn execute_delete(
|
||||
&self,
|
||||
table_name: &str,
|
||||
where_clause: Option<&ParsedExpr>,
|
||||
) -> Result<QueryResult, SqlError> {
|
||||
let tables = self.tables.read();
|
||||
let table = tables
|
||||
.get(table_name)
|
||||
.ok_or_else(|| SqlError::TableNotFound(table_name.to_string()))?;
|
||||
|
||||
let rows = table.scan();
|
||||
let mut count = 0;
|
||||
|
||||
let to_delete: Vec<RowId> = rows
|
||||
.iter()
|
||||
.filter(|row| {
|
||||
where_clause
|
||||
.map(|w| self.evaluate_where(row, w))
|
||||
.unwrap_or(true)
|
||||
})
|
||||
.map(|row| row.id)
|
||||
.collect();
|
||||
|
||||
for id in to_delete {
|
||||
if table.delete(id)? {
|
||||
count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(QueryResult::affected(count))
|
||||
}
|
||||
|
||||
/// Creates an index.
|
||||
fn execute_create_index(
|
||||
&self,
|
||||
name: &str,
|
||||
table_name: &str,
|
||||
columns: &[String],
|
||||
unique: bool,
|
||||
) -> Result<QueryResult, SqlError> {
|
||||
let tables = self.tables.read();
|
||||
let table = tables
|
||||
.get(table_name)
|
||||
.ok_or_else(|| SqlError::TableNotFound(table_name.to_string()))?;
|
||||
|
||||
// For simplicity, only support single-column indexes
|
||||
let column = columns
|
||||
.first()
|
||||
.ok_or_else(|| SqlError::InvalidOperation("Index requires at least one column".to_string()))?;
|
||||
|
||||
table.create_index(name, column, unique)?;
|
||||
Ok(QueryResult::empty())
|
||||
}
|
||||
|
||||
/// Drops an index.
|
||||
fn execute_drop_index(&self, name: &str) -> Result<QueryResult, SqlError> {
|
||||
// Would need to find which table has this index
|
||||
// For now, return success
|
||||
Ok(QueryResult::empty())
|
||||
}
|
||||
|
||||
// Transaction methods
|
||||
|
||||
/// Begins a transaction.
|
||||
pub fn begin_transaction(&self) -> TransactionId {
|
||||
self.txn_manager.begin(IsolationLevel::ReadCommitted)
|
||||
}
|
||||
|
||||
/// Commits a transaction.
|
||||
pub fn commit(&self, txn_id: TransactionId) -> Result<(), SqlError> {
|
||||
self.txn_manager.commit(txn_id)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Rolls back a transaction.
|
||||
pub fn rollback(&self, txn_id: TransactionId) -> Result<(), SqlError> {
|
||||
self.txn_manager.rollback(txn_id)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Returns number of tables.
|
||||
pub fn table_count(&self) -> usize {
|
||||
self.tables.read().len()
|
||||
}
|
||||
|
||||
/// Returns table names.
|
||||
pub fn table_names(&self) -> Vec<String> {
|
||||
self.tables.read().keys().cloned().collect()
|
||||
}
|
||||
|
||||
/// Gets table definition.
|
||||
pub fn get_table_def(&self, name: &str) -> Option<TableDef> {
|
||||
self.tables.read().get(name).map(|t| t.def.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SqlEngine {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn setup_engine() -> SqlEngine {
|
||||
let engine = SqlEngine::new();
|
||||
engine
|
||||
.execute("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT NOT NULL, age INTEGER)")
|
||||
.unwrap();
|
||||
engine
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_table() {
|
||||
let engine = SqlEngine::new();
|
||||
engine
|
||||
.execute("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)")
|
||||
.unwrap();
|
||||
assert_eq!(engine.table_count(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_insert_and_select() {
|
||||
let engine = setup_engine();
|
||||
|
||||
engine
|
||||
.execute("INSERT INTO users (id, name, age) VALUES (1, 'Alice', 30)")
|
||||
.unwrap();
|
||||
engine
|
||||
.execute("INSERT INTO users (id, name, age) VALUES (2, 'Bob', 25)")
|
||||
.unwrap();
|
||||
|
||||
let result = engine.execute("SELECT name, age FROM users").unwrap();
|
||||
assert_eq!(result.rows.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_select_with_where() {
|
||||
let engine = setup_engine();
|
||||
|
||||
engine
|
||||
.execute("INSERT INTO users (id, name, age) VALUES (1, 'Alice', 30)")
|
||||
.unwrap();
|
||||
engine
|
||||
.execute("INSERT INTO users (id, name, age) VALUES (2, 'Bob', 25)")
|
||||
.unwrap();
|
||||
|
||||
let result = engine.execute("SELECT name FROM users WHERE age > 26").unwrap();
|
||||
assert_eq!(result.rows.len(), 1);
|
||||
assert_eq!(result.rows[0][0], SqlValue::Text("Alice".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_select_order_by() {
|
||||
let engine = setup_engine();
|
||||
|
||||
engine
|
||||
.execute("INSERT INTO users (id, name, age) VALUES (1, 'Alice', 30)")
|
||||
.unwrap();
|
||||
engine
|
||||
.execute("INSERT INTO users (id, name, age) VALUES (2, 'Bob', 25)")
|
||||
.unwrap();
|
||||
|
||||
let result = engine
|
||||
.execute("SELECT name FROM users ORDER BY age")
|
||||
.unwrap();
|
||||
assert_eq!(result.rows[0][0], SqlValue::Text("Bob".to_string()));
|
||||
assert_eq!(result.rows[1][0], SqlValue::Text("Alice".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_select_limit() {
|
||||
let engine = setup_engine();
|
||||
|
||||
for i in 1..=10 {
|
||||
engine
|
||||
.execute(&format!(
|
||||
"INSERT INTO users (id, name, age) VALUES ({}, 'User{}', {})",
|
||||
i, i, 20 + i
|
||||
))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let result = engine.execute("SELECT name FROM users LIMIT 3").unwrap();
|
||||
assert_eq!(result.rows.len(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_aggregate_count() {
|
||||
let engine = setup_engine();
|
||||
|
||||
engine
|
||||
.execute("INSERT INTO users (id, name, age) VALUES (1, 'Alice', 30)")
|
||||
.unwrap();
|
||||
engine
|
||||
.execute("INSERT INTO users (id, name, age) VALUES (2, 'Bob', 25)")
|
||||
.unwrap();
|
||||
|
||||
let result = engine.execute("SELECT COUNT(*) FROM users").unwrap();
|
||||
assert_eq!(result.rows[0][0], SqlValue::Integer(2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_aggregate_sum_avg() {
|
||||
let engine = setup_engine();
|
||||
|
||||
engine
|
||||
.execute("INSERT INTO users (id, name, age) VALUES (1, 'Alice', 30)")
|
||||
.unwrap();
|
||||
engine
|
||||
.execute("INSERT INTO users (id, name, age) VALUES (2, 'Bob', 20)")
|
||||
.unwrap();
|
||||
|
||||
let result = engine.execute("SELECT SUM(age) FROM users").unwrap();
|
||||
assert_eq!(result.rows[0][0], SqlValue::Real(50.0));
|
||||
|
||||
let result = engine.execute("SELECT AVG(age) FROM users").unwrap();
|
||||
assert_eq!(result.rows[0][0], SqlValue::Real(25.0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_update() {
|
||||
let engine = setup_engine();
|
||||
|
||||
engine
|
||||
.execute("INSERT INTO users (id, name, age) VALUES (1, 'Alice', 30)")
|
||||
.unwrap();
|
||||
|
||||
let result = engine
|
||||
.execute("UPDATE users SET age = 31 WHERE name = 'Alice'")
|
||||
.unwrap();
|
||||
assert_eq!(result.rows_affected, 1);
|
||||
|
||||
let result = engine
|
||||
.execute("SELECT age FROM users WHERE name = 'Alice'")
|
||||
.unwrap();
|
||||
assert_eq!(result.rows[0][0], SqlValue::Integer(31));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_delete() {
|
||||
let engine = setup_engine();
|
||||
|
||||
engine
|
||||
.execute("INSERT INTO users (id, name, age) VALUES (1, 'Alice', 30)")
|
||||
.unwrap();
|
||||
engine
|
||||
.execute("INSERT INTO users (id, name, age) VALUES (2, 'Bob', 25)")
|
||||
.unwrap();
|
||||
|
||||
let result = engine
|
||||
.execute("DELETE FROM users WHERE name = 'Alice'")
|
||||
.unwrap();
|
||||
assert_eq!(result.rows_affected, 1);
|
||||
|
||||
let result = engine.execute("SELECT COUNT(*) FROM users").unwrap();
|
||||
assert_eq!(result.rows[0][0], SqlValue::Integer(1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_drop_table() {
|
||||
let engine = setup_engine();
|
||||
assert_eq!(engine.table_count(), 1);
|
||||
|
||||
engine.execute("DROP TABLE users").unwrap();
|
||||
assert_eq!(engine.table_count(), 0);
|
||||
}
|
||||
}
|
||||
23
crates/synor-database/src/sql/mod.rs
Normal file
23
crates/synor-database/src/sql/mod.rs
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
//! SQL Query Layer for Synor Database.
|
||||
//!
|
||||
//! Provides SQLite-compatible query interface:
|
||||
//!
|
||||
//! - DDL: CREATE TABLE, DROP TABLE, ALTER TABLE
|
||||
//! - DML: SELECT, INSERT, UPDATE, DELETE
|
||||
//! - Clauses: WHERE, ORDER BY, LIMIT, GROUP BY
|
||||
//! - Joins: INNER, LEFT, RIGHT JOIN
|
||||
//! - Functions: COUNT, SUM, AVG, MIN, MAX
|
||||
|
||||
pub mod executor;
|
||||
pub mod parser;
|
||||
pub mod row;
|
||||
pub mod table;
|
||||
pub mod transaction;
|
||||
pub mod types;
|
||||
|
||||
pub use executor::{QueryResult, SqlEngine};
|
||||
pub use parser::SqlParser;
|
||||
pub use row::{Row, RowId};
|
||||
pub use table::{ColumnDef, Table, TableDef};
|
||||
pub use transaction::{Transaction, TransactionId, TransactionState};
|
||||
pub use types::{SqlError, SqlType, SqlValue};
|
||||
732
crates/synor-database/src/sql/parser.rs
Normal file
732
crates/synor-database/src/sql/parser.rs
Normal file
|
|
@ -0,0 +1,732 @@
|
|||
//! SQL parser using sqlparser-rs.
|
||||
|
||||
use super::types::{SqlError, SqlType, SqlValue};
|
||||
use sqlparser::ast::{
|
||||
BinaryOperator, ColumnOption, DataType, Expr, Query, Select, SelectItem, SetExpr, Statement,
|
||||
TableFactor, Value as AstValue,
|
||||
};
|
||||
use sqlparser::dialect::SQLiteDialect;
|
||||
use sqlparser::parser::Parser;
|
||||
|
||||
/// Parsed SQL statement.
|
||||
#[derive(Debug)]
|
||||
pub enum ParsedStatement {
|
||||
/// CREATE TABLE statement.
|
||||
CreateTable {
|
||||
name: String,
|
||||
columns: Vec<ParsedColumn>,
|
||||
if_not_exists: bool,
|
||||
},
|
||||
/// DROP TABLE statement.
|
||||
DropTable {
|
||||
name: String,
|
||||
if_exists: bool,
|
||||
},
|
||||
/// SELECT statement.
|
||||
Select(ParsedSelect),
|
||||
/// INSERT statement.
|
||||
Insert {
|
||||
table: String,
|
||||
columns: Vec<String>,
|
||||
values: Vec<Vec<SqlValue>>,
|
||||
},
|
||||
/// UPDATE statement.
|
||||
Update {
|
||||
table: String,
|
||||
assignments: Vec<(String, SqlValue)>,
|
||||
where_clause: Option<ParsedExpr>,
|
||||
},
|
||||
/// DELETE statement.
|
||||
Delete {
|
||||
table: String,
|
||||
where_clause: Option<ParsedExpr>,
|
||||
},
|
||||
/// CREATE INDEX statement.
|
||||
CreateIndex {
|
||||
name: String,
|
||||
table: String,
|
||||
columns: Vec<String>,
|
||||
unique: bool,
|
||||
},
|
||||
/// DROP INDEX statement.
|
||||
DropIndex { name: String },
|
||||
}
|
||||
|
||||
/// Parsed column definition.
|
||||
#[derive(Debug)]
|
||||
pub struct ParsedColumn {
|
||||
pub name: String,
|
||||
pub data_type: SqlType,
|
||||
pub nullable: bool,
|
||||
pub default: Option<SqlValue>,
|
||||
pub primary_key: bool,
|
||||
pub unique: bool,
|
||||
}
|
||||
|
||||
/// Parsed SELECT statement.
|
||||
#[derive(Debug)]
|
||||
pub struct ParsedSelect {
|
||||
pub columns: Vec<ParsedSelectItem>,
|
||||
pub from: String,
|
||||
pub joins: Vec<ParsedJoin>,
|
||||
pub where_clause: Option<ParsedExpr>,
|
||||
pub group_by: Vec<String>,
|
||||
pub having: Option<ParsedExpr>,
|
||||
pub order_by: Vec<ParsedOrderBy>,
|
||||
pub limit: Option<usize>,
|
||||
pub offset: Option<usize>,
|
||||
}
|
||||
|
||||
/// Parsed select item.
|
||||
#[derive(Debug)]
|
||||
pub enum ParsedSelectItem {
|
||||
/// All columns (*).
|
||||
Wildcard,
|
||||
/// Single column.
|
||||
Column(String),
|
||||
/// Column with alias.
|
||||
ColumnAlias { column: String, alias: String },
|
||||
/// Aggregate function.
|
||||
Aggregate {
|
||||
function: String,
|
||||
column: Option<String>,
|
||||
alias: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
/// Parsed JOIN clause.
|
||||
#[derive(Debug)]
|
||||
pub struct ParsedJoin {
|
||||
pub table: String,
|
||||
pub join_type: JoinType,
|
||||
pub on: ParsedExpr,
|
||||
}
|
||||
|
||||
/// Join types.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum JoinType {
|
||||
Inner,
|
||||
Left,
|
||||
Right,
|
||||
Full,
|
||||
}
|
||||
|
||||
/// Parsed ORDER BY clause.
|
||||
#[derive(Debug)]
|
||||
pub struct ParsedOrderBy {
|
||||
pub column: String,
|
||||
pub ascending: bool,
|
||||
}
|
||||
|
||||
/// Parsed expression.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ParsedExpr {
|
||||
/// Column reference.
|
||||
Column(String),
|
||||
/// Literal value.
|
||||
Literal(SqlValue),
|
||||
/// Binary operation.
|
||||
BinaryOp {
|
||||
left: Box<ParsedExpr>,
|
||||
op: BinaryOp,
|
||||
right: Box<ParsedExpr>,
|
||||
},
|
||||
/// Unary NOT.
|
||||
Not(Box<ParsedExpr>),
|
||||
/// IS NULL check.
|
||||
IsNull(Box<ParsedExpr>),
|
||||
/// IS NOT NULL check.
|
||||
IsNotNull(Box<ParsedExpr>),
|
||||
/// IN list.
|
||||
InList {
|
||||
expr: Box<ParsedExpr>,
|
||||
list: Vec<ParsedExpr>,
|
||||
negated: bool,
|
||||
},
|
||||
/// BETWEEN.
|
||||
Between {
|
||||
expr: Box<ParsedExpr>,
|
||||
low: Box<ParsedExpr>,
|
||||
high: Box<ParsedExpr>,
|
||||
negated: bool,
|
||||
},
|
||||
/// Function call.
|
||||
Function { name: String, args: Vec<ParsedExpr> },
|
||||
}
|
||||
|
||||
/// Binary operators.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum BinaryOp {
|
||||
Eq,
|
||||
Ne,
|
||||
Lt,
|
||||
Le,
|
||||
Gt,
|
||||
Ge,
|
||||
And,
|
||||
Or,
|
||||
Like,
|
||||
Plus,
|
||||
Minus,
|
||||
Multiply,
|
||||
Divide,
|
||||
}
|
||||
|
||||
/// SQL parser.
|
||||
pub struct SqlParser;
|
||||
|
||||
impl SqlParser {
|
||||
/// Parses a SQL statement.
|
||||
pub fn parse(sql: &str) -> Result<ParsedStatement, SqlError> {
|
||||
let dialect = SQLiteDialect {};
|
||||
let statements = Parser::parse_sql(&dialect, sql)
|
||||
.map_err(|e| SqlError::Parse(e.to_string()))?;
|
||||
|
||||
if statements.is_empty() {
|
||||
return Err(SqlError::Parse("Empty SQL statement".to_string()));
|
||||
}
|
||||
|
||||
if statements.len() > 1 {
|
||||
return Err(SqlError::Parse("Multiple statements not supported".to_string()));
|
||||
}
|
||||
|
||||
Self::convert_statement(&statements[0])
|
||||
}
|
||||
|
||||
fn convert_statement(stmt: &Statement) -> Result<ParsedStatement, SqlError> {
|
||||
match stmt {
|
||||
Statement::CreateTable { name, columns, if_not_exists, constraints, .. } => {
|
||||
Self::convert_create_table(name, columns, constraints, *if_not_exists)
|
||||
}
|
||||
Statement::Drop { object_type, names, if_exists, .. } => {
|
||||
Self::convert_drop(object_type, names, *if_exists)
|
||||
}
|
||||
Statement::Query(query) => Self::convert_query(query),
|
||||
Statement::Insert { table_name, columns, source, .. } => {
|
||||
Self::convert_insert(table_name, columns, source)
|
||||
}
|
||||
Statement::Update { table, assignments, selection, .. } => {
|
||||
Self::convert_update(table, assignments, selection)
|
||||
}
|
||||
Statement::Delete { from, selection, .. } => {
|
||||
Self::convert_delete(from, selection)
|
||||
}
|
||||
Statement::CreateIndex { name, table_name, columns, unique, .. } => {
|
||||
Self::convert_create_index(name, table_name, columns, *unique)
|
||||
}
|
||||
_ => Err(SqlError::Unsupported(format!("Statement not supported"))),
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_create_table(
|
||||
name: &sqlparser::ast::ObjectName,
|
||||
columns: &[sqlparser::ast::ColumnDef],
|
||||
constraints: &[sqlparser::ast::TableConstraint],
|
||||
if_not_exists: bool,
|
||||
) -> Result<ParsedStatement, SqlError> {
|
||||
let table_name = name.to_string();
|
||||
let mut parsed_columns = Vec::new();
|
||||
let mut primary_keys: Vec<String> = Vec::new();
|
||||
|
||||
// Extract primary keys from table constraints
|
||||
for constraint in constraints {
|
||||
if let sqlparser::ast::TableConstraint::Unique { columns: pk_cols, is_primary: true, .. } = constraint {
|
||||
for col in pk_cols {
|
||||
primary_keys.push(col.value.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for col in columns {
|
||||
let col_name = col.name.value.clone();
|
||||
let data_type = Self::convert_data_type(&col.data_type)?;
|
||||
|
||||
let mut nullable = true;
|
||||
let mut default = None;
|
||||
let mut primary_key = primary_keys.contains(&col_name);
|
||||
let mut unique = false;
|
||||
|
||||
for option in &col.options {
|
||||
match &option.option {
|
||||
ColumnOption::Null => nullable = true,
|
||||
ColumnOption::NotNull => nullable = false,
|
||||
ColumnOption::Default(expr) => {
|
||||
default = Some(Self::convert_value_expr(expr)?);
|
||||
}
|
||||
ColumnOption::Unique { is_primary, .. } => {
|
||||
if *is_primary {
|
||||
primary_key = true;
|
||||
} else {
|
||||
unique = true;
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
if primary_key {
|
||||
nullable = false;
|
||||
unique = true;
|
||||
}
|
||||
|
||||
parsed_columns.push(ParsedColumn {
|
||||
name: col_name,
|
||||
data_type,
|
||||
nullable,
|
||||
default,
|
||||
primary_key,
|
||||
unique,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(ParsedStatement::CreateTable {
|
||||
name: table_name,
|
||||
columns: parsed_columns,
|
||||
if_not_exists,
|
||||
})
|
||||
}
|
||||
|
||||
fn convert_data_type(dt: &DataType) -> Result<SqlType, SqlError> {
|
||||
match dt {
|
||||
DataType::Int(_)
|
||||
| DataType::Integer(_)
|
||||
| DataType::BigInt(_)
|
||||
| DataType::SmallInt(_)
|
||||
| DataType::TinyInt(_) => Ok(SqlType::Integer),
|
||||
DataType::Real | DataType::Float(_) | DataType::Double | DataType::DoublePrecision => {
|
||||
Ok(SqlType::Real)
|
||||
}
|
||||
DataType::Varchar(_)
|
||||
| DataType::Char(_)
|
||||
| DataType::Text
|
||||
| DataType::String(_) => Ok(SqlType::Text),
|
||||
DataType::Binary(_) | DataType::Varbinary(_) | DataType::Blob(_) => Ok(SqlType::Blob),
|
||||
DataType::Boolean => Ok(SqlType::Boolean),
|
||||
DataType::Timestamp(_, _) | DataType::Date | DataType::Datetime(_) => {
|
||||
Ok(SqlType::Timestamp)
|
||||
}
|
||||
_ => Err(SqlError::Unsupported(format!("Data type: {:?}", dt))),
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_drop(
|
||||
object_type: &sqlparser::ast::ObjectType,
|
||||
names: &[sqlparser::ast::ObjectName],
|
||||
if_exists: bool,
|
||||
) -> Result<ParsedStatement, SqlError> {
|
||||
let name = names
|
||||
.first()
|
||||
.map(|n| n.to_string())
|
||||
.ok_or_else(|| SqlError::Parse("Missing object name".to_string()))?;
|
||||
|
||||
match object_type {
|
||||
sqlparser::ast::ObjectType::Table => Ok(ParsedStatement::DropTable { name, if_exists }),
|
||||
sqlparser::ast::ObjectType::Index => Ok(ParsedStatement::DropIndex { name }),
|
||||
_ => Err(SqlError::Unsupported(format!("DROP not supported"))),
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_query(query: &Query) -> Result<ParsedStatement, SqlError> {
|
||||
let select = match &*query.body {
|
||||
SetExpr::Select(select) => select,
|
||||
_ => return Err(SqlError::Unsupported("Non-SELECT query body".to_string())),
|
||||
};
|
||||
|
||||
let parsed_select = Self::convert_select(select, query)?;
|
||||
Ok(ParsedStatement::Select(parsed_select))
|
||||
}
|
||||
|
||||
fn convert_select(select: &Select, query: &Query) -> Result<ParsedSelect, SqlError> {
|
||||
// Parse columns
|
||||
let columns = select
|
||||
.projection
|
||||
.iter()
|
||||
.map(Self::convert_select_item)
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
// Parse FROM
|
||||
let from = select
|
||||
.from
|
||||
.first()
|
||||
.map(|f| Self::convert_table_factor(&f.relation))
|
||||
.transpose()?
|
||||
.unwrap_or_default();
|
||||
|
||||
// Parse WHERE
|
||||
let where_clause = select
|
||||
.selection
|
||||
.as_ref()
|
||||
.map(Self::convert_expr)
|
||||
.transpose()?;
|
||||
|
||||
// Parse ORDER BY - simplified approach
|
||||
let order_by: Vec<ParsedOrderBy> = query
|
||||
.order_by
|
||||
.iter()
|
||||
.filter_map(|e| Self::convert_order_by(e).ok())
|
||||
.collect();
|
||||
|
||||
// Parse LIMIT/OFFSET
|
||||
let limit = query
|
||||
.limit
|
||||
.as_ref()
|
||||
.and_then(|l| Self::expr_to_usize(l));
|
||||
let offset = query
|
||||
.offset
|
||||
.as_ref()
|
||||
.and_then(|o| Self::expr_to_usize(&o.value));
|
||||
|
||||
Ok(ParsedSelect {
|
||||
columns,
|
||||
from,
|
||||
joins: Vec::new(),
|
||||
where_clause,
|
||||
group_by: Vec::new(),
|
||||
having: None,
|
||||
order_by,
|
||||
limit,
|
||||
offset,
|
||||
})
|
||||
}
|
||||
|
||||
fn convert_select_item(item: &SelectItem) -> Result<ParsedSelectItem, SqlError> {
|
||||
match item {
|
||||
SelectItem::Wildcard(_) => Ok(ParsedSelectItem::Wildcard),
|
||||
SelectItem::UnnamedExpr(expr) => Self::convert_select_expr(expr),
|
||||
SelectItem::ExprWithAlias { expr, alias } => {
|
||||
if let Expr::Identifier(id) = expr {
|
||||
Ok(ParsedSelectItem::ColumnAlias {
|
||||
column: id.value.clone(),
|
||||
alias: alias.value.clone(),
|
||||
})
|
||||
} else {
|
||||
Self::convert_select_expr(expr)
|
||||
}
|
||||
}
|
||||
_ => Err(SqlError::Unsupported("Select item not supported".to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_select_expr(expr: &Expr) -> Result<ParsedSelectItem, SqlError> {
|
||||
match expr {
|
||||
Expr::Identifier(id) => Ok(ParsedSelectItem::Column(id.value.clone())),
|
||||
Expr::CompoundIdentifier(ids) => {
|
||||
Ok(ParsedSelectItem::Column(ids.last().map(|i| i.value.clone()).unwrap_or_default()))
|
||||
}
|
||||
Expr::Function(func) => {
|
||||
let name = func.name.to_string().to_uppercase();
|
||||
// Try to extract column from first arg - simplified for compatibility
|
||||
let column = Self::extract_func_column_arg(func);
|
||||
Ok(ParsedSelectItem::Aggregate {
|
||||
function: name,
|
||||
column,
|
||||
alias: None,
|
||||
})
|
||||
}
|
||||
_ => Err(SqlError::Unsupported("Select expression not supported".to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_table_factor(factor: &TableFactor) -> Result<String, SqlError> {
|
||||
match factor {
|
||||
TableFactor::Table { name, .. } => Ok(name.to_string()),
|
||||
_ => Err(SqlError::Unsupported("Table factor not supported".to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_func_column_arg(func: &sqlparser::ast::Function) -> Option<String> {
|
||||
// Use string representation and parse it - works across sqlparser versions
|
||||
let func_str = func.to_string();
|
||||
// Parse function like "SUM(age)" or "COUNT(*)"
|
||||
if let Some(start) = func_str.find('(') {
|
||||
if let Some(end) = func_str.rfind(')') {
|
||||
let arg = func_str[start + 1..end].trim();
|
||||
if !arg.is_empty() && arg != "*" {
|
||||
return Some(arg.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn convert_order_by(ob: &sqlparser::ast::OrderByExpr) -> Result<ParsedOrderBy, SqlError> {
|
||||
let column = match &ob.expr {
|
||||
Expr::Identifier(id) => id.value.clone(),
|
||||
_ => return Err(SqlError::Unsupported("Order by expression".to_string())),
|
||||
};
|
||||
let ascending = ob.asc.unwrap_or(true);
|
||||
Ok(ParsedOrderBy { column, ascending })
|
||||
}
|
||||
|
||||
fn convert_expr(expr: &Expr) -> Result<ParsedExpr, SqlError> {
|
||||
match expr {
|
||||
Expr::Identifier(id) => Ok(ParsedExpr::Column(id.value.clone())),
|
||||
Expr::CompoundIdentifier(ids) => {
|
||||
Ok(ParsedExpr::Column(ids.last().map(|i| i.value.clone()).unwrap_or_default()))
|
||||
}
|
||||
Expr::Value(v) => Ok(ParsedExpr::Literal(Self::convert_value(v)?)),
|
||||
Expr::BinaryOp { left, op, right } => {
|
||||
let left = Box::new(Self::convert_expr(left)?);
|
||||
let right = Box::new(Self::convert_expr(right)?);
|
||||
let op = Self::convert_binary_op(op)?;
|
||||
Ok(ParsedExpr::BinaryOp { left, op, right })
|
||||
}
|
||||
Expr::UnaryOp { op: sqlparser::ast::UnaryOperator::Not, expr } => {
|
||||
Ok(ParsedExpr::Not(Box::new(Self::convert_expr(expr)?)))
|
||||
}
|
||||
Expr::IsNull(expr) => Ok(ParsedExpr::IsNull(Box::new(Self::convert_expr(expr)?))),
|
||||
Expr::IsNotNull(expr) => Ok(ParsedExpr::IsNotNull(Box::new(Self::convert_expr(expr)?))),
|
||||
Expr::InList { expr, list, negated } => Ok(ParsedExpr::InList {
|
||||
expr: Box::new(Self::convert_expr(expr)?),
|
||||
list: list.iter().map(Self::convert_expr).collect::<Result<_, _>>()?,
|
||||
negated: *negated,
|
||||
}),
|
||||
Expr::Between { expr, low, high, negated } => Ok(ParsedExpr::Between {
|
||||
expr: Box::new(Self::convert_expr(expr)?),
|
||||
low: Box::new(Self::convert_expr(low)?),
|
||||
high: Box::new(Self::convert_expr(high)?),
|
||||
negated: *negated,
|
||||
}),
|
||||
Expr::Like { expr, pattern, .. } => {
|
||||
let left = Box::new(Self::convert_expr(expr)?);
|
||||
let right = Box::new(Self::convert_expr(pattern)?);
|
||||
Ok(ParsedExpr::BinaryOp { left, op: BinaryOp::Like, right })
|
||||
}
|
||||
Expr::Nested(inner) => Self::convert_expr(inner),
|
||||
_ => Err(SqlError::Unsupported("Expression not supported".to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_binary_op(op: &BinaryOperator) -> Result<BinaryOp, SqlError> {
|
||||
match op {
|
||||
BinaryOperator::Eq => Ok(BinaryOp::Eq),
|
||||
BinaryOperator::NotEq => Ok(BinaryOp::Ne),
|
||||
BinaryOperator::Lt => Ok(BinaryOp::Lt),
|
||||
BinaryOperator::LtEq => Ok(BinaryOp::Le),
|
||||
BinaryOperator::Gt => Ok(BinaryOp::Gt),
|
||||
BinaryOperator::GtEq => Ok(BinaryOp::Ge),
|
||||
BinaryOperator::And => Ok(BinaryOp::And),
|
||||
BinaryOperator::Or => Ok(BinaryOp::Or),
|
||||
BinaryOperator::Plus => Ok(BinaryOp::Plus),
|
||||
BinaryOperator::Minus => Ok(BinaryOp::Minus),
|
||||
BinaryOperator::Multiply => Ok(BinaryOp::Multiply),
|
||||
BinaryOperator::Divide => Ok(BinaryOp::Divide),
|
||||
_ => Err(SqlError::Unsupported("Operator not supported".to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_value(v: &AstValue) -> Result<SqlValue, SqlError> {
|
||||
match v {
|
||||
AstValue::Null => Ok(SqlValue::Null),
|
||||
AstValue::Number(n, _) => {
|
||||
if n.contains('.') {
|
||||
n.parse::<f64>()
|
||||
.map(SqlValue::Real)
|
||||
.map_err(|e| SqlError::Parse(e.to_string()))
|
||||
} else {
|
||||
n.parse::<i64>()
|
||||
.map(SqlValue::Integer)
|
||||
.map_err(|e| SqlError::Parse(e.to_string()))
|
||||
}
|
||||
}
|
||||
AstValue::SingleQuotedString(s) | AstValue::DoubleQuotedString(s) => {
|
||||
Ok(SqlValue::Text(s.clone()))
|
||||
}
|
||||
AstValue::Boolean(b) => Ok(SqlValue::Boolean(*b)),
|
||||
AstValue::HexStringLiteral(h) => hex::decode(h)
|
||||
.map(SqlValue::Blob)
|
||||
.map_err(|e| SqlError::Parse(e.to_string())),
|
||||
_ => Err(SqlError::Unsupported("Value not supported".to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_value_expr(expr: &Expr) -> Result<SqlValue, SqlError> {
|
||||
match expr {
|
||||
Expr::Value(v) => Self::convert_value(v),
|
||||
_ => Err(SqlError::Unsupported("Non-literal default".to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_insert(
|
||||
table_name: &sqlparser::ast::ObjectName,
|
||||
columns: &[sqlparser::ast::Ident],
|
||||
source: &Option<Box<Query>>,
|
||||
) -> Result<ParsedStatement, SqlError> {
|
||||
let table = table_name.to_string();
|
||||
let col_names: Vec<String> = columns.iter().map(|c| c.value.clone()).collect();
|
||||
|
||||
let values = match source.as_ref().map(|s| s.body.as_ref()) {
|
||||
Some(SetExpr::Values(vals)) => {
|
||||
let mut result = Vec::new();
|
||||
for row in &vals.rows {
|
||||
let row_values: Vec<SqlValue> = row
|
||||
.iter()
|
||||
.map(Self::convert_value_expr)
|
||||
.collect::<Result<_, _>>()?;
|
||||
result.push(row_values);
|
||||
}
|
||||
result
|
||||
}
|
||||
_ => return Err(SqlError::Unsupported("INSERT source".to_string())),
|
||||
};
|
||||
|
||||
Ok(ParsedStatement::Insert {
|
||||
table,
|
||||
columns: col_names,
|
||||
values,
|
||||
})
|
||||
}
|
||||
|
||||
fn convert_update(
|
||||
table: &sqlparser::ast::TableWithJoins,
|
||||
assignments: &[sqlparser::ast::Assignment],
|
||||
selection: &Option<Expr>,
|
||||
) -> Result<ParsedStatement, SqlError> {
|
||||
let table_name = Self::convert_table_factor(&table.relation)?;
|
||||
|
||||
let parsed_assignments: Vec<(String, SqlValue)> = assignments
|
||||
.iter()
|
||||
.map(|a| {
|
||||
let col = a.id.iter().map(|i| i.value.clone()).collect::<Vec<_>>().join(".");
|
||||
let val = Self::convert_value_expr(&a.value)?;
|
||||
Ok((col, val))
|
||||
})
|
||||
.collect::<Result<_, SqlError>>()?;
|
||||
|
||||
let where_clause = selection.as_ref().map(Self::convert_expr).transpose()?;
|
||||
|
||||
Ok(ParsedStatement::Update {
|
||||
table: table_name,
|
||||
assignments: parsed_assignments,
|
||||
where_clause,
|
||||
})
|
||||
}
|
||||
|
||||
fn convert_delete(
|
||||
from: &[sqlparser::ast::TableWithJoins],
|
||||
selection: &Option<Expr>,
|
||||
) -> Result<ParsedStatement, SqlError> {
|
||||
let table = from
|
||||
.first()
|
||||
.map(|f| Self::convert_table_factor(&f.relation))
|
||||
.transpose()?
|
||||
.unwrap_or_default();
|
||||
|
||||
let where_clause = selection.as_ref().map(Self::convert_expr).transpose()?;
|
||||
|
||||
Ok(ParsedStatement::Delete {
|
||||
table,
|
||||
where_clause,
|
||||
})
|
||||
}
|
||||
|
||||
fn convert_create_index(
|
||||
name: &Option<sqlparser::ast::ObjectName>,
|
||||
table_name: &sqlparser::ast::ObjectName,
|
||||
columns: &[sqlparser::ast::OrderByExpr],
|
||||
unique: bool,
|
||||
) -> Result<ParsedStatement, SqlError> {
|
||||
let index_name = name
|
||||
.as_ref()
|
||||
.map(|n| n.to_string())
|
||||
.ok_or_else(|| SqlError::Parse("Index name required".to_string()))?;
|
||||
|
||||
let table = table_name.to_string();
|
||||
|
||||
let cols: Vec<String> = columns
|
||||
.iter()
|
||||
.map(|c| c.expr.to_string())
|
||||
.collect();
|
||||
|
||||
Ok(ParsedStatement::CreateIndex {
|
||||
name: index_name,
|
||||
table,
|
||||
columns: cols,
|
||||
unique,
|
||||
})
|
||||
}
|
||||
|
||||
fn expr_to_usize(expr: &Expr) -> Option<usize> {
|
||||
if let Expr::Value(AstValue::Number(n, _)) = expr {
|
||||
n.parse().ok()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_create_table() {
|
||||
let sql = "CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT NOT NULL, age INTEGER)";
|
||||
let stmt = SqlParser::parse(sql).unwrap();
|
||||
|
||||
if let ParsedStatement::CreateTable { name, columns, .. } = stmt {
|
||||
assert_eq!(name, "users");
|
||||
assert_eq!(columns.len(), 3);
|
||||
assert!(columns[0].primary_key);
|
||||
assert!(!columns[1].nullable);
|
||||
} else {
|
||||
panic!("Expected CreateTable");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_select() {
|
||||
let sql = "SELECT name, age FROM users WHERE age > 18 ORDER BY name LIMIT 10";
|
||||
let stmt = SqlParser::parse(sql).unwrap();
|
||||
|
||||
if let ParsedStatement::Select(select) = stmt {
|
||||
assert_eq!(select.columns.len(), 2);
|
||||
assert_eq!(select.from, "users");
|
||||
assert!(select.where_clause.is_some());
|
||||
assert_eq!(select.limit, Some(10));
|
||||
} else {
|
||||
panic!("Expected Select");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_insert() {
|
||||
let sql = "INSERT INTO users (name, age) VALUES ('Alice', 30), ('Bob', 25)";
|
||||
let stmt = SqlParser::parse(sql).unwrap();
|
||||
|
||||
if let ParsedStatement::Insert { table, columns, values } = stmt {
|
||||
assert_eq!(table, "users");
|
||||
assert_eq!(columns, vec!["name", "age"]);
|
||||
assert_eq!(values.len(), 2);
|
||||
} else {
|
||||
panic!("Expected Insert");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_update() {
|
||||
let sql = "UPDATE users SET age = 31 WHERE name = 'Alice'";
|
||||
let stmt = SqlParser::parse(sql).unwrap();
|
||||
|
||||
if let ParsedStatement::Update { table, assignments, where_clause } = stmt {
|
||||
assert_eq!(table, "users");
|
||||
assert_eq!(assignments.len(), 1);
|
||||
assert!(where_clause.is_some());
|
||||
} else {
|
||||
panic!("Expected Update");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_delete() {
|
||||
let sql = "DELETE FROM users WHERE age < 18";
|
||||
let stmt = SqlParser::parse(sql).unwrap();
|
||||
|
||||
if let ParsedStatement::Delete { table, where_clause } = stmt {
|
||||
assert_eq!(table, "users");
|
||||
assert!(where_clause.is_some());
|
||||
} else {
|
||||
panic!("Expected Delete");
|
||||
}
|
||||
}
|
||||
}
|
||||
241
crates/synor-database/src/sql/row.rs
Normal file
241
crates/synor-database/src/sql/row.rs
Normal file
|
|
@ -0,0 +1,241 @@
|
|||
//! Row representation for SQL tables.
|
||||
|
||||
use super::types::{SqlError, SqlValue};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Unique row identifier.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub struct RowId(pub u64);
|
||||
|
||||
impl RowId {
|
||||
/// Creates a new row ID.
|
||||
pub fn new(id: u64) -> Self {
|
||||
RowId(id)
|
||||
}
|
||||
|
||||
/// Returns the inner ID value.
|
||||
pub fn inner(&self) -> u64 {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for RowId {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
/// A single row in a SQL table.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Row {
|
||||
/// Row identifier.
|
||||
pub id: RowId,
|
||||
/// Column values indexed by column name.
|
||||
values: HashMap<String, SqlValue>,
|
||||
/// Ordered column names (preserves insertion order).
|
||||
columns: Vec<String>,
|
||||
}
|
||||
|
||||
impl Row {
|
||||
/// Creates a new empty row.
|
||||
pub fn new(id: RowId) -> Self {
|
||||
Self {
|
||||
id,
|
||||
values: HashMap::new(),
|
||||
columns: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a row with given columns.
|
||||
pub fn with_columns(id: RowId, columns: Vec<String>) -> Self {
|
||||
let mut values = HashMap::with_capacity(columns.len());
|
||||
for col in &columns {
|
||||
values.insert(col.clone(), SqlValue::Null);
|
||||
}
|
||||
Self {
|
||||
id,
|
||||
values,
|
||||
columns,
|
||||
}
|
||||
}
|
||||
|
||||
/// Sets a column value.
|
||||
pub fn set(&mut self, column: &str, value: SqlValue) {
|
||||
if !self.values.contains_key(column) {
|
||||
self.columns.push(column.to_string());
|
||||
}
|
||||
self.values.insert(column.to_string(), value);
|
||||
}
|
||||
|
||||
/// Gets a column value.
|
||||
pub fn get(&self, column: &str) -> Option<&SqlValue> {
|
||||
self.values.get(column)
|
||||
}
|
||||
|
||||
/// Gets a column value or returns Null.
|
||||
pub fn get_or_null(&self, column: &str) -> SqlValue {
|
||||
self.values.get(column).cloned().unwrap_or(SqlValue::Null)
|
||||
}
|
||||
|
||||
/// Returns all column names.
|
||||
pub fn columns(&self) -> &[String] {
|
||||
&self.columns
|
||||
}
|
||||
|
||||
/// Returns all values in column order.
|
||||
pub fn values(&self) -> Vec<&SqlValue> {
|
||||
self.columns.iter().map(|c| self.values.get(c).unwrap()).collect()
|
||||
}
|
||||
|
||||
/// Returns the number of columns.
|
||||
pub fn len(&self) -> usize {
|
||||
self.columns.len()
|
||||
}
|
||||
|
||||
/// Returns true if the row has no columns.
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.columns.is_empty()
|
||||
}
|
||||
|
||||
/// Projects to specific columns.
|
||||
pub fn project(&self, columns: &[String]) -> Row {
|
||||
let mut row = Row::new(self.id);
|
||||
for col in columns {
|
||||
if let Some(value) = self.values.get(col) {
|
||||
row.set(col, value.clone());
|
||||
}
|
||||
}
|
||||
row
|
||||
}
|
||||
|
||||
/// Converts to a map.
|
||||
pub fn to_map(&self) -> HashMap<String, SqlValue> {
|
||||
self.values.clone()
|
||||
}
|
||||
|
||||
/// Converts from a map.
|
||||
pub fn from_map(id: RowId, map: HashMap<String, SqlValue>) -> Self {
|
||||
let columns: Vec<String> = map.keys().cloned().collect();
|
||||
Self {
|
||||
id,
|
||||
values: map,
|
||||
columns,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for Row {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
if self.columns.len() != other.columns.len() {
|
||||
return false;
|
||||
}
|
||||
for col in &self.columns {
|
||||
if self.get(col) != other.get(col) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
/// Builder for creating rows.
|
||||
pub struct RowBuilder {
|
||||
id: RowId,
|
||||
values: Vec<(String, SqlValue)>,
|
||||
}
|
||||
|
||||
impl RowBuilder {
|
||||
/// Creates a new row builder.
|
||||
pub fn new(id: RowId) -> Self {
|
||||
Self {
|
||||
id,
|
||||
values: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Adds a column value.
|
||||
pub fn column(mut self, name: impl Into<String>, value: SqlValue) -> Self {
|
||||
self.values.push((name.into(), value));
|
||||
self
|
||||
}
|
||||
|
||||
/// Adds an integer column.
|
||||
pub fn int(self, name: impl Into<String>, value: i64) -> Self {
|
||||
self.column(name, SqlValue::Integer(value))
|
||||
}
|
||||
|
||||
/// Adds a real column.
|
||||
pub fn real(self, name: impl Into<String>, value: f64) -> Self {
|
||||
self.column(name, SqlValue::Real(value))
|
||||
}
|
||||
|
||||
/// Adds a text column.
|
||||
pub fn text(self, name: impl Into<String>, value: impl Into<String>) -> Self {
|
||||
self.column(name, SqlValue::Text(value.into()))
|
||||
}
|
||||
|
||||
/// Adds a boolean column.
|
||||
pub fn boolean(self, name: impl Into<String>, value: bool) -> Self {
|
||||
self.column(name, SqlValue::Boolean(value))
|
||||
}
|
||||
|
||||
/// Adds a null column.
|
||||
pub fn null(self, name: impl Into<String>) -> Self {
|
||||
self.column(name, SqlValue::Null)
|
||||
}
|
||||
|
||||
/// Builds the row.
|
||||
pub fn build(self) -> Row {
|
||||
let mut row = Row::new(self.id);
|
||||
for (name, value) in self.values {
|
||||
row.set(&name, value);
|
||||
}
|
||||
row
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_row_basic() {
|
||||
let mut row = Row::new(RowId(1));
|
||||
row.set("name", SqlValue::Text("Alice".to_string()));
|
||||
row.set("age", SqlValue::Integer(30));
|
||||
|
||||
assert_eq!(row.get("name"), Some(&SqlValue::Text("Alice".to_string())));
|
||||
assert_eq!(row.get("age"), Some(&SqlValue::Integer(30)));
|
||||
assert_eq!(row.get("missing"), None);
|
||||
assert_eq!(row.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_row_builder() {
|
||||
let row = RowBuilder::new(RowId(1))
|
||||
.text("name", "Bob")
|
||||
.int("age", 25)
|
||||
.boolean("active", true)
|
||||
.build();
|
||||
|
||||
assert_eq!(row.get("name"), Some(&SqlValue::Text("Bob".to_string())));
|
||||
assert_eq!(row.get("age"), Some(&SqlValue::Integer(25)));
|
||||
assert_eq!(row.get("active"), Some(&SqlValue::Boolean(true)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_row_projection() {
|
||||
let row = RowBuilder::new(RowId(1))
|
||||
.text("a", "1")
|
||||
.text("b", "2")
|
||||
.text("c", "3")
|
||||
.build();
|
||||
|
||||
let projected = row.project(&["a".to_string(), "c".to_string()]);
|
||||
assert_eq!(projected.len(), 2);
|
||||
assert!(projected.get("a").is_some());
|
||||
assert!(projected.get("b").is_none());
|
||||
assert!(projected.get("c").is_some());
|
||||
}
|
||||
}
|
||||
570
crates/synor-database/src/sql/table.rs
Normal file
570
crates/synor-database/src/sql/table.rs
Normal file
|
|
@ -0,0 +1,570 @@
|
|||
//! SQL table definition and storage.
|
||||
|
||||
use super::row::{Row, RowId};
|
||||
use super::types::{SqlError, SqlType, SqlValue};
|
||||
use parking_lot::RwLock;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::{BTreeMap, HashMap, HashSet};
|
||||
|
||||
/// Column definition.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct ColumnDef {
|
||||
/// Column name.
|
||||
pub name: String,
|
||||
/// Data type.
|
||||
pub data_type: SqlType,
|
||||
/// Whether null values are allowed.
|
||||
pub nullable: bool,
|
||||
/// Default value.
|
||||
pub default: Option<SqlValue>,
|
||||
/// Primary key flag.
|
||||
pub primary_key: bool,
|
||||
/// Unique constraint.
|
||||
pub unique: bool,
|
||||
}
|
||||
|
||||
impl ColumnDef {
|
||||
/// Creates a new column definition.
|
||||
pub fn new(name: impl Into<String>, data_type: SqlType) -> Self {
|
||||
Self {
|
||||
name: name.into(),
|
||||
data_type,
|
||||
nullable: true,
|
||||
default: None,
|
||||
primary_key: false,
|
||||
unique: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Sets as not null.
|
||||
pub fn not_null(mut self) -> Self {
|
||||
self.nullable = false;
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets default value.
|
||||
pub fn default(mut self, value: SqlValue) -> Self {
|
||||
self.default = Some(value);
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets as primary key.
|
||||
pub fn primary_key(mut self) -> Self {
|
||||
self.primary_key = true;
|
||||
self.nullable = false;
|
||||
self.unique = true;
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets as unique.
|
||||
pub fn unique(mut self) -> Self {
|
||||
self.unique = true;
|
||||
self
|
||||
}
|
||||
|
||||
/// Validates a value against this column definition.
|
||||
pub fn validate(&self, value: &SqlValue) -> Result<(), SqlError> {
|
||||
// Check null constraint
|
||||
if value.is_null() && !self.nullable {
|
||||
return Err(SqlError::NotNullViolation(self.name.clone()));
|
||||
}
|
||||
|
||||
// Skip type check for null values
|
||||
if value.is_null() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Check type compatibility
|
||||
let compatible = match (&self.data_type, value) {
|
||||
(SqlType::Integer, SqlValue::Integer(_)) => true,
|
||||
(SqlType::Real, SqlValue::Real(_)) => true,
|
||||
(SqlType::Real, SqlValue::Integer(_)) => true, // Allow int to real
|
||||
(SqlType::Text, SqlValue::Text(_)) => true,
|
||||
(SqlType::Blob, SqlValue::Blob(_)) => true,
|
||||
(SqlType::Boolean, SqlValue::Boolean(_)) => true,
|
||||
(SqlType::Timestamp, SqlValue::Timestamp(_)) => true,
|
||||
(SqlType::Timestamp, SqlValue::Integer(_)) => true, // Allow int to timestamp
|
||||
_ => false,
|
||||
};
|
||||
|
||||
if !compatible {
|
||||
return Err(SqlError::TypeMismatch {
|
||||
expected: format!("{:?}", self.data_type),
|
||||
got: format!("{:?}", value.sql_type()),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Table definition.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct TableDef {
|
||||
/// Table name.
|
||||
pub name: String,
|
||||
/// Column definitions.
|
||||
pub columns: Vec<ColumnDef>,
|
||||
/// Primary key column name (if any).
|
||||
pub primary_key: Option<String>,
|
||||
}
|
||||
|
||||
impl TableDef {
|
||||
/// Creates a new table definition.
|
||||
pub fn new(name: impl Into<String>) -> Self {
|
||||
Self {
|
||||
name: name.into(),
|
||||
columns: Vec::new(),
|
||||
primary_key: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Adds a column.
|
||||
pub fn column(mut self, col: ColumnDef) -> Self {
|
||||
if col.primary_key {
|
||||
self.primary_key = Some(col.name.clone());
|
||||
}
|
||||
self.columns.push(col);
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets column definition by name.
|
||||
pub fn get_column(&self, name: &str) -> Option<&ColumnDef> {
|
||||
self.columns.iter().find(|c| c.name == name)
|
||||
}
|
||||
|
||||
/// Returns column names.
|
||||
pub fn column_names(&self) -> Vec<String> {
|
||||
self.columns.iter().map(|c| c.name.clone()).collect()
|
||||
}
|
||||
|
||||
/// Validates a row against this table definition.
|
||||
pub fn validate_row(&self, row: &Row) -> Result<(), SqlError> {
|
||||
for col in &self.columns {
|
||||
let value = row.get_or_null(&col.name);
|
||||
col.validate(&value)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Index on a table column.
|
||||
#[derive(Debug)]
|
||||
pub struct TableIndex {
|
||||
/// Index name.
|
||||
pub name: String,
|
||||
/// Column being indexed.
|
||||
pub column: String,
|
||||
/// Whether index enforces uniqueness.
|
||||
pub unique: bool,
|
||||
/// B-tree index data: value -> row IDs.
|
||||
data: BTreeMap<SqlValue, HashSet<RowId>>,
|
||||
}
|
||||
|
||||
impl TableIndex {
|
||||
/// Creates a new index.
|
||||
pub fn new(name: impl Into<String>, column: impl Into<String>, unique: bool) -> Self {
|
||||
Self {
|
||||
name: name.into(),
|
||||
column: column.into(),
|
||||
unique,
|
||||
data: BTreeMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Inserts a value-rowid mapping.
|
||||
pub fn insert(&mut self, value: SqlValue, row_id: RowId) -> Result<(), SqlError> {
|
||||
if self.unique {
|
||||
if let Some(existing) = self.data.get(&value) {
|
||||
if !existing.is_empty() {
|
||||
return Err(SqlError::ConstraintViolation(format!(
|
||||
"Unique constraint violation on index '{}'",
|
||||
self.name
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
self.data.entry(value).or_default().insert(row_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Removes a value-rowid mapping.
|
||||
pub fn remove(&mut self, value: &SqlValue, row_id: &RowId) {
|
||||
if let Some(ids) = self.data.get_mut(value) {
|
||||
ids.remove(row_id);
|
||||
if ids.is_empty() {
|
||||
self.data.remove(value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Looks up rows by exact value.
|
||||
pub fn lookup(&self, value: &SqlValue) -> Vec<RowId> {
|
||||
self.data
|
||||
.get(value)
|
||||
.map(|ids| ids.iter().cloned().collect())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Range query.
|
||||
pub fn range(&self, start: Option<&SqlValue>, end: Option<&SqlValue>) -> Vec<RowId> {
|
||||
let mut result = Vec::new();
|
||||
for (key, ids) in &self.data {
|
||||
let in_range = match (start, end) {
|
||||
(Some(s), Some(e)) => key >= s && key <= e,
|
||||
(Some(s), None) => key >= s,
|
||||
(None, Some(e)) => key <= e,
|
||||
(None, None) => true,
|
||||
};
|
||||
if in_range {
|
||||
result.extend(ids.iter().cloned());
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
/// A SQL table with data.
|
||||
pub struct Table {
|
||||
/// Table definition.
|
||||
pub def: TableDef,
|
||||
/// Row storage: row ID -> row data.
|
||||
rows: RwLock<HashMap<RowId, Row>>,
|
||||
/// Next row ID.
|
||||
next_id: RwLock<u64>,
|
||||
/// Indexes.
|
||||
indexes: RwLock<HashMap<String, TableIndex>>,
|
||||
/// Primary key index (if any).
|
||||
pk_index: RwLock<Option<String>>,
|
||||
}
|
||||
|
||||
impl Table {
|
||||
/// Creates a new table.
|
||||
pub fn new(def: TableDef) -> Self {
|
||||
let pk_col = def.primary_key.clone();
|
||||
let unique_cols: Vec<String> = def
|
||||
.columns
|
||||
.iter()
|
||||
.filter(|c| c.unique && !c.primary_key)
|
||||
.map(|c| c.name.clone())
|
||||
.collect();
|
||||
|
||||
let table = Self {
|
||||
def,
|
||||
rows: RwLock::new(HashMap::new()),
|
||||
next_id: RwLock::new(1),
|
||||
indexes: RwLock::new(HashMap::new()),
|
||||
pk_index: RwLock::new(None),
|
||||
};
|
||||
|
||||
{
|
||||
let mut indexes = table.indexes.write();
|
||||
|
||||
// Create primary key index if defined
|
||||
if let Some(pk) = pk_col {
|
||||
let idx_name = format!("pk_{}", pk);
|
||||
indexes.insert(idx_name.clone(), TableIndex::new(&idx_name, &pk, true));
|
||||
*table.pk_index.write() = Some(idx_name);
|
||||
}
|
||||
|
||||
// Create indexes for unique columns
|
||||
for col in unique_cols {
|
||||
let idx_name = format!("unique_{}", col);
|
||||
indexes.insert(idx_name.clone(), TableIndex::new(&idx_name, &col, true));
|
||||
}
|
||||
}
|
||||
|
||||
table
|
||||
}
|
||||
|
||||
/// Returns table name.
|
||||
pub fn name(&self) -> &str {
|
||||
&self.def.name
|
||||
}
|
||||
|
||||
/// Returns row count.
|
||||
pub fn count(&self) -> usize {
|
||||
self.rows.read().len()
|
||||
}
|
||||
|
||||
/// Creates an index on a column.
|
||||
pub fn create_index(
|
||||
&self,
|
||||
name: impl Into<String>,
|
||||
column: impl Into<String>,
|
||||
unique: bool,
|
||||
) -> Result<(), SqlError> {
|
||||
let name = name.into();
|
||||
let column = column.into();
|
||||
|
||||
let mut indexes = self.indexes.write();
|
||||
if indexes.contains_key(&name) {
|
||||
return Err(SqlError::InvalidOperation(format!("Index '{}' already exists", name)));
|
||||
}
|
||||
|
||||
let mut index = TableIndex::new(&name, &column, unique);
|
||||
|
||||
// Index existing rows
|
||||
let rows = self.rows.read();
|
||||
for (row_id, row) in rows.iter() {
|
||||
let value = row.get_or_null(&column);
|
||||
index.insert(value, *row_id)?;
|
||||
}
|
||||
|
||||
indexes.insert(name, index);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Drops an index.
|
||||
pub fn drop_index(&self, name: &str) -> Result<(), SqlError> {
|
||||
let mut indexes = self.indexes.write();
|
||||
if indexes.remove(name).is_none() {
|
||||
return Err(SqlError::InvalidOperation(format!("Index '{}' not found", name)));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Inserts a row.
|
||||
pub fn insert(&self, values: HashMap<String, SqlValue>) -> Result<RowId, SqlError> {
|
||||
let mut next_id = self.next_id.write();
|
||||
let row_id = RowId(*next_id);
|
||||
*next_id += 1;
|
||||
|
||||
let row = Row::from_map(row_id, values);
|
||||
|
||||
// Validate row
|
||||
self.def.validate_row(&row)?;
|
||||
|
||||
// Check uniqueness constraints via indexes
|
||||
{
|
||||
let mut indexes = self.indexes.write();
|
||||
for (_, index) in indexes.iter_mut() {
|
||||
if index.unique {
|
||||
let value = row.get_or_null(&index.column);
|
||||
if !value.is_null() && !index.lookup(&value).is_empty() {
|
||||
return Err(SqlError::ConstraintViolation(format!(
|
||||
"Unique constraint violation on column '{}'",
|
||||
index.column
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update indexes
|
||||
for (_, index) in indexes.iter_mut() {
|
||||
let value = row.get_or_null(&index.column);
|
||||
index.insert(value, row_id)?;
|
||||
}
|
||||
}
|
||||
|
||||
// Insert row
|
||||
self.rows.write().insert(row_id, row);
|
||||
|
||||
Ok(row_id)
|
||||
}
|
||||
|
||||
/// Gets a row by ID.
|
||||
pub fn get(&self, id: RowId) -> Option<Row> {
|
||||
self.rows.read().get(&id).cloned()
|
||||
}
|
||||
|
||||
/// Updates a row.
|
||||
pub fn update(&self, id: RowId, updates: HashMap<String, SqlValue>) -> Result<(), SqlError> {
|
||||
let mut rows = self.rows.write();
|
||||
let row = rows.get_mut(&id).ok_or_else(|| {
|
||||
SqlError::InvalidOperation(format!("Row {} not found", id))
|
||||
})?;
|
||||
|
||||
let old_values: HashMap<String, SqlValue> = updates
|
||||
.keys()
|
||||
.map(|k| (k.clone(), row.get_or_null(k)))
|
||||
.collect();
|
||||
|
||||
// Validate updates
|
||||
for (col, value) in &updates {
|
||||
if let Some(col_def) = self.def.get_column(col) {
|
||||
col_def.validate(value)?;
|
||||
}
|
||||
}
|
||||
|
||||
// Update indexes
|
||||
{
|
||||
let mut indexes = self.indexes.write();
|
||||
for (_, index) in indexes.iter_mut() {
|
||||
if let Some(new_value) = updates.get(&index.column) {
|
||||
let old_value = old_values.get(&index.column).cloned().unwrap_or(SqlValue::Null);
|
||||
index.remove(&old_value, &id);
|
||||
index.insert(new_value.clone(), id)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply updates
|
||||
for (col, value) in updates {
|
||||
row.set(&col, value);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Deletes a row.
|
||||
pub fn delete(&self, id: RowId) -> Result<bool, SqlError> {
|
||||
let mut rows = self.rows.write();
|
||||
if let Some(row) = rows.remove(&id) {
|
||||
// Update indexes
|
||||
let mut indexes = self.indexes.write();
|
||||
for (_, index) in indexes.iter_mut() {
|
||||
let value = row.get_or_null(&index.column);
|
||||
index.remove(&value, &id);
|
||||
}
|
||||
Ok(true)
|
||||
} else {
|
||||
Ok(false)
|
||||
}
|
||||
}
|
||||
|
||||
/// Scans all rows.
|
||||
pub fn scan(&self) -> Vec<Row> {
|
||||
self.rows.read().values().cloned().collect()
|
||||
}
|
||||
|
||||
/// Scans with filter function.
|
||||
pub fn scan_filter<F>(&self, predicate: F) -> Vec<Row>
|
||||
where
|
||||
F: Fn(&Row) -> bool,
|
||||
{
|
||||
self.rows
|
||||
.read()
|
||||
.values()
|
||||
.filter(|row| predicate(row))
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Looks up rows by index.
|
||||
pub fn lookup_index(&self, index_name: &str, value: &SqlValue) -> Vec<Row> {
|
||||
let indexes = self.indexes.read();
|
||||
let rows = self.rows.read();
|
||||
|
||||
if let Some(index) = indexes.get(index_name) {
|
||||
index
|
||||
.lookup(value)
|
||||
.into_iter()
|
||||
.filter_map(|id| rows.get(&id).cloned())
|
||||
.collect()
|
||||
} else {
|
||||
Vec::new()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn create_test_table() -> Table {
|
||||
let def = TableDef::new("users")
|
||||
.column(ColumnDef::new("id", SqlType::Integer).primary_key())
|
||||
.column(ColumnDef::new("name", SqlType::Text).not_null())
|
||||
.column(ColumnDef::new("age", SqlType::Integer))
|
||||
.column(ColumnDef::new("email", SqlType::Text).unique());
|
||||
|
||||
Table::new(def)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_table_insert() {
|
||||
let table = create_test_table();
|
||||
|
||||
let mut values = HashMap::new();
|
||||
values.insert("id".to_string(), SqlValue::Integer(1));
|
||||
values.insert("name".to_string(), SqlValue::Text("Alice".to_string()));
|
||||
values.insert("age".to_string(), SqlValue::Integer(30));
|
||||
values.insert("email".to_string(), SqlValue::Text("alice@example.com".to_string()));
|
||||
|
||||
let row_id = table.insert(values).unwrap();
|
||||
assert_eq!(table.count(), 1);
|
||||
|
||||
let row = table.get(row_id).unwrap();
|
||||
assert_eq!(row.get("name"), Some(&SqlValue::Text("Alice".to_string())));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_table_not_null() {
|
||||
let table = create_test_table();
|
||||
|
||||
let mut values = HashMap::new();
|
||||
values.insert("id".to_string(), SqlValue::Integer(1));
|
||||
// Missing required "name" field
|
||||
|
||||
let result = table.insert(values);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_table_unique() {
|
||||
let table = create_test_table();
|
||||
|
||||
let mut values1 = HashMap::new();
|
||||
values1.insert("id".to_string(), SqlValue::Integer(1));
|
||||
values1.insert("name".to_string(), SqlValue::Text("Alice".to_string()));
|
||||
values1.insert("email".to_string(), SqlValue::Text("test@example.com".to_string()));
|
||||
table.insert(values1).unwrap();
|
||||
|
||||
let mut values2 = HashMap::new();
|
||||
values2.insert("id".to_string(), SqlValue::Integer(2));
|
||||
values2.insert("name".to_string(), SqlValue::Text("Bob".to_string()));
|
||||
values2.insert("email".to_string(), SqlValue::Text("test@example.com".to_string()));
|
||||
|
||||
let result = table.insert(values2);
|
||||
assert!(result.is_err()); // Duplicate email
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_table_update() {
|
||||
let table = create_test_table();
|
||||
|
||||
let mut values = HashMap::new();
|
||||
values.insert("id".to_string(), SqlValue::Integer(1));
|
||||
values.insert("name".to_string(), SqlValue::Text("Alice".to_string()));
|
||||
values.insert("age".to_string(), SqlValue::Integer(30));
|
||||
|
||||
let row_id = table.insert(values).unwrap();
|
||||
|
||||
let mut updates = HashMap::new();
|
||||
updates.insert("age".to_string(), SqlValue::Integer(31));
|
||||
table.update(row_id, updates).unwrap();
|
||||
|
||||
let row = table.get(row_id).unwrap();
|
||||
assert_eq!(row.get("age"), Some(&SqlValue::Integer(31)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_table_delete() {
|
||||
let table = create_test_table();
|
||||
|
||||
let mut values = HashMap::new();
|
||||
values.insert("id".to_string(), SqlValue::Integer(1));
|
||||
values.insert("name".to_string(), SqlValue::Text("Alice".to_string()));
|
||||
|
||||
let row_id = table.insert(values).unwrap();
|
||||
assert_eq!(table.count(), 1);
|
||||
|
||||
table.delete(row_id).unwrap();
|
||||
assert_eq!(table.count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_table_index() {
|
||||
let table = create_test_table();
|
||||
table.create_index("idx_name", "name", false).unwrap();
|
||||
|
||||
let mut values = HashMap::new();
|
||||
values.insert("id".to_string(), SqlValue::Integer(1));
|
||||
values.insert("name".to_string(), SqlValue::Text("Alice".to_string()));
|
||||
table.insert(values).unwrap();
|
||||
|
||||
let rows = table.lookup_index("idx_name", &SqlValue::Text("Alice".to_string()));
|
||||
assert_eq!(rows.len(), 1);
|
||||
}
|
||||
}
|
||||
355
crates/synor-database/src/sql/transaction.rs
Normal file
355
crates/synor-database/src/sql/transaction.rs
Normal file
|
|
@ -0,0 +1,355 @@
|
|||
//! ACID transaction support for SQL.
|
||||
|
||||
use super::row::RowId;
|
||||
use super::types::{SqlError, SqlValue};
|
||||
use parking_lot::RwLock;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
/// Transaction identifier.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub struct TransactionId(pub u64);
|
||||
|
||||
impl TransactionId {
|
||||
/// Creates a new transaction ID.
|
||||
pub fn new() -> Self {
|
||||
static COUNTER: AtomicU64 = AtomicU64::new(1);
|
||||
TransactionId(COUNTER.fetch_add(1, Ordering::SeqCst))
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for TransactionId {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for TransactionId {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "txn_{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
/// Transaction state.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub enum TransactionState {
|
||||
/// Transaction is active.
|
||||
Active,
|
||||
/// Transaction is committed.
|
||||
Committed,
|
||||
/// Transaction is rolled back.
|
||||
RolledBack,
|
||||
}
|
||||
|
||||
/// A single operation in a transaction.
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum TransactionOp {
|
||||
/// Insert a row.
|
||||
Insert {
|
||||
table: String,
|
||||
row_id: RowId,
|
||||
values: HashMap<String, SqlValue>,
|
||||
},
|
||||
/// Update a row.
|
||||
Update {
|
||||
table: String,
|
||||
row_id: RowId,
|
||||
old_values: HashMap<String, SqlValue>,
|
||||
new_values: HashMap<String, SqlValue>,
|
||||
},
|
||||
/// Delete a row.
|
||||
Delete {
|
||||
table: String,
|
||||
row_id: RowId,
|
||||
old_values: HashMap<String, SqlValue>,
|
||||
},
|
||||
}
|
||||
|
||||
/// Transaction for tracking changes.
|
||||
#[derive(Debug)]
|
||||
pub struct Transaction {
|
||||
/// Transaction ID.
|
||||
pub id: TransactionId,
|
||||
/// Transaction state.
|
||||
pub state: TransactionState,
|
||||
/// Operations in this transaction.
|
||||
operations: Vec<TransactionOp>,
|
||||
/// Start time.
|
||||
pub started_at: u64,
|
||||
/// Isolation level.
|
||||
pub isolation: IsolationLevel,
|
||||
}
|
||||
|
||||
/// Transaction isolation levels.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub enum IsolationLevel {
|
||||
/// Read uncommitted (dirty reads allowed).
|
||||
ReadUncommitted,
|
||||
/// Read committed (no dirty reads).
|
||||
ReadCommitted,
|
||||
/// Repeatable read (no non-repeatable reads).
|
||||
RepeatableRead,
|
||||
/// Serializable (full isolation).
|
||||
Serializable,
|
||||
}
|
||||
|
||||
impl Default for IsolationLevel {
|
||||
fn default() -> Self {
|
||||
IsolationLevel::ReadCommitted
|
||||
}
|
||||
}
|
||||
|
||||
impl Transaction {
|
||||
/// Creates a new transaction.
|
||||
pub fn new(isolation: IsolationLevel) -> Self {
|
||||
let now = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_millis() as u64;
|
||||
|
||||
Self {
|
||||
id: TransactionId::new(),
|
||||
state: TransactionState::Active,
|
||||
operations: Vec::new(),
|
||||
started_at: now,
|
||||
isolation,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if the transaction is active.
|
||||
pub fn is_active(&self) -> bool {
|
||||
self.state == TransactionState::Active
|
||||
}
|
||||
|
||||
/// Records an insert operation.
|
||||
pub fn record_insert(&mut self, table: String, row_id: RowId, values: HashMap<String, SqlValue>) {
|
||||
self.operations.push(TransactionOp::Insert {
|
||||
table,
|
||||
row_id,
|
||||
values,
|
||||
});
|
||||
}
|
||||
|
||||
/// Records an update operation.
|
||||
pub fn record_update(
|
||||
&mut self,
|
||||
table: String,
|
||||
row_id: RowId,
|
||||
old_values: HashMap<String, SqlValue>,
|
||||
new_values: HashMap<String, SqlValue>,
|
||||
) {
|
||||
self.operations.push(TransactionOp::Update {
|
||||
table,
|
||||
row_id,
|
||||
old_values,
|
||||
new_values,
|
||||
});
|
||||
}
|
||||
|
||||
/// Records a delete operation.
|
||||
pub fn record_delete(&mut self, table: String, row_id: RowId, old_values: HashMap<String, SqlValue>) {
|
||||
self.operations.push(TransactionOp::Delete {
|
||||
table,
|
||||
row_id,
|
||||
old_values,
|
||||
});
|
||||
}
|
||||
|
||||
/// Returns operations for rollback (in reverse order).
|
||||
pub fn rollback_ops(&self) -> impl Iterator<Item = &TransactionOp> {
|
||||
self.operations.iter().rev()
|
||||
}
|
||||
|
||||
/// Returns operations for commit.
|
||||
pub fn commit_ops(&self) -> &[TransactionOp] {
|
||||
&self.operations
|
||||
}
|
||||
|
||||
/// Marks the transaction as committed.
|
||||
pub fn mark_committed(&mut self) {
|
||||
self.state = TransactionState::Committed;
|
||||
}
|
||||
|
||||
/// Marks the transaction as rolled back.
|
||||
pub fn mark_rolled_back(&mut self) {
|
||||
self.state = TransactionState::RolledBack;
|
||||
}
|
||||
}
|
||||
|
||||
/// Transaction manager.
|
||||
pub struct TransactionManager {
|
||||
/// Active transactions.
|
||||
transactions: RwLock<HashMap<TransactionId, Transaction>>,
|
||||
}
|
||||
|
||||
impl TransactionManager {
|
||||
/// Creates a new transaction manager.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
transactions: RwLock::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Begins a new transaction.
|
||||
pub fn begin(&self, isolation: IsolationLevel) -> TransactionId {
|
||||
let txn = Transaction::new(isolation);
|
||||
let id = txn.id;
|
||||
self.transactions.write().insert(id, txn);
|
||||
id
|
||||
}
|
||||
|
||||
/// Gets a transaction by ID.
|
||||
pub fn get(&self, id: TransactionId) -> Option<Transaction> {
|
||||
self.transactions.read().get(&id).cloned()
|
||||
}
|
||||
|
||||
/// Records an operation in a transaction.
|
||||
pub fn record_op(&self, id: TransactionId, op: TransactionOp) -> Result<(), SqlError> {
|
||||
let mut txns = self.transactions.write();
|
||||
let txn = txns
|
||||
.get_mut(&id)
|
||||
.ok_or_else(|| SqlError::Transaction(format!("Transaction {} not found", id)))?;
|
||||
|
||||
if !txn.is_active() {
|
||||
return Err(SqlError::Transaction(format!("Transaction {} is not active", id)));
|
||||
}
|
||||
|
||||
txn.operations.push(op);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Commits a transaction.
|
||||
pub fn commit(&self, id: TransactionId) -> Result<Vec<TransactionOp>, SqlError> {
|
||||
let mut txns = self.transactions.write();
|
||||
let txn = txns
|
||||
.get_mut(&id)
|
||||
.ok_or_else(|| SqlError::Transaction(format!("Transaction {} not found", id)))?;
|
||||
|
||||
if !txn.is_active() {
|
||||
return Err(SqlError::Transaction(format!("Transaction {} is not active", id)));
|
||||
}
|
||||
|
||||
txn.mark_committed();
|
||||
let ops = txn.operations.clone();
|
||||
txns.remove(&id);
|
||||
Ok(ops)
|
||||
}
|
||||
|
||||
/// Rolls back a transaction, returning operations to undo.
|
||||
pub fn rollback(&self, id: TransactionId) -> Result<Vec<TransactionOp>, SqlError> {
|
||||
let mut txns = self.transactions.write();
|
||||
let txn = txns
|
||||
.get_mut(&id)
|
||||
.ok_or_else(|| SqlError::Transaction(format!("Transaction {} not found", id)))?;
|
||||
|
||||
if !txn.is_active() {
|
||||
return Err(SqlError::Transaction(format!("Transaction {} is not active", id)));
|
||||
}
|
||||
|
||||
txn.mark_rolled_back();
|
||||
let ops: Vec<TransactionOp> = txn.operations.iter().rev().cloned().collect();
|
||||
txns.remove(&id);
|
||||
Ok(ops)
|
||||
}
|
||||
|
||||
/// Returns the number of active transactions.
|
||||
pub fn active_count(&self) -> usize {
|
||||
self.transactions.read().len()
|
||||
}
|
||||
|
||||
/// Checks if a transaction exists and is active.
|
||||
pub fn is_active(&self, id: TransactionId) -> bool {
|
||||
self.transactions
|
||||
.read()
|
||||
.get(&id)
|
||||
.map(|t| t.is_active())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for TransactionManager {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for Transaction {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
id: self.id,
|
||||
state: self.state,
|
||||
operations: self.operations.clone(),
|
||||
started_at: self.started_at,
|
||||
isolation: self.isolation,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_transaction_lifecycle() {
|
||||
let manager = TransactionManager::new();
|
||||
|
||||
let txn_id = manager.begin(IsolationLevel::ReadCommitted);
|
||||
assert!(manager.is_active(txn_id));
|
||||
assert_eq!(manager.active_count(), 1);
|
||||
|
||||
let mut values = HashMap::new();
|
||||
values.insert("name".to_string(), SqlValue::Text("Alice".to_string()));
|
||||
|
||||
manager
|
||||
.record_op(
|
||||
txn_id,
|
||||
TransactionOp::Insert {
|
||||
table: "users".to_string(),
|
||||
row_id: RowId(1),
|
||||
values,
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let ops = manager.commit(txn_id).unwrap();
|
||||
assert_eq!(ops.len(), 1);
|
||||
assert_eq!(manager.active_count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transaction_rollback() {
|
||||
let manager = TransactionManager::new();
|
||||
|
||||
let txn_id = manager.begin(IsolationLevel::ReadCommitted);
|
||||
|
||||
let mut values = HashMap::new();
|
||||
values.insert("name".to_string(), SqlValue::Text("Bob".to_string()));
|
||||
|
||||
manager
|
||||
.record_op(
|
||||
txn_id,
|
||||
TransactionOp::Insert {
|
||||
table: "users".to_string(),
|
||||
row_id: RowId(1),
|
||||
values,
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let ops = manager.rollback(txn_id).unwrap();
|
||||
assert_eq!(ops.len(), 1);
|
||||
assert_eq!(manager.active_count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transaction_not_found() {
|
||||
let manager = TransactionManager::new();
|
||||
let fake_id = TransactionId(99999);
|
||||
|
||||
assert!(!manager.is_active(fake_id));
|
||||
assert!(manager.commit(fake_id).is_err());
|
||||
assert!(manager.rollback(fake_id).is_err());
|
||||
}
|
||||
}
|
||||
368
crates/synor-database/src/sql/types.rs
Normal file
368
crates/synor-database/src/sql/types.rs
Normal file
|
|
@ -0,0 +1,368 @@
|
|||
//! SQL type system.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value as JsonValue;
|
||||
use std::cmp::Ordering;
|
||||
use std::hash::{Hash, Hasher};
|
||||
use thiserror::Error;
|
||||
|
||||
/// SQL data types.
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum SqlType {
|
||||
/// 64-bit signed integer.
|
||||
Integer,
|
||||
/// 64-bit floating point.
|
||||
Real,
|
||||
/// UTF-8 text string.
|
||||
Text,
|
||||
/// Binary data.
|
||||
Blob,
|
||||
/// Boolean value.
|
||||
Boolean,
|
||||
/// Unix timestamp in milliseconds.
|
||||
Timestamp,
|
||||
/// Null type.
|
||||
Null,
|
||||
}
|
||||
|
||||
impl SqlType {
|
||||
/// Parses type from string.
|
||||
pub fn from_str(s: &str) -> Option<Self> {
|
||||
match s.to_uppercase().as_str() {
|
||||
"INTEGER" | "INT" | "BIGINT" | "SMALLINT" => Some(SqlType::Integer),
|
||||
"REAL" | "FLOAT" | "DOUBLE" => Some(SqlType::Real),
|
||||
"TEXT" | "VARCHAR" | "CHAR" | "STRING" => Some(SqlType::Text),
|
||||
"BLOB" | "BINARY" | "BYTES" => Some(SqlType::Blob),
|
||||
"BOOLEAN" | "BOOL" => Some(SqlType::Boolean),
|
||||
"TIMESTAMP" | "DATETIME" | "DATE" => Some(SqlType::Timestamp),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the default value for this type.
|
||||
pub fn default_value(&self) -> SqlValue {
|
||||
match self {
|
||||
SqlType::Integer => SqlValue::Integer(0),
|
||||
SqlType::Real => SqlValue::Real(0.0),
|
||||
SqlType::Text => SqlValue::Text(String::new()),
|
||||
SqlType::Blob => SqlValue::Blob(Vec::new()),
|
||||
SqlType::Boolean => SqlValue::Boolean(false),
|
||||
SqlType::Timestamp => SqlValue::Timestamp(0),
|
||||
SqlType::Null => SqlValue::Null,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// SQL value types.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub enum SqlValue {
|
||||
/// Null value.
|
||||
Null,
|
||||
/// Integer value.
|
||||
Integer(i64),
|
||||
/// Real (floating point) value.
|
||||
Real(f64),
|
||||
/// Text string value.
|
||||
Text(String),
|
||||
/// Binary blob value.
|
||||
Blob(Vec<u8>),
|
||||
/// Boolean value.
|
||||
Boolean(bool),
|
||||
/// Timestamp value (Unix ms).
|
||||
Timestamp(u64),
|
||||
}
|
||||
|
||||
impl SqlValue {
|
||||
/// Returns the SQL type of this value.
|
||||
pub fn sql_type(&self) -> SqlType {
|
||||
match self {
|
||||
SqlValue::Null => SqlType::Null,
|
||||
SqlValue::Integer(_) => SqlType::Integer,
|
||||
SqlValue::Real(_) => SqlType::Real,
|
||||
SqlValue::Text(_) => SqlType::Text,
|
||||
SqlValue::Blob(_) => SqlType::Blob,
|
||||
SqlValue::Boolean(_) => SqlType::Boolean,
|
||||
SqlValue::Timestamp(_) => SqlType::Timestamp,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if this is a null value.
|
||||
pub fn is_null(&self) -> bool {
|
||||
matches!(self, SqlValue::Null)
|
||||
}
|
||||
|
||||
/// Converts to integer if possible.
|
||||
pub fn as_integer(&self) -> Option<i64> {
|
||||
match self {
|
||||
SqlValue::Integer(i) => Some(*i),
|
||||
SqlValue::Real(f) => Some(*f as i64),
|
||||
SqlValue::Boolean(b) => Some(if *b { 1 } else { 0 }),
|
||||
SqlValue::Timestamp(t) => Some(*t as i64),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts to real if possible.
|
||||
pub fn as_real(&self) -> Option<f64> {
|
||||
match self {
|
||||
SqlValue::Integer(i) => Some(*i as f64),
|
||||
SqlValue::Real(f) => Some(*f),
|
||||
SqlValue::Timestamp(t) => Some(*t as f64),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts to text if possible.
|
||||
pub fn as_text(&self) -> Option<&str> {
|
||||
match self {
|
||||
SqlValue::Text(s) => Some(s),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts to boolean if possible.
|
||||
pub fn as_boolean(&self) -> Option<bool> {
|
||||
match self {
|
||||
SqlValue::Boolean(b) => Some(*b),
|
||||
SqlValue::Integer(i) => Some(*i != 0),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts to JSON value.
|
||||
pub fn to_json(&self) -> JsonValue {
|
||||
match self {
|
||||
SqlValue::Null => JsonValue::Null,
|
||||
SqlValue::Integer(i) => JsonValue::Number((*i).into()),
|
||||
SqlValue::Real(f) => serde_json::Number::from_f64(*f)
|
||||
.map(JsonValue::Number)
|
||||
.unwrap_or(JsonValue::Null),
|
||||
SqlValue::Text(s) => JsonValue::String(s.clone()),
|
||||
SqlValue::Blob(b) => JsonValue::String(hex::encode(b)),
|
||||
SqlValue::Boolean(b) => JsonValue::Bool(*b),
|
||||
SqlValue::Timestamp(t) => JsonValue::Number((*t).into()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a numeric type order for comparison purposes.
|
||||
fn type_order(&self) -> u8 {
|
||||
match self {
|
||||
SqlValue::Null => 0,
|
||||
SqlValue::Boolean(_) => 1,
|
||||
SqlValue::Integer(_) => 2,
|
||||
SqlValue::Real(_) => 3,
|
||||
SqlValue::Text(_) => 4,
|
||||
SqlValue::Blob(_) => 5,
|
||||
SqlValue::Timestamp(_) => 6,
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates from JSON value.
|
||||
pub fn from_json(value: &JsonValue) -> Self {
|
||||
match value {
|
||||
JsonValue::Null => SqlValue::Null,
|
||||
JsonValue::Bool(b) => SqlValue::Boolean(*b),
|
||||
JsonValue::Number(n) => {
|
||||
if let Some(i) = n.as_i64() {
|
||||
SqlValue::Integer(i)
|
||||
} else if let Some(f) = n.as_f64() {
|
||||
SqlValue::Real(f)
|
||||
} else {
|
||||
SqlValue::Null
|
||||
}
|
||||
}
|
||||
JsonValue::String(s) => SqlValue::Text(s.clone()),
|
||||
_ => SqlValue::Text(value.to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for SqlValue {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
match (self, other) {
|
||||
(SqlValue::Null, SqlValue::Null) => true,
|
||||
(SqlValue::Integer(a), SqlValue::Integer(b)) => a == b,
|
||||
(SqlValue::Real(a), SqlValue::Real(b)) => a == b,
|
||||
(SqlValue::Text(a), SqlValue::Text(b)) => a == b,
|
||||
(SqlValue::Blob(a), SqlValue::Blob(b)) => a == b,
|
||||
(SqlValue::Boolean(a), SqlValue::Boolean(b)) => a == b,
|
||||
(SqlValue::Timestamp(a), SqlValue::Timestamp(b)) => a == b,
|
||||
// Cross-type comparisons
|
||||
(SqlValue::Integer(a), SqlValue::Real(b)) => (*a as f64) == *b,
|
||||
(SqlValue::Real(a), SqlValue::Integer(b)) => *a == (*b as f64),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Eq for SqlValue {}
|
||||
|
||||
impl Hash for SqlValue {
|
||||
fn hash<H: Hasher>(&self, state: &mut H) {
|
||||
std::mem::discriminant(self).hash(state);
|
||||
match self {
|
||||
SqlValue::Null => {}
|
||||
SqlValue::Integer(i) => i.hash(state),
|
||||
SqlValue::Real(f) => f.to_bits().hash(state),
|
||||
SqlValue::Text(s) => s.hash(state),
|
||||
SqlValue::Blob(b) => b.hash(state),
|
||||
SqlValue::Boolean(b) => b.hash(state),
|
||||
SqlValue::Timestamp(t) => t.hash(state),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialOrd for SqlValue {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
}
|
||||
|
||||
impl Ord for SqlValue {
|
||||
fn cmp(&self, other: &Self) -> Ordering {
|
||||
match (self, other) {
|
||||
(SqlValue::Null, SqlValue::Null) => Ordering::Equal,
|
||||
(SqlValue::Null, _) => Ordering::Less,
|
||||
(_, SqlValue::Null) => Ordering::Greater,
|
||||
(SqlValue::Integer(a), SqlValue::Integer(b)) => a.cmp(b),
|
||||
(SqlValue::Real(a), SqlValue::Real(b)) => {
|
||||
// Convert to bits for total ordering (handles NaN)
|
||||
a.to_bits().cmp(&b.to_bits())
|
||||
}
|
||||
(SqlValue::Text(a), SqlValue::Text(b)) => a.cmp(b),
|
||||
(SqlValue::Blob(a), SqlValue::Blob(b)) => a.cmp(b),
|
||||
(SqlValue::Boolean(a), SqlValue::Boolean(b)) => a.cmp(b),
|
||||
(SqlValue::Timestamp(a), SqlValue::Timestamp(b)) => a.cmp(b),
|
||||
(SqlValue::Integer(a), SqlValue::Real(b)) => {
|
||||
(*a as f64).to_bits().cmp(&b.to_bits())
|
||||
}
|
||||
(SqlValue::Real(a), SqlValue::Integer(b)) => {
|
||||
a.to_bits().cmp(&(*b as f64).to_bits())
|
||||
}
|
||||
// Different types: order by type discriminant
|
||||
_ => self.type_order().cmp(&other.type_order()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SqlValue {
|
||||
fn default() -> Self {
|
||||
SqlValue::Null
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for SqlValue {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
SqlValue::Null => write!(f, "NULL"),
|
||||
SqlValue::Integer(i) => write!(f, "{}", i),
|
||||
SqlValue::Real(r) => write!(f, "{}", r),
|
||||
SqlValue::Text(s) => write!(f, "'{}'", s),
|
||||
SqlValue::Blob(b) => write!(f, "X'{}'", hex::encode(b)),
|
||||
SqlValue::Boolean(b) => write!(f, "{}", if *b { "TRUE" } else { "FALSE" }),
|
||||
SqlValue::Timestamp(t) => write!(f, "{}", t),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// SQL errors.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum SqlError {
|
||||
/// Parse error.
|
||||
#[error("Parse error: {0}")]
|
||||
Parse(String),
|
||||
|
||||
/// Table not found.
|
||||
#[error("Table not found: {0}")]
|
||||
TableNotFound(String),
|
||||
|
||||
/// Table already exists.
|
||||
#[error("Table already exists: {0}")]
|
||||
TableExists(String),
|
||||
|
||||
/// Column not found.
|
||||
#[error("Column not found: {0}")]
|
||||
ColumnNotFound(String),
|
||||
|
||||
/// Type mismatch.
|
||||
#[error("Type mismatch: expected {expected}, got {got}")]
|
||||
TypeMismatch { expected: String, got: String },
|
||||
|
||||
/// Constraint violation.
|
||||
#[error("Constraint violation: {0}")]
|
||||
ConstraintViolation(String),
|
||||
|
||||
/// Primary key violation.
|
||||
#[error("Primary key violation: duplicate key {0}")]
|
||||
PrimaryKeyViolation(String),
|
||||
|
||||
/// Not null violation.
|
||||
#[error("Not null violation: column {0} cannot be null")]
|
||||
NotNullViolation(String),
|
||||
|
||||
/// Transaction error.
|
||||
#[error("Transaction error: {0}")]
|
||||
Transaction(String),
|
||||
|
||||
/// Invalid operation.
|
||||
#[error("Invalid operation: {0}")]
|
||||
InvalidOperation(String),
|
||||
|
||||
/// Unsupported feature.
|
||||
#[error("Unsupported: {0}")]
|
||||
Unsupported(String),
|
||||
|
||||
/// Internal error.
|
||||
#[error("Internal error: {0}")]
|
||||
Internal(String),
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_sql_type_from_str() {
|
||||
assert_eq!(SqlType::from_str("INTEGER"), Some(SqlType::Integer));
|
||||
assert_eq!(SqlType::from_str("int"), Some(SqlType::Integer));
|
||||
assert_eq!(SqlType::from_str("TEXT"), Some(SqlType::Text));
|
||||
assert_eq!(SqlType::from_str("varchar"), Some(SqlType::Text));
|
||||
assert_eq!(SqlType::from_str("BOOLEAN"), Some(SqlType::Boolean));
|
||||
assert_eq!(SqlType::from_str("unknown"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sql_value_conversions() {
|
||||
let int_val = SqlValue::Integer(42);
|
||||
assert_eq!(int_val.as_integer(), Some(42));
|
||||
assert_eq!(int_val.as_real(), Some(42.0));
|
||||
|
||||
let real_val = SqlValue::Real(3.14);
|
||||
assert_eq!(real_val.as_real(), Some(3.14));
|
||||
assert_eq!(real_val.as_integer(), Some(3));
|
||||
|
||||
let text_val = SqlValue::Text("hello".to_string());
|
||||
assert_eq!(text_val.as_text(), Some("hello"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sql_value_comparison() {
|
||||
assert_eq!(SqlValue::Integer(5), SqlValue::Integer(5));
|
||||
assert!(SqlValue::Integer(5) < SqlValue::Integer(10));
|
||||
assert!(SqlValue::Text("a".to_string()) < SqlValue::Text("b".to_string()));
|
||||
assert!(SqlValue::Null < SqlValue::Integer(0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sql_value_json() {
|
||||
let val = SqlValue::Integer(42);
|
||||
assert_eq!(val.to_json(), serde_json::json!(42));
|
||||
|
||||
let val = SqlValue::Text("hello".to_string());
|
||||
assert_eq!(val.to_json(), serde_json::json!("hello"));
|
||||
|
||||
let json = serde_json::json!(true);
|
||||
let val = SqlValue::from_json(&json);
|
||||
assert_eq!(val, SqlValue::Boolean(true));
|
||||
}
|
||||
}
|
||||
505
docs/PLAN/PHASE10-Database-Advanced.md
Normal file
505
docs/PLAN/PHASE10-Database-Advanced.md
Normal file
|
|
@ -0,0 +1,505 @@
|
|||
# Phase 10 Advanced Database Features
|
||||
|
||||
> Implementation plan for SQL, Graph, and Replication features in Synor Database L2
|
||||
|
||||
## Overview
|
||||
|
||||
These advanced features extend the Synor Database to support:
|
||||
1. **Relational (SQL)** - SQLite-compatible query subset for structured data
|
||||
2. **Graph Store** - Relationship queries for connected data
|
||||
3. **Replication** - Raft consensus for high availability
|
||||
|
||||
## Feature 1: Relational (SQL) Store
|
||||
|
||||
### Purpose
|
||||
Provide a familiar SQL interface for developers who need structured relational queries, joins, and ACID transactions.
|
||||
|
||||
### Architecture
|
||||
|
||||
```text
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ SQL QUERY LAYER │
|
||||
├─────────────────────────────────────────────────────────────┤
|
||||
│ │
|
||||
│ ┌──────────────┐ ┌──────────────┐ ┌──────────────────┐ │
|
||||
│ │ SQL Parser │ │ Planner │ │ Executor │ │
|
||||
│ │ (sqlparser) │ │ (logical) │ │ (physical) │ │
|
||||
│ └──────────────┘ └──────────────┘ └──────────────────┘ │
|
||||
│ │
|
||||
│ ┌──────────────────────────────────────────────────────┐ │
|
||||
│ │ Table Storage Engine │ │
|
||||
│ │ - Row-oriented storage │ │
|
||||
│ │ - B-tree indexes │ │
|
||||
│ │ - Transaction log (WAL) │ │
|
||||
│ └──────────────────────────────────────────────────────┘ │
|
||||
│ │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Supported SQL Subset
|
||||
|
||||
| Category | Statements |
|
||||
|----------|------------|
|
||||
| DDL | CREATE TABLE, DROP TABLE, ALTER TABLE |
|
||||
| DML | SELECT, INSERT, UPDATE, DELETE |
|
||||
| Clauses | WHERE, ORDER BY, LIMIT, OFFSET, GROUP BY, HAVING |
|
||||
| Joins | INNER JOIN, LEFT JOIN, RIGHT JOIN |
|
||||
| Functions | COUNT, SUM, AVG, MIN, MAX, COALESCE |
|
||||
| Operators | =, !=, <, >, <=, >=, AND, OR, NOT, IN, LIKE |
|
||||
|
||||
### Data Types
|
||||
|
||||
| SQL Type | Rust Type | Storage |
|
||||
|----------|-----------|---------|
|
||||
| INTEGER | i64 | 8 bytes |
|
||||
| REAL | f64 | 8 bytes |
|
||||
| TEXT | String | Variable |
|
||||
| BLOB | Vec<u8> | Variable |
|
||||
| BOOLEAN | bool | 1 byte |
|
||||
| TIMESTAMP | u64 | 8 bytes (Unix ms) |
|
||||
|
||||
### Implementation Components
|
||||
|
||||
```
|
||||
crates/synor-database/src/sql/
|
||||
├── mod.rs # Module exports
|
||||
├── parser.rs # SQL parsing (sqlparser-rs)
|
||||
├── planner.rs # Query planning & optimization
|
||||
├── executor.rs # Query execution engine
|
||||
├── table.rs # Table definition & storage
|
||||
├── row.rs # Row representation
|
||||
├── types.rs # SQL type system
|
||||
├── transaction.rs # ACID transactions
|
||||
└── index.rs # SQL-specific indexing
|
||||
```
|
||||
|
||||
### API Design
|
||||
|
||||
```rust
|
||||
// Table definition
|
||||
pub struct TableDef {
|
||||
pub name: String,
|
||||
pub columns: Vec<ColumnDef>,
|
||||
pub primary_key: Option<String>,
|
||||
pub indexes: Vec<IndexDef>,
|
||||
}
|
||||
|
||||
pub struct ColumnDef {
|
||||
pub name: String,
|
||||
pub data_type: SqlType,
|
||||
pub nullable: bool,
|
||||
pub default: Option<SqlValue>,
|
||||
}
|
||||
|
||||
// SQL execution
|
||||
pub struct SqlEngine {
|
||||
tables: HashMap<String, Table>,
|
||||
transaction_log: TransactionLog,
|
||||
}
|
||||
|
||||
impl SqlEngine {
|
||||
pub fn execute(&mut self, sql: &str) -> Result<SqlResult, SqlError>;
|
||||
pub fn begin_transaction(&mut self) -> TransactionId;
|
||||
pub fn commit(&mut self, txn: TransactionId) -> Result<(), SqlError>;
|
||||
pub fn rollback(&mut self, txn: TransactionId);
|
||||
}
|
||||
```
|
||||
|
||||
### Gateway Endpoints
|
||||
|
||||
| Endpoint | Method | Description |
|
||||
|----------|--------|-------------|
|
||||
| `/db/:db/sql` | POST | Execute SQL query |
|
||||
| `/db/:db/sql/tables` | GET | List tables |
|
||||
| `/db/:db/sql/tables/:table` | GET | Get table schema |
|
||||
| `/db/:db/sql/tables/:table` | DELETE | Drop table |
|
||||
|
||||
---
|
||||
|
||||
## Feature 2: Graph Store
|
||||
|
||||
### Purpose
|
||||
Enable relationship-based queries for social networks, knowledge graphs, recommendation engines, and any connected data.
|
||||
|
||||
### Architecture
|
||||
|
||||
```text
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ GRAPH QUERY LAYER │
|
||||
├─────────────────────────────────────────────────────────────┤
|
||||
│ │
|
||||
│ ┌──────────────┐ ┌──────────────┐ ┌──────────────────┐ │
|
||||
│ │ Graph Query │ │ Traversal │ │ Path Finding │ │
|
||||
│ │ Parser │ │ Engine │ │ (Dijkstra) │ │
|
||||
│ └──────────────┘ └──────────────┘ └──────────────────┘ │
|
||||
│ │
|
||||
│ ┌──────────────────────────────────────────────────────┐ │
|
||||
│ │ Graph Storage Engine │ │
|
||||
│ │ - Adjacency list storage │ │
|
||||
│ │ - Edge index (source, target, type) │ │
|
||||
│ │ - Property storage │ │
|
||||
│ └──────────────────────────────────────────────────────┘ │
|
||||
│ │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Data Model
|
||||
|
||||
```text
|
||||
Node (Vertex):
|
||||
- id: NodeId (32 bytes)
|
||||
- labels: Vec<String>
|
||||
- properties: JsonValue
|
||||
|
||||
Edge (Relationship):
|
||||
- id: EdgeId (32 bytes)
|
||||
- source: NodeId
|
||||
- target: NodeId
|
||||
- edge_type: String
|
||||
- properties: JsonValue
|
||||
- directed: bool
|
||||
```
|
||||
|
||||
### Query Language (Simplified Cypher-like)
|
||||
|
||||
```
|
||||
// Find all friends of user Alice
|
||||
MATCH (a:User {name: "Alice"})-[:FRIEND]->(friend)
|
||||
RETURN friend
|
||||
|
||||
// Find shortest path between two nodes
|
||||
MATCH path = shortestPath((a:User {id: "123"})-[*]-(b:User {id: "456"}))
|
||||
RETURN path
|
||||
|
||||
// Find mutual friends
|
||||
MATCH (a:User {name: "Alice"})-[:FRIEND]->(mutual)<-[:FRIEND]-(b:User {name: "Bob"})
|
||||
RETURN mutual
|
||||
```
|
||||
|
||||
### Implementation Components
|
||||
|
||||
```
|
||||
crates/synor-database/src/graph/
|
||||
├── mod.rs # Module exports
|
||||
├── node.rs # Node definition & storage
|
||||
├── edge.rs # Edge definition & storage
|
||||
├── store.rs # Graph storage engine
|
||||
├── query.rs # Query language parser
|
||||
├── traversal.rs # Graph traversal algorithms
|
||||
├── path.rs # Path finding (BFS, DFS, Dijkstra)
|
||||
└── index.rs # Graph-specific indexes
|
||||
```
|
||||
|
||||
### API Design
|
||||
|
||||
```rust
|
||||
pub struct Node {
|
||||
pub id: NodeId,
|
||||
pub labels: Vec<String>,
|
||||
pub properties: JsonValue,
|
||||
}
|
||||
|
||||
pub struct Edge {
|
||||
pub id: EdgeId,
|
||||
pub source: NodeId,
|
||||
pub target: NodeId,
|
||||
pub edge_type: String,
|
||||
pub properties: JsonValue,
|
||||
}
|
||||
|
||||
pub struct GraphStore {
|
||||
nodes: HashMap<NodeId, Node>,
|
||||
edges: HashMap<EdgeId, Edge>,
|
||||
adjacency: HashMap<NodeId, Vec<EdgeId>>, // outgoing
|
||||
reverse_adj: HashMap<NodeId, Vec<EdgeId>>, // incoming
|
||||
}
|
||||
|
||||
impl GraphStore {
|
||||
// Node operations
|
||||
pub fn create_node(&mut self, labels: Vec<String>, props: JsonValue) -> NodeId;
|
||||
pub fn get_node(&self, id: &NodeId) -> Option<&Node>;
|
||||
pub fn update_node(&mut self, id: &NodeId, props: JsonValue) -> Result<(), GraphError>;
|
||||
pub fn delete_node(&mut self, id: &NodeId) -> Result<(), GraphError>;
|
||||
|
||||
// Edge operations
|
||||
pub fn create_edge(&mut self, source: NodeId, target: NodeId, edge_type: &str, props: JsonValue) -> EdgeId;
|
||||
pub fn get_edge(&self, id: &EdgeId) -> Option<&Edge>;
|
||||
pub fn delete_edge(&mut self, id: &EdgeId) -> Result<(), GraphError>;
|
||||
|
||||
// Traversal
|
||||
pub fn neighbors(&self, id: &NodeId, direction: Direction) -> Vec<&Node>;
|
||||
pub fn edges_of(&self, id: &NodeId, direction: Direction) -> Vec<&Edge>;
|
||||
pub fn shortest_path(&self, from: &NodeId, to: &NodeId) -> Option<Vec<NodeId>>;
|
||||
pub fn traverse(&self, start: &NodeId, query: &TraversalQuery) -> Vec<TraversalResult>;
|
||||
}
|
||||
```
|
||||
|
||||
### Gateway Endpoints
|
||||
|
||||
| Endpoint | Method | Description |
|
||||
|----------|--------|-------------|
|
||||
| `/db/:db/graph/nodes` | POST | Create node |
|
||||
| `/db/:db/graph/nodes/:id` | GET | Get node |
|
||||
| `/db/:db/graph/nodes/:id` | PUT | Update node |
|
||||
| `/db/:db/graph/nodes/:id` | DELETE | Delete node |
|
||||
| `/db/:db/graph/edges` | POST | Create edge |
|
||||
| `/db/:db/graph/edges/:id` | GET | Get edge |
|
||||
| `/db/:db/graph/edges/:id` | DELETE | Delete edge |
|
||||
| `/db/:db/graph/query` | POST | Execute graph query |
|
||||
| `/db/:db/graph/path` | POST | Find shortest path |
|
||||
| `/db/:db/graph/traverse` | POST | Traverse from node |
|
||||
|
||||
---
|
||||
|
||||
## Feature 3: Replication (Raft Consensus)
|
||||
|
||||
### Purpose
|
||||
Provide high availability and fault tolerance through distributed consensus, ensuring data consistency across multiple nodes.
|
||||
|
||||
### Architecture
|
||||
|
||||
```text
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ RAFT CONSENSUS LAYER │
|
||||
├─────────────────────────────────────────────────────────────┤
|
||||
│ │
|
||||
│ ┌──────────────┐ ┌──────────────┐ ┌──────────────────┐ │
|
||||
│ │ Leader │ │ Follower │ │ Candidate │ │
|
||||
│ │ Election │ │ Replication │ │ (Election) │ │
|
||||
│ └──────────────┘ └──────────────┘ └──────────────────┘ │
|
||||
│ │
|
||||
│ ┌──────────────────────────────────────────────────────┐ │
|
||||
│ │ Log Replication │ │
|
||||
│ │ - Append entries │ │
|
||||
│ │ - Commit index │ │
|
||||
│ │ - Log compaction (snapshots) │ │
|
||||
│ └──────────────────────────────────────────────────────┘ │
|
||||
│ │
|
||||
│ ┌──────────────────────────────────────────────────────┐ │
|
||||
│ │ State Machine │ │
|
||||
│ │ - Apply committed entries │ │
|
||||
│ │ - Database operations │ │
|
||||
│ └──────────────────────────────────────────────────────┘ │
|
||||
│ │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Raft Protocol Overview
|
||||
|
||||
```text
|
||||
Leader Election:
|
||||
1. Followers timeout → become Candidate
|
||||
2. Candidate requests votes from peers
|
||||
3. Majority votes → become Leader
|
||||
4. Leader sends heartbeats to maintain authority
|
||||
|
||||
Log Replication:
|
||||
1. Client sends write to Leader
|
||||
2. Leader appends to local log
|
||||
3. Leader replicates to Followers
|
||||
4. Majority acknowledge → entry committed
|
||||
5. Leader applies to state machine
|
||||
6. Leader responds to client
|
||||
```
|
||||
|
||||
### Implementation Components
|
||||
|
||||
```
|
||||
crates/synor-database/src/replication/
|
||||
├── mod.rs # Module exports
|
||||
├── raft.rs # Core Raft implementation
|
||||
├── state.rs # Node state (Leader/Follower/Candidate)
|
||||
├── log.rs # Replicated log
|
||||
├── rpc.rs # RPC messages (AppendEntries, RequestVote)
|
||||
├── election.rs # Leader election logic
|
||||
├── snapshot.rs # Log compaction & snapshots
|
||||
├── cluster.rs # Cluster membership
|
||||
└── client.rs # Client for forwarding to leader
|
||||
```
|
||||
|
||||
### API Design
|
||||
|
||||
```rust
|
||||
#[derive(Clone, Copy, PartialEq)]
|
||||
pub enum NodeRole {
|
||||
Leader,
|
||||
Follower,
|
||||
Candidate,
|
||||
}
|
||||
|
||||
pub struct RaftConfig {
|
||||
pub node_id: u64,
|
||||
pub peers: Vec<PeerAddress>,
|
||||
pub election_timeout_ms: (u64, u64), // min, max
|
||||
pub heartbeat_interval_ms: u64,
|
||||
pub snapshot_threshold: u64,
|
||||
}
|
||||
|
||||
pub struct LogEntry {
|
||||
pub term: u64,
|
||||
pub index: u64,
|
||||
pub command: Command,
|
||||
}
|
||||
|
||||
pub enum Command {
|
||||
// Database operations
|
||||
KvSet { key: String, value: Vec<u8> },
|
||||
KvDelete { key: String },
|
||||
DocInsert { collection: String, doc: JsonValue },
|
||||
DocUpdate { collection: String, id: DocumentId, update: JsonValue },
|
||||
DocDelete { collection: String, id: DocumentId },
|
||||
// ... other operations
|
||||
}
|
||||
|
||||
pub struct RaftNode {
|
||||
config: RaftConfig,
|
||||
state: RaftState,
|
||||
log: ReplicatedLog,
|
||||
state_machine: Arc<Database>,
|
||||
}
|
||||
|
||||
impl RaftNode {
|
||||
pub async fn start(&mut self) -> Result<(), RaftError>;
|
||||
pub async fn propose(&self, command: Command) -> Result<(), RaftError>;
|
||||
pub fn is_leader(&self) -> bool;
|
||||
pub fn leader_id(&self) -> Option<u64>;
|
||||
pub fn status(&self) -> ClusterStatus;
|
||||
}
|
||||
|
||||
// RPC Messages
|
||||
pub struct AppendEntries {
|
||||
pub term: u64,
|
||||
pub leader_id: u64,
|
||||
pub prev_log_index: u64,
|
||||
pub prev_log_term: u64,
|
||||
pub entries: Vec<LogEntry>,
|
||||
pub leader_commit: u64,
|
||||
}
|
||||
|
||||
pub struct RequestVote {
|
||||
pub term: u64,
|
||||
pub candidate_id: u64,
|
||||
pub last_log_index: u64,
|
||||
pub last_log_term: u64,
|
||||
}
|
||||
```
|
||||
|
||||
### Cluster Configuration
|
||||
|
||||
```yaml
|
||||
# docker-compose.raft.yml
|
||||
services:
|
||||
db-node-1:
|
||||
image: synor/database:latest
|
||||
environment:
|
||||
RAFT_NODE_ID: 1
|
||||
RAFT_PEERS: "db-node-2:5000,db-node-3:5000"
|
||||
RAFT_ELECTION_TIMEOUT: "150-300"
|
||||
RAFT_HEARTBEAT_MS: 50
|
||||
ports:
|
||||
- "8484:8484" # HTTP API
|
||||
- "5000:5000" # Raft RPC
|
||||
|
||||
db-node-2:
|
||||
image: synor/database:latest
|
||||
environment:
|
||||
RAFT_NODE_ID: 2
|
||||
RAFT_PEERS: "db-node-1:5000,db-node-3:5000"
|
||||
ports:
|
||||
- "8485:8484"
|
||||
- "5001:5000"
|
||||
|
||||
db-node-3:
|
||||
image: synor/database:latest
|
||||
environment:
|
||||
RAFT_NODE_ID: 3
|
||||
RAFT_PEERS: "db-node-1:5000,db-node-2:5000"
|
||||
ports:
|
||||
- "8486:8484"
|
||||
- "5002:5000"
|
||||
```
|
||||
|
||||
### Gateway Endpoints
|
||||
|
||||
| Endpoint | Method | Description |
|
||||
|----------|--------|-------------|
|
||||
| `/cluster/status` | GET | Get cluster status |
|
||||
| `/cluster/leader` | GET | Get current leader |
|
||||
| `/cluster/nodes` | GET | List all nodes |
|
||||
| `/cluster/nodes/:id` | DELETE | Remove node from cluster |
|
||||
| `/cluster/nodes` | POST | Add node to cluster |
|
||||
|
||||
---
|
||||
|
||||
## Implementation Order
|
||||
|
||||
### Step 1: SQL Store
|
||||
1. Add `sqlparser` dependency
|
||||
2. Implement type system (`types.rs`)
|
||||
3. Implement row storage (`row.rs`, `table.rs`)
|
||||
4. Implement SQL parser wrapper (`parser.rs`)
|
||||
5. Implement query planner (`planner.rs`)
|
||||
6. Implement query executor (`executor.rs`)
|
||||
7. Add transaction support (`transaction.rs`)
|
||||
8. Add gateway endpoints
|
||||
9. Write tests
|
||||
|
||||
### Step 2: Graph Store
|
||||
1. Implement node/edge types (`node.rs`, `edge.rs`)
|
||||
2. Implement graph storage (`store.rs`)
|
||||
3. Implement query parser (`query.rs`)
|
||||
4. Implement traversal algorithms (`traversal.rs`, `path.rs`)
|
||||
5. Add graph indexes (`index.rs`)
|
||||
6. Add gateway endpoints
|
||||
7. Write tests
|
||||
|
||||
### Step 3: Replication
|
||||
1. Implement Raft state machine (`state.rs`, `raft.rs`)
|
||||
2. Implement replicated log (`log.rs`)
|
||||
3. Implement RPC layer (`rpc.rs`)
|
||||
4. Implement leader election (`election.rs`)
|
||||
5. Implement log compaction (`snapshot.rs`)
|
||||
6. Implement cluster management (`cluster.rs`)
|
||||
7. Integrate with database operations
|
||||
8. Add gateway endpoints
|
||||
9. Write tests
|
||||
10. Create Docker Compose for cluster
|
||||
|
||||
---
|
||||
|
||||
## Pricing Impact
|
||||
|
||||
| Feature | Operation | Cost (SYNOR) |
|
||||
|---------|-----------|--------------|
|
||||
| SQL | Query/million | 0.02 |
|
||||
| SQL | Write/million | 0.05 |
|
||||
| Graph | Traversal/million | 0.03 |
|
||||
| Graph | Path query/million | 0.05 |
|
||||
| Replication | Included | Base storage cost |
|
||||
|
||||
---
|
||||
|
||||
## Success Criteria
|
||||
|
||||
### SQL Store
|
||||
- [ ] Parse and execute basic SELECT, INSERT, UPDATE, DELETE
|
||||
- [ ] Support WHERE clauses with operators
|
||||
- [ ] Support ORDER BY, LIMIT, OFFSET
|
||||
- [ ] Support simple JOINs
|
||||
- [ ] Support aggregate functions
|
||||
- [ ] ACID transactions
|
||||
|
||||
### Graph Store
|
||||
- [ ] Create/read/update/delete nodes and edges
|
||||
- [ ] Traverse neighbors (in/out/both)
|
||||
- [ ] Find shortest path between nodes
|
||||
- [ ] Execute pattern matching queries
|
||||
- [ ] Support property filters
|
||||
|
||||
### Replication
|
||||
- [ ] Leader election works correctly
|
||||
- [ ] Log replication achieves consensus
|
||||
- [ ] Reads from any node (eventual consistency)
|
||||
- [ ] Writes only through leader
|
||||
- [ ] Node failure handled gracefully
|
||||
- [ ] Log compaction reduces storage
|
||||
Loading…
Add table
Reference in a new issue