style: apply cargo fmt formatting
This commit is contained in:
parent
5126c33113
commit
dcd1cccc67
170 changed files with 4463 additions and 2837 deletions
|
|
@ -256,7 +256,11 @@ pub async fn handle(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
CompilerCommands::Encode { function, args, abi } => {
|
||||
CompilerCommands::Encode {
|
||||
function,
|
||||
args,
|
||||
abi,
|
||||
} => {
|
||||
output::print_info(&format!("Encoding call to: {}", function));
|
||||
output::print_kv("Arguments", &args);
|
||||
if let Some(a) = abi {
|
||||
|
|
@ -268,7 +272,11 @@ pub async fn handle(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
CompilerCommands::Decode { data, function, abi } => {
|
||||
CompilerCommands::Decode {
|
||||
data,
|
||||
function,
|
||||
abi,
|
||||
} => {
|
||||
output::print_info(&format!("Decoding result for: {}", function));
|
||||
output::print_kv("Data", &data);
|
||||
if let Some(a) = abi {
|
||||
|
|
@ -314,7 +322,11 @@ pub async fn handle(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
CompilerCommands::SecurityScan { wasm, min_severity, format: _ } => {
|
||||
CompilerCommands::SecurityScan {
|
||||
wasm,
|
||||
min_severity,
|
||||
format: _,
|
||||
} => {
|
||||
output::print_info(&format!("Security scan: {}", wasm.display()));
|
||||
output::print_kv("Min severity", &min_severity);
|
||||
|
||||
|
|
@ -344,7 +356,11 @@ pub async fn handle(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
CompilerCommands::Validate { wasm, exports, max_memory } => {
|
||||
CompilerCommands::Validate {
|
||||
wasm,
|
||||
exports,
|
||||
max_memory,
|
||||
} => {
|
||||
output::print_info(&format!("Validating: {}", wasm.display()));
|
||||
|
||||
if let Some(e) = exports {
|
||||
|
|
|
|||
|
|
@ -196,8 +196,7 @@ pub async fn deploy(
|
|||
}
|
||||
|
||||
// Determine output directory
|
||||
let output_path = output_dir
|
||||
.unwrap_or_else(|| cwd.join(config.output_dir()));
|
||||
let output_path = output_dir.unwrap_or_else(|| cwd.join(config.output_dir()));
|
||||
|
||||
if !output_path.exists() {
|
||||
return Err(anyhow!(
|
||||
|
|
@ -270,7 +269,10 @@ fn validate_name(name: &str) -> Result<()> {
|
|||
if name.len() > 63 {
|
||||
return Err(anyhow!("Name must be 63 characters or less"));
|
||||
}
|
||||
if !name.chars().all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '-') {
|
||||
if !name
|
||||
.chars()
|
||||
.all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '-')
|
||||
{
|
||||
return Err(anyhow!(
|
||||
"Name must contain only lowercase letters, numbers, and hyphens"
|
||||
));
|
||||
|
|
@ -281,8 +283,8 @@ fn validate_name(name: &str) -> Result<()> {
|
|||
|
||||
// Reserved names
|
||||
const RESERVED: &[&str] = &[
|
||||
"www", "api", "app", "admin", "mail", "ftp", "ssh", "cdn",
|
||||
"storage", "gateway", "hosting", "node", "synor",
|
||||
"www", "api", "app", "admin", "mail", "ftp", "ssh", "cdn", "storage", "gateway", "hosting",
|
||||
"node", "synor",
|
||||
];
|
||||
if RESERVED.contains(&name) {
|
||||
return Err(anyhow!("Name '{}' is reserved", name));
|
||||
|
|
@ -397,11 +399,7 @@ fn guess_content_type(path: &Path) -> String {
|
|||
}
|
||||
|
||||
/// Upload files to Synor Storage.
|
||||
async fn upload_files(
|
||||
base_dir: &Path,
|
||||
files: &[DeployFile],
|
||||
gateway_url: &str,
|
||||
) -> Result<String> {
|
||||
async fn upload_files(base_dir: &Path, files: &[DeployFile], gateway_url: &str) -> Result<String> {
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
// Create a multipart form with all files
|
||||
|
|
@ -445,11 +443,7 @@ async fn upload_files(
|
|||
}
|
||||
|
||||
/// Register the deployment with the hosting gateway.
|
||||
async fn register_deployment(
|
||||
name: &str,
|
||||
cid: &str,
|
||||
gateway_url: &str,
|
||||
) -> Result<String> {
|
||||
async fn register_deployment(name: &str, cid: &str, gateway_url: &str) -> Result<String> {
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
#[derive(Serialize)]
|
||||
|
|
@ -662,7 +656,11 @@ pub async fn delete(name: &str, gateway_url: &str, format: OutputFormat) -> Resu
|
|||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let body = response.text().await.unwrap_or_default();
|
||||
return Err(anyhow!("Failed to delete deployment: {} - {}", status, body));
|
||||
return Err(anyhow!(
|
||||
"Failed to delete deployment: {} - {}",
|
||||
status,
|
||||
body
|
||||
));
|
||||
}
|
||||
|
||||
match format {
|
||||
|
|
@ -707,22 +705,13 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_guess_content_type() {
|
||||
assert_eq!(
|
||||
guess_content_type(Path::new("index.html")),
|
||||
"text/html"
|
||||
);
|
||||
assert_eq!(
|
||||
guess_content_type(Path::new("style.css")),
|
||||
"text/css"
|
||||
);
|
||||
assert_eq!(guess_content_type(Path::new("index.html")), "text/html");
|
||||
assert_eq!(guess_content_type(Path::new("style.css")), "text/css");
|
||||
assert_eq!(
|
||||
guess_content_type(Path::new("app.js")),
|
||||
"application/javascript"
|
||||
);
|
||||
assert_eq!(
|
||||
guess_content_type(Path::new("image.png")),
|
||||
"image/png"
|
||||
);
|
||||
assert_eq!(guess_content_type(Path::new("image.png")), "image/png");
|
||||
assert_eq!(
|
||||
guess_content_type(Path::new("data.wasm")),
|
||||
"application/wasm"
|
||||
|
|
|
|||
|
|
@ -246,7 +246,13 @@ pub async fn handle(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
DexCommands::PlaceOrder { market, side, price, quantity, wallet } => {
|
||||
DexCommands::PlaceOrder {
|
||||
market,
|
||||
side,
|
||||
price,
|
||||
quantity,
|
||||
wallet,
|
||||
} => {
|
||||
output::print_info("Placing limit order...");
|
||||
output::print_kv("Market", &market);
|
||||
output::print_kv("Side", &side);
|
||||
|
|
@ -257,7 +263,12 @@ pub async fn handle(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
DexCommands::MarketOrder { market, side, quantity, wallet } => {
|
||||
DexCommands::MarketOrder {
|
||||
market,
|
||||
side,
|
||||
quantity,
|
||||
wallet,
|
||||
} => {
|
||||
output::print_info("Placing market order...");
|
||||
output::print_kv("Market", &market);
|
||||
output::print_kv("Side", &side);
|
||||
|
|
@ -275,7 +286,10 @@ pub async fn handle(
|
|||
|
||||
DexCommands::CancelAll { market, wallet } => {
|
||||
let scope = market.unwrap_or_else(|| "all markets".to_string());
|
||||
output::print_info(&format!("Cancelling all orders in {} for {}", scope, wallet));
|
||||
output::print_info(&format!(
|
||||
"Cancelling all orders in {} for {}",
|
||||
scope, wallet
|
||||
));
|
||||
output::print_success("3 orders cancelled");
|
||||
Ok(())
|
||||
}
|
||||
|
|
@ -317,7 +331,12 @@ pub async fn handle(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
DexCommands::AddLiquidity { pool_id, amount_a, amount_b, wallet } => {
|
||||
DexCommands::AddLiquidity {
|
||||
pool_id,
|
||||
amount_a,
|
||||
amount_b,
|
||||
wallet,
|
||||
} => {
|
||||
output::print_info("Adding liquidity...");
|
||||
output::print_kv("Pool", &pool_id);
|
||||
output::print_kv("Amount A", &amount_a);
|
||||
|
|
@ -327,7 +346,11 @@ pub async fn handle(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
DexCommands::RemoveLiquidity { pool_id, lp_amount, wallet } => {
|
||||
DexCommands::RemoveLiquidity {
|
||||
pool_id,
|
||||
lp_amount,
|
||||
wallet,
|
||||
} => {
|
||||
output::print_info("Removing liquidity...");
|
||||
output::print_kv("Pool", &pool_id);
|
||||
output::print_kv("LP Amount", &lp_amount);
|
||||
|
|
|
|||
|
|
@ -169,11 +169,7 @@ pub enum ZkCommands {
|
|||
}
|
||||
|
||||
/// Handle ZK commands.
|
||||
pub async fn handle(
|
||||
_client: &RpcClient,
|
||||
command: ZkCommands,
|
||||
_format: OutputFormat,
|
||||
) -> Result<()> {
|
||||
pub async fn handle(_client: &RpcClient, command: ZkCommands, _format: OutputFormat) -> Result<()> {
|
||||
match command {
|
||||
ZkCommands::Compile { circuit, output } => {
|
||||
output::print_info(&format!("Compiling circuit: {}", circuit.display()));
|
||||
|
|
@ -211,8 +207,16 @@ pub async fn handle(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
ZkCommands::ProveGroth16 { circuit, witness, proving_key: _, output } => {
|
||||
output::print_info(&format!("Generating Groth16 proof for circuit: {}", circuit));
|
||||
ZkCommands::ProveGroth16 {
|
||||
circuit,
|
||||
witness,
|
||||
proving_key: _,
|
||||
output,
|
||||
} => {
|
||||
output::print_info(&format!(
|
||||
"Generating Groth16 proof for circuit: {}",
|
||||
circuit
|
||||
));
|
||||
output::print_info(&format!("Witness: {}", witness.display()));
|
||||
output::print_info("Computing witness...");
|
||||
output::print_info("Generating proof...");
|
||||
|
|
@ -226,7 +230,11 @@ pub async fn handle(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
ZkCommands::ProvePlonk { circuit, witness, output } => {
|
||||
ZkCommands::ProvePlonk {
|
||||
circuit,
|
||||
witness,
|
||||
output,
|
||||
} => {
|
||||
output::print_info(&format!("Generating PLONK proof for circuit: {}", circuit));
|
||||
output::print_info(&format!("Witness: {}", witness.display()));
|
||||
output::print_info("Computing witness...");
|
||||
|
|
@ -240,7 +248,11 @@ pub async fn handle(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
ZkCommands::ProveStark { circuit, witness, output } => {
|
||||
ZkCommands::ProveStark {
|
||||
circuit,
|
||||
witness,
|
||||
output,
|
||||
} => {
|
||||
output::print_info(&format!("Generating STARK proof for circuit: {}", circuit));
|
||||
output::print_info(&format!("Witness: {}", witness.display()));
|
||||
output::print_info("Computing execution trace...");
|
||||
|
|
@ -256,7 +268,11 @@ pub async fn handle(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
ZkCommands::Verify { proof, verification_key: _, public_inputs: _ } => {
|
||||
ZkCommands::Verify {
|
||||
proof,
|
||||
verification_key: _,
|
||||
public_inputs: _,
|
||||
} => {
|
||||
output::print_info(&format!("Verifying proof: {}", proof.display()));
|
||||
output::print_info("Loading proof...");
|
||||
output::print_info("Verifying...");
|
||||
|
|
@ -265,8 +281,15 @@ pub async fn handle(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
ZkCommands::Setup { circuit, system, output } => {
|
||||
output::print_info(&format!("Generating {} keys for circuit: {}", system, circuit));
|
||||
ZkCommands::Setup {
|
||||
circuit,
|
||||
system,
|
||||
output,
|
||||
} => {
|
||||
output::print_info(&format!(
|
||||
"Generating {} keys for circuit: {}",
|
||||
system, circuit
|
||||
));
|
||||
output::print_info("This may take a while for large circuits...");
|
||||
output::print_info("Generating proving key...");
|
||||
output::print_info("Deriving verification key...");
|
||||
|
|
|
|||
|
|
@ -469,7 +469,11 @@ enum DeployCommands {
|
|||
output: Option<PathBuf>,
|
||||
|
||||
/// Hosting gateway URL
|
||||
#[arg(long, env = "SYNOR_HOSTING_URL", default_value = "http://127.0.0.1:8280")]
|
||||
#[arg(
|
||||
long,
|
||||
env = "SYNOR_HOSTING_URL",
|
||||
default_value = "http://127.0.0.1:8280"
|
||||
)]
|
||||
gateway: String,
|
||||
|
||||
/// Skip running the build command
|
||||
|
|
@ -495,7 +499,11 @@ enum DeployCommands {
|
|||
/// List deployments
|
||||
List {
|
||||
/// Hosting gateway URL
|
||||
#[arg(long, env = "SYNOR_HOSTING_URL", default_value = "http://127.0.0.1:8280")]
|
||||
#[arg(
|
||||
long,
|
||||
env = "SYNOR_HOSTING_URL",
|
||||
default_value = "http://127.0.0.1:8280"
|
||||
)]
|
||||
gateway: String,
|
||||
},
|
||||
|
||||
|
|
@ -505,7 +513,11 @@ enum DeployCommands {
|
|||
name: String,
|
||||
|
||||
/// Hosting gateway URL
|
||||
#[arg(long, env = "SYNOR_HOSTING_URL", default_value = "http://127.0.0.1:8280")]
|
||||
#[arg(
|
||||
long,
|
||||
env = "SYNOR_HOSTING_URL",
|
||||
default_value = "http://127.0.0.1:8280"
|
||||
)]
|
||||
gateway: String,
|
||||
},
|
||||
|
||||
|
|
@ -515,7 +527,11 @@ enum DeployCommands {
|
|||
name: String,
|
||||
|
||||
/// Hosting gateway URL
|
||||
#[arg(long, env = "SYNOR_HOSTING_URL", default_value = "http://127.0.0.1:8280")]
|
||||
#[arg(
|
||||
long,
|
||||
env = "SYNOR_HOSTING_URL",
|
||||
default_value = "http://127.0.0.1:8280"
|
||||
)]
|
||||
gateway: String,
|
||||
},
|
||||
}
|
||||
|
|
@ -591,9 +607,11 @@ async fn main() {
|
|||
gateway,
|
||||
skip_build,
|
||||
} => commands::deploy::deploy(name, out_dir, &gateway, skip_build, output).await,
|
||||
DeployCommands::Init { name, spa, output: out_dir } => {
|
||||
commands::deploy::init(name, spa, out_dir, output)
|
||||
}
|
||||
DeployCommands::Init {
|
||||
name,
|
||||
spa,
|
||||
output: out_dir,
|
||||
} => commands::deploy::init(name, spa, out_dir, output),
|
||||
DeployCommands::List { gateway } => commands::deploy::list(&gateway, output).await,
|
||||
DeployCommands::Delete { name, gateway } => {
|
||||
commands::deploy::delete(&name, &gateway, output).await
|
||||
|
|
|
|||
|
|
@ -676,7 +676,9 @@ async fn get_blocks(
|
|||
}
|
||||
|
||||
// Fetch blocks by blue score (most recent first)
|
||||
let start_score = score.blue_score.saturating_sub((params.page.saturating_sub(1) * limit) as u64);
|
||||
let start_score = score
|
||||
.blue_score
|
||||
.saturating_sub((params.page.saturating_sub(1) * limit) as u64);
|
||||
let blocks_data: Vec<serde_json::Value> = state
|
||||
.rpc_call("synor_getBlocksByBlueScore", (start_score, true))
|
||||
.await
|
||||
|
|
@ -697,17 +699,28 @@ async fn get_blocks(
|
|||
parent_hashes: header
|
||||
.get("parents")
|
||||
.and_then(|p| p.as_array())
|
||||
.map(|a| a.iter().filter_map(|v| v.as_str().map(String::from)).collect())
|
||||
.map(|a| {
|
||||
a.iter()
|
||||
.filter_map(|v| v.as_str().map(String::from))
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default(),
|
||||
timestamp,
|
||||
timestamp_human: format_timestamp(timestamp),
|
||||
bits: header.get("bits")?.as_u64()? as u32,
|
||||
nonce: header.get("nonce")?.as_u64()?,
|
||||
daa_score: header.get("blueScore").and_then(|v| v.as_u64()).unwrap_or(0),
|
||||
blue_score: header.get("blueScore").and_then(|v| v.as_u64()).unwrap_or(0),
|
||||
daa_score: header
|
||||
.get("blueScore")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(0),
|
||||
blue_score: header
|
||||
.get("blueScore")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(0),
|
||||
blue_work: String::new(),
|
||||
difficulty: 0.0,
|
||||
transaction_count: b.get("transactions")
|
||||
transaction_count: b
|
||||
.get("transactions")
|
||||
.and_then(|t| t.as_array())
|
||||
.map(|a| a.len())
|
||||
.unwrap_or(0),
|
||||
|
|
@ -1102,9 +1115,7 @@ async fn estimate_gas(
|
|||
};
|
||||
|
||||
// Call the node's contract_estimateGas RPC method
|
||||
let gas_used: u64 = state
|
||||
.rpc_call("contract_estimateGas", rpc_request)
|
||||
.await?;
|
||||
let gas_used: u64 = state.rpc_call("contract_estimateGas", rpc_request).await?;
|
||||
|
||||
// Calculate recommended gas limit with 20% safety margin
|
||||
let gas_limit_recommended = ((gas_used as f64) * 1.2).ceil() as u64;
|
||||
|
|
@ -1494,8 +1505,7 @@ async fn main() -> anyhow::Result<()> {
|
|||
let app = if let Some(ref static_dir) = config.static_dir {
|
||||
// Serve static files with SPA fallback (index.html for client-side routing)
|
||||
let index_path = format!("{}/index.html", static_dir);
|
||||
let serve_dir = ServeDir::new(static_dir)
|
||||
.not_found_service(ServeFile::new(&index_path));
|
||||
let serve_dir = ServeDir::new(static_dir).not_found_service(ServeFile::new(&index_path));
|
||||
|
||||
api_router
|
||||
.fallback_service(serve_dir)
|
||||
|
|
|
|||
|
|
@ -684,10 +684,12 @@ mod tests {
|
|||
#[test]
|
||||
fn test_all_paths_are_distinct() {
|
||||
let config = NodeConfig::for_network("mainnet").unwrap();
|
||||
let paths = [config.blocks_path(),
|
||||
let paths = [
|
||||
config.blocks_path(),
|
||||
config.chainstate_path(),
|
||||
config.contracts_path(),
|
||||
config.keys_path()];
|
||||
config.keys_path(),
|
||||
];
|
||||
|
||||
for i in 0..paths.len() {
|
||||
for j in (i + 1)..paths.len() {
|
||||
|
|
@ -794,9 +796,11 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_with_mining_enabled() {
|
||||
let config = NodeConfig::for_network("mainnet")
|
||||
.unwrap()
|
||||
.with_mining(true, Some("synor:test_address".to_string()), 4);
|
||||
let config = NodeConfig::for_network("mainnet").unwrap().with_mining(
|
||||
true,
|
||||
Some("synor:test_address".to_string()),
|
||||
4,
|
||||
);
|
||||
|
||||
assert!(config.mining.enabled);
|
||||
assert_eq!(
|
||||
|
|
@ -828,9 +832,10 @@ mod tests {
|
|||
#[test]
|
||||
fn test_with_p2p() {
|
||||
let seeds = vec!["seed1.example.com:30303".to_string()];
|
||||
let config = NodeConfig::for_network("mainnet")
|
||||
.unwrap()
|
||||
.with_p2p("0.0.0.0", 30303, seeds.clone());
|
||||
let config =
|
||||
NodeConfig::for_network("mainnet")
|
||||
.unwrap()
|
||||
.with_p2p("0.0.0.0", 30303, seeds.clone());
|
||||
|
||||
assert_eq!(config.p2p.listen_addr, "0.0.0.0:30303");
|
||||
assert_eq!(config.p2p.seeds, seeds);
|
||||
|
|
@ -1027,7 +1032,10 @@ mod tests {
|
|||
let loaded = NodeConfig::load(&path).unwrap();
|
||||
|
||||
assert_eq!(loaded.mining.enabled, config.mining.enabled);
|
||||
assert_eq!(loaded.mining.coinbase_address, config.mining.coinbase_address);
|
||||
assert_eq!(
|
||||
loaded.mining.coinbase_address,
|
||||
config.mining.coinbase_address
|
||||
);
|
||||
assert_eq!(loaded.mining.threads, config.mining.threads);
|
||||
assert_eq!(loaded.storage.cache_size_mb, config.storage.cache_size_mb);
|
||||
assert_eq!(loaded.logging.level, config.logging.level);
|
||||
|
|
|
|||
|
|
@ -425,11 +425,13 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_node_state_all_variants_are_distinct() {
|
||||
let states = [NodeState::Starting,
|
||||
let states = [
|
||||
NodeState::Starting,
|
||||
NodeState::Syncing,
|
||||
NodeState::Running,
|
||||
NodeState::Stopping,
|
||||
NodeState::Stopped];
|
||||
NodeState::Stopped,
|
||||
];
|
||||
|
||||
for i in 0..states.len() {
|
||||
for j in (i + 1)..states.len() {
|
||||
|
|
@ -605,7 +607,10 @@ mod tests {
|
|||
.with_mining(true, Some("synor:test".to_string()), 4);
|
||||
|
||||
assert!(config.mining.enabled);
|
||||
assert_eq!(config.mining.coinbase_address, Some("synor:test".to_string()));
|
||||
assert_eq!(
|
||||
config.mining.coinbase_address,
|
||||
Some("synor:test".to_string())
|
||||
);
|
||||
assert_eq!(config.mining.threads, 4);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -12,9 +12,12 @@ use synor_mining::{
|
|||
MinerCommand, MinerConfig, MinerEvent, MiningResult, MiningStats as CrateMiningStats,
|
||||
TemplateTransaction,
|
||||
};
|
||||
use synor_types::{Address, Amount, Block, BlockHeader, BlockId, BlueScore, Hash256, Network, Timestamp, Transaction, TxOutput};
|
||||
use synor_types::block::BlockBody;
|
||||
use synor_types::transaction::ScriptPubKey;
|
||||
use synor_types::{
|
||||
Address, Amount, Block, BlockHeader, BlockId, BlueScore, Hash256, Network, Timestamp,
|
||||
Transaction, TxOutput,
|
||||
};
|
||||
|
||||
use crate::config::NodeConfig;
|
||||
use crate::services::{ConsensusService, MempoolService};
|
||||
|
|
@ -473,10 +476,7 @@ impl MinerService {
|
|||
extra_data.extend_from_slice(&result.nonce.to_le_bytes());
|
||||
extra_data.extend_from_slice(&template.coinbase_data.extra_data);
|
||||
|
||||
let coinbase_tx = Transaction::coinbase(
|
||||
vec![coinbase_output],
|
||||
extra_data,
|
||||
);
|
||||
let coinbase_tx = Transaction::coinbase(vec![coinbase_output], extra_data);
|
||||
|
||||
// Start with coinbase transaction
|
||||
let mut transactions = vec![coinbase_tx];
|
||||
|
|
@ -522,8 +522,7 @@ impl MinerService {
|
|||
let block = Block { header, body };
|
||||
|
||||
// Serialize with Borsh
|
||||
borsh::to_vec(&block)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to serialize block: {}", e))
|
||||
borsh::to_vec(&block).map_err(|e| anyhow::anyhow!("Failed to serialize block: {}", e))
|
||||
}
|
||||
|
||||
/// Submits a mined block (for external submission via RPC).
|
||||
|
|
|
|||
|
|
@ -234,7 +234,10 @@ mod network_partition_tests {
|
|||
|
||||
// Node 0 should have fewer peers after isolation
|
||||
let isolated_peers = network.nodes[0].network().peer_count().await;
|
||||
info!(isolated_peers = isolated_peers, "Node 0 peers after isolation");
|
||||
info!(
|
||||
isolated_peers = isolated_peers,
|
||||
"Node 0 peers after isolation"
|
||||
);
|
||||
|
||||
assert!(
|
||||
isolated_peers < initial_peer_counts[0] || initial_peer_counts[0] == 0,
|
||||
|
|
@ -271,7 +274,10 @@ mod network_partition_tests {
|
|||
|
||||
// After healing, nodes should have peers
|
||||
let total_peers = network.total_peer_count().await;
|
||||
info!(total_peers = total_peers, "Total peers after partition recovery");
|
||||
info!(
|
||||
total_peers = total_peers,
|
||||
"Total peers after partition recovery"
|
||||
);
|
||||
|
||||
// Consensus state should converge
|
||||
let consensus0 = network.nodes[0].consensus();
|
||||
|
|
@ -287,7 +293,10 @@ mod network_partition_tests {
|
|||
);
|
||||
|
||||
// Both should have some consensus state
|
||||
assert!(vsp0.is_some() || vsp1.is_some(), "At least one node should have VSP");
|
||||
assert!(
|
||||
vsp0.is_some() || vsp1.is_some(),
|
||||
"At least one node should have VSP"
|
||||
);
|
||||
|
||||
network.stop_all().await.unwrap();
|
||||
}
|
||||
|
|
@ -351,10 +360,12 @@ mod network_partition_tests {
|
|||
|
||||
// Record blue scores from each partition
|
||||
let scores_before: Vec<u64> = futures::future::join_all(
|
||||
network.nodes.iter().map(|n| async {
|
||||
n.consensus().current_blue_score().await
|
||||
})
|
||||
).await;
|
||||
network
|
||||
.nodes
|
||||
.iter()
|
||||
.map(|n| async { n.consensus().current_blue_score().await }),
|
||||
)
|
||||
.await;
|
||||
|
||||
info!(scores_before = ?scores_before, "Blue scores before healing");
|
||||
|
||||
|
|
@ -368,10 +379,12 @@ mod network_partition_tests {
|
|||
|
||||
// Blue scores should converge
|
||||
let scores_after: Vec<u64> = futures::future::join_all(
|
||||
network.nodes.iter().map(|n| async {
|
||||
n.consensus().current_blue_score().await
|
||||
})
|
||||
).await;
|
||||
network
|
||||
.nodes
|
||||
.iter()
|
||||
.map(|n| async { n.consensus().current_blue_score().await }),
|
||||
)
|
||||
.await;
|
||||
|
||||
info!(scores_after = ?scores_after, "Blue scores after healing");
|
||||
|
||||
|
|
@ -380,7 +393,9 @@ mod network_partition_tests {
|
|||
assert!(
|
||||
after >= before,
|
||||
"Node {} blue score should not decrease: {} -> {}",
|
||||
i, before, after
|
||||
i,
|
||||
before,
|
||||
after
|
||||
);
|
||||
}
|
||||
|
||||
|
|
@ -407,10 +422,7 @@ mod double_spend_tests {
|
|||
let mempool = network.nodes[0].mempool();
|
||||
let initial_size = mempool.size().await;
|
||||
|
||||
info!(
|
||||
initial_mempool_size = initial_size,
|
||||
"Initial mempool state"
|
||||
);
|
||||
info!(initial_mempool_size = initial_size, "Initial mempool state");
|
||||
|
||||
// In production, we would:
|
||||
// 1. Create two transactions spending the same UTXO
|
||||
|
|
@ -420,7 +432,7 @@ mod double_spend_tests {
|
|||
// For now, verify mempool API is working
|
||||
// and handles empty/invalid data gracefully
|
||||
let _invalid_tx = vec![0u8; 50]; // Invalid transaction bytes (for future use)
|
||||
// Submitting invalid tx should fail gracefully
|
||||
// Submitting invalid tx should fail gracefully
|
||||
|
||||
// Mempool should maintain integrity
|
||||
let final_size = mempool.size().await;
|
||||
|
|
@ -653,7 +665,11 @@ mod invalid_block_rejection_tests {
|
|||
|
||||
// All valid tips should have known parents in the DAG
|
||||
for tip in &tips {
|
||||
let has_parents = consensus.get_block_info(tip).await.map(|info| !info.parents.is_empty()).unwrap_or(false);
|
||||
let has_parents = consensus
|
||||
.get_block_info(tip)
|
||||
.await
|
||||
.map(|info| !info.parents.is_empty())
|
||||
.unwrap_or(false);
|
||||
info!(
|
||||
block = hex::encode(&tip[..8]),
|
||||
has_parents = has_parents,
|
||||
|
|
@ -687,16 +703,22 @@ mod sybil_attack_tests {
|
|||
|
||||
// Track blue scores - honest nodes should maintain correct view
|
||||
let honest_scores: Vec<u64> = futures::future::join_all(
|
||||
network.nodes.iter().take(3).map(|n| async {
|
||||
n.consensus().current_blue_score().await
|
||||
})
|
||||
).await;
|
||||
network
|
||||
.nodes
|
||||
.iter()
|
||||
.take(3)
|
||||
.map(|n| async { n.consensus().current_blue_score().await }),
|
||||
)
|
||||
.await;
|
||||
|
||||
let sybil_scores: Vec<u64> = futures::future::join_all(
|
||||
network.nodes.iter().skip(3).map(|n| async {
|
||||
n.consensus().current_blue_score().await
|
||||
})
|
||||
).await;
|
||||
network
|
||||
.nodes
|
||||
.iter()
|
||||
.skip(3)
|
||||
.map(|n| async { n.consensus().current_blue_score().await }),
|
||||
)
|
||||
.await;
|
||||
|
||||
info!(
|
||||
honest_scores = ?honest_scores,
|
||||
|
|
@ -805,7 +827,10 @@ mod eclipse_attack_tests {
|
|||
sleep(Duration::from_secs(1)).await;
|
||||
|
||||
let after_eclipse_peers = victim_network.peer_count().await;
|
||||
info!(after_eclipse_peers = after_eclipse_peers, "Peers after eclipse attempt");
|
||||
info!(
|
||||
after_eclipse_peers = after_eclipse_peers,
|
||||
"Peers after eclipse attempt"
|
||||
);
|
||||
|
||||
// In a real implementation, the node would:
|
||||
// 1. Detect low peer diversity
|
||||
|
|
@ -863,7 +888,10 @@ mod eclipse_attack_tests {
|
|||
sleep(Duration::from_secs(1)).await;
|
||||
|
||||
let eclipsed_peers = network.nodes[0].network().peer_count().await;
|
||||
info!(eclipsed_peers = eclipsed_peers, "Node 0 peers during eclipse");
|
||||
info!(
|
||||
eclipsed_peers = eclipsed_peers,
|
||||
"Node 0 peers during eclipse"
|
||||
);
|
||||
|
||||
// Manually reconnect (simulating recovery mechanism)
|
||||
network.connect_nodes(0, 1).await.unwrap();
|
||||
|
|
@ -871,7 +899,10 @@ mod eclipse_attack_tests {
|
|||
sleep(Duration::from_secs(2)).await;
|
||||
|
||||
let recovered_peers = network.nodes[0].network().peer_count().await;
|
||||
info!(recovered_peers = recovered_peers, "Node 0 peers after recovery");
|
||||
info!(
|
||||
recovered_peers = recovered_peers,
|
||||
"Node 0 peers after recovery"
|
||||
);
|
||||
|
||||
// Should have reconnected
|
||||
assert!(
|
||||
|
|
@ -1038,10 +1069,12 @@ mod dag_reorg_tests {
|
|||
|
||||
// Record divergent states
|
||||
let states_before: Vec<u64> = futures::future::join_all(
|
||||
network.nodes.iter().map(|n| async {
|
||||
n.consensus().current_blue_score().await
|
||||
})
|
||||
).await;
|
||||
network
|
||||
.nodes
|
||||
.iter()
|
||||
.map(|n| async { n.consensus().current_blue_score().await }),
|
||||
)
|
||||
.await;
|
||||
|
||||
info!(states_before = ?states_before, "States before reconnection");
|
||||
|
||||
|
|
@ -1052,10 +1085,12 @@ mod dag_reorg_tests {
|
|||
|
||||
// Get converged states
|
||||
let states_after: Vec<u64> = futures::future::join_all(
|
||||
network.nodes.iter().map(|n| async {
|
||||
n.consensus().current_blue_score().await
|
||||
})
|
||||
).await;
|
||||
network
|
||||
.nodes
|
||||
.iter()
|
||||
.map(|n| async { n.consensus().current_blue_score().await }),
|
||||
)
|
||||
.await;
|
||||
|
||||
info!(states_after = ?states_after, "States after reconnection");
|
||||
|
||||
|
|
@ -1064,7 +1099,9 @@ mod dag_reorg_tests {
|
|||
assert!(
|
||||
after >= before,
|
||||
"Node {} blue score regression: {} -> {}",
|
||||
i, before, after
|
||||
i,
|
||||
before,
|
||||
after
|
||||
);
|
||||
}
|
||||
|
||||
|
|
@ -1192,10 +1229,12 @@ mod parallel_blocks_tests {
|
|||
|
||||
// Collect blue scores from all nodes
|
||||
let blue_scores: Vec<u64> = futures::future::join_all(
|
||||
network.nodes.iter().map(|n| async {
|
||||
n.consensus().current_blue_score().await
|
||||
})
|
||||
).await;
|
||||
network
|
||||
.nodes
|
||||
.iter()
|
||||
.map(|n| async { n.consensus().current_blue_score().await }),
|
||||
)
|
||||
.await;
|
||||
|
||||
info!(blue_scores = ?blue_scores, "Blue scores across nodes");
|
||||
|
||||
|
|
@ -1206,7 +1245,8 @@ mod parallel_blocks_tests {
|
|||
assert!(
|
||||
max_score - min_score <= 2,
|
||||
"Blue scores should be consistent: {} - {} > 2",
|
||||
max_score, min_score
|
||||
max_score,
|
||||
min_score
|
||||
);
|
||||
|
||||
network.stop_all().await.unwrap();
|
||||
|
|
@ -1264,10 +1304,12 @@ mod parallel_blocks_tests {
|
|||
|
||||
// Get selected chains from all nodes
|
||||
let chains: Vec<Vec<[u8; 32]>> = futures::future::join_all(
|
||||
network.nodes.iter().map(|n| async {
|
||||
n.consensus().get_selected_chain(10).await
|
||||
})
|
||||
).await;
|
||||
network
|
||||
.nodes
|
||||
.iter()
|
||||
.map(|n| async { n.consensus().get_selected_chain(10).await }),
|
||||
)
|
||||
.await;
|
||||
|
||||
info!(
|
||||
chain_lengths = ?chains.iter().map(|c| c.len()).collect::<Vec<_>>(),
|
||||
|
|
@ -1276,7 +1318,8 @@ mod parallel_blocks_tests {
|
|||
|
||||
// All nodes should have the same selected chain (after sync)
|
||||
// Check that genesis (first block) matches
|
||||
let genesis_blocks: Vec<_> = chains.iter()
|
||||
let genesis_blocks: Vec<_> = chains
|
||||
.iter()
|
||||
.filter(|c| !c.is_empty())
|
||||
.map(|c| c[0])
|
||||
.collect();
|
||||
|
|
@ -1353,10 +1396,13 @@ mod bft_threshold_tests {
|
|||
|
||||
// Honest nodes (0, 1, 2) should maintain consensus
|
||||
let honest_scores: Vec<u64> = futures::future::join_all(
|
||||
network.nodes.iter().take(3).map(|n| async {
|
||||
n.consensus().current_blue_score().await
|
||||
})
|
||||
).await;
|
||||
network
|
||||
.nodes
|
||||
.iter()
|
||||
.take(3)
|
||||
.map(|n| async { n.consensus().current_blue_score().await }),
|
||||
)
|
||||
.await;
|
||||
|
||||
info!(honest_scores = ?honest_scores, "Honest node blue scores");
|
||||
|
||||
|
|
@ -1399,10 +1445,7 @@ mod bft_threshold_tests {
|
|||
|
||||
// Blue score should not decrease
|
||||
let final_blue = network.nodes[0].consensus().current_blue_score().await;
|
||||
assert!(
|
||||
final_blue >= initial_blue,
|
||||
"Blue score should not decrease"
|
||||
);
|
||||
assert!(final_blue >= initial_blue, "Blue score should not decrease");
|
||||
|
||||
// Stop remaining nodes
|
||||
for node in network.nodes.iter().take(3) {
|
||||
|
|
@ -1615,10 +1658,12 @@ mod integration_tests {
|
|||
|
||||
// Record initial state
|
||||
let initial_scores: Vec<u64> = futures::future::join_all(
|
||||
network.nodes.iter().map(|n| async {
|
||||
n.consensus().current_blue_score().await
|
||||
})
|
||||
).await;
|
||||
network
|
||||
.nodes
|
||||
.iter()
|
||||
.map(|n| async { n.consensus().current_blue_score().await }),
|
||||
)
|
||||
.await;
|
||||
info!(initial_scores = ?initial_scores, "Initial blue scores");
|
||||
|
||||
info!("Phase 2: Simulate 2 Byzantine nodes (partition)");
|
||||
|
|
@ -1640,18 +1685,24 @@ mod integration_tests {
|
|||
|
||||
info!("Phase 4: Verify convergence");
|
||||
let final_scores: Vec<u64> = futures::future::join_all(
|
||||
network.nodes.iter().map(|n| async {
|
||||
n.consensus().current_blue_score().await
|
||||
})
|
||||
).await;
|
||||
network
|
||||
.nodes
|
||||
.iter()
|
||||
.map(|n| async { n.consensus().current_blue_score().await }),
|
||||
)
|
||||
.await;
|
||||
info!(final_scores = ?final_scores, "Final blue scores");
|
||||
|
||||
// All nodes should have non-decreasing blue scores
|
||||
for (i, (&initial, &final_score)) in initial_scores.iter().zip(final_scores.iter()).enumerate() {
|
||||
for (i, (&initial, &final_score)) in
|
||||
initial_scores.iter().zip(final_scores.iter()).enumerate()
|
||||
{
|
||||
assert!(
|
||||
final_score >= initial,
|
||||
"Node {} score regression: {} -> {}",
|
||||
i, initial, final_score
|
||||
i,
|
||||
initial,
|
||||
final_score
|
||||
);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -17,8 +17,8 @@
|
|||
//! 3. Vault contract verifies proof and unlocks original tokens
|
||||
|
||||
use crate::{
|
||||
AssetId, Bridge, BridgeAddress, BridgeError, BridgeResult, BridgeTransfer, ChainType, TransferId, TransferManager, TransferStatus, VaultManager,
|
||||
ETH_MIN_CONFIRMATIONS,
|
||||
AssetId, Bridge, BridgeAddress, BridgeError, BridgeResult, BridgeTransfer, ChainType,
|
||||
TransferId, TransferManager, TransferStatus, VaultManager, ETH_MIN_CONFIRMATIONS,
|
||||
};
|
||||
use alloy_primitives::{Address, B256, U256};
|
||||
use alloy_sol_types::sol;
|
||||
|
|
@ -281,9 +281,9 @@ impl EthereumBridge {
|
|||
// Check for replay
|
||||
let event_hash = event.hash();
|
||||
if self.processed_events.read().contains_key(&event_hash) {
|
||||
return Err(BridgeError::TransferAlreadyExists(
|
||||
hex::encode(event_hash.as_slice()),
|
||||
));
|
||||
return Err(BridgeError::TransferAlreadyExists(hex::encode(
|
||||
event_hash.as_slice(),
|
||||
)));
|
||||
}
|
||||
|
||||
// Verify token is supported
|
||||
|
|
@ -393,18 +393,15 @@ impl EthereumBridge {
|
|||
// Collect matching transfer IDs first
|
||||
let matching_transfer_id = {
|
||||
let transfers = self.transfers.read();
|
||||
transfers
|
||||
.pending_transfers()
|
||||
.iter()
|
||||
.find_map(|transfer| {
|
||||
transfer.source_tx_hash.as_ref().and_then(|tx_hash| {
|
||||
if tx_hash.as_slice() == event_hash.as_slice() {
|
||||
Some(transfer.id.clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
transfers.pending_transfers().iter().find_map(|transfer| {
|
||||
transfer.source_tx_hash.as_ref().and_then(|tx_hash| {
|
||||
if tx_hash.as_slice() == event_hash.as_slice() {
|
||||
Some(transfer.id.clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
})
|
||||
};
|
||||
|
||||
// Now update the transfer if found
|
||||
|
|
@ -457,7 +454,9 @@ impl EthereumBridge {
|
|||
.map_err(|e| BridgeError::InvalidAddress(e.to_string()))?;
|
||||
|
||||
if bytes.len() != 20 {
|
||||
return Err(BridgeError::InvalidAddress("invalid address length".to_string()));
|
||||
return Err(BridgeError::InvalidAddress(
|
||||
"invalid address length".to_string(),
|
||||
));
|
||||
}
|
||||
Address::from_slice(&bytes)
|
||||
};
|
||||
|
|
@ -801,7 +800,10 @@ mod tests {
|
|||
wrapped.mint(1000);
|
||||
|
||||
let result = wrapped.burn(1500);
|
||||
assert!(matches!(result, Err(BridgeError::InsufficientBalance { .. })));
|
||||
assert!(matches!(
|
||||
result,
|
||||
Err(BridgeError::InsufficientBalance { .. })
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -896,7 +898,9 @@ mod tests {
|
|||
let current_time = 1700000000;
|
||||
|
||||
let event = create_lock_event(0);
|
||||
bridge.process_lock_event(event.clone(), current_time).unwrap();
|
||||
bridge
|
||||
.process_lock_event(event.clone(), current_time)
|
||||
.unwrap();
|
||||
|
||||
let result = bridge.process_lock_event(event, current_time + 100);
|
||||
assert!(matches!(result, Err(BridgeError::TransferAlreadyExists(_))));
|
||||
|
|
@ -949,8 +953,12 @@ mod tests {
|
|||
let event_hash = B256::from([0x11; 32]);
|
||||
let unauthorized_relayer = Address::from([0x99; 20]);
|
||||
|
||||
let result = bridge.submit_relayer_signature(event_hash, unauthorized_relayer, vec![0x00; 65]);
|
||||
assert!(matches!(result, Err(BridgeError::SignatureVerificationFailed(_))));
|
||||
let result =
|
||||
bridge.submit_relayer_signature(event_hash, unauthorized_relayer, vec![0x00; 65]);
|
||||
assert!(matches!(
|
||||
result,
|
||||
Err(BridgeError::SignatureVerificationFailed(_))
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -964,7 +972,9 @@ mod tests {
|
|||
});
|
||||
|
||||
let event_hash = B256::from([0x11; 32]);
|
||||
let result = bridge.submit_relayer_signature(event_hash, relayer, vec![0x00; 65]).unwrap();
|
||||
let result = bridge
|
||||
.submit_relayer_signature(event_hash, relayer, vec![0x00; 65])
|
||||
.unwrap();
|
||||
|
||||
assert!(result);
|
||||
}
|
||||
|
|
@ -983,7 +993,9 @@ mod tests {
|
|||
.update_confirmations(&transfer_id, 12, current_time + 100)
|
||||
.unwrap();
|
||||
|
||||
bridge.mint_wrapped_tokens(&transfer_id, current_time + 200).unwrap();
|
||||
bridge
|
||||
.mint_wrapped_tokens(&transfer_id, current_time + 200)
|
||||
.unwrap();
|
||||
|
||||
let wrapped = bridge.get_wrapped_token(Address::ZERO).unwrap();
|
||||
assert_eq!(wrapped.total_supply, 1000);
|
||||
|
|
@ -1022,13 +1034,7 @@ mod tests {
|
|||
|
||||
let asset = AssetId::wrapped(&AssetId::eth());
|
||||
let transfer_id = bridge
|
||||
.initiate_burn(
|
||||
asset,
|
||||
1000,
|
||||
test_recipient(),
|
||||
test_sender(),
|
||||
current_time,
|
||||
)
|
||||
.initiate_burn(asset, 1000, test_recipient(), test_sender(), current_time)
|
||||
.unwrap();
|
||||
|
||||
let transfers = bridge.transfers.read();
|
||||
|
|
@ -1053,15 +1059,13 @@ mod tests {
|
|||
drop(wrapped_tokens);
|
||||
|
||||
let asset = AssetId::wrapped(&AssetId::eth());
|
||||
let result = bridge.initiate_burn(
|
||||
asset,
|
||||
1000,
|
||||
test_recipient(),
|
||||
test_sender(),
|
||||
current_time,
|
||||
);
|
||||
let result =
|
||||
bridge.initiate_burn(asset, 1000, test_recipient(), test_sender(), current_time);
|
||||
|
||||
assert!(matches!(result, Err(BridgeError::InsufficientBalance { .. })));
|
||||
assert!(matches!(
|
||||
result,
|
||||
Err(BridgeError::InsufficientBalance { .. })
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -1069,13 +1073,7 @@ mod tests {
|
|||
let bridge = EthereumBridge::new(EthereumBridgeConfig::default());
|
||||
|
||||
let asset = AssetId::wrapped(&AssetId::eth());
|
||||
let result = bridge.initiate_burn(
|
||||
asset,
|
||||
1000,
|
||||
test_recipient(),
|
||||
test_sender(),
|
||||
0,
|
||||
);
|
||||
let result = bridge.initiate_burn(asset, 1000, test_recipient(), test_sender(), 0);
|
||||
|
||||
assert!(matches!(result, Err(BridgeError::AssetNotSupported(_))));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -79,7 +79,9 @@ pub const ETH_MIN_CONFIRMATIONS: u64 = 12;
|
|||
pub const BTC_MIN_CONFIRMATIONS: u64 = 6;
|
||||
|
||||
/// Bridge chain identifier
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, BorshSerialize, BorshDeserialize)]
|
||||
#[derive(
|
||||
Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, BorshSerialize, BorshDeserialize,
|
||||
)]
|
||||
pub enum ChainType {
|
||||
/// Synor native chain
|
||||
Synor,
|
||||
|
|
@ -128,7 +130,9 @@ impl fmt::Display for ChainType {
|
|||
}
|
||||
|
||||
/// Asset identifier across chains
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, BorshSerialize, BorshDeserialize)]
|
||||
#[derive(
|
||||
Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, BorshSerialize, BorshDeserialize,
|
||||
)]
|
||||
pub struct AssetId {
|
||||
/// Chain where the asset originates
|
||||
pub chain: ChainType,
|
||||
|
|
@ -199,7 +203,9 @@ impl fmt::Display for AssetId {
|
|||
}
|
||||
|
||||
/// Bridge address (unified format for cross-chain addresses)
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, BorshSerialize, BorshDeserialize)]
|
||||
#[derive(
|
||||
Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, BorshSerialize, BorshDeserialize,
|
||||
)]
|
||||
pub struct BridgeAddress {
|
||||
/// Chain type
|
||||
pub chain: ChainType,
|
||||
|
|
|
|||
|
|
@ -14,7 +14,9 @@ use std::collections::HashMap;
|
|||
use std::fmt;
|
||||
|
||||
/// Unique transfer identifier
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, BorshSerialize, BorshDeserialize)]
|
||||
#[derive(
|
||||
Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, BorshSerialize, BorshDeserialize,
|
||||
)]
|
||||
pub struct TransferId(pub String);
|
||||
|
||||
impl TransferId {
|
||||
|
|
@ -48,7 +50,9 @@ impl fmt::Display for TransferId {
|
|||
}
|
||||
|
||||
/// Transfer direction
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, BorshSerialize, BorshDeserialize)]
|
||||
#[derive(
|
||||
Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, BorshSerialize, BorshDeserialize,
|
||||
)]
|
||||
pub enum TransferDirection {
|
||||
/// From external chain to Synor (Lock → Mint)
|
||||
Inbound,
|
||||
|
|
@ -66,7 +70,9 @@ impl fmt::Display for TransferDirection {
|
|||
}
|
||||
|
||||
/// Transfer status
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, BorshSerialize, BorshDeserialize)]
|
||||
#[derive(
|
||||
Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, BorshSerialize, BorshDeserialize,
|
||||
)]
|
||||
pub enum TransferStatus {
|
||||
/// Transfer initiated, awaiting lock confirmation
|
||||
Pending,
|
||||
|
|
@ -1030,7 +1036,10 @@ mod tests {
|
|||
transfer.fail("Proof verification failed", current_time + 50);
|
||||
|
||||
assert_eq!(transfer.status, TransferStatus::Failed);
|
||||
assert_eq!(transfer.error, Some("Proof verification failed".to_string()));
|
||||
assert_eq!(
|
||||
transfer.error,
|
||||
Some("Proof verification failed".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -1331,8 +1340,12 @@ mod tests {
|
|||
|
||||
assert_eq!(manager.pending_transfers().len(), 1);
|
||||
|
||||
manager.confirm_lock(&id, vec![0x11; 32], 100, current_time + 10).unwrap();
|
||||
manager.update_confirmations(&id, 12, current_time + 120).unwrap();
|
||||
manager
|
||||
.confirm_lock(&id, vec![0x11; 32], 100, current_time + 10)
|
||||
.unwrap();
|
||||
manager
|
||||
.update_confirmations(&id, 12, current_time + 120)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(manager.pending_transfers().len(), 0);
|
||||
}
|
||||
|
|
@ -1356,8 +1369,12 @@ mod tests {
|
|||
|
||||
assert_eq!(manager.ready_for_confirmation().len(), 0);
|
||||
|
||||
manager.confirm_lock(&id, vec![0x11; 32], 100, current_time + 10).unwrap();
|
||||
manager.update_confirmations(&id, 12, current_time + 120).unwrap();
|
||||
manager
|
||||
.confirm_lock(&id, vec![0x11; 32], 100, current_time + 10)
|
||||
.unwrap();
|
||||
manager
|
||||
.update_confirmations(&id, 12, current_time + 120)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(manager.ready_for_confirmation().len(), 1);
|
||||
}
|
||||
|
|
@ -1410,7 +1427,9 @@ mod tests {
|
|||
)
|
||||
.unwrap();
|
||||
|
||||
manager.fail_transfer(&id, "Verification failed", current_time + 50).unwrap();
|
||||
manager
|
||||
.fail_transfer(&id, "Verification failed", current_time + 50)
|
||||
.unwrap();
|
||||
|
||||
let transfer = manager.get(&id).unwrap();
|
||||
assert_eq!(transfer.status, TransferStatus::Failed);
|
||||
|
|
@ -1452,9 +1471,15 @@ mod tests {
|
|||
)
|
||||
.unwrap();
|
||||
|
||||
manager.confirm_lock(&id1, vec![0x11; 32], 100, current_time).unwrap();
|
||||
manager.update_confirmations(&id1, 12, current_time).unwrap();
|
||||
manager.confirm_mint(&id1, vec![0x22; 32], current_time).unwrap();
|
||||
manager
|
||||
.confirm_lock(&id1, vec![0x11; 32], 100, current_time)
|
||||
.unwrap();
|
||||
manager
|
||||
.update_confirmations(&id1, 12, current_time)
|
||||
.unwrap();
|
||||
manager
|
||||
.confirm_mint(&id1, vec![0x22; 32], current_time)
|
||||
.unwrap();
|
||||
|
||||
let stats = manager.stats();
|
||||
assert_eq!(stats.total_count, 2);
|
||||
|
|
@ -1484,19 +1509,27 @@ mod tests {
|
|||
let transfer = manager.get(&id).unwrap();
|
||||
assert_eq!(transfer.status, TransferStatus::Pending);
|
||||
|
||||
manager.confirm_lock(&id, vec![0x11; 32], 100, current_time + 60).unwrap();
|
||||
manager
|
||||
.confirm_lock(&id, vec![0x11; 32], 100, current_time + 60)
|
||||
.unwrap();
|
||||
let transfer = manager.get(&id).unwrap();
|
||||
assert_eq!(transfer.status, TransferStatus::Locked);
|
||||
|
||||
manager.update_confirmations(&id, 6, current_time + 120).unwrap();
|
||||
manager
|
||||
.update_confirmations(&id, 6, current_time + 120)
|
||||
.unwrap();
|
||||
let transfer = manager.get(&id).unwrap();
|
||||
assert_eq!(transfer.status, TransferStatus::Locked);
|
||||
|
||||
manager.update_confirmations(&id, 12, current_time + 180).unwrap();
|
||||
manager
|
||||
.update_confirmations(&id, 12, current_time + 180)
|
||||
.unwrap();
|
||||
let transfer = manager.get(&id).unwrap();
|
||||
assert_eq!(transfer.status, TransferStatus::Confirmed);
|
||||
|
||||
manager.confirm_mint(&id, vec![0x22; 32], current_time + 240).unwrap();
|
||||
manager
|
||||
.confirm_mint(&id, vec![0x22; 32], current_time + 240)
|
||||
.unwrap();
|
||||
let transfer = manager.get(&id).unwrap();
|
||||
assert_eq!(transfer.status, TransferStatus::Completed);
|
||||
}
|
||||
|
|
@ -1521,15 +1554,21 @@ mod tests {
|
|||
let transfer = manager.get(&id).unwrap();
|
||||
assert_eq!(transfer.status, TransferStatus::Pending);
|
||||
|
||||
manager.confirm_lock(&id, vec![0x11; 32], 100, current_time + 60).unwrap();
|
||||
manager
|
||||
.confirm_lock(&id, vec![0x11; 32], 100, current_time + 60)
|
||||
.unwrap();
|
||||
let transfer = manager.get(&id).unwrap();
|
||||
assert_eq!(transfer.status, TransferStatus::Locked);
|
||||
|
||||
manager.update_confirmations(&id, 6, current_time + 120).unwrap();
|
||||
manager
|
||||
.update_confirmations(&id, 6, current_time + 120)
|
||||
.unwrap();
|
||||
let transfer = manager.get(&id).unwrap();
|
||||
assert_eq!(transfer.status, TransferStatus::Confirmed);
|
||||
|
||||
manager.confirm_unlock(&id, vec![0x33; 32], current_time + 180).unwrap();
|
||||
manager
|
||||
.confirm_unlock(&id, vec![0x33; 32], current_time + 180)
|
||||
.unwrap();
|
||||
let transfer = manager.get(&id).unwrap();
|
||||
assert_eq!(transfer.status, TransferStatus::Completed);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,7 +12,9 @@ use std::collections::HashMap;
|
|||
use std::fmt;
|
||||
|
||||
/// Unique vault identifier
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, BorshSerialize, BorshDeserialize)]
|
||||
#[derive(
|
||||
Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, BorshSerialize, BorshDeserialize,
|
||||
)]
|
||||
pub struct VaultId(pub String);
|
||||
|
||||
impl VaultId {
|
||||
|
|
@ -198,13 +200,7 @@ impl Vault {
|
|||
return Err(BridgeError::TransferAlreadyExists(lock_id));
|
||||
}
|
||||
|
||||
let locked = LockedAsset::new(
|
||||
self.asset.clone(),
|
||||
amount,
|
||||
owner,
|
||||
recipient,
|
||||
current_time,
|
||||
);
|
||||
let locked = LockedAsset::new(self.asset.clone(), amount, owner, recipient, current_time);
|
||||
|
||||
self.locked_assets.insert(lock_id, locked);
|
||||
self.total_locked += amount;
|
||||
|
|
@ -283,7 +279,10 @@ impl Vault {
|
|||
}
|
||||
|
||||
/// Get expired locked assets
|
||||
pub fn expired_locked(&self, current_time: u64) -> impl Iterator<Item = (&String, &LockedAsset)> {
|
||||
pub fn expired_locked(
|
||||
&self,
|
||||
current_time: u64,
|
||||
) -> impl Iterator<Item = (&String, &LockedAsset)> {
|
||||
self.locked_assets
|
||||
.iter()
|
||||
.filter(move |(_, l)| !l.released && l.is_expired(current_time))
|
||||
|
|
@ -511,14 +510,8 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_locked_asset_expiry() {
|
||||
let locked = LockedAsset::new(
|
||||
AssetId::eth(),
|
||||
1000,
|
||||
test_owner(),
|
||||
test_recipient(),
|
||||
1000,
|
||||
)
|
||||
.with_expiry(2000);
|
||||
let locked = LockedAsset::new(AssetId::eth(), 1000, test_owner(), test_recipient(), 1000)
|
||||
.with_expiry(2000);
|
||||
|
||||
assert!(!locked.is_expired(1500));
|
||||
assert!(locked.is_expired(2000));
|
||||
|
|
@ -624,11 +617,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_lock_unlock() {
|
||||
let mut vault = Vault::new(
|
||||
VaultId::new("test"),
|
||||
ChainType::Ethereum,
|
||||
AssetId::eth(),
|
||||
);
|
||||
let mut vault = Vault::new(VaultId::new("test"), ChainType::Ethereum, AssetId::eth());
|
||||
|
||||
let current_time = 1700000000;
|
||||
|
||||
|
|
@ -654,20 +643,28 @@ mod tests {
|
|||
);
|
||||
|
||||
let current_time = 1700000000;
|
||||
vault.lock("lock-1", 1000, test_owner(), test_recipient(), current_time).unwrap();
|
||||
vault.lock("lock-2", 2000, test_owner(), test_recipient(), current_time).unwrap();
|
||||
vault.lock("lock-3", 500, test_owner_alt(), test_recipient(), current_time).unwrap();
|
||||
vault
|
||||
.lock("lock-1", 1000, test_owner(), test_recipient(), current_time)
|
||||
.unwrap();
|
||||
vault
|
||||
.lock("lock-2", 2000, test_owner(), test_recipient(), current_time)
|
||||
.unwrap();
|
||||
vault
|
||||
.lock(
|
||||
"lock-3",
|
||||
500,
|
||||
test_owner_alt(),
|
||||
test_recipient(),
|
||||
current_time,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(vault.total_locked, 3500);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_duplicate_lock() {
|
||||
let mut vault = Vault::new(
|
||||
VaultId::new("test"),
|
||||
ChainType::Ethereum,
|
||||
AssetId::eth(),
|
||||
);
|
||||
let mut vault = Vault::new(VaultId::new("test"), ChainType::Ethereum, AssetId::eth());
|
||||
|
||||
vault
|
||||
.lock("lock1", 1000, test_owner(), test_recipient(), 0)
|
||||
|
|
@ -697,20 +694,21 @@ mod tests {
|
|||
AssetId::eth(),
|
||||
);
|
||||
|
||||
vault.lock("lock-1", 1000, test_owner(), test_recipient(), 0).unwrap();
|
||||
vault
|
||||
.lock("lock-1", 1000, test_owner(), test_recipient(), 0)
|
||||
.unwrap();
|
||||
vault.unlock("lock-1").unwrap();
|
||||
|
||||
let result = vault.unlock("lock-1");
|
||||
assert!(matches!(result, Err(BridgeError::TransferAlreadyCompleted(_))));
|
||||
assert!(matches!(
|
||||
result,
|
||||
Err(BridgeError::TransferAlreadyCompleted(_))
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_vault_pause() {
|
||||
let mut vault = Vault::new(
|
||||
VaultId::new("test"),
|
||||
ChainType::Ethereum,
|
||||
AssetId::eth(),
|
||||
);
|
||||
let mut vault = Vault::new(VaultId::new("test"), ChainType::Ethereum, AssetId::eth());
|
||||
|
||||
vault.pause();
|
||||
|
||||
|
|
@ -730,7 +728,9 @@ mod tests {
|
|||
vault.resume();
|
||||
|
||||
assert_eq!(vault.state, VaultState::Active);
|
||||
vault.lock("lock-1", 1000, test_owner(), test_recipient(), 0).unwrap();
|
||||
vault
|
||||
.lock("lock-1", 1000, test_owner(), test_recipient(), 0)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -750,12 +750,8 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_daily_limit() {
|
||||
let mut vault = Vault::new(
|
||||
VaultId::new("test"),
|
||||
ChainType::Ethereum,
|
||||
AssetId::eth(),
|
||||
)
|
||||
.with_daily_limit(1000);
|
||||
let mut vault = Vault::new(VaultId::new("test"), ChainType::Ethereum, AssetId::eth())
|
||||
.with_daily_limit(1000);
|
||||
|
||||
let current_time = 86400 * 100;
|
||||
|
||||
|
|
@ -781,8 +777,24 @@ mod tests {
|
|||
);
|
||||
|
||||
let current_time = 0;
|
||||
vault.lock("lock-1", 1000000000, test_owner(), test_recipient(), current_time).unwrap();
|
||||
vault.lock("lock-2", 1000000000, test_owner(), test_recipient(), current_time).unwrap();
|
||||
vault
|
||||
.lock(
|
||||
"lock-1",
|
||||
1000000000,
|
||||
test_owner(),
|
||||
test_recipient(),
|
||||
current_time,
|
||||
)
|
||||
.unwrap();
|
||||
vault
|
||||
.lock(
|
||||
"lock-2",
|
||||
1000000000,
|
||||
test_owner(),
|
||||
test_recipient(),
|
||||
current_time,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(vault.total_locked, 2000000000);
|
||||
}
|
||||
|
|
@ -795,7 +807,9 @@ mod tests {
|
|||
AssetId::eth(),
|
||||
);
|
||||
|
||||
vault.lock("lock-1", 1000, test_owner(), test_recipient(), 0).unwrap();
|
||||
vault
|
||||
.lock("lock-1", 1000, test_owner(), test_recipient(), 0)
|
||||
.unwrap();
|
||||
|
||||
assert!(vault.get_locked("lock-1").is_some());
|
||||
assert!(vault.get_locked("nonexistent").is_none());
|
||||
|
|
@ -809,8 +823,12 @@ mod tests {
|
|||
AssetId::eth(),
|
||||
);
|
||||
|
||||
vault.lock("lock-1", 1000, test_owner(), test_recipient(), 0).unwrap();
|
||||
vault.lock("lock-2", 2000, test_owner(), test_recipient(), 0).unwrap();
|
||||
vault
|
||||
.lock("lock-1", 1000, test_owner(), test_recipient(), 0)
|
||||
.unwrap();
|
||||
vault
|
||||
.lock("lock-2", 2000, test_owner(), test_recipient(), 0)
|
||||
.unwrap();
|
||||
|
||||
let all: Vec<_> = vault.all_locked().collect();
|
||||
assert_eq!(all.len(), 2);
|
||||
|
|
@ -824,8 +842,12 @@ mod tests {
|
|||
AssetId::eth(),
|
||||
);
|
||||
|
||||
vault.lock("lock-1", 1000, test_owner(), test_recipient(), 0).unwrap();
|
||||
vault.lock("lock-2", 2000, test_owner(), test_recipient(), 0).unwrap();
|
||||
vault
|
||||
.lock("lock-1", 1000, test_owner(), test_recipient(), 0)
|
||||
.unwrap();
|
||||
vault
|
||||
.lock("lock-2", 2000, test_owner(), test_recipient(), 0)
|
||||
.unwrap();
|
||||
vault.unlock("lock-1").unwrap();
|
||||
|
||||
let active: Vec<_> = vault.active_locked().collect();
|
||||
|
|
@ -858,7 +880,9 @@ mod tests {
|
|||
assert!(manager.find_vault(&ChainType::Ethereum, ð).is_some());
|
||||
|
||||
let vault = manager.get_or_create_vault(ChainType::Ethereum, eth.clone());
|
||||
vault.lock("lock1", 100, test_owner(), test_recipient(), 0).unwrap();
|
||||
vault
|
||||
.lock("lock1", 100, test_owner(), test_recipient(), 0)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(manager.total_locked(), 100);
|
||||
}
|
||||
|
|
@ -881,7 +905,9 @@ mod tests {
|
|||
|
||||
{
|
||||
let vault = manager.get_vault_mut(&vault_id).unwrap();
|
||||
vault.lock("lock-1", 1000, test_owner(), test_recipient(), 0).unwrap();
|
||||
vault
|
||||
.lock("lock-1", 1000, test_owner(), test_recipient(), 0)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let vault = manager.get_vault(&vault_id).unwrap();
|
||||
|
|
@ -902,7 +928,9 @@ mod tests {
|
|||
manager.create_vault(ChainType::Ethereum, eth.clone());
|
||||
|
||||
let vault = manager.find_vault_mut(&ChainType::Ethereum, ð).unwrap();
|
||||
vault.lock("lock-1", 1000, test_owner(), test_recipient(), 0).unwrap();
|
||||
vault
|
||||
.lock("lock-1", 1000, test_owner(), test_recipient(), 0)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(manager.total_locked(), 1000);
|
||||
}
|
||||
|
|
@ -913,7 +941,9 @@ mod tests {
|
|||
let eth = AssetId::eth();
|
||||
|
||||
let vault = manager.get_or_create_vault(ChainType::Ethereum, eth.clone());
|
||||
vault.lock("lock-1", 1000, test_owner(), test_recipient(), 0).unwrap();
|
||||
vault
|
||||
.lock("lock-1", 1000, test_owner(), test_recipient(), 0)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(manager.vault_ids().len(), 1);
|
||||
assert_eq!(manager.total_locked(), 1000);
|
||||
|
|
|
|||
|
|
@ -241,7 +241,10 @@ impl DeviceRegistry {
|
|||
}
|
||||
|
||||
/// Gets a processor by ID.
|
||||
pub fn get_processor(&self, processor_id: ProcessorId) -> Result<Arc<dyn Processor>, ComputeError> {
|
||||
pub fn get_processor(
|
||||
&self,
|
||||
processor_id: ProcessorId,
|
||||
) -> Result<Arc<dyn Processor>, ComputeError> {
|
||||
self.processors
|
||||
.read()
|
||||
.get(&processor_id)
|
||||
|
|
@ -266,7 +269,10 @@ impl DeviceRegistry {
|
|||
|
||||
/// Gets the next processor ID.
|
||||
pub fn next_processor_id(&self) -> ProcessorId {
|
||||
ProcessorId(self.next_processor_id.fetch_add(1, std::sync::atomic::Ordering::SeqCst))
|
||||
ProcessorId(
|
||||
self.next_processor_id
|
||||
.fetch_add(1, std::sync::atomic::Ordering::SeqCst),
|
||||
)
|
||||
}
|
||||
|
||||
/// Gets total number of devices.
|
||||
|
|
@ -309,7 +315,10 @@ impl DeviceRegistry {
|
|||
device.status = status;
|
||||
Ok(())
|
||||
} else {
|
||||
Err(ComputeError::Internal(format!("Device not found: {}", device_id)))
|
||||
Err(ComputeError::Internal(format!(
|
||||
"Device not found: {}",
|
||||
device_id
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -323,7 +332,7 @@ impl Default for DeviceRegistry {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::processor::{CpuVariant, AvxSupport};
|
||||
use crate::processor::{AvxSupport, CpuVariant};
|
||||
|
||||
#[test]
|
||||
fn test_device_id() {
|
||||
|
|
|
|||
|
|
@ -67,6 +67,10 @@ pub use market::{
|
|||
ResourceType, SpotMarket, Trade,
|
||||
};
|
||||
pub use memory::{MemoryManager, TensorHandle, TransferPath, UnifiedMemory};
|
||||
pub use model::{
|
||||
ModelCategory, ModelFormat, ModelId, ModelInfo, ModelRegistry, ModelUploadRequest,
|
||||
ModelUploadResponse,
|
||||
};
|
||||
pub use processor::{
|
||||
ComputeThroughput, CpuVariant, GpuVariant, NpuVariant, Operation, OperationType, Processor,
|
||||
ProcessorCapabilities, ProcessorId, ProcessorType, TpuVersion,
|
||||
|
|
@ -78,10 +82,6 @@ pub use task::{
|
|||
ComputeTask, DecomposedWorkload, Task, TaskDecomposer, TaskId, TaskPriority, TaskResult,
|
||||
TaskStatus,
|
||||
};
|
||||
pub use model::{
|
||||
ModelCategory, ModelFormat, ModelId, ModelInfo, ModelRegistry, ModelUploadRequest,
|
||||
ModelUploadResponse,
|
||||
};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
|
@ -434,7 +434,10 @@ impl ComputeCluster {
|
|||
let jobs = self.jobs.read();
|
||||
|
||||
let total_nodes = nodes.len();
|
||||
let online_nodes = nodes.values().filter(|n| n.status == NodeStatus::Online).count();
|
||||
let online_nodes = nodes
|
||||
.values()
|
||||
.filter(|n| n.status == NodeStatus::Online)
|
||||
.count();
|
||||
|
||||
let total_gpus: usize = nodes
|
||||
.values()
|
||||
|
|
@ -515,16 +518,16 @@ pub enum GpuTier {
|
|||
impl Default for ComputePricing {
|
||||
fn default() -> Self {
|
||||
let mut gpu_hourly = HashMap::new();
|
||||
gpu_hourly.insert(GpuTier::Consumer, 100_000_000); // 0.10 SYNOR
|
||||
gpu_hourly.insert(GpuTier::Professional, 300_000_000); // 0.30 SYNOR
|
||||
gpu_hourly.insert(GpuTier::DataCenter, 2_000_000_000); // 2.00 SYNOR
|
||||
gpu_hourly.insert(GpuTier::Premium, 4_000_000_000); // 4.00 SYNOR
|
||||
gpu_hourly.insert(GpuTier::Consumer, 100_000_000); // 0.10 SYNOR
|
||||
gpu_hourly.insert(GpuTier::Professional, 300_000_000); // 0.30 SYNOR
|
||||
gpu_hourly.insert(GpuTier::DataCenter, 2_000_000_000); // 2.00 SYNOR
|
||||
gpu_hourly.insert(GpuTier::Premium, 4_000_000_000); // 4.00 SYNOR
|
||||
|
||||
Self {
|
||||
gpu_hourly,
|
||||
cpu_core_hour: 20_000_000, // 0.02 SYNOR
|
||||
memory_gb_hour: 5_000_000, // 0.005 SYNOR
|
||||
network_egress_gb: 50_000_000, // 0.05 SYNOR
|
||||
cpu_core_hour: 20_000_000, // 0.02 SYNOR
|
||||
memory_gb_hour: 5_000_000, // 0.005 SYNOR
|
||||
network_egress_gb: 50_000_000, // 0.05 SYNOR
|
||||
inference_per_million_tokens: 100_000_000, // 0.10 SYNOR
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -686,24 +686,24 @@ impl PricingEngine {
|
|||
pub fn greenest_region(&self) -> &str {
|
||||
self.regions
|
||||
.iter()
|
||||
.max_by(|a, b| {
|
||||
a.renewable_pct
|
||||
.partial_cmp(&b.renewable_pct)
|
||||
.unwrap()
|
||||
})
|
||||
.max_by(|a, b| a.renewable_pct.partial_cmp(&b.renewable_pct).unwrap())
|
||||
.map(|r| r.region.as_str())
|
||||
.unwrap_or("eu-north")
|
||||
}
|
||||
|
||||
/// Compares price to cloud providers.
|
||||
pub fn compare_to_cloud(&self, resource: &ResourceType, region: Option<&str>) -> CloudComparison {
|
||||
pub fn compare_to_cloud(
|
||||
&self,
|
||||
resource: &ResourceType,
|
||||
region: Option<&str>,
|
||||
) -> CloudComparison {
|
||||
let our_price = self.spot_price(resource, region);
|
||||
|
||||
// Approximate cloud provider prices (USD/hour for GPU)
|
||||
let (aws_price, gcp_price, azure_price) = match resource {
|
||||
ResourceType::GpuHours(GpuTier::DataCenter) => (3.06, 2.95, 3.10), // A100 equivalents
|
||||
ResourceType::GpuHours(GpuTier::Ultra) => (5.00, 4.50, 5.20), // H100 equivalents
|
||||
ResourceType::GpuHours(GpuTier::High) => (1.50, 1.40, 1.60), // T4/A10 equivalents
|
||||
ResourceType::GpuHours(GpuTier::Ultra) => (5.00, 4.50, 5.20), // H100 equivalents
|
||||
ResourceType::GpuHours(GpuTier::High) => (1.50, 1.40, 1.60), // T4/A10 equivalents
|
||||
ResourceType::CpuHours(CpuTier::Server) => (0.40, 0.35, 0.42),
|
||||
_ => (1.0, 1.0, 1.0),
|
||||
};
|
||||
|
|
@ -888,9 +888,18 @@ impl SpotMarket {
|
|||
);
|
||||
}
|
||||
|
||||
order_books.insert(ResourceType::TpuHours, OrderBook::new(ResourceType::TpuHours));
|
||||
order_books.insert(ResourceType::NpuHours, OrderBook::new(ResourceType::NpuHours));
|
||||
order_books.insert(ResourceType::LpuCredits, OrderBook::new(ResourceType::LpuCredits));
|
||||
order_books.insert(
|
||||
ResourceType::TpuHours,
|
||||
OrderBook::new(ResourceType::TpuHours),
|
||||
);
|
||||
order_books.insert(
|
||||
ResourceType::NpuHours,
|
||||
OrderBook::new(ResourceType::NpuHours),
|
||||
);
|
||||
order_books.insert(
|
||||
ResourceType::LpuCredits,
|
||||
OrderBook::new(ResourceType::LpuCredits),
|
||||
);
|
||||
|
||||
Self {
|
||||
order_books,
|
||||
|
|
@ -1074,12 +1083,21 @@ mod tests {
|
|||
fn test_pricing_engine() {
|
||||
let engine = PricingEngine::new();
|
||||
|
||||
let price = engine.spot_price(&ResourceType::GpuHours(GpuTier::DataCenter), Some("eu-north"));
|
||||
let price = engine.spot_price(
|
||||
&ResourceType::GpuHours(GpuTier::DataCenter),
|
||||
Some("eu-north"),
|
||||
);
|
||||
assert!(price > 0.0);
|
||||
|
||||
// eu-north should be cheaper (low electricity cost)
|
||||
let eu_price = engine.spot_price(&ResourceType::GpuHours(GpuTier::DataCenter), Some("eu-north"));
|
||||
let eu_west_price = engine.spot_price(&ResourceType::GpuHours(GpuTier::DataCenter), Some("eu-west"));
|
||||
let eu_price = engine.spot_price(
|
||||
&ResourceType::GpuHours(GpuTier::DataCenter),
|
||||
Some("eu-north"),
|
||||
);
|
||||
let eu_west_price = engine.spot_price(
|
||||
&ResourceType::GpuHours(GpuTier::DataCenter),
|
||||
Some("eu-west"),
|
||||
);
|
||||
|
||||
// eu-north has cheaper electricity
|
||||
assert!(eu_price < eu_west_price);
|
||||
|
|
@ -1089,7 +1107,8 @@ mod tests {
|
|||
fn test_cloud_comparison() {
|
||||
let engine = PricingEngine::new();
|
||||
|
||||
let comparison = engine.compare_to_cloud(&ResourceType::GpuHours(GpuTier::DataCenter), None);
|
||||
let comparison =
|
||||
engine.compare_to_cloud(&ResourceType::GpuHours(GpuTier::DataCenter), None);
|
||||
|
||||
// Should show significant savings
|
||||
assert!(comparison.aws_savings > 50.0);
|
||||
|
|
|
|||
|
|
@ -106,11 +106,11 @@ impl TransferPath {
|
|||
/// Returns approximate bandwidth in GB/s.
|
||||
pub fn bandwidth_gbps(&self) -> f64 {
|
||||
match self {
|
||||
TransferPath::NvLink => 900.0, // NVLink 4.0
|
||||
TransferPath::NvLink => 900.0, // NVLink 4.0
|
||||
TransferPath::PciePeerToPeer => 64.0, // PCIe 5.0 x16
|
||||
TransferPath::CpuMediated => 50.0, // DDR5
|
||||
TransferPath::CpuMediated => 50.0, // DDR5
|
||||
TransferPath::UnifiedMemory => 400.0, // Apple unified
|
||||
TransferPath::Network => 10.0, // 100Gbps network
|
||||
TransferPath::Network => 10.0, // 100Gbps network
|
||||
TransferPath::SameMemory => f64::INFINITY,
|
||||
}
|
||||
}
|
||||
|
|
@ -154,7 +154,11 @@ impl MemoryManager {
|
|||
}
|
||||
|
||||
/// Allocates a tensor.
|
||||
pub fn allocate(&self, shape: Vec<usize>, dtype: DataType) -> Result<TensorHandle, ComputeError> {
|
||||
pub fn allocate(
|
||||
&self,
|
||||
shape: Vec<usize>,
|
||||
dtype: DataType,
|
||||
) -> Result<TensorHandle, ComputeError> {
|
||||
let handle = TensorHandle::new(shape, dtype);
|
||||
self.tensors.write().insert(handle.id, handle.clone());
|
||||
Ok(handle)
|
||||
|
|
@ -223,9 +227,13 @@ impl MemoryManager {
|
|||
}
|
||||
|
||||
// Check for NVLink between NVIDIA GPUs
|
||||
if matches!(from, ProcessorType::Gpu(crate::processor::GpuVariant::NvidiaCuda { .. }))
|
||||
&& matches!(to, ProcessorType::Gpu(crate::processor::GpuVariant::NvidiaCuda { .. }))
|
||||
{
|
||||
if matches!(
|
||||
from,
|
||||
ProcessorType::Gpu(crate::processor::GpuVariant::NvidiaCuda { .. })
|
||||
) && matches!(
|
||||
to,
|
||||
ProcessorType::Gpu(crate::processor::GpuVariant::NvidiaCuda { .. })
|
||||
) {
|
||||
return TransferPath::NvLink;
|
||||
}
|
||||
|
||||
|
|
@ -244,10 +252,22 @@ impl MemoryManager {
|
|||
|
||||
match (a, b) {
|
||||
// Apple Silicon unified memory
|
||||
(ProcessorType::Cpu(CpuVariant::Arm64 { .. }), ProcessorType::Gpu(GpuVariant::AppleMetal))
|
||||
| (ProcessorType::Gpu(GpuVariant::AppleMetal), ProcessorType::Cpu(CpuVariant::Arm64 { .. }))
|
||||
| (ProcessorType::Npu(NpuVariant::AppleNeuralEngine { .. }), ProcessorType::Cpu(CpuVariant::Arm64 { .. }))
|
||||
| (ProcessorType::Npu(NpuVariant::AppleNeuralEngine { .. }), ProcessorType::Gpu(GpuVariant::AppleMetal)) => true,
|
||||
(
|
||||
ProcessorType::Cpu(CpuVariant::Arm64 { .. }),
|
||||
ProcessorType::Gpu(GpuVariant::AppleMetal),
|
||||
)
|
||||
| (
|
||||
ProcessorType::Gpu(GpuVariant::AppleMetal),
|
||||
ProcessorType::Cpu(CpuVariant::Arm64 { .. }),
|
||||
)
|
||||
| (
|
||||
ProcessorType::Npu(NpuVariant::AppleNeuralEngine { .. }),
|
||||
ProcessorType::Cpu(CpuVariant::Arm64 { .. }),
|
||||
)
|
||||
| (
|
||||
ProcessorType::Npu(NpuVariant::AppleNeuralEngine { .. }),
|
||||
ProcessorType::Gpu(GpuVariant::AppleMetal),
|
||||
) => true,
|
||||
// Same type
|
||||
_ if a == b => true,
|
||||
_ => false,
|
||||
|
|
@ -325,7 +345,9 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_transfer_path_bandwidth() {
|
||||
assert!(TransferPath::NvLink.bandwidth_gbps() > TransferPath::PciePeerToPeer.bandwidth_gbps());
|
||||
assert!(
|
||||
TransferPath::NvLink.bandwidth_gbps() > TransferPath::PciePeerToPeer.bandwidth_gbps()
|
||||
);
|
||||
assert!(TransferPath::SameMemory.bandwidth_gbps().is_infinite());
|
||||
}
|
||||
|
||||
|
|
@ -333,7 +355,9 @@ mod tests {
|
|||
fn test_memory_manager() {
|
||||
let manager = MemoryManager::new();
|
||||
|
||||
let handle = manager.allocate(vec![1024, 1024], DataType::Float32).unwrap();
|
||||
let handle = manager
|
||||
.allocate(vec![1024, 1024], DataType::Float32)
|
||||
.unwrap();
|
||||
assert_eq!(manager.tensor_count(), 1);
|
||||
|
||||
manager.free(handle.id).unwrap();
|
||||
|
|
@ -347,22 +371,26 @@ mod tests {
|
|||
let handle = manager.allocate(vec![1024], DataType::Float32).unwrap();
|
||||
|
||||
// First ensure should allocate
|
||||
let path = manager.ensure_on(
|
||||
handle.id,
|
||||
ProcessorType::Gpu(crate::processor::GpuVariant::NvidiaCuda {
|
||||
compute_capability: (8, 0),
|
||||
}),
|
||||
).unwrap();
|
||||
let path = manager
|
||||
.ensure_on(
|
||||
handle.id,
|
||||
ProcessorType::Gpu(crate::processor::GpuVariant::NvidiaCuda {
|
||||
compute_capability: (8, 0),
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(path, TransferPath::SameMemory);
|
||||
|
||||
// Second ensure to same location should be same memory
|
||||
let path = manager.ensure_on(
|
||||
handle.id,
|
||||
ProcessorType::Gpu(crate::processor::GpuVariant::NvidiaCuda {
|
||||
compute_capability: (8, 0),
|
||||
}),
|
||||
).unwrap();
|
||||
let path = manager
|
||||
.ensure_on(
|
||||
handle.id,
|
||||
ProcessorType::Gpu(crate::processor::GpuVariant::NvidiaCuda {
|
||||
compute_capability: (8, 0),
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(path, TransferPath::SameMemory);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -140,13 +140,7 @@ pub struct ModelInfo {
|
|||
|
||||
impl ModelInfo {
|
||||
/// Creates a new LLM model info.
|
||||
pub fn llm(
|
||||
alias: &str,
|
||||
name: &str,
|
||||
cid: &str,
|
||||
parameters: u64,
|
||||
context_length: u32,
|
||||
) -> Self {
|
||||
pub fn llm(alias: &str, name: &str, cid: &str, parameters: u64, context_length: u32) -> Self {
|
||||
Self {
|
||||
id: ModelId::from_alias(alias),
|
||||
name: name.to_string(),
|
||||
|
|
@ -156,7 +150,12 @@ impl ModelInfo {
|
|||
format: ModelFormat::SafeTensors,
|
||||
size_bytes: parameters * 2, // ~2 bytes per param in fp16
|
||||
parameters,
|
||||
supported_precisions: vec![Precision::Fp16, Precision::Bf16, Precision::Int8, Precision::Int4],
|
||||
supported_precisions: vec![
|
||||
Precision::Fp16,
|
||||
Precision::Bf16,
|
||||
Precision::Int8,
|
||||
Precision::Int4,
|
||||
],
|
||||
recommended_processor: ProcessorType::Lpu,
|
||||
context_length: Some(context_length),
|
||||
input_schema: None,
|
||||
|
|
@ -238,33 +237,123 @@ impl ModelRegistry {
|
|||
let default_models = vec![
|
||||
// ===== LLMs =====
|
||||
// Llama 3 family
|
||||
ModelInfo::llm("llama-3-8b", "Llama 3 8B", "QmLlama3_8B_placeholder", 8_000_000_000, 8192),
|
||||
ModelInfo::llm("llama-3-70b", "Llama 3 70B", "QmLlama3_70B_placeholder", 70_000_000_000, 8192),
|
||||
ModelInfo::llm("llama-3.1-8b", "Llama 3.1 8B", "QmLlama31_8B_placeholder", 8_000_000_000, 128000),
|
||||
ModelInfo::llm("llama-3.1-70b", "Llama 3.1 70B", "QmLlama31_70B_placeholder", 70_000_000_000, 128000),
|
||||
ModelInfo::llm("llama-3.1-405b", "Llama 3.1 405B", "QmLlama31_405B_placeholder", 405_000_000_000, 128000),
|
||||
|
||||
ModelInfo::llm(
|
||||
"llama-3-8b",
|
||||
"Llama 3 8B",
|
||||
"QmLlama3_8B_placeholder",
|
||||
8_000_000_000,
|
||||
8192,
|
||||
),
|
||||
ModelInfo::llm(
|
||||
"llama-3-70b",
|
||||
"Llama 3 70B",
|
||||
"QmLlama3_70B_placeholder",
|
||||
70_000_000_000,
|
||||
8192,
|
||||
),
|
||||
ModelInfo::llm(
|
||||
"llama-3.1-8b",
|
||||
"Llama 3.1 8B",
|
||||
"QmLlama31_8B_placeholder",
|
||||
8_000_000_000,
|
||||
128000,
|
||||
),
|
||||
ModelInfo::llm(
|
||||
"llama-3.1-70b",
|
||||
"Llama 3.1 70B",
|
||||
"QmLlama31_70B_placeholder",
|
||||
70_000_000_000,
|
||||
128000,
|
||||
),
|
||||
ModelInfo::llm(
|
||||
"llama-3.1-405b",
|
||||
"Llama 3.1 405B",
|
||||
"QmLlama31_405B_placeholder",
|
||||
405_000_000_000,
|
||||
128000,
|
||||
),
|
||||
// Mistral family
|
||||
ModelInfo::llm("mistral-7b", "Mistral 7B", "QmMistral7B_placeholder", 7_000_000_000, 32768),
|
||||
ModelInfo::llm("mixtral-8x7b", "Mixtral 8x7B", "QmMixtral8x7B_placeholder", 46_000_000_000, 32768),
|
||||
ModelInfo::llm("mixtral-8x22b", "Mixtral 8x22B", "QmMixtral8x22B_placeholder", 176_000_000_000, 65536),
|
||||
|
||||
ModelInfo::llm(
|
||||
"mistral-7b",
|
||||
"Mistral 7B",
|
||||
"QmMistral7B_placeholder",
|
||||
7_000_000_000,
|
||||
32768,
|
||||
),
|
||||
ModelInfo::llm(
|
||||
"mixtral-8x7b",
|
||||
"Mixtral 8x7B",
|
||||
"QmMixtral8x7B_placeholder",
|
||||
46_000_000_000,
|
||||
32768,
|
||||
),
|
||||
ModelInfo::llm(
|
||||
"mixtral-8x22b",
|
||||
"Mixtral 8x22B",
|
||||
"QmMixtral8x22B_placeholder",
|
||||
176_000_000_000,
|
||||
65536,
|
||||
),
|
||||
// Qwen family
|
||||
ModelInfo::llm("qwen-2.5-7b", "Qwen 2.5 7B", "QmQwen25_7B_placeholder", 7_000_000_000, 128000),
|
||||
ModelInfo::llm("qwen-2.5-72b", "Qwen 2.5 72B", "QmQwen25_72B_placeholder", 72_000_000_000, 128000),
|
||||
|
||||
ModelInfo::llm(
|
||||
"qwen-2.5-7b",
|
||||
"Qwen 2.5 7B",
|
||||
"QmQwen25_7B_placeholder",
|
||||
7_000_000_000,
|
||||
128000,
|
||||
),
|
||||
ModelInfo::llm(
|
||||
"qwen-2.5-72b",
|
||||
"Qwen 2.5 72B",
|
||||
"QmQwen25_72B_placeholder",
|
||||
72_000_000_000,
|
||||
128000,
|
||||
),
|
||||
// DeepSeek family
|
||||
ModelInfo::llm("deepseek-v2", "DeepSeek V2", "QmDeepSeekV2_placeholder", 236_000_000_000, 128000),
|
||||
ModelInfo::llm("deepseek-coder-33b", "DeepSeek Coder 33B", "QmDeepSeekCoder33B_placeholder", 33_000_000_000, 16384),
|
||||
|
||||
ModelInfo::llm(
|
||||
"deepseek-v2",
|
||||
"DeepSeek V2",
|
||||
"QmDeepSeekV2_placeholder",
|
||||
236_000_000_000,
|
||||
128000,
|
||||
),
|
||||
ModelInfo::llm(
|
||||
"deepseek-coder-33b",
|
||||
"DeepSeek Coder 33B",
|
||||
"QmDeepSeekCoder33B_placeholder",
|
||||
33_000_000_000,
|
||||
16384,
|
||||
),
|
||||
// Phi family (small/efficient)
|
||||
ModelInfo::llm("phi-3-mini", "Phi 3 Mini", "QmPhi3Mini_placeholder", 3_800_000_000, 128000),
|
||||
ModelInfo::llm("phi-3-medium", "Phi 3 Medium", "QmPhi3Medium_placeholder", 14_000_000_000, 128000),
|
||||
|
||||
ModelInfo::llm(
|
||||
"phi-3-mini",
|
||||
"Phi 3 Mini",
|
||||
"QmPhi3Mini_placeholder",
|
||||
3_800_000_000,
|
||||
128000,
|
||||
),
|
||||
ModelInfo::llm(
|
||||
"phi-3-medium",
|
||||
"Phi 3 Medium",
|
||||
"QmPhi3Medium_placeholder",
|
||||
14_000_000_000,
|
||||
128000,
|
||||
),
|
||||
// Code models
|
||||
ModelInfo::llm("codellama-34b", "Code Llama 34B", "QmCodeLlama34B_placeholder", 34_000_000_000, 16384),
|
||||
ModelInfo::llm("starcoder2-15b", "StarCoder2 15B", "QmStarCoder2_15B_placeholder", 15_000_000_000, 16384),
|
||||
|
||||
ModelInfo::llm(
|
||||
"codellama-34b",
|
||||
"Code Llama 34B",
|
||||
"QmCodeLlama34B_placeholder",
|
||||
34_000_000_000,
|
||||
16384,
|
||||
),
|
||||
ModelInfo::llm(
|
||||
"starcoder2-15b",
|
||||
"StarCoder2 15B",
|
||||
"QmStarCoder2_15B_placeholder",
|
||||
15_000_000_000,
|
||||
16384,
|
||||
),
|
||||
// ===== Embedding Models =====
|
||||
ModelInfo {
|
||||
id: ModelId::from_alias("bge-large"),
|
||||
|
|
@ -306,7 +395,6 @@ impl ModelRegistry {
|
|||
is_public: true,
|
||||
owner: None,
|
||||
},
|
||||
|
||||
// ===== Vision Models =====
|
||||
ModelInfo {
|
||||
id: ModelId::from_alias("stable-diffusion-xl"),
|
||||
|
|
@ -348,7 +436,6 @@ impl ModelRegistry {
|
|||
is_public: true,
|
||||
owner: None,
|
||||
},
|
||||
|
||||
// ===== Speech Models =====
|
||||
ModelInfo {
|
||||
id: ModelId::from_alias("whisper-large-v3"),
|
||||
|
|
@ -370,7 +457,6 @@ impl ModelRegistry {
|
|||
is_public: true,
|
||||
owner: None,
|
||||
},
|
||||
|
||||
// ===== Multi-Modal Models =====
|
||||
ModelInfo {
|
||||
id: ModelId::from_alias("llava-1.5-13b"),
|
||||
|
|
@ -555,7 +641,9 @@ mod tests {
|
|||
let registry = ModelRegistry::new();
|
||||
let results = registry.search("llama");
|
||||
assert!(!results.is_empty());
|
||||
assert!(results.iter().all(|m| m.name.to_lowercase().contains("llama")));
|
||||
assert!(results
|
||||
.iter()
|
||||
.all(|m| m.name.to_lowercase().contains("llama")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
|||
|
|
@ -305,7 +305,7 @@ impl ProcessorCapabilities {
|
|||
},
|
||||
memory: MemorySpecs {
|
||||
capacity_bytes: 230 * 1024 * 1024 * 1024, // 230 GB SRAM!
|
||||
bandwidth_gbps: 80_000, // 80 TB/s internal
|
||||
bandwidth_gbps: 80_000, // 80 TB/s internal
|
||||
type_: MemoryType::Sram,
|
||||
},
|
||||
operations: Self::lpu_operations(),
|
||||
|
|
@ -349,8 +349,8 @@ impl ProcessorCapabilities {
|
|||
/// Creates Apple Neural Engine capabilities.
|
||||
pub fn apple_neural_engine(cores: u32) -> Self {
|
||||
let int8_tops = match cores {
|
||||
16 => 18.0, // M3
|
||||
32 => 35.0, // M3 Max
|
||||
16 => 18.0, // M3
|
||||
32 => 35.0, // M3 Max
|
||||
_ => cores as f64 * 1.1,
|
||||
};
|
||||
|
||||
|
|
@ -542,6 +542,8 @@ mod tests {
|
|||
fn test_lpu_capabilities() {
|
||||
let caps = ProcessorCapabilities::lpu();
|
||||
assert!(caps.memory.bandwidth_gbps > 10000); // Very high internal bandwidth
|
||||
assert!(caps.optimal_for.contains(&WorkloadCharacteristic::Sequential));
|
||||
assert!(caps
|
||||
.optimal_for
|
||||
.contains(&WorkloadCharacteristic::Sequential));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -253,10 +253,22 @@ impl Processor for GenericProcessor {
|
|||
fn shares_memory_with(&self, other: &ProcessorType) -> bool {
|
||||
match (&self.processor_type, other) {
|
||||
// Apple Silicon has unified memory
|
||||
(ProcessorType::Cpu(CpuVariant::Arm64 { .. }), ProcessorType::Gpu(GpuVariant::AppleMetal))
|
||||
| (ProcessorType::Gpu(GpuVariant::AppleMetal), ProcessorType::Cpu(CpuVariant::Arm64 { .. }))
|
||||
| (ProcessorType::Npu(NpuVariant::AppleNeuralEngine { .. }), ProcessorType::Cpu(CpuVariant::Arm64 { .. }))
|
||||
| (ProcessorType::Npu(NpuVariant::AppleNeuralEngine { .. }), ProcessorType::Gpu(GpuVariant::AppleMetal)) => true,
|
||||
(
|
||||
ProcessorType::Cpu(CpuVariant::Arm64 { .. }),
|
||||
ProcessorType::Gpu(GpuVariant::AppleMetal),
|
||||
)
|
||||
| (
|
||||
ProcessorType::Gpu(GpuVariant::AppleMetal),
|
||||
ProcessorType::Cpu(CpuVariant::Arm64 { .. }),
|
||||
)
|
||||
| (
|
||||
ProcessorType::Npu(NpuVariant::AppleNeuralEngine { .. }),
|
||||
ProcessorType::Cpu(CpuVariant::Arm64 { .. }),
|
||||
)
|
||||
| (
|
||||
ProcessorType::Npu(NpuVariant::AppleNeuralEngine { .. }),
|
||||
ProcessorType::Gpu(GpuVariant::AppleMetal),
|
||||
) => true,
|
||||
// Same type always shares
|
||||
(a, b) if a == b => true,
|
||||
_ => false,
|
||||
|
|
|
|||
|
|
@ -191,10 +191,7 @@ pub enum Operation {
|
|||
},
|
||||
|
||||
/// Data loading from storage.
|
||||
DataLoad {
|
||||
bytes: usize,
|
||||
async_: bool,
|
||||
},
|
||||
DataLoad { bytes: usize, async_: bool },
|
||||
|
||||
/// Data preprocessing.
|
||||
DataPreprocess {
|
||||
|
|
@ -209,16 +206,10 @@ pub enum Operation {
|
|||
},
|
||||
|
||||
/// Detokenization.
|
||||
Detokenization {
|
||||
tokens: usize,
|
||||
vocab_size: usize,
|
||||
},
|
||||
Detokenization { tokens: usize, vocab_size: usize },
|
||||
|
||||
/// Checkpoint save.
|
||||
Checkpoint {
|
||||
bytes: usize,
|
||||
async_: bool,
|
||||
},
|
||||
Checkpoint { bytes: usize, async_: bool },
|
||||
|
||||
/// All-reduce across devices.
|
||||
AllReduce {
|
||||
|
|
@ -228,9 +219,7 @@ pub enum Operation {
|
|||
},
|
||||
|
||||
/// Backward pass for a layer.
|
||||
Backward {
|
||||
forward_op: Box<Operation>,
|
||||
},
|
||||
Backward { forward_op: Box<Operation> },
|
||||
|
||||
/// Optimizer step.
|
||||
OptimizerStep {
|
||||
|
|
@ -240,16 +229,10 @@ pub enum Operation {
|
|||
},
|
||||
|
||||
/// Transpose.
|
||||
Transpose {
|
||||
shape: Vec<usize>,
|
||||
axes: Vec<usize>,
|
||||
},
|
||||
Transpose { shape: Vec<usize>, axes: Vec<usize> },
|
||||
|
||||
/// Reshape.
|
||||
Reshape {
|
||||
from: Vec<usize>,
|
||||
to: Vec<usize>,
|
||||
},
|
||||
Reshape { from: Vec<usize>, to: Vec<usize> },
|
||||
|
||||
/// Concatenate tensors.
|
||||
Concat {
|
||||
|
|
@ -378,9 +361,7 @@ impl Operation {
|
|||
| Operation::SiLU { elements } => *elements as f64,
|
||||
|
||||
// Softmax: ~5 ops per element (exp, sum, div)
|
||||
Operation::Softmax {
|
||||
batch, seq_len, ..
|
||||
} => 5.0 * (*batch as f64) * (*seq_len as f64),
|
||||
Operation::Softmax { batch, seq_len, .. } => 5.0 * (*batch as f64) * (*seq_len as f64),
|
||||
|
||||
// Embedding: just lookup, minimal FLOPS
|
||||
Operation::Embedding {
|
||||
|
|
|
|||
|
|
@ -39,8 +39,7 @@ impl ProcessorProfiles {
|
|||
bandwidth_gbps: 460,
|
||||
type_: MemoryType::Ddr5,
|
||||
},
|
||||
operations: ProcessorCapabilities::cpu(96, 2.4, false)
|
||||
.operations,
|
||||
operations: ProcessorCapabilities::cpu(96, 2.4, false).operations,
|
||||
power: PowerCharacteristics {
|
||||
tdp_watts: 360,
|
||||
efficiency: 0.85,
|
||||
|
|
@ -70,8 +69,7 @@ impl ProcessorProfiles {
|
|||
bandwidth_gbps: 307,
|
||||
type_: MemoryType::Ddr5,
|
||||
},
|
||||
operations: ProcessorCapabilities::cpu(56, 2.9, true)
|
||||
.operations,
|
||||
operations: ProcessorCapabilities::cpu(56, 2.9, true).operations,
|
||||
power: PowerCharacteristics {
|
||||
tdp_watts: 350,
|
||||
efficiency: 0.80,
|
||||
|
|
@ -101,8 +99,7 @@ impl ProcessorProfiles {
|
|||
bandwidth_gbps: 400,
|
||||
type_: MemoryType::Unified,
|
||||
},
|
||||
operations: ProcessorCapabilities::cpu(16, 4.0, false)
|
||||
.operations,
|
||||
operations: ProcessorCapabilities::cpu(16, 4.0, false).operations,
|
||||
power: PowerCharacteristics {
|
||||
tdp_watts: 40,
|
||||
efficiency: 0.95,
|
||||
|
|
@ -141,8 +138,7 @@ impl ProcessorProfiles {
|
|||
bandwidth_gbps: 3350,
|
||||
type_: MemoryType::Hbm3,
|
||||
},
|
||||
operations: ProcessorCapabilities::nvidia_gpu(16896, 528, 80, 3350, (9, 0))
|
||||
.operations,
|
||||
operations: ProcessorCapabilities::nvidia_gpu(16896, 528, 80, 3350, (9, 0)).operations,
|
||||
power: PowerCharacteristics {
|
||||
tdp_watts: 700,
|
||||
efficiency: 0.90,
|
||||
|
|
@ -173,8 +169,7 @@ impl ProcessorProfiles {
|
|||
bandwidth_gbps: 2039,
|
||||
type_: MemoryType::Hbm2e,
|
||||
},
|
||||
operations: ProcessorCapabilities::nvidia_gpu(6912, 432, 80, 2039, (8, 0))
|
||||
.operations,
|
||||
operations: ProcessorCapabilities::nvidia_gpu(6912, 432, 80, 2039, (8, 0)).operations,
|
||||
power: PowerCharacteristics {
|
||||
tdp_watts: 400,
|
||||
efficiency: 0.88,
|
||||
|
|
@ -205,8 +200,7 @@ impl ProcessorProfiles {
|
|||
bandwidth_gbps: 1008,
|
||||
type_: MemoryType::Gddr6,
|
||||
},
|
||||
operations: ProcessorCapabilities::nvidia_gpu(16384, 512, 24, 1008, (8, 9))
|
||||
.operations,
|
||||
operations: ProcessorCapabilities::nvidia_gpu(16384, 512, 24, 1008, (8, 9)).operations,
|
||||
power: PowerCharacteristics {
|
||||
tdp_watts: 450,
|
||||
efficiency: 0.85,
|
||||
|
|
@ -236,8 +230,7 @@ impl ProcessorProfiles {
|
|||
bandwidth_gbps: 936,
|
||||
type_: MemoryType::Gddr6,
|
||||
},
|
||||
operations: ProcessorCapabilities::nvidia_gpu(10496, 328, 24, 936, (8, 6))
|
||||
.operations,
|
||||
operations: ProcessorCapabilities::nvidia_gpu(10496, 328, 24, 936, (8, 6)).operations,
|
||||
power: PowerCharacteristics {
|
||||
tdp_watts: 350,
|
||||
efficiency: 0.82,
|
||||
|
|
@ -272,8 +265,8 @@ impl ProcessorProfiles {
|
|||
type_: MemoryType::Hbm3,
|
||||
},
|
||||
operations: {
|
||||
let mut ops = ProcessorCapabilities::nvidia_gpu(16384, 512, 80, 5300, (9, 0))
|
||||
.operations;
|
||||
let mut ops =
|
||||
ProcessorCapabilities::nvidia_gpu(16384, 512, 80, 5300, (9, 0)).operations;
|
||||
ops.remove(&OperationType::FlashAttention); // Different implementation
|
||||
ops
|
||||
},
|
||||
|
|
@ -308,8 +301,8 @@ impl ProcessorProfiles {
|
|||
type_: MemoryType::Gddr6,
|
||||
},
|
||||
operations: {
|
||||
let mut ops = ProcessorCapabilities::nvidia_gpu(6144, 0, 24, 960, (8, 0))
|
||||
.operations;
|
||||
let mut ops =
|
||||
ProcessorCapabilities::nvidia_gpu(6144, 0, 24, 960, (8, 0)).operations;
|
||||
ops.remove(&OperationType::FlashAttention);
|
||||
ops
|
||||
},
|
||||
|
|
@ -318,9 +311,7 @@ impl ProcessorProfiles {
|
|||
efficiency: 0.80,
|
||||
power_tier: PowerTier::High,
|
||||
},
|
||||
optimal_for: vec![
|
||||
WorkloadCharacteristic::HighlyParallel,
|
||||
],
|
||||
optimal_for: vec![WorkloadCharacteristic::HighlyParallel],
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -429,8 +420,7 @@ impl ProcessorProfiles {
|
|||
bandwidth_gbps: 200,
|
||||
type_: MemoryType::Unified,
|
||||
},
|
||||
operations: ProcessorCapabilities::apple_neural_engine(16)
|
||||
.operations,
|
||||
operations: ProcessorCapabilities::apple_neural_engine(16).operations,
|
||||
power: PowerCharacteristics {
|
||||
tdp_watts: 8,
|
||||
efficiency: 0.98,
|
||||
|
|
@ -465,8 +455,7 @@ impl ProcessorProfiles {
|
|||
bandwidth_gbps: 77,
|
||||
type_: MemoryType::Lpddr,
|
||||
},
|
||||
operations: ProcessorCapabilities::apple_neural_engine(16)
|
||||
.operations,
|
||||
operations: ProcessorCapabilities::apple_neural_engine(16).operations,
|
||||
power: PowerCharacteristics {
|
||||
tdp_watts: 10,
|
||||
efficiency: 0.95,
|
||||
|
|
|
|||
|
|
@ -24,10 +24,7 @@ pub enum ProcessorType {
|
|||
/// WebAssembly runtime.
|
||||
Wasm,
|
||||
/// Custom/Unknown accelerator.
|
||||
Custom {
|
||||
vendor: String,
|
||||
model: String,
|
||||
},
|
||||
Custom { vendor: String, model: String },
|
||||
}
|
||||
|
||||
impl Default for ProcessorType {
|
||||
|
|
|
|||
|
|
@ -6,10 +6,10 @@
|
|||
//! - Latency-aware scheduling
|
||||
//! - Real-time utilization metrics
|
||||
|
||||
use super::TaskAssignment;
|
||||
use crate::device::DeviceRegistry;
|
||||
use crate::processor::{Operation, OperationType, ProcessorId, ProcessorType};
|
||||
use crate::task::{Task, TaskId, TaskPriority};
|
||||
use super::TaskAssignment;
|
||||
use parking_lot::RwLock;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
|
|
@ -127,8 +127,12 @@ impl LoadBalancer {
|
|||
/// Register a processor with its type.
|
||||
pub fn register_processor(&self, processor_id: ProcessorId, processor_type: ProcessorType) {
|
||||
self.loads.write().insert(processor_id, AtomicU64::new(0));
|
||||
self.metrics.write().insert(processor_id, ProcessorMetrics::default());
|
||||
self.processor_types.write().insert(processor_id, processor_type);
|
||||
self.metrics
|
||||
.write()
|
||||
.insert(processor_id, ProcessorMetrics::default());
|
||||
self.processor_types
|
||||
.write()
|
||||
.insert(processor_id, processor_type);
|
||||
}
|
||||
|
||||
/// Unregister a processor.
|
||||
|
|
@ -150,7 +154,8 @@ impl LoadBalancer {
|
|||
|
||||
/// Get current load for a processor.
|
||||
pub fn get_load(&self, processor_id: ProcessorId) -> u64 {
|
||||
self.loads.read()
|
||||
self.loads
|
||||
.read()
|
||||
.get(&processor_id)
|
||||
.map(|l| l.load(Ordering::Relaxed))
|
||||
.unwrap_or(0)
|
||||
|
|
@ -179,140 +184,140 @@ impl LoadBalancer {
|
|||
ProcessorType::Cpu(_) => matches!(
|
||||
op_type,
|
||||
OperationType::MatMul
|
||||
| OperationType::Conv2d
|
||||
| OperationType::Conv3d
|
||||
| OperationType::DepthwiseConv
|
||||
| OperationType::BatchNorm
|
||||
| OperationType::LayerNorm
|
||||
| OperationType::Add
|
||||
| OperationType::Mul
|
||||
| OperationType::ReLU
|
||||
| OperationType::GeLU
|
||||
| OperationType::SiLU
|
||||
| OperationType::Softmax
|
||||
| OperationType::Sum
|
||||
| OperationType::Mean
|
||||
| OperationType::Max
|
||||
| OperationType::ArgMax
|
||||
| OperationType::Embedding
|
||||
| OperationType::TopK
|
||||
| OperationType::Sampling
|
||||
| OperationType::Tokenization
|
||||
| OperationType::Detokenization
|
||||
| OperationType::DataLoad
|
||||
| OperationType::DataPreprocess
|
||||
| OperationType::Transpose
|
||||
| OperationType::Reshape
|
||||
| OperationType::Concat
|
||||
| OperationType::Split
|
||||
| OperationType::Conv2d
|
||||
| OperationType::Conv3d
|
||||
| OperationType::DepthwiseConv
|
||||
| OperationType::BatchNorm
|
||||
| OperationType::LayerNorm
|
||||
| OperationType::Add
|
||||
| OperationType::Mul
|
||||
| OperationType::ReLU
|
||||
| OperationType::GeLU
|
||||
| OperationType::SiLU
|
||||
| OperationType::Softmax
|
||||
| OperationType::Sum
|
||||
| OperationType::Mean
|
||||
| OperationType::Max
|
||||
| OperationType::ArgMax
|
||||
| OperationType::Embedding
|
||||
| OperationType::TopK
|
||||
| OperationType::Sampling
|
||||
| OperationType::Tokenization
|
||||
| OperationType::Detokenization
|
||||
| OperationType::DataLoad
|
||||
| OperationType::DataPreprocess
|
||||
| OperationType::Transpose
|
||||
| OperationType::Reshape
|
||||
| OperationType::Concat
|
||||
| OperationType::Split
|
||||
),
|
||||
|
||||
// GPUs excel at parallel operations
|
||||
ProcessorType::Gpu(_) => matches!(
|
||||
op_type,
|
||||
OperationType::MatMul
|
||||
| OperationType::Conv2d
|
||||
| OperationType::Conv3d
|
||||
| OperationType::DepthwiseConv
|
||||
| OperationType::BatchNorm
|
||||
| OperationType::LayerNorm
|
||||
| OperationType::SelfAttention
|
||||
| OperationType::CrossAttention
|
||||
| OperationType::FlashAttention
|
||||
| OperationType::Add
|
||||
| OperationType::Mul
|
||||
| OperationType::ReLU
|
||||
| OperationType::GeLU
|
||||
| OperationType::SiLU
|
||||
| OperationType::Softmax
|
||||
| OperationType::Sum
|
||||
| OperationType::Mean
|
||||
| OperationType::Max
|
||||
| OperationType::ArgMax
|
||||
| OperationType::Embedding
|
||||
| OperationType::RoPE
|
||||
| OperationType::KVCache
|
||||
| OperationType::TopK
|
||||
| OperationType::Sampling
|
||||
| OperationType::Transpose
|
||||
| OperationType::Reshape
|
||||
| OperationType::Concat
|
||||
| OperationType::Split
|
||||
| OperationType::Gather
|
||||
| OperationType::Scatter
|
||||
| OperationType::AllReduce
|
||||
| OperationType::AllGather
|
||||
| OperationType::ReduceScatter
|
||||
| OperationType::Backward
|
||||
| OperationType::OptimizerStep
|
||||
| OperationType::GradientClip
|
||||
| OperationType::Conv2d
|
||||
| OperationType::Conv3d
|
||||
| OperationType::DepthwiseConv
|
||||
| OperationType::BatchNorm
|
||||
| OperationType::LayerNorm
|
||||
| OperationType::SelfAttention
|
||||
| OperationType::CrossAttention
|
||||
| OperationType::FlashAttention
|
||||
| OperationType::Add
|
||||
| OperationType::Mul
|
||||
| OperationType::ReLU
|
||||
| OperationType::GeLU
|
||||
| OperationType::SiLU
|
||||
| OperationType::Softmax
|
||||
| OperationType::Sum
|
||||
| OperationType::Mean
|
||||
| OperationType::Max
|
||||
| OperationType::ArgMax
|
||||
| OperationType::Embedding
|
||||
| OperationType::RoPE
|
||||
| OperationType::KVCache
|
||||
| OperationType::TopK
|
||||
| OperationType::Sampling
|
||||
| OperationType::Transpose
|
||||
| OperationType::Reshape
|
||||
| OperationType::Concat
|
||||
| OperationType::Split
|
||||
| OperationType::Gather
|
||||
| OperationType::Scatter
|
||||
| OperationType::AllReduce
|
||||
| OperationType::AllGather
|
||||
| OperationType::ReduceScatter
|
||||
| OperationType::Backward
|
||||
| OperationType::OptimizerStep
|
||||
| OperationType::GradientClip
|
||||
),
|
||||
|
||||
// TPUs optimized for ML
|
||||
ProcessorType::Tpu(_) => matches!(
|
||||
op_type,
|
||||
OperationType::MatMul
|
||||
| OperationType::Conv2d
|
||||
| OperationType::BatchNorm
|
||||
| OperationType::LayerNorm
|
||||
| OperationType::SelfAttention
|
||||
| OperationType::CrossAttention
|
||||
| OperationType::FlashAttention
|
||||
| OperationType::Add
|
||||
| OperationType::Mul
|
||||
| OperationType::ReLU
|
||||
| OperationType::GeLU
|
||||
| OperationType::SiLU
|
||||
| OperationType::Softmax
|
||||
| OperationType::Sum
|
||||
| OperationType::Mean
|
||||
| OperationType::Embedding
|
||||
| OperationType::RoPE
|
||||
| OperationType::KVCache
|
||||
| OperationType::AllReduce
|
||||
| OperationType::AllGather
|
||||
| OperationType::ReduceScatter
|
||||
| OperationType::Backward
|
||||
| OperationType::OptimizerStep
|
||||
| OperationType::Conv2d
|
||||
| OperationType::BatchNorm
|
||||
| OperationType::LayerNorm
|
||||
| OperationType::SelfAttention
|
||||
| OperationType::CrossAttention
|
||||
| OperationType::FlashAttention
|
||||
| OperationType::Add
|
||||
| OperationType::Mul
|
||||
| OperationType::ReLU
|
||||
| OperationType::GeLU
|
||||
| OperationType::SiLU
|
||||
| OperationType::Softmax
|
||||
| OperationType::Sum
|
||||
| OperationType::Mean
|
||||
| OperationType::Embedding
|
||||
| OperationType::RoPE
|
||||
| OperationType::KVCache
|
||||
| OperationType::AllReduce
|
||||
| OperationType::AllGather
|
||||
| OperationType::ReduceScatter
|
||||
| OperationType::Backward
|
||||
| OperationType::OptimizerStep
|
||||
),
|
||||
|
||||
// NPUs for neural network inference
|
||||
ProcessorType::Npu(_) => matches!(
|
||||
op_type,
|
||||
OperationType::MatMul
|
||||
| OperationType::Conv2d
|
||||
| OperationType::DepthwiseConv
|
||||
| OperationType::BatchNorm
|
||||
| OperationType::LayerNorm
|
||||
| OperationType::SelfAttention
|
||||
| OperationType::Add
|
||||
| OperationType::Mul
|
||||
| OperationType::ReLU
|
||||
| OperationType::GeLU
|
||||
| OperationType::SiLU
|
||||
| OperationType::Softmax
|
||||
| OperationType::Sum
|
||||
| OperationType::Mean
|
||||
| OperationType::Conv2d
|
||||
| OperationType::DepthwiseConv
|
||||
| OperationType::BatchNorm
|
||||
| OperationType::LayerNorm
|
||||
| OperationType::SelfAttention
|
||||
| OperationType::Add
|
||||
| OperationType::Mul
|
||||
| OperationType::ReLU
|
||||
| OperationType::GeLU
|
||||
| OperationType::SiLU
|
||||
| OperationType::Softmax
|
||||
| OperationType::Sum
|
||||
| OperationType::Mean
|
||||
),
|
||||
|
||||
// LPUs for sequential inference (optimized for LLMs)
|
||||
ProcessorType::Lpu => matches!(
|
||||
op_type,
|
||||
OperationType::MatMul
|
||||
| OperationType::LayerNorm
|
||||
| OperationType::SelfAttention
|
||||
| OperationType::FlashAttention
|
||||
| OperationType::Add
|
||||
| OperationType::Mul
|
||||
| OperationType::ReLU
|
||||
| OperationType::GeLU
|
||||
| OperationType::SiLU
|
||||
| OperationType::Softmax
|
||||
| OperationType::Embedding
|
||||
| OperationType::RoPE
|
||||
| OperationType::KVCache
|
||||
| OperationType::TopK
|
||||
| OperationType::Sampling
|
||||
| OperationType::LayerNorm
|
||||
| OperationType::SelfAttention
|
||||
| OperationType::FlashAttention
|
||||
| OperationType::Add
|
||||
| OperationType::Mul
|
||||
| OperationType::ReLU
|
||||
| OperationType::GeLU
|
||||
| OperationType::SiLU
|
||||
| OperationType::Softmax
|
||||
| OperationType::Embedding
|
||||
| OperationType::RoPE
|
||||
| OperationType::KVCache
|
||||
| OperationType::TopK
|
||||
| OperationType::Sampling
|
||||
),
|
||||
|
||||
// FPGAs can be programmed for anything
|
||||
|
|
@ -322,40 +327,40 @@ impl LoadBalancer {
|
|||
ProcessorType::Dsp(_) => matches!(
|
||||
op_type,
|
||||
OperationType::Conv2d
|
||||
| OperationType::DepthwiseConv
|
||||
| OperationType::Add
|
||||
| OperationType::Mul
|
||||
| OperationType::Sum
|
||||
| OperationType::Mean
|
||||
| OperationType::Max
|
||||
| OperationType::DepthwiseConv
|
||||
| OperationType::Add
|
||||
| OperationType::Mul
|
||||
| OperationType::Sum
|
||||
| OperationType::Mean
|
||||
| OperationType::Max
|
||||
),
|
||||
|
||||
// WebGPU has limited operations
|
||||
ProcessorType::WebGpu => matches!(
|
||||
op_type,
|
||||
OperationType::MatMul
|
||||
| OperationType::Conv2d
|
||||
| OperationType::Add
|
||||
| OperationType::Mul
|
||||
| OperationType::ReLU
|
||||
| OperationType::Softmax
|
||||
| OperationType::Sum
|
||||
| OperationType::Transpose
|
||||
| OperationType::Reshape
|
||||
| OperationType::Conv2d
|
||||
| OperationType::Add
|
||||
| OperationType::Mul
|
||||
| OperationType::ReLU
|
||||
| OperationType::Softmax
|
||||
| OperationType::Sum
|
||||
| OperationType::Transpose
|
||||
| OperationType::Reshape
|
||||
),
|
||||
|
||||
// WASM for portable compute
|
||||
ProcessorType::Wasm => matches!(
|
||||
op_type,
|
||||
OperationType::MatMul
|
||||
| OperationType::Add
|
||||
| OperationType::Mul
|
||||
| OperationType::ReLU
|
||||
| OperationType::Softmax
|
||||
| OperationType::Sum
|
||||
| OperationType::Mean
|
||||
| OperationType::Tokenization
|
||||
| OperationType::Detokenization
|
||||
| OperationType::Add
|
||||
| OperationType::Mul
|
||||
| OperationType::ReLU
|
||||
| OperationType::Softmax
|
||||
| OperationType::Sum
|
||||
| OperationType::Mean
|
||||
| OperationType::Tokenization
|
||||
| OperationType::Detokenization
|
||||
),
|
||||
|
||||
// Custom processors - assume they can handle anything
|
||||
|
|
@ -381,7 +386,9 @@ impl LoadBalancer {
|
|||
}
|
||||
|
||||
// Get utilization and metrics
|
||||
let utilization = proc_metrics.map(|m| m.utilization).unwrap_or(load as f64 / 100.0);
|
||||
let utilization = proc_metrics
|
||||
.map(|m| m.utilization)
|
||||
.unwrap_or(load as f64 / 100.0);
|
||||
let power = proc_metrics.map(|m| m.power_watts).unwrap_or(100.0);
|
||||
let avg_completion = proc_metrics.map(|m| m.avg_completion_ms).unwrap_or(100.0);
|
||||
|
||||
|
|
@ -431,13 +438,13 @@ impl LoadBalancer {
|
|||
BalancingStrategy::Cost => {
|
||||
// Prioritize cheaper resources (consumer devices)
|
||||
let cost_factor = match processor_type {
|
||||
ProcessorType::Wasm => 0.1, // Cheapest (browser)
|
||||
ProcessorType::Wasm => 0.1, // Cheapest (browser)
|
||||
ProcessorType::WebGpu => 0.15,
|
||||
ProcessorType::Cpu(_) => 0.2,
|
||||
ProcessorType::Npu(_) => 0.3, // Mobile NPUs
|
||||
ProcessorType::Npu(_) => 0.3, // Mobile NPUs
|
||||
ProcessorType::Gpu(_) => 0.5,
|
||||
ProcessorType::Lpu => 0.8,
|
||||
ProcessorType::Tpu(_) => 1.0, // Most expensive
|
||||
ProcessorType::Tpu(_) => 1.0, // Most expensive
|
||||
_ => 0.5,
|
||||
};
|
||||
|
||||
|
|
@ -450,7 +457,7 @@ impl LoadBalancer {
|
|||
|
||||
// Bonus for low-latency processors
|
||||
let latency_bonus = match processor_type {
|
||||
ProcessorType::Lpu => 5.0, // Designed for low latency
|
||||
ProcessorType::Lpu => 5.0, // Designed for low latency
|
||||
ProcessorType::Npu(_) => 3.0,
|
||||
ProcessorType::Gpu(_) => 2.0,
|
||||
ProcessorType::Tpu(_) => 1.5,
|
||||
|
|
@ -550,7 +557,8 @@ impl LoadBalancer {
|
|||
let mut suggestions = Vec::new();
|
||||
let loads = self.loads.read();
|
||||
|
||||
let load_values: Vec<_> = loads.iter()
|
||||
let load_values: Vec<_> = loads
|
||||
.iter()
|
||||
.map(|(id, load)| (*id, load.load(Ordering::Relaxed)))
|
||||
.collect();
|
||||
|
||||
|
|
@ -558,16 +566,18 @@ impl LoadBalancer {
|
|||
return suggestions;
|
||||
}
|
||||
|
||||
let avg_load: f64 = load_values.iter().map(|(_, l)| *l as f64).sum::<f64>()
|
||||
/ load_values.len() as f64;
|
||||
let avg_load: f64 =
|
||||
load_values.iter().map(|(_, l)| *l as f64).sum::<f64>() / load_values.len() as f64;
|
||||
|
||||
let processor_types = self.processor_types.read();
|
||||
|
||||
let overloaded: Vec<_> = load_values.iter()
|
||||
let overloaded: Vec<_> = load_values
|
||||
.iter()
|
||||
.filter(|(_, l)| *l as f64 > avg_load * (1.0 + self.rebalance_threshold))
|
||||
.collect();
|
||||
|
||||
let underloaded: Vec<_> = load_values.iter()
|
||||
let underloaded: Vec<_> = load_values
|
||||
.iter()
|
||||
.filter(|(_, l)| (*l as f64) < avg_load * (1.0 - self.rebalance_threshold))
|
||||
.collect();
|
||||
|
||||
|
|
@ -627,7 +637,9 @@ impl LoadBalancer {
|
|||
/// Clean up old migration history.
|
||||
pub fn cleanup_history(&self, max_age: Duration) {
|
||||
let cutoff = Instant::now() - max_age;
|
||||
self.migration_history.write().retain(|r| r.timestamp > cutoff);
|
||||
self.migration_history
|
||||
.write()
|
||||
.retain(|r| r.timestamp > cutoff);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -725,7 +737,9 @@ mod tests {
|
|||
balancer.register_processor(ProcessorId(0), ProcessorType::Cpu(CpuVariant::default()));
|
||||
balancer.register_processor(
|
||||
ProcessorId(1),
|
||||
ProcessorType::Gpu(GpuVariant::NvidiaCuda { compute_capability: (8, 9) }),
|
||||
ProcessorType::Gpu(GpuVariant::NvidiaCuda {
|
||||
compute_capability: (8, 9),
|
||||
}),
|
||||
);
|
||||
|
||||
// Give CPU high load
|
||||
|
|
@ -757,7 +771,9 @@ mod tests {
|
|||
};
|
||||
|
||||
let cpu = ProcessorType::Cpu(CpuVariant::default());
|
||||
let gpu = ProcessorType::Gpu(GpuVariant::NvidiaCuda { compute_capability: (8, 9) });
|
||||
let gpu = ProcessorType::Gpu(GpuVariant::NvidiaCuda {
|
||||
compute_capability: (8, 9),
|
||||
});
|
||||
let lpu = ProcessorType::Lpu;
|
||||
|
||||
// MatMul can run on all
|
||||
|
|
@ -778,7 +794,10 @@ mod tests {
|
|||
let npu_id = ProcessorId(1);
|
||||
|
||||
balancer.register_processor(cpu_id, ProcessorType::Cpu(CpuVariant::default()));
|
||||
balancer.register_processor(npu_id, ProcessorType::Npu(crate::processor::NpuVariant::AppleNeuralEngine { cores: 16 }));
|
||||
balancer.register_processor(
|
||||
npu_id,
|
||||
ProcessorType::Npu(crate::processor::NpuVariant::AppleNeuralEngine { cores: 16 }),
|
||||
);
|
||||
|
||||
let task = create_test_task(TaskPriority::Normal);
|
||||
|
||||
|
|
|
|||
|
|
@ -69,7 +69,9 @@ impl HeterogeneousScheduler {
|
|||
let utilization = self.estimate_utilization(&schedule);
|
||||
|
||||
// 5. Store active schedule
|
||||
self.active_schedules.write().insert(schedule.id, schedule.clone());
|
||||
self.active_schedules
|
||||
.write()
|
||||
.insert(schedule.id, schedule.clone());
|
||||
|
||||
Ok(ScheduleResult {
|
||||
schedule,
|
||||
|
|
@ -89,10 +91,12 @@ impl HeterogeneousScheduler {
|
|||
let mut handles = Vec::new();
|
||||
|
||||
for task_id in &stage.tasks {
|
||||
let task = schedule.tasks.get(task_id)
|
||||
.ok_or_else(|| ComputeError::Internal(format!("Task not found: {:?}", task_id)))?;
|
||||
let processor_id = schedule.assignment.get(task_id)
|
||||
.ok_or_else(|| ComputeError::Internal(format!("No assignment for task: {:?}", task_id)))?;
|
||||
let task = schedule.tasks.get(task_id).ok_or_else(|| {
|
||||
ComputeError::Internal(format!("Task not found: {:?}", task_id))
|
||||
})?;
|
||||
let processor_id = schedule.assignment.get(task_id).ok_or_else(|| {
|
||||
ComputeError::Internal(format!("No assignment for task: {:?}", task_id))
|
||||
})?;
|
||||
|
||||
let processor = self.device_registry.get_processor(processor_id)?;
|
||||
let task_clone = task.clone();
|
||||
|
|
@ -144,8 +148,9 @@ impl HeterogeneousScheduler {
|
|||
let best_processor = self.find_best_processor(&task).await?;
|
||||
|
||||
// Check if we should rebalance
|
||||
let final_processor = self.load_balancer
|
||||
.maybe_rebalance(&task, best_processor, &assignment);
|
||||
let final_processor =
|
||||
self.load_balancer
|
||||
.maybe_rebalance(&task, best_processor, &assignment);
|
||||
|
||||
assignment.assign(task.id, final_processor);
|
||||
}
|
||||
|
|
@ -207,9 +212,7 @@ impl HeterogeneousScheduler {
|
|||
fn topological_sort(&self, tasks: &[Task], deps: &DependencyGraph) -> Vec<Task> {
|
||||
let mut sorted = Vec::new();
|
||||
let mut visited = std::collections::HashSet::new();
|
||||
let task_map: HashMap<TaskId, Task> = tasks.iter()
|
||||
.map(|t| (t.id, t.clone()))
|
||||
.collect();
|
||||
let task_map: HashMap<TaskId, Task> = tasks.iter().map(|t| (t.id, t.clone())).collect();
|
||||
|
||||
fn visit(
|
||||
task_id: TaskId,
|
||||
|
|
@ -254,9 +257,7 @@ impl HeterogeneousScheduler {
|
|||
) -> Result<Schedule, ComputeError> {
|
||||
let mut stages = Vec::new();
|
||||
let mut scheduled = std::collections::HashSet::new();
|
||||
let task_map: HashMap<TaskId, Task> = tasks.iter()
|
||||
.map(|t| (t.id, t.clone()))
|
||||
.collect();
|
||||
let task_map: HashMap<TaskId, Task> = tasks.iter().map(|t| (t.id, t.clone())).collect();
|
||||
|
||||
while scheduled.len() < tasks.len() {
|
||||
let mut stage_tasks = Vec::new();
|
||||
|
|
@ -267,8 +268,7 @@ impl HeterogeneousScheduler {
|
|||
}
|
||||
|
||||
// Check if all dependencies are satisfied
|
||||
let deps_satisfied = task.dependencies.iter()
|
||||
.all(|dep| scheduled.contains(dep));
|
||||
let deps_satisfied = task.dependencies.iter().all(|dep| scheduled.contains(dep));
|
||||
|
||||
if deps_satisfied {
|
||||
stage_tasks.push(task.id);
|
||||
|
|
@ -277,7 +277,7 @@ impl HeterogeneousScheduler {
|
|||
|
||||
if stage_tasks.is_empty() {
|
||||
return Err(ComputeError::SchedulingFailed(
|
||||
"Circular dependency detected".to_string()
|
||||
"Circular dependency detected".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -153,7 +153,10 @@ impl PriorityWorkQueue {
|
|||
TaskPriority::Normal,
|
||||
TaskPriority::Background,
|
||||
] {
|
||||
queues.insert(priority, WorkQueue::new(processor_type.clone(), capacity_per_priority));
|
||||
queues.insert(
|
||||
priority,
|
||||
WorkQueue::new(processor_type.clone(), capacity_per_priority),
|
||||
);
|
||||
}
|
||||
|
||||
Self {
|
||||
|
|
@ -223,10 +226,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_work_queue_basic() {
|
||||
let queue = WorkQueue::new(
|
||||
ProcessorType::Cpu(CpuVariant::default()),
|
||||
100,
|
||||
);
|
||||
let queue = WorkQueue::new(ProcessorType::Cpu(CpuVariant::default()), 100);
|
||||
|
||||
assert!(queue.is_empty());
|
||||
|
||||
|
|
@ -246,10 +246,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_priority_queue() {
|
||||
let queue = PriorityWorkQueue::new(
|
||||
ProcessorType::Cpu(CpuVariant::default()),
|
||||
100,
|
||||
);
|
||||
let queue = PriorityWorkQueue::new(ProcessorType::Cpu(CpuVariant::default()), 100);
|
||||
|
||||
queue.push(create_test_task(1, TaskPriority::Background));
|
||||
queue.push(create_test_task(2, TaskPriority::Critical));
|
||||
|
|
|
|||
|
|
@ -495,9 +495,9 @@ mod tests {
|
|||
compute_capability: (8, 0)
|
||||
}
|
||||
)));
|
||||
assert!(matmul_task.is_compatible_with(ProcessorType::Tpu(
|
||||
crate::processor::TpuVersion::V5p
|
||||
)));
|
||||
assert!(
|
||||
matmul_task.is_compatible_with(ProcessorType::Tpu(crate::processor::TpuVersion::V5p))
|
||||
);
|
||||
|
||||
let data_load_task = Task::new(Operation::DataLoad {
|
||||
bytes: 1000,
|
||||
|
|
@ -505,9 +505,8 @@ mod tests {
|
|||
});
|
||||
|
||||
// DataLoad should be compatible with CPU
|
||||
assert!(data_load_task.is_compatible_with(ProcessorType::Cpu(
|
||||
crate::processor::CpuVariant::default()
|
||||
)));
|
||||
assert!(data_load_task
|
||||
.is_compatible_with(ProcessorType::Cpu(crate::processor::CpuVariant::default())));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
|||
|
|
@ -25,8 +25,7 @@
|
|||
use std::time::Duration;
|
||||
|
||||
/// Blocks per second mode.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
||||
#[derive(Default)]
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Default)]
|
||||
pub enum BpsMode {
|
||||
/// Standard mode: 10 blocks per second (100ms block time)
|
||||
/// - Suitable for most network conditions
|
||||
|
|
@ -75,7 +74,6 @@ impl BpsMode {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
impl std::fmt::Display for BpsMode {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
|
|
@ -148,39 +146,39 @@ impl NetworkConfig {
|
|||
bps_mode: mode,
|
||||
blocks_per_second: 10,
|
||||
target_block_time_ms: 100,
|
||||
daa_window_size: 2641, // ~264s window
|
||||
ghostdag_k: 18, // For 10 BPS
|
||||
daa_window_size: 2641, // ~264s window
|
||||
ghostdag_k: 18, // For 10 BPS
|
||||
dagknight_k_min: 8,
|
||||
dagknight_k_max: 64,
|
||||
finality_depth: 864, // ~86 seconds
|
||||
pruning_depth: 864_000, // ~24 hours
|
||||
finality_depth: 864, // ~86 seconds
|
||||
pruning_depth: 864_000, // ~24 hours
|
||||
merge_set_size_limit: 180,
|
||||
expected_delay_ms: 100,
|
||||
},
|
||||
BpsMode::Fast32 => Self {
|
||||
bps_mode: mode,
|
||||
blocks_per_second: 32,
|
||||
target_block_time_ms: 31, // ~31.25ms
|
||||
daa_window_size: 8461, // ~264s window at 32 BPS
|
||||
ghostdag_k: 58, // Scaled for 32 BPS
|
||||
dagknight_k_min: 16, // Higher min for faster blocks
|
||||
dagknight_k_max: 128, // Higher max for adaptation
|
||||
finality_depth: 2765, // ~86 seconds at 32 BPS
|
||||
pruning_depth: 2_764_800, // ~24 hours at 32 BPS
|
||||
merge_set_size_limit: 576, // 32/10 * 180
|
||||
target_block_time_ms: 31, // ~31.25ms
|
||||
daa_window_size: 8461, // ~264s window at 32 BPS
|
||||
ghostdag_k: 58, // Scaled for 32 BPS
|
||||
dagknight_k_min: 16, // Higher min for faster blocks
|
||||
dagknight_k_max: 128, // Higher max for adaptation
|
||||
finality_depth: 2765, // ~86 seconds at 32 BPS
|
||||
pruning_depth: 2_764_800, // ~24 hours at 32 BPS
|
||||
merge_set_size_limit: 576, // 32/10 * 180
|
||||
expected_delay_ms: 50,
|
||||
},
|
||||
BpsMode::Ultra100 => Self {
|
||||
bps_mode: mode,
|
||||
blocks_per_second: 100,
|
||||
target_block_time_ms: 10,
|
||||
daa_window_size: 26410, // ~264s window at 100 BPS
|
||||
ghostdag_k: 180, // Scaled for 100 BPS
|
||||
dagknight_k_min: 50, // Higher min for very fast blocks
|
||||
dagknight_k_max: 255, // u8 max - very high for adaptation
|
||||
finality_depth: 8640, // ~86 seconds at 100 BPS
|
||||
pruning_depth: 8_640_000, // ~24 hours at 100 BPS
|
||||
merge_set_size_limit: 1800, // 100/10 * 180
|
||||
daa_window_size: 26410, // ~264s window at 100 BPS
|
||||
ghostdag_k: 180, // Scaled for 100 BPS
|
||||
dagknight_k_min: 50, // Higher min for very fast blocks
|
||||
dagknight_k_max: 255, // u8 max - very high for adaptation
|
||||
finality_depth: 8640, // ~86 seconds at 100 BPS
|
||||
pruning_depth: 8_640_000, // ~24 hours at 100 BPS
|
||||
merge_set_size_limit: 1800, // 100/10 * 180
|
||||
expected_delay_ms: 20,
|
||||
},
|
||||
BpsMode::Custom(bps) => {
|
||||
|
|
@ -269,7 +267,7 @@ pub fn bps_comparison_table() -> String {
|
|||
|
||||
let mut table = String::from(
|
||||
"| Property | Standard (10 BPS) | Fast (32 BPS) | Ultra (100 BPS) |\n\
|
||||
|----------|-------------------|---------------|------------------|\n"
|
||||
|----------|-------------------|---------------|------------------|\n",
|
||||
);
|
||||
|
||||
// Block Time
|
||||
|
|
@ -314,7 +312,9 @@ pub fn bps_comparison_table() -> String {
|
|||
// Estimated TPS
|
||||
table.push_str(&format!(
|
||||
"| Est. TPS @1000tx/block | {:.0} | {:.0} | {:.0} |\n",
|
||||
standard.estimate_tps(1000), fast.estimate_tps(1000), ultra.estimate_tps(1000)
|
||||
standard.estimate_tps(1000),
|
||||
fast.estimate_tps(1000),
|
||||
ultra.estimate_tps(1000)
|
||||
));
|
||||
|
||||
table
|
||||
|
|
@ -401,9 +401,9 @@ mod tests {
|
|||
fn test_latency_acceptable() {
|
||||
let config = NetworkConfig::standard(); // expects 100ms
|
||||
|
||||
assert!(config.is_latency_acceptable(50)); // Good
|
||||
assert!(config.is_latency_acceptable(100)); // OK
|
||||
assert!(config.is_latency_acceptable(200)); // Still OK (2x limit)
|
||||
assert!(config.is_latency_acceptable(50)); // Good
|
||||
assert!(config.is_latency_acceptable(100)); // OK
|
||||
assert!(config.is_latency_acceptable(200)); // Still OK (2x limit)
|
||||
assert!(!config.is_latency_acceptable(300)); // Too high
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -55,8 +55,8 @@
|
|||
//! | Layer 2 transactions | FALCON-512 (batch efficiency) |
|
||||
//! | High-value transactions | Dilithium3 (conservative choice) |
|
||||
|
||||
use pqcrypto_falcon::falcon512;
|
||||
use pqcrypto_falcon::falcon1024;
|
||||
use pqcrypto_falcon::falcon512;
|
||||
use pqcrypto_traits::sign::{
|
||||
DetachedSignature, PublicKey as PqPublicKey, SecretKey as PqSecretKey,
|
||||
};
|
||||
|
|
@ -64,8 +64,7 @@ use thiserror::Error;
|
|||
use zeroize::{Zeroize, ZeroizeOnDrop};
|
||||
|
||||
/// FALCON variant selection.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
#[derive(Default)]
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
|
||||
pub enum FalconVariant {
|
||||
/// 128-bit security, ~690 byte signatures
|
||||
#[default]
|
||||
|
|
@ -124,7 +123,6 @@ impl FalconVariant {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
/// FALCON public key.
|
||||
#[derive(Clone)]
|
||||
pub struct FalconPublicKey {
|
||||
|
|
@ -188,7 +186,10 @@ impl std::fmt::Debug for FalconPublicKey {
|
|||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("FalconPublicKey")
|
||||
.field("variant", &self.variant)
|
||||
.field("bytes", &hex::encode(&self.bytes[..8.min(self.bytes.len())]))
|
||||
.field(
|
||||
"bytes",
|
||||
&hex::encode(&self.bytes[..8.min(self.bytes.len())]),
|
||||
)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
|
@ -492,7 +493,10 @@ mod tests {
|
|||
|
||||
// Verify with wrong message should fail
|
||||
let wrong_message = b"Wrong message";
|
||||
assert!(keypair.public_key().verify(wrong_message, &signature).is_err());
|
||||
assert!(keypair
|
||||
.public_key()
|
||||
.verify(wrong_message, &signature)
|
||||
.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
|||
|
|
@ -150,9 +150,9 @@ impl PqAlgorithm {
|
|||
/// Default priority order (higher = more preferred)
|
||||
fn default_priority(&self) -> u8 {
|
||||
match self {
|
||||
Self::Dilithium3 => 100, // Default, well-balanced
|
||||
Self::Falcon1024 => 90, // High security, compact
|
||||
Self::Falcon512 => 85, // Compact, mobile-friendly
|
||||
Self::Dilithium3 => 100, // Default, well-balanced
|
||||
Self::Falcon1024 => 90, // High security, compact
|
||||
Self::Falcon512 => 85, // Compact, mobile-friendly
|
||||
Self::SphincsShake192s => 70, // Conservative backup
|
||||
Self::SphincsShake256s => 60, // Maximum security
|
||||
Self::SphincsShake128s => 50, // Basic SPHINCS+
|
||||
|
|
@ -270,7 +270,8 @@ impl AlgorithmCapabilities {
|
|||
|
||||
/// Decode capabilities from bytes
|
||||
pub fn decode(data: &[u8]) -> Result<Self, NegotiationError> {
|
||||
serde_json::from_slice(data).map_err(|e| NegotiationError::InvalidCapabilities(e.to_string()))
|
||||
serde_json::from_slice(data)
|
||||
.map_err(|e| NegotiationError::InvalidCapabilities(e.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -384,8 +385,7 @@ impl AlgorithmNegotiator {
|
|||
// Check security level
|
||||
let meets_local_security =
|
||||
algo.security_level() >= self.local_caps.min_security_level;
|
||||
let meets_remote_security =
|
||||
algo.security_level() >= remote_caps.min_security_level;
|
||||
let meets_remote_security = algo.security_level() >= remote_caps.min_security_level;
|
||||
|
||||
// Check signature size
|
||||
let local_size_ok = self.local_caps.max_signature_size == 0
|
||||
|
|
@ -513,10 +513,7 @@ impl AlgorithmNegotiator {
|
|||
}
|
||||
|
||||
/// Quick negotiation using just algorithm names
|
||||
pub fn quick_negotiate(
|
||||
local: &[PqAlgorithm],
|
||||
remote: &[PqAlgorithm],
|
||||
) -> Option<PqAlgorithm> {
|
||||
pub fn quick_negotiate(local: &[PqAlgorithm], remote: &[PqAlgorithm]) -> Option<PqAlgorithm> {
|
||||
// Find common algorithms and return the one with highest default priority
|
||||
let local_set: HashSet<_> = local.iter().collect();
|
||||
let remote_set: HashSet<_> = remote.iter().collect();
|
||||
|
|
@ -604,7 +601,10 @@ pub enum NegotiationMessage {
|
|||
},
|
||||
|
||||
/// Acknowledge selection
|
||||
Acknowledgment { session_id: [u8; 32], accepted: bool },
|
||||
Acknowledgment {
|
||||
session_id: [u8; 32],
|
||||
accepted: bool,
|
||||
},
|
||||
|
||||
/// Request renegotiation
|
||||
Renegotiate { reason: String },
|
||||
|
|
@ -691,8 +691,10 @@ mod tests {
|
|||
let result = negotiator.negotiate(&remote_caps).unwrap();
|
||||
|
||||
// Should prefer FALCON for bandwidth-constrained scenarios
|
||||
assert!(result.algorithm == PqAlgorithm::Falcon512 ||
|
||||
result.algorithm == PqAlgorithm::Falcon1024);
|
||||
assert!(
|
||||
result.algorithm == PqAlgorithm::Falcon512
|
||||
|| result.algorithm == PqAlgorithm::Falcon1024
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
|||
|
|
@ -60,8 +60,7 @@ use zeroize::{Zeroize, ZeroizeOnDrop};
|
|||
/// All variants use SHAKE (SHA3-based) for hashing.
|
||||
/// 's' variants have smaller signatures but are slower.
|
||||
/// 'f' variants are faster but have larger signatures.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
#[derive(Default)]
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
|
||||
pub enum SphincsVariant {
|
||||
/// 128-bit security, small signatures (~7.8KB)
|
||||
#[default]
|
||||
|
|
@ -119,7 +118,6 @@ impl SphincsVariant {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
/// SPHINCS+ public key.
|
||||
#[derive(Clone)]
|
||||
pub struct SphincsPublicKey {
|
||||
|
|
@ -191,7 +189,10 @@ impl std::fmt::Debug for SphincsPublicKey {
|
|||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("SphincsPublicKey")
|
||||
.field("variant", &self.variant)
|
||||
.field("bytes", &hex::encode(&self.bytes[..8.min(self.bytes.len())]))
|
||||
.field(
|
||||
"bytes",
|
||||
&hex::encode(&self.bytes[..8.min(self.bytes.len())]),
|
||||
)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
|
@ -500,7 +501,10 @@ mod tests {
|
|||
|
||||
// Verify with wrong message should fail
|
||||
let wrong_message = b"Wrong message";
|
||||
assert!(keypair.public_key().verify(wrong_message, &signature).is_err());
|
||||
assert!(keypair
|
||||
.public_key()
|
||||
.verify(wrong_message, &signature)
|
||||
.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
|||
|
|
@ -155,10 +155,7 @@ pub struct DagKnightManager {
|
|||
|
||||
impl DagKnightManager {
|
||||
/// Creates a new DAGKnight manager with standard 10 BPS configuration.
|
||||
pub fn new(
|
||||
dag: Arc<BlockDag>,
|
||||
reachability: Arc<ReachabilityStore>,
|
||||
) -> Self {
|
||||
pub fn new(dag: Arc<BlockDag>, reachability: Arc<ReachabilityStore>) -> Self {
|
||||
Self::with_config(dag, reachability, BlockRateConfig::Standard)
|
||||
}
|
||||
|
||||
|
|
@ -269,7 +266,8 @@ impl DagKnightManager {
|
|||
let anticone_size = self.calculate_anticone_size(&block_id, parents);
|
||||
|
||||
// Record observation in latency tracker
|
||||
self.latency_tracker.record_block(block_id, block_time_ms, anticone_size);
|
||||
self.latency_tracker
|
||||
.record_block(block_id, block_time_ms, anticone_size);
|
||||
|
||||
// Process with underlying GHOSTDAG
|
||||
let data = self.ghostdag.add_block(block_id, parents)?;
|
||||
|
|
@ -292,11 +290,9 @@ impl DagKnightManager {
|
|||
for tip in tips {
|
||||
if tip != *block_id && !parents.contains(&tip) {
|
||||
// Check if tip is in the past of any parent
|
||||
let in_past = parents.iter().any(|p| {
|
||||
self.reachability
|
||||
.is_ancestor(p, &tip)
|
||||
.unwrap_or(false)
|
||||
});
|
||||
let in_past = parents
|
||||
.iter()
|
||||
.any(|p| self.reachability.is_ancestor(p, &tip).unwrap_or(false));
|
||||
|
||||
if !in_past {
|
||||
anticone_count += 1;
|
||||
|
|
@ -375,7 +371,8 @@ impl DagKnightManager {
|
|||
let sigma_multiplier = confidence.sigma_multiplier();
|
||||
|
||||
// Required depth scales with variance and confidence level
|
||||
let required_depth = (self.block_rate_bps * (mean_delay + sigma * sigma_multiplier)).ceil() as u64;
|
||||
let required_depth =
|
||||
(self.block_rate_bps * (mean_delay + sigma * sigma_multiplier)).ceil() as u64;
|
||||
|
||||
// Current confidence based on actual depth
|
||||
let current_confidence = if depth >= required_depth {
|
||||
|
|
@ -388,7 +385,8 @@ impl DagKnightManager {
|
|||
// Time to reach required depth
|
||||
let blocks_needed = required_depth.saturating_sub(depth);
|
||||
let time_per_block_ms = 1000.0 / self.block_rate_bps;
|
||||
let estimated_time = Duration::from_millis((blocks_needed as f64 * time_per_block_ms) as u64);
|
||||
let estimated_time =
|
||||
Duration::from_millis((blocks_needed as f64 * time_per_block_ms) as u64);
|
||||
|
||||
// Block is final if depth exceeds finality threshold for this block rate
|
||||
let is_final = depth >= self.finality_depth();
|
||||
|
|
@ -506,7 +504,10 @@ impl std::fmt::Debug for DagKnightManager {
|
|||
.field("block_rate_config", &self.block_rate_config)
|
||||
.field("block_rate_bps", &self.block_rate_bps)
|
||||
.field("adaptive_k", &*self.adaptive_k.read())
|
||||
.field("k_bounds", &format!("{}-{}", self.k_bounds.min_k, self.k_bounds.max_k))
|
||||
.field(
|
||||
"k_bounds",
|
||||
&format!("{}-{}", self.k_bounds.min_k, self.k_bounds.max_k),
|
||||
)
|
||||
.field("mean_delay_ms", &stats.mean_delay_ms)
|
||||
.field("sample_count", &stats.sample_count)
|
||||
.finish()
|
||||
|
|
@ -534,10 +535,7 @@ pub fn calculate_optimal_k(network_delay_ms: f64, block_rate_bps: f64) -> u8 {
|
|||
}
|
||||
|
||||
/// Calculates the optimal k for a specific block rate configuration.
|
||||
pub fn calculate_optimal_k_for_config(
|
||||
network_delay_ms: f64,
|
||||
config: BlockRateConfig,
|
||||
) -> u8 {
|
||||
pub fn calculate_optimal_k_for_config(network_delay_ms: f64, config: BlockRateConfig) -> u8 {
|
||||
let bounds = AdaptiveKBounds::for_block_rate(config);
|
||||
let delay_secs = network_delay_ms / 1000.0;
|
||||
let k = (config.bps() * delay_secs * SAFETY_MARGIN).ceil() as u16;
|
||||
|
|
@ -578,7 +576,9 @@ mod tests {
|
|||
(dag, reachability, dagknight)
|
||||
}
|
||||
|
||||
fn setup_test_dag_with_config(config: BlockRateConfig) -> (Arc<BlockDag>, Arc<ReachabilityStore>, DagKnightManager) {
|
||||
fn setup_test_dag_with_config(
|
||||
config: BlockRateConfig,
|
||||
) -> (Arc<BlockDag>, Arc<ReachabilityStore>, DagKnightManager) {
|
||||
let genesis = make_block_id(0);
|
||||
let dag = Arc::new(BlockDag::new(genesis, 0));
|
||||
let reachability = Arc::new(ReachabilityStore::new(genesis));
|
||||
|
|
@ -671,14 +671,19 @@ mod tests {
|
|||
let tps_poor = estimate_throughput(10.0, 100, 40.0);
|
||||
|
||||
// Good network should have higher throughput
|
||||
assert!(tps_good > tps_poor, "tps_good={} should be > tps_poor={}", tps_good, tps_poor);
|
||||
assert!(
|
||||
tps_good > tps_poor,
|
||||
"tps_good={} should be > tps_poor={}",
|
||||
tps_good,
|
||||
tps_poor
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_throughput_by_config() {
|
||||
// At same network conditions, higher BPS = higher theoretical TPS
|
||||
let tps_10 = estimate_throughput(10.0, 100, 20.0); // 10 BPS
|
||||
let tps_32 = estimate_throughput(32.0, 100, 20.0); // 32 BPS
|
||||
let tps_10 = estimate_throughput(10.0, 100, 20.0); // 10 BPS
|
||||
let tps_32 = estimate_throughput(32.0, 100, 20.0); // 32 BPS
|
||||
let tps_100 = estimate_throughput(100.0, 100, 20.0); // 100 BPS
|
||||
|
||||
// Higher block rates give higher TPS (with network overhead)
|
||||
|
|
@ -698,19 +703,37 @@ mod tests {
|
|||
let maximum_time_hrs = maximum.finality_depth() as f64 / 100.0 / 3600.0;
|
||||
|
||||
// Should all be approximately 2.4 hours (allow some variance)
|
||||
assert!((standard_time_hrs - 2.4).abs() < 0.1, "standard: {}", standard_time_hrs);
|
||||
assert!((enhanced_time_hrs - 2.4).abs() < 0.1, "enhanced: {}", enhanced_time_hrs);
|
||||
assert!((maximum_time_hrs - 2.4).abs() < 0.1, "maximum: {}", maximum_time_hrs);
|
||||
assert!(
|
||||
(standard_time_hrs - 2.4).abs() < 0.1,
|
||||
"standard: {}",
|
||||
standard_time_hrs
|
||||
);
|
||||
assert!(
|
||||
(enhanced_time_hrs - 2.4).abs() < 0.1,
|
||||
"enhanced: {}",
|
||||
enhanced_time_hrs
|
||||
);
|
||||
assert!(
|
||||
(maximum_time_hrs - 2.4).abs() < 0.1,
|
||||
"maximum: {}",
|
||||
maximum_time_hrs
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_confidence_levels() {
|
||||
assert!(ConfirmationConfidence::VeryHigh.sigma_multiplier()
|
||||
> ConfirmationConfidence::High.sigma_multiplier());
|
||||
assert!(ConfirmationConfidence::High.sigma_multiplier()
|
||||
> ConfirmationConfidence::Medium.sigma_multiplier());
|
||||
assert!(ConfirmationConfidence::Medium.sigma_multiplier()
|
||||
> ConfirmationConfidence::Low.sigma_multiplier());
|
||||
assert!(
|
||||
ConfirmationConfidence::VeryHigh.sigma_multiplier()
|
||||
> ConfirmationConfidence::High.sigma_multiplier()
|
||||
);
|
||||
assert!(
|
||||
ConfirmationConfidence::High.sigma_multiplier()
|
||||
> ConfirmationConfidence::Medium.sigma_multiplier()
|
||||
);
|
||||
assert!(
|
||||
ConfirmationConfidence::Medium.sigma_multiplier()
|
||||
> ConfirmationConfidence::Low.sigma_multiplier()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
|||
|
|
@ -98,12 +98,7 @@ impl LatencyTracker {
|
|||
/// * `block_id` - Hash of the observed block
|
||||
/// * `block_time_ms` - Timestamp from block header (Unix ms)
|
||||
/// * `anticone_size` - Number of blocks in the anticone at observation time
|
||||
pub fn record_block(
|
||||
&self,
|
||||
block_id: BlockId,
|
||||
block_time_ms: u64,
|
||||
anticone_size: usize,
|
||||
) {
|
||||
pub fn record_block(&self, block_id: BlockId, block_time_ms: u64, anticone_size: usize) {
|
||||
let local_time = Instant::now();
|
||||
let now_ms = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
|
|
@ -208,7 +203,10 @@ impl LatencyTracker {
|
|||
let anticone_growth_rate = if n > 1 {
|
||||
let first = samples.front().unwrap();
|
||||
let last = samples.back().unwrap();
|
||||
let time_span_secs = last.local_time.duration_since(first.local_time).as_secs_f64();
|
||||
let time_span_secs = last
|
||||
.local_time
|
||||
.duration_since(first.local_time)
|
||||
.as_secs_f64();
|
||||
|
||||
if time_span_secs > 0.0 {
|
||||
let total_anticone_growth: usize = samples.iter().map(|s| s.anticone_size).sum();
|
||||
|
|
|
|||
|
|
@ -32,8 +32,8 @@ pub mod reachability;
|
|||
|
||||
pub use dag::{BlockDag, BlockRelations, DagError};
|
||||
pub use dagknight::{
|
||||
calculate_optimal_k, calculate_optimal_k_for_config, estimate_throughput,
|
||||
AdaptiveKBounds, ConfirmationConfidence, ConfirmationStatus, DagKnightManager,
|
||||
calculate_optimal_k, calculate_optimal_k_for_config, estimate_throughput, AdaptiveKBounds,
|
||||
ConfirmationConfidence, ConfirmationStatus, DagKnightManager,
|
||||
};
|
||||
pub use ghostdag::{GhostdagData, GhostdagError, GhostdagManager};
|
||||
pub use latency::{LatencySample, LatencyStats, LatencyTracker};
|
||||
|
|
@ -116,27 +116,27 @@ impl BlockRateConfig {
|
|||
/// Returns the merge depth adjusted for block rate.
|
||||
pub const fn merge_depth(&self) -> u64 {
|
||||
match self {
|
||||
BlockRateConfig::Standard => 3600, // ~6 min at 10 bps
|
||||
BlockRateConfig::Enhanced => 11520, // ~6 min at 32 bps
|
||||
BlockRateConfig::Maximum => 36000, // ~6 min at 100 bps
|
||||
BlockRateConfig::Standard => 3600, // ~6 min at 10 bps
|
||||
BlockRateConfig::Enhanced => 11520, // ~6 min at 32 bps
|
||||
BlockRateConfig::Maximum => 36000, // ~6 min at 100 bps
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the finality depth adjusted for block rate.
|
||||
pub const fn finality_depth(&self) -> u64 {
|
||||
match self {
|
||||
BlockRateConfig::Standard => 86400, // ~2.4 hours at 10 bps
|
||||
BlockRateConfig::Enhanced => 276480, // ~2.4 hours at 32 bps
|
||||
BlockRateConfig::Maximum => 864000, // ~2.4 hours at 100 bps
|
||||
BlockRateConfig::Standard => 86400, // ~2.4 hours at 10 bps
|
||||
BlockRateConfig::Enhanced => 276480, // ~2.4 hours at 32 bps
|
||||
BlockRateConfig::Maximum => 864000, // ~2.4 hours at 100 bps
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the pruning depth adjusted for block rate.
|
||||
pub const fn pruning_depth(&self) -> u64 {
|
||||
match self {
|
||||
BlockRateConfig::Standard => 288_000, // ~8 hours at 10 bps
|
||||
BlockRateConfig::Enhanced => 921_600, // ~8 hours at 32 bps
|
||||
BlockRateConfig::Maximum => 2_880_000, // ~8 hours at 100 bps
|
||||
BlockRateConfig::Standard => 288_000, // ~8 hours at 10 bps
|
||||
BlockRateConfig::Enhanced => 921_600, // ~8 hours at 32 bps
|
||||
BlockRateConfig::Maximum => 2_880_000, // ~8 hours at 100 bps
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -42,7 +42,9 @@ impl DocumentId {
|
|||
let bytes = hex::decode(s)
|
||||
.map_err(|_| DatabaseError::InvalidOperation("Invalid hex string".into()))?;
|
||||
if bytes.len() != 32 {
|
||||
return Err(DatabaseError::InvalidOperation("Invalid document ID length".into()));
|
||||
return Err(DatabaseError::InvalidOperation(
|
||||
"Invalid document ID length".into(),
|
||||
));
|
||||
}
|
||||
let mut arr = [0u8; 32];
|
||||
arr.copy_from_slice(&bytes);
|
||||
|
|
@ -249,7 +251,11 @@ impl Collection {
|
|||
}
|
||||
|
||||
/// Updates documents matching a filter.
|
||||
pub fn update_many(&self, filter: &DocumentFilter, update: JsonValue) -> Result<u64, DatabaseError> {
|
||||
pub fn update_many(
|
||||
&self,
|
||||
filter: &DocumentFilter,
|
||||
update: JsonValue,
|
||||
) -> Result<u64, DatabaseError> {
|
||||
let mut docs = self.documents.write();
|
||||
let mut count = 0;
|
||||
for doc in docs.values_mut() {
|
||||
|
|
@ -321,60 +327,71 @@ enum FilterCondition {
|
|||
impl DocumentFilter {
|
||||
/// Creates a new empty filter (matches all).
|
||||
pub fn new() -> Self {
|
||||
Self { conditions: Vec::new() }
|
||||
Self {
|
||||
conditions: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Equality condition.
|
||||
pub fn eq(mut self, field: impl Into<String>, value: JsonValue) -> Self {
|
||||
self.conditions.push(FilterCondition::Eq(field.into(), value));
|
||||
self.conditions
|
||||
.push(FilterCondition::Eq(field.into(), value));
|
||||
self
|
||||
}
|
||||
|
||||
/// Not equal condition.
|
||||
pub fn ne(mut self, field: impl Into<String>, value: JsonValue) -> Self {
|
||||
self.conditions.push(FilterCondition::Ne(field.into(), value));
|
||||
self.conditions
|
||||
.push(FilterCondition::Ne(field.into(), value));
|
||||
self
|
||||
}
|
||||
|
||||
/// Greater than.
|
||||
pub fn gt(mut self, field: impl Into<String>, value: JsonValue) -> Self {
|
||||
self.conditions.push(FilterCondition::Gt(field.into(), value));
|
||||
self.conditions
|
||||
.push(FilterCondition::Gt(field.into(), value));
|
||||
self
|
||||
}
|
||||
|
||||
/// Greater than or equal.
|
||||
pub fn gte(mut self, field: impl Into<String>, value: JsonValue) -> Self {
|
||||
self.conditions.push(FilterCondition::Gte(field.into(), value));
|
||||
self.conditions
|
||||
.push(FilterCondition::Gte(field.into(), value));
|
||||
self
|
||||
}
|
||||
|
||||
/// Less than.
|
||||
pub fn lt(mut self, field: impl Into<String>, value: JsonValue) -> Self {
|
||||
self.conditions.push(FilterCondition::Lt(field.into(), value));
|
||||
self.conditions
|
||||
.push(FilterCondition::Lt(field.into(), value));
|
||||
self
|
||||
}
|
||||
|
||||
/// Less than or equal.
|
||||
pub fn lte(mut self, field: impl Into<String>, value: JsonValue) -> Self {
|
||||
self.conditions.push(FilterCondition::Lte(field.into(), value));
|
||||
self.conditions
|
||||
.push(FilterCondition::Lte(field.into(), value));
|
||||
self
|
||||
}
|
||||
|
||||
/// In array.
|
||||
pub fn in_array(mut self, field: impl Into<String>, values: Vec<JsonValue>) -> Self {
|
||||
self.conditions.push(FilterCondition::In(field.into(), values));
|
||||
self.conditions
|
||||
.push(FilterCondition::In(field.into(), values));
|
||||
self
|
||||
}
|
||||
|
||||
/// String contains.
|
||||
pub fn contains(mut self, field: impl Into<String>, substring: impl Into<String>) -> Self {
|
||||
self.conditions.push(FilterCondition::Contains(field.into(), substring.into()));
|
||||
self.conditions
|
||||
.push(FilterCondition::Contains(field.into(), substring.into()));
|
||||
self
|
||||
}
|
||||
|
||||
/// Field exists.
|
||||
pub fn exists(mut self, field: impl Into<String>, exists: bool) -> Self {
|
||||
self.conditions.push(FilterCondition::Exists(field.into(), exists));
|
||||
self.conditions
|
||||
.push(FilterCondition::Exists(field.into(), exists));
|
||||
self
|
||||
}
|
||||
|
||||
|
|
@ -396,7 +413,9 @@ impl DocumentFilter {
|
|||
return true;
|
||||
}
|
||||
|
||||
self.conditions.iter().all(|cond| self.eval_condition(cond, doc))
|
||||
self.conditions
|
||||
.iter()
|
||||
.all(|cond| self.eval_condition(cond, doc))
|
||||
}
|
||||
|
||||
fn eval_condition(&self, cond: &FilterCondition, doc: &Document) -> bool {
|
||||
|
|
@ -419,27 +438,21 @@ impl DocumentFilter {
|
|||
FilterCondition::Lte(field, value) => {
|
||||
self.compare_values(doc.get_nested(field), value, |a, b| a <= b)
|
||||
}
|
||||
FilterCondition::In(field, values) => {
|
||||
doc.get_nested(field)
|
||||
.map(|v| values.contains(v))
|
||||
.unwrap_or(false)
|
||||
}
|
||||
FilterCondition::Contains(field, substring) => {
|
||||
doc.get_nested(field)
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.contains(substring))
|
||||
.unwrap_or(false)
|
||||
}
|
||||
FilterCondition::In(field, values) => doc
|
||||
.get_nested(field)
|
||||
.map(|v| values.contains(v))
|
||||
.unwrap_or(false),
|
||||
FilterCondition::Contains(field, substring) => doc
|
||||
.get_nested(field)
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.contains(substring))
|
||||
.unwrap_or(false),
|
||||
FilterCondition::Exists(field, should_exist) => {
|
||||
let exists = doc.get_nested(field).is_some();
|
||||
exists == *should_exist
|
||||
}
|
||||
FilterCondition::And(filters) => {
|
||||
filters.iter().all(|f| f.matches(doc))
|
||||
}
|
||||
FilterCondition::Or(filters) => {
|
||||
filters.iter().any(|f| f.matches(doc))
|
||||
}
|
||||
FilterCondition::And(filters) => filters.iter().all(|f| f.matches(doc)),
|
||||
FilterCondition::Or(filters) => filters.iter().any(|f| f.matches(doc)),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -448,12 +461,10 @@ impl DocumentFilter {
|
|||
F: Fn(f64, f64) -> bool,
|
||||
{
|
||||
match (a, b) {
|
||||
(Some(JsonValue::Number(a)), JsonValue::Number(b)) => {
|
||||
match (a.as_f64(), b.as_f64()) {
|
||||
(Some(a), Some(b)) => cmp(a, b),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
(Some(JsonValue::Number(a)), JsonValue::Number(b)) => match (a.as_f64(), b.as_f64()) {
|
||||
(Some(a), Some(b)) => cmp(a, b),
|
||||
_ => false,
|
||||
},
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
|
@ -512,7 +523,11 @@ impl DocumentStore {
|
|||
}
|
||||
|
||||
/// Finds documents in a collection.
|
||||
pub fn find(&self, collection: &str, filter: &DocumentFilter) -> Result<Vec<Document>, DatabaseError> {
|
||||
pub fn find(
|
||||
&self,
|
||||
collection: &str,
|
||||
filter: &DocumentFilter,
|
||||
) -> Result<Vec<Document>, DatabaseError> {
|
||||
let collections = self.collections.read();
|
||||
let coll = collections
|
||||
.get(collection)
|
||||
|
|
@ -521,7 +536,11 @@ impl DocumentStore {
|
|||
}
|
||||
|
||||
/// Finds one document.
|
||||
pub fn find_one(&self, collection: &str, filter: &DocumentFilter) -> Result<Option<Document>, DatabaseError> {
|
||||
pub fn find_one(
|
||||
&self,
|
||||
collection: &str,
|
||||
filter: &DocumentFilter,
|
||||
) -> Result<Option<Document>, DatabaseError> {
|
||||
let collections = self.collections.read();
|
||||
let coll = collections
|
||||
.get(collection)
|
||||
|
|
@ -530,7 +549,11 @@ impl DocumentStore {
|
|||
}
|
||||
|
||||
/// Finds a document by ID.
|
||||
pub fn find_by_id(&self, collection: &str, id: &DocumentId) -> Result<Option<Document>, DatabaseError> {
|
||||
pub fn find_by_id(
|
||||
&self,
|
||||
collection: &str,
|
||||
id: &DocumentId,
|
||||
) -> Result<Option<Document>, DatabaseError> {
|
||||
let collections = self.collections.read();
|
||||
let coll = collections
|
||||
.get(collection)
|
||||
|
|
@ -539,7 +562,12 @@ impl DocumentStore {
|
|||
}
|
||||
|
||||
/// Updates a document by ID.
|
||||
pub fn update_by_id(&self, collection: &str, id: &DocumentId, update: JsonValue) -> Result<bool, DatabaseError> {
|
||||
pub fn update_by_id(
|
||||
&self,
|
||||
collection: &str,
|
||||
id: &DocumentId,
|
||||
update: JsonValue,
|
||||
) -> Result<bool, DatabaseError> {
|
||||
let collections = self.collections.read();
|
||||
let coll = collections
|
||||
.get(collection)
|
||||
|
|
@ -584,7 +612,8 @@ mod tests {
|
|||
fn test_collection_insert_find() {
|
||||
let coll = Collection::new("users");
|
||||
|
||||
coll.insert_one(json!({"name": "Alice", "age": 30})).unwrap();
|
||||
coll.insert_one(json!({"name": "Alice", "age": 30}))
|
||||
.unwrap();
|
||||
coll.insert_one(json!({"name": "Bob", "age": 25})).unwrap();
|
||||
|
||||
let filter = DocumentFilter::new().eq("name", json!("Alice"));
|
||||
|
|
@ -597,9 +626,11 @@ mod tests {
|
|||
fn test_filter_comparison() {
|
||||
let coll = Collection::new("users");
|
||||
|
||||
coll.insert_one(json!({"name": "Alice", "age": 30})).unwrap();
|
||||
coll.insert_one(json!({"name": "Alice", "age": 30}))
|
||||
.unwrap();
|
||||
coll.insert_one(json!({"name": "Bob", "age": 25})).unwrap();
|
||||
coll.insert_one(json!({"name": "Charlie", "age": 35})).unwrap();
|
||||
coll.insert_one(json!({"name": "Charlie", "age": 35}))
|
||||
.unwrap();
|
||||
|
||||
let filter = DocumentFilter::new().gte("age", json!(30));
|
||||
let results = coll.find(&filter);
|
||||
|
|
@ -622,7 +653,9 @@ mod tests {
|
|||
#[test]
|
||||
fn test_update_document() {
|
||||
let coll = Collection::new("users");
|
||||
let id = coll.insert_one(json!({"name": "Alice", "age": 30})).unwrap();
|
||||
let id = coll
|
||||
.insert_one(json!({"name": "Alice", "age": 30}))
|
||||
.unwrap();
|
||||
|
||||
coll.update_by_id(&id, json!({"age": 31})).unwrap();
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
//! Authentication and authorization for Database Gateway.
|
||||
|
||||
use parking_lot::RwLock;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use parking_lot::RwLock;
|
||||
|
||||
/// API key for authentication.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
|
|
|
|||
|
|
@ -272,19 +272,13 @@ pub fn json_to_filter(json: &JsonValue) -> Option<Filter> {
|
|||
|
||||
// Handle $and
|
||||
if let Some(and_arr) = obj.get("$and").and_then(|v| v.as_array()) {
|
||||
let filters: Vec<Filter> = and_arr
|
||||
.iter()
|
||||
.filter_map(json_to_filter)
|
||||
.collect();
|
||||
let filters: Vec<Filter> = and_arr.iter().filter_map(json_to_filter).collect();
|
||||
return Some(Filter::And(filters));
|
||||
}
|
||||
|
||||
// Handle $or
|
||||
if let Some(or_arr) = obj.get("$or").and_then(|v| v.as_array()) {
|
||||
let filters: Vec<Filter> = or_arr
|
||||
.iter()
|
||||
.filter_map(json_to_filter)
|
||||
.collect();
|
||||
let filters: Vec<Filter> = or_arr.iter().filter_map(json_to_filter).collect();
|
||||
return Some(Filter::Or(filters));
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -137,7 +137,12 @@ async fn kv_get(
|
|||
// For demo, use a default database
|
||||
let db = match get_default_database(&state) {
|
||||
Some(db) => db,
|
||||
None => return (StatusCode::NOT_FOUND, Json(ApiResponse::<KvGetResponse>::error("No database"))),
|
||||
None => {
|
||||
return (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(ApiResponse::<KvGetResponse>::error("No database")),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
state.record_read();
|
||||
|
|
@ -153,7 +158,12 @@ async fn kv_set(
|
|||
) -> impl IntoResponse {
|
||||
let db = match get_default_database(&state) {
|
||||
Some(db) => db,
|
||||
None => return (StatusCode::NOT_FOUND, Json(ApiResponse::<KvGetResponse>::error("No database"))),
|
||||
None => {
|
||||
return (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(ApiResponse::<KvGetResponse>::error("No database")),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
state.record_write(req.value.len() as u64);
|
||||
|
|
@ -168,7 +178,12 @@ async fn kv_delete(
|
|||
) -> impl IntoResponse {
|
||||
let db = match get_default_database(&state) {
|
||||
Some(db) => db,
|
||||
None => return (StatusCode::NOT_FOUND, Json(ApiResponse::<bool>::error("No database"))),
|
||||
None => {
|
||||
return (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(ApiResponse::<bool>::error("No database")),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
let response = handle_kv_delete(db.kv(), &key);
|
||||
|
|
@ -182,7 +197,12 @@ async fn kv_batch(
|
|||
) -> impl IntoResponse {
|
||||
let db = match get_default_database(&state) {
|
||||
Some(db) => db,
|
||||
None => return (StatusCode::NOT_FOUND, Json(ApiResponse::<KvBatchResponse>::error("No database"))),
|
||||
None => {
|
||||
return (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(ApiResponse::<KvBatchResponse>::error("No database")),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
let response = handle_kv_batch(db.kv(), req);
|
||||
|
|
@ -217,7 +237,9 @@ async fn list_databases(
|
|||
})
|
||||
.collect();
|
||||
|
||||
Json(ApiResponse::ok(ListDatabasesResponse { databases: response }))
|
||||
Json(ApiResponse::ok(ListDatabasesResponse {
|
||||
databases: response,
|
||||
}))
|
||||
}
|
||||
|
||||
async fn create_database(
|
||||
|
|
@ -250,7 +272,10 @@ async fn create_database(
|
|||
})),
|
||||
)
|
||||
}
|
||||
Err(e) => (StatusCode::BAD_REQUEST, Json(ApiResponse::error(e.to_string()))),
|
||||
Err(e) => (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(ApiResponse::error(e.to_string())),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -302,12 +327,20 @@ async fn create_collection(
|
|||
) -> impl IntoResponse {
|
||||
let db = match get_database(&state, &db_name) {
|
||||
Some(db) => db,
|
||||
None => return (StatusCode::NOT_FOUND, Json(ApiResponse::<bool>::error("Database not found"))),
|
||||
None => {
|
||||
return (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(ApiResponse::<bool>::error("Database not found")),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
match db.documents().create_collection(&req.name) {
|
||||
Ok(_) => (StatusCode::CREATED, Json(ApiResponse::ok(true))),
|
||||
Err(e) => (StatusCode::BAD_REQUEST, Json(ApiResponse::<bool>::error(e.to_string()))),
|
||||
Err(e) => (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(ApiResponse::<bool>::error(e.to_string())),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -349,15 +382,20 @@ async fn query_documents(
|
|||
) -> impl IntoResponse {
|
||||
let db = match get_database(&state, &db_name) {
|
||||
Some(db) => db,
|
||||
None => return (StatusCode::NOT_FOUND, Json(ApiResponse::<QueryDocumentsResponse>::error("Database not found"))),
|
||||
None => {
|
||||
return (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(ApiResponse::<QueryDocumentsResponse>::error(
|
||||
"Database not found",
|
||||
)),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
state.record_read();
|
||||
|
||||
// Build query
|
||||
let mut query = Query::new(&coll_name)
|
||||
.skip(req.skip)
|
||||
.limit(req.limit);
|
||||
let mut query = Query::new(&coll_name).skip(req.skip).limit(req.limit);
|
||||
|
||||
// Add filter
|
||||
if let Some(filter_json) = &req.filter {
|
||||
|
|
@ -384,8 +422,14 @@ async fn query_documents(
|
|||
}
|
||||
|
||||
match db.query().execute(&query) {
|
||||
Ok(result) => (StatusCode::OK, Json(ApiResponse::ok(QueryDocumentsResponse::from(result)))),
|
||||
Err(e) => (StatusCode::BAD_REQUEST, Json(ApiResponse::error(e.to_string()))),
|
||||
Ok(result) => (
|
||||
StatusCode::OK,
|
||||
Json(ApiResponse::ok(QueryDocumentsResponse::from(result))),
|
||||
),
|
||||
Err(e) => (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(ApiResponse::error(e.to_string())),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -396,15 +440,25 @@ async fn insert_document(
|
|||
) -> impl IntoResponse {
|
||||
let db = match get_database(&state, &db_name) {
|
||||
Some(db) => db,
|
||||
None => return (StatusCode::NOT_FOUND, Json(ApiResponse::<String>::error("Database not found"))),
|
||||
None => {
|
||||
return (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(ApiResponse::<String>::error("Database not found")),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
let size = serde_json::to_vec(&req.document).map(|v| v.len()).unwrap_or(0);
|
||||
let size = serde_json::to_vec(&req.document)
|
||||
.map(|v| v.len())
|
||||
.unwrap_or(0);
|
||||
state.record_write(size as u64);
|
||||
|
||||
match db.documents().insert(&coll_name, req.document) {
|
||||
Ok(id) => (StatusCode::CREATED, Json(ApiResponse::ok(id.to_hex()))),
|
||||
Err(e) => (StatusCode::BAD_REQUEST, Json(ApiResponse::error(e.to_string()))),
|
||||
Err(e) => (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(ApiResponse::error(e.to_string())),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -415,7 +469,12 @@ async fn insert_many_documents(
|
|||
) -> impl IntoResponse {
|
||||
let db = match get_database(&state, &db_name) {
|
||||
Some(db) => db,
|
||||
None => return (StatusCode::NOT_FOUND, Json(ApiResponse::<Vec<String>>::error("Database not found"))),
|
||||
None => {
|
||||
return (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(ApiResponse::<Vec<String>>::error("Database not found")),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
let mut ids = Vec::with_capacity(req.documents.len());
|
||||
|
|
@ -425,7 +484,12 @@ async fn insert_many_documents(
|
|||
total_size += serde_json::to_vec(&doc).map(|v| v.len()).unwrap_or(0) as u64;
|
||||
match db.documents().insert(&coll_name, doc) {
|
||||
Ok(id) => ids.push(id.to_hex()),
|
||||
Err(e) => return (StatusCode::BAD_REQUEST, Json(ApiResponse::error(e.to_string()))),
|
||||
Err(e) => {
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(ApiResponse::error(e.to_string())),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -477,7 +541,9 @@ async fn update_document(
|
|||
Err(e) => return Json(ApiResponse::error(e.to_string())),
|
||||
};
|
||||
|
||||
let update_size = serde_json::to_vec(&req.update).map(|v| v.len()).unwrap_or(0);
|
||||
let update_size = serde_json::to_vec(&req.update)
|
||||
.map(|v| v.len())
|
||||
.unwrap_or(0);
|
||||
state.record_write(update_size as u64);
|
||||
|
||||
match db.documents().update_by_id(&coll_name, &id, req.update) {
|
||||
|
|
@ -519,7 +585,12 @@ async fn insert_embeddings(
|
|||
) -> impl IntoResponse {
|
||||
let db = match get_database(&state, &db_name) {
|
||||
Some(db) => db,
|
||||
None => return (StatusCode::NOT_FOUND, Json(ApiResponse::<usize>::error("Database not found"))),
|
||||
None => {
|
||||
return (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(ApiResponse::<usize>::error("Database not found")),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
let mut count = 0;
|
||||
|
|
@ -533,7 +604,10 @@ async fn insert_embeddings(
|
|||
}
|
||||
|
||||
if let Err(e) = db.vectors().insert(embedding) {
|
||||
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error(e.to_string())));
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(ApiResponse::error(e.to_string())),
|
||||
);
|
||||
}
|
||||
count += 1;
|
||||
}
|
||||
|
|
@ -549,7 +623,14 @@ async fn vector_search(
|
|||
) -> impl IntoResponse {
|
||||
let db = match get_database(&state, &db_name) {
|
||||
Some(db) => db,
|
||||
None => return (StatusCode::NOT_FOUND, Json(ApiResponse::<VectorSearchResponse>::error("Database not found"))),
|
||||
None => {
|
||||
return (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(ApiResponse::<VectorSearchResponse>::error(
|
||||
"Database not found",
|
||||
)),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
state.record_vector_search();
|
||||
|
|
@ -563,9 +644,18 @@ async fn vector_search(
|
|||
Ok(results) => {
|
||||
let count = results.len();
|
||||
let matches: Vec<VectorMatch> = results.into_iter().map(Into::into).collect();
|
||||
(StatusCode::OK, Json(ApiResponse::ok(VectorSearchResponse { results: matches, count })))
|
||||
(
|
||||
StatusCode::OK,
|
||||
Json(ApiResponse::ok(VectorSearchResponse {
|
||||
results: matches,
|
||||
count,
|
||||
})),
|
||||
)
|
||||
}
|
||||
Err(e) => (StatusCode::BAD_REQUEST, Json(ApiResponse::error(e.to_string()))),
|
||||
Err(e) => (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(ApiResponse::error(e.to_string())),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -114,10 +114,7 @@ impl GatewayServer {
|
|||
tracing::error!("Failed to create default database: {}", e);
|
||||
}
|
||||
|
||||
let state = Arc::new(AppState::new(
|
||||
self.db_manager.clone(),
|
||||
self.auth.clone(),
|
||||
));
|
||||
let state = Arc::new(AppState::new(self.db_manager.clone(), self.auth.clone()));
|
||||
|
||||
let app = create_router(state);
|
||||
|
||||
|
|
|
|||
|
|
@ -86,7 +86,12 @@ pub struct Edge {
|
|||
|
||||
impl Edge {
|
||||
/// Creates a new directed edge.
|
||||
pub fn new(source: NodeId, target: NodeId, edge_type: impl Into<String>, properties: JsonValue) -> Self {
|
||||
pub fn new(
|
||||
source: NodeId,
|
||||
target: NodeId,
|
||||
edge_type: impl Into<String>,
|
||||
properties: JsonValue,
|
||||
) -> Self {
|
||||
let now = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
|
|
@ -105,7 +110,12 @@ impl Edge {
|
|||
}
|
||||
|
||||
/// Creates an undirected edge.
|
||||
pub fn undirected(source: NodeId, target: NodeId, edge_type: impl Into<String>, properties: JsonValue) -> Self {
|
||||
pub fn undirected(
|
||||
source: NodeId,
|
||||
target: NodeId,
|
||||
edge_type: impl Into<String>,
|
||||
properties: JsonValue,
|
||||
) -> Self {
|
||||
let mut edge = Self::new(source, target, edge_type, properties);
|
||||
edge.directed = false;
|
||||
edge
|
||||
|
|
@ -138,8 +148,8 @@ impl Edge {
|
|||
|
||||
/// Checks if this edge connects two specific nodes.
|
||||
pub fn connects_pair(&self, a: &NodeId, b: &NodeId) -> bool {
|
||||
(&self.source == a && &self.target == b) ||
|
||||
(!self.directed && &self.source == b && &self.target == a)
|
||||
(&self.source == a && &self.target == b)
|
||||
|| (!self.directed && &self.source == b && &self.target == a)
|
||||
}
|
||||
|
||||
/// Gets a property value.
|
||||
|
|
@ -156,7 +166,9 @@ impl Edge {
|
|||
|
||||
/// Checks if the edge matches a property filter.
|
||||
pub fn matches_properties(&self, filter: &JsonValue) -> bool {
|
||||
if let (Some(filter_obj), Some(props_obj)) = (filter.as_object(), self.properties.as_object()) {
|
||||
if let (Some(filter_obj), Some(props_obj)) =
|
||||
(filter.as_object(), self.properties.as_object())
|
||||
{
|
||||
for (key, expected) in filter_obj {
|
||||
if let Some(actual) = props_obj.get(key) {
|
||||
if actual != expected {
|
||||
|
|
@ -216,7 +228,12 @@ impl EdgeBuilder {
|
|||
|
||||
/// Builds the edge.
|
||||
pub fn build(self) -> Edge {
|
||||
let mut edge = Edge::new(self.source, self.target, self.edge_type, JsonValue::Object(self.properties));
|
||||
let mut edge = Edge::new(
|
||||
self.source,
|
||||
self.target,
|
||||
self.edge_type,
|
||||
JsonValue::Object(self.properties),
|
||||
);
|
||||
edge.directed = self.directed;
|
||||
edge.weight = self.weight;
|
||||
edge
|
||||
|
|
@ -264,7 +281,10 @@ mod tests {
|
|||
|
||||
assert!(!edge.directed);
|
||||
assert_eq!(edge.weight, 2.5);
|
||||
assert_eq!(edge.get_property("percentage"), Some(&serde_json::json!(50)));
|
||||
assert_eq!(
|
||||
edge.get_property("percentage"),
|
||||
Some(&serde_json::json!(50))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
|||
|
|
@ -172,7 +172,9 @@ impl Node {
|
|||
|
||||
/// Checks if the node matches a property filter.
|
||||
pub fn matches_properties(&self, filter: &JsonValue) -> bool {
|
||||
if let (Some(filter_obj), Some(props_obj)) = (filter.as_object(), self.properties.as_object()) {
|
||||
if let (Some(filter_obj), Some(props_obj)) =
|
||||
(filter.as_object(), self.properties.as_object())
|
||||
{
|
||||
for (key, expected) in filter_obj {
|
||||
if let Some(actual) = props_obj.get(key) {
|
||||
if actual != expected {
|
||||
|
|
@ -258,10 +260,7 @@ mod tests {
|
|||
|
||||
assert!(node.has_label("User"));
|
||||
assert!(!node.has_label("Admin"));
|
||||
assert_eq!(
|
||||
node.get_property("name"),
|
||||
Some(&serde_json::json!("Alice"))
|
||||
);
|
||||
assert_eq!(node.get_property("name"), Some(&serde_json::json!("Alice")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
|||
|
|
@ -60,7 +60,10 @@ impl Eq for DijkstraState {}
|
|||
impl Ord for DijkstraState {
|
||||
fn cmp(&self, other: &Self) -> Ordering {
|
||||
// Reverse ordering for min-heap
|
||||
other.distance.partial_cmp(&self.distance).unwrap_or(Ordering::Equal)
|
||||
other
|
||||
.distance
|
||||
.partial_cmp(&self.distance)
|
||||
.unwrap_or(Ordering::Equal)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -140,9 +143,16 @@ impl<'a> PathFinder<'a> {
|
|||
let mut visited = HashSet::new();
|
||||
|
||||
distances.insert(*from, 0.0);
|
||||
heap.push(DijkstraState { node: *from, distance: 0.0 });
|
||||
heap.push(DijkstraState {
|
||||
node: *from,
|
||||
distance: 0.0,
|
||||
});
|
||||
|
||||
while let Some(DijkstraState { node: current, distance: dist }) = heap.pop() {
|
||||
while let Some(DijkstraState {
|
||||
node: current,
|
||||
distance: dist,
|
||||
}) = heap.pop()
|
||||
{
|
||||
if ¤t == to {
|
||||
// Reconstruct path
|
||||
let mut path = vec![current];
|
||||
|
|
@ -181,12 +191,18 @@ impl<'a> PathFinder<'a> {
|
|||
}
|
||||
|
||||
let new_dist = dist + edge.weight;
|
||||
let is_shorter = distances.get(&neighbor).map(|&d| new_dist < d).unwrap_or(true);
|
||||
let is_shorter = distances
|
||||
.get(&neighbor)
|
||||
.map(|&d| new_dist < d)
|
||||
.unwrap_or(true);
|
||||
|
||||
if is_shorter {
|
||||
distances.insert(neighbor, new_dist);
|
||||
previous.insert(neighbor, (current, edge.clone()));
|
||||
heap.push(DijkstraState { node: neighbor, distance: new_dist });
|
||||
heap.push(DijkstraState {
|
||||
node: neighbor,
|
||||
distance: new_dist,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -251,7 +267,9 @@ impl<'a> PathFinder<'a> {
|
|||
path.push(neighbor);
|
||||
edges.push(edge.clone());
|
||||
|
||||
self.find_all_paths_dfs(&neighbor, target, max_length, path, edges, visited, results);
|
||||
self.find_all_paths_dfs(
|
||||
&neighbor, target, max_length, path, edges, visited, results,
|
||||
);
|
||||
|
||||
path.pop();
|
||||
edges.pop();
|
||||
|
|
@ -261,7 +279,12 @@ impl<'a> PathFinder<'a> {
|
|||
}
|
||||
|
||||
/// Finds the shortest path considering only specific edge types.
|
||||
pub fn shortest_path_by_type(&self, from: &NodeId, to: &NodeId, edge_types: &[String]) -> PathResult {
|
||||
pub fn shortest_path_by_type(
|
||||
&self,
|
||||
from: &NodeId,
|
||||
to: &NodeId,
|
||||
edge_types: &[String],
|
||||
) -> PathResult {
|
||||
if from == to {
|
||||
return PathResult::found(vec![*from], Vec::new(), 0.0);
|
||||
}
|
||||
|
|
@ -377,11 +400,21 @@ mod tests {
|
|||
let d = store.create_node(vec![], serde_json::json!({"name": "D"}));
|
||||
let e = store.create_node(vec![], serde_json::json!({"name": "E"}));
|
||||
|
||||
store.create_edge(a, b, "LINK", serde_json::json!({})).unwrap();
|
||||
store.create_edge(b, c, "LINK", serde_json::json!({})).unwrap();
|
||||
store.create_edge(c, d, "LINK", serde_json::json!({})).unwrap();
|
||||
store.create_edge(a, e, "LINK", serde_json::json!({})).unwrap();
|
||||
store.create_edge(e, d, "LINK", serde_json::json!({})).unwrap();
|
||||
store
|
||||
.create_edge(a, b, "LINK", serde_json::json!({}))
|
||||
.unwrap();
|
||||
store
|
||||
.create_edge(b, c, "LINK", serde_json::json!({}))
|
||||
.unwrap();
|
||||
store
|
||||
.create_edge(c, d, "LINK", serde_json::json!({}))
|
||||
.unwrap();
|
||||
store
|
||||
.create_edge(a, e, "LINK", serde_json::json!({}))
|
||||
.unwrap();
|
||||
store
|
||||
.create_edge(e, d, "LINK", serde_json::json!({}))
|
||||
.unwrap();
|
||||
|
||||
store
|
||||
}
|
||||
|
|
@ -392,8 +425,14 @@ mod tests {
|
|||
let finder = PathFinder::new(&store);
|
||||
|
||||
let nodes: Vec<_> = store.find_nodes(None, &serde_json::json!({}));
|
||||
let a = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("A"))).unwrap();
|
||||
let d = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("D"))).unwrap();
|
||||
let a = nodes
|
||||
.iter()
|
||||
.find(|n| n.get_property("name") == Some(&serde_json::json!("A")))
|
||||
.unwrap();
|
||||
let d = nodes
|
||||
.iter()
|
||||
.find(|n| n.get_property("name") == Some(&serde_json::json!("D")))
|
||||
.unwrap();
|
||||
|
||||
let result = finder.shortest_path_bfs(&a.id, &d.id);
|
||||
|
||||
|
|
@ -413,13 +452,19 @@ mod tests {
|
|||
// A --(3.0)--> C
|
||||
let mut edge1 = super::super::edge::Edge::new(a, b, "LINK", serde_json::json!({}));
|
||||
edge1.weight = 1.0;
|
||||
store.create_edge(a, b, "LINK", serde_json::json!({})).unwrap();
|
||||
store
|
||||
.create_edge(a, b, "LINK", serde_json::json!({}))
|
||||
.unwrap();
|
||||
|
||||
let mut edge2 = super::super::edge::Edge::new(b, c, "LINK", serde_json::json!({}));
|
||||
edge2.weight = 1.0;
|
||||
store.create_edge(b, c, "LINK", serde_json::json!({})).unwrap();
|
||||
store
|
||||
.create_edge(b, c, "LINK", serde_json::json!({}))
|
||||
.unwrap();
|
||||
|
||||
store.create_edge(a, c, "DIRECT", serde_json::json!({})).unwrap();
|
||||
store
|
||||
.create_edge(a, c, "DIRECT", serde_json::json!({}))
|
||||
.unwrap();
|
||||
|
||||
let finder = PathFinder::new(&store);
|
||||
let result = finder.shortest_path_dijkstra(&a, &c);
|
||||
|
|
@ -449,8 +494,14 @@ mod tests {
|
|||
let finder = PathFinder::new(&store);
|
||||
|
||||
let nodes: Vec<_> = store.find_nodes(None, &serde_json::json!({}));
|
||||
let a = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("A"))).unwrap();
|
||||
let d = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("D"))).unwrap();
|
||||
let a = nodes
|
||||
.iter()
|
||||
.find(|n| n.get_property("name") == Some(&serde_json::json!("A")))
|
||||
.unwrap();
|
||||
let d = nodes
|
||||
.iter()
|
||||
.find(|n| n.get_property("name") == Some(&serde_json::json!("D")))
|
||||
.unwrap();
|
||||
|
||||
let paths = finder.all_paths(&a.id, &d.id, 5);
|
||||
|
||||
|
|
@ -463,8 +514,14 @@ mod tests {
|
|||
let finder = PathFinder::new(&store);
|
||||
|
||||
let nodes: Vec<_> = store.find_nodes(None, &serde_json::json!({}));
|
||||
let a = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("A"))).unwrap();
|
||||
let d = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("D"))).unwrap();
|
||||
let a = nodes
|
||||
.iter()
|
||||
.find(|n| n.get_property("name") == Some(&serde_json::json!("A")))
|
||||
.unwrap();
|
||||
let d = nodes
|
||||
.iter()
|
||||
.find(|n| n.get_property("name") == Some(&serde_json::json!("D")))
|
||||
.unwrap();
|
||||
|
||||
assert!(finder.path_exists(&a.id, &d.id));
|
||||
}
|
||||
|
|
@ -475,9 +532,18 @@ mod tests {
|
|||
let finder = PathFinder::new(&store);
|
||||
|
||||
let nodes: Vec<_> = store.find_nodes(None, &serde_json::json!({}));
|
||||
let a = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("A"))).unwrap();
|
||||
let d = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("D"))).unwrap();
|
||||
let b = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("B"))).unwrap();
|
||||
let a = nodes
|
||||
.iter()
|
||||
.find(|n| n.get_property("name") == Some(&serde_json::json!("A")))
|
||||
.unwrap();
|
||||
let d = nodes
|
||||
.iter()
|
||||
.find(|n| n.get_property("name") == Some(&serde_json::json!("D")))
|
||||
.unwrap();
|
||||
let b = nodes
|
||||
.iter()
|
||||
.find(|n| n.get_property("name") == Some(&serde_json::json!("B")))
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(finder.distance(&a.id, &b.id), Some(1));
|
||||
assert_eq!(finder.distance(&a.id, &d.id), Some(2)); // A -> E -> D
|
||||
|
|
|
|||
|
|
@ -23,7 +23,10 @@ pub enum GraphQuery {
|
|||
/// DELETE query for removing nodes/edges.
|
||||
Delete { variable: String, detach: bool },
|
||||
/// SET query for updating properties.
|
||||
Set { variable: String, properties: JsonValue },
|
||||
Set {
|
||||
variable: String,
|
||||
properties: JsonValue,
|
||||
},
|
||||
}
|
||||
|
||||
/// Pattern to match in the graph.
|
||||
|
|
@ -78,15 +81,35 @@ pub enum RelationshipDirection {
|
|||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub enum WhereClause {
|
||||
/// Property comparison.
|
||||
PropertyEquals { variable: String, property: String, value: JsonValue },
|
||||
PropertyEquals {
|
||||
variable: String,
|
||||
property: String,
|
||||
value: JsonValue,
|
||||
},
|
||||
/// Property comparison (not equals).
|
||||
PropertyNotEquals { variable: String, property: String, value: JsonValue },
|
||||
PropertyNotEquals {
|
||||
variable: String,
|
||||
property: String,
|
||||
value: JsonValue,
|
||||
},
|
||||
/// Property greater than.
|
||||
PropertyGt { variable: String, property: String, value: JsonValue },
|
||||
PropertyGt {
|
||||
variable: String,
|
||||
property: String,
|
||||
value: JsonValue,
|
||||
},
|
||||
/// Property less than.
|
||||
PropertyLt { variable: String, property: String, value: JsonValue },
|
||||
PropertyLt {
|
||||
variable: String,
|
||||
property: String,
|
||||
value: JsonValue,
|
||||
},
|
||||
/// Property contains (for text).
|
||||
PropertyContains { variable: String, property: String, value: String },
|
||||
PropertyContains {
|
||||
variable: String,
|
||||
property: String,
|
||||
value: String,
|
||||
},
|
||||
/// AND condition.
|
||||
And(Box<WhereClause>, Box<WhereClause>),
|
||||
/// OR condition.
|
||||
|
|
@ -105,7 +128,10 @@ pub enum ReturnItem {
|
|||
/// Return a property of a variable.
|
||||
Property { variable: String, property: String },
|
||||
/// Return with an alias.
|
||||
Alias { item: Box<ReturnItem>, alias: String },
|
||||
Alias {
|
||||
item: Box<ReturnItem>,
|
||||
alias: String,
|
||||
},
|
||||
/// Count aggregation.
|
||||
Count(Option<String>),
|
||||
}
|
||||
|
|
@ -114,7 +140,11 @@ pub enum ReturnItem {
|
|||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub enum CreateElement {
|
||||
/// Create a node.
|
||||
Node { variable: Option<String>, labels: Vec<String>, properties: JsonValue },
|
||||
Node {
|
||||
variable: Option<String>,
|
||||
labels: Vec<String>,
|
||||
properties: JsonValue,
|
||||
},
|
||||
/// Create a relationship.
|
||||
Relationship {
|
||||
from_var: String,
|
||||
|
|
@ -176,7 +206,10 @@ impl GraphQueryParser {
|
|||
} else if upper.starts_with("SET") {
|
||||
Self::parse_set(query)
|
||||
} else {
|
||||
Err(GraphError::InvalidOperation(format!("Unknown query type: {}", query)))
|
||||
Err(GraphError::InvalidOperation(format!(
|
||||
"Unknown query type: {}",
|
||||
query
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -185,7 +218,10 @@ impl GraphQueryParser {
|
|||
let upper = query.to_uppercase();
|
||||
|
||||
// Find MATCH, WHERE, RETURN, LIMIT positions
|
||||
let match_end = upper.find("WHERE").or_else(|| upper.find("RETURN")).unwrap_or(query.len());
|
||||
let match_end = upper
|
||||
.find("WHERE")
|
||||
.or_else(|| upper.find("RETURN"))
|
||||
.unwrap_or(query.len());
|
||||
let where_start = upper.find("WHERE");
|
||||
let return_start = upper.find("RETURN");
|
||||
let limit_start = upper.find("LIMIT");
|
||||
|
|
@ -253,7 +289,9 @@ impl GraphQueryParser {
|
|||
}
|
||||
|
||||
if nodes.is_empty() {
|
||||
return Err(GraphError::InvalidOperation("No node pattern found".to_string()));
|
||||
return Err(GraphError::InvalidOperation(
|
||||
"No node pattern found".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Combine nodes with relationships
|
||||
|
|
@ -264,10 +302,15 @@ impl GraphQueryParser {
|
|||
}
|
||||
}
|
||||
|
||||
Ok(MatchPattern { start, relationships })
|
||||
Ok(MatchPattern {
|
||||
start,
|
||||
relationships,
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_node_pattern(chars: &mut std::iter::Peekable<std::str::Chars>) -> Result<NodePattern, GraphError> {
|
||||
fn parse_node_pattern(
|
||||
chars: &mut std::iter::Peekable<std::str::Chars>,
|
||||
) -> Result<NodePattern, GraphError> {
|
||||
// Consume '('
|
||||
chars.next();
|
||||
|
||||
|
|
@ -335,10 +378,16 @@ impl GraphQueryParser {
|
|||
}
|
||||
}
|
||||
|
||||
Ok(NodePattern { variable, labels, properties })
|
||||
Ok(NodePattern {
|
||||
variable,
|
||||
labels,
|
||||
properties,
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_relationship_pattern(chars: &mut std::iter::Peekable<std::str::Chars>) -> Result<RelationshipPattern, GraphError> {
|
||||
fn parse_relationship_pattern(
|
||||
chars: &mut std::iter::Peekable<std::str::Chars>,
|
||||
) -> Result<RelationshipPattern, GraphError> {
|
||||
let mut direction = RelationshipDirection::Undirected;
|
||||
let mut edge_type = None;
|
||||
let mut variable = None;
|
||||
|
|
@ -408,7 +457,11 @@ impl GraphQueryParser {
|
|||
variable,
|
||||
edge_type,
|
||||
direction,
|
||||
target: NodePattern { variable: None, labels: Vec::new(), properties: None },
|
||||
target: NodePattern {
|
||||
variable: None,
|
||||
labels: Vec::new(),
|
||||
properties: None,
|
||||
},
|
||||
min_hops,
|
||||
max_hops,
|
||||
})
|
||||
|
|
@ -476,7 +529,9 @@ impl GraphQueryParser {
|
|||
elements.push(CreateElement::Node {
|
||||
variable: node.variable,
|
||||
labels: node.labels,
|
||||
properties: node.properties.unwrap_or(JsonValue::Object(serde_json::Map::new())),
|
||||
properties: node
|
||||
.properties
|
||||
.unwrap_or(JsonValue::Object(serde_json::Map::new())),
|
||||
});
|
||||
} else {
|
||||
break;
|
||||
|
|
@ -488,7 +543,11 @@ impl GraphQueryParser {
|
|||
|
||||
fn parse_delete(query: &str) -> Result<GraphQuery, GraphError> {
|
||||
let detach = query.to_uppercase().starts_with("DETACH");
|
||||
let start = if detach { "DETACH DELETE".len() } else { "DELETE".len() };
|
||||
let start = if detach {
|
||||
"DETACH DELETE".len()
|
||||
} else {
|
||||
"DELETE".len()
|
||||
};
|
||||
let variable = query[start..].trim().to_string();
|
||||
|
||||
Ok(GraphQuery::Delete { variable, detach })
|
||||
|
|
@ -500,19 +559,24 @@ impl GraphQueryParser {
|
|||
let parts: Vec<_> = content.split('=').collect();
|
||||
|
||||
if parts.len() != 2 {
|
||||
return Err(GraphError::InvalidOperation("Invalid SET syntax".to_string()));
|
||||
return Err(GraphError::InvalidOperation(
|
||||
"Invalid SET syntax".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let var_prop: Vec<_> = parts[0].trim().split('.').collect();
|
||||
if var_prop.len() != 2 {
|
||||
return Err(GraphError::InvalidOperation("Invalid SET variable".to_string()));
|
||||
return Err(GraphError::InvalidOperation(
|
||||
"Invalid SET variable".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let variable = var_prop[0].to_string();
|
||||
let property = var_prop[1].to_string();
|
||||
let value_str = parts[1].trim();
|
||||
|
||||
let value: JsonValue = serde_json::from_str(value_str).unwrap_or(JsonValue::String(value_str.to_string()));
|
||||
let value: JsonValue =
|
||||
serde_json::from_str(value_str).unwrap_or(JsonValue::String(value_str.to_string()));
|
||||
|
||||
Ok(GraphQuery::Set {
|
||||
variable,
|
||||
|
|
@ -535,18 +599,21 @@ impl<'a> GraphQueryExecutor<'a> {
|
|||
/// Executes a graph query.
|
||||
pub fn execute(&self, query: &GraphQuery) -> Result<QueryResult, GraphError> {
|
||||
match query {
|
||||
GraphQuery::Match { pattern, where_clause, return_items, limit } => {
|
||||
self.execute_match(pattern, where_clause.as_ref(), return_items, *limit)
|
||||
}
|
||||
GraphQuery::Create { .. } => {
|
||||
Err(GraphError::InvalidOperation("CREATE requires mutable access".to_string()))
|
||||
}
|
||||
GraphQuery::Delete { .. } => {
|
||||
Err(GraphError::InvalidOperation("DELETE requires mutable access".to_string()))
|
||||
}
|
||||
GraphQuery::Set { .. } => {
|
||||
Err(GraphError::InvalidOperation("SET requires mutable access".to_string()))
|
||||
}
|
||||
GraphQuery::Match {
|
||||
pattern,
|
||||
where_clause,
|
||||
return_items,
|
||||
limit,
|
||||
} => self.execute_match(pattern, where_clause.as_ref(), return_items, *limit),
|
||||
GraphQuery::Create { .. } => Err(GraphError::InvalidOperation(
|
||||
"CREATE requires mutable access".to_string(),
|
||||
)),
|
||||
GraphQuery::Delete { .. } => Err(GraphError::InvalidOperation(
|
||||
"DELETE requires mutable access".to_string(),
|
||||
)),
|
||||
GraphQuery::Set { .. } => Err(GraphError::InvalidOperation(
|
||||
"SET requires mutable access".to_string(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -586,7 +653,11 @@ impl<'a> GraphQueryExecutor<'a> {
|
|||
.depth(rel_pattern.max_hops)
|
||||
.direction(direction)
|
||||
.edge_types(
|
||||
rel_pattern.edge_type.clone().map(|t| vec![t]).unwrap_or_default(),
|
||||
rel_pattern
|
||||
.edge_type
|
||||
.clone()
|
||||
.map(|t| vec![t])
|
||||
.unwrap_or_default(),
|
||||
)
|
||||
.labels(rel_pattern.target.labels.clone());
|
||||
|
||||
|
|
@ -635,7 +706,10 @@ impl<'a> GraphQueryExecutor<'a> {
|
|||
|
||||
fn find_matching_nodes(&self, pattern: &NodePattern) -> Vec<Node> {
|
||||
let label = pattern.labels.first().map(|s| s.as_str());
|
||||
let filter = pattern.properties.clone().unwrap_or(JsonValue::Object(serde_json::Map::new()));
|
||||
let filter = pattern
|
||||
.properties
|
||||
.clone()
|
||||
.unwrap_or(JsonValue::Object(serde_json::Map::new()));
|
||||
self.store.find_nodes(label, &filter)
|
||||
}
|
||||
|
||||
|
|
@ -657,7 +731,11 @@ impl<'a> GraphQueryExecutor<'a> {
|
|||
})
|
||||
}
|
||||
|
||||
fn get_column_names(&self, return_items: &[ReturnItem], bindings: &[HashMap<String, JsonValue>]) -> Vec<String> {
|
||||
fn get_column_names(
|
||||
&self,
|
||||
return_items: &[ReturnItem],
|
||||
bindings: &[HashMap<String, JsonValue>],
|
||||
) -> Vec<String> {
|
||||
let mut columns = Vec::new();
|
||||
|
||||
for item in return_items {
|
||||
|
|
@ -673,7 +751,10 @@ impl<'a> GraphQueryExecutor<'a> {
|
|||
}
|
||||
ReturnItem::Alias { alias, .. } => columns.push(alias.clone()),
|
||||
ReturnItem::Count(var) => {
|
||||
columns.push(format!("count({})", var.as_ref().map(|s| s.as_str()).unwrap_or("*")));
|
||||
columns.push(format!(
|
||||
"count({})",
|
||||
var.as_ref().map(|s| s.as_str()).unwrap_or("*")
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -681,11 +762,18 @@ impl<'a> GraphQueryExecutor<'a> {
|
|||
columns
|
||||
}
|
||||
|
||||
fn extract_rows(&self, return_items: &[ReturnItem], bindings: &[HashMap<String, JsonValue>]) -> Vec<Vec<JsonValue>> {
|
||||
fn extract_rows(
|
||||
&self,
|
||||
return_items: &[ReturnItem],
|
||||
bindings: &[HashMap<String, JsonValue>],
|
||||
) -> Vec<Vec<JsonValue>> {
|
||||
let mut rows = Vec::new();
|
||||
|
||||
// Handle COUNT specially
|
||||
if return_items.iter().any(|i| matches!(i, ReturnItem::Count(_))) {
|
||||
if return_items
|
||||
.iter()
|
||||
.any(|i| matches!(i, ReturnItem::Count(_)))
|
||||
{
|
||||
rows.push(vec![JsonValue::Number(bindings.len().into())]);
|
||||
return rows;
|
||||
}
|
||||
|
|
@ -760,8 +848,14 @@ mod tests {
|
|||
if let GraphQuery::Match { pattern, .. } = parsed {
|
||||
assert_eq!(pattern.start.labels, vec!["User".to_string()]);
|
||||
assert_eq!(pattern.relationships.len(), 1);
|
||||
assert_eq!(pattern.relationships[0].edge_type, Some("FRIEND".to_string()));
|
||||
assert_eq!(pattern.relationships[0].direction, RelationshipDirection::Outgoing);
|
||||
assert_eq!(
|
||||
pattern.relationships[0].edge_type,
|
||||
Some("FRIEND".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
pattern.relationships[0].direction,
|
||||
RelationshipDirection::Outgoing
|
||||
);
|
||||
} else {
|
||||
panic!("Expected Match query");
|
||||
}
|
||||
|
|
@ -771,9 +865,14 @@ mod tests {
|
|||
fn test_execute_match() {
|
||||
let store = GraphStore::new();
|
||||
|
||||
let alice = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Alice"}));
|
||||
let alice = store.create_node(
|
||||
vec!["User".to_string()],
|
||||
serde_json::json!({"name": "Alice"}),
|
||||
);
|
||||
let bob = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Bob"}));
|
||||
store.create_edge(alice, bob, "FRIEND", serde_json::json!({})).unwrap();
|
||||
store
|
||||
.create_edge(alice, bob, "FRIEND", serde_json::json!({}))
|
||||
.unwrap();
|
||||
|
||||
let query = GraphQueryParser::parse("MATCH (n:User) RETURN n").unwrap();
|
||||
let executor = GraphQueryExecutor::new(&store);
|
||||
|
|
|
|||
|
|
@ -177,18 +177,8 @@ impl GraphStore {
|
|||
/// Deletes a node and all its connected edges.
|
||||
pub fn delete_node(&self, id: &NodeId) -> Result<(), GraphError> {
|
||||
// Get connected edges
|
||||
let outgoing: Vec<EdgeId> = self
|
||||
.adjacency
|
||||
.read()
|
||||
.get(id)
|
||||
.cloned()
|
||||
.unwrap_or_default();
|
||||
let incoming: Vec<EdgeId> = self
|
||||
.reverse_adj
|
||||
.read()
|
||||
.get(id)
|
||||
.cloned()
|
||||
.unwrap_or_default();
|
||||
let outgoing: Vec<EdgeId> = self.adjacency.read().get(id).cloned().unwrap_or_default();
|
||||
let incoming: Vec<EdgeId> = self.reverse_adj.read().get(id).cloned().unwrap_or_default();
|
||||
|
||||
// Delete all connected edges
|
||||
for edge_id in outgoing.iter().chain(incoming.iter()) {
|
||||
|
|
@ -457,7 +447,12 @@ impl GraphStore {
|
|||
}
|
||||
|
||||
/// Gets the neighbor node from an edge.
|
||||
fn get_neighbor_from_edge(&self, edge: &Edge, from: &NodeId, direction: Direction) -> Option<NodeId> {
|
||||
fn get_neighbor_from_edge(
|
||||
&self,
|
||||
edge: &Edge,
|
||||
from: &NodeId,
|
||||
direction: Direction,
|
||||
) -> Option<NodeId> {
|
||||
match direction {
|
||||
Direction::Outgoing => {
|
||||
if &edge.source == from {
|
||||
|
|
@ -491,7 +486,12 @@ impl GraphStore {
|
|||
}
|
||||
|
||||
/// Gets neighbors connected by a specific edge type.
|
||||
pub fn neighbors_by_type(&self, id: &NodeId, edge_type: &str, direction: Direction) -> Vec<Node> {
|
||||
pub fn neighbors_by_type(
|
||||
&self,
|
||||
id: &NodeId,
|
||||
edge_type: &str,
|
||||
direction: Direction,
|
||||
) -> Vec<Node> {
|
||||
let edges = self.edges_of(id, direction);
|
||||
let nodes = self.nodes.read();
|
||||
|
||||
|
|
@ -565,7 +565,10 @@ mod tests {
|
|||
fn test_create_edge() {
|
||||
let store = GraphStore::new();
|
||||
|
||||
let alice = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Alice"}));
|
||||
let alice = store.create_node(
|
||||
vec!["User".to_string()],
|
||||
serde_json::json!({"name": "Alice"}),
|
||||
);
|
||||
let bob = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Bob"}));
|
||||
|
||||
let edge_id = store
|
||||
|
|
@ -582,12 +585,22 @@ mod tests {
|
|||
fn test_neighbors() {
|
||||
let store = GraphStore::new();
|
||||
|
||||
let alice = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Alice"}));
|
||||
let alice = store.create_node(
|
||||
vec!["User".to_string()],
|
||||
serde_json::json!({"name": "Alice"}),
|
||||
);
|
||||
let bob = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Bob"}));
|
||||
let charlie = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Charlie"}));
|
||||
let charlie = store.create_node(
|
||||
vec!["User".to_string()],
|
||||
serde_json::json!({"name": "Charlie"}),
|
||||
);
|
||||
|
||||
store.create_edge(alice, bob, "FRIEND", serde_json::json!({})).unwrap();
|
||||
store.create_edge(alice, charlie, "FRIEND", serde_json::json!({})).unwrap();
|
||||
store
|
||||
.create_edge(alice, bob, "FRIEND", serde_json::json!({}))
|
||||
.unwrap();
|
||||
store
|
||||
.create_edge(alice, charlie, "FRIEND", serde_json::json!({}))
|
||||
.unwrap();
|
||||
|
||||
let neighbors = store.neighbors(&alice, Direction::Outgoing);
|
||||
assert_eq!(neighbors.len(), 2);
|
||||
|
|
@ -597,9 +610,15 @@ mod tests {
|
|||
fn test_find_by_label() {
|
||||
let store = GraphStore::new();
|
||||
|
||||
store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Alice"}));
|
||||
store.create_node(
|
||||
vec!["User".to_string()],
|
||||
serde_json::json!({"name": "Alice"}),
|
||||
);
|
||||
store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Bob"}));
|
||||
store.create_node(vec!["Product".to_string()], serde_json::json!({"name": "Widget"}));
|
||||
store.create_node(
|
||||
vec!["Product".to_string()],
|
||||
serde_json::json!({"name": "Widget"}),
|
||||
);
|
||||
|
||||
let users = store.find_nodes_by_label("User");
|
||||
assert_eq!(users.len(), 2);
|
||||
|
|
@ -615,7 +634,9 @@ mod tests {
|
|||
let alice = store.create_node(vec!["User".to_string()], serde_json::json!({}));
|
||||
let bob = store.create_node(vec!["User".to_string()], serde_json::json!({}));
|
||||
|
||||
store.create_edge(alice, bob, "FRIEND", serde_json::json!({})).unwrap();
|
||||
store
|
||||
.create_edge(alice, bob, "FRIEND", serde_json::json!({}))
|
||||
.unwrap();
|
||||
|
||||
// Delete Alice - should also delete the edge
|
||||
store.delete_node(&alice).unwrap();
|
||||
|
|
@ -631,7 +652,9 @@ mod tests {
|
|||
let a = store.create_node(vec![], serde_json::json!({}));
|
||||
let b = store.create_node(vec![], serde_json::json!({}));
|
||||
|
||||
store.create_undirected_edge(a, b, "LINK", serde_json::json!({})).unwrap();
|
||||
store
|
||||
.create_undirected_edge(a, b, "LINK", serde_json::json!({}))
|
||||
.unwrap();
|
||||
|
||||
// Both directions should work
|
||||
let a_neighbors = store.neighbors(&a, Direction::Outgoing);
|
||||
|
|
@ -648,8 +671,12 @@ mod tests {
|
|||
let a = store.create_node(vec![], serde_json::json!({}));
|
||||
let b = store.create_node(vec![], serde_json::json!({}));
|
||||
|
||||
store.create_edge(a, b, "TYPE_A", serde_json::json!({})).unwrap();
|
||||
store.create_edge(a, b, "TYPE_B", serde_json::json!({})).unwrap();
|
||||
store
|
||||
.create_edge(a, b, "TYPE_A", serde_json::json!({}))
|
||||
.unwrap();
|
||||
store
|
||||
.create_edge(a, b, "TYPE_B", serde_json::json!({}))
|
||||
.unwrap();
|
||||
|
||||
let edges = store.edges_between(&a, &b);
|
||||
assert_eq!(edges.len(), 2);
|
||||
|
|
|
|||
|
|
@ -171,7 +171,9 @@ impl<'a> Traverser<'a> {
|
|||
|
||||
for edge in edges {
|
||||
// Check edge type filter
|
||||
if !query.edge_types.is_empty() && !query.edge_types.contains(&edge.edge_type) {
|
||||
if !query.edge_types.is_empty()
|
||||
&& !query.edge_types.contains(&edge.edge_type)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
@ -395,18 +397,35 @@ mod tests {
|
|||
fn setup_social_graph() -> GraphStore {
|
||||
let store = GraphStore::new();
|
||||
|
||||
let alice = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Alice"}));
|
||||
let alice = store.create_node(
|
||||
vec!["User".to_string()],
|
||||
serde_json::json!({"name": "Alice"}),
|
||||
);
|
||||
let bob = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Bob"}));
|
||||
let charlie = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Charlie"}));
|
||||
let dave = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Dave"}));
|
||||
let charlie = store.create_node(
|
||||
vec!["User".to_string()],
|
||||
serde_json::json!({"name": "Charlie"}),
|
||||
);
|
||||
let dave = store.create_node(
|
||||
vec!["User".to_string()],
|
||||
serde_json::json!({"name": "Dave"}),
|
||||
);
|
||||
|
||||
// Alice -> Bob -> Charlie -> Dave
|
||||
store.create_edge(alice, bob, "FRIEND", serde_json::json!({})).unwrap();
|
||||
store.create_edge(bob, charlie, "FRIEND", serde_json::json!({})).unwrap();
|
||||
store.create_edge(charlie, dave, "FRIEND", serde_json::json!({})).unwrap();
|
||||
store
|
||||
.create_edge(alice, bob, "FRIEND", serde_json::json!({}))
|
||||
.unwrap();
|
||||
store
|
||||
.create_edge(bob, charlie, "FRIEND", serde_json::json!({}))
|
||||
.unwrap();
|
||||
store
|
||||
.create_edge(charlie, dave, "FRIEND", serde_json::json!({}))
|
||||
.unwrap();
|
||||
|
||||
// Alice -> Charlie (shortcut)
|
||||
store.create_edge(alice, charlie, "KNOWS", serde_json::json!({})).unwrap();
|
||||
store
|
||||
.create_edge(alice, charlie, "KNOWS", serde_json::json!({}))
|
||||
.unwrap();
|
||||
|
||||
store
|
||||
}
|
||||
|
|
@ -417,7 +436,10 @@ mod tests {
|
|||
let traverser = Traverser::new(&store);
|
||||
|
||||
let users = store.find_nodes_by_label("User");
|
||||
let alice = users.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("Alice"))).unwrap();
|
||||
let alice = users
|
||||
.iter()
|
||||
.find(|n| n.get_property("name") == Some(&serde_json::json!("Alice")))
|
||||
.unwrap();
|
||||
|
||||
let query = TraversalQuery::new().depth(2);
|
||||
let results = traverser.traverse(&alice.id, &query);
|
||||
|
|
@ -432,7 +454,10 @@ mod tests {
|
|||
let traverser = Traverser::new(&store);
|
||||
|
||||
let users = store.find_nodes_by_label("User");
|
||||
let alice = users.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("Alice"))).unwrap();
|
||||
let alice = users
|
||||
.iter()
|
||||
.find(|n| n.get_property("name") == Some(&serde_json::json!("Alice")))
|
||||
.unwrap();
|
||||
|
||||
let query = TraversalQuery::new()
|
||||
.depth(2)
|
||||
|
|
@ -440,7 +465,10 @@ mod tests {
|
|||
let results = traverser.traverse(&alice.id, &query);
|
||||
|
||||
// Following only FRIEND edges: Alice -> Bob -> Charlie
|
||||
let names: Vec<_> = results.iter().filter_map(|r| r.node.get_property("name")).collect();
|
||||
let names: Vec<_> = results
|
||||
.iter()
|
||||
.filter_map(|r| r.node.get_property("name"))
|
||||
.collect();
|
||||
assert!(names.contains(&&serde_json::json!("Bob")));
|
||||
assert!(names.contains(&&serde_json::json!("Charlie")));
|
||||
}
|
||||
|
|
@ -451,7 +479,10 @@ mod tests {
|
|||
let traverser = Traverser::new(&store);
|
||||
|
||||
let users = store.find_nodes_by_label("User");
|
||||
let alice = users.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("Alice"))).unwrap();
|
||||
let alice = users
|
||||
.iter()
|
||||
.find(|n| n.get_property("name") == Some(&serde_json::json!("Alice")))
|
||||
.unwrap();
|
||||
|
||||
let query = TraversalQuery::new().depth(1);
|
||||
let results = traverser.traverse(&alice.id, &query);
|
||||
|
|
@ -468,7 +499,10 @@ mod tests {
|
|||
let traverser = Traverser::new(&store);
|
||||
|
||||
let users = store.find_nodes_by_label("User");
|
||||
let alice = users.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("Alice"))).unwrap();
|
||||
let alice = users
|
||||
.iter()
|
||||
.find(|n| n.get_property("name") == Some(&serde_json::json!("Alice")))
|
||||
.unwrap();
|
||||
|
||||
let query = TraversalQuery::new().depth(10).limit(2);
|
||||
let results = traverser.traverse(&alice.id, &query);
|
||||
|
|
@ -486,11 +520,21 @@ mod tests {
|
|||
let mutual2 = store.create_node(vec![], serde_json::json!({"name": "Mutual2"}));
|
||||
let only_alice = store.create_node(vec![], serde_json::json!({"name": "OnlyAlice"}));
|
||||
|
||||
store.create_edge(alice, mutual1, "FRIEND", serde_json::json!({})).unwrap();
|
||||
store.create_edge(alice, mutual2, "FRIEND", serde_json::json!({})).unwrap();
|
||||
store.create_edge(alice, only_alice, "FRIEND", serde_json::json!({})).unwrap();
|
||||
store.create_edge(bob, mutual1, "FRIEND", serde_json::json!({})).unwrap();
|
||||
store.create_edge(bob, mutual2, "FRIEND", serde_json::json!({})).unwrap();
|
||||
store
|
||||
.create_edge(alice, mutual1, "FRIEND", serde_json::json!({}))
|
||||
.unwrap();
|
||||
store
|
||||
.create_edge(alice, mutual2, "FRIEND", serde_json::json!({}))
|
||||
.unwrap();
|
||||
store
|
||||
.create_edge(alice, only_alice, "FRIEND", serde_json::json!({}))
|
||||
.unwrap();
|
||||
store
|
||||
.create_edge(bob, mutual1, "FRIEND", serde_json::json!({}))
|
||||
.unwrap();
|
||||
store
|
||||
.create_edge(bob, mutual2, "FRIEND", serde_json::json!({}))
|
||||
.unwrap();
|
||||
|
||||
let traverser = Traverser::new(&store);
|
||||
let mutual = traverser.mutual_connections(&alice, &bob, Some("FRIEND"));
|
||||
|
|
|
|||
|
|
@ -174,17 +174,24 @@ impl Index {
|
|||
// Check uniqueness if required
|
||||
if self.config.unique {
|
||||
let exists = match self.config.index_type {
|
||||
IndexType::Hash | IndexType::Unique => {
|
||||
self.hash.read().get(&key).map(|s| !s.is_empty()).unwrap_or(false)
|
||||
}
|
||||
_ => {
|
||||
self.btree.read().get(&key).map(|s| !s.is_empty()).unwrap_or(false)
|
||||
}
|
||||
IndexType::Hash | IndexType::Unique => self
|
||||
.hash
|
||||
.read()
|
||||
.get(&key)
|
||||
.map(|s| !s.is_empty())
|
||||
.unwrap_or(false),
|
||||
_ => self
|
||||
.btree
|
||||
.read()
|
||||
.get(&key)
|
||||
.map(|s| !s.is_empty())
|
||||
.unwrap_or(false),
|
||||
};
|
||||
if exists {
|
||||
return Err(DatabaseError::AlreadyExists(
|
||||
format!("Unique constraint violation on index '{}'", self.config.name)
|
||||
));
|
||||
return Err(DatabaseError::AlreadyExists(format!(
|
||||
"Unique constraint violation on index '{}'",
|
||||
self.config.name
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -239,20 +246,18 @@ impl Index {
|
|||
self.stats.write().lookups += 1;
|
||||
|
||||
let result: Vec<DocumentId> = match self.config.index_type {
|
||||
IndexType::Hash | IndexType::Unique => {
|
||||
self.hash
|
||||
.read()
|
||||
.get(&key)
|
||||
.map(|s| s.iter().cloned().collect())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
_ => {
|
||||
self.btree
|
||||
.read()
|
||||
.get(&key)
|
||||
.map(|s| s.iter().cloned().collect())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
IndexType::Hash | IndexType::Unique => self
|
||||
.hash
|
||||
.read()
|
||||
.get(&key)
|
||||
.map(|s| s.iter().cloned().collect())
|
||||
.unwrap_or_default(),
|
||||
_ => self
|
||||
.btree
|
||||
.read()
|
||||
.get(&key)
|
||||
.map(|s| s.iter().cloned().collect())
|
||||
.unwrap_or_default(),
|
||||
};
|
||||
|
||||
if !result.is_empty() {
|
||||
|
|
@ -407,12 +412,7 @@ impl IndexManager {
|
|||
}
|
||||
|
||||
/// Removes a document from indexes.
|
||||
pub fn unindex_document(
|
||||
&self,
|
||||
collection: &str,
|
||||
doc_id: &DocumentId,
|
||||
document: &JsonValue,
|
||||
) {
|
||||
pub fn unindex_document(&self, collection: &str, doc_id: &DocumentId, document: &JsonValue) {
|
||||
let index_names = self.get_collection_indexes(collection);
|
||||
let indexes = self.indexes.read();
|
||||
|
||||
|
|
@ -483,7 +483,9 @@ mod tests {
|
|||
let index = Index::new(config);
|
||||
|
||||
let doc1 = DocumentId::new();
|
||||
index.insert(doc1.clone(), &json!("alice@example.com")).unwrap();
|
||||
index
|
||||
.insert(doc1.clone(), &json!("alice@example.com"))
|
||||
.unwrap();
|
||||
|
||||
let results = index.lookup(&json!("alice@example.com"));
|
||||
assert_eq!(results.len(), 1);
|
||||
|
|
@ -521,7 +523,9 @@ mod tests {
|
|||
let doc_id = DocumentId::new();
|
||||
let doc = json!({"name": "Alice", "age": 30});
|
||||
|
||||
manager.index_document("users", doc_id.clone(), &doc).unwrap();
|
||||
manager
|
||||
.index_document("users", doc_id.clone(), &doc)
|
||||
.unwrap();
|
||||
|
||||
let indexes = manager.list_indexes();
|
||||
assert_eq!(indexes.len(), 1);
|
||||
|
|
|
|||
|
|
@ -126,8 +126,7 @@ impl KeyValueStore {
|
|||
|
||||
/// Gets a value as string.
|
||||
pub fn get_string(&self, key: &str) -> Option<String> {
|
||||
self.get(key)
|
||||
.and_then(|v| String::from_utf8(v).ok())
|
||||
self.get(key).and_then(|v| String::from_utf8(v).ok())
|
||||
}
|
||||
|
||||
/// Sets a value with optional TTL.
|
||||
|
|
@ -224,8 +223,9 @@ impl KeyValueStore {
|
|||
} else {
|
||||
let s = String::from_utf8(entry.value.clone())
|
||||
.map_err(|_| DatabaseError::InvalidOperation("Value is not a string".into()))?;
|
||||
s.parse::<i64>()
|
||||
.map_err(|_| DatabaseError::InvalidOperation("Value is not an integer".into()))?
|
||||
s.parse::<i64>().map_err(|_| {
|
||||
DatabaseError::InvalidOperation("Value is not an integer".into())
|
||||
})?
|
||||
}
|
||||
} else {
|
||||
0
|
||||
|
|
@ -243,9 +243,9 @@ impl KeyValueStore {
|
|||
pub fn append(&self, key: &str, value: &[u8]) -> Result<usize, DatabaseError> {
|
||||
let mut data = self.data.write();
|
||||
|
||||
let entry = data.entry(key.to_string()).or_insert_with(|| {
|
||||
KvEntry::new(Vec::new(), 0)
|
||||
});
|
||||
let entry = data
|
||||
.entry(key.to_string())
|
||||
.or_insert_with(|| KvEntry::new(Vec::new(), 0));
|
||||
|
||||
if entry.is_expired() {
|
||||
entry.value.clear();
|
||||
|
|
@ -393,11 +393,16 @@ mod tests {
|
|||
fn test_mget_mset() {
|
||||
let store = KeyValueStore::new();
|
||||
|
||||
store.mset(&[
|
||||
("k1", b"v1".to_vec()),
|
||||
("k2", b"v2".to_vec()),
|
||||
("k3", b"v3".to_vec()),
|
||||
], 0).unwrap();
|
||||
store
|
||||
.mset(
|
||||
&[
|
||||
("k1", b"v1".to_vec()),
|
||||
("k2", b"v2".to_vec()),
|
||||
("k3", b"v3".to_vec()),
|
||||
],
|
||||
0,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let results = store.mget(&["k1", "k2", "k4"]);
|
||||
assert_eq!(results.len(), 3);
|
||||
|
|
|
|||
|
|
@ -65,12 +65,14 @@ pub use graph::{
|
|||
pub use index::{Index, IndexConfig, IndexManager, IndexType};
|
||||
pub use keyvalue::{KeyValue, KeyValueStore, KvEntry};
|
||||
pub use query::{Filter, Query, QueryEngine, QueryResult, SortOrder};
|
||||
pub use schema::{Field, FieldType, Schema, SchemaValidator};
|
||||
pub use replication::{
|
||||
ClusterConfig, Command as RaftCommand, NodeRole, RaftConfig, RaftEvent, RaftNode, RaftState,
|
||||
ReplicatedLog,
|
||||
};
|
||||
pub use sql::{QueryResult as SqlQueryResult, SqlEngine, SqlParser, SqlType, SqlValue, Table, TableDef};
|
||||
pub use schema::{Field, FieldType, Schema, SchemaValidator};
|
||||
pub use sql::{
|
||||
QueryResult as SqlQueryResult, SqlEngine, SqlParser, SqlType, SqlValue, Table, TableDef,
|
||||
};
|
||||
pub use timeseries::{DataPoint, Metric, TimeSeries, TimeSeriesStore};
|
||||
pub use vector::{Embedding, SimilarityMetric, VectorIndex, VectorStore};
|
||||
|
||||
|
|
|
|||
|
|
@ -419,10 +419,7 @@ impl QueryEngine {
|
|||
|
||||
let values: Vec<f64> = docs
|
||||
.iter()
|
||||
.filter_map(|doc| {
|
||||
doc.get(field)
|
||||
.and_then(|v| v.as_f64())
|
||||
})
|
||||
.filter_map(|doc| doc.get(field).and_then(|v| v.as_f64()))
|
||||
.collect();
|
||||
|
||||
let result = match op {
|
||||
|
|
@ -439,22 +436,18 @@ impl QueryEngine {
|
|||
serde_json::to_value(avg).unwrap_or(JsonValue::Null)
|
||||
}
|
||||
}
|
||||
AggregateOp::Min => {
|
||||
values
|
||||
.iter()
|
||||
.copied()
|
||||
.min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
|
||||
.map(|v| serde_json::to_value(v).unwrap_or(JsonValue::Null))
|
||||
.unwrap_or(JsonValue::Null)
|
||||
}
|
||||
AggregateOp::Max => {
|
||||
values
|
||||
.iter()
|
||||
.copied()
|
||||
.max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
|
||||
.map(|v| serde_json::to_value(v).unwrap_or(JsonValue::Null))
|
||||
.unwrap_or(JsonValue::Null)
|
||||
}
|
||||
AggregateOp::Min => values
|
||||
.iter()
|
||||
.copied()
|
||||
.min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
|
||||
.map(|v| serde_json::to_value(v).unwrap_or(JsonValue::Null))
|
||||
.unwrap_or(JsonValue::Null),
|
||||
AggregateOp::Max => values
|
||||
.iter()
|
||||
.copied()
|
||||
.max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
|
||||
.map(|v| serde_json::to_value(v).unwrap_or(JsonValue::Null))
|
||||
.unwrap_or(JsonValue::Null),
|
||||
};
|
||||
|
||||
Ok(result)
|
||||
|
|
@ -507,8 +500,10 @@ mod tests {
|
|||
fn test_simple_query() {
|
||||
let docs = Arc::new(DocumentStore::new());
|
||||
docs.create_collection("users").unwrap();
|
||||
docs.insert("users", json!({"name": "Alice", "age": 30})).unwrap();
|
||||
docs.insert("users", json!({"name": "Bob", "age": 25})).unwrap();
|
||||
docs.insert("users", json!({"name": "Alice", "age": 30}))
|
||||
.unwrap();
|
||||
docs.insert("users", json!({"name": "Bob", "age": 25}))
|
||||
.unwrap();
|
||||
|
||||
let vectors = Arc::new(VectorStore::new(3));
|
||||
let indexes = Arc::new(IndexManager::new());
|
||||
|
|
@ -525,8 +520,10 @@ mod tests {
|
|||
fn test_filter_query() {
|
||||
let docs = Arc::new(DocumentStore::new());
|
||||
docs.create_collection("users").unwrap();
|
||||
docs.insert("users", json!({"name": "Alice", "age": 30})).unwrap();
|
||||
docs.insert("users", json!({"name": "Bob", "age": 25})).unwrap();
|
||||
docs.insert("users", json!({"name": "Alice", "age": 30}))
|
||||
.unwrap();
|
||||
docs.insert("users", json!({"name": "Bob", "age": 25}))
|
||||
.unwrap();
|
||||
|
||||
let vectors = Arc::new(VectorStore::new(3));
|
||||
let indexes = Arc::new(IndexManager::new());
|
||||
|
|
@ -543,9 +540,12 @@ mod tests {
|
|||
fn test_sorted_query() {
|
||||
let docs = Arc::new(DocumentStore::new());
|
||||
docs.create_collection("users").unwrap();
|
||||
docs.insert("users", json!({"name": "Alice", "age": 30})).unwrap();
|
||||
docs.insert("users", json!({"name": "Bob", "age": 25})).unwrap();
|
||||
docs.insert("users", json!({"name": "Charlie", "age": 35})).unwrap();
|
||||
docs.insert("users", json!({"name": "Alice", "age": 30}))
|
||||
.unwrap();
|
||||
docs.insert("users", json!({"name": "Bob", "age": 25}))
|
||||
.unwrap();
|
||||
docs.insert("users", json!({"name": "Charlie", "age": 35}))
|
||||
.unwrap();
|
||||
|
||||
let vectors = Arc::new(VectorStore::new(3));
|
||||
let indexes = Arc::new(IndexManager::new());
|
||||
|
|
|
|||
|
|
@ -92,12 +92,7 @@ impl Election {
|
|||
|
||||
/// Creates a RequestVote message for this election.
|
||||
pub fn create_request(&self, log: &ReplicatedLog) -> RequestVote {
|
||||
RequestVote::new(
|
||||
self.term,
|
||||
self.node_id,
|
||||
log.last_index(),
|
||||
log.last_term(),
|
||||
)
|
||||
RequestVote::new(self.term, self.node_id, log.last_index(), log.last_term())
|
||||
}
|
||||
|
||||
/// Checks the current result of the election.
|
||||
|
|
@ -217,8 +212,8 @@ impl Default for ElectionTimeout {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::replication::state::Command;
|
||||
use crate::replication::log::LogEntry;
|
||||
use crate::replication::state::Command;
|
||||
|
||||
#[test]
|
||||
fn test_election_basic() {
|
||||
|
|
|
|||
|
|
@ -139,7 +139,10 @@ impl ReplicatedLog {
|
|||
|
||||
let entries = self.entries.read();
|
||||
let offset = (from_index - start) as usize;
|
||||
entries.get(offset..).map(|s| s.to_vec()).unwrap_or_default()
|
||||
entries
|
||||
.get(offset..)
|
||||
.map(|s| s.to_vec())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Appends an entry to the log.
|
||||
|
|
@ -151,7 +154,12 @@ impl ReplicatedLog {
|
|||
}
|
||||
|
||||
/// Appends multiple entries, potentially overwriting conflicting entries.
|
||||
pub fn append_entries(&self, prev_index: u64, prev_term: u64, new_entries: Vec<LogEntry>) -> bool {
|
||||
pub fn append_entries(
|
||||
&self,
|
||||
prev_index: u64,
|
||||
prev_term: u64,
|
||||
new_entries: Vec<LogEntry>,
|
||||
) -> bool {
|
||||
// Check that prev entry matches
|
||||
if prev_index > 0 {
|
||||
if let Some(prev_entry_term) = self.term_at(prev_index) {
|
||||
|
|
@ -245,7 +253,11 @@ impl ReplicatedLog {
|
|||
}
|
||||
|
||||
/// Creates entries for replication starting from a given index.
|
||||
pub fn entries_for_replication(&self, from_index: u64, max_entries: usize) -> (u64, u64, Vec<LogEntry>) {
|
||||
pub fn entries_for_replication(
|
||||
&self,
|
||||
from_index: u64,
|
||||
max_entries: usize,
|
||||
) -> (u64, u64, Vec<LogEntry>) {
|
||||
let prev_index = from_index.saturating_sub(1);
|
||||
let prev_term = self.term_at(prev_index).unwrap_or(0);
|
||||
|
||||
|
|
|
|||
|
|
@ -222,7 +222,11 @@ impl RaftNode {
|
|||
|
||||
// Create new election
|
||||
let cluster_size = self.cluster.voting_members();
|
||||
self.election = Some(Election::new(self.id, self.state.current_term, cluster_size));
|
||||
self.election = Some(Election::new(
|
||||
self.id,
|
||||
self.state.current_term,
|
||||
cluster_size,
|
||||
));
|
||||
|
||||
// Create RequestVote message
|
||||
let request = RequestVote::new(
|
||||
|
|
@ -295,9 +299,9 @@ impl RaftNode {
|
|||
return;
|
||||
}
|
||||
|
||||
let (prev_log_index, prev_log_term, entries) =
|
||||
self.log
|
||||
.entries_for_replication(next_index, self.config.max_entries_per_rpc);
|
||||
let (prev_log_index, prev_log_term, entries) = self
|
||||
.log
|
||||
.entries_for_replication(next_index, self.config.max_entries_per_rpc);
|
||||
|
||||
let request = AppendEntries::with_entries(
|
||||
self.state.current_term,
|
||||
|
|
@ -308,8 +312,10 @@ impl RaftNode {
|
|||
self.state.commit_index,
|
||||
);
|
||||
|
||||
self.events
|
||||
.push(RaftEvent::SendRpc(peer_id, RpcMessage::AppendEntries(request)));
|
||||
self.events.push(RaftEvent::SendRpc(
|
||||
peer_id,
|
||||
RpcMessage::AppendEntries(request),
|
||||
));
|
||||
}
|
||||
|
||||
fn send_install_snapshot(&mut self, peer_id: NodeId) {
|
||||
|
|
@ -332,8 +338,10 @@ impl RaftNode {
|
|||
done,
|
||||
);
|
||||
|
||||
self.events
|
||||
.push(RaftEvent::SendRpc(peer_id, RpcMessage::InstallSnapshot(request)));
|
||||
self.events.push(RaftEvent::SendRpc(
|
||||
peer_id,
|
||||
RpcMessage::InstallSnapshot(request),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -395,7 +403,11 @@ impl RaftNode {
|
|||
}
|
||||
}
|
||||
|
||||
fn handle_append_entries(&mut self, _from: NodeId, req: AppendEntries) -> AppendEntriesResponse {
|
||||
fn handle_append_entries(
|
||||
&mut self,
|
||||
_from: NodeId,
|
||||
req: AppendEntries,
|
||||
) -> AppendEntriesResponse {
|
||||
// Rule: If term > currentTerm, become follower
|
||||
if req.term > self.state.current_term {
|
||||
self.become_follower(req.term, Some(req.leader_id));
|
||||
|
|
@ -416,9 +428,9 @@ impl RaftNode {
|
|||
}
|
||||
|
||||
// Try to append entries
|
||||
let success =
|
||||
self.log
|
||||
.append_entries(req.prev_log_index, req.prev_log_term, req.entries);
|
||||
let success = self
|
||||
.log
|
||||
.append_entries(req.prev_log_index, req.prev_log_term, req.entries);
|
||||
|
||||
if success {
|
||||
// Update commit index
|
||||
|
|
@ -443,7 +455,11 @@ impl RaftNode {
|
|||
}
|
||||
conflict_index -= 1;
|
||||
}
|
||||
AppendEntriesResponse::conflict(self.state.current_term, conflict_term, conflict_index)
|
||||
AppendEntriesResponse::conflict(
|
||||
self.state.current_term,
|
||||
conflict_term,
|
||||
conflict_index,
|
||||
)
|
||||
} else {
|
||||
AppendEntriesResponse::failure(self.state.current_term)
|
||||
}
|
||||
|
|
@ -502,7 +518,11 @@ impl RaftNode {
|
|||
self.cluster.update_peer_state(from, PeerState::Reachable);
|
||||
}
|
||||
|
||||
fn handle_install_snapshot(&mut self, _from: NodeId, req: InstallSnapshot) -> InstallSnapshotResponse {
|
||||
fn handle_install_snapshot(
|
||||
&mut self,
|
||||
_from: NodeId,
|
||||
req: InstallSnapshot,
|
||||
) -> InstallSnapshotResponse {
|
||||
// Rule: If term > currentTerm, become follower
|
||||
if req.term > self.state.current_term {
|
||||
self.become_follower(req.term, Some(req.leader_id));
|
||||
|
|
@ -692,12 +712,14 @@ impl RaftNode {
|
|||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use super::super::cluster::PeerAddress;
|
||||
use super::*;
|
||||
|
||||
fn create_test_cluster(node_id: NodeId, peers: &[NodeId]) -> ClusterConfig {
|
||||
let mut cluster =
|
||||
ClusterConfig::new(node_id, PeerAddress::new("127.0.0.1", 9000 + node_id as u16));
|
||||
let mut cluster = ClusterConfig::new(
|
||||
node_id,
|
||||
PeerAddress::new("127.0.0.1", 9000 + node_id as u16),
|
||||
);
|
||||
for &peer in peers {
|
||||
cluster.add_peer(super::super::cluster::PeerInfo::new(
|
||||
peer,
|
||||
|
|
|
|||
|
|
@ -176,7 +176,10 @@ impl SnapshotManager {
|
|||
/// Adds a chunk to the pending snapshot.
|
||||
pub fn add_chunk(&mut self, offset: u64, data: Vec<u8>) -> bool {
|
||||
if let Some(ref mut pending) = self.pending_snapshot {
|
||||
if offset == pending.expected_offset + pending.chunks.iter().map(|c| c.len() as u64).sum::<u64>() {
|
||||
if offset
|
||||
== pending.expected_offset
|
||||
+ pending.chunks.iter().map(|c| c.len() as u64).sum::<u64>()
|
||||
{
|
||||
pending.chunks.push(data);
|
||||
return true;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -118,31 +118,55 @@ pub enum Command {
|
|||
|
||||
// Key-Value operations
|
||||
/// Set a key-value pair.
|
||||
KvSet { key: String, value: Vec<u8>, ttl: Option<u64> },
|
||||
KvSet {
|
||||
key: String,
|
||||
value: Vec<u8>,
|
||||
ttl: Option<u64>,
|
||||
},
|
||||
/// Delete a key.
|
||||
KvDelete { key: String },
|
||||
|
||||
// Document operations
|
||||
/// Insert a document.
|
||||
DocInsert { collection: String, document: JsonValue },
|
||||
DocInsert {
|
||||
collection: String,
|
||||
document: JsonValue,
|
||||
},
|
||||
/// Update a document.
|
||||
DocUpdate { collection: String, id: String, update: JsonValue },
|
||||
DocUpdate {
|
||||
collection: String,
|
||||
id: String,
|
||||
update: JsonValue,
|
||||
},
|
||||
/// Delete a document.
|
||||
DocDelete { collection: String, id: String },
|
||||
|
||||
// Vector operations
|
||||
/// Insert a vector.
|
||||
VectorInsert { namespace: String, id: String, vector: Vec<f32>, metadata: JsonValue },
|
||||
VectorInsert {
|
||||
namespace: String,
|
||||
id: String,
|
||||
vector: Vec<f32>,
|
||||
metadata: JsonValue,
|
||||
},
|
||||
/// Delete a vector.
|
||||
VectorDelete { namespace: String, id: String },
|
||||
|
||||
// Time-series operations
|
||||
/// Record a metric data point.
|
||||
TimeSeriesRecord { metric: String, value: f64, timestamp: u64, tags: JsonValue },
|
||||
TimeSeriesRecord {
|
||||
metric: String,
|
||||
value: f64,
|
||||
timestamp: u64,
|
||||
tags: JsonValue,
|
||||
},
|
||||
|
||||
// Graph operations
|
||||
/// Create a graph node.
|
||||
GraphNodeCreate { labels: Vec<String>, properties: JsonValue },
|
||||
GraphNodeCreate {
|
||||
labels: Vec<String>,
|
||||
properties: JsonValue,
|
||||
},
|
||||
/// Delete a graph node.
|
||||
GraphNodeDelete { id: String },
|
||||
/// Create a graph edge.
|
||||
|
|
@ -161,13 +185,20 @@ pub enum Command {
|
|||
|
||||
// Schema operations
|
||||
/// Create a collection/table.
|
||||
CreateCollection { name: String, schema: Option<JsonValue> },
|
||||
CreateCollection {
|
||||
name: String,
|
||||
schema: Option<JsonValue>,
|
||||
},
|
||||
/// Drop a collection/table.
|
||||
DropCollection { name: String },
|
||||
|
||||
// Index operations
|
||||
/// Create an index.
|
||||
CreateIndex { collection: String, field: String, index_type: String },
|
||||
CreateIndex {
|
||||
collection: String,
|
||||
field: String,
|
||||
index_type: String,
|
||||
},
|
||||
/// Drop an index.
|
||||
DropIndex { name: String },
|
||||
|
||||
|
|
@ -265,7 +296,12 @@ impl LeaderState {
|
|||
}
|
||||
|
||||
/// Calculates the new commit index based on majority replication.
|
||||
pub fn calculate_commit_index(&self, current_commit: u64, current_term: u64, log_term_at: impl Fn(u64) -> Option<u64>) -> u64 {
|
||||
pub fn calculate_commit_index(
|
||||
&self,
|
||||
current_commit: u64,
|
||||
current_term: u64,
|
||||
log_term_at: impl Fn(u64) -> Option<u64>,
|
||||
) -> u64 {
|
||||
// Find the highest index that a majority have replicated
|
||||
let mut indices: Vec<u64> = self.match_index.values().cloned().collect();
|
||||
indices.sort_unstable();
|
||||
|
|
|
|||
|
|
@ -315,8 +315,8 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_vector_field() {
|
||||
let schema = Schema::new("embedding")
|
||||
.field(Field::required("vector", FieldType::Vector(3)));
|
||||
let schema =
|
||||
Schema::new("embedding").field(Field::required("vector", FieldType::Vector(3)));
|
||||
|
||||
let mut validator = SchemaValidator::new();
|
||||
validator.register(schema);
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
//! SQL query executor.
|
||||
|
||||
use super::parser::{
|
||||
BinaryOp, ParsedExpr, ParsedSelect, ParsedSelectItem,
|
||||
ParsedStatement, SqlParser,
|
||||
BinaryOp, ParsedExpr, ParsedSelect, ParsedSelectItem, ParsedStatement, SqlParser,
|
||||
};
|
||||
use super::row::{Row, RowId};
|
||||
use super::table::{ColumnDef, Table, TableDef};
|
||||
|
|
@ -192,11 +191,7 @@ impl SqlEngine {
|
|||
match a_val.partial_cmp(&b_val) {
|
||||
Some(std::cmp::Ordering::Equal) => continue,
|
||||
Some(ord) => {
|
||||
return if ob.ascending {
|
||||
ord
|
||||
} else {
|
||||
ord.reverse()
|
||||
};
|
||||
return if ob.ascending { ord } else { ord.reverse() };
|
||||
}
|
||||
None => continue,
|
||||
}
|
||||
|
|
@ -216,7 +211,11 @@ impl SqlEngine {
|
|||
}
|
||||
|
||||
// Handle aggregates
|
||||
if select.columns.iter().any(|c| matches!(c, ParsedSelectItem::Aggregate { .. })) {
|
||||
if select
|
||||
.columns
|
||||
.iter()
|
||||
.any(|c| matches!(c, ParsedSelectItem::Aggregate { .. }))
|
||||
{
|
||||
return self.execute_aggregate(select, &rows, table);
|
||||
}
|
||||
|
||||
|
|
@ -244,7 +243,9 @@ impl SqlEngine {
|
|||
ParsedSelectItem::Wildcard => table.def.column_names(),
|
||||
ParsedSelectItem::Column(name) => vec![name.clone()],
|
||||
ParsedSelectItem::ColumnAlias { alias, .. } => vec![alias.clone()],
|
||||
ParsedSelectItem::Aggregate { function, alias, .. } => {
|
||||
ParsedSelectItem::Aggregate {
|
||||
function, alias, ..
|
||||
} => {
|
||||
vec![alias.clone().unwrap_or_else(|| function.clone())]
|
||||
}
|
||||
})
|
||||
|
|
@ -328,7 +329,9 @@ impl SqlEngine {
|
|||
rows.iter()
|
||||
.map(|r| r.get_or_null(col))
|
||||
.filter(|v| !v.is_null())
|
||||
.min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
|
||||
.min_by(|a, b| {
|
||||
a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
|
||||
})
|
||||
.unwrap_or(SqlValue::Null)
|
||||
}
|
||||
"MAX" => {
|
||||
|
|
@ -338,12 +341,12 @@ impl SqlEngine {
|
|||
rows.iter()
|
||||
.map(|r| r.get_or_null(col))
|
||||
.filter(|v| !v.is_null())
|
||||
.max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
|
||||
.max_by(|a, b| {
|
||||
a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
|
||||
})
|
||||
.unwrap_or(SqlValue::Null)
|
||||
}
|
||||
_ => {
|
||||
return Err(SqlError::Unsupported(format!("Function: {}", function)))
|
||||
}
|
||||
_ => return Err(SqlError::Unsupported(format!("Function: {}", function))),
|
||||
};
|
||||
result_values.push(value);
|
||||
}
|
||||
|
|
@ -404,7 +407,11 @@ impl SqlEngine {
|
|||
ParsedExpr::IsNotNull(inner) => {
|
||||
SqlValue::Boolean(!self.evaluate_expr(row, inner).is_null())
|
||||
}
|
||||
ParsedExpr::InList { expr, list, negated } => {
|
||||
ParsedExpr::InList {
|
||||
expr,
|
||||
list,
|
||||
negated,
|
||||
} => {
|
||||
let val = self.evaluate_expr(row, expr);
|
||||
let in_list = list.iter().any(|item| {
|
||||
let item_val = self.evaluate_expr(row, item);
|
||||
|
|
@ -424,9 +431,7 @@ impl SqlEngine {
|
|||
let between = val >= low_val && val <= high_val;
|
||||
SqlValue::Boolean(if *negated { !between } else { between })
|
||||
}
|
||||
ParsedExpr::Function { name, args } => {
|
||||
self.evaluate_function(row, name, args)
|
||||
}
|
||||
ParsedExpr::Function { name, args } => self.evaluate_function(row, name, args),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -474,9 +479,7 @@ impl SqlEngine {
|
|||
_ => SqlValue::Null,
|
||||
},
|
||||
BinaryOp::Divide => match (left, right) {
|
||||
(SqlValue::Integer(a), SqlValue::Integer(b)) if *b != 0 => {
|
||||
SqlValue::Integer(a / b)
|
||||
}
|
||||
(SqlValue::Integer(a), SqlValue::Integer(b)) if *b != 0 => SqlValue::Integer(a / b),
|
||||
(SqlValue::Real(a), SqlValue::Real(b)) if *b != 0.0 => SqlValue::Real(a / b),
|
||||
_ => SqlValue::Null,
|
||||
},
|
||||
|
|
@ -536,9 +539,7 @@ impl SqlEngine {
|
|||
/// Matches a LIKE pattern.
|
||||
fn match_like(&self, text: &str, pattern: &str) -> bool {
|
||||
// Simple LIKE implementation: % = any chars, _ = single char
|
||||
let _regex_pattern = pattern
|
||||
.replace('%', ".*")
|
||||
.replace('_', ".");
|
||||
let _regex_pattern = pattern.replace('%', ".*").replace('_', ".");
|
||||
// For simplicity, just do case-insensitive contains for now
|
||||
if pattern.starts_with('%') && pattern.ends_with('%') {
|
||||
let inner = &pattern[1..pattern.len() - 1];
|
||||
|
|
@ -615,8 +616,7 @@ impl SqlEngine {
|
|||
.unwrap_or(true);
|
||||
|
||||
if matches {
|
||||
let updates: HashMap<String, SqlValue> =
|
||||
assignments.iter().cloned().collect();
|
||||
let updates: HashMap<String, SqlValue> = assignments.iter().cloned().collect();
|
||||
table.update(row.id, updates)?;
|
||||
count += 1;
|
||||
}
|
||||
|
|
@ -672,9 +672,9 @@ impl SqlEngine {
|
|||
.ok_or_else(|| SqlError::TableNotFound(table_name.to_string()))?;
|
||||
|
||||
// For simplicity, only support single-column indexes
|
||||
let column = columns
|
||||
.first()
|
||||
.ok_or_else(|| SqlError::InvalidOperation("Index requires at least one column".to_string()))?;
|
||||
let column = columns.first().ok_or_else(|| {
|
||||
SqlError::InvalidOperation("Index requires at least one column".to_string())
|
||||
})?;
|
||||
|
||||
table.create_index(name, column, unique)?;
|
||||
Ok(QueryResult::empty())
|
||||
|
|
@ -775,7 +775,9 @@ mod tests {
|
|||
.execute("INSERT INTO users (id, name, age) VALUES (2, 'Bob', 25)")
|
||||
.unwrap();
|
||||
|
||||
let result = engine.execute("SELECT name FROM users WHERE age > 26").unwrap();
|
||||
let result = engine
|
||||
.execute("SELECT name FROM users WHERE age > 26")
|
||||
.unwrap();
|
||||
assert_eq!(result.rows.len(), 1);
|
||||
assert_eq!(result.rows[0][0], SqlValue::Text("Alice".to_string()));
|
||||
}
|
||||
|
|
@ -806,7 +808,9 @@ mod tests {
|
|||
engine
|
||||
.execute(&format!(
|
||||
"INSERT INTO users (id, name, age) VALUES ({}, 'User{}', {})",
|
||||
i, i, 20 + i
|
||||
i,
|
||||
i,
|
||||
20 + i
|
||||
))
|
||||
.unwrap();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,10 +18,7 @@ pub enum ParsedStatement {
|
|||
if_not_exists: bool,
|
||||
},
|
||||
/// DROP TABLE statement.
|
||||
DropTable {
|
||||
name: String,
|
||||
if_exists: bool,
|
||||
},
|
||||
DropTable { name: String, if_exists: bool },
|
||||
/// SELECT statement.
|
||||
Select(ParsedSelect),
|
||||
/// INSERT statement.
|
||||
|
|
@ -179,15 +176,17 @@ impl SqlParser {
|
|||
/// Parses a SQL statement.
|
||||
pub fn parse(sql: &str) -> Result<ParsedStatement, SqlError> {
|
||||
let dialect = SQLiteDialect {};
|
||||
let statements = Parser::parse_sql(&dialect, sql)
|
||||
.map_err(|e| SqlError::Parse(e.to_string()))?;
|
||||
let statements =
|
||||
Parser::parse_sql(&dialect, sql).map_err(|e| SqlError::Parse(e.to_string()))?;
|
||||
|
||||
if statements.is_empty() {
|
||||
return Err(SqlError::Parse("Empty SQL statement".to_string()));
|
||||
}
|
||||
|
||||
if statements.len() > 1 {
|
||||
return Err(SqlError::Parse("Multiple statements not supported".to_string()));
|
||||
return Err(SqlError::Parse(
|
||||
"Multiple statements not supported".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
Self::convert_statement(&statements[0])
|
||||
|
|
@ -195,25 +194,42 @@ impl SqlParser {
|
|||
|
||||
fn convert_statement(stmt: &Statement) -> Result<ParsedStatement, SqlError> {
|
||||
match stmt {
|
||||
Statement::CreateTable { name, columns, if_not_exists, constraints, .. } => {
|
||||
Self::convert_create_table(name, columns, constraints, *if_not_exists)
|
||||
}
|
||||
Statement::Drop { object_type, names, if_exists, .. } => {
|
||||
Self::convert_drop(object_type, names, *if_exists)
|
||||
}
|
||||
Statement::CreateTable {
|
||||
name,
|
||||
columns,
|
||||
if_not_exists,
|
||||
constraints,
|
||||
..
|
||||
} => Self::convert_create_table(name, columns, constraints, *if_not_exists),
|
||||
Statement::Drop {
|
||||
object_type,
|
||||
names,
|
||||
if_exists,
|
||||
..
|
||||
} => Self::convert_drop(object_type, names, *if_exists),
|
||||
Statement::Query(query) => Self::convert_query(query),
|
||||
Statement::Insert { table_name, columns, source, .. } => {
|
||||
Self::convert_insert(table_name, columns, source)
|
||||
}
|
||||
Statement::Update { table, assignments, selection, .. } => {
|
||||
Self::convert_update(table, assignments, selection)
|
||||
}
|
||||
Statement::Delete { from, selection, .. } => {
|
||||
Self::convert_delete(from, selection)
|
||||
}
|
||||
Statement::CreateIndex { name, table_name, columns, unique, .. } => {
|
||||
Self::convert_create_index(name, table_name, columns, *unique)
|
||||
}
|
||||
Statement::Insert {
|
||||
table_name,
|
||||
columns,
|
||||
source,
|
||||
..
|
||||
} => Self::convert_insert(table_name, columns, source),
|
||||
Statement::Update {
|
||||
table,
|
||||
assignments,
|
||||
selection,
|
||||
..
|
||||
} => Self::convert_update(table, assignments, selection),
|
||||
Statement::Delete {
|
||||
from, selection, ..
|
||||
} => Self::convert_delete(from, selection),
|
||||
Statement::CreateIndex {
|
||||
name,
|
||||
table_name,
|
||||
columns,
|
||||
unique,
|
||||
..
|
||||
} => Self::convert_create_index(name, table_name, columns, *unique),
|
||||
_ => Err(SqlError::Unsupported(format!("Statement not supported"))),
|
||||
}
|
||||
}
|
||||
|
|
@ -230,7 +246,12 @@ impl SqlParser {
|
|||
|
||||
// Extract primary keys from table constraints
|
||||
for constraint in constraints {
|
||||
if let sqlparser::ast::TableConstraint::Unique { columns: pk_cols, is_primary: true, .. } = constraint {
|
||||
if let sqlparser::ast::TableConstraint::Unique {
|
||||
columns: pk_cols,
|
||||
is_primary: true,
|
||||
..
|
||||
} = constraint
|
||||
{
|
||||
for col in pk_cols {
|
||||
primary_keys.push(col.value.clone());
|
||||
}
|
||||
|
|
@ -296,10 +317,9 @@ impl SqlParser {
|
|||
DataType::Real | DataType::Float(_) | DataType::Double | DataType::DoublePrecision => {
|
||||
Ok(SqlType::Real)
|
||||
}
|
||||
DataType::Varchar(_)
|
||||
| DataType::Char(_)
|
||||
| DataType::Text
|
||||
| DataType::String(_) => Ok(SqlType::Text),
|
||||
DataType::Varchar(_) | DataType::Char(_) | DataType::Text | DataType::String(_) => {
|
||||
Ok(SqlType::Text)
|
||||
}
|
||||
DataType::Binary(_) | DataType::Varbinary(_) | DataType::Blob(_) => Ok(SqlType::Blob),
|
||||
DataType::Boolean => Ok(SqlType::Boolean),
|
||||
DataType::Timestamp(_, _) | DataType::Date | DataType::Datetime(_) => {
|
||||
|
|
@ -367,10 +387,7 @@ impl SqlParser {
|
|||
.collect();
|
||||
|
||||
// Parse LIMIT/OFFSET
|
||||
let limit = query
|
||||
.limit
|
||||
.as_ref()
|
||||
.and_then(|l| Self::expr_to_usize(l));
|
||||
let limit = query.limit.as_ref().and_then(|l| Self::expr_to_usize(l));
|
||||
let offset = query
|
||||
.offset
|
||||
.as_ref()
|
||||
|
|
@ -403,16 +420,18 @@ impl SqlParser {
|
|||
Self::convert_select_expr(expr)
|
||||
}
|
||||
}
|
||||
_ => Err(SqlError::Unsupported("Select item not supported".to_string())),
|
||||
_ => Err(SqlError::Unsupported(
|
||||
"Select item not supported".to_string(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_select_expr(expr: &Expr) -> Result<ParsedSelectItem, SqlError> {
|
||||
match expr {
|
||||
Expr::Identifier(id) => Ok(ParsedSelectItem::Column(id.value.clone())),
|
||||
Expr::CompoundIdentifier(ids) => {
|
||||
Ok(ParsedSelectItem::Column(ids.last().map(|i| i.value.clone()).unwrap_or_default()))
|
||||
}
|
||||
Expr::CompoundIdentifier(ids) => Ok(ParsedSelectItem::Column(
|
||||
ids.last().map(|i| i.value.clone()).unwrap_or_default(),
|
||||
)),
|
||||
Expr::Function(func) => {
|
||||
let name = func.name.to_string().to_uppercase();
|
||||
// Try to extract column from first arg - simplified for compatibility
|
||||
|
|
@ -423,14 +442,18 @@ impl SqlParser {
|
|||
alias: None,
|
||||
})
|
||||
}
|
||||
_ => Err(SqlError::Unsupported("Select expression not supported".to_string())),
|
||||
_ => Err(SqlError::Unsupported(
|
||||
"Select expression not supported".to_string(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_table_factor(factor: &TableFactor) -> Result<String, SqlError> {
|
||||
match factor {
|
||||
TableFactor::Table { name, .. } => Ok(name.to_string()),
|
||||
_ => Err(SqlError::Unsupported("Table factor not supported".to_string())),
|
||||
_ => Err(SqlError::Unsupported(
|
||||
"Table factor not supported".to_string(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -461,9 +484,9 @@ impl SqlParser {
|
|||
fn convert_expr(expr: &Expr) -> Result<ParsedExpr, SqlError> {
|
||||
match expr {
|
||||
Expr::Identifier(id) => Ok(ParsedExpr::Column(id.value.clone())),
|
||||
Expr::CompoundIdentifier(ids) => {
|
||||
Ok(ParsedExpr::Column(ids.last().map(|i| i.value.clone()).unwrap_or_default()))
|
||||
}
|
||||
Expr::CompoundIdentifier(ids) => Ok(ParsedExpr::Column(
|
||||
ids.last().map(|i| i.value.clone()).unwrap_or_default(),
|
||||
)),
|
||||
Expr::Value(v) => Ok(ParsedExpr::Literal(Self::convert_value(v)?)),
|
||||
Expr::BinaryOp { left, op, right } => {
|
||||
let left = Box::new(Self::convert_expr(left)?);
|
||||
|
|
@ -471,17 +494,30 @@ impl SqlParser {
|
|||
let op = Self::convert_binary_op(op)?;
|
||||
Ok(ParsedExpr::BinaryOp { left, op, right })
|
||||
}
|
||||
Expr::UnaryOp { op: sqlparser::ast::UnaryOperator::Not, expr } => {
|
||||
Ok(ParsedExpr::Not(Box::new(Self::convert_expr(expr)?)))
|
||||
}
|
||||
Expr::UnaryOp {
|
||||
op: sqlparser::ast::UnaryOperator::Not,
|
||||
expr,
|
||||
} => Ok(ParsedExpr::Not(Box::new(Self::convert_expr(expr)?))),
|
||||
Expr::IsNull(expr) => Ok(ParsedExpr::IsNull(Box::new(Self::convert_expr(expr)?))),
|
||||
Expr::IsNotNull(expr) => Ok(ParsedExpr::IsNotNull(Box::new(Self::convert_expr(expr)?))),
|
||||
Expr::InList { expr, list, negated } => Ok(ParsedExpr::InList {
|
||||
Expr::InList {
|
||||
expr,
|
||||
list,
|
||||
negated,
|
||||
} => Ok(ParsedExpr::InList {
|
||||
expr: Box::new(Self::convert_expr(expr)?),
|
||||
list: list.iter().map(Self::convert_expr).collect::<Result<_, _>>()?,
|
||||
list: list
|
||||
.iter()
|
||||
.map(Self::convert_expr)
|
||||
.collect::<Result<_, _>>()?,
|
||||
negated: *negated,
|
||||
}),
|
||||
Expr::Between { expr, low, high, negated } => Ok(ParsedExpr::Between {
|
||||
Expr::Between {
|
||||
expr,
|
||||
low,
|
||||
high,
|
||||
negated,
|
||||
} => Ok(ParsedExpr::Between {
|
||||
expr: Box::new(Self::convert_expr(expr)?),
|
||||
low: Box::new(Self::convert_expr(low)?),
|
||||
high: Box::new(Self::convert_expr(high)?),
|
||||
|
|
@ -490,10 +526,16 @@ impl SqlParser {
|
|||
Expr::Like { expr, pattern, .. } => {
|
||||
let left = Box::new(Self::convert_expr(expr)?);
|
||||
let right = Box::new(Self::convert_expr(pattern)?);
|
||||
Ok(ParsedExpr::BinaryOp { left, op: BinaryOp::Like, right })
|
||||
Ok(ParsedExpr::BinaryOp {
|
||||
left,
|
||||
op: BinaryOp::Like,
|
||||
right,
|
||||
})
|
||||
}
|
||||
Expr::Nested(inner) => Self::convert_expr(inner),
|
||||
_ => Err(SqlError::Unsupported("Expression not supported".to_string())),
|
||||
_ => Err(SqlError::Unsupported(
|
||||
"Expression not supported".to_string(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -587,7 +629,11 @@ impl SqlParser {
|
|||
let parsed_assignments: Vec<(String, SqlValue)> = assignments
|
||||
.iter()
|
||||
.map(|a| {
|
||||
let col = a.id.iter().map(|i| i.value.clone()).collect::<Vec<_>>().join(".");
|
||||
let col =
|
||||
a.id.iter()
|
||||
.map(|i| i.value.clone())
|
||||
.collect::<Vec<_>>()
|
||||
.join(".");
|
||||
let val = Self::convert_value_expr(&a.value)?;
|
||||
Ok((col, val))
|
||||
})
|
||||
|
|
@ -633,10 +679,7 @@ impl SqlParser {
|
|||
|
||||
let table = table_name.to_string();
|
||||
|
||||
let cols: Vec<String> = columns
|
||||
.iter()
|
||||
.map(|c| c.expr.to_string())
|
||||
.collect();
|
||||
let cols: Vec<String> = columns.iter().map(|c| c.expr.to_string()).collect();
|
||||
|
||||
Ok(ParsedStatement::CreateIndex {
|
||||
name: index_name,
|
||||
|
|
@ -694,7 +737,12 @@ mod tests {
|
|||
let sql = "INSERT INTO users (name, age) VALUES ('Alice', 30), ('Bob', 25)";
|
||||
let stmt = SqlParser::parse(sql).unwrap();
|
||||
|
||||
if let ParsedStatement::Insert { table, columns, values } = stmt {
|
||||
if let ParsedStatement::Insert {
|
||||
table,
|
||||
columns,
|
||||
values,
|
||||
} = stmt
|
||||
{
|
||||
assert_eq!(table, "users");
|
||||
assert_eq!(columns, vec!["name", "age"]);
|
||||
assert_eq!(values.len(), 2);
|
||||
|
|
@ -708,7 +756,12 @@ mod tests {
|
|||
let sql = "UPDATE users SET age = 31 WHERE name = 'Alice'";
|
||||
let stmt = SqlParser::parse(sql).unwrap();
|
||||
|
||||
if let ParsedStatement::Update { table, assignments, where_clause } = stmt {
|
||||
if let ParsedStatement::Update {
|
||||
table,
|
||||
assignments,
|
||||
where_clause,
|
||||
} = stmt
|
||||
{
|
||||
assert_eq!(table, "users");
|
||||
assert_eq!(assignments.len(), 1);
|
||||
assert!(where_clause.is_some());
|
||||
|
|
@ -722,7 +775,11 @@ mod tests {
|
|||
let sql = "DELETE FROM users WHERE age < 18";
|
||||
let stmt = SqlParser::parse(sql).unwrap();
|
||||
|
||||
if let ParsedStatement::Delete { table, where_clause } = stmt {
|
||||
if let ParsedStatement::Delete {
|
||||
table,
|
||||
where_clause,
|
||||
} = stmt
|
||||
{
|
||||
assert_eq!(table, "users");
|
||||
assert!(where_clause.is_some());
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -85,7 +85,10 @@ impl Row {
|
|||
|
||||
/// Returns all values in column order.
|
||||
pub fn values(&self) -> Vec<&SqlValue> {
|
||||
self.columns.iter().map(|c| self.values.get(c).unwrap()).collect()
|
||||
self.columns
|
||||
.iter()
|
||||
.map(|c| self.values.get(c).unwrap())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Returns the number of columns.
|
||||
|
|
|
|||
|
|
@ -299,7 +299,10 @@ impl Table {
|
|||
|
||||
let mut indexes = self.indexes.write();
|
||||
if indexes.contains_key(&name) {
|
||||
return Err(SqlError::InvalidOperation(format!("Index '{}' already exists", name)));
|
||||
return Err(SqlError::InvalidOperation(format!(
|
||||
"Index '{}' already exists",
|
||||
name
|
||||
)));
|
||||
}
|
||||
|
||||
let mut index = TableIndex::new(&name, &column, unique);
|
||||
|
|
@ -319,7 +322,10 @@ impl Table {
|
|||
pub fn drop_index(&self, name: &str) -> Result<(), SqlError> {
|
||||
let mut indexes = self.indexes.write();
|
||||
if indexes.remove(name).is_none() {
|
||||
return Err(SqlError::InvalidOperation(format!("Index '{}' not found", name)));
|
||||
return Err(SqlError::InvalidOperation(format!(
|
||||
"Index '{}' not found",
|
||||
name
|
||||
)));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
|
@ -371,9 +377,9 @@ impl Table {
|
|||
/// Updates a row.
|
||||
pub fn update(&self, id: RowId, updates: HashMap<String, SqlValue>) -> Result<(), SqlError> {
|
||||
let mut rows = self.rows.write();
|
||||
let row = rows.get_mut(&id).ok_or_else(|| {
|
||||
SqlError::InvalidOperation(format!("Row {} not found", id))
|
||||
})?;
|
||||
let row = rows
|
||||
.get_mut(&id)
|
||||
.ok_or_else(|| SqlError::InvalidOperation(format!("Row {} not found", id)))?;
|
||||
|
||||
let old_values: HashMap<String, SqlValue> = updates
|
||||
.keys()
|
||||
|
|
@ -392,7 +398,10 @@ impl Table {
|
|||
let mut indexes = self.indexes.write();
|
||||
for (_, index) in indexes.iter_mut() {
|
||||
if let Some(new_value) = updates.get(&index.column) {
|
||||
let old_value = old_values.get(&index.column).cloned().unwrap_or(SqlValue::Null);
|
||||
let old_value = old_values
|
||||
.get(&index.column)
|
||||
.cloned()
|
||||
.unwrap_or(SqlValue::Null);
|
||||
index.remove(&old_value, &id);
|
||||
index.insert(new_value.clone(), id)?;
|
||||
}
|
||||
|
|
@ -480,7 +489,10 @@ mod tests {
|
|||
values.insert("id".to_string(), SqlValue::Integer(1));
|
||||
values.insert("name".to_string(), SqlValue::Text("Alice".to_string()));
|
||||
values.insert("age".to_string(), SqlValue::Integer(30));
|
||||
values.insert("email".to_string(), SqlValue::Text("alice@example.com".to_string()));
|
||||
values.insert(
|
||||
"email".to_string(),
|
||||
SqlValue::Text("alice@example.com".to_string()),
|
||||
);
|
||||
|
||||
let row_id = table.insert(values).unwrap();
|
||||
assert_eq!(table.count(), 1);
|
||||
|
|
@ -508,13 +520,19 @@ mod tests {
|
|||
let mut values1 = HashMap::new();
|
||||
values1.insert("id".to_string(), SqlValue::Integer(1));
|
||||
values1.insert("name".to_string(), SqlValue::Text("Alice".to_string()));
|
||||
values1.insert("email".to_string(), SqlValue::Text("test@example.com".to_string()));
|
||||
values1.insert(
|
||||
"email".to_string(),
|
||||
SqlValue::Text("test@example.com".to_string()),
|
||||
);
|
||||
table.insert(values1).unwrap();
|
||||
|
||||
let mut values2 = HashMap::new();
|
||||
values2.insert("id".to_string(), SqlValue::Integer(2));
|
||||
values2.insert("name".to_string(), SqlValue::Text("Bob".to_string()));
|
||||
values2.insert("email".to_string(), SqlValue::Text("test@example.com".to_string()));
|
||||
values2.insert(
|
||||
"email".to_string(),
|
||||
SqlValue::Text("test@example.com".to_string()),
|
||||
);
|
||||
|
||||
let result = table.insert(values2);
|
||||
assert!(result.is_err()); // Duplicate email
|
||||
|
|
|
|||
|
|
@ -124,7 +124,12 @@ impl Transaction {
|
|||
}
|
||||
|
||||
/// Records an insert operation.
|
||||
pub fn record_insert(&mut self, table: String, row_id: RowId, values: HashMap<String, SqlValue>) {
|
||||
pub fn record_insert(
|
||||
&mut self,
|
||||
table: String,
|
||||
row_id: RowId,
|
||||
values: HashMap<String, SqlValue>,
|
||||
) {
|
||||
self.operations.push(TransactionOp::Insert {
|
||||
table,
|
||||
row_id,
|
||||
|
|
@ -149,7 +154,12 @@ impl Transaction {
|
|||
}
|
||||
|
||||
/// Records a delete operation.
|
||||
pub fn record_delete(&mut self, table: String, row_id: RowId, old_values: HashMap<String, SqlValue>) {
|
||||
pub fn record_delete(
|
||||
&mut self,
|
||||
table: String,
|
||||
row_id: RowId,
|
||||
old_values: HashMap<String, SqlValue>,
|
||||
) {
|
||||
self.operations.push(TransactionOp::Delete {
|
||||
table,
|
||||
row_id,
|
||||
|
|
@ -213,7 +223,10 @@ impl TransactionManager {
|
|||
.ok_or_else(|| SqlError::Transaction(format!("Transaction {} not found", id)))?;
|
||||
|
||||
if !txn.is_active() {
|
||||
return Err(SqlError::Transaction(format!("Transaction {} is not active", id)));
|
||||
return Err(SqlError::Transaction(format!(
|
||||
"Transaction {} is not active",
|
||||
id
|
||||
)));
|
||||
}
|
||||
|
||||
txn.operations.push(op);
|
||||
|
|
@ -228,7 +241,10 @@ impl TransactionManager {
|
|||
.ok_or_else(|| SqlError::Transaction(format!("Transaction {} not found", id)))?;
|
||||
|
||||
if !txn.is_active() {
|
||||
return Err(SqlError::Transaction(format!("Transaction {} is not active", id)));
|
||||
return Err(SqlError::Transaction(format!(
|
||||
"Transaction {} is not active",
|
||||
id
|
||||
)));
|
||||
}
|
||||
|
||||
txn.mark_committed();
|
||||
|
|
@ -245,7 +261,10 @@ impl TransactionManager {
|
|||
.ok_or_else(|| SqlError::Transaction(format!("Transaction {} not found", id)))?;
|
||||
|
||||
if !txn.is_active() {
|
||||
return Err(SqlError::Transaction(format!("Transaction {} is not active", id)));
|
||||
return Err(SqlError::Transaction(format!(
|
||||
"Transaction {} is not active",
|
||||
id
|
||||
)));
|
||||
}
|
||||
|
||||
txn.mark_rolled_back();
|
||||
|
|
|
|||
|
|
@ -233,12 +233,8 @@ impl Ord for SqlValue {
|
|||
(SqlValue::Blob(a), SqlValue::Blob(b)) => a.cmp(b),
|
||||
(SqlValue::Boolean(a), SqlValue::Boolean(b)) => a.cmp(b),
|
||||
(SqlValue::Timestamp(a), SqlValue::Timestamp(b)) => a.cmp(b),
|
||||
(SqlValue::Integer(a), SqlValue::Real(b)) => {
|
||||
(*a as f64).to_bits().cmp(&b.to_bits())
|
||||
}
|
||||
(SqlValue::Real(a), SqlValue::Integer(b)) => {
|
||||
a.to_bits().cmp(&(*b as f64).to_bits())
|
||||
}
|
||||
(SqlValue::Integer(a), SqlValue::Real(b)) => (*a as f64).to_bits().cmp(&b.to_bits()),
|
||||
(SqlValue::Real(a), SqlValue::Integer(b)) => a.to_bits().cmp(&(*b as f64).to_bits()),
|
||||
// Different types: order by type discriminant
|
||||
_ => self.type_order().cmp(&other.type_order()),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -158,11 +158,7 @@ impl Metric {
|
|||
|
||||
/// Calculates sum in a time range.
|
||||
pub fn sum(&self, start: u64, end: u64) -> f64 {
|
||||
self.data
|
||||
.read()
|
||||
.range(start..=end)
|
||||
.map(|(_, &v)| v)
|
||||
.sum()
|
||||
self.data.read().range(start..=end).map(|(_, &v)| v).sum()
|
||||
}
|
||||
|
||||
/// Counts data points in a time range.
|
||||
|
|
|
|||
|
|
@ -207,9 +207,7 @@ impl VectorIndex {
|
|||
let embeddings = self.embeddings.read();
|
||||
let mut results: Vec<VectorSearchResult> = embeddings
|
||||
.values()
|
||||
.filter(|e| {
|
||||
namespace.map(|ns| e.namespace == ns).unwrap_or(true)
|
||||
})
|
||||
.filter(|e| namespace.map(|ns| e.namespace == ns).unwrap_or(true))
|
||||
.map(|e| {
|
||||
let score = self.calculate_similarity(&e.vector, query);
|
||||
VectorSearchResult {
|
||||
|
|
@ -217,14 +215,14 @@ impl VectorIndex {
|
|||
score,
|
||||
}
|
||||
})
|
||||
.filter(|r| {
|
||||
threshold.map(|t| r.score >= t).unwrap_or(true)
|
||||
})
|
||||
.filter(|r| threshold.map(|t| r.score >= t).unwrap_or(true))
|
||||
.collect();
|
||||
|
||||
// Sort by score descending
|
||||
results.sort_by(|a, b| {
|
||||
b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal)
|
||||
b.score
|
||||
.partial_cmp(&a.score)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
|
||||
// Apply limit
|
||||
|
|
@ -234,8 +232,9 @@ impl VectorIndex {
|
|||
let elapsed = start.elapsed().as_millis() as f64;
|
||||
let mut stats = self.stats.write();
|
||||
stats.searches += 1;
|
||||
stats.avg_search_time_ms =
|
||||
(stats.avg_search_time_ms * (stats.searches - 1) as f64 + elapsed) / stats.searches as f64;
|
||||
stats.avg_search_time_ms = (stats.avg_search_time_ms * (stats.searches - 1) as f64
|
||||
+ elapsed)
|
||||
/ stats.searches as f64;
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
|
@ -329,7 +328,8 @@ impl VectorStore {
|
|||
namespace: Option<&str>,
|
||||
threshold: Option<f32>,
|
||||
) -> Result<Vec<VectorSearchResult>, DatabaseError> {
|
||||
self.default_index.search(query, limit, namespace, threshold)
|
||||
self.default_index
|
||||
.search(query, limit, namespace, threshold)
|
||||
}
|
||||
|
||||
/// Gets an embedding by ID.
|
||||
|
|
@ -388,10 +388,7 @@ pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
|
|||
|
||||
/// Manhattan distance (L1) between two vectors.
|
||||
pub fn manhattan_distance(a: &[f32], b: &[f32]) -> f32 {
|
||||
a.iter()
|
||||
.zip(b.iter())
|
||||
.map(|(x, y)| (x - y).abs())
|
||||
.sum()
|
||||
a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
@ -414,9 +411,15 @@ mod tests {
|
|||
fn test_vector_insert_search() {
|
||||
let store = VectorStore::new(3);
|
||||
|
||||
store.insert(Embedding::new("a", vec![1.0, 0.0, 0.0])).unwrap();
|
||||
store.insert(Embedding::new("b", vec![0.9, 0.1, 0.0])).unwrap();
|
||||
store.insert(Embedding::new("c", vec![0.0, 1.0, 0.0])).unwrap();
|
||||
store
|
||||
.insert(Embedding::new("a", vec![1.0, 0.0, 0.0]))
|
||||
.unwrap();
|
||||
store
|
||||
.insert(Embedding::new("b", vec![0.9, 0.1, 0.0]))
|
||||
.unwrap();
|
||||
store
|
||||
.insert(Embedding::new("c", vec![0.0, 1.0, 0.0]))
|
||||
.unwrap();
|
||||
|
||||
let results = store.search(&[1.0, 0.0, 0.0], 2, None, None).unwrap();
|
||||
|
||||
|
|
@ -429,8 +432,12 @@ mod tests {
|
|||
fn test_similarity_threshold() {
|
||||
let store = VectorStore::new(3);
|
||||
|
||||
store.insert(Embedding::new("a", vec![1.0, 0.0, 0.0])).unwrap();
|
||||
store.insert(Embedding::new("b", vec![0.0, 1.0, 0.0])).unwrap();
|
||||
store
|
||||
.insert(Embedding::new("a", vec![1.0, 0.0, 0.0]))
|
||||
.unwrap();
|
||||
store
|
||||
.insert(Embedding::new("b", vec![0.0, 1.0, 0.0]))
|
||||
.unwrap();
|
||||
|
||||
let results = store.search(&[1.0, 0.0, 0.0], 10, None, Some(0.5)).unwrap();
|
||||
|
||||
|
|
@ -443,14 +450,16 @@ mod tests {
|
|||
fn test_namespace_filter() {
|
||||
let store = VectorStore::new(3);
|
||||
|
||||
store.insert(
|
||||
Embedding::new("a", vec![1.0, 0.0, 0.0]).with_namespace("ns1")
|
||||
).unwrap();
|
||||
store.insert(
|
||||
Embedding::new("b", vec![1.0, 0.0, 0.0]).with_namespace("ns2")
|
||||
).unwrap();
|
||||
store
|
||||
.insert(Embedding::new("a", vec![1.0, 0.0, 0.0]).with_namespace("ns1"))
|
||||
.unwrap();
|
||||
store
|
||||
.insert(Embedding::new("b", vec![1.0, 0.0, 0.0]).with_namespace("ns2"))
|
||||
.unwrap();
|
||||
|
||||
let results = store.search(&[1.0, 0.0, 0.0], 10, Some("ns1"), None).unwrap();
|
||||
let results = store
|
||||
.search(&[1.0, 0.0, 0.0], 10, Some("ns1"), None)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(results.len(), 1);
|
||||
assert_eq!(results[0].embedding.id, "a");
|
||||
|
|
|
|||
|
|
@ -184,9 +184,7 @@ impl Credit {
|
|||
|
||||
/// Check if credit is expired
|
||||
pub fn is_expired(&self) -> bool {
|
||||
self.expires_at
|
||||
.map(|exp| Utc::now() > exp)
|
||||
.unwrap_or(false)
|
||||
self.expires_at.map(|exp| Utc::now() > exp).unwrap_or(false)
|
||||
}
|
||||
|
||||
/// Get remaining amount
|
||||
|
|
@ -241,14 +239,21 @@ impl std::fmt::Display for CreditError {
|
|||
match self {
|
||||
CreditError::CreditInactive => write!(f, "Credit is no longer active"),
|
||||
CreditError::CreditExpired => write!(f, "Credit has expired"),
|
||||
CreditError::InsufficientCredit { requested, available } => {
|
||||
CreditError::InsufficientCredit {
|
||||
requested,
|
||||
available,
|
||||
} => {
|
||||
write!(
|
||||
f,
|
||||
"Insufficient credit: requested {}, available {}",
|
||||
requested, available
|
||||
)
|
||||
}
|
||||
CreditError::ExceedsMaxCredit { current, requested, maximum } => {
|
||||
CreditError::ExceedsMaxCredit {
|
||||
current,
|
||||
requested,
|
||||
maximum,
|
||||
} => {
|
||||
write!(
|
||||
f,
|
||||
"Credit exceeds maximum: current {}, requested {}, maximum {}",
|
||||
|
|
@ -279,9 +284,9 @@ pub struct CreditPolicy {
|
|||
impl Default for CreditPolicy {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
welcome_amount: Decimal::new(10, 0), // 10 SYNOR
|
||||
welcome_amount: Decimal::new(10, 0), // 10 SYNOR
|
||||
referral_referrer_amount: Decimal::new(25, 0), // 25 SYNOR
|
||||
referral_referee_amount: Decimal::new(10, 0), // 10 SYNOR
|
||||
referral_referee_amount: Decimal::new(10, 0), // 10 SYNOR
|
||||
max_credit_per_account: Decimal::new(1000, 0), // 1000 SYNOR
|
||||
default_expiry_days: 365,
|
||||
}
|
||||
|
|
@ -334,12 +339,20 @@ impl CreditManager {
|
|||
let referee_id = referee_id.into();
|
||||
|
||||
// Credit for the referrer
|
||||
let referrer_credit = Credit::referral(&referrer_id, self.policy.referral_referrer_amount, &referee_id)
|
||||
.with_expiry_days(self.policy.default_expiry_days);
|
||||
let referrer_credit = Credit::referral(
|
||||
&referrer_id,
|
||||
self.policy.referral_referrer_amount,
|
||||
&referee_id,
|
||||
)
|
||||
.with_expiry_days(self.policy.default_expiry_days);
|
||||
|
||||
// Credit for the referee
|
||||
let referee_credit = Credit::referral(&referee_id, self.policy.referral_referee_amount, &referrer_id)
|
||||
.with_expiry_days(self.policy.default_expiry_days);
|
||||
let referee_credit = Credit::referral(
|
||||
&referee_id,
|
||||
self.policy.referral_referee_amount,
|
||||
&referrer_id,
|
||||
)
|
||||
.with_expiry_days(self.policy.default_expiry_days);
|
||||
|
||||
self.credits
|
||||
.entry(referrer_id)
|
||||
|
|
@ -448,13 +461,11 @@ impl CreditManager {
|
|||
let mut remaining = amount;
|
||||
|
||||
// Sort by expiry date (soonest first) for FIFO
|
||||
credits.sort_by(|a, b| {
|
||||
match (&a.expires_at, &b.expires_at) {
|
||||
(Some(a_exp), Some(b_exp)) => a_exp.cmp(b_exp),
|
||||
(Some(_), None) => std::cmp::Ordering::Less,
|
||||
(None, Some(_)) => std::cmp::Ordering::Greater,
|
||||
(None, None) => a.created_at.cmp(&b.created_at),
|
||||
}
|
||||
credits.sort_by(|a, b| match (&a.expires_at, &b.expires_at) {
|
||||
(Some(a_exp), Some(b_exp)) => a_exp.cmp(b_exp),
|
||||
(Some(_), None) => std::cmp::Ordering::Less,
|
||||
(None, Some(_)) => std::cmp::Ordering::Greater,
|
||||
(None, None) => a.created_at.cmp(&b.created_at),
|
||||
});
|
||||
|
||||
for credit in credits.iter_mut() {
|
||||
|
|
|
|||
|
|
@ -319,12 +319,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_line_item() {
|
||||
let item = InvoiceLineItem::new(
|
||||
"Storage L2",
|
||||
ServiceType::Storage,
|
||||
dec!(10),
|
||||
dec!(0.02),
|
||||
);
|
||||
let item = InvoiceLineItem::new("Storage L2", ServiceType::Storage, dec!(10), dec!(0.02));
|
||||
|
||||
assert_eq!(item.amount, dec!(0.20));
|
||||
}
|
||||
|
|
@ -332,8 +327,18 @@ mod tests {
|
|||
#[test]
|
||||
fn test_invoice_calculate() {
|
||||
let mut invoice = Invoice::new("test")
|
||||
.add_line_item(InvoiceLineItem::new("Storage", ServiceType::Storage, dec!(100), dec!(0.02)))
|
||||
.add_line_item(InvoiceLineItem::new("Compute", ServiceType::Compute, dec!(10), dec!(0.50)));
|
||||
.add_line_item(InvoiceLineItem::new(
|
||||
"Storage",
|
||||
ServiceType::Storage,
|
||||
dec!(100),
|
||||
dec!(0.02),
|
||||
))
|
||||
.add_line_item(InvoiceLineItem::new(
|
||||
"Compute",
|
||||
ServiceType::Compute,
|
||||
dec!(10),
|
||||
dec!(0.50),
|
||||
));
|
||||
|
||||
invoice.discount = dec!(1);
|
||||
invoice.calculate();
|
||||
|
|
|
|||
|
|
@ -164,17 +164,12 @@ impl BillingEngine {
|
|||
let outstanding: Vec<_> = account
|
||||
.invoice_ids
|
||||
.iter()
|
||||
.filter(|id| {
|
||||
invoices
|
||||
.get(*id)
|
||||
.map(|inv| !inv.is_paid())
|
||||
.unwrap_or(false)
|
||||
})
|
||||
.filter(|id| invoices.get(*id).map(|inv| !inv.is_paid()).unwrap_or(false))
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
let next_invoice = account.billing_cycle_start
|
||||
+ Duration::days(self.config.billing_cycle_days as i64);
|
||||
let next_invoice =
|
||||
account.billing_cycle_start + Duration::days(self.config.billing_cycle_days as i64);
|
||||
|
||||
Ok(AccountBillingInfo {
|
||||
account_id: account_id.to_string(),
|
||||
|
|
@ -198,11 +193,7 @@ impl BillingEngine {
|
|||
|
||||
account.prepaid_balance += amount;
|
||||
|
||||
tracing::info!(
|
||||
"Added {} SYNOR prepaid to account {}",
|
||||
amount,
|
||||
account_id
|
||||
);
|
||||
tracing::info!("Added {} SYNOR prepaid to account {}", amount, account_id);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
@ -378,7 +369,9 @@ impl BillingEngine {
|
|||
PaymentMethod::CreditBalance => {
|
||||
// Deduct from credit balance
|
||||
if account.credit_balance < payment.amount {
|
||||
return Err(EconomicsError::PaymentFailed("Insufficient credit balance".to_string()));
|
||||
return Err(EconomicsError::PaymentFailed(
|
||||
"Insufficient credit balance".to_string(),
|
||||
));
|
||||
}
|
||||
account.credit_balance -= payment.amount;
|
||||
payment.mark_completed();
|
||||
|
|
@ -484,7 +477,10 @@ impl BillingEngine {
|
|||
/// Get unpaid invoices for an account
|
||||
pub async fn get_unpaid_invoices(&self, account_id: &str) -> Result<Vec<Invoice>> {
|
||||
let all_invoices = self.get_account_invoices(account_id).await?;
|
||||
Ok(all_invoices.into_iter().filter(|inv| !inv.is_paid()).collect())
|
||||
Ok(all_invoices
|
||||
.into_iter()
|
||||
.filter(|inv| !inv.is_paid())
|
||||
.collect())
|
||||
}
|
||||
|
||||
/// Get detailed account information including creation date
|
||||
|
|
@ -617,7 +613,10 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn test_register_account() {
|
||||
let engine = setup_engine().await;
|
||||
engine.register_account("test_account", "standard").await.unwrap();
|
||||
engine
|
||||
.register_account("test_account", "standard")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let info = engine.get_account_details("test_account").await.unwrap();
|
||||
assert_eq!(info.account_id, "test_account");
|
||||
|
|
@ -627,7 +626,10 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn test_add_prepaid() {
|
||||
let engine = setup_engine().await;
|
||||
engine.register_account("prepaid_test", "standard").await.unwrap();
|
||||
engine
|
||||
.register_account("prepaid_test", "standard")
|
||||
.await
|
||||
.unwrap();
|
||||
engine.add_prepaid("prepaid_test", dec!(100)).await.unwrap();
|
||||
|
||||
let info = engine.get_account_info("prepaid_test").await.unwrap();
|
||||
|
|
@ -637,7 +639,10 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn test_add_credit() {
|
||||
let engine = setup_engine().await;
|
||||
engine.register_account("credit_test", "standard").await.unwrap();
|
||||
engine
|
||||
.register_account("credit_test", "standard")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let credit = Credit::new("credit_test", dec!(50), "Welcome bonus");
|
||||
engine.add_credit("credit_test", credit).await.unwrap();
|
||||
|
|
|
|||
|
|
@ -210,17 +210,23 @@ impl PaymentProcessor {
|
|||
payment.mark_processing();
|
||||
|
||||
// Simulate transaction
|
||||
let tx_hash = format!("0x{:x}000000000000000000000000000000000000000000000000000000000000",
|
||||
let tx_hash = format!(
|
||||
"0x{:x}000000000000000000000000000000000000000000000000000000000000",
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs());
|
||||
.as_secs()
|
||||
);
|
||||
|
||||
payment.mark_confirmed(tx_hash);
|
||||
|
||||
// Add addresses to metadata
|
||||
payment.metadata.insert("from".to_string(), from_address.to_string());
|
||||
payment.metadata.insert("to".to_string(), to_address.to_string());
|
||||
payment
|
||||
.metadata
|
||||
.insert("from".to_string(), from_address.to_string());
|
||||
payment
|
||||
.metadata
|
||||
.insert("to".to_string(), to_address.to_string());
|
||||
|
||||
payment.mark_completed();
|
||||
|
||||
|
|
@ -325,7 +331,10 @@ mod tests {
|
|||
|
||||
payment.mark_failed("Insufficient funds");
|
||||
assert!(!payment.is_complete());
|
||||
assert_eq!(payment.failure_reason, Some("Insufficient funds".to_string()));
|
||||
assert_eq!(
|
||||
payment.failure_reason,
|
||||
Some("Insufficient funds".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
|
|
|||
|
|
@ -177,7 +177,10 @@ impl CostEstimator {
|
|||
|
||||
/// Estimate cost for a usage projection
|
||||
pub async fn estimate(&self, projection: UsageProjection) -> Result<CostEstimate> {
|
||||
let tier_name = projection.tier.clone().unwrap_or_else(|| "free".to_string());
|
||||
let tier_name = projection
|
||||
.tier
|
||||
.clone()
|
||||
.unwrap_or_else(|| "free".to_string());
|
||||
let months = projection.duration_months.max(1);
|
||||
|
||||
let mut by_service = HashMap::new();
|
||||
|
|
|
|||
|
|
@ -22,10 +22,7 @@ pub enum EconomicsError {
|
|||
|
||||
/// Insufficient balance
|
||||
#[error("Insufficient balance: required {required}, available {available}")]
|
||||
InsufficientBalance {
|
||||
required: String,
|
||||
available: String,
|
||||
},
|
||||
InsufficientBalance { required: String, available: String },
|
||||
|
||||
/// Insufficient funds (with Decimal values)
|
||||
#[error("Insufficient funds: required {required}, available {available}")]
|
||||
|
|
@ -36,10 +33,7 @@ pub enum EconomicsError {
|
|||
|
||||
/// Stale price with specific asset
|
||||
#[error("Price stale for {asset}: {age_seconds} seconds old")]
|
||||
StalePrice {
|
||||
asset: String,
|
||||
age_seconds: i64,
|
||||
},
|
||||
StalePrice { asset: String, age_seconds: i64 },
|
||||
|
||||
/// Account not found
|
||||
#[error("Account not found: {0}")]
|
||||
|
|
|
|||
|
|
@ -251,9 +251,7 @@ impl EconomicsManager {
|
|||
use rust_decimal_macros::dec;
|
||||
|
||||
// Default to development oracle with mock feeds at $1.50 base price
|
||||
let oracle = Arc::new(RwLock::new(
|
||||
oracle::OracleFactory::development(dec!(1.50))
|
||||
));
|
||||
let oracle = Arc::new(RwLock::new(oracle::OracleFactory::development(dec!(1.50))));
|
||||
let pricing = Arc::new(PricingEngine::new());
|
||||
let metering = Arc::new(MeteringService::new(pricing.clone()));
|
||||
let billing = Arc::new(BillingEngine::new(metering.clone(), pricing.clone()));
|
||||
|
|
@ -270,9 +268,7 @@ impl EconomicsManager {
|
|||
|
||||
/// Create an economics manager with production oracle configuration
|
||||
pub fn with_production_oracle(config: oracle::ProductionOracleConfig) -> Self {
|
||||
let oracle = Arc::new(RwLock::new(
|
||||
oracle::OracleFactory::production(config)
|
||||
));
|
||||
let oracle = Arc::new(RwLock::new(oracle::OracleFactory::production(config)));
|
||||
let pricing = Arc::new(PricingEngine::new());
|
||||
let metering = Arc::new(MeteringService::new(pricing.clone()));
|
||||
let billing = Arc::new(BillingEngine::new(metering.clone(), pricing.clone()));
|
||||
|
|
|
|||
|
|
@ -209,23 +209,22 @@ impl MeteringService {
|
|||
}
|
||||
|
||||
// Calculate cost for this event
|
||||
let cost = self.pricing.calculate_cost(
|
||||
event.service_type,
|
||||
event.resource_unit,
|
||||
event.amount,
|
||||
)?;
|
||||
let cost =
|
||||
self.pricing
|
||||
.calculate_cost(event.service_type, event.resource_unit, event.amount)?;
|
||||
|
||||
// Update current usage
|
||||
{
|
||||
let mut usage = self.current_usage.write().await;
|
||||
let account_usage = usage.entry(event.account_id.clone()).or_insert_with(|| {
|
||||
AccountUsage {
|
||||
account_id: event.account_id.clone(),
|
||||
by_service: HashMap::new(),
|
||||
current_period_start: Utc::now(),
|
||||
last_event: None,
|
||||
}
|
||||
});
|
||||
let account_usage =
|
||||
usage
|
||||
.entry(event.account_id.clone())
|
||||
.or_insert_with(|| AccountUsage {
|
||||
account_id: event.account_id.clone(),
|
||||
by_service: HashMap::new(),
|
||||
current_period_start: Utc::now(),
|
||||
last_event: None,
|
||||
});
|
||||
|
||||
*account_usage
|
||||
.by_service
|
||||
|
|
@ -263,7 +262,8 @@ impl MeteringService {
|
|||
ServiceType::Storage,
|
||||
ResourceUnit::Bytes,
|
||||
Decimal::from(usage.bytes_stored),
|
||||
)).await?;
|
||||
))
|
||||
.await?;
|
||||
}
|
||||
|
||||
// Storage: bytes retrieved
|
||||
|
|
@ -273,7 +273,8 @@ impl MeteringService {
|
|||
ServiceType::Storage,
|
||||
ResourceUnit::BandwidthGb,
|
||||
Decimal::from(usage.bytes_retrieved) / Decimal::from(1_073_741_824u64), // to GB
|
||||
)).await?;
|
||||
))
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
|
@ -288,7 +289,8 @@ impl MeteringService {
|
|||
ServiceType::Hosting,
|
||||
ResourceUnit::BandwidthGb,
|
||||
Decimal::from(usage.bandwidth_bytes) / Decimal::from(1_073_741_824u64),
|
||||
)).await?;
|
||||
))
|
||||
.await?;
|
||||
}
|
||||
|
||||
// Custom domains
|
||||
|
|
@ -298,7 +300,8 @@ impl MeteringService {
|
|||
ServiceType::Hosting,
|
||||
ResourceUnit::Domains,
|
||||
Decimal::from(usage.custom_domains),
|
||||
)).await?;
|
||||
))
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
|
@ -313,7 +316,8 @@ impl MeteringService {
|
|||
ServiceType::Database,
|
||||
ResourceUnit::Queries,
|
||||
Decimal::from(usage.queries),
|
||||
)).await?;
|
||||
))
|
||||
.await?;
|
||||
}
|
||||
|
||||
// Vector searches
|
||||
|
|
@ -323,7 +327,8 @@ impl MeteringService {
|
|||
ServiceType::Database,
|
||||
ResourceUnit::VectorSearches,
|
||||
Decimal::from(usage.vector_searches),
|
||||
)).await?;
|
||||
))
|
||||
.await?;
|
||||
}
|
||||
|
||||
// Storage
|
||||
|
|
@ -333,7 +338,8 @@ impl MeteringService {
|
|||
ServiceType::Database,
|
||||
ResourceUnit::GbMonth,
|
||||
Decimal::from(usage.storage_bytes) / Decimal::from(1_073_741_824u64),
|
||||
)).await?;
|
||||
))
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
|
@ -348,7 +354,8 @@ impl MeteringService {
|
|||
ServiceType::Compute,
|
||||
ResourceUnit::CpuCoreHours,
|
||||
Decimal::from(usage.cpu_core_seconds) / Decimal::from(3600),
|
||||
)).await?;
|
||||
))
|
||||
.await?;
|
||||
}
|
||||
|
||||
// GPU hours
|
||||
|
|
@ -358,7 +365,8 @@ impl MeteringService {
|
|||
ServiceType::Compute,
|
||||
ResourceUnit::GpuHours,
|
||||
Decimal::from(usage.gpu_seconds) / Decimal::from(3600),
|
||||
)).await?;
|
||||
))
|
||||
.await?;
|
||||
}
|
||||
|
||||
// Memory GB hours
|
||||
|
|
@ -368,7 +376,8 @@ impl MeteringService {
|
|||
ServiceType::Compute,
|
||||
ResourceUnit::MemoryGbHours,
|
||||
Decimal::from(usage.memory_gb_seconds) / Decimal::from(3600),
|
||||
)).await?;
|
||||
))
|
||||
.await?;
|
||||
}
|
||||
|
||||
// Invocations (serverless)
|
||||
|
|
@ -378,7 +387,8 @@ impl MeteringService {
|
|||
ServiceType::Compute,
|
||||
ResourceUnit::Invocations,
|
||||
Decimal::from(usage.invocations),
|
||||
)).await?;
|
||||
))
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
|
@ -393,7 +403,8 @@ impl MeteringService {
|
|||
ServiceType::Network,
|
||||
ResourceUnit::BandwidthGb,
|
||||
Decimal::from(total_bytes) / Decimal::from(1_073_741_824u64),
|
||||
)).await?;
|
||||
))
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
|
@ -421,10 +432,7 @@ impl MeteringService {
|
|||
// Check buffered events
|
||||
let buffer = self.event_buffer.read().await;
|
||||
for event in buffer.iter() {
|
||||
if event.account_id == account_id
|
||||
&& event.timestamp >= start
|
||||
&& event.timestamp < end
|
||||
{
|
||||
if event.account_id == account_id && event.timestamp >= start && event.timestamp < end {
|
||||
let cost = self.pricing.calculate_cost(
|
||||
event.service_type,
|
||||
event.resource_unit,
|
||||
|
|
|
|||
|
|
@ -224,9 +224,8 @@ impl IsolationTree {
|
|||
// Random split point
|
||||
let split = min_val + (max_val - min_val) * 0.5;
|
||||
|
||||
let (left_data, right_data): (Vec<_>, Vec<_>) = data.iter()
|
||||
.cloned()
|
||||
.partition(|row| row[feature] < split);
|
||||
let (left_data, right_data): (Vec<_>, Vec<_>) =
|
||||
data.iter().cloned().partition(|row| row[feature] < split);
|
||||
|
||||
Some(Self {
|
||||
split_feature: feature,
|
||||
|
|
@ -280,7 +279,8 @@ impl IsolationForest {
|
|||
let trees: Vec<_> = (0..n_trees)
|
||||
.filter_map(|i| {
|
||||
// Subsample with deterministic "randomness" based on tree index
|
||||
let sample: Vec<_> = data.iter()
|
||||
let sample: Vec<_> = data
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(j, _)| (i + j) % 3 != 0)
|
||||
.take(sample_size)
|
||||
|
|
@ -299,9 +299,12 @@ impl IsolationForest {
|
|||
return 0.5;
|
||||
}
|
||||
|
||||
let avg_path: f64 = self.trees.iter()
|
||||
let avg_path: f64 = self
|
||||
.trees
|
||||
.iter()
|
||||
.map(|tree| tree.path_length(point, 0.0))
|
||||
.sum::<f64>() / self.trees.len() as f64;
|
||||
.sum::<f64>()
|
||||
/ self.trees.len() as f64;
|
||||
|
||||
let c = c_factor(self.sample_size);
|
||||
if c < f64::EPSILON {
|
||||
|
|
@ -365,17 +368,28 @@ impl PairDetector {
|
|||
|
||||
// Track addresses
|
||||
if !point.addresses.is_empty() {
|
||||
self.recent_addresses.push_back((point.timestamp, point.addresses.clone()));
|
||||
self.recent_addresses
|
||||
.push_back((point.timestamp, point.addresses.clone()));
|
||||
}
|
||||
|
||||
self.price_history.push_back(point);
|
||||
|
||||
// Cleanup old data
|
||||
let cutoff = Utc::now() - Duration::hours(24);
|
||||
while self.price_history.front().map(|p| p.timestamp < cutoff).unwrap_or(false) {
|
||||
while self
|
||||
.price_history
|
||||
.front()
|
||||
.map(|p| p.timestamp < cutoff)
|
||||
.unwrap_or(false)
|
||||
{
|
||||
self.price_history.pop_front();
|
||||
}
|
||||
while self.recent_addresses.front().map(|(t, _)| *t < cutoff).unwrap_or(false) {
|
||||
while self
|
||||
.recent_addresses
|
||||
.front()
|
||||
.map(|(t, _)| *t < cutoff)
|
||||
.unwrap_or(false)
|
||||
{
|
||||
self.recent_addresses.pop_front();
|
||||
}
|
||||
}
|
||||
|
|
@ -386,7 +400,9 @@ impl PairDetector {
|
|||
}
|
||||
|
||||
// Build feature vectors: [price, volume, return, bid/ask ratio]
|
||||
let data: Vec<Vec<f64>> = self.price_history.iter()
|
||||
let data: Vec<Vec<f64>> = self
|
||||
.price_history
|
||||
.iter()
|
||||
.skip(1)
|
||||
.zip(self.price_history.iter())
|
||||
.map(|(curr, prev)| {
|
||||
|
|
@ -403,7 +419,11 @@ impl PairDetector {
|
|||
(Some(bid), Some(ask)) => {
|
||||
let bid_f = bid.to_string().parse::<f64>().unwrap_or(0.0);
|
||||
let ask_f = ask.to_string().parse::<f64>().unwrap_or(1.0);
|
||||
if ask_f > 0.0 { bid_f / ask_f } else { 1.0 }
|
||||
if ask_f > 0.0 {
|
||||
bid_f / ask_f
|
||||
} else {
|
||||
1.0
|
||||
}
|
||||
}
|
||||
_ => 1.0,
|
||||
};
|
||||
|
|
@ -462,19 +482,34 @@ impl AnomalyDetector {
|
|||
|
||||
// Run all detectors using the immutable reference first
|
||||
if let Some(detector) = self.detectors.get(pair) {
|
||||
if let Some(a) = Self::detect_price_outlier_impl(pair, &data, detector, min_data_points, z_score_threshold) {
|
||||
if let Some(a) = Self::detect_price_outlier_impl(
|
||||
pair,
|
||||
&data,
|
||||
detector,
|
||||
min_data_points,
|
||||
z_score_threshold,
|
||||
) {
|
||||
anomalies.push(a);
|
||||
}
|
||||
if let Some(a) = Self::detect_volume_spike_impl(pair, &data, detector, min_data_points, volume_spike_multiplier) {
|
||||
if let Some(a) = Self::detect_volume_spike_impl(
|
||||
pair,
|
||||
&data,
|
||||
detector,
|
||||
min_data_points,
|
||||
volume_spike_multiplier,
|
||||
) {
|
||||
anomalies.push(a);
|
||||
}
|
||||
if let Some(a) = Self::detect_wash_trading_impl(pair, &data, detector, wash_trading_window) {
|
||||
if let Some(a) =
|
||||
Self::detect_wash_trading_impl(pair, &data, detector, wash_trading_window)
|
||||
{
|
||||
anomalies.push(a);
|
||||
}
|
||||
if let Some(a) = Self::detect_pump_dump_impl(pair, detector, pump_dump_window) {
|
||||
anomalies.push(a);
|
||||
}
|
||||
if let Some(a) = Self::detect_flash_loan_impl(pair, &data, detector, flash_loan_window) {
|
||||
if let Some(a) = Self::detect_flash_loan_impl(pair, &data, detector, flash_loan_window)
|
||||
{
|
||||
anomalies.push(a);
|
||||
}
|
||||
if ml_enabled {
|
||||
|
|
@ -493,7 +528,13 @@ impl AnomalyDetector {
|
|||
anomalies
|
||||
}
|
||||
|
||||
fn detect_price_outlier_impl(pair: &str, data: &MarketDataPoint, detector: &PairDetector, min_data_points: usize, z_score_threshold: f64) -> Option<Anomaly> {
|
||||
fn detect_price_outlier_impl(
|
||||
pair: &str,
|
||||
data: &MarketDataPoint,
|
||||
detector: &PairDetector,
|
||||
min_data_points: usize,
|
||||
z_score_threshold: f64,
|
||||
) -> Option<Anomaly> {
|
||||
if detector.price_stats.count < min_data_points {
|
||||
return None;
|
||||
}
|
||||
|
|
@ -531,7 +572,13 @@ impl AnomalyDetector {
|
|||
}
|
||||
}
|
||||
|
||||
fn detect_volume_spike_impl(pair: &str, data: &MarketDataPoint, detector: &PairDetector, min_data_points: usize, volume_spike_multiplier: f64) -> Option<Anomaly> {
|
||||
fn detect_volume_spike_impl(
|
||||
pair: &str,
|
||||
data: &MarketDataPoint,
|
||||
detector: &PairDetector,
|
||||
min_data_points: usize,
|
||||
volume_spike_multiplier: f64,
|
||||
) -> Option<Anomaly> {
|
||||
if detector.volume_stats.count < min_data_points {
|
||||
return None;
|
||||
}
|
||||
|
|
@ -550,7 +597,9 @@ impl AnomalyDetector {
|
|||
confidence: 0.75,
|
||||
description: format!(
|
||||
"Volume {} is {:.1}x the average {:.2}",
|
||||
data.volume, volume_f64 / mean, mean
|
||||
data.volume,
|
||||
volume_f64 / mean,
|
||||
mean
|
||||
),
|
||||
data: AnomalyData {
|
||||
current_value: data.volume,
|
||||
|
|
@ -566,7 +615,12 @@ impl AnomalyDetector {
|
|||
}
|
||||
}
|
||||
|
||||
fn detect_wash_trading_impl(pair: &str, data: &MarketDataPoint, detector: &PairDetector, wash_trading_window: i64) -> Option<Anomaly> {
|
||||
fn detect_wash_trading_impl(
|
||||
pair: &str,
|
||||
data: &MarketDataPoint,
|
||||
detector: &PairDetector,
|
||||
wash_trading_window: i64,
|
||||
) -> Option<Anomaly> {
|
||||
if data.addresses.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
|
@ -617,14 +671,20 @@ impl AnomalyDetector {
|
|||
None
|
||||
}
|
||||
|
||||
fn detect_pump_dump_impl(pair: &str, detector: &PairDetector, pump_dump_window: i64) -> Option<Anomaly> {
|
||||
fn detect_pump_dump_impl(
|
||||
pair: &str,
|
||||
detector: &PairDetector,
|
||||
pump_dump_window: i64,
|
||||
) -> Option<Anomaly> {
|
||||
// Need enough history
|
||||
if detector.price_history.len() < 10 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let window_start = Utc::now() - Duration::minutes(pump_dump_window);
|
||||
let prices: Vec<_> = detector.price_history.iter()
|
||||
let prices: Vec<_> = detector
|
||||
.price_history
|
||||
.iter()
|
||||
.filter(|p| p.timestamp >= window_start)
|
||||
.map(|p| p.price.to_string().parse::<f64>().unwrap_or(0.0))
|
||||
.collect();
|
||||
|
|
@ -635,7 +695,9 @@ impl AnomalyDetector {
|
|||
|
||||
// Find max and check for reversal
|
||||
let max_price = prices.iter().copied().fold(f64::MIN, f64::max);
|
||||
let max_idx = prices.iter().position(|&p| (p - max_price).abs() < f64::EPSILON)?;
|
||||
let max_idx = prices
|
||||
.iter()
|
||||
.position(|&p| (p - max_price).abs() < f64::EPSILON)?;
|
||||
let first_price = prices.first()?;
|
||||
let last_price = prices.last()?;
|
||||
|
||||
|
|
@ -672,10 +734,17 @@ impl AnomalyDetector {
|
|||
None
|
||||
}
|
||||
|
||||
fn detect_flash_loan_impl(pair: &str, data: &MarketDataPoint, detector: &PairDetector, flash_loan_window: i64) -> Option<Anomaly> {
|
||||
fn detect_flash_loan_impl(
|
||||
pair: &str,
|
||||
data: &MarketDataPoint,
|
||||
detector: &PairDetector,
|
||||
flash_loan_window: i64,
|
||||
) -> Option<Anomaly> {
|
||||
// Flash loan signature: huge volume spike + quick price movement + reversal
|
||||
let window_start = Utc::now() - Duration::seconds(flash_loan_window);
|
||||
let recent: Vec<_> = detector.price_history.iter()
|
||||
let recent: Vec<_> = detector
|
||||
.price_history
|
||||
.iter()
|
||||
.filter(|p| p.timestamp >= window_start)
|
||||
.collect();
|
||||
|
||||
|
|
@ -683,11 +752,13 @@ impl AnomalyDetector {
|
|||
return None;
|
||||
}
|
||||
|
||||
let volumes: Vec<f64> = recent.iter()
|
||||
let volumes: Vec<f64> = recent
|
||||
.iter()
|
||||
.map(|p| p.volume.to_string().parse::<f64>().unwrap_or(0.0))
|
||||
.collect();
|
||||
|
||||
let prices: Vec<f64> = recent.iter()
|
||||
let prices: Vec<f64> = recent
|
||||
.iter()
|
||||
.map(|p| p.price.to_string().parse::<f64>().unwrap_or(0.0))
|
||||
.collect();
|
||||
|
||||
|
|
@ -706,7 +777,10 @@ impl AnomalyDetector {
|
|||
// Big spike and quick reversal
|
||||
if spike > 10.0 && reversal > 8.0 {
|
||||
let mut context = HashMap::new();
|
||||
context.insert("volume_spike".to_string(), format!("{:.0}x", max_volume / avg_volume));
|
||||
context.insert(
|
||||
"volume_spike".to_string(),
|
||||
format!("{:.0}x", max_volume / avg_volume),
|
||||
);
|
||||
context.insert("price_spike".to_string(), format!("{:.1}%", spike));
|
||||
|
||||
return Some(Anomaly {
|
||||
|
|
@ -717,7 +791,8 @@ impl AnomalyDetector {
|
|||
confidence: 0.65,
|
||||
description: format!(
|
||||
"Suspected flash loan attack: {}x volume, {:.1}% price spike",
|
||||
max_volume / avg_volume, spike
|
||||
max_volume / avg_volume,
|
||||
spike
|
||||
),
|
||||
data: AnomalyData {
|
||||
current_value: data.volume,
|
||||
|
|
@ -734,7 +809,11 @@ impl AnomalyDetector {
|
|||
None
|
||||
}
|
||||
|
||||
fn detect_ml_anomaly_impl(pair: &str, data: &MarketDataPoint, detector: &PairDetector) -> Option<Anomaly> {
|
||||
fn detect_ml_anomaly_impl(
|
||||
pair: &str,
|
||||
data: &MarketDataPoint,
|
||||
detector: &PairDetector,
|
||||
) -> Option<Anomaly> {
|
||||
let forest = detector.isolation_forest.as_ref()?;
|
||||
|
||||
if detector.price_history.is_empty() {
|
||||
|
|
@ -756,7 +835,11 @@ impl AnomalyDetector {
|
|||
(Some(bid), Some(ask)) => {
|
||||
let bid_f = bid.to_string().parse::<f64>().ok()?;
|
||||
let ask_f = ask.to_string().parse::<f64>().ok()?;
|
||||
if ask_f > 0.0 { bid_f / ask_f } else { 1.0 }
|
||||
if ask_f > 0.0 {
|
||||
bid_f / ask_f
|
||||
} else {
|
||||
1.0
|
||||
}
|
||||
}
|
||||
_ => 1.0,
|
||||
};
|
||||
|
|
@ -774,10 +857,7 @@ impl AnomalyDetector {
|
|||
detected_at: Utc::now(),
|
||||
severity: score,
|
||||
confidence: 0.6,
|
||||
description: format!(
|
||||
"ML model detected anomaly with score {:.3}",
|
||||
score
|
||||
),
|
||||
description: format!("ML model detected anomaly with score {:.3}", score),
|
||||
data: AnomalyData {
|
||||
current_value: data.price,
|
||||
expected_value: Decimal::from_f64_retain(detector.price_stats.mean)?,
|
||||
|
|
@ -798,7 +878,8 @@ impl AnomalyDetector {
|
|||
|
||||
/// Get recent anomalies for a pair
|
||||
pub fn get_anomalies(&self, pair: &str) -> Vec<Anomaly> {
|
||||
self.detectors.get(pair)
|
||||
self.detectors
|
||||
.get(pair)
|
||||
.map(|d| d.anomalies.clone())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
|
@ -807,11 +888,14 @@ impl AnomalyDetector {
|
|||
pub fn get_stats(&self, pair: &str) -> Option<AnomalyStats> {
|
||||
let detector = self.detectors.get(pair)?;
|
||||
|
||||
let by_type: HashMap<AnomalyType, usize> = detector.anomalies.iter()
|
||||
.fold(HashMap::new(), |mut acc, a| {
|
||||
*acc.entry(a.anomaly_type.clone()).or_insert(0) += 1;
|
||||
acc
|
||||
});
|
||||
let by_type: HashMap<AnomalyType, usize> =
|
||||
detector
|
||||
.anomalies
|
||||
.iter()
|
||||
.fold(HashMap::new(), |mut acc, a| {
|
||||
*acc.entry(a.anomaly_type.clone()).or_insert(0) += 1;
|
||||
acc
|
||||
});
|
||||
|
||||
Some(AnomalyStats {
|
||||
total_anomalies: detector.anomalies.len(),
|
||||
|
|
@ -913,7 +997,9 @@ mod tests {
|
|||
|
||||
let anomalies = detector.process("SYNOR/USD", outlier);
|
||||
assert!(!anomalies.is_empty());
|
||||
assert!(anomalies.iter().any(|a| a.anomaly_type == AnomalyType::PriceOutlier));
|
||||
assert!(anomalies
|
||||
.iter()
|
||||
.any(|a| a.anomaly_type == AnomalyType::PriceOutlier));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -946,6 +1032,8 @@ mod tests {
|
|||
};
|
||||
|
||||
let anomalies = detector.process("SYNOR/USD", spike);
|
||||
assert!(anomalies.iter().any(|a| a.anomaly_type == AnomalyType::VolumeSpike));
|
||||
assert!(anomalies
|
||||
.iter()
|
||||
.any(|a| a.anomaly_type == AnomalyType::VolumeSpike));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -40,17 +40,11 @@ pub enum TriggerReason {
|
|||
threshold: SynorDecimal,
|
||||
},
|
||||
/// Multiple oracle sources disagree
|
||||
OracleDisagreement {
|
||||
spread_percent: Decimal,
|
||||
},
|
||||
OracleDisagreement { spread_percent: Decimal },
|
||||
/// Manual trigger by admin
|
||||
ManualHalt {
|
||||
reason: String,
|
||||
},
|
||||
ManualHalt { reason: String },
|
||||
/// Cascade from related market
|
||||
CascadeTrigger {
|
||||
source_pair: String,
|
||||
},
|
||||
CascadeTrigger { source_pair: String },
|
||||
}
|
||||
|
||||
/// Circuit breaker event
|
||||
|
|
@ -98,12 +92,12 @@ pub struct CircuitBreakerConfig {
|
|||
impl Default for CircuitBreakerConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_1m_change: Decimal::new(10, 2), // 10%
|
||||
max_5m_change: Decimal::new(20, 2), // 20%
|
||||
max_1h_change: Decimal::new(50, 2), // 50%
|
||||
max_1m_change: Decimal::new(10, 2), // 10%
|
||||
max_5m_change: Decimal::new(20, 2), // 20%
|
||||
max_1h_change: Decimal::new(50, 2), // 50%
|
||||
max_twap_deviation: Decimal::new(30, 2), // 30%
|
||||
min_liquidity: Decimal::new(10000, 0), // $10k
|
||||
max_oracle_spread: Decimal::new(5, 2), // 5%
|
||||
min_liquidity: Decimal::new(10000, 0), // $10k
|
||||
max_oracle_spread: Decimal::new(5, 2), // 5%
|
||||
cooldown_duration: Duration::minutes(5),
|
||||
recovery_checks: 3,
|
||||
cascade_enabled: true,
|
||||
|
|
@ -173,7 +167,12 @@ impl PairCircuitBreaker {
|
|||
|
||||
// Keep only last 24 hours
|
||||
let cutoff = Utc::now() - Duration::hours(24);
|
||||
while self.price_history.front().map(|s| s.timestamp < cutoff).unwrap_or(false) {
|
||||
while self
|
||||
.price_history
|
||||
.front()
|
||||
.map(|s| s.timestamp < cutoff)
|
||||
.unwrap_or(false)
|
||||
{
|
||||
self.price_history.pop_front();
|
||||
}
|
||||
|
||||
|
|
@ -192,7 +191,8 @@ impl PairCircuitBreaker {
|
|||
|
||||
fn get_price_at(&self, seconds_ago: i64) -> Option<SynorDecimal> {
|
||||
let target = Utc::now() - Duration::seconds(seconds_ago);
|
||||
self.price_history.iter()
|
||||
self.price_history
|
||||
.iter()
|
||||
.rev()
|
||||
.find(|s| s.timestamp <= target)
|
||||
.map(|s| s.price)
|
||||
|
|
@ -232,7 +232,9 @@ impl CircuitBreakerManager {
|
|||
price: SynorDecimal,
|
||||
liquidity: Option<SynorDecimal>,
|
||||
) -> Result<CircuitState> {
|
||||
let breaker = self.breakers.entry(pair.to_string())
|
||||
let breaker = self
|
||||
.breakers
|
||||
.entry(pair.to_string())
|
||||
.or_insert_with(PairCircuitBreaker::new);
|
||||
|
||||
// Use the convenience method for real-time price recording
|
||||
|
|
@ -256,7 +258,9 @@ impl CircuitBreakerManager {
|
|||
liquidity: Option<SynorDecimal>,
|
||||
timestamp: DateTime<Utc>,
|
||||
) -> Result<CircuitState> {
|
||||
let breaker = self.breakers.entry(pair.to_string())
|
||||
let breaker = self
|
||||
.breakers
|
||||
.entry(pair.to_string())
|
||||
.or_insert_with(PairCircuitBreaker::new);
|
||||
|
||||
breaker.record_price_at(price, liquidity, timestamp);
|
||||
|
|
@ -273,22 +277,26 @@ impl CircuitBreakerManager {
|
|||
|
||||
/// Check all trigger conditions
|
||||
fn check_triggers(&mut self, pair: &str) -> Result<()> {
|
||||
let breaker = self.breakers.get(pair).ok_or_else(||
|
||||
EconomicsError::PriceFeedUnavailable(pair.to_string())
|
||||
)?;
|
||||
let breaker = self
|
||||
.breakers
|
||||
.get(pair)
|
||||
.ok_or_else(|| EconomicsError::PriceFeedUnavailable(pair.to_string()))?;
|
||||
|
||||
let current = breaker.current_price().ok_or_else(||
|
||||
EconomicsError::PriceFeedUnavailable(pair.to_string())
|
||||
)?;
|
||||
let current = breaker
|
||||
.current_price()
|
||||
.ok_or_else(|| EconomicsError::PriceFeedUnavailable(pair.to_string()))?;
|
||||
|
||||
// Check 1-minute change
|
||||
if let Some(price_1m) = breaker.get_price_at(60) {
|
||||
let change = ((current - price_1m) / price_1m).abs();
|
||||
if change > self.config.max_1m_change {
|
||||
return self.trigger_breaker(pair, TriggerReason::RapidPriceChange {
|
||||
change_percent: change * Decimal::ONE_HUNDRED,
|
||||
window_seconds: 60,
|
||||
});
|
||||
return self.trigger_breaker(
|
||||
pair,
|
||||
TriggerReason::RapidPriceChange {
|
||||
change_percent: change * Decimal::ONE_HUNDRED,
|
||||
window_seconds: 60,
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -296,10 +304,13 @@ impl CircuitBreakerManager {
|
|||
if let Some(price_5m) = breaker.get_price_at(300) {
|
||||
let change = ((current - price_5m) / price_5m).abs();
|
||||
if change > self.config.max_5m_change {
|
||||
return self.trigger_breaker(pair, TriggerReason::RapidPriceChange {
|
||||
change_percent: change * Decimal::ONE_HUNDRED,
|
||||
window_seconds: 300,
|
||||
});
|
||||
return self.trigger_breaker(
|
||||
pair,
|
||||
TriggerReason::RapidPriceChange {
|
||||
change_percent: change * Decimal::ONE_HUNDRED,
|
||||
window_seconds: 300,
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -307,10 +318,13 @@ impl CircuitBreakerManager {
|
|||
if let Some(price_1h) = breaker.get_price_at(3600) {
|
||||
let change = ((current - price_1h) / price_1h).abs();
|
||||
if change > self.config.max_1h_change {
|
||||
return self.trigger_breaker(pair, TriggerReason::RapidPriceChange {
|
||||
change_percent: change * Decimal::ONE_HUNDRED,
|
||||
window_seconds: 3600,
|
||||
});
|
||||
return self.trigger_breaker(
|
||||
pair,
|
||||
TriggerReason::RapidPriceChange {
|
||||
change_percent: change * Decimal::ONE_HUNDRED,
|
||||
window_seconds: 3600,
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -318,20 +332,26 @@ impl CircuitBreakerManager {
|
|||
if let Some(twap) = breaker.twap_24h {
|
||||
let deviation = ((current - twap) / twap).abs();
|
||||
if deviation > self.config.max_twap_deviation {
|
||||
return self.trigger_breaker(pair, TriggerReason::ExcessiveDeviation {
|
||||
deviation_percent: deviation * Decimal::ONE_HUNDRED,
|
||||
reference_price: twap,
|
||||
});
|
||||
return self.trigger_breaker(
|
||||
pair,
|
||||
TriggerReason::ExcessiveDeviation {
|
||||
deviation_percent: deviation * Decimal::ONE_HUNDRED,
|
||||
reference_price: twap,
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Check liquidity
|
||||
if let Some(liquidity) = breaker.current_liquidity() {
|
||||
if liquidity < self.config.min_liquidity {
|
||||
return self.trigger_breaker(pair, TriggerReason::LowLiquidity {
|
||||
current: liquidity,
|
||||
threshold: self.config.min_liquidity,
|
||||
});
|
||||
return self.trigger_breaker(
|
||||
pair,
|
||||
TriggerReason::LowLiquidity {
|
||||
current: liquidity,
|
||||
threshold: self.config.min_liquidity,
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -340,9 +360,10 @@ impl CircuitBreakerManager {
|
|||
|
||||
/// Trigger the circuit breaker
|
||||
fn trigger_breaker(&mut self, pair: &str, reason: TriggerReason) -> Result<()> {
|
||||
let breaker = self.breakers.get_mut(pair).ok_or_else(||
|
||||
EconomicsError::PriceFeedUnavailable(pair.to_string())
|
||||
)?;
|
||||
let breaker = self
|
||||
.breakers
|
||||
.get_mut(pair)
|
||||
.ok_or_else(|| EconomicsError::PriceFeedUnavailable(pair.to_string()))?;
|
||||
|
||||
let event = CircuitEvent {
|
||||
pair: pair.to_string(),
|
||||
|
|
@ -361,7 +382,10 @@ impl CircuitBreakerManager {
|
|||
|
||||
// Check cascade triggers
|
||||
if self.config.cascade_enabled {
|
||||
let cascades: Vec<_> = self.config.cascade_pairs.iter()
|
||||
let cascades: Vec<_> = self
|
||||
.config
|
||||
.cascade_pairs
|
||||
.iter()
|
||||
.filter(|(source, _)| source == pair)
|
||||
.map(|(_, target)| target.clone())
|
||||
.collect();
|
||||
|
|
@ -407,10 +431,15 @@ impl CircuitBreakerManager {
|
|||
|
||||
// Get current state first (immutable borrow)
|
||||
let (current_state, triggered_at, trigger_reason) = {
|
||||
let breaker = self.breakers.get(pair).ok_or_else(||
|
||||
EconomicsError::PriceFeedUnavailable(pair.to_string())
|
||||
)?;
|
||||
(breaker.state, breaker.triggered_at, breaker.trigger_reason.clone())
|
||||
let breaker = self
|
||||
.breakers
|
||||
.get(pair)
|
||||
.ok_or_else(|| EconomicsError::PriceFeedUnavailable(pair.to_string()))?;
|
||||
(
|
||||
breaker.state,
|
||||
breaker.triggered_at,
|
||||
breaker.trigger_reason.clone(),
|
||||
)
|
||||
};
|
||||
|
||||
// Check stability for half-open state (immutable borrow)
|
||||
|
|
@ -421,9 +450,10 @@ impl CircuitBreakerManager {
|
|||
};
|
||||
|
||||
// Now get mutable reference for updates
|
||||
let breaker = self.breakers.get_mut(pair).ok_or_else(||
|
||||
EconomicsError::PriceFeedUnavailable(pair.to_string())
|
||||
)?;
|
||||
let breaker = self
|
||||
.breakers
|
||||
.get_mut(pair)
|
||||
.ok_or_else(|| EconomicsError::PriceFeedUnavailable(pair.to_string()))?;
|
||||
|
||||
match current_state {
|
||||
CircuitState::Open => {
|
||||
|
|
@ -435,9 +465,9 @@ impl CircuitBreakerManager {
|
|||
pair: pair.to_string(),
|
||||
from_state: CircuitState::Open,
|
||||
to_state: CircuitState::HalfOpen,
|
||||
reason: trigger_reason.clone().unwrap_or(
|
||||
TriggerReason::ManualHalt { reason: "Unknown".into() }
|
||||
),
|
||||
reason: trigger_reason.clone().unwrap_or(TriggerReason::ManualHalt {
|
||||
reason: "Unknown".into(),
|
||||
}),
|
||||
timestamp: Utc::now(),
|
||||
cooldown: None,
|
||||
};
|
||||
|
|
@ -457,9 +487,9 @@ impl CircuitBreakerManager {
|
|||
pair: pair.to_string(),
|
||||
from_state: CircuitState::HalfOpen,
|
||||
to_state: CircuitState::Closed,
|
||||
reason: trigger_reason.unwrap_or(
|
||||
TriggerReason::ManualHalt { reason: "Recovery".into() }
|
||||
),
|
||||
reason: trigger_reason.unwrap_or(TriggerReason::ManualHalt {
|
||||
reason: "Recovery".into(),
|
||||
}),
|
||||
timestamp: Utc::now(),
|
||||
cooldown: None,
|
||||
};
|
||||
|
|
@ -482,9 +512,10 @@ impl CircuitBreakerManager {
|
|||
|
||||
/// Check if market conditions are stable
|
||||
fn is_stable(&self, pair: &str) -> Result<bool> {
|
||||
let breaker = self.breakers.get(pair).ok_or_else(||
|
||||
EconomicsError::PriceFeedUnavailable(pair.to_string())
|
||||
)?;
|
||||
let breaker = self
|
||||
.breakers
|
||||
.get(pair)
|
||||
.ok_or_else(|| EconomicsError::PriceFeedUnavailable(pair.to_string()))?;
|
||||
|
||||
let current = match breaker.current_price() {
|
||||
Some(p) => p,
|
||||
|
|
@ -511,7 +542,8 @@ impl CircuitBreakerManager {
|
|||
|
||||
/// Get current state for a pair
|
||||
pub fn get_state(&self, pair: &str) -> CircuitState {
|
||||
self.breakers.get(pair)
|
||||
self.breakers
|
||||
.get(pair)
|
||||
.map(|b| b.state)
|
||||
.unwrap_or(CircuitState::Closed)
|
||||
}
|
||||
|
|
@ -523,25 +555,32 @@ impl CircuitBreakerManager {
|
|||
|
||||
/// Manually trigger circuit breaker
|
||||
pub fn manual_halt(&mut self, pair: &str, reason: impl Into<String>) -> Result<()> {
|
||||
self.breakers.entry(pair.to_string())
|
||||
self.breakers
|
||||
.entry(pair.to_string())
|
||||
.or_insert_with(PairCircuitBreaker::new);
|
||||
|
||||
self.trigger_breaker(pair, TriggerReason::ManualHalt {
|
||||
reason: reason.into(),
|
||||
})
|
||||
self.trigger_breaker(
|
||||
pair,
|
||||
TriggerReason::ManualHalt {
|
||||
reason: reason.into(),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
/// Manually reset circuit breaker
|
||||
pub fn manual_reset(&mut self, pair: &str) -> Result<()> {
|
||||
let breaker = self.breakers.get_mut(pair).ok_or_else(||
|
||||
EconomicsError::PriceFeedUnavailable(pair.to_string())
|
||||
)?;
|
||||
let breaker = self
|
||||
.breakers
|
||||
.get_mut(pair)
|
||||
.ok_or_else(|| EconomicsError::PriceFeedUnavailable(pair.to_string()))?;
|
||||
|
||||
let event = CircuitEvent {
|
||||
pair: pair.to_string(),
|
||||
from_state: breaker.state,
|
||||
to_state: CircuitState::Closed,
|
||||
reason: TriggerReason::ManualHalt { reason: "Manual reset".into() },
|
||||
reason: TriggerReason::ManualHalt {
|
||||
reason: "Manual reset".into(),
|
||||
},
|
||||
timestamp: Utc::now(),
|
||||
cooldown: None,
|
||||
};
|
||||
|
|
@ -558,26 +597,32 @@ impl CircuitBreakerManager {
|
|||
/// Record oracle disagreement
|
||||
pub fn record_oracle_spread(&mut self, pair: &str, spread: Decimal) -> Result<()> {
|
||||
if spread > self.config.max_oracle_spread {
|
||||
self.breakers.entry(pair.to_string())
|
||||
self.breakers
|
||||
.entry(pair.to_string())
|
||||
.or_insert_with(PairCircuitBreaker::new);
|
||||
|
||||
self.trigger_breaker(pair, TriggerReason::OracleDisagreement {
|
||||
spread_percent: spread * Decimal::ONE_HUNDRED,
|
||||
})?;
|
||||
self.trigger_breaker(
|
||||
pair,
|
||||
TriggerReason::OracleDisagreement {
|
||||
spread_percent: spread * Decimal::ONE_HUNDRED,
|
||||
},
|
||||
)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get event history for a pair
|
||||
pub fn get_events(&self, pair: &str) -> Vec<CircuitEvent> {
|
||||
self.breakers.get(pair)
|
||||
self.breakers
|
||||
.get(pair)
|
||||
.map(|b| b.events.clone())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Get all currently halted pairs
|
||||
pub fn get_halted_pairs(&self) -> Vec<(String, CircuitState, Option<TriggerReason>)> {
|
||||
self.breakers.iter()
|
||||
self.breakers
|
||||
.iter()
|
||||
.filter(|(_, b)| b.state != CircuitState::Closed)
|
||||
.map(|(pair, b)| (pair.clone(), b.state, b.trigger_reason.clone()))
|
||||
.collect()
|
||||
|
|
@ -586,8 +631,16 @@ impl CircuitBreakerManager {
|
|||
/// Get summary statistics
|
||||
pub fn get_stats(&self) -> CircuitBreakerStats {
|
||||
let total = self.breakers.len();
|
||||
let open = self.breakers.values().filter(|b| b.state == CircuitState::Open).count();
|
||||
let half_open = self.breakers.values().filter(|b| b.state == CircuitState::HalfOpen).count();
|
||||
let open = self
|
||||
.breakers
|
||||
.values()
|
||||
.filter(|b| b.state == CircuitState::Open)
|
||||
.count();
|
||||
let half_open = self
|
||||
.breakers
|
||||
.values()
|
||||
.filter(|b| b.state == CircuitState::HalfOpen)
|
||||
.count();
|
||||
let total_events: usize = self.breakers.values().map(|b| b.events.len()).sum();
|
||||
|
||||
CircuitBreakerStats {
|
||||
|
|
@ -628,7 +681,9 @@ mod tests {
|
|||
// Normal price movements should not trigger
|
||||
for i in 0..10 {
|
||||
let price = dec!(100) + Decimal::from(i);
|
||||
let state = manager.record_price("SYNOR/USD", price, Some(dec!(100000))).unwrap();
|
||||
let state = manager
|
||||
.record_price("SYNOR/USD", price, Some(dec!(100000)))
|
||||
.unwrap();
|
||||
assert_eq!(state, CircuitState::Closed);
|
||||
}
|
||||
}
|
||||
|
|
@ -641,10 +696,19 @@ mod tests {
|
|||
let now = Utc::now();
|
||||
|
||||
// Record baseline 2 minutes ago
|
||||
manager.record_price_at("SYNOR/USD", dec!(100), Some(dec!(100000)), now - Duration::minutes(2)).unwrap();
|
||||
manager
|
||||
.record_price_at(
|
||||
"SYNOR/USD",
|
||||
dec!(100),
|
||||
Some(dec!(100000)),
|
||||
now - Duration::minutes(2),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// Simulate 15% drop (exceeds 10% 1-minute threshold)
|
||||
let state = manager.record_price_at("SYNOR/USD", dec!(85), Some(dec!(100000)), now).unwrap();
|
||||
let state = manager
|
||||
.record_price_at("SYNOR/USD", dec!(85), Some(dec!(100000)), now)
|
||||
.unwrap();
|
||||
assert_eq!(state, CircuitState::Open);
|
||||
}
|
||||
|
||||
|
|
@ -653,7 +717,9 @@ mod tests {
|
|||
let mut manager = CircuitBreakerManager::new();
|
||||
|
||||
// Record with very low liquidity
|
||||
let state = manager.record_price("SYNOR/USD", dec!(100), Some(dec!(100))).unwrap();
|
||||
let state = manager
|
||||
.record_price("SYNOR/USD", dec!(100), Some(dec!(100)))
|
||||
.unwrap();
|
||||
assert_eq!(state, CircuitState::Open);
|
||||
}
|
||||
|
||||
|
|
@ -661,11 +727,15 @@ mod tests {
|
|||
fn test_manual_halt_and_reset() {
|
||||
let mut manager = CircuitBreakerManager::new();
|
||||
|
||||
manager.record_price("SYNOR/USD", dec!(100), Some(dec!(100000))).unwrap();
|
||||
manager
|
||||
.record_price("SYNOR/USD", dec!(100), Some(dec!(100000)))
|
||||
.unwrap();
|
||||
assert!(manager.is_trading_allowed("SYNOR/USD"));
|
||||
|
||||
// Manual halt
|
||||
manager.manual_halt("SYNOR/USD", "Scheduled maintenance").unwrap();
|
||||
manager
|
||||
.manual_halt("SYNOR/USD", "Scheduled maintenance")
|
||||
.unwrap();
|
||||
assert!(!manager.is_trading_allowed("SYNOR/USD"));
|
||||
|
||||
// Manual reset
|
||||
|
|
@ -678,10 +748,14 @@ mod tests {
|
|||
let mut manager = CircuitBreakerManager::new();
|
||||
|
||||
// Initialize
|
||||
manager.record_price("SYNOR/USD", dec!(100), Some(dec!(100000))).unwrap();
|
||||
manager
|
||||
.record_price("SYNOR/USD", dec!(100), Some(dec!(100000)))
|
||||
.unwrap();
|
||||
|
||||
// Record 10% spread (exceeds 5% threshold)
|
||||
manager.record_oracle_spread("SYNOR/USD", dec!(0.10)).unwrap();
|
||||
manager
|
||||
.record_oracle_spread("SYNOR/USD", dec!(0.10))
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(manager.get_state("SYNOR/USD"), CircuitState::Open);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -200,7 +200,10 @@ pub struct CrossChainConfig {
|
|||
impl Default for CrossChainConfig {
|
||||
fn default() -> Self {
|
||||
let mut tracked = HashMap::new();
|
||||
tracked.insert(ChainNetwork::Ethereum, vec!["ETH".to_string(), "USDC".to_string(), "USDT".to_string()]);
|
||||
tracked.insert(
|
||||
ChainNetwork::Ethereum,
|
||||
vec!["ETH".to_string(), "USDC".to_string(), "USDT".to_string()],
|
||||
);
|
||||
tracked.insert(ChainNetwork::Bitcoin, vec!["BTC".to_string()]);
|
||||
tracked.insert(ChainNetwork::Cosmos, vec!["ATOM".to_string()]);
|
||||
tracked.insert(ChainNetwork::Osmosis, vec!["OSMO".to_string()]);
|
||||
|
|
@ -305,16 +308,21 @@ impl CrossChainOracle {
|
|||
};
|
||||
|
||||
if !verified {
|
||||
return Err(EconomicsError::InvalidPrice("Packet verification failed".into()));
|
||||
return Err(EconomicsError::InvalidPrice(
|
||||
"Packet verification failed".into(),
|
||||
));
|
||||
}
|
||||
|
||||
// Cache the price
|
||||
let pair_key = format!("{}/{}", packet.token, packet.quote);
|
||||
self.cache.insert(pair_key, CrossChainPrice {
|
||||
packet,
|
||||
received_at: Utc::now(),
|
||||
verified,
|
||||
});
|
||||
self.cache.insert(
|
||||
pair_key,
|
||||
CrossChainPrice {
|
||||
packet,
|
||||
received_at: Utc::now(),
|
||||
verified,
|
||||
},
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
@ -326,7 +334,9 @@ impl CrossChainOracle {
|
|||
token: &str,
|
||||
quote: &str,
|
||||
) -> Result<CrossChainPricePacket> {
|
||||
let fetcher = self.fetchers.get(&chain)
|
||||
let fetcher = self
|
||||
.fetchers
|
||||
.get(&chain)
|
||||
.ok_or_else(|| EconomicsError::PriceFeedUnavailable(format!("{:?}", chain)))?;
|
||||
|
||||
let packet = fetcher.fetch_price(token, quote).await?;
|
||||
|
|
@ -334,11 +344,14 @@ impl CrossChainOracle {
|
|||
// Verify and cache
|
||||
if fetcher.verify_packet(&packet) {
|
||||
let pair_key = format!("{}/{}", token, quote);
|
||||
self.cache.insert(pair_key.clone(), CrossChainPrice {
|
||||
packet: packet.clone(),
|
||||
received_at: Utc::now(),
|
||||
verified: true,
|
||||
});
|
||||
self.cache.insert(
|
||||
pair_key.clone(),
|
||||
CrossChainPrice {
|
||||
packet: packet.clone(),
|
||||
received_at: Utc::now(),
|
||||
verified: true,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
Ok(packet)
|
||||
|
|
@ -347,7 +360,8 @@ impl CrossChainOracle {
|
|||
/// Get cached price for a token pair
|
||||
pub fn get_price(&self, token: &str, quote: &str) -> Option<SynorDecimal> {
|
||||
let pair_key = format!("{}/{}", token, quote);
|
||||
self.cache.get(&pair_key)
|
||||
self.cache
|
||||
.get(&pair_key)
|
||||
.filter(|c| c.verified)
|
||||
.filter(|c| (Utc::now() - c.received_at).num_seconds() < self.config.max_packet_age)
|
||||
.map(|c| c.packet.price)
|
||||
|
|
@ -356,7 +370,8 @@ impl CrossChainOracle {
|
|||
/// Get price with full packet info
|
||||
pub fn get_price_with_info(&self, token: &str, quote: &str) -> Option<&CrossChainPricePacket> {
|
||||
let pair_key = format!("{}/{}", token, quote);
|
||||
self.cache.get(&pair_key)
|
||||
self.cache
|
||||
.get(&pair_key)
|
||||
.filter(|c| c.verified)
|
||||
.map(|c| &c.packet)
|
||||
}
|
||||
|
|
@ -408,7 +423,8 @@ impl CrossChainOracle {
|
|||
|
||||
/// Get all cached prices
|
||||
pub fn get_all_prices(&self) -> Vec<TokenPrice> {
|
||||
self.cache.values()
|
||||
self.cache
|
||||
.values()
|
||||
.filter(|c| c.verified)
|
||||
.map(|c| c.packet.to_token_price())
|
||||
.collect()
|
||||
|
|
@ -417,9 +433,8 @@ impl CrossChainOracle {
|
|||
/// Clear stale cache entries
|
||||
pub fn cleanup_cache(&mut self) {
|
||||
let max_age = self.config.max_packet_age;
|
||||
self.cache.retain(|_, v| {
|
||||
(Utc::now() - v.received_at).num_seconds() < max_age
|
||||
});
|
||||
self.cache
|
||||
.retain(|_, v| (Utc::now() - v.received_at).num_seconds() < max_age);
|
||||
}
|
||||
|
||||
/// Send an IBC price request and track pending packet
|
||||
|
|
@ -450,7 +465,11 @@ impl CrossChainOracle {
|
|||
}
|
||||
|
||||
/// Confirm a pending packet was received
|
||||
pub fn confirm_pending_packet(&mut self, channel: &str, sequence: u64) -> Option<PendingPacket> {
|
||||
pub fn confirm_pending_packet(
|
||||
&mut self,
|
||||
channel: &str,
|
||||
sequence: u64,
|
||||
) -> Option<PendingPacket> {
|
||||
if let Some(idx) = self
|
||||
.pending_packets
|
||||
.iter()
|
||||
|
|
@ -469,9 +488,8 @@ impl CrossChainOracle {
|
|||
|
||||
/// Cleanup timed out pending packets
|
||||
pub fn cleanup_pending(&mut self, timeout_secs: i64) {
|
||||
self.pending_packets.retain(|p| {
|
||||
(Utc::now() - p.sent_at).num_seconds() < timeout_secs
|
||||
});
|
||||
self.pending_packets
|
||||
.retain(|p| (Utc::now() - p.sent_at).num_seconds() < timeout_secs);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -502,9 +520,18 @@ pub struct EthereumPriceFetcher {
|
|||
impl EthereumPriceFetcher {
|
||||
pub fn new(rpc_url: impl Into<String>) -> Self {
|
||||
let mut feeds = HashMap::new();
|
||||
feeds.insert("ETH/USD".to_string(), "0x5f4eC3Df9cbd43714FE2740f5E3616155c5b8419".to_string());
|
||||
feeds.insert("BTC/USD".to_string(), "0xF4030086522a5bEEa4988F8cA5B36dbC97BeE88c".to_string());
|
||||
feeds.insert("USDC/USD".to_string(), "0x8fFfFfd4AfB6115b954Bd326cbe7B4BA576818f6".to_string());
|
||||
feeds.insert(
|
||||
"ETH/USD".to_string(),
|
||||
"0x5f4eC3Df9cbd43714FE2740f5E3616155c5b8419".to_string(),
|
||||
);
|
||||
feeds.insert(
|
||||
"BTC/USD".to_string(),
|
||||
"0xF4030086522a5bEEa4988F8cA5B36dbC97BeE88c".to_string(),
|
||||
);
|
||||
feeds.insert(
|
||||
"USDC/USD".to_string(),
|
||||
"0x8fFfFfd4AfB6115b954Bd326cbe7B4BA576818f6".to_string(),
|
||||
);
|
||||
|
||||
Self {
|
||||
rpc_url: rpc_url.into(),
|
||||
|
|
@ -526,7 +553,9 @@ impl ChainPriceFetcher for EthereumPriceFetcher {
|
|||
|
||||
async fn fetch_price(&self, token: &str, quote: &str) -> Result<CrossChainPricePacket> {
|
||||
let pair = format!("{}/{}", token, quote);
|
||||
let _feed_addr = self.chainlink_feeds.get(&pair)
|
||||
let _feed_addr = self
|
||||
.chainlink_feeds
|
||||
.get(&pair)
|
||||
.ok_or_else(|| EconomicsError::PriceFeedUnavailable(pair.clone()))?;
|
||||
|
||||
// In production: Call Chainlink aggregator via ethers-rs
|
||||
|
|
@ -539,19 +568,16 @@ impl ChainPriceFetcher for EthereumPriceFetcher {
|
|||
source_block: 19000000,
|
||||
source_timestamp: Utc::now(),
|
||||
proof: None,
|
||||
signatures: vec![
|
||||
OracleSignature {
|
||||
signer: "chainlink".to_string(),
|
||||
signature: vec![0; 65],
|
||||
timestamp: Utc::now(),
|
||||
},
|
||||
],
|
||||
signatures: vec![OracleSignature {
|
||||
signer: "chainlink".to_string(),
|
||||
signature: vec![0; 65],
|
||||
timestamp: Utc::now(),
|
||||
}],
|
||||
})
|
||||
}
|
||||
|
||||
fn verify_packet(&self, packet: &CrossChainPricePacket) -> bool {
|
||||
packet.source_chain == ChainNetwork::Ethereum
|
||||
&& !packet.signatures.is_empty()
|
||||
packet.source_chain == ChainNetwork::Ethereum && !packet.signatures.is_empty()
|
||||
}
|
||||
|
||||
fn supported_tokens(&self) -> Vec<String> {
|
||||
|
|
@ -611,8 +637,7 @@ impl ChainPriceFetcher for CosmosPriceFetcher {
|
|||
}
|
||||
|
||||
fn verify_packet(&self, packet: &CrossChainPricePacket) -> bool {
|
||||
packet.source_chain == ChainNetwork::Cosmos
|
||||
&& packet.proof.is_some()
|
||||
packet.source_chain == ChainNetwork::Cosmos && packet.proof.is_some()
|
||||
}
|
||||
|
||||
fn supported_tokens(&self) -> Vec<String> {
|
||||
|
|
@ -650,7 +675,11 @@ impl CrossChainOracleBuilder {
|
|||
}
|
||||
|
||||
/// Add Cosmos/IBC price fetcher
|
||||
pub fn with_cosmos(mut self, light_client_id: impl Into<String>, chain_id: impl Into<String>) -> Self {
|
||||
pub fn with_cosmos(
|
||||
mut self,
|
||||
light_client_id: impl Into<String>,
|
||||
chain_id: impl Into<String>,
|
||||
) -> Self {
|
||||
self.cosmos_light_client = Some((light_client_id.into(), chain_id.into()));
|
||||
self
|
||||
}
|
||||
|
|
@ -693,8 +722,7 @@ impl CrossChainOracleFactory {
|
|||
|
||||
/// Create a production oracle with real endpoints
|
||||
pub fn production(config: CrossChainProductionConfig) -> CrossChainOracle {
|
||||
let mut builder = CrossChainOracleBuilder::new()
|
||||
.with_config(config.cross_chain_config);
|
||||
let mut builder = CrossChainOracleBuilder::new().with_config(config.cross_chain_config);
|
||||
|
||||
if let Some(eth_rpc) = config.ethereum_rpc_url {
|
||||
builder = builder.with_ethereum(eth_rpc);
|
||||
|
|
@ -715,7 +743,10 @@ impl CrossChainOracleFactory {
|
|||
}
|
||||
|
||||
/// Create an oracle with only Cosmos/IBC support
|
||||
pub fn cosmos_only(light_client_id: impl Into<String>, chain_id: impl Into<String>) -> CrossChainOracle {
|
||||
pub fn cosmos_only(
|
||||
light_client_id: impl Into<String>,
|
||||
chain_id: impl Into<String>,
|
||||
) -> CrossChainOracle {
|
||||
CrossChainOracleBuilder::new()
|
||||
.with_cosmos(light_client_id, chain_id)
|
||||
.build()
|
||||
|
|
|
|||
|
|
@ -112,7 +112,9 @@ impl AggregationRound {
|
|||
/// Add a submission to this round
|
||||
pub fn add_submission(&mut self, submission: PriceSubmission) -> Result<()> {
|
||||
if self.finalized {
|
||||
return Err(EconomicsError::InvalidPrice("Round already finalized".into()));
|
||||
return Err(EconomicsError::InvalidPrice(
|
||||
"Round already finalized".into(),
|
||||
));
|
||||
}
|
||||
if Utc::now() >= self.deadline {
|
||||
return Err(EconomicsError::InvalidPrice("Round deadline passed".into()));
|
||||
|
|
@ -122,7 +124,11 @@ impl AggregationRound {
|
|||
}
|
||||
|
||||
// Check for duplicate submission from same node
|
||||
if self.submissions.iter().any(|s| s.node_id == submission.node_id) {
|
||||
if self
|
||||
.submissions
|
||||
.iter()
|
||||
.any(|s| s.node_id == submission.node_id)
|
||||
{
|
||||
return Err(EconomicsError::InvalidPrice("Duplicate submission".into()));
|
||||
}
|
||||
|
||||
|
|
@ -231,7 +237,9 @@ impl DecentralizedOracle {
|
|||
|
||||
/// Update node heartbeat
|
||||
pub fn heartbeat(&mut self, node_id: &str) -> Result<()> {
|
||||
let node = self.nodes.get_mut(node_id)
|
||||
let node = self
|
||||
.nodes
|
||||
.get_mut(node_id)
|
||||
.ok_or_else(|| EconomicsError::InvalidPrice(format!("Unknown node: {}", node_id)))?;
|
||||
node.last_heartbeat = Utc::now();
|
||||
Ok(())
|
||||
|
|
@ -241,11 +249,8 @@ impl DecentralizedOracle {
|
|||
pub fn start_round(&mut self, pair: impl Into<String>) -> u64 {
|
||||
let pair = pair.into();
|
||||
self.round_counter += 1;
|
||||
let round = AggregationRound::new(
|
||||
self.round_counter,
|
||||
pair.clone(),
|
||||
self.config.round_duration
|
||||
);
|
||||
let round =
|
||||
AggregationRound::new(self.round_counter, pair.clone(), self.config.round_duration);
|
||||
self.current_rounds.insert(pair, round);
|
||||
self.round_counter
|
||||
}
|
||||
|
|
@ -253,7 +258,9 @@ impl DecentralizedOracle {
|
|||
/// Submit a price for the current round
|
||||
pub fn submit_price(&mut self, submission: PriceSubmission) -> Result<()> {
|
||||
// Verify node exists and is eligible
|
||||
let node = self.nodes.get(&submission.node_id)
|
||||
let node = self
|
||||
.nodes
|
||||
.get(&submission.node_id)
|
||||
.ok_or_else(|| EconomicsError::InvalidPrice("Unknown node".into()))?;
|
||||
|
||||
if !node.is_eligible(self.config.min_stake, self.config.min_reputation) {
|
||||
|
|
@ -264,7 +271,9 @@ impl DecentralizedOracle {
|
|||
// For now, we trust the submission
|
||||
|
||||
// Add to current round
|
||||
let round = self.current_rounds.get_mut(&submission.pair)
|
||||
let round = self
|
||||
.current_rounds
|
||||
.get_mut(&submission.pair)
|
||||
.ok_or_else(|| EconomicsError::InvalidPrice("No active round for pair".into()))?;
|
||||
|
||||
round.add_submission(submission)
|
||||
|
|
@ -274,13 +283,15 @@ impl DecentralizedOracle {
|
|||
pub fn finalize_round(&mut self, pair: &str) -> Result<SynorDecimal> {
|
||||
// First check state and get submissions (immutable borrow)
|
||||
let (is_finalized, existing_price, submissions) = {
|
||||
let round = self.current_rounds.get(pair)
|
||||
let round = self
|
||||
.current_rounds
|
||||
.get(pair)
|
||||
.ok_or_else(|| EconomicsError::PriceFeedUnavailable(pair.to_string()))?;
|
||||
|
||||
if round.finalized {
|
||||
return round.final_price.ok_or_else(||
|
||||
EconomicsError::InvalidPrice("Round has no price".into())
|
||||
);
|
||||
return round
|
||||
.final_price
|
||||
.ok_or_else(|| EconomicsError::InvalidPrice("Round has no price".into()));
|
||||
}
|
||||
|
||||
// Check minimum submissions
|
||||
|
|
@ -292,13 +303,16 @@ impl DecentralizedOracle {
|
|||
)));
|
||||
}
|
||||
|
||||
(round.finalized, round.final_price, round.submissions.clone())
|
||||
(
|
||||
round.finalized,
|
||||
round.final_price,
|
||||
round.submissions.clone(),
|
||||
)
|
||||
};
|
||||
|
||||
if is_finalized {
|
||||
return existing_price.ok_or_else(||
|
||||
EconomicsError::InvalidPrice("Round has no price".into())
|
||||
);
|
||||
return existing_price
|
||||
.ok_or_else(|| EconomicsError::InvalidPrice("Round has no price".into()));
|
||||
}
|
||||
|
||||
// Filter outliers and aggregate (using cloned submissions)
|
||||
|
|
@ -326,7 +340,11 @@ impl DecentralizedOracle {
|
|||
}
|
||||
|
||||
/// Aggregate prices from a vector of submissions (owned)
|
||||
fn aggregate_prices_from_vec(&self, pair: &str, submissions: &[PriceSubmission]) -> Result<SynorDecimal> {
|
||||
fn aggregate_prices_from_vec(
|
||||
&self,
|
||||
pair: &str,
|
||||
submissions: &[PriceSubmission],
|
||||
) -> Result<SynorDecimal> {
|
||||
if submissions.is_empty() {
|
||||
return Err(EconomicsError::PriceFeedUnavailable(pair.to_string()));
|
||||
}
|
||||
|
|
@ -334,15 +352,21 @@ impl DecentralizedOracle {
|
|||
// Filter outliers first
|
||||
let filtered = self.filter_outliers_vec(submissions);
|
||||
if filtered.is_empty() {
|
||||
return Err(EconomicsError::InvalidPrice("All submissions were outliers".into()));
|
||||
return Err(EconomicsError::InvalidPrice(
|
||||
"All submissions were outliers".into(),
|
||||
));
|
||||
}
|
||||
|
||||
let filtered_refs: Vec<_> = filtered.iter().collect();
|
||||
match self.strategy {
|
||||
AggregationStrategy::Median => self.calculate_median(&filtered_refs),
|
||||
AggregationStrategy::StakeWeightedMedian => self.calculate_stake_weighted_median(&filtered_refs),
|
||||
AggregationStrategy::StakeWeightedMedian => {
|
||||
self.calculate_stake_weighted_median(&filtered_refs)
|
||||
}
|
||||
AggregationStrategy::TrimmedMean => self.calculate_trimmed_mean(&filtered_refs),
|
||||
AggregationStrategy::ReputationWeighted => self.calculate_reputation_weighted(&filtered_refs),
|
||||
AggregationStrategy::ReputationWeighted => {
|
||||
self.calculate_reputation_weighted(&filtered_refs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -376,13 +400,14 @@ impl DecentralizedOracle {
|
|||
}
|
||||
|
||||
/// Calculate stake-weighted median
|
||||
fn calculate_stake_weighted_median(&self, submissions: &[&PriceSubmission]) -> Result<SynorDecimal> {
|
||||
fn calculate_stake_weighted_median(
|
||||
&self,
|
||||
submissions: &[&PriceSubmission],
|
||||
) -> Result<SynorDecimal> {
|
||||
// Get stake for each submission
|
||||
let mut weighted: Vec<(SynorDecimal, SynorDecimal)> = submissions
|
||||
.iter()
|
||||
.filter_map(|s| {
|
||||
self.nodes.get(&s.node_id).map(|n| (s.price, n.stake))
|
||||
})
|
||||
.filter_map(|s| self.nodes.get(&s.node_id).map(|n| (s.price, n.stake)))
|
||||
.collect();
|
||||
|
||||
if weighted.is_empty() {
|
||||
|
|
@ -424,12 +449,17 @@ impl DecentralizedOracle {
|
|||
}
|
||||
|
||||
/// Calculate reputation-weighted average
|
||||
fn calculate_reputation_weighted(&self, submissions: &[&PriceSubmission]) -> Result<SynorDecimal> {
|
||||
fn calculate_reputation_weighted(
|
||||
&self,
|
||||
submissions: &[&PriceSubmission],
|
||||
) -> Result<SynorDecimal> {
|
||||
let mut weighted_sum = Decimal::ZERO;
|
||||
let mut total_weight = Decimal::ZERO;
|
||||
|
||||
for sub in submissions {
|
||||
let reputation = self.nodes.get(&sub.node_id)
|
||||
let reputation = self
|
||||
.nodes
|
||||
.get(&sub.node_id)
|
||||
.map(|n| n.reputation)
|
||||
.unwrap_or(0.5);
|
||||
|
||||
|
|
@ -448,7 +478,9 @@ impl DecentralizedOracle {
|
|||
/// Update node reputations based on submission accuracy
|
||||
fn update_reputations(&mut self, _pair: &str, final_price: SynorDecimal) {
|
||||
// Get submissions from current round before it was moved
|
||||
let submissions: Vec<_> = self.history.last()
|
||||
let submissions: Vec<_> = self
|
||||
.history
|
||||
.last()
|
||||
.map(|r| r.submissions.clone())
|
||||
.unwrap_or_default();
|
||||
|
||||
|
|
@ -457,7 +489,8 @@ impl DecentralizedOracle {
|
|||
let deviation = (sub.price - final_price).abs() / final_price;
|
||||
|
||||
// Increase reputation for accurate submissions, decrease for inaccurate
|
||||
if deviation <= Decimal::new(1, 2) { // Within 1%
|
||||
if deviation <= Decimal::new(1, 2) {
|
||||
// Within 1%
|
||||
node.reputation = (node.reputation + 0.01).min(1.0);
|
||||
} else if deviation > self.config.max_deviation {
|
||||
node.reputation = (node.reputation - 0.05).max(0.0);
|
||||
|
|
@ -473,7 +506,8 @@ impl DecentralizedOracle {
|
|||
|
||||
/// Get number of active nodes
|
||||
pub fn active_node_count(&self) -> usize {
|
||||
self.nodes.values()
|
||||
self.nodes
|
||||
.values()
|
||||
.filter(|n| n.is_eligible(self.config.min_stake, self.config.min_reputation))
|
||||
.count()
|
||||
}
|
||||
|
|
@ -492,20 +526,23 @@ impl DecentralizedOracle {
|
|||
|
||||
/// Convert finalized price to TokenPrice
|
||||
pub fn to_token_price(&self, pair: &str) -> Option<TokenPrice> {
|
||||
self.history.iter()
|
||||
self.history
|
||||
.iter()
|
||||
.rev()
|
||||
.find(|r| r.pair == pair && r.finalized)
|
||||
.and_then(|r| r.final_price.map(|price| {
|
||||
let parts: Vec<_> = pair.split('/').collect();
|
||||
TokenPrice {
|
||||
token: parts.get(0).unwrap_or(&"").to_string(),
|
||||
quote: parts.get(1).unwrap_or(&"").to_string(),
|
||||
price,
|
||||
timestamp: r.deadline,
|
||||
source: PriceSource::Aggregated,
|
||||
confidence: 1.0,
|
||||
}
|
||||
}))
|
||||
.and_then(|r| {
|
||||
r.final_price.map(|price| {
|
||||
let parts: Vec<_> = pair.split('/').collect();
|
||||
TokenPrice {
|
||||
token: parts.get(0).unwrap_or(&"").to_string(),
|
||||
quote: parts.get(1).unwrap_or(&"").to_string(),
|
||||
price,
|
||||
timestamp: r.deadline,
|
||||
source: PriceSource::Aggregated,
|
||||
confidence: 1.0,
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -521,13 +558,15 @@ mod tests {
|
|||
use rust_decimal_macros::dec;
|
||||
|
||||
fn create_test_nodes() -> Vec<OracleNode> {
|
||||
(0..5).map(|i| {
|
||||
OracleNode::new(
|
||||
format!("node_{}", i),
|
||||
vec![i as u8; 32],
|
||||
dec!(10000), // 10k stake
|
||||
)
|
||||
}).collect()
|
||||
(0..5)
|
||||
.map(|i| {
|
||||
OracleNode::new(
|
||||
format!("node_{}", i),
|
||||
vec![i as u8; 32],
|
||||
dec!(10000), // 10k stake
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
|||
|
|
@ -6,8 +6,8 @@
|
|||
use crate::error::{EconomicsError, Result};
|
||||
use crate::SynorDecimal;
|
||||
use chrono::{DateTime, Duration, Timelike, Utc};
|
||||
use rust_decimal::Decimal;
|
||||
use rust_decimal::prelude::ToPrimitive;
|
||||
use rust_decimal::Decimal;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::f64::consts::PI;
|
||||
|
|
@ -178,7 +178,12 @@ impl BlackScholes {
|
|||
}
|
||||
|
||||
/// Price a European option
|
||||
pub fn price(&self, contract: &OptionContract, spot: SynorDecimal, vol: f64) -> Result<OptionPricing> {
|
||||
pub fn price(
|
||||
&self,
|
||||
contract: &OptionContract,
|
||||
spot: SynorDecimal,
|
||||
vol: f64,
|
||||
) -> Result<OptionPricing> {
|
||||
if contract.is_expired() {
|
||||
// At expiration, option is worth intrinsic value
|
||||
let intrinsic = contract.intrinsic_value(spot);
|
||||
|
|
@ -208,12 +213,13 @@ impl BlackScholes {
|
|||
});
|
||||
}
|
||||
|
||||
let s = spot.to_f64().ok_or_else(||
|
||||
EconomicsError::InvalidPrice("Invalid spot price".into())
|
||||
)?;
|
||||
let k = contract.strike.to_f64().ok_or_else(||
|
||||
EconomicsError::InvalidPrice("Invalid strike price".into())
|
||||
)?;
|
||||
let s = spot
|
||||
.to_f64()
|
||||
.ok_or_else(|| EconomicsError::InvalidPrice("Invalid spot price".into()))?;
|
||||
let k = contract
|
||||
.strike
|
||||
.to_f64()
|
||||
.ok_or_else(|| EconomicsError::InvalidPrice("Invalid strike price".into()))?;
|
||||
let t = contract.time_to_expiry();
|
||||
|
||||
if t <= 0.0 || vol <= 0.0 {
|
||||
|
|
@ -286,13 +292,10 @@ impl BlackScholes {
|
|||
let theta_common = -(s * vol * (-q * t).exp() * n_prime_d1) / (2.0 * sqrt_t);
|
||||
let theta = match contract.option_type {
|
||||
OptionType::Call => {
|
||||
theta_common
|
||||
+ q * s * (-q * t).exp() * n_d1
|
||||
- r * k * (-r * t).exp() * n_d2
|
||||
theta_common + q * s * (-q * t).exp() * n_d1 - r * k * (-r * t).exp() * n_d2
|
||||
}
|
||||
OptionType::Put => {
|
||||
theta_common
|
||||
- q * s * (-q * t).exp() * (1.0 - n_d1)
|
||||
theta_common - q * s * (-q * t).exp() * (1.0 - n_d1)
|
||||
+ r * k * (-r * t).exp() * (1.0 - n_d2)
|
||||
}
|
||||
} / 365.0; // Per day
|
||||
|
|
@ -327,16 +330,16 @@ impl BlackScholes {
|
|||
spot: SynorDecimal,
|
||||
market_price: SynorDecimal,
|
||||
) -> Result<f64> {
|
||||
let target = market_price.to_f64().ok_or_else(||
|
||||
EconomicsError::InvalidPrice("Invalid market price".into())
|
||||
)?;
|
||||
let target = market_price
|
||||
.to_f64()
|
||||
.ok_or_else(|| EconomicsError::InvalidPrice("Invalid market price".into()))?;
|
||||
|
||||
// Initial guess based on time value
|
||||
let intrinsic = contract.intrinsic_value(spot).to_f64().unwrap_or(0.0);
|
||||
let time_value = (target - intrinsic).max(0.0);
|
||||
let s = spot.to_f64().ok_or_else(||
|
||||
EconomicsError::InvalidPrice("Invalid spot".into())
|
||||
)?;
|
||||
let s = spot
|
||||
.to_f64()
|
||||
.ok_or_else(|| EconomicsError::InvalidPrice("Invalid spot".into()))?;
|
||||
let t = contract.time_to_expiry();
|
||||
|
||||
// Brenner-Subrahmanyam approximation for initial guess
|
||||
|
|
@ -381,7 +384,11 @@ impl BlackScholes {
|
|||
|
||||
for _ in 0..100 {
|
||||
let mid = (low + high) / 2.0;
|
||||
let price = self.price(contract, spot, mid)?.price.to_f64().unwrap_or(0.0);
|
||||
let price = self
|
||||
.price(contract, spot, mid)?
|
||||
.price
|
||||
.to_f64()
|
||||
.unwrap_or(0.0);
|
||||
|
||||
if (price - target).abs() < 0.0001 {
|
||||
return Ok(mid);
|
||||
|
|
@ -493,9 +500,9 @@ impl FuturesModel {
|
|||
/// F = S * e^((r + u - y) * T)
|
||||
/// where r = risk-free rate, u = storage cost, y = convenience yield
|
||||
pub fn price(&self, contract: &FuturesContract, spot: SynorDecimal) -> Result<FuturesPricing> {
|
||||
let s = spot.to_f64().ok_or_else(||
|
||||
EconomicsError::InvalidPrice("Invalid spot price".into())
|
||||
)?;
|
||||
let s = spot
|
||||
.to_f64()
|
||||
.ok_or_else(|| EconomicsError::InvalidPrice("Invalid spot price".into()))?;
|
||||
|
||||
let t = contract.time_to_expiry();
|
||||
if t < 0.0 {
|
||||
|
|
@ -515,8 +522,7 @@ impl FuturesModel {
|
|||
0.0
|
||||
};
|
||||
|
||||
let coc = Decimal::from_f64_retain(cost_of_carry * t * s)
|
||||
.unwrap_or(Decimal::ZERO);
|
||||
let coc = Decimal::from_f64_retain(cost_of_carry * t * s).unwrap_or(Decimal::ZERO);
|
||||
|
||||
Ok(FuturesPricing {
|
||||
fair_value,
|
||||
|
|
@ -529,13 +535,18 @@ impl FuturesModel {
|
|||
|
||||
/// Calculate implied repo rate from futures price
|
||||
/// R = (F/S - 1) / T
|
||||
pub fn implied_repo_rate(&self, contract: &FuturesContract, spot: SynorDecimal, futures_price: SynorDecimal) -> Result<f64> {
|
||||
let s = spot.to_f64().ok_or_else(||
|
||||
EconomicsError::InvalidPrice("Invalid spot".into())
|
||||
)?;
|
||||
let f = futures_price.to_f64().ok_or_else(||
|
||||
EconomicsError::InvalidPrice("Invalid futures price".into())
|
||||
)?;
|
||||
pub fn implied_repo_rate(
|
||||
&self,
|
||||
contract: &FuturesContract,
|
||||
spot: SynorDecimal,
|
||||
futures_price: SynorDecimal,
|
||||
) -> Result<f64> {
|
||||
let s = spot
|
||||
.to_f64()
|
||||
.ok_or_else(|| EconomicsError::InvalidPrice("Invalid spot".into()))?;
|
||||
let f = futures_price
|
||||
.to_f64()
|
||||
.ok_or_else(|| EconomicsError::InvalidPrice("Invalid futures price".into()))?;
|
||||
let t = contract.time_to_expiry();
|
||||
|
||||
if t <= 0.0 || s <= 0.0 {
|
||||
|
|
@ -595,12 +606,12 @@ impl PerpetualModel {
|
|||
mark_price: SynorDecimal,
|
||||
index_price: SynorDecimal,
|
||||
) -> Result<f64> {
|
||||
let mark = mark_price.to_f64().ok_or_else(||
|
||||
EconomicsError::InvalidPrice("Invalid mark price".into())
|
||||
)?;
|
||||
let index = index_price.to_f64().ok_or_else(||
|
||||
EconomicsError::InvalidPrice("Invalid index price".into())
|
||||
)?;
|
||||
let mark = mark_price
|
||||
.to_f64()
|
||||
.ok_or_else(|| EconomicsError::InvalidPrice("Invalid mark price".into()))?;
|
||||
let index = index_price
|
||||
.to_f64()
|
||||
.ok_or_else(|| EconomicsError::InvalidPrice("Invalid index price".into()))?;
|
||||
|
||||
if index <= 0.0 {
|
||||
return Err(EconomicsError::InvalidPrice("Invalid index".into()));
|
||||
|
|
@ -640,7 +651,8 @@ impl PerpetualModel {
|
|||
let hours_since_midnight = now.time().hour();
|
||||
let next_funding_hour = ((hours_since_midnight / self.funding_interval_hours) + 1)
|
||||
* self.funding_interval_hours;
|
||||
let next_funding = now.date_naive()
|
||||
let next_funding = now
|
||||
.date_naive()
|
||||
.and_hms_opt(next_funding_hour % 24, 0, 0)
|
||||
.map(|dt| DateTime::from_naive_utc_and_offset(dt, Utc))
|
||||
.unwrap_or(now + Duration::hours(self.funding_interval_hours as i64));
|
||||
|
|
@ -721,7 +733,8 @@ impl DerivativesOracle {
|
|||
|
||||
/// Set volatility surface for an underlying
|
||||
pub fn set_vol_surface(&mut self, surface: VolatilitySurface) {
|
||||
self.vol_surfaces.insert(surface.underlying.clone(), surface);
|
||||
self.vol_surfaces
|
||||
.insert(surface.underlying.clone(), surface);
|
||||
}
|
||||
|
||||
/// Price an option using the volatility surface
|
||||
|
|
@ -730,12 +743,13 @@ impl DerivativesOracle {
|
|||
contract: &OptionContract,
|
||||
spot: SynorDecimal,
|
||||
) -> Result<OptionPricing> {
|
||||
let s = spot.to_f64().ok_or_else(||
|
||||
EconomicsError::InvalidPrice("Invalid spot".into())
|
||||
)?;
|
||||
let k = contract.strike.to_f64().ok_or_else(||
|
||||
EconomicsError::InvalidPrice("Invalid strike".into())
|
||||
)?;
|
||||
let s = spot
|
||||
.to_f64()
|
||||
.ok_or_else(|| EconomicsError::InvalidPrice("Invalid spot".into()))?;
|
||||
let k = contract
|
||||
.strike
|
||||
.to_f64()
|
||||
.ok_or_else(|| EconomicsError::InvalidPrice("Invalid strike".into()))?;
|
||||
|
||||
// Get volatility from surface or use default
|
||||
let vol = if let Some(surface) = self.vol_surfaces.get(&contract.underlying) {
|
||||
|
|
@ -765,7 +779,8 @@ impl DerivativesOracle {
|
|||
mark_price: SynorDecimal,
|
||||
open_interest: SynorDecimal,
|
||||
) -> Result<PerpetualPricing> {
|
||||
self.perpetual_model.price(index_price, mark_price, open_interest)
|
||||
self.perpetual_model
|
||||
.price(index_price, mark_price, open_interest)
|
||||
}
|
||||
|
||||
/// Calculate implied vol from market price
|
||||
|
|
@ -775,7 +790,8 @@ impl DerivativesOracle {
|
|||
spot: SynorDecimal,
|
||||
market_price: SynorDecimal,
|
||||
) -> Result<f64> {
|
||||
self.options_model.implied_volatility(contract, spot, market_price)
|
||||
self.options_model
|
||||
.implied_volatility(contract, spot, market_price)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -792,8 +808,18 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_option_intrinsic_value() {
|
||||
let call = OptionContract::new("ETH", dec!(2000), Utc::now() + Duration::days(30), OptionType::Call);
|
||||
let put = OptionContract::new("ETH", dec!(2000), Utc::now() + Duration::days(30), OptionType::Put);
|
||||
let call = OptionContract::new(
|
||||
"ETH",
|
||||
dec!(2000),
|
||||
Utc::now() + Duration::days(30),
|
||||
OptionType::Call,
|
||||
);
|
||||
let put = OptionContract::new(
|
||||
"ETH",
|
||||
dec!(2000),
|
||||
Utc::now() + Duration::days(30),
|
||||
OptionType::Put,
|
||||
);
|
||||
|
||||
// ITM call
|
||||
assert_eq!(call.intrinsic_value(dec!(2100)), dec!(100));
|
||||
|
|
@ -880,7 +906,9 @@ mod tests {
|
|||
let pricing = model.price(&contract, dec!(2000), vol).unwrap();
|
||||
|
||||
// Calculate IV from that price
|
||||
let iv = model.implied_volatility(&contract, dec!(2000), pricing.price).unwrap();
|
||||
let iv = model
|
||||
.implied_volatility(&contract, dec!(2000), pricing.price)
|
||||
.unwrap();
|
||||
|
||||
// Should match original vol
|
||||
assert!((iv - vol).abs() < 0.01);
|
||||
|
|
|
|||
|
|
@ -40,12 +40,12 @@ impl CollateralAsset {
|
|||
pub fn standard(symbol: impl Into<String>) -> Self {
|
||||
Self {
|
||||
symbol: symbol.into(),
|
||||
collateral_factor: Decimal::new(75, 2), // 75%
|
||||
liquidation_threshold: Decimal::new(80, 2), // 80%
|
||||
liquidation_bonus: Decimal::new(5, 2), // 5%
|
||||
collateral_factor: Decimal::new(75, 2), // 75%
|
||||
liquidation_threshold: Decimal::new(80, 2), // 80%
|
||||
liquidation_bonus: Decimal::new(5, 2), // 5%
|
||||
supply_cap: None,
|
||||
borrow_enabled: true,
|
||||
reserve_factor: Decimal::new(10, 2), // 10%
|
||||
reserve_factor: Decimal::new(10, 2), // 10%
|
||||
volatility_multiplier: Decimal::ONE,
|
||||
}
|
||||
}
|
||||
|
|
@ -54,12 +54,12 @@ impl CollateralAsset {
|
|||
pub fn stablecoin(symbol: impl Into<String>) -> Self {
|
||||
Self {
|
||||
symbol: symbol.into(),
|
||||
collateral_factor: Decimal::new(90, 2), // 90%
|
||||
liquidation_threshold: Decimal::new(95, 2), // 95%
|
||||
liquidation_bonus: Decimal::new(2, 2), // 2%
|
||||
collateral_factor: Decimal::new(90, 2), // 90%
|
||||
liquidation_threshold: Decimal::new(95, 2), // 95%
|
||||
liquidation_bonus: Decimal::new(2, 2), // 2%
|
||||
supply_cap: None,
|
||||
borrow_enabled: true,
|
||||
reserve_factor: Decimal::new(5, 2), // 5%
|
||||
reserve_factor: Decimal::new(5, 2), // 5%
|
||||
volatility_multiplier: Decimal::ONE,
|
||||
}
|
||||
}
|
||||
|
|
@ -68,13 +68,13 @@ impl CollateralAsset {
|
|||
pub fn volatile(symbol: impl Into<String>) -> Self {
|
||||
Self {
|
||||
symbol: symbol.into(),
|
||||
collateral_factor: Decimal::new(50, 2), // 50%
|
||||
liquidation_threshold: Decimal::new(65, 2), // 65%
|
||||
liquidation_bonus: Decimal::new(10, 2), // 10%
|
||||
collateral_factor: Decimal::new(50, 2), // 50%
|
||||
liquidation_threshold: Decimal::new(65, 2), // 65%
|
||||
liquidation_bonus: Decimal::new(10, 2), // 10%
|
||||
supply_cap: None,
|
||||
borrow_enabled: true,
|
||||
reserve_factor: Decimal::new(20, 2), // 20%
|
||||
volatility_multiplier: Decimal::new(12, 1), // 1.2x
|
||||
reserve_factor: Decimal::new(20, 2), // 20%
|
||||
volatility_multiplier: Decimal::new(12, 1), // 1.2x
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -131,7 +131,11 @@ impl LendingPosition {
|
|||
/// Withdraw collateral
|
||||
pub fn withdraw(&mut self, asset: impl Into<String>, amount: SynorDecimal) -> Result<()> {
|
||||
let asset = asset.into();
|
||||
let current = self.collateral.get(&asset).copied().unwrap_or(Decimal::ZERO);
|
||||
let current = self
|
||||
.collateral
|
||||
.get(&asset)
|
||||
.copied()
|
||||
.unwrap_or(Decimal::ZERO);
|
||||
if amount > current {
|
||||
return Err(EconomicsError::InsufficientFunds {
|
||||
required: amount,
|
||||
|
|
@ -233,10 +237,10 @@ pub struct LiquidationOracleConfig {
|
|||
impl Default for LiquidationOracleConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_price_age: 60, // 1 minute (stricter than general oracle)
|
||||
max_price_age: 60, // 1 minute (stricter than general oracle)
|
||||
min_confidence: 0.9,
|
||||
min_sources: 2,
|
||||
liquidation_grace_period: 300, // 5 minutes
|
||||
liquidation_grace_period: 300, // 5 minutes
|
||||
min_liquidation_amount: Decimal::new(10, 0), // $10
|
||||
max_liquidation_pct: Decimal::new(50, 2), // 50% at a time
|
||||
partial_liquidation: true,
|
||||
|
|
@ -289,7 +293,8 @@ impl LiquidationOracle {
|
|||
/// Create a new position
|
||||
pub fn create_position(&mut self, account_id: impl Into<String>) -> &mut LendingPosition {
|
||||
let account_id = account_id.into();
|
||||
self.positions.entry(account_id.clone())
|
||||
self.positions
|
||||
.entry(account_id.clone())
|
||||
.or_insert_with(|| LendingPosition::new(account_id))
|
||||
}
|
||||
|
||||
|
|
@ -320,7 +325,8 @@ impl LiquidationOracle {
|
|||
freshness_remaining: self.config.max_price_age - age,
|
||||
};
|
||||
|
||||
self.price_cache.insert(asset.to_string(), liq_price.clone());
|
||||
self.price_cache
|
||||
.insert(asset.to_string(), liq_price.clone());
|
||||
Ok(liq_price)
|
||||
}
|
||||
|
||||
|
|
@ -328,7 +334,9 @@ impl LiquidationOracle {
|
|||
pub fn calculate_health(&mut self, account_id: &str) -> Result<HealthStatus> {
|
||||
// Clone position data to avoid borrow conflicts with get_liquidation_price
|
||||
let (collateral, borrows, interest_owed) = {
|
||||
let position = self.positions.get(account_id)
|
||||
let position = self
|
||||
.positions
|
||||
.get(account_id)
|
||||
.ok_or_else(|| EconomicsError::AccountNotFound(account_id.to_string()))?;
|
||||
(
|
||||
position.collateral.clone(),
|
||||
|
|
@ -352,7 +360,9 @@ impl LiquidationOracle {
|
|||
});
|
||||
}
|
||||
|
||||
let asset_config = self.assets.get(asset)
|
||||
let asset_config = self
|
||||
.assets
|
||||
.get(asset)
|
||||
.ok_or_else(|| EconomicsError::InvalidPrice(format!("Unknown asset: {}", asset)))?;
|
||||
|
||||
let value = *amount * price.price;
|
||||
|
|
@ -426,24 +436,37 @@ impl LiquidationOracle {
|
|||
return Err(EconomicsError::InvalidPrice("Position is healthy".into()));
|
||||
}
|
||||
|
||||
let position = self.positions.get(account_id)
|
||||
let position = self
|
||||
.positions
|
||||
.get(account_id)
|
||||
.ok_or_else(|| EconomicsError::AccountNotFound(account_id.to_string()))?;
|
||||
|
||||
let debt_amount = position.borrows.get(debt_asset).copied().unwrap_or(Decimal::ZERO);
|
||||
let collateral_amount = position.collateral.get(collateral_asset).copied().unwrap_or(Decimal::ZERO);
|
||||
let debt_amount = position
|
||||
.borrows
|
||||
.get(debt_asset)
|
||||
.copied()
|
||||
.unwrap_or(Decimal::ZERO);
|
||||
let collateral_amount = position
|
||||
.collateral
|
||||
.get(collateral_asset)
|
||||
.copied()
|
||||
.unwrap_or(Decimal::ZERO);
|
||||
|
||||
if debt_amount == Decimal::ZERO {
|
||||
return Err(EconomicsError::InvalidPrice("No debt to repay".into()));
|
||||
}
|
||||
if collateral_amount == Decimal::ZERO {
|
||||
return Err(EconomicsError::InvalidPrice("No collateral to seize".into()));
|
||||
return Err(EconomicsError::InvalidPrice(
|
||||
"No collateral to seize".into(),
|
||||
));
|
||||
}
|
||||
|
||||
let debt_price = self.get_liquidation_price(debt_asset)?;
|
||||
let collateral_price = self.get_liquidation_price(collateral_asset)?;
|
||||
|
||||
let collateral_config = self.assets.get(collateral_asset)
|
||||
.ok_or_else(|| EconomicsError::InvalidPrice(format!("Unknown asset: {}", collateral_asset)))?;
|
||||
let collateral_config = self.assets.get(collateral_asset).ok_or_else(|| {
|
||||
EconomicsError::InvalidPrice(format!("Unknown asset: {}", collateral_asset))
|
||||
})?;
|
||||
|
||||
// Max debt repayable = close_factor * total_debt
|
||||
let max_debt_repay = debt_amount * self.config.close_factor;
|
||||
|
|
@ -458,12 +481,14 @@ impl LiquidationOracle {
|
|||
let actual_collateral_seized = collateral_to_seize.min(collateral_amount);
|
||||
let actual_debt_repaid = if actual_collateral_seized < collateral_to_seize {
|
||||
// Partial liquidation
|
||||
(actual_collateral_seized * collateral_price.price) / (bonus_multiplier * debt_price.price)
|
||||
(actual_collateral_seized * collateral_price.price)
|
||||
/ (bonus_multiplier * debt_price.price)
|
||||
} else {
|
||||
max_debt_repay
|
||||
};
|
||||
|
||||
let bonus_amount = actual_collateral_seized * collateral_config.liquidation_bonus / bonus_multiplier;
|
||||
let bonus_amount =
|
||||
actual_collateral_seized * collateral_config.liquidation_bonus / bonus_multiplier;
|
||||
|
||||
Ok(LiquidationCalculation {
|
||||
account_id: account_id.to_string(),
|
||||
|
|
@ -489,7 +514,9 @@ impl LiquidationOracle {
|
|||
let calc = self.calculate_liquidation(account_id, debt_asset, collateral_asset)?;
|
||||
|
||||
// Update position
|
||||
let position = self.positions.get_mut(account_id)
|
||||
let position = self
|
||||
.positions
|
||||
.get_mut(account_id)
|
||||
.ok_or_else(|| EconomicsError::AccountNotFound(account_id.to_string()))?;
|
||||
|
||||
// Reduce debt
|
||||
|
|
@ -549,7 +576,9 @@ impl LiquidationOracle {
|
|||
// Protocol gets a portion of the liquidation bonus
|
||||
if let Some(asset_config) = self.assets.get(&event.collateral_asset) {
|
||||
let protocol_share = event.bonus_amount * asset_config.reserve_factor;
|
||||
*reserves.entry(event.collateral_asset.clone()).or_insert(Decimal::ZERO) += protocol_share;
|
||||
*reserves
|
||||
.entry(event.collateral_asset.clone())
|
||||
.or_insert(Decimal::ZERO) += protocol_share;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -561,15 +590,18 @@ impl LiquidationOracle {
|
|||
let total_positions = self.positions.len();
|
||||
let total_liquidations = self.liquidation_history.len();
|
||||
|
||||
let total_debt_liquidated: SynorDecimal = self.liquidation_history.iter()
|
||||
.map(|e| e.debt_amount)
|
||||
.sum();
|
||||
let total_debt_liquidated: SynorDecimal =
|
||||
self.liquidation_history.iter().map(|e| e.debt_amount).sum();
|
||||
|
||||
let total_collateral_seized: SynorDecimal = self.liquidation_history.iter()
|
||||
let total_collateral_seized: SynorDecimal = self
|
||||
.liquidation_history
|
||||
.iter()
|
||||
.map(|e| e.collateral_amount)
|
||||
.sum();
|
||||
|
||||
let unique_liquidated: std::collections::HashSet<_> = self.liquidation_history.iter()
|
||||
let unique_liquidated: std::collections::HashSet<_> = self
|
||||
.liquidation_history
|
||||
.iter()
|
||||
.map(|e| &e.account_id)
|
||||
.collect();
|
||||
|
||||
|
|
@ -617,18 +649,60 @@ mod tests {
|
|||
let mut price_oracle = PriceOracle::with_config(OracleConfig::default());
|
||||
|
||||
// Add prices from multiple sources for test validity
|
||||
price_oracle.update_price(TokenPrice::new("ETH", "USD", dec!(2000), PriceSource::Internal)).unwrap();
|
||||
price_oracle.update_price(TokenPrice::new("ETH", "USD", dec!(2000), PriceSource::Aggregated)).unwrap();
|
||||
price_oracle.update_price(TokenPrice::new("SYNOR", "USD", dec!(1), PriceSource::Internal)).unwrap();
|
||||
price_oracle.update_price(TokenPrice::new("SYNOR", "USD", dec!(1), PriceSource::Aggregated)).unwrap();
|
||||
price_oracle.update_price(TokenPrice::new("USDC", "USD", dec!(1), PriceSource::Internal)).unwrap();
|
||||
price_oracle.update_price(TokenPrice::new("USDC", "USD", dec!(1), PriceSource::Aggregated)).unwrap();
|
||||
price_oracle
|
||||
.update_price(TokenPrice::new(
|
||||
"ETH",
|
||||
"USD",
|
||||
dec!(2000),
|
||||
PriceSource::Internal,
|
||||
))
|
||||
.unwrap();
|
||||
price_oracle
|
||||
.update_price(TokenPrice::new(
|
||||
"ETH",
|
||||
"USD",
|
||||
dec!(2000),
|
||||
PriceSource::Aggregated,
|
||||
))
|
||||
.unwrap();
|
||||
price_oracle
|
||||
.update_price(TokenPrice::new(
|
||||
"SYNOR",
|
||||
"USD",
|
||||
dec!(1),
|
||||
PriceSource::Internal,
|
||||
))
|
||||
.unwrap();
|
||||
price_oracle
|
||||
.update_price(TokenPrice::new(
|
||||
"SYNOR",
|
||||
"USD",
|
||||
dec!(1),
|
||||
PriceSource::Aggregated,
|
||||
))
|
||||
.unwrap();
|
||||
price_oracle
|
||||
.update_price(TokenPrice::new(
|
||||
"USDC",
|
||||
"USD",
|
||||
dec!(1),
|
||||
PriceSource::Internal,
|
||||
))
|
||||
.unwrap();
|
||||
price_oracle
|
||||
.update_price(TokenPrice::new(
|
||||
"USDC",
|
||||
"USD",
|
||||
dec!(1),
|
||||
PriceSource::Aggregated,
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
// Use test config with relaxed freshness requirements
|
||||
let test_config = LiquidationOracleConfig {
|
||||
max_price_age: 3600, // 1 hour for tests
|
||||
min_confidence: 0.5, // Lower confidence threshold
|
||||
min_sources: 1, // Single source OK for tests
|
||||
max_price_age: 3600, // 1 hour for tests
|
||||
min_confidence: 0.5, // Lower confidence threshold
|
||||
min_sources: 1, // Single source OK for tests
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
|
|
@ -660,8 +734,8 @@ mod tests {
|
|||
|
||||
// Create position with good health
|
||||
let pos = oracle.create_position("user1");
|
||||
pos.deposit("ETH", dec!(1)); // $2000 worth
|
||||
pos.borrow("USDC", dec!(500)); // Borrow $500
|
||||
pos.deposit("ETH", dec!(1)); // $2000 worth
|
||||
pos.borrow("USDC", dec!(500)); // Borrow $500
|
||||
|
||||
let health = oracle.calculate_health("user1").unwrap();
|
||||
|
||||
|
|
@ -678,8 +752,8 @@ mod tests {
|
|||
|
||||
// Create position close to liquidation
|
||||
let pos = oracle.create_position("user2");
|
||||
pos.deposit("ETH", dec!(1)); // $2000 worth
|
||||
pos.borrow("USDC", dec!(1500)); // Borrow $1500
|
||||
pos.deposit("ETH", dec!(1)); // $2000 worth
|
||||
pos.borrow("USDC", dec!(1500)); // Borrow $1500
|
||||
|
||||
let health = oracle.calculate_health("user2").unwrap();
|
||||
|
||||
|
|
@ -699,7 +773,9 @@ mod tests {
|
|||
pos.deposit("ETH", dec!(1));
|
||||
pos.borrow("USDC", dec!(1500));
|
||||
|
||||
let calc = oracle.calculate_liquidation("user3", "USDC", "ETH").unwrap();
|
||||
let calc = oracle
|
||||
.calculate_liquidation("user3", "USDC", "ETH")
|
||||
.unwrap();
|
||||
|
||||
// Should be able to liquidate
|
||||
assert!(calc.debt_to_repay > Decimal::ZERO);
|
||||
|
|
|
|||
|
|
@ -49,8 +49,7 @@ pub use price_feed::{
|
|||
PriceSource,
|
||||
};
|
||||
pub use twap::{
|
||||
OnChainTwap, OnChainTwapFactory, TwapCalculator, TwapConfig, TwapObservation,
|
||||
TwapOracleBuilder,
|
||||
OnChainTwap, OnChainTwapFactory, TwapCalculator, TwapConfig, TwapObservation, TwapOracleBuilder,
|
||||
};
|
||||
|
||||
use crate::error::{EconomicsError, Result};
|
||||
|
|
@ -241,7 +240,10 @@ impl PriceOracle {
|
|||
/// Get price history for a pair
|
||||
pub fn get_price_history(&self, token: &str, quote: &str) -> Vec<TokenPrice> {
|
||||
let pair_key = format!("{}/{}", token, quote);
|
||||
self.price_history.get(&pair_key).cloned().unwrap_or_default()
|
||||
self.price_history
|
||||
.get(&pair_key)
|
||||
.cloned()
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Fetch prices from all configured feeds
|
||||
|
|
@ -403,7 +405,9 @@ impl PriceOracle {
|
|||
}
|
||||
|
||||
let healthy = !pairs_status.is_empty()
|
||||
&& pairs_status.values().all(|s| !s.is_stale && s.price_count > 0);
|
||||
&& pairs_status
|
||||
.values()
|
||||
.all(|s| !s.is_stale && s.price_count > 0);
|
||||
|
||||
OracleHealthStatus {
|
||||
healthy,
|
||||
|
|
@ -443,7 +447,10 @@ impl PriceOracleBuilder {
|
|||
/// Add a mock price feed (for testing)
|
||||
pub fn with_mock_feed(mut self, base_price: SynorDecimal) -> Self {
|
||||
use price_feed::MockPriceFeed;
|
||||
self.feeds.push(Box::new(MockPriceFeed::new(PriceSource::Internal, base_price)));
|
||||
self.feeds.push(Box::new(MockPriceFeed::new(
|
||||
PriceSource::Internal,
|
||||
base_price,
|
||||
)));
|
||||
self
|
||||
}
|
||||
|
||||
|
|
@ -455,9 +462,14 @@ impl PriceOracleBuilder {
|
|||
}
|
||||
|
||||
/// Add Chainlink oracle feed
|
||||
pub fn with_chainlink(mut self, contract_address: impl Into<String>, rpc_url: impl Into<String>) -> Self {
|
||||
pub fn with_chainlink(
|
||||
mut self,
|
||||
contract_address: impl Into<String>,
|
||||
rpc_url: impl Into<String>,
|
||||
) -> Self {
|
||||
use price_feed::ChainlinkFeed;
|
||||
self.feeds.push(Box::new(ChainlinkFeed::new(contract_address, rpc_url)));
|
||||
self.feeds
|
||||
.push(Box::new(ChainlinkFeed::new(contract_address, rpc_url)));
|
||||
self
|
||||
}
|
||||
|
||||
|
|
@ -508,8 +520,14 @@ impl OracleFactory {
|
|||
let mut oracle = PriceOracle::new();
|
||||
|
||||
// Add multiple mock feeds with slight variations for testing
|
||||
oracle.add_feed(Box::new(MockPriceFeed::new(PriceSource::Internal, base_price)));
|
||||
oracle.add_feed(Box::new(MockPriceFeed::new(PriceSource::SynorDex, base_price)));
|
||||
oracle.add_feed(Box::new(MockPriceFeed::new(
|
||||
PriceSource::Internal,
|
||||
base_price,
|
||||
)));
|
||||
oracle.add_feed(Box::new(MockPriceFeed::new(
|
||||
PriceSource::SynorDex,
|
||||
base_price,
|
||||
)));
|
||||
|
||||
oracle
|
||||
}
|
||||
|
|
|
|||
|
|
@ -110,7 +110,8 @@ impl PriceFeed for MockPriceFeed {
|
|||
async fn fetch_price(&self, token: &str, quote: &str) -> Result<TokenPrice> {
|
||||
// Add small random variance
|
||||
let variance = (rand_simple() * 2.0 - 1.0) * self.volatility;
|
||||
let price = self.base_price * (Decimal::ONE + Decimal::from_f64_retain(variance).unwrap_or_default());
|
||||
let price = self.base_price
|
||||
* (Decimal::ONE + Decimal::from_f64_retain(variance).unwrap_or_default());
|
||||
|
||||
Ok(TokenPrice {
|
||||
token: token.to_string(),
|
||||
|
|
@ -258,9 +259,12 @@ impl PriceFeed for CoinGeckoFeed {
|
|||
"SYNOR" => "synor", // Would need actual CoinGecko ID
|
||||
"BTC" => "bitcoin",
|
||||
"ETH" => "ethereum",
|
||||
_ => return Err(EconomicsError::PriceFeedUnavailable(
|
||||
format!("Token {} not supported on CoinGecko", token)
|
||||
)),
|
||||
_ => {
|
||||
return Err(EconomicsError::PriceFeedUnavailable(format!(
|
||||
"Token {} not supported on CoinGecko",
|
||||
token
|
||||
)))
|
||||
}
|
||||
};
|
||||
|
||||
let quote_currency = quote.to_lowercase();
|
||||
|
|
@ -294,8 +298,7 @@ impl PriceFeed for CoinGeckoFeed {
|
|||
Ok(TokenPrice {
|
||||
token: token.to_string(),
|
||||
quote: quote.to_string(),
|
||||
price: Decimal::from_f64_retain(price)
|
||||
.unwrap_or_default(),
|
||||
price: Decimal::from_f64_retain(price).unwrap_or_default(),
|
||||
timestamp: Utc::now(),
|
||||
source: PriceSource::CoinGecko,
|
||||
confidence: 0.90,
|
||||
|
|
|
|||
|
|
@ -116,8 +116,8 @@ impl TwapCalculator {
|
|||
|
||||
let duration = (interval_end - interval_start).num_seconds() as f64;
|
||||
if duration > 0.0 {
|
||||
let weight = Decimal::from_f64_retain(duration / total_duration)
|
||||
.unwrap_or(Decimal::ZERO);
|
||||
let weight =
|
||||
Decimal::from_f64_retain(duration / total_duration).unwrap_or(Decimal::ZERO);
|
||||
|
||||
weighted_sum += price.price * weight;
|
||||
total_weight += weight;
|
||||
|
|
@ -343,7 +343,8 @@ impl OnChainTwap {
|
|||
/// Apply the pending cardinality increase (called during next observation)
|
||||
pub fn apply_cardinality_growth(&mut self) {
|
||||
if self.cardinality_next > self.cardinality {
|
||||
self.observations.reserve(self.cardinality_next - self.cardinality);
|
||||
self.observations
|
||||
.reserve(self.cardinality_next - self.cardinality);
|
||||
self.cardinality = self.cardinality_next;
|
||||
}
|
||||
}
|
||||
|
|
@ -396,7 +397,12 @@ impl TwapOracleBuilder {
|
|||
}
|
||||
|
||||
/// Add an initial observation
|
||||
pub fn with_observation(mut self, timestamp: DateTime<Utc>, price_cumulative: SynorDecimal, spl_cumulative: SynorDecimal) -> Self {
|
||||
pub fn with_observation(
|
||||
mut self,
|
||||
timestamp: DateTime<Utc>,
|
||||
price_cumulative: SynorDecimal,
|
||||
spl_cumulative: SynorDecimal,
|
||||
) -> Self {
|
||||
self.initial_observations.push(TwapObservation {
|
||||
timestamp,
|
||||
price_cumulative,
|
||||
|
|
|
|||
|
|
@ -76,7 +76,11 @@ pub struct Discount {
|
|||
|
||||
impl Discount {
|
||||
/// Create a new percentage discount
|
||||
pub fn percentage(code: impl Into<String>, name: impl Into<String>, percentage: SynorDecimal) -> Self {
|
||||
pub fn percentage(
|
||||
code: impl Into<String>,
|
||||
name: impl Into<String>,
|
||||
percentage: SynorDecimal,
|
||||
) -> Self {
|
||||
Self {
|
||||
code: code.into(),
|
||||
name: name.into(),
|
||||
|
|
@ -96,7 +100,11 @@ impl Discount {
|
|||
}
|
||||
|
||||
/// Create a fixed amount discount
|
||||
pub fn fixed_amount(code: impl Into<String>, name: impl Into<String>, amount: SynorDecimal) -> Self {
|
||||
pub fn fixed_amount(
|
||||
code: impl Into<String>,
|
||||
name: impl Into<String>,
|
||||
amount: SynorDecimal,
|
||||
) -> Self {
|
||||
Self {
|
||||
code: code.into(),
|
||||
name: name.into(),
|
||||
|
|
@ -120,7 +128,10 @@ impl Discount {
|
|||
Self {
|
||||
code: format!("VOLUME_{}", min_spend),
|
||||
name: format!("Volume Discount ({} SYNOR+)", min_spend),
|
||||
description: format!("{}% off when spending {} SYNOR or more", percentage, min_spend),
|
||||
description: format!(
|
||||
"{}% off when spending {} SYNOR or more",
|
||||
percentage, min_spend
|
||||
),
|
||||
discount_type: DiscountType::Volume,
|
||||
value: percentage,
|
||||
min_spend: Some(min_spend),
|
||||
|
|
@ -260,9 +271,11 @@ impl Discount {
|
|||
}
|
||||
|
||||
let discount = match self.discount_type {
|
||||
DiscountType::Percentage | DiscountType::Volume | DiscountType::Loyalty | DiscountType::Referral | DiscountType::Partner => {
|
||||
amount * (self.value / Decimal::ONE_HUNDRED)
|
||||
}
|
||||
DiscountType::Percentage
|
||||
| DiscountType::Volume
|
||||
| DiscountType::Loyalty
|
||||
| DiscountType::Referral
|
||||
| DiscountType::Partner => amount * (self.value / Decimal::ONE_HUNDRED),
|
||||
DiscountType::FixedAmount | DiscountType::Promotional => {
|
||||
self.value.min(amount) // Can't discount more than amount
|
||||
}
|
||||
|
|
@ -298,7 +311,7 @@ impl Discount {
|
|||
/// Volume discount tiers
|
||||
pub fn standard_volume_discounts() -> Vec<Discount> {
|
||||
vec![
|
||||
Discount::volume(Decimal::new(100, 0), Decimal::new(5, 0)), // 5% at 100+ SYNOR
|
||||
Discount::volume(Decimal::new(100, 0), Decimal::new(5, 0)), // 5% at 100+ SYNOR
|
||||
Discount::volume(Decimal::new(500, 0), Decimal::new(10, 0)), // 10% at 500+ SYNOR
|
||||
Discount::volume(Decimal::new(1000, 0), Decimal::new(15, 0)), // 15% at 1000+ SYNOR
|
||||
Discount::volume(Decimal::new(5000, 0), Decimal::new(20, 0)), // 20% at 5000+ SYNOR
|
||||
|
|
@ -309,11 +322,7 @@ pub fn standard_volume_discounts() -> Vec<Discount> {
|
|||
pub fn find_best_volume_discount(amount: SynorDecimal) -> Option<Discount> {
|
||||
standard_volume_discounts()
|
||||
.into_iter()
|
||||
.filter(|d| {
|
||||
d.min_spend
|
||||
.map(|min| amount >= min)
|
||||
.unwrap_or(false)
|
||||
})
|
||||
.filter(|d| d.min_spend.map(|min| amount >= min).unwrap_or(false))
|
||||
.max_by(|a, b| a.value.cmp(&b.value))
|
||||
}
|
||||
|
||||
|
|
@ -378,8 +387,8 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_discount_usage_limit() {
|
||||
let mut discount = Discount::percentage("LIMITED", "Limited Use", dec!(10))
|
||||
.with_max_uses(2);
|
||||
let mut discount =
|
||||
Discount::percentage("LIMITED", "Limited Use", dec!(10)).with_max_uses(2);
|
||||
|
||||
assert!(discount.use_discount());
|
||||
assert!(discount.use_discount());
|
||||
|
|
|
|||
|
|
@ -198,15 +198,9 @@ impl PricingEngine {
|
|||
.get(&service_type)
|
||||
.ok_or_else(|| EconomicsError::ServiceNotConfigured(service_type.to_string()))?;
|
||||
|
||||
let unit_price = pricing
|
||||
.base_prices
|
||||
.get(&resource_unit)
|
||||
.ok_or_else(|| {
|
||||
EconomicsError::ServiceNotConfigured(format!(
|
||||
"{} - {}",
|
||||
service_type, resource_unit
|
||||
))
|
||||
})?;
|
||||
let unit_price = pricing.base_prices.get(&resource_unit).ok_or_else(|| {
|
||||
EconomicsError::ServiceNotConfigured(format!("{} - {}", service_type, resource_unit))
|
||||
})?;
|
||||
|
||||
let cost = amount * unit_price;
|
||||
|
||||
|
|
@ -357,7 +351,8 @@ impl PricingEngine {
|
|||
storage: ServicePricingSummary {
|
||||
gb_month: self.get_base_price(ServiceType::Storage, ResourceUnit::GbMonth),
|
||||
retrieval_gb: self.get_base_price(ServiceType::Storage, ResourceUnit::BandwidthGb),
|
||||
free_storage_gb: self.get_free_allocation(ServiceType::Storage, ResourceUnit::GbMonth),
|
||||
free_storage_gb: self
|
||||
.get_free_allocation(ServiceType::Storage, ResourceUnit::GbMonth),
|
||||
},
|
||||
hosting: HostingPricingSummary {
|
||||
bandwidth_gb: self.get_base_price(ServiceType::Hosting, ResourceUnit::BandwidthGb),
|
||||
|
|
@ -377,7 +372,8 @@ impl PricingEngine {
|
|||
.get_free_allocation(ServiceType::Database, ResourceUnit::Queries),
|
||||
},
|
||||
compute: ComputePricingSummary {
|
||||
cpu_core_hour: self.get_base_price(ServiceType::Compute, ResourceUnit::CpuCoreHours),
|
||||
cpu_core_hour: self
|
||||
.get_base_price(ServiceType::Compute, ResourceUnit::CpuCoreHours),
|
||||
gpu_hour: self.get_base_price(ServiceType::Compute, ResourceUnit::GpuHours),
|
||||
memory_gb_hour: self
|
||||
.get_base_price(ServiceType::Compute, ResourceUnit::MemoryGbHours),
|
||||
|
|
@ -472,7 +468,11 @@ mod tests {
|
|||
|
||||
// 10 million queries
|
||||
let cost = engine
|
||||
.calculate_cost(ServiceType::Database, ResourceUnit::Queries, dec!(10_000_000))
|
||||
.calculate_cost(
|
||||
ServiceType::Database,
|
||||
ResourceUnit::Queries,
|
||||
dec!(10_000_000),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(cost, dec!(0.10)); // 10M * 0.00000001
|
||||
|
|
@ -484,7 +484,9 @@ mod tests {
|
|||
|
||||
// Premium tier gets 20% discount
|
||||
let base_cost = dec!(100);
|
||||
let discount = engine.calculate_tier_discount("premium", base_cost).unwrap();
|
||||
let discount = engine
|
||||
.calculate_tier_discount("premium", base_cost)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(discount, dec!(20)); // 20%
|
||||
}
|
||||
|
|
|
|||
|
|
@ -108,8 +108,8 @@ impl PricingTier {
|
|||
discount_percentage: Decimal::new(30, 0), // 30% discount
|
||||
priority_support: true,
|
||||
sla_percentage: Decimal::new(9999, 2), // 99.99% SLA
|
||||
custom_domain_limit: 0, // Unlimited
|
||||
api_rate_limit: 0, // Unlimited
|
||||
custom_domain_limit: 0, // Unlimited
|
||||
api_rate_limit: 0, // Unlimited
|
||||
features: vec![
|
||||
"Everything in Premium".to_string(),
|
||||
"30%+ Usage Discount".to_string(),
|
||||
|
|
@ -147,7 +147,9 @@ impl PricingTier {
|
|||
|
||||
/// Check if this tier has a feature
|
||||
pub fn has_feature(&self, feature: &str) -> bool {
|
||||
self.features.iter().any(|f| f.to_lowercase().contains(&feature.to_lowercase()))
|
||||
self.features
|
||||
.iter()
|
||||
.any(|f| f.to_lowercase().contains(&feature.to_lowercase()))
|
||||
}
|
||||
|
||||
/// Calculate effective monthly cost including usage discount
|
||||
|
|
@ -163,7 +165,9 @@ impl PricingTier {
|
|||
let other_cost = other.effective_cost(monthly_usage);
|
||||
|
||||
// Upgrade if other tier is cheaper or offers significant benefits
|
||||
other_cost < current_cost || (other.sla_percentage > self.sla_percentage && other_cost <= current_cost * Decimal::new(12, 1))
|
||||
other_cost < current_cost
|
||||
|| (other.sla_percentage > self.sla_percentage
|
||||
&& other_cost <= current_cost * Decimal::new(12, 1))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -181,7 +181,12 @@ impl AuthService {
|
|||
}
|
||||
|
||||
/// Generate a JWT token.
|
||||
pub fn generate_token(&self, user_id: &str, tier: ApiKeyTier, permissions: Permissions) -> Result<String, ApiError> {
|
||||
pub fn generate_token(
|
||||
&self,
|
||||
user_id: &str,
|
||||
tier: ApiKeyTier,
|
||||
permissions: Permissions,
|
||||
) -> Result<String, ApiError> {
|
||||
let now = Utc::now();
|
||||
let exp = now + self.jwt_expiration;
|
||||
|
||||
|
|
@ -215,8 +220,7 @@ impl AuthService {
|
|||
})?;
|
||||
|
||||
let claims = token_data.claims;
|
||||
let expires_at = DateTime::from_timestamp(claims.exp, 0)
|
||||
.map(|dt| dt.with_timezone(&Utc));
|
||||
let expires_at = DateTime::from_timestamp(claims.exp, 0).map(|dt| dt.with_timezone(&Utc));
|
||||
|
||||
Ok(AuthContext {
|
||||
user_id: claims.sub,
|
||||
|
|
@ -278,9 +282,7 @@ impl AuthService {
|
|||
|
||||
// Try API key header
|
||||
if let Some(api_key) = headers.get("X-API-Key") {
|
||||
let key = api_key
|
||||
.to_str()
|
||||
.map_err(|_| ApiError::InvalidApiKey)?;
|
||||
let key = api_key.to_str().map_err(|_| ApiError::InvalidApiKey)?;
|
||||
return self.validate_api_key(key).await;
|
||||
}
|
||||
|
||||
|
|
@ -295,8 +297,7 @@ impl AuthService {
|
|||
let decoded = BASE64
|
||||
.decode(encoded)
|
||||
.map_err(|_| ApiError::InvalidApiKey)?;
|
||||
let key = String::from_utf8(decoded)
|
||||
.map_err(|_| ApiError::InvalidApiKey)?;
|
||||
let key = String::from_utf8(decoded).map_err(|_| ApiError::InvalidApiKey)?;
|
||||
return self.validate_api_key(&key).await;
|
||||
}
|
||||
}
|
||||
|
|
@ -318,7 +319,9 @@ where
|
|||
fn from_request_parts<'life0, 'life1, 'async_trait>(
|
||||
parts: &'life0 mut Parts,
|
||||
_state: &'life1 S,
|
||||
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Self, Self::Rejection>> + Send + 'async_trait>>
|
||||
) -> std::pin::Pin<
|
||||
Box<dyn std::future::Future<Output = Result<Self, Self::Rejection>> + Send + 'async_trait>,
|
||||
>
|
||||
where
|
||||
'life0: 'async_trait,
|
||||
'life1: 'async_trait,
|
||||
|
|
@ -351,7 +354,9 @@ where
|
|||
fn from_request_parts<'life0, 'life1, 'async_trait>(
|
||||
parts: &'life0 mut Parts,
|
||||
_state: &'life1 S,
|
||||
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Self, Self::Rejection>> + Send + 'async_trait>>
|
||||
) -> std::pin::Pin<
|
||||
Box<dyn std::future::Future<Output = Result<Self, Self::Rejection>> + Send + 'async_trait>,
|
||||
>
|
||||
where
|
||||
'life0: 'async_trait,
|
||||
'life1: 'async_trait,
|
||||
|
|
@ -359,10 +364,7 @@ where
|
|||
{
|
||||
Box::pin(async move {
|
||||
// Get auth service from extensions
|
||||
let auth_service = parts
|
||||
.extensions
|
||||
.get::<AuthService>()
|
||||
.cloned();
|
||||
let auth_service = parts.extensions.get::<AuthService>().cloned();
|
||||
|
||||
if let Some(auth_service) = auth_service {
|
||||
match auth_service.authenticate(&parts.headers).await {
|
||||
|
|
@ -377,10 +379,7 @@ where
|
|||
}
|
||||
|
||||
/// Require specific permissions.
|
||||
pub fn require_permission(
|
||||
context: &AuthContext,
|
||||
permission: &str,
|
||||
) -> Result<(), ApiError> {
|
||||
pub fn require_permission(context: &AuthContext, permission: &str) -> Result<(), ApiError> {
|
||||
let has_permission = match permission {
|
||||
"read" => context.can_read(),
|
||||
"write" => context.can_write(),
|
||||
|
|
@ -397,10 +396,7 @@ pub fn require_permission(
|
|||
}
|
||||
|
||||
/// Require access to a specific service.
|
||||
pub fn require_service_access(
|
||||
context: &AuthContext,
|
||||
service: &str,
|
||||
) -> Result<(), ApiError> {
|
||||
pub fn require_service_access(context: &AuthContext, service: &str) -> Result<(), ApiError> {
|
||||
if context.can_access_service(service) {
|
||||
Ok(())
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -160,10 +160,26 @@ pub struct RateLimitTiers {
|
|||
impl Default for RateLimitTiers {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
free: TierConfig { rpm: 60, burst: 10, concurrent: 5 },
|
||||
developer: TierConfig { rpm: 600, burst: 100, concurrent: 20 },
|
||||
pro: TierConfig { rpm: 6000, burst: 1000, concurrent: 100 },
|
||||
enterprise: TierConfig { rpm: 0, burst: 0, concurrent: 0 }, // Unlimited
|
||||
free: TierConfig {
|
||||
rpm: 60,
|
||||
burst: 10,
|
||||
concurrent: 5,
|
||||
},
|
||||
developer: TierConfig {
|
||||
rpm: 600,
|
||||
burst: 100,
|
||||
concurrent: 20,
|
||||
},
|
||||
pro: TierConfig {
|
||||
rpm: 6000,
|
||||
burst: 1000,
|
||||
concurrent: 100,
|
||||
},
|
||||
enterprise: TierConfig {
|
||||
rpm: 0,
|
||||
burst: 0,
|
||||
concurrent: 0,
|
||||
}, // Unlimited
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -170,9 +170,7 @@ impl ApiError {
|
|||
| Self::ContractError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
||||
|
||||
// 429 Too Many Requests
|
||||
Self::RateLimitExceeded | Self::TooManyRequests { .. } => {
|
||||
StatusCode::TOO_MANY_REQUESTS
|
||||
}
|
||||
Self::RateLimitExceeded | Self::TooManyRequests { .. } => StatusCode::TOO_MANY_REQUESTS,
|
||||
|
||||
// 500 Internal Server Error
|
||||
Self::InternalError | Self::Custom(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
|
|
@ -222,7 +220,10 @@ impl ApiError {
|
|||
/// Build error details with optional extra information.
|
||||
pub fn to_details(&self) -> ErrorDetails {
|
||||
let details = match self {
|
||||
Self::InsufficientBalance { required, available } => Some(serde_json::json!({
|
||||
Self::InsufficientBalance {
|
||||
required,
|
||||
available,
|
||||
} => Some(serde_json::json!({
|
||||
"required": required,
|
||||
"available": available
|
||||
})),
|
||||
|
|
@ -257,10 +258,9 @@ impl IntoResponse for ApiError {
|
|||
|
||||
// Add rate limit headers for 429 errors
|
||||
if let Self::TooManyRequests { retry_after } = &self {
|
||||
response.headers_mut().insert(
|
||||
"Retry-After",
|
||||
retry_after.to_string().parse().unwrap(),
|
||||
);
|
||||
response
|
||||
.headers_mut()
|
||||
.insert("Retry-After", retry_after.to_string().parse().unwrap());
|
||||
}
|
||||
|
||||
response
|
||||
|
|
|
|||
|
|
@ -144,7 +144,8 @@ pub async fn timing_middleware(request: Request, next: Next) -> Response {
|
|||
|
||||
// Update metrics
|
||||
metrics::counter!("http_requests_total", "method" => method.to_string(), "status" => status.as_u16().to_string()).increment(1);
|
||||
metrics::histogram!("http_request_duration_seconds", "method" => method.to_string()).record(duration.as_secs_f64());
|
||||
metrics::histogram!("http_request_duration_seconds", "method" => method.to_string())
|
||||
.record(duration.as_secs_f64());
|
||||
|
||||
response
|
||||
}
|
||||
|
|
@ -169,7 +170,10 @@ impl RateLimiterState {
|
|||
}
|
||||
|
||||
/// Get or create a rate limiter for an IP.
|
||||
pub async fn get_ip_limiter(&self, ip: &str) -> Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>> {
|
||||
pub async fn get_ip_limiter(
|
||||
&self,
|
||||
ip: &str,
|
||||
) -> Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>> {
|
||||
{
|
||||
let limiters = self.ip_limiters.read().await;
|
||||
if let Some(limiter) = limiters.get(ip) {
|
||||
|
|
@ -189,7 +193,11 @@ impl RateLimiterState {
|
|||
}
|
||||
|
||||
/// Get or create a rate limiter for an API key.
|
||||
pub async fn get_key_limiter(&self, key_id: &str, tier: ApiKeyTier) -> Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>> {
|
||||
pub async fn get_key_limiter(
|
||||
&self,
|
||||
key_id: &str,
|
||||
tier: ApiKeyTier,
|
||||
) -> Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>> {
|
||||
{
|
||||
let limiters = self.key_limiters.read().await;
|
||||
if let Some(limiter) = limiters.get(key_id) {
|
||||
|
|
@ -255,9 +263,7 @@ pub async fn rate_limit_middleware(
|
|||
// Use a fixed retry time since we can't easily convert to quanta's instant
|
||||
let retry_after = 60; // Default to 60 seconds
|
||||
|
||||
Err(ApiError::TooManyRequests {
|
||||
retry_after,
|
||||
})
|
||||
Err(ApiError::TooManyRequests { retry_after })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -274,10 +280,7 @@ pub async fn auth_middleware(
|
|||
}
|
||||
|
||||
/// API version middleware - validates version prefix.
|
||||
pub async fn version_middleware(
|
||||
request: Request,
|
||||
next: Next,
|
||||
) -> Result<Response, ApiError> {
|
||||
pub async fn version_middleware(request: Request, next: Next) -> Result<Response, ApiError> {
|
||||
let path = request.uri().path();
|
||||
|
||||
// Skip version check for health, metrics, and docs
|
||||
|
|
@ -307,22 +310,13 @@ pub async fn security_headers_middleware(request: Request, next: Next) -> Respon
|
|||
let headers = response.headers_mut();
|
||||
|
||||
// Prevent XSS
|
||||
headers.insert(
|
||||
"X-Content-Type-Options",
|
||||
"nosniff".parse().unwrap(),
|
||||
);
|
||||
headers.insert("X-Content-Type-Options", "nosniff".parse().unwrap());
|
||||
|
||||
// Prevent clickjacking
|
||||
headers.insert(
|
||||
"X-Frame-Options",
|
||||
"DENY".parse().unwrap(),
|
||||
);
|
||||
headers.insert("X-Frame-Options", "DENY".parse().unwrap());
|
||||
|
||||
// Enable XSS filter
|
||||
headers.insert(
|
||||
"X-XSS-Protection",
|
||||
"1; mode=block".parse().unwrap(),
|
||||
);
|
||||
headers.insert("X-XSS-Protection", "1; mode=block".parse().unwrap());
|
||||
|
||||
// Strict transport security (HTTPS)
|
||||
headers.insert(
|
||||
|
|
|
|||
|
|
@ -6,11 +6,7 @@
|
|||
//! - Contract analysis and validation
|
||||
//! - Security scanning
|
||||
|
||||
use axum::{
|
||||
extract::State,
|
||||
routing::post,
|
||||
Json, Router,
|
||||
};
|
||||
use axum::{extract::State, routing::post, Json, Router};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
|
|
@ -44,7 +40,7 @@ pub fn router() -> Router<AppState> {
|
|||
// Types
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct CompileRequest {
|
||||
pub wasm: String, // base64 encoded WASM
|
||||
pub wasm: String, // base64 encoded WASM
|
||||
pub optimization_level: Option<String>, // none, basic, size, aggressive
|
||||
pub strip_debug: Option<bool>,
|
||||
pub strip_names: Option<bool>,
|
||||
|
|
@ -216,16 +212,14 @@ async fn compile_contract(
|
|||
abi: Some(ContractAbi {
|
||||
name: "MyContract".to_string(),
|
||||
version: "1.0.0".to_string(),
|
||||
functions: vec![
|
||||
AbiFunction {
|
||||
name: "init".to_string(),
|
||||
selector: "0x12345678".to_string(),
|
||||
inputs: vec![],
|
||||
outputs: vec![],
|
||||
view: false,
|
||||
payable: false,
|
||||
},
|
||||
],
|
||||
functions: vec![AbiFunction {
|
||||
name: "init".to_string(),
|
||||
selector: "0x12345678".to_string(),
|
||||
inputs: vec![],
|
||||
outputs: vec![],
|
||||
view: false,
|
||||
payable: false,
|
||||
}],
|
||||
events: vec![],
|
||||
errors: vec![],
|
||||
}),
|
||||
|
|
@ -342,23 +336,19 @@ async fn analyze_contract(
|
|||
imports: 100,
|
||||
total: 5000,
|
||||
},
|
||||
functions: vec![
|
||||
FunctionAnalysis {
|
||||
name: "init".to_string(),
|
||||
size: 500,
|
||||
instruction_count: 50,
|
||||
local_count: 3,
|
||||
exported: true,
|
||||
estimated_gas: 10000,
|
||||
},
|
||||
],
|
||||
imports: vec![
|
||||
ImportInfo {
|
||||
module: "env".to_string(),
|
||||
name: "memory".to_string(),
|
||||
kind: "memory".to_string(),
|
||||
},
|
||||
],
|
||||
functions: vec![FunctionAnalysis {
|
||||
name: "init".to_string(),
|
||||
size: 500,
|
||||
instruction_count: 50,
|
||||
local_count: 3,
|
||||
exported: true,
|
||||
estimated_gas: 10000,
|
||||
}],
|
||||
imports: vec![ImportInfo {
|
||||
module: "env".to_string(),
|
||||
name: "memory".to_string(),
|
||||
kind: "memory".to_string(),
|
||||
}],
|
||||
gas_analysis: GasAnalysis {
|
||||
deployment_gas: 100000,
|
||||
memory_init_gas: 5000,
|
||||
|
|
@ -378,14 +368,12 @@ async fn security_scan(
|
|||
|
||||
let result = SecurityScanResult {
|
||||
score: 85,
|
||||
issues: vec![
|
||||
SecurityIssue {
|
||||
severity: "low".to_string(),
|
||||
issue_type: "unbounded_loop".to_string(),
|
||||
description: "Potential unbounded loop detected".to_string(),
|
||||
location: Some("function:process".to_string()),
|
||||
},
|
||||
],
|
||||
issues: vec![SecurityIssue {
|
||||
severity: "low".to_string(),
|
||||
issue_type: "unbounded_loop".to_string(),
|
||||
description: "Potential unbounded loop detected".to_string(),
|
||||
location: Some("function:process".to_string()),
|
||||
}],
|
||||
recommendations: vec![
|
||||
"Add loop iteration limits".to_string(),
|
||||
"Consider using checked arithmetic".to_string(),
|
||||
|
|
|
|||
|
|
@ -122,17 +122,15 @@ async fn list_markets(
|
|||
) -> ApiResult<Json<ApiResponse<Vec<Market>>>> {
|
||||
require_permission(&auth, "read")?;
|
||||
|
||||
let markets = vec![
|
||||
Market {
|
||||
symbol: "ETH-USDC".to_string(),
|
||||
base_asset: "ETH".to_string(),
|
||||
quote_asset: "USDC".to_string(),
|
||||
last_price: "3000.00".to_string(),
|
||||
change_24h: "2.5".to_string(),
|
||||
volume_24h: "10000000".to_string(),
|
||||
status: "active".to_string(),
|
||||
},
|
||||
];
|
||||
let markets = vec![Market {
|
||||
symbol: "ETH-USDC".to_string(),
|
||||
base_asset: "ETH".to_string(),
|
||||
quote_asset: "USDC".to_string(),
|
||||
last_price: "3000.00".to_string(),
|
||||
change_24h: "2.5".to_string(),
|
||||
volume_24h: "10000000".to_string(),
|
||||
status: "active".to_string(),
|
||||
}];
|
||||
|
||||
Ok(Json(ApiResponse::success(markets)))
|
||||
}
|
||||
|
|
@ -165,8 +163,14 @@ async fn get_orderbook(
|
|||
require_permission(&auth, "read")?;
|
||||
|
||||
let orderbook = Orderbook {
|
||||
bids: vec![OrderbookEntry { price: "2999.00".to_string(), quantity: "1.5".to_string() }],
|
||||
asks: vec![OrderbookEntry { price: "3001.00".to_string(), quantity: "2.0".to_string() }],
|
||||
bids: vec![OrderbookEntry {
|
||||
price: "2999.00".to_string(),
|
||||
quantity: "1.5".to_string(),
|
||||
}],
|
||||
asks: vec![OrderbookEntry {
|
||||
price: "3001.00".to_string(),
|
||||
quantity: "2.0".to_string(),
|
||||
}],
|
||||
spread: "2.00".to_string(),
|
||||
};
|
||||
|
||||
|
|
@ -286,7 +290,9 @@ async fn place_perp_order(
|
|||
Json(req): Json<serde_json::Value>,
|
||||
) -> ApiResult<Json<ApiResponse<serde_json::Value>>> {
|
||||
require_permission(&auth, "write")?;
|
||||
Ok(Json(ApiResponse::success(serde_json::json!({"order_id": "perp_123"}))))
|
||||
Ok(Json(ApiResponse::success(
|
||||
serde_json::json!({"order_id": "perp_123"}),
|
||||
)))
|
||||
}
|
||||
|
||||
async fn list_pools(
|
||||
|
|
@ -295,18 +301,16 @@ async fn list_pools(
|
|||
) -> ApiResult<Json<ApiResponse<Vec<Pool>>>> {
|
||||
require_permission(&auth, "read")?;
|
||||
|
||||
let pools = vec![
|
||||
Pool {
|
||||
pool_id: "ETH-USDC".to_string(),
|
||||
name: "ETH/USDC".to_string(),
|
||||
token_a: "ETH".to_string(),
|
||||
token_b: "USDC".to_string(),
|
||||
reserve_a: "1000".to_string(),
|
||||
reserve_b: "3000000".to_string(),
|
||||
tvl: "6000000".to_string(),
|
||||
apr: "15.5".to_string(),
|
||||
},
|
||||
];
|
||||
let pools = vec![Pool {
|
||||
pool_id: "ETH-USDC".to_string(),
|
||||
name: "ETH/USDC".to_string(),
|
||||
token_a: "ETH".to_string(),
|
||||
token_b: "USDC".to_string(),
|
||||
reserve_a: "1000".to_string(),
|
||||
reserve_b: "3000000".to_string(),
|
||||
tvl: "6000000".to_string(),
|
||||
apr: "15.5".to_string(),
|
||||
}];
|
||||
|
||||
Ok(Json(ApiResponse::success(pools)))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -128,16 +128,14 @@ async fn list_chains(
|
|||
) -> ApiResult<Json<ApiResponse<Vec<Chain>>>> {
|
||||
require_permission(&auth, "read")?;
|
||||
|
||||
let chains = vec![
|
||||
Chain {
|
||||
chain_id: "cosmoshub-4".to_string(),
|
||||
name: "Cosmos Hub".to_string(),
|
||||
status: "active".to_string(),
|
||||
rpc_endpoint: "https://rpc.cosmos.network".to_string(),
|
||||
latest_height: 18000000,
|
||||
active_channels: 50,
|
||||
},
|
||||
];
|
||||
let chains = vec![Chain {
|
||||
chain_id: "cosmoshub-4".to_string(),
|
||||
name: "Cosmos Hub".to_string(),
|
||||
status: "active".to_string(),
|
||||
rpc_endpoint: "https://rpc.cosmos.network".to_string(),
|
||||
latest_height: 18000000,
|
||||
active_channels: 50,
|
||||
}];
|
||||
|
||||
Ok(Json(ApiResponse::success(chains)))
|
||||
}
|
||||
|
|
@ -280,15 +278,13 @@ async fn get_routes(
|
|||
) -> ApiResult<Json<ApiResponse<Vec<TransferRoute>>>> {
|
||||
require_permission(&auth, "read")?;
|
||||
|
||||
let routes = vec![
|
||||
TransferRoute {
|
||||
source_chain: "cosmoshub-4".to_string(),
|
||||
dest_chain: "synor-mainnet".to_string(),
|
||||
channel_id: "channel-0".to_string(),
|
||||
estimated_time: "30s".to_string(),
|
||||
fee: "0.001 ATOM".to_string(),
|
||||
},
|
||||
];
|
||||
let routes = vec![TransferRoute {
|
||||
source_chain: "cosmoshub-4".to_string(),
|
||||
dest_chain: "synor-mainnet".to_string(),
|
||||
channel_id: "channel-0".to_string(),
|
||||
estimated_time: "30s".to_string(),
|
||||
fee: "0.001 ATOM".to_string(),
|
||||
}];
|
||||
|
||||
Ok(Json(ApiResponse::success(routes)))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -303,13 +303,11 @@ async fn get_peers(
|
|||
) -> ApiResult<Json<ApiResponse<Vec<serde_json::Value>>>> {
|
||||
require_permission(&auth, "read")?;
|
||||
|
||||
let peers = vec![
|
||||
serde_json::json!({
|
||||
"id": "peer1",
|
||||
"address": "192.168.1.1:16100",
|
||||
"connected_since": 1705312200
|
||||
})
|
||||
];
|
||||
let peers = vec![serde_json::json!({
|
||||
"id": "peer1",
|
||||
"address": "192.168.1.1:16100",
|
||||
"connected_since": 1705312200
|
||||
})];
|
||||
|
||||
Ok(Json(ApiResponse::success(peers)))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -247,14 +247,12 @@ async fn list_directory(
|
|||
) -> ApiResult<Json<ApiResponse<Vec<DirectoryEntry>>>> {
|
||||
require_permission(&auth, "read")?;
|
||||
|
||||
let entries = vec![
|
||||
DirectoryEntry {
|
||||
name: "file1.txt".to_string(),
|
||||
cid: "bafyfile1...".to_string(),
|
||||
size: 1024,
|
||||
is_directory: false,
|
||||
},
|
||||
];
|
||||
let entries = vec![DirectoryEntry {
|
||||
name: "file1.txt".to_string(),
|
||||
cid: "bafyfile1...".to_string(),
|
||||
size: 1024,
|
||||
is_directory: false,
|
||||
}];
|
||||
|
||||
Ok(Json(ApiResponse::success(entries)))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -409,7 +409,10 @@ async fn list_addresses(
|
|||
];
|
||||
|
||||
let pagination_meta = pagination.to_meta(addresses.len() as u64);
|
||||
Ok(Json(ApiResponse::success_paginated(addresses, pagination_meta)))
|
||||
Ok(Json(ApiResponse::success_paginated(
|
||||
addresses,
|
||||
pagination_meta,
|
||||
)))
|
||||
}
|
||||
|
||||
/// Generate a stealth address.
|
||||
|
|
@ -477,7 +480,9 @@ async fn get_balances(
|
|||
require_permission(&auth, "read")?;
|
||||
|
||||
if req.addresses.is_empty() {
|
||||
return Err(ApiError::ValidationError("addresses cannot be empty".to_string()));
|
||||
return Err(ApiError::ValidationError(
|
||||
"addresses cannot be empty".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if req.addresses.len() > 100 {
|
||||
|
|
@ -585,12 +590,15 @@ async fn send_transaction(
|
|||
}
|
||||
|
||||
// Validate amount
|
||||
let amount: f64 = req.amount.parse().map_err(|_| {
|
||||
ApiError::ValidationError("Invalid amount format".to_string())
|
||||
})?;
|
||||
let amount: f64 = req
|
||||
.amount
|
||||
.parse()
|
||||
.map_err(|_| ApiError::ValidationError("Invalid amount format".to_string()))?;
|
||||
|
||||
if amount <= 0.0 {
|
||||
return Err(ApiError::ValidationError("Amount must be positive".to_string()));
|
||||
return Err(ApiError::ValidationError(
|
||||
"Amount must be positive".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// In production, build, sign, and broadcast the transaction
|
||||
|
|
|
|||
|
|
@ -139,16 +139,14 @@ async fn list_circuits(
|
|||
) -> ApiResult<Json<ApiResponse<Vec<Circuit>>>> {
|
||||
require_permission(&auth, "read")?;
|
||||
|
||||
let circuits = vec![
|
||||
Circuit {
|
||||
circuit_id: "multiplier-v1".to_string(),
|
||||
name: "Multiplier".to_string(),
|
||||
constraints: 1,
|
||||
public_inputs: 1,
|
||||
private_inputs: 2,
|
||||
outputs: 1,
|
||||
},
|
||||
];
|
||||
let circuits = vec![Circuit {
|
||||
circuit_id: "multiplier-v1".to_string(),
|
||||
name: "Multiplier".to_string(),
|
||||
constraints: 1,
|
||||
public_inputs: 1,
|
||||
private_inputs: 2,
|
||||
outputs: 1,
|
||||
}];
|
||||
|
||||
let meta = pagination.to_meta(circuits.len() as u64);
|
||||
Ok(Json(ApiResponse::success_paginated(circuits, meta)))
|
||||
|
|
|
|||
|
|
@ -17,15 +17,9 @@ use axum::{
|
|||
Router,
|
||||
};
|
||||
use std::{net::SocketAddr, sync::Arc};
|
||||
use tokio::{
|
||||
net::TcpListener,
|
||||
signal,
|
||||
sync::oneshot,
|
||||
};
|
||||
use tokio::{net::TcpListener, signal, sync::oneshot};
|
||||
use tower_http::{
|
||||
compression::CompressionLayer,
|
||||
limit::RequestBodyLimitLayer,
|
||||
timeout::TimeoutLayer,
|
||||
compression::CompressionLayer, limit::RequestBodyLimitLayer, timeout::TimeoutLayer,
|
||||
trace::TraceLayer,
|
||||
};
|
||||
use tracing::info;
|
||||
|
|
|
|||
|
|
@ -164,7 +164,8 @@ impl VersionRegistry {
|
|||
|
||||
/// Register a version.
|
||||
pub fn register(&mut self, info: VersionInfo) {
|
||||
self.versions.insert((info.version.major, info.version.minor), info);
|
||||
self.versions
|
||||
.insert((info.version.major, info.version.minor), info);
|
||||
}
|
||||
|
||||
/// Get version info.
|
||||
|
|
@ -330,10 +331,7 @@ pub async fn version_middleware(req: Request, next: Next) -> Response {
|
|||
// Add deprecation headers if needed
|
||||
if let Some(info) = registry.get(&extracted.version) {
|
||||
if info.is_deprecated {
|
||||
headers.insert(
|
||||
X_API_DEPRECATED.clone(),
|
||||
HeaderValue::from_static("true"),
|
||||
);
|
||||
headers.insert(X_API_DEPRECATED.clone(), HeaderValue::from_static("true"));
|
||||
|
||||
if let Some(deprecated_at) = &info.deprecated_at {
|
||||
if let Ok(v) = HeaderValue::from_str(&deprecated_at.to_rfc3339()) {
|
||||
|
|
@ -427,8 +425,8 @@ impl VersionsResponse {
|
|||
// Routes
|
||||
// ============================================================================
|
||||
|
||||
use axum::{routing::get, Json, Router};
|
||||
use crate::routes::AppState;
|
||||
use axum::{routing::get, Json, Router};
|
||||
|
||||
/// Build version routes.
|
||||
pub fn router() -> Router<AppState> {
|
||||
|
|
|
|||
|
|
@ -593,11 +593,7 @@ async fn ws_blocks_handler(
|
|||
ws.on_upgrade(move |socket| handle_blocks_socket(socket, state, auth))
|
||||
}
|
||||
|
||||
async fn handle_blocks_socket(
|
||||
socket: WebSocket,
|
||||
state: AppState,
|
||||
_auth: Option<AuthContext>,
|
||||
) {
|
||||
async fn handle_blocks_socket(socket: WebSocket, state: AppState, _auth: Option<AuthContext>) {
|
||||
let ws_state = &state.websocket;
|
||||
ws_state.broadcaster.add_connection().await;
|
||||
|
||||
|
|
@ -834,11 +830,7 @@ async fn ws_markets_handler(
|
|||
ws.on_upgrade(move |socket| handle_markets_socket(socket, state, auth))
|
||||
}
|
||||
|
||||
async fn handle_markets_socket(
|
||||
socket: WebSocket,
|
||||
state: AppState,
|
||||
_auth: Option<AuthContext>,
|
||||
) {
|
||||
async fn handle_markets_socket(socket: WebSocket, state: AppState, _auth: Option<AuthContext>) {
|
||||
let ws_state = &state.websocket;
|
||||
ws_state.broadcaster.add_connection().await;
|
||||
|
||||
|
|
|
|||
|
|
@ -7,8 +7,8 @@
|
|||
//! hosting-gateway --domain synor.cc --storage-url http://localhost:8180
|
||||
//! hosting-gateway --config /path/to/config.toml
|
||||
|
||||
use synor_hosting::{HostingGateway, GatewayConfig};
|
||||
use std::net::SocketAddr;
|
||||
use synor_hosting::{GatewayConfig, HostingGateway};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
|
|
@ -19,28 +19,32 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||
// Parse command line arguments
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
|
||||
let listen_addr: SocketAddr = args.iter()
|
||||
let listen_addr: SocketAddr = args
|
||||
.iter()
|
||||
.position(|a| a == "--listen" || a == "-l")
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.and_then(|s| s.parse().ok())
|
||||
.or_else(|| std::env::var("LISTEN_ADDR").ok()?.parse().ok())
|
||||
.unwrap_or_else(|| "0.0.0.0:8080".parse().unwrap());
|
||||
|
||||
let hosting_domain = args.iter()
|
||||
let hosting_domain = args
|
||||
.iter()
|
||||
.position(|a| a == "--domain" || a == "-d")
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.cloned()
|
||||
.or_else(|| std::env::var("HOSTING_DOMAIN").ok())
|
||||
.unwrap_or_else(|| "synor.cc".to_string());
|
||||
|
||||
let storage_url = args.iter()
|
||||
let storage_url = args
|
||||
.iter()
|
||||
.position(|a| a == "--storage-url" || a == "-s")
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.cloned()
|
||||
.or_else(|| std::env::var("STORAGE_GATEWAY_URL").ok())
|
||||
.unwrap_or_else(|| "http://localhost:8180".to_string());
|
||||
|
||||
let rate_limit: u32 = args.iter()
|
||||
let rate_limit: u32 = args
|
||||
.iter()
|
||||
.position(|a| a == "--rate-limit")
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.and_then(|s| s.parse().ok())
|
||||
|
|
|
|||
|
|
@ -233,11 +233,7 @@ impl EdgeCompute {
|
|||
}
|
||||
|
||||
/// Run AI inference at the edge.
|
||||
pub async fn inference(
|
||||
&self,
|
||||
_model: &str,
|
||||
_input: &[u8],
|
||||
) -> Result<Vec<u8>, EdgeError> {
|
||||
pub async fn inference(&self, _model: &str, _input: &[u8]) -> Result<Vec<u8>, EdgeError> {
|
||||
if !self.enabled {
|
||||
return Err(EdgeError::NotEnabled);
|
||||
}
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Reference in a new issue