style: apply cargo fmt formatting

This commit is contained in:
Gulshan Yadav 2026-02-02 05:58:22 +05:30
parent 5126c33113
commit dcd1cccc67
170 changed files with 4463 additions and 2837 deletions

View file

@ -256,7 +256,11 @@ pub async fn handle(
Ok(()) Ok(())
} }
CompilerCommands::Encode { function, args, abi } => { CompilerCommands::Encode {
function,
args,
abi,
} => {
output::print_info(&format!("Encoding call to: {}", function)); output::print_info(&format!("Encoding call to: {}", function));
output::print_kv("Arguments", &args); output::print_kv("Arguments", &args);
if let Some(a) = abi { if let Some(a) = abi {
@ -268,7 +272,11 @@ pub async fn handle(
Ok(()) Ok(())
} }
CompilerCommands::Decode { data, function, abi } => { CompilerCommands::Decode {
data,
function,
abi,
} => {
output::print_info(&format!("Decoding result for: {}", function)); output::print_info(&format!("Decoding result for: {}", function));
output::print_kv("Data", &data); output::print_kv("Data", &data);
if let Some(a) = abi { if let Some(a) = abi {
@ -314,7 +322,11 @@ pub async fn handle(
Ok(()) Ok(())
} }
CompilerCommands::SecurityScan { wasm, min_severity, format: _ } => { CompilerCommands::SecurityScan {
wasm,
min_severity,
format: _,
} => {
output::print_info(&format!("Security scan: {}", wasm.display())); output::print_info(&format!("Security scan: {}", wasm.display()));
output::print_kv("Min severity", &min_severity); output::print_kv("Min severity", &min_severity);
@ -344,7 +356,11 @@ pub async fn handle(
Ok(()) Ok(())
} }
CompilerCommands::Validate { wasm, exports, max_memory } => { CompilerCommands::Validate {
wasm,
exports,
max_memory,
} => {
output::print_info(&format!("Validating: {}", wasm.display())); output::print_info(&format!("Validating: {}", wasm.display()));
if let Some(e) = exports { if let Some(e) = exports {

View file

@ -196,8 +196,7 @@ pub async fn deploy(
} }
// Determine output directory // Determine output directory
let output_path = output_dir let output_path = output_dir.unwrap_or_else(|| cwd.join(config.output_dir()));
.unwrap_or_else(|| cwd.join(config.output_dir()));
if !output_path.exists() { if !output_path.exists() {
return Err(anyhow!( return Err(anyhow!(
@ -270,7 +269,10 @@ fn validate_name(name: &str) -> Result<()> {
if name.len() > 63 { if name.len() > 63 {
return Err(anyhow!("Name must be 63 characters or less")); return Err(anyhow!("Name must be 63 characters or less"));
} }
if !name.chars().all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '-') { if !name
.chars()
.all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '-')
{
return Err(anyhow!( return Err(anyhow!(
"Name must contain only lowercase letters, numbers, and hyphens" "Name must contain only lowercase letters, numbers, and hyphens"
)); ));
@ -281,8 +283,8 @@ fn validate_name(name: &str) -> Result<()> {
// Reserved names // Reserved names
const RESERVED: &[&str] = &[ const RESERVED: &[&str] = &[
"www", "api", "app", "admin", "mail", "ftp", "ssh", "cdn", "www", "api", "app", "admin", "mail", "ftp", "ssh", "cdn", "storage", "gateway", "hosting",
"storage", "gateway", "hosting", "node", "synor", "node", "synor",
]; ];
if RESERVED.contains(&name) { if RESERVED.contains(&name) {
return Err(anyhow!("Name '{}' is reserved", name)); return Err(anyhow!("Name '{}' is reserved", name));
@ -397,11 +399,7 @@ fn guess_content_type(path: &Path) -> String {
} }
/// Upload files to Synor Storage. /// Upload files to Synor Storage.
async fn upload_files( async fn upload_files(base_dir: &Path, files: &[DeployFile], gateway_url: &str) -> Result<String> {
base_dir: &Path,
files: &[DeployFile],
gateway_url: &str,
) -> Result<String> {
let client = reqwest::Client::new(); let client = reqwest::Client::new();
// Create a multipart form with all files // Create a multipart form with all files
@ -445,11 +443,7 @@ async fn upload_files(
} }
/// Register the deployment with the hosting gateway. /// Register the deployment with the hosting gateway.
async fn register_deployment( async fn register_deployment(name: &str, cid: &str, gateway_url: &str) -> Result<String> {
name: &str,
cid: &str,
gateway_url: &str,
) -> Result<String> {
let client = reqwest::Client::new(); let client = reqwest::Client::new();
#[derive(Serialize)] #[derive(Serialize)]
@ -662,7 +656,11 @@ pub async fn delete(name: &str, gateway_url: &str, format: OutputFormat) -> Resu
if !response.status().is_success() { if !response.status().is_success() {
let status = response.status(); let status = response.status();
let body = response.text().await.unwrap_or_default(); let body = response.text().await.unwrap_or_default();
return Err(anyhow!("Failed to delete deployment: {} - {}", status, body)); return Err(anyhow!(
"Failed to delete deployment: {} - {}",
status,
body
));
} }
match format { match format {
@ -707,22 +705,13 @@ mod tests {
#[test] #[test]
fn test_guess_content_type() { fn test_guess_content_type() {
assert_eq!( assert_eq!(guess_content_type(Path::new("index.html")), "text/html");
guess_content_type(Path::new("index.html")), assert_eq!(guess_content_type(Path::new("style.css")), "text/css");
"text/html"
);
assert_eq!(
guess_content_type(Path::new("style.css")),
"text/css"
);
assert_eq!( assert_eq!(
guess_content_type(Path::new("app.js")), guess_content_type(Path::new("app.js")),
"application/javascript" "application/javascript"
); );
assert_eq!( assert_eq!(guess_content_type(Path::new("image.png")), "image/png");
guess_content_type(Path::new("image.png")),
"image/png"
);
assert_eq!( assert_eq!(
guess_content_type(Path::new("data.wasm")), guess_content_type(Path::new("data.wasm")),
"application/wasm" "application/wasm"

View file

@ -246,7 +246,13 @@ pub async fn handle(
Ok(()) Ok(())
} }
DexCommands::PlaceOrder { market, side, price, quantity, wallet } => { DexCommands::PlaceOrder {
market,
side,
price,
quantity,
wallet,
} => {
output::print_info("Placing limit order..."); output::print_info("Placing limit order...");
output::print_kv("Market", &market); output::print_kv("Market", &market);
output::print_kv("Side", &side); output::print_kv("Side", &side);
@ -257,7 +263,12 @@ pub async fn handle(
Ok(()) Ok(())
} }
DexCommands::MarketOrder { market, side, quantity, wallet } => { DexCommands::MarketOrder {
market,
side,
quantity,
wallet,
} => {
output::print_info("Placing market order..."); output::print_info("Placing market order...");
output::print_kv("Market", &market); output::print_kv("Market", &market);
output::print_kv("Side", &side); output::print_kv("Side", &side);
@ -275,7 +286,10 @@ pub async fn handle(
DexCommands::CancelAll { market, wallet } => { DexCommands::CancelAll { market, wallet } => {
let scope = market.unwrap_or_else(|| "all markets".to_string()); let scope = market.unwrap_or_else(|| "all markets".to_string());
output::print_info(&format!("Cancelling all orders in {} for {}", scope, wallet)); output::print_info(&format!(
"Cancelling all orders in {} for {}",
scope, wallet
));
output::print_success("3 orders cancelled"); output::print_success("3 orders cancelled");
Ok(()) Ok(())
} }
@ -317,7 +331,12 @@ pub async fn handle(
Ok(()) Ok(())
} }
DexCommands::AddLiquidity { pool_id, amount_a, amount_b, wallet } => { DexCommands::AddLiquidity {
pool_id,
amount_a,
amount_b,
wallet,
} => {
output::print_info("Adding liquidity..."); output::print_info("Adding liquidity...");
output::print_kv("Pool", &pool_id); output::print_kv("Pool", &pool_id);
output::print_kv("Amount A", &amount_a); output::print_kv("Amount A", &amount_a);
@ -327,7 +346,11 @@ pub async fn handle(
Ok(()) Ok(())
} }
DexCommands::RemoveLiquidity { pool_id, lp_amount, wallet } => { DexCommands::RemoveLiquidity {
pool_id,
lp_amount,
wallet,
} => {
output::print_info("Removing liquidity..."); output::print_info("Removing liquidity...");
output::print_kv("Pool", &pool_id); output::print_kv("Pool", &pool_id);
output::print_kv("LP Amount", &lp_amount); output::print_kv("LP Amount", &lp_amount);

View file

@ -169,11 +169,7 @@ pub enum ZkCommands {
} }
/// Handle ZK commands. /// Handle ZK commands.
pub async fn handle( pub async fn handle(_client: &RpcClient, command: ZkCommands, _format: OutputFormat) -> Result<()> {
_client: &RpcClient,
command: ZkCommands,
_format: OutputFormat,
) -> Result<()> {
match command { match command {
ZkCommands::Compile { circuit, output } => { ZkCommands::Compile { circuit, output } => {
output::print_info(&format!("Compiling circuit: {}", circuit.display())); output::print_info(&format!("Compiling circuit: {}", circuit.display()));
@ -211,8 +207,16 @@ pub async fn handle(
Ok(()) Ok(())
} }
ZkCommands::ProveGroth16 { circuit, witness, proving_key: _, output } => { ZkCommands::ProveGroth16 {
output::print_info(&format!("Generating Groth16 proof for circuit: {}", circuit)); circuit,
witness,
proving_key: _,
output,
} => {
output::print_info(&format!(
"Generating Groth16 proof for circuit: {}",
circuit
));
output::print_info(&format!("Witness: {}", witness.display())); output::print_info(&format!("Witness: {}", witness.display()));
output::print_info("Computing witness..."); output::print_info("Computing witness...");
output::print_info("Generating proof..."); output::print_info("Generating proof...");
@ -226,7 +230,11 @@ pub async fn handle(
Ok(()) Ok(())
} }
ZkCommands::ProvePlonk { circuit, witness, output } => { ZkCommands::ProvePlonk {
circuit,
witness,
output,
} => {
output::print_info(&format!("Generating PLONK proof for circuit: {}", circuit)); output::print_info(&format!("Generating PLONK proof for circuit: {}", circuit));
output::print_info(&format!("Witness: {}", witness.display())); output::print_info(&format!("Witness: {}", witness.display()));
output::print_info("Computing witness..."); output::print_info("Computing witness...");
@ -240,7 +248,11 @@ pub async fn handle(
Ok(()) Ok(())
} }
ZkCommands::ProveStark { circuit, witness, output } => { ZkCommands::ProveStark {
circuit,
witness,
output,
} => {
output::print_info(&format!("Generating STARK proof for circuit: {}", circuit)); output::print_info(&format!("Generating STARK proof for circuit: {}", circuit));
output::print_info(&format!("Witness: {}", witness.display())); output::print_info(&format!("Witness: {}", witness.display()));
output::print_info("Computing execution trace..."); output::print_info("Computing execution trace...");
@ -256,7 +268,11 @@ pub async fn handle(
Ok(()) Ok(())
} }
ZkCommands::Verify { proof, verification_key: _, public_inputs: _ } => { ZkCommands::Verify {
proof,
verification_key: _,
public_inputs: _,
} => {
output::print_info(&format!("Verifying proof: {}", proof.display())); output::print_info(&format!("Verifying proof: {}", proof.display()));
output::print_info("Loading proof..."); output::print_info("Loading proof...");
output::print_info("Verifying..."); output::print_info("Verifying...");
@ -265,8 +281,15 @@ pub async fn handle(
Ok(()) Ok(())
} }
ZkCommands::Setup { circuit, system, output } => { ZkCommands::Setup {
output::print_info(&format!("Generating {} keys for circuit: {}", system, circuit)); circuit,
system,
output,
} => {
output::print_info(&format!(
"Generating {} keys for circuit: {}",
system, circuit
));
output::print_info("This may take a while for large circuits..."); output::print_info("This may take a while for large circuits...");
output::print_info("Generating proving key..."); output::print_info("Generating proving key...");
output::print_info("Deriving verification key..."); output::print_info("Deriving verification key...");

View file

@ -469,7 +469,11 @@ enum DeployCommands {
output: Option<PathBuf>, output: Option<PathBuf>,
/// Hosting gateway URL /// Hosting gateway URL
#[arg(long, env = "SYNOR_HOSTING_URL", default_value = "http://127.0.0.1:8280")] #[arg(
long,
env = "SYNOR_HOSTING_URL",
default_value = "http://127.0.0.1:8280"
)]
gateway: String, gateway: String,
/// Skip running the build command /// Skip running the build command
@ -495,7 +499,11 @@ enum DeployCommands {
/// List deployments /// List deployments
List { List {
/// Hosting gateway URL /// Hosting gateway URL
#[arg(long, env = "SYNOR_HOSTING_URL", default_value = "http://127.0.0.1:8280")] #[arg(
long,
env = "SYNOR_HOSTING_URL",
default_value = "http://127.0.0.1:8280"
)]
gateway: String, gateway: String,
}, },
@ -505,7 +513,11 @@ enum DeployCommands {
name: String, name: String,
/// Hosting gateway URL /// Hosting gateway URL
#[arg(long, env = "SYNOR_HOSTING_URL", default_value = "http://127.0.0.1:8280")] #[arg(
long,
env = "SYNOR_HOSTING_URL",
default_value = "http://127.0.0.1:8280"
)]
gateway: String, gateway: String,
}, },
@ -515,7 +527,11 @@ enum DeployCommands {
name: String, name: String,
/// Hosting gateway URL /// Hosting gateway URL
#[arg(long, env = "SYNOR_HOSTING_URL", default_value = "http://127.0.0.1:8280")] #[arg(
long,
env = "SYNOR_HOSTING_URL",
default_value = "http://127.0.0.1:8280"
)]
gateway: String, gateway: String,
}, },
} }
@ -591,9 +607,11 @@ async fn main() {
gateway, gateway,
skip_build, skip_build,
} => commands::deploy::deploy(name, out_dir, &gateway, skip_build, output).await, } => commands::deploy::deploy(name, out_dir, &gateway, skip_build, output).await,
DeployCommands::Init { name, spa, output: out_dir } => { DeployCommands::Init {
commands::deploy::init(name, spa, out_dir, output) name,
} spa,
output: out_dir,
} => commands::deploy::init(name, spa, out_dir, output),
DeployCommands::List { gateway } => commands::deploy::list(&gateway, output).await, DeployCommands::List { gateway } => commands::deploy::list(&gateway, output).await,
DeployCommands::Delete { name, gateway } => { DeployCommands::Delete { name, gateway } => {
commands::deploy::delete(&name, &gateway, output).await commands::deploy::delete(&name, &gateway, output).await

View file

@ -676,7 +676,9 @@ async fn get_blocks(
} }
// Fetch blocks by blue score (most recent first) // Fetch blocks by blue score (most recent first)
let start_score = score.blue_score.saturating_sub((params.page.saturating_sub(1) * limit) as u64); let start_score = score
.blue_score
.saturating_sub((params.page.saturating_sub(1) * limit) as u64);
let blocks_data: Vec<serde_json::Value> = state let blocks_data: Vec<serde_json::Value> = state
.rpc_call("synor_getBlocksByBlueScore", (start_score, true)) .rpc_call("synor_getBlocksByBlueScore", (start_score, true))
.await .await
@ -697,17 +699,28 @@ async fn get_blocks(
parent_hashes: header parent_hashes: header
.get("parents") .get("parents")
.and_then(|p| p.as_array()) .and_then(|p| p.as_array())
.map(|a| a.iter().filter_map(|v| v.as_str().map(String::from)).collect()) .map(|a| {
a.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect()
})
.unwrap_or_default(), .unwrap_or_default(),
timestamp, timestamp,
timestamp_human: format_timestamp(timestamp), timestamp_human: format_timestamp(timestamp),
bits: header.get("bits")?.as_u64()? as u32, bits: header.get("bits")?.as_u64()? as u32,
nonce: header.get("nonce")?.as_u64()?, nonce: header.get("nonce")?.as_u64()?,
daa_score: header.get("blueScore").and_then(|v| v.as_u64()).unwrap_or(0), daa_score: header
blue_score: header.get("blueScore").and_then(|v| v.as_u64()).unwrap_or(0), .get("blueScore")
.and_then(|v| v.as_u64())
.unwrap_or(0),
blue_score: header
.get("blueScore")
.and_then(|v| v.as_u64())
.unwrap_or(0),
blue_work: String::new(), blue_work: String::new(),
difficulty: 0.0, difficulty: 0.0,
transaction_count: b.get("transactions") transaction_count: b
.get("transactions")
.and_then(|t| t.as_array()) .and_then(|t| t.as_array())
.map(|a| a.len()) .map(|a| a.len())
.unwrap_or(0), .unwrap_or(0),
@ -1102,9 +1115,7 @@ async fn estimate_gas(
}; };
// Call the node's contract_estimateGas RPC method // Call the node's contract_estimateGas RPC method
let gas_used: u64 = state let gas_used: u64 = state.rpc_call("contract_estimateGas", rpc_request).await?;
.rpc_call("contract_estimateGas", rpc_request)
.await?;
// Calculate recommended gas limit with 20% safety margin // Calculate recommended gas limit with 20% safety margin
let gas_limit_recommended = ((gas_used as f64) * 1.2).ceil() as u64; let gas_limit_recommended = ((gas_used as f64) * 1.2).ceil() as u64;
@ -1494,8 +1505,7 @@ async fn main() -> anyhow::Result<()> {
let app = if let Some(ref static_dir) = config.static_dir { let app = if let Some(ref static_dir) = config.static_dir {
// Serve static files with SPA fallback (index.html for client-side routing) // Serve static files with SPA fallback (index.html for client-side routing)
let index_path = format!("{}/index.html", static_dir); let index_path = format!("{}/index.html", static_dir);
let serve_dir = ServeDir::new(static_dir) let serve_dir = ServeDir::new(static_dir).not_found_service(ServeFile::new(&index_path));
.not_found_service(ServeFile::new(&index_path));
api_router api_router
.fallback_service(serve_dir) .fallback_service(serve_dir)

View file

@ -684,10 +684,12 @@ mod tests {
#[test] #[test]
fn test_all_paths_are_distinct() { fn test_all_paths_are_distinct() {
let config = NodeConfig::for_network("mainnet").unwrap(); let config = NodeConfig::for_network("mainnet").unwrap();
let paths = [config.blocks_path(), let paths = [
config.blocks_path(),
config.chainstate_path(), config.chainstate_path(),
config.contracts_path(), config.contracts_path(),
config.keys_path()]; config.keys_path(),
];
for i in 0..paths.len() { for i in 0..paths.len() {
for j in (i + 1)..paths.len() { for j in (i + 1)..paths.len() {
@ -794,9 +796,11 @@ mod tests {
#[test] #[test]
fn test_with_mining_enabled() { fn test_with_mining_enabled() {
let config = NodeConfig::for_network("mainnet") let config = NodeConfig::for_network("mainnet").unwrap().with_mining(
.unwrap() true,
.with_mining(true, Some("synor:test_address".to_string()), 4); Some("synor:test_address".to_string()),
4,
);
assert!(config.mining.enabled); assert!(config.mining.enabled);
assert_eq!( assert_eq!(
@ -828,9 +832,10 @@ mod tests {
#[test] #[test]
fn test_with_p2p() { fn test_with_p2p() {
let seeds = vec!["seed1.example.com:30303".to_string()]; let seeds = vec!["seed1.example.com:30303".to_string()];
let config = NodeConfig::for_network("mainnet") let config =
.unwrap() NodeConfig::for_network("mainnet")
.with_p2p("0.0.0.0", 30303, seeds.clone()); .unwrap()
.with_p2p("0.0.0.0", 30303, seeds.clone());
assert_eq!(config.p2p.listen_addr, "0.0.0.0:30303"); assert_eq!(config.p2p.listen_addr, "0.0.0.0:30303");
assert_eq!(config.p2p.seeds, seeds); assert_eq!(config.p2p.seeds, seeds);
@ -1027,7 +1032,10 @@ mod tests {
let loaded = NodeConfig::load(&path).unwrap(); let loaded = NodeConfig::load(&path).unwrap();
assert_eq!(loaded.mining.enabled, config.mining.enabled); assert_eq!(loaded.mining.enabled, config.mining.enabled);
assert_eq!(loaded.mining.coinbase_address, config.mining.coinbase_address); assert_eq!(
loaded.mining.coinbase_address,
config.mining.coinbase_address
);
assert_eq!(loaded.mining.threads, config.mining.threads); assert_eq!(loaded.mining.threads, config.mining.threads);
assert_eq!(loaded.storage.cache_size_mb, config.storage.cache_size_mb); assert_eq!(loaded.storage.cache_size_mb, config.storage.cache_size_mb);
assert_eq!(loaded.logging.level, config.logging.level); assert_eq!(loaded.logging.level, config.logging.level);

View file

@ -425,11 +425,13 @@ mod tests {
#[test] #[test]
fn test_node_state_all_variants_are_distinct() { fn test_node_state_all_variants_are_distinct() {
let states = [NodeState::Starting, let states = [
NodeState::Starting,
NodeState::Syncing, NodeState::Syncing,
NodeState::Running, NodeState::Running,
NodeState::Stopping, NodeState::Stopping,
NodeState::Stopped]; NodeState::Stopped,
];
for i in 0..states.len() { for i in 0..states.len() {
for j in (i + 1)..states.len() { for j in (i + 1)..states.len() {
@ -605,7 +607,10 @@ mod tests {
.with_mining(true, Some("synor:test".to_string()), 4); .with_mining(true, Some("synor:test".to_string()), 4);
assert!(config.mining.enabled); assert!(config.mining.enabled);
assert_eq!(config.mining.coinbase_address, Some("synor:test".to_string())); assert_eq!(
config.mining.coinbase_address,
Some("synor:test".to_string())
);
assert_eq!(config.mining.threads, 4); assert_eq!(config.mining.threads, 4);
} }

View file

@ -12,9 +12,12 @@ use synor_mining::{
MinerCommand, MinerConfig, MinerEvent, MiningResult, MiningStats as CrateMiningStats, MinerCommand, MinerConfig, MinerEvent, MiningResult, MiningStats as CrateMiningStats,
TemplateTransaction, TemplateTransaction,
}; };
use synor_types::{Address, Amount, Block, BlockHeader, BlockId, BlueScore, Hash256, Network, Timestamp, Transaction, TxOutput};
use synor_types::block::BlockBody; use synor_types::block::BlockBody;
use synor_types::transaction::ScriptPubKey; use synor_types::transaction::ScriptPubKey;
use synor_types::{
Address, Amount, Block, BlockHeader, BlockId, BlueScore, Hash256, Network, Timestamp,
Transaction, TxOutput,
};
use crate::config::NodeConfig; use crate::config::NodeConfig;
use crate::services::{ConsensusService, MempoolService}; use crate::services::{ConsensusService, MempoolService};
@ -473,10 +476,7 @@ impl MinerService {
extra_data.extend_from_slice(&result.nonce.to_le_bytes()); extra_data.extend_from_slice(&result.nonce.to_le_bytes());
extra_data.extend_from_slice(&template.coinbase_data.extra_data); extra_data.extend_from_slice(&template.coinbase_data.extra_data);
let coinbase_tx = Transaction::coinbase( let coinbase_tx = Transaction::coinbase(vec![coinbase_output], extra_data);
vec![coinbase_output],
extra_data,
);
// Start with coinbase transaction // Start with coinbase transaction
let mut transactions = vec![coinbase_tx]; let mut transactions = vec![coinbase_tx];
@ -522,8 +522,7 @@ impl MinerService {
let block = Block { header, body }; let block = Block { header, body };
// Serialize with Borsh // Serialize with Borsh
borsh::to_vec(&block) borsh::to_vec(&block).map_err(|e| anyhow::anyhow!("Failed to serialize block: {}", e))
.map_err(|e| anyhow::anyhow!("Failed to serialize block: {}", e))
} }
/// Submits a mined block (for external submission via RPC). /// Submits a mined block (for external submission via RPC).

View file

@ -234,7 +234,10 @@ mod network_partition_tests {
// Node 0 should have fewer peers after isolation // Node 0 should have fewer peers after isolation
let isolated_peers = network.nodes[0].network().peer_count().await; let isolated_peers = network.nodes[0].network().peer_count().await;
info!(isolated_peers = isolated_peers, "Node 0 peers after isolation"); info!(
isolated_peers = isolated_peers,
"Node 0 peers after isolation"
);
assert!( assert!(
isolated_peers < initial_peer_counts[0] || initial_peer_counts[0] == 0, isolated_peers < initial_peer_counts[0] || initial_peer_counts[0] == 0,
@ -271,7 +274,10 @@ mod network_partition_tests {
// After healing, nodes should have peers // After healing, nodes should have peers
let total_peers = network.total_peer_count().await; let total_peers = network.total_peer_count().await;
info!(total_peers = total_peers, "Total peers after partition recovery"); info!(
total_peers = total_peers,
"Total peers after partition recovery"
);
// Consensus state should converge // Consensus state should converge
let consensus0 = network.nodes[0].consensus(); let consensus0 = network.nodes[0].consensus();
@ -287,7 +293,10 @@ mod network_partition_tests {
); );
// Both should have some consensus state // Both should have some consensus state
assert!(vsp0.is_some() || vsp1.is_some(), "At least one node should have VSP"); assert!(
vsp0.is_some() || vsp1.is_some(),
"At least one node should have VSP"
);
network.stop_all().await.unwrap(); network.stop_all().await.unwrap();
} }
@ -351,10 +360,12 @@ mod network_partition_tests {
// Record blue scores from each partition // Record blue scores from each partition
let scores_before: Vec<u64> = futures::future::join_all( let scores_before: Vec<u64> = futures::future::join_all(
network.nodes.iter().map(|n| async { network
n.consensus().current_blue_score().await .nodes
}) .iter()
).await; .map(|n| async { n.consensus().current_blue_score().await }),
)
.await;
info!(scores_before = ?scores_before, "Blue scores before healing"); info!(scores_before = ?scores_before, "Blue scores before healing");
@ -368,10 +379,12 @@ mod network_partition_tests {
// Blue scores should converge // Blue scores should converge
let scores_after: Vec<u64> = futures::future::join_all( let scores_after: Vec<u64> = futures::future::join_all(
network.nodes.iter().map(|n| async { network
n.consensus().current_blue_score().await .nodes
}) .iter()
).await; .map(|n| async { n.consensus().current_blue_score().await }),
)
.await;
info!(scores_after = ?scores_after, "Blue scores after healing"); info!(scores_after = ?scores_after, "Blue scores after healing");
@ -380,7 +393,9 @@ mod network_partition_tests {
assert!( assert!(
after >= before, after >= before,
"Node {} blue score should not decrease: {} -> {}", "Node {} blue score should not decrease: {} -> {}",
i, before, after i,
before,
after
); );
} }
@ -407,10 +422,7 @@ mod double_spend_tests {
let mempool = network.nodes[0].mempool(); let mempool = network.nodes[0].mempool();
let initial_size = mempool.size().await; let initial_size = mempool.size().await;
info!( info!(initial_mempool_size = initial_size, "Initial mempool state");
initial_mempool_size = initial_size,
"Initial mempool state"
);
// In production, we would: // In production, we would:
// 1. Create two transactions spending the same UTXO // 1. Create two transactions spending the same UTXO
@ -420,7 +432,7 @@ mod double_spend_tests {
// For now, verify mempool API is working // For now, verify mempool API is working
// and handles empty/invalid data gracefully // and handles empty/invalid data gracefully
let _invalid_tx = vec![0u8; 50]; // Invalid transaction bytes (for future use) let _invalid_tx = vec![0u8; 50]; // Invalid transaction bytes (for future use)
// Submitting invalid tx should fail gracefully // Submitting invalid tx should fail gracefully
// Mempool should maintain integrity // Mempool should maintain integrity
let final_size = mempool.size().await; let final_size = mempool.size().await;
@ -653,7 +665,11 @@ mod invalid_block_rejection_tests {
// All valid tips should have known parents in the DAG // All valid tips should have known parents in the DAG
for tip in &tips { for tip in &tips {
let has_parents = consensus.get_block_info(tip).await.map(|info| !info.parents.is_empty()).unwrap_or(false); let has_parents = consensus
.get_block_info(tip)
.await
.map(|info| !info.parents.is_empty())
.unwrap_or(false);
info!( info!(
block = hex::encode(&tip[..8]), block = hex::encode(&tip[..8]),
has_parents = has_parents, has_parents = has_parents,
@ -687,16 +703,22 @@ mod sybil_attack_tests {
// Track blue scores - honest nodes should maintain correct view // Track blue scores - honest nodes should maintain correct view
let honest_scores: Vec<u64> = futures::future::join_all( let honest_scores: Vec<u64> = futures::future::join_all(
network.nodes.iter().take(3).map(|n| async { network
n.consensus().current_blue_score().await .nodes
}) .iter()
).await; .take(3)
.map(|n| async { n.consensus().current_blue_score().await }),
)
.await;
let sybil_scores: Vec<u64> = futures::future::join_all( let sybil_scores: Vec<u64> = futures::future::join_all(
network.nodes.iter().skip(3).map(|n| async { network
n.consensus().current_blue_score().await .nodes
}) .iter()
).await; .skip(3)
.map(|n| async { n.consensus().current_blue_score().await }),
)
.await;
info!( info!(
honest_scores = ?honest_scores, honest_scores = ?honest_scores,
@ -805,7 +827,10 @@ mod eclipse_attack_tests {
sleep(Duration::from_secs(1)).await; sleep(Duration::from_secs(1)).await;
let after_eclipse_peers = victim_network.peer_count().await; let after_eclipse_peers = victim_network.peer_count().await;
info!(after_eclipse_peers = after_eclipse_peers, "Peers after eclipse attempt"); info!(
after_eclipse_peers = after_eclipse_peers,
"Peers after eclipse attempt"
);
// In a real implementation, the node would: // In a real implementation, the node would:
// 1. Detect low peer diversity // 1. Detect low peer diversity
@ -863,7 +888,10 @@ mod eclipse_attack_tests {
sleep(Duration::from_secs(1)).await; sleep(Duration::from_secs(1)).await;
let eclipsed_peers = network.nodes[0].network().peer_count().await; let eclipsed_peers = network.nodes[0].network().peer_count().await;
info!(eclipsed_peers = eclipsed_peers, "Node 0 peers during eclipse"); info!(
eclipsed_peers = eclipsed_peers,
"Node 0 peers during eclipse"
);
// Manually reconnect (simulating recovery mechanism) // Manually reconnect (simulating recovery mechanism)
network.connect_nodes(0, 1).await.unwrap(); network.connect_nodes(0, 1).await.unwrap();
@ -871,7 +899,10 @@ mod eclipse_attack_tests {
sleep(Duration::from_secs(2)).await; sleep(Duration::from_secs(2)).await;
let recovered_peers = network.nodes[0].network().peer_count().await; let recovered_peers = network.nodes[0].network().peer_count().await;
info!(recovered_peers = recovered_peers, "Node 0 peers after recovery"); info!(
recovered_peers = recovered_peers,
"Node 0 peers after recovery"
);
// Should have reconnected // Should have reconnected
assert!( assert!(
@ -1038,10 +1069,12 @@ mod dag_reorg_tests {
// Record divergent states // Record divergent states
let states_before: Vec<u64> = futures::future::join_all( let states_before: Vec<u64> = futures::future::join_all(
network.nodes.iter().map(|n| async { network
n.consensus().current_blue_score().await .nodes
}) .iter()
).await; .map(|n| async { n.consensus().current_blue_score().await }),
)
.await;
info!(states_before = ?states_before, "States before reconnection"); info!(states_before = ?states_before, "States before reconnection");
@ -1052,10 +1085,12 @@ mod dag_reorg_tests {
// Get converged states // Get converged states
let states_after: Vec<u64> = futures::future::join_all( let states_after: Vec<u64> = futures::future::join_all(
network.nodes.iter().map(|n| async { network
n.consensus().current_blue_score().await .nodes
}) .iter()
).await; .map(|n| async { n.consensus().current_blue_score().await }),
)
.await;
info!(states_after = ?states_after, "States after reconnection"); info!(states_after = ?states_after, "States after reconnection");
@ -1064,7 +1099,9 @@ mod dag_reorg_tests {
assert!( assert!(
after >= before, after >= before,
"Node {} blue score regression: {} -> {}", "Node {} blue score regression: {} -> {}",
i, before, after i,
before,
after
); );
} }
@ -1192,10 +1229,12 @@ mod parallel_blocks_tests {
// Collect blue scores from all nodes // Collect blue scores from all nodes
let blue_scores: Vec<u64> = futures::future::join_all( let blue_scores: Vec<u64> = futures::future::join_all(
network.nodes.iter().map(|n| async { network
n.consensus().current_blue_score().await .nodes
}) .iter()
).await; .map(|n| async { n.consensus().current_blue_score().await }),
)
.await;
info!(blue_scores = ?blue_scores, "Blue scores across nodes"); info!(blue_scores = ?blue_scores, "Blue scores across nodes");
@ -1206,7 +1245,8 @@ mod parallel_blocks_tests {
assert!( assert!(
max_score - min_score <= 2, max_score - min_score <= 2,
"Blue scores should be consistent: {} - {} > 2", "Blue scores should be consistent: {} - {} > 2",
max_score, min_score max_score,
min_score
); );
network.stop_all().await.unwrap(); network.stop_all().await.unwrap();
@ -1264,10 +1304,12 @@ mod parallel_blocks_tests {
// Get selected chains from all nodes // Get selected chains from all nodes
let chains: Vec<Vec<[u8; 32]>> = futures::future::join_all( let chains: Vec<Vec<[u8; 32]>> = futures::future::join_all(
network.nodes.iter().map(|n| async { network
n.consensus().get_selected_chain(10).await .nodes
}) .iter()
).await; .map(|n| async { n.consensus().get_selected_chain(10).await }),
)
.await;
info!( info!(
chain_lengths = ?chains.iter().map(|c| c.len()).collect::<Vec<_>>(), chain_lengths = ?chains.iter().map(|c| c.len()).collect::<Vec<_>>(),
@ -1276,7 +1318,8 @@ mod parallel_blocks_tests {
// All nodes should have the same selected chain (after sync) // All nodes should have the same selected chain (after sync)
// Check that genesis (first block) matches // Check that genesis (first block) matches
let genesis_blocks: Vec<_> = chains.iter() let genesis_blocks: Vec<_> = chains
.iter()
.filter(|c| !c.is_empty()) .filter(|c| !c.is_empty())
.map(|c| c[0]) .map(|c| c[0])
.collect(); .collect();
@ -1353,10 +1396,13 @@ mod bft_threshold_tests {
// Honest nodes (0, 1, 2) should maintain consensus // Honest nodes (0, 1, 2) should maintain consensus
let honest_scores: Vec<u64> = futures::future::join_all( let honest_scores: Vec<u64> = futures::future::join_all(
network.nodes.iter().take(3).map(|n| async { network
n.consensus().current_blue_score().await .nodes
}) .iter()
).await; .take(3)
.map(|n| async { n.consensus().current_blue_score().await }),
)
.await;
info!(honest_scores = ?honest_scores, "Honest node blue scores"); info!(honest_scores = ?honest_scores, "Honest node blue scores");
@ -1399,10 +1445,7 @@ mod bft_threshold_tests {
// Blue score should not decrease // Blue score should not decrease
let final_blue = network.nodes[0].consensus().current_blue_score().await; let final_blue = network.nodes[0].consensus().current_blue_score().await;
assert!( assert!(final_blue >= initial_blue, "Blue score should not decrease");
final_blue >= initial_blue,
"Blue score should not decrease"
);
// Stop remaining nodes // Stop remaining nodes
for node in network.nodes.iter().take(3) { for node in network.nodes.iter().take(3) {
@ -1615,10 +1658,12 @@ mod integration_tests {
// Record initial state // Record initial state
let initial_scores: Vec<u64> = futures::future::join_all( let initial_scores: Vec<u64> = futures::future::join_all(
network.nodes.iter().map(|n| async { network
n.consensus().current_blue_score().await .nodes
}) .iter()
).await; .map(|n| async { n.consensus().current_blue_score().await }),
)
.await;
info!(initial_scores = ?initial_scores, "Initial blue scores"); info!(initial_scores = ?initial_scores, "Initial blue scores");
info!("Phase 2: Simulate 2 Byzantine nodes (partition)"); info!("Phase 2: Simulate 2 Byzantine nodes (partition)");
@ -1640,18 +1685,24 @@ mod integration_tests {
info!("Phase 4: Verify convergence"); info!("Phase 4: Verify convergence");
let final_scores: Vec<u64> = futures::future::join_all( let final_scores: Vec<u64> = futures::future::join_all(
network.nodes.iter().map(|n| async { network
n.consensus().current_blue_score().await .nodes
}) .iter()
).await; .map(|n| async { n.consensus().current_blue_score().await }),
)
.await;
info!(final_scores = ?final_scores, "Final blue scores"); info!(final_scores = ?final_scores, "Final blue scores");
// All nodes should have non-decreasing blue scores // All nodes should have non-decreasing blue scores
for (i, (&initial, &final_score)) in initial_scores.iter().zip(final_scores.iter()).enumerate() { for (i, (&initial, &final_score)) in
initial_scores.iter().zip(final_scores.iter()).enumerate()
{
assert!( assert!(
final_score >= initial, final_score >= initial,
"Node {} score regression: {} -> {}", "Node {} score regression: {} -> {}",
i, initial, final_score i,
initial,
final_score
); );
} }

View file

@ -17,8 +17,8 @@
//! 3. Vault contract verifies proof and unlocks original tokens //! 3. Vault contract verifies proof and unlocks original tokens
use crate::{ use crate::{
AssetId, Bridge, BridgeAddress, BridgeError, BridgeResult, BridgeTransfer, ChainType, TransferId, TransferManager, TransferStatus, VaultManager, AssetId, Bridge, BridgeAddress, BridgeError, BridgeResult, BridgeTransfer, ChainType,
ETH_MIN_CONFIRMATIONS, TransferId, TransferManager, TransferStatus, VaultManager, ETH_MIN_CONFIRMATIONS,
}; };
use alloy_primitives::{Address, B256, U256}; use alloy_primitives::{Address, B256, U256};
use alloy_sol_types::sol; use alloy_sol_types::sol;
@ -281,9 +281,9 @@ impl EthereumBridge {
// Check for replay // Check for replay
let event_hash = event.hash(); let event_hash = event.hash();
if self.processed_events.read().contains_key(&event_hash) { if self.processed_events.read().contains_key(&event_hash) {
return Err(BridgeError::TransferAlreadyExists( return Err(BridgeError::TransferAlreadyExists(hex::encode(
hex::encode(event_hash.as_slice()), event_hash.as_slice(),
)); )));
} }
// Verify token is supported // Verify token is supported
@ -393,18 +393,15 @@ impl EthereumBridge {
// Collect matching transfer IDs first // Collect matching transfer IDs first
let matching_transfer_id = { let matching_transfer_id = {
let transfers = self.transfers.read(); let transfers = self.transfers.read();
transfers transfers.pending_transfers().iter().find_map(|transfer| {
.pending_transfers() transfer.source_tx_hash.as_ref().and_then(|tx_hash| {
.iter() if tx_hash.as_slice() == event_hash.as_slice() {
.find_map(|transfer| { Some(transfer.id.clone())
transfer.source_tx_hash.as_ref().and_then(|tx_hash| { } else {
if tx_hash.as_slice() == event_hash.as_slice() { None
Some(transfer.id.clone()) }
} else {
None
}
})
}) })
})
}; };
// Now update the transfer if found // Now update the transfer if found
@ -457,7 +454,9 @@ impl EthereumBridge {
.map_err(|e| BridgeError::InvalidAddress(e.to_string()))?; .map_err(|e| BridgeError::InvalidAddress(e.to_string()))?;
if bytes.len() != 20 { if bytes.len() != 20 {
return Err(BridgeError::InvalidAddress("invalid address length".to_string())); return Err(BridgeError::InvalidAddress(
"invalid address length".to_string(),
));
} }
Address::from_slice(&bytes) Address::from_slice(&bytes)
}; };
@ -801,7 +800,10 @@ mod tests {
wrapped.mint(1000); wrapped.mint(1000);
let result = wrapped.burn(1500); let result = wrapped.burn(1500);
assert!(matches!(result, Err(BridgeError::InsufficientBalance { .. }))); assert!(matches!(
result,
Err(BridgeError::InsufficientBalance { .. })
));
} }
#[test] #[test]
@ -896,7 +898,9 @@ mod tests {
let current_time = 1700000000; let current_time = 1700000000;
let event = create_lock_event(0); let event = create_lock_event(0);
bridge.process_lock_event(event.clone(), current_time).unwrap(); bridge
.process_lock_event(event.clone(), current_time)
.unwrap();
let result = bridge.process_lock_event(event, current_time + 100); let result = bridge.process_lock_event(event, current_time + 100);
assert!(matches!(result, Err(BridgeError::TransferAlreadyExists(_)))); assert!(matches!(result, Err(BridgeError::TransferAlreadyExists(_))));
@ -949,8 +953,12 @@ mod tests {
let event_hash = B256::from([0x11; 32]); let event_hash = B256::from([0x11; 32]);
let unauthorized_relayer = Address::from([0x99; 20]); let unauthorized_relayer = Address::from([0x99; 20]);
let result = bridge.submit_relayer_signature(event_hash, unauthorized_relayer, vec![0x00; 65]); let result =
assert!(matches!(result, Err(BridgeError::SignatureVerificationFailed(_)))); bridge.submit_relayer_signature(event_hash, unauthorized_relayer, vec![0x00; 65]);
assert!(matches!(
result,
Err(BridgeError::SignatureVerificationFailed(_))
));
} }
#[test] #[test]
@ -964,7 +972,9 @@ mod tests {
}); });
let event_hash = B256::from([0x11; 32]); let event_hash = B256::from([0x11; 32]);
let result = bridge.submit_relayer_signature(event_hash, relayer, vec![0x00; 65]).unwrap(); let result = bridge
.submit_relayer_signature(event_hash, relayer, vec![0x00; 65])
.unwrap();
assert!(result); assert!(result);
} }
@ -983,7 +993,9 @@ mod tests {
.update_confirmations(&transfer_id, 12, current_time + 100) .update_confirmations(&transfer_id, 12, current_time + 100)
.unwrap(); .unwrap();
bridge.mint_wrapped_tokens(&transfer_id, current_time + 200).unwrap(); bridge
.mint_wrapped_tokens(&transfer_id, current_time + 200)
.unwrap();
let wrapped = bridge.get_wrapped_token(Address::ZERO).unwrap(); let wrapped = bridge.get_wrapped_token(Address::ZERO).unwrap();
assert_eq!(wrapped.total_supply, 1000); assert_eq!(wrapped.total_supply, 1000);
@ -1022,13 +1034,7 @@ mod tests {
let asset = AssetId::wrapped(&AssetId::eth()); let asset = AssetId::wrapped(&AssetId::eth());
let transfer_id = bridge let transfer_id = bridge
.initiate_burn( .initiate_burn(asset, 1000, test_recipient(), test_sender(), current_time)
asset,
1000,
test_recipient(),
test_sender(),
current_time,
)
.unwrap(); .unwrap();
let transfers = bridge.transfers.read(); let transfers = bridge.transfers.read();
@ -1053,15 +1059,13 @@ mod tests {
drop(wrapped_tokens); drop(wrapped_tokens);
let asset = AssetId::wrapped(&AssetId::eth()); let asset = AssetId::wrapped(&AssetId::eth());
let result = bridge.initiate_burn( let result =
asset, bridge.initiate_burn(asset, 1000, test_recipient(), test_sender(), current_time);
1000,
test_recipient(),
test_sender(),
current_time,
);
assert!(matches!(result, Err(BridgeError::InsufficientBalance { .. }))); assert!(matches!(
result,
Err(BridgeError::InsufficientBalance { .. })
));
} }
#[test] #[test]
@ -1069,13 +1073,7 @@ mod tests {
let bridge = EthereumBridge::new(EthereumBridgeConfig::default()); let bridge = EthereumBridge::new(EthereumBridgeConfig::default());
let asset = AssetId::wrapped(&AssetId::eth()); let asset = AssetId::wrapped(&AssetId::eth());
let result = bridge.initiate_burn( let result = bridge.initiate_burn(asset, 1000, test_recipient(), test_sender(), 0);
asset,
1000,
test_recipient(),
test_sender(),
0,
);
assert!(matches!(result, Err(BridgeError::AssetNotSupported(_)))); assert!(matches!(result, Err(BridgeError::AssetNotSupported(_))));
} }

View file

@ -79,7 +79,9 @@ pub const ETH_MIN_CONFIRMATIONS: u64 = 12;
pub const BTC_MIN_CONFIRMATIONS: u64 = 6; pub const BTC_MIN_CONFIRMATIONS: u64 = 6;
/// Bridge chain identifier /// Bridge chain identifier
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, BorshSerialize, BorshDeserialize)] #[derive(
Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, BorshSerialize, BorshDeserialize,
)]
pub enum ChainType { pub enum ChainType {
/// Synor native chain /// Synor native chain
Synor, Synor,
@ -128,7 +130,9 @@ impl fmt::Display for ChainType {
} }
/// Asset identifier across chains /// Asset identifier across chains
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, BorshSerialize, BorshDeserialize)] #[derive(
Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, BorshSerialize, BorshDeserialize,
)]
pub struct AssetId { pub struct AssetId {
/// Chain where the asset originates /// Chain where the asset originates
pub chain: ChainType, pub chain: ChainType,
@ -199,7 +203,9 @@ impl fmt::Display for AssetId {
} }
/// Bridge address (unified format for cross-chain addresses) /// Bridge address (unified format for cross-chain addresses)
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, BorshSerialize, BorshDeserialize)] #[derive(
Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, BorshSerialize, BorshDeserialize,
)]
pub struct BridgeAddress { pub struct BridgeAddress {
/// Chain type /// Chain type
pub chain: ChainType, pub chain: ChainType,

View file

@ -14,7 +14,9 @@ use std::collections::HashMap;
use std::fmt; use std::fmt;
/// Unique transfer identifier /// Unique transfer identifier
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, BorshSerialize, BorshDeserialize)] #[derive(
Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, BorshSerialize, BorshDeserialize,
)]
pub struct TransferId(pub String); pub struct TransferId(pub String);
impl TransferId { impl TransferId {
@ -48,7 +50,9 @@ impl fmt::Display for TransferId {
} }
/// Transfer direction /// Transfer direction
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, BorshSerialize, BorshDeserialize)] #[derive(
Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, BorshSerialize, BorshDeserialize,
)]
pub enum TransferDirection { pub enum TransferDirection {
/// From external chain to Synor (Lock → Mint) /// From external chain to Synor (Lock → Mint)
Inbound, Inbound,
@ -66,7 +70,9 @@ impl fmt::Display for TransferDirection {
} }
/// Transfer status /// Transfer status
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, BorshSerialize, BorshDeserialize)] #[derive(
Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, BorshSerialize, BorshDeserialize,
)]
pub enum TransferStatus { pub enum TransferStatus {
/// Transfer initiated, awaiting lock confirmation /// Transfer initiated, awaiting lock confirmation
Pending, Pending,
@ -1030,7 +1036,10 @@ mod tests {
transfer.fail("Proof verification failed", current_time + 50); transfer.fail("Proof verification failed", current_time + 50);
assert_eq!(transfer.status, TransferStatus::Failed); assert_eq!(transfer.status, TransferStatus::Failed);
assert_eq!(transfer.error, Some("Proof verification failed".to_string())); assert_eq!(
transfer.error,
Some("Proof verification failed".to_string())
);
} }
#[test] #[test]
@ -1331,8 +1340,12 @@ mod tests {
assert_eq!(manager.pending_transfers().len(), 1); assert_eq!(manager.pending_transfers().len(), 1);
manager.confirm_lock(&id, vec![0x11; 32], 100, current_time + 10).unwrap(); manager
manager.update_confirmations(&id, 12, current_time + 120).unwrap(); .confirm_lock(&id, vec![0x11; 32], 100, current_time + 10)
.unwrap();
manager
.update_confirmations(&id, 12, current_time + 120)
.unwrap();
assert_eq!(manager.pending_transfers().len(), 0); assert_eq!(manager.pending_transfers().len(), 0);
} }
@ -1356,8 +1369,12 @@ mod tests {
assert_eq!(manager.ready_for_confirmation().len(), 0); assert_eq!(manager.ready_for_confirmation().len(), 0);
manager.confirm_lock(&id, vec![0x11; 32], 100, current_time + 10).unwrap(); manager
manager.update_confirmations(&id, 12, current_time + 120).unwrap(); .confirm_lock(&id, vec![0x11; 32], 100, current_time + 10)
.unwrap();
manager
.update_confirmations(&id, 12, current_time + 120)
.unwrap();
assert_eq!(manager.ready_for_confirmation().len(), 1); assert_eq!(manager.ready_for_confirmation().len(), 1);
} }
@ -1410,7 +1427,9 @@ mod tests {
) )
.unwrap(); .unwrap();
manager.fail_transfer(&id, "Verification failed", current_time + 50).unwrap(); manager
.fail_transfer(&id, "Verification failed", current_time + 50)
.unwrap();
let transfer = manager.get(&id).unwrap(); let transfer = manager.get(&id).unwrap();
assert_eq!(transfer.status, TransferStatus::Failed); assert_eq!(transfer.status, TransferStatus::Failed);
@ -1452,9 +1471,15 @@ mod tests {
) )
.unwrap(); .unwrap();
manager.confirm_lock(&id1, vec![0x11; 32], 100, current_time).unwrap(); manager
manager.update_confirmations(&id1, 12, current_time).unwrap(); .confirm_lock(&id1, vec![0x11; 32], 100, current_time)
manager.confirm_mint(&id1, vec![0x22; 32], current_time).unwrap(); .unwrap();
manager
.update_confirmations(&id1, 12, current_time)
.unwrap();
manager
.confirm_mint(&id1, vec![0x22; 32], current_time)
.unwrap();
let stats = manager.stats(); let stats = manager.stats();
assert_eq!(stats.total_count, 2); assert_eq!(stats.total_count, 2);
@ -1484,19 +1509,27 @@ mod tests {
let transfer = manager.get(&id).unwrap(); let transfer = manager.get(&id).unwrap();
assert_eq!(transfer.status, TransferStatus::Pending); assert_eq!(transfer.status, TransferStatus::Pending);
manager.confirm_lock(&id, vec![0x11; 32], 100, current_time + 60).unwrap(); manager
.confirm_lock(&id, vec![0x11; 32], 100, current_time + 60)
.unwrap();
let transfer = manager.get(&id).unwrap(); let transfer = manager.get(&id).unwrap();
assert_eq!(transfer.status, TransferStatus::Locked); assert_eq!(transfer.status, TransferStatus::Locked);
manager.update_confirmations(&id, 6, current_time + 120).unwrap(); manager
.update_confirmations(&id, 6, current_time + 120)
.unwrap();
let transfer = manager.get(&id).unwrap(); let transfer = manager.get(&id).unwrap();
assert_eq!(transfer.status, TransferStatus::Locked); assert_eq!(transfer.status, TransferStatus::Locked);
manager.update_confirmations(&id, 12, current_time + 180).unwrap(); manager
.update_confirmations(&id, 12, current_time + 180)
.unwrap();
let transfer = manager.get(&id).unwrap(); let transfer = manager.get(&id).unwrap();
assert_eq!(transfer.status, TransferStatus::Confirmed); assert_eq!(transfer.status, TransferStatus::Confirmed);
manager.confirm_mint(&id, vec![0x22; 32], current_time + 240).unwrap(); manager
.confirm_mint(&id, vec![0x22; 32], current_time + 240)
.unwrap();
let transfer = manager.get(&id).unwrap(); let transfer = manager.get(&id).unwrap();
assert_eq!(transfer.status, TransferStatus::Completed); assert_eq!(transfer.status, TransferStatus::Completed);
} }
@ -1521,15 +1554,21 @@ mod tests {
let transfer = manager.get(&id).unwrap(); let transfer = manager.get(&id).unwrap();
assert_eq!(transfer.status, TransferStatus::Pending); assert_eq!(transfer.status, TransferStatus::Pending);
manager.confirm_lock(&id, vec![0x11; 32], 100, current_time + 60).unwrap(); manager
.confirm_lock(&id, vec![0x11; 32], 100, current_time + 60)
.unwrap();
let transfer = manager.get(&id).unwrap(); let transfer = manager.get(&id).unwrap();
assert_eq!(transfer.status, TransferStatus::Locked); assert_eq!(transfer.status, TransferStatus::Locked);
manager.update_confirmations(&id, 6, current_time + 120).unwrap(); manager
.update_confirmations(&id, 6, current_time + 120)
.unwrap();
let transfer = manager.get(&id).unwrap(); let transfer = manager.get(&id).unwrap();
assert_eq!(transfer.status, TransferStatus::Confirmed); assert_eq!(transfer.status, TransferStatus::Confirmed);
manager.confirm_unlock(&id, vec![0x33; 32], current_time + 180).unwrap(); manager
.confirm_unlock(&id, vec![0x33; 32], current_time + 180)
.unwrap();
let transfer = manager.get(&id).unwrap(); let transfer = manager.get(&id).unwrap();
assert_eq!(transfer.status, TransferStatus::Completed); assert_eq!(transfer.status, TransferStatus::Completed);
} }

View file

@ -12,7 +12,9 @@ use std::collections::HashMap;
use std::fmt; use std::fmt;
/// Unique vault identifier /// Unique vault identifier
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, BorshSerialize, BorshDeserialize)] #[derive(
Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, BorshSerialize, BorshDeserialize,
)]
pub struct VaultId(pub String); pub struct VaultId(pub String);
impl VaultId { impl VaultId {
@ -198,13 +200,7 @@ impl Vault {
return Err(BridgeError::TransferAlreadyExists(lock_id)); return Err(BridgeError::TransferAlreadyExists(lock_id));
} }
let locked = LockedAsset::new( let locked = LockedAsset::new(self.asset.clone(), amount, owner, recipient, current_time);
self.asset.clone(),
amount,
owner,
recipient,
current_time,
);
self.locked_assets.insert(lock_id, locked); self.locked_assets.insert(lock_id, locked);
self.total_locked += amount; self.total_locked += amount;
@ -283,7 +279,10 @@ impl Vault {
} }
/// Get expired locked assets /// Get expired locked assets
pub fn expired_locked(&self, current_time: u64) -> impl Iterator<Item = (&String, &LockedAsset)> { pub fn expired_locked(
&self,
current_time: u64,
) -> impl Iterator<Item = (&String, &LockedAsset)> {
self.locked_assets self.locked_assets
.iter() .iter()
.filter(move |(_, l)| !l.released && l.is_expired(current_time)) .filter(move |(_, l)| !l.released && l.is_expired(current_time))
@ -511,14 +510,8 @@ mod tests {
#[test] #[test]
fn test_locked_asset_expiry() { fn test_locked_asset_expiry() {
let locked = LockedAsset::new( let locked = LockedAsset::new(AssetId::eth(), 1000, test_owner(), test_recipient(), 1000)
AssetId::eth(), .with_expiry(2000);
1000,
test_owner(),
test_recipient(),
1000,
)
.with_expiry(2000);
assert!(!locked.is_expired(1500)); assert!(!locked.is_expired(1500));
assert!(locked.is_expired(2000)); assert!(locked.is_expired(2000));
@ -624,11 +617,7 @@ mod tests {
#[test] #[test]
fn test_lock_unlock() { fn test_lock_unlock() {
let mut vault = Vault::new( let mut vault = Vault::new(VaultId::new("test"), ChainType::Ethereum, AssetId::eth());
VaultId::new("test"),
ChainType::Ethereum,
AssetId::eth(),
);
let current_time = 1700000000; let current_time = 1700000000;
@ -654,20 +643,28 @@ mod tests {
); );
let current_time = 1700000000; let current_time = 1700000000;
vault.lock("lock-1", 1000, test_owner(), test_recipient(), current_time).unwrap(); vault
vault.lock("lock-2", 2000, test_owner(), test_recipient(), current_time).unwrap(); .lock("lock-1", 1000, test_owner(), test_recipient(), current_time)
vault.lock("lock-3", 500, test_owner_alt(), test_recipient(), current_time).unwrap(); .unwrap();
vault
.lock("lock-2", 2000, test_owner(), test_recipient(), current_time)
.unwrap();
vault
.lock(
"lock-3",
500,
test_owner_alt(),
test_recipient(),
current_time,
)
.unwrap();
assert_eq!(vault.total_locked, 3500); assert_eq!(vault.total_locked, 3500);
} }
#[test] #[test]
fn test_duplicate_lock() { fn test_duplicate_lock() {
let mut vault = Vault::new( let mut vault = Vault::new(VaultId::new("test"), ChainType::Ethereum, AssetId::eth());
VaultId::new("test"),
ChainType::Ethereum,
AssetId::eth(),
);
vault vault
.lock("lock1", 1000, test_owner(), test_recipient(), 0) .lock("lock1", 1000, test_owner(), test_recipient(), 0)
@ -697,20 +694,21 @@ mod tests {
AssetId::eth(), AssetId::eth(),
); );
vault.lock("lock-1", 1000, test_owner(), test_recipient(), 0).unwrap(); vault
.lock("lock-1", 1000, test_owner(), test_recipient(), 0)
.unwrap();
vault.unlock("lock-1").unwrap(); vault.unlock("lock-1").unwrap();
let result = vault.unlock("lock-1"); let result = vault.unlock("lock-1");
assert!(matches!(result, Err(BridgeError::TransferAlreadyCompleted(_)))); assert!(matches!(
result,
Err(BridgeError::TransferAlreadyCompleted(_))
));
} }
#[test] #[test]
fn test_vault_pause() { fn test_vault_pause() {
let mut vault = Vault::new( let mut vault = Vault::new(VaultId::new("test"), ChainType::Ethereum, AssetId::eth());
VaultId::new("test"),
ChainType::Ethereum,
AssetId::eth(),
);
vault.pause(); vault.pause();
@ -730,7 +728,9 @@ mod tests {
vault.resume(); vault.resume();
assert_eq!(vault.state, VaultState::Active); assert_eq!(vault.state, VaultState::Active);
vault.lock("lock-1", 1000, test_owner(), test_recipient(), 0).unwrap(); vault
.lock("lock-1", 1000, test_owner(), test_recipient(), 0)
.unwrap();
} }
#[test] #[test]
@ -750,12 +750,8 @@ mod tests {
#[test] #[test]
fn test_daily_limit() { fn test_daily_limit() {
let mut vault = Vault::new( let mut vault = Vault::new(VaultId::new("test"), ChainType::Ethereum, AssetId::eth())
VaultId::new("test"), .with_daily_limit(1000);
ChainType::Ethereum,
AssetId::eth(),
)
.with_daily_limit(1000);
let current_time = 86400 * 100; let current_time = 86400 * 100;
@ -781,8 +777,24 @@ mod tests {
); );
let current_time = 0; let current_time = 0;
vault.lock("lock-1", 1000000000, test_owner(), test_recipient(), current_time).unwrap(); vault
vault.lock("lock-2", 1000000000, test_owner(), test_recipient(), current_time).unwrap(); .lock(
"lock-1",
1000000000,
test_owner(),
test_recipient(),
current_time,
)
.unwrap();
vault
.lock(
"lock-2",
1000000000,
test_owner(),
test_recipient(),
current_time,
)
.unwrap();
assert_eq!(vault.total_locked, 2000000000); assert_eq!(vault.total_locked, 2000000000);
} }
@ -795,7 +807,9 @@ mod tests {
AssetId::eth(), AssetId::eth(),
); );
vault.lock("lock-1", 1000, test_owner(), test_recipient(), 0).unwrap(); vault
.lock("lock-1", 1000, test_owner(), test_recipient(), 0)
.unwrap();
assert!(vault.get_locked("lock-1").is_some()); assert!(vault.get_locked("lock-1").is_some());
assert!(vault.get_locked("nonexistent").is_none()); assert!(vault.get_locked("nonexistent").is_none());
@ -809,8 +823,12 @@ mod tests {
AssetId::eth(), AssetId::eth(),
); );
vault.lock("lock-1", 1000, test_owner(), test_recipient(), 0).unwrap(); vault
vault.lock("lock-2", 2000, test_owner(), test_recipient(), 0).unwrap(); .lock("lock-1", 1000, test_owner(), test_recipient(), 0)
.unwrap();
vault
.lock("lock-2", 2000, test_owner(), test_recipient(), 0)
.unwrap();
let all: Vec<_> = vault.all_locked().collect(); let all: Vec<_> = vault.all_locked().collect();
assert_eq!(all.len(), 2); assert_eq!(all.len(), 2);
@ -824,8 +842,12 @@ mod tests {
AssetId::eth(), AssetId::eth(),
); );
vault.lock("lock-1", 1000, test_owner(), test_recipient(), 0).unwrap(); vault
vault.lock("lock-2", 2000, test_owner(), test_recipient(), 0).unwrap(); .lock("lock-1", 1000, test_owner(), test_recipient(), 0)
.unwrap();
vault
.lock("lock-2", 2000, test_owner(), test_recipient(), 0)
.unwrap();
vault.unlock("lock-1").unwrap(); vault.unlock("lock-1").unwrap();
let active: Vec<_> = vault.active_locked().collect(); let active: Vec<_> = vault.active_locked().collect();
@ -858,7 +880,9 @@ mod tests {
assert!(manager.find_vault(&ChainType::Ethereum, &eth).is_some()); assert!(manager.find_vault(&ChainType::Ethereum, &eth).is_some());
let vault = manager.get_or_create_vault(ChainType::Ethereum, eth.clone()); let vault = manager.get_or_create_vault(ChainType::Ethereum, eth.clone());
vault.lock("lock1", 100, test_owner(), test_recipient(), 0).unwrap(); vault
.lock("lock1", 100, test_owner(), test_recipient(), 0)
.unwrap();
assert_eq!(manager.total_locked(), 100); assert_eq!(manager.total_locked(), 100);
} }
@ -881,7 +905,9 @@ mod tests {
{ {
let vault = manager.get_vault_mut(&vault_id).unwrap(); let vault = manager.get_vault_mut(&vault_id).unwrap();
vault.lock("lock-1", 1000, test_owner(), test_recipient(), 0).unwrap(); vault
.lock("lock-1", 1000, test_owner(), test_recipient(), 0)
.unwrap();
} }
let vault = manager.get_vault(&vault_id).unwrap(); let vault = manager.get_vault(&vault_id).unwrap();
@ -902,7 +928,9 @@ mod tests {
manager.create_vault(ChainType::Ethereum, eth.clone()); manager.create_vault(ChainType::Ethereum, eth.clone());
let vault = manager.find_vault_mut(&ChainType::Ethereum, &eth).unwrap(); let vault = manager.find_vault_mut(&ChainType::Ethereum, &eth).unwrap();
vault.lock("lock-1", 1000, test_owner(), test_recipient(), 0).unwrap(); vault
.lock("lock-1", 1000, test_owner(), test_recipient(), 0)
.unwrap();
assert_eq!(manager.total_locked(), 1000); assert_eq!(manager.total_locked(), 1000);
} }
@ -913,7 +941,9 @@ mod tests {
let eth = AssetId::eth(); let eth = AssetId::eth();
let vault = manager.get_or_create_vault(ChainType::Ethereum, eth.clone()); let vault = manager.get_or_create_vault(ChainType::Ethereum, eth.clone());
vault.lock("lock-1", 1000, test_owner(), test_recipient(), 0).unwrap(); vault
.lock("lock-1", 1000, test_owner(), test_recipient(), 0)
.unwrap();
assert_eq!(manager.vault_ids().len(), 1); assert_eq!(manager.vault_ids().len(), 1);
assert_eq!(manager.total_locked(), 1000); assert_eq!(manager.total_locked(), 1000);

View file

@ -241,7 +241,10 @@ impl DeviceRegistry {
} }
/// Gets a processor by ID. /// Gets a processor by ID.
pub fn get_processor(&self, processor_id: ProcessorId) -> Result<Arc<dyn Processor>, ComputeError> { pub fn get_processor(
&self,
processor_id: ProcessorId,
) -> Result<Arc<dyn Processor>, ComputeError> {
self.processors self.processors
.read() .read()
.get(&processor_id) .get(&processor_id)
@ -266,7 +269,10 @@ impl DeviceRegistry {
/// Gets the next processor ID. /// Gets the next processor ID.
pub fn next_processor_id(&self) -> ProcessorId { pub fn next_processor_id(&self) -> ProcessorId {
ProcessorId(self.next_processor_id.fetch_add(1, std::sync::atomic::Ordering::SeqCst)) ProcessorId(
self.next_processor_id
.fetch_add(1, std::sync::atomic::Ordering::SeqCst),
)
} }
/// Gets total number of devices. /// Gets total number of devices.
@ -309,7 +315,10 @@ impl DeviceRegistry {
device.status = status; device.status = status;
Ok(()) Ok(())
} else { } else {
Err(ComputeError::Internal(format!("Device not found: {}", device_id))) Err(ComputeError::Internal(format!(
"Device not found: {}",
device_id
)))
} }
} }
} }
@ -323,7 +332,7 @@ impl Default for DeviceRegistry {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::processor::{CpuVariant, AvxSupport}; use crate::processor::{AvxSupport, CpuVariant};
#[test] #[test]
fn test_device_id() { fn test_device_id() {

View file

@ -67,6 +67,10 @@ pub use market::{
ResourceType, SpotMarket, Trade, ResourceType, SpotMarket, Trade,
}; };
pub use memory::{MemoryManager, TensorHandle, TransferPath, UnifiedMemory}; pub use memory::{MemoryManager, TensorHandle, TransferPath, UnifiedMemory};
pub use model::{
ModelCategory, ModelFormat, ModelId, ModelInfo, ModelRegistry, ModelUploadRequest,
ModelUploadResponse,
};
pub use processor::{ pub use processor::{
ComputeThroughput, CpuVariant, GpuVariant, NpuVariant, Operation, OperationType, Processor, ComputeThroughput, CpuVariant, GpuVariant, NpuVariant, Operation, OperationType, Processor,
ProcessorCapabilities, ProcessorId, ProcessorType, TpuVersion, ProcessorCapabilities, ProcessorId, ProcessorType, TpuVersion,
@ -78,10 +82,6 @@ pub use task::{
ComputeTask, DecomposedWorkload, Task, TaskDecomposer, TaskId, TaskPriority, TaskResult, ComputeTask, DecomposedWorkload, Task, TaskDecomposer, TaskId, TaskPriority, TaskResult,
TaskStatus, TaskStatus,
}; };
pub use model::{
ModelCategory, ModelFormat, ModelId, ModelInfo, ModelRegistry, ModelUploadRequest,
ModelUploadResponse,
};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::HashMap;
@ -434,7 +434,10 @@ impl ComputeCluster {
let jobs = self.jobs.read(); let jobs = self.jobs.read();
let total_nodes = nodes.len(); let total_nodes = nodes.len();
let online_nodes = nodes.values().filter(|n| n.status == NodeStatus::Online).count(); let online_nodes = nodes
.values()
.filter(|n| n.status == NodeStatus::Online)
.count();
let total_gpus: usize = nodes let total_gpus: usize = nodes
.values() .values()
@ -515,16 +518,16 @@ pub enum GpuTier {
impl Default for ComputePricing { impl Default for ComputePricing {
fn default() -> Self { fn default() -> Self {
let mut gpu_hourly = HashMap::new(); let mut gpu_hourly = HashMap::new();
gpu_hourly.insert(GpuTier::Consumer, 100_000_000); // 0.10 SYNOR gpu_hourly.insert(GpuTier::Consumer, 100_000_000); // 0.10 SYNOR
gpu_hourly.insert(GpuTier::Professional, 300_000_000); // 0.30 SYNOR gpu_hourly.insert(GpuTier::Professional, 300_000_000); // 0.30 SYNOR
gpu_hourly.insert(GpuTier::DataCenter, 2_000_000_000); // 2.00 SYNOR gpu_hourly.insert(GpuTier::DataCenter, 2_000_000_000); // 2.00 SYNOR
gpu_hourly.insert(GpuTier::Premium, 4_000_000_000); // 4.00 SYNOR gpu_hourly.insert(GpuTier::Premium, 4_000_000_000); // 4.00 SYNOR
Self { Self {
gpu_hourly, gpu_hourly,
cpu_core_hour: 20_000_000, // 0.02 SYNOR cpu_core_hour: 20_000_000, // 0.02 SYNOR
memory_gb_hour: 5_000_000, // 0.005 SYNOR memory_gb_hour: 5_000_000, // 0.005 SYNOR
network_egress_gb: 50_000_000, // 0.05 SYNOR network_egress_gb: 50_000_000, // 0.05 SYNOR
inference_per_million_tokens: 100_000_000, // 0.10 SYNOR inference_per_million_tokens: 100_000_000, // 0.10 SYNOR
} }
} }

View file

@ -686,24 +686,24 @@ impl PricingEngine {
pub fn greenest_region(&self) -> &str { pub fn greenest_region(&self) -> &str {
self.regions self.regions
.iter() .iter()
.max_by(|a, b| { .max_by(|a, b| a.renewable_pct.partial_cmp(&b.renewable_pct).unwrap())
a.renewable_pct
.partial_cmp(&b.renewable_pct)
.unwrap()
})
.map(|r| r.region.as_str()) .map(|r| r.region.as_str())
.unwrap_or("eu-north") .unwrap_or("eu-north")
} }
/// Compares price to cloud providers. /// Compares price to cloud providers.
pub fn compare_to_cloud(&self, resource: &ResourceType, region: Option<&str>) -> CloudComparison { pub fn compare_to_cloud(
&self,
resource: &ResourceType,
region: Option<&str>,
) -> CloudComparison {
let our_price = self.spot_price(resource, region); let our_price = self.spot_price(resource, region);
// Approximate cloud provider prices (USD/hour for GPU) // Approximate cloud provider prices (USD/hour for GPU)
let (aws_price, gcp_price, azure_price) = match resource { let (aws_price, gcp_price, azure_price) = match resource {
ResourceType::GpuHours(GpuTier::DataCenter) => (3.06, 2.95, 3.10), // A100 equivalents ResourceType::GpuHours(GpuTier::DataCenter) => (3.06, 2.95, 3.10), // A100 equivalents
ResourceType::GpuHours(GpuTier::Ultra) => (5.00, 4.50, 5.20), // H100 equivalents ResourceType::GpuHours(GpuTier::Ultra) => (5.00, 4.50, 5.20), // H100 equivalents
ResourceType::GpuHours(GpuTier::High) => (1.50, 1.40, 1.60), // T4/A10 equivalents ResourceType::GpuHours(GpuTier::High) => (1.50, 1.40, 1.60), // T4/A10 equivalents
ResourceType::CpuHours(CpuTier::Server) => (0.40, 0.35, 0.42), ResourceType::CpuHours(CpuTier::Server) => (0.40, 0.35, 0.42),
_ => (1.0, 1.0, 1.0), _ => (1.0, 1.0, 1.0),
}; };
@ -888,9 +888,18 @@ impl SpotMarket {
); );
} }
order_books.insert(ResourceType::TpuHours, OrderBook::new(ResourceType::TpuHours)); order_books.insert(
order_books.insert(ResourceType::NpuHours, OrderBook::new(ResourceType::NpuHours)); ResourceType::TpuHours,
order_books.insert(ResourceType::LpuCredits, OrderBook::new(ResourceType::LpuCredits)); OrderBook::new(ResourceType::TpuHours),
);
order_books.insert(
ResourceType::NpuHours,
OrderBook::new(ResourceType::NpuHours),
);
order_books.insert(
ResourceType::LpuCredits,
OrderBook::new(ResourceType::LpuCredits),
);
Self { Self {
order_books, order_books,
@ -1074,12 +1083,21 @@ mod tests {
fn test_pricing_engine() { fn test_pricing_engine() {
let engine = PricingEngine::new(); let engine = PricingEngine::new();
let price = engine.spot_price(&ResourceType::GpuHours(GpuTier::DataCenter), Some("eu-north")); let price = engine.spot_price(
&ResourceType::GpuHours(GpuTier::DataCenter),
Some("eu-north"),
);
assert!(price > 0.0); assert!(price > 0.0);
// eu-north should be cheaper (low electricity cost) // eu-north should be cheaper (low electricity cost)
let eu_price = engine.spot_price(&ResourceType::GpuHours(GpuTier::DataCenter), Some("eu-north")); let eu_price = engine.spot_price(
let eu_west_price = engine.spot_price(&ResourceType::GpuHours(GpuTier::DataCenter), Some("eu-west")); &ResourceType::GpuHours(GpuTier::DataCenter),
Some("eu-north"),
);
let eu_west_price = engine.spot_price(
&ResourceType::GpuHours(GpuTier::DataCenter),
Some("eu-west"),
);
// eu-north has cheaper electricity // eu-north has cheaper electricity
assert!(eu_price < eu_west_price); assert!(eu_price < eu_west_price);
@ -1089,7 +1107,8 @@ mod tests {
fn test_cloud_comparison() { fn test_cloud_comparison() {
let engine = PricingEngine::new(); let engine = PricingEngine::new();
let comparison = engine.compare_to_cloud(&ResourceType::GpuHours(GpuTier::DataCenter), None); let comparison =
engine.compare_to_cloud(&ResourceType::GpuHours(GpuTier::DataCenter), None);
// Should show significant savings // Should show significant savings
assert!(comparison.aws_savings > 50.0); assert!(comparison.aws_savings > 50.0);

View file

@ -106,11 +106,11 @@ impl TransferPath {
/// Returns approximate bandwidth in GB/s. /// Returns approximate bandwidth in GB/s.
pub fn bandwidth_gbps(&self) -> f64 { pub fn bandwidth_gbps(&self) -> f64 {
match self { match self {
TransferPath::NvLink => 900.0, // NVLink 4.0 TransferPath::NvLink => 900.0, // NVLink 4.0
TransferPath::PciePeerToPeer => 64.0, // PCIe 5.0 x16 TransferPath::PciePeerToPeer => 64.0, // PCIe 5.0 x16
TransferPath::CpuMediated => 50.0, // DDR5 TransferPath::CpuMediated => 50.0, // DDR5
TransferPath::UnifiedMemory => 400.0, // Apple unified TransferPath::UnifiedMemory => 400.0, // Apple unified
TransferPath::Network => 10.0, // 100Gbps network TransferPath::Network => 10.0, // 100Gbps network
TransferPath::SameMemory => f64::INFINITY, TransferPath::SameMemory => f64::INFINITY,
} }
} }
@ -154,7 +154,11 @@ impl MemoryManager {
} }
/// Allocates a tensor. /// Allocates a tensor.
pub fn allocate(&self, shape: Vec<usize>, dtype: DataType) -> Result<TensorHandle, ComputeError> { pub fn allocate(
&self,
shape: Vec<usize>,
dtype: DataType,
) -> Result<TensorHandle, ComputeError> {
let handle = TensorHandle::new(shape, dtype); let handle = TensorHandle::new(shape, dtype);
self.tensors.write().insert(handle.id, handle.clone()); self.tensors.write().insert(handle.id, handle.clone());
Ok(handle) Ok(handle)
@ -223,9 +227,13 @@ impl MemoryManager {
} }
// Check for NVLink between NVIDIA GPUs // Check for NVLink between NVIDIA GPUs
if matches!(from, ProcessorType::Gpu(crate::processor::GpuVariant::NvidiaCuda { .. })) if matches!(
&& matches!(to, ProcessorType::Gpu(crate::processor::GpuVariant::NvidiaCuda { .. })) from,
{ ProcessorType::Gpu(crate::processor::GpuVariant::NvidiaCuda { .. })
) && matches!(
to,
ProcessorType::Gpu(crate::processor::GpuVariant::NvidiaCuda { .. })
) {
return TransferPath::NvLink; return TransferPath::NvLink;
} }
@ -244,10 +252,22 @@ impl MemoryManager {
match (a, b) { match (a, b) {
// Apple Silicon unified memory // Apple Silicon unified memory
(ProcessorType::Cpu(CpuVariant::Arm64 { .. }), ProcessorType::Gpu(GpuVariant::AppleMetal)) (
| (ProcessorType::Gpu(GpuVariant::AppleMetal), ProcessorType::Cpu(CpuVariant::Arm64 { .. })) ProcessorType::Cpu(CpuVariant::Arm64 { .. }),
| (ProcessorType::Npu(NpuVariant::AppleNeuralEngine { .. }), ProcessorType::Cpu(CpuVariant::Arm64 { .. })) ProcessorType::Gpu(GpuVariant::AppleMetal),
| (ProcessorType::Npu(NpuVariant::AppleNeuralEngine { .. }), ProcessorType::Gpu(GpuVariant::AppleMetal)) => true, )
| (
ProcessorType::Gpu(GpuVariant::AppleMetal),
ProcessorType::Cpu(CpuVariant::Arm64 { .. }),
)
| (
ProcessorType::Npu(NpuVariant::AppleNeuralEngine { .. }),
ProcessorType::Cpu(CpuVariant::Arm64 { .. }),
)
| (
ProcessorType::Npu(NpuVariant::AppleNeuralEngine { .. }),
ProcessorType::Gpu(GpuVariant::AppleMetal),
) => true,
// Same type // Same type
_ if a == b => true, _ if a == b => true,
_ => false, _ => false,
@ -325,7 +345,9 @@ mod tests {
#[test] #[test]
fn test_transfer_path_bandwidth() { fn test_transfer_path_bandwidth() {
assert!(TransferPath::NvLink.bandwidth_gbps() > TransferPath::PciePeerToPeer.bandwidth_gbps()); assert!(
TransferPath::NvLink.bandwidth_gbps() > TransferPath::PciePeerToPeer.bandwidth_gbps()
);
assert!(TransferPath::SameMemory.bandwidth_gbps().is_infinite()); assert!(TransferPath::SameMemory.bandwidth_gbps().is_infinite());
} }
@ -333,7 +355,9 @@ mod tests {
fn test_memory_manager() { fn test_memory_manager() {
let manager = MemoryManager::new(); let manager = MemoryManager::new();
let handle = manager.allocate(vec![1024, 1024], DataType::Float32).unwrap(); let handle = manager
.allocate(vec![1024, 1024], DataType::Float32)
.unwrap();
assert_eq!(manager.tensor_count(), 1); assert_eq!(manager.tensor_count(), 1);
manager.free(handle.id).unwrap(); manager.free(handle.id).unwrap();
@ -347,22 +371,26 @@ mod tests {
let handle = manager.allocate(vec![1024], DataType::Float32).unwrap(); let handle = manager.allocate(vec![1024], DataType::Float32).unwrap();
// First ensure should allocate // First ensure should allocate
let path = manager.ensure_on( let path = manager
handle.id, .ensure_on(
ProcessorType::Gpu(crate::processor::GpuVariant::NvidiaCuda { handle.id,
compute_capability: (8, 0), ProcessorType::Gpu(crate::processor::GpuVariant::NvidiaCuda {
}), compute_capability: (8, 0),
).unwrap(); }),
)
.unwrap();
assert_eq!(path, TransferPath::SameMemory); assert_eq!(path, TransferPath::SameMemory);
// Second ensure to same location should be same memory // Second ensure to same location should be same memory
let path = manager.ensure_on( let path = manager
handle.id, .ensure_on(
ProcessorType::Gpu(crate::processor::GpuVariant::NvidiaCuda { handle.id,
compute_capability: (8, 0), ProcessorType::Gpu(crate::processor::GpuVariant::NvidiaCuda {
}), compute_capability: (8, 0),
).unwrap(); }),
)
.unwrap();
assert_eq!(path, TransferPath::SameMemory); assert_eq!(path, TransferPath::SameMemory);
} }

View file

@ -140,13 +140,7 @@ pub struct ModelInfo {
impl ModelInfo { impl ModelInfo {
/// Creates a new LLM model info. /// Creates a new LLM model info.
pub fn llm( pub fn llm(alias: &str, name: &str, cid: &str, parameters: u64, context_length: u32) -> Self {
alias: &str,
name: &str,
cid: &str,
parameters: u64,
context_length: u32,
) -> Self {
Self { Self {
id: ModelId::from_alias(alias), id: ModelId::from_alias(alias),
name: name.to_string(), name: name.to_string(),
@ -156,7 +150,12 @@ impl ModelInfo {
format: ModelFormat::SafeTensors, format: ModelFormat::SafeTensors,
size_bytes: parameters * 2, // ~2 bytes per param in fp16 size_bytes: parameters * 2, // ~2 bytes per param in fp16
parameters, parameters,
supported_precisions: vec![Precision::Fp16, Precision::Bf16, Precision::Int8, Precision::Int4], supported_precisions: vec![
Precision::Fp16,
Precision::Bf16,
Precision::Int8,
Precision::Int4,
],
recommended_processor: ProcessorType::Lpu, recommended_processor: ProcessorType::Lpu,
context_length: Some(context_length), context_length: Some(context_length),
input_schema: None, input_schema: None,
@ -238,33 +237,123 @@ impl ModelRegistry {
let default_models = vec![ let default_models = vec![
// ===== LLMs ===== // ===== LLMs =====
// Llama 3 family // Llama 3 family
ModelInfo::llm("llama-3-8b", "Llama 3 8B", "QmLlama3_8B_placeholder", 8_000_000_000, 8192), ModelInfo::llm(
ModelInfo::llm("llama-3-70b", "Llama 3 70B", "QmLlama3_70B_placeholder", 70_000_000_000, 8192), "llama-3-8b",
ModelInfo::llm("llama-3.1-8b", "Llama 3.1 8B", "QmLlama31_8B_placeholder", 8_000_000_000, 128000), "Llama 3 8B",
ModelInfo::llm("llama-3.1-70b", "Llama 3.1 70B", "QmLlama31_70B_placeholder", 70_000_000_000, 128000), "QmLlama3_8B_placeholder",
ModelInfo::llm("llama-3.1-405b", "Llama 3.1 405B", "QmLlama31_405B_placeholder", 405_000_000_000, 128000), 8_000_000_000,
8192,
),
ModelInfo::llm(
"llama-3-70b",
"Llama 3 70B",
"QmLlama3_70B_placeholder",
70_000_000_000,
8192,
),
ModelInfo::llm(
"llama-3.1-8b",
"Llama 3.1 8B",
"QmLlama31_8B_placeholder",
8_000_000_000,
128000,
),
ModelInfo::llm(
"llama-3.1-70b",
"Llama 3.1 70B",
"QmLlama31_70B_placeholder",
70_000_000_000,
128000,
),
ModelInfo::llm(
"llama-3.1-405b",
"Llama 3.1 405B",
"QmLlama31_405B_placeholder",
405_000_000_000,
128000,
),
// Mistral family // Mistral family
ModelInfo::llm("mistral-7b", "Mistral 7B", "QmMistral7B_placeholder", 7_000_000_000, 32768), ModelInfo::llm(
ModelInfo::llm("mixtral-8x7b", "Mixtral 8x7B", "QmMixtral8x7B_placeholder", 46_000_000_000, 32768), "mistral-7b",
ModelInfo::llm("mixtral-8x22b", "Mixtral 8x22B", "QmMixtral8x22B_placeholder", 176_000_000_000, 65536), "Mistral 7B",
"QmMistral7B_placeholder",
7_000_000_000,
32768,
),
ModelInfo::llm(
"mixtral-8x7b",
"Mixtral 8x7B",
"QmMixtral8x7B_placeholder",
46_000_000_000,
32768,
),
ModelInfo::llm(
"mixtral-8x22b",
"Mixtral 8x22B",
"QmMixtral8x22B_placeholder",
176_000_000_000,
65536,
),
// Qwen family // Qwen family
ModelInfo::llm("qwen-2.5-7b", "Qwen 2.5 7B", "QmQwen25_7B_placeholder", 7_000_000_000, 128000), ModelInfo::llm(
ModelInfo::llm("qwen-2.5-72b", "Qwen 2.5 72B", "QmQwen25_72B_placeholder", 72_000_000_000, 128000), "qwen-2.5-7b",
"Qwen 2.5 7B",
"QmQwen25_7B_placeholder",
7_000_000_000,
128000,
),
ModelInfo::llm(
"qwen-2.5-72b",
"Qwen 2.5 72B",
"QmQwen25_72B_placeholder",
72_000_000_000,
128000,
),
// DeepSeek family // DeepSeek family
ModelInfo::llm("deepseek-v2", "DeepSeek V2", "QmDeepSeekV2_placeholder", 236_000_000_000, 128000), ModelInfo::llm(
ModelInfo::llm("deepseek-coder-33b", "DeepSeek Coder 33B", "QmDeepSeekCoder33B_placeholder", 33_000_000_000, 16384), "deepseek-v2",
"DeepSeek V2",
"QmDeepSeekV2_placeholder",
236_000_000_000,
128000,
),
ModelInfo::llm(
"deepseek-coder-33b",
"DeepSeek Coder 33B",
"QmDeepSeekCoder33B_placeholder",
33_000_000_000,
16384,
),
// Phi family (small/efficient) // Phi family (small/efficient)
ModelInfo::llm("phi-3-mini", "Phi 3 Mini", "QmPhi3Mini_placeholder", 3_800_000_000, 128000), ModelInfo::llm(
ModelInfo::llm("phi-3-medium", "Phi 3 Medium", "QmPhi3Medium_placeholder", 14_000_000_000, 128000), "phi-3-mini",
"Phi 3 Mini",
"QmPhi3Mini_placeholder",
3_800_000_000,
128000,
),
ModelInfo::llm(
"phi-3-medium",
"Phi 3 Medium",
"QmPhi3Medium_placeholder",
14_000_000_000,
128000,
),
// Code models // Code models
ModelInfo::llm("codellama-34b", "Code Llama 34B", "QmCodeLlama34B_placeholder", 34_000_000_000, 16384), ModelInfo::llm(
ModelInfo::llm("starcoder2-15b", "StarCoder2 15B", "QmStarCoder2_15B_placeholder", 15_000_000_000, 16384), "codellama-34b",
"Code Llama 34B",
"QmCodeLlama34B_placeholder",
34_000_000_000,
16384,
),
ModelInfo::llm(
"starcoder2-15b",
"StarCoder2 15B",
"QmStarCoder2_15B_placeholder",
15_000_000_000,
16384,
),
// ===== Embedding Models ===== // ===== Embedding Models =====
ModelInfo { ModelInfo {
id: ModelId::from_alias("bge-large"), id: ModelId::from_alias("bge-large"),
@ -306,7 +395,6 @@ impl ModelRegistry {
is_public: true, is_public: true,
owner: None, owner: None,
}, },
// ===== Vision Models ===== // ===== Vision Models =====
ModelInfo { ModelInfo {
id: ModelId::from_alias("stable-diffusion-xl"), id: ModelId::from_alias("stable-diffusion-xl"),
@ -348,7 +436,6 @@ impl ModelRegistry {
is_public: true, is_public: true,
owner: None, owner: None,
}, },
// ===== Speech Models ===== // ===== Speech Models =====
ModelInfo { ModelInfo {
id: ModelId::from_alias("whisper-large-v3"), id: ModelId::from_alias("whisper-large-v3"),
@ -370,7 +457,6 @@ impl ModelRegistry {
is_public: true, is_public: true,
owner: None, owner: None,
}, },
// ===== Multi-Modal Models ===== // ===== Multi-Modal Models =====
ModelInfo { ModelInfo {
id: ModelId::from_alias("llava-1.5-13b"), id: ModelId::from_alias("llava-1.5-13b"),
@ -555,7 +641,9 @@ mod tests {
let registry = ModelRegistry::new(); let registry = ModelRegistry::new();
let results = registry.search("llama"); let results = registry.search("llama");
assert!(!results.is_empty()); assert!(!results.is_empty());
assert!(results.iter().all(|m| m.name.to_lowercase().contains("llama"))); assert!(results
.iter()
.all(|m| m.name.to_lowercase().contains("llama")));
} }
#[test] #[test]

View file

@ -305,7 +305,7 @@ impl ProcessorCapabilities {
}, },
memory: MemorySpecs { memory: MemorySpecs {
capacity_bytes: 230 * 1024 * 1024 * 1024, // 230 GB SRAM! capacity_bytes: 230 * 1024 * 1024 * 1024, // 230 GB SRAM!
bandwidth_gbps: 80_000, // 80 TB/s internal bandwidth_gbps: 80_000, // 80 TB/s internal
type_: MemoryType::Sram, type_: MemoryType::Sram,
}, },
operations: Self::lpu_operations(), operations: Self::lpu_operations(),
@ -349,8 +349,8 @@ impl ProcessorCapabilities {
/// Creates Apple Neural Engine capabilities. /// Creates Apple Neural Engine capabilities.
pub fn apple_neural_engine(cores: u32) -> Self { pub fn apple_neural_engine(cores: u32) -> Self {
let int8_tops = match cores { let int8_tops = match cores {
16 => 18.0, // M3 16 => 18.0, // M3
32 => 35.0, // M3 Max 32 => 35.0, // M3 Max
_ => cores as f64 * 1.1, _ => cores as f64 * 1.1,
}; };
@ -542,6 +542,8 @@ mod tests {
fn test_lpu_capabilities() { fn test_lpu_capabilities() {
let caps = ProcessorCapabilities::lpu(); let caps = ProcessorCapabilities::lpu();
assert!(caps.memory.bandwidth_gbps > 10000); // Very high internal bandwidth assert!(caps.memory.bandwidth_gbps > 10000); // Very high internal bandwidth
assert!(caps.optimal_for.contains(&WorkloadCharacteristic::Sequential)); assert!(caps
.optimal_for
.contains(&WorkloadCharacteristic::Sequential));
} }
} }

View file

@ -253,10 +253,22 @@ impl Processor for GenericProcessor {
fn shares_memory_with(&self, other: &ProcessorType) -> bool { fn shares_memory_with(&self, other: &ProcessorType) -> bool {
match (&self.processor_type, other) { match (&self.processor_type, other) {
// Apple Silicon has unified memory // Apple Silicon has unified memory
(ProcessorType::Cpu(CpuVariant::Arm64 { .. }), ProcessorType::Gpu(GpuVariant::AppleMetal)) (
| (ProcessorType::Gpu(GpuVariant::AppleMetal), ProcessorType::Cpu(CpuVariant::Arm64 { .. })) ProcessorType::Cpu(CpuVariant::Arm64 { .. }),
| (ProcessorType::Npu(NpuVariant::AppleNeuralEngine { .. }), ProcessorType::Cpu(CpuVariant::Arm64 { .. })) ProcessorType::Gpu(GpuVariant::AppleMetal),
| (ProcessorType::Npu(NpuVariant::AppleNeuralEngine { .. }), ProcessorType::Gpu(GpuVariant::AppleMetal)) => true, )
| (
ProcessorType::Gpu(GpuVariant::AppleMetal),
ProcessorType::Cpu(CpuVariant::Arm64 { .. }),
)
| (
ProcessorType::Npu(NpuVariant::AppleNeuralEngine { .. }),
ProcessorType::Cpu(CpuVariant::Arm64 { .. }),
)
| (
ProcessorType::Npu(NpuVariant::AppleNeuralEngine { .. }),
ProcessorType::Gpu(GpuVariant::AppleMetal),
) => true,
// Same type always shares // Same type always shares
(a, b) if a == b => true, (a, b) if a == b => true,
_ => false, _ => false,

View file

@ -191,10 +191,7 @@ pub enum Operation {
}, },
/// Data loading from storage. /// Data loading from storage.
DataLoad { DataLoad { bytes: usize, async_: bool },
bytes: usize,
async_: bool,
},
/// Data preprocessing. /// Data preprocessing.
DataPreprocess { DataPreprocess {
@ -209,16 +206,10 @@ pub enum Operation {
}, },
/// Detokenization. /// Detokenization.
Detokenization { Detokenization { tokens: usize, vocab_size: usize },
tokens: usize,
vocab_size: usize,
},
/// Checkpoint save. /// Checkpoint save.
Checkpoint { Checkpoint { bytes: usize, async_: bool },
bytes: usize,
async_: bool,
},
/// All-reduce across devices. /// All-reduce across devices.
AllReduce { AllReduce {
@ -228,9 +219,7 @@ pub enum Operation {
}, },
/// Backward pass for a layer. /// Backward pass for a layer.
Backward { Backward { forward_op: Box<Operation> },
forward_op: Box<Operation>,
},
/// Optimizer step. /// Optimizer step.
OptimizerStep { OptimizerStep {
@ -240,16 +229,10 @@ pub enum Operation {
}, },
/// Transpose. /// Transpose.
Transpose { Transpose { shape: Vec<usize>, axes: Vec<usize> },
shape: Vec<usize>,
axes: Vec<usize>,
},
/// Reshape. /// Reshape.
Reshape { Reshape { from: Vec<usize>, to: Vec<usize> },
from: Vec<usize>,
to: Vec<usize>,
},
/// Concatenate tensors. /// Concatenate tensors.
Concat { Concat {
@ -378,9 +361,7 @@ impl Operation {
| Operation::SiLU { elements } => *elements as f64, | Operation::SiLU { elements } => *elements as f64,
// Softmax: ~5 ops per element (exp, sum, div) // Softmax: ~5 ops per element (exp, sum, div)
Operation::Softmax { Operation::Softmax { batch, seq_len, .. } => 5.0 * (*batch as f64) * (*seq_len as f64),
batch, seq_len, ..
} => 5.0 * (*batch as f64) * (*seq_len as f64),
// Embedding: just lookup, minimal FLOPS // Embedding: just lookup, minimal FLOPS
Operation::Embedding { Operation::Embedding {

View file

@ -39,8 +39,7 @@ impl ProcessorProfiles {
bandwidth_gbps: 460, bandwidth_gbps: 460,
type_: MemoryType::Ddr5, type_: MemoryType::Ddr5,
}, },
operations: ProcessorCapabilities::cpu(96, 2.4, false) operations: ProcessorCapabilities::cpu(96, 2.4, false).operations,
.operations,
power: PowerCharacteristics { power: PowerCharacteristics {
tdp_watts: 360, tdp_watts: 360,
efficiency: 0.85, efficiency: 0.85,
@ -70,8 +69,7 @@ impl ProcessorProfiles {
bandwidth_gbps: 307, bandwidth_gbps: 307,
type_: MemoryType::Ddr5, type_: MemoryType::Ddr5,
}, },
operations: ProcessorCapabilities::cpu(56, 2.9, true) operations: ProcessorCapabilities::cpu(56, 2.9, true).operations,
.operations,
power: PowerCharacteristics { power: PowerCharacteristics {
tdp_watts: 350, tdp_watts: 350,
efficiency: 0.80, efficiency: 0.80,
@ -101,8 +99,7 @@ impl ProcessorProfiles {
bandwidth_gbps: 400, bandwidth_gbps: 400,
type_: MemoryType::Unified, type_: MemoryType::Unified,
}, },
operations: ProcessorCapabilities::cpu(16, 4.0, false) operations: ProcessorCapabilities::cpu(16, 4.0, false).operations,
.operations,
power: PowerCharacteristics { power: PowerCharacteristics {
tdp_watts: 40, tdp_watts: 40,
efficiency: 0.95, efficiency: 0.95,
@ -141,8 +138,7 @@ impl ProcessorProfiles {
bandwidth_gbps: 3350, bandwidth_gbps: 3350,
type_: MemoryType::Hbm3, type_: MemoryType::Hbm3,
}, },
operations: ProcessorCapabilities::nvidia_gpu(16896, 528, 80, 3350, (9, 0)) operations: ProcessorCapabilities::nvidia_gpu(16896, 528, 80, 3350, (9, 0)).operations,
.operations,
power: PowerCharacteristics { power: PowerCharacteristics {
tdp_watts: 700, tdp_watts: 700,
efficiency: 0.90, efficiency: 0.90,
@ -173,8 +169,7 @@ impl ProcessorProfiles {
bandwidth_gbps: 2039, bandwidth_gbps: 2039,
type_: MemoryType::Hbm2e, type_: MemoryType::Hbm2e,
}, },
operations: ProcessorCapabilities::nvidia_gpu(6912, 432, 80, 2039, (8, 0)) operations: ProcessorCapabilities::nvidia_gpu(6912, 432, 80, 2039, (8, 0)).operations,
.operations,
power: PowerCharacteristics { power: PowerCharacteristics {
tdp_watts: 400, tdp_watts: 400,
efficiency: 0.88, efficiency: 0.88,
@ -205,8 +200,7 @@ impl ProcessorProfiles {
bandwidth_gbps: 1008, bandwidth_gbps: 1008,
type_: MemoryType::Gddr6, type_: MemoryType::Gddr6,
}, },
operations: ProcessorCapabilities::nvidia_gpu(16384, 512, 24, 1008, (8, 9)) operations: ProcessorCapabilities::nvidia_gpu(16384, 512, 24, 1008, (8, 9)).operations,
.operations,
power: PowerCharacteristics { power: PowerCharacteristics {
tdp_watts: 450, tdp_watts: 450,
efficiency: 0.85, efficiency: 0.85,
@ -236,8 +230,7 @@ impl ProcessorProfiles {
bandwidth_gbps: 936, bandwidth_gbps: 936,
type_: MemoryType::Gddr6, type_: MemoryType::Gddr6,
}, },
operations: ProcessorCapabilities::nvidia_gpu(10496, 328, 24, 936, (8, 6)) operations: ProcessorCapabilities::nvidia_gpu(10496, 328, 24, 936, (8, 6)).operations,
.operations,
power: PowerCharacteristics { power: PowerCharacteristics {
tdp_watts: 350, tdp_watts: 350,
efficiency: 0.82, efficiency: 0.82,
@ -272,8 +265,8 @@ impl ProcessorProfiles {
type_: MemoryType::Hbm3, type_: MemoryType::Hbm3,
}, },
operations: { operations: {
let mut ops = ProcessorCapabilities::nvidia_gpu(16384, 512, 80, 5300, (9, 0)) let mut ops =
.operations; ProcessorCapabilities::nvidia_gpu(16384, 512, 80, 5300, (9, 0)).operations;
ops.remove(&OperationType::FlashAttention); // Different implementation ops.remove(&OperationType::FlashAttention); // Different implementation
ops ops
}, },
@ -308,8 +301,8 @@ impl ProcessorProfiles {
type_: MemoryType::Gddr6, type_: MemoryType::Gddr6,
}, },
operations: { operations: {
let mut ops = ProcessorCapabilities::nvidia_gpu(6144, 0, 24, 960, (8, 0)) let mut ops =
.operations; ProcessorCapabilities::nvidia_gpu(6144, 0, 24, 960, (8, 0)).operations;
ops.remove(&OperationType::FlashAttention); ops.remove(&OperationType::FlashAttention);
ops ops
}, },
@ -318,9 +311,7 @@ impl ProcessorProfiles {
efficiency: 0.80, efficiency: 0.80,
power_tier: PowerTier::High, power_tier: PowerTier::High,
}, },
optimal_for: vec![ optimal_for: vec![WorkloadCharacteristic::HighlyParallel],
WorkloadCharacteristic::HighlyParallel,
],
} }
} }
@ -429,8 +420,7 @@ impl ProcessorProfiles {
bandwidth_gbps: 200, bandwidth_gbps: 200,
type_: MemoryType::Unified, type_: MemoryType::Unified,
}, },
operations: ProcessorCapabilities::apple_neural_engine(16) operations: ProcessorCapabilities::apple_neural_engine(16).operations,
.operations,
power: PowerCharacteristics { power: PowerCharacteristics {
tdp_watts: 8, tdp_watts: 8,
efficiency: 0.98, efficiency: 0.98,
@ -465,8 +455,7 @@ impl ProcessorProfiles {
bandwidth_gbps: 77, bandwidth_gbps: 77,
type_: MemoryType::Lpddr, type_: MemoryType::Lpddr,
}, },
operations: ProcessorCapabilities::apple_neural_engine(16) operations: ProcessorCapabilities::apple_neural_engine(16).operations,
.operations,
power: PowerCharacteristics { power: PowerCharacteristics {
tdp_watts: 10, tdp_watts: 10,
efficiency: 0.95, efficiency: 0.95,

View file

@ -24,10 +24,7 @@ pub enum ProcessorType {
/// WebAssembly runtime. /// WebAssembly runtime.
Wasm, Wasm,
/// Custom/Unknown accelerator. /// Custom/Unknown accelerator.
Custom { Custom { vendor: String, model: String },
vendor: String,
model: String,
},
} }
impl Default for ProcessorType { impl Default for ProcessorType {

View file

@ -6,10 +6,10 @@
//! - Latency-aware scheduling //! - Latency-aware scheduling
//! - Real-time utilization metrics //! - Real-time utilization metrics
use super::TaskAssignment;
use crate::device::DeviceRegistry; use crate::device::DeviceRegistry;
use crate::processor::{Operation, OperationType, ProcessorId, ProcessorType}; use crate::processor::{Operation, OperationType, ProcessorId, ProcessorType};
use crate::task::{Task, TaskId, TaskPriority}; use crate::task::{Task, TaskId, TaskPriority};
use super::TaskAssignment;
use parking_lot::RwLock; use parking_lot::RwLock;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::atomic::{AtomicU64, Ordering};
@ -127,8 +127,12 @@ impl LoadBalancer {
/// Register a processor with its type. /// Register a processor with its type.
pub fn register_processor(&self, processor_id: ProcessorId, processor_type: ProcessorType) { pub fn register_processor(&self, processor_id: ProcessorId, processor_type: ProcessorType) {
self.loads.write().insert(processor_id, AtomicU64::new(0)); self.loads.write().insert(processor_id, AtomicU64::new(0));
self.metrics.write().insert(processor_id, ProcessorMetrics::default()); self.metrics
self.processor_types.write().insert(processor_id, processor_type); .write()
.insert(processor_id, ProcessorMetrics::default());
self.processor_types
.write()
.insert(processor_id, processor_type);
} }
/// Unregister a processor. /// Unregister a processor.
@ -150,7 +154,8 @@ impl LoadBalancer {
/// Get current load for a processor. /// Get current load for a processor.
pub fn get_load(&self, processor_id: ProcessorId) -> u64 { pub fn get_load(&self, processor_id: ProcessorId) -> u64 {
self.loads.read() self.loads
.read()
.get(&processor_id) .get(&processor_id)
.map(|l| l.load(Ordering::Relaxed)) .map(|l| l.load(Ordering::Relaxed))
.unwrap_or(0) .unwrap_or(0)
@ -179,140 +184,140 @@ impl LoadBalancer {
ProcessorType::Cpu(_) => matches!( ProcessorType::Cpu(_) => matches!(
op_type, op_type,
OperationType::MatMul OperationType::MatMul
| OperationType::Conv2d | OperationType::Conv2d
| OperationType::Conv3d | OperationType::Conv3d
| OperationType::DepthwiseConv | OperationType::DepthwiseConv
| OperationType::BatchNorm | OperationType::BatchNorm
| OperationType::LayerNorm | OperationType::LayerNorm
| OperationType::Add | OperationType::Add
| OperationType::Mul | OperationType::Mul
| OperationType::ReLU | OperationType::ReLU
| OperationType::GeLU | OperationType::GeLU
| OperationType::SiLU | OperationType::SiLU
| OperationType::Softmax | OperationType::Softmax
| OperationType::Sum | OperationType::Sum
| OperationType::Mean | OperationType::Mean
| OperationType::Max | OperationType::Max
| OperationType::ArgMax | OperationType::ArgMax
| OperationType::Embedding | OperationType::Embedding
| OperationType::TopK | OperationType::TopK
| OperationType::Sampling | OperationType::Sampling
| OperationType::Tokenization | OperationType::Tokenization
| OperationType::Detokenization | OperationType::Detokenization
| OperationType::DataLoad | OperationType::DataLoad
| OperationType::DataPreprocess | OperationType::DataPreprocess
| OperationType::Transpose | OperationType::Transpose
| OperationType::Reshape | OperationType::Reshape
| OperationType::Concat | OperationType::Concat
| OperationType::Split | OperationType::Split
), ),
// GPUs excel at parallel operations // GPUs excel at parallel operations
ProcessorType::Gpu(_) => matches!( ProcessorType::Gpu(_) => matches!(
op_type, op_type,
OperationType::MatMul OperationType::MatMul
| OperationType::Conv2d | OperationType::Conv2d
| OperationType::Conv3d | OperationType::Conv3d
| OperationType::DepthwiseConv | OperationType::DepthwiseConv
| OperationType::BatchNorm | OperationType::BatchNorm
| OperationType::LayerNorm | OperationType::LayerNorm
| OperationType::SelfAttention | OperationType::SelfAttention
| OperationType::CrossAttention | OperationType::CrossAttention
| OperationType::FlashAttention | OperationType::FlashAttention
| OperationType::Add | OperationType::Add
| OperationType::Mul | OperationType::Mul
| OperationType::ReLU | OperationType::ReLU
| OperationType::GeLU | OperationType::GeLU
| OperationType::SiLU | OperationType::SiLU
| OperationType::Softmax | OperationType::Softmax
| OperationType::Sum | OperationType::Sum
| OperationType::Mean | OperationType::Mean
| OperationType::Max | OperationType::Max
| OperationType::ArgMax | OperationType::ArgMax
| OperationType::Embedding | OperationType::Embedding
| OperationType::RoPE | OperationType::RoPE
| OperationType::KVCache | OperationType::KVCache
| OperationType::TopK | OperationType::TopK
| OperationType::Sampling | OperationType::Sampling
| OperationType::Transpose | OperationType::Transpose
| OperationType::Reshape | OperationType::Reshape
| OperationType::Concat | OperationType::Concat
| OperationType::Split | OperationType::Split
| OperationType::Gather | OperationType::Gather
| OperationType::Scatter | OperationType::Scatter
| OperationType::AllReduce | OperationType::AllReduce
| OperationType::AllGather | OperationType::AllGather
| OperationType::ReduceScatter | OperationType::ReduceScatter
| OperationType::Backward | OperationType::Backward
| OperationType::OptimizerStep | OperationType::OptimizerStep
| OperationType::GradientClip | OperationType::GradientClip
), ),
// TPUs optimized for ML // TPUs optimized for ML
ProcessorType::Tpu(_) => matches!( ProcessorType::Tpu(_) => matches!(
op_type, op_type,
OperationType::MatMul OperationType::MatMul
| OperationType::Conv2d | OperationType::Conv2d
| OperationType::BatchNorm | OperationType::BatchNorm
| OperationType::LayerNorm | OperationType::LayerNorm
| OperationType::SelfAttention | OperationType::SelfAttention
| OperationType::CrossAttention | OperationType::CrossAttention
| OperationType::FlashAttention | OperationType::FlashAttention
| OperationType::Add | OperationType::Add
| OperationType::Mul | OperationType::Mul
| OperationType::ReLU | OperationType::ReLU
| OperationType::GeLU | OperationType::GeLU
| OperationType::SiLU | OperationType::SiLU
| OperationType::Softmax | OperationType::Softmax
| OperationType::Sum | OperationType::Sum
| OperationType::Mean | OperationType::Mean
| OperationType::Embedding | OperationType::Embedding
| OperationType::RoPE | OperationType::RoPE
| OperationType::KVCache | OperationType::KVCache
| OperationType::AllReduce | OperationType::AllReduce
| OperationType::AllGather | OperationType::AllGather
| OperationType::ReduceScatter | OperationType::ReduceScatter
| OperationType::Backward | OperationType::Backward
| OperationType::OptimizerStep | OperationType::OptimizerStep
), ),
// NPUs for neural network inference // NPUs for neural network inference
ProcessorType::Npu(_) => matches!( ProcessorType::Npu(_) => matches!(
op_type, op_type,
OperationType::MatMul OperationType::MatMul
| OperationType::Conv2d | OperationType::Conv2d
| OperationType::DepthwiseConv | OperationType::DepthwiseConv
| OperationType::BatchNorm | OperationType::BatchNorm
| OperationType::LayerNorm | OperationType::LayerNorm
| OperationType::SelfAttention | OperationType::SelfAttention
| OperationType::Add | OperationType::Add
| OperationType::Mul | OperationType::Mul
| OperationType::ReLU | OperationType::ReLU
| OperationType::GeLU | OperationType::GeLU
| OperationType::SiLU | OperationType::SiLU
| OperationType::Softmax | OperationType::Softmax
| OperationType::Sum | OperationType::Sum
| OperationType::Mean | OperationType::Mean
), ),
// LPUs for sequential inference (optimized for LLMs) // LPUs for sequential inference (optimized for LLMs)
ProcessorType::Lpu => matches!( ProcessorType::Lpu => matches!(
op_type, op_type,
OperationType::MatMul OperationType::MatMul
| OperationType::LayerNorm | OperationType::LayerNorm
| OperationType::SelfAttention | OperationType::SelfAttention
| OperationType::FlashAttention | OperationType::FlashAttention
| OperationType::Add | OperationType::Add
| OperationType::Mul | OperationType::Mul
| OperationType::ReLU | OperationType::ReLU
| OperationType::GeLU | OperationType::GeLU
| OperationType::SiLU | OperationType::SiLU
| OperationType::Softmax | OperationType::Softmax
| OperationType::Embedding | OperationType::Embedding
| OperationType::RoPE | OperationType::RoPE
| OperationType::KVCache | OperationType::KVCache
| OperationType::TopK | OperationType::TopK
| OperationType::Sampling | OperationType::Sampling
), ),
// FPGAs can be programmed for anything // FPGAs can be programmed for anything
@ -322,40 +327,40 @@ impl LoadBalancer {
ProcessorType::Dsp(_) => matches!( ProcessorType::Dsp(_) => matches!(
op_type, op_type,
OperationType::Conv2d OperationType::Conv2d
| OperationType::DepthwiseConv | OperationType::DepthwiseConv
| OperationType::Add | OperationType::Add
| OperationType::Mul | OperationType::Mul
| OperationType::Sum | OperationType::Sum
| OperationType::Mean | OperationType::Mean
| OperationType::Max | OperationType::Max
), ),
// WebGPU has limited operations // WebGPU has limited operations
ProcessorType::WebGpu => matches!( ProcessorType::WebGpu => matches!(
op_type, op_type,
OperationType::MatMul OperationType::MatMul
| OperationType::Conv2d | OperationType::Conv2d
| OperationType::Add | OperationType::Add
| OperationType::Mul | OperationType::Mul
| OperationType::ReLU | OperationType::ReLU
| OperationType::Softmax | OperationType::Softmax
| OperationType::Sum | OperationType::Sum
| OperationType::Transpose | OperationType::Transpose
| OperationType::Reshape | OperationType::Reshape
), ),
// WASM for portable compute // WASM for portable compute
ProcessorType::Wasm => matches!( ProcessorType::Wasm => matches!(
op_type, op_type,
OperationType::MatMul OperationType::MatMul
| OperationType::Add | OperationType::Add
| OperationType::Mul | OperationType::Mul
| OperationType::ReLU | OperationType::ReLU
| OperationType::Softmax | OperationType::Softmax
| OperationType::Sum | OperationType::Sum
| OperationType::Mean | OperationType::Mean
| OperationType::Tokenization | OperationType::Tokenization
| OperationType::Detokenization | OperationType::Detokenization
), ),
// Custom processors - assume they can handle anything // Custom processors - assume they can handle anything
@ -381,7 +386,9 @@ impl LoadBalancer {
} }
// Get utilization and metrics // Get utilization and metrics
let utilization = proc_metrics.map(|m| m.utilization).unwrap_or(load as f64 / 100.0); let utilization = proc_metrics
.map(|m| m.utilization)
.unwrap_or(load as f64 / 100.0);
let power = proc_metrics.map(|m| m.power_watts).unwrap_or(100.0); let power = proc_metrics.map(|m| m.power_watts).unwrap_or(100.0);
let avg_completion = proc_metrics.map(|m| m.avg_completion_ms).unwrap_or(100.0); let avg_completion = proc_metrics.map(|m| m.avg_completion_ms).unwrap_or(100.0);
@ -431,13 +438,13 @@ impl LoadBalancer {
BalancingStrategy::Cost => { BalancingStrategy::Cost => {
// Prioritize cheaper resources (consumer devices) // Prioritize cheaper resources (consumer devices)
let cost_factor = match processor_type { let cost_factor = match processor_type {
ProcessorType::Wasm => 0.1, // Cheapest (browser) ProcessorType::Wasm => 0.1, // Cheapest (browser)
ProcessorType::WebGpu => 0.15, ProcessorType::WebGpu => 0.15,
ProcessorType::Cpu(_) => 0.2, ProcessorType::Cpu(_) => 0.2,
ProcessorType::Npu(_) => 0.3, // Mobile NPUs ProcessorType::Npu(_) => 0.3, // Mobile NPUs
ProcessorType::Gpu(_) => 0.5, ProcessorType::Gpu(_) => 0.5,
ProcessorType::Lpu => 0.8, ProcessorType::Lpu => 0.8,
ProcessorType::Tpu(_) => 1.0, // Most expensive ProcessorType::Tpu(_) => 1.0, // Most expensive
_ => 0.5, _ => 0.5,
}; };
@ -450,7 +457,7 @@ impl LoadBalancer {
// Bonus for low-latency processors // Bonus for low-latency processors
let latency_bonus = match processor_type { let latency_bonus = match processor_type {
ProcessorType::Lpu => 5.0, // Designed for low latency ProcessorType::Lpu => 5.0, // Designed for low latency
ProcessorType::Npu(_) => 3.0, ProcessorType::Npu(_) => 3.0,
ProcessorType::Gpu(_) => 2.0, ProcessorType::Gpu(_) => 2.0,
ProcessorType::Tpu(_) => 1.5, ProcessorType::Tpu(_) => 1.5,
@ -550,7 +557,8 @@ impl LoadBalancer {
let mut suggestions = Vec::new(); let mut suggestions = Vec::new();
let loads = self.loads.read(); let loads = self.loads.read();
let load_values: Vec<_> = loads.iter() let load_values: Vec<_> = loads
.iter()
.map(|(id, load)| (*id, load.load(Ordering::Relaxed))) .map(|(id, load)| (*id, load.load(Ordering::Relaxed)))
.collect(); .collect();
@ -558,16 +566,18 @@ impl LoadBalancer {
return suggestions; return suggestions;
} }
let avg_load: f64 = load_values.iter().map(|(_, l)| *l as f64).sum::<f64>() let avg_load: f64 =
/ load_values.len() as f64; load_values.iter().map(|(_, l)| *l as f64).sum::<f64>() / load_values.len() as f64;
let processor_types = self.processor_types.read(); let processor_types = self.processor_types.read();
let overloaded: Vec<_> = load_values.iter() let overloaded: Vec<_> = load_values
.iter()
.filter(|(_, l)| *l as f64 > avg_load * (1.0 + self.rebalance_threshold)) .filter(|(_, l)| *l as f64 > avg_load * (1.0 + self.rebalance_threshold))
.collect(); .collect();
let underloaded: Vec<_> = load_values.iter() let underloaded: Vec<_> = load_values
.iter()
.filter(|(_, l)| (*l as f64) < avg_load * (1.0 - self.rebalance_threshold)) .filter(|(_, l)| (*l as f64) < avg_load * (1.0 - self.rebalance_threshold))
.collect(); .collect();
@ -627,7 +637,9 @@ impl LoadBalancer {
/// Clean up old migration history. /// Clean up old migration history.
pub fn cleanup_history(&self, max_age: Duration) { pub fn cleanup_history(&self, max_age: Duration) {
let cutoff = Instant::now() - max_age; let cutoff = Instant::now() - max_age;
self.migration_history.write().retain(|r| r.timestamp > cutoff); self.migration_history
.write()
.retain(|r| r.timestamp > cutoff);
} }
} }
@ -725,7 +737,9 @@ mod tests {
balancer.register_processor(ProcessorId(0), ProcessorType::Cpu(CpuVariant::default())); balancer.register_processor(ProcessorId(0), ProcessorType::Cpu(CpuVariant::default()));
balancer.register_processor( balancer.register_processor(
ProcessorId(1), ProcessorId(1),
ProcessorType::Gpu(GpuVariant::NvidiaCuda { compute_capability: (8, 9) }), ProcessorType::Gpu(GpuVariant::NvidiaCuda {
compute_capability: (8, 9),
}),
); );
// Give CPU high load // Give CPU high load
@ -757,7 +771,9 @@ mod tests {
}; };
let cpu = ProcessorType::Cpu(CpuVariant::default()); let cpu = ProcessorType::Cpu(CpuVariant::default());
let gpu = ProcessorType::Gpu(GpuVariant::NvidiaCuda { compute_capability: (8, 9) }); let gpu = ProcessorType::Gpu(GpuVariant::NvidiaCuda {
compute_capability: (8, 9),
});
let lpu = ProcessorType::Lpu; let lpu = ProcessorType::Lpu;
// MatMul can run on all // MatMul can run on all
@ -778,7 +794,10 @@ mod tests {
let npu_id = ProcessorId(1); let npu_id = ProcessorId(1);
balancer.register_processor(cpu_id, ProcessorType::Cpu(CpuVariant::default())); balancer.register_processor(cpu_id, ProcessorType::Cpu(CpuVariant::default()));
balancer.register_processor(npu_id, ProcessorType::Npu(crate::processor::NpuVariant::AppleNeuralEngine { cores: 16 })); balancer.register_processor(
npu_id,
ProcessorType::Npu(crate::processor::NpuVariant::AppleNeuralEngine { cores: 16 }),
);
let task = create_test_task(TaskPriority::Normal); let task = create_test_task(TaskPriority::Normal);

View file

@ -69,7 +69,9 @@ impl HeterogeneousScheduler {
let utilization = self.estimate_utilization(&schedule); let utilization = self.estimate_utilization(&schedule);
// 5. Store active schedule // 5. Store active schedule
self.active_schedules.write().insert(schedule.id, schedule.clone()); self.active_schedules
.write()
.insert(schedule.id, schedule.clone());
Ok(ScheduleResult { Ok(ScheduleResult {
schedule, schedule,
@ -89,10 +91,12 @@ impl HeterogeneousScheduler {
let mut handles = Vec::new(); let mut handles = Vec::new();
for task_id in &stage.tasks { for task_id in &stage.tasks {
let task = schedule.tasks.get(task_id) let task = schedule.tasks.get(task_id).ok_or_else(|| {
.ok_or_else(|| ComputeError::Internal(format!("Task not found: {:?}", task_id)))?; ComputeError::Internal(format!("Task not found: {:?}", task_id))
let processor_id = schedule.assignment.get(task_id) })?;
.ok_or_else(|| ComputeError::Internal(format!("No assignment for task: {:?}", task_id)))?; let processor_id = schedule.assignment.get(task_id).ok_or_else(|| {
ComputeError::Internal(format!("No assignment for task: {:?}", task_id))
})?;
let processor = self.device_registry.get_processor(processor_id)?; let processor = self.device_registry.get_processor(processor_id)?;
let task_clone = task.clone(); let task_clone = task.clone();
@ -144,8 +148,9 @@ impl HeterogeneousScheduler {
let best_processor = self.find_best_processor(&task).await?; let best_processor = self.find_best_processor(&task).await?;
// Check if we should rebalance // Check if we should rebalance
let final_processor = self.load_balancer let final_processor =
.maybe_rebalance(&task, best_processor, &assignment); self.load_balancer
.maybe_rebalance(&task, best_processor, &assignment);
assignment.assign(task.id, final_processor); assignment.assign(task.id, final_processor);
} }
@ -207,9 +212,7 @@ impl HeterogeneousScheduler {
fn topological_sort(&self, tasks: &[Task], deps: &DependencyGraph) -> Vec<Task> { fn topological_sort(&self, tasks: &[Task], deps: &DependencyGraph) -> Vec<Task> {
let mut sorted = Vec::new(); let mut sorted = Vec::new();
let mut visited = std::collections::HashSet::new(); let mut visited = std::collections::HashSet::new();
let task_map: HashMap<TaskId, Task> = tasks.iter() let task_map: HashMap<TaskId, Task> = tasks.iter().map(|t| (t.id, t.clone())).collect();
.map(|t| (t.id, t.clone()))
.collect();
fn visit( fn visit(
task_id: TaskId, task_id: TaskId,
@ -254,9 +257,7 @@ impl HeterogeneousScheduler {
) -> Result<Schedule, ComputeError> { ) -> Result<Schedule, ComputeError> {
let mut stages = Vec::new(); let mut stages = Vec::new();
let mut scheduled = std::collections::HashSet::new(); let mut scheduled = std::collections::HashSet::new();
let task_map: HashMap<TaskId, Task> = tasks.iter() let task_map: HashMap<TaskId, Task> = tasks.iter().map(|t| (t.id, t.clone())).collect();
.map(|t| (t.id, t.clone()))
.collect();
while scheduled.len() < tasks.len() { while scheduled.len() < tasks.len() {
let mut stage_tasks = Vec::new(); let mut stage_tasks = Vec::new();
@ -267,8 +268,7 @@ impl HeterogeneousScheduler {
} }
// Check if all dependencies are satisfied // Check if all dependencies are satisfied
let deps_satisfied = task.dependencies.iter() let deps_satisfied = task.dependencies.iter().all(|dep| scheduled.contains(dep));
.all(|dep| scheduled.contains(dep));
if deps_satisfied { if deps_satisfied {
stage_tasks.push(task.id); stage_tasks.push(task.id);
@ -277,7 +277,7 @@ impl HeterogeneousScheduler {
if stage_tasks.is_empty() { if stage_tasks.is_empty() {
return Err(ComputeError::SchedulingFailed( return Err(ComputeError::SchedulingFailed(
"Circular dependency detected".to_string() "Circular dependency detected".to_string(),
)); ));
} }

View file

@ -153,7 +153,10 @@ impl PriorityWorkQueue {
TaskPriority::Normal, TaskPriority::Normal,
TaskPriority::Background, TaskPriority::Background,
] { ] {
queues.insert(priority, WorkQueue::new(processor_type.clone(), capacity_per_priority)); queues.insert(
priority,
WorkQueue::new(processor_type.clone(), capacity_per_priority),
);
} }
Self { Self {
@ -223,10 +226,7 @@ mod tests {
#[test] #[test]
fn test_work_queue_basic() { fn test_work_queue_basic() {
let queue = WorkQueue::new( let queue = WorkQueue::new(ProcessorType::Cpu(CpuVariant::default()), 100);
ProcessorType::Cpu(CpuVariant::default()),
100,
);
assert!(queue.is_empty()); assert!(queue.is_empty());
@ -246,10 +246,7 @@ mod tests {
#[test] #[test]
fn test_priority_queue() { fn test_priority_queue() {
let queue = PriorityWorkQueue::new( let queue = PriorityWorkQueue::new(ProcessorType::Cpu(CpuVariant::default()), 100);
ProcessorType::Cpu(CpuVariant::default()),
100,
);
queue.push(create_test_task(1, TaskPriority::Background)); queue.push(create_test_task(1, TaskPriority::Background));
queue.push(create_test_task(2, TaskPriority::Critical)); queue.push(create_test_task(2, TaskPriority::Critical));

View file

@ -495,9 +495,9 @@ mod tests {
compute_capability: (8, 0) compute_capability: (8, 0)
} }
))); )));
assert!(matmul_task.is_compatible_with(ProcessorType::Tpu( assert!(
crate::processor::TpuVersion::V5p matmul_task.is_compatible_with(ProcessorType::Tpu(crate::processor::TpuVersion::V5p))
))); );
let data_load_task = Task::new(Operation::DataLoad { let data_load_task = Task::new(Operation::DataLoad {
bytes: 1000, bytes: 1000,
@ -505,9 +505,8 @@ mod tests {
}); });
// DataLoad should be compatible with CPU // DataLoad should be compatible with CPU
assert!(data_load_task.is_compatible_with(ProcessorType::Cpu( assert!(data_load_task
crate::processor::CpuVariant::default() .is_compatible_with(ProcessorType::Cpu(crate::processor::CpuVariant::default())));
)));
} }
#[test] #[test]

View file

@ -25,8 +25,7 @@
use std::time::Duration; use std::time::Duration;
/// Blocks per second mode. /// Blocks per second mode.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Default)]
#[derive(Default)]
pub enum BpsMode { pub enum BpsMode {
/// Standard mode: 10 blocks per second (100ms block time) /// Standard mode: 10 blocks per second (100ms block time)
/// - Suitable for most network conditions /// - Suitable for most network conditions
@ -75,7 +74,6 @@ impl BpsMode {
} }
} }
impl std::fmt::Display for BpsMode { impl std::fmt::Display for BpsMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self { match self {
@ -148,39 +146,39 @@ impl NetworkConfig {
bps_mode: mode, bps_mode: mode,
blocks_per_second: 10, blocks_per_second: 10,
target_block_time_ms: 100, target_block_time_ms: 100,
daa_window_size: 2641, // ~264s window daa_window_size: 2641, // ~264s window
ghostdag_k: 18, // For 10 BPS ghostdag_k: 18, // For 10 BPS
dagknight_k_min: 8, dagknight_k_min: 8,
dagknight_k_max: 64, dagknight_k_max: 64,
finality_depth: 864, // ~86 seconds finality_depth: 864, // ~86 seconds
pruning_depth: 864_000, // ~24 hours pruning_depth: 864_000, // ~24 hours
merge_set_size_limit: 180, merge_set_size_limit: 180,
expected_delay_ms: 100, expected_delay_ms: 100,
}, },
BpsMode::Fast32 => Self { BpsMode::Fast32 => Self {
bps_mode: mode, bps_mode: mode,
blocks_per_second: 32, blocks_per_second: 32,
target_block_time_ms: 31, // ~31.25ms target_block_time_ms: 31, // ~31.25ms
daa_window_size: 8461, // ~264s window at 32 BPS daa_window_size: 8461, // ~264s window at 32 BPS
ghostdag_k: 58, // Scaled for 32 BPS ghostdag_k: 58, // Scaled for 32 BPS
dagknight_k_min: 16, // Higher min for faster blocks dagknight_k_min: 16, // Higher min for faster blocks
dagknight_k_max: 128, // Higher max for adaptation dagknight_k_max: 128, // Higher max for adaptation
finality_depth: 2765, // ~86 seconds at 32 BPS finality_depth: 2765, // ~86 seconds at 32 BPS
pruning_depth: 2_764_800, // ~24 hours at 32 BPS pruning_depth: 2_764_800, // ~24 hours at 32 BPS
merge_set_size_limit: 576, // 32/10 * 180 merge_set_size_limit: 576, // 32/10 * 180
expected_delay_ms: 50, expected_delay_ms: 50,
}, },
BpsMode::Ultra100 => Self { BpsMode::Ultra100 => Self {
bps_mode: mode, bps_mode: mode,
blocks_per_second: 100, blocks_per_second: 100,
target_block_time_ms: 10, target_block_time_ms: 10,
daa_window_size: 26410, // ~264s window at 100 BPS daa_window_size: 26410, // ~264s window at 100 BPS
ghostdag_k: 180, // Scaled for 100 BPS ghostdag_k: 180, // Scaled for 100 BPS
dagknight_k_min: 50, // Higher min for very fast blocks dagknight_k_min: 50, // Higher min for very fast blocks
dagknight_k_max: 255, // u8 max - very high for adaptation dagknight_k_max: 255, // u8 max - very high for adaptation
finality_depth: 8640, // ~86 seconds at 100 BPS finality_depth: 8640, // ~86 seconds at 100 BPS
pruning_depth: 8_640_000, // ~24 hours at 100 BPS pruning_depth: 8_640_000, // ~24 hours at 100 BPS
merge_set_size_limit: 1800, // 100/10 * 180 merge_set_size_limit: 1800, // 100/10 * 180
expected_delay_ms: 20, expected_delay_ms: 20,
}, },
BpsMode::Custom(bps) => { BpsMode::Custom(bps) => {
@ -269,7 +267,7 @@ pub fn bps_comparison_table() -> String {
let mut table = String::from( let mut table = String::from(
"| Property | Standard (10 BPS) | Fast (32 BPS) | Ultra (100 BPS) |\n\ "| Property | Standard (10 BPS) | Fast (32 BPS) | Ultra (100 BPS) |\n\
|----------|-------------------|---------------|------------------|\n" |----------|-------------------|---------------|------------------|\n",
); );
// Block Time // Block Time
@ -314,7 +312,9 @@ pub fn bps_comparison_table() -> String {
// Estimated TPS // Estimated TPS
table.push_str(&format!( table.push_str(&format!(
"| Est. TPS @1000tx/block | {:.0} | {:.0} | {:.0} |\n", "| Est. TPS @1000tx/block | {:.0} | {:.0} | {:.0} |\n",
standard.estimate_tps(1000), fast.estimate_tps(1000), ultra.estimate_tps(1000) standard.estimate_tps(1000),
fast.estimate_tps(1000),
ultra.estimate_tps(1000)
)); ));
table table
@ -401,9 +401,9 @@ mod tests {
fn test_latency_acceptable() { fn test_latency_acceptable() {
let config = NetworkConfig::standard(); // expects 100ms let config = NetworkConfig::standard(); // expects 100ms
assert!(config.is_latency_acceptable(50)); // Good assert!(config.is_latency_acceptable(50)); // Good
assert!(config.is_latency_acceptable(100)); // OK assert!(config.is_latency_acceptable(100)); // OK
assert!(config.is_latency_acceptable(200)); // Still OK (2x limit) assert!(config.is_latency_acceptable(200)); // Still OK (2x limit)
assert!(!config.is_latency_acceptable(300)); // Too high assert!(!config.is_latency_acceptable(300)); // Too high
} }

View file

@ -55,8 +55,8 @@
//! | Layer 2 transactions | FALCON-512 (batch efficiency) | //! | Layer 2 transactions | FALCON-512 (batch efficiency) |
//! | High-value transactions | Dilithium3 (conservative choice) | //! | High-value transactions | Dilithium3 (conservative choice) |
use pqcrypto_falcon::falcon512;
use pqcrypto_falcon::falcon1024; use pqcrypto_falcon::falcon1024;
use pqcrypto_falcon::falcon512;
use pqcrypto_traits::sign::{ use pqcrypto_traits::sign::{
DetachedSignature, PublicKey as PqPublicKey, SecretKey as PqSecretKey, DetachedSignature, PublicKey as PqPublicKey, SecretKey as PqSecretKey,
}; };
@ -64,8 +64,7 @@ use thiserror::Error;
use zeroize::{Zeroize, ZeroizeOnDrop}; use zeroize::{Zeroize, ZeroizeOnDrop};
/// FALCON variant selection. /// FALCON variant selection.
#[derive(Clone, Copy, Debug, PartialEq, Eq)] #[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
#[derive(Default)]
pub enum FalconVariant { pub enum FalconVariant {
/// 128-bit security, ~690 byte signatures /// 128-bit security, ~690 byte signatures
#[default] #[default]
@ -124,7 +123,6 @@ impl FalconVariant {
} }
} }
/// FALCON public key. /// FALCON public key.
#[derive(Clone)] #[derive(Clone)]
pub struct FalconPublicKey { pub struct FalconPublicKey {
@ -188,7 +186,10 @@ impl std::fmt::Debug for FalconPublicKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FalconPublicKey") f.debug_struct("FalconPublicKey")
.field("variant", &self.variant) .field("variant", &self.variant)
.field("bytes", &hex::encode(&self.bytes[..8.min(self.bytes.len())])) .field(
"bytes",
&hex::encode(&self.bytes[..8.min(self.bytes.len())]),
)
.finish() .finish()
} }
} }
@ -492,7 +493,10 @@ mod tests {
// Verify with wrong message should fail // Verify with wrong message should fail
let wrong_message = b"Wrong message"; let wrong_message = b"Wrong message";
assert!(keypair.public_key().verify(wrong_message, &signature).is_err()); assert!(keypair
.public_key()
.verify(wrong_message, &signature)
.is_err());
} }
#[test] #[test]

View file

@ -150,9 +150,9 @@ impl PqAlgorithm {
/// Default priority order (higher = more preferred) /// Default priority order (higher = more preferred)
fn default_priority(&self) -> u8 { fn default_priority(&self) -> u8 {
match self { match self {
Self::Dilithium3 => 100, // Default, well-balanced Self::Dilithium3 => 100, // Default, well-balanced
Self::Falcon1024 => 90, // High security, compact Self::Falcon1024 => 90, // High security, compact
Self::Falcon512 => 85, // Compact, mobile-friendly Self::Falcon512 => 85, // Compact, mobile-friendly
Self::SphincsShake192s => 70, // Conservative backup Self::SphincsShake192s => 70, // Conservative backup
Self::SphincsShake256s => 60, // Maximum security Self::SphincsShake256s => 60, // Maximum security
Self::SphincsShake128s => 50, // Basic SPHINCS+ Self::SphincsShake128s => 50, // Basic SPHINCS+
@ -270,7 +270,8 @@ impl AlgorithmCapabilities {
/// Decode capabilities from bytes /// Decode capabilities from bytes
pub fn decode(data: &[u8]) -> Result<Self, NegotiationError> { pub fn decode(data: &[u8]) -> Result<Self, NegotiationError> {
serde_json::from_slice(data).map_err(|e| NegotiationError::InvalidCapabilities(e.to_string())) serde_json::from_slice(data)
.map_err(|e| NegotiationError::InvalidCapabilities(e.to_string()))
} }
} }
@ -384,8 +385,7 @@ impl AlgorithmNegotiator {
// Check security level // Check security level
let meets_local_security = let meets_local_security =
algo.security_level() >= self.local_caps.min_security_level; algo.security_level() >= self.local_caps.min_security_level;
let meets_remote_security = let meets_remote_security = algo.security_level() >= remote_caps.min_security_level;
algo.security_level() >= remote_caps.min_security_level;
// Check signature size // Check signature size
let local_size_ok = self.local_caps.max_signature_size == 0 let local_size_ok = self.local_caps.max_signature_size == 0
@ -513,10 +513,7 @@ impl AlgorithmNegotiator {
} }
/// Quick negotiation using just algorithm names /// Quick negotiation using just algorithm names
pub fn quick_negotiate( pub fn quick_negotiate(local: &[PqAlgorithm], remote: &[PqAlgorithm]) -> Option<PqAlgorithm> {
local: &[PqAlgorithm],
remote: &[PqAlgorithm],
) -> Option<PqAlgorithm> {
// Find common algorithms and return the one with highest default priority // Find common algorithms and return the one with highest default priority
let local_set: HashSet<_> = local.iter().collect(); let local_set: HashSet<_> = local.iter().collect();
let remote_set: HashSet<_> = remote.iter().collect(); let remote_set: HashSet<_> = remote.iter().collect();
@ -604,7 +601,10 @@ pub enum NegotiationMessage {
}, },
/// Acknowledge selection /// Acknowledge selection
Acknowledgment { session_id: [u8; 32], accepted: bool }, Acknowledgment {
session_id: [u8; 32],
accepted: bool,
},
/// Request renegotiation /// Request renegotiation
Renegotiate { reason: String }, Renegotiate { reason: String },
@ -691,8 +691,10 @@ mod tests {
let result = negotiator.negotiate(&remote_caps).unwrap(); let result = negotiator.negotiate(&remote_caps).unwrap();
// Should prefer FALCON for bandwidth-constrained scenarios // Should prefer FALCON for bandwidth-constrained scenarios
assert!(result.algorithm == PqAlgorithm::Falcon512 || assert!(
result.algorithm == PqAlgorithm::Falcon1024); result.algorithm == PqAlgorithm::Falcon512
|| result.algorithm == PqAlgorithm::Falcon1024
);
} }
#[test] #[test]

View file

@ -60,8 +60,7 @@ use zeroize::{Zeroize, ZeroizeOnDrop};
/// All variants use SHAKE (SHA3-based) for hashing. /// All variants use SHAKE (SHA3-based) for hashing.
/// 's' variants have smaller signatures but are slower. /// 's' variants have smaller signatures but are slower.
/// 'f' variants are faster but have larger signatures. /// 'f' variants are faster but have larger signatures.
#[derive(Clone, Copy, Debug, PartialEq, Eq)] #[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
#[derive(Default)]
pub enum SphincsVariant { pub enum SphincsVariant {
/// 128-bit security, small signatures (~7.8KB) /// 128-bit security, small signatures (~7.8KB)
#[default] #[default]
@ -119,7 +118,6 @@ impl SphincsVariant {
} }
} }
/// SPHINCS+ public key. /// SPHINCS+ public key.
#[derive(Clone)] #[derive(Clone)]
pub struct SphincsPublicKey { pub struct SphincsPublicKey {
@ -191,7 +189,10 @@ impl std::fmt::Debug for SphincsPublicKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SphincsPublicKey") f.debug_struct("SphincsPublicKey")
.field("variant", &self.variant) .field("variant", &self.variant)
.field("bytes", &hex::encode(&self.bytes[..8.min(self.bytes.len())])) .field(
"bytes",
&hex::encode(&self.bytes[..8.min(self.bytes.len())]),
)
.finish() .finish()
} }
} }
@ -500,7 +501,10 @@ mod tests {
// Verify with wrong message should fail // Verify with wrong message should fail
let wrong_message = b"Wrong message"; let wrong_message = b"Wrong message";
assert!(keypair.public_key().verify(wrong_message, &signature).is_err()); assert!(keypair
.public_key()
.verify(wrong_message, &signature)
.is_err());
} }
#[test] #[test]

View file

@ -155,10 +155,7 @@ pub struct DagKnightManager {
impl DagKnightManager { impl DagKnightManager {
/// Creates a new DAGKnight manager with standard 10 BPS configuration. /// Creates a new DAGKnight manager with standard 10 BPS configuration.
pub fn new( pub fn new(dag: Arc<BlockDag>, reachability: Arc<ReachabilityStore>) -> Self {
dag: Arc<BlockDag>,
reachability: Arc<ReachabilityStore>,
) -> Self {
Self::with_config(dag, reachability, BlockRateConfig::Standard) Self::with_config(dag, reachability, BlockRateConfig::Standard)
} }
@ -269,7 +266,8 @@ impl DagKnightManager {
let anticone_size = self.calculate_anticone_size(&block_id, parents); let anticone_size = self.calculate_anticone_size(&block_id, parents);
// Record observation in latency tracker // Record observation in latency tracker
self.latency_tracker.record_block(block_id, block_time_ms, anticone_size); self.latency_tracker
.record_block(block_id, block_time_ms, anticone_size);
// Process with underlying GHOSTDAG // Process with underlying GHOSTDAG
let data = self.ghostdag.add_block(block_id, parents)?; let data = self.ghostdag.add_block(block_id, parents)?;
@ -292,11 +290,9 @@ impl DagKnightManager {
for tip in tips { for tip in tips {
if tip != *block_id && !parents.contains(&tip) { if tip != *block_id && !parents.contains(&tip) {
// Check if tip is in the past of any parent // Check if tip is in the past of any parent
let in_past = parents.iter().any(|p| { let in_past = parents
self.reachability .iter()
.is_ancestor(p, &tip) .any(|p| self.reachability.is_ancestor(p, &tip).unwrap_or(false));
.unwrap_or(false)
});
if !in_past { if !in_past {
anticone_count += 1; anticone_count += 1;
@ -375,7 +371,8 @@ impl DagKnightManager {
let sigma_multiplier = confidence.sigma_multiplier(); let sigma_multiplier = confidence.sigma_multiplier();
// Required depth scales with variance and confidence level // Required depth scales with variance and confidence level
let required_depth = (self.block_rate_bps * (mean_delay + sigma * sigma_multiplier)).ceil() as u64; let required_depth =
(self.block_rate_bps * (mean_delay + sigma * sigma_multiplier)).ceil() as u64;
// Current confidence based on actual depth // Current confidence based on actual depth
let current_confidence = if depth >= required_depth { let current_confidence = if depth >= required_depth {
@ -388,7 +385,8 @@ impl DagKnightManager {
// Time to reach required depth // Time to reach required depth
let blocks_needed = required_depth.saturating_sub(depth); let blocks_needed = required_depth.saturating_sub(depth);
let time_per_block_ms = 1000.0 / self.block_rate_bps; let time_per_block_ms = 1000.0 / self.block_rate_bps;
let estimated_time = Duration::from_millis((blocks_needed as f64 * time_per_block_ms) as u64); let estimated_time =
Duration::from_millis((blocks_needed as f64 * time_per_block_ms) as u64);
// Block is final if depth exceeds finality threshold for this block rate // Block is final if depth exceeds finality threshold for this block rate
let is_final = depth >= self.finality_depth(); let is_final = depth >= self.finality_depth();
@ -506,7 +504,10 @@ impl std::fmt::Debug for DagKnightManager {
.field("block_rate_config", &self.block_rate_config) .field("block_rate_config", &self.block_rate_config)
.field("block_rate_bps", &self.block_rate_bps) .field("block_rate_bps", &self.block_rate_bps)
.field("adaptive_k", &*self.adaptive_k.read()) .field("adaptive_k", &*self.adaptive_k.read())
.field("k_bounds", &format!("{}-{}", self.k_bounds.min_k, self.k_bounds.max_k)) .field(
"k_bounds",
&format!("{}-{}", self.k_bounds.min_k, self.k_bounds.max_k),
)
.field("mean_delay_ms", &stats.mean_delay_ms) .field("mean_delay_ms", &stats.mean_delay_ms)
.field("sample_count", &stats.sample_count) .field("sample_count", &stats.sample_count)
.finish() .finish()
@ -534,10 +535,7 @@ pub fn calculate_optimal_k(network_delay_ms: f64, block_rate_bps: f64) -> u8 {
} }
/// Calculates the optimal k for a specific block rate configuration. /// Calculates the optimal k for a specific block rate configuration.
pub fn calculate_optimal_k_for_config( pub fn calculate_optimal_k_for_config(network_delay_ms: f64, config: BlockRateConfig) -> u8 {
network_delay_ms: f64,
config: BlockRateConfig,
) -> u8 {
let bounds = AdaptiveKBounds::for_block_rate(config); let bounds = AdaptiveKBounds::for_block_rate(config);
let delay_secs = network_delay_ms / 1000.0; let delay_secs = network_delay_ms / 1000.0;
let k = (config.bps() * delay_secs * SAFETY_MARGIN).ceil() as u16; let k = (config.bps() * delay_secs * SAFETY_MARGIN).ceil() as u16;
@ -578,7 +576,9 @@ mod tests {
(dag, reachability, dagknight) (dag, reachability, dagknight)
} }
fn setup_test_dag_with_config(config: BlockRateConfig) -> (Arc<BlockDag>, Arc<ReachabilityStore>, DagKnightManager) { fn setup_test_dag_with_config(
config: BlockRateConfig,
) -> (Arc<BlockDag>, Arc<ReachabilityStore>, DagKnightManager) {
let genesis = make_block_id(0); let genesis = make_block_id(0);
let dag = Arc::new(BlockDag::new(genesis, 0)); let dag = Arc::new(BlockDag::new(genesis, 0));
let reachability = Arc::new(ReachabilityStore::new(genesis)); let reachability = Arc::new(ReachabilityStore::new(genesis));
@ -671,14 +671,19 @@ mod tests {
let tps_poor = estimate_throughput(10.0, 100, 40.0); let tps_poor = estimate_throughput(10.0, 100, 40.0);
// Good network should have higher throughput // Good network should have higher throughput
assert!(tps_good > tps_poor, "tps_good={} should be > tps_poor={}", tps_good, tps_poor); assert!(
tps_good > tps_poor,
"tps_good={} should be > tps_poor={}",
tps_good,
tps_poor
);
} }
#[test] #[test]
fn test_throughput_by_config() { fn test_throughput_by_config() {
// At same network conditions, higher BPS = higher theoretical TPS // At same network conditions, higher BPS = higher theoretical TPS
let tps_10 = estimate_throughput(10.0, 100, 20.0); // 10 BPS let tps_10 = estimate_throughput(10.0, 100, 20.0); // 10 BPS
let tps_32 = estimate_throughput(32.0, 100, 20.0); // 32 BPS let tps_32 = estimate_throughput(32.0, 100, 20.0); // 32 BPS
let tps_100 = estimate_throughput(100.0, 100, 20.0); // 100 BPS let tps_100 = estimate_throughput(100.0, 100, 20.0); // 100 BPS
// Higher block rates give higher TPS (with network overhead) // Higher block rates give higher TPS (with network overhead)
@ -698,19 +703,37 @@ mod tests {
let maximum_time_hrs = maximum.finality_depth() as f64 / 100.0 / 3600.0; let maximum_time_hrs = maximum.finality_depth() as f64 / 100.0 / 3600.0;
// Should all be approximately 2.4 hours (allow some variance) // Should all be approximately 2.4 hours (allow some variance)
assert!((standard_time_hrs - 2.4).abs() < 0.1, "standard: {}", standard_time_hrs); assert!(
assert!((enhanced_time_hrs - 2.4).abs() < 0.1, "enhanced: {}", enhanced_time_hrs); (standard_time_hrs - 2.4).abs() < 0.1,
assert!((maximum_time_hrs - 2.4).abs() < 0.1, "maximum: {}", maximum_time_hrs); "standard: {}",
standard_time_hrs
);
assert!(
(enhanced_time_hrs - 2.4).abs() < 0.1,
"enhanced: {}",
enhanced_time_hrs
);
assert!(
(maximum_time_hrs - 2.4).abs() < 0.1,
"maximum: {}",
maximum_time_hrs
);
} }
#[test] #[test]
fn test_confidence_levels() { fn test_confidence_levels() {
assert!(ConfirmationConfidence::VeryHigh.sigma_multiplier() assert!(
> ConfirmationConfidence::High.sigma_multiplier()); ConfirmationConfidence::VeryHigh.sigma_multiplier()
assert!(ConfirmationConfidence::High.sigma_multiplier() > ConfirmationConfidence::High.sigma_multiplier()
> ConfirmationConfidence::Medium.sigma_multiplier()); );
assert!(ConfirmationConfidence::Medium.sigma_multiplier() assert!(
> ConfirmationConfidence::Low.sigma_multiplier()); ConfirmationConfidence::High.sigma_multiplier()
> ConfirmationConfidence::Medium.sigma_multiplier()
);
assert!(
ConfirmationConfidence::Medium.sigma_multiplier()
> ConfirmationConfidence::Low.sigma_multiplier()
);
} }
#[test] #[test]

View file

@ -98,12 +98,7 @@ impl LatencyTracker {
/// * `block_id` - Hash of the observed block /// * `block_id` - Hash of the observed block
/// * `block_time_ms` - Timestamp from block header (Unix ms) /// * `block_time_ms` - Timestamp from block header (Unix ms)
/// * `anticone_size` - Number of blocks in the anticone at observation time /// * `anticone_size` - Number of blocks in the anticone at observation time
pub fn record_block( pub fn record_block(&self, block_id: BlockId, block_time_ms: u64, anticone_size: usize) {
&self,
block_id: BlockId,
block_time_ms: u64,
anticone_size: usize,
) {
let local_time = Instant::now(); let local_time = Instant::now();
let now_ms = std::time::SystemTime::now() let now_ms = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH) .duration_since(std::time::UNIX_EPOCH)
@ -208,7 +203,10 @@ impl LatencyTracker {
let anticone_growth_rate = if n > 1 { let anticone_growth_rate = if n > 1 {
let first = samples.front().unwrap(); let first = samples.front().unwrap();
let last = samples.back().unwrap(); let last = samples.back().unwrap();
let time_span_secs = last.local_time.duration_since(first.local_time).as_secs_f64(); let time_span_secs = last
.local_time
.duration_since(first.local_time)
.as_secs_f64();
if time_span_secs > 0.0 { if time_span_secs > 0.0 {
let total_anticone_growth: usize = samples.iter().map(|s| s.anticone_size).sum(); let total_anticone_growth: usize = samples.iter().map(|s| s.anticone_size).sum();

View file

@ -32,8 +32,8 @@ pub mod reachability;
pub use dag::{BlockDag, BlockRelations, DagError}; pub use dag::{BlockDag, BlockRelations, DagError};
pub use dagknight::{ pub use dagknight::{
calculate_optimal_k, calculate_optimal_k_for_config, estimate_throughput, calculate_optimal_k, calculate_optimal_k_for_config, estimate_throughput, AdaptiveKBounds,
AdaptiveKBounds, ConfirmationConfidence, ConfirmationStatus, DagKnightManager, ConfirmationConfidence, ConfirmationStatus, DagKnightManager,
}; };
pub use ghostdag::{GhostdagData, GhostdagError, GhostdagManager}; pub use ghostdag::{GhostdagData, GhostdagError, GhostdagManager};
pub use latency::{LatencySample, LatencyStats, LatencyTracker}; pub use latency::{LatencySample, LatencyStats, LatencyTracker};
@ -116,27 +116,27 @@ impl BlockRateConfig {
/// Returns the merge depth adjusted for block rate. /// Returns the merge depth adjusted for block rate.
pub const fn merge_depth(&self) -> u64 { pub const fn merge_depth(&self) -> u64 {
match self { match self {
BlockRateConfig::Standard => 3600, // ~6 min at 10 bps BlockRateConfig::Standard => 3600, // ~6 min at 10 bps
BlockRateConfig::Enhanced => 11520, // ~6 min at 32 bps BlockRateConfig::Enhanced => 11520, // ~6 min at 32 bps
BlockRateConfig::Maximum => 36000, // ~6 min at 100 bps BlockRateConfig::Maximum => 36000, // ~6 min at 100 bps
} }
} }
/// Returns the finality depth adjusted for block rate. /// Returns the finality depth adjusted for block rate.
pub const fn finality_depth(&self) -> u64 { pub const fn finality_depth(&self) -> u64 {
match self { match self {
BlockRateConfig::Standard => 86400, // ~2.4 hours at 10 bps BlockRateConfig::Standard => 86400, // ~2.4 hours at 10 bps
BlockRateConfig::Enhanced => 276480, // ~2.4 hours at 32 bps BlockRateConfig::Enhanced => 276480, // ~2.4 hours at 32 bps
BlockRateConfig::Maximum => 864000, // ~2.4 hours at 100 bps BlockRateConfig::Maximum => 864000, // ~2.4 hours at 100 bps
} }
} }
/// Returns the pruning depth adjusted for block rate. /// Returns the pruning depth adjusted for block rate.
pub const fn pruning_depth(&self) -> u64 { pub const fn pruning_depth(&self) -> u64 {
match self { match self {
BlockRateConfig::Standard => 288_000, // ~8 hours at 10 bps BlockRateConfig::Standard => 288_000, // ~8 hours at 10 bps
BlockRateConfig::Enhanced => 921_600, // ~8 hours at 32 bps BlockRateConfig::Enhanced => 921_600, // ~8 hours at 32 bps
BlockRateConfig::Maximum => 2_880_000, // ~8 hours at 100 bps BlockRateConfig::Maximum => 2_880_000, // ~8 hours at 100 bps
} }
} }
} }

View file

@ -42,7 +42,9 @@ impl DocumentId {
let bytes = hex::decode(s) let bytes = hex::decode(s)
.map_err(|_| DatabaseError::InvalidOperation("Invalid hex string".into()))?; .map_err(|_| DatabaseError::InvalidOperation("Invalid hex string".into()))?;
if bytes.len() != 32 { if bytes.len() != 32 {
return Err(DatabaseError::InvalidOperation("Invalid document ID length".into())); return Err(DatabaseError::InvalidOperation(
"Invalid document ID length".into(),
));
} }
let mut arr = [0u8; 32]; let mut arr = [0u8; 32];
arr.copy_from_slice(&bytes); arr.copy_from_slice(&bytes);
@ -249,7 +251,11 @@ impl Collection {
} }
/// Updates documents matching a filter. /// Updates documents matching a filter.
pub fn update_many(&self, filter: &DocumentFilter, update: JsonValue) -> Result<u64, DatabaseError> { pub fn update_many(
&self,
filter: &DocumentFilter,
update: JsonValue,
) -> Result<u64, DatabaseError> {
let mut docs = self.documents.write(); let mut docs = self.documents.write();
let mut count = 0; let mut count = 0;
for doc in docs.values_mut() { for doc in docs.values_mut() {
@ -321,60 +327,71 @@ enum FilterCondition {
impl DocumentFilter { impl DocumentFilter {
/// Creates a new empty filter (matches all). /// Creates a new empty filter (matches all).
pub fn new() -> Self { pub fn new() -> Self {
Self { conditions: Vec::new() } Self {
conditions: Vec::new(),
}
} }
/// Equality condition. /// Equality condition.
pub fn eq(mut self, field: impl Into<String>, value: JsonValue) -> Self { pub fn eq(mut self, field: impl Into<String>, value: JsonValue) -> Self {
self.conditions.push(FilterCondition::Eq(field.into(), value)); self.conditions
.push(FilterCondition::Eq(field.into(), value));
self self
} }
/// Not equal condition. /// Not equal condition.
pub fn ne(mut self, field: impl Into<String>, value: JsonValue) -> Self { pub fn ne(mut self, field: impl Into<String>, value: JsonValue) -> Self {
self.conditions.push(FilterCondition::Ne(field.into(), value)); self.conditions
.push(FilterCondition::Ne(field.into(), value));
self self
} }
/// Greater than. /// Greater than.
pub fn gt(mut self, field: impl Into<String>, value: JsonValue) -> Self { pub fn gt(mut self, field: impl Into<String>, value: JsonValue) -> Self {
self.conditions.push(FilterCondition::Gt(field.into(), value)); self.conditions
.push(FilterCondition::Gt(field.into(), value));
self self
} }
/// Greater than or equal. /// Greater than or equal.
pub fn gte(mut self, field: impl Into<String>, value: JsonValue) -> Self { pub fn gte(mut self, field: impl Into<String>, value: JsonValue) -> Self {
self.conditions.push(FilterCondition::Gte(field.into(), value)); self.conditions
.push(FilterCondition::Gte(field.into(), value));
self self
} }
/// Less than. /// Less than.
pub fn lt(mut self, field: impl Into<String>, value: JsonValue) -> Self { pub fn lt(mut self, field: impl Into<String>, value: JsonValue) -> Self {
self.conditions.push(FilterCondition::Lt(field.into(), value)); self.conditions
.push(FilterCondition::Lt(field.into(), value));
self self
} }
/// Less than or equal. /// Less than or equal.
pub fn lte(mut self, field: impl Into<String>, value: JsonValue) -> Self { pub fn lte(mut self, field: impl Into<String>, value: JsonValue) -> Self {
self.conditions.push(FilterCondition::Lte(field.into(), value)); self.conditions
.push(FilterCondition::Lte(field.into(), value));
self self
} }
/// In array. /// In array.
pub fn in_array(mut self, field: impl Into<String>, values: Vec<JsonValue>) -> Self { pub fn in_array(mut self, field: impl Into<String>, values: Vec<JsonValue>) -> Self {
self.conditions.push(FilterCondition::In(field.into(), values)); self.conditions
.push(FilterCondition::In(field.into(), values));
self self
} }
/// String contains. /// String contains.
pub fn contains(mut self, field: impl Into<String>, substring: impl Into<String>) -> Self { pub fn contains(mut self, field: impl Into<String>, substring: impl Into<String>) -> Self {
self.conditions.push(FilterCondition::Contains(field.into(), substring.into())); self.conditions
.push(FilterCondition::Contains(field.into(), substring.into()));
self self
} }
/// Field exists. /// Field exists.
pub fn exists(mut self, field: impl Into<String>, exists: bool) -> Self { pub fn exists(mut self, field: impl Into<String>, exists: bool) -> Self {
self.conditions.push(FilterCondition::Exists(field.into(), exists)); self.conditions
.push(FilterCondition::Exists(field.into(), exists));
self self
} }
@ -396,7 +413,9 @@ impl DocumentFilter {
return true; return true;
} }
self.conditions.iter().all(|cond| self.eval_condition(cond, doc)) self.conditions
.iter()
.all(|cond| self.eval_condition(cond, doc))
} }
fn eval_condition(&self, cond: &FilterCondition, doc: &Document) -> bool { fn eval_condition(&self, cond: &FilterCondition, doc: &Document) -> bool {
@ -419,27 +438,21 @@ impl DocumentFilter {
FilterCondition::Lte(field, value) => { FilterCondition::Lte(field, value) => {
self.compare_values(doc.get_nested(field), value, |a, b| a <= b) self.compare_values(doc.get_nested(field), value, |a, b| a <= b)
} }
FilterCondition::In(field, values) => { FilterCondition::In(field, values) => doc
doc.get_nested(field) .get_nested(field)
.map(|v| values.contains(v)) .map(|v| values.contains(v))
.unwrap_or(false) .unwrap_or(false),
} FilterCondition::Contains(field, substring) => doc
FilterCondition::Contains(field, substring) => { .get_nested(field)
doc.get_nested(field) .and_then(|v| v.as_str())
.and_then(|v| v.as_str()) .map(|s| s.contains(substring))
.map(|s| s.contains(substring)) .unwrap_or(false),
.unwrap_or(false)
}
FilterCondition::Exists(field, should_exist) => { FilterCondition::Exists(field, should_exist) => {
let exists = doc.get_nested(field).is_some(); let exists = doc.get_nested(field).is_some();
exists == *should_exist exists == *should_exist
} }
FilterCondition::And(filters) => { FilterCondition::And(filters) => filters.iter().all(|f| f.matches(doc)),
filters.iter().all(|f| f.matches(doc)) FilterCondition::Or(filters) => filters.iter().any(|f| f.matches(doc)),
}
FilterCondition::Or(filters) => {
filters.iter().any(|f| f.matches(doc))
}
} }
} }
@ -448,12 +461,10 @@ impl DocumentFilter {
F: Fn(f64, f64) -> bool, F: Fn(f64, f64) -> bool,
{ {
match (a, b) { match (a, b) {
(Some(JsonValue::Number(a)), JsonValue::Number(b)) => { (Some(JsonValue::Number(a)), JsonValue::Number(b)) => match (a.as_f64(), b.as_f64()) {
match (a.as_f64(), b.as_f64()) { (Some(a), Some(b)) => cmp(a, b),
(Some(a), Some(b)) => cmp(a, b), _ => false,
_ => false, },
}
}
_ => false, _ => false,
} }
} }
@ -512,7 +523,11 @@ impl DocumentStore {
} }
/// Finds documents in a collection. /// Finds documents in a collection.
pub fn find(&self, collection: &str, filter: &DocumentFilter) -> Result<Vec<Document>, DatabaseError> { pub fn find(
&self,
collection: &str,
filter: &DocumentFilter,
) -> Result<Vec<Document>, DatabaseError> {
let collections = self.collections.read(); let collections = self.collections.read();
let coll = collections let coll = collections
.get(collection) .get(collection)
@ -521,7 +536,11 @@ impl DocumentStore {
} }
/// Finds one document. /// Finds one document.
pub fn find_one(&self, collection: &str, filter: &DocumentFilter) -> Result<Option<Document>, DatabaseError> { pub fn find_one(
&self,
collection: &str,
filter: &DocumentFilter,
) -> Result<Option<Document>, DatabaseError> {
let collections = self.collections.read(); let collections = self.collections.read();
let coll = collections let coll = collections
.get(collection) .get(collection)
@ -530,7 +549,11 @@ impl DocumentStore {
} }
/// Finds a document by ID. /// Finds a document by ID.
pub fn find_by_id(&self, collection: &str, id: &DocumentId) -> Result<Option<Document>, DatabaseError> { pub fn find_by_id(
&self,
collection: &str,
id: &DocumentId,
) -> Result<Option<Document>, DatabaseError> {
let collections = self.collections.read(); let collections = self.collections.read();
let coll = collections let coll = collections
.get(collection) .get(collection)
@ -539,7 +562,12 @@ impl DocumentStore {
} }
/// Updates a document by ID. /// Updates a document by ID.
pub fn update_by_id(&self, collection: &str, id: &DocumentId, update: JsonValue) -> Result<bool, DatabaseError> { pub fn update_by_id(
&self,
collection: &str,
id: &DocumentId,
update: JsonValue,
) -> Result<bool, DatabaseError> {
let collections = self.collections.read(); let collections = self.collections.read();
let coll = collections let coll = collections
.get(collection) .get(collection)
@ -584,7 +612,8 @@ mod tests {
fn test_collection_insert_find() { fn test_collection_insert_find() {
let coll = Collection::new("users"); let coll = Collection::new("users");
coll.insert_one(json!({"name": "Alice", "age": 30})).unwrap(); coll.insert_one(json!({"name": "Alice", "age": 30}))
.unwrap();
coll.insert_one(json!({"name": "Bob", "age": 25})).unwrap(); coll.insert_one(json!({"name": "Bob", "age": 25})).unwrap();
let filter = DocumentFilter::new().eq("name", json!("Alice")); let filter = DocumentFilter::new().eq("name", json!("Alice"));
@ -597,9 +626,11 @@ mod tests {
fn test_filter_comparison() { fn test_filter_comparison() {
let coll = Collection::new("users"); let coll = Collection::new("users");
coll.insert_one(json!({"name": "Alice", "age": 30})).unwrap(); coll.insert_one(json!({"name": "Alice", "age": 30}))
.unwrap();
coll.insert_one(json!({"name": "Bob", "age": 25})).unwrap(); coll.insert_one(json!({"name": "Bob", "age": 25})).unwrap();
coll.insert_one(json!({"name": "Charlie", "age": 35})).unwrap(); coll.insert_one(json!({"name": "Charlie", "age": 35}))
.unwrap();
let filter = DocumentFilter::new().gte("age", json!(30)); let filter = DocumentFilter::new().gte("age", json!(30));
let results = coll.find(&filter); let results = coll.find(&filter);
@ -622,7 +653,9 @@ mod tests {
#[test] #[test]
fn test_update_document() { fn test_update_document() {
let coll = Collection::new("users"); let coll = Collection::new("users");
let id = coll.insert_one(json!({"name": "Alice", "age": 30})).unwrap(); let id = coll
.insert_one(json!({"name": "Alice", "age": 30}))
.unwrap();
coll.update_by_id(&id, json!({"age": 31})).unwrap(); coll.update_by_id(&id, json!({"age": 31})).unwrap();

View file

@ -1,9 +1,9 @@
//! Authentication and authorization for Database Gateway. //! Authentication and authorization for Database Gateway.
use parking_lot::RwLock;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use parking_lot::RwLock;
/// API key for authentication. /// API key for authentication.
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]

View file

@ -272,19 +272,13 @@ pub fn json_to_filter(json: &JsonValue) -> Option<Filter> {
// Handle $and // Handle $and
if let Some(and_arr) = obj.get("$and").and_then(|v| v.as_array()) { if let Some(and_arr) = obj.get("$and").and_then(|v| v.as_array()) {
let filters: Vec<Filter> = and_arr let filters: Vec<Filter> = and_arr.iter().filter_map(json_to_filter).collect();
.iter()
.filter_map(json_to_filter)
.collect();
return Some(Filter::And(filters)); return Some(Filter::And(filters));
} }
// Handle $or // Handle $or
if let Some(or_arr) = obj.get("$or").and_then(|v| v.as_array()) { if let Some(or_arr) = obj.get("$or").and_then(|v| v.as_array()) {
let filters: Vec<Filter> = or_arr let filters: Vec<Filter> = or_arr.iter().filter_map(json_to_filter).collect();
.iter()
.filter_map(json_to_filter)
.collect();
return Some(Filter::Or(filters)); return Some(Filter::Or(filters));
} }

View file

@ -137,7 +137,12 @@ async fn kv_get(
// For demo, use a default database // For demo, use a default database
let db = match get_default_database(&state) { let db = match get_default_database(&state) {
Some(db) => db, Some(db) => db,
None => return (StatusCode::NOT_FOUND, Json(ApiResponse::<KvGetResponse>::error("No database"))), None => {
return (
StatusCode::NOT_FOUND,
Json(ApiResponse::<KvGetResponse>::error("No database")),
)
}
}; };
state.record_read(); state.record_read();
@ -153,7 +158,12 @@ async fn kv_set(
) -> impl IntoResponse { ) -> impl IntoResponse {
let db = match get_default_database(&state) { let db = match get_default_database(&state) {
Some(db) => db, Some(db) => db,
None => return (StatusCode::NOT_FOUND, Json(ApiResponse::<KvGetResponse>::error("No database"))), None => {
return (
StatusCode::NOT_FOUND,
Json(ApiResponse::<KvGetResponse>::error("No database")),
)
}
}; };
state.record_write(req.value.len() as u64); state.record_write(req.value.len() as u64);
@ -168,7 +178,12 @@ async fn kv_delete(
) -> impl IntoResponse { ) -> impl IntoResponse {
let db = match get_default_database(&state) { let db = match get_default_database(&state) {
Some(db) => db, Some(db) => db,
None => return (StatusCode::NOT_FOUND, Json(ApiResponse::<bool>::error("No database"))), None => {
return (
StatusCode::NOT_FOUND,
Json(ApiResponse::<bool>::error("No database")),
)
}
}; };
let response = handle_kv_delete(db.kv(), &key); let response = handle_kv_delete(db.kv(), &key);
@ -182,7 +197,12 @@ async fn kv_batch(
) -> impl IntoResponse { ) -> impl IntoResponse {
let db = match get_default_database(&state) { let db = match get_default_database(&state) {
Some(db) => db, Some(db) => db,
None => return (StatusCode::NOT_FOUND, Json(ApiResponse::<KvBatchResponse>::error("No database"))), None => {
return (
StatusCode::NOT_FOUND,
Json(ApiResponse::<KvBatchResponse>::error("No database")),
)
}
}; };
let response = handle_kv_batch(db.kv(), req); let response = handle_kv_batch(db.kv(), req);
@ -217,7 +237,9 @@ async fn list_databases(
}) })
.collect(); .collect();
Json(ApiResponse::ok(ListDatabasesResponse { databases: response })) Json(ApiResponse::ok(ListDatabasesResponse {
databases: response,
}))
} }
async fn create_database( async fn create_database(
@ -250,7 +272,10 @@ async fn create_database(
})), })),
) )
} }
Err(e) => (StatusCode::BAD_REQUEST, Json(ApiResponse::error(e.to_string()))), Err(e) => (
StatusCode::BAD_REQUEST,
Json(ApiResponse::error(e.to_string())),
),
} }
} }
@ -302,12 +327,20 @@ async fn create_collection(
) -> impl IntoResponse { ) -> impl IntoResponse {
let db = match get_database(&state, &db_name) { let db = match get_database(&state, &db_name) {
Some(db) => db, Some(db) => db,
None => return (StatusCode::NOT_FOUND, Json(ApiResponse::<bool>::error("Database not found"))), None => {
return (
StatusCode::NOT_FOUND,
Json(ApiResponse::<bool>::error("Database not found")),
)
}
}; };
match db.documents().create_collection(&req.name) { match db.documents().create_collection(&req.name) {
Ok(_) => (StatusCode::CREATED, Json(ApiResponse::ok(true))), Ok(_) => (StatusCode::CREATED, Json(ApiResponse::ok(true))),
Err(e) => (StatusCode::BAD_REQUEST, Json(ApiResponse::<bool>::error(e.to_string()))), Err(e) => (
StatusCode::BAD_REQUEST,
Json(ApiResponse::<bool>::error(e.to_string())),
),
} }
} }
@ -349,15 +382,20 @@ async fn query_documents(
) -> impl IntoResponse { ) -> impl IntoResponse {
let db = match get_database(&state, &db_name) { let db = match get_database(&state, &db_name) {
Some(db) => db, Some(db) => db,
None => return (StatusCode::NOT_FOUND, Json(ApiResponse::<QueryDocumentsResponse>::error("Database not found"))), None => {
return (
StatusCode::NOT_FOUND,
Json(ApiResponse::<QueryDocumentsResponse>::error(
"Database not found",
)),
)
}
}; };
state.record_read(); state.record_read();
// Build query // Build query
let mut query = Query::new(&coll_name) let mut query = Query::new(&coll_name).skip(req.skip).limit(req.limit);
.skip(req.skip)
.limit(req.limit);
// Add filter // Add filter
if let Some(filter_json) = &req.filter { if let Some(filter_json) = &req.filter {
@ -384,8 +422,14 @@ async fn query_documents(
} }
match db.query().execute(&query) { match db.query().execute(&query) {
Ok(result) => (StatusCode::OK, Json(ApiResponse::ok(QueryDocumentsResponse::from(result)))), Ok(result) => (
Err(e) => (StatusCode::BAD_REQUEST, Json(ApiResponse::error(e.to_string()))), StatusCode::OK,
Json(ApiResponse::ok(QueryDocumentsResponse::from(result))),
),
Err(e) => (
StatusCode::BAD_REQUEST,
Json(ApiResponse::error(e.to_string())),
),
} }
} }
@ -396,15 +440,25 @@ async fn insert_document(
) -> impl IntoResponse { ) -> impl IntoResponse {
let db = match get_database(&state, &db_name) { let db = match get_database(&state, &db_name) {
Some(db) => db, Some(db) => db,
None => return (StatusCode::NOT_FOUND, Json(ApiResponse::<String>::error("Database not found"))), None => {
return (
StatusCode::NOT_FOUND,
Json(ApiResponse::<String>::error("Database not found")),
)
}
}; };
let size = serde_json::to_vec(&req.document).map(|v| v.len()).unwrap_or(0); let size = serde_json::to_vec(&req.document)
.map(|v| v.len())
.unwrap_or(0);
state.record_write(size as u64); state.record_write(size as u64);
match db.documents().insert(&coll_name, req.document) { match db.documents().insert(&coll_name, req.document) {
Ok(id) => (StatusCode::CREATED, Json(ApiResponse::ok(id.to_hex()))), Ok(id) => (StatusCode::CREATED, Json(ApiResponse::ok(id.to_hex()))),
Err(e) => (StatusCode::BAD_REQUEST, Json(ApiResponse::error(e.to_string()))), Err(e) => (
StatusCode::BAD_REQUEST,
Json(ApiResponse::error(e.to_string())),
),
} }
} }
@ -415,7 +469,12 @@ async fn insert_many_documents(
) -> impl IntoResponse { ) -> impl IntoResponse {
let db = match get_database(&state, &db_name) { let db = match get_database(&state, &db_name) {
Some(db) => db, Some(db) => db,
None => return (StatusCode::NOT_FOUND, Json(ApiResponse::<Vec<String>>::error("Database not found"))), None => {
return (
StatusCode::NOT_FOUND,
Json(ApiResponse::<Vec<String>>::error("Database not found")),
)
}
}; };
let mut ids = Vec::with_capacity(req.documents.len()); let mut ids = Vec::with_capacity(req.documents.len());
@ -425,7 +484,12 @@ async fn insert_many_documents(
total_size += serde_json::to_vec(&doc).map(|v| v.len()).unwrap_or(0) as u64; total_size += serde_json::to_vec(&doc).map(|v| v.len()).unwrap_or(0) as u64;
match db.documents().insert(&coll_name, doc) { match db.documents().insert(&coll_name, doc) {
Ok(id) => ids.push(id.to_hex()), Ok(id) => ids.push(id.to_hex()),
Err(e) => return (StatusCode::BAD_REQUEST, Json(ApiResponse::error(e.to_string()))), Err(e) => {
return (
StatusCode::BAD_REQUEST,
Json(ApiResponse::error(e.to_string())),
)
}
} }
} }
@ -477,7 +541,9 @@ async fn update_document(
Err(e) => return Json(ApiResponse::error(e.to_string())), Err(e) => return Json(ApiResponse::error(e.to_string())),
}; };
let update_size = serde_json::to_vec(&req.update).map(|v| v.len()).unwrap_or(0); let update_size = serde_json::to_vec(&req.update)
.map(|v| v.len())
.unwrap_or(0);
state.record_write(update_size as u64); state.record_write(update_size as u64);
match db.documents().update_by_id(&coll_name, &id, req.update) { match db.documents().update_by_id(&coll_name, &id, req.update) {
@ -519,7 +585,12 @@ async fn insert_embeddings(
) -> impl IntoResponse { ) -> impl IntoResponse {
let db = match get_database(&state, &db_name) { let db = match get_database(&state, &db_name) {
Some(db) => db, Some(db) => db,
None => return (StatusCode::NOT_FOUND, Json(ApiResponse::<usize>::error("Database not found"))), None => {
return (
StatusCode::NOT_FOUND,
Json(ApiResponse::<usize>::error("Database not found")),
)
}
}; };
let mut count = 0; let mut count = 0;
@ -533,7 +604,10 @@ async fn insert_embeddings(
} }
if let Err(e) = db.vectors().insert(embedding) { if let Err(e) = db.vectors().insert(embedding) {
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error(e.to_string()))); return (
StatusCode::BAD_REQUEST,
Json(ApiResponse::error(e.to_string())),
);
} }
count += 1; count += 1;
} }
@ -549,7 +623,14 @@ async fn vector_search(
) -> impl IntoResponse { ) -> impl IntoResponse {
let db = match get_database(&state, &db_name) { let db = match get_database(&state, &db_name) {
Some(db) => db, Some(db) => db,
None => return (StatusCode::NOT_FOUND, Json(ApiResponse::<VectorSearchResponse>::error("Database not found"))), None => {
return (
StatusCode::NOT_FOUND,
Json(ApiResponse::<VectorSearchResponse>::error(
"Database not found",
)),
)
}
}; };
state.record_vector_search(); state.record_vector_search();
@ -563,9 +644,18 @@ async fn vector_search(
Ok(results) => { Ok(results) => {
let count = results.len(); let count = results.len();
let matches: Vec<VectorMatch> = results.into_iter().map(Into::into).collect(); let matches: Vec<VectorMatch> = results.into_iter().map(Into::into).collect();
(StatusCode::OK, Json(ApiResponse::ok(VectorSearchResponse { results: matches, count }))) (
StatusCode::OK,
Json(ApiResponse::ok(VectorSearchResponse {
results: matches,
count,
})),
)
} }
Err(e) => (StatusCode::BAD_REQUEST, Json(ApiResponse::error(e.to_string()))), Err(e) => (
StatusCode::BAD_REQUEST,
Json(ApiResponse::error(e.to_string())),
),
} }
} }

View file

@ -114,10 +114,7 @@ impl GatewayServer {
tracing::error!("Failed to create default database: {}", e); tracing::error!("Failed to create default database: {}", e);
} }
let state = Arc::new(AppState::new( let state = Arc::new(AppState::new(self.db_manager.clone(), self.auth.clone()));
self.db_manager.clone(),
self.auth.clone(),
));
let app = create_router(state); let app = create_router(state);

View file

@ -86,7 +86,12 @@ pub struct Edge {
impl Edge { impl Edge {
/// Creates a new directed edge. /// Creates a new directed edge.
pub fn new(source: NodeId, target: NodeId, edge_type: impl Into<String>, properties: JsonValue) -> Self { pub fn new(
source: NodeId,
target: NodeId,
edge_type: impl Into<String>,
properties: JsonValue,
) -> Self {
let now = std::time::SystemTime::now() let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH) .duration_since(std::time::UNIX_EPOCH)
.unwrap() .unwrap()
@ -105,7 +110,12 @@ impl Edge {
} }
/// Creates an undirected edge. /// Creates an undirected edge.
pub fn undirected(source: NodeId, target: NodeId, edge_type: impl Into<String>, properties: JsonValue) -> Self { pub fn undirected(
source: NodeId,
target: NodeId,
edge_type: impl Into<String>,
properties: JsonValue,
) -> Self {
let mut edge = Self::new(source, target, edge_type, properties); let mut edge = Self::new(source, target, edge_type, properties);
edge.directed = false; edge.directed = false;
edge edge
@ -138,8 +148,8 @@ impl Edge {
/// Checks if this edge connects two specific nodes. /// Checks if this edge connects two specific nodes.
pub fn connects_pair(&self, a: &NodeId, b: &NodeId) -> bool { pub fn connects_pair(&self, a: &NodeId, b: &NodeId) -> bool {
(&self.source == a && &self.target == b) || (&self.source == a && &self.target == b)
(!self.directed && &self.source == b && &self.target == a) || (!self.directed && &self.source == b && &self.target == a)
} }
/// Gets a property value. /// Gets a property value.
@ -156,7 +166,9 @@ impl Edge {
/// Checks if the edge matches a property filter. /// Checks if the edge matches a property filter.
pub fn matches_properties(&self, filter: &JsonValue) -> bool { pub fn matches_properties(&self, filter: &JsonValue) -> bool {
if let (Some(filter_obj), Some(props_obj)) = (filter.as_object(), self.properties.as_object()) { if let (Some(filter_obj), Some(props_obj)) =
(filter.as_object(), self.properties.as_object())
{
for (key, expected) in filter_obj { for (key, expected) in filter_obj {
if let Some(actual) = props_obj.get(key) { if let Some(actual) = props_obj.get(key) {
if actual != expected { if actual != expected {
@ -216,7 +228,12 @@ impl EdgeBuilder {
/// Builds the edge. /// Builds the edge.
pub fn build(self) -> Edge { pub fn build(self) -> Edge {
let mut edge = Edge::new(self.source, self.target, self.edge_type, JsonValue::Object(self.properties)); let mut edge = Edge::new(
self.source,
self.target,
self.edge_type,
JsonValue::Object(self.properties),
);
edge.directed = self.directed; edge.directed = self.directed;
edge.weight = self.weight; edge.weight = self.weight;
edge edge
@ -264,7 +281,10 @@ mod tests {
assert!(!edge.directed); assert!(!edge.directed);
assert_eq!(edge.weight, 2.5); assert_eq!(edge.weight, 2.5);
assert_eq!(edge.get_property("percentage"), Some(&serde_json::json!(50))); assert_eq!(
edge.get_property("percentage"),
Some(&serde_json::json!(50))
);
} }
#[test] #[test]

View file

@ -172,7 +172,9 @@ impl Node {
/// Checks if the node matches a property filter. /// Checks if the node matches a property filter.
pub fn matches_properties(&self, filter: &JsonValue) -> bool { pub fn matches_properties(&self, filter: &JsonValue) -> bool {
if let (Some(filter_obj), Some(props_obj)) = (filter.as_object(), self.properties.as_object()) { if let (Some(filter_obj), Some(props_obj)) =
(filter.as_object(), self.properties.as_object())
{
for (key, expected) in filter_obj { for (key, expected) in filter_obj {
if let Some(actual) = props_obj.get(key) { if let Some(actual) = props_obj.get(key) {
if actual != expected { if actual != expected {
@ -258,10 +260,7 @@ mod tests {
assert!(node.has_label("User")); assert!(node.has_label("User"));
assert!(!node.has_label("Admin")); assert!(!node.has_label("Admin"));
assert_eq!( assert_eq!(node.get_property("name"), Some(&serde_json::json!("Alice")));
node.get_property("name"),
Some(&serde_json::json!("Alice"))
);
} }
#[test] #[test]

View file

@ -60,7 +60,10 @@ impl Eq for DijkstraState {}
impl Ord for DijkstraState { impl Ord for DijkstraState {
fn cmp(&self, other: &Self) -> Ordering { fn cmp(&self, other: &Self) -> Ordering {
// Reverse ordering for min-heap // Reverse ordering for min-heap
other.distance.partial_cmp(&self.distance).unwrap_or(Ordering::Equal) other
.distance
.partial_cmp(&self.distance)
.unwrap_or(Ordering::Equal)
} }
} }
@ -140,9 +143,16 @@ impl<'a> PathFinder<'a> {
let mut visited = HashSet::new(); let mut visited = HashSet::new();
distances.insert(*from, 0.0); distances.insert(*from, 0.0);
heap.push(DijkstraState { node: *from, distance: 0.0 }); heap.push(DijkstraState {
node: *from,
distance: 0.0,
});
while let Some(DijkstraState { node: current, distance: dist }) = heap.pop() { while let Some(DijkstraState {
node: current,
distance: dist,
}) = heap.pop()
{
if &current == to { if &current == to {
// Reconstruct path // Reconstruct path
let mut path = vec![current]; let mut path = vec![current];
@ -181,12 +191,18 @@ impl<'a> PathFinder<'a> {
} }
let new_dist = dist + edge.weight; let new_dist = dist + edge.weight;
let is_shorter = distances.get(&neighbor).map(|&d| new_dist < d).unwrap_or(true); let is_shorter = distances
.get(&neighbor)
.map(|&d| new_dist < d)
.unwrap_or(true);
if is_shorter { if is_shorter {
distances.insert(neighbor, new_dist); distances.insert(neighbor, new_dist);
previous.insert(neighbor, (current, edge.clone())); previous.insert(neighbor, (current, edge.clone()));
heap.push(DijkstraState { node: neighbor, distance: new_dist }); heap.push(DijkstraState {
node: neighbor,
distance: new_dist,
});
} }
} }
} }
@ -251,7 +267,9 @@ impl<'a> PathFinder<'a> {
path.push(neighbor); path.push(neighbor);
edges.push(edge.clone()); edges.push(edge.clone());
self.find_all_paths_dfs(&neighbor, target, max_length, path, edges, visited, results); self.find_all_paths_dfs(
&neighbor, target, max_length, path, edges, visited, results,
);
path.pop(); path.pop();
edges.pop(); edges.pop();
@ -261,7 +279,12 @@ impl<'a> PathFinder<'a> {
} }
/// Finds the shortest path considering only specific edge types. /// Finds the shortest path considering only specific edge types.
pub fn shortest_path_by_type(&self, from: &NodeId, to: &NodeId, edge_types: &[String]) -> PathResult { pub fn shortest_path_by_type(
&self,
from: &NodeId,
to: &NodeId,
edge_types: &[String],
) -> PathResult {
if from == to { if from == to {
return PathResult::found(vec![*from], Vec::new(), 0.0); return PathResult::found(vec![*from], Vec::new(), 0.0);
} }
@ -377,11 +400,21 @@ mod tests {
let d = store.create_node(vec![], serde_json::json!({"name": "D"})); let d = store.create_node(vec![], serde_json::json!({"name": "D"}));
let e = store.create_node(vec![], serde_json::json!({"name": "E"})); let e = store.create_node(vec![], serde_json::json!({"name": "E"}));
store.create_edge(a, b, "LINK", serde_json::json!({})).unwrap(); store
store.create_edge(b, c, "LINK", serde_json::json!({})).unwrap(); .create_edge(a, b, "LINK", serde_json::json!({}))
store.create_edge(c, d, "LINK", serde_json::json!({})).unwrap(); .unwrap();
store.create_edge(a, e, "LINK", serde_json::json!({})).unwrap(); store
store.create_edge(e, d, "LINK", serde_json::json!({})).unwrap(); .create_edge(b, c, "LINK", serde_json::json!({}))
.unwrap();
store
.create_edge(c, d, "LINK", serde_json::json!({}))
.unwrap();
store
.create_edge(a, e, "LINK", serde_json::json!({}))
.unwrap();
store
.create_edge(e, d, "LINK", serde_json::json!({}))
.unwrap();
store store
} }
@ -392,8 +425,14 @@ mod tests {
let finder = PathFinder::new(&store); let finder = PathFinder::new(&store);
let nodes: Vec<_> = store.find_nodes(None, &serde_json::json!({})); let nodes: Vec<_> = store.find_nodes(None, &serde_json::json!({}));
let a = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("A"))).unwrap(); let a = nodes
let d = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("D"))).unwrap(); .iter()
.find(|n| n.get_property("name") == Some(&serde_json::json!("A")))
.unwrap();
let d = nodes
.iter()
.find(|n| n.get_property("name") == Some(&serde_json::json!("D")))
.unwrap();
let result = finder.shortest_path_bfs(&a.id, &d.id); let result = finder.shortest_path_bfs(&a.id, &d.id);
@ -413,13 +452,19 @@ mod tests {
// A --(3.0)--> C // A --(3.0)--> C
let mut edge1 = super::super::edge::Edge::new(a, b, "LINK", serde_json::json!({})); let mut edge1 = super::super::edge::Edge::new(a, b, "LINK", serde_json::json!({}));
edge1.weight = 1.0; edge1.weight = 1.0;
store.create_edge(a, b, "LINK", serde_json::json!({})).unwrap(); store
.create_edge(a, b, "LINK", serde_json::json!({}))
.unwrap();
let mut edge2 = super::super::edge::Edge::new(b, c, "LINK", serde_json::json!({})); let mut edge2 = super::super::edge::Edge::new(b, c, "LINK", serde_json::json!({}));
edge2.weight = 1.0; edge2.weight = 1.0;
store.create_edge(b, c, "LINK", serde_json::json!({})).unwrap(); store
.create_edge(b, c, "LINK", serde_json::json!({}))
.unwrap();
store.create_edge(a, c, "DIRECT", serde_json::json!({})).unwrap(); store
.create_edge(a, c, "DIRECT", serde_json::json!({}))
.unwrap();
let finder = PathFinder::new(&store); let finder = PathFinder::new(&store);
let result = finder.shortest_path_dijkstra(&a, &c); let result = finder.shortest_path_dijkstra(&a, &c);
@ -449,8 +494,14 @@ mod tests {
let finder = PathFinder::new(&store); let finder = PathFinder::new(&store);
let nodes: Vec<_> = store.find_nodes(None, &serde_json::json!({})); let nodes: Vec<_> = store.find_nodes(None, &serde_json::json!({}));
let a = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("A"))).unwrap(); let a = nodes
let d = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("D"))).unwrap(); .iter()
.find(|n| n.get_property("name") == Some(&serde_json::json!("A")))
.unwrap();
let d = nodes
.iter()
.find(|n| n.get_property("name") == Some(&serde_json::json!("D")))
.unwrap();
let paths = finder.all_paths(&a.id, &d.id, 5); let paths = finder.all_paths(&a.id, &d.id, 5);
@ -463,8 +514,14 @@ mod tests {
let finder = PathFinder::new(&store); let finder = PathFinder::new(&store);
let nodes: Vec<_> = store.find_nodes(None, &serde_json::json!({})); let nodes: Vec<_> = store.find_nodes(None, &serde_json::json!({}));
let a = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("A"))).unwrap(); let a = nodes
let d = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("D"))).unwrap(); .iter()
.find(|n| n.get_property("name") == Some(&serde_json::json!("A")))
.unwrap();
let d = nodes
.iter()
.find(|n| n.get_property("name") == Some(&serde_json::json!("D")))
.unwrap();
assert!(finder.path_exists(&a.id, &d.id)); assert!(finder.path_exists(&a.id, &d.id));
} }
@ -475,9 +532,18 @@ mod tests {
let finder = PathFinder::new(&store); let finder = PathFinder::new(&store);
let nodes: Vec<_> = store.find_nodes(None, &serde_json::json!({})); let nodes: Vec<_> = store.find_nodes(None, &serde_json::json!({}));
let a = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("A"))).unwrap(); let a = nodes
let d = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("D"))).unwrap(); .iter()
let b = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("B"))).unwrap(); .find(|n| n.get_property("name") == Some(&serde_json::json!("A")))
.unwrap();
let d = nodes
.iter()
.find(|n| n.get_property("name") == Some(&serde_json::json!("D")))
.unwrap();
let b = nodes
.iter()
.find(|n| n.get_property("name") == Some(&serde_json::json!("B")))
.unwrap();
assert_eq!(finder.distance(&a.id, &b.id), Some(1)); assert_eq!(finder.distance(&a.id, &b.id), Some(1));
assert_eq!(finder.distance(&a.id, &d.id), Some(2)); // A -> E -> D assert_eq!(finder.distance(&a.id, &d.id), Some(2)); // A -> E -> D

View file

@ -23,7 +23,10 @@ pub enum GraphQuery {
/// DELETE query for removing nodes/edges. /// DELETE query for removing nodes/edges.
Delete { variable: String, detach: bool }, Delete { variable: String, detach: bool },
/// SET query for updating properties. /// SET query for updating properties.
Set { variable: String, properties: JsonValue }, Set {
variable: String,
properties: JsonValue,
},
} }
/// Pattern to match in the graph. /// Pattern to match in the graph.
@ -78,15 +81,35 @@ pub enum RelationshipDirection {
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
pub enum WhereClause { pub enum WhereClause {
/// Property comparison. /// Property comparison.
PropertyEquals { variable: String, property: String, value: JsonValue }, PropertyEquals {
variable: String,
property: String,
value: JsonValue,
},
/// Property comparison (not equals). /// Property comparison (not equals).
PropertyNotEquals { variable: String, property: String, value: JsonValue }, PropertyNotEquals {
variable: String,
property: String,
value: JsonValue,
},
/// Property greater than. /// Property greater than.
PropertyGt { variable: String, property: String, value: JsonValue }, PropertyGt {
variable: String,
property: String,
value: JsonValue,
},
/// Property less than. /// Property less than.
PropertyLt { variable: String, property: String, value: JsonValue }, PropertyLt {
variable: String,
property: String,
value: JsonValue,
},
/// Property contains (for text). /// Property contains (for text).
PropertyContains { variable: String, property: String, value: String }, PropertyContains {
variable: String,
property: String,
value: String,
},
/// AND condition. /// AND condition.
And(Box<WhereClause>, Box<WhereClause>), And(Box<WhereClause>, Box<WhereClause>),
/// OR condition. /// OR condition.
@ -105,7 +128,10 @@ pub enum ReturnItem {
/// Return a property of a variable. /// Return a property of a variable.
Property { variable: String, property: String }, Property { variable: String, property: String },
/// Return with an alias. /// Return with an alias.
Alias { item: Box<ReturnItem>, alias: String }, Alias {
item: Box<ReturnItem>,
alias: String,
},
/// Count aggregation. /// Count aggregation.
Count(Option<String>), Count(Option<String>),
} }
@ -114,7 +140,11 @@ pub enum ReturnItem {
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
pub enum CreateElement { pub enum CreateElement {
/// Create a node. /// Create a node.
Node { variable: Option<String>, labels: Vec<String>, properties: JsonValue }, Node {
variable: Option<String>,
labels: Vec<String>,
properties: JsonValue,
},
/// Create a relationship. /// Create a relationship.
Relationship { Relationship {
from_var: String, from_var: String,
@ -176,7 +206,10 @@ impl GraphQueryParser {
} else if upper.starts_with("SET") { } else if upper.starts_with("SET") {
Self::parse_set(query) Self::parse_set(query)
} else { } else {
Err(GraphError::InvalidOperation(format!("Unknown query type: {}", query))) Err(GraphError::InvalidOperation(format!(
"Unknown query type: {}",
query
)))
} }
} }
@ -185,7 +218,10 @@ impl GraphQueryParser {
let upper = query.to_uppercase(); let upper = query.to_uppercase();
// Find MATCH, WHERE, RETURN, LIMIT positions // Find MATCH, WHERE, RETURN, LIMIT positions
let match_end = upper.find("WHERE").or_else(|| upper.find("RETURN")).unwrap_or(query.len()); let match_end = upper
.find("WHERE")
.or_else(|| upper.find("RETURN"))
.unwrap_or(query.len());
let where_start = upper.find("WHERE"); let where_start = upper.find("WHERE");
let return_start = upper.find("RETURN"); let return_start = upper.find("RETURN");
let limit_start = upper.find("LIMIT"); let limit_start = upper.find("LIMIT");
@ -253,7 +289,9 @@ impl GraphQueryParser {
} }
if nodes.is_empty() { if nodes.is_empty() {
return Err(GraphError::InvalidOperation("No node pattern found".to_string())); return Err(GraphError::InvalidOperation(
"No node pattern found".to_string(),
));
} }
// Combine nodes with relationships // Combine nodes with relationships
@ -264,10 +302,15 @@ impl GraphQueryParser {
} }
} }
Ok(MatchPattern { start, relationships }) Ok(MatchPattern {
start,
relationships,
})
} }
fn parse_node_pattern(chars: &mut std::iter::Peekable<std::str::Chars>) -> Result<NodePattern, GraphError> { fn parse_node_pattern(
chars: &mut std::iter::Peekable<std::str::Chars>,
) -> Result<NodePattern, GraphError> {
// Consume '(' // Consume '('
chars.next(); chars.next();
@ -335,10 +378,16 @@ impl GraphQueryParser {
} }
} }
Ok(NodePattern { variable, labels, properties }) Ok(NodePattern {
variable,
labels,
properties,
})
} }
fn parse_relationship_pattern(chars: &mut std::iter::Peekable<std::str::Chars>) -> Result<RelationshipPattern, GraphError> { fn parse_relationship_pattern(
chars: &mut std::iter::Peekable<std::str::Chars>,
) -> Result<RelationshipPattern, GraphError> {
let mut direction = RelationshipDirection::Undirected; let mut direction = RelationshipDirection::Undirected;
let mut edge_type = None; let mut edge_type = None;
let mut variable = None; let mut variable = None;
@ -408,7 +457,11 @@ impl GraphQueryParser {
variable, variable,
edge_type, edge_type,
direction, direction,
target: NodePattern { variable: None, labels: Vec::new(), properties: None }, target: NodePattern {
variable: None,
labels: Vec::new(),
properties: None,
},
min_hops, min_hops,
max_hops, max_hops,
}) })
@ -476,7 +529,9 @@ impl GraphQueryParser {
elements.push(CreateElement::Node { elements.push(CreateElement::Node {
variable: node.variable, variable: node.variable,
labels: node.labels, labels: node.labels,
properties: node.properties.unwrap_or(JsonValue::Object(serde_json::Map::new())), properties: node
.properties
.unwrap_or(JsonValue::Object(serde_json::Map::new())),
}); });
} else { } else {
break; break;
@ -488,7 +543,11 @@ impl GraphQueryParser {
fn parse_delete(query: &str) -> Result<GraphQuery, GraphError> { fn parse_delete(query: &str) -> Result<GraphQuery, GraphError> {
let detach = query.to_uppercase().starts_with("DETACH"); let detach = query.to_uppercase().starts_with("DETACH");
let start = if detach { "DETACH DELETE".len() } else { "DELETE".len() }; let start = if detach {
"DETACH DELETE".len()
} else {
"DELETE".len()
};
let variable = query[start..].trim().to_string(); let variable = query[start..].trim().to_string();
Ok(GraphQuery::Delete { variable, detach }) Ok(GraphQuery::Delete { variable, detach })
@ -500,19 +559,24 @@ impl GraphQueryParser {
let parts: Vec<_> = content.split('=').collect(); let parts: Vec<_> = content.split('=').collect();
if parts.len() != 2 { if parts.len() != 2 {
return Err(GraphError::InvalidOperation("Invalid SET syntax".to_string())); return Err(GraphError::InvalidOperation(
"Invalid SET syntax".to_string(),
));
} }
let var_prop: Vec<_> = parts[0].trim().split('.').collect(); let var_prop: Vec<_> = parts[0].trim().split('.').collect();
if var_prop.len() != 2 { if var_prop.len() != 2 {
return Err(GraphError::InvalidOperation("Invalid SET variable".to_string())); return Err(GraphError::InvalidOperation(
"Invalid SET variable".to_string(),
));
} }
let variable = var_prop[0].to_string(); let variable = var_prop[0].to_string();
let property = var_prop[1].to_string(); let property = var_prop[1].to_string();
let value_str = parts[1].trim(); let value_str = parts[1].trim();
let value: JsonValue = serde_json::from_str(value_str).unwrap_or(JsonValue::String(value_str.to_string())); let value: JsonValue =
serde_json::from_str(value_str).unwrap_or(JsonValue::String(value_str.to_string()));
Ok(GraphQuery::Set { Ok(GraphQuery::Set {
variable, variable,
@ -535,18 +599,21 @@ impl<'a> GraphQueryExecutor<'a> {
/// Executes a graph query. /// Executes a graph query.
pub fn execute(&self, query: &GraphQuery) -> Result<QueryResult, GraphError> { pub fn execute(&self, query: &GraphQuery) -> Result<QueryResult, GraphError> {
match query { match query {
GraphQuery::Match { pattern, where_clause, return_items, limit } => { GraphQuery::Match {
self.execute_match(pattern, where_clause.as_ref(), return_items, *limit) pattern,
} where_clause,
GraphQuery::Create { .. } => { return_items,
Err(GraphError::InvalidOperation("CREATE requires mutable access".to_string())) limit,
} } => self.execute_match(pattern, where_clause.as_ref(), return_items, *limit),
GraphQuery::Delete { .. } => { GraphQuery::Create { .. } => Err(GraphError::InvalidOperation(
Err(GraphError::InvalidOperation("DELETE requires mutable access".to_string())) "CREATE requires mutable access".to_string(),
} )),
GraphQuery::Set { .. } => { GraphQuery::Delete { .. } => Err(GraphError::InvalidOperation(
Err(GraphError::InvalidOperation("SET requires mutable access".to_string())) "DELETE requires mutable access".to_string(),
} )),
GraphQuery::Set { .. } => Err(GraphError::InvalidOperation(
"SET requires mutable access".to_string(),
)),
} }
} }
@ -586,7 +653,11 @@ impl<'a> GraphQueryExecutor<'a> {
.depth(rel_pattern.max_hops) .depth(rel_pattern.max_hops)
.direction(direction) .direction(direction)
.edge_types( .edge_types(
rel_pattern.edge_type.clone().map(|t| vec![t]).unwrap_or_default(), rel_pattern
.edge_type
.clone()
.map(|t| vec![t])
.unwrap_or_default(),
) )
.labels(rel_pattern.target.labels.clone()); .labels(rel_pattern.target.labels.clone());
@ -635,7 +706,10 @@ impl<'a> GraphQueryExecutor<'a> {
fn find_matching_nodes(&self, pattern: &NodePattern) -> Vec<Node> { fn find_matching_nodes(&self, pattern: &NodePattern) -> Vec<Node> {
let label = pattern.labels.first().map(|s| s.as_str()); let label = pattern.labels.first().map(|s| s.as_str());
let filter = pattern.properties.clone().unwrap_or(JsonValue::Object(serde_json::Map::new())); let filter = pattern
.properties
.clone()
.unwrap_or(JsonValue::Object(serde_json::Map::new()));
self.store.find_nodes(label, &filter) self.store.find_nodes(label, &filter)
} }
@ -657,7 +731,11 @@ impl<'a> GraphQueryExecutor<'a> {
}) })
} }
fn get_column_names(&self, return_items: &[ReturnItem], bindings: &[HashMap<String, JsonValue>]) -> Vec<String> { fn get_column_names(
&self,
return_items: &[ReturnItem],
bindings: &[HashMap<String, JsonValue>],
) -> Vec<String> {
let mut columns = Vec::new(); let mut columns = Vec::new();
for item in return_items { for item in return_items {
@ -673,7 +751,10 @@ impl<'a> GraphQueryExecutor<'a> {
} }
ReturnItem::Alias { alias, .. } => columns.push(alias.clone()), ReturnItem::Alias { alias, .. } => columns.push(alias.clone()),
ReturnItem::Count(var) => { ReturnItem::Count(var) => {
columns.push(format!("count({})", var.as_ref().map(|s| s.as_str()).unwrap_or("*"))); columns.push(format!(
"count({})",
var.as_ref().map(|s| s.as_str()).unwrap_or("*")
));
} }
} }
} }
@ -681,11 +762,18 @@ impl<'a> GraphQueryExecutor<'a> {
columns columns
} }
fn extract_rows(&self, return_items: &[ReturnItem], bindings: &[HashMap<String, JsonValue>]) -> Vec<Vec<JsonValue>> { fn extract_rows(
&self,
return_items: &[ReturnItem],
bindings: &[HashMap<String, JsonValue>],
) -> Vec<Vec<JsonValue>> {
let mut rows = Vec::new(); let mut rows = Vec::new();
// Handle COUNT specially // Handle COUNT specially
if return_items.iter().any(|i| matches!(i, ReturnItem::Count(_))) { if return_items
.iter()
.any(|i| matches!(i, ReturnItem::Count(_)))
{
rows.push(vec![JsonValue::Number(bindings.len().into())]); rows.push(vec![JsonValue::Number(bindings.len().into())]);
return rows; return rows;
} }
@ -760,8 +848,14 @@ mod tests {
if let GraphQuery::Match { pattern, .. } = parsed { if let GraphQuery::Match { pattern, .. } = parsed {
assert_eq!(pattern.start.labels, vec!["User".to_string()]); assert_eq!(pattern.start.labels, vec!["User".to_string()]);
assert_eq!(pattern.relationships.len(), 1); assert_eq!(pattern.relationships.len(), 1);
assert_eq!(pattern.relationships[0].edge_type, Some("FRIEND".to_string())); assert_eq!(
assert_eq!(pattern.relationships[0].direction, RelationshipDirection::Outgoing); pattern.relationships[0].edge_type,
Some("FRIEND".to_string())
);
assert_eq!(
pattern.relationships[0].direction,
RelationshipDirection::Outgoing
);
} else { } else {
panic!("Expected Match query"); panic!("Expected Match query");
} }
@ -771,9 +865,14 @@ mod tests {
fn test_execute_match() { fn test_execute_match() {
let store = GraphStore::new(); let store = GraphStore::new();
let alice = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Alice"})); let alice = store.create_node(
vec!["User".to_string()],
serde_json::json!({"name": "Alice"}),
);
let bob = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Bob"})); let bob = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Bob"}));
store.create_edge(alice, bob, "FRIEND", serde_json::json!({})).unwrap(); store
.create_edge(alice, bob, "FRIEND", serde_json::json!({}))
.unwrap();
let query = GraphQueryParser::parse("MATCH (n:User) RETURN n").unwrap(); let query = GraphQueryParser::parse("MATCH (n:User) RETURN n").unwrap();
let executor = GraphQueryExecutor::new(&store); let executor = GraphQueryExecutor::new(&store);

View file

@ -177,18 +177,8 @@ impl GraphStore {
/// Deletes a node and all its connected edges. /// Deletes a node and all its connected edges.
pub fn delete_node(&self, id: &NodeId) -> Result<(), GraphError> { pub fn delete_node(&self, id: &NodeId) -> Result<(), GraphError> {
// Get connected edges // Get connected edges
let outgoing: Vec<EdgeId> = self let outgoing: Vec<EdgeId> = self.adjacency.read().get(id).cloned().unwrap_or_default();
.adjacency let incoming: Vec<EdgeId> = self.reverse_adj.read().get(id).cloned().unwrap_or_default();
.read()
.get(id)
.cloned()
.unwrap_or_default();
let incoming: Vec<EdgeId> = self
.reverse_adj
.read()
.get(id)
.cloned()
.unwrap_or_default();
// Delete all connected edges // Delete all connected edges
for edge_id in outgoing.iter().chain(incoming.iter()) { for edge_id in outgoing.iter().chain(incoming.iter()) {
@ -457,7 +447,12 @@ impl GraphStore {
} }
/// Gets the neighbor node from an edge. /// Gets the neighbor node from an edge.
fn get_neighbor_from_edge(&self, edge: &Edge, from: &NodeId, direction: Direction) -> Option<NodeId> { fn get_neighbor_from_edge(
&self,
edge: &Edge,
from: &NodeId,
direction: Direction,
) -> Option<NodeId> {
match direction { match direction {
Direction::Outgoing => { Direction::Outgoing => {
if &edge.source == from { if &edge.source == from {
@ -491,7 +486,12 @@ impl GraphStore {
} }
/// Gets neighbors connected by a specific edge type. /// Gets neighbors connected by a specific edge type.
pub fn neighbors_by_type(&self, id: &NodeId, edge_type: &str, direction: Direction) -> Vec<Node> { pub fn neighbors_by_type(
&self,
id: &NodeId,
edge_type: &str,
direction: Direction,
) -> Vec<Node> {
let edges = self.edges_of(id, direction); let edges = self.edges_of(id, direction);
let nodes = self.nodes.read(); let nodes = self.nodes.read();
@ -565,7 +565,10 @@ mod tests {
fn test_create_edge() { fn test_create_edge() {
let store = GraphStore::new(); let store = GraphStore::new();
let alice = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Alice"})); let alice = store.create_node(
vec!["User".to_string()],
serde_json::json!({"name": "Alice"}),
);
let bob = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Bob"})); let bob = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Bob"}));
let edge_id = store let edge_id = store
@ -582,12 +585,22 @@ mod tests {
fn test_neighbors() { fn test_neighbors() {
let store = GraphStore::new(); let store = GraphStore::new();
let alice = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Alice"})); let alice = store.create_node(
vec!["User".to_string()],
serde_json::json!({"name": "Alice"}),
);
let bob = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Bob"})); let bob = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Bob"}));
let charlie = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Charlie"})); let charlie = store.create_node(
vec!["User".to_string()],
serde_json::json!({"name": "Charlie"}),
);
store.create_edge(alice, bob, "FRIEND", serde_json::json!({})).unwrap(); store
store.create_edge(alice, charlie, "FRIEND", serde_json::json!({})).unwrap(); .create_edge(alice, bob, "FRIEND", serde_json::json!({}))
.unwrap();
store
.create_edge(alice, charlie, "FRIEND", serde_json::json!({}))
.unwrap();
let neighbors = store.neighbors(&alice, Direction::Outgoing); let neighbors = store.neighbors(&alice, Direction::Outgoing);
assert_eq!(neighbors.len(), 2); assert_eq!(neighbors.len(), 2);
@ -597,9 +610,15 @@ mod tests {
fn test_find_by_label() { fn test_find_by_label() {
let store = GraphStore::new(); let store = GraphStore::new();
store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Alice"})); store.create_node(
vec!["User".to_string()],
serde_json::json!({"name": "Alice"}),
);
store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Bob"})); store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Bob"}));
store.create_node(vec!["Product".to_string()], serde_json::json!({"name": "Widget"})); store.create_node(
vec!["Product".to_string()],
serde_json::json!({"name": "Widget"}),
);
let users = store.find_nodes_by_label("User"); let users = store.find_nodes_by_label("User");
assert_eq!(users.len(), 2); assert_eq!(users.len(), 2);
@ -615,7 +634,9 @@ mod tests {
let alice = store.create_node(vec!["User".to_string()], serde_json::json!({})); let alice = store.create_node(vec!["User".to_string()], serde_json::json!({}));
let bob = store.create_node(vec!["User".to_string()], serde_json::json!({})); let bob = store.create_node(vec!["User".to_string()], serde_json::json!({}));
store.create_edge(alice, bob, "FRIEND", serde_json::json!({})).unwrap(); store
.create_edge(alice, bob, "FRIEND", serde_json::json!({}))
.unwrap();
// Delete Alice - should also delete the edge // Delete Alice - should also delete the edge
store.delete_node(&alice).unwrap(); store.delete_node(&alice).unwrap();
@ -631,7 +652,9 @@ mod tests {
let a = store.create_node(vec![], serde_json::json!({})); let a = store.create_node(vec![], serde_json::json!({}));
let b = store.create_node(vec![], serde_json::json!({})); let b = store.create_node(vec![], serde_json::json!({}));
store.create_undirected_edge(a, b, "LINK", serde_json::json!({})).unwrap(); store
.create_undirected_edge(a, b, "LINK", serde_json::json!({}))
.unwrap();
// Both directions should work // Both directions should work
let a_neighbors = store.neighbors(&a, Direction::Outgoing); let a_neighbors = store.neighbors(&a, Direction::Outgoing);
@ -648,8 +671,12 @@ mod tests {
let a = store.create_node(vec![], serde_json::json!({})); let a = store.create_node(vec![], serde_json::json!({}));
let b = store.create_node(vec![], serde_json::json!({})); let b = store.create_node(vec![], serde_json::json!({}));
store.create_edge(a, b, "TYPE_A", serde_json::json!({})).unwrap(); store
store.create_edge(a, b, "TYPE_B", serde_json::json!({})).unwrap(); .create_edge(a, b, "TYPE_A", serde_json::json!({}))
.unwrap();
store
.create_edge(a, b, "TYPE_B", serde_json::json!({}))
.unwrap();
let edges = store.edges_between(&a, &b); let edges = store.edges_between(&a, &b);
assert_eq!(edges.len(), 2); assert_eq!(edges.len(), 2);

View file

@ -171,7 +171,9 @@ impl<'a> Traverser<'a> {
for edge in edges { for edge in edges {
// Check edge type filter // Check edge type filter
if !query.edge_types.is_empty() && !query.edge_types.contains(&edge.edge_type) { if !query.edge_types.is_empty()
&& !query.edge_types.contains(&edge.edge_type)
{
continue; continue;
} }
@ -395,18 +397,35 @@ mod tests {
fn setup_social_graph() -> GraphStore { fn setup_social_graph() -> GraphStore {
let store = GraphStore::new(); let store = GraphStore::new();
let alice = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Alice"})); let alice = store.create_node(
vec!["User".to_string()],
serde_json::json!({"name": "Alice"}),
);
let bob = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Bob"})); let bob = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Bob"}));
let charlie = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Charlie"})); let charlie = store.create_node(
let dave = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Dave"})); vec!["User".to_string()],
serde_json::json!({"name": "Charlie"}),
);
let dave = store.create_node(
vec!["User".to_string()],
serde_json::json!({"name": "Dave"}),
);
// Alice -> Bob -> Charlie -> Dave // Alice -> Bob -> Charlie -> Dave
store.create_edge(alice, bob, "FRIEND", serde_json::json!({})).unwrap(); store
store.create_edge(bob, charlie, "FRIEND", serde_json::json!({})).unwrap(); .create_edge(alice, bob, "FRIEND", serde_json::json!({}))
store.create_edge(charlie, dave, "FRIEND", serde_json::json!({})).unwrap(); .unwrap();
store
.create_edge(bob, charlie, "FRIEND", serde_json::json!({}))
.unwrap();
store
.create_edge(charlie, dave, "FRIEND", serde_json::json!({}))
.unwrap();
// Alice -> Charlie (shortcut) // Alice -> Charlie (shortcut)
store.create_edge(alice, charlie, "KNOWS", serde_json::json!({})).unwrap(); store
.create_edge(alice, charlie, "KNOWS", serde_json::json!({}))
.unwrap();
store store
} }
@ -417,7 +436,10 @@ mod tests {
let traverser = Traverser::new(&store); let traverser = Traverser::new(&store);
let users = store.find_nodes_by_label("User"); let users = store.find_nodes_by_label("User");
let alice = users.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("Alice"))).unwrap(); let alice = users
.iter()
.find(|n| n.get_property("name") == Some(&serde_json::json!("Alice")))
.unwrap();
let query = TraversalQuery::new().depth(2); let query = TraversalQuery::new().depth(2);
let results = traverser.traverse(&alice.id, &query); let results = traverser.traverse(&alice.id, &query);
@ -432,7 +454,10 @@ mod tests {
let traverser = Traverser::new(&store); let traverser = Traverser::new(&store);
let users = store.find_nodes_by_label("User"); let users = store.find_nodes_by_label("User");
let alice = users.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("Alice"))).unwrap(); let alice = users
.iter()
.find(|n| n.get_property("name") == Some(&serde_json::json!("Alice")))
.unwrap();
let query = TraversalQuery::new() let query = TraversalQuery::new()
.depth(2) .depth(2)
@ -440,7 +465,10 @@ mod tests {
let results = traverser.traverse(&alice.id, &query); let results = traverser.traverse(&alice.id, &query);
// Following only FRIEND edges: Alice -> Bob -> Charlie // Following only FRIEND edges: Alice -> Bob -> Charlie
let names: Vec<_> = results.iter().filter_map(|r| r.node.get_property("name")).collect(); let names: Vec<_> = results
.iter()
.filter_map(|r| r.node.get_property("name"))
.collect();
assert!(names.contains(&&serde_json::json!("Bob"))); assert!(names.contains(&&serde_json::json!("Bob")));
assert!(names.contains(&&serde_json::json!("Charlie"))); assert!(names.contains(&&serde_json::json!("Charlie")));
} }
@ -451,7 +479,10 @@ mod tests {
let traverser = Traverser::new(&store); let traverser = Traverser::new(&store);
let users = store.find_nodes_by_label("User"); let users = store.find_nodes_by_label("User");
let alice = users.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("Alice"))).unwrap(); let alice = users
.iter()
.find(|n| n.get_property("name") == Some(&serde_json::json!("Alice")))
.unwrap();
let query = TraversalQuery::new().depth(1); let query = TraversalQuery::new().depth(1);
let results = traverser.traverse(&alice.id, &query); let results = traverser.traverse(&alice.id, &query);
@ -468,7 +499,10 @@ mod tests {
let traverser = Traverser::new(&store); let traverser = Traverser::new(&store);
let users = store.find_nodes_by_label("User"); let users = store.find_nodes_by_label("User");
let alice = users.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("Alice"))).unwrap(); let alice = users
.iter()
.find(|n| n.get_property("name") == Some(&serde_json::json!("Alice")))
.unwrap();
let query = TraversalQuery::new().depth(10).limit(2); let query = TraversalQuery::new().depth(10).limit(2);
let results = traverser.traverse(&alice.id, &query); let results = traverser.traverse(&alice.id, &query);
@ -486,11 +520,21 @@ mod tests {
let mutual2 = store.create_node(vec![], serde_json::json!({"name": "Mutual2"})); let mutual2 = store.create_node(vec![], serde_json::json!({"name": "Mutual2"}));
let only_alice = store.create_node(vec![], serde_json::json!({"name": "OnlyAlice"})); let only_alice = store.create_node(vec![], serde_json::json!({"name": "OnlyAlice"}));
store.create_edge(alice, mutual1, "FRIEND", serde_json::json!({})).unwrap(); store
store.create_edge(alice, mutual2, "FRIEND", serde_json::json!({})).unwrap(); .create_edge(alice, mutual1, "FRIEND", serde_json::json!({}))
store.create_edge(alice, only_alice, "FRIEND", serde_json::json!({})).unwrap(); .unwrap();
store.create_edge(bob, mutual1, "FRIEND", serde_json::json!({})).unwrap(); store
store.create_edge(bob, mutual2, "FRIEND", serde_json::json!({})).unwrap(); .create_edge(alice, mutual2, "FRIEND", serde_json::json!({}))
.unwrap();
store
.create_edge(alice, only_alice, "FRIEND", serde_json::json!({}))
.unwrap();
store
.create_edge(bob, mutual1, "FRIEND", serde_json::json!({}))
.unwrap();
store
.create_edge(bob, mutual2, "FRIEND", serde_json::json!({}))
.unwrap();
let traverser = Traverser::new(&store); let traverser = Traverser::new(&store);
let mutual = traverser.mutual_connections(&alice, &bob, Some("FRIEND")); let mutual = traverser.mutual_connections(&alice, &bob, Some("FRIEND"));

View file

@ -174,17 +174,24 @@ impl Index {
// Check uniqueness if required // Check uniqueness if required
if self.config.unique { if self.config.unique {
let exists = match self.config.index_type { let exists = match self.config.index_type {
IndexType::Hash | IndexType::Unique => { IndexType::Hash | IndexType::Unique => self
self.hash.read().get(&key).map(|s| !s.is_empty()).unwrap_or(false) .hash
} .read()
_ => { .get(&key)
self.btree.read().get(&key).map(|s| !s.is_empty()).unwrap_or(false) .map(|s| !s.is_empty())
} .unwrap_or(false),
_ => self
.btree
.read()
.get(&key)
.map(|s| !s.is_empty())
.unwrap_or(false),
}; };
if exists { if exists {
return Err(DatabaseError::AlreadyExists( return Err(DatabaseError::AlreadyExists(format!(
format!("Unique constraint violation on index '{}'", self.config.name) "Unique constraint violation on index '{}'",
)); self.config.name
)));
} }
} }
@ -239,20 +246,18 @@ impl Index {
self.stats.write().lookups += 1; self.stats.write().lookups += 1;
let result: Vec<DocumentId> = match self.config.index_type { let result: Vec<DocumentId> = match self.config.index_type {
IndexType::Hash | IndexType::Unique => { IndexType::Hash | IndexType::Unique => self
self.hash .hash
.read() .read()
.get(&key) .get(&key)
.map(|s| s.iter().cloned().collect()) .map(|s| s.iter().cloned().collect())
.unwrap_or_default() .unwrap_or_default(),
} _ => self
_ => { .btree
self.btree .read()
.read() .get(&key)
.get(&key) .map(|s| s.iter().cloned().collect())
.map(|s| s.iter().cloned().collect()) .unwrap_or_default(),
.unwrap_or_default()
}
}; };
if !result.is_empty() { if !result.is_empty() {
@ -407,12 +412,7 @@ impl IndexManager {
} }
/// Removes a document from indexes. /// Removes a document from indexes.
pub fn unindex_document( pub fn unindex_document(&self, collection: &str, doc_id: &DocumentId, document: &JsonValue) {
&self,
collection: &str,
doc_id: &DocumentId,
document: &JsonValue,
) {
let index_names = self.get_collection_indexes(collection); let index_names = self.get_collection_indexes(collection);
let indexes = self.indexes.read(); let indexes = self.indexes.read();
@ -483,7 +483,9 @@ mod tests {
let index = Index::new(config); let index = Index::new(config);
let doc1 = DocumentId::new(); let doc1 = DocumentId::new();
index.insert(doc1.clone(), &json!("alice@example.com")).unwrap(); index
.insert(doc1.clone(), &json!("alice@example.com"))
.unwrap();
let results = index.lookup(&json!("alice@example.com")); let results = index.lookup(&json!("alice@example.com"));
assert_eq!(results.len(), 1); assert_eq!(results.len(), 1);
@ -521,7 +523,9 @@ mod tests {
let doc_id = DocumentId::new(); let doc_id = DocumentId::new();
let doc = json!({"name": "Alice", "age": 30}); let doc = json!({"name": "Alice", "age": 30});
manager.index_document("users", doc_id.clone(), &doc).unwrap(); manager
.index_document("users", doc_id.clone(), &doc)
.unwrap();
let indexes = manager.list_indexes(); let indexes = manager.list_indexes();
assert_eq!(indexes.len(), 1); assert_eq!(indexes.len(), 1);

View file

@ -126,8 +126,7 @@ impl KeyValueStore {
/// Gets a value as string. /// Gets a value as string.
pub fn get_string(&self, key: &str) -> Option<String> { pub fn get_string(&self, key: &str) -> Option<String> {
self.get(key) self.get(key).and_then(|v| String::from_utf8(v).ok())
.and_then(|v| String::from_utf8(v).ok())
} }
/// Sets a value with optional TTL. /// Sets a value with optional TTL.
@ -224,8 +223,9 @@ impl KeyValueStore {
} else { } else {
let s = String::from_utf8(entry.value.clone()) let s = String::from_utf8(entry.value.clone())
.map_err(|_| DatabaseError::InvalidOperation("Value is not a string".into()))?; .map_err(|_| DatabaseError::InvalidOperation("Value is not a string".into()))?;
s.parse::<i64>() s.parse::<i64>().map_err(|_| {
.map_err(|_| DatabaseError::InvalidOperation("Value is not an integer".into()))? DatabaseError::InvalidOperation("Value is not an integer".into())
})?
} }
} else { } else {
0 0
@ -243,9 +243,9 @@ impl KeyValueStore {
pub fn append(&self, key: &str, value: &[u8]) -> Result<usize, DatabaseError> { pub fn append(&self, key: &str, value: &[u8]) -> Result<usize, DatabaseError> {
let mut data = self.data.write(); let mut data = self.data.write();
let entry = data.entry(key.to_string()).or_insert_with(|| { let entry = data
KvEntry::new(Vec::new(), 0) .entry(key.to_string())
}); .or_insert_with(|| KvEntry::new(Vec::new(), 0));
if entry.is_expired() { if entry.is_expired() {
entry.value.clear(); entry.value.clear();
@ -393,11 +393,16 @@ mod tests {
fn test_mget_mset() { fn test_mget_mset() {
let store = KeyValueStore::new(); let store = KeyValueStore::new();
store.mset(&[ store
("k1", b"v1".to_vec()), .mset(
("k2", b"v2".to_vec()), &[
("k3", b"v3".to_vec()), ("k1", b"v1".to_vec()),
], 0).unwrap(); ("k2", b"v2".to_vec()),
("k3", b"v3".to_vec()),
],
0,
)
.unwrap();
let results = store.mget(&["k1", "k2", "k4"]); let results = store.mget(&["k1", "k2", "k4"]);
assert_eq!(results.len(), 3); assert_eq!(results.len(), 3);

View file

@ -65,12 +65,14 @@ pub use graph::{
pub use index::{Index, IndexConfig, IndexManager, IndexType}; pub use index::{Index, IndexConfig, IndexManager, IndexType};
pub use keyvalue::{KeyValue, KeyValueStore, KvEntry}; pub use keyvalue::{KeyValue, KeyValueStore, KvEntry};
pub use query::{Filter, Query, QueryEngine, QueryResult, SortOrder}; pub use query::{Filter, Query, QueryEngine, QueryResult, SortOrder};
pub use schema::{Field, FieldType, Schema, SchemaValidator};
pub use replication::{ pub use replication::{
ClusterConfig, Command as RaftCommand, NodeRole, RaftConfig, RaftEvent, RaftNode, RaftState, ClusterConfig, Command as RaftCommand, NodeRole, RaftConfig, RaftEvent, RaftNode, RaftState,
ReplicatedLog, ReplicatedLog,
}; };
pub use sql::{QueryResult as SqlQueryResult, SqlEngine, SqlParser, SqlType, SqlValue, Table, TableDef}; pub use schema::{Field, FieldType, Schema, SchemaValidator};
pub use sql::{
QueryResult as SqlQueryResult, SqlEngine, SqlParser, SqlType, SqlValue, Table, TableDef,
};
pub use timeseries::{DataPoint, Metric, TimeSeries, TimeSeriesStore}; pub use timeseries::{DataPoint, Metric, TimeSeries, TimeSeriesStore};
pub use vector::{Embedding, SimilarityMetric, VectorIndex, VectorStore}; pub use vector::{Embedding, SimilarityMetric, VectorIndex, VectorStore};

View file

@ -419,10 +419,7 @@ impl QueryEngine {
let values: Vec<f64> = docs let values: Vec<f64> = docs
.iter() .iter()
.filter_map(|doc| { .filter_map(|doc| doc.get(field).and_then(|v| v.as_f64()))
doc.get(field)
.and_then(|v| v.as_f64())
})
.collect(); .collect();
let result = match op { let result = match op {
@ -439,22 +436,18 @@ impl QueryEngine {
serde_json::to_value(avg).unwrap_or(JsonValue::Null) serde_json::to_value(avg).unwrap_or(JsonValue::Null)
} }
} }
AggregateOp::Min => { AggregateOp::Min => values
values .iter()
.iter() .copied()
.copied() .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)) .map(|v| serde_json::to_value(v).unwrap_or(JsonValue::Null))
.map(|v| serde_json::to_value(v).unwrap_or(JsonValue::Null)) .unwrap_or(JsonValue::Null),
.unwrap_or(JsonValue::Null) AggregateOp::Max => values
} .iter()
AggregateOp::Max => { .copied()
values .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.iter() .map(|v| serde_json::to_value(v).unwrap_or(JsonValue::Null))
.copied() .unwrap_or(JsonValue::Null),
.max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|v| serde_json::to_value(v).unwrap_or(JsonValue::Null))
.unwrap_or(JsonValue::Null)
}
}; };
Ok(result) Ok(result)
@ -507,8 +500,10 @@ mod tests {
fn test_simple_query() { fn test_simple_query() {
let docs = Arc::new(DocumentStore::new()); let docs = Arc::new(DocumentStore::new());
docs.create_collection("users").unwrap(); docs.create_collection("users").unwrap();
docs.insert("users", json!({"name": "Alice", "age": 30})).unwrap(); docs.insert("users", json!({"name": "Alice", "age": 30}))
docs.insert("users", json!({"name": "Bob", "age": 25})).unwrap(); .unwrap();
docs.insert("users", json!({"name": "Bob", "age": 25}))
.unwrap();
let vectors = Arc::new(VectorStore::new(3)); let vectors = Arc::new(VectorStore::new(3));
let indexes = Arc::new(IndexManager::new()); let indexes = Arc::new(IndexManager::new());
@ -525,8 +520,10 @@ mod tests {
fn test_filter_query() { fn test_filter_query() {
let docs = Arc::new(DocumentStore::new()); let docs = Arc::new(DocumentStore::new());
docs.create_collection("users").unwrap(); docs.create_collection("users").unwrap();
docs.insert("users", json!({"name": "Alice", "age": 30})).unwrap(); docs.insert("users", json!({"name": "Alice", "age": 30}))
docs.insert("users", json!({"name": "Bob", "age": 25})).unwrap(); .unwrap();
docs.insert("users", json!({"name": "Bob", "age": 25}))
.unwrap();
let vectors = Arc::new(VectorStore::new(3)); let vectors = Arc::new(VectorStore::new(3));
let indexes = Arc::new(IndexManager::new()); let indexes = Arc::new(IndexManager::new());
@ -543,9 +540,12 @@ mod tests {
fn test_sorted_query() { fn test_sorted_query() {
let docs = Arc::new(DocumentStore::new()); let docs = Arc::new(DocumentStore::new());
docs.create_collection("users").unwrap(); docs.create_collection("users").unwrap();
docs.insert("users", json!({"name": "Alice", "age": 30})).unwrap(); docs.insert("users", json!({"name": "Alice", "age": 30}))
docs.insert("users", json!({"name": "Bob", "age": 25})).unwrap(); .unwrap();
docs.insert("users", json!({"name": "Charlie", "age": 35})).unwrap(); docs.insert("users", json!({"name": "Bob", "age": 25}))
.unwrap();
docs.insert("users", json!({"name": "Charlie", "age": 35}))
.unwrap();
let vectors = Arc::new(VectorStore::new(3)); let vectors = Arc::new(VectorStore::new(3));
let indexes = Arc::new(IndexManager::new()); let indexes = Arc::new(IndexManager::new());

View file

@ -92,12 +92,7 @@ impl Election {
/// Creates a RequestVote message for this election. /// Creates a RequestVote message for this election.
pub fn create_request(&self, log: &ReplicatedLog) -> RequestVote { pub fn create_request(&self, log: &ReplicatedLog) -> RequestVote {
RequestVote::new( RequestVote::new(self.term, self.node_id, log.last_index(), log.last_term())
self.term,
self.node_id,
log.last_index(),
log.last_term(),
)
} }
/// Checks the current result of the election. /// Checks the current result of the election.
@ -217,8 +212,8 @@ impl Default for ElectionTimeout {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::replication::state::Command;
use crate::replication::log::LogEntry; use crate::replication::log::LogEntry;
use crate::replication::state::Command;
#[test] #[test]
fn test_election_basic() { fn test_election_basic() {

View file

@ -139,7 +139,10 @@ impl ReplicatedLog {
let entries = self.entries.read(); let entries = self.entries.read();
let offset = (from_index - start) as usize; let offset = (from_index - start) as usize;
entries.get(offset..).map(|s| s.to_vec()).unwrap_or_default() entries
.get(offset..)
.map(|s| s.to_vec())
.unwrap_or_default()
} }
/// Appends an entry to the log. /// Appends an entry to the log.
@ -151,7 +154,12 @@ impl ReplicatedLog {
} }
/// Appends multiple entries, potentially overwriting conflicting entries. /// Appends multiple entries, potentially overwriting conflicting entries.
pub fn append_entries(&self, prev_index: u64, prev_term: u64, new_entries: Vec<LogEntry>) -> bool { pub fn append_entries(
&self,
prev_index: u64,
prev_term: u64,
new_entries: Vec<LogEntry>,
) -> bool {
// Check that prev entry matches // Check that prev entry matches
if prev_index > 0 { if prev_index > 0 {
if let Some(prev_entry_term) = self.term_at(prev_index) { if let Some(prev_entry_term) = self.term_at(prev_index) {
@ -245,7 +253,11 @@ impl ReplicatedLog {
} }
/// Creates entries for replication starting from a given index. /// Creates entries for replication starting from a given index.
pub fn entries_for_replication(&self, from_index: u64, max_entries: usize) -> (u64, u64, Vec<LogEntry>) { pub fn entries_for_replication(
&self,
from_index: u64,
max_entries: usize,
) -> (u64, u64, Vec<LogEntry>) {
let prev_index = from_index.saturating_sub(1); let prev_index = from_index.saturating_sub(1);
let prev_term = self.term_at(prev_index).unwrap_or(0); let prev_term = self.term_at(prev_index).unwrap_or(0);

View file

@ -222,7 +222,11 @@ impl RaftNode {
// Create new election // Create new election
let cluster_size = self.cluster.voting_members(); let cluster_size = self.cluster.voting_members();
self.election = Some(Election::new(self.id, self.state.current_term, cluster_size)); self.election = Some(Election::new(
self.id,
self.state.current_term,
cluster_size,
));
// Create RequestVote message // Create RequestVote message
let request = RequestVote::new( let request = RequestVote::new(
@ -295,9 +299,9 @@ impl RaftNode {
return; return;
} }
let (prev_log_index, prev_log_term, entries) = let (prev_log_index, prev_log_term, entries) = self
self.log .log
.entries_for_replication(next_index, self.config.max_entries_per_rpc); .entries_for_replication(next_index, self.config.max_entries_per_rpc);
let request = AppendEntries::with_entries( let request = AppendEntries::with_entries(
self.state.current_term, self.state.current_term,
@ -308,8 +312,10 @@ impl RaftNode {
self.state.commit_index, self.state.commit_index,
); );
self.events self.events.push(RaftEvent::SendRpc(
.push(RaftEvent::SendRpc(peer_id, RpcMessage::AppendEntries(request))); peer_id,
RpcMessage::AppendEntries(request),
));
} }
fn send_install_snapshot(&mut self, peer_id: NodeId) { fn send_install_snapshot(&mut self, peer_id: NodeId) {
@ -332,8 +338,10 @@ impl RaftNode {
done, done,
); );
self.events self.events.push(RaftEvent::SendRpc(
.push(RaftEvent::SendRpc(peer_id, RpcMessage::InstallSnapshot(request))); peer_id,
RpcMessage::InstallSnapshot(request),
));
} }
} }
@ -395,7 +403,11 @@ impl RaftNode {
} }
} }
fn handle_append_entries(&mut self, _from: NodeId, req: AppendEntries) -> AppendEntriesResponse { fn handle_append_entries(
&mut self,
_from: NodeId,
req: AppendEntries,
) -> AppendEntriesResponse {
// Rule: If term > currentTerm, become follower // Rule: If term > currentTerm, become follower
if req.term > self.state.current_term { if req.term > self.state.current_term {
self.become_follower(req.term, Some(req.leader_id)); self.become_follower(req.term, Some(req.leader_id));
@ -416,9 +428,9 @@ impl RaftNode {
} }
// Try to append entries // Try to append entries
let success = let success = self
self.log .log
.append_entries(req.prev_log_index, req.prev_log_term, req.entries); .append_entries(req.prev_log_index, req.prev_log_term, req.entries);
if success { if success {
// Update commit index // Update commit index
@ -443,7 +455,11 @@ impl RaftNode {
} }
conflict_index -= 1; conflict_index -= 1;
} }
AppendEntriesResponse::conflict(self.state.current_term, conflict_term, conflict_index) AppendEntriesResponse::conflict(
self.state.current_term,
conflict_term,
conflict_index,
)
} else { } else {
AppendEntriesResponse::failure(self.state.current_term) AppendEntriesResponse::failure(self.state.current_term)
} }
@ -502,7 +518,11 @@ impl RaftNode {
self.cluster.update_peer_state(from, PeerState::Reachable); self.cluster.update_peer_state(from, PeerState::Reachable);
} }
fn handle_install_snapshot(&mut self, _from: NodeId, req: InstallSnapshot) -> InstallSnapshotResponse { fn handle_install_snapshot(
&mut self,
_from: NodeId,
req: InstallSnapshot,
) -> InstallSnapshotResponse {
// Rule: If term > currentTerm, become follower // Rule: If term > currentTerm, become follower
if req.term > self.state.current_term { if req.term > self.state.current_term {
self.become_follower(req.term, Some(req.leader_id)); self.become_follower(req.term, Some(req.leader_id));
@ -692,12 +712,14 @@ impl RaftNode {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*;
use super::super::cluster::PeerAddress; use super::super::cluster::PeerAddress;
use super::*;
fn create_test_cluster(node_id: NodeId, peers: &[NodeId]) -> ClusterConfig { fn create_test_cluster(node_id: NodeId, peers: &[NodeId]) -> ClusterConfig {
let mut cluster = let mut cluster = ClusterConfig::new(
ClusterConfig::new(node_id, PeerAddress::new("127.0.0.1", 9000 + node_id as u16)); node_id,
PeerAddress::new("127.0.0.1", 9000 + node_id as u16),
);
for &peer in peers { for &peer in peers {
cluster.add_peer(super::super::cluster::PeerInfo::new( cluster.add_peer(super::super::cluster::PeerInfo::new(
peer, peer,

View file

@ -176,7 +176,10 @@ impl SnapshotManager {
/// Adds a chunk to the pending snapshot. /// Adds a chunk to the pending snapshot.
pub fn add_chunk(&mut self, offset: u64, data: Vec<u8>) -> bool { pub fn add_chunk(&mut self, offset: u64, data: Vec<u8>) -> bool {
if let Some(ref mut pending) = self.pending_snapshot { if let Some(ref mut pending) = self.pending_snapshot {
if offset == pending.expected_offset + pending.chunks.iter().map(|c| c.len() as u64).sum::<u64>() { if offset
== pending.expected_offset
+ pending.chunks.iter().map(|c| c.len() as u64).sum::<u64>()
{
pending.chunks.push(data); pending.chunks.push(data);
return true; return true;
} }

View file

@ -118,31 +118,55 @@ pub enum Command {
// Key-Value operations // Key-Value operations
/// Set a key-value pair. /// Set a key-value pair.
KvSet { key: String, value: Vec<u8>, ttl: Option<u64> }, KvSet {
key: String,
value: Vec<u8>,
ttl: Option<u64>,
},
/// Delete a key. /// Delete a key.
KvDelete { key: String }, KvDelete { key: String },
// Document operations // Document operations
/// Insert a document. /// Insert a document.
DocInsert { collection: String, document: JsonValue }, DocInsert {
collection: String,
document: JsonValue,
},
/// Update a document. /// Update a document.
DocUpdate { collection: String, id: String, update: JsonValue }, DocUpdate {
collection: String,
id: String,
update: JsonValue,
},
/// Delete a document. /// Delete a document.
DocDelete { collection: String, id: String }, DocDelete { collection: String, id: String },
// Vector operations // Vector operations
/// Insert a vector. /// Insert a vector.
VectorInsert { namespace: String, id: String, vector: Vec<f32>, metadata: JsonValue }, VectorInsert {
namespace: String,
id: String,
vector: Vec<f32>,
metadata: JsonValue,
},
/// Delete a vector. /// Delete a vector.
VectorDelete { namespace: String, id: String }, VectorDelete { namespace: String, id: String },
// Time-series operations // Time-series operations
/// Record a metric data point. /// Record a metric data point.
TimeSeriesRecord { metric: String, value: f64, timestamp: u64, tags: JsonValue }, TimeSeriesRecord {
metric: String,
value: f64,
timestamp: u64,
tags: JsonValue,
},
// Graph operations // Graph operations
/// Create a graph node. /// Create a graph node.
GraphNodeCreate { labels: Vec<String>, properties: JsonValue }, GraphNodeCreate {
labels: Vec<String>,
properties: JsonValue,
},
/// Delete a graph node. /// Delete a graph node.
GraphNodeDelete { id: String }, GraphNodeDelete { id: String },
/// Create a graph edge. /// Create a graph edge.
@ -161,13 +185,20 @@ pub enum Command {
// Schema operations // Schema operations
/// Create a collection/table. /// Create a collection/table.
CreateCollection { name: String, schema: Option<JsonValue> }, CreateCollection {
name: String,
schema: Option<JsonValue>,
},
/// Drop a collection/table. /// Drop a collection/table.
DropCollection { name: String }, DropCollection { name: String },
// Index operations // Index operations
/// Create an index. /// Create an index.
CreateIndex { collection: String, field: String, index_type: String }, CreateIndex {
collection: String,
field: String,
index_type: String,
},
/// Drop an index. /// Drop an index.
DropIndex { name: String }, DropIndex { name: String },
@ -265,7 +296,12 @@ impl LeaderState {
} }
/// Calculates the new commit index based on majority replication. /// Calculates the new commit index based on majority replication.
pub fn calculate_commit_index(&self, current_commit: u64, current_term: u64, log_term_at: impl Fn(u64) -> Option<u64>) -> u64 { pub fn calculate_commit_index(
&self,
current_commit: u64,
current_term: u64,
log_term_at: impl Fn(u64) -> Option<u64>,
) -> u64 {
// Find the highest index that a majority have replicated // Find the highest index that a majority have replicated
let mut indices: Vec<u64> = self.match_index.values().cloned().collect(); let mut indices: Vec<u64> = self.match_index.values().cloned().collect();
indices.sort_unstable(); indices.sort_unstable();

View file

@ -315,8 +315,8 @@ mod tests {
#[test] #[test]
fn test_vector_field() { fn test_vector_field() {
let schema = Schema::new("embedding") let schema =
.field(Field::required("vector", FieldType::Vector(3))); Schema::new("embedding").field(Field::required("vector", FieldType::Vector(3)));
let mut validator = SchemaValidator::new(); let mut validator = SchemaValidator::new();
validator.register(schema); validator.register(schema);

View file

@ -1,8 +1,7 @@
//! SQL query executor. //! SQL query executor.
use super::parser::{ use super::parser::{
BinaryOp, ParsedExpr, ParsedSelect, ParsedSelectItem, BinaryOp, ParsedExpr, ParsedSelect, ParsedSelectItem, ParsedStatement, SqlParser,
ParsedStatement, SqlParser,
}; };
use super::row::{Row, RowId}; use super::row::{Row, RowId};
use super::table::{ColumnDef, Table, TableDef}; use super::table::{ColumnDef, Table, TableDef};
@ -192,11 +191,7 @@ impl SqlEngine {
match a_val.partial_cmp(&b_val) { match a_val.partial_cmp(&b_val) {
Some(std::cmp::Ordering::Equal) => continue, Some(std::cmp::Ordering::Equal) => continue,
Some(ord) => { Some(ord) => {
return if ob.ascending { return if ob.ascending { ord } else { ord.reverse() };
ord
} else {
ord.reverse()
};
} }
None => continue, None => continue,
} }
@ -216,7 +211,11 @@ impl SqlEngine {
} }
// Handle aggregates // Handle aggregates
if select.columns.iter().any(|c| matches!(c, ParsedSelectItem::Aggregate { .. })) { if select
.columns
.iter()
.any(|c| matches!(c, ParsedSelectItem::Aggregate { .. }))
{
return self.execute_aggregate(select, &rows, table); return self.execute_aggregate(select, &rows, table);
} }
@ -244,7 +243,9 @@ impl SqlEngine {
ParsedSelectItem::Wildcard => table.def.column_names(), ParsedSelectItem::Wildcard => table.def.column_names(),
ParsedSelectItem::Column(name) => vec![name.clone()], ParsedSelectItem::Column(name) => vec![name.clone()],
ParsedSelectItem::ColumnAlias { alias, .. } => vec![alias.clone()], ParsedSelectItem::ColumnAlias { alias, .. } => vec![alias.clone()],
ParsedSelectItem::Aggregate { function, alias, .. } => { ParsedSelectItem::Aggregate {
function, alias, ..
} => {
vec![alias.clone().unwrap_or_else(|| function.clone())] vec![alias.clone().unwrap_or_else(|| function.clone())]
} }
}) })
@ -328,7 +329,9 @@ impl SqlEngine {
rows.iter() rows.iter()
.map(|r| r.get_or_null(col)) .map(|r| r.get_or_null(col))
.filter(|v| !v.is_null()) .filter(|v| !v.is_null())
.min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)) .min_by(|a, b| {
a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
})
.unwrap_or(SqlValue::Null) .unwrap_or(SqlValue::Null)
} }
"MAX" => { "MAX" => {
@ -338,12 +341,12 @@ impl SqlEngine {
rows.iter() rows.iter()
.map(|r| r.get_or_null(col)) .map(|r| r.get_or_null(col))
.filter(|v| !v.is_null()) .filter(|v| !v.is_null())
.max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)) .max_by(|a, b| {
a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
})
.unwrap_or(SqlValue::Null) .unwrap_or(SqlValue::Null)
} }
_ => { _ => return Err(SqlError::Unsupported(format!("Function: {}", function))),
return Err(SqlError::Unsupported(format!("Function: {}", function)))
}
}; };
result_values.push(value); result_values.push(value);
} }
@ -404,7 +407,11 @@ impl SqlEngine {
ParsedExpr::IsNotNull(inner) => { ParsedExpr::IsNotNull(inner) => {
SqlValue::Boolean(!self.evaluate_expr(row, inner).is_null()) SqlValue::Boolean(!self.evaluate_expr(row, inner).is_null())
} }
ParsedExpr::InList { expr, list, negated } => { ParsedExpr::InList {
expr,
list,
negated,
} => {
let val = self.evaluate_expr(row, expr); let val = self.evaluate_expr(row, expr);
let in_list = list.iter().any(|item| { let in_list = list.iter().any(|item| {
let item_val = self.evaluate_expr(row, item); let item_val = self.evaluate_expr(row, item);
@ -424,9 +431,7 @@ impl SqlEngine {
let between = val >= low_val && val <= high_val; let between = val >= low_val && val <= high_val;
SqlValue::Boolean(if *negated { !between } else { between }) SqlValue::Boolean(if *negated { !between } else { between })
} }
ParsedExpr::Function { name, args } => { ParsedExpr::Function { name, args } => self.evaluate_function(row, name, args),
self.evaluate_function(row, name, args)
}
} }
} }
@ -474,9 +479,7 @@ impl SqlEngine {
_ => SqlValue::Null, _ => SqlValue::Null,
}, },
BinaryOp::Divide => match (left, right) { BinaryOp::Divide => match (left, right) {
(SqlValue::Integer(a), SqlValue::Integer(b)) if *b != 0 => { (SqlValue::Integer(a), SqlValue::Integer(b)) if *b != 0 => SqlValue::Integer(a / b),
SqlValue::Integer(a / b)
}
(SqlValue::Real(a), SqlValue::Real(b)) if *b != 0.0 => SqlValue::Real(a / b), (SqlValue::Real(a), SqlValue::Real(b)) if *b != 0.0 => SqlValue::Real(a / b),
_ => SqlValue::Null, _ => SqlValue::Null,
}, },
@ -536,9 +539,7 @@ impl SqlEngine {
/// Matches a LIKE pattern. /// Matches a LIKE pattern.
fn match_like(&self, text: &str, pattern: &str) -> bool { fn match_like(&self, text: &str, pattern: &str) -> bool {
// Simple LIKE implementation: % = any chars, _ = single char // Simple LIKE implementation: % = any chars, _ = single char
let _regex_pattern = pattern let _regex_pattern = pattern.replace('%', ".*").replace('_', ".");
.replace('%', ".*")
.replace('_', ".");
// For simplicity, just do case-insensitive contains for now // For simplicity, just do case-insensitive contains for now
if pattern.starts_with('%') && pattern.ends_with('%') { if pattern.starts_with('%') && pattern.ends_with('%') {
let inner = &pattern[1..pattern.len() - 1]; let inner = &pattern[1..pattern.len() - 1];
@ -615,8 +616,7 @@ impl SqlEngine {
.unwrap_or(true); .unwrap_or(true);
if matches { if matches {
let updates: HashMap<String, SqlValue> = let updates: HashMap<String, SqlValue> = assignments.iter().cloned().collect();
assignments.iter().cloned().collect();
table.update(row.id, updates)?; table.update(row.id, updates)?;
count += 1; count += 1;
} }
@ -672,9 +672,9 @@ impl SqlEngine {
.ok_or_else(|| SqlError::TableNotFound(table_name.to_string()))?; .ok_or_else(|| SqlError::TableNotFound(table_name.to_string()))?;
// For simplicity, only support single-column indexes // For simplicity, only support single-column indexes
let column = columns let column = columns.first().ok_or_else(|| {
.first() SqlError::InvalidOperation("Index requires at least one column".to_string())
.ok_or_else(|| SqlError::InvalidOperation("Index requires at least one column".to_string()))?; })?;
table.create_index(name, column, unique)?; table.create_index(name, column, unique)?;
Ok(QueryResult::empty()) Ok(QueryResult::empty())
@ -775,7 +775,9 @@ mod tests {
.execute("INSERT INTO users (id, name, age) VALUES (2, 'Bob', 25)") .execute("INSERT INTO users (id, name, age) VALUES (2, 'Bob', 25)")
.unwrap(); .unwrap();
let result = engine.execute("SELECT name FROM users WHERE age > 26").unwrap(); let result = engine
.execute("SELECT name FROM users WHERE age > 26")
.unwrap();
assert_eq!(result.rows.len(), 1); assert_eq!(result.rows.len(), 1);
assert_eq!(result.rows[0][0], SqlValue::Text("Alice".to_string())); assert_eq!(result.rows[0][0], SqlValue::Text("Alice".to_string()));
} }
@ -806,7 +808,9 @@ mod tests {
engine engine
.execute(&format!( .execute(&format!(
"INSERT INTO users (id, name, age) VALUES ({}, 'User{}', {})", "INSERT INTO users (id, name, age) VALUES ({}, 'User{}', {})",
i, i, 20 + i i,
i,
20 + i
)) ))
.unwrap(); .unwrap();
} }

View file

@ -18,10 +18,7 @@ pub enum ParsedStatement {
if_not_exists: bool, if_not_exists: bool,
}, },
/// DROP TABLE statement. /// DROP TABLE statement.
DropTable { DropTable { name: String, if_exists: bool },
name: String,
if_exists: bool,
},
/// SELECT statement. /// SELECT statement.
Select(ParsedSelect), Select(ParsedSelect),
/// INSERT statement. /// INSERT statement.
@ -179,15 +176,17 @@ impl SqlParser {
/// Parses a SQL statement. /// Parses a SQL statement.
pub fn parse(sql: &str) -> Result<ParsedStatement, SqlError> { pub fn parse(sql: &str) -> Result<ParsedStatement, SqlError> {
let dialect = SQLiteDialect {}; let dialect = SQLiteDialect {};
let statements = Parser::parse_sql(&dialect, sql) let statements =
.map_err(|e| SqlError::Parse(e.to_string()))?; Parser::parse_sql(&dialect, sql).map_err(|e| SqlError::Parse(e.to_string()))?;
if statements.is_empty() { if statements.is_empty() {
return Err(SqlError::Parse("Empty SQL statement".to_string())); return Err(SqlError::Parse("Empty SQL statement".to_string()));
} }
if statements.len() > 1 { if statements.len() > 1 {
return Err(SqlError::Parse("Multiple statements not supported".to_string())); return Err(SqlError::Parse(
"Multiple statements not supported".to_string(),
));
} }
Self::convert_statement(&statements[0]) Self::convert_statement(&statements[0])
@ -195,25 +194,42 @@ impl SqlParser {
fn convert_statement(stmt: &Statement) -> Result<ParsedStatement, SqlError> { fn convert_statement(stmt: &Statement) -> Result<ParsedStatement, SqlError> {
match stmt { match stmt {
Statement::CreateTable { name, columns, if_not_exists, constraints, .. } => { Statement::CreateTable {
Self::convert_create_table(name, columns, constraints, *if_not_exists) name,
} columns,
Statement::Drop { object_type, names, if_exists, .. } => { if_not_exists,
Self::convert_drop(object_type, names, *if_exists) constraints,
} ..
} => Self::convert_create_table(name, columns, constraints, *if_not_exists),
Statement::Drop {
object_type,
names,
if_exists,
..
} => Self::convert_drop(object_type, names, *if_exists),
Statement::Query(query) => Self::convert_query(query), Statement::Query(query) => Self::convert_query(query),
Statement::Insert { table_name, columns, source, .. } => { Statement::Insert {
Self::convert_insert(table_name, columns, source) table_name,
} columns,
Statement::Update { table, assignments, selection, .. } => { source,
Self::convert_update(table, assignments, selection) ..
} } => Self::convert_insert(table_name, columns, source),
Statement::Delete { from, selection, .. } => { Statement::Update {
Self::convert_delete(from, selection) table,
} assignments,
Statement::CreateIndex { name, table_name, columns, unique, .. } => { selection,
Self::convert_create_index(name, table_name, columns, *unique) ..
} } => Self::convert_update(table, assignments, selection),
Statement::Delete {
from, selection, ..
} => Self::convert_delete(from, selection),
Statement::CreateIndex {
name,
table_name,
columns,
unique,
..
} => Self::convert_create_index(name, table_name, columns, *unique),
_ => Err(SqlError::Unsupported(format!("Statement not supported"))), _ => Err(SqlError::Unsupported(format!("Statement not supported"))),
} }
} }
@ -230,7 +246,12 @@ impl SqlParser {
// Extract primary keys from table constraints // Extract primary keys from table constraints
for constraint in constraints { for constraint in constraints {
if let sqlparser::ast::TableConstraint::Unique { columns: pk_cols, is_primary: true, .. } = constraint { if let sqlparser::ast::TableConstraint::Unique {
columns: pk_cols,
is_primary: true,
..
} = constraint
{
for col in pk_cols { for col in pk_cols {
primary_keys.push(col.value.clone()); primary_keys.push(col.value.clone());
} }
@ -296,10 +317,9 @@ impl SqlParser {
DataType::Real | DataType::Float(_) | DataType::Double | DataType::DoublePrecision => { DataType::Real | DataType::Float(_) | DataType::Double | DataType::DoublePrecision => {
Ok(SqlType::Real) Ok(SqlType::Real)
} }
DataType::Varchar(_) DataType::Varchar(_) | DataType::Char(_) | DataType::Text | DataType::String(_) => {
| DataType::Char(_) Ok(SqlType::Text)
| DataType::Text }
| DataType::String(_) => Ok(SqlType::Text),
DataType::Binary(_) | DataType::Varbinary(_) | DataType::Blob(_) => Ok(SqlType::Blob), DataType::Binary(_) | DataType::Varbinary(_) | DataType::Blob(_) => Ok(SqlType::Blob),
DataType::Boolean => Ok(SqlType::Boolean), DataType::Boolean => Ok(SqlType::Boolean),
DataType::Timestamp(_, _) | DataType::Date | DataType::Datetime(_) => { DataType::Timestamp(_, _) | DataType::Date | DataType::Datetime(_) => {
@ -367,10 +387,7 @@ impl SqlParser {
.collect(); .collect();
// Parse LIMIT/OFFSET // Parse LIMIT/OFFSET
let limit = query let limit = query.limit.as_ref().and_then(|l| Self::expr_to_usize(l));
.limit
.as_ref()
.and_then(|l| Self::expr_to_usize(l));
let offset = query let offset = query
.offset .offset
.as_ref() .as_ref()
@ -403,16 +420,18 @@ impl SqlParser {
Self::convert_select_expr(expr) Self::convert_select_expr(expr)
} }
} }
_ => Err(SqlError::Unsupported("Select item not supported".to_string())), _ => Err(SqlError::Unsupported(
"Select item not supported".to_string(),
)),
} }
} }
fn convert_select_expr(expr: &Expr) -> Result<ParsedSelectItem, SqlError> { fn convert_select_expr(expr: &Expr) -> Result<ParsedSelectItem, SqlError> {
match expr { match expr {
Expr::Identifier(id) => Ok(ParsedSelectItem::Column(id.value.clone())), Expr::Identifier(id) => Ok(ParsedSelectItem::Column(id.value.clone())),
Expr::CompoundIdentifier(ids) => { Expr::CompoundIdentifier(ids) => Ok(ParsedSelectItem::Column(
Ok(ParsedSelectItem::Column(ids.last().map(|i| i.value.clone()).unwrap_or_default())) ids.last().map(|i| i.value.clone()).unwrap_or_default(),
} )),
Expr::Function(func) => { Expr::Function(func) => {
let name = func.name.to_string().to_uppercase(); let name = func.name.to_string().to_uppercase();
// Try to extract column from first arg - simplified for compatibility // Try to extract column from first arg - simplified for compatibility
@ -423,14 +442,18 @@ impl SqlParser {
alias: None, alias: None,
}) })
} }
_ => Err(SqlError::Unsupported("Select expression not supported".to_string())), _ => Err(SqlError::Unsupported(
"Select expression not supported".to_string(),
)),
} }
} }
fn convert_table_factor(factor: &TableFactor) -> Result<String, SqlError> { fn convert_table_factor(factor: &TableFactor) -> Result<String, SqlError> {
match factor { match factor {
TableFactor::Table { name, .. } => Ok(name.to_string()), TableFactor::Table { name, .. } => Ok(name.to_string()),
_ => Err(SqlError::Unsupported("Table factor not supported".to_string())), _ => Err(SqlError::Unsupported(
"Table factor not supported".to_string(),
)),
} }
} }
@ -461,9 +484,9 @@ impl SqlParser {
fn convert_expr(expr: &Expr) -> Result<ParsedExpr, SqlError> { fn convert_expr(expr: &Expr) -> Result<ParsedExpr, SqlError> {
match expr { match expr {
Expr::Identifier(id) => Ok(ParsedExpr::Column(id.value.clone())), Expr::Identifier(id) => Ok(ParsedExpr::Column(id.value.clone())),
Expr::CompoundIdentifier(ids) => { Expr::CompoundIdentifier(ids) => Ok(ParsedExpr::Column(
Ok(ParsedExpr::Column(ids.last().map(|i| i.value.clone()).unwrap_or_default())) ids.last().map(|i| i.value.clone()).unwrap_or_default(),
} )),
Expr::Value(v) => Ok(ParsedExpr::Literal(Self::convert_value(v)?)), Expr::Value(v) => Ok(ParsedExpr::Literal(Self::convert_value(v)?)),
Expr::BinaryOp { left, op, right } => { Expr::BinaryOp { left, op, right } => {
let left = Box::new(Self::convert_expr(left)?); let left = Box::new(Self::convert_expr(left)?);
@ -471,17 +494,30 @@ impl SqlParser {
let op = Self::convert_binary_op(op)?; let op = Self::convert_binary_op(op)?;
Ok(ParsedExpr::BinaryOp { left, op, right }) Ok(ParsedExpr::BinaryOp { left, op, right })
} }
Expr::UnaryOp { op: sqlparser::ast::UnaryOperator::Not, expr } => { Expr::UnaryOp {
Ok(ParsedExpr::Not(Box::new(Self::convert_expr(expr)?))) op: sqlparser::ast::UnaryOperator::Not,
} expr,
} => Ok(ParsedExpr::Not(Box::new(Self::convert_expr(expr)?))),
Expr::IsNull(expr) => Ok(ParsedExpr::IsNull(Box::new(Self::convert_expr(expr)?))), Expr::IsNull(expr) => Ok(ParsedExpr::IsNull(Box::new(Self::convert_expr(expr)?))),
Expr::IsNotNull(expr) => Ok(ParsedExpr::IsNotNull(Box::new(Self::convert_expr(expr)?))), Expr::IsNotNull(expr) => Ok(ParsedExpr::IsNotNull(Box::new(Self::convert_expr(expr)?))),
Expr::InList { expr, list, negated } => Ok(ParsedExpr::InList { Expr::InList {
expr,
list,
negated,
} => Ok(ParsedExpr::InList {
expr: Box::new(Self::convert_expr(expr)?), expr: Box::new(Self::convert_expr(expr)?),
list: list.iter().map(Self::convert_expr).collect::<Result<_, _>>()?, list: list
.iter()
.map(Self::convert_expr)
.collect::<Result<_, _>>()?,
negated: *negated, negated: *negated,
}), }),
Expr::Between { expr, low, high, negated } => Ok(ParsedExpr::Between { Expr::Between {
expr,
low,
high,
negated,
} => Ok(ParsedExpr::Between {
expr: Box::new(Self::convert_expr(expr)?), expr: Box::new(Self::convert_expr(expr)?),
low: Box::new(Self::convert_expr(low)?), low: Box::new(Self::convert_expr(low)?),
high: Box::new(Self::convert_expr(high)?), high: Box::new(Self::convert_expr(high)?),
@ -490,10 +526,16 @@ impl SqlParser {
Expr::Like { expr, pattern, .. } => { Expr::Like { expr, pattern, .. } => {
let left = Box::new(Self::convert_expr(expr)?); let left = Box::new(Self::convert_expr(expr)?);
let right = Box::new(Self::convert_expr(pattern)?); let right = Box::new(Self::convert_expr(pattern)?);
Ok(ParsedExpr::BinaryOp { left, op: BinaryOp::Like, right }) Ok(ParsedExpr::BinaryOp {
left,
op: BinaryOp::Like,
right,
})
} }
Expr::Nested(inner) => Self::convert_expr(inner), Expr::Nested(inner) => Self::convert_expr(inner),
_ => Err(SqlError::Unsupported("Expression not supported".to_string())), _ => Err(SqlError::Unsupported(
"Expression not supported".to_string(),
)),
} }
} }
@ -587,7 +629,11 @@ impl SqlParser {
let parsed_assignments: Vec<(String, SqlValue)> = assignments let parsed_assignments: Vec<(String, SqlValue)> = assignments
.iter() .iter()
.map(|a| { .map(|a| {
let col = a.id.iter().map(|i| i.value.clone()).collect::<Vec<_>>().join("."); let col =
a.id.iter()
.map(|i| i.value.clone())
.collect::<Vec<_>>()
.join(".");
let val = Self::convert_value_expr(&a.value)?; let val = Self::convert_value_expr(&a.value)?;
Ok((col, val)) Ok((col, val))
}) })
@ -633,10 +679,7 @@ impl SqlParser {
let table = table_name.to_string(); let table = table_name.to_string();
let cols: Vec<String> = columns let cols: Vec<String> = columns.iter().map(|c| c.expr.to_string()).collect();
.iter()
.map(|c| c.expr.to_string())
.collect();
Ok(ParsedStatement::CreateIndex { Ok(ParsedStatement::CreateIndex {
name: index_name, name: index_name,
@ -694,7 +737,12 @@ mod tests {
let sql = "INSERT INTO users (name, age) VALUES ('Alice', 30), ('Bob', 25)"; let sql = "INSERT INTO users (name, age) VALUES ('Alice', 30), ('Bob', 25)";
let stmt = SqlParser::parse(sql).unwrap(); let stmt = SqlParser::parse(sql).unwrap();
if let ParsedStatement::Insert { table, columns, values } = stmt { if let ParsedStatement::Insert {
table,
columns,
values,
} = stmt
{
assert_eq!(table, "users"); assert_eq!(table, "users");
assert_eq!(columns, vec!["name", "age"]); assert_eq!(columns, vec!["name", "age"]);
assert_eq!(values.len(), 2); assert_eq!(values.len(), 2);
@ -708,7 +756,12 @@ mod tests {
let sql = "UPDATE users SET age = 31 WHERE name = 'Alice'"; let sql = "UPDATE users SET age = 31 WHERE name = 'Alice'";
let stmt = SqlParser::parse(sql).unwrap(); let stmt = SqlParser::parse(sql).unwrap();
if let ParsedStatement::Update { table, assignments, where_clause } = stmt { if let ParsedStatement::Update {
table,
assignments,
where_clause,
} = stmt
{
assert_eq!(table, "users"); assert_eq!(table, "users");
assert_eq!(assignments.len(), 1); assert_eq!(assignments.len(), 1);
assert!(where_clause.is_some()); assert!(where_clause.is_some());
@ -722,7 +775,11 @@ mod tests {
let sql = "DELETE FROM users WHERE age < 18"; let sql = "DELETE FROM users WHERE age < 18";
let stmt = SqlParser::parse(sql).unwrap(); let stmt = SqlParser::parse(sql).unwrap();
if let ParsedStatement::Delete { table, where_clause } = stmt { if let ParsedStatement::Delete {
table,
where_clause,
} = stmt
{
assert_eq!(table, "users"); assert_eq!(table, "users");
assert!(where_clause.is_some()); assert!(where_clause.is_some());
} else { } else {

View file

@ -85,7 +85,10 @@ impl Row {
/// Returns all values in column order. /// Returns all values in column order.
pub fn values(&self) -> Vec<&SqlValue> { pub fn values(&self) -> Vec<&SqlValue> {
self.columns.iter().map(|c| self.values.get(c).unwrap()).collect() self.columns
.iter()
.map(|c| self.values.get(c).unwrap())
.collect()
} }
/// Returns the number of columns. /// Returns the number of columns.

View file

@ -299,7 +299,10 @@ impl Table {
let mut indexes = self.indexes.write(); let mut indexes = self.indexes.write();
if indexes.contains_key(&name) { if indexes.contains_key(&name) {
return Err(SqlError::InvalidOperation(format!("Index '{}' already exists", name))); return Err(SqlError::InvalidOperation(format!(
"Index '{}' already exists",
name
)));
} }
let mut index = TableIndex::new(&name, &column, unique); let mut index = TableIndex::new(&name, &column, unique);
@ -319,7 +322,10 @@ impl Table {
pub fn drop_index(&self, name: &str) -> Result<(), SqlError> { pub fn drop_index(&self, name: &str) -> Result<(), SqlError> {
let mut indexes = self.indexes.write(); let mut indexes = self.indexes.write();
if indexes.remove(name).is_none() { if indexes.remove(name).is_none() {
return Err(SqlError::InvalidOperation(format!("Index '{}' not found", name))); return Err(SqlError::InvalidOperation(format!(
"Index '{}' not found",
name
)));
} }
Ok(()) Ok(())
} }
@ -371,9 +377,9 @@ impl Table {
/// Updates a row. /// Updates a row.
pub fn update(&self, id: RowId, updates: HashMap<String, SqlValue>) -> Result<(), SqlError> { pub fn update(&self, id: RowId, updates: HashMap<String, SqlValue>) -> Result<(), SqlError> {
let mut rows = self.rows.write(); let mut rows = self.rows.write();
let row = rows.get_mut(&id).ok_or_else(|| { let row = rows
SqlError::InvalidOperation(format!("Row {} not found", id)) .get_mut(&id)
})?; .ok_or_else(|| SqlError::InvalidOperation(format!("Row {} not found", id)))?;
let old_values: HashMap<String, SqlValue> = updates let old_values: HashMap<String, SqlValue> = updates
.keys() .keys()
@ -392,7 +398,10 @@ impl Table {
let mut indexes = self.indexes.write(); let mut indexes = self.indexes.write();
for (_, index) in indexes.iter_mut() { for (_, index) in indexes.iter_mut() {
if let Some(new_value) = updates.get(&index.column) { if let Some(new_value) = updates.get(&index.column) {
let old_value = old_values.get(&index.column).cloned().unwrap_or(SqlValue::Null); let old_value = old_values
.get(&index.column)
.cloned()
.unwrap_or(SqlValue::Null);
index.remove(&old_value, &id); index.remove(&old_value, &id);
index.insert(new_value.clone(), id)?; index.insert(new_value.clone(), id)?;
} }
@ -480,7 +489,10 @@ mod tests {
values.insert("id".to_string(), SqlValue::Integer(1)); values.insert("id".to_string(), SqlValue::Integer(1));
values.insert("name".to_string(), SqlValue::Text("Alice".to_string())); values.insert("name".to_string(), SqlValue::Text("Alice".to_string()));
values.insert("age".to_string(), SqlValue::Integer(30)); values.insert("age".to_string(), SqlValue::Integer(30));
values.insert("email".to_string(), SqlValue::Text("alice@example.com".to_string())); values.insert(
"email".to_string(),
SqlValue::Text("alice@example.com".to_string()),
);
let row_id = table.insert(values).unwrap(); let row_id = table.insert(values).unwrap();
assert_eq!(table.count(), 1); assert_eq!(table.count(), 1);
@ -508,13 +520,19 @@ mod tests {
let mut values1 = HashMap::new(); let mut values1 = HashMap::new();
values1.insert("id".to_string(), SqlValue::Integer(1)); values1.insert("id".to_string(), SqlValue::Integer(1));
values1.insert("name".to_string(), SqlValue::Text("Alice".to_string())); values1.insert("name".to_string(), SqlValue::Text("Alice".to_string()));
values1.insert("email".to_string(), SqlValue::Text("test@example.com".to_string())); values1.insert(
"email".to_string(),
SqlValue::Text("test@example.com".to_string()),
);
table.insert(values1).unwrap(); table.insert(values1).unwrap();
let mut values2 = HashMap::new(); let mut values2 = HashMap::new();
values2.insert("id".to_string(), SqlValue::Integer(2)); values2.insert("id".to_string(), SqlValue::Integer(2));
values2.insert("name".to_string(), SqlValue::Text("Bob".to_string())); values2.insert("name".to_string(), SqlValue::Text("Bob".to_string()));
values2.insert("email".to_string(), SqlValue::Text("test@example.com".to_string())); values2.insert(
"email".to_string(),
SqlValue::Text("test@example.com".to_string()),
);
let result = table.insert(values2); let result = table.insert(values2);
assert!(result.is_err()); // Duplicate email assert!(result.is_err()); // Duplicate email

View file

@ -124,7 +124,12 @@ impl Transaction {
} }
/// Records an insert operation. /// Records an insert operation.
pub fn record_insert(&mut self, table: String, row_id: RowId, values: HashMap<String, SqlValue>) { pub fn record_insert(
&mut self,
table: String,
row_id: RowId,
values: HashMap<String, SqlValue>,
) {
self.operations.push(TransactionOp::Insert { self.operations.push(TransactionOp::Insert {
table, table,
row_id, row_id,
@ -149,7 +154,12 @@ impl Transaction {
} }
/// Records a delete operation. /// Records a delete operation.
pub fn record_delete(&mut self, table: String, row_id: RowId, old_values: HashMap<String, SqlValue>) { pub fn record_delete(
&mut self,
table: String,
row_id: RowId,
old_values: HashMap<String, SqlValue>,
) {
self.operations.push(TransactionOp::Delete { self.operations.push(TransactionOp::Delete {
table, table,
row_id, row_id,
@ -213,7 +223,10 @@ impl TransactionManager {
.ok_or_else(|| SqlError::Transaction(format!("Transaction {} not found", id)))?; .ok_or_else(|| SqlError::Transaction(format!("Transaction {} not found", id)))?;
if !txn.is_active() { if !txn.is_active() {
return Err(SqlError::Transaction(format!("Transaction {} is not active", id))); return Err(SqlError::Transaction(format!(
"Transaction {} is not active",
id
)));
} }
txn.operations.push(op); txn.operations.push(op);
@ -228,7 +241,10 @@ impl TransactionManager {
.ok_or_else(|| SqlError::Transaction(format!("Transaction {} not found", id)))?; .ok_or_else(|| SqlError::Transaction(format!("Transaction {} not found", id)))?;
if !txn.is_active() { if !txn.is_active() {
return Err(SqlError::Transaction(format!("Transaction {} is not active", id))); return Err(SqlError::Transaction(format!(
"Transaction {} is not active",
id
)));
} }
txn.mark_committed(); txn.mark_committed();
@ -245,7 +261,10 @@ impl TransactionManager {
.ok_or_else(|| SqlError::Transaction(format!("Transaction {} not found", id)))?; .ok_or_else(|| SqlError::Transaction(format!("Transaction {} not found", id)))?;
if !txn.is_active() { if !txn.is_active() {
return Err(SqlError::Transaction(format!("Transaction {} is not active", id))); return Err(SqlError::Transaction(format!(
"Transaction {} is not active",
id
)));
} }
txn.mark_rolled_back(); txn.mark_rolled_back();

View file

@ -233,12 +233,8 @@ impl Ord for SqlValue {
(SqlValue::Blob(a), SqlValue::Blob(b)) => a.cmp(b), (SqlValue::Blob(a), SqlValue::Blob(b)) => a.cmp(b),
(SqlValue::Boolean(a), SqlValue::Boolean(b)) => a.cmp(b), (SqlValue::Boolean(a), SqlValue::Boolean(b)) => a.cmp(b),
(SqlValue::Timestamp(a), SqlValue::Timestamp(b)) => a.cmp(b), (SqlValue::Timestamp(a), SqlValue::Timestamp(b)) => a.cmp(b),
(SqlValue::Integer(a), SqlValue::Real(b)) => { (SqlValue::Integer(a), SqlValue::Real(b)) => (*a as f64).to_bits().cmp(&b.to_bits()),
(*a as f64).to_bits().cmp(&b.to_bits()) (SqlValue::Real(a), SqlValue::Integer(b)) => a.to_bits().cmp(&(*b as f64).to_bits()),
}
(SqlValue::Real(a), SqlValue::Integer(b)) => {
a.to_bits().cmp(&(*b as f64).to_bits())
}
// Different types: order by type discriminant // Different types: order by type discriminant
_ => self.type_order().cmp(&other.type_order()), _ => self.type_order().cmp(&other.type_order()),
} }

View file

@ -158,11 +158,7 @@ impl Metric {
/// Calculates sum in a time range. /// Calculates sum in a time range.
pub fn sum(&self, start: u64, end: u64) -> f64 { pub fn sum(&self, start: u64, end: u64) -> f64 {
self.data self.data.read().range(start..=end).map(|(_, &v)| v).sum()
.read()
.range(start..=end)
.map(|(_, &v)| v)
.sum()
} }
/// Counts data points in a time range. /// Counts data points in a time range.

View file

@ -207,9 +207,7 @@ impl VectorIndex {
let embeddings = self.embeddings.read(); let embeddings = self.embeddings.read();
let mut results: Vec<VectorSearchResult> = embeddings let mut results: Vec<VectorSearchResult> = embeddings
.values() .values()
.filter(|e| { .filter(|e| namespace.map(|ns| e.namespace == ns).unwrap_or(true))
namespace.map(|ns| e.namespace == ns).unwrap_or(true)
})
.map(|e| { .map(|e| {
let score = self.calculate_similarity(&e.vector, query); let score = self.calculate_similarity(&e.vector, query);
VectorSearchResult { VectorSearchResult {
@ -217,14 +215,14 @@ impl VectorIndex {
score, score,
} }
}) })
.filter(|r| { .filter(|r| threshold.map(|t| r.score >= t).unwrap_or(true))
threshold.map(|t| r.score >= t).unwrap_or(true)
})
.collect(); .collect();
// Sort by score descending // Sort by score descending
results.sort_by(|a, b| { results.sort_by(|a, b| {
b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal) b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
}); });
// Apply limit // Apply limit
@ -234,8 +232,9 @@ impl VectorIndex {
let elapsed = start.elapsed().as_millis() as f64; let elapsed = start.elapsed().as_millis() as f64;
let mut stats = self.stats.write(); let mut stats = self.stats.write();
stats.searches += 1; stats.searches += 1;
stats.avg_search_time_ms = stats.avg_search_time_ms = (stats.avg_search_time_ms * (stats.searches - 1) as f64
(stats.avg_search_time_ms * (stats.searches - 1) as f64 + elapsed) / stats.searches as f64; + elapsed)
/ stats.searches as f64;
Ok(results) Ok(results)
} }
@ -329,7 +328,8 @@ impl VectorStore {
namespace: Option<&str>, namespace: Option<&str>,
threshold: Option<f32>, threshold: Option<f32>,
) -> Result<Vec<VectorSearchResult>, DatabaseError> { ) -> Result<Vec<VectorSearchResult>, DatabaseError> {
self.default_index.search(query, limit, namespace, threshold) self.default_index
.search(query, limit, namespace, threshold)
} }
/// Gets an embedding by ID. /// Gets an embedding by ID.
@ -388,10 +388,7 @@ pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
/// Manhattan distance (L1) between two vectors. /// Manhattan distance (L1) between two vectors.
pub fn manhattan_distance(a: &[f32], b: &[f32]) -> f32 { pub fn manhattan_distance(a: &[f32], b: &[f32]) -> f32 {
a.iter() a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum()
.zip(b.iter())
.map(|(x, y)| (x - y).abs())
.sum()
} }
#[cfg(test)] #[cfg(test)]
@ -414,9 +411,15 @@ mod tests {
fn test_vector_insert_search() { fn test_vector_insert_search() {
let store = VectorStore::new(3); let store = VectorStore::new(3);
store.insert(Embedding::new("a", vec![1.0, 0.0, 0.0])).unwrap(); store
store.insert(Embedding::new("b", vec![0.9, 0.1, 0.0])).unwrap(); .insert(Embedding::new("a", vec![1.0, 0.0, 0.0]))
store.insert(Embedding::new("c", vec![0.0, 1.0, 0.0])).unwrap(); .unwrap();
store
.insert(Embedding::new("b", vec![0.9, 0.1, 0.0]))
.unwrap();
store
.insert(Embedding::new("c", vec![0.0, 1.0, 0.0]))
.unwrap();
let results = store.search(&[1.0, 0.0, 0.0], 2, None, None).unwrap(); let results = store.search(&[1.0, 0.0, 0.0], 2, None, None).unwrap();
@ -429,8 +432,12 @@ mod tests {
fn test_similarity_threshold() { fn test_similarity_threshold() {
let store = VectorStore::new(3); let store = VectorStore::new(3);
store.insert(Embedding::new("a", vec![1.0, 0.0, 0.0])).unwrap(); store
store.insert(Embedding::new("b", vec![0.0, 1.0, 0.0])).unwrap(); .insert(Embedding::new("a", vec![1.0, 0.0, 0.0]))
.unwrap();
store
.insert(Embedding::new("b", vec![0.0, 1.0, 0.0]))
.unwrap();
let results = store.search(&[1.0, 0.0, 0.0], 10, None, Some(0.5)).unwrap(); let results = store.search(&[1.0, 0.0, 0.0], 10, None, Some(0.5)).unwrap();
@ -443,14 +450,16 @@ mod tests {
fn test_namespace_filter() { fn test_namespace_filter() {
let store = VectorStore::new(3); let store = VectorStore::new(3);
store.insert( store
Embedding::new("a", vec![1.0, 0.0, 0.0]).with_namespace("ns1") .insert(Embedding::new("a", vec![1.0, 0.0, 0.0]).with_namespace("ns1"))
).unwrap(); .unwrap();
store.insert( store
Embedding::new("b", vec![1.0, 0.0, 0.0]).with_namespace("ns2") .insert(Embedding::new("b", vec![1.0, 0.0, 0.0]).with_namespace("ns2"))
).unwrap(); .unwrap();
let results = store.search(&[1.0, 0.0, 0.0], 10, Some("ns1"), None).unwrap(); let results = store
.search(&[1.0, 0.0, 0.0], 10, Some("ns1"), None)
.unwrap();
assert_eq!(results.len(), 1); assert_eq!(results.len(), 1);
assert_eq!(results[0].embedding.id, "a"); assert_eq!(results[0].embedding.id, "a");

View file

@ -184,9 +184,7 @@ impl Credit {
/// Check if credit is expired /// Check if credit is expired
pub fn is_expired(&self) -> bool { pub fn is_expired(&self) -> bool {
self.expires_at self.expires_at.map(|exp| Utc::now() > exp).unwrap_or(false)
.map(|exp| Utc::now() > exp)
.unwrap_or(false)
} }
/// Get remaining amount /// Get remaining amount
@ -241,14 +239,21 @@ impl std::fmt::Display for CreditError {
match self { match self {
CreditError::CreditInactive => write!(f, "Credit is no longer active"), CreditError::CreditInactive => write!(f, "Credit is no longer active"),
CreditError::CreditExpired => write!(f, "Credit has expired"), CreditError::CreditExpired => write!(f, "Credit has expired"),
CreditError::InsufficientCredit { requested, available } => { CreditError::InsufficientCredit {
requested,
available,
} => {
write!( write!(
f, f,
"Insufficient credit: requested {}, available {}", "Insufficient credit: requested {}, available {}",
requested, available requested, available
) )
} }
CreditError::ExceedsMaxCredit { current, requested, maximum } => { CreditError::ExceedsMaxCredit {
current,
requested,
maximum,
} => {
write!( write!(
f, f,
"Credit exceeds maximum: current {}, requested {}, maximum {}", "Credit exceeds maximum: current {}, requested {}, maximum {}",
@ -279,9 +284,9 @@ pub struct CreditPolicy {
impl Default for CreditPolicy { impl Default for CreditPolicy {
fn default() -> Self { fn default() -> Self {
Self { Self {
welcome_amount: Decimal::new(10, 0), // 10 SYNOR welcome_amount: Decimal::new(10, 0), // 10 SYNOR
referral_referrer_amount: Decimal::new(25, 0), // 25 SYNOR referral_referrer_amount: Decimal::new(25, 0), // 25 SYNOR
referral_referee_amount: Decimal::new(10, 0), // 10 SYNOR referral_referee_amount: Decimal::new(10, 0), // 10 SYNOR
max_credit_per_account: Decimal::new(1000, 0), // 1000 SYNOR max_credit_per_account: Decimal::new(1000, 0), // 1000 SYNOR
default_expiry_days: 365, default_expiry_days: 365,
} }
@ -334,12 +339,20 @@ impl CreditManager {
let referee_id = referee_id.into(); let referee_id = referee_id.into();
// Credit for the referrer // Credit for the referrer
let referrer_credit = Credit::referral(&referrer_id, self.policy.referral_referrer_amount, &referee_id) let referrer_credit = Credit::referral(
.with_expiry_days(self.policy.default_expiry_days); &referrer_id,
self.policy.referral_referrer_amount,
&referee_id,
)
.with_expiry_days(self.policy.default_expiry_days);
// Credit for the referee // Credit for the referee
let referee_credit = Credit::referral(&referee_id, self.policy.referral_referee_amount, &referrer_id) let referee_credit = Credit::referral(
.with_expiry_days(self.policy.default_expiry_days); &referee_id,
self.policy.referral_referee_amount,
&referrer_id,
)
.with_expiry_days(self.policy.default_expiry_days);
self.credits self.credits
.entry(referrer_id) .entry(referrer_id)
@ -448,13 +461,11 @@ impl CreditManager {
let mut remaining = amount; let mut remaining = amount;
// Sort by expiry date (soonest first) for FIFO // Sort by expiry date (soonest first) for FIFO
credits.sort_by(|a, b| { credits.sort_by(|a, b| match (&a.expires_at, &b.expires_at) {
match (&a.expires_at, &b.expires_at) { (Some(a_exp), Some(b_exp)) => a_exp.cmp(b_exp),
(Some(a_exp), Some(b_exp)) => a_exp.cmp(b_exp), (Some(_), None) => std::cmp::Ordering::Less,
(Some(_), None) => std::cmp::Ordering::Less, (None, Some(_)) => std::cmp::Ordering::Greater,
(None, Some(_)) => std::cmp::Ordering::Greater, (None, None) => a.created_at.cmp(&b.created_at),
(None, None) => a.created_at.cmp(&b.created_at),
}
}); });
for credit in credits.iter_mut() { for credit in credits.iter_mut() {

View file

@ -319,12 +319,7 @@ mod tests {
#[test] #[test]
fn test_line_item() { fn test_line_item() {
let item = InvoiceLineItem::new( let item = InvoiceLineItem::new("Storage L2", ServiceType::Storage, dec!(10), dec!(0.02));
"Storage L2",
ServiceType::Storage,
dec!(10),
dec!(0.02),
);
assert_eq!(item.amount, dec!(0.20)); assert_eq!(item.amount, dec!(0.20));
} }
@ -332,8 +327,18 @@ mod tests {
#[test] #[test]
fn test_invoice_calculate() { fn test_invoice_calculate() {
let mut invoice = Invoice::new("test") let mut invoice = Invoice::new("test")
.add_line_item(InvoiceLineItem::new("Storage", ServiceType::Storage, dec!(100), dec!(0.02))) .add_line_item(InvoiceLineItem::new(
.add_line_item(InvoiceLineItem::new("Compute", ServiceType::Compute, dec!(10), dec!(0.50))); "Storage",
ServiceType::Storage,
dec!(100),
dec!(0.02),
))
.add_line_item(InvoiceLineItem::new(
"Compute",
ServiceType::Compute,
dec!(10),
dec!(0.50),
));
invoice.discount = dec!(1); invoice.discount = dec!(1);
invoice.calculate(); invoice.calculate();

View file

@ -164,17 +164,12 @@ impl BillingEngine {
let outstanding: Vec<_> = account let outstanding: Vec<_> = account
.invoice_ids .invoice_ids
.iter() .iter()
.filter(|id| { .filter(|id| invoices.get(*id).map(|inv| !inv.is_paid()).unwrap_or(false))
invoices
.get(*id)
.map(|inv| !inv.is_paid())
.unwrap_or(false)
})
.cloned() .cloned()
.collect(); .collect();
let next_invoice = account.billing_cycle_start let next_invoice =
+ Duration::days(self.config.billing_cycle_days as i64); account.billing_cycle_start + Duration::days(self.config.billing_cycle_days as i64);
Ok(AccountBillingInfo { Ok(AccountBillingInfo {
account_id: account_id.to_string(), account_id: account_id.to_string(),
@ -198,11 +193,7 @@ impl BillingEngine {
account.prepaid_balance += amount; account.prepaid_balance += amount;
tracing::info!( tracing::info!("Added {} SYNOR prepaid to account {}", amount, account_id);
"Added {} SYNOR prepaid to account {}",
amount,
account_id
);
Ok(()) Ok(())
} }
@ -378,7 +369,9 @@ impl BillingEngine {
PaymentMethod::CreditBalance => { PaymentMethod::CreditBalance => {
// Deduct from credit balance // Deduct from credit balance
if account.credit_balance < payment.amount { if account.credit_balance < payment.amount {
return Err(EconomicsError::PaymentFailed("Insufficient credit balance".to_string())); return Err(EconomicsError::PaymentFailed(
"Insufficient credit balance".to_string(),
));
} }
account.credit_balance -= payment.amount; account.credit_balance -= payment.amount;
payment.mark_completed(); payment.mark_completed();
@ -484,7 +477,10 @@ impl BillingEngine {
/// Get unpaid invoices for an account /// Get unpaid invoices for an account
pub async fn get_unpaid_invoices(&self, account_id: &str) -> Result<Vec<Invoice>> { pub async fn get_unpaid_invoices(&self, account_id: &str) -> Result<Vec<Invoice>> {
let all_invoices = self.get_account_invoices(account_id).await?; let all_invoices = self.get_account_invoices(account_id).await?;
Ok(all_invoices.into_iter().filter(|inv| !inv.is_paid()).collect()) Ok(all_invoices
.into_iter()
.filter(|inv| !inv.is_paid())
.collect())
} }
/// Get detailed account information including creation date /// Get detailed account information including creation date
@ -617,7 +613,10 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_register_account() { async fn test_register_account() {
let engine = setup_engine().await; let engine = setup_engine().await;
engine.register_account("test_account", "standard").await.unwrap(); engine
.register_account("test_account", "standard")
.await
.unwrap();
let info = engine.get_account_details("test_account").await.unwrap(); let info = engine.get_account_details("test_account").await.unwrap();
assert_eq!(info.account_id, "test_account"); assert_eq!(info.account_id, "test_account");
@ -627,7 +626,10 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_add_prepaid() { async fn test_add_prepaid() {
let engine = setup_engine().await; let engine = setup_engine().await;
engine.register_account("prepaid_test", "standard").await.unwrap(); engine
.register_account("prepaid_test", "standard")
.await
.unwrap();
engine.add_prepaid("prepaid_test", dec!(100)).await.unwrap(); engine.add_prepaid("prepaid_test", dec!(100)).await.unwrap();
let info = engine.get_account_info("prepaid_test").await.unwrap(); let info = engine.get_account_info("prepaid_test").await.unwrap();
@ -637,7 +639,10 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_add_credit() { async fn test_add_credit() {
let engine = setup_engine().await; let engine = setup_engine().await;
engine.register_account("credit_test", "standard").await.unwrap(); engine
.register_account("credit_test", "standard")
.await
.unwrap();
let credit = Credit::new("credit_test", dec!(50), "Welcome bonus"); let credit = Credit::new("credit_test", dec!(50), "Welcome bonus");
engine.add_credit("credit_test", credit).await.unwrap(); engine.add_credit("credit_test", credit).await.unwrap();

View file

@ -210,17 +210,23 @@ impl PaymentProcessor {
payment.mark_processing(); payment.mark_processing();
// Simulate transaction // Simulate transaction
let tx_hash = format!("0x{:x}000000000000000000000000000000000000000000000000000000000000", let tx_hash = format!(
"0x{:x}000000000000000000000000000000000000000000000000000000000000",
std::time::SystemTime::now() std::time::SystemTime::now()
.duration_since(UNIX_EPOCH) .duration_since(UNIX_EPOCH)
.unwrap() .unwrap()
.as_secs()); .as_secs()
);
payment.mark_confirmed(tx_hash); payment.mark_confirmed(tx_hash);
// Add addresses to metadata // Add addresses to metadata
payment.metadata.insert("from".to_string(), from_address.to_string()); payment
payment.metadata.insert("to".to_string(), to_address.to_string()); .metadata
.insert("from".to_string(), from_address.to_string());
payment
.metadata
.insert("to".to_string(), to_address.to_string());
payment.mark_completed(); payment.mark_completed();
@ -325,7 +331,10 @@ mod tests {
payment.mark_failed("Insufficient funds"); payment.mark_failed("Insufficient funds");
assert!(!payment.is_complete()); assert!(!payment.is_complete());
assert_eq!(payment.failure_reason, Some("Insufficient funds".to_string())); assert_eq!(
payment.failure_reason,
Some("Insufficient funds".to_string())
);
} }
#[tokio::test] #[tokio::test]

View file

@ -177,7 +177,10 @@ impl CostEstimator {
/// Estimate cost for a usage projection /// Estimate cost for a usage projection
pub async fn estimate(&self, projection: UsageProjection) -> Result<CostEstimate> { pub async fn estimate(&self, projection: UsageProjection) -> Result<CostEstimate> {
let tier_name = projection.tier.clone().unwrap_or_else(|| "free".to_string()); let tier_name = projection
.tier
.clone()
.unwrap_or_else(|| "free".to_string());
let months = projection.duration_months.max(1); let months = projection.duration_months.max(1);
let mut by_service = HashMap::new(); let mut by_service = HashMap::new();

View file

@ -22,10 +22,7 @@ pub enum EconomicsError {
/// Insufficient balance /// Insufficient balance
#[error("Insufficient balance: required {required}, available {available}")] #[error("Insufficient balance: required {required}, available {available}")]
InsufficientBalance { InsufficientBalance { required: String, available: String },
required: String,
available: String,
},
/// Insufficient funds (with Decimal values) /// Insufficient funds (with Decimal values)
#[error("Insufficient funds: required {required}, available {available}")] #[error("Insufficient funds: required {required}, available {available}")]
@ -36,10 +33,7 @@ pub enum EconomicsError {
/// Stale price with specific asset /// Stale price with specific asset
#[error("Price stale for {asset}: {age_seconds} seconds old")] #[error("Price stale for {asset}: {age_seconds} seconds old")]
StalePrice { StalePrice { asset: String, age_seconds: i64 },
asset: String,
age_seconds: i64,
},
/// Account not found /// Account not found
#[error("Account not found: {0}")] #[error("Account not found: {0}")]

View file

@ -251,9 +251,7 @@ impl EconomicsManager {
use rust_decimal_macros::dec; use rust_decimal_macros::dec;
// Default to development oracle with mock feeds at $1.50 base price // Default to development oracle with mock feeds at $1.50 base price
let oracle = Arc::new(RwLock::new( let oracle = Arc::new(RwLock::new(oracle::OracleFactory::development(dec!(1.50))));
oracle::OracleFactory::development(dec!(1.50))
));
let pricing = Arc::new(PricingEngine::new()); let pricing = Arc::new(PricingEngine::new());
let metering = Arc::new(MeteringService::new(pricing.clone())); let metering = Arc::new(MeteringService::new(pricing.clone()));
let billing = Arc::new(BillingEngine::new(metering.clone(), pricing.clone())); let billing = Arc::new(BillingEngine::new(metering.clone(), pricing.clone()));
@ -270,9 +268,7 @@ impl EconomicsManager {
/// Create an economics manager with production oracle configuration /// Create an economics manager with production oracle configuration
pub fn with_production_oracle(config: oracle::ProductionOracleConfig) -> Self { pub fn with_production_oracle(config: oracle::ProductionOracleConfig) -> Self {
let oracle = Arc::new(RwLock::new( let oracle = Arc::new(RwLock::new(oracle::OracleFactory::production(config)));
oracle::OracleFactory::production(config)
));
let pricing = Arc::new(PricingEngine::new()); let pricing = Arc::new(PricingEngine::new());
let metering = Arc::new(MeteringService::new(pricing.clone())); let metering = Arc::new(MeteringService::new(pricing.clone()));
let billing = Arc::new(BillingEngine::new(metering.clone(), pricing.clone())); let billing = Arc::new(BillingEngine::new(metering.clone(), pricing.clone()));

View file

@ -209,23 +209,22 @@ impl MeteringService {
} }
// Calculate cost for this event // Calculate cost for this event
let cost = self.pricing.calculate_cost( let cost =
event.service_type, self.pricing
event.resource_unit, .calculate_cost(event.service_type, event.resource_unit, event.amount)?;
event.amount,
)?;
// Update current usage // Update current usage
{ {
let mut usage = self.current_usage.write().await; let mut usage = self.current_usage.write().await;
let account_usage = usage.entry(event.account_id.clone()).or_insert_with(|| { let account_usage =
AccountUsage { usage
account_id: event.account_id.clone(), .entry(event.account_id.clone())
by_service: HashMap::new(), .or_insert_with(|| AccountUsage {
current_period_start: Utc::now(), account_id: event.account_id.clone(),
last_event: None, by_service: HashMap::new(),
} current_period_start: Utc::now(),
}); last_event: None,
});
*account_usage *account_usage
.by_service .by_service
@ -263,7 +262,8 @@ impl MeteringService {
ServiceType::Storage, ServiceType::Storage,
ResourceUnit::Bytes, ResourceUnit::Bytes,
Decimal::from(usage.bytes_stored), Decimal::from(usage.bytes_stored),
)).await?; ))
.await?;
} }
// Storage: bytes retrieved // Storage: bytes retrieved
@ -273,7 +273,8 @@ impl MeteringService {
ServiceType::Storage, ServiceType::Storage,
ResourceUnit::BandwidthGb, ResourceUnit::BandwidthGb,
Decimal::from(usage.bytes_retrieved) / Decimal::from(1_073_741_824u64), // to GB Decimal::from(usage.bytes_retrieved) / Decimal::from(1_073_741_824u64), // to GB
)).await?; ))
.await?;
} }
Ok(()) Ok(())
@ -288,7 +289,8 @@ impl MeteringService {
ServiceType::Hosting, ServiceType::Hosting,
ResourceUnit::BandwidthGb, ResourceUnit::BandwidthGb,
Decimal::from(usage.bandwidth_bytes) / Decimal::from(1_073_741_824u64), Decimal::from(usage.bandwidth_bytes) / Decimal::from(1_073_741_824u64),
)).await?; ))
.await?;
} }
// Custom domains // Custom domains
@ -298,7 +300,8 @@ impl MeteringService {
ServiceType::Hosting, ServiceType::Hosting,
ResourceUnit::Domains, ResourceUnit::Domains,
Decimal::from(usage.custom_domains), Decimal::from(usage.custom_domains),
)).await?; ))
.await?;
} }
Ok(()) Ok(())
@ -313,7 +316,8 @@ impl MeteringService {
ServiceType::Database, ServiceType::Database,
ResourceUnit::Queries, ResourceUnit::Queries,
Decimal::from(usage.queries), Decimal::from(usage.queries),
)).await?; ))
.await?;
} }
// Vector searches // Vector searches
@ -323,7 +327,8 @@ impl MeteringService {
ServiceType::Database, ServiceType::Database,
ResourceUnit::VectorSearches, ResourceUnit::VectorSearches,
Decimal::from(usage.vector_searches), Decimal::from(usage.vector_searches),
)).await?; ))
.await?;
} }
// Storage // Storage
@ -333,7 +338,8 @@ impl MeteringService {
ServiceType::Database, ServiceType::Database,
ResourceUnit::GbMonth, ResourceUnit::GbMonth,
Decimal::from(usage.storage_bytes) / Decimal::from(1_073_741_824u64), Decimal::from(usage.storage_bytes) / Decimal::from(1_073_741_824u64),
)).await?; ))
.await?;
} }
Ok(()) Ok(())
@ -348,7 +354,8 @@ impl MeteringService {
ServiceType::Compute, ServiceType::Compute,
ResourceUnit::CpuCoreHours, ResourceUnit::CpuCoreHours,
Decimal::from(usage.cpu_core_seconds) / Decimal::from(3600), Decimal::from(usage.cpu_core_seconds) / Decimal::from(3600),
)).await?; ))
.await?;
} }
// GPU hours // GPU hours
@ -358,7 +365,8 @@ impl MeteringService {
ServiceType::Compute, ServiceType::Compute,
ResourceUnit::GpuHours, ResourceUnit::GpuHours,
Decimal::from(usage.gpu_seconds) / Decimal::from(3600), Decimal::from(usage.gpu_seconds) / Decimal::from(3600),
)).await?; ))
.await?;
} }
// Memory GB hours // Memory GB hours
@ -368,7 +376,8 @@ impl MeteringService {
ServiceType::Compute, ServiceType::Compute,
ResourceUnit::MemoryGbHours, ResourceUnit::MemoryGbHours,
Decimal::from(usage.memory_gb_seconds) / Decimal::from(3600), Decimal::from(usage.memory_gb_seconds) / Decimal::from(3600),
)).await?; ))
.await?;
} }
// Invocations (serverless) // Invocations (serverless)
@ -378,7 +387,8 @@ impl MeteringService {
ServiceType::Compute, ServiceType::Compute,
ResourceUnit::Invocations, ResourceUnit::Invocations,
Decimal::from(usage.invocations), Decimal::from(usage.invocations),
)).await?; ))
.await?;
} }
Ok(()) Ok(())
@ -393,7 +403,8 @@ impl MeteringService {
ServiceType::Network, ServiceType::Network,
ResourceUnit::BandwidthGb, ResourceUnit::BandwidthGb,
Decimal::from(total_bytes) / Decimal::from(1_073_741_824u64), Decimal::from(total_bytes) / Decimal::from(1_073_741_824u64),
)).await?; ))
.await?;
} }
Ok(()) Ok(())
@ -421,10 +432,7 @@ impl MeteringService {
// Check buffered events // Check buffered events
let buffer = self.event_buffer.read().await; let buffer = self.event_buffer.read().await;
for event in buffer.iter() { for event in buffer.iter() {
if event.account_id == account_id if event.account_id == account_id && event.timestamp >= start && event.timestamp < end {
&& event.timestamp >= start
&& event.timestamp < end
{
let cost = self.pricing.calculate_cost( let cost = self.pricing.calculate_cost(
event.service_type, event.service_type,
event.resource_unit, event.resource_unit,

View file

@ -224,9 +224,8 @@ impl IsolationTree {
// Random split point // Random split point
let split = min_val + (max_val - min_val) * 0.5; let split = min_val + (max_val - min_val) * 0.5;
let (left_data, right_data): (Vec<_>, Vec<_>) = data.iter() let (left_data, right_data): (Vec<_>, Vec<_>) =
.cloned() data.iter().cloned().partition(|row| row[feature] < split);
.partition(|row| row[feature] < split);
Some(Self { Some(Self {
split_feature: feature, split_feature: feature,
@ -280,7 +279,8 @@ impl IsolationForest {
let trees: Vec<_> = (0..n_trees) let trees: Vec<_> = (0..n_trees)
.filter_map(|i| { .filter_map(|i| {
// Subsample with deterministic "randomness" based on tree index // Subsample with deterministic "randomness" based on tree index
let sample: Vec<_> = data.iter() let sample: Vec<_> = data
.iter()
.enumerate() .enumerate()
.filter(|(j, _)| (i + j) % 3 != 0) .filter(|(j, _)| (i + j) % 3 != 0)
.take(sample_size) .take(sample_size)
@ -299,9 +299,12 @@ impl IsolationForest {
return 0.5; return 0.5;
} }
let avg_path: f64 = self.trees.iter() let avg_path: f64 = self
.trees
.iter()
.map(|tree| tree.path_length(point, 0.0)) .map(|tree| tree.path_length(point, 0.0))
.sum::<f64>() / self.trees.len() as f64; .sum::<f64>()
/ self.trees.len() as f64;
let c = c_factor(self.sample_size); let c = c_factor(self.sample_size);
if c < f64::EPSILON { if c < f64::EPSILON {
@ -365,17 +368,28 @@ impl PairDetector {
// Track addresses // Track addresses
if !point.addresses.is_empty() { if !point.addresses.is_empty() {
self.recent_addresses.push_back((point.timestamp, point.addresses.clone())); self.recent_addresses
.push_back((point.timestamp, point.addresses.clone()));
} }
self.price_history.push_back(point); self.price_history.push_back(point);
// Cleanup old data // Cleanup old data
let cutoff = Utc::now() - Duration::hours(24); let cutoff = Utc::now() - Duration::hours(24);
while self.price_history.front().map(|p| p.timestamp < cutoff).unwrap_or(false) { while self
.price_history
.front()
.map(|p| p.timestamp < cutoff)
.unwrap_or(false)
{
self.price_history.pop_front(); self.price_history.pop_front();
} }
while self.recent_addresses.front().map(|(t, _)| *t < cutoff).unwrap_or(false) { while self
.recent_addresses
.front()
.map(|(t, _)| *t < cutoff)
.unwrap_or(false)
{
self.recent_addresses.pop_front(); self.recent_addresses.pop_front();
} }
} }
@ -386,7 +400,9 @@ impl PairDetector {
} }
// Build feature vectors: [price, volume, return, bid/ask ratio] // Build feature vectors: [price, volume, return, bid/ask ratio]
let data: Vec<Vec<f64>> = self.price_history.iter() let data: Vec<Vec<f64>> = self
.price_history
.iter()
.skip(1) .skip(1)
.zip(self.price_history.iter()) .zip(self.price_history.iter())
.map(|(curr, prev)| { .map(|(curr, prev)| {
@ -403,7 +419,11 @@ impl PairDetector {
(Some(bid), Some(ask)) => { (Some(bid), Some(ask)) => {
let bid_f = bid.to_string().parse::<f64>().unwrap_or(0.0); let bid_f = bid.to_string().parse::<f64>().unwrap_or(0.0);
let ask_f = ask.to_string().parse::<f64>().unwrap_or(1.0); let ask_f = ask.to_string().parse::<f64>().unwrap_or(1.0);
if ask_f > 0.0 { bid_f / ask_f } else { 1.0 } if ask_f > 0.0 {
bid_f / ask_f
} else {
1.0
}
} }
_ => 1.0, _ => 1.0,
}; };
@ -462,19 +482,34 @@ impl AnomalyDetector {
// Run all detectors using the immutable reference first // Run all detectors using the immutable reference first
if let Some(detector) = self.detectors.get(pair) { if let Some(detector) = self.detectors.get(pair) {
if let Some(a) = Self::detect_price_outlier_impl(pair, &data, detector, min_data_points, z_score_threshold) { if let Some(a) = Self::detect_price_outlier_impl(
pair,
&data,
detector,
min_data_points,
z_score_threshold,
) {
anomalies.push(a); anomalies.push(a);
} }
if let Some(a) = Self::detect_volume_spike_impl(pair, &data, detector, min_data_points, volume_spike_multiplier) { if let Some(a) = Self::detect_volume_spike_impl(
pair,
&data,
detector,
min_data_points,
volume_spike_multiplier,
) {
anomalies.push(a); anomalies.push(a);
} }
if let Some(a) = Self::detect_wash_trading_impl(pair, &data, detector, wash_trading_window) { if let Some(a) =
Self::detect_wash_trading_impl(pair, &data, detector, wash_trading_window)
{
anomalies.push(a); anomalies.push(a);
} }
if let Some(a) = Self::detect_pump_dump_impl(pair, detector, pump_dump_window) { if let Some(a) = Self::detect_pump_dump_impl(pair, detector, pump_dump_window) {
anomalies.push(a); anomalies.push(a);
} }
if let Some(a) = Self::detect_flash_loan_impl(pair, &data, detector, flash_loan_window) { if let Some(a) = Self::detect_flash_loan_impl(pair, &data, detector, flash_loan_window)
{
anomalies.push(a); anomalies.push(a);
} }
if ml_enabled { if ml_enabled {
@ -493,7 +528,13 @@ impl AnomalyDetector {
anomalies anomalies
} }
fn detect_price_outlier_impl(pair: &str, data: &MarketDataPoint, detector: &PairDetector, min_data_points: usize, z_score_threshold: f64) -> Option<Anomaly> { fn detect_price_outlier_impl(
pair: &str,
data: &MarketDataPoint,
detector: &PairDetector,
min_data_points: usize,
z_score_threshold: f64,
) -> Option<Anomaly> {
if detector.price_stats.count < min_data_points { if detector.price_stats.count < min_data_points {
return None; return None;
} }
@ -531,7 +572,13 @@ impl AnomalyDetector {
} }
} }
fn detect_volume_spike_impl(pair: &str, data: &MarketDataPoint, detector: &PairDetector, min_data_points: usize, volume_spike_multiplier: f64) -> Option<Anomaly> { fn detect_volume_spike_impl(
pair: &str,
data: &MarketDataPoint,
detector: &PairDetector,
min_data_points: usize,
volume_spike_multiplier: f64,
) -> Option<Anomaly> {
if detector.volume_stats.count < min_data_points { if detector.volume_stats.count < min_data_points {
return None; return None;
} }
@ -550,7 +597,9 @@ impl AnomalyDetector {
confidence: 0.75, confidence: 0.75,
description: format!( description: format!(
"Volume {} is {:.1}x the average {:.2}", "Volume {} is {:.1}x the average {:.2}",
data.volume, volume_f64 / mean, mean data.volume,
volume_f64 / mean,
mean
), ),
data: AnomalyData { data: AnomalyData {
current_value: data.volume, current_value: data.volume,
@ -566,7 +615,12 @@ impl AnomalyDetector {
} }
} }
fn detect_wash_trading_impl(pair: &str, data: &MarketDataPoint, detector: &PairDetector, wash_trading_window: i64) -> Option<Anomaly> { fn detect_wash_trading_impl(
pair: &str,
data: &MarketDataPoint,
detector: &PairDetector,
wash_trading_window: i64,
) -> Option<Anomaly> {
if data.addresses.is_empty() { if data.addresses.is_empty() {
return None; return None;
} }
@ -617,14 +671,20 @@ impl AnomalyDetector {
None None
} }
fn detect_pump_dump_impl(pair: &str, detector: &PairDetector, pump_dump_window: i64) -> Option<Anomaly> { fn detect_pump_dump_impl(
pair: &str,
detector: &PairDetector,
pump_dump_window: i64,
) -> Option<Anomaly> {
// Need enough history // Need enough history
if detector.price_history.len() < 10 { if detector.price_history.len() < 10 {
return None; return None;
} }
let window_start = Utc::now() - Duration::minutes(pump_dump_window); let window_start = Utc::now() - Duration::minutes(pump_dump_window);
let prices: Vec<_> = detector.price_history.iter() let prices: Vec<_> = detector
.price_history
.iter()
.filter(|p| p.timestamp >= window_start) .filter(|p| p.timestamp >= window_start)
.map(|p| p.price.to_string().parse::<f64>().unwrap_or(0.0)) .map(|p| p.price.to_string().parse::<f64>().unwrap_or(0.0))
.collect(); .collect();
@ -635,7 +695,9 @@ impl AnomalyDetector {
// Find max and check for reversal // Find max and check for reversal
let max_price = prices.iter().copied().fold(f64::MIN, f64::max); let max_price = prices.iter().copied().fold(f64::MIN, f64::max);
let max_idx = prices.iter().position(|&p| (p - max_price).abs() < f64::EPSILON)?; let max_idx = prices
.iter()
.position(|&p| (p - max_price).abs() < f64::EPSILON)?;
let first_price = prices.first()?; let first_price = prices.first()?;
let last_price = prices.last()?; let last_price = prices.last()?;
@ -672,10 +734,17 @@ impl AnomalyDetector {
None None
} }
fn detect_flash_loan_impl(pair: &str, data: &MarketDataPoint, detector: &PairDetector, flash_loan_window: i64) -> Option<Anomaly> { fn detect_flash_loan_impl(
pair: &str,
data: &MarketDataPoint,
detector: &PairDetector,
flash_loan_window: i64,
) -> Option<Anomaly> {
// Flash loan signature: huge volume spike + quick price movement + reversal // Flash loan signature: huge volume spike + quick price movement + reversal
let window_start = Utc::now() - Duration::seconds(flash_loan_window); let window_start = Utc::now() - Duration::seconds(flash_loan_window);
let recent: Vec<_> = detector.price_history.iter() let recent: Vec<_> = detector
.price_history
.iter()
.filter(|p| p.timestamp >= window_start) .filter(|p| p.timestamp >= window_start)
.collect(); .collect();
@ -683,11 +752,13 @@ impl AnomalyDetector {
return None; return None;
} }
let volumes: Vec<f64> = recent.iter() let volumes: Vec<f64> = recent
.iter()
.map(|p| p.volume.to_string().parse::<f64>().unwrap_or(0.0)) .map(|p| p.volume.to_string().parse::<f64>().unwrap_or(0.0))
.collect(); .collect();
let prices: Vec<f64> = recent.iter() let prices: Vec<f64> = recent
.iter()
.map(|p| p.price.to_string().parse::<f64>().unwrap_or(0.0)) .map(|p| p.price.to_string().parse::<f64>().unwrap_or(0.0))
.collect(); .collect();
@ -706,7 +777,10 @@ impl AnomalyDetector {
// Big spike and quick reversal // Big spike and quick reversal
if spike > 10.0 && reversal > 8.0 { if spike > 10.0 && reversal > 8.0 {
let mut context = HashMap::new(); let mut context = HashMap::new();
context.insert("volume_spike".to_string(), format!("{:.0}x", max_volume / avg_volume)); context.insert(
"volume_spike".to_string(),
format!("{:.0}x", max_volume / avg_volume),
);
context.insert("price_spike".to_string(), format!("{:.1}%", spike)); context.insert("price_spike".to_string(), format!("{:.1}%", spike));
return Some(Anomaly { return Some(Anomaly {
@ -717,7 +791,8 @@ impl AnomalyDetector {
confidence: 0.65, confidence: 0.65,
description: format!( description: format!(
"Suspected flash loan attack: {}x volume, {:.1}% price spike", "Suspected flash loan attack: {}x volume, {:.1}% price spike",
max_volume / avg_volume, spike max_volume / avg_volume,
spike
), ),
data: AnomalyData { data: AnomalyData {
current_value: data.volume, current_value: data.volume,
@ -734,7 +809,11 @@ impl AnomalyDetector {
None None
} }
fn detect_ml_anomaly_impl(pair: &str, data: &MarketDataPoint, detector: &PairDetector) -> Option<Anomaly> { fn detect_ml_anomaly_impl(
pair: &str,
data: &MarketDataPoint,
detector: &PairDetector,
) -> Option<Anomaly> {
let forest = detector.isolation_forest.as_ref()?; let forest = detector.isolation_forest.as_ref()?;
if detector.price_history.is_empty() { if detector.price_history.is_empty() {
@ -756,7 +835,11 @@ impl AnomalyDetector {
(Some(bid), Some(ask)) => { (Some(bid), Some(ask)) => {
let bid_f = bid.to_string().parse::<f64>().ok()?; let bid_f = bid.to_string().parse::<f64>().ok()?;
let ask_f = ask.to_string().parse::<f64>().ok()?; let ask_f = ask.to_string().parse::<f64>().ok()?;
if ask_f > 0.0 { bid_f / ask_f } else { 1.0 } if ask_f > 0.0 {
bid_f / ask_f
} else {
1.0
}
} }
_ => 1.0, _ => 1.0,
}; };
@ -774,10 +857,7 @@ impl AnomalyDetector {
detected_at: Utc::now(), detected_at: Utc::now(),
severity: score, severity: score,
confidence: 0.6, confidence: 0.6,
description: format!( description: format!("ML model detected anomaly with score {:.3}", score),
"ML model detected anomaly with score {:.3}",
score
),
data: AnomalyData { data: AnomalyData {
current_value: data.price, current_value: data.price,
expected_value: Decimal::from_f64_retain(detector.price_stats.mean)?, expected_value: Decimal::from_f64_retain(detector.price_stats.mean)?,
@ -798,7 +878,8 @@ impl AnomalyDetector {
/// Get recent anomalies for a pair /// Get recent anomalies for a pair
pub fn get_anomalies(&self, pair: &str) -> Vec<Anomaly> { pub fn get_anomalies(&self, pair: &str) -> Vec<Anomaly> {
self.detectors.get(pair) self.detectors
.get(pair)
.map(|d| d.anomalies.clone()) .map(|d| d.anomalies.clone())
.unwrap_or_default() .unwrap_or_default()
} }
@ -807,11 +888,14 @@ impl AnomalyDetector {
pub fn get_stats(&self, pair: &str) -> Option<AnomalyStats> { pub fn get_stats(&self, pair: &str) -> Option<AnomalyStats> {
let detector = self.detectors.get(pair)?; let detector = self.detectors.get(pair)?;
let by_type: HashMap<AnomalyType, usize> = detector.anomalies.iter() let by_type: HashMap<AnomalyType, usize> =
.fold(HashMap::new(), |mut acc, a| { detector
*acc.entry(a.anomaly_type.clone()).or_insert(0) += 1; .anomalies
acc .iter()
}); .fold(HashMap::new(), |mut acc, a| {
*acc.entry(a.anomaly_type.clone()).or_insert(0) += 1;
acc
});
Some(AnomalyStats { Some(AnomalyStats {
total_anomalies: detector.anomalies.len(), total_anomalies: detector.anomalies.len(),
@ -913,7 +997,9 @@ mod tests {
let anomalies = detector.process("SYNOR/USD", outlier); let anomalies = detector.process("SYNOR/USD", outlier);
assert!(!anomalies.is_empty()); assert!(!anomalies.is_empty());
assert!(anomalies.iter().any(|a| a.anomaly_type == AnomalyType::PriceOutlier)); assert!(anomalies
.iter()
.any(|a| a.anomaly_type == AnomalyType::PriceOutlier));
} }
#[test] #[test]
@ -946,6 +1032,8 @@ mod tests {
}; };
let anomalies = detector.process("SYNOR/USD", spike); let anomalies = detector.process("SYNOR/USD", spike);
assert!(anomalies.iter().any(|a| a.anomaly_type == AnomalyType::VolumeSpike)); assert!(anomalies
.iter()
.any(|a| a.anomaly_type == AnomalyType::VolumeSpike));
} }
} }

View file

@ -40,17 +40,11 @@ pub enum TriggerReason {
threshold: SynorDecimal, threshold: SynorDecimal,
}, },
/// Multiple oracle sources disagree /// Multiple oracle sources disagree
OracleDisagreement { OracleDisagreement { spread_percent: Decimal },
spread_percent: Decimal,
},
/// Manual trigger by admin /// Manual trigger by admin
ManualHalt { ManualHalt { reason: String },
reason: String,
},
/// Cascade from related market /// Cascade from related market
CascadeTrigger { CascadeTrigger { source_pair: String },
source_pair: String,
},
} }
/// Circuit breaker event /// Circuit breaker event
@ -98,12 +92,12 @@ pub struct CircuitBreakerConfig {
impl Default for CircuitBreakerConfig { impl Default for CircuitBreakerConfig {
fn default() -> Self { fn default() -> Self {
Self { Self {
max_1m_change: Decimal::new(10, 2), // 10% max_1m_change: Decimal::new(10, 2), // 10%
max_5m_change: Decimal::new(20, 2), // 20% max_5m_change: Decimal::new(20, 2), // 20%
max_1h_change: Decimal::new(50, 2), // 50% max_1h_change: Decimal::new(50, 2), // 50%
max_twap_deviation: Decimal::new(30, 2), // 30% max_twap_deviation: Decimal::new(30, 2), // 30%
min_liquidity: Decimal::new(10000, 0), // $10k min_liquidity: Decimal::new(10000, 0), // $10k
max_oracle_spread: Decimal::new(5, 2), // 5% max_oracle_spread: Decimal::new(5, 2), // 5%
cooldown_duration: Duration::minutes(5), cooldown_duration: Duration::minutes(5),
recovery_checks: 3, recovery_checks: 3,
cascade_enabled: true, cascade_enabled: true,
@ -173,7 +167,12 @@ impl PairCircuitBreaker {
// Keep only last 24 hours // Keep only last 24 hours
let cutoff = Utc::now() - Duration::hours(24); let cutoff = Utc::now() - Duration::hours(24);
while self.price_history.front().map(|s| s.timestamp < cutoff).unwrap_or(false) { while self
.price_history
.front()
.map(|s| s.timestamp < cutoff)
.unwrap_or(false)
{
self.price_history.pop_front(); self.price_history.pop_front();
} }
@ -192,7 +191,8 @@ impl PairCircuitBreaker {
fn get_price_at(&self, seconds_ago: i64) -> Option<SynorDecimal> { fn get_price_at(&self, seconds_ago: i64) -> Option<SynorDecimal> {
let target = Utc::now() - Duration::seconds(seconds_ago); let target = Utc::now() - Duration::seconds(seconds_ago);
self.price_history.iter() self.price_history
.iter()
.rev() .rev()
.find(|s| s.timestamp <= target) .find(|s| s.timestamp <= target)
.map(|s| s.price) .map(|s| s.price)
@ -232,7 +232,9 @@ impl CircuitBreakerManager {
price: SynorDecimal, price: SynorDecimal,
liquidity: Option<SynorDecimal>, liquidity: Option<SynorDecimal>,
) -> Result<CircuitState> { ) -> Result<CircuitState> {
let breaker = self.breakers.entry(pair.to_string()) let breaker = self
.breakers
.entry(pair.to_string())
.or_insert_with(PairCircuitBreaker::new); .or_insert_with(PairCircuitBreaker::new);
// Use the convenience method for real-time price recording // Use the convenience method for real-time price recording
@ -256,7 +258,9 @@ impl CircuitBreakerManager {
liquidity: Option<SynorDecimal>, liquidity: Option<SynorDecimal>,
timestamp: DateTime<Utc>, timestamp: DateTime<Utc>,
) -> Result<CircuitState> { ) -> Result<CircuitState> {
let breaker = self.breakers.entry(pair.to_string()) let breaker = self
.breakers
.entry(pair.to_string())
.or_insert_with(PairCircuitBreaker::new); .or_insert_with(PairCircuitBreaker::new);
breaker.record_price_at(price, liquidity, timestamp); breaker.record_price_at(price, liquidity, timestamp);
@ -273,22 +277,26 @@ impl CircuitBreakerManager {
/// Check all trigger conditions /// Check all trigger conditions
fn check_triggers(&mut self, pair: &str) -> Result<()> { fn check_triggers(&mut self, pair: &str) -> Result<()> {
let breaker = self.breakers.get(pair).ok_or_else(|| let breaker = self
EconomicsError::PriceFeedUnavailable(pair.to_string()) .breakers
)?; .get(pair)
.ok_or_else(|| EconomicsError::PriceFeedUnavailable(pair.to_string()))?;
let current = breaker.current_price().ok_or_else(|| let current = breaker
EconomicsError::PriceFeedUnavailable(pair.to_string()) .current_price()
)?; .ok_or_else(|| EconomicsError::PriceFeedUnavailable(pair.to_string()))?;
// Check 1-minute change // Check 1-minute change
if let Some(price_1m) = breaker.get_price_at(60) { if let Some(price_1m) = breaker.get_price_at(60) {
let change = ((current - price_1m) / price_1m).abs(); let change = ((current - price_1m) / price_1m).abs();
if change > self.config.max_1m_change { if change > self.config.max_1m_change {
return self.trigger_breaker(pair, TriggerReason::RapidPriceChange { return self.trigger_breaker(
change_percent: change * Decimal::ONE_HUNDRED, pair,
window_seconds: 60, TriggerReason::RapidPriceChange {
}); change_percent: change * Decimal::ONE_HUNDRED,
window_seconds: 60,
},
);
} }
} }
@ -296,10 +304,13 @@ impl CircuitBreakerManager {
if let Some(price_5m) = breaker.get_price_at(300) { if let Some(price_5m) = breaker.get_price_at(300) {
let change = ((current - price_5m) / price_5m).abs(); let change = ((current - price_5m) / price_5m).abs();
if change > self.config.max_5m_change { if change > self.config.max_5m_change {
return self.trigger_breaker(pair, TriggerReason::RapidPriceChange { return self.trigger_breaker(
change_percent: change * Decimal::ONE_HUNDRED, pair,
window_seconds: 300, TriggerReason::RapidPriceChange {
}); change_percent: change * Decimal::ONE_HUNDRED,
window_seconds: 300,
},
);
} }
} }
@ -307,10 +318,13 @@ impl CircuitBreakerManager {
if let Some(price_1h) = breaker.get_price_at(3600) { if let Some(price_1h) = breaker.get_price_at(3600) {
let change = ((current - price_1h) / price_1h).abs(); let change = ((current - price_1h) / price_1h).abs();
if change > self.config.max_1h_change { if change > self.config.max_1h_change {
return self.trigger_breaker(pair, TriggerReason::RapidPriceChange { return self.trigger_breaker(
change_percent: change * Decimal::ONE_HUNDRED, pair,
window_seconds: 3600, TriggerReason::RapidPriceChange {
}); change_percent: change * Decimal::ONE_HUNDRED,
window_seconds: 3600,
},
);
} }
} }
@ -318,20 +332,26 @@ impl CircuitBreakerManager {
if let Some(twap) = breaker.twap_24h { if let Some(twap) = breaker.twap_24h {
let deviation = ((current - twap) / twap).abs(); let deviation = ((current - twap) / twap).abs();
if deviation > self.config.max_twap_deviation { if deviation > self.config.max_twap_deviation {
return self.trigger_breaker(pair, TriggerReason::ExcessiveDeviation { return self.trigger_breaker(
deviation_percent: deviation * Decimal::ONE_HUNDRED, pair,
reference_price: twap, TriggerReason::ExcessiveDeviation {
}); deviation_percent: deviation * Decimal::ONE_HUNDRED,
reference_price: twap,
},
);
} }
} }
// Check liquidity // Check liquidity
if let Some(liquidity) = breaker.current_liquidity() { if let Some(liquidity) = breaker.current_liquidity() {
if liquidity < self.config.min_liquidity { if liquidity < self.config.min_liquidity {
return self.trigger_breaker(pair, TriggerReason::LowLiquidity { return self.trigger_breaker(
current: liquidity, pair,
threshold: self.config.min_liquidity, TriggerReason::LowLiquidity {
}); current: liquidity,
threshold: self.config.min_liquidity,
},
);
} }
} }
@ -340,9 +360,10 @@ impl CircuitBreakerManager {
/// Trigger the circuit breaker /// Trigger the circuit breaker
fn trigger_breaker(&mut self, pair: &str, reason: TriggerReason) -> Result<()> { fn trigger_breaker(&mut self, pair: &str, reason: TriggerReason) -> Result<()> {
let breaker = self.breakers.get_mut(pair).ok_or_else(|| let breaker = self
EconomicsError::PriceFeedUnavailable(pair.to_string()) .breakers
)?; .get_mut(pair)
.ok_or_else(|| EconomicsError::PriceFeedUnavailable(pair.to_string()))?;
let event = CircuitEvent { let event = CircuitEvent {
pair: pair.to_string(), pair: pair.to_string(),
@ -361,7 +382,10 @@ impl CircuitBreakerManager {
// Check cascade triggers // Check cascade triggers
if self.config.cascade_enabled { if self.config.cascade_enabled {
let cascades: Vec<_> = self.config.cascade_pairs.iter() let cascades: Vec<_> = self
.config
.cascade_pairs
.iter()
.filter(|(source, _)| source == pair) .filter(|(source, _)| source == pair)
.map(|(_, target)| target.clone()) .map(|(_, target)| target.clone())
.collect(); .collect();
@ -407,10 +431,15 @@ impl CircuitBreakerManager {
// Get current state first (immutable borrow) // Get current state first (immutable borrow)
let (current_state, triggered_at, trigger_reason) = { let (current_state, triggered_at, trigger_reason) = {
let breaker = self.breakers.get(pair).ok_or_else(|| let breaker = self
EconomicsError::PriceFeedUnavailable(pair.to_string()) .breakers
)?; .get(pair)
(breaker.state, breaker.triggered_at, breaker.trigger_reason.clone()) .ok_or_else(|| EconomicsError::PriceFeedUnavailable(pair.to_string()))?;
(
breaker.state,
breaker.triggered_at,
breaker.trigger_reason.clone(),
)
}; };
// Check stability for half-open state (immutable borrow) // Check stability for half-open state (immutable borrow)
@ -421,9 +450,10 @@ impl CircuitBreakerManager {
}; };
// Now get mutable reference for updates // Now get mutable reference for updates
let breaker = self.breakers.get_mut(pair).ok_or_else(|| let breaker = self
EconomicsError::PriceFeedUnavailable(pair.to_string()) .breakers
)?; .get_mut(pair)
.ok_or_else(|| EconomicsError::PriceFeedUnavailable(pair.to_string()))?;
match current_state { match current_state {
CircuitState::Open => { CircuitState::Open => {
@ -435,9 +465,9 @@ impl CircuitBreakerManager {
pair: pair.to_string(), pair: pair.to_string(),
from_state: CircuitState::Open, from_state: CircuitState::Open,
to_state: CircuitState::HalfOpen, to_state: CircuitState::HalfOpen,
reason: trigger_reason.clone().unwrap_or( reason: trigger_reason.clone().unwrap_or(TriggerReason::ManualHalt {
TriggerReason::ManualHalt { reason: "Unknown".into() } reason: "Unknown".into(),
), }),
timestamp: Utc::now(), timestamp: Utc::now(),
cooldown: None, cooldown: None,
}; };
@ -457,9 +487,9 @@ impl CircuitBreakerManager {
pair: pair.to_string(), pair: pair.to_string(),
from_state: CircuitState::HalfOpen, from_state: CircuitState::HalfOpen,
to_state: CircuitState::Closed, to_state: CircuitState::Closed,
reason: trigger_reason.unwrap_or( reason: trigger_reason.unwrap_or(TriggerReason::ManualHalt {
TriggerReason::ManualHalt { reason: "Recovery".into() } reason: "Recovery".into(),
), }),
timestamp: Utc::now(), timestamp: Utc::now(),
cooldown: None, cooldown: None,
}; };
@ -482,9 +512,10 @@ impl CircuitBreakerManager {
/// Check if market conditions are stable /// Check if market conditions are stable
fn is_stable(&self, pair: &str) -> Result<bool> { fn is_stable(&self, pair: &str) -> Result<bool> {
let breaker = self.breakers.get(pair).ok_or_else(|| let breaker = self
EconomicsError::PriceFeedUnavailable(pair.to_string()) .breakers
)?; .get(pair)
.ok_or_else(|| EconomicsError::PriceFeedUnavailable(pair.to_string()))?;
let current = match breaker.current_price() { let current = match breaker.current_price() {
Some(p) => p, Some(p) => p,
@ -511,7 +542,8 @@ impl CircuitBreakerManager {
/// Get current state for a pair /// Get current state for a pair
pub fn get_state(&self, pair: &str) -> CircuitState { pub fn get_state(&self, pair: &str) -> CircuitState {
self.breakers.get(pair) self.breakers
.get(pair)
.map(|b| b.state) .map(|b| b.state)
.unwrap_or(CircuitState::Closed) .unwrap_or(CircuitState::Closed)
} }
@ -523,25 +555,32 @@ impl CircuitBreakerManager {
/// Manually trigger circuit breaker /// Manually trigger circuit breaker
pub fn manual_halt(&mut self, pair: &str, reason: impl Into<String>) -> Result<()> { pub fn manual_halt(&mut self, pair: &str, reason: impl Into<String>) -> Result<()> {
self.breakers.entry(pair.to_string()) self.breakers
.entry(pair.to_string())
.or_insert_with(PairCircuitBreaker::new); .or_insert_with(PairCircuitBreaker::new);
self.trigger_breaker(pair, TriggerReason::ManualHalt { self.trigger_breaker(
reason: reason.into(), pair,
}) TriggerReason::ManualHalt {
reason: reason.into(),
},
)
} }
/// Manually reset circuit breaker /// Manually reset circuit breaker
pub fn manual_reset(&mut self, pair: &str) -> Result<()> { pub fn manual_reset(&mut self, pair: &str) -> Result<()> {
let breaker = self.breakers.get_mut(pair).ok_or_else(|| let breaker = self
EconomicsError::PriceFeedUnavailable(pair.to_string()) .breakers
)?; .get_mut(pair)
.ok_or_else(|| EconomicsError::PriceFeedUnavailable(pair.to_string()))?;
let event = CircuitEvent { let event = CircuitEvent {
pair: pair.to_string(), pair: pair.to_string(),
from_state: breaker.state, from_state: breaker.state,
to_state: CircuitState::Closed, to_state: CircuitState::Closed,
reason: TriggerReason::ManualHalt { reason: "Manual reset".into() }, reason: TriggerReason::ManualHalt {
reason: "Manual reset".into(),
},
timestamp: Utc::now(), timestamp: Utc::now(),
cooldown: None, cooldown: None,
}; };
@ -558,26 +597,32 @@ impl CircuitBreakerManager {
/// Record oracle disagreement /// Record oracle disagreement
pub fn record_oracle_spread(&mut self, pair: &str, spread: Decimal) -> Result<()> { pub fn record_oracle_spread(&mut self, pair: &str, spread: Decimal) -> Result<()> {
if spread > self.config.max_oracle_spread { if spread > self.config.max_oracle_spread {
self.breakers.entry(pair.to_string()) self.breakers
.entry(pair.to_string())
.or_insert_with(PairCircuitBreaker::new); .or_insert_with(PairCircuitBreaker::new);
self.trigger_breaker(pair, TriggerReason::OracleDisagreement { self.trigger_breaker(
spread_percent: spread * Decimal::ONE_HUNDRED, pair,
})?; TriggerReason::OracleDisagreement {
spread_percent: spread * Decimal::ONE_HUNDRED,
},
)?;
} }
Ok(()) Ok(())
} }
/// Get event history for a pair /// Get event history for a pair
pub fn get_events(&self, pair: &str) -> Vec<CircuitEvent> { pub fn get_events(&self, pair: &str) -> Vec<CircuitEvent> {
self.breakers.get(pair) self.breakers
.get(pair)
.map(|b| b.events.clone()) .map(|b| b.events.clone())
.unwrap_or_default() .unwrap_or_default()
} }
/// Get all currently halted pairs /// Get all currently halted pairs
pub fn get_halted_pairs(&self) -> Vec<(String, CircuitState, Option<TriggerReason>)> { pub fn get_halted_pairs(&self) -> Vec<(String, CircuitState, Option<TriggerReason>)> {
self.breakers.iter() self.breakers
.iter()
.filter(|(_, b)| b.state != CircuitState::Closed) .filter(|(_, b)| b.state != CircuitState::Closed)
.map(|(pair, b)| (pair.clone(), b.state, b.trigger_reason.clone())) .map(|(pair, b)| (pair.clone(), b.state, b.trigger_reason.clone()))
.collect() .collect()
@ -586,8 +631,16 @@ impl CircuitBreakerManager {
/// Get summary statistics /// Get summary statistics
pub fn get_stats(&self) -> CircuitBreakerStats { pub fn get_stats(&self) -> CircuitBreakerStats {
let total = self.breakers.len(); let total = self.breakers.len();
let open = self.breakers.values().filter(|b| b.state == CircuitState::Open).count(); let open = self
let half_open = self.breakers.values().filter(|b| b.state == CircuitState::HalfOpen).count(); .breakers
.values()
.filter(|b| b.state == CircuitState::Open)
.count();
let half_open = self
.breakers
.values()
.filter(|b| b.state == CircuitState::HalfOpen)
.count();
let total_events: usize = self.breakers.values().map(|b| b.events.len()).sum(); let total_events: usize = self.breakers.values().map(|b| b.events.len()).sum();
CircuitBreakerStats { CircuitBreakerStats {
@ -628,7 +681,9 @@ mod tests {
// Normal price movements should not trigger // Normal price movements should not trigger
for i in 0..10 { for i in 0..10 {
let price = dec!(100) + Decimal::from(i); let price = dec!(100) + Decimal::from(i);
let state = manager.record_price("SYNOR/USD", price, Some(dec!(100000))).unwrap(); let state = manager
.record_price("SYNOR/USD", price, Some(dec!(100000)))
.unwrap();
assert_eq!(state, CircuitState::Closed); assert_eq!(state, CircuitState::Closed);
} }
} }
@ -641,10 +696,19 @@ mod tests {
let now = Utc::now(); let now = Utc::now();
// Record baseline 2 minutes ago // Record baseline 2 minutes ago
manager.record_price_at("SYNOR/USD", dec!(100), Some(dec!(100000)), now - Duration::minutes(2)).unwrap(); manager
.record_price_at(
"SYNOR/USD",
dec!(100),
Some(dec!(100000)),
now - Duration::minutes(2),
)
.unwrap();
// Simulate 15% drop (exceeds 10% 1-minute threshold) // Simulate 15% drop (exceeds 10% 1-minute threshold)
let state = manager.record_price_at("SYNOR/USD", dec!(85), Some(dec!(100000)), now).unwrap(); let state = manager
.record_price_at("SYNOR/USD", dec!(85), Some(dec!(100000)), now)
.unwrap();
assert_eq!(state, CircuitState::Open); assert_eq!(state, CircuitState::Open);
} }
@ -653,7 +717,9 @@ mod tests {
let mut manager = CircuitBreakerManager::new(); let mut manager = CircuitBreakerManager::new();
// Record with very low liquidity // Record with very low liquidity
let state = manager.record_price("SYNOR/USD", dec!(100), Some(dec!(100))).unwrap(); let state = manager
.record_price("SYNOR/USD", dec!(100), Some(dec!(100)))
.unwrap();
assert_eq!(state, CircuitState::Open); assert_eq!(state, CircuitState::Open);
} }
@ -661,11 +727,15 @@ mod tests {
fn test_manual_halt_and_reset() { fn test_manual_halt_and_reset() {
let mut manager = CircuitBreakerManager::new(); let mut manager = CircuitBreakerManager::new();
manager.record_price("SYNOR/USD", dec!(100), Some(dec!(100000))).unwrap(); manager
.record_price("SYNOR/USD", dec!(100), Some(dec!(100000)))
.unwrap();
assert!(manager.is_trading_allowed("SYNOR/USD")); assert!(manager.is_trading_allowed("SYNOR/USD"));
// Manual halt // Manual halt
manager.manual_halt("SYNOR/USD", "Scheduled maintenance").unwrap(); manager
.manual_halt("SYNOR/USD", "Scheduled maintenance")
.unwrap();
assert!(!manager.is_trading_allowed("SYNOR/USD")); assert!(!manager.is_trading_allowed("SYNOR/USD"));
// Manual reset // Manual reset
@ -678,10 +748,14 @@ mod tests {
let mut manager = CircuitBreakerManager::new(); let mut manager = CircuitBreakerManager::new();
// Initialize // Initialize
manager.record_price("SYNOR/USD", dec!(100), Some(dec!(100000))).unwrap(); manager
.record_price("SYNOR/USD", dec!(100), Some(dec!(100000)))
.unwrap();
// Record 10% spread (exceeds 5% threshold) // Record 10% spread (exceeds 5% threshold)
manager.record_oracle_spread("SYNOR/USD", dec!(0.10)).unwrap(); manager
.record_oracle_spread("SYNOR/USD", dec!(0.10))
.unwrap();
assert_eq!(manager.get_state("SYNOR/USD"), CircuitState::Open); assert_eq!(manager.get_state("SYNOR/USD"), CircuitState::Open);
} }

View file

@ -200,7 +200,10 @@ pub struct CrossChainConfig {
impl Default for CrossChainConfig { impl Default for CrossChainConfig {
fn default() -> Self { fn default() -> Self {
let mut tracked = HashMap::new(); let mut tracked = HashMap::new();
tracked.insert(ChainNetwork::Ethereum, vec!["ETH".to_string(), "USDC".to_string(), "USDT".to_string()]); tracked.insert(
ChainNetwork::Ethereum,
vec!["ETH".to_string(), "USDC".to_string(), "USDT".to_string()],
);
tracked.insert(ChainNetwork::Bitcoin, vec!["BTC".to_string()]); tracked.insert(ChainNetwork::Bitcoin, vec!["BTC".to_string()]);
tracked.insert(ChainNetwork::Cosmos, vec!["ATOM".to_string()]); tracked.insert(ChainNetwork::Cosmos, vec!["ATOM".to_string()]);
tracked.insert(ChainNetwork::Osmosis, vec!["OSMO".to_string()]); tracked.insert(ChainNetwork::Osmosis, vec!["OSMO".to_string()]);
@ -305,16 +308,21 @@ impl CrossChainOracle {
}; };
if !verified { if !verified {
return Err(EconomicsError::InvalidPrice("Packet verification failed".into())); return Err(EconomicsError::InvalidPrice(
"Packet verification failed".into(),
));
} }
// Cache the price // Cache the price
let pair_key = format!("{}/{}", packet.token, packet.quote); let pair_key = format!("{}/{}", packet.token, packet.quote);
self.cache.insert(pair_key, CrossChainPrice { self.cache.insert(
packet, pair_key,
received_at: Utc::now(), CrossChainPrice {
verified, packet,
}); received_at: Utc::now(),
verified,
},
);
Ok(()) Ok(())
} }
@ -326,7 +334,9 @@ impl CrossChainOracle {
token: &str, token: &str,
quote: &str, quote: &str,
) -> Result<CrossChainPricePacket> { ) -> Result<CrossChainPricePacket> {
let fetcher = self.fetchers.get(&chain) let fetcher = self
.fetchers
.get(&chain)
.ok_or_else(|| EconomicsError::PriceFeedUnavailable(format!("{:?}", chain)))?; .ok_or_else(|| EconomicsError::PriceFeedUnavailable(format!("{:?}", chain)))?;
let packet = fetcher.fetch_price(token, quote).await?; let packet = fetcher.fetch_price(token, quote).await?;
@ -334,11 +344,14 @@ impl CrossChainOracle {
// Verify and cache // Verify and cache
if fetcher.verify_packet(&packet) { if fetcher.verify_packet(&packet) {
let pair_key = format!("{}/{}", token, quote); let pair_key = format!("{}/{}", token, quote);
self.cache.insert(pair_key.clone(), CrossChainPrice { self.cache.insert(
packet: packet.clone(), pair_key.clone(),
received_at: Utc::now(), CrossChainPrice {
verified: true, packet: packet.clone(),
}); received_at: Utc::now(),
verified: true,
},
);
} }
Ok(packet) Ok(packet)
@ -347,7 +360,8 @@ impl CrossChainOracle {
/// Get cached price for a token pair /// Get cached price for a token pair
pub fn get_price(&self, token: &str, quote: &str) -> Option<SynorDecimal> { pub fn get_price(&self, token: &str, quote: &str) -> Option<SynorDecimal> {
let pair_key = format!("{}/{}", token, quote); let pair_key = format!("{}/{}", token, quote);
self.cache.get(&pair_key) self.cache
.get(&pair_key)
.filter(|c| c.verified) .filter(|c| c.verified)
.filter(|c| (Utc::now() - c.received_at).num_seconds() < self.config.max_packet_age) .filter(|c| (Utc::now() - c.received_at).num_seconds() < self.config.max_packet_age)
.map(|c| c.packet.price) .map(|c| c.packet.price)
@ -356,7 +370,8 @@ impl CrossChainOracle {
/// Get price with full packet info /// Get price with full packet info
pub fn get_price_with_info(&self, token: &str, quote: &str) -> Option<&CrossChainPricePacket> { pub fn get_price_with_info(&self, token: &str, quote: &str) -> Option<&CrossChainPricePacket> {
let pair_key = format!("{}/{}", token, quote); let pair_key = format!("{}/{}", token, quote);
self.cache.get(&pair_key) self.cache
.get(&pair_key)
.filter(|c| c.verified) .filter(|c| c.verified)
.map(|c| &c.packet) .map(|c| &c.packet)
} }
@ -408,7 +423,8 @@ impl CrossChainOracle {
/// Get all cached prices /// Get all cached prices
pub fn get_all_prices(&self) -> Vec<TokenPrice> { pub fn get_all_prices(&self) -> Vec<TokenPrice> {
self.cache.values() self.cache
.values()
.filter(|c| c.verified) .filter(|c| c.verified)
.map(|c| c.packet.to_token_price()) .map(|c| c.packet.to_token_price())
.collect() .collect()
@ -417,9 +433,8 @@ impl CrossChainOracle {
/// Clear stale cache entries /// Clear stale cache entries
pub fn cleanup_cache(&mut self) { pub fn cleanup_cache(&mut self) {
let max_age = self.config.max_packet_age; let max_age = self.config.max_packet_age;
self.cache.retain(|_, v| { self.cache
(Utc::now() - v.received_at).num_seconds() < max_age .retain(|_, v| (Utc::now() - v.received_at).num_seconds() < max_age);
});
} }
/// Send an IBC price request and track pending packet /// Send an IBC price request and track pending packet
@ -450,7 +465,11 @@ impl CrossChainOracle {
} }
/// Confirm a pending packet was received /// Confirm a pending packet was received
pub fn confirm_pending_packet(&mut self, channel: &str, sequence: u64) -> Option<PendingPacket> { pub fn confirm_pending_packet(
&mut self,
channel: &str,
sequence: u64,
) -> Option<PendingPacket> {
if let Some(idx) = self if let Some(idx) = self
.pending_packets .pending_packets
.iter() .iter()
@ -469,9 +488,8 @@ impl CrossChainOracle {
/// Cleanup timed out pending packets /// Cleanup timed out pending packets
pub fn cleanup_pending(&mut self, timeout_secs: i64) { pub fn cleanup_pending(&mut self, timeout_secs: i64) {
self.pending_packets.retain(|p| { self.pending_packets
(Utc::now() - p.sent_at).num_seconds() < timeout_secs .retain(|p| (Utc::now() - p.sent_at).num_seconds() < timeout_secs);
});
} }
} }
@ -502,9 +520,18 @@ pub struct EthereumPriceFetcher {
impl EthereumPriceFetcher { impl EthereumPriceFetcher {
pub fn new(rpc_url: impl Into<String>) -> Self { pub fn new(rpc_url: impl Into<String>) -> Self {
let mut feeds = HashMap::new(); let mut feeds = HashMap::new();
feeds.insert("ETH/USD".to_string(), "0x5f4eC3Df9cbd43714FE2740f5E3616155c5b8419".to_string()); feeds.insert(
feeds.insert("BTC/USD".to_string(), "0xF4030086522a5bEEa4988F8cA5B36dbC97BeE88c".to_string()); "ETH/USD".to_string(),
feeds.insert("USDC/USD".to_string(), "0x8fFfFfd4AfB6115b954Bd326cbe7B4BA576818f6".to_string()); "0x5f4eC3Df9cbd43714FE2740f5E3616155c5b8419".to_string(),
);
feeds.insert(
"BTC/USD".to_string(),
"0xF4030086522a5bEEa4988F8cA5B36dbC97BeE88c".to_string(),
);
feeds.insert(
"USDC/USD".to_string(),
"0x8fFfFfd4AfB6115b954Bd326cbe7B4BA576818f6".to_string(),
);
Self { Self {
rpc_url: rpc_url.into(), rpc_url: rpc_url.into(),
@ -526,7 +553,9 @@ impl ChainPriceFetcher for EthereumPriceFetcher {
async fn fetch_price(&self, token: &str, quote: &str) -> Result<CrossChainPricePacket> { async fn fetch_price(&self, token: &str, quote: &str) -> Result<CrossChainPricePacket> {
let pair = format!("{}/{}", token, quote); let pair = format!("{}/{}", token, quote);
let _feed_addr = self.chainlink_feeds.get(&pair) let _feed_addr = self
.chainlink_feeds
.get(&pair)
.ok_or_else(|| EconomicsError::PriceFeedUnavailable(pair.clone()))?; .ok_or_else(|| EconomicsError::PriceFeedUnavailable(pair.clone()))?;
// In production: Call Chainlink aggregator via ethers-rs // In production: Call Chainlink aggregator via ethers-rs
@ -539,19 +568,16 @@ impl ChainPriceFetcher for EthereumPriceFetcher {
source_block: 19000000, source_block: 19000000,
source_timestamp: Utc::now(), source_timestamp: Utc::now(),
proof: None, proof: None,
signatures: vec![ signatures: vec![OracleSignature {
OracleSignature { signer: "chainlink".to_string(),
signer: "chainlink".to_string(), signature: vec![0; 65],
signature: vec![0; 65], timestamp: Utc::now(),
timestamp: Utc::now(), }],
},
],
}) })
} }
fn verify_packet(&self, packet: &CrossChainPricePacket) -> bool { fn verify_packet(&self, packet: &CrossChainPricePacket) -> bool {
packet.source_chain == ChainNetwork::Ethereum packet.source_chain == ChainNetwork::Ethereum && !packet.signatures.is_empty()
&& !packet.signatures.is_empty()
} }
fn supported_tokens(&self) -> Vec<String> { fn supported_tokens(&self) -> Vec<String> {
@ -611,8 +637,7 @@ impl ChainPriceFetcher for CosmosPriceFetcher {
} }
fn verify_packet(&self, packet: &CrossChainPricePacket) -> bool { fn verify_packet(&self, packet: &CrossChainPricePacket) -> bool {
packet.source_chain == ChainNetwork::Cosmos packet.source_chain == ChainNetwork::Cosmos && packet.proof.is_some()
&& packet.proof.is_some()
} }
fn supported_tokens(&self) -> Vec<String> { fn supported_tokens(&self) -> Vec<String> {
@ -650,7 +675,11 @@ impl CrossChainOracleBuilder {
} }
/// Add Cosmos/IBC price fetcher /// Add Cosmos/IBC price fetcher
pub fn with_cosmos(mut self, light_client_id: impl Into<String>, chain_id: impl Into<String>) -> Self { pub fn with_cosmos(
mut self,
light_client_id: impl Into<String>,
chain_id: impl Into<String>,
) -> Self {
self.cosmos_light_client = Some((light_client_id.into(), chain_id.into())); self.cosmos_light_client = Some((light_client_id.into(), chain_id.into()));
self self
} }
@ -693,8 +722,7 @@ impl CrossChainOracleFactory {
/// Create a production oracle with real endpoints /// Create a production oracle with real endpoints
pub fn production(config: CrossChainProductionConfig) -> CrossChainOracle { pub fn production(config: CrossChainProductionConfig) -> CrossChainOracle {
let mut builder = CrossChainOracleBuilder::new() let mut builder = CrossChainOracleBuilder::new().with_config(config.cross_chain_config);
.with_config(config.cross_chain_config);
if let Some(eth_rpc) = config.ethereum_rpc_url { if let Some(eth_rpc) = config.ethereum_rpc_url {
builder = builder.with_ethereum(eth_rpc); builder = builder.with_ethereum(eth_rpc);
@ -715,7 +743,10 @@ impl CrossChainOracleFactory {
} }
/// Create an oracle with only Cosmos/IBC support /// Create an oracle with only Cosmos/IBC support
pub fn cosmos_only(light_client_id: impl Into<String>, chain_id: impl Into<String>) -> CrossChainOracle { pub fn cosmos_only(
light_client_id: impl Into<String>,
chain_id: impl Into<String>,
) -> CrossChainOracle {
CrossChainOracleBuilder::new() CrossChainOracleBuilder::new()
.with_cosmos(light_client_id, chain_id) .with_cosmos(light_client_id, chain_id)
.build() .build()

View file

@ -112,7 +112,9 @@ impl AggregationRound {
/// Add a submission to this round /// Add a submission to this round
pub fn add_submission(&mut self, submission: PriceSubmission) -> Result<()> { pub fn add_submission(&mut self, submission: PriceSubmission) -> Result<()> {
if self.finalized { if self.finalized {
return Err(EconomicsError::InvalidPrice("Round already finalized".into())); return Err(EconomicsError::InvalidPrice(
"Round already finalized".into(),
));
} }
if Utc::now() >= self.deadline { if Utc::now() >= self.deadline {
return Err(EconomicsError::InvalidPrice("Round deadline passed".into())); return Err(EconomicsError::InvalidPrice("Round deadline passed".into()));
@ -122,7 +124,11 @@ impl AggregationRound {
} }
// Check for duplicate submission from same node // Check for duplicate submission from same node
if self.submissions.iter().any(|s| s.node_id == submission.node_id) { if self
.submissions
.iter()
.any(|s| s.node_id == submission.node_id)
{
return Err(EconomicsError::InvalidPrice("Duplicate submission".into())); return Err(EconomicsError::InvalidPrice("Duplicate submission".into()));
} }
@ -231,7 +237,9 @@ impl DecentralizedOracle {
/// Update node heartbeat /// Update node heartbeat
pub fn heartbeat(&mut self, node_id: &str) -> Result<()> { pub fn heartbeat(&mut self, node_id: &str) -> Result<()> {
let node = self.nodes.get_mut(node_id) let node = self
.nodes
.get_mut(node_id)
.ok_or_else(|| EconomicsError::InvalidPrice(format!("Unknown node: {}", node_id)))?; .ok_or_else(|| EconomicsError::InvalidPrice(format!("Unknown node: {}", node_id)))?;
node.last_heartbeat = Utc::now(); node.last_heartbeat = Utc::now();
Ok(()) Ok(())
@ -241,11 +249,8 @@ impl DecentralizedOracle {
pub fn start_round(&mut self, pair: impl Into<String>) -> u64 { pub fn start_round(&mut self, pair: impl Into<String>) -> u64 {
let pair = pair.into(); let pair = pair.into();
self.round_counter += 1; self.round_counter += 1;
let round = AggregationRound::new( let round =
self.round_counter, AggregationRound::new(self.round_counter, pair.clone(), self.config.round_duration);
pair.clone(),
self.config.round_duration
);
self.current_rounds.insert(pair, round); self.current_rounds.insert(pair, round);
self.round_counter self.round_counter
} }
@ -253,7 +258,9 @@ impl DecentralizedOracle {
/// Submit a price for the current round /// Submit a price for the current round
pub fn submit_price(&mut self, submission: PriceSubmission) -> Result<()> { pub fn submit_price(&mut self, submission: PriceSubmission) -> Result<()> {
// Verify node exists and is eligible // Verify node exists and is eligible
let node = self.nodes.get(&submission.node_id) let node = self
.nodes
.get(&submission.node_id)
.ok_or_else(|| EconomicsError::InvalidPrice("Unknown node".into()))?; .ok_or_else(|| EconomicsError::InvalidPrice("Unknown node".into()))?;
if !node.is_eligible(self.config.min_stake, self.config.min_reputation) { if !node.is_eligible(self.config.min_stake, self.config.min_reputation) {
@ -264,7 +271,9 @@ impl DecentralizedOracle {
// For now, we trust the submission // For now, we trust the submission
// Add to current round // Add to current round
let round = self.current_rounds.get_mut(&submission.pair) let round = self
.current_rounds
.get_mut(&submission.pair)
.ok_or_else(|| EconomicsError::InvalidPrice("No active round for pair".into()))?; .ok_or_else(|| EconomicsError::InvalidPrice("No active round for pair".into()))?;
round.add_submission(submission) round.add_submission(submission)
@ -274,13 +283,15 @@ impl DecentralizedOracle {
pub fn finalize_round(&mut self, pair: &str) -> Result<SynorDecimal> { pub fn finalize_round(&mut self, pair: &str) -> Result<SynorDecimal> {
// First check state and get submissions (immutable borrow) // First check state and get submissions (immutable borrow)
let (is_finalized, existing_price, submissions) = { let (is_finalized, existing_price, submissions) = {
let round = self.current_rounds.get(pair) let round = self
.current_rounds
.get(pair)
.ok_or_else(|| EconomicsError::PriceFeedUnavailable(pair.to_string()))?; .ok_or_else(|| EconomicsError::PriceFeedUnavailable(pair.to_string()))?;
if round.finalized { if round.finalized {
return round.final_price.ok_or_else(|| return round
EconomicsError::InvalidPrice("Round has no price".into()) .final_price
); .ok_or_else(|| EconomicsError::InvalidPrice("Round has no price".into()));
} }
// Check minimum submissions // Check minimum submissions
@ -292,13 +303,16 @@ impl DecentralizedOracle {
))); )));
} }
(round.finalized, round.final_price, round.submissions.clone()) (
round.finalized,
round.final_price,
round.submissions.clone(),
)
}; };
if is_finalized { if is_finalized {
return existing_price.ok_or_else(|| return existing_price
EconomicsError::InvalidPrice("Round has no price".into()) .ok_or_else(|| EconomicsError::InvalidPrice("Round has no price".into()));
);
} }
// Filter outliers and aggregate (using cloned submissions) // Filter outliers and aggregate (using cloned submissions)
@ -326,7 +340,11 @@ impl DecentralizedOracle {
} }
/// Aggregate prices from a vector of submissions (owned) /// Aggregate prices from a vector of submissions (owned)
fn aggregate_prices_from_vec(&self, pair: &str, submissions: &[PriceSubmission]) -> Result<SynorDecimal> { fn aggregate_prices_from_vec(
&self,
pair: &str,
submissions: &[PriceSubmission],
) -> Result<SynorDecimal> {
if submissions.is_empty() { if submissions.is_empty() {
return Err(EconomicsError::PriceFeedUnavailable(pair.to_string())); return Err(EconomicsError::PriceFeedUnavailable(pair.to_string()));
} }
@ -334,15 +352,21 @@ impl DecentralizedOracle {
// Filter outliers first // Filter outliers first
let filtered = self.filter_outliers_vec(submissions); let filtered = self.filter_outliers_vec(submissions);
if filtered.is_empty() { if filtered.is_empty() {
return Err(EconomicsError::InvalidPrice("All submissions were outliers".into())); return Err(EconomicsError::InvalidPrice(
"All submissions were outliers".into(),
));
} }
let filtered_refs: Vec<_> = filtered.iter().collect(); let filtered_refs: Vec<_> = filtered.iter().collect();
match self.strategy { match self.strategy {
AggregationStrategy::Median => self.calculate_median(&filtered_refs), AggregationStrategy::Median => self.calculate_median(&filtered_refs),
AggregationStrategy::StakeWeightedMedian => self.calculate_stake_weighted_median(&filtered_refs), AggregationStrategy::StakeWeightedMedian => {
self.calculate_stake_weighted_median(&filtered_refs)
}
AggregationStrategy::TrimmedMean => self.calculate_trimmed_mean(&filtered_refs), AggregationStrategy::TrimmedMean => self.calculate_trimmed_mean(&filtered_refs),
AggregationStrategy::ReputationWeighted => self.calculate_reputation_weighted(&filtered_refs), AggregationStrategy::ReputationWeighted => {
self.calculate_reputation_weighted(&filtered_refs)
}
} }
} }
@ -376,13 +400,14 @@ impl DecentralizedOracle {
} }
/// Calculate stake-weighted median /// Calculate stake-weighted median
fn calculate_stake_weighted_median(&self, submissions: &[&PriceSubmission]) -> Result<SynorDecimal> { fn calculate_stake_weighted_median(
&self,
submissions: &[&PriceSubmission],
) -> Result<SynorDecimal> {
// Get stake for each submission // Get stake for each submission
let mut weighted: Vec<(SynorDecimal, SynorDecimal)> = submissions let mut weighted: Vec<(SynorDecimal, SynorDecimal)> = submissions
.iter() .iter()
.filter_map(|s| { .filter_map(|s| self.nodes.get(&s.node_id).map(|n| (s.price, n.stake)))
self.nodes.get(&s.node_id).map(|n| (s.price, n.stake))
})
.collect(); .collect();
if weighted.is_empty() { if weighted.is_empty() {
@ -424,12 +449,17 @@ impl DecentralizedOracle {
} }
/// Calculate reputation-weighted average /// Calculate reputation-weighted average
fn calculate_reputation_weighted(&self, submissions: &[&PriceSubmission]) -> Result<SynorDecimal> { fn calculate_reputation_weighted(
&self,
submissions: &[&PriceSubmission],
) -> Result<SynorDecimal> {
let mut weighted_sum = Decimal::ZERO; let mut weighted_sum = Decimal::ZERO;
let mut total_weight = Decimal::ZERO; let mut total_weight = Decimal::ZERO;
for sub in submissions { for sub in submissions {
let reputation = self.nodes.get(&sub.node_id) let reputation = self
.nodes
.get(&sub.node_id)
.map(|n| n.reputation) .map(|n| n.reputation)
.unwrap_or(0.5); .unwrap_or(0.5);
@ -448,7 +478,9 @@ impl DecentralizedOracle {
/// Update node reputations based on submission accuracy /// Update node reputations based on submission accuracy
fn update_reputations(&mut self, _pair: &str, final_price: SynorDecimal) { fn update_reputations(&mut self, _pair: &str, final_price: SynorDecimal) {
// Get submissions from current round before it was moved // Get submissions from current round before it was moved
let submissions: Vec<_> = self.history.last() let submissions: Vec<_> = self
.history
.last()
.map(|r| r.submissions.clone()) .map(|r| r.submissions.clone())
.unwrap_or_default(); .unwrap_or_default();
@ -457,7 +489,8 @@ impl DecentralizedOracle {
let deviation = (sub.price - final_price).abs() / final_price; let deviation = (sub.price - final_price).abs() / final_price;
// Increase reputation for accurate submissions, decrease for inaccurate // Increase reputation for accurate submissions, decrease for inaccurate
if deviation <= Decimal::new(1, 2) { // Within 1% if deviation <= Decimal::new(1, 2) {
// Within 1%
node.reputation = (node.reputation + 0.01).min(1.0); node.reputation = (node.reputation + 0.01).min(1.0);
} else if deviation > self.config.max_deviation { } else if deviation > self.config.max_deviation {
node.reputation = (node.reputation - 0.05).max(0.0); node.reputation = (node.reputation - 0.05).max(0.0);
@ -473,7 +506,8 @@ impl DecentralizedOracle {
/// Get number of active nodes /// Get number of active nodes
pub fn active_node_count(&self) -> usize { pub fn active_node_count(&self) -> usize {
self.nodes.values() self.nodes
.values()
.filter(|n| n.is_eligible(self.config.min_stake, self.config.min_reputation)) .filter(|n| n.is_eligible(self.config.min_stake, self.config.min_reputation))
.count() .count()
} }
@ -492,20 +526,23 @@ impl DecentralizedOracle {
/// Convert finalized price to TokenPrice /// Convert finalized price to TokenPrice
pub fn to_token_price(&self, pair: &str) -> Option<TokenPrice> { pub fn to_token_price(&self, pair: &str) -> Option<TokenPrice> {
self.history.iter() self.history
.iter()
.rev() .rev()
.find(|r| r.pair == pair && r.finalized) .find(|r| r.pair == pair && r.finalized)
.and_then(|r| r.final_price.map(|price| { .and_then(|r| {
let parts: Vec<_> = pair.split('/').collect(); r.final_price.map(|price| {
TokenPrice { let parts: Vec<_> = pair.split('/').collect();
token: parts.get(0).unwrap_or(&"").to_string(), TokenPrice {
quote: parts.get(1).unwrap_or(&"").to_string(), token: parts.get(0).unwrap_or(&"").to_string(),
price, quote: parts.get(1).unwrap_or(&"").to_string(),
timestamp: r.deadline, price,
source: PriceSource::Aggregated, timestamp: r.deadline,
confidence: 1.0, source: PriceSource::Aggregated,
} confidence: 1.0,
})) }
})
})
} }
} }
@ -521,13 +558,15 @@ mod tests {
use rust_decimal_macros::dec; use rust_decimal_macros::dec;
fn create_test_nodes() -> Vec<OracleNode> { fn create_test_nodes() -> Vec<OracleNode> {
(0..5).map(|i| { (0..5)
OracleNode::new( .map(|i| {
format!("node_{}", i), OracleNode::new(
vec![i as u8; 32], format!("node_{}", i),
dec!(10000), // 10k stake vec![i as u8; 32],
) dec!(10000), // 10k stake
}).collect() )
})
.collect()
} }
#[test] #[test]

View file

@ -6,8 +6,8 @@
use crate::error::{EconomicsError, Result}; use crate::error::{EconomicsError, Result};
use crate::SynorDecimal; use crate::SynorDecimal;
use chrono::{DateTime, Duration, Timelike, Utc}; use chrono::{DateTime, Duration, Timelike, Utc};
use rust_decimal::Decimal;
use rust_decimal::prelude::ToPrimitive; use rust_decimal::prelude::ToPrimitive;
use rust_decimal::Decimal;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::HashMap;
use std::f64::consts::PI; use std::f64::consts::PI;
@ -178,7 +178,12 @@ impl BlackScholes {
} }
/// Price a European option /// Price a European option
pub fn price(&self, contract: &OptionContract, spot: SynorDecimal, vol: f64) -> Result<OptionPricing> { pub fn price(
&self,
contract: &OptionContract,
spot: SynorDecimal,
vol: f64,
) -> Result<OptionPricing> {
if contract.is_expired() { if contract.is_expired() {
// At expiration, option is worth intrinsic value // At expiration, option is worth intrinsic value
let intrinsic = contract.intrinsic_value(spot); let intrinsic = contract.intrinsic_value(spot);
@ -208,12 +213,13 @@ impl BlackScholes {
}); });
} }
let s = spot.to_f64().ok_or_else(|| let s = spot
EconomicsError::InvalidPrice("Invalid spot price".into()) .to_f64()
)?; .ok_or_else(|| EconomicsError::InvalidPrice("Invalid spot price".into()))?;
let k = contract.strike.to_f64().ok_or_else(|| let k = contract
EconomicsError::InvalidPrice("Invalid strike price".into()) .strike
)?; .to_f64()
.ok_or_else(|| EconomicsError::InvalidPrice("Invalid strike price".into()))?;
let t = contract.time_to_expiry(); let t = contract.time_to_expiry();
if t <= 0.0 || vol <= 0.0 { if t <= 0.0 || vol <= 0.0 {
@ -286,13 +292,10 @@ impl BlackScholes {
let theta_common = -(s * vol * (-q * t).exp() * n_prime_d1) / (2.0 * sqrt_t); let theta_common = -(s * vol * (-q * t).exp() * n_prime_d1) / (2.0 * sqrt_t);
let theta = match contract.option_type { let theta = match contract.option_type {
OptionType::Call => { OptionType::Call => {
theta_common theta_common + q * s * (-q * t).exp() * n_d1 - r * k * (-r * t).exp() * n_d2
+ q * s * (-q * t).exp() * n_d1
- r * k * (-r * t).exp() * n_d2
} }
OptionType::Put => { OptionType::Put => {
theta_common theta_common - q * s * (-q * t).exp() * (1.0 - n_d1)
- q * s * (-q * t).exp() * (1.0 - n_d1)
+ r * k * (-r * t).exp() * (1.0 - n_d2) + r * k * (-r * t).exp() * (1.0 - n_d2)
} }
} / 365.0; // Per day } / 365.0; // Per day
@ -327,16 +330,16 @@ impl BlackScholes {
spot: SynorDecimal, spot: SynorDecimal,
market_price: SynorDecimal, market_price: SynorDecimal,
) -> Result<f64> { ) -> Result<f64> {
let target = market_price.to_f64().ok_or_else(|| let target = market_price
EconomicsError::InvalidPrice("Invalid market price".into()) .to_f64()
)?; .ok_or_else(|| EconomicsError::InvalidPrice("Invalid market price".into()))?;
// Initial guess based on time value // Initial guess based on time value
let intrinsic = contract.intrinsic_value(spot).to_f64().unwrap_or(0.0); let intrinsic = contract.intrinsic_value(spot).to_f64().unwrap_or(0.0);
let time_value = (target - intrinsic).max(0.0); let time_value = (target - intrinsic).max(0.0);
let s = spot.to_f64().ok_or_else(|| let s = spot
EconomicsError::InvalidPrice("Invalid spot".into()) .to_f64()
)?; .ok_or_else(|| EconomicsError::InvalidPrice("Invalid spot".into()))?;
let t = contract.time_to_expiry(); let t = contract.time_to_expiry();
// Brenner-Subrahmanyam approximation for initial guess // Brenner-Subrahmanyam approximation for initial guess
@ -381,7 +384,11 @@ impl BlackScholes {
for _ in 0..100 { for _ in 0..100 {
let mid = (low + high) / 2.0; let mid = (low + high) / 2.0;
let price = self.price(contract, spot, mid)?.price.to_f64().unwrap_or(0.0); let price = self
.price(contract, spot, mid)?
.price
.to_f64()
.unwrap_or(0.0);
if (price - target).abs() < 0.0001 { if (price - target).abs() < 0.0001 {
return Ok(mid); return Ok(mid);
@ -493,9 +500,9 @@ impl FuturesModel {
/// F = S * e^((r + u - y) * T) /// F = S * e^((r + u - y) * T)
/// where r = risk-free rate, u = storage cost, y = convenience yield /// where r = risk-free rate, u = storage cost, y = convenience yield
pub fn price(&self, contract: &FuturesContract, spot: SynorDecimal) -> Result<FuturesPricing> { pub fn price(&self, contract: &FuturesContract, spot: SynorDecimal) -> Result<FuturesPricing> {
let s = spot.to_f64().ok_or_else(|| let s = spot
EconomicsError::InvalidPrice("Invalid spot price".into()) .to_f64()
)?; .ok_or_else(|| EconomicsError::InvalidPrice("Invalid spot price".into()))?;
let t = contract.time_to_expiry(); let t = contract.time_to_expiry();
if t < 0.0 { if t < 0.0 {
@ -515,8 +522,7 @@ impl FuturesModel {
0.0 0.0
}; };
let coc = Decimal::from_f64_retain(cost_of_carry * t * s) let coc = Decimal::from_f64_retain(cost_of_carry * t * s).unwrap_or(Decimal::ZERO);
.unwrap_or(Decimal::ZERO);
Ok(FuturesPricing { Ok(FuturesPricing {
fair_value, fair_value,
@ -529,13 +535,18 @@ impl FuturesModel {
/// Calculate implied repo rate from futures price /// Calculate implied repo rate from futures price
/// R = (F/S - 1) / T /// R = (F/S - 1) / T
pub fn implied_repo_rate(&self, contract: &FuturesContract, spot: SynorDecimal, futures_price: SynorDecimal) -> Result<f64> { pub fn implied_repo_rate(
let s = spot.to_f64().ok_or_else(|| &self,
EconomicsError::InvalidPrice("Invalid spot".into()) contract: &FuturesContract,
)?; spot: SynorDecimal,
let f = futures_price.to_f64().ok_or_else(|| futures_price: SynorDecimal,
EconomicsError::InvalidPrice("Invalid futures price".into()) ) -> Result<f64> {
)?; let s = spot
.to_f64()
.ok_or_else(|| EconomicsError::InvalidPrice("Invalid spot".into()))?;
let f = futures_price
.to_f64()
.ok_or_else(|| EconomicsError::InvalidPrice("Invalid futures price".into()))?;
let t = contract.time_to_expiry(); let t = contract.time_to_expiry();
if t <= 0.0 || s <= 0.0 { if t <= 0.0 || s <= 0.0 {
@ -595,12 +606,12 @@ impl PerpetualModel {
mark_price: SynorDecimal, mark_price: SynorDecimal,
index_price: SynorDecimal, index_price: SynorDecimal,
) -> Result<f64> { ) -> Result<f64> {
let mark = mark_price.to_f64().ok_or_else(|| let mark = mark_price
EconomicsError::InvalidPrice("Invalid mark price".into()) .to_f64()
)?; .ok_or_else(|| EconomicsError::InvalidPrice("Invalid mark price".into()))?;
let index = index_price.to_f64().ok_or_else(|| let index = index_price
EconomicsError::InvalidPrice("Invalid index price".into()) .to_f64()
)?; .ok_or_else(|| EconomicsError::InvalidPrice("Invalid index price".into()))?;
if index <= 0.0 { if index <= 0.0 {
return Err(EconomicsError::InvalidPrice("Invalid index".into())); return Err(EconomicsError::InvalidPrice("Invalid index".into()));
@ -640,7 +651,8 @@ impl PerpetualModel {
let hours_since_midnight = now.time().hour(); let hours_since_midnight = now.time().hour();
let next_funding_hour = ((hours_since_midnight / self.funding_interval_hours) + 1) let next_funding_hour = ((hours_since_midnight / self.funding_interval_hours) + 1)
* self.funding_interval_hours; * self.funding_interval_hours;
let next_funding = now.date_naive() let next_funding = now
.date_naive()
.and_hms_opt(next_funding_hour % 24, 0, 0) .and_hms_opt(next_funding_hour % 24, 0, 0)
.map(|dt| DateTime::from_naive_utc_and_offset(dt, Utc)) .map(|dt| DateTime::from_naive_utc_and_offset(dt, Utc))
.unwrap_or(now + Duration::hours(self.funding_interval_hours as i64)); .unwrap_or(now + Duration::hours(self.funding_interval_hours as i64));
@ -721,7 +733,8 @@ impl DerivativesOracle {
/// Set volatility surface for an underlying /// Set volatility surface for an underlying
pub fn set_vol_surface(&mut self, surface: VolatilitySurface) { pub fn set_vol_surface(&mut self, surface: VolatilitySurface) {
self.vol_surfaces.insert(surface.underlying.clone(), surface); self.vol_surfaces
.insert(surface.underlying.clone(), surface);
} }
/// Price an option using the volatility surface /// Price an option using the volatility surface
@ -730,12 +743,13 @@ impl DerivativesOracle {
contract: &OptionContract, contract: &OptionContract,
spot: SynorDecimal, spot: SynorDecimal,
) -> Result<OptionPricing> { ) -> Result<OptionPricing> {
let s = spot.to_f64().ok_or_else(|| let s = spot
EconomicsError::InvalidPrice("Invalid spot".into()) .to_f64()
)?; .ok_or_else(|| EconomicsError::InvalidPrice("Invalid spot".into()))?;
let k = contract.strike.to_f64().ok_or_else(|| let k = contract
EconomicsError::InvalidPrice("Invalid strike".into()) .strike
)?; .to_f64()
.ok_or_else(|| EconomicsError::InvalidPrice("Invalid strike".into()))?;
// Get volatility from surface or use default // Get volatility from surface or use default
let vol = if let Some(surface) = self.vol_surfaces.get(&contract.underlying) { let vol = if let Some(surface) = self.vol_surfaces.get(&contract.underlying) {
@ -765,7 +779,8 @@ impl DerivativesOracle {
mark_price: SynorDecimal, mark_price: SynorDecimal,
open_interest: SynorDecimal, open_interest: SynorDecimal,
) -> Result<PerpetualPricing> { ) -> Result<PerpetualPricing> {
self.perpetual_model.price(index_price, mark_price, open_interest) self.perpetual_model
.price(index_price, mark_price, open_interest)
} }
/// Calculate implied vol from market price /// Calculate implied vol from market price
@ -775,7 +790,8 @@ impl DerivativesOracle {
spot: SynorDecimal, spot: SynorDecimal,
market_price: SynorDecimal, market_price: SynorDecimal,
) -> Result<f64> { ) -> Result<f64> {
self.options_model.implied_volatility(contract, spot, market_price) self.options_model
.implied_volatility(contract, spot, market_price)
} }
} }
@ -792,8 +808,18 @@ mod tests {
#[test] #[test]
fn test_option_intrinsic_value() { fn test_option_intrinsic_value() {
let call = OptionContract::new("ETH", dec!(2000), Utc::now() + Duration::days(30), OptionType::Call); let call = OptionContract::new(
let put = OptionContract::new("ETH", dec!(2000), Utc::now() + Duration::days(30), OptionType::Put); "ETH",
dec!(2000),
Utc::now() + Duration::days(30),
OptionType::Call,
);
let put = OptionContract::new(
"ETH",
dec!(2000),
Utc::now() + Duration::days(30),
OptionType::Put,
);
// ITM call // ITM call
assert_eq!(call.intrinsic_value(dec!(2100)), dec!(100)); assert_eq!(call.intrinsic_value(dec!(2100)), dec!(100));
@ -880,7 +906,9 @@ mod tests {
let pricing = model.price(&contract, dec!(2000), vol).unwrap(); let pricing = model.price(&contract, dec!(2000), vol).unwrap();
// Calculate IV from that price // Calculate IV from that price
let iv = model.implied_volatility(&contract, dec!(2000), pricing.price).unwrap(); let iv = model
.implied_volatility(&contract, dec!(2000), pricing.price)
.unwrap();
// Should match original vol // Should match original vol
assert!((iv - vol).abs() < 0.01); assert!((iv - vol).abs() < 0.01);

View file

@ -40,12 +40,12 @@ impl CollateralAsset {
pub fn standard(symbol: impl Into<String>) -> Self { pub fn standard(symbol: impl Into<String>) -> Self {
Self { Self {
symbol: symbol.into(), symbol: symbol.into(),
collateral_factor: Decimal::new(75, 2), // 75% collateral_factor: Decimal::new(75, 2), // 75%
liquidation_threshold: Decimal::new(80, 2), // 80% liquidation_threshold: Decimal::new(80, 2), // 80%
liquidation_bonus: Decimal::new(5, 2), // 5% liquidation_bonus: Decimal::new(5, 2), // 5%
supply_cap: None, supply_cap: None,
borrow_enabled: true, borrow_enabled: true,
reserve_factor: Decimal::new(10, 2), // 10% reserve_factor: Decimal::new(10, 2), // 10%
volatility_multiplier: Decimal::ONE, volatility_multiplier: Decimal::ONE,
} }
} }
@ -54,12 +54,12 @@ impl CollateralAsset {
pub fn stablecoin(symbol: impl Into<String>) -> Self { pub fn stablecoin(symbol: impl Into<String>) -> Self {
Self { Self {
symbol: symbol.into(), symbol: symbol.into(),
collateral_factor: Decimal::new(90, 2), // 90% collateral_factor: Decimal::new(90, 2), // 90%
liquidation_threshold: Decimal::new(95, 2), // 95% liquidation_threshold: Decimal::new(95, 2), // 95%
liquidation_bonus: Decimal::new(2, 2), // 2% liquidation_bonus: Decimal::new(2, 2), // 2%
supply_cap: None, supply_cap: None,
borrow_enabled: true, borrow_enabled: true,
reserve_factor: Decimal::new(5, 2), // 5% reserve_factor: Decimal::new(5, 2), // 5%
volatility_multiplier: Decimal::ONE, volatility_multiplier: Decimal::ONE,
} }
} }
@ -68,13 +68,13 @@ impl CollateralAsset {
pub fn volatile(symbol: impl Into<String>) -> Self { pub fn volatile(symbol: impl Into<String>) -> Self {
Self { Self {
symbol: symbol.into(), symbol: symbol.into(),
collateral_factor: Decimal::new(50, 2), // 50% collateral_factor: Decimal::new(50, 2), // 50%
liquidation_threshold: Decimal::new(65, 2), // 65% liquidation_threshold: Decimal::new(65, 2), // 65%
liquidation_bonus: Decimal::new(10, 2), // 10% liquidation_bonus: Decimal::new(10, 2), // 10%
supply_cap: None, supply_cap: None,
borrow_enabled: true, borrow_enabled: true,
reserve_factor: Decimal::new(20, 2), // 20% reserve_factor: Decimal::new(20, 2), // 20%
volatility_multiplier: Decimal::new(12, 1), // 1.2x volatility_multiplier: Decimal::new(12, 1), // 1.2x
} }
} }
} }
@ -131,7 +131,11 @@ impl LendingPosition {
/// Withdraw collateral /// Withdraw collateral
pub fn withdraw(&mut self, asset: impl Into<String>, amount: SynorDecimal) -> Result<()> { pub fn withdraw(&mut self, asset: impl Into<String>, amount: SynorDecimal) -> Result<()> {
let asset = asset.into(); let asset = asset.into();
let current = self.collateral.get(&asset).copied().unwrap_or(Decimal::ZERO); let current = self
.collateral
.get(&asset)
.copied()
.unwrap_or(Decimal::ZERO);
if amount > current { if amount > current {
return Err(EconomicsError::InsufficientFunds { return Err(EconomicsError::InsufficientFunds {
required: amount, required: amount,
@ -233,10 +237,10 @@ pub struct LiquidationOracleConfig {
impl Default for LiquidationOracleConfig { impl Default for LiquidationOracleConfig {
fn default() -> Self { fn default() -> Self {
Self { Self {
max_price_age: 60, // 1 minute (stricter than general oracle) max_price_age: 60, // 1 minute (stricter than general oracle)
min_confidence: 0.9, min_confidence: 0.9,
min_sources: 2, min_sources: 2,
liquidation_grace_period: 300, // 5 minutes liquidation_grace_period: 300, // 5 minutes
min_liquidation_amount: Decimal::new(10, 0), // $10 min_liquidation_amount: Decimal::new(10, 0), // $10
max_liquidation_pct: Decimal::new(50, 2), // 50% at a time max_liquidation_pct: Decimal::new(50, 2), // 50% at a time
partial_liquidation: true, partial_liquidation: true,
@ -289,7 +293,8 @@ impl LiquidationOracle {
/// Create a new position /// Create a new position
pub fn create_position(&mut self, account_id: impl Into<String>) -> &mut LendingPosition { pub fn create_position(&mut self, account_id: impl Into<String>) -> &mut LendingPosition {
let account_id = account_id.into(); let account_id = account_id.into();
self.positions.entry(account_id.clone()) self.positions
.entry(account_id.clone())
.or_insert_with(|| LendingPosition::new(account_id)) .or_insert_with(|| LendingPosition::new(account_id))
} }
@ -320,7 +325,8 @@ impl LiquidationOracle {
freshness_remaining: self.config.max_price_age - age, freshness_remaining: self.config.max_price_age - age,
}; };
self.price_cache.insert(asset.to_string(), liq_price.clone()); self.price_cache
.insert(asset.to_string(), liq_price.clone());
Ok(liq_price) Ok(liq_price)
} }
@ -328,7 +334,9 @@ impl LiquidationOracle {
pub fn calculate_health(&mut self, account_id: &str) -> Result<HealthStatus> { pub fn calculate_health(&mut self, account_id: &str) -> Result<HealthStatus> {
// Clone position data to avoid borrow conflicts with get_liquidation_price // Clone position data to avoid borrow conflicts with get_liquidation_price
let (collateral, borrows, interest_owed) = { let (collateral, borrows, interest_owed) = {
let position = self.positions.get(account_id) let position = self
.positions
.get(account_id)
.ok_or_else(|| EconomicsError::AccountNotFound(account_id.to_string()))?; .ok_or_else(|| EconomicsError::AccountNotFound(account_id.to_string()))?;
( (
position.collateral.clone(), position.collateral.clone(),
@ -352,7 +360,9 @@ impl LiquidationOracle {
}); });
} }
let asset_config = self.assets.get(asset) let asset_config = self
.assets
.get(asset)
.ok_or_else(|| EconomicsError::InvalidPrice(format!("Unknown asset: {}", asset)))?; .ok_or_else(|| EconomicsError::InvalidPrice(format!("Unknown asset: {}", asset)))?;
let value = *amount * price.price; let value = *amount * price.price;
@ -426,24 +436,37 @@ impl LiquidationOracle {
return Err(EconomicsError::InvalidPrice("Position is healthy".into())); return Err(EconomicsError::InvalidPrice("Position is healthy".into()));
} }
let position = self.positions.get(account_id) let position = self
.positions
.get(account_id)
.ok_or_else(|| EconomicsError::AccountNotFound(account_id.to_string()))?; .ok_or_else(|| EconomicsError::AccountNotFound(account_id.to_string()))?;
let debt_amount = position.borrows.get(debt_asset).copied().unwrap_or(Decimal::ZERO); let debt_amount = position
let collateral_amount = position.collateral.get(collateral_asset).copied().unwrap_or(Decimal::ZERO); .borrows
.get(debt_asset)
.copied()
.unwrap_or(Decimal::ZERO);
let collateral_amount = position
.collateral
.get(collateral_asset)
.copied()
.unwrap_or(Decimal::ZERO);
if debt_amount == Decimal::ZERO { if debt_amount == Decimal::ZERO {
return Err(EconomicsError::InvalidPrice("No debt to repay".into())); return Err(EconomicsError::InvalidPrice("No debt to repay".into()));
} }
if collateral_amount == Decimal::ZERO { if collateral_amount == Decimal::ZERO {
return Err(EconomicsError::InvalidPrice("No collateral to seize".into())); return Err(EconomicsError::InvalidPrice(
"No collateral to seize".into(),
));
} }
let debt_price = self.get_liquidation_price(debt_asset)?; let debt_price = self.get_liquidation_price(debt_asset)?;
let collateral_price = self.get_liquidation_price(collateral_asset)?; let collateral_price = self.get_liquidation_price(collateral_asset)?;
let collateral_config = self.assets.get(collateral_asset) let collateral_config = self.assets.get(collateral_asset).ok_or_else(|| {
.ok_or_else(|| EconomicsError::InvalidPrice(format!("Unknown asset: {}", collateral_asset)))?; EconomicsError::InvalidPrice(format!("Unknown asset: {}", collateral_asset))
})?;
// Max debt repayable = close_factor * total_debt // Max debt repayable = close_factor * total_debt
let max_debt_repay = debt_amount * self.config.close_factor; let max_debt_repay = debt_amount * self.config.close_factor;
@ -458,12 +481,14 @@ impl LiquidationOracle {
let actual_collateral_seized = collateral_to_seize.min(collateral_amount); let actual_collateral_seized = collateral_to_seize.min(collateral_amount);
let actual_debt_repaid = if actual_collateral_seized < collateral_to_seize { let actual_debt_repaid = if actual_collateral_seized < collateral_to_seize {
// Partial liquidation // Partial liquidation
(actual_collateral_seized * collateral_price.price) / (bonus_multiplier * debt_price.price) (actual_collateral_seized * collateral_price.price)
/ (bonus_multiplier * debt_price.price)
} else { } else {
max_debt_repay max_debt_repay
}; };
let bonus_amount = actual_collateral_seized * collateral_config.liquidation_bonus / bonus_multiplier; let bonus_amount =
actual_collateral_seized * collateral_config.liquidation_bonus / bonus_multiplier;
Ok(LiquidationCalculation { Ok(LiquidationCalculation {
account_id: account_id.to_string(), account_id: account_id.to_string(),
@ -489,7 +514,9 @@ impl LiquidationOracle {
let calc = self.calculate_liquidation(account_id, debt_asset, collateral_asset)?; let calc = self.calculate_liquidation(account_id, debt_asset, collateral_asset)?;
// Update position // Update position
let position = self.positions.get_mut(account_id) let position = self
.positions
.get_mut(account_id)
.ok_or_else(|| EconomicsError::AccountNotFound(account_id.to_string()))?; .ok_or_else(|| EconomicsError::AccountNotFound(account_id.to_string()))?;
// Reduce debt // Reduce debt
@ -549,7 +576,9 @@ impl LiquidationOracle {
// Protocol gets a portion of the liquidation bonus // Protocol gets a portion of the liquidation bonus
if let Some(asset_config) = self.assets.get(&event.collateral_asset) { if let Some(asset_config) = self.assets.get(&event.collateral_asset) {
let protocol_share = event.bonus_amount * asset_config.reserve_factor; let protocol_share = event.bonus_amount * asset_config.reserve_factor;
*reserves.entry(event.collateral_asset.clone()).or_insert(Decimal::ZERO) += protocol_share; *reserves
.entry(event.collateral_asset.clone())
.or_insert(Decimal::ZERO) += protocol_share;
} }
} }
@ -561,15 +590,18 @@ impl LiquidationOracle {
let total_positions = self.positions.len(); let total_positions = self.positions.len();
let total_liquidations = self.liquidation_history.len(); let total_liquidations = self.liquidation_history.len();
let total_debt_liquidated: SynorDecimal = self.liquidation_history.iter() let total_debt_liquidated: SynorDecimal =
.map(|e| e.debt_amount) self.liquidation_history.iter().map(|e| e.debt_amount).sum();
.sum();
let total_collateral_seized: SynorDecimal = self.liquidation_history.iter() let total_collateral_seized: SynorDecimal = self
.liquidation_history
.iter()
.map(|e| e.collateral_amount) .map(|e| e.collateral_amount)
.sum(); .sum();
let unique_liquidated: std::collections::HashSet<_> = self.liquidation_history.iter() let unique_liquidated: std::collections::HashSet<_> = self
.liquidation_history
.iter()
.map(|e| &e.account_id) .map(|e| &e.account_id)
.collect(); .collect();
@ -617,18 +649,60 @@ mod tests {
let mut price_oracle = PriceOracle::with_config(OracleConfig::default()); let mut price_oracle = PriceOracle::with_config(OracleConfig::default());
// Add prices from multiple sources for test validity // Add prices from multiple sources for test validity
price_oracle.update_price(TokenPrice::new("ETH", "USD", dec!(2000), PriceSource::Internal)).unwrap(); price_oracle
price_oracle.update_price(TokenPrice::new("ETH", "USD", dec!(2000), PriceSource::Aggregated)).unwrap(); .update_price(TokenPrice::new(
price_oracle.update_price(TokenPrice::new("SYNOR", "USD", dec!(1), PriceSource::Internal)).unwrap(); "ETH",
price_oracle.update_price(TokenPrice::new("SYNOR", "USD", dec!(1), PriceSource::Aggregated)).unwrap(); "USD",
price_oracle.update_price(TokenPrice::new("USDC", "USD", dec!(1), PriceSource::Internal)).unwrap(); dec!(2000),
price_oracle.update_price(TokenPrice::new("USDC", "USD", dec!(1), PriceSource::Aggregated)).unwrap(); PriceSource::Internal,
))
.unwrap();
price_oracle
.update_price(TokenPrice::new(
"ETH",
"USD",
dec!(2000),
PriceSource::Aggregated,
))
.unwrap();
price_oracle
.update_price(TokenPrice::new(
"SYNOR",
"USD",
dec!(1),
PriceSource::Internal,
))
.unwrap();
price_oracle
.update_price(TokenPrice::new(
"SYNOR",
"USD",
dec!(1),
PriceSource::Aggregated,
))
.unwrap();
price_oracle
.update_price(TokenPrice::new(
"USDC",
"USD",
dec!(1),
PriceSource::Internal,
))
.unwrap();
price_oracle
.update_price(TokenPrice::new(
"USDC",
"USD",
dec!(1),
PriceSource::Aggregated,
))
.unwrap();
// Use test config with relaxed freshness requirements // Use test config with relaxed freshness requirements
let test_config = LiquidationOracleConfig { let test_config = LiquidationOracleConfig {
max_price_age: 3600, // 1 hour for tests max_price_age: 3600, // 1 hour for tests
min_confidence: 0.5, // Lower confidence threshold min_confidence: 0.5, // Lower confidence threshold
min_sources: 1, // Single source OK for tests min_sources: 1, // Single source OK for tests
..Default::default() ..Default::default()
}; };
@ -660,8 +734,8 @@ mod tests {
// Create position with good health // Create position with good health
let pos = oracle.create_position("user1"); let pos = oracle.create_position("user1");
pos.deposit("ETH", dec!(1)); // $2000 worth pos.deposit("ETH", dec!(1)); // $2000 worth
pos.borrow("USDC", dec!(500)); // Borrow $500 pos.borrow("USDC", dec!(500)); // Borrow $500
let health = oracle.calculate_health("user1").unwrap(); let health = oracle.calculate_health("user1").unwrap();
@ -678,8 +752,8 @@ mod tests {
// Create position close to liquidation // Create position close to liquidation
let pos = oracle.create_position("user2"); let pos = oracle.create_position("user2");
pos.deposit("ETH", dec!(1)); // $2000 worth pos.deposit("ETH", dec!(1)); // $2000 worth
pos.borrow("USDC", dec!(1500)); // Borrow $1500 pos.borrow("USDC", dec!(1500)); // Borrow $1500
let health = oracle.calculate_health("user2").unwrap(); let health = oracle.calculate_health("user2").unwrap();
@ -699,7 +773,9 @@ mod tests {
pos.deposit("ETH", dec!(1)); pos.deposit("ETH", dec!(1));
pos.borrow("USDC", dec!(1500)); pos.borrow("USDC", dec!(1500));
let calc = oracle.calculate_liquidation("user3", "USDC", "ETH").unwrap(); let calc = oracle
.calculate_liquidation("user3", "USDC", "ETH")
.unwrap();
// Should be able to liquidate // Should be able to liquidate
assert!(calc.debt_to_repay > Decimal::ZERO); assert!(calc.debt_to_repay > Decimal::ZERO);

View file

@ -49,8 +49,7 @@ pub use price_feed::{
PriceSource, PriceSource,
}; };
pub use twap::{ pub use twap::{
OnChainTwap, OnChainTwapFactory, TwapCalculator, TwapConfig, TwapObservation, OnChainTwap, OnChainTwapFactory, TwapCalculator, TwapConfig, TwapObservation, TwapOracleBuilder,
TwapOracleBuilder,
}; };
use crate::error::{EconomicsError, Result}; use crate::error::{EconomicsError, Result};
@ -241,7 +240,10 @@ impl PriceOracle {
/// Get price history for a pair /// Get price history for a pair
pub fn get_price_history(&self, token: &str, quote: &str) -> Vec<TokenPrice> { pub fn get_price_history(&self, token: &str, quote: &str) -> Vec<TokenPrice> {
let pair_key = format!("{}/{}", token, quote); let pair_key = format!("{}/{}", token, quote);
self.price_history.get(&pair_key).cloned().unwrap_or_default() self.price_history
.get(&pair_key)
.cloned()
.unwrap_or_default()
} }
/// Fetch prices from all configured feeds /// Fetch prices from all configured feeds
@ -403,7 +405,9 @@ impl PriceOracle {
} }
let healthy = !pairs_status.is_empty() let healthy = !pairs_status.is_empty()
&& pairs_status.values().all(|s| !s.is_stale && s.price_count > 0); && pairs_status
.values()
.all(|s| !s.is_stale && s.price_count > 0);
OracleHealthStatus { OracleHealthStatus {
healthy, healthy,
@ -443,7 +447,10 @@ impl PriceOracleBuilder {
/// Add a mock price feed (for testing) /// Add a mock price feed (for testing)
pub fn with_mock_feed(mut self, base_price: SynorDecimal) -> Self { pub fn with_mock_feed(mut self, base_price: SynorDecimal) -> Self {
use price_feed::MockPriceFeed; use price_feed::MockPriceFeed;
self.feeds.push(Box::new(MockPriceFeed::new(PriceSource::Internal, base_price))); self.feeds.push(Box::new(MockPriceFeed::new(
PriceSource::Internal,
base_price,
)));
self self
} }
@ -455,9 +462,14 @@ impl PriceOracleBuilder {
} }
/// Add Chainlink oracle feed /// Add Chainlink oracle feed
pub fn with_chainlink(mut self, contract_address: impl Into<String>, rpc_url: impl Into<String>) -> Self { pub fn with_chainlink(
mut self,
contract_address: impl Into<String>,
rpc_url: impl Into<String>,
) -> Self {
use price_feed::ChainlinkFeed; use price_feed::ChainlinkFeed;
self.feeds.push(Box::new(ChainlinkFeed::new(contract_address, rpc_url))); self.feeds
.push(Box::new(ChainlinkFeed::new(contract_address, rpc_url)));
self self
} }
@ -508,8 +520,14 @@ impl OracleFactory {
let mut oracle = PriceOracle::new(); let mut oracle = PriceOracle::new();
// Add multiple mock feeds with slight variations for testing // Add multiple mock feeds with slight variations for testing
oracle.add_feed(Box::new(MockPriceFeed::new(PriceSource::Internal, base_price))); oracle.add_feed(Box::new(MockPriceFeed::new(
oracle.add_feed(Box::new(MockPriceFeed::new(PriceSource::SynorDex, base_price))); PriceSource::Internal,
base_price,
)));
oracle.add_feed(Box::new(MockPriceFeed::new(
PriceSource::SynorDex,
base_price,
)));
oracle oracle
} }

View file

@ -110,7 +110,8 @@ impl PriceFeed for MockPriceFeed {
async fn fetch_price(&self, token: &str, quote: &str) -> Result<TokenPrice> { async fn fetch_price(&self, token: &str, quote: &str) -> Result<TokenPrice> {
// Add small random variance // Add small random variance
let variance = (rand_simple() * 2.0 - 1.0) * self.volatility; let variance = (rand_simple() * 2.0 - 1.0) * self.volatility;
let price = self.base_price * (Decimal::ONE + Decimal::from_f64_retain(variance).unwrap_or_default()); let price = self.base_price
* (Decimal::ONE + Decimal::from_f64_retain(variance).unwrap_or_default());
Ok(TokenPrice { Ok(TokenPrice {
token: token.to_string(), token: token.to_string(),
@ -258,9 +259,12 @@ impl PriceFeed for CoinGeckoFeed {
"SYNOR" => "synor", // Would need actual CoinGecko ID "SYNOR" => "synor", // Would need actual CoinGecko ID
"BTC" => "bitcoin", "BTC" => "bitcoin",
"ETH" => "ethereum", "ETH" => "ethereum",
_ => return Err(EconomicsError::PriceFeedUnavailable( _ => {
format!("Token {} not supported on CoinGecko", token) return Err(EconomicsError::PriceFeedUnavailable(format!(
)), "Token {} not supported on CoinGecko",
token
)))
}
}; };
let quote_currency = quote.to_lowercase(); let quote_currency = quote.to_lowercase();
@ -294,8 +298,7 @@ impl PriceFeed for CoinGeckoFeed {
Ok(TokenPrice { Ok(TokenPrice {
token: token.to_string(), token: token.to_string(),
quote: quote.to_string(), quote: quote.to_string(),
price: Decimal::from_f64_retain(price) price: Decimal::from_f64_retain(price).unwrap_or_default(),
.unwrap_or_default(),
timestamp: Utc::now(), timestamp: Utc::now(),
source: PriceSource::CoinGecko, source: PriceSource::CoinGecko,
confidence: 0.90, confidence: 0.90,

View file

@ -116,8 +116,8 @@ impl TwapCalculator {
let duration = (interval_end - interval_start).num_seconds() as f64; let duration = (interval_end - interval_start).num_seconds() as f64;
if duration > 0.0 { if duration > 0.0 {
let weight = Decimal::from_f64_retain(duration / total_duration) let weight =
.unwrap_or(Decimal::ZERO); Decimal::from_f64_retain(duration / total_duration).unwrap_or(Decimal::ZERO);
weighted_sum += price.price * weight; weighted_sum += price.price * weight;
total_weight += weight; total_weight += weight;
@ -343,7 +343,8 @@ impl OnChainTwap {
/// Apply the pending cardinality increase (called during next observation) /// Apply the pending cardinality increase (called during next observation)
pub fn apply_cardinality_growth(&mut self) { pub fn apply_cardinality_growth(&mut self) {
if self.cardinality_next > self.cardinality { if self.cardinality_next > self.cardinality {
self.observations.reserve(self.cardinality_next - self.cardinality); self.observations
.reserve(self.cardinality_next - self.cardinality);
self.cardinality = self.cardinality_next; self.cardinality = self.cardinality_next;
} }
} }
@ -396,7 +397,12 @@ impl TwapOracleBuilder {
} }
/// Add an initial observation /// Add an initial observation
pub fn with_observation(mut self, timestamp: DateTime<Utc>, price_cumulative: SynorDecimal, spl_cumulative: SynorDecimal) -> Self { pub fn with_observation(
mut self,
timestamp: DateTime<Utc>,
price_cumulative: SynorDecimal,
spl_cumulative: SynorDecimal,
) -> Self {
self.initial_observations.push(TwapObservation { self.initial_observations.push(TwapObservation {
timestamp, timestamp,
price_cumulative, price_cumulative,

View file

@ -76,7 +76,11 @@ pub struct Discount {
impl Discount { impl Discount {
/// Create a new percentage discount /// Create a new percentage discount
pub fn percentage(code: impl Into<String>, name: impl Into<String>, percentage: SynorDecimal) -> Self { pub fn percentage(
code: impl Into<String>,
name: impl Into<String>,
percentage: SynorDecimal,
) -> Self {
Self { Self {
code: code.into(), code: code.into(),
name: name.into(), name: name.into(),
@ -96,7 +100,11 @@ impl Discount {
} }
/// Create a fixed amount discount /// Create a fixed amount discount
pub fn fixed_amount(code: impl Into<String>, name: impl Into<String>, amount: SynorDecimal) -> Self { pub fn fixed_amount(
code: impl Into<String>,
name: impl Into<String>,
amount: SynorDecimal,
) -> Self {
Self { Self {
code: code.into(), code: code.into(),
name: name.into(), name: name.into(),
@ -120,7 +128,10 @@ impl Discount {
Self { Self {
code: format!("VOLUME_{}", min_spend), code: format!("VOLUME_{}", min_spend),
name: format!("Volume Discount ({} SYNOR+)", min_spend), name: format!("Volume Discount ({} SYNOR+)", min_spend),
description: format!("{}% off when spending {} SYNOR or more", percentage, min_spend), description: format!(
"{}% off when spending {} SYNOR or more",
percentage, min_spend
),
discount_type: DiscountType::Volume, discount_type: DiscountType::Volume,
value: percentage, value: percentage,
min_spend: Some(min_spend), min_spend: Some(min_spend),
@ -260,9 +271,11 @@ impl Discount {
} }
let discount = match self.discount_type { let discount = match self.discount_type {
DiscountType::Percentage | DiscountType::Volume | DiscountType::Loyalty | DiscountType::Referral | DiscountType::Partner => { DiscountType::Percentage
amount * (self.value / Decimal::ONE_HUNDRED) | DiscountType::Volume
} | DiscountType::Loyalty
| DiscountType::Referral
| DiscountType::Partner => amount * (self.value / Decimal::ONE_HUNDRED),
DiscountType::FixedAmount | DiscountType::Promotional => { DiscountType::FixedAmount | DiscountType::Promotional => {
self.value.min(amount) // Can't discount more than amount self.value.min(amount) // Can't discount more than amount
} }
@ -298,7 +311,7 @@ impl Discount {
/// Volume discount tiers /// Volume discount tiers
pub fn standard_volume_discounts() -> Vec<Discount> { pub fn standard_volume_discounts() -> Vec<Discount> {
vec![ vec![
Discount::volume(Decimal::new(100, 0), Decimal::new(5, 0)), // 5% at 100+ SYNOR Discount::volume(Decimal::new(100, 0), Decimal::new(5, 0)), // 5% at 100+ SYNOR
Discount::volume(Decimal::new(500, 0), Decimal::new(10, 0)), // 10% at 500+ SYNOR Discount::volume(Decimal::new(500, 0), Decimal::new(10, 0)), // 10% at 500+ SYNOR
Discount::volume(Decimal::new(1000, 0), Decimal::new(15, 0)), // 15% at 1000+ SYNOR Discount::volume(Decimal::new(1000, 0), Decimal::new(15, 0)), // 15% at 1000+ SYNOR
Discount::volume(Decimal::new(5000, 0), Decimal::new(20, 0)), // 20% at 5000+ SYNOR Discount::volume(Decimal::new(5000, 0), Decimal::new(20, 0)), // 20% at 5000+ SYNOR
@ -309,11 +322,7 @@ pub fn standard_volume_discounts() -> Vec<Discount> {
pub fn find_best_volume_discount(amount: SynorDecimal) -> Option<Discount> { pub fn find_best_volume_discount(amount: SynorDecimal) -> Option<Discount> {
standard_volume_discounts() standard_volume_discounts()
.into_iter() .into_iter()
.filter(|d| { .filter(|d| d.min_spend.map(|min| amount >= min).unwrap_or(false))
d.min_spend
.map(|min| amount >= min)
.unwrap_or(false)
})
.max_by(|a, b| a.value.cmp(&b.value)) .max_by(|a, b| a.value.cmp(&b.value))
} }
@ -378,8 +387,8 @@ mod tests {
#[test] #[test]
fn test_discount_usage_limit() { fn test_discount_usage_limit() {
let mut discount = Discount::percentage("LIMITED", "Limited Use", dec!(10)) let mut discount =
.with_max_uses(2); Discount::percentage("LIMITED", "Limited Use", dec!(10)).with_max_uses(2);
assert!(discount.use_discount()); assert!(discount.use_discount());
assert!(discount.use_discount()); assert!(discount.use_discount());

View file

@ -198,15 +198,9 @@ impl PricingEngine {
.get(&service_type) .get(&service_type)
.ok_or_else(|| EconomicsError::ServiceNotConfigured(service_type.to_string()))?; .ok_or_else(|| EconomicsError::ServiceNotConfigured(service_type.to_string()))?;
let unit_price = pricing let unit_price = pricing.base_prices.get(&resource_unit).ok_or_else(|| {
.base_prices EconomicsError::ServiceNotConfigured(format!("{} - {}", service_type, resource_unit))
.get(&resource_unit) })?;
.ok_or_else(|| {
EconomicsError::ServiceNotConfigured(format!(
"{} - {}",
service_type, resource_unit
))
})?;
let cost = amount * unit_price; let cost = amount * unit_price;
@ -357,7 +351,8 @@ impl PricingEngine {
storage: ServicePricingSummary { storage: ServicePricingSummary {
gb_month: self.get_base_price(ServiceType::Storage, ResourceUnit::GbMonth), gb_month: self.get_base_price(ServiceType::Storage, ResourceUnit::GbMonth),
retrieval_gb: self.get_base_price(ServiceType::Storage, ResourceUnit::BandwidthGb), retrieval_gb: self.get_base_price(ServiceType::Storage, ResourceUnit::BandwidthGb),
free_storage_gb: self.get_free_allocation(ServiceType::Storage, ResourceUnit::GbMonth), free_storage_gb: self
.get_free_allocation(ServiceType::Storage, ResourceUnit::GbMonth),
}, },
hosting: HostingPricingSummary { hosting: HostingPricingSummary {
bandwidth_gb: self.get_base_price(ServiceType::Hosting, ResourceUnit::BandwidthGb), bandwidth_gb: self.get_base_price(ServiceType::Hosting, ResourceUnit::BandwidthGb),
@ -377,7 +372,8 @@ impl PricingEngine {
.get_free_allocation(ServiceType::Database, ResourceUnit::Queries), .get_free_allocation(ServiceType::Database, ResourceUnit::Queries),
}, },
compute: ComputePricingSummary { compute: ComputePricingSummary {
cpu_core_hour: self.get_base_price(ServiceType::Compute, ResourceUnit::CpuCoreHours), cpu_core_hour: self
.get_base_price(ServiceType::Compute, ResourceUnit::CpuCoreHours),
gpu_hour: self.get_base_price(ServiceType::Compute, ResourceUnit::GpuHours), gpu_hour: self.get_base_price(ServiceType::Compute, ResourceUnit::GpuHours),
memory_gb_hour: self memory_gb_hour: self
.get_base_price(ServiceType::Compute, ResourceUnit::MemoryGbHours), .get_base_price(ServiceType::Compute, ResourceUnit::MemoryGbHours),
@ -472,7 +468,11 @@ mod tests {
// 10 million queries // 10 million queries
let cost = engine let cost = engine
.calculate_cost(ServiceType::Database, ResourceUnit::Queries, dec!(10_000_000)) .calculate_cost(
ServiceType::Database,
ResourceUnit::Queries,
dec!(10_000_000),
)
.unwrap(); .unwrap();
assert_eq!(cost, dec!(0.10)); // 10M * 0.00000001 assert_eq!(cost, dec!(0.10)); // 10M * 0.00000001
@ -484,7 +484,9 @@ mod tests {
// Premium tier gets 20% discount // Premium tier gets 20% discount
let base_cost = dec!(100); let base_cost = dec!(100);
let discount = engine.calculate_tier_discount("premium", base_cost).unwrap(); let discount = engine
.calculate_tier_discount("premium", base_cost)
.unwrap();
assert_eq!(discount, dec!(20)); // 20% assert_eq!(discount, dec!(20)); // 20%
} }

View file

@ -108,8 +108,8 @@ impl PricingTier {
discount_percentage: Decimal::new(30, 0), // 30% discount discount_percentage: Decimal::new(30, 0), // 30% discount
priority_support: true, priority_support: true,
sla_percentage: Decimal::new(9999, 2), // 99.99% SLA sla_percentage: Decimal::new(9999, 2), // 99.99% SLA
custom_domain_limit: 0, // Unlimited custom_domain_limit: 0, // Unlimited
api_rate_limit: 0, // Unlimited api_rate_limit: 0, // Unlimited
features: vec![ features: vec![
"Everything in Premium".to_string(), "Everything in Premium".to_string(),
"30%+ Usage Discount".to_string(), "30%+ Usage Discount".to_string(),
@ -147,7 +147,9 @@ impl PricingTier {
/// Check if this tier has a feature /// Check if this tier has a feature
pub fn has_feature(&self, feature: &str) -> bool { pub fn has_feature(&self, feature: &str) -> bool {
self.features.iter().any(|f| f.to_lowercase().contains(&feature.to_lowercase())) self.features
.iter()
.any(|f| f.to_lowercase().contains(&feature.to_lowercase()))
} }
/// Calculate effective monthly cost including usage discount /// Calculate effective monthly cost including usage discount
@ -163,7 +165,9 @@ impl PricingTier {
let other_cost = other.effective_cost(monthly_usage); let other_cost = other.effective_cost(monthly_usage);
// Upgrade if other tier is cheaper or offers significant benefits // Upgrade if other tier is cheaper or offers significant benefits
other_cost < current_cost || (other.sla_percentage > self.sla_percentage && other_cost <= current_cost * Decimal::new(12, 1)) other_cost < current_cost
|| (other.sla_percentage > self.sla_percentage
&& other_cost <= current_cost * Decimal::new(12, 1))
} }
} }

View file

@ -181,7 +181,12 @@ impl AuthService {
} }
/// Generate a JWT token. /// Generate a JWT token.
pub fn generate_token(&self, user_id: &str, tier: ApiKeyTier, permissions: Permissions) -> Result<String, ApiError> { pub fn generate_token(
&self,
user_id: &str,
tier: ApiKeyTier,
permissions: Permissions,
) -> Result<String, ApiError> {
let now = Utc::now(); let now = Utc::now();
let exp = now + self.jwt_expiration; let exp = now + self.jwt_expiration;
@ -215,8 +220,7 @@ impl AuthService {
})?; })?;
let claims = token_data.claims; let claims = token_data.claims;
let expires_at = DateTime::from_timestamp(claims.exp, 0) let expires_at = DateTime::from_timestamp(claims.exp, 0).map(|dt| dt.with_timezone(&Utc));
.map(|dt| dt.with_timezone(&Utc));
Ok(AuthContext { Ok(AuthContext {
user_id: claims.sub, user_id: claims.sub,
@ -278,9 +282,7 @@ impl AuthService {
// Try API key header // Try API key header
if let Some(api_key) = headers.get("X-API-Key") { if let Some(api_key) = headers.get("X-API-Key") {
let key = api_key let key = api_key.to_str().map_err(|_| ApiError::InvalidApiKey)?;
.to_str()
.map_err(|_| ApiError::InvalidApiKey)?;
return self.validate_api_key(key).await; return self.validate_api_key(key).await;
} }
@ -295,8 +297,7 @@ impl AuthService {
let decoded = BASE64 let decoded = BASE64
.decode(encoded) .decode(encoded)
.map_err(|_| ApiError::InvalidApiKey)?; .map_err(|_| ApiError::InvalidApiKey)?;
let key = String::from_utf8(decoded) let key = String::from_utf8(decoded).map_err(|_| ApiError::InvalidApiKey)?;
.map_err(|_| ApiError::InvalidApiKey)?;
return self.validate_api_key(&key).await; return self.validate_api_key(&key).await;
} }
} }
@ -318,7 +319,9 @@ where
fn from_request_parts<'life0, 'life1, 'async_trait>( fn from_request_parts<'life0, 'life1, 'async_trait>(
parts: &'life0 mut Parts, parts: &'life0 mut Parts,
_state: &'life1 S, _state: &'life1 S,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Self, Self::Rejection>> + Send + 'async_trait>> ) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<Self, Self::Rejection>> + Send + 'async_trait>,
>
where where
'life0: 'async_trait, 'life0: 'async_trait,
'life1: 'async_trait, 'life1: 'async_trait,
@ -351,7 +354,9 @@ where
fn from_request_parts<'life0, 'life1, 'async_trait>( fn from_request_parts<'life0, 'life1, 'async_trait>(
parts: &'life0 mut Parts, parts: &'life0 mut Parts,
_state: &'life1 S, _state: &'life1 S,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Self, Self::Rejection>> + Send + 'async_trait>> ) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<Self, Self::Rejection>> + Send + 'async_trait>,
>
where where
'life0: 'async_trait, 'life0: 'async_trait,
'life1: 'async_trait, 'life1: 'async_trait,
@ -359,10 +364,7 @@ where
{ {
Box::pin(async move { Box::pin(async move {
// Get auth service from extensions // Get auth service from extensions
let auth_service = parts let auth_service = parts.extensions.get::<AuthService>().cloned();
.extensions
.get::<AuthService>()
.cloned();
if let Some(auth_service) = auth_service { if let Some(auth_service) = auth_service {
match auth_service.authenticate(&parts.headers).await { match auth_service.authenticate(&parts.headers).await {
@ -377,10 +379,7 @@ where
} }
/// Require specific permissions. /// Require specific permissions.
pub fn require_permission( pub fn require_permission(context: &AuthContext, permission: &str) -> Result<(), ApiError> {
context: &AuthContext,
permission: &str,
) -> Result<(), ApiError> {
let has_permission = match permission { let has_permission = match permission {
"read" => context.can_read(), "read" => context.can_read(),
"write" => context.can_write(), "write" => context.can_write(),
@ -397,10 +396,7 @@ pub fn require_permission(
} }
/// Require access to a specific service. /// Require access to a specific service.
pub fn require_service_access( pub fn require_service_access(context: &AuthContext, service: &str) -> Result<(), ApiError> {
context: &AuthContext,
service: &str,
) -> Result<(), ApiError> {
if context.can_access_service(service) { if context.can_access_service(service) {
Ok(()) Ok(())
} else { } else {

View file

@ -160,10 +160,26 @@ pub struct RateLimitTiers {
impl Default for RateLimitTiers { impl Default for RateLimitTiers {
fn default() -> Self { fn default() -> Self {
Self { Self {
free: TierConfig { rpm: 60, burst: 10, concurrent: 5 }, free: TierConfig {
developer: TierConfig { rpm: 600, burst: 100, concurrent: 20 }, rpm: 60,
pro: TierConfig { rpm: 6000, burst: 1000, concurrent: 100 }, burst: 10,
enterprise: TierConfig { rpm: 0, burst: 0, concurrent: 0 }, // Unlimited concurrent: 5,
},
developer: TierConfig {
rpm: 600,
burst: 100,
concurrent: 20,
},
pro: TierConfig {
rpm: 6000,
burst: 1000,
concurrent: 100,
},
enterprise: TierConfig {
rpm: 0,
burst: 0,
concurrent: 0,
}, // Unlimited
} }
} }
} }

View file

@ -170,9 +170,7 @@ impl ApiError {
| Self::ContractError(_) => StatusCode::UNPROCESSABLE_ENTITY, | Self::ContractError(_) => StatusCode::UNPROCESSABLE_ENTITY,
// 429 Too Many Requests // 429 Too Many Requests
Self::RateLimitExceeded | Self::TooManyRequests { .. } => { Self::RateLimitExceeded | Self::TooManyRequests { .. } => StatusCode::TOO_MANY_REQUESTS,
StatusCode::TOO_MANY_REQUESTS
}
// 500 Internal Server Error // 500 Internal Server Error
Self::InternalError | Self::Custom(_) => StatusCode::INTERNAL_SERVER_ERROR, Self::InternalError | Self::Custom(_) => StatusCode::INTERNAL_SERVER_ERROR,
@ -222,7 +220,10 @@ impl ApiError {
/// Build error details with optional extra information. /// Build error details with optional extra information.
pub fn to_details(&self) -> ErrorDetails { pub fn to_details(&self) -> ErrorDetails {
let details = match self { let details = match self {
Self::InsufficientBalance { required, available } => Some(serde_json::json!({ Self::InsufficientBalance {
required,
available,
} => Some(serde_json::json!({
"required": required, "required": required,
"available": available "available": available
})), })),
@ -257,10 +258,9 @@ impl IntoResponse for ApiError {
// Add rate limit headers for 429 errors // Add rate limit headers for 429 errors
if let Self::TooManyRequests { retry_after } = &self { if let Self::TooManyRequests { retry_after } = &self {
response.headers_mut().insert( response
"Retry-After", .headers_mut()
retry_after.to_string().parse().unwrap(), .insert("Retry-After", retry_after.to_string().parse().unwrap());
);
} }
response response

View file

@ -144,7 +144,8 @@ pub async fn timing_middleware(request: Request, next: Next) -> Response {
// Update metrics // Update metrics
metrics::counter!("http_requests_total", "method" => method.to_string(), "status" => status.as_u16().to_string()).increment(1); metrics::counter!("http_requests_total", "method" => method.to_string(), "status" => status.as_u16().to_string()).increment(1);
metrics::histogram!("http_request_duration_seconds", "method" => method.to_string()).record(duration.as_secs_f64()); metrics::histogram!("http_request_duration_seconds", "method" => method.to_string())
.record(duration.as_secs_f64());
response response
} }
@ -169,7 +170,10 @@ impl RateLimiterState {
} }
/// Get or create a rate limiter for an IP. /// Get or create a rate limiter for an IP.
pub async fn get_ip_limiter(&self, ip: &str) -> Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>> { pub async fn get_ip_limiter(
&self,
ip: &str,
) -> Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>> {
{ {
let limiters = self.ip_limiters.read().await; let limiters = self.ip_limiters.read().await;
if let Some(limiter) = limiters.get(ip) { if let Some(limiter) = limiters.get(ip) {
@ -189,7 +193,11 @@ impl RateLimiterState {
} }
/// Get or create a rate limiter for an API key. /// Get or create a rate limiter for an API key.
pub async fn get_key_limiter(&self, key_id: &str, tier: ApiKeyTier) -> Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>> { pub async fn get_key_limiter(
&self,
key_id: &str,
tier: ApiKeyTier,
) -> Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>> {
{ {
let limiters = self.key_limiters.read().await; let limiters = self.key_limiters.read().await;
if let Some(limiter) = limiters.get(key_id) { if let Some(limiter) = limiters.get(key_id) {
@ -255,9 +263,7 @@ pub async fn rate_limit_middleware(
// Use a fixed retry time since we can't easily convert to quanta's instant // Use a fixed retry time since we can't easily convert to quanta's instant
let retry_after = 60; // Default to 60 seconds let retry_after = 60; // Default to 60 seconds
Err(ApiError::TooManyRequests { Err(ApiError::TooManyRequests { retry_after })
retry_after,
})
} }
} }
} }
@ -274,10 +280,7 @@ pub async fn auth_middleware(
} }
/// API version middleware - validates version prefix. /// API version middleware - validates version prefix.
pub async fn version_middleware( pub async fn version_middleware(request: Request, next: Next) -> Result<Response, ApiError> {
request: Request,
next: Next,
) -> Result<Response, ApiError> {
let path = request.uri().path(); let path = request.uri().path();
// Skip version check for health, metrics, and docs // Skip version check for health, metrics, and docs
@ -307,22 +310,13 @@ pub async fn security_headers_middleware(request: Request, next: Next) -> Respon
let headers = response.headers_mut(); let headers = response.headers_mut();
// Prevent XSS // Prevent XSS
headers.insert( headers.insert("X-Content-Type-Options", "nosniff".parse().unwrap());
"X-Content-Type-Options",
"nosniff".parse().unwrap(),
);
// Prevent clickjacking // Prevent clickjacking
headers.insert( headers.insert("X-Frame-Options", "DENY".parse().unwrap());
"X-Frame-Options",
"DENY".parse().unwrap(),
);
// Enable XSS filter // Enable XSS filter
headers.insert( headers.insert("X-XSS-Protection", "1; mode=block".parse().unwrap());
"X-XSS-Protection",
"1; mode=block".parse().unwrap(),
);
// Strict transport security (HTTPS) // Strict transport security (HTTPS)
headers.insert( headers.insert(

View file

@ -6,11 +6,7 @@
//! - Contract analysis and validation //! - Contract analysis and validation
//! - Security scanning //! - Security scanning
use axum::{ use axum::{extract::State, routing::post, Json, Router};
extract::State,
routing::post,
Json, Router,
};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::{ use crate::{
@ -44,7 +40,7 @@ pub fn router() -> Router<AppState> {
// Types // Types
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
pub struct CompileRequest { pub struct CompileRequest {
pub wasm: String, // base64 encoded WASM pub wasm: String, // base64 encoded WASM
pub optimization_level: Option<String>, // none, basic, size, aggressive pub optimization_level: Option<String>, // none, basic, size, aggressive
pub strip_debug: Option<bool>, pub strip_debug: Option<bool>,
pub strip_names: Option<bool>, pub strip_names: Option<bool>,
@ -216,16 +212,14 @@ async fn compile_contract(
abi: Some(ContractAbi { abi: Some(ContractAbi {
name: "MyContract".to_string(), name: "MyContract".to_string(),
version: "1.0.0".to_string(), version: "1.0.0".to_string(),
functions: vec![ functions: vec![AbiFunction {
AbiFunction { name: "init".to_string(),
name: "init".to_string(), selector: "0x12345678".to_string(),
selector: "0x12345678".to_string(), inputs: vec![],
inputs: vec![], outputs: vec![],
outputs: vec![], view: false,
view: false, payable: false,
payable: false, }],
},
],
events: vec![], events: vec![],
errors: vec![], errors: vec![],
}), }),
@ -342,23 +336,19 @@ async fn analyze_contract(
imports: 100, imports: 100,
total: 5000, total: 5000,
}, },
functions: vec![ functions: vec![FunctionAnalysis {
FunctionAnalysis { name: "init".to_string(),
name: "init".to_string(), size: 500,
size: 500, instruction_count: 50,
instruction_count: 50, local_count: 3,
local_count: 3, exported: true,
exported: true, estimated_gas: 10000,
estimated_gas: 10000, }],
}, imports: vec![ImportInfo {
], module: "env".to_string(),
imports: vec![ name: "memory".to_string(),
ImportInfo { kind: "memory".to_string(),
module: "env".to_string(), }],
name: "memory".to_string(),
kind: "memory".to_string(),
},
],
gas_analysis: GasAnalysis { gas_analysis: GasAnalysis {
deployment_gas: 100000, deployment_gas: 100000,
memory_init_gas: 5000, memory_init_gas: 5000,
@ -378,14 +368,12 @@ async fn security_scan(
let result = SecurityScanResult { let result = SecurityScanResult {
score: 85, score: 85,
issues: vec![ issues: vec![SecurityIssue {
SecurityIssue { severity: "low".to_string(),
severity: "low".to_string(), issue_type: "unbounded_loop".to_string(),
issue_type: "unbounded_loop".to_string(), description: "Potential unbounded loop detected".to_string(),
description: "Potential unbounded loop detected".to_string(), location: Some("function:process".to_string()),
location: Some("function:process".to_string()), }],
},
],
recommendations: vec![ recommendations: vec![
"Add loop iteration limits".to_string(), "Add loop iteration limits".to_string(),
"Consider using checked arithmetic".to_string(), "Consider using checked arithmetic".to_string(),

View file

@ -122,17 +122,15 @@ async fn list_markets(
) -> ApiResult<Json<ApiResponse<Vec<Market>>>> { ) -> ApiResult<Json<ApiResponse<Vec<Market>>>> {
require_permission(&auth, "read")?; require_permission(&auth, "read")?;
let markets = vec![ let markets = vec![Market {
Market { symbol: "ETH-USDC".to_string(),
symbol: "ETH-USDC".to_string(), base_asset: "ETH".to_string(),
base_asset: "ETH".to_string(), quote_asset: "USDC".to_string(),
quote_asset: "USDC".to_string(), last_price: "3000.00".to_string(),
last_price: "3000.00".to_string(), change_24h: "2.5".to_string(),
change_24h: "2.5".to_string(), volume_24h: "10000000".to_string(),
volume_24h: "10000000".to_string(), status: "active".to_string(),
status: "active".to_string(), }];
},
];
Ok(Json(ApiResponse::success(markets))) Ok(Json(ApiResponse::success(markets)))
} }
@ -165,8 +163,14 @@ async fn get_orderbook(
require_permission(&auth, "read")?; require_permission(&auth, "read")?;
let orderbook = Orderbook { let orderbook = Orderbook {
bids: vec![OrderbookEntry { price: "2999.00".to_string(), quantity: "1.5".to_string() }], bids: vec![OrderbookEntry {
asks: vec![OrderbookEntry { price: "3001.00".to_string(), quantity: "2.0".to_string() }], price: "2999.00".to_string(),
quantity: "1.5".to_string(),
}],
asks: vec![OrderbookEntry {
price: "3001.00".to_string(),
quantity: "2.0".to_string(),
}],
spread: "2.00".to_string(), spread: "2.00".to_string(),
}; };
@ -286,7 +290,9 @@ async fn place_perp_order(
Json(req): Json<serde_json::Value>, Json(req): Json<serde_json::Value>,
) -> ApiResult<Json<ApiResponse<serde_json::Value>>> { ) -> ApiResult<Json<ApiResponse<serde_json::Value>>> {
require_permission(&auth, "write")?; require_permission(&auth, "write")?;
Ok(Json(ApiResponse::success(serde_json::json!({"order_id": "perp_123"})))) Ok(Json(ApiResponse::success(
serde_json::json!({"order_id": "perp_123"}),
)))
} }
async fn list_pools( async fn list_pools(
@ -295,18 +301,16 @@ async fn list_pools(
) -> ApiResult<Json<ApiResponse<Vec<Pool>>>> { ) -> ApiResult<Json<ApiResponse<Vec<Pool>>>> {
require_permission(&auth, "read")?; require_permission(&auth, "read")?;
let pools = vec![ let pools = vec![Pool {
Pool { pool_id: "ETH-USDC".to_string(),
pool_id: "ETH-USDC".to_string(), name: "ETH/USDC".to_string(),
name: "ETH/USDC".to_string(), token_a: "ETH".to_string(),
token_a: "ETH".to_string(), token_b: "USDC".to_string(),
token_b: "USDC".to_string(), reserve_a: "1000".to_string(),
reserve_a: "1000".to_string(), reserve_b: "3000000".to_string(),
reserve_b: "3000000".to_string(), tvl: "6000000".to_string(),
tvl: "6000000".to_string(), apr: "15.5".to_string(),
apr: "15.5".to_string(), }];
},
];
Ok(Json(ApiResponse::success(pools))) Ok(Json(ApiResponse::success(pools)))
} }

View file

@ -128,16 +128,14 @@ async fn list_chains(
) -> ApiResult<Json<ApiResponse<Vec<Chain>>>> { ) -> ApiResult<Json<ApiResponse<Vec<Chain>>>> {
require_permission(&auth, "read")?; require_permission(&auth, "read")?;
let chains = vec![ let chains = vec![Chain {
Chain { chain_id: "cosmoshub-4".to_string(),
chain_id: "cosmoshub-4".to_string(), name: "Cosmos Hub".to_string(),
name: "Cosmos Hub".to_string(), status: "active".to_string(),
status: "active".to_string(), rpc_endpoint: "https://rpc.cosmos.network".to_string(),
rpc_endpoint: "https://rpc.cosmos.network".to_string(), latest_height: 18000000,
latest_height: 18000000, active_channels: 50,
active_channels: 50, }];
},
];
Ok(Json(ApiResponse::success(chains))) Ok(Json(ApiResponse::success(chains)))
} }
@ -280,15 +278,13 @@ async fn get_routes(
) -> ApiResult<Json<ApiResponse<Vec<TransferRoute>>>> { ) -> ApiResult<Json<ApiResponse<Vec<TransferRoute>>>> {
require_permission(&auth, "read")?; require_permission(&auth, "read")?;
let routes = vec![ let routes = vec![TransferRoute {
TransferRoute { source_chain: "cosmoshub-4".to_string(),
source_chain: "cosmoshub-4".to_string(), dest_chain: "synor-mainnet".to_string(),
dest_chain: "synor-mainnet".to_string(), channel_id: "channel-0".to_string(),
channel_id: "channel-0".to_string(), estimated_time: "30s".to_string(),
estimated_time: "30s".to_string(), fee: "0.001 ATOM".to_string(),
fee: "0.001 ATOM".to_string(), }];
},
];
Ok(Json(ApiResponse::success(routes))) Ok(Json(ApiResponse::success(routes)))
} }

View file

@ -303,13 +303,11 @@ async fn get_peers(
) -> ApiResult<Json<ApiResponse<Vec<serde_json::Value>>>> { ) -> ApiResult<Json<ApiResponse<Vec<serde_json::Value>>>> {
require_permission(&auth, "read")?; require_permission(&auth, "read")?;
let peers = vec![ let peers = vec![serde_json::json!({
serde_json::json!({ "id": "peer1",
"id": "peer1", "address": "192.168.1.1:16100",
"address": "192.168.1.1:16100", "connected_since": 1705312200
"connected_since": 1705312200 })];
})
];
Ok(Json(ApiResponse::success(peers))) Ok(Json(ApiResponse::success(peers)))
} }

View file

@ -247,14 +247,12 @@ async fn list_directory(
) -> ApiResult<Json<ApiResponse<Vec<DirectoryEntry>>>> { ) -> ApiResult<Json<ApiResponse<Vec<DirectoryEntry>>>> {
require_permission(&auth, "read")?; require_permission(&auth, "read")?;
let entries = vec![ let entries = vec![DirectoryEntry {
DirectoryEntry { name: "file1.txt".to_string(),
name: "file1.txt".to_string(), cid: "bafyfile1...".to_string(),
cid: "bafyfile1...".to_string(), size: 1024,
size: 1024, is_directory: false,
is_directory: false, }];
},
];
Ok(Json(ApiResponse::success(entries))) Ok(Json(ApiResponse::success(entries)))
} }

View file

@ -409,7 +409,10 @@ async fn list_addresses(
]; ];
let pagination_meta = pagination.to_meta(addresses.len() as u64); let pagination_meta = pagination.to_meta(addresses.len() as u64);
Ok(Json(ApiResponse::success_paginated(addresses, pagination_meta))) Ok(Json(ApiResponse::success_paginated(
addresses,
pagination_meta,
)))
} }
/// Generate a stealth address. /// Generate a stealth address.
@ -477,7 +480,9 @@ async fn get_balances(
require_permission(&auth, "read")?; require_permission(&auth, "read")?;
if req.addresses.is_empty() { if req.addresses.is_empty() {
return Err(ApiError::ValidationError("addresses cannot be empty".to_string())); return Err(ApiError::ValidationError(
"addresses cannot be empty".to_string(),
));
} }
if req.addresses.len() > 100 { if req.addresses.len() > 100 {
@ -585,12 +590,15 @@ async fn send_transaction(
} }
// Validate amount // Validate amount
let amount: f64 = req.amount.parse().map_err(|_| { let amount: f64 = req
ApiError::ValidationError("Invalid amount format".to_string()) .amount
})?; .parse()
.map_err(|_| ApiError::ValidationError("Invalid amount format".to_string()))?;
if amount <= 0.0 { if amount <= 0.0 {
return Err(ApiError::ValidationError("Amount must be positive".to_string())); return Err(ApiError::ValidationError(
"Amount must be positive".to_string(),
));
} }
// In production, build, sign, and broadcast the transaction // In production, build, sign, and broadcast the transaction

View file

@ -139,16 +139,14 @@ async fn list_circuits(
) -> ApiResult<Json<ApiResponse<Vec<Circuit>>>> { ) -> ApiResult<Json<ApiResponse<Vec<Circuit>>>> {
require_permission(&auth, "read")?; require_permission(&auth, "read")?;
let circuits = vec![ let circuits = vec![Circuit {
Circuit { circuit_id: "multiplier-v1".to_string(),
circuit_id: "multiplier-v1".to_string(), name: "Multiplier".to_string(),
name: "Multiplier".to_string(), constraints: 1,
constraints: 1, public_inputs: 1,
public_inputs: 1, private_inputs: 2,
private_inputs: 2, outputs: 1,
outputs: 1, }];
},
];
let meta = pagination.to_meta(circuits.len() as u64); let meta = pagination.to_meta(circuits.len() as u64);
Ok(Json(ApiResponse::success_paginated(circuits, meta))) Ok(Json(ApiResponse::success_paginated(circuits, meta)))

View file

@ -17,15 +17,9 @@ use axum::{
Router, Router,
}; };
use std::{net::SocketAddr, sync::Arc}; use std::{net::SocketAddr, sync::Arc};
use tokio::{ use tokio::{net::TcpListener, signal, sync::oneshot};
net::TcpListener,
signal,
sync::oneshot,
};
use tower_http::{ use tower_http::{
compression::CompressionLayer, compression::CompressionLayer, limit::RequestBodyLimitLayer, timeout::TimeoutLayer,
limit::RequestBodyLimitLayer,
timeout::TimeoutLayer,
trace::TraceLayer, trace::TraceLayer,
}; };
use tracing::info; use tracing::info;

View file

@ -164,7 +164,8 @@ impl VersionRegistry {
/// Register a version. /// Register a version.
pub fn register(&mut self, info: VersionInfo) { pub fn register(&mut self, info: VersionInfo) {
self.versions.insert((info.version.major, info.version.minor), info); self.versions
.insert((info.version.major, info.version.minor), info);
} }
/// Get version info. /// Get version info.
@ -330,10 +331,7 @@ pub async fn version_middleware(req: Request, next: Next) -> Response {
// Add deprecation headers if needed // Add deprecation headers if needed
if let Some(info) = registry.get(&extracted.version) { if let Some(info) = registry.get(&extracted.version) {
if info.is_deprecated { if info.is_deprecated {
headers.insert( headers.insert(X_API_DEPRECATED.clone(), HeaderValue::from_static("true"));
X_API_DEPRECATED.clone(),
HeaderValue::from_static("true"),
);
if let Some(deprecated_at) = &info.deprecated_at { if let Some(deprecated_at) = &info.deprecated_at {
if let Ok(v) = HeaderValue::from_str(&deprecated_at.to_rfc3339()) { if let Ok(v) = HeaderValue::from_str(&deprecated_at.to_rfc3339()) {
@ -427,8 +425,8 @@ impl VersionsResponse {
// Routes // Routes
// ============================================================================ // ============================================================================
use axum::{routing::get, Json, Router};
use crate::routes::AppState; use crate::routes::AppState;
use axum::{routing::get, Json, Router};
/// Build version routes. /// Build version routes.
pub fn router() -> Router<AppState> { pub fn router() -> Router<AppState> {

View file

@ -593,11 +593,7 @@ async fn ws_blocks_handler(
ws.on_upgrade(move |socket| handle_blocks_socket(socket, state, auth)) ws.on_upgrade(move |socket| handle_blocks_socket(socket, state, auth))
} }
async fn handle_blocks_socket( async fn handle_blocks_socket(socket: WebSocket, state: AppState, _auth: Option<AuthContext>) {
socket: WebSocket,
state: AppState,
_auth: Option<AuthContext>,
) {
let ws_state = &state.websocket; let ws_state = &state.websocket;
ws_state.broadcaster.add_connection().await; ws_state.broadcaster.add_connection().await;
@ -834,11 +830,7 @@ async fn ws_markets_handler(
ws.on_upgrade(move |socket| handle_markets_socket(socket, state, auth)) ws.on_upgrade(move |socket| handle_markets_socket(socket, state, auth))
} }
async fn handle_markets_socket( async fn handle_markets_socket(socket: WebSocket, state: AppState, _auth: Option<AuthContext>) {
socket: WebSocket,
state: AppState,
_auth: Option<AuthContext>,
) {
let ws_state = &state.websocket; let ws_state = &state.websocket;
ws_state.broadcaster.add_connection().await; ws_state.broadcaster.add_connection().await;

View file

@ -7,8 +7,8 @@
//! hosting-gateway --domain synor.cc --storage-url http://localhost:8180 //! hosting-gateway --domain synor.cc --storage-url http://localhost:8180
//! hosting-gateway --config /path/to/config.toml //! hosting-gateway --config /path/to/config.toml
use synor_hosting::{HostingGateway, GatewayConfig};
use std::net::SocketAddr; use std::net::SocketAddr;
use synor_hosting::{GatewayConfig, HostingGateway};
#[tokio::main] #[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> { async fn main() -> Result<(), Box<dyn std::error::Error>> {
@ -19,28 +19,32 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
// Parse command line arguments // Parse command line arguments
let args: Vec<String> = std::env::args().collect(); let args: Vec<String> = std::env::args().collect();
let listen_addr: SocketAddr = args.iter() let listen_addr: SocketAddr = args
.iter()
.position(|a| a == "--listen" || a == "-l") .position(|a| a == "--listen" || a == "-l")
.and_then(|i| args.get(i + 1)) .and_then(|i| args.get(i + 1))
.and_then(|s| s.parse().ok()) .and_then(|s| s.parse().ok())
.or_else(|| std::env::var("LISTEN_ADDR").ok()?.parse().ok()) .or_else(|| std::env::var("LISTEN_ADDR").ok()?.parse().ok())
.unwrap_or_else(|| "0.0.0.0:8080".parse().unwrap()); .unwrap_or_else(|| "0.0.0.0:8080".parse().unwrap());
let hosting_domain = args.iter() let hosting_domain = args
.iter()
.position(|a| a == "--domain" || a == "-d") .position(|a| a == "--domain" || a == "-d")
.and_then(|i| args.get(i + 1)) .and_then(|i| args.get(i + 1))
.cloned() .cloned()
.or_else(|| std::env::var("HOSTING_DOMAIN").ok()) .or_else(|| std::env::var("HOSTING_DOMAIN").ok())
.unwrap_or_else(|| "synor.cc".to_string()); .unwrap_or_else(|| "synor.cc".to_string());
let storage_url = args.iter() let storage_url = args
.iter()
.position(|a| a == "--storage-url" || a == "-s") .position(|a| a == "--storage-url" || a == "-s")
.and_then(|i| args.get(i + 1)) .and_then(|i| args.get(i + 1))
.cloned() .cloned()
.or_else(|| std::env::var("STORAGE_GATEWAY_URL").ok()) .or_else(|| std::env::var("STORAGE_GATEWAY_URL").ok())
.unwrap_or_else(|| "http://localhost:8180".to_string()); .unwrap_or_else(|| "http://localhost:8180".to_string());
let rate_limit: u32 = args.iter() let rate_limit: u32 = args
.iter()
.position(|a| a == "--rate-limit") .position(|a| a == "--rate-limit")
.and_then(|i| args.get(i + 1)) .and_then(|i| args.get(i + 1))
.and_then(|s| s.parse().ok()) .and_then(|s| s.parse().ok())

View file

@ -233,11 +233,7 @@ impl EdgeCompute {
} }
/// Run AI inference at the edge. /// Run AI inference at the edge.
pub async fn inference( pub async fn inference(&self, _model: &str, _input: &[u8]) -> Result<Vec<u8>, EdgeError> {
&self,
_model: &str,
_input: &[u8],
) -> Result<Vec<u8>, EdgeError> {
if !self.enabled { if !self.enabled {
return Err(EdgeError::NotEnabled); return Err(EdgeError::NotEnabled);
} }

Some files were not shown because too many files have changed in this diff Show more