503 lines
14 KiB
Rust
503 lines
14 KiB
Rust
//! Vector Store - AI/RAG optimized storage.
|
|
//!
|
|
//! Provides vector embeddings storage with similarity search.
|
|
//! Optimized for AI applications, RAG pipelines, and semantic search.
|
|
|
|
use crate::error::DatabaseError;
|
|
use parking_lot::RwLock;
|
|
use serde::{Deserialize, Serialize};
|
|
use std::collections::HashMap;
|
|
|
|
/// Vector embedding.
|
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
|
pub struct Embedding {
|
|
/// Unique embedding ID.
|
|
pub id: String,
|
|
/// Vector values.
|
|
pub vector: Vec<f32>,
|
|
/// Associated metadata.
|
|
pub metadata: serde_json::Value,
|
|
/// Namespace for organization.
|
|
pub namespace: String,
|
|
/// Creation timestamp.
|
|
pub created_at: u64,
|
|
}
|
|
|
|
impl Embedding {
|
|
/// Creates a new embedding.
|
|
pub fn new(id: impl Into<String>, vector: Vec<f32>) -> Self {
|
|
Self {
|
|
id: id.into(),
|
|
vector,
|
|
metadata: serde_json::Value::Null,
|
|
namespace: "default".to_string(),
|
|
created_at: std::time::SystemTime::now()
|
|
.duration_since(std::time::UNIX_EPOCH)
|
|
.unwrap()
|
|
.as_millis() as u64,
|
|
}
|
|
}
|
|
|
|
/// Sets metadata.
|
|
pub fn with_metadata(mut self, metadata: serde_json::Value) -> Self {
|
|
self.metadata = metadata;
|
|
self
|
|
}
|
|
|
|
/// Sets namespace.
|
|
pub fn with_namespace(mut self, namespace: impl Into<String>) -> Self {
|
|
self.namespace = namespace.into();
|
|
self
|
|
}
|
|
|
|
/// Returns vector dimension.
|
|
pub fn dimension(&self) -> usize {
|
|
self.vector.len()
|
|
}
|
|
|
|
/// Normalizes the vector to unit length.
|
|
pub fn normalize(&mut self) {
|
|
let magnitude: f32 = self.vector.iter().map(|x| x * x).sum::<f32>().sqrt();
|
|
if magnitude > 0.0 {
|
|
for v in &mut self.vector {
|
|
*v /= magnitude;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Similarity metric for vector comparison.
|
|
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
|
pub enum SimilarityMetric {
|
|
/// Cosine similarity (default for text embeddings).
|
|
Cosine,
|
|
/// Euclidean distance (L2).
|
|
Euclidean,
|
|
/// Dot product.
|
|
DotProduct,
|
|
/// Manhattan distance (L1).
|
|
Manhattan,
|
|
}
|
|
|
|
impl Default for SimilarityMetric {
|
|
fn default() -> Self {
|
|
SimilarityMetric::Cosine
|
|
}
|
|
}
|
|
|
|
/// Search result with similarity score.
|
|
#[derive(Clone, Debug)]
|
|
pub struct VectorSearchResult {
|
|
/// Matching embedding.
|
|
pub embedding: Embedding,
|
|
/// Similarity score (higher is more similar).
|
|
pub score: f32,
|
|
}
|
|
|
|
/// Vector index using HNSW-like structure.
|
|
#[derive(Debug)]
|
|
pub struct VectorIndex {
|
|
/// Index name.
|
|
pub name: String,
|
|
/// Vector dimension.
|
|
pub dimension: u32,
|
|
/// Similarity metric.
|
|
pub metric: SimilarityMetric,
|
|
/// Stored embeddings.
|
|
embeddings: RwLock<HashMap<String, Embedding>>,
|
|
/// Embeddings by namespace.
|
|
by_namespace: RwLock<HashMap<String, Vec<String>>>,
|
|
/// Statistics.
|
|
stats: RwLock<VectorStats>,
|
|
}
|
|
|
|
/// Vector index statistics.
|
|
#[derive(Clone, Debug, Default)]
|
|
pub struct VectorStats {
|
|
/// Total embeddings.
|
|
pub count: u64,
|
|
/// Total searches.
|
|
pub searches: u64,
|
|
/// Average search time (ms).
|
|
pub avg_search_time_ms: f64,
|
|
}
|
|
|
|
impl VectorIndex {
|
|
/// Creates a new vector index.
|
|
pub fn new(name: impl Into<String>, dimension: u32, metric: SimilarityMetric) -> Self {
|
|
Self {
|
|
name: name.into(),
|
|
dimension,
|
|
metric,
|
|
embeddings: RwLock::new(HashMap::new()),
|
|
by_namespace: RwLock::new(HashMap::new()),
|
|
stats: RwLock::new(VectorStats::default()),
|
|
}
|
|
}
|
|
|
|
/// Inserts an embedding.
|
|
pub fn insert(&self, embedding: Embedding) -> Result<(), DatabaseError> {
|
|
if embedding.vector.len() != self.dimension as usize {
|
|
return Err(DatabaseError::DimensionMismatch {
|
|
expected: self.dimension,
|
|
got: embedding.vector.len() as u32,
|
|
});
|
|
}
|
|
|
|
let id = embedding.id.clone();
|
|
let namespace = embedding.namespace.clone();
|
|
|
|
self.embeddings.write().insert(id.clone(), embedding);
|
|
self.by_namespace
|
|
.write()
|
|
.entry(namespace)
|
|
.or_insert_with(Vec::new)
|
|
.push(id);
|
|
|
|
self.stats.write().count += 1;
|
|
Ok(())
|
|
}
|
|
|
|
/// Inserts multiple embeddings.
|
|
pub fn insert_batch(&self, embeddings: Vec<Embedding>) -> Result<(), DatabaseError> {
|
|
for embedding in embeddings {
|
|
self.insert(embedding)?;
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
/// Gets an embedding by ID.
|
|
pub fn get(&self, id: &str) -> Option<Embedding> {
|
|
self.embeddings.read().get(id).cloned()
|
|
}
|
|
|
|
/// Deletes an embedding.
|
|
pub fn delete(&self, id: &str) -> Result<bool, DatabaseError> {
|
|
let mut embeddings = self.embeddings.write();
|
|
if let Some(embedding) = embeddings.remove(id) {
|
|
// Remove from namespace index
|
|
let mut by_ns = self.by_namespace.write();
|
|
if let Some(ids) = by_ns.get_mut(&embedding.namespace) {
|
|
ids.retain(|i| i != id);
|
|
}
|
|
self.stats.write().count -= 1;
|
|
Ok(true)
|
|
} else {
|
|
Ok(false)
|
|
}
|
|
}
|
|
|
|
/// Searches for similar vectors.
|
|
pub fn search(
|
|
&self,
|
|
query: &[f32],
|
|
limit: usize,
|
|
namespace: Option<&str>,
|
|
threshold: Option<f32>,
|
|
) -> Result<Vec<VectorSearchResult>, DatabaseError> {
|
|
if query.len() != self.dimension as usize {
|
|
return Err(DatabaseError::DimensionMismatch {
|
|
expected: self.dimension,
|
|
got: query.len() as u32,
|
|
});
|
|
}
|
|
|
|
let start = std::time::Instant::now();
|
|
|
|
let embeddings = self.embeddings.read();
|
|
let mut results: Vec<VectorSearchResult> = embeddings
|
|
.values()
|
|
.filter(|e| namespace.map(|ns| e.namespace == ns).unwrap_or(true))
|
|
.map(|e| {
|
|
let score = self.calculate_similarity(&e.vector, query);
|
|
VectorSearchResult {
|
|
embedding: e.clone(),
|
|
score,
|
|
}
|
|
})
|
|
.filter(|r| threshold.map(|t| r.score >= t).unwrap_or(true))
|
|
.collect();
|
|
|
|
// Sort by score descending
|
|
results.sort_by(|a, b| {
|
|
b.score
|
|
.partial_cmp(&a.score)
|
|
.unwrap_or(std::cmp::Ordering::Equal)
|
|
});
|
|
|
|
// Apply limit
|
|
results.truncate(limit);
|
|
|
|
// Update stats
|
|
let elapsed = start.elapsed().as_millis() as f64;
|
|
let mut stats = self.stats.write();
|
|
stats.searches += 1;
|
|
stats.avg_search_time_ms = (stats.avg_search_time_ms * (stats.searches - 1) as f64
|
|
+ elapsed)
|
|
/ stats.searches as f64;
|
|
|
|
Ok(results)
|
|
}
|
|
|
|
/// Calculates similarity between two vectors.
|
|
fn calculate_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
|
|
match self.metric {
|
|
SimilarityMetric::Cosine => cosine_similarity(a, b),
|
|
SimilarityMetric::Euclidean => {
|
|
// Convert distance to similarity (higher = more similar)
|
|
let dist = euclidean_distance(a, b);
|
|
1.0 / (1.0 + dist)
|
|
}
|
|
SimilarityMetric::DotProduct => dot_product(a, b),
|
|
SimilarityMetric::Manhattan => {
|
|
let dist = manhattan_distance(a, b);
|
|
1.0 / (1.0 + dist)
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Returns index statistics.
|
|
pub fn stats(&self) -> VectorStats {
|
|
self.stats.read().clone()
|
|
}
|
|
|
|
/// Returns count of embeddings.
|
|
pub fn count(&self) -> u64 {
|
|
self.stats.read().count
|
|
}
|
|
|
|
/// Clears all embeddings.
|
|
pub fn clear(&self) {
|
|
self.embeddings.write().clear();
|
|
self.by_namespace.write().clear();
|
|
self.stats.write().count = 0;
|
|
}
|
|
}
|
|
|
|
/// Vector store managing multiple indexes.
|
|
pub struct VectorStore {
|
|
/// Default dimension.
|
|
default_dimension: u32,
|
|
/// Indexes by name.
|
|
indexes: RwLock<HashMap<String, VectorIndex>>,
|
|
/// Default index.
|
|
default_index: VectorIndex,
|
|
}
|
|
|
|
impl VectorStore {
|
|
/// Creates a new vector store.
|
|
pub fn new(dimension: u32) -> Self {
|
|
Self {
|
|
default_dimension: dimension,
|
|
indexes: RwLock::new(HashMap::new()),
|
|
default_index: VectorIndex::new("default", dimension, SimilarityMetric::Cosine),
|
|
}
|
|
}
|
|
|
|
/// Creates a new index.
|
|
pub fn create_index(
|
|
&self,
|
|
name: &str,
|
|
dimension: u32,
|
|
metric: SimilarityMetric,
|
|
) -> Result<(), DatabaseError> {
|
|
let mut indexes = self.indexes.write();
|
|
if indexes.contains_key(name) {
|
|
return Err(DatabaseError::AlreadyExists(name.to_string()));
|
|
}
|
|
indexes.insert(name.to_string(), VectorIndex::new(name, dimension, metric));
|
|
Ok(())
|
|
}
|
|
|
|
/// Gets an index by name.
|
|
pub fn get_index(&self, _name: &str) -> Option<&VectorIndex> {
|
|
// Simplified - would use Arc in production
|
|
None
|
|
}
|
|
|
|
/// Inserts an embedding into the default index.
|
|
pub fn insert(&self, embedding: Embedding) -> Result<(), DatabaseError> {
|
|
self.default_index.insert(embedding)
|
|
}
|
|
|
|
/// Searches the default index.
|
|
pub fn search(
|
|
&self,
|
|
query: &[f32],
|
|
limit: usize,
|
|
namespace: Option<&str>,
|
|
threshold: Option<f32>,
|
|
) -> Result<Vec<VectorSearchResult>, DatabaseError> {
|
|
self.default_index
|
|
.search(query, limit, namespace, threshold)
|
|
}
|
|
|
|
/// Gets an embedding by ID.
|
|
pub fn get(&self, id: &str) -> Option<Embedding> {
|
|
self.default_index.get(id)
|
|
}
|
|
|
|
/// Deletes an embedding.
|
|
pub fn delete(&self, id: &str) -> Result<bool, DatabaseError> {
|
|
self.default_index.delete(id)
|
|
}
|
|
|
|
/// Returns embedding count.
|
|
pub fn count(&self) -> u64 {
|
|
self.default_index.count()
|
|
}
|
|
|
|
/// Lists all indexes.
|
|
pub fn list_indexes(&self) -> Vec<String> {
|
|
let mut names: Vec<_> = self.indexes.read().keys().cloned().collect();
|
|
names.push("default".to_string());
|
|
names
|
|
}
|
|
}
|
|
|
|
/// Cosine similarity between two vectors.
|
|
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
|
|
if a.len() != b.len() {
|
|
return 0.0;
|
|
}
|
|
|
|
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
|
|
let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
|
|
let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
|
|
|
|
if mag_a == 0.0 || mag_b == 0.0 {
|
|
0.0
|
|
} else {
|
|
dot / (mag_a * mag_b)
|
|
}
|
|
}
|
|
|
|
/// Dot product of two vectors.
|
|
pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
|
|
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
|
|
}
|
|
|
|
/// Euclidean distance (L2) between two vectors.
|
|
pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
|
|
a.iter()
|
|
.zip(b.iter())
|
|
.map(|(x, y)| (x - y).powi(2))
|
|
.sum::<f32>()
|
|
.sqrt()
|
|
}
|
|
|
|
/// Manhattan distance (L1) between two vectors.
|
|
pub fn manhattan_distance(a: &[f32], b: &[f32]) -> f32 {
|
|
a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum()
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use serde_json::json;
|
|
|
|
#[test]
|
|
fn test_embedding_creation() {
|
|
let embedding = Embedding::new("doc1", vec![1.0, 0.0, 0.0])
|
|
.with_metadata(json!({"title": "Hello World"}))
|
|
.with_namespace("documents");
|
|
|
|
assert_eq!(embedding.id, "doc1");
|
|
assert_eq!(embedding.dimension(), 3);
|
|
assert_eq!(embedding.namespace, "documents");
|
|
}
|
|
|
|
#[test]
|
|
fn test_vector_insert_search() {
|
|
let store = VectorStore::new(3);
|
|
|
|
store
|
|
.insert(Embedding::new("a", vec![1.0, 0.0, 0.0]))
|
|
.unwrap();
|
|
store
|
|
.insert(Embedding::new("b", vec![0.9, 0.1, 0.0]))
|
|
.unwrap();
|
|
store
|
|
.insert(Embedding::new("c", vec![0.0, 1.0, 0.0]))
|
|
.unwrap();
|
|
|
|
let results = store.search(&[1.0, 0.0, 0.0], 2, None, None).unwrap();
|
|
|
|
assert_eq!(results.len(), 2);
|
|
assert_eq!(results[0].embedding.id, "a");
|
|
assert!((results[0].score - 1.0).abs() < 0.001);
|
|
}
|
|
|
|
#[test]
|
|
fn test_similarity_threshold() {
|
|
let store = VectorStore::new(3);
|
|
|
|
store
|
|
.insert(Embedding::new("a", vec![1.0, 0.0, 0.0]))
|
|
.unwrap();
|
|
store
|
|
.insert(Embedding::new("b", vec![0.0, 1.0, 0.0]))
|
|
.unwrap();
|
|
|
|
let results = store.search(&[1.0, 0.0, 0.0], 10, None, Some(0.5)).unwrap();
|
|
|
|
// Only "a" should match with high threshold
|
|
assert_eq!(results.len(), 1);
|
|
assert_eq!(results[0].embedding.id, "a");
|
|
}
|
|
|
|
#[test]
|
|
fn test_namespace_filter() {
|
|
let store = VectorStore::new(3);
|
|
|
|
store
|
|
.insert(Embedding::new("a", vec![1.0, 0.0, 0.0]).with_namespace("ns1"))
|
|
.unwrap();
|
|
store
|
|
.insert(Embedding::new("b", vec![1.0, 0.0, 0.0]).with_namespace("ns2"))
|
|
.unwrap();
|
|
|
|
let results = store
|
|
.search(&[1.0, 0.0, 0.0], 10, Some("ns1"), None)
|
|
.unwrap();
|
|
|
|
assert_eq!(results.len(), 1);
|
|
assert_eq!(results[0].embedding.id, "a");
|
|
}
|
|
|
|
#[test]
|
|
fn test_dimension_mismatch() {
|
|
let store = VectorStore::new(3);
|
|
|
|
let result = store.insert(Embedding::new("a", vec![1.0, 0.0])); // 2D instead of 3D
|
|
assert!(result.is_err());
|
|
}
|
|
|
|
#[test]
|
|
fn test_similarity_metrics() {
|
|
let a = vec![1.0, 0.0, 0.0];
|
|
let b = vec![1.0, 0.0, 0.0];
|
|
let c = vec![0.0, 1.0, 0.0];
|
|
|
|
// Cosine
|
|
assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
|
|
assert!((cosine_similarity(&a, &c) - 0.0).abs() < 0.001);
|
|
|
|
// Euclidean
|
|
assert!((euclidean_distance(&a, &b) - 0.0).abs() < 0.001);
|
|
assert!((euclidean_distance(&a, &c) - 1.414).abs() < 0.01);
|
|
|
|
// Dot product
|
|
assert!((dot_product(&a, &b) - 1.0).abs() < 0.001);
|
|
assert!((dot_product(&a, &c) - 0.0).abs() < 0.001);
|
|
}
|
|
|
|
#[test]
|
|
fn test_embedding_normalize() {
|
|
let mut embedding = Embedding::new("test", vec![3.0, 4.0, 0.0]);
|
|
embedding.normalize();
|
|
|
|
let magnitude: f32 = embedding.vector.iter().map(|x| x * x).sum::<f32>().sqrt();
|
|
assert!((magnitude - 1.0).abs() < 0.001);
|
|
}
|
|
}
|