synor/crates/synor-database/src/vector.rs
2026-02-02 05:58:22 +05:30

503 lines
14 KiB
Rust

//! Vector Store - AI/RAG optimized storage.
//!
//! Provides vector embeddings storage with similarity search.
//! Optimized for AI applications, RAG pipelines, and semantic search.
use crate::error::DatabaseError;
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
/// Vector embedding.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Embedding {
/// Unique embedding ID.
pub id: String,
/// Vector values.
pub vector: Vec<f32>,
/// Associated metadata.
pub metadata: serde_json::Value,
/// Namespace for organization.
pub namespace: String,
/// Creation timestamp.
pub created_at: u64,
}
impl Embedding {
/// Creates a new embedding.
pub fn new(id: impl Into<String>, vector: Vec<f32>) -> Self {
Self {
id: id.into(),
vector,
metadata: serde_json::Value::Null,
namespace: "default".to_string(),
created_at: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis() as u64,
}
}
/// Sets metadata.
pub fn with_metadata(mut self, metadata: serde_json::Value) -> Self {
self.metadata = metadata;
self
}
/// Sets namespace.
pub fn with_namespace(mut self, namespace: impl Into<String>) -> Self {
self.namespace = namespace.into();
self
}
/// Returns vector dimension.
pub fn dimension(&self) -> usize {
self.vector.len()
}
/// Normalizes the vector to unit length.
pub fn normalize(&mut self) {
let magnitude: f32 = self.vector.iter().map(|x| x * x).sum::<f32>().sqrt();
if magnitude > 0.0 {
for v in &mut self.vector {
*v /= magnitude;
}
}
}
}
/// Similarity metric for vector comparison.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum SimilarityMetric {
/// Cosine similarity (default for text embeddings).
Cosine,
/// Euclidean distance (L2).
Euclidean,
/// Dot product.
DotProduct,
/// Manhattan distance (L1).
Manhattan,
}
impl Default for SimilarityMetric {
fn default() -> Self {
SimilarityMetric::Cosine
}
}
/// Search result with similarity score.
#[derive(Clone, Debug)]
pub struct VectorSearchResult {
/// Matching embedding.
pub embedding: Embedding,
/// Similarity score (higher is more similar).
pub score: f32,
}
/// Vector index using HNSW-like structure.
#[derive(Debug)]
pub struct VectorIndex {
/// Index name.
pub name: String,
/// Vector dimension.
pub dimension: u32,
/// Similarity metric.
pub metric: SimilarityMetric,
/// Stored embeddings.
embeddings: RwLock<HashMap<String, Embedding>>,
/// Embeddings by namespace.
by_namespace: RwLock<HashMap<String, Vec<String>>>,
/// Statistics.
stats: RwLock<VectorStats>,
}
/// Vector index statistics.
#[derive(Clone, Debug, Default)]
pub struct VectorStats {
/// Total embeddings.
pub count: u64,
/// Total searches.
pub searches: u64,
/// Average search time (ms).
pub avg_search_time_ms: f64,
}
impl VectorIndex {
/// Creates a new vector index.
pub fn new(name: impl Into<String>, dimension: u32, metric: SimilarityMetric) -> Self {
Self {
name: name.into(),
dimension,
metric,
embeddings: RwLock::new(HashMap::new()),
by_namespace: RwLock::new(HashMap::new()),
stats: RwLock::new(VectorStats::default()),
}
}
/// Inserts an embedding.
pub fn insert(&self, embedding: Embedding) -> Result<(), DatabaseError> {
if embedding.vector.len() != self.dimension as usize {
return Err(DatabaseError::DimensionMismatch {
expected: self.dimension,
got: embedding.vector.len() as u32,
});
}
let id = embedding.id.clone();
let namespace = embedding.namespace.clone();
self.embeddings.write().insert(id.clone(), embedding);
self.by_namespace
.write()
.entry(namespace)
.or_insert_with(Vec::new)
.push(id);
self.stats.write().count += 1;
Ok(())
}
/// Inserts multiple embeddings.
pub fn insert_batch(&self, embeddings: Vec<Embedding>) -> Result<(), DatabaseError> {
for embedding in embeddings {
self.insert(embedding)?;
}
Ok(())
}
/// Gets an embedding by ID.
pub fn get(&self, id: &str) -> Option<Embedding> {
self.embeddings.read().get(id).cloned()
}
/// Deletes an embedding.
pub fn delete(&self, id: &str) -> Result<bool, DatabaseError> {
let mut embeddings = self.embeddings.write();
if let Some(embedding) = embeddings.remove(id) {
// Remove from namespace index
let mut by_ns = self.by_namespace.write();
if let Some(ids) = by_ns.get_mut(&embedding.namespace) {
ids.retain(|i| i != id);
}
self.stats.write().count -= 1;
Ok(true)
} else {
Ok(false)
}
}
/// Searches for similar vectors.
pub fn search(
&self,
query: &[f32],
limit: usize,
namespace: Option<&str>,
threshold: Option<f32>,
) -> Result<Vec<VectorSearchResult>, DatabaseError> {
if query.len() != self.dimension as usize {
return Err(DatabaseError::DimensionMismatch {
expected: self.dimension,
got: query.len() as u32,
});
}
let start = std::time::Instant::now();
let embeddings = self.embeddings.read();
let mut results: Vec<VectorSearchResult> = embeddings
.values()
.filter(|e| namespace.map(|ns| e.namespace == ns).unwrap_or(true))
.map(|e| {
let score = self.calculate_similarity(&e.vector, query);
VectorSearchResult {
embedding: e.clone(),
score,
}
})
.filter(|r| threshold.map(|t| r.score >= t).unwrap_or(true))
.collect();
// Sort by score descending
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
// Apply limit
results.truncate(limit);
// Update stats
let elapsed = start.elapsed().as_millis() as f64;
let mut stats = self.stats.write();
stats.searches += 1;
stats.avg_search_time_ms = (stats.avg_search_time_ms * (stats.searches - 1) as f64
+ elapsed)
/ stats.searches as f64;
Ok(results)
}
/// Calculates similarity between two vectors.
fn calculate_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
match self.metric {
SimilarityMetric::Cosine => cosine_similarity(a, b),
SimilarityMetric::Euclidean => {
// Convert distance to similarity (higher = more similar)
let dist = euclidean_distance(a, b);
1.0 / (1.0 + dist)
}
SimilarityMetric::DotProduct => dot_product(a, b),
SimilarityMetric::Manhattan => {
let dist = manhattan_distance(a, b);
1.0 / (1.0 + dist)
}
}
}
/// Returns index statistics.
pub fn stats(&self) -> VectorStats {
self.stats.read().clone()
}
/// Returns count of embeddings.
pub fn count(&self) -> u64 {
self.stats.read().count
}
/// Clears all embeddings.
pub fn clear(&self) {
self.embeddings.write().clear();
self.by_namespace.write().clear();
self.stats.write().count = 0;
}
}
/// Vector store managing multiple indexes.
pub struct VectorStore {
/// Default dimension.
default_dimension: u32,
/// Indexes by name.
indexes: RwLock<HashMap<String, VectorIndex>>,
/// Default index.
default_index: VectorIndex,
}
impl VectorStore {
/// Creates a new vector store.
pub fn new(dimension: u32) -> Self {
Self {
default_dimension: dimension,
indexes: RwLock::new(HashMap::new()),
default_index: VectorIndex::new("default", dimension, SimilarityMetric::Cosine),
}
}
/// Creates a new index.
pub fn create_index(
&self,
name: &str,
dimension: u32,
metric: SimilarityMetric,
) -> Result<(), DatabaseError> {
let mut indexes = self.indexes.write();
if indexes.contains_key(name) {
return Err(DatabaseError::AlreadyExists(name.to_string()));
}
indexes.insert(name.to_string(), VectorIndex::new(name, dimension, metric));
Ok(())
}
/// Gets an index by name.
pub fn get_index(&self, _name: &str) -> Option<&VectorIndex> {
// Simplified - would use Arc in production
None
}
/// Inserts an embedding into the default index.
pub fn insert(&self, embedding: Embedding) -> Result<(), DatabaseError> {
self.default_index.insert(embedding)
}
/// Searches the default index.
pub fn search(
&self,
query: &[f32],
limit: usize,
namespace: Option<&str>,
threshold: Option<f32>,
) -> Result<Vec<VectorSearchResult>, DatabaseError> {
self.default_index
.search(query, limit, namespace, threshold)
}
/// Gets an embedding by ID.
pub fn get(&self, id: &str) -> Option<Embedding> {
self.default_index.get(id)
}
/// Deletes an embedding.
pub fn delete(&self, id: &str) -> Result<bool, DatabaseError> {
self.default_index.delete(id)
}
/// Returns embedding count.
pub fn count(&self) -> u64 {
self.default_index.count()
}
/// Lists all indexes.
pub fn list_indexes(&self) -> Vec<String> {
let mut names: Vec<_> = self.indexes.read().keys().cloned().collect();
names.push("default".to_string());
names
}
}
/// Cosine similarity between two vectors.
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return 0.0;
}
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if mag_a == 0.0 || mag_b == 0.0 {
0.0
} else {
dot / (mag_a * mag_b)
}
}
/// Dot product of two vectors.
pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
/// Euclidean distance (L2) between two vectors.
pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum::<f32>()
.sqrt()
}
/// Manhattan distance (L1) between two vectors.
pub fn manhattan_distance(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum()
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_embedding_creation() {
let embedding = Embedding::new("doc1", vec![1.0, 0.0, 0.0])
.with_metadata(json!({"title": "Hello World"}))
.with_namespace("documents");
assert_eq!(embedding.id, "doc1");
assert_eq!(embedding.dimension(), 3);
assert_eq!(embedding.namespace, "documents");
}
#[test]
fn test_vector_insert_search() {
let store = VectorStore::new(3);
store
.insert(Embedding::new("a", vec![1.0, 0.0, 0.0]))
.unwrap();
store
.insert(Embedding::new("b", vec![0.9, 0.1, 0.0]))
.unwrap();
store
.insert(Embedding::new("c", vec![0.0, 1.0, 0.0]))
.unwrap();
let results = store.search(&[1.0, 0.0, 0.0], 2, None, None).unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].embedding.id, "a");
assert!((results[0].score - 1.0).abs() < 0.001);
}
#[test]
fn test_similarity_threshold() {
let store = VectorStore::new(3);
store
.insert(Embedding::new("a", vec![1.0, 0.0, 0.0]))
.unwrap();
store
.insert(Embedding::new("b", vec![0.0, 1.0, 0.0]))
.unwrap();
let results = store.search(&[1.0, 0.0, 0.0], 10, None, Some(0.5)).unwrap();
// Only "a" should match with high threshold
assert_eq!(results.len(), 1);
assert_eq!(results[0].embedding.id, "a");
}
#[test]
fn test_namespace_filter() {
let store = VectorStore::new(3);
store
.insert(Embedding::new("a", vec![1.0, 0.0, 0.0]).with_namespace("ns1"))
.unwrap();
store
.insert(Embedding::new("b", vec![1.0, 0.0, 0.0]).with_namespace("ns2"))
.unwrap();
let results = store
.search(&[1.0, 0.0, 0.0], 10, Some("ns1"), None)
.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].embedding.id, "a");
}
#[test]
fn test_dimension_mismatch() {
let store = VectorStore::new(3);
let result = store.insert(Embedding::new("a", vec![1.0, 0.0])); // 2D instead of 3D
assert!(result.is_err());
}
#[test]
fn test_similarity_metrics() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
let c = vec![0.0, 1.0, 0.0];
// Cosine
assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
assert!((cosine_similarity(&a, &c) - 0.0).abs() < 0.001);
// Euclidean
assert!((euclidean_distance(&a, &b) - 0.0).abs() < 0.001);
assert!((euclidean_distance(&a, &c) - 1.414).abs() < 0.01);
// Dot product
assert!((dot_product(&a, &b) - 1.0).abs() < 0.001);
assert!((dot_product(&a, &c) - 0.0).abs() < 0.001);
}
#[test]
fn test_embedding_normalize() {
let mut embedding = Embedding::new("test", vec![3.0, 4.0, 0.0]);
embedding.normalize();
let magnitude: f32 = embedding.vector.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((magnitude - 1.0).abs() < 0.001);
}
}