style: apply cargo fmt formatting

This commit is contained in:
Gulshan Yadav 2026-02-02 05:58:22 +05:30
parent 5126c33113
commit dcd1cccc67
170 changed files with 4463 additions and 2837 deletions

View file

@ -256,7 +256,11 @@ pub async fn handle(
Ok(())
}
CompilerCommands::Encode { function, args, abi } => {
CompilerCommands::Encode {
function,
args,
abi,
} => {
output::print_info(&format!("Encoding call to: {}", function));
output::print_kv("Arguments", &args);
if let Some(a) = abi {
@ -268,7 +272,11 @@ pub async fn handle(
Ok(())
}
CompilerCommands::Decode { data, function, abi } => {
CompilerCommands::Decode {
data,
function,
abi,
} => {
output::print_info(&format!("Decoding result for: {}", function));
output::print_kv("Data", &data);
if let Some(a) = abi {
@ -314,7 +322,11 @@ pub async fn handle(
Ok(())
}
CompilerCommands::SecurityScan { wasm, min_severity, format: _ } => {
CompilerCommands::SecurityScan {
wasm,
min_severity,
format: _,
} => {
output::print_info(&format!("Security scan: {}", wasm.display()));
output::print_kv("Min severity", &min_severity);
@ -344,7 +356,11 @@ pub async fn handle(
Ok(())
}
CompilerCommands::Validate { wasm, exports, max_memory } => {
CompilerCommands::Validate {
wasm,
exports,
max_memory,
} => {
output::print_info(&format!("Validating: {}", wasm.display()));
if let Some(e) = exports {

View file

@ -196,8 +196,7 @@ pub async fn deploy(
}
// Determine output directory
let output_path = output_dir
.unwrap_or_else(|| cwd.join(config.output_dir()));
let output_path = output_dir.unwrap_or_else(|| cwd.join(config.output_dir()));
if !output_path.exists() {
return Err(anyhow!(
@ -270,7 +269,10 @@ fn validate_name(name: &str) -> Result<()> {
if name.len() > 63 {
return Err(anyhow!("Name must be 63 characters or less"));
}
if !name.chars().all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '-') {
if !name
.chars()
.all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '-')
{
return Err(anyhow!(
"Name must contain only lowercase letters, numbers, and hyphens"
));
@ -281,8 +283,8 @@ fn validate_name(name: &str) -> Result<()> {
// Reserved names
const RESERVED: &[&str] = &[
"www", "api", "app", "admin", "mail", "ftp", "ssh", "cdn",
"storage", "gateway", "hosting", "node", "synor",
"www", "api", "app", "admin", "mail", "ftp", "ssh", "cdn", "storage", "gateway", "hosting",
"node", "synor",
];
if RESERVED.contains(&name) {
return Err(anyhow!("Name '{}' is reserved", name));
@ -397,11 +399,7 @@ fn guess_content_type(path: &Path) -> String {
}
/// Upload files to Synor Storage.
async fn upload_files(
base_dir: &Path,
files: &[DeployFile],
gateway_url: &str,
) -> Result<String> {
async fn upload_files(base_dir: &Path, files: &[DeployFile], gateway_url: &str) -> Result<String> {
let client = reqwest::Client::new();
// Create a multipart form with all files
@ -445,11 +443,7 @@ async fn upload_files(
}
/// Register the deployment with the hosting gateway.
async fn register_deployment(
name: &str,
cid: &str,
gateway_url: &str,
) -> Result<String> {
async fn register_deployment(name: &str, cid: &str, gateway_url: &str) -> Result<String> {
let client = reqwest::Client::new();
#[derive(Serialize)]
@ -662,7 +656,11 @@ pub async fn delete(name: &str, gateway_url: &str, format: OutputFormat) -> Resu
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(anyhow!("Failed to delete deployment: {} - {}", status, body));
return Err(anyhow!(
"Failed to delete deployment: {} - {}",
status,
body
));
}
match format {
@ -707,22 +705,13 @@ mod tests {
#[test]
fn test_guess_content_type() {
assert_eq!(
guess_content_type(Path::new("index.html")),
"text/html"
);
assert_eq!(
guess_content_type(Path::new("style.css")),
"text/css"
);
assert_eq!(guess_content_type(Path::new("index.html")), "text/html");
assert_eq!(guess_content_type(Path::new("style.css")), "text/css");
assert_eq!(
guess_content_type(Path::new("app.js")),
"application/javascript"
);
assert_eq!(
guess_content_type(Path::new("image.png")),
"image/png"
);
assert_eq!(guess_content_type(Path::new("image.png")), "image/png");
assert_eq!(
guess_content_type(Path::new("data.wasm")),
"application/wasm"

View file

@ -246,7 +246,13 @@ pub async fn handle(
Ok(())
}
DexCommands::PlaceOrder { market, side, price, quantity, wallet } => {
DexCommands::PlaceOrder {
market,
side,
price,
quantity,
wallet,
} => {
output::print_info("Placing limit order...");
output::print_kv("Market", &market);
output::print_kv("Side", &side);
@ -257,7 +263,12 @@ pub async fn handle(
Ok(())
}
DexCommands::MarketOrder { market, side, quantity, wallet } => {
DexCommands::MarketOrder {
market,
side,
quantity,
wallet,
} => {
output::print_info("Placing market order...");
output::print_kv("Market", &market);
output::print_kv("Side", &side);
@ -275,7 +286,10 @@ pub async fn handle(
DexCommands::CancelAll { market, wallet } => {
let scope = market.unwrap_or_else(|| "all markets".to_string());
output::print_info(&format!("Cancelling all orders in {} for {}", scope, wallet));
output::print_info(&format!(
"Cancelling all orders in {} for {}",
scope, wallet
));
output::print_success("3 orders cancelled");
Ok(())
}
@ -317,7 +331,12 @@ pub async fn handle(
Ok(())
}
DexCommands::AddLiquidity { pool_id, amount_a, amount_b, wallet } => {
DexCommands::AddLiquidity {
pool_id,
amount_a,
amount_b,
wallet,
} => {
output::print_info("Adding liquidity...");
output::print_kv("Pool", &pool_id);
output::print_kv("Amount A", &amount_a);
@ -327,7 +346,11 @@ pub async fn handle(
Ok(())
}
DexCommands::RemoveLiquidity { pool_id, lp_amount, wallet } => {
DexCommands::RemoveLiquidity {
pool_id,
lp_amount,
wallet,
} => {
output::print_info("Removing liquidity...");
output::print_kv("Pool", &pool_id);
output::print_kv("LP Amount", &lp_amount);

View file

@ -169,11 +169,7 @@ pub enum ZkCommands {
}
/// Handle ZK commands.
pub async fn handle(
_client: &RpcClient,
command: ZkCommands,
_format: OutputFormat,
) -> Result<()> {
pub async fn handle(_client: &RpcClient, command: ZkCommands, _format: OutputFormat) -> Result<()> {
match command {
ZkCommands::Compile { circuit, output } => {
output::print_info(&format!("Compiling circuit: {}", circuit.display()));
@ -211,8 +207,16 @@ pub async fn handle(
Ok(())
}
ZkCommands::ProveGroth16 { circuit, witness, proving_key: _, output } => {
output::print_info(&format!("Generating Groth16 proof for circuit: {}", circuit));
ZkCommands::ProveGroth16 {
circuit,
witness,
proving_key: _,
output,
} => {
output::print_info(&format!(
"Generating Groth16 proof for circuit: {}",
circuit
));
output::print_info(&format!("Witness: {}", witness.display()));
output::print_info("Computing witness...");
output::print_info("Generating proof...");
@ -226,7 +230,11 @@ pub async fn handle(
Ok(())
}
ZkCommands::ProvePlonk { circuit, witness, output } => {
ZkCommands::ProvePlonk {
circuit,
witness,
output,
} => {
output::print_info(&format!("Generating PLONK proof for circuit: {}", circuit));
output::print_info(&format!("Witness: {}", witness.display()));
output::print_info("Computing witness...");
@ -240,7 +248,11 @@ pub async fn handle(
Ok(())
}
ZkCommands::ProveStark { circuit, witness, output } => {
ZkCommands::ProveStark {
circuit,
witness,
output,
} => {
output::print_info(&format!("Generating STARK proof for circuit: {}", circuit));
output::print_info(&format!("Witness: {}", witness.display()));
output::print_info("Computing execution trace...");
@ -256,7 +268,11 @@ pub async fn handle(
Ok(())
}
ZkCommands::Verify { proof, verification_key: _, public_inputs: _ } => {
ZkCommands::Verify {
proof,
verification_key: _,
public_inputs: _,
} => {
output::print_info(&format!("Verifying proof: {}", proof.display()));
output::print_info("Loading proof...");
output::print_info("Verifying...");
@ -265,8 +281,15 @@ pub async fn handle(
Ok(())
}
ZkCommands::Setup { circuit, system, output } => {
output::print_info(&format!("Generating {} keys for circuit: {}", system, circuit));
ZkCommands::Setup {
circuit,
system,
output,
} => {
output::print_info(&format!(
"Generating {} keys for circuit: {}",
system, circuit
));
output::print_info("This may take a while for large circuits...");
output::print_info("Generating proving key...");
output::print_info("Deriving verification key...");

View file

@ -469,7 +469,11 @@ enum DeployCommands {
output: Option<PathBuf>,
/// Hosting gateway URL
#[arg(long, env = "SYNOR_HOSTING_URL", default_value = "http://127.0.0.1:8280")]
#[arg(
long,
env = "SYNOR_HOSTING_URL",
default_value = "http://127.0.0.1:8280"
)]
gateway: String,
/// Skip running the build command
@ -495,7 +499,11 @@ enum DeployCommands {
/// List deployments
List {
/// Hosting gateway URL
#[arg(long, env = "SYNOR_HOSTING_URL", default_value = "http://127.0.0.1:8280")]
#[arg(
long,
env = "SYNOR_HOSTING_URL",
default_value = "http://127.0.0.1:8280"
)]
gateway: String,
},
@ -505,7 +513,11 @@ enum DeployCommands {
name: String,
/// Hosting gateway URL
#[arg(long, env = "SYNOR_HOSTING_URL", default_value = "http://127.0.0.1:8280")]
#[arg(
long,
env = "SYNOR_HOSTING_URL",
default_value = "http://127.0.0.1:8280"
)]
gateway: String,
},
@ -515,7 +527,11 @@ enum DeployCommands {
name: String,
/// Hosting gateway URL
#[arg(long, env = "SYNOR_HOSTING_URL", default_value = "http://127.0.0.1:8280")]
#[arg(
long,
env = "SYNOR_HOSTING_URL",
default_value = "http://127.0.0.1:8280"
)]
gateway: String,
},
}
@ -591,9 +607,11 @@ async fn main() {
gateway,
skip_build,
} => commands::deploy::deploy(name, out_dir, &gateway, skip_build, output).await,
DeployCommands::Init { name, spa, output: out_dir } => {
commands::deploy::init(name, spa, out_dir, output)
}
DeployCommands::Init {
name,
spa,
output: out_dir,
} => commands::deploy::init(name, spa, out_dir, output),
DeployCommands::List { gateway } => commands::deploy::list(&gateway, output).await,
DeployCommands::Delete { name, gateway } => {
commands::deploy::delete(&name, &gateway, output).await

View file

@ -676,7 +676,9 @@ async fn get_blocks(
}
// Fetch blocks by blue score (most recent first)
let start_score = score.blue_score.saturating_sub((params.page.saturating_sub(1) * limit) as u64);
let start_score = score
.blue_score
.saturating_sub((params.page.saturating_sub(1) * limit) as u64);
let blocks_data: Vec<serde_json::Value> = state
.rpc_call("synor_getBlocksByBlueScore", (start_score, true))
.await
@ -697,17 +699,28 @@ async fn get_blocks(
parent_hashes: header
.get("parents")
.and_then(|p| p.as_array())
.map(|a| a.iter().filter_map(|v| v.as_str().map(String::from)).collect())
.map(|a| {
a.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect()
})
.unwrap_or_default(),
timestamp,
timestamp_human: format_timestamp(timestamp),
bits: header.get("bits")?.as_u64()? as u32,
nonce: header.get("nonce")?.as_u64()?,
daa_score: header.get("blueScore").and_then(|v| v.as_u64()).unwrap_or(0),
blue_score: header.get("blueScore").and_then(|v| v.as_u64()).unwrap_or(0),
daa_score: header
.get("blueScore")
.and_then(|v| v.as_u64())
.unwrap_or(0),
blue_score: header
.get("blueScore")
.and_then(|v| v.as_u64())
.unwrap_or(0),
blue_work: String::new(),
difficulty: 0.0,
transaction_count: b.get("transactions")
transaction_count: b
.get("transactions")
.and_then(|t| t.as_array())
.map(|a| a.len())
.unwrap_or(0),
@ -1102,9 +1115,7 @@ async fn estimate_gas(
};
// Call the node's contract_estimateGas RPC method
let gas_used: u64 = state
.rpc_call("contract_estimateGas", rpc_request)
.await?;
let gas_used: u64 = state.rpc_call("contract_estimateGas", rpc_request).await?;
// Calculate recommended gas limit with 20% safety margin
let gas_limit_recommended = ((gas_used as f64) * 1.2).ceil() as u64;
@ -1494,8 +1505,7 @@ async fn main() -> anyhow::Result<()> {
let app = if let Some(ref static_dir) = config.static_dir {
// Serve static files with SPA fallback (index.html for client-side routing)
let index_path = format!("{}/index.html", static_dir);
let serve_dir = ServeDir::new(static_dir)
.not_found_service(ServeFile::new(&index_path));
let serve_dir = ServeDir::new(static_dir).not_found_service(ServeFile::new(&index_path));
api_router
.fallback_service(serve_dir)

View file

@ -684,10 +684,12 @@ mod tests {
#[test]
fn test_all_paths_are_distinct() {
let config = NodeConfig::for_network("mainnet").unwrap();
let paths = [config.blocks_path(),
let paths = [
config.blocks_path(),
config.chainstate_path(),
config.contracts_path(),
config.keys_path()];
config.keys_path(),
];
for i in 0..paths.len() {
for j in (i + 1)..paths.len() {
@ -794,9 +796,11 @@ mod tests {
#[test]
fn test_with_mining_enabled() {
let config = NodeConfig::for_network("mainnet")
.unwrap()
.with_mining(true, Some("synor:test_address".to_string()), 4);
let config = NodeConfig::for_network("mainnet").unwrap().with_mining(
true,
Some("synor:test_address".to_string()),
4,
);
assert!(config.mining.enabled);
assert_eq!(
@ -828,9 +832,10 @@ mod tests {
#[test]
fn test_with_p2p() {
let seeds = vec!["seed1.example.com:30303".to_string()];
let config = NodeConfig::for_network("mainnet")
.unwrap()
.with_p2p("0.0.0.0", 30303, seeds.clone());
let config =
NodeConfig::for_network("mainnet")
.unwrap()
.with_p2p("0.0.0.0", 30303, seeds.clone());
assert_eq!(config.p2p.listen_addr, "0.0.0.0:30303");
assert_eq!(config.p2p.seeds, seeds);
@ -1027,7 +1032,10 @@ mod tests {
let loaded = NodeConfig::load(&path).unwrap();
assert_eq!(loaded.mining.enabled, config.mining.enabled);
assert_eq!(loaded.mining.coinbase_address, config.mining.coinbase_address);
assert_eq!(
loaded.mining.coinbase_address,
config.mining.coinbase_address
);
assert_eq!(loaded.mining.threads, config.mining.threads);
assert_eq!(loaded.storage.cache_size_mb, config.storage.cache_size_mb);
assert_eq!(loaded.logging.level, config.logging.level);

View file

@ -425,11 +425,13 @@ mod tests {
#[test]
fn test_node_state_all_variants_are_distinct() {
let states = [NodeState::Starting,
let states = [
NodeState::Starting,
NodeState::Syncing,
NodeState::Running,
NodeState::Stopping,
NodeState::Stopped];
NodeState::Stopped,
];
for i in 0..states.len() {
for j in (i + 1)..states.len() {
@ -605,7 +607,10 @@ mod tests {
.with_mining(true, Some("synor:test".to_string()), 4);
assert!(config.mining.enabled);
assert_eq!(config.mining.coinbase_address, Some("synor:test".to_string()));
assert_eq!(
config.mining.coinbase_address,
Some("synor:test".to_string())
);
assert_eq!(config.mining.threads, 4);
}

View file

@ -12,9 +12,12 @@ use synor_mining::{
MinerCommand, MinerConfig, MinerEvent, MiningResult, MiningStats as CrateMiningStats,
TemplateTransaction,
};
use synor_types::{Address, Amount, Block, BlockHeader, BlockId, BlueScore, Hash256, Network, Timestamp, Transaction, TxOutput};
use synor_types::block::BlockBody;
use synor_types::transaction::ScriptPubKey;
use synor_types::{
Address, Amount, Block, BlockHeader, BlockId, BlueScore, Hash256, Network, Timestamp,
Transaction, TxOutput,
};
use crate::config::NodeConfig;
use crate::services::{ConsensusService, MempoolService};
@ -473,10 +476,7 @@ impl MinerService {
extra_data.extend_from_slice(&result.nonce.to_le_bytes());
extra_data.extend_from_slice(&template.coinbase_data.extra_data);
let coinbase_tx = Transaction::coinbase(
vec![coinbase_output],
extra_data,
);
let coinbase_tx = Transaction::coinbase(vec![coinbase_output], extra_data);
// Start with coinbase transaction
let mut transactions = vec![coinbase_tx];
@ -522,8 +522,7 @@ impl MinerService {
let block = Block { header, body };
// Serialize with Borsh
borsh::to_vec(&block)
.map_err(|e| anyhow::anyhow!("Failed to serialize block: {}", e))
borsh::to_vec(&block).map_err(|e| anyhow::anyhow!("Failed to serialize block: {}", e))
}
/// Submits a mined block (for external submission via RPC).

View file

@ -234,7 +234,10 @@ mod network_partition_tests {
// Node 0 should have fewer peers after isolation
let isolated_peers = network.nodes[0].network().peer_count().await;
info!(isolated_peers = isolated_peers, "Node 0 peers after isolation");
info!(
isolated_peers = isolated_peers,
"Node 0 peers after isolation"
);
assert!(
isolated_peers < initial_peer_counts[0] || initial_peer_counts[0] == 0,
@ -271,7 +274,10 @@ mod network_partition_tests {
// After healing, nodes should have peers
let total_peers = network.total_peer_count().await;
info!(total_peers = total_peers, "Total peers after partition recovery");
info!(
total_peers = total_peers,
"Total peers after partition recovery"
);
// Consensus state should converge
let consensus0 = network.nodes[0].consensus();
@ -287,7 +293,10 @@ mod network_partition_tests {
);
// Both should have some consensus state
assert!(vsp0.is_some() || vsp1.is_some(), "At least one node should have VSP");
assert!(
vsp0.is_some() || vsp1.is_some(),
"At least one node should have VSP"
);
network.stop_all().await.unwrap();
}
@ -351,10 +360,12 @@ mod network_partition_tests {
// Record blue scores from each partition
let scores_before: Vec<u64> = futures::future::join_all(
network.nodes.iter().map(|n| async {
n.consensus().current_blue_score().await
})
).await;
network
.nodes
.iter()
.map(|n| async { n.consensus().current_blue_score().await }),
)
.await;
info!(scores_before = ?scores_before, "Blue scores before healing");
@ -368,10 +379,12 @@ mod network_partition_tests {
// Blue scores should converge
let scores_after: Vec<u64> = futures::future::join_all(
network.nodes.iter().map(|n| async {
n.consensus().current_blue_score().await
})
).await;
network
.nodes
.iter()
.map(|n| async { n.consensus().current_blue_score().await }),
)
.await;
info!(scores_after = ?scores_after, "Blue scores after healing");
@ -380,7 +393,9 @@ mod network_partition_tests {
assert!(
after >= before,
"Node {} blue score should not decrease: {} -> {}",
i, before, after
i,
before,
after
);
}
@ -407,10 +422,7 @@ mod double_spend_tests {
let mempool = network.nodes[0].mempool();
let initial_size = mempool.size().await;
info!(
initial_mempool_size = initial_size,
"Initial mempool state"
);
info!(initial_mempool_size = initial_size, "Initial mempool state");
// In production, we would:
// 1. Create two transactions spending the same UTXO
@ -420,7 +432,7 @@ mod double_spend_tests {
// For now, verify mempool API is working
// and handles empty/invalid data gracefully
let _invalid_tx = vec![0u8; 50]; // Invalid transaction bytes (for future use)
// Submitting invalid tx should fail gracefully
// Submitting invalid tx should fail gracefully
// Mempool should maintain integrity
let final_size = mempool.size().await;
@ -653,7 +665,11 @@ mod invalid_block_rejection_tests {
// All valid tips should have known parents in the DAG
for tip in &tips {
let has_parents = consensus.get_block_info(tip).await.map(|info| !info.parents.is_empty()).unwrap_or(false);
let has_parents = consensus
.get_block_info(tip)
.await
.map(|info| !info.parents.is_empty())
.unwrap_or(false);
info!(
block = hex::encode(&tip[..8]),
has_parents = has_parents,
@ -687,16 +703,22 @@ mod sybil_attack_tests {
// Track blue scores - honest nodes should maintain correct view
let honest_scores: Vec<u64> = futures::future::join_all(
network.nodes.iter().take(3).map(|n| async {
n.consensus().current_blue_score().await
})
).await;
network
.nodes
.iter()
.take(3)
.map(|n| async { n.consensus().current_blue_score().await }),
)
.await;
let sybil_scores: Vec<u64> = futures::future::join_all(
network.nodes.iter().skip(3).map(|n| async {
n.consensus().current_blue_score().await
})
).await;
network
.nodes
.iter()
.skip(3)
.map(|n| async { n.consensus().current_blue_score().await }),
)
.await;
info!(
honest_scores = ?honest_scores,
@ -805,7 +827,10 @@ mod eclipse_attack_tests {
sleep(Duration::from_secs(1)).await;
let after_eclipse_peers = victim_network.peer_count().await;
info!(after_eclipse_peers = after_eclipse_peers, "Peers after eclipse attempt");
info!(
after_eclipse_peers = after_eclipse_peers,
"Peers after eclipse attempt"
);
// In a real implementation, the node would:
// 1. Detect low peer diversity
@ -863,7 +888,10 @@ mod eclipse_attack_tests {
sleep(Duration::from_secs(1)).await;
let eclipsed_peers = network.nodes[0].network().peer_count().await;
info!(eclipsed_peers = eclipsed_peers, "Node 0 peers during eclipse");
info!(
eclipsed_peers = eclipsed_peers,
"Node 0 peers during eclipse"
);
// Manually reconnect (simulating recovery mechanism)
network.connect_nodes(0, 1).await.unwrap();
@ -871,7 +899,10 @@ mod eclipse_attack_tests {
sleep(Duration::from_secs(2)).await;
let recovered_peers = network.nodes[0].network().peer_count().await;
info!(recovered_peers = recovered_peers, "Node 0 peers after recovery");
info!(
recovered_peers = recovered_peers,
"Node 0 peers after recovery"
);
// Should have reconnected
assert!(
@ -1038,10 +1069,12 @@ mod dag_reorg_tests {
// Record divergent states
let states_before: Vec<u64> = futures::future::join_all(
network.nodes.iter().map(|n| async {
n.consensus().current_blue_score().await
})
).await;
network
.nodes
.iter()
.map(|n| async { n.consensus().current_blue_score().await }),
)
.await;
info!(states_before = ?states_before, "States before reconnection");
@ -1052,10 +1085,12 @@ mod dag_reorg_tests {
// Get converged states
let states_after: Vec<u64> = futures::future::join_all(
network.nodes.iter().map(|n| async {
n.consensus().current_blue_score().await
})
).await;
network
.nodes
.iter()
.map(|n| async { n.consensus().current_blue_score().await }),
)
.await;
info!(states_after = ?states_after, "States after reconnection");
@ -1064,7 +1099,9 @@ mod dag_reorg_tests {
assert!(
after >= before,
"Node {} blue score regression: {} -> {}",
i, before, after
i,
before,
after
);
}
@ -1192,10 +1229,12 @@ mod parallel_blocks_tests {
// Collect blue scores from all nodes
let blue_scores: Vec<u64> = futures::future::join_all(
network.nodes.iter().map(|n| async {
n.consensus().current_blue_score().await
})
).await;
network
.nodes
.iter()
.map(|n| async { n.consensus().current_blue_score().await }),
)
.await;
info!(blue_scores = ?blue_scores, "Blue scores across nodes");
@ -1206,7 +1245,8 @@ mod parallel_blocks_tests {
assert!(
max_score - min_score <= 2,
"Blue scores should be consistent: {} - {} > 2",
max_score, min_score
max_score,
min_score
);
network.stop_all().await.unwrap();
@ -1264,10 +1304,12 @@ mod parallel_blocks_tests {
// Get selected chains from all nodes
let chains: Vec<Vec<[u8; 32]>> = futures::future::join_all(
network.nodes.iter().map(|n| async {
n.consensus().get_selected_chain(10).await
})
).await;
network
.nodes
.iter()
.map(|n| async { n.consensus().get_selected_chain(10).await }),
)
.await;
info!(
chain_lengths = ?chains.iter().map(|c| c.len()).collect::<Vec<_>>(),
@ -1276,7 +1318,8 @@ mod parallel_blocks_tests {
// All nodes should have the same selected chain (after sync)
// Check that genesis (first block) matches
let genesis_blocks: Vec<_> = chains.iter()
let genesis_blocks: Vec<_> = chains
.iter()
.filter(|c| !c.is_empty())
.map(|c| c[0])
.collect();
@ -1353,10 +1396,13 @@ mod bft_threshold_tests {
// Honest nodes (0, 1, 2) should maintain consensus
let honest_scores: Vec<u64> = futures::future::join_all(
network.nodes.iter().take(3).map(|n| async {
n.consensus().current_blue_score().await
})
).await;
network
.nodes
.iter()
.take(3)
.map(|n| async { n.consensus().current_blue_score().await }),
)
.await;
info!(honest_scores = ?honest_scores, "Honest node blue scores");
@ -1399,10 +1445,7 @@ mod bft_threshold_tests {
// Blue score should not decrease
let final_blue = network.nodes[0].consensus().current_blue_score().await;
assert!(
final_blue >= initial_blue,
"Blue score should not decrease"
);
assert!(final_blue >= initial_blue, "Blue score should not decrease");
// Stop remaining nodes
for node in network.nodes.iter().take(3) {
@ -1615,10 +1658,12 @@ mod integration_tests {
// Record initial state
let initial_scores: Vec<u64> = futures::future::join_all(
network.nodes.iter().map(|n| async {
n.consensus().current_blue_score().await
})
).await;
network
.nodes
.iter()
.map(|n| async { n.consensus().current_blue_score().await }),
)
.await;
info!(initial_scores = ?initial_scores, "Initial blue scores");
info!("Phase 2: Simulate 2 Byzantine nodes (partition)");
@ -1640,18 +1685,24 @@ mod integration_tests {
info!("Phase 4: Verify convergence");
let final_scores: Vec<u64> = futures::future::join_all(
network.nodes.iter().map(|n| async {
n.consensus().current_blue_score().await
})
).await;
network
.nodes
.iter()
.map(|n| async { n.consensus().current_blue_score().await }),
)
.await;
info!(final_scores = ?final_scores, "Final blue scores");
// All nodes should have non-decreasing blue scores
for (i, (&initial, &final_score)) in initial_scores.iter().zip(final_scores.iter()).enumerate() {
for (i, (&initial, &final_score)) in
initial_scores.iter().zip(final_scores.iter()).enumerate()
{
assert!(
final_score >= initial,
"Node {} score regression: {} -> {}",
i, initial, final_score
i,
initial,
final_score
);
}

View file

@ -17,8 +17,8 @@
//! 3. Vault contract verifies proof and unlocks original tokens
use crate::{
AssetId, Bridge, BridgeAddress, BridgeError, BridgeResult, BridgeTransfer, ChainType, TransferId, TransferManager, TransferStatus, VaultManager,
ETH_MIN_CONFIRMATIONS,
AssetId, Bridge, BridgeAddress, BridgeError, BridgeResult, BridgeTransfer, ChainType,
TransferId, TransferManager, TransferStatus, VaultManager, ETH_MIN_CONFIRMATIONS,
};
use alloy_primitives::{Address, B256, U256};
use alloy_sol_types::sol;
@ -281,9 +281,9 @@ impl EthereumBridge {
// Check for replay
let event_hash = event.hash();
if self.processed_events.read().contains_key(&event_hash) {
return Err(BridgeError::TransferAlreadyExists(
hex::encode(event_hash.as_slice()),
));
return Err(BridgeError::TransferAlreadyExists(hex::encode(
event_hash.as_slice(),
)));
}
// Verify token is supported
@ -393,18 +393,15 @@ impl EthereumBridge {
// Collect matching transfer IDs first
let matching_transfer_id = {
let transfers = self.transfers.read();
transfers
.pending_transfers()
.iter()
.find_map(|transfer| {
transfer.source_tx_hash.as_ref().and_then(|tx_hash| {
if tx_hash.as_slice() == event_hash.as_slice() {
Some(transfer.id.clone())
} else {
None
}
})
transfers.pending_transfers().iter().find_map(|transfer| {
transfer.source_tx_hash.as_ref().and_then(|tx_hash| {
if tx_hash.as_slice() == event_hash.as_slice() {
Some(transfer.id.clone())
} else {
None
}
})
})
};
// Now update the transfer if found
@ -457,7 +454,9 @@ impl EthereumBridge {
.map_err(|e| BridgeError::InvalidAddress(e.to_string()))?;
if bytes.len() != 20 {
return Err(BridgeError::InvalidAddress("invalid address length".to_string()));
return Err(BridgeError::InvalidAddress(
"invalid address length".to_string(),
));
}
Address::from_slice(&bytes)
};
@ -801,7 +800,10 @@ mod tests {
wrapped.mint(1000);
let result = wrapped.burn(1500);
assert!(matches!(result, Err(BridgeError::InsufficientBalance { .. })));
assert!(matches!(
result,
Err(BridgeError::InsufficientBalance { .. })
));
}
#[test]
@ -896,7 +898,9 @@ mod tests {
let current_time = 1700000000;
let event = create_lock_event(0);
bridge.process_lock_event(event.clone(), current_time).unwrap();
bridge
.process_lock_event(event.clone(), current_time)
.unwrap();
let result = bridge.process_lock_event(event, current_time + 100);
assert!(matches!(result, Err(BridgeError::TransferAlreadyExists(_))));
@ -949,8 +953,12 @@ mod tests {
let event_hash = B256::from([0x11; 32]);
let unauthorized_relayer = Address::from([0x99; 20]);
let result = bridge.submit_relayer_signature(event_hash, unauthorized_relayer, vec![0x00; 65]);
assert!(matches!(result, Err(BridgeError::SignatureVerificationFailed(_))));
let result =
bridge.submit_relayer_signature(event_hash, unauthorized_relayer, vec![0x00; 65]);
assert!(matches!(
result,
Err(BridgeError::SignatureVerificationFailed(_))
));
}
#[test]
@ -964,7 +972,9 @@ mod tests {
});
let event_hash = B256::from([0x11; 32]);
let result = bridge.submit_relayer_signature(event_hash, relayer, vec![0x00; 65]).unwrap();
let result = bridge
.submit_relayer_signature(event_hash, relayer, vec![0x00; 65])
.unwrap();
assert!(result);
}
@ -983,7 +993,9 @@ mod tests {
.update_confirmations(&transfer_id, 12, current_time + 100)
.unwrap();
bridge.mint_wrapped_tokens(&transfer_id, current_time + 200).unwrap();
bridge
.mint_wrapped_tokens(&transfer_id, current_time + 200)
.unwrap();
let wrapped = bridge.get_wrapped_token(Address::ZERO).unwrap();
assert_eq!(wrapped.total_supply, 1000);
@ -1022,13 +1034,7 @@ mod tests {
let asset = AssetId::wrapped(&AssetId::eth());
let transfer_id = bridge
.initiate_burn(
asset,
1000,
test_recipient(),
test_sender(),
current_time,
)
.initiate_burn(asset, 1000, test_recipient(), test_sender(), current_time)
.unwrap();
let transfers = bridge.transfers.read();
@ -1053,15 +1059,13 @@ mod tests {
drop(wrapped_tokens);
let asset = AssetId::wrapped(&AssetId::eth());
let result = bridge.initiate_burn(
asset,
1000,
test_recipient(),
test_sender(),
current_time,
);
let result =
bridge.initiate_burn(asset, 1000, test_recipient(), test_sender(), current_time);
assert!(matches!(result, Err(BridgeError::InsufficientBalance { .. })));
assert!(matches!(
result,
Err(BridgeError::InsufficientBalance { .. })
));
}
#[test]
@ -1069,13 +1073,7 @@ mod tests {
let bridge = EthereumBridge::new(EthereumBridgeConfig::default());
let asset = AssetId::wrapped(&AssetId::eth());
let result = bridge.initiate_burn(
asset,
1000,
test_recipient(),
test_sender(),
0,
);
let result = bridge.initiate_burn(asset, 1000, test_recipient(), test_sender(), 0);
assert!(matches!(result, Err(BridgeError::AssetNotSupported(_))));
}

View file

@ -79,7 +79,9 @@ pub const ETH_MIN_CONFIRMATIONS: u64 = 12;
pub const BTC_MIN_CONFIRMATIONS: u64 = 6;
/// Bridge chain identifier
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, BorshSerialize, BorshDeserialize)]
#[derive(
Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, BorshSerialize, BorshDeserialize,
)]
pub enum ChainType {
/// Synor native chain
Synor,
@ -128,7 +130,9 @@ impl fmt::Display for ChainType {
}
/// Asset identifier across chains
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, BorshSerialize, BorshDeserialize)]
#[derive(
Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, BorshSerialize, BorshDeserialize,
)]
pub struct AssetId {
/// Chain where the asset originates
pub chain: ChainType,
@ -199,7 +203,9 @@ impl fmt::Display for AssetId {
}
/// Bridge address (unified format for cross-chain addresses)
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, BorshSerialize, BorshDeserialize)]
#[derive(
Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, BorshSerialize, BorshDeserialize,
)]
pub struct BridgeAddress {
/// Chain type
pub chain: ChainType,

View file

@ -14,7 +14,9 @@ use std::collections::HashMap;
use std::fmt;
/// Unique transfer identifier
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, BorshSerialize, BorshDeserialize)]
#[derive(
Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, BorshSerialize, BorshDeserialize,
)]
pub struct TransferId(pub String);
impl TransferId {
@ -48,7 +50,9 @@ impl fmt::Display for TransferId {
}
/// Transfer direction
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, BorshSerialize, BorshDeserialize)]
#[derive(
Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, BorshSerialize, BorshDeserialize,
)]
pub enum TransferDirection {
/// From external chain to Synor (Lock → Mint)
Inbound,
@ -66,7 +70,9 @@ impl fmt::Display for TransferDirection {
}
/// Transfer status
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, BorshSerialize, BorshDeserialize)]
#[derive(
Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, BorshSerialize, BorshDeserialize,
)]
pub enum TransferStatus {
/// Transfer initiated, awaiting lock confirmation
Pending,
@ -1030,7 +1036,10 @@ mod tests {
transfer.fail("Proof verification failed", current_time + 50);
assert_eq!(transfer.status, TransferStatus::Failed);
assert_eq!(transfer.error, Some("Proof verification failed".to_string()));
assert_eq!(
transfer.error,
Some("Proof verification failed".to_string())
);
}
#[test]
@ -1331,8 +1340,12 @@ mod tests {
assert_eq!(manager.pending_transfers().len(), 1);
manager.confirm_lock(&id, vec![0x11; 32], 100, current_time + 10).unwrap();
manager.update_confirmations(&id, 12, current_time + 120).unwrap();
manager
.confirm_lock(&id, vec![0x11; 32], 100, current_time + 10)
.unwrap();
manager
.update_confirmations(&id, 12, current_time + 120)
.unwrap();
assert_eq!(manager.pending_transfers().len(), 0);
}
@ -1356,8 +1369,12 @@ mod tests {
assert_eq!(manager.ready_for_confirmation().len(), 0);
manager.confirm_lock(&id, vec![0x11; 32], 100, current_time + 10).unwrap();
manager.update_confirmations(&id, 12, current_time + 120).unwrap();
manager
.confirm_lock(&id, vec![0x11; 32], 100, current_time + 10)
.unwrap();
manager
.update_confirmations(&id, 12, current_time + 120)
.unwrap();
assert_eq!(manager.ready_for_confirmation().len(), 1);
}
@ -1410,7 +1427,9 @@ mod tests {
)
.unwrap();
manager.fail_transfer(&id, "Verification failed", current_time + 50).unwrap();
manager
.fail_transfer(&id, "Verification failed", current_time + 50)
.unwrap();
let transfer = manager.get(&id).unwrap();
assert_eq!(transfer.status, TransferStatus::Failed);
@ -1452,9 +1471,15 @@ mod tests {
)
.unwrap();
manager.confirm_lock(&id1, vec![0x11; 32], 100, current_time).unwrap();
manager.update_confirmations(&id1, 12, current_time).unwrap();
manager.confirm_mint(&id1, vec![0x22; 32], current_time).unwrap();
manager
.confirm_lock(&id1, vec![0x11; 32], 100, current_time)
.unwrap();
manager
.update_confirmations(&id1, 12, current_time)
.unwrap();
manager
.confirm_mint(&id1, vec![0x22; 32], current_time)
.unwrap();
let stats = manager.stats();
assert_eq!(stats.total_count, 2);
@ -1484,19 +1509,27 @@ mod tests {
let transfer = manager.get(&id).unwrap();
assert_eq!(transfer.status, TransferStatus::Pending);
manager.confirm_lock(&id, vec![0x11; 32], 100, current_time + 60).unwrap();
manager
.confirm_lock(&id, vec![0x11; 32], 100, current_time + 60)
.unwrap();
let transfer = manager.get(&id).unwrap();
assert_eq!(transfer.status, TransferStatus::Locked);
manager.update_confirmations(&id, 6, current_time + 120).unwrap();
manager
.update_confirmations(&id, 6, current_time + 120)
.unwrap();
let transfer = manager.get(&id).unwrap();
assert_eq!(transfer.status, TransferStatus::Locked);
manager.update_confirmations(&id, 12, current_time + 180).unwrap();
manager
.update_confirmations(&id, 12, current_time + 180)
.unwrap();
let transfer = manager.get(&id).unwrap();
assert_eq!(transfer.status, TransferStatus::Confirmed);
manager.confirm_mint(&id, vec![0x22; 32], current_time + 240).unwrap();
manager
.confirm_mint(&id, vec![0x22; 32], current_time + 240)
.unwrap();
let transfer = manager.get(&id).unwrap();
assert_eq!(transfer.status, TransferStatus::Completed);
}
@ -1521,15 +1554,21 @@ mod tests {
let transfer = manager.get(&id).unwrap();
assert_eq!(transfer.status, TransferStatus::Pending);
manager.confirm_lock(&id, vec![0x11; 32], 100, current_time + 60).unwrap();
manager
.confirm_lock(&id, vec![0x11; 32], 100, current_time + 60)
.unwrap();
let transfer = manager.get(&id).unwrap();
assert_eq!(transfer.status, TransferStatus::Locked);
manager.update_confirmations(&id, 6, current_time + 120).unwrap();
manager
.update_confirmations(&id, 6, current_time + 120)
.unwrap();
let transfer = manager.get(&id).unwrap();
assert_eq!(transfer.status, TransferStatus::Confirmed);
manager.confirm_unlock(&id, vec![0x33; 32], current_time + 180).unwrap();
manager
.confirm_unlock(&id, vec![0x33; 32], current_time + 180)
.unwrap();
let transfer = manager.get(&id).unwrap();
assert_eq!(transfer.status, TransferStatus::Completed);
}

View file

@ -12,7 +12,9 @@ use std::collections::HashMap;
use std::fmt;
/// Unique vault identifier
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, BorshSerialize, BorshDeserialize)]
#[derive(
Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, BorshSerialize, BorshDeserialize,
)]
pub struct VaultId(pub String);
impl VaultId {
@ -198,13 +200,7 @@ impl Vault {
return Err(BridgeError::TransferAlreadyExists(lock_id));
}
let locked = LockedAsset::new(
self.asset.clone(),
amount,
owner,
recipient,
current_time,
);
let locked = LockedAsset::new(self.asset.clone(), amount, owner, recipient, current_time);
self.locked_assets.insert(lock_id, locked);
self.total_locked += amount;
@ -283,7 +279,10 @@ impl Vault {
}
/// Get expired locked assets
pub fn expired_locked(&self, current_time: u64) -> impl Iterator<Item = (&String, &LockedAsset)> {
pub fn expired_locked(
&self,
current_time: u64,
) -> impl Iterator<Item = (&String, &LockedAsset)> {
self.locked_assets
.iter()
.filter(move |(_, l)| !l.released && l.is_expired(current_time))
@ -511,14 +510,8 @@ mod tests {
#[test]
fn test_locked_asset_expiry() {
let locked = LockedAsset::new(
AssetId::eth(),
1000,
test_owner(),
test_recipient(),
1000,
)
.with_expiry(2000);
let locked = LockedAsset::new(AssetId::eth(), 1000, test_owner(), test_recipient(), 1000)
.with_expiry(2000);
assert!(!locked.is_expired(1500));
assert!(locked.is_expired(2000));
@ -624,11 +617,7 @@ mod tests {
#[test]
fn test_lock_unlock() {
let mut vault = Vault::new(
VaultId::new("test"),
ChainType::Ethereum,
AssetId::eth(),
);
let mut vault = Vault::new(VaultId::new("test"), ChainType::Ethereum, AssetId::eth());
let current_time = 1700000000;
@ -654,20 +643,28 @@ mod tests {
);
let current_time = 1700000000;
vault.lock("lock-1", 1000, test_owner(), test_recipient(), current_time).unwrap();
vault.lock("lock-2", 2000, test_owner(), test_recipient(), current_time).unwrap();
vault.lock("lock-3", 500, test_owner_alt(), test_recipient(), current_time).unwrap();
vault
.lock("lock-1", 1000, test_owner(), test_recipient(), current_time)
.unwrap();
vault
.lock("lock-2", 2000, test_owner(), test_recipient(), current_time)
.unwrap();
vault
.lock(
"lock-3",
500,
test_owner_alt(),
test_recipient(),
current_time,
)
.unwrap();
assert_eq!(vault.total_locked, 3500);
}
#[test]
fn test_duplicate_lock() {
let mut vault = Vault::new(
VaultId::new("test"),
ChainType::Ethereum,
AssetId::eth(),
);
let mut vault = Vault::new(VaultId::new("test"), ChainType::Ethereum, AssetId::eth());
vault
.lock("lock1", 1000, test_owner(), test_recipient(), 0)
@ -697,20 +694,21 @@ mod tests {
AssetId::eth(),
);
vault.lock("lock-1", 1000, test_owner(), test_recipient(), 0).unwrap();
vault
.lock("lock-1", 1000, test_owner(), test_recipient(), 0)
.unwrap();
vault.unlock("lock-1").unwrap();
let result = vault.unlock("lock-1");
assert!(matches!(result, Err(BridgeError::TransferAlreadyCompleted(_))));
assert!(matches!(
result,
Err(BridgeError::TransferAlreadyCompleted(_))
));
}
#[test]
fn test_vault_pause() {
let mut vault = Vault::new(
VaultId::new("test"),
ChainType::Ethereum,
AssetId::eth(),
);
let mut vault = Vault::new(VaultId::new("test"), ChainType::Ethereum, AssetId::eth());
vault.pause();
@ -730,7 +728,9 @@ mod tests {
vault.resume();
assert_eq!(vault.state, VaultState::Active);
vault.lock("lock-1", 1000, test_owner(), test_recipient(), 0).unwrap();
vault
.lock("lock-1", 1000, test_owner(), test_recipient(), 0)
.unwrap();
}
#[test]
@ -750,12 +750,8 @@ mod tests {
#[test]
fn test_daily_limit() {
let mut vault = Vault::new(
VaultId::new("test"),
ChainType::Ethereum,
AssetId::eth(),
)
.with_daily_limit(1000);
let mut vault = Vault::new(VaultId::new("test"), ChainType::Ethereum, AssetId::eth())
.with_daily_limit(1000);
let current_time = 86400 * 100;
@ -781,8 +777,24 @@ mod tests {
);
let current_time = 0;
vault.lock("lock-1", 1000000000, test_owner(), test_recipient(), current_time).unwrap();
vault.lock("lock-2", 1000000000, test_owner(), test_recipient(), current_time).unwrap();
vault
.lock(
"lock-1",
1000000000,
test_owner(),
test_recipient(),
current_time,
)
.unwrap();
vault
.lock(
"lock-2",
1000000000,
test_owner(),
test_recipient(),
current_time,
)
.unwrap();
assert_eq!(vault.total_locked, 2000000000);
}
@ -795,7 +807,9 @@ mod tests {
AssetId::eth(),
);
vault.lock("lock-1", 1000, test_owner(), test_recipient(), 0).unwrap();
vault
.lock("lock-1", 1000, test_owner(), test_recipient(), 0)
.unwrap();
assert!(vault.get_locked("lock-1").is_some());
assert!(vault.get_locked("nonexistent").is_none());
@ -809,8 +823,12 @@ mod tests {
AssetId::eth(),
);
vault.lock("lock-1", 1000, test_owner(), test_recipient(), 0).unwrap();
vault.lock("lock-2", 2000, test_owner(), test_recipient(), 0).unwrap();
vault
.lock("lock-1", 1000, test_owner(), test_recipient(), 0)
.unwrap();
vault
.lock("lock-2", 2000, test_owner(), test_recipient(), 0)
.unwrap();
let all: Vec<_> = vault.all_locked().collect();
assert_eq!(all.len(), 2);
@ -824,8 +842,12 @@ mod tests {
AssetId::eth(),
);
vault.lock("lock-1", 1000, test_owner(), test_recipient(), 0).unwrap();
vault.lock("lock-2", 2000, test_owner(), test_recipient(), 0).unwrap();
vault
.lock("lock-1", 1000, test_owner(), test_recipient(), 0)
.unwrap();
vault
.lock("lock-2", 2000, test_owner(), test_recipient(), 0)
.unwrap();
vault.unlock("lock-1").unwrap();
let active: Vec<_> = vault.active_locked().collect();
@ -858,7 +880,9 @@ mod tests {
assert!(manager.find_vault(&ChainType::Ethereum, &eth).is_some());
let vault = manager.get_or_create_vault(ChainType::Ethereum, eth.clone());
vault.lock("lock1", 100, test_owner(), test_recipient(), 0).unwrap();
vault
.lock("lock1", 100, test_owner(), test_recipient(), 0)
.unwrap();
assert_eq!(manager.total_locked(), 100);
}
@ -881,7 +905,9 @@ mod tests {
{
let vault = manager.get_vault_mut(&vault_id).unwrap();
vault.lock("lock-1", 1000, test_owner(), test_recipient(), 0).unwrap();
vault
.lock("lock-1", 1000, test_owner(), test_recipient(), 0)
.unwrap();
}
let vault = manager.get_vault(&vault_id).unwrap();
@ -902,7 +928,9 @@ mod tests {
manager.create_vault(ChainType::Ethereum, eth.clone());
let vault = manager.find_vault_mut(&ChainType::Ethereum, &eth).unwrap();
vault.lock("lock-1", 1000, test_owner(), test_recipient(), 0).unwrap();
vault
.lock("lock-1", 1000, test_owner(), test_recipient(), 0)
.unwrap();
assert_eq!(manager.total_locked(), 1000);
}
@ -913,7 +941,9 @@ mod tests {
let eth = AssetId::eth();
let vault = manager.get_or_create_vault(ChainType::Ethereum, eth.clone());
vault.lock("lock-1", 1000, test_owner(), test_recipient(), 0).unwrap();
vault
.lock("lock-1", 1000, test_owner(), test_recipient(), 0)
.unwrap();
assert_eq!(manager.vault_ids().len(), 1);
assert_eq!(manager.total_locked(), 1000);

View file

@ -241,7 +241,10 @@ impl DeviceRegistry {
}
/// Gets a processor by ID.
pub fn get_processor(&self, processor_id: ProcessorId) -> Result<Arc<dyn Processor>, ComputeError> {
pub fn get_processor(
&self,
processor_id: ProcessorId,
) -> Result<Arc<dyn Processor>, ComputeError> {
self.processors
.read()
.get(&processor_id)
@ -266,7 +269,10 @@ impl DeviceRegistry {
/// Gets the next processor ID.
pub fn next_processor_id(&self) -> ProcessorId {
ProcessorId(self.next_processor_id.fetch_add(1, std::sync::atomic::Ordering::SeqCst))
ProcessorId(
self.next_processor_id
.fetch_add(1, std::sync::atomic::Ordering::SeqCst),
)
}
/// Gets total number of devices.
@ -309,7 +315,10 @@ impl DeviceRegistry {
device.status = status;
Ok(())
} else {
Err(ComputeError::Internal(format!("Device not found: {}", device_id)))
Err(ComputeError::Internal(format!(
"Device not found: {}",
device_id
)))
}
}
}
@ -323,7 +332,7 @@ impl Default for DeviceRegistry {
#[cfg(test)]
mod tests {
use super::*;
use crate::processor::{CpuVariant, AvxSupport};
use crate::processor::{AvxSupport, CpuVariant};
#[test]
fn test_device_id() {

View file

@ -67,6 +67,10 @@ pub use market::{
ResourceType, SpotMarket, Trade,
};
pub use memory::{MemoryManager, TensorHandle, TransferPath, UnifiedMemory};
pub use model::{
ModelCategory, ModelFormat, ModelId, ModelInfo, ModelRegistry, ModelUploadRequest,
ModelUploadResponse,
};
pub use processor::{
ComputeThroughput, CpuVariant, GpuVariant, NpuVariant, Operation, OperationType, Processor,
ProcessorCapabilities, ProcessorId, ProcessorType, TpuVersion,
@ -78,10 +82,6 @@ pub use task::{
ComputeTask, DecomposedWorkload, Task, TaskDecomposer, TaskId, TaskPriority, TaskResult,
TaskStatus,
};
pub use model::{
ModelCategory, ModelFormat, ModelId, ModelInfo, ModelRegistry, ModelUploadRequest,
ModelUploadResponse,
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
@ -434,7 +434,10 @@ impl ComputeCluster {
let jobs = self.jobs.read();
let total_nodes = nodes.len();
let online_nodes = nodes.values().filter(|n| n.status == NodeStatus::Online).count();
let online_nodes = nodes
.values()
.filter(|n| n.status == NodeStatus::Online)
.count();
let total_gpus: usize = nodes
.values()
@ -515,16 +518,16 @@ pub enum GpuTier {
impl Default for ComputePricing {
fn default() -> Self {
let mut gpu_hourly = HashMap::new();
gpu_hourly.insert(GpuTier::Consumer, 100_000_000); // 0.10 SYNOR
gpu_hourly.insert(GpuTier::Professional, 300_000_000); // 0.30 SYNOR
gpu_hourly.insert(GpuTier::DataCenter, 2_000_000_000); // 2.00 SYNOR
gpu_hourly.insert(GpuTier::Premium, 4_000_000_000); // 4.00 SYNOR
gpu_hourly.insert(GpuTier::Consumer, 100_000_000); // 0.10 SYNOR
gpu_hourly.insert(GpuTier::Professional, 300_000_000); // 0.30 SYNOR
gpu_hourly.insert(GpuTier::DataCenter, 2_000_000_000); // 2.00 SYNOR
gpu_hourly.insert(GpuTier::Premium, 4_000_000_000); // 4.00 SYNOR
Self {
gpu_hourly,
cpu_core_hour: 20_000_000, // 0.02 SYNOR
memory_gb_hour: 5_000_000, // 0.005 SYNOR
network_egress_gb: 50_000_000, // 0.05 SYNOR
cpu_core_hour: 20_000_000, // 0.02 SYNOR
memory_gb_hour: 5_000_000, // 0.005 SYNOR
network_egress_gb: 50_000_000, // 0.05 SYNOR
inference_per_million_tokens: 100_000_000, // 0.10 SYNOR
}
}

View file

@ -686,24 +686,24 @@ impl PricingEngine {
pub fn greenest_region(&self) -> &str {
self.regions
.iter()
.max_by(|a, b| {
a.renewable_pct
.partial_cmp(&b.renewable_pct)
.unwrap()
})
.max_by(|a, b| a.renewable_pct.partial_cmp(&b.renewable_pct).unwrap())
.map(|r| r.region.as_str())
.unwrap_or("eu-north")
}
/// Compares price to cloud providers.
pub fn compare_to_cloud(&self, resource: &ResourceType, region: Option<&str>) -> CloudComparison {
pub fn compare_to_cloud(
&self,
resource: &ResourceType,
region: Option<&str>,
) -> CloudComparison {
let our_price = self.spot_price(resource, region);
// Approximate cloud provider prices (USD/hour for GPU)
let (aws_price, gcp_price, azure_price) = match resource {
ResourceType::GpuHours(GpuTier::DataCenter) => (3.06, 2.95, 3.10), // A100 equivalents
ResourceType::GpuHours(GpuTier::Ultra) => (5.00, 4.50, 5.20), // H100 equivalents
ResourceType::GpuHours(GpuTier::High) => (1.50, 1.40, 1.60), // T4/A10 equivalents
ResourceType::GpuHours(GpuTier::Ultra) => (5.00, 4.50, 5.20), // H100 equivalents
ResourceType::GpuHours(GpuTier::High) => (1.50, 1.40, 1.60), // T4/A10 equivalents
ResourceType::CpuHours(CpuTier::Server) => (0.40, 0.35, 0.42),
_ => (1.0, 1.0, 1.0),
};
@ -888,9 +888,18 @@ impl SpotMarket {
);
}
order_books.insert(ResourceType::TpuHours, OrderBook::new(ResourceType::TpuHours));
order_books.insert(ResourceType::NpuHours, OrderBook::new(ResourceType::NpuHours));
order_books.insert(ResourceType::LpuCredits, OrderBook::new(ResourceType::LpuCredits));
order_books.insert(
ResourceType::TpuHours,
OrderBook::new(ResourceType::TpuHours),
);
order_books.insert(
ResourceType::NpuHours,
OrderBook::new(ResourceType::NpuHours),
);
order_books.insert(
ResourceType::LpuCredits,
OrderBook::new(ResourceType::LpuCredits),
);
Self {
order_books,
@ -1074,12 +1083,21 @@ mod tests {
fn test_pricing_engine() {
let engine = PricingEngine::new();
let price = engine.spot_price(&ResourceType::GpuHours(GpuTier::DataCenter), Some("eu-north"));
let price = engine.spot_price(
&ResourceType::GpuHours(GpuTier::DataCenter),
Some("eu-north"),
);
assert!(price > 0.0);
// eu-north should be cheaper (low electricity cost)
let eu_price = engine.spot_price(&ResourceType::GpuHours(GpuTier::DataCenter), Some("eu-north"));
let eu_west_price = engine.spot_price(&ResourceType::GpuHours(GpuTier::DataCenter), Some("eu-west"));
let eu_price = engine.spot_price(
&ResourceType::GpuHours(GpuTier::DataCenter),
Some("eu-north"),
);
let eu_west_price = engine.spot_price(
&ResourceType::GpuHours(GpuTier::DataCenter),
Some("eu-west"),
);
// eu-north has cheaper electricity
assert!(eu_price < eu_west_price);
@ -1089,7 +1107,8 @@ mod tests {
fn test_cloud_comparison() {
let engine = PricingEngine::new();
let comparison = engine.compare_to_cloud(&ResourceType::GpuHours(GpuTier::DataCenter), None);
let comparison =
engine.compare_to_cloud(&ResourceType::GpuHours(GpuTier::DataCenter), None);
// Should show significant savings
assert!(comparison.aws_savings > 50.0);

View file

@ -106,11 +106,11 @@ impl TransferPath {
/// Returns approximate bandwidth in GB/s.
pub fn bandwidth_gbps(&self) -> f64 {
match self {
TransferPath::NvLink => 900.0, // NVLink 4.0
TransferPath::NvLink => 900.0, // NVLink 4.0
TransferPath::PciePeerToPeer => 64.0, // PCIe 5.0 x16
TransferPath::CpuMediated => 50.0, // DDR5
TransferPath::CpuMediated => 50.0, // DDR5
TransferPath::UnifiedMemory => 400.0, // Apple unified
TransferPath::Network => 10.0, // 100Gbps network
TransferPath::Network => 10.0, // 100Gbps network
TransferPath::SameMemory => f64::INFINITY,
}
}
@ -154,7 +154,11 @@ impl MemoryManager {
}
/// Allocates a tensor.
pub fn allocate(&self, shape: Vec<usize>, dtype: DataType) -> Result<TensorHandle, ComputeError> {
pub fn allocate(
&self,
shape: Vec<usize>,
dtype: DataType,
) -> Result<TensorHandle, ComputeError> {
let handle = TensorHandle::new(shape, dtype);
self.tensors.write().insert(handle.id, handle.clone());
Ok(handle)
@ -223,9 +227,13 @@ impl MemoryManager {
}
// Check for NVLink between NVIDIA GPUs
if matches!(from, ProcessorType::Gpu(crate::processor::GpuVariant::NvidiaCuda { .. }))
&& matches!(to, ProcessorType::Gpu(crate::processor::GpuVariant::NvidiaCuda { .. }))
{
if matches!(
from,
ProcessorType::Gpu(crate::processor::GpuVariant::NvidiaCuda { .. })
) && matches!(
to,
ProcessorType::Gpu(crate::processor::GpuVariant::NvidiaCuda { .. })
) {
return TransferPath::NvLink;
}
@ -244,10 +252,22 @@ impl MemoryManager {
match (a, b) {
// Apple Silicon unified memory
(ProcessorType::Cpu(CpuVariant::Arm64 { .. }), ProcessorType::Gpu(GpuVariant::AppleMetal))
| (ProcessorType::Gpu(GpuVariant::AppleMetal), ProcessorType::Cpu(CpuVariant::Arm64 { .. }))
| (ProcessorType::Npu(NpuVariant::AppleNeuralEngine { .. }), ProcessorType::Cpu(CpuVariant::Arm64 { .. }))
| (ProcessorType::Npu(NpuVariant::AppleNeuralEngine { .. }), ProcessorType::Gpu(GpuVariant::AppleMetal)) => true,
(
ProcessorType::Cpu(CpuVariant::Arm64 { .. }),
ProcessorType::Gpu(GpuVariant::AppleMetal),
)
| (
ProcessorType::Gpu(GpuVariant::AppleMetal),
ProcessorType::Cpu(CpuVariant::Arm64 { .. }),
)
| (
ProcessorType::Npu(NpuVariant::AppleNeuralEngine { .. }),
ProcessorType::Cpu(CpuVariant::Arm64 { .. }),
)
| (
ProcessorType::Npu(NpuVariant::AppleNeuralEngine { .. }),
ProcessorType::Gpu(GpuVariant::AppleMetal),
) => true,
// Same type
_ if a == b => true,
_ => false,
@ -325,7 +345,9 @@ mod tests {
#[test]
fn test_transfer_path_bandwidth() {
assert!(TransferPath::NvLink.bandwidth_gbps() > TransferPath::PciePeerToPeer.bandwidth_gbps());
assert!(
TransferPath::NvLink.bandwidth_gbps() > TransferPath::PciePeerToPeer.bandwidth_gbps()
);
assert!(TransferPath::SameMemory.bandwidth_gbps().is_infinite());
}
@ -333,7 +355,9 @@ mod tests {
fn test_memory_manager() {
let manager = MemoryManager::new();
let handle = manager.allocate(vec![1024, 1024], DataType::Float32).unwrap();
let handle = manager
.allocate(vec![1024, 1024], DataType::Float32)
.unwrap();
assert_eq!(manager.tensor_count(), 1);
manager.free(handle.id).unwrap();
@ -347,22 +371,26 @@ mod tests {
let handle = manager.allocate(vec![1024], DataType::Float32).unwrap();
// First ensure should allocate
let path = manager.ensure_on(
handle.id,
ProcessorType::Gpu(crate::processor::GpuVariant::NvidiaCuda {
compute_capability: (8, 0),
}),
).unwrap();
let path = manager
.ensure_on(
handle.id,
ProcessorType::Gpu(crate::processor::GpuVariant::NvidiaCuda {
compute_capability: (8, 0),
}),
)
.unwrap();
assert_eq!(path, TransferPath::SameMemory);
// Second ensure to same location should be same memory
let path = manager.ensure_on(
handle.id,
ProcessorType::Gpu(crate::processor::GpuVariant::NvidiaCuda {
compute_capability: (8, 0),
}),
).unwrap();
let path = manager
.ensure_on(
handle.id,
ProcessorType::Gpu(crate::processor::GpuVariant::NvidiaCuda {
compute_capability: (8, 0),
}),
)
.unwrap();
assert_eq!(path, TransferPath::SameMemory);
}

View file

@ -140,13 +140,7 @@ pub struct ModelInfo {
impl ModelInfo {
/// Creates a new LLM model info.
pub fn llm(
alias: &str,
name: &str,
cid: &str,
parameters: u64,
context_length: u32,
) -> Self {
pub fn llm(alias: &str, name: &str, cid: &str, parameters: u64, context_length: u32) -> Self {
Self {
id: ModelId::from_alias(alias),
name: name.to_string(),
@ -156,7 +150,12 @@ impl ModelInfo {
format: ModelFormat::SafeTensors,
size_bytes: parameters * 2, // ~2 bytes per param in fp16
parameters,
supported_precisions: vec![Precision::Fp16, Precision::Bf16, Precision::Int8, Precision::Int4],
supported_precisions: vec![
Precision::Fp16,
Precision::Bf16,
Precision::Int8,
Precision::Int4,
],
recommended_processor: ProcessorType::Lpu,
context_length: Some(context_length),
input_schema: None,
@ -238,33 +237,123 @@ impl ModelRegistry {
let default_models = vec![
// ===== LLMs =====
// Llama 3 family
ModelInfo::llm("llama-3-8b", "Llama 3 8B", "QmLlama3_8B_placeholder", 8_000_000_000, 8192),
ModelInfo::llm("llama-3-70b", "Llama 3 70B", "QmLlama3_70B_placeholder", 70_000_000_000, 8192),
ModelInfo::llm("llama-3.1-8b", "Llama 3.1 8B", "QmLlama31_8B_placeholder", 8_000_000_000, 128000),
ModelInfo::llm("llama-3.1-70b", "Llama 3.1 70B", "QmLlama31_70B_placeholder", 70_000_000_000, 128000),
ModelInfo::llm("llama-3.1-405b", "Llama 3.1 405B", "QmLlama31_405B_placeholder", 405_000_000_000, 128000),
ModelInfo::llm(
"llama-3-8b",
"Llama 3 8B",
"QmLlama3_8B_placeholder",
8_000_000_000,
8192,
),
ModelInfo::llm(
"llama-3-70b",
"Llama 3 70B",
"QmLlama3_70B_placeholder",
70_000_000_000,
8192,
),
ModelInfo::llm(
"llama-3.1-8b",
"Llama 3.1 8B",
"QmLlama31_8B_placeholder",
8_000_000_000,
128000,
),
ModelInfo::llm(
"llama-3.1-70b",
"Llama 3.1 70B",
"QmLlama31_70B_placeholder",
70_000_000_000,
128000,
),
ModelInfo::llm(
"llama-3.1-405b",
"Llama 3.1 405B",
"QmLlama31_405B_placeholder",
405_000_000_000,
128000,
),
// Mistral family
ModelInfo::llm("mistral-7b", "Mistral 7B", "QmMistral7B_placeholder", 7_000_000_000, 32768),
ModelInfo::llm("mixtral-8x7b", "Mixtral 8x7B", "QmMixtral8x7B_placeholder", 46_000_000_000, 32768),
ModelInfo::llm("mixtral-8x22b", "Mixtral 8x22B", "QmMixtral8x22B_placeholder", 176_000_000_000, 65536),
ModelInfo::llm(
"mistral-7b",
"Mistral 7B",
"QmMistral7B_placeholder",
7_000_000_000,
32768,
),
ModelInfo::llm(
"mixtral-8x7b",
"Mixtral 8x7B",
"QmMixtral8x7B_placeholder",
46_000_000_000,
32768,
),
ModelInfo::llm(
"mixtral-8x22b",
"Mixtral 8x22B",
"QmMixtral8x22B_placeholder",
176_000_000_000,
65536,
),
// Qwen family
ModelInfo::llm("qwen-2.5-7b", "Qwen 2.5 7B", "QmQwen25_7B_placeholder", 7_000_000_000, 128000),
ModelInfo::llm("qwen-2.5-72b", "Qwen 2.5 72B", "QmQwen25_72B_placeholder", 72_000_000_000, 128000),
ModelInfo::llm(
"qwen-2.5-7b",
"Qwen 2.5 7B",
"QmQwen25_7B_placeholder",
7_000_000_000,
128000,
),
ModelInfo::llm(
"qwen-2.5-72b",
"Qwen 2.5 72B",
"QmQwen25_72B_placeholder",
72_000_000_000,
128000,
),
// DeepSeek family
ModelInfo::llm("deepseek-v2", "DeepSeek V2", "QmDeepSeekV2_placeholder", 236_000_000_000, 128000),
ModelInfo::llm("deepseek-coder-33b", "DeepSeek Coder 33B", "QmDeepSeekCoder33B_placeholder", 33_000_000_000, 16384),
ModelInfo::llm(
"deepseek-v2",
"DeepSeek V2",
"QmDeepSeekV2_placeholder",
236_000_000_000,
128000,
),
ModelInfo::llm(
"deepseek-coder-33b",
"DeepSeek Coder 33B",
"QmDeepSeekCoder33B_placeholder",
33_000_000_000,
16384,
),
// Phi family (small/efficient)
ModelInfo::llm("phi-3-mini", "Phi 3 Mini", "QmPhi3Mini_placeholder", 3_800_000_000, 128000),
ModelInfo::llm("phi-3-medium", "Phi 3 Medium", "QmPhi3Medium_placeholder", 14_000_000_000, 128000),
ModelInfo::llm(
"phi-3-mini",
"Phi 3 Mini",
"QmPhi3Mini_placeholder",
3_800_000_000,
128000,
),
ModelInfo::llm(
"phi-3-medium",
"Phi 3 Medium",
"QmPhi3Medium_placeholder",
14_000_000_000,
128000,
),
// Code models
ModelInfo::llm("codellama-34b", "Code Llama 34B", "QmCodeLlama34B_placeholder", 34_000_000_000, 16384),
ModelInfo::llm("starcoder2-15b", "StarCoder2 15B", "QmStarCoder2_15B_placeholder", 15_000_000_000, 16384),
ModelInfo::llm(
"codellama-34b",
"Code Llama 34B",
"QmCodeLlama34B_placeholder",
34_000_000_000,
16384,
),
ModelInfo::llm(
"starcoder2-15b",
"StarCoder2 15B",
"QmStarCoder2_15B_placeholder",
15_000_000_000,
16384,
),
// ===== Embedding Models =====
ModelInfo {
id: ModelId::from_alias("bge-large"),
@ -306,7 +395,6 @@ impl ModelRegistry {
is_public: true,
owner: None,
},
// ===== Vision Models =====
ModelInfo {
id: ModelId::from_alias("stable-diffusion-xl"),
@ -348,7 +436,6 @@ impl ModelRegistry {
is_public: true,
owner: None,
},
// ===== Speech Models =====
ModelInfo {
id: ModelId::from_alias("whisper-large-v3"),
@ -370,7 +457,6 @@ impl ModelRegistry {
is_public: true,
owner: None,
},
// ===== Multi-Modal Models =====
ModelInfo {
id: ModelId::from_alias("llava-1.5-13b"),
@ -555,7 +641,9 @@ mod tests {
let registry = ModelRegistry::new();
let results = registry.search("llama");
assert!(!results.is_empty());
assert!(results.iter().all(|m| m.name.to_lowercase().contains("llama")));
assert!(results
.iter()
.all(|m| m.name.to_lowercase().contains("llama")));
}
#[test]

View file

@ -305,7 +305,7 @@ impl ProcessorCapabilities {
},
memory: MemorySpecs {
capacity_bytes: 230 * 1024 * 1024 * 1024, // 230 GB SRAM!
bandwidth_gbps: 80_000, // 80 TB/s internal
bandwidth_gbps: 80_000, // 80 TB/s internal
type_: MemoryType::Sram,
},
operations: Self::lpu_operations(),
@ -349,8 +349,8 @@ impl ProcessorCapabilities {
/// Creates Apple Neural Engine capabilities.
pub fn apple_neural_engine(cores: u32) -> Self {
let int8_tops = match cores {
16 => 18.0, // M3
32 => 35.0, // M3 Max
16 => 18.0, // M3
32 => 35.0, // M3 Max
_ => cores as f64 * 1.1,
};
@ -542,6 +542,8 @@ mod tests {
fn test_lpu_capabilities() {
let caps = ProcessorCapabilities::lpu();
assert!(caps.memory.bandwidth_gbps > 10000); // Very high internal bandwidth
assert!(caps.optimal_for.contains(&WorkloadCharacteristic::Sequential));
assert!(caps
.optimal_for
.contains(&WorkloadCharacteristic::Sequential));
}
}

View file

@ -253,10 +253,22 @@ impl Processor for GenericProcessor {
fn shares_memory_with(&self, other: &ProcessorType) -> bool {
match (&self.processor_type, other) {
// Apple Silicon has unified memory
(ProcessorType::Cpu(CpuVariant::Arm64 { .. }), ProcessorType::Gpu(GpuVariant::AppleMetal))
| (ProcessorType::Gpu(GpuVariant::AppleMetal), ProcessorType::Cpu(CpuVariant::Arm64 { .. }))
| (ProcessorType::Npu(NpuVariant::AppleNeuralEngine { .. }), ProcessorType::Cpu(CpuVariant::Arm64 { .. }))
| (ProcessorType::Npu(NpuVariant::AppleNeuralEngine { .. }), ProcessorType::Gpu(GpuVariant::AppleMetal)) => true,
(
ProcessorType::Cpu(CpuVariant::Arm64 { .. }),
ProcessorType::Gpu(GpuVariant::AppleMetal),
)
| (
ProcessorType::Gpu(GpuVariant::AppleMetal),
ProcessorType::Cpu(CpuVariant::Arm64 { .. }),
)
| (
ProcessorType::Npu(NpuVariant::AppleNeuralEngine { .. }),
ProcessorType::Cpu(CpuVariant::Arm64 { .. }),
)
| (
ProcessorType::Npu(NpuVariant::AppleNeuralEngine { .. }),
ProcessorType::Gpu(GpuVariant::AppleMetal),
) => true,
// Same type always shares
(a, b) if a == b => true,
_ => false,

View file

@ -191,10 +191,7 @@ pub enum Operation {
},
/// Data loading from storage.
DataLoad {
bytes: usize,
async_: bool,
},
DataLoad { bytes: usize, async_: bool },
/// Data preprocessing.
DataPreprocess {
@ -209,16 +206,10 @@ pub enum Operation {
},
/// Detokenization.
Detokenization {
tokens: usize,
vocab_size: usize,
},
Detokenization { tokens: usize, vocab_size: usize },
/// Checkpoint save.
Checkpoint {
bytes: usize,
async_: bool,
},
Checkpoint { bytes: usize, async_: bool },
/// All-reduce across devices.
AllReduce {
@ -228,9 +219,7 @@ pub enum Operation {
},
/// Backward pass for a layer.
Backward {
forward_op: Box<Operation>,
},
Backward { forward_op: Box<Operation> },
/// Optimizer step.
OptimizerStep {
@ -240,16 +229,10 @@ pub enum Operation {
},
/// Transpose.
Transpose {
shape: Vec<usize>,
axes: Vec<usize>,
},
Transpose { shape: Vec<usize>, axes: Vec<usize> },
/// Reshape.
Reshape {
from: Vec<usize>,
to: Vec<usize>,
},
Reshape { from: Vec<usize>, to: Vec<usize> },
/// Concatenate tensors.
Concat {
@ -378,9 +361,7 @@ impl Operation {
| Operation::SiLU { elements } => *elements as f64,
// Softmax: ~5 ops per element (exp, sum, div)
Operation::Softmax {
batch, seq_len, ..
} => 5.0 * (*batch as f64) * (*seq_len as f64),
Operation::Softmax { batch, seq_len, .. } => 5.0 * (*batch as f64) * (*seq_len as f64),
// Embedding: just lookup, minimal FLOPS
Operation::Embedding {

View file

@ -39,8 +39,7 @@ impl ProcessorProfiles {
bandwidth_gbps: 460,
type_: MemoryType::Ddr5,
},
operations: ProcessorCapabilities::cpu(96, 2.4, false)
.operations,
operations: ProcessorCapabilities::cpu(96, 2.4, false).operations,
power: PowerCharacteristics {
tdp_watts: 360,
efficiency: 0.85,
@ -70,8 +69,7 @@ impl ProcessorProfiles {
bandwidth_gbps: 307,
type_: MemoryType::Ddr5,
},
operations: ProcessorCapabilities::cpu(56, 2.9, true)
.operations,
operations: ProcessorCapabilities::cpu(56, 2.9, true).operations,
power: PowerCharacteristics {
tdp_watts: 350,
efficiency: 0.80,
@ -101,8 +99,7 @@ impl ProcessorProfiles {
bandwidth_gbps: 400,
type_: MemoryType::Unified,
},
operations: ProcessorCapabilities::cpu(16, 4.0, false)
.operations,
operations: ProcessorCapabilities::cpu(16, 4.0, false).operations,
power: PowerCharacteristics {
tdp_watts: 40,
efficiency: 0.95,
@ -141,8 +138,7 @@ impl ProcessorProfiles {
bandwidth_gbps: 3350,
type_: MemoryType::Hbm3,
},
operations: ProcessorCapabilities::nvidia_gpu(16896, 528, 80, 3350, (9, 0))
.operations,
operations: ProcessorCapabilities::nvidia_gpu(16896, 528, 80, 3350, (9, 0)).operations,
power: PowerCharacteristics {
tdp_watts: 700,
efficiency: 0.90,
@ -173,8 +169,7 @@ impl ProcessorProfiles {
bandwidth_gbps: 2039,
type_: MemoryType::Hbm2e,
},
operations: ProcessorCapabilities::nvidia_gpu(6912, 432, 80, 2039, (8, 0))
.operations,
operations: ProcessorCapabilities::nvidia_gpu(6912, 432, 80, 2039, (8, 0)).operations,
power: PowerCharacteristics {
tdp_watts: 400,
efficiency: 0.88,
@ -205,8 +200,7 @@ impl ProcessorProfiles {
bandwidth_gbps: 1008,
type_: MemoryType::Gddr6,
},
operations: ProcessorCapabilities::nvidia_gpu(16384, 512, 24, 1008, (8, 9))
.operations,
operations: ProcessorCapabilities::nvidia_gpu(16384, 512, 24, 1008, (8, 9)).operations,
power: PowerCharacteristics {
tdp_watts: 450,
efficiency: 0.85,
@ -236,8 +230,7 @@ impl ProcessorProfiles {
bandwidth_gbps: 936,
type_: MemoryType::Gddr6,
},
operations: ProcessorCapabilities::nvidia_gpu(10496, 328, 24, 936, (8, 6))
.operations,
operations: ProcessorCapabilities::nvidia_gpu(10496, 328, 24, 936, (8, 6)).operations,
power: PowerCharacteristics {
tdp_watts: 350,
efficiency: 0.82,
@ -272,8 +265,8 @@ impl ProcessorProfiles {
type_: MemoryType::Hbm3,
},
operations: {
let mut ops = ProcessorCapabilities::nvidia_gpu(16384, 512, 80, 5300, (9, 0))
.operations;
let mut ops =
ProcessorCapabilities::nvidia_gpu(16384, 512, 80, 5300, (9, 0)).operations;
ops.remove(&OperationType::FlashAttention); // Different implementation
ops
},
@ -308,8 +301,8 @@ impl ProcessorProfiles {
type_: MemoryType::Gddr6,
},
operations: {
let mut ops = ProcessorCapabilities::nvidia_gpu(6144, 0, 24, 960, (8, 0))
.operations;
let mut ops =
ProcessorCapabilities::nvidia_gpu(6144, 0, 24, 960, (8, 0)).operations;
ops.remove(&OperationType::FlashAttention);
ops
},
@ -318,9 +311,7 @@ impl ProcessorProfiles {
efficiency: 0.80,
power_tier: PowerTier::High,
},
optimal_for: vec![
WorkloadCharacteristic::HighlyParallel,
],
optimal_for: vec![WorkloadCharacteristic::HighlyParallel],
}
}
@ -429,8 +420,7 @@ impl ProcessorProfiles {
bandwidth_gbps: 200,
type_: MemoryType::Unified,
},
operations: ProcessorCapabilities::apple_neural_engine(16)
.operations,
operations: ProcessorCapabilities::apple_neural_engine(16).operations,
power: PowerCharacteristics {
tdp_watts: 8,
efficiency: 0.98,
@ -465,8 +455,7 @@ impl ProcessorProfiles {
bandwidth_gbps: 77,
type_: MemoryType::Lpddr,
},
operations: ProcessorCapabilities::apple_neural_engine(16)
.operations,
operations: ProcessorCapabilities::apple_neural_engine(16).operations,
power: PowerCharacteristics {
tdp_watts: 10,
efficiency: 0.95,

View file

@ -24,10 +24,7 @@ pub enum ProcessorType {
/// WebAssembly runtime.
Wasm,
/// Custom/Unknown accelerator.
Custom {
vendor: String,
model: String,
},
Custom { vendor: String, model: String },
}
impl Default for ProcessorType {

View file

@ -6,10 +6,10 @@
//! - Latency-aware scheduling
//! - Real-time utilization metrics
use super::TaskAssignment;
use crate::device::DeviceRegistry;
use crate::processor::{Operation, OperationType, ProcessorId, ProcessorType};
use crate::task::{Task, TaskId, TaskPriority};
use super::TaskAssignment;
use parking_lot::RwLock;
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
@ -127,8 +127,12 @@ impl LoadBalancer {
/// Register a processor with its type.
pub fn register_processor(&self, processor_id: ProcessorId, processor_type: ProcessorType) {
self.loads.write().insert(processor_id, AtomicU64::new(0));
self.metrics.write().insert(processor_id, ProcessorMetrics::default());
self.processor_types.write().insert(processor_id, processor_type);
self.metrics
.write()
.insert(processor_id, ProcessorMetrics::default());
self.processor_types
.write()
.insert(processor_id, processor_type);
}
/// Unregister a processor.
@ -150,7 +154,8 @@ impl LoadBalancer {
/// Get current load for a processor.
pub fn get_load(&self, processor_id: ProcessorId) -> u64 {
self.loads.read()
self.loads
.read()
.get(&processor_id)
.map(|l| l.load(Ordering::Relaxed))
.unwrap_or(0)
@ -179,140 +184,140 @@ impl LoadBalancer {
ProcessorType::Cpu(_) => matches!(
op_type,
OperationType::MatMul
| OperationType::Conv2d
| OperationType::Conv3d
| OperationType::DepthwiseConv
| OperationType::BatchNorm
| OperationType::LayerNorm
| OperationType::Add
| OperationType::Mul
| OperationType::ReLU
| OperationType::GeLU
| OperationType::SiLU
| OperationType::Softmax
| OperationType::Sum
| OperationType::Mean
| OperationType::Max
| OperationType::ArgMax
| OperationType::Embedding
| OperationType::TopK
| OperationType::Sampling
| OperationType::Tokenization
| OperationType::Detokenization
| OperationType::DataLoad
| OperationType::DataPreprocess
| OperationType::Transpose
| OperationType::Reshape
| OperationType::Concat
| OperationType::Split
| OperationType::Conv2d
| OperationType::Conv3d
| OperationType::DepthwiseConv
| OperationType::BatchNorm
| OperationType::LayerNorm
| OperationType::Add
| OperationType::Mul
| OperationType::ReLU
| OperationType::GeLU
| OperationType::SiLU
| OperationType::Softmax
| OperationType::Sum
| OperationType::Mean
| OperationType::Max
| OperationType::ArgMax
| OperationType::Embedding
| OperationType::TopK
| OperationType::Sampling
| OperationType::Tokenization
| OperationType::Detokenization
| OperationType::DataLoad
| OperationType::DataPreprocess
| OperationType::Transpose
| OperationType::Reshape
| OperationType::Concat
| OperationType::Split
),
// GPUs excel at parallel operations
ProcessorType::Gpu(_) => matches!(
op_type,
OperationType::MatMul
| OperationType::Conv2d
| OperationType::Conv3d
| OperationType::DepthwiseConv
| OperationType::BatchNorm
| OperationType::LayerNorm
| OperationType::SelfAttention
| OperationType::CrossAttention
| OperationType::FlashAttention
| OperationType::Add
| OperationType::Mul
| OperationType::ReLU
| OperationType::GeLU
| OperationType::SiLU
| OperationType::Softmax
| OperationType::Sum
| OperationType::Mean
| OperationType::Max
| OperationType::ArgMax
| OperationType::Embedding
| OperationType::RoPE
| OperationType::KVCache
| OperationType::TopK
| OperationType::Sampling
| OperationType::Transpose
| OperationType::Reshape
| OperationType::Concat
| OperationType::Split
| OperationType::Gather
| OperationType::Scatter
| OperationType::AllReduce
| OperationType::AllGather
| OperationType::ReduceScatter
| OperationType::Backward
| OperationType::OptimizerStep
| OperationType::GradientClip
| OperationType::Conv2d
| OperationType::Conv3d
| OperationType::DepthwiseConv
| OperationType::BatchNorm
| OperationType::LayerNorm
| OperationType::SelfAttention
| OperationType::CrossAttention
| OperationType::FlashAttention
| OperationType::Add
| OperationType::Mul
| OperationType::ReLU
| OperationType::GeLU
| OperationType::SiLU
| OperationType::Softmax
| OperationType::Sum
| OperationType::Mean
| OperationType::Max
| OperationType::ArgMax
| OperationType::Embedding
| OperationType::RoPE
| OperationType::KVCache
| OperationType::TopK
| OperationType::Sampling
| OperationType::Transpose
| OperationType::Reshape
| OperationType::Concat
| OperationType::Split
| OperationType::Gather
| OperationType::Scatter
| OperationType::AllReduce
| OperationType::AllGather
| OperationType::ReduceScatter
| OperationType::Backward
| OperationType::OptimizerStep
| OperationType::GradientClip
),
// TPUs optimized for ML
ProcessorType::Tpu(_) => matches!(
op_type,
OperationType::MatMul
| OperationType::Conv2d
| OperationType::BatchNorm
| OperationType::LayerNorm
| OperationType::SelfAttention
| OperationType::CrossAttention
| OperationType::FlashAttention
| OperationType::Add
| OperationType::Mul
| OperationType::ReLU
| OperationType::GeLU
| OperationType::SiLU
| OperationType::Softmax
| OperationType::Sum
| OperationType::Mean
| OperationType::Embedding
| OperationType::RoPE
| OperationType::KVCache
| OperationType::AllReduce
| OperationType::AllGather
| OperationType::ReduceScatter
| OperationType::Backward
| OperationType::OptimizerStep
| OperationType::Conv2d
| OperationType::BatchNorm
| OperationType::LayerNorm
| OperationType::SelfAttention
| OperationType::CrossAttention
| OperationType::FlashAttention
| OperationType::Add
| OperationType::Mul
| OperationType::ReLU
| OperationType::GeLU
| OperationType::SiLU
| OperationType::Softmax
| OperationType::Sum
| OperationType::Mean
| OperationType::Embedding
| OperationType::RoPE
| OperationType::KVCache
| OperationType::AllReduce
| OperationType::AllGather
| OperationType::ReduceScatter
| OperationType::Backward
| OperationType::OptimizerStep
),
// NPUs for neural network inference
ProcessorType::Npu(_) => matches!(
op_type,
OperationType::MatMul
| OperationType::Conv2d
| OperationType::DepthwiseConv
| OperationType::BatchNorm
| OperationType::LayerNorm
| OperationType::SelfAttention
| OperationType::Add
| OperationType::Mul
| OperationType::ReLU
| OperationType::GeLU
| OperationType::SiLU
| OperationType::Softmax
| OperationType::Sum
| OperationType::Mean
| OperationType::Conv2d
| OperationType::DepthwiseConv
| OperationType::BatchNorm
| OperationType::LayerNorm
| OperationType::SelfAttention
| OperationType::Add
| OperationType::Mul
| OperationType::ReLU
| OperationType::GeLU
| OperationType::SiLU
| OperationType::Softmax
| OperationType::Sum
| OperationType::Mean
),
// LPUs for sequential inference (optimized for LLMs)
ProcessorType::Lpu => matches!(
op_type,
OperationType::MatMul
| OperationType::LayerNorm
| OperationType::SelfAttention
| OperationType::FlashAttention
| OperationType::Add
| OperationType::Mul
| OperationType::ReLU
| OperationType::GeLU
| OperationType::SiLU
| OperationType::Softmax
| OperationType::Embedding
| OperationType::RoPE
| OperationType::KVCache
| OperationType::TopK
| OperationType::Sampling
| OperationType::LayerNorm
| OperationType::SelfAttention
| OperationType::FlashAttention
| OperationType::Add
| OperationType::Mul
| OperationType::ReLU
| OperationType::GeLU
| OperationType::SiLU
| OperationType::Softmax
| OperationType::Embedding
| OperationType::RoPE
| OperationType::KVCache
| OperationType::TopK
| OperationType::Sampling
),
// FPGAs can be programmed for anything
@ -322,40 +327,40 @@ impl LoadBalancer {
ProcessorType::Dsp(_) => matches!(
op_type,
OperationType::Conv2d
| OperationType::DepthwiseConv
| OperationType::Add
| OperationType::Mul
| OperationType::Sum
| OperationType::Mean
| OperationType::Max
| OperationType::DepthwiseConv
| OperationType::Add
| OperationType::Mul
| OperationType::Sum
| OperationType::Mean
| OperationType::Max
),
// WebGPU has limited operations
ProcessorType::WebGpu => matches!(
op_type,
OperationType::MatMul
| OperationType::Conv2d
| OperationType::Add
| OperationType::Mul
| OperationType::ReLU
| OperationType::Softmax
| OperationType::Sum
| OperationType::Transpose
| OperationType::Reshape
| OperationType::Conv2d
| OperationType::Add
| OperationType::Mul
| OperationType::ReLU
| OperationType::Softmax
| OperationType::Sum
| OperationType::Transpose
| OperationType::Reshape
),
// WASM for portable compute
ProcessorType::Wasm => matches!(
op_type,
OperationType::MatMul
| OperationType::Add
| OperationType::Mul
| OperationType::ReLU
| OperationType::Softmax
| OperationType::Sum
| OperationType::Mean
| OperationType::Tokenization
| OperationType::Detokenization
| OperationType::Add
| OperationType::Mul
| OperationType::ReLU
| OperationType::Softmax
| OperationType::Sum
| OperationType::Mean
| OperationType::Tokenization
| OperationType::Detokenization
),
// Custom processors - assume they can handle anything
@ -381,7 +386,9 @@ impl LoadBalancer {
}
// Get utilization and metrics
let utilization = proc_metrics.map(|m| m.utilization).unwrap_or(load as f64 / 100.0);
let utilization = proc_metrics
.map(|m| m.utilization)
.unwrap_or(load as f64 / 100.0);
let power = proc_metrics.map(|m| m.power_watts).unwrap_or(100.0);
let avg_completion = proc_metrics.map(|m| m.avg_completion_ms).unwrap_or(100.0);
@ -431,13 +438,13 @@ impl LoadBalancer {
BalancingStrategy::Cost => {
// Prioritize cheaper resources (consumer devices)
let cost_factor = match processor_type {
ProcessorType::Wasm => 0.1, // Cheapest (browser)
ProcessorType::Wasm => 0.1, // Cheapest (browser)
ProcessorType::WebGpu => 0.15,
ProcessorType::Cpu(_) => 0.2,
ProcessorType::Npu(_) => 0.3, // Mobile NPUs
ProcessorType::Npu(_) => 0.3, // Mobile NPUs
ProcessorType::Gpu(_) => 0.5,
ProcessorType::Lpu => 0.8,
ProcessorType::Tpu(_) => 1.0, // Most expensive
ProcessorType::Tpu(_) => 1.0, // Most expensive
_ => 0.5,
};
@ -450,7 +457,7 @@ impl LoadBalancer {
// Bonus for low-latency processors
let latency_bonus = match processor_type {
ProcessorType::Lpu => 5.0, // Designed for low latency
ProcessorType::Lpu => 5.0, // Designed for low latency
ProcessorType::Npu(_) => 3.0,
ProcessorType::Gpu(_) => 2.0,
ProcessorType::Tpu(_) => 1.5,
@ -550,7 +557,8 @@ impl LoadBalancer {
let mut suggestions = Vec::new();
let loads = self.loads.read();
let load_values: Vec<_> = loads.iter()
let load_values: Vec<_> = loads
.iter()
.map(|(id, load)| (*id, load.load(Ordering::Relaxed)))
.collect();
@ -558,16 +566,18 @@ impl LoadBalancer {
return suggestions;
}
let avg_load: f64 = load_values.iter().map(|(_, l)| *l as f64).sum::<f64>()
/ load_values.len() as f64;
let avg_load: f64 =
load_values.iter().map(|(_, l)| *l as f64).sum::<f64>() / load_values.len() as f64;
let processor_types = self.processor_types.read();
let overloaded: Vec<_> = load_values.iter()
let overloaded: Vec<_> = load_values
.iter()
.filter(|(_, l)| *l as f64 > avg_load * (1.0 + self.rebalance_threshold))
.collect();
let underloaded: Vec<_> = load_values.iter()
let underloaded: Vec<_> = load_values
.iter()
.filter(|(_, l)| (*l as f64) < avg_load * (1.0 - self.rebalance_threshold))
.collect();
@ -627,7 +637,9 @@ impl LoadBalancer {
/// Clean up old migration history.
pub fn cleanup_history(&self, max_age: Duration) {
let cutoff = Instant::now() - max_age;
self.migration_history.write().retain(|r| r.timestamp > cutoff);
self.migration_history
.write()
.retain(|r| r.timestamp > cutoff);
}
}
@ -725,7 +737,9 @@ mod tests {
balancer.register_processor(ProcessorId(0), ProcessorType::Cpu(CpuVariant::default()));
balancer.register_processor(
ProcessorId(1),
ProcessorType::Gpu(GpuVariant::NvidiaCuda { compute_capability: (8, 9) }),
ProcessorType::Gpu(GpuVariant::NvidiaCuda {
compute_capability: (8, 9),
}),
);
// Give CPU high load
@ -757,7 +771,9 @@ mod tests {
};
let cpu = ProcessorType::Cpu(CpuVariant::default());
let gpu = ProcessorType::Gpu(GpuVariant::NvidiaCuda { compute_capability: (8, 9) });
let gpu = ProcessorType::Gpu(GpuVariant::NvidiaCuda {
compute_capability: (8, 9),
});
let lpu = ProcessorType::Lpu;
// MatMul can run on all
@ -778,7 +794,10 @@ mod tests {
let npu_id = ProcessorId(1);
balancer.register_processor(cpu_id, ProcessorType::Cpu(CpuVariant::default()));
balancer.register_processor(npu_id, ProcessorType::Npu(crate::processor::NpuVariant::AppleNeuralEngine { cores: 16 }));
balancer.register_processor(
npu_id,
ProcessorType::Npu(crate::processor::NpuVariant::AppleNeuralEngine { cores: 16 }),
);
let task = create_test_task(TaskPriority::Normal);

View file

@ -69,7 +69,9 @@ impl HeterogeneousScheduler {
let utilization = self.estimate_utilization(&schedule);
// 5. Store active schedule
self.active_schedules.write().insert(schedule.id, schedule.clone());
self.active_schedules
.write()
.insert(schedule.id, schedule.clone());
Ok(ScheduleResult {
schedule,
@ -89,10 +91,12 @@ impl HeterogeneousScheduler {
let mut handles = Vec::new();
for task_id in &stage.tasks {
let task = schedule.tasks.get(task_id)
.ok_or_else(|| ComputeError::Internal(format!("Task not found: {:?}", task_id)))?;
let processor_id = schedule.assignment.get(task_id)
.ok_or_else(|| ComputeError::Internal(format!("No assignment for task: {:?}", task_id)))?;
let task = schedule.tasks.get(task_id).ok_or_else(|| {
ComputeError::Internal(format!("Task not found: {:?}", task_id))
})?;
let processor_id = schedule.assignment.get(task_id).ok_or_else(|| {
ComputeError::Internal(format!("No assignment for task: {:?}", task_id))
})?;
let processor = self.device_registry.get_processor(processor_id)?;
let task_clone = task.clone();
@ -144,8 +148,9 @@ impl HeterogeneousScheduler {
let best_processor = self.find_best_processor(&task).await?;
// Check if we should rebalance
let final_processor = self.load_balancer
.maybe_rebalance(&task, best_processor, &assignment);
let final_processor =
self.load_balancer
.maybe_rebalance(&task, best_processor, &assignment);
assignment.assign(task.id, final_processor);
}
@ -207,9 +212,7 @@ impl HeterogeneousScheduler {
fn topological_sort(&self, tasks: &[Task], deps: &DependencyGraph) -> Vec<Task> {
let mut sorted = Vec::new();
let mut visited = std::collections::HashSet::new();
let task_map: HashMap<TaskId, Task> = tasks.iter()
.map(|t| (t.id, t.clone()))
.collect();
let task_map: HashMap<TaskId, Task> = tasks.iter().map(|t| (t.id, t.clone())).collect();
fn visit(
task_id: TaskId,
@ -254,9 +257,7 @@ impl HeterogeneousScheduler {
) -> Result<Schedule, ComputeError> {
let mut stages = Vec::new();
let mut scheduled = std::collections::HashSet::new();
let task_map: HashMap<TaskId, Task> = tasks.iter()
.map(|t| (t.id, t.clone()))
.collect();
let task_map: HashMap<TaskId, Task> = tasks.iter().map(|t| (t.id, t.clone())).collect();
while scheduled.len() < tasks.len() {
let mut stage_tasks = Vec::new();
@ -267,8 +268,7 @@ impl HeterogeneousScheduler {
}
// Check if all dependencies are satisfied
let deps_satisfied = task.dependencies.iter()
.all(|dep| scheduled.contains(dep));
let deps_satisfied = task.dependencies.iter().all(|dep| scheduled.contains(dep));
if deps_satisfied {
stage_tasks.push(task.id);
@ -277,7 +277,7 @@ impl HeterogeneousScheduler {
if stage_tasks.is_empty() {
return Err(ComputeError::SchedulingFailed(
"Circular dependency detected".to_string()
"Circular dependency detected".to_string(),
));
}

View file

@ -153,7 +153,10 @@ impl PriorityWorkQueue {
TaskPriority::Normal,
TaskPriority::Background,
] {
queues.insert(priority, WorkQueue::new(processor_type.clone(), capacity_per_priority));
queues.insert(
priority,
WorkQueue::new(processor_type.clone(), capacity_per_priority),
);
}
Self {
@ -223,10 +226,7 @@ mod tests {
#[test]
fn test_work_queue_basic() {
let queue = WorkQueue::new(
ProcessorType::Cpu(CpuVariant::default()),
100,
);
let queue = WorkQueue::new(ProcessorType::Cpu(CpuVariant::default()), 100);
assert!(queue.is_empty());
@ -246,10 +246,7 @@ mod tests {
#[test]
fn test_priority_queue() {
let queue = PriorityWorkQueue::new(
ProcessorType::Cpu(CpuVariant::default()),
100,
);
let queue = PriorityWorkQueue::new(ProcessorType::Cpu(CpuVariant::default()), 100);
queue.push(create_test_task(1, TaskPriority::Background));
queue.push(create_test_task(2, TaskPriority::Critical));

View file

@ -495,9 +495,9 @@ mod tests {
compute_capability: (8, 0)
}
)));
assert!(matmul_task.is_compatible_with(ProcessorType::Tpu(
crate::processor::TpuVersion::V5p
)));
assert!(
matmul_task.is_compatible_with(ProcessorType::Tpu(crate::processor::TpuVersion::V5p))
);
let data_load_task = Task::new(Operation::DataLoad {
bytes: 1000,
@ -505,9 +505,8 @@ mod tests {
});
// DataLoad should be compatible with CPU
assert!(data_load_task.is_compatible_with(ProcessorType::Cpu(
crate::processor::CpuVariant::default()
)));
assert!(data_load_task
.is_compatible_with(ProcessorType::Cpu(crate::processor::CpuVariant::default())));
}
#[test]

View file

@ -25,8 +25,7 @@
use std::time::Duration;
/// Blocks per second mode.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
#[derive(Default)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Default)]
pub enum BpsMode {
/// Standard mode: 10 blocks per second (100ms block time)
/// - Suitable for most network conditions
@ -75,7 +74,6 @@ impl BpsMode {
}
}
impl std::fmt::Display for BpsMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
@ -148,39 +146,39 @@ impl NetworkConfig {
bps_mode: mode,
blocks_per_second: 10,
target_block_time_ms: 100,
daa_window_size: 2641, // ~264s window
ghostdag_k: 18, // For 10 BPS
daa_window_size: 2641, // ~264s window
ghostdag_k: 18, // For 10 BPS
dagknight_k_min: 8,
dagknight_k_max: 64,
finality_depth: 864, // ~86 seconds
pruning_depth: 864_000, // ~24 hours
finality_depth: 864, // ~86 seconds
pruning_depth: 864_000, // ~24 hours
merge_set_size_limit: 180,
expected_delay_ms: 100,
},
BpsMode::Fast32 => Self {
bps_mode: mode,
blocks_per_second: 32,
target_block_time_ms: 31, // ~31.25ms
daa_window_size: 8461, // ~264s window at 32 BPS
ghostdag_k: 58, // Scaled for 32 BPS
dagknight_k_min: 16, // Higher min for faster blocks
dagknight_k_max: 128, // Higher max for adaptation
finality_depth: 2765, // ~86 seconds at 32 BPS
pruning_depth: 2_764_800, // ~24 hours at 32 BPS
merge_set_size_limit: 576, // 32/10 * 180
target_block_time_ms: 31, // ~31.25ms
daa_window_size: 8461, // ~264s window at 32 BPS
ghostdag_k: 58, // Scaled for 32 BPS
dagknight_k_min: 16, // Higher min for faster blocks
dagknight_k_max: 128, // Higher max for adaptation
finality_depth: 2765, // ~86 seconds at 32 BPS
pruning_depth: 2_764_800, // ~24 hours at 32 BPS
merge_set_size_limit: 576, // 32/10 * 180
expected_delay_ms: 50,
},
BpsMode::Ultra100 => Self {
bps_mode: mode,
blocks_per_second: 100,
target_block_time_ms: 10,
daa_window_size: 26410, // ~264s window at 100 BPS
ghostdag_k: 180, // Scaled for 100 BPS
dagknight_k_min: 50, // Higher min for very fast blocks
dagknight_k_max: 255, // u8 max - very high for adaptation
finality_depth: 8640, // ~86 seconds at 100 BPS
pruning_depth: 8_640_000, // ~24 hours at 100 BPS
merge_set_size_limit: 1800, // 100/10 * 180
daa_window_size: 26410, // ~264s window at 100 BPS
ghostdag_k: 180, // Scaled for 100 BPS
dagknight_k_min: 50, // Higher min for very fast blocks
dagknight_k_max: 255, // u8 max - very high for adaptation
finality_depth: 8640, // ~86 seconds at 100 BPS
pruning_depth: 8_640_000, // ~24 hours at 100 BPS
merge_set_size_limit: 1800, // 100/10 * 180
expected_delay_ms: 20,
},
BpsMode::Custom(bps) => {
@ -269,7 +267,7 @@ pub fn bps_comparison_table() -> String {
let mut table = String::from(
"| Property | Standard (10 BPS) | Fast (32 BPS) | Ultra (100 BPS) |\n\
|----------|-------------------|---------------|------------------|\n"
|----------|-------------------|---------------|------------------|\n",
);
// Block Time
@ -314,7 +312,9 @@ pub fn bps_comparison_table() -> String {
// Estimated TPS
table.push_str(&format!(
"| Est. TPS @1000tx/block | {:.0} | {:.0} | {:.0} |\n",
standard.estimate_tps(1000), fast.estimate_tps(1000), ultra.estimate_tps(1000)
standard.estimate_tps(1000),
fast.estimate_tps(1000),
ultra.estimate_tps(1000)
));
table
@ -401,9 +401,9 @@ mod tests {
fn test_latency_acceptable() {
let config = NetworkConfig::standard(); // expects 100ms
assert!(config.is_latency_acceptable(50)); // Good
assert!(config.is_latency_acceptable(100)); // OK
assert!(config.is_latency_acceptable(200)); // Still OK (2x limit)
assert!(config.is_latency_acceptable(50)); // Good
assert!(config.is_latency_acceptable(100)); // OK
assert!(config.is_latency_acceptable(200)); // Still OK (2x limit)
assert!(!config.is_latency_acceptable(300)); // Too high
}

View file

@ -55,8 +55,8 @@
//! | Layer 2 transactions | FALCON-512 (batch efficiency) |
//! | High-value transactions | Dilithium3 (conservative choice) |
use pqcrypto_falcon::falcon512;
use pqcrypto_falcon::falcon1024;
use pqcrypto_falcon::falcon512;
use pqcrypto_traits::sign::{
DetachedSignature, PublicKey as PqPublicKey, SecretKey as PqSecretKey,
};
@ -64,8 +64,7 @@ use thiserror::Error;
use zeroize::{Zeroize, ZeroizeOnDrop};
/// FALCON variant selection.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[derive(Default)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
pub enum FalconVariant {
/// 128-bit security, ~690 byte signatures
#[default]
@ -124,7 +123,6 @@ impl FalconVariant {
}
}
/// FALCON public key.
#[derive(Clone)]
pub struct FalconPublicKey {
@ -188,7 +186,10 @@ impl std::fmt::Debug for FalconPublicKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FalconPublicKey")
.field("variant", &self.variant)
.field("bytes", &hex::encode(&self.bytes[..8.min(self.bytes.len())]))
.field(
"bytes",
&hex::encode(&self.bytes[..8.min(self.bytes.len())]),
)
.finish()
}
}
@ -492,7 +493,10 @@ mod tests {
// Verify with wrong message should fail
let wrong_message = b"Wrong message";
assert!(keypair.public_key().verify(wrong_message, &signature).is_err());
assert!(keypair
.public_key()
.verify(wrong_message, &signature)
.is_err());
}
#[test]

View file

@ -150,9 +150,9 @@ impl PqAlgorithm {
/// Default priority order (higher = more preferred)
fn default_priority(&self) -> u8 {
match self {
Self::Dilithium3 => 100, // Default, well-balanced
Self::Falcon1024 => 90, // High security, compact
Self::Falcon512 => 85, // Compact, mobile-friendly
Self::Dilithium3 => 100, // Default, well-balanced
Self::Falcon1024 => 90, // High security, compact
Self::Falcon512 => 85, // Compact, mobile-friendly
Self::SphincsShake192s => 70, // Conservative backup
Self::SphincsShake256s => 60, // Maximum security
Self::SphincsShake128s => 50, // Basic SPHINCS+
@ -270,7 +270,8 @@ impl AlgorithmCapabilities {
/// Decode capabilities from bytes
pub fn decode(data: &[u8]) -> Result<Self, NegotiationError> {
serde_json::from_slice(data).map_err(|e| NegotiationError::InvalidCapabilities(e.to_string()))
serde_json::from_slice(data)
.map_err(|e| NegotiationError::InvalidCapabilities(e.to_string()))
}
}
@ -384,8 +385,7 @@ impl AlgorithmNegotiator {
// Check security level
let meets_local_security =
algo.security_level() >= self.local_caps.min_security_level;
let meets_remote_security =
algo.security_level() >= remote_caps.min_security_level;
let meets_remote_security = algo.security_level() >= remote_caps.min_security_level;
// Check signature size
let local_size_ok = self.local_caps.max_signature_size == 0
@ -513,10 +513,7 @@ impl AlgorithmNegotiator {
}
/// Quick negotiation using just algorithm names
pub fn quick_negotiate(
local: &[PqAlgorithm],
remote: &[PqAlgorithm],
) -> Option<PqAlgorithm> {
pub fn quick_negotiate(local: &[PqAlgorithm], remote: &[PqAlgorithm]) -> Option<PqAlgorithm> {
// Find common algorithms and return the one with highest default priority
let local_set: HashSet<_> = local.iter().collect();
let remote_set: HashSet<_> = remote.iter().collect();
@ -604,7 +601,10 @@ pub enum NegotiationMessage {
},
/// Acknowledge selection
Acknowledgment { session_id: [u8; 32], accepted: bool },
Acknowledgment {
session_id: [u8; 32],
accepted: bool,
},
/// Request renegotiation
Renegotiate { reason: String },
@ -691,8 +691,10 @@ mod tests {
let result = negotiator.negotiate(&remote_caps).unwrap();
// Should prefer FALCON for bandwidth-constrained scenarios
assert!(result.algorithm == PqAlgorithm::Falcon512 ||
result.algorithm == PqAlgorithm::Falcon1024);
assert!(
result.algorithm == PqAlgorithm::Falcon512
|| result.algorithm == PqAlgorithm::Falcon1024
);
}
#[test]

View file

@ -60,8 +60,7 @@ use zeroize::{Zeroize, ZeroizeOnDrop};
/// All variants use SHAKE (SHA3-based) for hashing.
/// 's' variants have smaller signatures but are slower.
/// 'f' variants are faster but have larger signatures.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[derive(Default)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
pub enum SphincsVariant {
/// 128-bit security, small signatures (~7.8KB)
#[default]
@ -119,7 +118,6 @@ impl SphincsVariant {
}
}
/// SPHINCS+ public key.
#[derive(Clone)]
pub struct SphincsPublicKey {
@ -191,7 +189,10 @@ impl std::fmt::Debug for SphincsPublicKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SphincsPublicKey")
.field("variant", &self.variant)
.field("bytes", &hex::encode(&self.bytes[..8.min(self.bytes.len())]))
.field(
"bytes",
&hex::encode(&self.bytes[..8.min(self.bytes.len())]),
)
.finish()
}
}
@ -500,7 +501,10 @@ mod tests {
// Verify with wrong message should fail
let wrong_message = b"Wrong message";
assert!(keypair.public_key().verify(wrong_message, &signature).is_err());
assert!(keypair
.public_key()
.verify(wrong_message, &signature)
.is_err());
}
#[test]

View file

@ -155,10 +155,7 @@ pub struct DagKnightManager {
impl DagKnightManager {
/// Creates a new DAGKnight manager with standard 10 BPS configuration.
pub fn new(
dag: Arc<BlockDag>,
reachability: Arc<ReachabilityStore>,
) -> Self {
pub fn new(dag: Arc<BlockDag>, reachability: Arc<ReachabilityStore>) -> Self {
Self::with_config(dag, reachability, BlockRateConfig::Standard)
}
@ -269,7 +266,8 @@ impl DagKnightManager {
let anticone_size = self.calculate_anticone_size(&block_id, parents);
// Record observation in latency tracker
self.latency_tracker.record_block(block_id, block_time_ms, anticone_size);
self.latency_tracker
.record_block(block_id, block_time_ms, anticone_size);
// Process with underlying GHOSTDAG
let data = self.ghostdag.add_block(block_id, parents)?;
@ -292,11 +290,9 @@ impl DagKnightManager {
for tip in tips {
if tip != *block_id && !parents.contains(&tip) {
// Check if tip is in the past of any parent
let in_past = parents.iter().any(|p| {
self.reachability
.is_ancestor(p, &tip)
.unwrap_or(false)
});
let in_past = parents
.iter()
.any(|p| self.reachability.is_ancestor(p, &tip).unwrap_or(false));
if !in_past {
anticone_count += 1;
@ -375,7 +371,8 @@ impl DagKnightManager {
let sigma_multiplier = confidence.sigma_multiplier();
// Required depth scales with variance and confidence level
let required_depth = (self.block_rate_bps * (mean_delay + sigma * sigma_multiplier)).ceil() as u64;
let required_depth =
(self.block_rate_bps * (mean_delay + sigma * sigma_multiplier)).ceil() as u64;
// Current confidence based on actual depth
let current_confidence = if depth >= required_depth {
@ -388,7 +385,8 @@ impl DagKnightManager {
// Time to reach required depth
let blocks_needed = required_depth.saturating_sub(depth);
let time_per_block_ms = 1000.0 / self.block_rate_bps;
let estimated_time = Duration::from_millis((blocks_needed as f64 * time_per_block_ms) as u64);
let estimated_time =
Duration::from_millis((blocks_needed as f64 * time_per_block_ms) as u64);
// Block is final if depth exceeds finality threshold for this block rate
let is_final = depth >= self.finality_depth();
@ -506,7 +504,10 @@ impl std::fmt::Debug for DagKnightManager {
.field("block_rate_config", &self.block_rate_config)
.field("block_rate_bps", &self.block_rate_bps)
.field("adaptive_k", &*self.adaptive_k.read())
.field("k_bounds", &format!("{}-{}", self.k_bounds.min_k, self.k_bounds.max_k))
.field(
"k_bounds",
&format!("{}-{}", self.k_bounds.min_k, self.k_bounds.max_k),
)
.field("mean_delay_ms", &stats.mean_delay_ms)
.field("sample_count", &stats.sample_count)
.finish()
@ -534,10 +535,7 @@ pub fn calculate_optimal_k(network_delay_ms: f64, block_rate_bps: f64) -> u8 {
}
/// Calculates the optimal k for a specific block rate configuration.
pub fn calculate_optimal_k_for_config(
network_delay_ms: f64,
config: BlockRateConfig,
) -> u8 {
pub fn calculate_optimal_k_for_config(network_delay_ms: f64, config: BlockRateConfig) -> u8 {
let bounds = AdaptiveKBounds::for_block_rate(config);
let delay_secs = network_delay_ms / 1000.0;
let k = (config.bps() * delay_secs * SAFETY_MARGIN).ceil() as u16;
@ -578,7 +576,9 @@ mod tests {
(dag, reachability, dagknight)
}
fn setup_test_dag_with_config(config: BlockRateConfig) -> (Arc<BlockDag>, Arc<ReachabilityStore>, DagKnightManager) {
fn setup_test_dag_with_config(
config: BlockRateConfig,
) -> (Arc<BlockDag>, Arc<ReachabilityStore>, DagKnightManager) {
let genesis = make_block_id(0);
let dag = Arc::new(BlockDag::new(genesis, 0));
let reachability = Arc::new(ReachabilityStore::new(genesis));
@ -671,14 +671,19 @@ mod tests {
let tps_poor = estimate_throughput(10.0, 100, 40.0);
// Good network should have higher throughput
assert!(tps_good > tps_poor, "tps_good={} should be > tps_poor={}", tps_good, tps_poor);
assert!(
tps_good > tps_poor,
"tps_good={} should be > tps_poor={}",
tps_good,
tps_poor
);
}
#[test]
fn test_throughput_by_config() {
// At same network conditions, higher BPS = higher theoretical TPS
let tps_10 = estimate_throughput(10.0, 100, 20.0); // 10 BPS
let tps_32 = estimate_throughput(32.0, 100, 20.0); // 32 BPS
let tps_10 = estimate_throughput(10.0, 100, 20.0); // 10 BPS
let tps_32 = estimate_throughput(32.0, 100, 20.0); // 32 BPS
let tps_100 = estimate_throughput(100.0, 100, 20.0); // 100 BPS
// Higher block rates give higher TPS (with network overhead)
@ -698,19 +703,37 @@ mod tests {
let maximum_time_hrs = maximum.finality_depth() as f64 / 100.0 / 3600.0;
// Should all be approximately 2.4 hours (allow some variance)
assert!((standard_time_hrs - 2.4).abs() < 0.1, "standard: {}", standard_time_hrs);
assert!((enhanced_time_hrs - 2.4).abs() < 0.1, "enhanced: {}", enhanced_time_hrs);
assert!((maximum_time_hrs - 2.4).abs() < 0.1, "maximum: {}", maximum_time_hrs);
assert!(
(standard_time_hrs - 2.4).abs() < 0.1,
"standard: {}",
standard_time_hrs
);
assert!(
(enhanced_time_hrs - 2.4).abs() < 0.1,
"enhanced: {}",
enhanced_time_hrs
);
assert!(
(maximum_time_hrs - 2.4).abs() < 0.1,
"maximum: {}",
maximum_time_hrs
);
}
#[test]
fn test_confidence_levels() {
assert!(ConfirmationConfidence::VeryHigh.sigma_multiplier()
> ConfirmationConfidence::High.sigma_multiplier());
assert!(ConfirmationConfidence::High.sigma_multiplier()
> ConfirmationConfidence::Medium.sigma_multiplier());
assert!(ConfirmationConfidence::Medium.sigma_multiplier()
> ConfirmationConfidence::Low.sigma_multiplier());
assert!(
ConfirmationConfidence::VeryHigh.sigma_multiplier()
> ConfirmationConfidence::High.sigma_multiplier()
);
assert!(
ConfirmationConfidence::High.sigma_multiplier()
> ConfirmationConfidence::Medium.sigma_multiplier()
);
assert!(
ConfirmationConfidence::Medium.sigma_multiplier()
> ConfirmationConfidence::Low.sigma_multiplier()
);
}
#[test]

View file

@ -98,12 +98,7 @@ impl LatencyTracker {
/// * `block_id` - Hash of the observed block
/// * `block_time_ms` - Timestamp from block header (Unix ms)
/// * `anticone_size` - Number of blocks in the anticone at observation time
pub fn record_block(
&self,
block_id: BlockId,
block_time_ms: u64,
anticone_size: usize,
) {
pub fn record_block(&self, block_id: BlockId, block_time_ms: u64, anticone_size: usize) {
let local_time = Instant::now();
let now_ms = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
@ -208,7 +203,10 @@ impl LatencyTracker {
let anticone_growth_rate = if n > 1 {
let first = samples.front().unwrap();
let last = samples.back().unwrap();
let time_span_secs = last.local_time.duration_since(first.local_time).as_secs_f64();
let time_span_secs = last
.local_time
.duration_since(first.local_time)
.as_secs_f64();
if time_span_secs > 0.0 {
let total_anticone_growth: usize = samples.iter().map(|s| s.anticone_size).sum();

View file

@ -32,8 +32,8 @@ pub mod reachability;
pub use dag::{BlockDag, BlockRelations, DagError};
pub use dagknight::{
calculate_optimal_k, calculate_optimal_k_for_config, estimate_throughput,
AdaptiveKBounds, ConfirmationConfidence, ConfirmationStatus, DagKnightManager,
calculate_optimal_k, calculate_optimal_k_for_config, estimate_throughput, AdaptiveKBounds,
ConfirmationConfidence, ConfirmationStatus, DagKnightManager,
};
pub use ghostdag::{GhostdagData, GhostdagError, GhostdagManager};
pub use latency::{LatencySample, LatencyStats, LatencyTracker};
@ -116,27 +116,27 @@ impl BlockRateConfig {
/// Returns the merge depth adjusted for block rate.
pub const fn merge_depth(&self) -> u64 {
match self {
BlockRateConfig::Standard => 3600, // ~6 min at 10 bps
BlockRateConfig::Enhanced => 11520, // ~6 min at 32 bps
BlockRateConfig::Maximum => 36000, // ~6 min at 100 bps
BlockRateConfig::Standard => 3600, // ~6 min at 10 bps
BlockRateConfig::Enhanced => 11520, // ~6 min at 32 bps
BlockRateConfig::Maximum => 36000, // ~6 min at 100 bps
}
}
/// Returns the finality depth adjusted for block rate.
pub const fn finality_depth(&self) -> u64 {
match self {
BlockRateConfig::Standard => 86400, // ~2.4 hours at 10 bps
BlockRateConfig::Enhanced => 276480, // ~2.4 hours at 32 bps
BlockRateConfig::Maximum => 864000, // ~2.4 hours at 100 bps
BlockRateConfig::Standard => 86400, // ~2.4 hours at 10 bps
BlockRateConfig::Enhanced => 276480, // ~2.4 hours at 32 bps
BlockRateConfig::Maximum => 864000, // ~2.4 hours at 100 bps
}
}
/// Returns the pruning depth adjusted for block rate.
pub const fn pruning_depth(&self) -> u64 {
match self {
BlockRateConfig::Standard => 288_000, // ~8 hours at 10 bps
BlockRateConfig::Enhanced => 921_600, // ~8 hours at 32 bps
BlockRateConfig::Maximum => 2_880_000, // ~8 hours at 100 bps
BlockRateConfig::Standard => 288_000, // ~8 hours at 10 bps
BlockRateConfig::Enhanced => 921_600, // ~8 hours at 32 bps
BlockRateConfig::Maximum => 2_880_000, // ~8 hours at 100 bps
}
}
}

View file

@ -42,7 +42,9 @@ impl DocumentId {
let bytes = hex::decode(s)
.map_err(|_| DatabaseError::InvalidOperation("Invalid hex string".into()))?;
if bytes.len() != 32 {
return Err(DatabaseError::InvalidOperation("Invalid document ID length".into()));
return Err(DatabaseError::InvalidOperation(
"Invalid document ID length".into(),
));
}
let mut arr = [0u8; 32];
arr.copy_from_slice(&bytes);
@ -249,7 +251,11 @@ impl Collection {
}
/// Updates documents matching a filter.
pub fn update_many(&self, filter: &DocumentFilter, update: JsonValue) -> Result<u64, DatabaseError> {
pub fn update_many(
&self,
filter: &DocumentFilter,
update: JsonValue,
) -> Result<u64, DatabaseError> {
let mut docs = self.documents.write();
let mut count = 0;
for doc in docs.values_mut() {
@ -321,60 +327,71 @@ enum FilterCondition {
impl DocumentFilter {
/// Creates a new empty filter (matches all).
pub fn new() -> Self {
Self { conditions: Vec::new() }
Self {
conditions: Vec::new(),
}
}
/// Equality condition.
pub fn eq(mut self, field: impl Into<String>, value: JsonValue) -> Self {
self.conditions.push(FilterCondition::Eq(field.into(), value));
self.conditions
.push(FilterCondition::Eq(field.into(), value));
self
}
/// Not equal condition.
pub fn ne(mut self, field: impl Into<String>, value: JsonValue) -> Self {
self.conditions.push(FilterCondition::Ne(field.into(), value));
self.conditions
.push(FilterCondition::Ne(field.into(), value));
self
}
/// Greater than.
pub fn gt(mut self, field: impl Into<String>, value: JsonValue) -> Self {
self.conditions.push(FilterCondition::Gt(field.into(), value));
self.conditions
.push(FilterCondition::Gt(field.into(), value));
self
}
/// Greater than or equal.
pub fn gte(mut self, field: impl Into<String>, value: JsonValue) -> Self {
self.conditions.push(FilterCondition::Gte(field.into(), value));
self.conditions
.push(FilterCondition::Gte(field.into(), value));
self
}
/// Less than.
pub fn lt(mut self, field: impl Into<String>, value: JsonValue) -> Self {
self.conditions.push(FilterCondition::Lt(field.into(), value));
self.conditions
.push(FilterCondition::Lt(field.into(), value));
self
}
/// Less than or equal.
pub fn lte(mut self, field: impl Into<String>, value: JsonValue) -> Self {
self.conditions.push(FilterCondition::Lte(field.into(), value));
self.conditions
.push(FilterCondition::Lte(field.into(), value));
self
}
/// In array.
pub fn in_array(mut self, field: impl Into<String>, values: Vec<JsonValue>) -> Self {
self.conditions.push(FilterCondition::In(field.into(), values));
self.conditions
.push(FilterCondition::In(field.into(), values));
self
}
/// String contains.
pub fn contains(mut self, field: impl Into<String>, substring: impl Into<String>) -> Self {
self.conditions.push(FilterCondition::Contains(field.into(), substring.into()));
self.conditions
.push(FilterCondition::Contains(field.into(), substring.into()));
self
}
/// Field exists.
pub fn exists(mut self, field: impl Into<String>, exists: bool) -> Self {
self.conditions.push(FilterCondition::Exists(field.into(), exists));
self.conditions
.push(FilterCondition::Exists(field.into(), exists));
self
}
@ -396,7 +413,9 @@ impl DocumentFilter {
return true;
}
self.conditions.iter().all(|cond| self.eval_condition(cond, doc))
self.conditions
.iter()
.all(|cond| self.eval_condition(cond, doc))
}
fn eval_condition(&self, cond: &FilterCondition, doc: &Document) -> bool {
@ -419,27 +438,21 @@ impl DocumentFilter {
FilterCondition::Lte(field, value) => {
self.compare_values(doc.get_nested(field), value, |a, b| a <= b)
}
FilterCondition::In(field, values) => {
doc.get_nested(field)
.map(|v| values.contains(v))
.unwrap_or(false)
}
FilterCondition::Contains(field, substring) => {
doc.get_nested(field)
.and_then(|v| v.as_str())
.map(|s| s.contains(substring))
.unwrap_or(false)
}
FilterCondition::In(field, values) => doc
.get_nested(field)
.map(|v| values.contains(v))
.unwrap_or(false),
FilterCondition::Contains(field, substring) => doc
.get_nested(field)
.and_then(|v| v.as_str())
.map(|s| s.contains(substring))
.unwrap_or(false),
FilterCondition::Exists(field, should_exist) => {
let exists = doc.get_nested(field).is_some();
exists == *should_exist
}
FilterCondition::And(filters) => {
filters.iter().all(|f| f.matches(doc))
}
FilterCondition::Or(filters) => {
filters.iter().any(|f| f.matches(doc))
}
FilterCondition::And(filters) => filters.iter().all(|f| f.matches(doc)),
FilterCondition::Or(filters) => filters.iter().any(|f| f.matches(doc)),
}
}
@ -448,12 +461,10 @@ impl DocumentFilter {
F: Fn(f64, f64) -> bool,
{
match (a, b) {
(Some(JsonValue::Number(a)), JsonValue::Number(b)) => {
match (a.as_f64(), b.as_f64()) {
(Some(a), Some(b)) => cmp(a, b),
_ => false,
}
}
(Some(JsonValue::Number(a)), JsonValue::Number(b)) => match (a.as_f64(), b.as_f64()) {
(Some(a), Some(b)) => cmp(a, b),
_ => false,
},
_ => false,
}
}
@ -512,7 +523,11 @@ impl DocumentStore {
}
/// Finds documents in a collection.
pub fn find(&self, collection: &str, filter: &DocumentFilter) -> Result<Vec<Document>, DatabaseError> {
pub fn find(
&self,
collection: &str,
filter: &DocumentFilter,
) -> Result<Vec<Document>, DatabaseError> {
let collections = self.collections.read();
let coll = collections
.get(collection)
@ -521,7 +536,11 @@ impl DocumentStore {
}
/// Finds one document.
pub fn find_one(&self, collection: &str, filter: &DocumentFilter) -> Result<Option<Document>, DatabaseError> {
pub fn find_one(
&self,
collection: &str,
filter: &DocumentFilter,
) -> Result<Option<Document>, DatabaseError> {
let collections = self.collections.read();
let coll = collections
.get(collection)
@ -530,7 +549,11 @@ impl DocumentStore {
}
/// Finds a document by ID.
pub fn find_by_id(&self, collection: &str, id: &DocumentId) -> Result<Option<Document>, DatabaseError> {
pub fn find_by_id(
&self,
collection: &str,
id: &DocumentId,
) -> Result<Option<Document>, DatabaseError> {
let collections = self.collections.read();
let coll = collections
.get(collection)
@ -539,7 +562,12 @@ impl DocumentStore {
}
/// Updates a document by ID.
pub fn update_by_id(&self, collection: &str, id: &DocumentId, update: JsonValue) -> Result<bool, DatabaseError> {
pub fn update_by_id(
&self,
collection: &str,
id: &DocumentId,
update: JsonValue,
) -> Result<bool, DatabaseError> {
let collections = self.collections.read();
let coll = collections
.get(collection)
@ -584,7 +612,8 @@ mod tests {
fn test_collection_insert_find() {
let coll = Collection::new("users");
coll.insert_one(json!({"name": "Alice", "age": 30})).unwrap();
coll.insert_one(json!({"name": "Alice", "age": 30}))
.unwrap();
coll.insert_one(json!({"name": "Bob", "age": 25})).unwrap();
let filter = DocumentFilter::new().eq("name", json!("Alice"));
@ -597,9 +626,11 @@ mod tests {
fn test_filter_comparison() {
let coll = Collection::new("users");
coll.insert_one(json!({"name": "Alice", "age": 30})).unwrap();
coll.insert_one(json!({"name": "Alice", "age": 30}))
.unwrap();
coll.insert_one(json!({"name": "Bob", "age": 25})).unwrap();
coll.insert_one(json!({"name": "Charlie", "age": 35})).unwrap();
coll.insert_one(json!({"name": "Charlie", "age": 35}))
.unwrap();
let filter = DocumentFilter::new().gte("age", json!(30));
let results = coll.find(&filter);
@ -622,7 +653,9 @@ mod tests {
#[test]
fn test_update_document() {
let coll = Collection::new("users");
let id = coll.insert_one(json!({"name": "Alice", "age": 30})).unwrap();
let id = coll
.insert_one(json!({"name": "Alice", "age": 30}))
.unwrap();
coll.update_by_id(&id, json!({"age": 31})).unwrap();

View file

@ -1,9 +1,9 @@
//! Authentication and authorization for Database Gateway.
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use parking_lot::RwLock;
/// API key for authentication.
#[derive(Clone, Debug, Serialize, Deserialize)]

View file

@ -272,19 +272,13 @@ pub fn json_to_filter(json: &JsonValue) -> Option<Filter> {
// Handle $and
if let Some(and_arr) = obj.get("$and").and_then(|v| v.as_array()) {
let filters: Vec<Filter> = and_arr
.iter()
.filter_map(json_to_filter)
.collect();
let filters: Vec<Filter> = and_arr.iter().filter_map(json_to_filter).collect();
return Some(Filter::And(filters));
}
// Handle $or
if let Some(or_arr) = obj.get("$or").and_then(|v| v.as_array()) {
let filters: Vec<Filter> = or_arr
.iter()
.filter_map(json_to_filter)
.collect();
let filters: Vec<Filter> = or_arr.iter().filter_map(json_to_filter).collect();
return Some(Filter::Or(filters));
}

View file

@ -137,7 +137,12 @@ async fn kv_get(
// For demo, use a default database
let db = match get_default_database(&state) {
Some(db) => db,
None => return (StatusCode::NOT_FOUND, Json(ApiResponse::<KvGetResponse>::error("No database"))),
None => {
return (
StatusCode::NOT_FOUND,
Json(ApiResponse::<KvGetResponse>::error("No database")),
)
}
};
state.record_read();
@ -153,7 +158,12 @@ async fn kv_set(
) -> impl IntoResponse {
let db = match get_default_database(&state) {
Some(db) => db,
None => return (StatusCode::NOT_FOUND, Json(ApiResponse::<KvGetResponse>::error("No database"))),
None => {
return (
StatusCode::NOT_FOUND,
Json(ApiResponse::<KvGetResponse>::error("No database")),
)
}
};
state.record_write(req.value.len() as u64);
@ -168,7 +178,12 @@ async fn kv_delete(
) -> impl IntoResponse {
let db = match get_default_database(&state) {
Some(db) => db,
None => return (StatusCode::NOT_FOUND, Json(ApiResponse::<bool>::error("No database"))),
None => {
return (
StatusCode::NOT_FOUND,
Json(ApiResponse::<bool>::error("No database")),
)
}
};
let response = handle_kv_delete(db.kv(), &key);
@ -182,7 +197,12 @@ async fn kv_batch(
) -> impl IntoResponse {
let db = match get_default_database(&state) {
Some(db) => db,
None => return (StatusCode::NOT_FOUND, Json(ApiResponse::<KvBatchResponse>::error("No database"))),
None => {
return (
StatusCode::NOT_FOUND,
Json(ApiResponse::<KvBatchResponse>::error("No database")),
)
}
};
let response = handle_kv_batch(db.kv(), req);
@ -217,7 +237,9 @@ async fn list_databases(
})
.collect();
Json(ApiResponse::ok(ListDatabasesResponse { databases: response }))
Json(ApiResponse::ok(ListDatabasesResponse {
databases: response,
}))
}
async fn create_database(
@ -250,7 +272,10 @@ async fn create_database(
})),
)
}
Err(e) => (StatusCode::BAD_REQUEST, Json(ApiResponse::error(e.to_string()))),
Err(e) => (
StatusCode::BAD_REQUEST,
Json(ApiResponse::error(e.to_string())),
),
}
}
@ -302,12 +327,20 @@ async fn create_collection(
) -> impl IntoResponse {
let db = match get_database(&state, &db_name) {
Some(db) => db,
None => return (StatusCode::NOT_FOUND, Json(ApiResponse::<bool>::error("Database not found"))),
None => {
return (
StatusCode::NOT_FOUND,
Json(ApiResponse::<bool>::error("Database not found")),
)
}
};
match db.documents().create_collection(&req.name) {
Ok(_) => (StatusCode::CREATED, Json(ApiResponse::ok(true))),
Err(e) => (StatusCode::BAD_REQUEST, Json(ApiResponse::<bool>::error(e.to_string()))),
Err(e) => (
StatusCode::BAD_REQUEST,
Json(ApiResponse::<bool>::error(e.to_string())),
),
}
}
@ -349,15 +382,20 @@ async fn query_documents(
) -> impl IntoResponse {
let db = match get_database(&state, &db_name) {
Some(db) => db,
None => return (StatusCode::NOT_FOUND, Json(ApiResponse::<QueryDocumentsResponse>::error("Database not found"))),
None => {
return (
StatusCode::NOT_FOUND,
Json(ApiResponse::<QueryDocumentsResponse>::error(
"Database not found",
)),
)
}
};
state.record_read();
// Build query
let mut query = Query::new(&coll_name)
.skip(req.skip)
.limit(req.limit);
let mut query = Query::new(&coll_name).skip(req.skip).limit(req.limit);
// Add filter
if let Some(filter_json) = &req.filter {
@ -384,8 +422,14 @@ async fn query_documents(
}
match db.query().execute(&query) {
Ok(result) => (StatusCode::OK, Json(ApiResponse::ok(QueryDocumentsResponse::from(result)))),
Err(e) => (StatusCode::BAD_REQUEST, Json(ApiResponse::error(e.to_string()))),
Ok(result) => (
StatusCode::OK,
Json(ApiResponse::ok(QueryDocumentsResponse::from(result))),
),
Err(e) => (
StatusCode::BAD_REQUEST,
Json(ApiResponse::error(e.to_string())),
),
}
}
@ -396,15 +440,25 @@ async fn insert_document(
) -> impl IntoResponse {
let db = match get_database(&state, &db_name) {
Some(db) => db,
None => return (StatusCode::NOT_FOUND, Json(ApiResponse::<String>::error("Database not found"))),
None => {
return (
StatusCode::NOT_FOUND,
Json(ApiResponse::<String>::error("Database not found")),
)
}
};
let size = serde_json::to_vec(&req.document).map(|v| v.len()).unwrap_or(0);
let size = serde_json::to_vec(&req.document)
.map(|v| v.len())
.unwrap_or(0);
state.record_write(size as u64);
match db.documents().insert(&coll_name, req.document) {
Ok(id) => (StatusCode::CREATED, Json(ApiResponse::ok(id.to_hex()))),
Err(e) => (StatusCode::BAD_REQUEST, Json(ApiResponse::error(e.to_string()))),
Err(e) => (
StatusCode::BAD_REQUEST,
Json(ApiResponse::error(e.to_string())),
),
}
}
@ -415,7 +469,12 @@ async fn insert_many_documents(
) -> impl IntoResponse {
let db = match get_database(&state, &db_name) {
Some(db) => db,
None => return (StatusCode::NOT_FOUND, Json(ApiResponse::<Vec<String>>::error("Database not found"))),
None => {
return (
StatusCode::NOT_FOUND,
Json(ApiResponse::<Vec<String>>::error("Database not found")),
)
}
};
let mut ids = Vec::with_capacity(req.documents.len());
@ -425,7 +484,12 @@ async fn insert_many_documents(
total_size += serde_json::to_vec(&doc).map(|v| v.len()).unwrap_or(0) as u64;
match db.documents().insert(&coll_name, doc) {
Ok(id) => ids.push(id.to_hex()),
Err(e) => return (StatusCode::BAD_REQUEST, Json(ApiResponse::error(e.to_string()))),
Err(e) => {
return (
StatusCode::BAD_REQUEST,
Json(ApiResponse::error(e.to_string())),
)
}
}
}
@ -477,7 +541,9 @@ async fn update_document(
Err(e) => return Json(ApiResponse::error(e.to_string())),
};
let update_size = serde_json::to_vec(&req.update).map(|v| v.len()).unwrap_or(0);
let update_size = serde_json::to_vec(&req.update)
.map(|v| v.len())
.unwrap_or(0);
state.record_write(update_size as u64);
match db.documents().update_by_id(&coll_name, &id, req.update) {
@ -519,7 +585,12 @@ async fn insert_embeddings(
) -> impl IntoResponse {
let db = match get_database(&state, &db_name) {
Some(db) => db,
None => return (StatusCode::NOT_FOUND, Json(ApiResponse::<usize>::error("Database not found"))),
None => {
return (
StatusCode::NOT_FOUND,
Json(ApiResponse::<usize>::error("Database not found")),
)
}
};
let mut count = 0;
@ -533,7 +604,10 @@ async fn insert_embeddings(
}
if let Err(e) = db.vectors().insert(embedding) {
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error(e.to_string())));
return (
StatusCode::BAD_REQUEST,
Json(ApiResponse::error(e.to_string())),
);
}
count += 1;
}
@ -549,7 +623,14 @@ async fn vector_search(
) -> impl IntoResponse {
let db = match get_database(&state, &db_name) {
Some(db) => db,
None => return (StatusCode::NOT_FOUND, Json(ApiResponse::<VectorSearchResponse>::error("Database not found"))),
None => {
return (
StatusCode::NOT_FOUND,
Json(ApiResponse::<VectorSearchResponse>::error(
"Database not found",
)),
)
}
};
state.record_vector_search();
@ -563,9 +644,18 @@ async fn vector_search(
Ok(results) => {
let count = results.len();
let matches: Vec<VectorMatch> = results.into_iter().map(Into::into).collect();
(StatusCode::OK, Json(ApiResponse::ok(VectorSearchResponse { results: matches, count })))
(
StatusCode::OK,
Json(ApiResponse::ok(VectorSearchResponse {
results: matches,
count,
})),
)
}
Err(e) => (StatusCode::BAD_REQUEST, Json(ApiResponse::error(e.to_string()))),
Err(e) => (
StatusCode::BAD_REQUEST,
Json(ApiResponse::error(e.to_string())),
),
}
}

View file

@ -114,10 +114,7 @@ impl GatewayServer {
tracing::error!("Failed to create default database: {}", e);
}
let state = Arc::new(AppState::new(
self.db_manager.clone(),
self.auth.clone(),
));
let state = Arc::new(AppState::new(self.db_manager.clone(), self.auth.clone()));
let app = create_router(state);

View file

@ -86,7 +86,12 @@ pub struct Edge {
impl Edge {
/// Creates a new directed edge.
pub fn new(source: NodeId, target: NodeId, edge_type: impl Into<String>, properties: JsonValue) -> Self {
pub fn new(
source: NodeId,
target: NodeId,
edge_type: impl Into<String>,
properties: JsonValue,
) -> Self {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
@ -105,7 +110,12 @@ impl Edge {
}
/// Creates an undirected edge.
pub fn undirected(source: NodeId, target: NodeId, edge_type: impl Into<String>, properties: JsonValue) -> Self {
pub fn undirected(
source: NodeId,
target: NodeId,
edge_type: impl Into<String>,
properties: JsonValue,
) -> Self {
let mut edge = Self::new(source, target, edge_type, properties);
edge.directed = false;
edge
@ -138,8 +148,8 @@ impl Edge {
/// Checks if this edge connects two specific nodes.
pub fn connects_pair(&self, a: &NodeId, b: &NodeId) -> bool {
(&self.source == a && &self.target == b) ||
(!self.directed && &self.source == b && &self.target == a)
(&self.source == a && &self.target == b)
|| (!self.directed && &self.source == b && &self.target == a)
}
/// Gets a property value.
@ -156,7 +166,9 @@ impl Edge {
/// Checks if the edge matches a property filter.
pub fn matches_properties(&self, filter: &JsonValue) -> bool {
if let (Some(filter_obj), Some(props_obj)) = (filter.as_object(), self.properties.as_object()) {
if let (Some(filter_obj), Some(props_obj)) =
(filter.as_object(), self.properties.as_object())
{
for (key, expected) in filter_obj {
if let Some(actual) = props_obj.get(key) {
if actual != expected {
@ -216,7 +228,12 @@ impl EdgeBuilder {
/// Builds the edge.
pub fn build(self) -> Edge {
let mut edge = Edge::new(self.source, self.target, self.edge_type, JsonValue::Object(self.properties));
let mut edge = Edge::new(
self.source,
self.target,
self.edge_type,
JsonValue::Object(self.properties),
);
edge.directed = self.directed;
edge.weight = self.weight;
edge
@ -264,7 +281,10 @@ mod tests {
assert!(!edge.directed);
assert_eq!(edge.weight, 2.5);
assert_eq!(edge.get_property("percentage"), Some(&serde_json::json!(50)));
assert_eq!(
edge.get_property("percentage"),
Some(&serde_json::json!(50))
);
}
#[test]

View file

@ -172,7 +172,9 @@ impl Node {
/// Checks if the node matches a property filter.
pub fn matches_properties(&self, filter: &JsonValue) -> bool {
if let (Some(filter_obj), Some(props_obj)) = (filter.as_object(), self.properties.as_object()) {
if let (Some(filter_obj), Some(props_obj)) =
(filter.as_object(), self.properties.as_object())
{
for (key, expected) in filter_obj {
if let Some(actual) = props_obj.get(key) {
if actual != expected {
@ -258,10 +260,7 @@ mod tests {
assert!(node.has_label("User"));
assert!(!node.has_label("Admin"));
assert_eq!(
node.get_property("name"),
Some(&serde_json::json!("Alice"))
);
assert_eq!(node.get_property("name"), Some(&serde_json::json!("Alice")));
}
#[test]

View file

@ -60,7 +60,10 @@ impl Eq for DijkstraState {}
impl Ord for DijkstraState {
fn cmp(&self, other: &Self) -> Ordering {
// Reverse ordering for min-heap
other.distance.partial_cmp(&self.distance).unwrap_or(Ordering::Equal)
other
.distance
.partial_cmp(&self.distance)
.unwrap_or(Ordering::Equal)
}
}
@ -140,9 +143,16 @@ impl<'a> PathFinder<'a> {
let mut visited = HashSet::new();
distances.insert(*from, 0.0);
heap.push(DijkstraState { node: *from, distance: 0.0 });
heap.push(DijkstraState {
node: *from,
distance: 0.0,
});
while let Some(DijkstraState { node: current, distance: dist }) = heap.pop() {
while let Some(DijkstraState {
node: current,
distance: dist,
}) = heap.pop()
{
if &current == to {
// Reconstruct path
let mut path = vec![current];
@ -181,12 +191,18 @@ impl<'a> PathFinder<'a> {
}
let new_dist = dist + edge.weight;
let is_shorter = distances.get(&neighbor).map(|&d| new_dist < d).unwrap_or(true);
let is_shorter = distances
.get(&neighbor)
.map(|&d| new_dist < d)
.unwrap_or(true);
if is_shorter {
distances.insert(neighbor, new_dist);
previous.insert(neighbor, (current, edge.clone()));
heap.push(DijkstraState { node: neighbor, distance: new_dist });
heap.push(DijkstraState {
node: neighbor,
distance: new_dist,
});
}
}
}
@ -251,7 +267,9 @@ impl<'a> PathFinder<'a> {
path.push(neighbor);
edges.push(edge.clone());
self.find_all_paths_dfs(&neighbor, target, max_length, path, edges, visited, results);
self.find_all_paths_dfs(
&neighbor, target, max_length, path, edges, visited, results,
);
path.pop();
edges.pop();
@ -261,7 +279,12 @@ impl<'a> PathFinder<'a> {
}
/// Finds the shortest path considering only specific edge types.
pub fn shortest_path_by_type(&self, from: &NodeId, to: &NodeId, edge_types: &[String]) -> PathResult {
pub fn shortest_path_by_type(
&self,
from: &NodeId,
to: &NodeId,
edge_types: &[String],
) -> PathResult {
if from == to {
return PathResult::found(vec![*from], Vec::new(), 0.0);
}
@ -377,11 +400,21 @@ mod tests {
let d = store.create_node(vec![], serde_json::json!({"name": "D"}));
let e = store.create_node(vec![], serde_json::json!({"name": "E"}));
store.create_edge(a, b, "LINK", serde_json::json!({})).unwrap();
store.create_edge(b, c, "LINK", serde_json::json!({})).unwrap();
store.create_edge(c, d, "LINK", serde_json::json!({})).unwrap();
store.create_edge(a, e, "LINK", serde_json::json!({})).unwrap();
store.create_edge(e, d, "LINK", serde_json::json!({})).unwrap();
store
.create_edge(a, b, "LINK", serde_json::json!({}))
.unwrap();
store
.create_edge(b, c, "LINK", serde_json::json!({}))
.unwrap();
store
.create_edge(c, d, "LINK", serde_json::json!({}))
.unwrap();
store
.create_edge(a, e, "LINK", serde_json::json!({}))
.unwrap();
store
.create_edge(e, d, "LINK", serde_json::json!({}))
.unwrap();
store
}
@ -392,8 +425,14 @@ mod tests {
let finder = PathFinder::new(&store);
let nodes: Vec<_> = store.find_nodes(None, &serde_json::json!({}));
let a = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("A"))).unwrap();
let d = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("D"))).unwrap();
let a = nodes
.iter()
.find(|n| n.get_property("name") == Some(&serde_json::json!("A")))
.unwrap();
let d = nodes
.iter()
.find(|n| n.get_property("name") == Some(&serde_json::json!("D")))
.unwrap();
let result = finder.shortest_path_bfs(&a.id, &d.id);
@ -413,13 +452,19 @@ mod tests {
// A --(3.0)--> C
let mut edge1 = super::super::edge::Edge::new(a, b, "LINK", serde_json::json!({}));
edge1.weight = 1.0;
store.create_edge(a, b, "LINK", serde_json::json!({})).unwrap();
store
.create_edge(a, b, "LINK", serde_json::json!({}))
.unwrap();
let mut edge2 = super::super::edge::Edge::new(b, c, "LINK", serde_json::json!({}));
edge2.weight = 1.0;
store.create_edge(b, c, "LINK", serde_json::json!({})).unwrap();
store
.create_edge(b, c, "LINK", serde_json::json!({}))
.unwrap();
store.create_edge(a, c, "DIRECT", serde_json::json!({})).unwrap();
store
.create_edge(a, c, "DIRECT", serde_json::json!({}))
.unwrap();
let finder = PathFinder::new(&store);
let result = finder.shortest_path_dijkstra(&a, &c);
@ -449,8 +494,14 @@ mod tests {
let finder = PathFinder::new(&store);
let nodes: Vec<_> = store.find_nodes(None, &serde_json::json!({}));
let a = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("A"))).unwrap();
let d = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("D"))).unwrap();
let a = nodes
.iter()
.find(|n| n.get_property("name") == Some(&serde_json::json!("A")))
.unwrap();
let d = nodes
.iter()
.find(|n| n.get_property("name") == Some(&serde_json::json!("D")))
.unwrap();
let paths = finder.all_paths(&a.id, &d.id, 5);
@ -463,8 +514,14 @@ mod tests {
let finder = PathFinder::new(&store);
let nodes: Vec<_> = store.find_nodes(None, &serde_json::json!({}));
let a = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("A"))).unwrap();
let d = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("D"))).unwrap();
let a = nodes
.iter()
.find(|n| n.get_property("name") == Some(&serde_json::json!("A")))
.unwrap();
let d = nodes
.iter()
.find(|n| n.get_property("name") == Some(&serde_json::json!("D")))
.unwrap();
assert!(finder.path_exists(&a.id, &d.id));
}
@ -475,9 +532,18 @@ mod tests {
let finder = PathFinder::new(&store);
let nodes: Vec<_> = store.find_nodes(None, &serde_json::json!({}));
let a = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("A"))).unwrap();
let d = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("D"))).unwrap();
let b = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("B"))).unwrap();
let a = nodes
.iter()
.find(|n| n.get_property("name") == Some(&serde_json::json!("A")))
.unwrap();
let d = nodes
.iter()
.find(|n| n.get_property("name") == Some(&serde_json::json!("D")))
.unwrap();
let b = nodes
.iter()
.find(|n| n.get_property("name") == Some(&serde_json::json!("B")))
.unwrap();
assert_eq!(finder.distance(&a.id, &b.id), Some(1));
assert_eq!(finder.distance(&a.id, &d.id), Some(2)); // A -> E -> D

View file

@ -23,7 +23,10 @@ pub enum GraphQuery {
/// DELETE query for removing nodes/edges.
Delete { variable: String, detach: bool },
/// SET query for updating properties.
Set { variable: String, properties: JsonValue },
Set {
variable: String,
properties: JsonValue,
},
}
/// Pattern to match in the graph.
@ -78,15 +81,35 @@ pub enum RelationshipDirection {
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum WhereClause {
/// Property comparison.
PropertyEquals { variable: String, property: String, value: JsonValue },
PropertyEquals {
variable: String,
property: String,
value: JsonValue,
},
/// Property comparison (not equals).
PropertyNotEquals { variable: String, property: String, value: JsonValue },
PropertyNotEquals {
variable: String,
property: String,
value: JsonValue,
},
/// Property greater than.
PropertyGt { variable: String, property: String, value: JsonValue },
PropertyGt {
variable: String,
property: String,
value: JsonValue,
},
/// Property less than.
PropertyLt { variable: String, property: String, value: JsonValue },
PropertyLt {
variable: String,
property: String,
value: JsonValue,
},
/// Property contains (for text).
PropertyContains { variable: String, property: String, value: String },
PropertyContains {
variable: String,
property: String,
value: String,
},
/// AND condition.
And(Box<WhereClause>, Box<WhereClause>),
/// OR condition.
@ -105,7 +128,10 @@ pub enum ReturnItem {
/// Return a property of a variable.
Property { variable: String, property: String },
/// Return with an alias.
Alias { item: Box<ReturnItem>, alias: String },
Alias {
item: Box<ReturnItem>,
alias: String,
},
/// Count aggregation.
Count(Option<String>),
}
@ -114,7 +140,11 @@ pub enum ReturnItem {
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum CreateElement {
/// Create a node.
Node { variable: Option<String>, labels: Vec<String>, properties: JsonValue },
Node {
variable: Option<String>,
labels: Vec<String>,
properties: JsonValue,
},
/// Create a relationship.
Relationship {
from_var: String,
@ -176,7 +206,10 @@ impl GraphQueryParser {
} else if upper.starts_with("SET") {
Self::parse_set(query)
} else {
Err(GraphError::InvalidOperation(format!("Unknown query type: {}", query)))
Err(GraphError::InvalidOperation(format!(
"Unknown query type: {}",
query
)))
}
}
@ -185,7 +218,10 @@ impl GraphQueryParser {
let upper = query.to_uppercase();
// Find MATCH, WHERE, RETURN, LIMIT positions
let match_end = upper.find("WHERE").or_else(|| upper.find("RETURN")).unwrap_or(query.len());
let match_end = upper
.find("WHERE")
.or_else(|| upper.find("RETURN"))
.unwrap_or(query.len());
let where_start = upper.find("WHERE");
let return_start = upper.find("RETURN");
let limit_start = upper.find("LIMIT");
@ -253,7 +289,9 @@ impl GraphQueryParser {
}
if nodes.is_empty() {
return Err(GraphError::InvalidOperation("No node pattern found".to_string()));
return Err(GraphError::InvalidOperation(
"No node pattern found".to_string(),
));
}
// Combine nodes with relationships
@ -264,10 +302,15 @@ impl GraphQueryParser {
}
}
Ok(MatchPattern { start, relationships })
Ok(MatchPattern {
start,
relationships,
})
}
fn parse_node_pattern(chars: &mut std::iter::Peekable<std::str::Chars>) -> Result<NodePattern, GraphError> {
fn parse_node_pattern(
chars: &mut std::iter::Peekable<std::str::Chars>,
) -> Result<NodePattern, GraphError> {
// Consume '('
chars.next();
@ -335,10 +378,16 @@ impl GraphQueryParser {
}
}
Ok(NodePattern { variable, labels, properties })
Ok(NodePattern {
variable,
labels,
properties,
})
}
fn parse_relationship_pattern(chars: &mut std::iter::Peekable<std::str::Chars>) -> Result<RelationshipPattern, GraphError> {
fn parse_relationship_pattern(
chars: &mut std::iter::Peekable<std::str::Chars>,
) -> Result<RelationshipPattern, GraphError> {
let mut direction = RelationshipDirection::Undirected;
let mut edge_type = None;
let mut variable = None;
@ -408,7 +457,11 @@ impl GraphQueryParser {
variable,
edge_type,
direction,
target: NodePattern { variable: None, labels: Vec::new(), properties: None },
target: NodePattern {
variable: None,
labels: Vec::new(),
properties: None,
},
min_hops,
max_hops,
})
@ -476,7 +529,9 @@ impl GraphQueryParser {
elements.push(CreateElement::Node {
variable: node.variable,
labels: node.labels,
properties: node.properties.unwrap_or(JsonValue::Object(serde_json::Map::new())),
properties: node
.properties
.unwrap_or(JsonValue::Object(serde_json::Map::new())),
});
} else {
break;
@ -488,7 +543,11 @@ impl GraphQueryParser {
fn parse_delete(query: &str) -> Result<GraphQuery, GraphError> {
let detach = query.to_uppercase().starts_with("DETACH");
let start = if detach { "DETACH DELETE".len() } else { "DELETE".len() };
let start = if detach {
"DETACH DELETE".len()
} else {
"DELETE".len()
};
let variable = query[start..].trim().to_string();
Ok(GraphQuery::Delete { variable, detach })
@ -500,19 +559,24 @@ impl GraphQueryParser {
let parts: Vec<_> = content.split('=').collect();
if parts.len() != 2 {
return Err(GraphError::InvalidOperation("Invalid SET syntax".to_string()));
return Err(GraphError::InvalidOperation(
"Invalid SET syntax".to_string(),
));
}
let var_prop: Vec<_> = parts[0].trim().split('.').collect();
if var_prop.len() != 2 {
return Err(GraphError::InvalidOperation("Invalid SET variable".to_string()));
return Err(GraphError::InvalidOperation(
"Invalid SET variable".to_string(),
));
}
let variable = var_prop[0].to_string();
let property = var_prop[1].to_string();
let value_str = parts[1].trim();
let value: JsonValue = serde_json::from_str(value_str).unwrap_or(JsonValue::String(value_str.to_string()));
let value: JsonValue =
serde_json::from_str(value_str).unwrap_or(JsonValue::String(value_str.to_string()));
Ok(GraphQuery::Set {
variable,
@ -535,18 +599,21 @@ impl<'a> GraphQueryExecutor<'a> {
/// Executes a graph query.
pub fn execute(&self, query: &GraphQuery) -> Result<QueryResult, GraphError> {
match query {
GraphQuery::Match { pattern, where_clause, return_items, limit } => {
self.execute_match(pattern, where_clause.as_ref(), return_items, *limit)
}
GraphQuery::Create { .. } => {
Err(GraphError::InvalidOperation("CREATE requires mutable access".to_string()))
}
GraphQuery::Delete { .. } => {
Err(GraphError::InvalidOperation("DELETE requires mutable access".to_string()))
}
GraphQuery::Set { .. } => {
Err(GraphError::InvalidOperation("SET requires mutable access".to_string()))
}
GraphQuery::Match {
pattern,
where_clause,
return_items,
limit,
} => self.execute_match(pattern, where_clause.as_ref(), return_items, *limit),
GraphQuery::Create { .. } => Err(GraphError::InvalidOperation(
"CREATE requires mutable access".to_string(),
)),
GraphQuery::Delete { .. } => Err(GraphError::InvalidOperation(
"DELETE requires mutable access".to_string(),
)),
GraphQuery::Set { .. } => Err(GraphError::InvalidOperation(
"SET requires mutable access".to_string(),
)),
}
}
@ -586,7 +653,11 @@ impl<'a> GraphQueryExecutor<'a> {
.depth(rel_pattern.max_hops)
.direction(direction)
.edge_types(
rel_pattern.edge_type.clone().map(|t| vec![t]).unwrap_or_default(),
rel_pattern
.edge_type
.clone()
.map(|t| vec![t])
.unwrap_or_default(),
)
.labels(rel_pattern.target.labels.clone());
@ -635,7 +706,10 @@ impl<'a> GraphQueryExecutor<'a> {
fn find_matching_nodes(&self, pattern: &NodePattern) -> Vec<Node> {
let label = pattern.labels.first().map(|s| s.as_str());
let filter = pattern.properties.clone().unwrap_or(JsonValue::Object(serde_json::Map::new()));
let filter = pattern
.properties
.clone()
.unwrap_or(JsonValue::Object(serde_json::Map::new()));
self.store.find_nodes(label, &filter)
}
@ -657,7 +731,11 @@ impl<'a> GraphQueryExecutor<'a> {
})
}
fn get_column_names(&self, return_items: &[ReturnItem], bindings: &[HashMap<String, JsonValue>]) -> Vec<String> {
fn get_column_names(
&self,
return_items: &[ReturnItem],
bindings: &[HashMap<String, JsonValue>],
) -> Vec<String> {
let mut columns = Vec::new();
for item in return_items {
@ -673,7 +751,10 @@ impl<'a> GraphQueryExecutor<'a> {
}
ReturnItem::Alias { alias, .. } => columns.push(alias.clone()),
ReturnItem::Count(var) => {
columns.push(format!("count({})", var.as_ref().map(|s| s.as_str()).unwrap_or("*")));
columns.push(format!(
"count({})",
var.as_ref().map(|s| s.as_str()).unwrap_or("*")
));
}
}
}
@ -681,11 +762,18 @@ impl<'a> GraphQueryExecutor<'a> {
columns
}
fn extract_rows(&self, return_items: &[ReturnItem], bindings: &[HashMap<String, JsonValue>]) -> Vec<Vec<JsonValue>> {
fn extract_rows(
&self,
return_items: &[ReturnItem],
bindings: &[HashMap<String, JsonValue>],
) -> Vec<Vec<JsonValue>> {
let mut rows = Vec::new();
// Handle COUNT specially
if return_items.iter().any(|i| matches!(i, ReturnItem::Count(_))) {
if return_items
.iter()
.any(|i| matches!(i, ReturnItem::Count(_)))
{
rows.push(vec![JsonValue::Number(bindings.len().into())]);
return rows;
}
@ -760,8 +848,14 @@ mod tests {
if let GraphQuery::Match { pattern, .. } = parsed {
assert_eq!(pattern.start.labels, vec!["User".to_string()]);
assert_eq!(pattern.relationships.len(), 1);
assert_eq!(pattern.relationships[0].edge_type, Some("FRIEND".to_string()));
assert_eq!(pattern.relationships[0].direction, RelationshipDirection::Outgoing);
assert_eq!(
pattern.relationships[0].edge_type,
Some("FRIEND".to_string())
);
assert_eq!(
pattern.relationships[0].direction,
RelationshipDirection::Outgoing
);
} else {
panic!("Expected Match query");
}
@ -771,9 +865,14 @@ mod tests {
fn test_execute_match() {
let store = GraphStore::new();
let alice = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Alice"}));
let alice = store.create_node(
vec!["User".to_string()],
serde_json::json!({"name": "Alice"}),
);
let bob = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Bob"}));
store.create_edge(alice, bob, "FRIEND", serde_json::json!({})).unwrap();
store
.create_edge(alice, bob, "FRIEND", serde_json::json!({}))
.unwrap();
let query = GraphQueryParser::parse("MATCH (n:User) RETURN n").unwrap();
let executor = GraphQueryExecutor::new(&store);

View file

@ -177,18 +177,8 @@ impl GraphStore {
/// Deletes a node and all its connected edges.
pub fn delete_node(&self, id: &NodeId) -> Result<(), GraphError> {
// Get connected edges
let outgoing: Vec<EdgeId> = self
.adjacency
.read()
.get(id)
.cloned()
.unwrap_or_default();
let incoming: Vec<EdgeId> = self
.reverse_adj
.read()
.get(id)
.cloned()
.unwrap_or_default();
let outgoing: Vec<EdgeId> = self.adjacency.read().get(id).cloned().unwrap_or_default();
let incoming: Vec<EdgeId> = self.reverse_adj.read().get(id).cloned().unwrap_or_default();
// Delete all connected edges
for edge_id in outgoing.iter().chain(incoming.iter()) {
@ -457,7 +447,12 @@ impl GraphStore {
}
/// Gets the neighbor node from an edge.
fn get_neighbor_from_edge(&self, edge: &Edge, from: &NodeId, direction: Direction) -> Option<NodeId> {
fn get_neighbor_from_edge(
&self,
edge: &Edge,
from: &NodeId,
direction: Direction,
) -> Option<NodeId> {
match direction {
Direction::Outgoing => {
if &edge.source == from {
@ -491,7 +486,12 @@ impl GraphStore {
}
/// Gets neighbors connected by a specific edge type.
pub fn neighbors_by_type(&self, id: &NodeId, edge_type: &str, direction: Direction) -> Vec<Node> {
pub fn neighbors_by_type(
&self,
id: &NodeId,
edge_type: &str,
direction: Direction,
) -> Vec<Node> {
let edges = self.edges_of(id, direction);
let nodes = self.nodes.read();
@ -565,7 +565,10 @@ mod tests {
fn test_create_edge() {
let store = GraphStore::new();
let alice = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Alice"}));
let alice = store.create_node(
vec!["User".to_string()],
serde_json::json!({"name": "Alice"}),
);
let bob = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Bob"}));
let edge_id = store
@ -582,12 +585,22 @@ mod tests {
fn test_neighbors() {
let store = GraphStore::new();
let alice = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Alice"}));
let alice = store.create_node(
vec!["User".to_string()],
serde_json::json!({"name": "Alice"}),
);
let bob = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Bob"}));
let charlie = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Charlie"}));
let charlie = store.create_node(
vec!["User".to_string()],
serde_json::json!({"name": "Charlie"}),
);
store.create_edge(alice, bob, "FRIEND", serde_json::json!({})).unwrap();
store.create_edge(alice, charlie, "FRIEND", serde_json::json!({})).unwrap();
store
.create_edge(alice, bob, "FRIEND", serde_json::json!({}))
.unwrap();
store
.create_edge(alice, charlie, "FRIEND", serde_json::json!({}))
.unwrap();
let neighbors = store.neighbors(&alice, Direction::Outgoing);
assert_eq!(neighbors.len(), 2);
@ -597,9 +610,15 @@ mod tests {
fn test_find_by_label() {
let store = GraphStore::new();
store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Alice"}));
store.create_node(
vec!["User".to_string()],
serde_json::json!({"name": "Alice"}),
);
store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Bob"}));
store.create_node(vec!["Product".to_string()], serde_json::json!({"name": "Widget"}));
store.create_node(
vec!["Product".to_string()],
serde_json::json!({"name": "Widget"}),
);
let users = store.find_nodes_by_label("User");
assert_eq!(users.len(), 2);
@ -615,7 +634,9 @@ mod tests {
let alice = store.create_node(vec!["User".to_string()], serde_json::json!({}));
let bob = store.create_node(vec!["User".to_string()], serde_json::json!({}));
store.create_edge(alice, bob, "FRIEND", serde_json::json!({})).unwrap();
store
.create_edge(alice, bob, "FRIEND", serde_json::json!({}))
.unwrap();
// Delete Alice - should also delete the edge
store.delete_node(&alice).unwrap();
@ -631,7 +652,9 @@ mod tests {
let a = store.create_node(vec![], serde_json::json!({}));
let b = store.create_node(vec![], serde_json::json!({}));
store.create_undirected_edge(a, b, "LINK", serde_json::json!({})).unwrap();
store
.create_undirected_edge(a, b, "LINK", serde_json::json!({}))
.unwrap();
// Both directions should work
let a_neighbors = store.neighbors(&a, Direction::Outgoing);
@ -648,8 +671,12 @@ mod tests {
let a = store.create_node(vec![], serde_json::json!({}));
let b = store.create_node(vec![], serde_json::json!({}));
store.create_edge(a, b, "TYPE_A", serde_json::json!({})).unwrap();
store.create_edge(a, b, "TYPE_B", serde_json::json!({})).unwrap();
store
.create_edge(a, b, "TYPE_A", serde_json::json!({}))
.unwrap();
store
.create_edge(a, b, "TYPE_B", serde_json::json!({}))
.unwrap();
let edges = store.edges_between(&a, &b);
assert_eq!(edges.len(), 2);

View file

@ -171,7 +171,9 @@ impl<'a> Traverser<'a> {
for edge in edges {
// Check edge type filter
if !query.edge_types.is_empty() && !query.edge_types.contains(&edge.edge_type) {
if !query.edge_types.is_empty()
&& !query.edge_types.contains(&edge.edge_type)
{
continue;
}
@ -395,18 +397,35 @@ mod tests {
fn setup_social_graph() -> GraphStore {
let store = GraphStore::new();
let alice = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Alice"}));
let alice = store.create_node(
vec!["User".to_string()],
serde_json::json!({"name": "Alice"}),
);
let bob = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Bob"}));
let charlie = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Charlie"}));
let dave = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Dave"}));
let charlie = store.create_node(
vec!["User".to_string()],
serde_json::json!({"name": "Charlie"}),
);
let dave = store.create_node(
vec!["User".to_string()],
serde_json::json!({"name": "Dave"}),
);
// Alice -> Bob -> Charlie -> Dave
store.create_edge(alice, bob, "FRIEND", serde_json::json!({})).unwrap();
store.create_edge(bob, charlie, "FRIEND", serde_json::json!({})).unwrap();
store.create_edge(charlie, dave, "FRIEND", serde_json::json!({})).unwrap();
store
.create_edge(alice, bob, "FRIEND", serde_json::json!({}))
.unwrap();
store
.create_edge(bob, charlie, "FRIEND", serde_json::json!({}))
.unwrap();
store
.create_edge(charlie, dave, "FRIEND", serde_json::json!({}))
.unwrap();
// Alice -> Charlie (shortcut)
store.create_edge(alice, charlie, "KNOWS", serde_json::json!({})).unwrap();
store
.create_edge(alice, charlie, "KNOWS", serde_json::json!({}))
.unwrap();
store
}
@ -417,7 +436,10 @@ mod tests {
let traverser = Traverser::new(&store);
let users = store.find_nodes_by_label("User");
let alice = users.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("Alice"))).unwrap();
let alice = users
.iter()
.find(|n| n.get_property("name") == Some(&serde_json::json!("Alice")))
.unwrap();
let query = TraversalQuery::new().depth(2);
let results = traverser.traverse(&alice.id, &query);
@ -432,7 +454,10 @@ mod tests {
let traverser = Traverser::new(&store);
let users = store.find_nodes_by_label("User");
let alice = users.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("Alice"))).unwrap();
let alice = users
.iter()
.find(|n| n.get_property("name") == Some(&serde_json::json!("Alice")))
.unwrap();
let query = TraversalQuery::new()
.depth(2)
@ -440,7 +465,10 @@ mod tests {
let results = traverser.traverse(&alice.id, &query);
// Following only FRIEND edges: Alice -> Bob -> Charlie
let names: Vec<_> = results.iter().filter_map(|r| r.node.get_property("name")).collect();
let names: Vec<_> = results
.iter()
.filter_map(|r| r.node.get_property("name"))
.collect();
assert!(names.contains(&&serde_json::json!("Bob")));
assert!(names.contains(&&serde_json::json!("Charlie")));
}
@ -451,7 +479,10 @@ mod tests {
let traverser = Traverser::new(&store);
let users = store.find_nodes_by_label("User");
let alice = users.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("Alice"))).unwrap();
let alice = users
.iter()
.find(|n| n.get_property("name") == Some(&serde_json::json!("Alice")))
.unwrap();
let query = TraversalQuery::new().depth(1);
let results = traverser.traverse(&alice.id, &query);
@ -468,7 +499,10 @@ mod tests {
let traverser = Traverser::new(&store);
let users = store.find_nodes_by_label("User");
let alice = users.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("Alice"))).unwrap();
let alice = users
.iter()
.find(|n| n.get_property("name") == Some(&serde_json::json!("Alice")))
.unwrap();
let query = TraversalQuery::new().depth(10).limit(2);
let results = traverser.traverse(&alice.id, &query);
@ -486,11 +520,21 @@ mod tests {
let mutual2 = store.create_node(vec![], serde_json::json!({"name": "Mutual2"}));
let only_alice = store.create_node(vec![], serde_json::json!({"name": "OnlyAlice"}));
store.create_edge(alice, mutual1, "FRIEND", serde_json::json!({})).unwrap();
store.create_edge(alice, mutual2, "FRIEND", serde_json::json!({})).unwrap();
store.create_edge(alice, only_alice, "FRIEND", serde_json::json!({})).unwrap();
store.create_edge(bob, mutual1, "FRIEND", serde_json::json!({})).unwrap();
store.create_edge(bob, mutual2, "FRIEND", serde_json::json!({})).unwrap();
store
.create_edge(alice, mutual1, "FRIEND", serde_json::json!({}))
.unwrap();
store
.create_edge(alice, mutual2, "FRIEND", serde_json::json!({}))
.unwrap();
store
.create_edge(alice, only_alice, "FRIEND", serde_json::json!({}))
.unwrap();
store
.create_edge(bob, mutual1, "FRIEND", serde_json::json!({}))
.unwrap();
store
.create_edge(bob, mutual2, "FRIEND", serde_json::json!({}))
.unwrap();
let traverser = Traverser::new(&store);
let mutual = traverser.mutual_connections(&alice, &bob, Some("FRIEND"));

View file

@ -174,17 +174,24 @@ impl Index {
// Check uniqueness if required
if self.config.unique {
let exists = match self.config.index_type {
IndexType::Hash | IndexType::Unique => {
self.hash.read().get(&key).map(|s| !s.is_empty()).unwrap_or(false)
}
_ => {
self.btree.read().get(&key).map(|s| !s.is_empty()).unwrap_or(false)
}
IndexType::Hash | IndexType::Unique => self
.hash
.read()
.get(&key)
.map(|s| !s.is_empty())
.unwrap_or(false),
_ => self
.btree
.read()
.get(&key)
.map(|s| !s.is_empty())
.unwrap_or(false),
};
if exists {
return Err(DatabaseError::AlreadyExists(
format!("Unique constraint violation on index '{}'", self.config.name)
));
return Err(DatabaseError::AlreadyExists(format!(
"Unique constraint violation on index '{}'",
self.config.name
)));
}
}
@ -239,20 +246,18 @@ impl Index {
self.stats.write().lookups += 1;
let result: Vec<DocumentId> = match self.config.index_type {
IndexType::Hash | IndexType::Unique => {
self.hash
.read()
.get(&key)
.map(|s| s.iter().cloned().collect())
.unwrap_or_default()
}
_ => {
self.btree
.read()
.get(&key)
.map(|s| s.iter().cloned().collect())
.unwrap_or_default()
}
IndexType::Hash | IndexType::Unique => self
.hash
.read()
.get(&key)
.map(|s| s.iter().cloned().collect())
.unwrap_or_default(),
_ => self
.btree
.read()
.get(&key)
.map(|s| s.iter().cloned().collect())
.unwrap_or_default(),
};
if !result.is_empty() {
@ -407,12 +412,7 @@ impl IndexManager {
}
/// Removes a document from indexes.
pub fn unindex_document(
&self,
collection: &str,
doc_id: &DocumentId,
document: &JsonValue,
) {
pub fn unindex_document(&self, collection: &str, doc_id: &DocumentId, document: &JsonValue) {
let index_names = self.get_collection_indexes(collection);
let indexes = self.indexes.read();
@ -483,7 +483,9 @@ mod tests {
let index = Index::new(config);
let doc1 = DocumentId::new();
index.insert(doc1.clone(), &json!("alice@example.com")).unwrap();
index
.insert(doc1.clone(), &json!("alice@example.com"))
.unwrap();
let results = index.lookup(&json!("alice@example.com"));
assert_eq!(results.len(), 1);
@ -521,7 +523,9 @@ mod tests {
let doc_id = DocumentId::new();
let doc = json!({"name": "Alice", "age": 30});
manager.index_document("users", doc_id.clone(), &doc).unwrap();
manager
.index_document("users", doc_id.clone(), &doc)
.unwrap();
let indexes = manager.list_indexes();
assert_eq!(indexes.len(), 1);

View file

@ -126,8 +126,7 @@ impl KeyValueStore {
/// Gets a value as string.
pub fn get_string(&self, key: &str) -> Option<String> {
self.get(key)
.and_then(|v| String::from_utf8(v).ok())
self.get(key).and_then(|v| String::from_utf8(v).ok())
}
/// Sets a value with optional TTL.
@ -224,8 +223,9 @@ impl KeyValueStore {
} else {
let s = String::from_utf8(entry.value.clone())
.map_err(|_| DatabaseError::InvalidOperation("Value is not a string".into()))?;
s.parse::<i64>()
.map_err(|_| DatabaseError::InvalidOperation("Value is not an integer".into()))?
s.parse::<i64>().map_err(|_| {
DatabaseError::InvalidOperation("Value is not an integer".into())
})?
}
} else {
0
@ -243,9 +243,9 @@ impl KeyValueStore {
pub fn append(&self, key: &str, value: &[u8]) -> Result<usize, DatabaseError> {
let mut data = self.data.write();
let entry = data.entry(key.to_string()).or_insert_with(|| {
KvEntry::new(Vec::new(), 0)
});
let entry = data
.entry(key.to_string())
.or_insert_with(|| KvEntry::new(Vec::new(), 0));
if entry.is_expired() {
entry.value.clear();
@ -393,11 +393,16 @@ mod tests {
fn test_mget_mset() {
let store = KeyValueStore::new();
store.mset(&[
("k1", b"v1".to_vec()),
("k2", b"v2".to_vec()),
("k3", b"v3".to_vec()),
], 0).unwrap();
store
.mset(
&[
("k1", b"v1".to_vec()),
("k2", b"v2".to_vec()),
("k3", b"v3".to_vec()),
],
0,
)
.unwrap();
let results = store.mget(&["k1", "k2", "k4"]);
assert_eq!(results.len(), 3);

View file

@ -65,12 +65,14 @@ pub use graph::{
pub use index::{Index, IndexConfig, IndexManager, IndexType};
pub use keyvalue::{KeyValue, KeyValueStore, KvEntry};
pub use query::{Filter, Query, QueryEngine, QueryResult, SortOrder};
pub use schema::{Field, FieldType, Schema, SchemaValidator};
pub use replication::{
ClusterConfig, Command as RaftCommand, NodeRole, RaftConfig, RaftEvent, RaftNode, RaftState,
ReplicatedLog,
};
pub use sql::{QueryResult as SqlQueryResult, SqlEngine, SqlParser, SqlType, SqlValue, Table, TableDef};
pub use schema::{Field, FieldType, Schema, SchemaValidator};
pub use sql::{
QueryResult as SqlQueryResult, SqlEngine, SqlParser, SqlType, SqlValue, Table, TableDef,
};
pub use timeseries::{DataPoint, Metric, TimeSeries, TimeSeriesStore};
pub use vector::{Embedding, SimilarityMetric, VectorIndex, VectorStore};

View file

@ -419,10 +419,7 @@ impl QueryEngine {
let values: Vec<f64> = docs
.iter()
.filter_map(|doc| {
doc.get(field)
.and_then(|v| v.as_f64())
})
.filter_map(|doc| doc.get(field).and_then(|v| v.as_f64()))
.collect();
let result = match op {
@ -439,22 +436,18 @@ impl QueryEngine {
serde_json::to_value(avg).unwrap_or(JsonValue::Null)
}
}
AggregateOp::Min => {
values
.iter()
.copied()
.min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|v| serde_json::to_value(v).unwrap_or(JsonValue::Null))
.unwrap_or(JsonValue::Null)
}
AggregateOp::Max => {
values
.iter()
.copied()
.max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|v| serde_json::to_value(v).unwrap_or(JsonValue::Null))
.unwrap_or(JsonValue::Null)
}
AggregateOp::Min => values
.iter()
.copied()
.min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|v| serde_json::to_value(v).unwrap_or(JsonValue::Null))
.unwrap_or(JsonValue::Null),
AggregateOp::Max => values
.iter()
.copied()
.max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|v| serde_json::to_value(v).unwrap_or(JsonValue::Null))
.unwrap_or(JsonValue::Null),
};
Ok(result)
@ -507,8 +500,10 @@ mod tests {
fn test_simple_query() {
let docs = Arc::new(DocumentStore::new());
docs.create_collection("users").unwrap();
docs.insert("users", json!({"name": "Alice", "age": 30})).unwrap();
docs.insert("users", json!({"name": "Bob", "age": 25})).unwrap();
docs.insert("users", json!({"name": "Alice", "age": 30}))
.unwrap();
docs.insert("users", json!({"name": "Bob", "age": 25}))
.unwrap();
let vectors = Arc::new(VectorStore::new(3));
let indexes = Arc::new(IndexManager::new());
@ -525,8 +520,10 @@ mod tests {
fn test_filter_query() {
let docs = Arc::new(DocumentStore::new());
docs.create_collection("users").unwrap();
docs.insert("users", json!({"name": "Alice", "age": 30})).unwrap();
docs.insert("users", json!({"name": "Bob", "age": 25})).unwrap();
docs.insert("users", json!({"name": "Alice", "age": 30}))
.unwrap();
docs.insert("users", json!({"name": "Bob", "age": 25}))
.unwrap();
let vectors = Arc::new(VectorStore::new(3));
let indexes = Arc::new(IndexManager::new());
@ -543,9 +540,12 @@ mod tests {
fn test_sorted_query() {
let docs = Arc::new(DocumentStore::new());
docs.create_collection("users").unwrap();
docs.insert("users", json!({"name": "Alice", "age": 30})).unwrap();
docs.insert("users", json!({"name": "Bob", "age": 25})).unwrap();
docs.insert("users", json!({"name": "Charlie", "age": 35})).unwrap();
docs.insert("users", json!({"name": "Alice", "age": 30}))
.unwrap();
docs.insert("users", json!({"name": "Bob", "age": 25}))
.unwrap();
docs.insert("users", json!({"name": "Charlie", "age": 35}))
.unwrap();
let vectors = Arc::new(VectorStore::new(3));
let indexes = Arc::new(IndexManager::new());

View file

@ -92,12 +92,7 @@ impl Election {
/// Creates a RequestVote message for this election.
pub fn create_request(&self, log: &ReplicatedLog) -> RequestVote {
RequestVote::new(
self.term,
self.node_id,
log.last_index(),
log.last_term(),
)
RequestVote::new(self.term, self.node_id, log.last_index(), log.last_term())
}
/// Checks the current result of the election.
@ -217,8 +212,8 @@ impl Default for ElectionTimeout {
#[cfg(test)]
mod tests {
use super::*;
use crate::replication::state::Command;
use crate::replication::log::LogEntry;
use crate::replication::state::Command;
#[test]
fn test_election_basic() {

View file

@ -139,7 +139,10 @@ impl ReplicatedLog {
let entries = self.entries.read();
let offset = (from_index - start) as usize;
entries.get(offset..).map(|s| s.to_vec()).unwrap_or_default()
entries
.get(offset..)
.map(|s| s.to_vec())
.unwrap_or_default()
}
/// Appends an entry to the log.
@ -151,7 +154,12 @@ impl ReplicatedLog {
}
/// Appends multiple entries, potentially overwriting conflicting entries.
pub fn append_entries(&self, prev_index: u64, prev_term: u64, new_entries: Vec<LogEntry>) -> bool {
pub fn append_entries(
&self,
prev_index: u64,
prev_term: u64,
new_entries: Vec<LogEntry>,
) -> bool {
// Check that prev entry matches
if prev_index > 0 {
if let Some(prev_entry_term) = self.term_at(prev_index) {
@ -245,7 +253,11 @@ impl ReplicatedLog {
}
/// Creates entries for replication starting from a given index.
pub fn entries_for_replication(&self, from_index: u64, max_entries: usize) -> (u64, u64, Vec<LogEntry>) {
pub fn entries_for_replication(
&self,
from_index: u64,
max_entries: usize,
) -> (u64, u64, Vec<LogEntry>) {
let prev_index = from_index.saturating_sub(1);
let prev_term = self.term_at(prev_index).unwrap_or(0);

View file

@ -222,7 +222,11 @@ impl RaftNode {
// Create new election
let cluster_size = self.cluster.voting_members();
self.election = Some(Election::new(self.id, self.state.current_term, cluster_size));
self.election = Some(Election::new(
self.id,
self.state.current_term,
cluster_size,
));
// Create RequestVote message
let request = RequestVote::new(
@ -295,9 +299,9 @@ impl RaftNode {
return;
}
let (prev_log_index, prev_log_term, entries) =
self.log
.entries_for_replication(next_index, self.config.max_entries_per_rpc);
let (prev_log_index, prev_log_term, entries) = self
.log
.entries_for_replication(next_index, self.config.max_entries_per_rpc);
let request = AppendEntries::with_entries(
self.state.current_term,
@ -308,8 +312,10 @@ impl RaftNode {
self.state.commit_index,
);
self.events
.push(RaftEvent::SendRpc(peer_id, RpcMessage::AppendEntries(request)));
self.events.push(RaftEvent::SendRpc(
peer_id,
RpcMessage::AppendEntries(request),
));
}
fn send_install_snapshot(&mut self, peer_id: NodeId) {
@ -332,8 +338,10 @@ impl RaftNode {
done,
);
self.events
.push(RaftEvent::SendRpc(peer_id, RpcMessage::InstallSnapshot(request)));
self.events.push(RaftEvent::SendRpc(
peer_id,
RpcMessage::InstallSnapshot(request),
));
}
}
@ -395,7 +403,11 @@ impl RaftNode {
}
}
fn handle_append_entries(&mut self, _from: NodeId, req: AppendEntries) -> AppendEntriesResponse {
fn handle_append_entries(
&mut self,
_from: NodeId,
req: AppendEntries,
) -> AppendEntriesResponse {
// Rule: If term > currentTerm, become follower
if req.term > self.state.current_term {
self.become_follower(req.term, Some(req.leader_id));
@ -416,9 +428,9 @@ impl RaftNode {
}
// Try to append entries
let success =
self.log
.append_entries(req.prev_log_index, req.prev_log_term, req.entries);
let success = self
.log
.append_entries(req.prev_log_index, req.prev_log_term, req.entries);
if success {
// Update commit index
@ -443,7 +455,11 @@ impl RaftNode {
}
conflict_index -= 1;
}
AppendEntriesResponse::conflict(self.state.current_term, conflict_term, conflict_index)
AppendEntriesResponse::conflict(
self.state.current_term,
conflict_term,
conflict_index,
)
} else {
AppendEntriesResponse::failure(self.state.current_term)
}
@ -502,7 +518,11 @@ impl RaftNode {
self.cluster.update_peer_state(from, PeerState::Reachable);
}
fn handle_install_snapshot(&mut self, _from: NodeId, req: InstallSnapshot) -> InstallSnapshotResponse {
fn handle_install_snapshot(
&mut self,
_from: NodeId,
req: InstallSnapshot,
) -> InstallSnapshotResponse {
// Rule: If term > currentTerm, become follower
if req.term > self.state.current_term {
self.become_follower(req.term, Some(req.leader_id));
@ -692,12 +712,14 @@ impl RaftNode {
#[cfg(test)]
mod tests {
use super::*;
use super::super::cluster::PeerAddress;
use super::*;
fn create_test_cluster(node_id: NodeId, peers: &[NodeId]) -> ClusterConfig {
let mut cluster =
ClusterConfig::new(node_id, PeerAddress::new("127.0.0.1", 9000 + node_id as u16));
let mut cluster = ClusterConfig::new(
node_id,
PeerAddress::new("127.0.0.1", 9000 + node_id as u16),
);
for &peer in peers {
cluster.add_peer(super::super::cluster::PeerInfo::new(
peer,

View file

@ -176,7 +176,10 @@ impl SnapshotManager {
/// Adds a chunk to the pending snapshot.
pub fn add_chunk(&mut self, offset: u64, data: Vec<u8>) -> bool {
if let Some(ref mut pending) = self.pending_snapshot {
if offset == pending.expected_offset + pending.chunks.iter().map(|c| c.len() as u64).sum::<u64>() {
if offset
== pending.expected_offset
+ pending.chunks.iter().map(|c| c.len() as u64).sum::<u64>()
{
pending.chunks.push(data);
return true;
}

View file

@ -118,31 +118,55 @@ pub enum Command {
// Key-Value operations
/// Set a key-value pair.
KvSet { key: String, value: Vec<u8>, ttl: Option<u64> },
KvSet {
key: String,
value: Vec<u8>,
ttl: Option<u64>,
},
/// Delete a key.
KvDelete { key: String },
// Document operations
/// Insert a document.
DocInsert { collection: String, document: JsonValue },
DocInsert {
collection: String,
document: JsonValue,
},
/// Update a document.
DocUpdate { collection: String, id: String, update: JsonValue },
DocUpdate {
collection: String,
id: String,
update: JsonValue,
},
/// Delete a document.
DocDelete { collection: String, id: String },
// Vector operations
/// Insert a vector.
VectorInsert { namespace: String, id: String, vector: Vec<f32>, metadata: JsonValue },
VectorInsert {
namespace: String,
id: String,
vector: Vec<f32>,
metadata: JsonValue,
},
/// Delete a vector.
VectorDelete { namespace: String, id: String },
// Time-series operations
/// Record a metric data point.
TimeSeriesRecord { metric: String, value: f64, timestamp: u64, tags: JsonValue },
TimeSeriesRecord {
metric: String,
value: f64,
timestamp: u64,
tags: JsonValue,
},
// Graph operations
/// Create a graph node.
GraphNodeCreate { labels: Vec<String>, properties: JsonValue },
GraphNodeCreate {
labels: Vec<String>,
properties: JsonValue,
},
/// Delete a graph node.
GraphNodeDelete { id: String },
/// Create a graph edge.
@ -161,13 +185,20 @@ pub enum Command {
// Schema operations
/// Create a collection/table.
CreateCollection { name: String, schema: Option<JsonValue> },
CreateCollection {
name: String,
schema: Option<JsonValue>,
},
/// Drop a collection/table.
DropCollection { name: String },
// Index operations
/// Create an index.
CreateIndex { collection: String, field: String, index_type: String },
CreateIndex {
collection: String,
field: String,
index_type: String,
},
/// Drop an index.
DropIndex { name: String },
@ -265,7 +296,12 @@ impl LeaderState {
}
/// Calculates the new commit index based on majority replication.
pub fn calculate_commit_index(&self, current_commit: u64, current_term: u64, log_term_at: impl Fn(u64) -> Option<u64>) -> u64 {
pub fn calculate_commit_index(
&self,
current_commit: u64,
current_term: u64,
log_term_at: impl Fn(u64) -> Option<u64>,
) -> u64 {
// Find the highest index that a majority have replicated
let mut indices: Vec<u64> = self.match_index.values().cloned().collect();
indices.sort_unstable();

View file

@ -315,8 +315,8 @@ mod tests {
#[test]
fn test_vector_field() {
let schema = Schema::new("embedding")
.field(Field::required("vector", FieldType::Vector(3)));
let schema =
Schema::new("embedding").field(Field::required("vector", FieldType::Vector(3)));
let mut validator = SchemaValidator::new();
validator.register(schema);

View file

@ -1,8 +1,7 @@
//! SQL query executor.
use super::parser::{
BinaryOp, ParsedExpr, ParsedSelect, ParsedSelectItem,
ParsedStatement, SqlParser,
BinaryOp, ParsedExpr, ParsedSelect, ParsedSelectItem, ParsedStatement, SqlParser,
};
use super::row::{Row, RowId};
use super::table::{ColumnDef, Table, TableDef};
@ -192,11 +191,7 @@ impl SqlEngine {
match a_val.partial_cmp(&b_val) {
Some(std::cmp::Ordering::Equal) => continue,
Some(ord) => {
return if ob.ascending {
ord
} else {
ord.reverse()
};
return if ob.ascending { ord } else { ord.reverse() };
}
None => continue,
}
@ -216,7 +211,11 @@ impl SqlEngine {
}
// Handle aggregates
if select.columns.iter().any(|c| matches!(c, ParsedSelectItem::Aggregate { .. })) {
if select
.columns
.iter()
.any(|c| matches!(c, ParsedSelectItem::Aggregate { .. }))
{
return self.execute_aggregate(select, &rows, table);
}
@ -244,7 +243,9 @@ impl SqlEngine {
ParsedSelectItem::Wildcard => table.def.column_names(),
ParsedSelectItem::Column(name) => vec![name.clone()],
ParsedSelectItem::ColumnAlias { alias, .. } => vec![alias.clone()],
ParsedSelectItem::Aggregate { function, alias, .. } => {
ParsedSelectItem::Aggregate {
function, alias, ..
} => {
vec![alias.clone().unwrap_or_else(|| function.clone())]
}
})
@ -328,7 +329,9 @@ impl SqlEngine {
rows.iter()
.map(|r| r.get_or_null(col))
.filter(|v| !v.is_null())
.min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.min_by(|a, b| {
a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
})
.unwrap_or(SqlValue::Null)
}
"MAX" => {
@ -338,12 +341,12 @@ impl SqlEngine {
rows.iter()
.map(|r| r.get_or_null(col))
.filter(|v| !v.is_null())
.max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.max_by(|a, b| {
a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
})
.unwrap_or(SqlValue::Null)
}
_ => {
return Err(SqlError::Unsupported(format!("Function: {}", function)))
}
_ => return Err(SqlError::Unsupported(format!("Function: {}", function))),
};
result_values.push(value);
}
@ -404,7 +407,11 @@ impl SqlEngine {
ParsedExpr::IsNotNull(inner) => {
SqlValue::Boolean(!self.evaluate_expr(row, inner).is_null())
}
ParsedExpr::InList { expr, list, negated } => {
ParsedExpr::InList {
expr,
list,
negated,
} => {
let val = self.evaluate_expr(row, expr);
let in_list = list.iter().any(|item| {
let item_val = self.evaluate_expr(row, item);
@ -424,9 +431,7 @@ impl SqlEngine {
let between = val >= low_val && val <= high_val;
SqlValue::Boolean(if *negated { !between } else { between })
}
ParsedExpr::Function { name, args } => {
self.evaluate_function(row, name, args)
}
ParsedExpr::Function { name, args } => self.evaluate_function(row, name, args),
}
}
@ -474,9 +479,7 @@ impl SqlEngine {
_ => SqlValue::Null,
},
BinaryOp::Divide => match (left, right) {
(SqlValue::Integer(a), SqlValue::Integer(b)) if *b != 0 => {
SqlValue::Integer(a / b)
}
(SqlValue::Integer(a), SqlValue::Integer(b)) if *b != 0 => SqlValue::Integer(a / b),
(SqlValue::Real(a), SqlValue::Real(b)) if *b != 0.0 => SqlValue::Real(a / b),
_ => SqlValue::Null,
},
@ -536,9 +539,7 @@ impl SqlEngine {
/// Matches a LIKE pattern.
fn match_like(&self, text: &str, pattern: &str) -> bool {
// Simple LIKE implementation: % = any chars, _ = single char
let _regex_pattern = pattern
.replace('%', ".*")
.replace('_', ".");
let _regex_pattern = pattern.replace('%', ".*").replace('_', ".");
// For simplicity, just do case-insensitive contains for now
if pattern.starts_with('%') && pattern.ends_with('%') {
let inner = &pattern[1..pattern.len() - 1];
@ -615,8 +616,7 @@ impl SqlEngine {
.unwrap_or(true);
if matches {
let updates: HashMap<String, SqlValue> =
assignments.iter().cloned().collect();
let updates: HashMap<String, SqlValue> = assignments.iter().cloned().collect();
table.update(row.id, updates)?;
count += 1;
}
@ -672,9 +672,9 @@ impl SqlEngine {
.ok_or_else(|| SqlError::TableNotFound(table_name.to_string()))?;
// For simplicity, only support single-column indexes
let column = columns
.first()
.ok_or_else(|| SqlError::InvalidOperation("Index requires at least one column".to_string()))?;
let column = columns.first().ok_or_else(|| {
SqlError::InvalidOperation("Index requires at least one column".to_string())
})?;
table.create_index(name, column, unique)?;
Ok(QueryResult::empty())
@ -775,7 +775,9 @@ mod tests {
.execute("INSERT INTO users (id, name, age) VALUES (2, 'Bob', 25)")
.unwrap();
let result = engine.execute("SELECT name FROM users WHERE age > 26").unwrap();
let result = engine
.execute("SELECT name FROM users WHERE age > 26")
.unwrap();
assert_eq!(result.rows.len(), 1);
assert_eq!(result.rows[0][0], SqlValue::Text("Alice".to_string()));
}
@ -806,7 +808,9 @@ mod tests {
engine
.execute(&format!(
"INSERT INTO users (id, name, age) VALUES ({}, 'User{}', {})",
i, i, 20 + i
i,
i,
20 + i
))
.unwrap();
}

View file

@ -18,10 +18,7 @@ pub enum ParsedStatement {
if_not_exists: bool,
},
/// DROP TABLE statement.
DropTable {
name: String,
if_exists: bool,
},
DropTable { name: String, if_exists: bool },
/// SELECT statement.
Select(ParsedSelect),
/// INSERT statement.
@ -179,15 +176,17 @@ impl SqlParser {
/// Parses a SQL statement.
pub fn parse(sql: &str) -> Result<ParsedStatement, SqlError> {
let dialect = SQLiteDialect {};
let statements = Parser::parse_sql(&dialect, sql)
.map_err(|e| SqlError::Parse(e.to_string()))?;
let statements =
Parser::parse_sql(&dialect, sql).map_err(|e| SqlError::Parse(e.to_string()))?;
if statements.is_empty() {
return Err(SqlError::Parse("Empty SQL statement".to_string()));
}
if statements.len() > 1 {
return Err(SqlError::Parse("Multiple statements not supported".to_string()));
return Err(SqlError::Parse(
"Multiple statements not supported".to_string(),
));
}
Self::convert_statement(&statements[0])
@ -195,25 +194,42 @@ impl SqlParser {
fn convert_statement(stmt: &Statement) -> Result<ParsedStatement, SqlError> {
match stmt {
Statement::CreateTable { name, columns, if_not_exists, constraints, .. } => {
Self::convert_create_table(name, columns, constraints, *if_not_exists)
}
Statement::Drop { object_type, names, if_exists, .. } => {
Self::convert_drop(object_type, names, *if_exists)
}
Statement::CreateTable {
name,
columns,
if_not_exists,
constraints,
..
} => Self::convert_create_table(name, columns, constraints, *if_not_exists),
Statement::Drop {
object_type,
names,
if_exists,
..
} => Self::convert_drop(object_type, names, *if_exists),
Statement::Query(query) => Self::convert_query(query),
Statement::Insert { table_name, columns, source, .. } => {
Self::convert_insert(table_name, columns, source)
}
Statement::Update { table, assignments, selection, .. } => {
Self::convert_update(table, assignments, selection)
}
Statement::Delete { from, selection, .. } => {
Self::convert_delete(from, selection)
}
Statement::CreateIndex { name, table_name, columns, unique, .. } => {
Self::convert_create_index(name, table_name, columns, *unique)
}
Statement::Insert {
table_name,
columns,
source,
..
} => Self::convert_insert(table_name, columns, source),
Statement::Update {
table,
assignments,
selection,
..
} => Self::convert_update(table, assignments, selection),
Statement::Delete {
from, selection, ..
} => Self::convert_delete(from, selection),
Statement::CreateIndex {
name,
table_name,
columns,
unique,
..
} => Self::convert_create_index(name, table_name, columns, *unique),
_ => Err(SqlError::Unsupported(format!("Statement not supported"))),
}
}
@ -230,7 +246,12 @@ impl SqlParser {
// Extract primary keys from table constraints
for constraint in constraints {
if let sqlparser::ast::TableConstraint::Unique { columns: pk_cols, is_primary: true, .. } = constraint {
if let sqlparser::ast::TableConstraint::Unique {
columns: pk_cols,
is_primary: true,
..
} = constraint
{
for col in pk_cols {
primary_keys.push(col.value.clone());
}
@ -296,10 +317,9 @@ impl SqlParser {
DataType::Real | DataType::Float(_) | DataType::Double | DataType::DoublePrecision => {
Ok(SqlType::Real)
}
DataType::Varchar(_)
| DataType::Char(_)
| DataType::Text
| DataType::String(_) => Ok(SqlType::Text),
DataType::Varchar(_) | DataType::Char(_) | DataType::Text | DataType::String(_) => {
Ok(SqlType::Text)
}
DataType::Binary(_) | DataType::Varbinary(_) | DataType::Blob(_) => Ok(SqlType::Blob),
DataType::Boolean => Ok(SqlType::Boolean),
DataType::Timestamp(_, _) | DataType::Date | DataType::Datetime(_) => {
@ -367,10 +387,7 @@ impl SqlParser {
.collect();
// Parse LIMIT/OFFSET
let limit = query
.limit
.as_ref()
.and_then(|l| Self::expr_to_usize(l));
let limit = query.limit.as_ref().and_then(|l| Self::expr_to_usize(l));
let offset = query
.offset
.as_ref()
@ -403,16 +420,18 @@ impl SqlParser {
Self::convert_select_expr(expr)
}
}
_ => Err(SqlError::Unsupported("Select item not supported".to_string())),
_ => Err(SqlError::Unsupported(
"Select item not supported".to_string(),
)),
}
}
fn convert_select_expr(expr: &Expr) -> Result<ParsedSelectItem, SqlError> {
match expr {
Expr::Identifier(id) => Ok(ParsedSelectItem::Column(id.value.clone())),
Expr::CompoundIdentifier(ids) => {
Ok(ParsedSelectItem::Column(ids.last().map(|i| i.value.clone()).unwrap_or_default()))
}
Expr::CompoundIdentifier(ids) => Ok(ParsedSelectItem::Column(
ids.last().map(|i| i.value.clone()).unwrap_or_default(),
)),
Expr::Function(func) => {
let name = func.name.to_string().to_uppercase();
// Try to extract column from first arg - simplified for compatibility
@ -423,14 +442,18 @@ impl SqlParser {
alias: None,
})
}
_ => Err(SqlError::Unsupported("Select expression not supported".to_string())),
_ => Err(SqlError::Unsupported(
"Select expression not supported".to_string(),
)),
}
}
fn convert_table_factor(factor: &TableFactor) -> Result<String, SqlError> {
match factor {
TableFactor::Table { name, .. } => Ok(name.to_string()),
_ => Err(SqlError::Unsupported("Table factor not supported".to_string())),
_ => Err(SqlError::Unsupported(
"Table factor not supported".to_string(),
)),
}
}
@ -461,9 +484,9 @@ impl SqlParser {
fn convert_expr(expr: &Expr) -> Result<ParsedExpr, SqlError> {
match expr {
Expr::Identifier(id) => Ok(ParsedExpr::Column(id.value.clone())),
Expr::CompoundIdentifier(ids) => {
Ok(ParsedExpr::Column(ids.last().map(|i| i.value.clone()).unwrap_or_default()))
}
Expr::CompoundIdentifier(ids) => Ok(ParsedExpr::Column(
ids.last().map(|i| i.value.clone()).unwrap_or_default(),
)),
Expr::Value(v) => Ok(ParsedExpr::Literal(Self::convert_value(v)?)),
Expr::BinaryOp { left, op, right } => {
let left = Box::new(Self::convert_expr(left)?);
@ -471,17 +494,30 @@ impl SqlParser {
let op = Self::convert_binary_op(op)?;
Ok(ParsedExpr::BinaryOp { left, op, right })
}
Expr::UnaryOp { op: sqlparser::ast::UnaryOperator::Not, expr } => {
Ok(ParsedExpr::Not(Box::new(Self::convert_expr(expr)?)))
}
Expr::UnaryOp {
op: sqlparser::ast::UnaryOperator::Not,
expr,
} => Ok(ParsedExpr::Not(Box::new(Self::convert_expr(expr)?))),
Expr::IsNull(expr) => Ok(ParsedExpr::IsNull(Box::new(Self::convert_expr(expr)?))),
Expr::IsNotNull(expr) => Ok(ParsedExpr::IsNotNull(Box::new(Self::convert_expr(expr)?))),
Expr::InList { expr, list, negated } => Ok(ParsedExpr::InList {
Expr::InList {
expr,
list,
negated,
} => Ok(ParsedExpr::InList {
expr: Box::new(Self::convert_expr(expr)?),
list: list.iter().map(Self::convert_expr).collect::<Result<_, _>>()?,
list: list
.iter()
.map(Self::convert_expr)
.collect::<Result<_, _>>()?,
negated: *negated,
}),
Expr::Between { expr, low, high, negated } => Ok(ParsedExpr::Between {
Expr::Between {
expr,
low,
high,
negated,
} => Ok(ParsedExpr::Between {
expr: Box::new(Self::convert_expr(expr)?),
low: Box::new(Self::convert_expr(low)?),
high: Box::new(Self::convert_expr(high)?),
@ -490,10 +526,16 @@ impl SqlParser {
Expr::Like { expr, pattern, .. } => {
let left = Box::new(Self::convert_expr(expr)?);
let right = Box::new(Self::convert_expr(pattern)?);
Ok(ParsedExpr::BinaryOp { left, op: BinaryOp::Like, right })
Ok(ParsedExpr::BinaryOp {
left,
op: BinaryOp::Like,
right,
})
}
Expr::Nested(inner) => Self::convert_expr(inner),
_ => Err(SqlError::Unsupported("Expression not supported".to_string())),
_ => Err(SqlError::Unsupported(
"Expression not supported".to_string(),
)),
}
}
@ -587,7 +629,11 @@ impl SqlParser {
let parsed_assignments: Vec<(String, SqlValue)> = assignments
.iter()
.map(|a| {
let col = a.id.iter().map(|i| i.value.clone()).collect::<Vec<_>>().join(".");
let col =
a.id.iter()
.map(|i| i.value.clone())
.collect::<Vec<_>>()
.join(".");
let val = Self::convert_value_expr(&a.value)?;
Ok((col, val))
})
@ -633,10 +679,7 @@ impl SqlParser {
let table = table_name.to_string();
let cols: Vec<String> = columns
.iter()
.map(|c| c.expr.to_string())
.collect();
let cols: Vec<String> = columns.iter().map(|c| c.expr.to_string()).collect();
Ok(ParsedStatement::CreateIndex {
name: index_name,
@ -694,7 +737,12 @@ mod tests {
let sql = "INSERT INTO users (name, age) VALUES ('Alice', 30), ('Bob', 25)";
let stmt = SqlParser::parse(sql).unwrap();
if let ParsedStatement::Insert { table, columns, values } = stmt {
if let ParsedStatement::Insert {
table,
columns,
values,
} = stmt
{
assert_eq!(table, "users");
assert_eq!(columns, vec!["name", "age"]);
assert_eq!(values.len(), 2);
@ -708,7 +756,12 @@ mod tests {
let sql = "UPDATE users SET age = 31 WHERE name = 'Alice'";
let stmt = SqlParser::parse(sql).unwrap();
if let ParsedStatement::Update { table, assignments, where_clause } = stmt {
if let ParsedStatement::Update {
table,
assignments,
where_clause,
} = stmt
{
assert_eq!(table, "users");
assert_eq!(assignments.len(), 1);
assert!(where_clause.is_some());
@ -722,7 +775,11 @@ mod tests {
let sql = "DELETE FROM users WHERE age < 18";
let stmt = SqlParser::parse(sql).unwrap();
if let ParsedStatement::Delete { table, where_clause } = stmt {
if let ParsedStatement::Delete {
table,
where_clause,
} = stmt
{
assert_eq!(table, "users");
assert!(where_clause.is_some());
} else {

View file

@ -85,7 +85,10 @@ impl Row {
/// Returns all values in column order.
pub fn values(&self) -> Vec<&SqlValue> {
self.columns.iter().map(|c| self.values.get(c).unwrap()).collect()
self.columns
.iter()
.map(|c| self.values.get(c).unwrap())
.collect()
}
/// Returns the number of columns.

View file

@ -299,7 +299,10 @@ impl Table {
let mut indexes = self.indexes.write();
if indexes.contains_key(&name) {
return Err(SqlError::InvalidOperation(format!("Index '{}' already exists", name)));
return Err(SqlError::InvalidOperation(format!(
"Index '{}' already exists",
name
)));
}
let mut index = TableIndex::new(&name, &column, unique);
@ -319,7 +322,10 @@ impl Table {
pub fn drop_index(&self, name: &str) -> Result<(), SqlError> {
let mut indexes = self.indexes.write();
if indexes.remove(name).is_none() {
return Err(SqlError::InvalidOperation(format!("Index '{}' not found", name)));
return Err(SqlError::InvalidOperation(format!(
"Index '{}' not found",
name
)));
}
Ok(())
}
@ -371,9 +377,9 @@ impl Table {
/// Updates a row.
pub fn update(&self, id: RowId, updates: HashMap<String, SqlValue>) -> Result<(), SqlError> {
let mut rows = self.rows.write();
let row = rows.get_mut(&id).ok_or_else(|| {
SqlError::InvalidOperation(format!("Row {} not found", id))
})?;
let row = rows
.get_mut(&id)
.ok_or_else(|| SqlError::InvalidOperation(format!("Row {} not found", id)))?;
let old_values: HashMap<String, SqlValue> = updates
.keys()
@ -392,7 +398,10 @@ impl Table {
let mut indexes = self.indexes.write();
for (_, index) in indexes.iter_mut() {
if let Some(new_value) = updates.get(&index.column) {
let old_value = old_values.get(&index.column).cloned().unwrap_or(SqlValue::Null);
let old_value = old_values
.get(&index.column)
.cloned()
.unwrap_or(SqlValue::Null);
index.remove(&old_value, &id);
index.insert(new_value.clone(), id)?;
}
@ -480,7 +489,10 @@ mod tests {
values.insert("id".to_string(), SqlValue::Integer(1));
values.insert("name".to_string(), SqlValue::Text("Alice".to_string()));
values.insert("age".to_string(), SqlValue::Integer(30));
values.insert("email".to_string(), SqlValue::Text("alice@example.com".to_string()));
values.insert(
"email".to_string(),
SqlValue::Text("alice@example.com".to_string()),
);
let row_id = table.insert(values).unwrap();
assert_eq!(table.count(), 1);
@ -508,13 +520,19 @@ mod tests {
let mut values1 = HashMap::new();
values1.insert("id".to_string(), SqlValue::Integer(1));
values1.insert("name".to_string(), SqlValue::Text("Alice".to_string()));
values1.insert("email".to_string(), SqlValue::Text("test@example.com".to_string()));
values1.insert(
"email".to_string(),
SqlValue::Text("test@example.com".to_string()),
);
table.insert(values1).unwrap();
let mut values2 = HashMap::new();
values2.insert("id".to_string(), SqlValue::Integer(2));
values2.insert("name".to_string(), SqlValue::Text("Bob".to_string()));
values2.insert("email".to_string(), SqlValue::Text("test@example.com".to_string()));
values2.insert(
"email".to_string(),
SqlValue::Text("test@example.com".to_string()),
);
let result = table.insert(values2);
assert!(result.is_err()); // Duplicate email

View file

@ -124,7 +124,12 @@ impl Transaction {
}
/// Records an insert operation.
pub fn record_insert(&mut self, table: String, row_id: RowId, values: HashMap<String, SqlValue>) {
pub fn record_insert(
&mut self,
table: String,
row_id: RowId,
values: HashMap<String, SqlValue>,
) {
self.operations.push(TransactionOp::Insert {
table,
row_id,
@ -149,7 +154,12 @@ impl Transaction {
}
/// Records a delete operation.
pub fn record_delete(&mut self, table: String, row_id: RowId, old_values: HashMap<String, SqlValue>) {
pub fn record_delete(
&mut self,
table: String,
row_id: RowId,
old_values: HashMap<String, SqlValue>,
) {
self.operations.push(TransactionOp::Delete {
table,
row_id,
@ -213,7 +223,10 @@ impl TransactionManager {
.ok_or_else(|| SqlError::Transaction(format!("Transaction {} not found", id)))?;
if !txn.is_active() {
return Err(SqlError::Transaction(format!("Transaction {} is not active", id)));
return Err(SqlError::Transaction(format!(
"Transaction {} is not active",
id
)));
}
txn.operations.push(op);
@ -228,7 +241,10 @@ impl TransactionManager {
.ok_or_else(|| SqlError::Transaction(format!("Transaction {} not found", id)))?;
if !txn.is_active() {
return Err(SqlError::Transaction(format!("Transaction {} is not active", id)));
return Err(SqlError::Transaction(format!(
"Transaction {} is not active",
id
)));
}
txn.mark_committed();
@ -245,7 +261,10 @@ impl TransactionManager {
.ok_or_else(|| SqlError::Transaction(format!("Transaction {} not found", id)))?;
if !txn.is_active() {
return Err(SqlError::Transaction(format!("Transaction {} is not active", id)));
return Err(SqlError::Transaction(format!(
"Transaction {} is not active",
id
)));
}
txn.mark_rolled_back();

View file

@ -233,12 +233,8 @@ impl Ord for SqlValue {
(SqlValue::Blob(a), SqlValue::Blob(b)) => a.cmp(b),
(SqlValue::Boolean(a), SqlValue::Boolean(b)) => a.cmp(b),
(SqlValue::Timestamp(a), SqlValue::Timestamp(b)) => a.cmp(b),
(SqlValue::Integer(a), SqlValue::Real(b)) => {
(*a as f64).to_bits().cmp(&b.to_bits())
}
(SqlValue::Real(a), SqlValue::Integer(b)) => {
a.to_bits().cmp(&(*b as f64).to_bits())
}
(SqlValue::Integer(a), SqlValue::Real(b)) => (*a as f64).to_bits().cmp(&b.to_bits()),
(SqlValue::Real(a), SqlValue::Integer(b)) => a.to_bits().cmp(&(*b as f64).to_bits()),
// Different types: order by type discriminant
_ => self.type_order().cmp(&other.type_order()),
}

View file

@ -158,11 +158,7 @@ impl Metric {
/// Calculates sum in a time range.
pub fn sum(&self, start: u64, end: u64) -> f64 {
self.data
.read()
.range(start..=end)
.map(|(_, &v)| v)
.sum()
self.data.read().range(start..=end).map(|(_, &v)| v).sum()
}
/// Counts data points in a time range.

View file

@ -207,9 +207,7 @@ impl VectorIndex {
let embeddings = self.embeddings.read();
let mut results: Vec<VectorSearchResult> = embeddings
.values()
.filter(|e| {
namespace.map(|ns| e.namespace == ns).unwrap_or(true)
})
.filter(|e| namespace.map(|ns| e.namespace == ns).unwrap_or(true))
.map(|e| {
let score = self.calculate_similarity(&e.vector, query);
VectorSearchResult {
@ -217,14 +215,14 @@ impl VectorIndex {
score,
}
})
.filter(|r| {
threshold.map(|t| r.score >= t).unwrap_or(true)
})
.filter(|r| threshold.map(|t| r.score >= t).unwrap_or(true))
.collect();
// Sort by score descending
results.sort_by(|a, b| {
b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal)
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
// Apply limit
@ -234,8 +232,9 @@ impl VectorIndex {
let elapsed = start.elapsed().as_millis() as f64;
let mut stats = self.stats.write();
stats.searches += 1;
stats.avg_search_time_ms =
(stats.avg_search_time_ms * (stats.searches - 1) as f64 + elapsed) / stats.searches as f64;
stats.avg_search_time_ms = (stats.avg_search_time_ms * (stats.searches - 1) as f64
+ elapsed)
/ stats.searches as f64;
Ok(results)
}
@ -329,7 +328,8 @@ impl VectorStore {
namespace: Option<&str>,
threshold: Option<f32>,
) -> Result<Vec<VectorSearchResult>, DatabaseError> {
self.default_index.search(query, limit, namespace, threshold)
self.default_index
.search(query, limit, namespace, threshold)
}
/// Gets an embedding by ID.
@ -388,10 +388,7 @@ pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
/// Manhattan distance (L1) between two vectors.
pub fn manhattan_distance(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).abs())
.sum()
a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum()
}
#[cfg(test)]
@ -414,9 +411,15 @@ mod tests {
fn test_vector_insert_search() {
let store = VectorStore::new(3);
store.insert(Embedding::new("a", vec![1.0, 0.0, 0.0])).unwrap();
store.insert(Embedding::new("b", vec![0.9, 0.1, 0.0])).unwrap();
store.insert(Embedding::new("c", vec![0.0, 1.0, 0.0])).unwrap();
store
.insert(Embedding::new("a", vec![1.0, 0.0, 0.0]))
.unwrap();
store
.insert(Embedding::new("b", vec![0.9, 0.1, 0.0]))
.unwrap();
store
.insert(Embedding::new("c", vec![0.0, 1.0, 0.0]))
.unwrap();
let results = store.search(&[1.0, 0.0, 0.0], 2, None, None).unwrap();
@ -429,8 +432,12 @@ mod tests {
fn test_similarity_threshold() {
let store = VectorStore::new(3);
store.insert(Embedding::new("a", vec![1.0, 0.0, 0.0])).unwrap();
store.insert(Embedding::new("b", vec![0.0, 1.0, 0.0])).unwrap();
store
.insert(Embedding::new("a", vec![1.0, 0.0, 0.0]))
.unwrap();
store
.insert(Embedding::new("b", vec![0.0, 1.0, 0.0]))
.unwrap();
let results = store.search(&[1.0, 0.0, 0.0], 10, None, Some(0.5)).unwrap();
@ -443,14 +450,16 @@ mod tests {
fn test_namespace_filter() {
let store = VectorStore::new(3);
store.insert(
Embedding::new("a", vec![1.0, 0.0, 0.0]).with_namespace("ns1")
).unwrap();
store.insert(
Embedding::new("b", vec![1.0, 0.0, 0.0]).with_namespace("ns2")
).unwrap();
store
.insert(Embedding::new("a", vec![1.0, 0.0, 0.0]).with_namespace("ns1"))
.unwrap();
store
.insert(Embedding::new("b", vec![1.0, 0.0, 0.0]).with_namespace("ns2"))
.unwrap();
let results = store.search(&[1.0, 0.0, 0.0], 10, Some("ns1"), None).unwrap();
let results = store
.search(&[1.0, 0.0, 0.0], 10, Some("ns1"), None)
.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].embedding.id, "a");

View file

@ -184,9 +184,7 @@ impl Credit {
/// Check if credit is expired
pub fn is_expired(&self) -> bool {
self.expires_at
.map(|exp| Utc::now() > exp)
.unwrap_or(false)
self.expires_at.map(|exp| Utc::now() > exp).unwrap_or(false)
}
/// Get remaining amount
@ -241,14 +239,21 @@ impl std::fmt::Display for CreditError {
match self {
CreditError::CreditInactive => write!(f, "Credit is no longer active"),
CreditError::CreditExpired => write!(f, "Credit has expired"),
CreditError::InsufficientCredit { requested, available } => {
CreditError::InsufficientCredit {
requested,
available,
} => {
write!(
f,
"Insufficient credit: requested {}, available {}",
requested, available
)
}
CreditError::ExceedsMaxCredit { current, requested, maximum } => {
CreditError::ExceedsMaxCredit {
current,
requested,
maximum,
} => {
write!(
f,
"Credit exceeds maximum: current {}, requested {}, maximum {}",
@ -279,9 +284,9 @@ pub struct CreditPolicy {
impl Default for CreditPolicy {
fn default() -> Self {
Self {
welcome_amount: Decimal::new(10, 0), // 10 SYNOR
welcome_amount: Decimal::new(10, 0), // 10 SYNOR
referral_referrer_amount: Decimal::new(25, 0), // 25 SYNOR
referral_referee_amount: Decimal::new(10, 0), // 10 SYNOR
referral_referee_amount: Decimal::new(10, 0), // 10 SYNOR
max_credit_per_account: Decimal::new(1000, 0), // 1000 SYNOR
default_expiry_days: 365,
}
@ -334,12 +339,20 @@ impl CreditManager {
let referee_id = referee_id.into();
// Credit for the referrer
let referrer_credit = Credit::referral(&referrer_id, self.policy.referral_referrer_amount, &referee_id)
.with_expiry_days(self.policy.default_expiry_days);
let referrer_credit = Credit::referral(
&referrer_id,
self.policy.referral_referrer_amount,
&referee_id,
)
.with_expiry_days(self.policy.default_expiry_days);
// Credit for the referee
let referee_credit = Credit::referral(&referee_id, self.policy.referral_referee_amount, &referrer_id)
.with_expiry_days(self.policy.default_expiry_days);
let referee_credit = Credit::referral(
&referee_id,
self.policy.referral_referee_amount,
&referrer_id,
)
.with_expiry_days(self.policy.default_expiry_days);
self.credits
.entry(referrer_id)
@ -448,13 +461,11 @@ impl CreditManager {
let mut remaining = amount;
// Sort by expiry date (soonest first) for FIFO
credits.sort_by(|a, b| {
match (&a.expires_at, &b.expires_at) {
(Some(a_exp), Some(b_exp)) => a_exp.cmp(b_exp),
(Some(_), None) => std::cmp::Ordering::Less,
(None, Some(_)) => std::cmp::Ordering::Greater,
(None, None) => a.created_at.cmp(&b.created_at),
}
credits.sort_by(|a, b| match (&a.expires_at, &b.expires_at) {
(Some(a_exp), Some(b_exp)) => a_exp.cmp(b_exp),
(Some(_), None) => std::cmp::Ordering::Less,
(None, Some(_)) => std::cmp::Ordering::Greater,
(None, None) => a.created_at.cmp(&b.created_at),
});
for credit in credits.iter_mut() {

View file

@ -319,12 +319,7 @@ mod tests {
#[test]
fn test_line_item() {
let item = InvoiceLineItem::new(
"Storage L2",
ServiceType::Storage,
dec!(10),
dec!(0.02),
);
let item = InvoiceLineItem::new("Storage L2", ServiceType::Storage, dec!(10), dec!(0.02));
assert_eq!(item.amount, dec!(0.20));
}
@ -332,8 +327,18 @@ mod tests {
#[test]
fn test_invoice_calculate() {
let mut invoice = Invoice::new("test")
.add_line_item(InvoiceLineItem::new("Storage", ServiceType::Storage, dec!(100), dec!(0.02)))
.add_line_item(InvoiceLineItem::new("Compute", ServiceType::Compute, dec!(10), dec!(0.50)));
.add_line_item(InvoiceLineItem::new(
"Storage",
ServiceType::Storage,
dec!(100),
dec!(0.02),
))
.add_line_item(InvoiceLineItem::new(
"Compute",
ServiceType::Compute,
dec!(10),
dec!(0.50),
));
invoice.discount = dec!(1);
invoice.calculate();

View file

@ -164,17 +164,12 @@ impl BillingEngine {
let outstanding: Vec<_> = account
.invoice_ids
.iter()
.filter(|id| {
invoices
.get(*id)
.map(|inv| !inv.is_paid())
.unwrap_or(false)
})
.filter(|id| invoices.get(*id).map(|inv| !inv.is_paid()).unwrap_or(false))
.cloned()
.collect();
let next_invoice = account.billing_cycle_start
+ Duration::days(self.config.billing_cycle_days as i64);
let next_invoice =
account.billing_cycle_start + Duration::days(self.config.billing_cycle_days as i64);
Ok(AccountBillingInfo {
account_id: account_id.to_string(),
@ -198,11 +193,7 @@ impl BillingEngine {
account.prepaid_balance += amount;
tracing::info!(
"Added {} SYNOR prepaid to account {}",
amount,
account_id
);
tracing::info!("Added {} SYNOR prepaid to account {}", amount, account_id);
Ok(())
}
@ -378,7 +369,9 @@ impl BillingEngine {
PaymentMethod::CreditBalance => {
// Deduct from credit balance
if account.credit_balance < payment.amount {
return Err(EconomicsError::PaymentFailed("Insufficient credit balance".to_string()));
return Err(EconomicsError::PaymentFailed(
"Insufficient credit balance".to_string(),
));
}
account.credit_balance -= payment.amount;
payment.mark_completed();
@ -484,7 +477,10 @@ impl BillingEngine {
/// Get unpaid invoices for an account
pub async fn get_unpaid_invoices(&self, account_id: &str) -> Result<Vec<Invoice>> {
let all_invoices = self.get_account_invoices(account_id).await?;
Ok(all_invoices.into_iter().filter(|inv| !inv.is_paid()).collect())
Ok(all_invoices
.into_iter()
.filter(|inv| !inv.is_paid())
.collect())
}
/// Get detailed account information including creation date
@ -617,7 +613,10 @@ mod tests {
#[tokio::test]
async fn test_register_account() {
let engine = setup_engine().await;
engine.register_account("test_account", "standard").await.unwrap();
engine
.register_account("test_account", "standard")
.await
.unwrap();
let info = engine.get_account_details("test_account").await.unwrap();
assert_eq!(info.account_id, "test_account");
@ -627,7 +626,10 @@ mod tests {
#[tokio::test]
async fn test_add_prepaid() {
let engine = setup_engine().await;
engine.register_account("prepaid_test", "standard").await.unwrap();
engine
.register_account("prepaid_test", "standard")
.await
.unwrap();
engine.add_prepaid("prepaid_test", dec!(100)).await.unwrap();
let info = engine.get_account_info("prepaid_test").await.unwrap();
@ -637,7 +639,10 @@ mod tests {
#[tokio::test]
async fn test_add_credit() {
let engine = setup_engine().await;
engine.register_account("credit_test", "standard").await.unwrap();
engine
.register_account("credit_test", "standard")
.await
.unwrap();
let credit = Credit::new("credit_test", dec!(50), "Welcome bonus");
engine.add_credit("credit_test", credit).await.unwrap();

View file

@ -210,17 +210,23 @@ impl PaymentProcessor {
payment.mark_processing();
// Simulate transaction
let tx_hash = format!("0x{:x}000000000000000000000000000000000000000000000000000000000000",
let tx_hash = format!(
"0x{:x}000000000000000000000000000000000000000000000000000000000000",
std::time::SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs());
.as_secs()
);
payment.mark_confirmed(tx_hash);
// Add addresses to metadata
payment.metadata.insert("from".to_string(), from_address.to_string());
payment.metadata.insert("to".to_string(), to_address.to_string());
payment
.metadata
.insert("from".to_string(), from_address.to_string());
payment
.metadata
.insert("to".to_string(), to_address.to_string());
payment.mark_completed();
@ -325,7 +331,10 @@ mod tests {
payment.mark_failed("Insufficient funds");
assert!(!payment.is_complete());
assert_eq!(payment.failure_reason, Some("Insufficient funds".to_string()));
assert_eq!(
payment.failure_reason,
Some("Insufficient funds".to_string())
);
}
#[tokio::test]

View file

@ -177,7 +177,10 @@ impl CostEstimator {
/// Estimate cost for a usage projection
pub async fn estimate(&self, projection: UsageProjection) -> Result<CostEstimate> {
let tier_name = projection.tier.clone().unwrap_or_else(|| "free".to_string());
let tier_name = projection
.tier
.clone()
.unwrap_or_else(|| "free".to_string());
let months = projection.duration_months.max(1);
let mut by_service = HashMap::new();

View file

@ -22,10 +22,7 @@ pub enum EconomicsError {
/// Insufficient balance
#[error("Insufficient balance: required {required}, available {available}")]
InsufficientBalance {
required: String,
available: String,
},
InsufficientBalance { required: String, available: String },
/// Insufficient funds (with Decimal values)
#[error("Insufficient funds: required {required}, available {available}")]
@ -36,10 +33,7 @@ pub enum EconomicsError {
/// Stale price with specific asset
#[error("Price stale for {asset}: {age_seconds} seconds old")]
StalePrice {
asset: String,
age_seconds: i64,
},
StalePrice { asset: String, age_seconds: i64 },
/// Account not found
#[error("Account not found: {0}")]

View file

@ -251,9 +251,7 @@ impl EconomicsManager {
use rust_decimal_macros::dec;
// Default to development oracle with mock feeds at $1.50 base price
let oracle = Arc::new(RwLock::new(
oracle::OracleFactory::development(dec!(1.50))
));
let oracle = Arc::new(RwLock::new(oracle::OracleFactory::development(dec!(1.50))));
let pricing = Arc::new(PricingEngine::new());
let metering = Arc::new(MeteringService::new(pricing.clone()));
let billing = Arc::new(BillingEngine::new(metering.clone(), pricing.clone()));
@ -270,9 +268,7 @@ impl EconomicsManager {
/// Create an economics manager with production oracle configuration
pub fn with_production_oracle(config: oracle::ProductionOracleConfig) -> Self {
let oracle = Arc::new(RwLock::new(
oracle::OracleFactory::production(config)
));
let oracle = Arc::new(RwLock::new(oracle::OracleFactory::production(config)));
let pricing = Arc::new(PricingEngine::new());
let metering = Arc::new(MeteringService::new(pricing.clone()));
let billing = Arc::new(BillingEngine::new(metering.clone(), pricing.clone()));

View file

@ -209,23 +209,22 @@ impl MeteringService {
}
// Calculate cost for this event
let cost = self.pricing.calculate_cost(
event.service_type,
event.resource_unit,
event.amount,
)?;
let cost =
self.pricing
.calculate_cost(event.service_type, event.resource_unit, event.amount)?;
// Update current usage
{
let mut usage = self.current_usage.write().await;
let account_usage = usage.entry(event.account_id.clone()).or_insert_with(|| {
AccountUsage {
account_id: event.account_id.clone(),
by_service: HashMap::new(),
current_period_start: Utc::now(),
last_event: None,
}
});
let account_usage =
usage
.entry(event.account_id.clone())
.or_insert_with(|| AccountUsage {
account_id: event.account_id.clone(),
by_service: HashMap::new(),
current_period_start: Utc::now(),
last_event: None,
});
*account_usage
.by_service
@ -263,7 +262,8 @@ impl MeteringService {
ServiceType::Storage,
ResourceUnit::Bytes,
Decimal::from(usage.bytes_stored),
)).await?;
))
.await?;
}
// Storage: bytes retrieved
@ -273,7 +273,8 @@ impl MeteringService {
ServiceType::Storage,
ResourceUnit::BandwidthGb,
Decimal::from(usage.bytes_retrieved) / Decimal::from(1_073_741_824u64), // to GB
)).await?;
))
.await?;
}
Ok(())
@ -288,7 +289,8 @@ impl MeteringService {
ServiceType::Hosting,
ResourceUnit::BandwidthGb,
Decimal::from(usage.bandwidth_bytes) / Decimal::from(1_073_741_824u64),
)).await?;
))
.await?;
}
// Custom domains
@ -298,7 +300,8 @@ impl MeteringService {
ServiceType::Hosting,
ResourceUnit::Domains,
Decimal::from(usage.custom_domains),
)).await?;
))
.await?;
}
Ok(())
@ -313,7 +316,8 @@ impl MeteringService {
ServiceType::Database,
ResourceUnit::Queries,
Decimal::from(usage.queries),
)).await?;
))
.await?;
}
// Vector searches
@ -323,7 +327,8 @@ impl MeteringService {
ServiceType::Database,
ResourceUnit::VectorSearches,
Decimal::from(usage.vector_searches),
)).await?;
))
.await?;
}
// Storage
@ -333,7 +338,8 @@ impl MeteringService {
ServiceType::Database,
ResourceUnit::GbMonth,
Decimal::from(usage.storage_bytes) / Decimal::from(1_073_741_824u64),
)).await?;
))
.await?;
}
Ok(())
@ -348,7 +354,8 @@ impl MeteringService {
ServiceType::Compute,
ResourceUnit::CpuCoreHours,
Decimal::from(usage.cpu_core_seconds) / Decimal::from(3600),
)).await?;
))
.await?;
}
// GPU hours
@ -358,7 +365,8 @@ impl MeteringService {
ServiceType::Compute,
ResourceUnit::GpuHours,
Decimal::from(usage.gpu_seconds) / Decimal::from(3600),
)).await?;
))
.await?;
}
// Memory GB hours
@ -368,7 +376,8 @@ impl MeteringService {
ServiceType::Compute,
ResourceUnit::MemoryGbHours,
Decimal::from(usage.memory_gb_seconds) / Decimal::from(3600),
)).await?;
))
.await?;
}
// Invocations (serverless)
@ -378,7 +387,8 @@ impl MeteringService {
ServiceType::Compute,
ResourceUnit::Invocations,
Decimal::from(usage.invocations),
)).await?;
))
.await?;
}
Ok(())
@ -393,7 +403,8 @@ impl MeteringService {
ServiceType::Network,
ResourceUnit::BandwidthGb,
Decimal::from(total_bytes) / Decimal::from(1_073_741_824u64),
)).await?;
))
.await?;
}
Ok(())
@ -421,10 +432,7 @@ impl MeteringService {
// Check buffered events
let buffer = self.event_buffer.read().await;
for event in buffer.iter() {
if event.account_id == account_id
&& event.timestamp >= start
&& event.timestamp < end
{
if event.account_id == account_id && event.timestamp >= start && event.timestamp < end {
let cost = self.pricing.calculate_cost(
event.service_type,
event.resource_unit,

View file

@ -224,9 +224,8 @@ impl IsolationTree {
// Random split point
let split = min_val + (max_val - min_val) * 0.5;
let (left_data, right_data): (Vec<_>, Vec<_>) = data.iter()
.cloned()
.partition(|row| row[feature] < split);
let (left_data, right_data): (Vec<_>, Vec<_>) =
data.iter().cloned().partition(|row| row[feature] < split);
Some(Self {
split_feature: feature,
@ -280,7 +279,8 @@ impl IsolationForest {
let trees: Vec<_> = (0..n_trees)
.filter_map(|i| {
// Subsample with deterministic "randomness" based on tree index
let sample: Vec<_> = data.iter()
let sample: Vec<_> = data
.iter()
.enumerate()
.filter(|(j, _)| (i + j) % 3 != 0)
.take(sample_size)
@ -299,9 +299,12 @@ impl IsolationForest {
return 0.5;
}
let avg_path: f64 = self.trees.iter()
let avg_path: f64 = self
.trees
.iter()
.map(|tree| tree.path_length(point, 0.0))
.sum::<f64>() / self.trees.len() as f64;
.sum::<f64>()
/ self.trees.len() as f64;
let c = c_factor(self.sample_size);
if c < f64::EPSILON {
@ -365,17 +368,28 @@ impl PairDetector {
// Track addresses
if !point.addresses.is_empty() {
self.recent_addresses.push_back((point.timestamp, point.addresses.clone()));
self.recent_addresses
.push_back((point.timestamp, point.addresses.clone()));
}
self.price_history.push_back(point);
// Cleanup old data
let cutoff = Utc::now() - Duration::hours(24);
while self.price_history.front().map(|p| p.timestamp < cutoff).unwrap_or(false) {
while self
.price_history
.front()
.map(|p| p.timestamp < cutoff)
.unwrap_or(false)
{
self.price_history.pop_front();
}
while self.recent_addresses.front().map(|(t, _)| *t < cutoff).unwrap_or(false) {
while self
.recent_addresses
.front()
.map(|(t, _)| *t < cutoff)
.unwrap_or(false)
{
self.recent_addresses.pop_front();
}
}
@ -386,7 +400,9 @@ impl PairDetector {
}
// Build feature vectors: [price, volume, return, bid/ask ratio]
let data: Vec<Vec<f64>> = self.price_history.iter()
let data: Vec<Vec<f64>> = self
.price_history
.iter()
.skip(1)
.zip(self.price_history.iter())
.map(|(curr, prev)| {
@ -403,7 +419,11 @@ impl PairDetector {
(Some(bid), Some(ask)) => {
let bid_f = bid.to_string().parse::<f64>().unwrap_or(0.0);
let ask_f = ask.to_string().parse::<f64>().unwrap_or(1.0);
if ask_f > 0.0 { bid_f / ask_f } else { 1.0 }
if ask_f > 0.0 {
bid_f / ask_f
} else {
1.0
}
}
_ => 1.0,
};
@ -462,19 +482,34 @@ impl AnomalyDetector {
// Run all detectors using the immutable reference first
if let Some(detector) = self.detectors.get(pair) {
if let Some(a) = Self::detect_price_outlier_impl(pair, &data, detector, min_data_points, z_score_threshold) {
if let Some(a) = Self::detect_price_outlier_impl(
pair,
&data,
detector,
min_data_points,
z_score_threshold,
) {
anomalies.push(a);
}
if let Some(a) = Self::detect_volume_spike_impl(pair, &data, detector, min_data_points, volume_spike_multiplier) {
if let Some(a) = Self::detect_volume_spike_impl(
pair,
&data,
detector,
min_data_points,
volume_spike_multiplier,
) {
anomalies.push(a);
}
if let Some(a) = Self::detect_wash_trading_impl(pair, &data, detector, wash_trading_window) {
if let Some(a) =
Self::detect_wash_trading_impl(pair, &data, detector, wash_trading_window)
{
anomalies.push(a);
}
if let Some(a) = Self::detect_pump_dump_impl(pair, detector, pump_dump_window) {
anomalies.push(a);
}
if let Some(a) = Self::detect_flash_loan_impl(pair, &data, detector, flash_loan_window) {
if let Some(a) = Self::detect_flash_loan_impl(pair, &data, detector, flash_loan_window)
{
anomalies.push(a);
}
if ml_enabled {
@ -493,7 +528,13 @@ impl AnomalyDetector {
anomalies
}
fn detect_price_outlier_impl(pair: &str, data: &MarketDataPoint, detector: &PairDetector, min_data_points: usize, z_score_threshold: f64) -> Option<Anomaly> {
fn detect_price_outlier_impl(
pair: &str,
data: &MarketDataPoint,
detector: &PairDetector,
min_data_points: usize,
z_score_threshold: f64,
) -> Option<Anomaly> {
if detector.price_stats.count < min_data_points {
return None;
}
@ -531,7 +572,13 @@ impl AnomalyDetector {
}
}
fn detect_volume_spike_impl(pair: &str, data: &MarketDataPoint, detector: &PairDetector, min_data_points: usize, volume_spike_multiplier: f64) -> Option<Anomaly> {
fn detect_volume_spike_impl(
pair: &str,
data: &MarketDataPoint,
detector: &PairDetector,
min_data_points: usize,
volume_spike_multiplier: f64,
) -> Option<Anomaly> {
if detector.volume_stats.count < min_data_points {
return None;
}
@ -550,7 +597,9 @@ impl AnomalyDetector {
confidence: 0.75,
description: format!(
"Volume {} is {:.1}x the average {:.2}",
data.volume, volume_f64 / mean, mean
data.volume,
volume_f64 / mean,
mean
),
data: AnomalyData {
current_value: data.volume,
@ -566,7 +615,12 @@ impl AnomalyDetector {
}
}
fn detect_wash_trading_impl(pair: &str, data: &MarketDataPoint, detector: &PairDetector, wash_trading_window: i64) -> Option<Anomaly> {
fn detect_wash_trading_impl(
pair: &str,
data: &MarketDataPoint,
detector: &PairDetector,
wash_trading_window: i64,
) -> Option<Anomaly> {
if data.addresses.is_empty() {
return None;
}
@ -617,14 +671,20 @@ impl AnomalyDetector {
None
}
fn detect_pump_dump_impl(pair: &str, detector: &PairDetector, pump_dump_window: i64) -> Option<Anomaly> {
fn detect_pump_dump_impl(
pair: &str,
detector: &PairDetector,
pump_dump_window: i64,
) -> Option<Anomaly> {
// Need enough history
if detector.price_history.len() < 10 {
return None;
}
let window_start = Utc::now() - Duration::minutes(pump_dump_window);
let prices: Vec<_> = detector.price_history.iter()
let prices: Vec<_> = detector
.price_history
.iter()
.filter(|p| p.timestamp >= window_start)
.map(|p| p.price.to_string().parse::<f64>().unwrap_or(0.0))
.collect();
@ -635,7 +695,9 @@ impl AnomalyDetector {
// Find max and check for reversal
let max_price = prices.iter().copied().fold(f64::MIN, f64::max);
let max_idx = prices.iter().position(|&p| (p - max_price).abs() < f64::EPSILON)?;
let max_idx = prices
.iter()
.position(|&p| (p - max_price).abs() < f64::EPSILON)?;
let first_price = prices.first()?;
let last_price = prices.last()?;
@ -672,10 +734,17 @@ impl AnomalyDetector {
None
}
fn detect_flash_loan_impl(pair: &str, data: &MarketDataPoint, detector: &PairDetector, flash_loan_window: i64) -> Option<Anomaly> {
fn detect_flash_loan_impl(
pair: &str,
data: &MarketDataPoint,
detector: &PairDetector,
flash_loan_window: i64,
) -> Option<Anomaly> {
// Flash loan signature: huge volume spike + quick price movement + reversal
let window_start = Utc::now() - Duration::seconds(flash_loan_window);
let recent: Vec<_> = detector.price_history.iter()
let recent: Vec<_> = detector
.price_history
.iter()
.filter(|p| p.timestamp >= window_start)
.collect();
@ -683,11 +752,13 @@ impl AnomalyDetector {
return None;
}
let volumes: Vec<f64> = recent.iter()
let volumes: Vec<f64> = recent
.iter()
.map(|p| p.volume.to_string().parse::<f64>().unwrap_or(0.0))
.collect();
let prices: Vec<f64> = recent.iter()
let prices: Vec<f64> = recent
.iter()
.map(|p| p.price.to_string().parse::<f64>().unwrap_or(0.0))
.collect();
@ -706,7 +777,10 @@ impl AnomalyDetector {
// Big spike and quick reversal
if spike > 10.0 && reversal > 8.0 {
let mut context = HashMap::new();
context.insert("volume_spike".to_string(), format!("{:.0}x", max_volume / avg_volume));
context.insert(
"volume_spike".to_string(),
format!("{:.0}x", max_volume / avg_volume),
);
context.insert("price_spike".to_string(), format!("{:.1}%", spike));
return Some(Anomaly {
@ -717,7 +791,8 @@ impl AnomalyDetector {
confidence: 0.65,
description: format!(
"Suspected flash loan attack: {}x volume, {:.1}% price spike",
max_volume / avg_volume, spike
max_volume / avg_volume,
spike
),
data: AnomalyData {
current_value: data.volume,
@ -734,7 +809,11 @@ impl AnomalyDetector {
None
}
fn detect_ml_anomaly_impl(pair: &str, data: &MarketDataPoint, detector: &PairDetector) -> Option<Anomaly> {
fn detect_ml_anomaly_impl(
pair: &str,
data: &MarketDataPoint,
detector: &PairDetector,
) -> Option<Anomaly> {
let forest = detector.isolation_forest.as_ref()?;
if detector.price_history.is_empty() {
@ -756,7 +835,11 @@ impl AnomalyDetector {
(Some(bid), Some(ask)) => {
let bid_f = bid.to_string().parse::<f64>().ok()?;
let ask_f = ask.to_string().parse::<f64>().ok()?;
if ask_f > 0.0 { bid_f / ask_f } else { 1.0 }
if ask_f > 0.0 {
bid_f / ask_f
} else {
1.0
}
}
_ => 1.0,
};
@ -774,10 +857,7 @@ impl AnomalyDetector {
detected_at: Utc::now(),
severity: score,
confidence: 0.6,
description: format!(
"ML model detected anomaly with score {:.3}",
score
),
description: format!("ML model detected anomaly with score {:.3}", score),
data: AnomalyData {
current_value: data.price,
expected_value: Decimal::from_f64_retain(detector.price_stats.mean)?,
@ -798,7 +878,8 @@ impl AnomalyDetector {
/// Get recent anomalies for a pair
pub fn get_anomalies(&self, pair: &str) -> Vec<Anomaly> {
self.detectors.get(pair)
self.detectors
.get(pair)
.map(|d| d.anomalies.clone())
.unwrap_or_default()
}
@ -807,11 +888,14 @@ impl AnomalyDetector {
pub fn get_stats(&self, pair: &str) -> Option<AnomalyStats> {
let detector = self.detectors.get(pair)?;
let by_type: HashMap<AnomalyType, usize> = detector.anomalies.iter()
.fold(HashMap::new(), |mut acc, a| {
*acc.entry(a.anomaly_type.clone()).or_insert(0) += 1;
acc
});
let by_type: HashMap<AnomalyType, usize> =
detector
.anomalies
.iter()
.fold(HashMap::new(), |mut acc, a| {
*acc.entry(a.anomaly_type.clone()).or_insert(0) += 1;
acc
});
Some(AnomalyStats {
total_anomalies: detector.anomalies.len(),
@ -913,7 +997,9 @@ mod tests {
let anomalies = detector.process("SYNOR/USD", outlier);
assert!(!anomalies.is_empty());
assert!(anomalies.iter().any(|a| a.anomaly_type == AnomalyType::PriceOutlier));
assert!(anomalies
.iter()
.any(|a| a.anomaly_type == AnomalyType::PriceOutlier));
}
#[test]
@ -946,6 +1032,8 @@ mod tests {
};
let anomalies = detector.process("SYNOR/USD", spike);
assert!(anomalies.iter().any(|a| a.anomaly_type == AnomalyType::VolumeSpike));
assert!(anomalies
.iter()
.any(|a| a.anomaly_type == AnomalyType::VolumeSpike));
}
}

View file

@ -40,17 +40,11 @@ pub enum TriggerReason {
threshold: SynorDecimal,
},
/// Multiple oracle sources disagree
OracleDisagreement {
spread_percent: Decimal,
},
OracleDisagreement { spread_percent: Decimal },
/// Manual trigger by admin
ManualHalt {
reason: String,
},
ManualHalt { reason: String },
/// Cascade from related market
CascadeTrigger {
source_pair: String,
},
CascadeTrigger { source_pair: String },
}
/// Circuit breaker event
@ -98,12 +92,12 @@ pub struct CircuitBreakerConfig {
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
max_1m_change: Decimal::new(10, 2), // 10%
max_5m_change: Decimal::new(20, 2), // 20%
max_1h_change: Decimal::new(50, 2), // 50%
max_1m_change: Decimal::new(10, 2), // 10%
max_5m_change: Decimal::new(20, 2), // 20%
max_1h_change: Decimal::new(50, 2), // 50%
max_twap_deviation: Decimal::new(30, 2), // 30%
min_liquidity: Decimal::new(10000, 0), // $10k
max_oracle_spread: Decimal::new(5, 2), // 5%
min_liquidity: Decimal::new(10000, 0), // $10k
max_oracle_spread: Decimal::new(5, 2), // 5%
cooldown_duration: Duration::minutes(5),
recovery_checks: 3,
cascade_enabled: true,
@ -173,7 +167,12 @@ impl PairCircuitBreaker {
// Keep only last 24 hours
let cutoff = Utc::now() - Duration::hours(24);
while self.price_history.front().map(|s| s.timestamp < cutoff).unwrap_or(false) {
while self
.price_history
.front()
.map(|s| s.timestamp < cutoff)
.unwrap_or(false)
{
self.price_history.pop_front();
}
@ -192,7 +191,8 @@ impl PairCircuitBreaker {
fn get_price_at(&self, seconds_ago: i64) -> Option<SynorDecimal> {
let target = Utc::now() - Duration::seconds(seconds_ago);
self.price_history.iter()
self.price_history
.iter()
.rev()
.find(|s| s.timestamp <= target)
.map(|s| s.price)
@ -232,7 +232,9 @@ impl CircuitBreakerManager {
price: SynorDecimal,
liquidity: Option<SynorDecimal>,
) -> Result<CircuitState> {
let breaker = self.breakers.entry(pair.to_string())
let breaker = self
.breakers
.entry(pair.to_string())
.or_insert_with(PairCircuitBreaker::new);
// Use the convenience method for real-time price recording
@ -256,7 +258,9 @@ impl CircuitBreakerManager {
liquidity: Option<SynorDecimal>,
timestamp: DateTime<Utc>,
) -> Result<CircuitState> {
let breaker = self.breakers.entry(pair.to_string())
let breaker = self
.breakers
.entry(pair.to_string())
.or_insert_with(PairCircuitBreaker::new);
breaker.record_price_at(price, liquidity, timestamp);
@ -273,22 +277,26 @@ impl CircuitBreakerManager {
/// Check all trigger conditions
fn check_triggers(&mut self, pair: &str) -> Result<()> {
let breaker = self.breakers.get(pair).ok_or_else(||
EconomicsError::PriceFeedUnavailable(pair.to_string())
)?;
let breaker = self
.breakers
.get(pair)
.ok_or_else(|| EconomicsError::PriceFeedUnavailable(pair.to_string()))?;
let current = breaker.current_price().ok_or_else(||
EconomicsError::PriceFeedUnavailable(pair.to_string())
)?;
let current = breaker
.current_price()
.ok_or_else(|| EconomicsError::PriceFeedUnavailable(pair.to_string()))?;
// Check 1-minute change
if let Some(price_1m) = breaker.get_price_at(60) {
let change = ((current - price_1m) / price_1m).abs();
if change > self.config.max_1m_change {
return self.trigger_breaker(pair, TriggerReason::RapidPriceChange {
change_percent: change * Decimal::ONE_HUNDRED,
window_seconds: 60,
});
return self.trigger_breaker(
pair,
TriggerReason::RapidPriceChange {
change_percent: change * Decimal::ONE_HUNDRED,
window_seconds: 60,
},
);
}
}
@ -296,10 +304,13 @@ impl CircuitBreakerManager {
if let Some(price_5m) = breaker.get_price_at(300) {
let change = ((current - price_5m) / price_5m).abs();
if change > self.config.max_5m_change {
return self.trigger_breaker(pair, TriggerReason::RapidPriceChange {
change_percent: change * Decimal::ONE_HUNDRED,
window_seconds: 300,
});
return self.trigger_breaker(
pair,
TriggerReason::RapidPriceChange {
change_percent: change * Decimal::ONE_HUNDRED,
window_seconds: 300,
},
);
}
}
@ -307,10 +318,13 @@ impl CircuitBreakerManager {
if let Some(price_1h) = breaker.get_price_at(3600) {
let change = ((current - price_1h) / price_1h).abs();
if change > self.config.max_1h_change {
return self.trigger_breaker(pair, TriggerReason::RapidPriceChange {
change_percent: change * Decimal::ONE_HUNDRED,
window_seconds: 3600,
});
return self.trigger_breaker(
pair,
TriggerReason::RapidPriceChange {
change_percent: change * Decimal::ONE_HUNDRED,
window_seconds: 3600,
},
);
}
}
@ -318,20 +332,26 @@ impl CircuitBreakerManager {
if let Some(twap) = breaker.twap_24h {
let deviation = ((current - twap) / twap).abs();
if deviation > self.config.max_twap_deviation {
return self.trigger_breaker(pair, TriggerReason::ExcessiveDeviation {
deviation_percent: deviation * Decimal::ONE_HUNDRED,
reference_price: twap,
});
return self.trigger_breaker(
pair,
TriggerReason::ExcessiveDeviation {
deviation_percent: deviation * Decimal::ONE_HUNDRED,
reference_price: twap,
},
);
}
}
// Check liquidity
if let Some(liquidity) = breaker.current_liquidity() {
if liquidity < self.config.min_liquidity {
return self.trigger_breaker(pair, TriggerReason::LowLiquidity {
current: liquidity,
threshold: self.config.min_liquidity,
});
return self.trigger_breaker(
pair,
TriggerReason::LowLiquidity {
current: liquidity,
threshold: self.config.min_liquidity,
},
);
}
}
@ -340,9 +360,10 @@ impl CircuitBreakerManager {
/// Trigger the circuit breaker
fn trigger_breaker(&mut self, pair: &str, reason: TriggerReason) -> Result<()> {
let breaker = self.breakers.get_mut(pair).ok_or_else(||
EconomicsError::PriceFeedUnavailable(pair.to_string())
)?;
let breaker = self
.breakers
.get_mut(pair)
.ok_or_else(|| EconomicsError::PriceFeedUnavailable(pair.to_string()))?;
let event = CircuitEvent {
pair: pair.to_string(),
@ -361,7 +382,10 @@ impl CircuitBreakerManager {
// Check cascade triggers
if self.config.cascade_enabled {
let cascades: Vec<_> = self.config.cascade_pairs.iter()
let cascades: Vec<_> = self
.config
.cascade_pairs
.iter()
.filter(|(source, _)| source == pair)
.map(|(_, target)| target.clone())
.collect();
@ -407,10 +431,15 @@ impl CircuitBreakerManager {
// Get current state first (immutable borrow)
let (current_state, triggered_at, trigger_reason) = {
let breaker = self.breakers.get(pair).ok_or_else(||
EconomicsError::PriceFeedUnavailable(pair.to_string())
)?;
(breaker.state, breaker.triggered_at, breaker.trigger_reason.clone())
let breaker = self
.breakers
.get(pair)
.ok_or_else(|| EconomicsError::PriceFeedUnavailable(pair.to_string()))?;
(
breaker.state,
breaker.triggered_at,
breaker.trigger_reason.clone(),
)
};
// Check stability for half-open state (immutable borrow)
@ -421,9 +450,10 @@ impl CircuitBreakerManager {
};
// Now get mutable reference for updates
let breaker = self.breakers.get_mut(pair).ok_or_else(||
EconomicsError::PriceFeedUnavailable(pair.to_string())
)?;
let breaker = self
.breakers
.get_mut(pair)
.ok_or_else(|| EconomicsError::PriceFeedUnavailable(pair.to_string()))?;
match current_state {
CircuitState::Open => {
@ -435,9 +465,9 @@ impl CircuitBreakerManager {
pair: pair.to_string(),
from_state: CircuitState::Open,
to_state: CircuitState::HalfOpen,
reason: trigger_reason.clone().unwrap_or(
TriggerReason::ManualHalt { reason: "Unknown".into() }
),
reason: trigger_reason.clone().unwrap_or(TriggerReason::ManualHalt {
reason: "Unknown".into(),
}),
timestamp: Utc::now(),
cooldown: None,
};
@ -457,9 +487,9 @@ impl CircuitBreakerManager {
pair: pair.to_string(),
from_state: CircuitState::HalfOpen,
to_state: CircuitState::Closed,
reason: trigger_reason.unwrap_or(
TriggerReason::ManualHalt { reason: "Recovery".into() }
),
reason: trigger_reason.unwrap_or(TriggerReason::ManualHalt {
reason: "Recovery".into(),
}),
timestamp: Utc::now(),
cooldown: None,
};
@ -482,9 +512,10 @@ impl CircuitBreakerManager {
/// Check if market conditions are stable
fn is_stable(&self, pair: &str) -> Result<bool> {
let breaker = self.breakers.get(pair).ok_or_else(||
EconomicsError::PriceFeedUnavailable(pair.to_string())
)?;
let breaker = self
.breakers
.get(pair)
.ok_or_else(|| EconomicsError::PriceFeedUnavailable(pair.to_string()))?;
let current = match breaker.current_price() {
Some(p) => p,
@ -511,7 +542,8 @@ impl CircuitBreakerManager {
/// Get current state for a pair
pub fn get_state(&self, pair: &str) -> CircuitState {
self.breakers.get(pair)
self.breakers
.get(pair)
.map(|b| b.state)
.unwrap_or(CircuitState::Closed)
}
@ -523,25 +555,32 @@ impl CircuitBreakerManager {
/// Manually trigger circuit breaker
pub fn manual_halt(&mut self, pair: &str, reason: impl Into<String>) -> Result<()> {
self.breakers.entry(pair.to_string())
self.breakers
.entry(pair.to_string())
.or_insert_with(PairCircuitBreaker::new);
self.trigger_breaker(pair, TriggerReason::ManualHalt {
reason: reason.into(),
})
self.trigger_breaker(
pair,
TriggerReason::ManualHalt {
reason: reason.into(),
},
)
}
/// Manually reset circuit breaker
pub fn manual_reset(&mut self, pair: &str) -> Result<()> {
let breaker = self.breakers.get_mut(pair).ok_or_else(||
EconomicsError::PriceFeedUnavailable(pair.to_string())
)?;
let breaker = self
.breakers
.get_mut(pair)
.ok_or_else(|| EconomicsError::PriceFeedUnavailable(pair.to_string()))?;
let event = CircuitEvent {
pair: pair.to_string(),
from_state: breaker.state,
to_state: CircuitState::Closed,
reason: TriggerReason::ManualHalt { reason: "Manual reset".into() },
reason: TriggerReason::ManualHalt {
reason: "Manual reset".into(),
},
timestamp: Utc::now(),
cooldown: None,
};
@ -558,26 +597,32 @@ impl CircuitBreakerManager {
/// Record oracle disagreement
pub fn record_oracle_spread(&mut self, pair: &str, spread: Decimal) -> Result<()> {
if spread > self.config.max_oracle_spread {
self.breakers.entry(pair.to_string())
self.breakers
.entry(pair.to_string())
.or_insert_with(PairCircuitBreaker::new);
self.trigger_breaker(pair, TriggerReason::OracleDisagreement {
spread_percent: spread * Decimal::ONE_HUNDRED,
})?;
self.trigger_breaker(
pair,
TriggerReason::OracleDisagreement {
spread_percent: spread * Decimal::ONE_HUNDRED,
},
)?;
}
Ok(())
}
/// Get event history for a pair
pub fn get_events(&self, pair: &str) -> Vec<CircuitEvent> {
self.breakers.get(pair)
self.breakers
.get(pair)
.map(|b| b.events.clone())
.unwrap_or_default()
}
/// Get all currently halted pairs
pub fn get_halted_pairs(&self) -> Vec<(String, CircuitState, Option<TriggerReason>)> {
self.breakers.iter()
self.breakers
.iter()
.filter(|(_, b)| b.state != CircuitState::Closed)
.map(|(pair, b)| (pair.clone(), b.state, b.trigger_reason.clone()))
.collect()
@ -586,8 +631,16 @@ impl CircuitBreakerManager {
/// Get summary statistics
pub fn get_stats(&self) -> CircuitBreakerStats {
let total = self.breakers.len();
let open = self.breakers.values().filter(|b| b.state == CircuitState::Open).count();
let half_open = self.breakers.values().filter(|b| b.state == CircuitState::HalfOpen).count();
let open = self
.breakers
.values()
.filter(|b| b.state == CircuitState::Open)
.count();
let half_open = self
.breakers
.values()
.filter(|b| b.state == CircuitState::HalfOpen)
.count();
let total_events: usize = self.breakers.values().map(|b| b.events.len()).sum();
CircuitBreakerStats {
@ -628,7 +681,9 @@ mod tests {
// Normal price movements should not trigger
for i in 0..10 {
let price = dec!(100) + Decimal::from(i);
let state = manager.record_price("SYNOR/USD", price, Some(dec!(100000))).unwrap();
let state = manager
.record_price("SYNOR/USD", price, Some(dec!(100000)))
.unwrap();
assert_eq!(state, CircuitState::Closed);
}
}
@ -641,10 +696,19 @@ mod tests {
let now = Utc::now();
// Record baseline 2 minutes ago
manager.record_price_at("SYNOR/USD", dec!(100), Some(dec!(100000)), now - Duration::minutes(2)).unwrap();
manager
.record_price_at(
"SYNOR/USD",
dec!(100),
Some(dec!(100000)),
now - Duration::minutes(2),
)
.unwrap();
// Simulate 15% drop (exceeds 10% 1-minute threshold)
let state = manager.record_price_at("SYNOR/USD", dec!(85), Some(dec!(100000)), now).unwrap();
let state = manager
.record_price_at("SYNOR/USD", dec!(85), Some(dec!(100000)), now)
.unwrap();
assert_eq!(state, CircuitState::Open);
}
@ -653,7 +717,9 @@ mod tests {
let mut manager = CircuitBreakerManager::new();
// Record with very low liquidity
let state = manager.record_price("SYNOR/USD", dec!(100), Some(dec!(100))).unwrap();
let state = manager
.record_price("SYNOR/USD", dec!(100), Some(dec!(100)))
.unwrap();
assert_eq!(state, CircuitState::Open);
}
@ -661,11 +727,15 @@ mod tests {
fn test_manual_halt_and_reset() {
let mut manager = CircuitBreakerManager::new();
manager.record_price("SYNOR/USD", dec!(100), Some(dec!(100000))).unwrap();
manager
.record_price("SYNOR/USD", dec!(100), Some(dec!(100000)))
.unwrap();
assert!(manager.is_trading_allowed("SYNOR/USD"));
// Manual halt
manager.manual_halt("SYNOR/USD", "Scheduled maintenance").unwrap();
manager
.manual_halt("SYNOR/USD", "Scheduled maintenance")
.unwrap();
assert!(!manager.is_trading_allowed("SYNOR/USD"));
// Manual reset
@ -678,10 +748,14 @@ mod tests {
let mut manager = CircuitBreakerManager::new();
// Initialize
manager.record_price("SYNOR/USD", dec!(100), Some(dec!(100000))).unwrap();
manager
.record_price("SYNOR/USD", dec!(100), Some(dec!(100000)))
.unwrap();
// Record 10% spread (exceeds 5% threshold)
manager.record_oracle_spread("SYNOR/USD", dec!(0.10)).unwrap();
manager
.record_oracle_spread("SYNOR/USD", dec!(0.10))
.unwrap();
assert_eq!(manager.get_state("SYNOR/USD"), CircuitState::Open);
}

View file

@ -200,7 +200,10 @@ pub struct CrossChainConfig {
impl Default for CrossChainConfig {
fn default() -> Self {
let mut tracked = HashMap::new();
tracked.insert(ChainNetwork::Ethereum, vec!["ETH".to_string(), "USDC".to_string(), "USDT".to_string()]);
tracked.insert(
ChainNetwork::Ethereum,
vec!["ETH".to_string(), "USDC".to_string(), "USDT".to_string()],
);
tracked.insert(ChainNetwork::Bitcoin, vec!["BTC".to_string()]);
tracked.insert(ChainNetwork::Cosmos, vec!["ATOM".to_string()]);
tracked.insert(ChainNetwork::Osmosis, vec!["OSMO".to_string()]);
@ -305,16 +308,21 @@ impl CrossChainOracle {
};
if !verified {
return Err(EconomicsError::InvalidPrice("Packet verification failed".into()));
return Err(EconomicsError::InvalidPrice(
"Packet verification failed".into(),
));
}
// Cache the price
let pair_key = format!("{}/{}", packet.token, packet.quote);
self.cache.insert(pair_key, CrossChainPrice {
packet,
received_at: Utc::now(),
verified,
});
self.cache.insert(
pair_key,
CrossChainPrice {
packet,
received_at: Utc::now(),
verified,
},
);
Ok(())
}
@ -326,7 +334,9 @@ impl CrossChainOracle {
token: &str,
quote: &str,
) -> Result<CrossChainPricePacket> {
let fetcher = self.fetchers.get(&chain)
let fetcher = self
.fetchers
.get(&chain)
.ok_or_else(|| EconomicsError::PriceFeedUnavailable(format!("{:?}", chain)))?;
let packet = fetcher.fetch_price(token, quote).await?;
@ -334,11 +344,14 @@ impl CrossChainOracle {
// Verify and cache
if fetcher.verify_packet(&packet) {
let pair_key = format!("{}/{}", token, quote);
self.cache.insert(pair_key.clone(), CrossChainPrice {
packet: packet.clone(),
received_at: Utc::now(),
verified: true,
});
self.cache.insert(
pair_key.clone(),
CrossChainPrice {
packet: packet.clone(),
received_at: Utc::now(),
verified: true,
},
);
}
Ok(packet)
@ -347,7 +360,8 @@ impl CrossChainOracle {
/// Get cached price for a token pair
pub fn get_price(&self, token: &str, quote: &str) -> Option<SynorDecimal> {
let pair_key = format!("{}/{}", token, quote);
self.cache.get(&pair_key)
self.cache
.get(&pair_key)
.filter(|c| c.verified)
.filter(|c| (Utc::now() - c.received_at).num_seconds() < self.config.max_packet_age)
.map(|c| c.packet.price)
@ -356,7 +370,8 @@ impl CrossChainOracle {
/// Get price with full packet info
pub fn get_price_with_info(&self, token: &str, quote: &str) -> Option<&CrossChainPricePacket> {
let pair_key = format!("{}/{}", token, quote);
self.cache.get(&pair_key)
self.cache
.get(&pair_key)
.filter(|c| c.verified)
.map(|c| &c.packet)
}
@ -408,7 +423,8 @@ impl CrossChainOracle {
/// Get all cached prices
pub fn get_all_prices(&self) -> Vec<TokenPrice> {
self.cache.values()
self.cache
.values()
.filter(|c| c.verified)
.map(|c| c.packet.to_token_price())
.collect()
@ -417,9 +433,8 @@ impl CrossChainOracle {
/// Clear stale cache entries
pub fn cleanup_cache(&mut self) {
let max_age = self.config.max_packet_age;
self.cache.retain(|_, v| {
(Utc::now() - v.received_at).num_seconds() < max_age
});
self.cache
.retain(|_, v| (Utc::now() - v.received_at).num_seconds() < max_age);
}
/// Send an IBC price request and track pending packet
@ -450,7 +465,11 @@ impl CrossChainOracle {
}
/// Confirm a pending packet was received
pub fn confirm_pending_packet(&mut self, channel: &str, sequence: u64) -> Option<PendingPacket> {
pub fn confirm_pending_packet(
&mut self,
channel: &str,
sequence: u64,
) -> Option<PendingPacket> {
if let Some(idx) = self
.pending_packets
.iter()
@ -469,9 +488,8 @@ impl CrossChainOracle {
/// Cleanup timed out pending packets
pub fn cleanup_pending(&mut self, timeout_secs: i64) {
self.pending_packets.retain(|p| {
(Utc::now() - p.sent_at).num_seconds() < timeout_secs
});
self.pending_packets
.retain(|p| (Utc::now() - p.sent_at).num_seconds() < timeout_secs);
}
}
@ -502,9 +520,18 @@ pub struct EthereumPriceFetcher {
impl EthereumPriceFetcher {
pub fn new(rpc_url: impl Into<String>) -> Self {
let mut feeds = HashMap::new();
feeds.insert("ETH/USD".to_string(), "0x5f4eC3Df9cbd43714FE2740f5E3616155c5b8419".to_string());
feeds.insert("BTC/USD".to_string(), "0xF4030086522a5bEEa4988F8cA5B36dbC97BeE88c".to_string());
feeds.insert("USDC/USD".to_string(), "0x8fFfFfd4AfB6115b954Bd326cbe7B4BA576818f6".to_string());
feeds.insert(
"ETH/USD".to_string(),
"0x5f4eC3Df9cbd43714FE2740f5E3616155c5b8419".to_string(),
);
feeds.insert(
"BTC/USD".to_string(),
"0xF4030086522a5bEEa4988F8cA5B36dbC97BeE88c".to_string(),
);
feeds.insert(
"USDC/USD".to_string(),
"0x8fFfFfd4AfB6115b954Bd326cbe7B4BA576818f6".to_string(),
);
Self {
rpc_url: rpc_url.into(),
@ -526,7 +553,9 @@ impl ChainPriceFetcher for EthereumPriceFetcher {
async fn fetch_price(&self, token: &str, quote: &str) -> Result<CrossChainPricePacket> {
let pair = format!("{}/{}", token, quote);
let _feed_addr = self.chainlink_feeds.get(&pair)
let _feed_addr = self
.chainlink_feeds
.get(&pair)
.ok_or_else(|| EconomicsError::PriceFeedUnavailable(pair.clone()))?;
// In production: Call Chainlink aggregator via ethers-rs
@ -539,19 +568,16 @@ impl ChainPriceFetcher for EthereumPriceFetcher {
source_block: 19000000,
source_timestamp: Utc::now(),
proof: None,
signatures: vec![
OracleSignature {
signer: "chainlink".to_string(),
signature: vec![0; 65],
timestamp: Utc::now(),
},
],
signatures: vec![OracleSignature {
signer: "chainlink".to_string(),
signature: vec![0; 65],
timestamp: Utc::now(),
}],
})
}
fn verify_packet(&self, packet: &CrossChainPricePacket) -> bool {
packet.source_chain == ChainNetwork::Ethereum
&& !packet.signatures.is_empty()
packet.source_chain == ChainNetwork::Ethereum && !packet.signatures.is_empty()
}
fn supported_tokens(&self) -> Vec<String> {
@ -611,8 +637,7 @@ impl ChainPriceFetcher for CosmosPriceFetcher {
}
fn verify_packet(&self, packet: &CrossChainPricePacket) -> bool {
packet.source_chain == ChainNetwork::Cosmos
&& packet.proof.is_some()
packet.source_chain == ChainNetwork::Cosmos && packet.proof.is_some()
}
fn supported_tokens(&self) -> Vec<String> {
@ -650,7 +675,11 @@ impl CrossChainOracleBuilder {
}
/// Add Cosmos/IBC price fetcher
pub fn with_cosmos(mut self, light_client_id: impl Into<String>, chain_id: impl Into<String>) -> Self {
pub fn with_cosmos(
mut self,
light_client_id: impl Into<String>,
chain_id: impl Into<String>,
) -> Self {
self.cosmos_light_client = Some((light_client_id.into(), chain_id.into()));
self
}
@ -693,8 +722,7 @@ impl CrossChainOracleFactory {
/// Create a production oracle with real endpoints
pub fn production(config: CrossChainProductionConfig) -> CrossChainOracle {
let mut builder = CrossChainOracleBuilder::new()
.with_config(config.cross_chain_config);
let mut builder = CrossChainOracleBuilder::new().with_config(config.cross_chain_config);
if let Some(eth_rpc) = config.ethereum_rpc_url {
builder = builder.with_ethereum(eth_rpc);
@ -715,7 +743,10 @@ impl CrossChainOracleFactory {
}
/// Create an oracle with only Cosmos/IBC support
pub fn cosmos_only(light_client_id: impl Into<String>, chain_id: impl Into<String>) -> CrossChainOracle {
pub fn cosmos_only(
light_client_id: impl Into<String>,
chain_id: impl Into<String>,
) -> CrossChainOracle {
CrossChainOracleBuilder::new()
.with_cosmos(light_client_id, chain_id)
.build()

View file

@ -112,7 +112,9 @@ impl AggregationRound {
/// Add a submission to this round
pub fn add_submission(&mut self, submission: PriceSubmission) -> Result<()> {
if self.finalized {
return Err(EconomicsError::InvalidPrice("Round already finalized".into()));
return Err(EconomicsError::InvalidPrice(
"Round already finalized".into(),
));
}
if Utc::now() >= self.deadline {
return Err(EconomicsError::InvalidPrice("Round deadline passed".into()));
@ -122,7 +124,11 @@ impl AggregationRound {
}
// Check for duplicate submission from same node
if self.submissions.iter().any(|s| s.node_id == submission.node_id) {
if self
.submissions
.iter()
.any(|s| s.node_id == submission.node_id)
{
return Err(EconomicsError::InvalidPrice("Duplicate submission".into()));
}
@ -231,7 +237,9 @@ impl DecentralizedOracle {
/// Update node heartbeat
pub fn heartbeat(&mut self, node_id: &str) -> Result<()> {
let node = self.nodes.get_mut(node_id)
let node = self
.nodes
.get_mut(node_id)
.ok_or_else(|| EconomicsError::InvalidPrice(format!("Unknown node: {}", node_id)))?;
node.last_heartbeat = Utc::now();
Ok(())
@ -241,11 +249,8 @@ impl DecentralizedOracle {
pub fn start_round(&mut self, pair: impl Into<String>) -> u64 {
let pair = pair.into();
self.round_counter += 1;
let round = AggregationRound::new(
self.round_counter,
pair.clone(),
self.config.round_duration
);
let round =
AggregationRound::new(self.round_counter, pair.clone(), self.config.round_duration);
self.current_rounds.insert(pair, round);
self.round_counter
}
@ -253,7 +258,9 @@ impl DecentralizedOracle {
/// Submit a price for the current round
pub fn submit_price(&mut self, submission: PriceSubmission) -> Result<()> {
// Verify node exists and is eligible
let node = self.nodes.get(&submission.node_id)
let node = self
.nodes
.get(&submission.node_id)
.ok_or_else(|| EconomicsError::InvalidPrice("Unknown node".into()))?;
if !node.is_eligible(self.config.min_stake, self.config.min_reputation) {
@ -264,7 +271,9 @@ impl DecentralizedOracle {
// For now, we trust the submission
// Add to current round
let round = self.current_rounds.get_mut(&submission.pair)
let round = self
.current_rounds
.get_mut(&submission.pair)
.ok_or_else(|| EconomicsError::InvalidPrice("No active round for pair".into()))?;
round.add_submission(submission)
@ -274,13 +283,15 @@ impl DecentralizedOracle {
pub fn finalize_round(&mut self, pair: &str) -> Result<SynorDecimal> {
// First check state and get submissions (immutable borrow)
let (is_finalized, existing_price, submissions) = {
let round = self.current_rounds.get(pair)
let round = self
.current_rounds
.get(pair)
.ok_or_else(|| EconomicsError::PriceFeedUnavailable(pair.to_string()))?;
if round.finalized {
return round.final_price.ok_or_else(||
EconomicsError::InvalidPrice("Round has no price".into())
);
return round
.final_price
.ok_or_else(|| EconomicsError::InvalidPrice("Round has no price".into()));
}
// Check minimum submissions
@ -292,13 +303,16 @@ impl DecentralizedOracle {
)));
}
(round.finalized, round.final_price, round.submissions.clone())
(
round.finalized,
round.final_price,
round.submissions.clone(),
)
};
if is_finalized {
return existing_price.ok_or_else(||
EconomicsError::InvalidPrice("Round has no price".into())
);
return existing_price
.ok_or_else(|| EconomicsError::InvalidPrice("Round has no price".into()));
}
// Filter outliers and aggregate (using cloned submissions)
@ -326,7 +340,11 @@ impl DecentralizedOracle {
}
/// Aggregate prices from a vector of submissions (owned)
fn aggregate_prices_from_vec(&self, pair: &str, submissions: &[PriceSubmission]) -> Result<SynorDecimal> {
fn aggregate_prices_from_vec(
&self,
pair: &str,
submissions: &[PriceSubmission],
) -> Result<SynorDecimal> {
if submissions.is_empty() {
return Err(EconomicsError::PriceFeedUnavailable(pair.to_string()));
}
@ -334,15 +352,21 @@ impl DecentralizedOracle {
// Filter outliers first
let filtered = self.filter_outliers_vec(submissions);
if filtered.is_empty() {
return Err(EconomicsError::InvalidPrice("All submissions were outliers".into()));
return Err(EconomicsError::InvalidPrice(
"All submissions were outliers".into(),
));
}
let filtered_refs: Vec<_> = filtered.iter().collect();
match self.strategy {
AggregationStrategy::Median => self.calculate_median(&filtered_refs),
AggregationStrategy::StakeWeightedMedian => self.calculate_stake_weighted_median(&filtered_refs),
AggregationStrategy::StakeWeightedMedian => {
self.calculate_stake_weighted_median(&filtered_refs)
}
AggregationStrategy::TrimmedMean => self.calculate_trimmed_mean(&filtered_refs),
AggregationStrategy::ReputationWeighted => self.calculate_reputation_weighted(&filtered_refs),
AggregationStrategy::ReputationWeighted => {
self.calculate_reputation_weighted(&filtered_refs)
}
}
}
@ -376,13 +400,14 @@ impl DecentralizedOracle {
}
/// Calculate stake-weighted median
fn calculate_stake_weighted_median(&self, submissions: &[&PriceSubmission]) -> Result<SynorDecimal> {
fn calculate_stake_weighted_median(
&self,
submissions: &[&PriceSubmission],
) -> Result<SynorDecimal> {
// Get stake for each submission
let mut weighted: Vec<(SynorDecimal, SynorDecimal)> = submissions
.iter()
.filter_map(|s| {
self.nodes.get(&s.node_id).map(|n| (s.price, n.stake))
})
.filter_map(|s| self.nodes.get(&s.node_id).map(|n| (s.price, n.stake)))
.collect();
if weighted.is_empty() {
@ -424,12 +449,17 @@ impl DecentralizedOracle {
}
/// Calculate reputation-weighted average
fn calculate_reputation_weighted(&self, submissions: &[&PriceSubmission]) -> Result<SynorDecimal> {
fn calculate_reputation_weighted(
&self,
submissions: &[&PriceSubmission],
) -> Result<SynorDecimal> {
let mut weighted_sum = Decimal::ZERO;
let mut total_weight = Decimal::ZERO;
for sub in submissions {
let reputation = self.nodes.get(&sub.node_id)
let reputation = self
.nodes
.get(&sub.node_id)
.map(|n| n.reputation)
.unwrap_or(0.5);
@ -448,7 +478,9 @@ impl DecentralizedOracle {
/// Update node reputations based on submission accuracy
fn update_reputations(&mut self, _pair: &str, final_price: SynorDecimal) {
// Get submissions from current round before it was moved
let submissions: Vec<_> = self.history.last()
let submissions: Vec<_> = self
.history
.last()
.map(|r| r.submissions.clone())
.unwrap_or_default();
@ -457,7 +489,8 @@ impl DecentralizedOracle {
let deviation = (sub.price - final_price).abs() / final_price;
// Increase reputation for accurate submissions, decrease for inaccurate
if deviation <= Decimal::new(1, 2) { // Within 1%
if deviation <= Decimal::new(1, 2) {
// Within 1%
node.reputation = (node.reputation + 0.01).min(1.0);
} else if deviation > self.config.max_deviation {
node.reputation = (node.reputation - 0.05).max(0.0);
@ -473,7 +506,8 @@ impl DecentralizedOracle {
/// Get number of active nodes
pub fn active_node_count(&self) -> usize {
self.nodes.values()
self.nodes
.values()
.filter(|n| n.is_eligible(self.config.min_stake, self.config.min_reputation))
.count()
}
@ -492,20 +526,23 @@ impl DecentralizedOracle {
/// Convert finalized price to TokenPrice
pub fn to_token_price(&self, pair: &str) -> Option<TokenPrice> {
self.history.iter()
self.history
.iter()
.rev()
.find(|r| r.pair == pair && r.finalized)
.and_then(|r| r.final_price.map(|price| {
let parts: Vec<_> = pair.split('/').collect();
TokenPrice {
token: parts.get(0).unwrap_or(&"").to_string(),
quote: parts.get(1).unwrap_or(&"").to_string(),
price,
timestamp: r.deadline,
source: PriceSource::Aggregated,
confidence: 1.0,
}
}))
.and_then(|r| {
r.final_price.map(|price| {
let parts: Vec<_> = pair.split('/').collect();
TokenPrice {
token: parts.get(0).unwrap_or(&"").to_string(),
quote: parts.get(1).unwrap_or(&"").to_string(),
price,
timestamp: r.deadline,
source: PriceSource::Aggregated,
confidence: 1.0,
}
})
})
}
}
@ -521,13 +558,15 @@ mod tests {
use rust_decimal_macros::dec;
fn create_test_nodes() -> Vec<OracleNode> {
(0..5).map(|i| {
OracleNode::new(
format!("node_{}", i),
vec![i as u8; 32],
dec!(10000), // 10k stake
)
}).collect()
(0..5)
.map(|i| {
OracleNode::new(
format!("node_{}", i),
vec![i as u8; 32],
dec!(10000), // 10k stake
)
})
.collect()
}
#[test]

View file

@ -6,8 +6,8 @@
use crate::error::{EconomicsError, Result};
use crate::SynorDecimal;
use chrono::{DateTime, Duration, Timelike, Utc};
use rust_decimal::Decimal;
use rust_decimal::prelude::ToPrimitive;
use rust_decimal::Decimal;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::f64::consts::PI;
@ -178,7 +178,12 @@ impl BlackScholes {
}
/// Price a European option
pub fn price(&self, contract: &OptionContract, spot: SynorDecimal, vol: f64) -> Result<OptionPricing> {
pub fn price(
&self,
contract: &OptionContract,
spot: SynorDecimal,
vol: f64,
) -> Result<OptionPricing> {
if contract.is_expired() {
// At expiration, option is worth intrinsic value
let intrinsic = contract.intrinsic_value(spot);
@ -208,12 +213,13 @@ impl BlackScholes {
});
}
let s = spot.to_f64().ok_or_else(||
EconomicsError::InvalidPrice("Invalid spot price".into())
)?;
let k = contract.strike.to_f64().ok_or_else(||
EconomicsError::InvalidPrice("Invalid strike price".into())
)?;
let s = spot
.to_f64()
.ok_or_else(|| EconomicsError::InvalidPrice("Invalid spot price".into()))?;
let k = contract
.strike
.to_f64()
.ok_or_else(|| EconomicsError::InvalidPrice("Invalid strike price".into()))?;
let t = contract.time_to_expiry();
if t <= 0.0 || vol <= 0.0 {
@ -286,13 +292,10 @@ impl BlackScholes {
let theta_common = -(s * vol * (-q * t).exp() * n_prime_d1) / (2.0 * sqrt_t);
let theta = match contract.option_type {
OptionType::Call => {
theta_common
+ q * s * (-q * t).exp() * n_d1
- r * k * (-r * t).exp() * n_d2
theta_common + q * s * (-q * t).exp() * n_d1 - r * k * (-r * t).exp() * n_d2
}
OptionType::Put => {
theta_common
- q * s * (-q * t).exp() * (1.0 - n_d1)
theta_common - q * s * (-q * t).exp() * (1.0 - n_d1)
+ r * k * (-r * t).exp() * (1.0 - n_d2)
}
} / 365.0; // Per day
@ -327,16 +330,16 @@ impl BlackScholes {
spot: SynorDecimal,
market_price: SynorDecimal,
) -> Result<f64> {
let target = market_price.to_f64().ok_or_else(||
EconomicsError::InvalidPrice("Invalid market price".into())
)?;
let target = market_price
.to_f64()
.ok_or_else(|| EconomicsError::InvalidPrice("Invalid market price".into()))?;
// Initial guess based on time value
let intrinsic = contract.intrinsic_value(spot).to_f64().unwrap_or(0.0);
let time_value = (target - intrinsic).max(0.0);
let s = spot.to_f64().ok_or_else(||
EconomicsError::InvalidPrice("Invalid spot".into())
)?;
let s = spot
.to_f64()
.ok_or_else(|| EconomicsError::InvalidPrice("Invalid spot".into()))?;
let t = contract.time_to_expiry();
// Brenner-Subrahmanyam approximation for initial guess
@ -381,7 +384,11 @@ impl BlackScholes {
for _ in 0..100 {
let mid = (low + high) / 2.0;
let price = self.price(contract, spot, mid)?.price.to_f64().unwrap_or(0.0);
let price = self
.price(contract, spot, mid)?
.price
.to_f64()
.unwrap_or(0.0);
if (price - target).abs() < 0.0001 {
return Ok(mid);
@ -493,9 +500,9 @@ impl FuturesModel {
/// F = S * e^((r + u - y) * T)
/// where r = risk-free rate, u = storage cost, y = convenience yield
pub fn price(&self, contract: &FuturesContract, spot: SynorDecimal) -> Result<FuturesPricing> {
let s = spot.to_f64().ok_or_else(||
EconomicsError::InvalidPrice("Invalid spot price".into())
)?;
let s = spot
.to_f64()
.ok_or_else(|| EconomicsError::InvalidPrice("Invalid spot price".into()))?;
let t = contract.time_to_expiry();
if t < 0.0 {
@ -515,8 +522,7 @@ impl FuturesModel {
0.0
};
let coc = Decimal::from_f64_retain(cost_of_carry * t * s)
.unwrap_or(Decimal::ZERO);
let coc = Decimal::from_f64_retain(cost_of_carry * t * s).unwrap_or(Decimal::ZERO);
Ok(FuturesPricing {
fair_value,
@ -529,13 +535,18 @@ impl FuturesModel {
/// Calculate implied repo rate from futures price
/// R = (F/S - 1) / T
pub fn implied_repo_rate(&self, contract: &FuturesContract, spot: SynorDecimal, futures_price: SynorDecimal) -> Result<f64> {
let s = spot.to_f64().ok_or_else(||
EconomicsError::InvalidPrice("Invalid spot".into())
)?;
let f = futures_price.to_f64().ok_or_else(||
EconomicsError::InvalidPrice("Invalid futures price".into())
)?;
pub fn implied_repo_rate(
&self,
contract: &FuturesContract,
spot: SynorDecimal,
futures_price: SynorDecimal,
) -> Result<f64> {
let s = spot
.to_f64()
.ok_or_else(|| EconomicsError::InvalidPrice("Invalid spot".into()))?;
let f = futures_price
.to_f64()
.ok_or_else(|| EconomicsError::InvalidPrice("Invalid futures price".into()))?;
let t = contract.time_to_expiry();
if t <= 0.0 || s <= 0.0 {
@ -595,12 +606,12 @@ impl PerpetualModel {
mark_price: SynorDecimal,
index_price: SynorDecimal,
) -> Result<f64> {
let mark = mark_price.to_f64().ok_or_else(||
EconomicsError::InvalidPrice("Invalid mark price".into())
)?;
let index = index_price.to_f64().ok_or_else(||
EconomicsError::InvalidPrice("Invalid index price".into())
)?;
let mark = mark_price
.to_f64()
.ok_or_else(|| EconomicsError::InvalidPrice("Invalid mark price".into()))?;
let index = index_price
.to_f64()
.ok_or_else(|| EconomicsError::InvalidPrice("Invalid index price".into()))?;
if index <= 0.0 {
return Err(EconomicsError::InvalidPrice("Invalid index".into()));
@ -640,7 +651,8 @@ impl PerpetualModel {
let hours_since_midnight = now.time().hour();
let next_funding_hour = ((hours_since_midnight / self.funding_interval_hours) + 1)
* self.funding_interval_hours;
let next_funding = now.date_naive()
let next_funding = now
.date_naive()
.and_hms_opt(next_funding_hour % 24, 0, 0)
.map(|dt| DateTime::from_naive_utc_and_offset(dt, Utc))
.unwrap_or(now + Duration::hours(self.funding_interval_hours as i64));
@ -721,7 +733,8 @@ impl DerivativesOracle {
/// Set volatility surface for an underlying
pub fn set_vol_surface(&mut self, surface: VolatilitySurface) {
self.vol_surfaces.insert(surface.underlying.clone(), surface);
self.vol_surfaces
.insert(surface.underlying.clone(), surface);
}
/// Price an option using the volatility surface
@ -730,12 +743,13 @@ impl DerivativesOracle {
contract: &OptionContract,
spot: SynorDecimal,
) -> Result<OptionPricing> {
let s = spot.to_f64().ok_or_else(||
EconomicsError::InvalidPrice("Invalid spot".into())
)?;
let k = contract.strike.to_f64().ok_or_else(||
EconomicsError::InvalidPrice("Invalid strike".into())
)?;
let s = spot
.to_f64()
.ok_or_else(|| EconomicsError::InvalidPrice("Invalid spot".into()))?;
let k = contract
.strike
.to_f64()
.ok_or_else(|| EconomicsError::InvalidPrice("Invalid strike".into()))?;
// Get volatility from surface or use default
let vol = if let Some(surface) = self.vol_surfaces.get(&contract.underlying) {
@ -765,7 +779,8 @@ impl DerivativesOracle {
mark_price: SynorDecimal,
open_interest: SynorDecimal,
) -> Result<PerpetualPricing> {
self.perpetual_model.price(index_price, mark_price, open_interest)
self.perpetual_model
.price(index_price, mark_price, open_interest)
}
/// Calculate implied vol from market price
@ -775,7 +790,8 @@ impl DerivativesOracle {
spot: SynorDecimal,
market_price: SynorDecimal,
) -> Result<f64> {
self.options_model.implied_volatility(contract, spot, market_price)
self.options_model
.implied_volatility(contract, spot, market_price)
}
}
@ -792,8 +808,18 @@ mod tests {
#[test]
fn test_option_intrinsic_value() {
let call = OptionContract::new("ETH", dec!(2000), Utc::now() + Duration::days(30), OptionType::Call);
let put = OptionContract::new("ETH", dec!(2000), Utc::now() + Duration::days(30), OptionType::Put);
let call = OptionContract::new(
"ETH",
dec!(2000),
Utc::now() + Duration::days(30),
OptionType::Call,
);
let put = OptionContract::new(
"ETH",
dec!(2000),
Utc::now() + Duration::days(30),
OptionType::Put,
);
// ITM call
assert_eq!(call.intrinsic_value(dec!(2100)), dec!(100));
@ -880,7 +906,9 @@ mod tests {
let pricing = model.price(&contract, dec!(2000), vol).unwrap();
// Calculate IV from that price
let iv = model.implied_volatility(&contract, dec!(2000), pricing.price).unwrap();
let iv = model
.implied_volatility(&contract, dec!(2000), pricing.price)
.unwrap();
// Should match original vol
assert!((iv - vol).abs() < 0.01);

View file

@ -40,12 +40,12 @@ impl CollateralAsset {
pub fn standard(symbol: impl Into<String>) -> Self {
Self {
symbol: symbol.into(),
collateral_factor: Decimal::new(75, 2), // 75%
liquidation_threshold: Decimal::new(80, 2), // 80%
liquidation_bonus: Decimal::new(5, 2), // 5%
collateral_factor: Decimal::new(75, 2), // 75%
liquidation_threshold: Decimal::new(80, 2), // 80%
liquidation_bonus: Decimal::new(5, 2), // 5%
supply_cap: None,
borrow_enabled: true,
reserve_factor: Decimal::new(10, 2), // 10%
reserve_factor: Decimal::new(10, 2), // 10%
volatility_multiplier: Decimal::ONE,
}
}
@ -54,12 +54,12 @@ impl CollateralAsset {
pub fn stablecoin(symbol: impl Into<String>) -> Self {
Self {
symbol: symbol.into(),
collateral_factor: Decimal::new(90, 2), // 90%
liquidation_threshold: Decimal::new(95, 2), // 95%
liquidation_bonus: Decimal::new(2, 2), // 2%
collateral_factor: Decimal::new(90, 2), // 90%
liquidation_threshold: Decimal::new(95, 2), // 95%
liquidation_bonus: Decimal::new(2, 2), // 2%
supply_cap: None,
borrow_enabled: true,
reserve_factor: Decimal::new(5, 2), // 5%
reserve_factor: Decimal::new(5, 2), // 5%
volatility_multiplier: Decimal::ONE,
}
}
@ -68,13 +68,13 @@ impl CollateralAsset {
pub fn volatile(symbol: impl Into<String>) -> Self {
Self {
symbol: symbol.into(),
collateral_factor: Decimal::new(50, 2), // 50%
liquidation_threshold: Decimal::new(65, 2), // 65%
liquidation_bonus: Decimal::new(10, 2), // 10%
collateral_factor: Decimal::new(50, 2), // 50%
liquidation_threshold: Decimal::new(65, 2), // 65%
liquidation_bonus: Decimal::new(10, 2), // 10%
supply_cap: None,
borrow_enabled: true,
reserve_factor: Decimal::new(20, 2), // 20%
volatility_multiplier: Decimal::new(12, 1), // 1.2x
reserve_factor: Decimal::new(20, 2), // 20%
volatility_multiplier: Decimal::new(12, 1), // 1.2x
}
}
}
@ -131,7 +131,11 @@ impl LendingPosition {
/// Withdraw collateral
pub fn withdraw(&mut self, asset: impl Into<String>, amount: SynorDecimal) -> Result<()> {
let asset = asset.into();
let current = self.collateral.get(&asset).copied().unwrap_or(Decimal::ZERO);
let current = self
.collateral
.get(&asset)
.copied()
.unwrap_or(Decimal::ZERO);
if amount > current {
return Err(EconomicsError::InsufficientFunds {
required: amount,
@ -233,10 +237,10 @@ pub struct LiquidationOracleConfig {
impl Default for LiquidationOracleConfig {
fn default() -> Self {
Self {
max_price_age: 60, // 1 minute (stricter than general oracle)
max_price_age: 60, // 1 minute (stricter than general oracle)
min_confidence: 0.9,
min_sources: 2,
liquidation_grace_period: 300, // 5 minutes
liquidation_grace_period: 300, // 5 minutes
min_liquidation_amount: Decimal::new(10, 0), // $10
max_liquidation_pct: Decimal::new(50, 2), // 50% at a time
partial_liquidation: true,
@ -289,7 +293,8 @@ impl LiquidationOracle {
/// Create a new position
pub fn create_position(&mut self, account_id: impl Into<String>) -> &mut LendingPosition {
let account_id = account_id.into();
self.positions.entry(account_id.clone())
self.positions
.entry(account_id.clone())
.or_insert_with(|| LendingPosition::new(account_id))
}
@ -320,7 +325,8 @@ impl LiquidationOracle {
freshness_remaining: self.config.max_price_age - age,
};
self.price_cache.insert(asset.to_string(), liq_price.clone());
self.price_cache
.insert(asset.to_string(), liq_price.clone());
Ok(liq_price)
}
@ -328,7 +334,9 @@ impl LiquidationOracle {
pub fn calculate_health(&mut self, account_id: &str) -> Result<HealthStatus> {
// Clone position data to avoid borrow conflicts with get_liquidation_price
let (collateral, borrows, interest_owed) = {
let position = self.positions.get(account_id)
let position = self
.positions
.get(account_id)
.ok_or_else(|| EconomicsError::AccountNotFound(account_id.to_string()))?;
(
position.collateral.clone(),
@ -352,7 +360,9 @@ impl LiquidationOracle {
});
}
let asset_config = self.assets.get(asset)
let asset_config = self
.assets
.get(asset)
.ok_or_else(|| EconomicsError::InvalidPrice(format!("Unknown asset: {}", asset)))?;
let value = *amount * price.price;
@ -426,24 +436,37 @@ impl LiquidationOracle {
return Err(EconomicsError::InvalidPrice("Position is healthy".into()));
}
let position = self.positions.get(account_id)
let position = self
.positions
.get(account_id)
.ok_or_else(|| EconomicsError::AccountNotFound(account_id.to_string()))?;
let debt_amount = position.borrows.get(debt_asset).copied().unwrap_or(Decimal::ZERO);
let collateral_amount = position.collateral.get(collateral_asset).copied().unwrap_or(Decimal::ZERO);
let debt_amount = position
.borrows
.get(debt_asset)
.copied()
.unwrap_or(Decimal::ZERO);
let collateral_amount = position
.collateral
.get(collateral_asset)
.copied()
.unwrap_or(Decimal::ZERO);
if debt_amount == Decimal::ZERO {
return Err(EconomicsError::InvalidPrice("No debt to repay".into()));
}
if collateral_amount == Decimal::ZERO {
return Err(EconomicsError::InvalidPrice("No collateral to seize".into()));
return Err(EconomicsError::InvalidPrice(
"No collateral to seize".into(),
));
}
let debt_price = self.get_liquidation_price(debt_asset)?;
let collateral_price = self.get_liquidation_price(collateral_asset)?;
let collateral_config = self.assets.get(collateral_asset)
.ok_or_else(|| EconomicsError::InvalidPrice(format!("Unknown asset: {}", collateral_asset)))?;
let collateral_config = self.assets.get(collateral_asset).ok_or_else(|| {
EconomicsError::InvalidPrice(format!("Unknown asset: {}", collateral_asset))
})?;
// Max debt repayable = close_factor * total_debt
let max_debt_repay = debt_amount * self.config.close_factor;
@ -458,12 +481,14 @@ impl LiquidationOracle {
let actual_collateral_seized = collateral_to_seize.min(collateral_amount);
let actual_debt_repaid = if actual_collateral_seized < collateral_to_seize {
// Partial liquidation
(actual_collateral_seized * collateral_price.price) / (bonus_multiplier * debt_price.price)
(actual_collateral_seized * collateral_price.price)
/ (bonus_multiplier * debt_price.price)
} else {
max_debt_repay
};
let bonus_amount = actual_collateral_seized * collateral_config.liquidation_bonus / bonus_multiplier;
let bonus_amount =
actual_collateral_seized * collateral_config.liquidation_bonus / bonus_multiplier;
Ok(LiquidationCalculation {
account_id: account_id.to_string(),
@ -489,7 +514,9 @@ impl LiquidationOracle {
let calc = self.calculate_liquidation(account_id, debt_asset, collateral_asset)?;
// Update position
let position = self.positions.get_mut(account_id)
let position = self
.positions
.get_mut(account_id)
.ok_or_else(|| EconomicsError::AccountNotFound(account_id.to_string()))?;
// Reduce debt
@ -549,7 +576,9 @@ impl LiquidationOracle {
// Protocol gets a portion of the liquidation bonus
if let Some(asset_config) = self.assets.get(&event.collateral_asset) {
let protocol_share = event.bonus_amount * asset_config.reserve_factor;
*reserves.entry(event.collateral_asset.clone()).or_insert(Decimal::ZERO) += protocol_share;
*reserves
.entry(event.collateral_asset.clone())
.or_insert(Decimal::ZERO) += protocol_share;
}
}
@ -561,15 +590,18 @@ impl LiquidationOracle {
let total_positions = self.positions.len();
let total_liquidations = self.liquidation_history.len();
let total_debt_liquidated: SynorDecimal = self.liquidation_history.iter()
.map(|e| e.debt_amount)
.sum();
let total_debt_liquidated: SynorDecimal =
self.liquidation_history.iter().map(|e| e.debt_amount).sum();
let total_collateral_seized: SynorDecimal = self.liquidation_history.iter()
let total_collateral_seized: SynorDecimal = self
.liquidation_history
.iter()
.map(|e| e.collateral_amount)
.sum();
let unique_liquidated: std::collections::HashSet<_> = self.liquidation_history.iter()
let unique_liquidated: std::collections::HashSet<_> = self
.liquidation_history
.iter()
.map(|e| &e.account_id)
.collect();
@ -617,18 +649,60 @@ mod tests {
let mut price_oracle = PriceOracle::with_config(OracleConfig::default());
// Add prices from multiple sources for test validity
price_oracle.update_price(TokenPrice::new("ETH", "USD", dec!(2000), PriceSource::Internal)).unwrap();
price_oracle.update_price(TokenPrice::new("ETH", "USD", dec!(2000), PriceSource::Aggregated)).unwrap();
price_oracle.update_price(TokenPrice::new("SYNOR", "USD", dec!(1), PriceSource::Internal)).unwrap();
price_oracle.update_price(TokenPrice::new("SYNOR", "USD", dec!(1), PriceSource::Aggregated)).unwrap();
price_oracle.update_price(TokenPrice::new("USDC", "USD", dec!(1), PriceSource::Internal)).unwrap();
price_oracle.update_price(TokenPrice::new("USDC", "USD", dec!(1), PriceSource::Aggregated)).unwrap();
price_oracle
.update_price(TokenPrice::new(
"ETH",
"USD",
dec!(2000),
PriceSource::Internal,
))
.unwrap();
price_oracle
.update_price(TokenPrice::new(
"ETH",
"USD",
dec!(2000),
PriceSource::Aggregated,
))
.unwrap();
price_oracle
.update_price(TokenPrice::new(
"SYNOR",
"USD",
dec!(1),
PriceSource::Internal,
))
.unwrap();
price_oracle
.update_price(TokenPrice::new(
"SYNOR",
"USD",
dec!(1),
PriceSource::Aggregated,
))
.unwrap();
price_oracle
.update_price(TokenPrice::new(
"USDC",
"USD",
dec!(1),
PriceSource::Internal,
))
.unwrap();
price_oracle
.update_price(TokenPrice::new(
"USDC",
"USD",
dec!(1),
PriceSource::Aggregated,
))
.unwrap();
// Use test config with relaxed freshness requirements
let test_config = LiquidationOracleConfig {
max_price_age: 3600, // 1 hour for tests
min_confidence: 0.5, // Lower confidence threshold
min_sources: 1, // Single source OK for tests
max_price_age: 3600, // 1 hour for tests
min_confidence: 0.5, // Lower confidence threshold
min_sources: 1, // Single source OK for tests
..Default::default()
};
@ -660,8 +734,8 @@ mod tests {
// Create position with good health
let pos = oracle.create_position("user1");
pos.deposit("ETH", dec!(1)); // $2000 worth
pos.borrow("USDC", dec!(500)); // Borrow $500
pos.deposit("ETH", dec!(1)); // $2000 worth
pos.borrow("USDC", dec!(500)); // Borrow $500
let health = oracle.calculate_health("user1").unwrap();
@ -678,8 +752,8 @@ mod tests {
// Create position close to liquidation
let pos = oracle.create_position("user2");
pos.deposit("ETH", dec!(1)); // $2000 worth
pos.borrow("USDC", dec!(1500)); // Borrow $1500
pos.deposit("ETH", dec!(1)); // $2000 worth
pos.borrow("USDC", dec!(1500)); // Borrow $1500
let health = oracle.calculate_health("user2").unwrap();
@ -699,7 +773,9 @@ mod tests {
pos.deposit("ETH", dec!(1));
pos.borrow("USDC", dec!(1500));
let calc = oracle.calculate_liquidation("user3", "USDC", "ETH").unwrap();
let calc = oracle
.calculate_liquidation("user3", "USDC", "ETH")
.unwrap();
// Should be able to liquidate
assert!(calc.debt_to_repay > Decimal::ZERO);

View file

@ -49,8 +49,7 @@ pub use price_feed::{
PriceSource,
};
pub use twap::{
OnChainTwap, OnChainTwapFactory, TwapCalculator, TwapConfig, TwapObservation,
TwapOracleBuilder,
OnChainTwap, OnChainTwapFactory, TwapCalculator, TwapConfig, TwapObservation, TwapOracleBuilder,
};
use crate::error::{EconomicsError, Result};
@ -241,7 +240,10 @@ impl PriceOracle {
/// Get price history for a pair
pub fn get_price_history(&self, token: &str, quote: &str) -> Vec<TokenPrice> {
let pair_key = format!("{}/{}", token, quote);
self.price_history.get(&pair_key).cloned().unwrap_or_default()
self.price_history
.get(&pair_key)
.cloned()
.unwrap_or_default()
}
/// Fetch prices from all configured feeds
@ -403,7 +405,9 @@ impl PriceOracle {
}
let healthy = !pairs_status.is_empty()
&& pairs_status.values().all(|s| !s.is_stale && s.price_count > 0);
&& pairs_status
.values()
.all(|s| !s.is_stale && s.price_count > 0);
OracleHealthStatus {
healthy,
@ -443,7 +447,10 @@ impl PriceOracleBuilder {
/// Add a mock price feed (for testing)
pub fn with_mock_feed(mut self, base_price: SynorDecimal) -> Self {
use price_feed::MockPriceFeed;
self.feeds.push(Box::new(MockPriceFeed::new(PriceSource::Internal, base_price)));
self.feeds.push(Box::new(MockPriceFeed::new(
PriceSource::Internal,
base_price,
)));
self
}
@ -455,9 +462,14 @@ impl PriceOracleBuilder {
}
/// Add Chainlink oracle feed
pub fn with_chainlink(mut self, contract_address: impl Into<String>, rpc_url: impl Into<String>) -> Self {
pub fn with_chainlink(
mut self,
contract_address: impl Into<String>,
rpc_url: impl Into<String>,
) -> Self {
use price_feed::ChainlinkFeed;
self.feeds.push(Box::new(ChainlinkFeed::new(contract_address, rpc_url)));
self.feeds
.push(Box::new(ChainlinkFeed::new(contract_address, rpc_url)));
self
}
@ -508,8 +520,14 @@ impl OracleFactory {
let mut oracle = PriceOracle::new();
// Add multiple mock feeds with slight variations for testing
oracle.add_feed(Box::new(MockPriceFeed::new(PriceSource::Internal, base_price)));
oracle.add_feed(Box::new(MockPriceFeed::new(PriceSource::SynorDex, base_price)));
oracle.add_feed(Box::new(MockPriceFeed::new(
PriceSource::Internal,
base_price,
)));
oracle.add_feed(Box::new(MockPriceFeed::new(
PriceSource::SynorDex,
base_price,
)));
oracle
}

View file

@ -110,7 +110,8 @@ impl PriceFeed for MockPriceFeed {
async fn fetch_price(&self, token: &str, quote: &str) -> Result<TokenPrice> {
// Add small random variance
let variance = (rand_simple() * 2.0 - 1.0) * self.volatility;
let price = self.base_price * (Decimal::ONE + Decimal::from_f64_retain(variance).unwrap_or_default());
let price = self.base_price
* (Decimal::ONE + Decimal::from_f64_retain(variance).unwrap_or_default());
Ok(TokenPrice {
token: token.to_string(),
@ -258,9 +259,12 @@ impl PriceFeed for CoinGeckoFeed {
"SYNOR" => "synor", // Would need actual CoinGecko ID
"BTC" => "bitcoin",
"ETH" => "ethereum",
_ => return Err(EconomicsError::PriceFeedUnavailable(
format!("Token {} not supported on CoinGecko", token)
)),
_ => {
return Err(EconomicsError::PriceFeedUnavailable(format!(
"Token {} not supported on CoinGecko",
token
)))
}
};
let quote_currency = quote.to_lowercase();
@ -294,8 +298,7 @@ impl PriceFeed for CoinGeckoFeed {
Ok(TokenPrice {
token: token.to_string(),
quote: quote.to_string(),
price: Decimal::from_f64_retain(price)
.unwrap_or_default(),
price: Decimal::from_f64_retain(price).unwrap_or_default(),
timestamp: Utc::now(),
source: PriceSource::CoinGecko,
confidence: 0.90,

View file

@ -116,8 +116,8 @@ impl TwapCalculator {
let duration = (interval_end - interval_start).num_seconds() as f64;
if duration > 0.0 {
let weight = Decimal::from_f64_retain(duration / total_duration)
.unwrap_or(Decimal::ZERO);
let weight =
Decimal::from_f64_retain(duration / total_duration).unwrap_or(Decimal::ZERO);
weighted_sum += price.price * weight;
total_weight += weight;
@ -343,7 +343,8 @@ impl OnChainTwap {
/// Apply the pending cardinality increase (called during next observation)
pub fn apply_cardinality_growth(&mut self) {
if self.cardinality_next > self.cardinality {
self.observations.reserve(self.cardinality_next - self.cardinality);
self.observations
.reserve(self.cardinality_next - self.cardinality);
self.cardinality = self.cardinality_next;
}
}
@ -396,7 +397,12 @@ impl TwapOracleBuilder {
}
/// Add an initial observation
pub fn with_observation(mut self, timestamp: DateTime<Utc>, price_cumulative: SynorDecimal, spl_cumulative: SynorDecimal) -> Self {
pub fn with_observation(
mut self,
timestamp: DateTime<Utc>,
price_cumulative: SynorDecimal,
spl_cumulative: SynorDecimal,
) -> Self {
self.initial_observations.push(TwapObservation {
timestamp,
price_cumulative,

View file

@ -76,7 +76,11 @@ pub struct Discount {
impl Discount {
/// Create a new percentage discount
pub fn percentage(code: impl Into<String>, name: impl Into<String>, percentage: SynorDecimal) -> Self {
pub fn percentage(
code: impl Into<String>,
name: impl Into<String>,
percentage: SynorDecimal,
) -> Self {
Self {
code: code.into(),
name: name.into(),
@ -96,7 +100,11 @@ impl Discount {
}
/// Create a fixed amount discount
pub fn fixed_amount(code: impl Into<String>, name: impl Into<String>, amount: SynorDecimal) -> Self {
pub fn fixed_amount(
code: impl Into<String>,
name: impl Into<String>,
amount: SynorDecimal,
) -> Self {
Self {
code: code.into(),
name: name.into(),
@ -120,7 +128,10 @@ impl Discount {
Self {
code: format!("VOLUME_{}", min_spend),
name: format!("Volume Discount ({} SYNOR+)", min_spend),
description: format!("{}% off when spending {} SYNOR or more", percentage, min_spend),
description: format!(
"{}% off when spending {} SYNOR or more",
percentage, min_spend
),
discount_type: DiscountType::Volume,
value: percentage,
min_spend: Some(min_spend),
@ -260,9 +271,11 @@ impl Discount {
}
let discount = match self.discount_type {
DiscountType::Percentage | DiscountType::Volume | DiscountType::Loyalty | DiscountType::Referral | DiscountType::Partner => {
amount * (self.value / Decimal::ONE_HUNDRED)
}
DiscountType::Percentage
| DiscountType::Volume
| DiscountType::Loyalty
| DiscountType::Referral
| DiscountType::Partner => amount * (self.value / Decimal::ONE_HUNDRED),
DiscountType::FixedAmount | DiscountType::Promotional => {
self.value.min(amount) // Can't discount more than amount
}
@ -298,7 +311,7 @@ impl Discount {
/// Volume discount tiers
pub fn standard_volume_discounts() -> Vec<Discount> {
vec![
Discount::volume(Decimal::new(100, 0), Decimal::new(5, 0)), // 5% at 100+ SYNOR
Discount::volume(Decimal::new(100, 0), Decimal::new(5, 0)), // 5% at 100+ SYNOR
Discount::volume(Decimal::new(500, 0), Decimal::new(10, 0)), // 10% at 500+ SYNOR
Discount::volume(Decimal::new(1000, 0), Decimal::new(15, 0)), // 15% at 1000+ SYNOR
Discount::volume(Decimal::new(5000, 0), Decimal::new(20, 0)), // 20% at 5000+ SYNOR
@ -309,11 +322,7 @@ pub fn standard_volume_discounts() -> Vec<Discount> {
pub fn find_best_volume_discount(amount: SynorDecimal) -> Option<Discount> {
standard_volume_discounts()
.into_iter()
.filter(|d| {
d.min_spend
.map(|min| amount >= min)
.unwrap_or(false)
})
.filter(|d| d.min_spend.map(|min| amount >= min).unwrap_or(false))
.max_by(|a, b| a.value.cmp(&b.value))
}
@ -378,8 +387,8 @@ mod tests {
#[test]
fn test_discount_usage_limit() {
let mut discount = Discount::percentage("LIMITED", "Limited Use", dec!(10))
.with_max_uses(2);
let mut discount =
Discount::percentage("LIMITED", "Limited Use", dec!(10)).with_max_uses(2);
assert!(discount.use_discount());
assert!(discount.use_discount());

View file

@ -198,15 +198,9 @@ impl PricingEngine {
.get(&service_type)
.ok_or_else(|| EconomicsError::ServiceNotConfigured(service_type.to_string()))?;
let unit_price = pricing
.base_prices
.get(&resource_unit)
.ok_or_else(|| {
EconomicsError::ServiceNotConfigured(format!(
"{} - {}",
service_type, resource_unit
))
})?;
let unit_price = pricing.base_prices.get(&resource_unit).ok_or_else(|| {
EconomicsError::ServiceNotConfigured(format!("{} - {}", service_type, resource_unit))
})?;
let cost = amount * unit_price;
@ -357,7 +351,8 @@ impl PricingEngine {
storage: ServicePricingSummary {
gb_month: self.get_base_price(ServiceType::Storage, ResourceUnit::GbMonth),
retrieval_gb: self.get_base_price(ServiceType::Storage, ResourceUnit::BandwidthGb),
free_storage_gb: self.get_free_allocation(ServiceType::Storage, ResourceUnit::GbMonth),
free_storage_gb: self
.get_free_allocation(ServiceType::Storage, ResourceUnit::GbMonth),
},
hosting: HostingPricingSummary {
bandwidth_gb: self.get_base_price(ServiceType::Hosting, ResourceUnit::BandwidthGb),
@ -377,7 +372,8 @@ impl PricingEngine {
.get_free_allocation(ServiceType::Database, ResourceUnit::Queries),
},
compute: ComputePricingSummary {
cpu_core_hour: self.get_base_price(ServiceType::Compute, ResourceUnit::CpuCoreHours),
cpu_core_hour: self
.get_base_price(ServiceType::Compute, ResourceUnit::CpuCoreHours),
gpu_hour: self.get_base_price(ServiceType::Compute, ResourceUnit::GpuHours),
memory_gb_hour: self
.get_base_price(ServiceType::Compute, ResourceUnit::MemoryGbHours),
@ -472,7 +468,11 @@ mod tests {
// 10 million queries
let cost = engine
.calculate_cost(ServiceType::Database, ResourceUnit::Queries, dec!(10_000_000))
.calculate_cost(
ServiceType::Database,
ResourceUnit::Queries,
dec!(10_000_000),
)
.unwrap();
assert_eq!(cost, dec!(0.10)); // 10M * 0.00000001
@ -484,7 +484,9 @@ mod tests {
// Premium tier gets 20% discount
let base_cost = dec!(100);
let discount = engine.calculate_tier_discount("premium", base_cost).unwrap();
let discount = engine
.calculate_tier_discount("premium", base_cost)
.unwrap();
assert_eq!(discount, dec!(20)); // 20%
}

View file

@ -108,8 +108,8 @@ impl PricingTier {
discount_percentage: Decimal::new(30, 0), // 30% discount
priority_support: true,
sla_percentage: Decimal::new(9999, 2), // 99.99% SLA
custom_domain_limit: 0, // Unlimited
api_rate_limit: 0, // Unlimited
custom_domain_limit: 0, // Unlimited
api_rate_limit: 0, // Unlimited
features: vec![
"Everything in Premium".to_string(),
"30%+ Usage Discount".to_string(),
@ -147,7 +147,9 @@ impl PricingTier {
/// Check if this tier has a feature
pub fn has_feature(&self, feature: &str) -> bool {
self.features.iter().any(|f| f.to_lowercase().contains(&feature.to_lowercase()))
self.features
.iter()
.any(|f| f.to_lowercase().contains(&feature.to_lowercase()))
}
/// Calculate effective monthly cost including usage discount
@ -163,7 +165,9 @@ impl PricingTier {
let other_cost = other.effective_cost(monthly_usage);
// Upgrade if other tier is cheaper or offers significant benefits
other_cost < current_cost || (other.sla_percentage > self.sla_percentage && other_cost <= current_cost * Decimal::new(12, 1))
other_cost < current_cost
|| (other.sla_percentage > self.sla_percentage
&& other_cost <= current_cost * Decimal::new(12, 1))
}
}

View file

@ -181,7 +181,12 @@ impl AuthService {
}
/// Generate a JWT token.
pub fn generate_token(&self, user_id: &str, tier: ApiKeyTier, permissions: Permissions) -> Result<String, ApiError> {
pub fn generate_token(
&self,
user_id: &str,
tier: ApiKeyTier,
permissions: Permissions,
) -> Result<String, ApiError> {
let now = Utc::now();
let exp = now + self.jwt_expiration;
@ -215,8 +220,7 @@ impl AuthService {
})?;
let claims = token_data.claims;
let expires_at = DateTime::from_timestamp(claims.exp, 0)
.map(|dt| dt.with_timezone(&Utc));
let expires_at = DateTime::from_timestamp(claims.exp, 0).map(|dt| dt.with_timezone(&Utc));
Ok(AuthContext {
user_id: claims.sub,
@ -278,9 +282,7 @@ impl AuthService {
// Try API key header
if let Some(api_key) = headers.get("X-API-Key") {
let key = api_key
.to_str()
.map_err(|_| ApiError::InvalidApiKey)?;
let key = api_key.to_str().map_err(|_| ApiError::InvalidApiKey)?;
return self.validate_api_key(key).await;
}
@ -295,8 +297,7 @@ impl AuthService {
let decoded = BASE64
.decode(encoded)
.map_err(|_| ApiError::InvalidApiKey)?;
let key = String::from_utf8(decoded)
.map_err(|_| ApiError::InvalidApiKey)?;
let key = String::from_utf8(decoded).map_err(|_| ApiError::InvalidApiKey)?;
return self.validate_api_key(&key).await;
}
}
@ -318,7 +319,9 @@ where
fn from_request_parts<'life0, 'life1, 'async_trait>(
parts: &'life0 mut Parts,
_state: &'life1 S,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Self, Self::Rejection>> + Send + 'async_trait>>
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<Self, Self::Rejection>> + Send + 'async_trait>,
>
where
'life0: 'async_trait,
'life1: 'async_trait,
@ -351,7 +354,9 @@ where
fn from_request_parts<'life0, 'life1, 'async_trait>(
parts: &'life0 mut Parts,
_state: &'life1 S,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Self, Self::Rejection>> + Send + 'async_trait>>
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<Self, Self::Rejection>> + Send + 'async_trait>,
>
where
'life0: 'async_trait,
'life1: 'async_trait,
@ -359,10 +364,7 @@ where
{
Box::pin(async move {
// Get auth service from extensions
let auth_service = parts
.extensions
.get::<AuthService>()
.cloned();
let auth_service = parts.extensions.get::<AuthService>().cloned();
if let Some(auth_service) = auth_service {
match auth_service.authenticate(&parts.headers).await {
@ -377,10 +379,7 @@ where
}
/// Require specific permissions.
pub fn require_permission(
context: &AuthContext,
permission: &str,
) -> Result<(), ApiError> {
pub fn require_permission(context: &AuthContext, permission: &str) -> Result<(), ApiError> {
let has_permission = match permission {
"read" => context.can_read(),
"write" => context.can_write(),
@ -397,10 +396,7 @@ pub fn require_permission(
}
/// Require access to a specific service.
pub fn require_service_access(
context: &AuthContext,
service: &str,
) -> Result<(), ApiError> {
pub fn require_service_access(context: &AuthContext, service: &str) -> Result<(), ApiError> {
if context.can_access_service(service) {
Ok(())
} else {

View file

@ -160,10 +160,26 @@ pub struct RateLimitTiers {
impl Default for RateLimitTiers {
fn default() -> Self {
Self {
free: TierConfig { rpm: 60, burst: 10, concurrent: 5 },
developer: TierConfig { rpm: 600, burst: 100, concurrent: 20 },
pro: TierConfig { rpm: 6000, burst: 1000, concurrent: 100 },
enterprise: TierConfig { rpm: 0, burst: 0, concurrent: 0 }, // Unlimited
free: TierConfig {
rpm: 60,
burst: 10,
concurrent: 5,
},
developer: TierConfig {
rpm: 600,
burst: 100,
concurrent: 20,
},
pro: TierConfig {
rpm: 6000,
burst: 1000,
concurrent: 100,
},
enterprise: TierConfig {
rpm: 0,
burst: 0,
concurrent: 0,
}, // Unlimited
}
}
}

View file

@ -170,9 +170,7 @@ impl ApiError {
| Self::ContractError(_) => StatusCode::UNPROCESSABLE_ENTITY,
// 429 Too Many Requests
Self::RateLimitExceeded | Self::TooManyRequests { .. } => {
StatusCode::TOO_MANY_REQUESTS
}
Self::RateLimitExceeded | Self::TooManyRequests { .. } => StatusCode::TOO_MANY_REQUESTS,
// 500 Internal Server Error
Self::InternalError | Self::Custom(_) => StatusCode::INTERNAL_SERVER_ERROR,
@ -222,7 +220,10 @@ impl ApiError {
/// Build error details with optional extra information.
pub fn to_details(&self) -> ErrorDetails {
let details = match self {
Self::InsufficientBalance { required, available } => Some(serde_json::json!({
Self::InsufficientBalance {
required,
available,
} => Some(serde_json::json!({
"required": required,
"available": available
})),
@ -257,10 +258,9 @@ impl IntoResponse for ApiError {
// Add rate limit headers for 429 errors
if let Self::TooManyRequests { retry_after } = &self {
response.headers_mut().insert(
"Retry-After",
retry_after.to_string().parse().unwrap(),
);
response
.headers_mut()
.insert("Retry-After", retry_after.to_string().parse().unwrap());
}
response

View file

@ -144,7 +144,8 @@ pub async fn timing_middleware(request: Request, next: Next) -> Response {
// Update metrics
metrics::counter!("http_requests_total", "method" => method.to_string(), "status" => status.as_u16().to_string()).increment(1);
metrics::histogram!("http_request_duration_seconds", "method" => method.to_string()).record(duration.as_secs_f64());
metrics::histogram!("http_request_duration_seconds", "method" => method.to_string())
.record(duration.as_secs_f64());
response
}
@ -169,7 +170,10 @@ impl RateLimiterState {
}
/// Get or create a rate limiter for an IP.
pub async fn get_ip_limiter(&self, ip: &str) -> Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>> {
pub async fn get_ip_limiter(
&self,
ip: &str,
) -> Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>> {
{
let limiters = self.ip_limiters.read().await;
if let Some(limiter) = limiters.get(ip) {
@ -189,7 +193,11 @@ impl RateLimiterState {
}
/// Get or create a rate limiter for an API key.
pub async fn get_key_limiter(&self, key_id: &str, tier: ApiKeyTier) -> Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>> {
pub async fn get_key_limiter(
&self,
key_id: &str,
tier: ApiKeyTier,
) -> Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>> {
{
let limiters = self.key_limiters.read().await;
if let Some(limiter) = limiters.get(key_id) {
@ -255,9 +263,7 @@ pub async fn rate_limit_middleware(
// Use a fixed retry time since we can't easily convert to quanta's instant
let retry_after = 60; // Default to 60 seconds
Err(ApiError::TooManyRequests {
retry_after,
})
Err(ApiError::TooManyRequests { retry_after })
}
}
}
@ -274,10 +280,7 @@ pub async fn auth_middleware(
}
/// API version middleware - validates version prefix.
pub async fn version_middleware(
request: Request,
next: Next,
) -> Result<Response, ApiError> {
pub async fn version_middleware(request: Request, next: Next) -> Result<Response, ApiError> {
let path = request.uri().path();
// Skip version check for health, metrics, and docs
@ -307,22 +310,13 @@ pub async fn security_headers_middleware(request: Request, next: Next) -> Respon
let headers = response.headers_mut();
// Prevent XSS
headers.insert(
"X-Content-Type-Options",
"nosniff".parse().unwrap(),
);
headers.insert("X-Content-Type-Options", "nosniff".parse().unwrap());
// Prevent clickjacking
headers.insert(
"X-Frame-Options",
"DENY".parse().unwrap(),
);
headers.insert("X-Frame-Options", "DENY".parse().unwrap());
// Enable XSS filter
headers.insert(
"X-XSS-Protection",
"1; mode=block".parse().unwrap(),
);
headers.insert("X-XSS-Protection", "1; mode=block".parse().unwrap());
// Strict transport security (HTTPS)
headers.insert(

View file

@ -6,11 +6,7 @@
//! - Contract analysis and validation
//! - Security scanning
use axum::{
extract::State,
routing::post,
Json, Router,
};
use axum::{extract::State, routing::post, Json, Router};
use serde::{Deserialize, Serialize};
use crate::{
@ -44,7 +40,7 @@ pub fn router() -> Router<AppState> {
// Types
#[derive(Debug, Deserialize)]
pub struct CompileRequest {
pub wasm: String, // base64 encoded WASM
pub wasm: String, // base64 encoded WASM
pub optimization_level: Option<String>, // none, basic, size, aggressive
pub strip_debug: Option<bool>,
pub strip_names: Option<bool>,
@ -216,16 +212,14 @@ async fn compile_contract(
abi: Some(ContractAbi {
name: "MyContract".to_string(),
version: "1.0.0".to_string(),
functions: vec![
AbiFunction {
name: "init".to_string(),
selector: "0x12345678".to_string(),
inputs: vec![],
outputs: vec![],
view: false,
payable: false,
},
],
functions: vec![AbiFunction {
name: "init".to_string(),
selector: "0x12345678".to_string(),
inputs: vec![],
outputs: vec![],
view: false,
payable: false,
}],
events: vec![],
errors: vec![],
}),
@ -342,23 +336,19 @@ async fn analyze_contract(
imports: 100,
total: 5000,
},
functions: vec![
FunctionAnalysis {
name: "init".to_string(),
size: 500,
instruction_count: 50,
local_count: 3,
exported: true,
estimated_gas: 10000,
},
],
imports: vec![
ImportInfo {
module: "env".to_string(),
name: "memory".to_string(),
kind: "memory".to_string(),
},
],
functions: vec![FunctionAnalysis {
name: "init".to_string(),
size: 500,
instruction_count: 50,
local_count: 3,
exported: true,
estimated_gas: 10000,
}],
imports: vec![ImportInfo {
module: "env".to_string(),
name: "memory".to_string(),
kind: "memory".to_string(),
}],
gas_analysis: GasAnalysis {
deployment_gas: 100000,
memory_init_gas: 5000,
@ -378,14 +368,12 @@ async fn security_scan(
let result = SecurityScanResult {
score: 85,
issues: vec![
SecurityIssue {
severity: "low".to_string(),
issue_type: "unbounded_loop".to_string(),
description: "Potential unbounded loop detected".to_string(),
location: Some("function:process".to_string()),
},
],
issues: vec![SecurityIssue {
severity: "low".to_string(),
issue_type: "unbounded_loop".to_string(),
description: "Potential unbounded loop detected".to_string(),
location: Some("function:process".to_string()),
}],
recommendations: vec![
"Add loop iteration limits".to_string(),
"Consider using checked arithmetic".to_string(),

View file

@ -122,17 +122,15 @@ async fn list_markets(
) -> ApiResult<Json<ApiResponse<Vec<Market>>>> {
require_permission(&auth, "read")?;
let markets = vec![
Market {
symbol: "ETH-USDC".to_string(),
base_asset: "ETH".to_string(),
quote_asset: "USDC".to_string(),
last_price: "3000.00".to_string(),
change_24h: "2.5".to_string(),
volume_24h: "10000000".to_string(),
status: "active".to_string(),
},
];
let markets = vec![Market {
symbol: "ETH-USDC".to_string(),
base_asset: "ETH".to_string(),
quote_asset: "USDC".to_string(),
last_price: "3000.00".to_string(),
change_24h: "2.5".to_string(),
volume_24h: "10000000".to_string(),
status: "active".to_string(),
}];
Ok(Json(ApiResponse::success(markets)))
}
@ -165,8 +163,14 @@ async fn get_orderbook(
require_permission(&auth, "read")?;
let orderbook = Orderbook {
bids: vec![OrderbookEntry { price: "2999.00".to_string(), quantity: "1.5".to_string() }],
asks: vec![OrderbookEntry { price: "3001.00".to_string(), quantity: "2.0".to_string() }],
bids: vec![OrderbookEntry {
price: "2999.00".to_string(),
quantity: "1.5".to_string(),
}],
asks: vec![OrderbookEntry {
price: "3001.00".to_string(),
quantity: "2.0".to_string(),
}],
spread: "2.00".to_string(),
};
@ -286,7 +290,9 @@ async fn place_perp_order(
Json(req): Json<serde_json::Value>,
) -> ApiResult<Json<ApiResponse<serde_json::Value>>> {
require_permission(&auth, "write")?;
Ok(Json(ApiResponse::success(serde_json::json!({"order_id": "perp_123"}))))
Ok(Json(ApiResponse::success(
serde_json::json!({"order_id": "perp_123"}),
)))
}
async fn list_pools(
@ -295,18 +301,16 @@ async fn list_pools(
) -> ApiResult<Json<ApiResponse<Vec<Pool>>>> {
require_permission(&auth, "read")?;
let pools = vec![
Pool {
pool_id: "ETH-USDC".to_string(),
name: "ETH/USDC".to_string(),
token_a: "ETH".to_string(),
token_b: "USDC".to_string(),
reserve_a: "1000".to_string(),
reserve_b: "3000000".to_string(),
tvl: "6000000".to_string(),
apr: "15.5".to_string(),
},
];
let pools = vec![Pool {
pool_id: "ETH-USDC".to_string(),
name: "ETH/USDC".to_string(),
token_a: "ETH".to_string(),
token_b: "USDC".to_string(),
reserve_a: "1000".to_string(),
reserve_b: "3000000".to_string(),
tvl: "6000000".to_string(),
apr: "15.5".to_string(),
}];
Ok(Json(ApiResponse::success(pools)))
}

View file

@ -128,16 +128,14 @@ async fn list_chains(
) -> ApiResult<Json<ApiResponse<Vec<Chain>>>> {
require_permission(&auth, "read")?;
let chains = vec![
Chain {
chain_id: "cosmoshub-4".to_string(),
name: "Cosmos Hub".to_string(),
status: "active".to_string(),
rpc_endpoint: "https://rpc.cosmos.network".to_string(),
latest_height: 18000000,
active_channels: 50,
},
];
let chains = vec![Chain {
chain_id: "cosmoshub-4".to_string(),
name: "Cosmos Hub".to_string(),
status: "active".to_string(),
rpc_endpoint: "https://rpc.cosmos.network".to_string(),
latest_height: 18000000,
active_channels: 50,
}];
Ok(Json(ApiResponse::success(chains)))
}
@ -280,15 +278,13 @@ async fn get_routes(
) -> ApiResult<Json<ApiResponse<Vec<TransferRoute>>>> {
require_permission(&auth, "read")?;
let routes = vec![
TransferRoute {
source_chain: "cosmoshub-4".to_string(),
dest_chain: "synor-mainnet".to_string(),
channel_id: "channel-0".to_string(),
estimated_time: "30s".to_string(),
fee: "0.001 ATOM".to_string(),
},
];
let routes = vec![TransferRoute {
source_chain: "cosmoshub-4".to_string(),
dest_chain: "synor-mainnet".to_string(),
channel_id: "channel-0".to_string(),
estimated_time: "30s".to_string(),
fee: "0.001 ATOM".to_string(),
}];
Ok(Json(ApiResponse::success(routes)))
}

View file

@ -303,13 +303,11 @@ async fn get_peers(
) -> ApiResult<Json<ApiResponse<Vec<serde_json::Value>>>> {
require_permission(&auth, "read")?;
let peers = vec![
serde_json::json!({
"id": "peer1",
"address": "192.168.1.1:16100",
"connected_since": 1705312200
})
];
let peers = vec![serde_json::json!({
"id": "peer1",
"address": "192.168.1.1:16100",
"connected_since": 1705312200
})];
Ok(Json(ApiResponse::success(peers)))
}

View file

@ -247,14 +247,12 @@ async fn list_directory(
) -> ApiResult<Json<ApiResponse<Vec<DirectoryEntry>>>> {
require_permission(&auth, "read")?;
let entries = vec![
DirectoryEntry {
name: "file1.txt".to_string(),
cid: "bafyfile1...".to_string(),
size: 1024,
is_directory: false,
},
];
let entries = vec![DirectoryEntry {
name: "file1.txt".to_string(),
cid: "bafyfile1...".to_string(),
size: 1024,
is_directory: false,
}];
Ok(Json(ApiResponse::success(entries)))
}

View file

@ -409,7 +409,10 @@ async fn list_addresses(
];
let pagination_meta = pagination.to_meta(addresses.len() as u64);
Ok(Json(ApiResponse::success_paginated(addresses, pagination_meta)))
Ok(Json(ApiResponse::success_paginated(
addresses,
pagination_meta,
)))
}
/// Generate a stealth address.
@ -477,7 +480,9 @@ async fn get_balances(
require_permission(&auth, "read")?;
if req.addresses.is_empty() {
return Err(ApiError::ValidationError("addresses cannot be empty".to_string()));
return Err(ApiError::ValidationError(
"addresses cannot be empty".to_string(),
));
}
if req.addresses.len() > 100 {
@ -585,12 +590,15 @@ async fn send_transaction(
}
// Validate amount
let amount: f64 = req.amount.parse().map_err(|_| {
ApiError::ValidationError("Invalid amount format".to_string())
})?;
let amount: f64 = req
.amount
.parse()
.map_err(|_| ApiError::ValidationError("Invalid amount format".to_string()))?;
if amount <= 0.0 {
return Err(ApiError::ValidationError("Amount must be positive".to_string()));
return Err(ApiError::ValidationError(
"Amount must be positive".to_string(),
));
}
// In production, build, sign, and broadcast the transaction

View file

@ -139,16 +139,14 @@ async fn list_circuits(
) -> ApiResult<Json<ApiResponse<Vec<Circuit>>>> {
require_permission(&auth, "read")?;
let circuits = vec![
Circuit {
circuit_id: "multiplier-v1".to_string(),
name: "Multiplier".to_string(),
constraints: 1,
public_inputs: 1,
private_inputs: 2,
outputs: 1,
},
];
let circuits = vec![Circuit {
circuit_id: "multiplier-v1".to_string(),
name: "Multiplier".to_string(),
constraints: 1,
public_inputs: 1,
private_inputs: 2,
outputs: 1,
}];
let meta = pagination.to_meta(circuits.len() as u64);
Ok(Json(ApiResponse::success_paginated(circuits, meta)))

View file

@ -17,15 +17,9 @@ use axum::{
Router,
};
use std::{net::SocketAddr, sync::Arc};
use tokio::{
net::TcpListener,
signal,
sync::oneshot,
};
use tokio::{net::TcpListener, signal, sync::oneshot};
use tower_http::{
compression::CompressionLayer,
limit::RequestBodyLimitLayer,
timeout::TimeoutLayer,
compression::CompressionLayer, limit::RequestBodyLimitLayer, timeout::TimeoutLayer,
trace::TraceLayer,
};
use tracing::info;

View file

@ -164,7 +164,8 @@ impl VersionRegistry {
/// Register a version.
pub fn register(&mut self, info: VersionInfo) {
self.versions.insert((info.version.major, info.version.minor), info);
self.versions
.insert((info.version.major, info.version.minor), info);
}
/// Get version info.
@ -330,10 +331,7 @@ pub async fn version_middleware(req: Request, next: Next) -> Response {
// Add deprecation headers if needed
if let Some(info) = registry.get(&extracted.version) {
if info.is_deprecated {
headers.insert(
X_API_DEPRECATED.clone(),
HeaderValue::from_static("true"),
);
headers.insert(X_API_DEPRECATED.clone(), HeaderValue::from_static("true"));
if let Some(deprecated_at) = &info.deprecated_at {
if let Ok(v) = HeaderValue::from_str(&deprecated_at.to_rfc3339()) {
@ -427,8 +425,8 @@ impl VersionsResponse {
// Routes
// ============================================================================
use axum::{routing::get, Json, Router};
use crate::routes::AppState;
use axum::{routing::get, Json, Router};
/// Build version routes.
pub fn router() -> Router<AppState> {

View file

@ -593,11 +593,7 @@ async fn ws_blocks_handler(
ws.on_upgrade(move |socket| handle_blocks_socket(socket, state, auth))
}
async fn handle_blocks_socket(
socket: WebSocket,
state: AppState,
_auth: Option<AuthContext>,
) {
async fn handle_blocks_socket(socket: WebSocket, state: AppState, _auth: Option<AuthContext>) {
let ws_state = &state.websocket;
ws_state.broadcaster.add_connection().await;
@ -834,11 +830,7 @@ async fn ws_markets_handler(
ws.on_upgrade(move |socket| handle_markets_socket(socket, state, auth))
}
async fn handle_markets_socket(
socket: WebSocket,
state: AppState,
_auth: Option<AuthContext>,
) {
async fn handle_markets_socket(socket: WebSocket, state: AppState, _auth: Option<AuthContext>) {
let ws_state = &state.websocket;
ws_state.broadcaster.add_connection().await;

View file

@ -7,8 +7,8 @@
//! hosting-gateway --domain synor.cc --storage-url http://localhost:8180
//! hosting-gateway --config /path/to/config.toml
use synor_hosting::{HostingGateway, GatewayConfig};
use std::net::SocketAddr;
use synor_hosting::{GatewayConfig, HostingGateway};
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
@ -19,28 +19,32 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
// Parse command line arguments
let args: Vec<String> = std::env::args().collect();
let listen_addr: SocketAddr = args.iter()
let listen_addr: SocketAddr = args
.iter()
.position(|a| a == "--listen" || a == "-l")
.and_then(|i| args.get(i + 1))
.and_then(|s| s.parse().ok())
.or_else(|| std::env::var("LISTEN_ADDR").ok()?.parse().ok())
.unwrap_or_else(|| "0.0.0.0:8080".parse().unwrap());
let hosting_domain = args.iter()
let hosting_domain = args
.iter()
.position(|a| a == "--domain" || a == "-d")
.and_then(|i| args.get(i + 1))
.cloned()
.or_else(|| std::env::var("HOSTING_DOMAIN").ok())
.unwrap_or_else(|| "synor.cc".to_string());
let storage_url = args.iter()
let storage_url = args
.iter()
.position(|a| a == "--storage-url" || a == "-s")
.and_then(|i| args.get(i + 1))
.cloned()
.or_else(|| std::env::var("STORAGE_GATEWAY_URL").ok())
.unwrap_or_else(|| "http://localhost:8180".to_string());
let rate_limit: u32 = args.iter()
let rate_limit: u32 = args
.iter()
.position(|a| a == "--rate-limit")
.and_then(|i| args.get(i + 1))
.and_then(|s| s.parse().ok())

View file

@ -233,11 +233,7 @@ impl EdgeCompute {
}
/// Run AI inference at the edge.
pub async fn inference(
&self,
_model: &str,
_input: &[u8],
) -> Result<Vec<u8>, EdgeError> {
pub async fn inference(&self, _model: &str, _input: &[u8]) -> Result<Vec<u8>, EdgeError> {
if !self.enabled {
return Err(EdgeError::NotEnabled);
}

Some files were not shown because too many files have changed in this diff Show more