style: apply cargo fmt formatting
This commit is contained in:
parent
5126c33113
commit
dcd1cccc67
170 changed files with 4463 additions and 2837 deletions
|
|
@ -256,7 +256,11 @@ pub async fn handle(
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
CompilerCommands::Encode { function, args, abi } => {
|
CompilerCommands::Encode {
|
||||||
|
function,
|
||||||
|
args,
|
||||||
|
abi,
|
||||||
|
} => {
|
||||||
output::print_info(&format!("Encoding call to: {}", function));
|
output::print_info(&format!("Encoding call to: {}", function));
|
||||||
output::print_kv("Arguments", &args);
|
output::print_kv("Arguments", &args);
|
||||||
if let Some(a) = abi {
|
if let Some(a) = abi {
|
||||||
|
|
@ -268,7 +272,11 @@ pub async fn handle(
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
CompilerCommands::Decode { data, function, abi } => {
|
CompilerCommands::Decode {
|
||||||
|
data,
|
||||||
|
function,
|
||||||
|
abi,
|
||||||
|
} => {
|
||||||
output::print_info(&format!("Decoding result for: {}", function));
|
output::print_info(&format!("Decoding result for: {}", function));
|
||||||
output::print_kv("Data", &data);
|
output::print_kv("Data", &data);
|
||||||
if let Some(a) = abi {
|
if let Some(a) = abi {
|
||||||
|
|
@ -314,7 +322,11 @@ pub async fn handle(
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
CompilerCommands::SecurityScan { wasm, min_severity, format: _ } => {
|
CompilerCommands::SecurityScan {
|
||||||
|
wasm,
|
||||||
|
min_severity,
|
||||||
|
format: _,
|
||||||
|
} => {
|
||||||
output::print_info(&format!("Security scan: {}", wasm.display()));
|
output::print_info(&format!("Security scan: {}", wasm.display()));
|
||||||
output::print_kv("Min severity", &min_severity);
|
output::print_kv("Min severity", &min_severity);
|
||||||
|
|
||||||
|
|
@ -344,7 +356,11 @@ pub async fn handle(
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
CompilerCommands::Validate { wasm, exports, max_memory } => {
|
CompilerCommands::Validate {
|
||||||
|
wasm,
|
||||||
|
exports,
|
||||||
|
max_memory,
|
||||||
|
} => {
|
||||||
output::print_info(&format!("Validating: {}", wasm.display()));
|
output::print_info(&format!("Validating: {}", wasm.display()));
|
||||||
|
|
||||||
if let Some(e) = exports {
|
if let Some(e) = exports {
|
||||||
|
|
|
||||||
|
|
@ -196,8 +196,7 @@ pub async fn deploy(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Determine output directory
|
// Determine output directory
|
||||||
let output_path = output_dir
|
let output_path = output_dir.unwrap_or_else(|| cwd.join(config.output_dir()));
|
||||||
.unwrap_or_else(|| cwd.join(config.output_dir()));
|
|
||||||
|
|
||||||
if !output_path.exists() {
|
if !output_path.exists() {
|
||||||
return Err(anyhow!(
|
return Err(anyhow!(
|
||||||
|
|
@ -270,7 +269,10 @@ fn validate_name(name: &str) -> Result<()> {
|
||||||
if name.len() > 63 {
|
if name.len() > 63 {
|
||||||
return Err(anyhow!("Name must be 63 characters or less"));
|
return Err(anyhow!("Name must be 63 characters or less"));
|
||||||
}
|
}
|
||||||
if !name.chars().all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '-') {
|
if !name
|
||||||
|
.chars()
|
||||||
|
.all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '-')
|
||||||
|
{
|
||||||
return Err(anyhow!(
|
return Err(anyhow!(
|
||||||
"Name must contain only lowercase letters, numbers, and hyphens"
|
"Name must contain only lowercase letters, numbers, and hyphens"
|
||||||
));
|
));
|
||||||
|
|
@ -281,8 +283,8 @@ fn validate_name(name: &str) -> Result<()> {
|
||||||
|
|
||||||
// Reserved names
|
// Reserved names
|
||||||
const RESERVED: &[&str] = &[
|
const RESERVED: &[&str] = &[
|
||||||
"www", "api", "app", "admin", "mail", "ftp", "ssh", "cdn",
|
"www", "api", "app", "admin", "mail", "ftp", "ssh", "cdn", "storage", "gateway", "hosting",
|
||||||
"storage", "gateway", "hosting", "node", "synor",
|
"node", "synor",
|
||||||
];
|
];
|
||||||
if RESERVED.contains(&name) {
|
if RESERVED.contains(&name) {
|
||||||
return Err(anyhow!("Name '{}' is reserved", name));
|
return Err(anyhow!("Name '{}' is reserved", name));
|
||||||
|
|
@ -397,11 +399,7 @@ fn guess_content_type(path: &Path) -> String {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Upload files to Synor Storage.
|
/// Upload files to Synor Storage.
|
||||||
async fn upload_files(
|
async fn upload_files(base_dir: &Path, files: &[DeployFile], gateway_url: &str) -> Result<String> {
|
||||||
base_dir: &Path,
|
|
||||||
files: &[DeployFile],
|
|
||||||
gateway_url: &str,
|
|
||||||
) -> Result<String> {
|
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::new();
|
||||||
|
|
||||||
// Create a multipart form with all files
|
// Create a multipart form with all files
|
||||||
|
|
@ -445,11 +443,7 @@ async fn upload_files(
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Register the deployment with the hosting gateway.
|
/// Register the deployment with the hosting gateway.
|
||||||
async fn register_deployment(
|
async fn register_deployment(name: &str, cid: &str, gateway_url: &str) -> Result<String> {
|
||||||
name: &str,
|
|
||||||
cid: &str,
|
|
||||||
gateway_url: &str,
|
|
||||||
) -> Result<String> {
|
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::new();
|
||||||
|
|
||||||
#[derive(Serialize)]
|
#[derive(Serialize)]
|
||||||
|
|
@ -662,7 +656,11 @@ pub async fn delete(name: &str, gateway_url: &str, format: OutputFormat) -> Resu
|
||||||
if !response.status().is_success() {
|
if !response.status().is_success() {
|
||||||
let status = response.status();
|
let status = response.status();
|
||||||
let body = response.text().await.unwrap_or_default();
|
let body = response.text().await.unwrap_or_default();
|
||||||
return Err(anyhow!("Failed to delete deployment: {} - {}", status, body));
|
return Err(anyhow!(
|
||||||
|
"Failed to delete deployment: {} - {}",
|
||||||
|
status,
|
||||||
|
body
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
match format {
|
match format {
|
||||||
|
|
@ -707,22 +705,13 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_guess_content_type() {
|
fn test_guess_content_type() {
|
||||||
assert_eq!(
|
assert_eq!(guess_content_type(Path::new("index.html")), "text/html");
|
||||||
guess_content_type(Path::new("index.html")),
|
assert_eq!(guess_content_type(Path::new("style.css")), "text/css");
|
||||||
"text/html"
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
guess_content_type(Path::new("style.css")),
|
|
||||||
"text/css"
|
|
||||||
);
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
guess_content_type(Path::new("app.js")),
|
guess_content_type(Path::new("app.js")),
|
||||||
"application/javascript"
|
"application/javascript"
|
||||||
);
|
);
|
||||||
assert_eq!(
|
assert_eq!(guess_content_type(Path::new("image.png")), "image/png");
|
||||||
guess_content_type(Path::new("image.png")),
|
|
||||||
"image/png"
|
|
||||||
);
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
guess_content_type(Path::new("data.wasm")),
|
guess_content_type(Path::new("data.wasm")),
|
||||||
"application/wasm"
|
"application/wasm"
|
||||||
|
|
|
||||||
|
|
@ -246,7 +246,13 @@ pub async fn handle(
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
DexCommands::PlaceOrder { market, side, price, quantity, wallet } => {
|
DexCommands::PlaceOrder {
|
||||||
|
market,
|
||||||
|
side,
|
||||||
|
price,
|
||||||
|
quantity,
|
||||||
|
wallet,
|
||||||
|
} => {
|
||||||
output::print_info("Placing limit order...");
|
output::print_info("Placing limit order...");
|
||||||
output::print_kv("Market", &market);
|
output::print_kv("Market", &market);
|
||||||
output::print_kv("Side", &side);
|
output::print_kv("Side", &side);
|
||||||
|
|
@ -257,7 +263,12 @@ pub async fn handle(
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
DexCommands::MarketOrder { market, side, quantity, wallet } => {
|
DexCommands::MarketOrder {
|
||||||
|
market,
|
||||||
|
side,
|
||||||
|
quantity,
|
||||||
|
wallet,
|
||||||
|
} => {
|
||||||
output::print_info("Placing market order...");
|
output::print_info("Placing market order...");
|
||||||
output::print_kv("Market", &market);
|
output::print_kv("Market", &market);
|
||||||
output::print_kv("Side", &side);
|
output::print_kv("Side", &side);
|
||||||
|
|
@ -275,7 +286,10 @@ pub async fn handle(
|
||||||
|
|
||||||
DexCommands::CancelAll { market, wallet } => {
|
DexCommands::CancelAll { market, wallet } => {
|
||||||
let scope = market.unwrap_or_else(|| "all markets".to_string());
|
let scope = market.unwrap_or_else(|| "all markets".to_string());
|
||||||
output::print_info(&format!("Cancelling all orders in {} for {}", scope, wallet));
|
output::print_info(&format!(
|
||||||
|
"Cancelling all orders in {} for {}",
|
||||||
|
scope, wallet
|
||||||
|
));
|
||||||
output::print_success("3 orders cancelled");
|
output::print_success("3 orders cancelled");
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
@ -317,7 +331,12 @@ pub async fn handle(
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
DexCommands::AddLiquidity { pool_id, amount_a, amount_b, wallet } => {
|
DexCommands::AddLiquidity {
|
||||||
|
pool_id,
|
||||||
|
amount_a,
|
||||||
|
amount_b,
|
||||||
|
wallet,
|
||||||
|
} => {
|
||||||
output::print_info("Adding liquidity...");
|
output::print_info("Adding liquidity...");
|
||||||
output::print_kv("Pool", &pool_id);
|
output::print_kv("Pool", &pool_id);
|
||||||
output::print_kv("Amount A", &amount_a);
|
output::print_kv("Amount A", &amount_a);
|
||||||
|
|
@ -327,7 +346,11 @@ pub async fn handle(
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
DexCommands::RemoveLiquidity { pool_id, lp_amount, wallet } => {
|
DexCommands::RemoveLiquidity {
|
||||||
|
pool_id,
|
||||||
|
lp_amount,
|
||||||
|
wallet,
|
||||||
|
} => {
|
||||||
output::print_info("Removing liquidity...");
|
output::print_info("Removing liquidity...");
|
||||||
output::print_kv("Pool", &pool_id);
|
output::print_kv("Pool", &pool_id);
|
||||||
output::print_kv("LP Amount", &lp_amount);
|
output::print_kv("LP Amount", &lp_amount);
|
||||||
|
|
|
||||||
|
|
@ -169,11 +169,7 @@ pub enum ZkCommands {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Handle ZK commands.
|
/// Handle ZK commands.
|
||||||
pub async fn handle(
|
pub async fn handle(_client: &RpcClient, command: ZkCommands, _format: OutputFormat) -> Result<()> {
|
||||||
_client: &RpcClient,
|
|
||||||
command: ZkCommands,
|
|
||||||
_format: OutputFormat,
|
|
||||||
) -> Result<()> {
|
|
||||||
match command {
|
match command {
|
||||||
ZkCommands::Compile { circuit, output } => {
|
ZkCommands::Compile { circuit, output } => {
|
||||||
output::print_info(&format!("Compiling circuit: {}", circuit.display()));
|
output::print_info(&format!("Compiling circuit: {}", circuit.display()));
|
||||||
|
|
@ -211,8 +207,16 @@ pub async fn handle(
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
ZkCommands::ProveGroth16 { circuit, witness, proving_key: _, output } => {
|
ZkCommands::ProveGroth16 {
|
||||||
output::print_info(&format!("Generating Groth16 proof for circuit: {}", circuit));
|
circuit,
|
||||||
|
witness,
|
||||||
|
proving_key: _,
|
||||||
|
output,
|
||||||
|
} => {
|
||||||
|
output::print_info(&format!(
|
||||||
|
"Generating Groth16 proof for circuit: {}",
|
||||||
|
circuit
|
||||||
|
));
|
||||||
output::print_info(&format!("Witness: {}", witness.display()));
|
output::print_info(&format!("Witness: {}", witness.display()));
|
||||||
output::print_info("Computing witness...");
|
output::print_info("Computing witness...");
|
||||||
output::print_info("Generating proof...");
|
output::print_info("Generating proof...");
|
||||||
|
|
@ -226,7 +230,11 @@ pub async fn handle(
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
ZkCommands::ProvePlonk { circuit, witness, output } => {
|
ZkCommands::ProvePlonk {
|
||||||
|
circuit,
|
||||||
|
witness,
|
||||||
|
output,
|
||||||
|
} => {
|
||||||
output::print_info(&format!("Generating PLONK proof for circuit: {}", circuit));
|
output::print_info(&format!("Generating PLONK proof for circuit: {}", circuit));
|
||||||
output::print_info(&format!("Witness: {}", witness.display()));
|
output::print_info(&format!("Witness: {}", witness.display()));
|
||||||
output::print_info("Computing witness...");
|
output::print_info("Computing witness...");
|
||||||
|
|
@ -240,7 +248,11 @@ pub async fn handle(
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
ZkCommands::ProveStark { circuit, witness, output } => {
|
ZkCommands::ProveStark {
|
||||||
|
circuit,
|
||||||
|
witness,
|
||||||
|
output,
|
||||||
|
} => {
|
||||||
output::print_info(&format!("Generating STARK proof for circuit: {}", circuit));
|
output::print_info(&format!("Generating STARK proof for circuit: {}", circuit));
|
||||||
output::print_info(&format!("Witness: {}", witness.display()));
|
output::print_info(&format!("Witness: {}", witness.display()));
|
||||||
output::print_info("Computing execution trace...");
|
output::print_info("Computing execution trace...");
|
||||||
|
|
@ -256,7 +268,11 @@ pub async fn handle(
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
ZkCommands::Verify { proof, verification_key: _, public_inputs: _ } => {
|
ZkCommands::Verify {
|
||||||
|
proof,
|
||||||
|
verification_key: _,
|
||||||
|
public_inputs: _,
|
||||||
|
} => {
|
||||||
output::print_info(&format!("Verifying proof: {}", proof.display()));
|
output::print_info(&format!("Verifying proof: {}", proof.display()));
|
||||||
output::print_info("Loading proof...");
|
output::print_info("Loading proof...");
|
||||||
output::print_info("Verifying...");
|
output::print_info("Verifying...");
|
||||||
|
|
@ -265,8 +281,15 @@ pub async fn handle(
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
ZkCommands::Setup { circuit, system, output } => {
|
ZkCommands::Setup {
|
||||||
output::print_info(&format!("Generating {} keys for circuit: {}", system, circuit));
|
circuit,
|
||||||
|
system,
|
||||||
|
output,
|
||||||
|
} => {
|
||||||
|
output::print_info(&format!(
|
||||||
|
"Generating {} keys for circuit: {}",
|
||||||
|
system, circuit
|
||||||
|
));
|
||||||
output::print_info("This may take a while for large circuits...");
|
output::print_info("This may take a while for large circuits...");
|
||||||
output::print_info("Generating proving key...");
|
output::print_info("Generating proving key...");
|
||||||
output::print_info("Deriving verification key...");
|
output::print_info("Deriving verification key...");
|
||||||
|
|
|
||||||
|
|
@ -469,7 +469,11 @@ enum DeployCommands {
|
||||||
output: Option<PathBuf>,
|
output: Option<PathBuf>,
|
||||||
|
|
||||||
/// Hosting gateway URL
|
/// Hosting gateway URL
|
||||||
#[arg(long, env = "SYNOR_HOSTING_URL", default_value = "http://127.0.0.1:8280")]
|
#[arg(
|
||||||
|
long,
|
||||||
|
env = "SYNOR_HOSTING_URL",
|
||||||
|
default_value = "http://127.0.0.1:8280"
|
||||||
|
)]
|
||||||
gateway: String,
|
gateway: String,
|
||||||
|
|
||||||
/// Skip running the build command
|
/// Skip running the build command
|
||||||
|
|
@ -495,7 +499,11 @@ enum DeployCommands {
|
||||||
/// List deployments
|
/// List deployments
|
||||||
List {
|
List {
|
||||||
/// Hosting gateway URL
|
/// Hosting gateway URL
|
||||||
#[arg(long, env = "SYNOR_HOSTING_URL", default_value = "http://127.0.0.1:8280")]
|
#[arg(
|
||||||
|
long,
|
||||||
|
env = "SYNOR_HOSTING_URL",
|
||||||
|
default_value = "http://127.0.0.1:8280"
|
||||||
|
)]
|
||||||
gateway: String,
|
gateway: String,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
|
@ -505,7 +513,11 @@ enum DeployCommands {
|
||||||
name: String,
|
name: String,
|
||||||
|
|
||||||
/// Hosting gateway URL
|
/// Hosting gateway URL
|
||||||
#[arg(long, env = "SYNOR_HOSTING_URL", default_value = "http://127.0.0.1:8280")]
|
#[arg(
|
||||||
|
long,
|
||||||
|
env = "SYNOR_HOSTING_URL",
|
||||||
|
default_value = "http://127.0.0.1:8280"
|
||||||
|
)]
|
||||||
gateway: String,
|
gateway: String,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
|
@ -515,7 +527,11 @@ enum DeployCommands {
|
||||||
name: String,
|
name: String,
|
||||||
|
|
||||||
/// Hosting gateway URL
|
/// Hosting gateway URL
|
||||||
#[arg(long, env = "SYNOR_HOSTING_URL", default_value = "http://127.0.0.1:8280")]
|
#[arg(
|
||||||
|
long,
|
||||||
|
env = "SYNOR_HOSTING_URL",
|
||||||
|
default_value = "http://127.0.0.1:8280"
|
||||||
|
)]
|
||||||
gateway: String,
|
gateway: String,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
@ -591,9 +607,11 @@ async fn main() {
|
||||||
gateway,
|
gateway,
|
||||||
skip_build,
|
skip_build,
|
||||||
} => commands::deploy::deploy(name, out_dir, &gateway, skip_build, output).await,
|
} => commands::deploy::deploy(name, out_dir, &gateway, skip_build, output).await,
|
||||||
DeployCommands::Init { name, spa, output: out_dir } => {
|
DeployCommands::Init {
|
||||||
commands::deploy::init(name, spa, out_dir, output)
|
name,
|
||||||
}
|
spa,
|
||||||
|
output: out_dir,
|
||||||
|
} => commands::deploy::init(name, spa, out_dir, output),
|
||||||
DeployCommands::List { gateway } => commands::deploy::list(&gateway, output).await,
|
DeployCommands::List { gateway } => commands::deploy::list(&gateway, output).await,
|
||||||
DeployCommands::Delete { name, gateway } => {
|
DeployCommands::Delete { name, gateway } => {
|
||||||
commands::deploy::delete(&name, &gateway, output).await
|
commands::deploy::delete(&name, &gateway, output).await
|
||||||
|
|
|
||||||
|
|
@ -676,7 +676,9 @@ async fn get_blocks(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fetch blocks by blue score (most recent first)
|
// Fetch blocks by blue score (most recent first)
|
||||||
let start_score = score.blue_score.saturating_sub((params.page.saturating_sub(1) * limit) as u64);
|
let start_score = score
|
||||||
|
.blue_score
|
||||||
|
.saturating_sub((params.page.saturating_sub(1) * limit) as u64);
|
||||||
let blocks_data: Vec<serde_json::Value> = state
|
let blocks_data: Vec<serde_json::Value> = state
|
||||||
.rpc_call("synor_getBlocksByBlueScore", (start_score, true))
|
.rpc_call("synor_getBlocksByBlueScore", (start_score, true))
|
||||||
.await
|
.await
|
||||||
|
|
@ -697,17 +699,28 @@ async fn get_blocks(
|
||||||
parent_hashes: header
|
parent_hashes: header
|
||||||
.get("parents")
|
.get("parents")
|
||||||
.and_then(|p| p.as_array())
|
.and_then(|p| p.as_array())
|
||||||
.map(|a| a.iter().filter_map(|v| v.as_str().map(String::from)).collect())
|
.map(|a| {
|
||||||
|
a.iter()
|
||||||
|
.filter_map(|v| v.as_str().map(String::from))
|
||||||
|
.collect()
|
||||||
|
})
|
||||||
.unwrap_or_default(),
|
.unwrap_or_default(),
|
||||||
timestamp,
|
timestamp,
|
||||||
timestamp_human: format_timestamp(timestamp),
|
timestamp_human: format_timestamp(timestamp),
|
||||||
bits: header.get("bits")?.as_u64()? as u32,
|
bits: header.get("bits")?.as_u64()? as u32,
|
||||||
nonce: header.get("nonce")?.as_u64()?,
|
nonce: header.get("nonce")?.as_u64()?,
|
||||||
daa_score: header.get("blueScore").and_then(|v| v.as_u64()).unwrap_or(0),
|
daa_score: header
|
||||||
blue_score: header.get("blueScore").and_then(|v| v.as_u64()).unwrap_or(0),
|
.get("blueScore")
|
||||||
|
.and_then(|v| v.as_u64())
|
||||||
|
.unwrap_or(0),
|
||||||
|
blue_score: header
|
||||||
|
.get("blueScore")
|
||||||
|
.and_then(|v| v.as_u64())
|
||||||
|
.unwrap_or(0),
|
||||||
blue_work: String::new(),
|
blue_work: String::new(),
|
||||||
difficulty: 0.0,
|
difficulty: 0.0,
|
||||||
transaction_count: b.get("transactions")
|
transaction_count: b
|
||||||
|
.get("transactions")
|
||||||
.and_then(|t| t.as_array())
|
.and_then(|t| t.as_array())
|
||||||
.map(|a| a.len())
|
.map(|a| a.len())
|
||||||
.unwrap_or(0),
|
.unwrap_or(0),
|
||||||
|
|
@ -1102,9 +1115,7 @@ async fn estimate_gas(
|
||||||
};
|
};
|
||||||
|
|
||||||
// Call the node's contract_estimateGas RPC method
|
// Call the node's contract_estimateGas RPC method
|
||||||
let gas_used: u64 = state
|
let gas_used: u64 = state.rpc_call("contract_estimateGas", rpc_request).await?;
|
||||||
.rpc_call("contract_estimateGas", rpc_request)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
// Calculate recommended gas limit with 20% safety margin
|
// Calculate recommended gas limit with 20% safety margin
|
||||||
let gas_limit_recommended = ((gas_used as f64) * 1.2).ceil() as u64;
|
let gas_limit_recommended = ((gas_used as f64) * 1.2).ceil() as u64;
|
||||||
|
|
@ -1494,8 +1505,7 @@ async fn main() -> anyhow::Result<()> {
|
||||||
let app = if let Some(ref static_dir) = config.static_dir {
|
let app = if let Some(ref static_dir) = config.static_dir {
|
||||||
// Serve static files with SPA fallback (index.html for client-side routing)
|
// Serve static files with SPA fallback (index.html for client-side routing)
|
||||||
let index_path = format!("{}/index.html", static_dir);
|
let index_path = format!("{}/index.html", static_dir);
|
||||||
let serve_dir = ServeDir::new(static_dir)
|
let serve_dir = ServeDir::new(static_dir).not_found_service(ServeFile::new(&index_path));
|
||||||
.not_found_service(ServeFile::new(&index_path));
|
|
||||||
|
|
||||||
api_router
|
api_router
|
||||||
.fallback_service(serve_dir)
|
.fallback_service(serve_dir)
|
||||||
|
|
|
||||||
|
|
@ -684,10 +684,12 @@ mod tests {
|
||||||
#[test]
|
#[test]
|
||||||
fn test_all_paths_are_distinct() {
|
fn test_all_paths_are_distinct() {
|
||||||
let config = NodeConfig::for_network("mainnet").unwrap();
|
let config = NodeConfig::for_network("mainnet").unwrap();
|
||||||
let paths = [config.blocks_path(),
|
let paths = [
|
||||||
|
config.blocks_path(),
|
||||||
config.chainstate_path(),
|
config.chainstate_path(),
|
||||||
config.contracts_path(),
|
config.contracts_path(),
|
||||||
config.keys_path()];
|
config.keys_path(),
|
||||||
|
];
|
||||||
|
|
||||||
for i in 0..paths.len() {
|
for i in 0..paths.len() {
|
||||||
for j in (i + 1)..paths.len() {
|
for j in (i + 1)..paths.len() {
|
||||||
|
|
@ -794,9 +796,11 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_with_mining_enabled() {
|
fn test_with_mining_enabled() {
|
||||||
let config = NodeConfig::for_network("mainnet")
|
let config = NodeConfig::for_network("mainnet").unwrap().with_mining(
|
||||||
.unwrap()
|
true,
|
||||||
.with_mining(true, Some("synor:test_address".to_string()), 4);
|
Some("synor:test_address".to_string()),
|
||||||
|
4,
|
||||||
|
);
|
||||||
|
|
||||||
assert!(config.mining.enabled);
|
assert!(config.mining.enabled);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
|
|
@ -828,9 +832,10 @@ mod tests {
|
||||||
#[test]
|
#[test]
|
||||||
fn test_with_p2p() {
|
fn test_with_p2p() {
|
||||||
let seeds = vec!["seed1.example.com:30303".to_string()];
|
let seeds = vec!["seed1.example.com:30303".to_string()];
|
||||||
let config = NodeConfig::for_network("mainnet")
|
let config =
|
||||||
.unwrap()
|
NodeConfig::for_network("mainnet")
|
||||||
.with_p2p("0.0.0.0", 30303, seeds.clone());
|
.unwrap()
|
||||||
|
.with_p2p("0.0.0.0", 30303, seeds.clone());
|
||||||
|
|
||||||
assert_eq!(config.p2p.listen_addr, "0.0.0.0:30303");
|
assert_eq!(config.p2p.listen_addr, "0.0.0.0:30303");
|
||||||
assert_eq!(config.p2p.seeds, seeds);
|
assert_eq!(config.p2p.seeds, seeds);
|
||||||
|
|
@ -1027,7 +1032,10 @@ mod tests {
|
||||||
let loaded = NodeConfig::load(&path).unwrap();
|
let loaded = NodeConfig::load(&path).unwrap();
|
||||||
|
|
||||||
assert_eq!(loaded.mining.enabled, config.mining.enabled);
|
assert_eq!(loaded.mining.enabled, config.mining.enabled);
|
||||||
assert_eq!(loaded.mining.coinbase_address, config.mining.coinbase_address);
|
assert_eq!(
|
||||||
|
loaded.mining.coinbase_address,
|
||||||
|
config.mining.coinbase_address
|
||||||
|
);
|
||||||
assert_eq!(loaded.mining.threads, config.mining.threads);
|
assert_eq!(loaded.mining.threads, config.mining.threads);
|
||||||
assert_eq!(loaded.storage.cache_size_mb, config.storage.cache_size_mb);
|
assert_eq!(loaded.storage.cache_size_mb, config.storage.cache_size_mb);
|
||||||
assert_eq!(loaded.logging.level, config.logging.level);
|
assert_eq!(loaded.logging.level, config.logging.level);
|
||||||
|
|
|
||||||
|
|
@ -425,11 +425,13 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_node_state_all_variants_are_distinct() {
|
fn test_node_state_all_variants_are_distinct() {
|
||||||
let states = [NodeState::Starting,
|
let states = [
|
||||||
|
NodeState::Starting,
|
||||||
NodeState::Syncing,
|
NodeState::Syncing,
|
||||||
NodeState::Running,
|
NodeState::Running,
|
||||||
NodeState::Stopping,
|
NodeState::Stopping,
|
||||||
NodeState::Stopped];
|
NodeState::Stopped,
|
||||||
|
];
|
||||||
|
|
||||||
for i in 0..states.len() {
|
for i in 0..states.len() {
|
||||||
for j in (i + 1)..states.len() {
|
for j in (i + 1)..states.len() {
|
||||||
|
|
@ -605,7 +607,10 @@ mod tests {
|
||||||
.with_mining(true, Some("synor:test".to_string()), 4);
|
.with_mining(true, Some("synor:test".to_string()), 4);
|
||||||
|
|
||||||
assert!(config.mining.enabled);
|
assert!(config.mining.enabled);
|
||||||
assert_eq!(config.mining.coinbase_address, Some("synor:test".to_string()));
|
assert_eq!(
|
||||||
|
config.mining.coinbase_address,
|
||||||
|
Some("synor:test".to_string())
|
||||||
|
);
|
||||||
assert_eq!(config.mining.threads, 4);
|
assert_eq!(config.mining.threads, 4);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -12,9 +12,12 @@ use synor_mining::{
|
||||||
MinerCommand, MinerConfig, MinerEvent, MiningResult, MiningStats as CrateMiningStats,
|
MinerCommand, MinerConfig, MinerEvent, MiningResult, MiningStats as CrateMiningStats,
|
||||||
TemplateTransaction,
|
TemplateTransaction,
|
||||||
};
|
};
|
||||||
use synor_types::{Address, Amount, Block, BlockHeader, BlockId, BlueScore, Hash256, Network, Timestamp, Transaction, TxOutput};
|
|
||||||
use synor_types::block::BlockBody;
|
use synor_types::block::BlockBody;
|
||||||
use synor_types::transaction::ScriptPubKey;
|
use synor_types::transaction::ScriptPubKey;
|
||||||
|
use synor_types::{
|
||||||
|
Address, Amount, Block, BlockHeader, BlockId, BlueScore, Hash256, Network, Timestamp,
|
||||||
|
Transaction, TxOutput,
|
||||||
|
};
|
||||||
|
|
||||||
use crate::config::NodeConfig;
|
use crate::config::NodeConfig;
|
||||||
use crate::services::{ConsensusService, MempoolService};
|
use crate::services::{ConsensusService, MempoolService};
|
||||||
|
|
@ -473,10 +476,7 @@ impl MinerService {
|
||||||
extra_data.extend_from_slice(&result.nonce.to_le_bytes());
|
extra_data.extend_from_slice(&result.nonce.to_le_bytes());
|
||||||
extra_data.extend_from_slice(&template.coinbase_data.extra_data);
|
extra_data.extend_from_slice(&template.coinbase_data.extra_data);
|
||||||
|
|
||||||
let coinbase_tx = Transaction::coinbase(
|
let coinbase_tx = Transaction::coinbase(vec![coinbase_output], extra_data);
|
||||||
vec![coinbase_output],
|
|
||||||
extra_data,
|
|
||||||
);
|
|
||||||
|
|
||||||
// Start with coinbase transaction
|
// Start with coinbase transaction
|
||||||
let mut transactions = vec![coinbase_tx];
|
let mut transactions = vec![coinbase_tx];
|
||||||
|
|
@ -522,8 +522,7 @@ impl MinerService {
|
||||||
let block = Block { header, body };
|
let block = Block { header, body };
|
||||||
|
|
||||||
// Serialize with Borsh
|
// Serialize with Borsh
|
||||||
borsh::to_vec(&block)
|
borsh::to_vec(&block).map_err(|e| anyhow::anyhow!("Failed to serialize block: {}", e))
|
||||||
.map_err(|e| anyhow::anyhow!("Failed to serialize block: {}", e))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Submits a mined block (for external submission via RPC).
|
/// Submits a mined block (for external submission via RPC).
|
||||||
|
|
|
||||||
|
|
@ -234,7 +234,10 @@ mod network_partition_tests {
|
||||||
|
|
||||||
// Node 0 should have fewer peers after isolation
|
// Node 0 should have fewer peers after isolation
|
||||||
let isolated_peers = network.nodes[0].network().peer_count().await;
|
let isolated_peers = network.nodes[0].network().peer_count().await;
|
||||||
info!(isolated_peers = isolated_peers, "Node 0 peers after isolation");
|
info!(
|
||||||
|
isolated_peers = isolated_peers,
|
||||||
|
"Node 0 peers after isolation"
|
||||||
|
);
|
||||||
|
|
||||||
assert!(
|
assert!(
|
||||||
isolated_peers < initial_peer_counts[0] || initial_peer_counts[0] == 0,
|
isolated_peers < initial_peer_counts[0] || initial_peer_counts[0] == 0,
|
||||||
|
|
@ -271,7 +274,10 @@ mod network_partition_tests {
|
||||||
|
|
||||||
// After healing, nodes should have peers
|
// After healing, nodes should have peers
|
||||||
let total_peers = network.total_peer_count().await;
|
let total_peers = network.total_peer_count().await;
|
||||||
info!(total_peers = total_peers, "Total peers after partition recovery");
|
info!(
|
||||||
|
total_peers = total_peers,
|
||||||
|
"Total peers after partition recovery"
|
||||||
|
);
|
||||||
|
|
||||||
// Consensus state should converge
|
// Consensus state should converge
|
||||||
let consensus0 = network.nodes[0].consensus();
|
let consensus0 = network.nodes[0].consensus();
|
||||||
|
|
@ -287,7 +293,10 @@ mod network_partition_tests {
|
||||||
);
|
);
|
||||||
|
|
||||||
// Both should have some consensus state
|
// Both should have some consensus state
|
||||||
assert!(vsp0.is_some() || vsp1.is_some(), "At least one node should have VSP");
|
assert!(
|
||||||
|
vsp0.is_some() || vsp1.is_some(),
|
||||||
|
"At least one node should have VSP"
|
||||||
|
);
|
||||||
|
|
||||||
network.stop_all().await.unwrap();
|
network.stop_all().await.unwrap();
|
||||||
}
|
}
|
||||||
|
|
@ -351,10 +360,12 @@ mod network_partition_tests {
|
||||||
|
|
||||||
// Record blue scores from each partition
|
// Record blue scores from each partition
|
||||||
let scores_before: Vec<u64> = futures::future::join_all(
|
let scores_before: Vec<u64> = futures::future::join_all(
|
||||||
network.nodes.iter().map(|n| async {
|
network
|
||||||
n.consensus().current_blue_score().await
|
.nodes
|
||||||
})
|
.iter()
|
||||||
).await;
|
.map(|n| async { n.consensus().current_blue_score().await }),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
info!(scores_before = ?scores_before, "Blue scores before healing");
|
info!(scores_before = ?scores_before, "Blue scores before healing");
|
||||||
|
|
||||||
|
|
@ -368,10 +379,12 @@ mod network_partition_tests {
|
||||||
|
|
||||||
// Blue scores should converge
|
// Blue scores should converge
|
||||||
let scores_after: Vec<u64> = futures::future::join_all(
|
let scores_after: Vec<u64> = futures::future::join_all(
|
||||||
network.nodes.iter().map(|n| async {
|
network
|
||||||
n.consensus().current_blue_score().await
|
.nodes
|
||||||
})
|
.iter()
|
||||||
).await;
|
.map(|n| async { n.consensus().current_blue_score().await }),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
info!(scores_after = ?scores_after, "Blue scores after healing");
|
info!(scores_after = ?scores_after, "Blue scores after healing");
|
||||||
|
|
||||||
|
|
@ -380,7 +393,9 @@ mod network_partition_tests {
|
||||||
assert!(
|
assert!(
|
||||||
after >= before,
|
after >= before,
|
||||||
"Node {} blue score should not decrease: {} -> {}",
|
"Node {} blue score should not decrease: {} -> {}",
|
||||||
i, before, after
|
i,
|
||||||
|
before,
|
||||||
|
after
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -407,10 +422,7 @@ mod double_spend_tests {
|
||||||
let mempool = network.nodes[0].mempool();
|
let mempool = network.nodes[0].mempool();
|
||||||
let initial_size = mempool.size().await;
|
let initial_size = mempool.size().await;
|
||||||
|
|
||||||
info!(
|
info!(initial_mempool_size = initial_size, "Initial mempool state");
|
||||||
initial_mempool_size = initial_size,
|
|
||||||
"Initial mempool state"
|
|
||||||
);
|
|
||||||
|
|
||||||
// In production, we would:
|
// In production, we would:
|
||||||
// 1. Create two transactions spending the same UTXO
|
// 1. Create two transactions spending the same UTXO
|
||||||
|
|
@ -420,7 +432,7 @@ mod double_spend_tests {
|
||||||
// For now, verify mempool API is working
|
// For now, verify mempool API is working
|
||||||
// and handles empty/invalid data gracefully
|
// and handles empty/invalid data gracefully
|
||||||
let _invalid_tx = vec![0u8; 50]; // Invalid transaction bytes (for future use)
|
let _invalid_tx = vec![0u8; 50]; // Invalid transaction bytes (for future use)
|
||||||
// Submitting invalid tx should fail gracefully
|
// Submitting invalid tx should fail gracefully
|
||||||
|
|
||||||
// Mempool should maintain integrity
|
// Mempool should maintain integrity
|
||||||
let final_size = mempool.size().await;
|
let final_size = mempool.size().await;
|
||||||
|
|
@ -653,7 +665,11 @@ mod invalid_block_rejection_tests {
|
||||||
|
|
||||||
// All valid tips should have known parents in the DAG
|
// All valid tips should have known parents in the DAG
|
||||||
for tip in &tips {
|
for tip in &tips {
|
||||||
let has_parents = consensus.get_block_info(tip).await.map(|info| !info.parents.is_empty()).unwrap_or(false);
|
let has_parents = consensus
|
||||||
|
.get_block_info(tip)
|
||||||
|
.await
|
||||||
|
.map(|info| !info.parents.is_empty())
|
||||||
|
.unwrap_or(false);
|
||||||
info!(
|
info!(
|
||||||
block = hex::encode(&tip[..8]),
|
block = hex::encode(&tip[..8]),
|
||||||
has_parents = has_parents,
|
has_parents = has_parents,
|
||||||
|
|
@ -687,16 +703,22 @@ mod sybil_attack_tests {
|
||||||
|
|
||||||
// Track blue scores - honest nodes should maintain correct view
|
// Track blue scores - honest nodes should maintain correct view
|
||||||
let honest_scores: Vec<u64> = futures::future::join_all(
|
let honest_scores: Vec<u64> = futures::future::join_all(
|
||||||
network.nodes.iter().take(3).map(|n| async {
|
network
|
||||||
n.consensus().current_blue_score().await
|
.nodes
|
||||||
})
|
.iter()
|
||||||
).await;
|
.take(3)
|
||||||
|
.map(|n| async { n.consensus().current_blue_score().await }),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
let sybil_scores: Vec<u64> = futures::future::join_all(
|
let sybil_scores: Vec<u64> = futures::future::join_all(
|
||||||
network.nodes.iter().skip(3).map(|n| async {
|
network
|
||||||
n.consensus().current_blue_score().await
|
.nodes
|
||||||
})
|
.iter()
|
||||||
).await;
|
.skip(3)
|
||||||
|
.map(|n| async { n.consensus().current_blue_score().await }),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
info!(
|
info!(
|
||||||
honest_scores = ?honest_scores,
|
honest_scores = ?honest_scores,
|
||||||
|
|
@ -805,7 +827,10 @@ mod eclipse_attack_tests {
|
||||||
sleep(Duration::from_secs(1)).await;
|
sleep(Duration::from_secs(1)).await;
|
||||||
|
|
||||||
let after_eclipse_peers = victim_network.peer_count().await;
|
let after_eclipse_peers = victim_network.peer_count().await;
|
||||||
info!(after_eclipse_peers = after_eclipse_peers, "Peers after eclipse attempt");
|
info!(
|
||||||
|
after_eclipse_peers = after_eclipse_peers,
|
||||||
|
"Peers after eclipse attempt"
|
||||||
|
);
|
||||||
|
|
||||||
// In a real implementation, the node would:
|
// In a real implementation, the node would:
|
||||||
// 1. Detect low peer diversity
|
// 1. Detect low peer diversity
|
||||||
|
|
@ -863,7 +888,10 @@ mod eclipse_attack_tests {
|
||||||
sleep(Duration::from_secs(1)).await;
|
sleep(Duration::from_secs(1)).await;
|
||||||
|
|
||||||
let eclipsed_peers = network.nodes[0].network().peer_count().await;
|
let eclipsed_peers = network.nodes[0].network().peer_count().await;
|
||||||
info!(eclipsed_peers = eclipsed_peers, "Node 0 peers during eclipse");
|
info!(
|
||||||
|
eclipsed_peers = eclipsed_peers,
|
||||||
|
"Node 0 peers during eclipse"
|
||||||
|
);
|
||||||
|
|
||||||
// Manually reconnect (simulating recovery mechanism)
|
// Manually reconnect (simulating recovery mechanism)
|
||||||
network.connect_nodes(0, 1).await.unwrap();
|
network.connect_nodes(0, 1).await.unwrap();
|
||||||
|
|
@ -871,7 +899,10 @@ mod eclipse_attack_tests {
|
||||||
sleep(Duration::from_secs(2)).await;
|
sleep(Duration::from_secs(2)).await;
|
||||||
|
|
||||||
let recovered_peers = network.nodes[0].network().peer_count().await;
|
let recovered_peers = network.nodes[0].network().peer_count().await;
|
||||||
info!(recovered_peers = recovered_peers, "Node 0 peers after recovery");
|
info!(
|
||||||
|
recovered_peers = recovered_peers,
|
||||||
|
"Node 0 peers after recovery"
|
||||||
|
);
|
||||||
|
|
||||||
// Should have reconnected
|
// Should have reconnected
|
||||||
assert!(
|
assert!(
|
||||||
|
|
@ -1038,10 +1069,12 @@ mod dag_reorg_tests {
|
||||||
|
|
||||||
// Record divergent states
|
// Record divergent states
|
||||||
let states_before: Vec<u64> = futures::future::join_all(
|
let states_before: Vec<u64> = futures::future::join_all(
|
||||||
network.nodes.iter().map(|n| async {
|
network
|
||||||
n.consensus().current_blue_score().await
|
.nodes
|
||||||
})
|
.iter()
|
||||||
).await;
|
.map(|n| async { n.consensus().current_blue_score().await }),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
info!(states_before = ?states_before, "States before reconnection");
|
info!(states_before = ?states_before, "States before reconnection");
|
||||||
|
|
||||||
|
|
@ -1052,10 +1085,12 @@ mod dag_reorg_tests {
|
||||||
|
|
||||||
// Get converged states
|
// Get converged states
|
||||||
let states_after: Vec<u64> = futures::future::join_all(
|
let states_after: Vec<u64> = futures::future::join_all(
|
||||||
network.nodes.iter().map(|n| async {
|
network
|
||||||
n.consensus().current_blue_score().await
|
.nodes
|
||||||
})
|
.iter()
|
||||||
).await;
|
.map(|n| async { n.consensus().current_blue_score().await }),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
info!(states_after = ?states_after, "States after reconnection");
|
info!(states_after = ?states_after, "States after reconnection");
|
||||||
|
|
||||||
|
|
@ -1064,7 +1099,9 @@ mod dag_reorg_tests {
|
||||||
assert!(
|
assert!(
|
||||||
after >= before,
|
after >= before,
|
||||||
"Node {} blue score regression: {} -> {}",
|
"Node {} blue score regression: {} -> {}",
|
||||||
i, before, after
|
i,
|
||||||
|
before,
|
||||||
|
after
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1192,10 +1229,12 @@ mod parallel_blocks_tests {
|
||||||
|
|
||||||
// Collect blue scores from all nodes
|
// Collect blue scores from all nodes
|
||||||
let blue_scores: Vec<u64> = futures::future::join_all(
|
let blue_scores: Vec<u64> = futures::future::join_all(
|
||||||
network.nodes.iter().map(|n| async {
|
network
|
||||||
n.consensus().current_blue_score().await
|
.nodes
|
||||||
})
|
.iter()
|
||||||
).await;
|
.map(|n| async { n.consensus().current_blue_score().await }),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
info!(blue_scores = ?blue_scores, "Blue scores across nodes");
|
info!(blue_scores = ?blue_scores, "Blue scores across nodes");
|
||||||
|
|
||||||
|
|
@ -1206,7 +1245,8 @@ mod parallel_blocks_tests {
|
||||||
assert!(
|
assert!(
|
||||||
max_score - min_score <= 2,
|
max_score - min_score <= 2,
|
||||||
"Blue scores should be consistent: {} - {} > 2",
|
"Blue scores should be consistent: {} - {} > 2",
|
||||||
max_score, min_score
|
max_score,
|
||||||
|
min_score
|
||||||
);
|
);
|
||||||
|
|
||||||
network.stop_all().await.unwrap();
|
network.stop_all().await.unwrap();
|
||||||
|
|
@ -1264,10 +1304,12 @@ mod parallel_blocks_tests {
|
||||||
|
|
||||||
// Get selected chains from all nodes
|
// Get selected chains from all nodes
|
||||||
let chains: Vec<Vec<[u8; 32]>> = futures::future::join_all(
|
let chains: Vec<Vec<[u8; 32]>> = futures::future::join_all(
|
||||||
network.nodes.iter().map(|n| async {
|
network
|
||||||
n.consensus().get_selected_chain(10).await
|
.nodes
|
||||||
})
|
.iter()
|
||||||
).await;
|
.map(|n| async { n.consensus().get_selected_chain(10).await }),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
info!(
|
info!(
|
||||||
chain_lengths = ?chains.iter().map(|c| c.len()).collect::<Vec<_>>(),
|
chain_lengths = ?chains.iter().map(|c| c.len()).collect::<Vec<_>>(),
|
||||||
|
|
@ -1276,7 +1318,8 @@ mod parallel_blocks_tests {
|
||||||
|
|
||||||
// All nodes should have the same selected chain (after sync)
|
// All nodes should have the same selected chain (after sync)
|
||||||
// Check that genesis (first block) matches
|
// Check that genesis (first block) matches
|
||||||
let genesis_blocks: Vec<_> = chains.iter()
|
let genesis_blocks: Vec<_> = chains
|
||||||
|
.iter()
|
||||||
.filter(|c| !c.is_empty())
|
.filter(|c| !c.is_empty())
|
||||||
.map(|c| c[0])
|
.map(|c| c[0])
|
||||||
.collect();
|
.collect();
|
||||||
|
|
@ -1353,10 +1396,13 @@ mod bft_threshold_tests {
|
||||||
|
|
||||||
// Honest nodes (0, 1, 2) should maintain consensus
|
// Honest nodes (0, 1, 2) should maintain consensus
|
||||||
let honest_scores: Vec<u64> = futures::future::join_all(
|
let honest_scores: Vec<u64> = futures::future::join_all(
|
||||||
network.nodes.iter().take(3).map(|n| async {
|
network
|
||||||
n.consensus().current_blue_score().await
|
.nodes
|
||||||
})
|
.iter()
|
||||||
).await;
|
.take(3)
|
||||||
|
.map(|n| async { n.consensus().current_blue_score().await }),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
info!(honest_scores = ?honest_scores, "Honest node blue scores");
|
info!(honest_scores = ?honest_scores, "Honest node blue scores");
|
||||||
|
|
||||||
|
|
@ -1399,10 +1445,7 @@ mod bft_threshold_tests {
|
||||||
|
|
||||||
// Blue score should not decrease
|
// Blue score should not decrease
|
||||||
let final_blue = network.nodes[0].consensus().current_blue_score().await;
|
let final_blue = network.nodes[0].consensus().current_blue_score().await;
|
||||||
assert!(
|
assert!(final_blue >= initial_blue, "Blue score should not decrease");
|
||||||
final_blue >= initial_blue,
|
|
||||||
"Blue score should not decrease"
|
|
||||||
);
|
|
||||||
|
|
||||||
// Stop remaining nodes
|
// Stop remaining nodes
|
||||||
for node in network.nodes.iter().take(3) {
|
for node in network.nodes.iter().take(3) {
|
||||||
|
|
@ -1615,10 +1658,12 @@ mod integration_tests {
|
||||||
|
|
||||||
// Record initial state
|
// Record initial state
|
||||||
let initial_scores: Vec<u64> = futures::future::join_all(
|
let initial_scores: Vec<u64> = futures::future::join_all(
|
||||||
network.nodes.iter().map(|n| async {
|
network
|
||||||
n.consensus().current_blue_score().await
|
.nodes
|
||||||
})
|
.iter()
|
||||||
).await;
|
.map(|n| async { n.consensus().current_blue_score().await }),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
info!(initial_scores = ?initial_scores, "Initial blue scores");
|
info!(initial_scores = ?initial_scores, "Initial blue scores");
|
||||||
|
|
||||||
info!("Phase 2: Simulate 2 Byzantine nodes (partition)");
|
info!("Phase 2: Simulate 2 Byzantine nodes (partition)");
|
||||||
|
|
@ -1640,18 +1685,24 @@ mod integration_tests {
|
||||||
|
|
||||||
info!("Phase 4: Verify convergence");
|
info!("Phase 4: Verify convergence");
|
||||||
let final_scores: Vec<u64> = futures::future::join_all(
|
let final_scores: Vec<u64> = futures::future::join_all(
|
||||||
network.nodes.iter().map(|n| async {
|
network
|
||||||
n.consensus().current_blue_score().await
|
.nodes
|
||||||
})
|
.iter()
|
||||||
).await;
|
.map(|n| async { n.consensus().current_blue_score().await }),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
info!(final_scores = ?final_scores, "Final blue scores");
|
info!(final_scores = ?final_scores, "Final blue scores");
|
||||||
|
|
||||||
// All nodes should have non-decreasing blue scores
|
// All nodes should have non-decreasing blue scores
|
||||||
for (i, (&initial, &final_score)) in initial_scores.iter().zip(final_scores.iter()).enumerate() {
|
for (i, (&initial, &final_score)) in
|
||||||
|
initial_scores.iter().zip(final_scores.iter()).enumerate()
|
||||||
|
{
|
||||||
assert!(
|
assert!(
|
||||||
final_score >= initial,
|
final_score >= initial,
|
||||||
"Node {} score regression: {} -> {}",
|
"Node {} score regression: {} -> {}",
|
||||||
i, initial, final_score
|
i,
|
||||||
|
initial,
|
||||||
|
final_score
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -17,8 +17,8 @@
|
||||||
//! 3. Vault contract verifies proof and unlocks original tokens
|
//! 3. Vault contract verifies proof and unlocks original tokens
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
AssetId, Bridge, BridgeAddress, BridgeError, BridgeResult, BridgeTransfer, ChainType, TransferId, TransferManager, TransferStatus, VaultManager,
|
AssetId, Bridge, BridgeAddress, BridgeError, BridgeResult, BridgeTransfer, ChainType,
|
||||||
ETH_MIN_CONFIRMATIONS,
|
TransferId, TransferManager, TransferStatus, VaultManager, ETH_MIN_CONFIRMATIONS,
|
||||||
};
|
};
|
||||||
use alloy_primitives::{Address, B256, U256};
|
use alloy_primitives::{Address, B256, U256};
|
||||||
use alloy_sol_types::sol;
|
use alloy_sol_types::sol;
|
||||||
|
|
@ -281,9 +281,9 @@ impl EthereumBridge {
|
||||||
// Check for replay
|
// Check for replay
|
||||||
let event_hash = event.hash();
|
let event_hash = event.hash();
|
||||||
if self.processed_events.read().contains_key(&event_hash) {
|
if self.processed_events.read().contains_key(&event_hash) {
|
||||||
return Err(BridgeError::TransferAlreadyExists(
|
return Err(BridgeError::TransferAlreadyExists(hex::encode(
|
||||||
hex::encode(event_hash.as_slice()),
|
event_hash.as_slice(),
|
||||||
));
|
)));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify token is supported
|
// Verify token is supported
|
||||||
|
|
@ -393,18 +393,15 @@ impl EthereumBridge {
|
||||||
// Collect matching transfer IDs first
|
// Collect matching transfer IDs first
|
||||||
let matching_transfer_id = {
|
let matching_transfer_id = {
|
||||||
let transfers = self.transfers.read();
|
let transfers = self.transfers.read();
|
||||||
transfers
|
transfers.pending_transfers().iter().find_map(|transfer| {
|
||||||
.pending_transfers()
|
transfer.source_tx_hash.as_ref().and_then(|tx_hash| {
|
||||||
.iter()
|
if tx_hash.as_slice() == event_hash.as_slice() {
|
||||||
.find_map(|transfer| {
|
Some(transfer.id.clone())
|
||||||
transfer.source_tx_hash.as_ref().and_then(|tx_hash| {
|
} else {
|
||||||
if tx_hash.as_slice() == event_hash.as_slice() {
|
None
|
||||||
Some(transfer.id.clone())
|
}
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
})
|
|
||||||
})
|
})
|
||||||
|
})
|
||||||
};
|
};
|
||||||
|
|
||||||
// Now update the transfer if found
|
// Now update the transfer if found
|
||||||
|
|
@ -457,7 +454,9 @@ impl EthereumBridge {
|
||||||
.map_err(|e| BridgeError::InvalidAddress(e.to_string()))?;
|
.map_err(|e| BridgeError::InvalidAddress(e.to_string()))?;
|
||||||
|
|
||||||
if bytes.len() != 20 {
|
if bytes.len() != 20 {
|
||||||
return Err(BridgeError::InvalidAddress("invalid address length".to_string()));
|
return Err(BridgeError::InvalidAddress(
|
||||||
|
"invalid address length".to_string(),
|
||||||
|
));
|
||||||
}
|
}
|
||||||
Address::from_slice(&bytes)
|
Address::from_slice(&bytes)
|
||||||
};
|
};
|
||||||
|
|
@ -801,7 +800,10 @@ mod tests {
|
||||||
wrapped.mint(1000);
|
wrapped.mint(1000);
|
||||||
|
|
||||||
let result = wrapped.burn(1500);
|
let result = wrapped.burn(1500);
|
||||||
assert!(matches!(result, Err(BridgeError::InsufficientBalance { .. })));
|
assert!(matches!(
|
||||||
|
result,
|
||||||
|
Err(BridgeError::InsufficientBalance { .. })
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -896,7 +898,9 @@ mod tests {
|
||||||
let current_time = 1700000000;
|
let current_time = 1700000000;
|
||||||
|
|
||||||
let event = create_lock_event(0);
|
let event = create_lock_event(0);
|
||||||
bridge.process_lock_event(event.clone(), current_time).unwrap();
|
bridge
|
||||||
|
.process_lock_event(event.clone(), current_time)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let result = bridge.process_lock_event(event, current_time + 100);
|
let result = bridge.process_lock_event(event, current_time + 100);
|
||||||
assert!(matches!(result, Err(BridgeError::TransferAlreadyExists(_))));
|
assert!(matches!(result, Err(BridgeError::TransferAlreadyExists(_))));
|
||||||
|
|
@ -949,8 +953,12 @@ mod tests {
|
||||||
let event_hash = B256::from([0x11; 32]);
|
let event_hash = B256::from([0x11; 32]);
|
||||||
let unauthorized_relayer = Address::from([0x99; 20]);
|
let unauthorized_relayer = Address::from([0x99; 20]);
|
||||||
|
|
||||||
let result = bridge.submit_relayer_signature(event_hash, unauthorized_relayer, vec![0x00; 65]);
|
let result =
|
||||||
assert!(matches!(result, Err(BridgeError::SignatureVerificationFailed(_))));
|
bridge.submit_relayer_signature(event_hash, unauthorized_relayer, vec![0x00; 65]);
|
||||||
|
assert!(matches!(
|
||||||
|
result,
|
||||||
|
Err(BridgeError::SignatureVerificationFailed(_))
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -964,7 +972,9 @@ mod tests {
|
||||||
});
|
});
|
||||||
|
|
||||||
let event_hash = B256::from([0x11; 32]);
|
let event_hash = B256::from([0x11; 32]);
|
||||||
let result = bridge.submit_relayer_signature(event_hash, relayer, vec![0x00; 65]).unwrap();
|
let result = bridge
|
||||||
|
.submit_relayer_signature(event_hash, relayer, vec![0x00; 65])
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
assert!(result);
|
assert!(result);
|
||||||
}
|
}
|
||||||
|
|
@ -983,7 +993,9 @@ mod tests {
|
||||||
.update_confirmations(&transfer_id, 12, current_time + 100)
|
.update_confirmations(&transfer_id, 12, current_time + 100)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
bridge.mint_wrapped_tokens(&transfer_id, current_time + 200).unwrap();
|
bridge
|
||||||
|
.mint_wrapped_tokens(&transfer_id, current_time + 200)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let wrapped = bridge.get_wrapped_token(Address::ZERO).unwrap();
|
let wrapped = bridge.get_wrapped_token(Address::ZERO).unwrap();
|
||||||
assert_eq!(wrapped.total_supply, 1000);
|
assert_eq!(wrapped.total_supply, 1000);
|
||||||
|
|
@ -1022,13 +1034,7 @@ mod tests {
|
||||||
|
|
||||||
let asset = AssetId::wrapped(&AssetId::eth());
|
let asset = AssetId::wrapped(&AssetId::eth());
|
||||||
let transfer_id = bridge
|
let transfer_id = bridge
|
||||||
.initiate_burn(
|
.initiate_burn(asset, 1000, test_recipient(), test_sender(), current_time)
|
||||||
asset,
|
|
||||||
1000,
|
|
||||||
test_recipient(),
|
|
||||||
test_sender(),
|
|
||||||
current_time,
|
|
||||||
)
|
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let transfers = bridge.transfers.read();
|
let transfers = bridge.transfers.read();
|
||||||
|
|
@ -1053,15 +1059,13 @@ mod tests {
|
||||||
drop(wrapped_tokens);
|
drop(wrapped_tokens);
|
||||||
|
|
||||||
let asset = AssetId::wrapped(&AssetId::eth());
|
let asset = AssetId::wrapped(&AssetId::eth());
|
||||||
let result = bridge.initiate_burn(
|
let result =
|
||||||
asset,
|
bridge.initiate_burn(asset, 1000, test_recipient(), test_sender(), current_time);
|
||||||
1000,
|
|
||||||
test_recipient(),
|
|
||||||
test_sender(),
|
|
||||||
current_time,
|
|
||||||
);
|
|
||||||
|
|
||||||
assert!(matches!(result, Err(BridgeError::InsufficientBalance { .. })));
|
assert!(matches!(
|
||||||
|
result,
|
||||||
|
Err(BridgeError::InsufficientBalance { .. })
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -1069,13 +1073,7 @@ mod tests {
|
||||||
let bridge = EthereumBridge::new(EthereumBridgeConfig::default());
|
let bridge = EthereumBridge::new(EthereumBridgeConfig::default());
|
||||||
|
|
||||||
let asset = AssetId::wrapped(&AssetId::eth());
|
let asset = AssetId::wrapped(&AssetId::eth());
|
||||||
let result = bridge.initiate_burn(
|
let result = bridge.initiate_burn(asset, 1000, test_recipient(), test_sender(), 0);
|
||||||
asset,
|
|
||||||
1000,
|
|
||||||
test_recipient(),
|
|
||||||
test_sender(),
|
|
||||||
0,
|
|
||||||
);
|
|
||||||
|
|
||||||
assert!(matches!(result, Err(BridgeError::AssetNotSupported(_))));
|
assert!(matches!(result, Err(BridgeError::AssetNotSupported(_))));
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -79,7 +79,9 @@ pub const ETH_MIN_CONFIRMATIONS: u64 = 12;
|
||||||
pub const BTC_MIN_CONFIRMATIONS: u64 = 6;
|
pub const BTC_MIN_CONFIRMATIONS: u64 = 6;
|
||||||
|
|
||||||
/// Bridge chain identifier
|
/// Bridge chain identifier
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, BorshSerialize, BorshDeserialize)]
|
#[derive(
|
||||||
|
Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, BorshSerialize, BorshDeserialize,
|
||||||
|
)]
|
||||||
pub enum ChainType {
|
pub enum ChainType {
|
||||||
/// Synor native chain
|
/// Synor native chain
|
||||||
Synor,
|
Synor,
|
||||||
|
|
@ -128,7 +130,9 @@ impl fmt::Display for ChainType {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Asset identifier across chains
|
/// Asset identifier across chains
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, BorshSerialize, BorshDeserialize)]
|
#[derive(
|
||||||
|
Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, BorshSerialize, BorshDeserialize,
|
||||||
|
)]
|
||||||
pub struct AssetId {
|
pub struct AssetId {
|
||||||
/// Chain where the asset originates
|
/// Chain where the asset originates
|
||||||
pub chain: ChainType,
|
pub chain: ChainType,
|
||||||
|
|
@ -199,7 +203,9 @@ impl fmt::Display for AssetId {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Bridge address (unified format for cross-chain addresses)
|
/// Bridge address (unified format for cross-chain addresses)
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, BorshSerialize, BorshDeserialize)]
|
#[derive(
|
||||||
|
Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, BorshSerialize, BorshDeserialize,
|
||||||
|
)]
|
||||||
pub struct BridgeAddress {
|
pub struct BridgeAddress {
|
||||||
/// Chain type
|
/// Chain type
|
||||||
pub chain: ChainType,
|
pub chain: ChainType,
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,9 @@ use std::collections::HashMap;
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
|
|
||||||
/// Unique transfer identifier
|
/// Unique transfer identifier
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, BorshSerialize, BorshDeserialize)]
|
#[derive(
|
||||||
|
Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, BorshSerialize, BorshDeserialize,
|
||||||
|
)]
|
||||||
pub struct TransferId(pub String);
|
pub struct TransferId(pub String);
|
||||||
|
|
||||||
impl TransferId {
|
impl TransferId {
|
||||||
|
|
@ -48,7 +50,9 @@ impl fmt::Display for TransferId {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Transfer direction
|
/// Transfer direction
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, BorshSerialize, BorshDeserialize)]
|
#[derive(
|
||||||
|
Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, BorshSerialize, BorshDeserialize,
|
||||||
|
)]
|
||||||
pub enum TransferDirection {
|
pub enum TransferDirection {
|
||||||
/// From external chain to Synor (Lock → Mint)
|
/// From external chain to Synor (Lock → Mint)
|
||||||
Inbound,
|
Inbound,
|
||||||
|
|
@ -66,7 +70,9 @@ impl fmt::Display for TransferDirection {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Transfer status
|
/// Transfer status
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, BorshSerialize, BorshDeserialize)]
|
#[derive(
|
||||||
|
Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, BorshSerialize, BorshDeserialize,
|
||||||
|
)]
|
||||||
pub enum TransferStatus {
|
pub enum TransferStatus {
|
||||||
/// Transfer initiated, awaiting lock confirmation
|
/// Transfer initiated, awaiting lock confirmation
|
||||||
Pending,
|
Pending,
|
||||||
|
|
@ -1030,7 +1036,10 @@ mod tests {
|
||||||
transfer.fail("Proof verification failed", current_time + 50);
|
transfer.fail("Proof verification failed", current_time + 50);
|
||||||
|
|
||||||
assert_eq!(transfer.status, TransferStatus::Failed);
|
assert_eq!(transfer.status, TransferStatus::Failed);
|
||||||
assert_eq!(transfer.error, Some("Proof verification failed".to_string()));
|
assert_eq!(
|
||||||
|
transfer.error,
|
||||||
|
Some("Proof verification failed".to_string())
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -1331,8 +1340,12 @@ mod tests {
|
||||||
|
|
||||||
assert_eq!(manager.pending_transfers().len(), 1);
|
assert_eq!(manager.pending_transfers().len(), 1);
|
||||||
|
|
||||||
manager.confirm_lock(&id, vec![0x11; 32], 100, current_time + 10).unwrap();
|
manager
|
||||||
manager.update_confirmations(&id, 12, current_time + 120).unwrap();
|
.confirm_lock(&id, vec![0x11; 32], 100, current_time + 10)
|
||||||
|
.unwrap();
|
||||||
|
manager
|
||||||
|
.update_confirmations(&id, 12, current_time + 120)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
assert_eq!(manager.pending_transfers().len(), 0);
|
assert_eq!(manager.pending_transfers().len(), 0);
|
||||||
}
|
}
|
||||||
|
|
@ -1356,8 +1369,12 @@ mod tests {
|
||||||
|
|
||||||
assert_eq!(manager.ready_for_confirmation().len(), 0);
|
assert_eq!(manager.ready_for_confirmation().len(), 0);
|
||||||
|
|
||||||
manager.confirm_lock(&id, vec![0x11; 32], 100, current_time + 10).unwrap();
|
manager
|
||||||
manager.update_confirmations(&id, 12, current_time + 120).unwrap();
|
.confirm_lock(&id, vec![0x11; 32], 100, current_time + 10)
|
||||||
|
.unwrap();
|
||||||
|
manager
|
||||||
|
.update_confirmations(&id, 12, current_time + 120)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
assert_eq!(manager.ready_for_confirmation().len(), 1);
|
assert_eq!(manager.ready_for_confirmation().len(), 1);
|
||||||
}
|
}
|
||||||
|
|
@ -1410,7 +1427,9 @@ mod tests {
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
manager.fail_transfer(&id, "Verification failed", current_time + 50).unwrap();
|
manager
|
||||||
|
.fail_transfer(&id, "Verification failed", current_time + 50)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let transfer = manager.get(&id).unwrap();
|
let transfer = manager.get(&id).unwrap();
|
||||||
assert_eq!(transfer.status, TransferStatus::Failed);
|
assert_eq!(transfer.status, TransferStatus::Failed);
|
||||||
|
|
@ -1452,9 +1471,15 @@ mod tests {
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
manager.confirm_lock(&id1, vec![0x11; 32], 100, current_time).unwrap();
|
manager
|
||||||
manager.update_confirmations(&id1, 12, current_time).unwrap();
|
.confirm_lock(&id1, vec![0x11; 32], 100, current_time)
|
||||||
manager.confirm_mint(&id1, vec![0x22; 32], current_time).unwrap();
|
.unwrap();
|
||||||
|
manager
|
||||||
|
.update_confirmations(&id1, 12, current_time)
|
||||||
|
.unwrap();
|
||||||
|
manager
|
||||||
|
.confirm_mint(&id1, vec![0x22; 32], current_time)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let stats = manager.stats();
|
let stats = manager.stats();
|
||||||
assert_eq!(stats.total_count, 2);
|
assert_eq!(stats.total_count, 2);
|
||||||
|
|
@ -1484,19 +1509,27 @@ mod tests {
|
||||||
let transfer = manager.get(&id).unwrap();
|
let transfer = manager.get(&id).unwrap();
|
||||||
assert_eq!(transfer.status, TransferStatus::Pending);
|
assert_eq!(transfer.status, TransferStatus::Pending);
|
||||||
|
|
||||||
manager.confirm_lock(&id, vec![0x11; 32], 100, current_time + 60).unwrap();
|
manager
|
||||||
|
.confirm_lock(&id, vec![0x11; 32], 100, current_time + 60)
|
||||||
|
.unwrap();
|
||||||
let transfer = manager.get(&id).unwrap();
|
let transfer = manager.get(&id).unwrap();
|
||||||
assert_eq!(transfer.status, TransferStatus::Locked);
|
assert_eq!(transfer.status, TransferStatus::Locked);
|
||||||
|
|
||||||
manager.update_confirmations(&id, 6, current_time + 120).unwrap();
|
manager
|
||||||
|
.update_confirmations(&id, 6, current_time + 120)
|
||||||
|
.unwrap();
|
||||||
let transfer = manager.get(&id).unwrap();
|
let transfer = manager.get(&id).unwrap();
|
||||||
assert_eq!(transfer.status, TransferStatus::Locked);
|
assert_eq!(transfer.status, TransferStatus::Locked);
|
||||||
|
|
||||||
manager.update_confirmations(&id, 12, current_time + 180).unwrap();
|
manager
|
||||||
|
.update_confirmations(&id, 12, current_time + 180)
|
||||||
|
.unwrap();
|
||||||
let transfer = manager.get(&id).unwrap();
|
let transfer = manager.get(&id).unwrap();
|
||||||
assert_eq!(transfer.status, TransferStatus::Confirmed);
|
assert_eq!(transfer.status, TransferStatus::Confirmed);
|
||||||
|
|
||||||
manager.confirm_mint(&id, vec![0x22; 32], current_time + 240).unwrap();
|
manager
|
||||||
|
.confirm_mint(&id, vec![0x22; 32], current_time + 240)
|
||||||
|
.unwrap();
|
||||||
let transfer = manager.get(&id).unwrap();
|
let transfer = manager.get(&id).unwrap();
|
||||||
assert_eq!(transfer.status, TransferStatus::Completed);
|
assert_eq!(transfer.status, TransferStatus::Completed);
|
||||||
}
|
}
|
||||||
|
|
@ -1521,15 +1554,21 @@ mod tests {
|
||||||
let transfer = manager.get(&id).unwrap();
|
let transfer = manager.get(&id).unwrap();
|
||||||
assert_eq!(transfer.status, TransferStatus::Pending);
|
assert_eq!(transfer.status, TransferStatus::Pending);
|
||||||
|
|
||||||
manager.confirm_lock(&id, vec![0x11; 32], 100, current_time + 60).unwrap();
|
manager
|
||||||
|
.confirm_lock(&id, vec![0x11; 32], 100, current_time + 60)
|
||||||
|
.unwrap();
|
||||||
let transfer = manager.get(&id).unwrap();
|
let transfer = manager.get(&id).unwrap();
|
||||||
assert_eq!(transfer.status, TransferStatus::Locked);
|
assert_eq!(transfer.status, TransferStatus::Locked);
|
||||||
|
|
||||||
manager.update_confirmations(&id, 6, current_time + 120).unwrap();
|
manager
|
||||||
|
.update_confirmations(&id, 6, current_time + 120)
|
||||||
|
.unwrap();
|
||||||
let transfer = manager.get(&id).unwrap();
|
let transfer = manager.get(&id).unwrap();
|
||||||
assert_eq!(transfer.status, TransferStatus::Confirmed);
|
assert_eq!(transfer.status, TransferStatus::Confirmed);
|
||||||
|
|
||||||
manager.confirm_unlock(&id, vec![0x33; 32], current_time + 180).unwrap();
|
manager
|
||||||
|
.confirm_unlock(&id, vec![0x33; 32], current_time + 180)
|
||||||
|
.unwrap();
|
||||||
let transfer = manager.get(&id).unwrap();
|
let transfer = manager.get(&id).unwrap();
|
||||||
assert_eq!(transfer.status, TransferStatus::Completed);
|
assert_eq!(transfer.status, TransferStatus::Completed);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,9 @@ use std::collections::HashMap;
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
|
|
||||||
/// Unique vault identifier
|
/// Unique vault identifier
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, BorshSerialize, BorshDeserialize)]
|
#[derive(
|
||||||
|
Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, BorshSerialize, BorshDeserialize,
|
||||||
|
)]
|
||||||
pub struct VaultId(pub String);
|
pub struct VaultId(pub String);
|
||||||
|
|
||||||
impl VaultId {
|
impl VaultId {
|
||||||
|
|
@ -198,13 +200,7 @@ impl Vault {
|
||||||
return Err(BridgeError::TransferAlreadyExists(lock_id));
|
return Err(BridgeError::TransferAlreadyExists(lock_id));
|
||||||
}
|
}
|
||||||
|
|
||||||
let locked = LockedAsset::new(
|
let locked = LockedAsset::new(self.asset.clone(), amount, owner, recipient, current_time);
|
||||||
self.asset.clone(),
|
|
||||||
amount,
|
|
||||||
owner,
|
|
||||||
recipient,
|
|
||||||
current_time,
|
|
||||||
);
|
|
||||||
|
|
||||||
self.locked_assets.insert(lock_id, locked);
|
self.locked_assets.insert(lock_id, locked);
|
||||||
self.total_locked += amount;
|
self.total_locked += amount;
|
||||||
|
|
@ -283,7 +279,10 @@ impl Vault {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get expired locked assets
|
/// Get expired locked assets
|
||||||
pub fn expired_locked(&self, current_time: u64) -> impl Iterator<Item = (&String, &LockedAsset)> {
|
pub fn expired_locked(
|
||||||
|
&self,
|
||||||
|
current_time: u64,
|
||||||
|
) -> impl Iterator<Item = (&String, &LockedAsset)> {
|
||||||
self.locked_assets
|
self.locked_assets
|
||||||
.iter()
|
.iter()
|
||||||
.filter(move |(_, l)| !l.released && l.is_expired(current_time))
|
.filter(move |(_, l)| !l.released && l.is_expired(current_time))
|
||||||
|
|
@ -511,14 +510,8 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_locked_asset_expiry() {
|
fn test_locked_asset_expiry() {
|
||||||
let locked = LockedAsset::new(
|
let locked = LockedAsset::new(AssetId::eth(), 1000, test_owner(), test_recipient(), 1000)
|
||||||
AssetId::eth(),
|
.with_expiry(2000);
|
||||||
1000,
|
|
||||||
test_owner(),
|
|
||||||
test_recipient(),
|
|
||||||
1000,
|
|
||||||
)
|
|
||||||
.with_expiry(2000);
|
|
||||||
|
|
||||||
assert!(!locked.is_expired(1500));
|
assert!(!locked.is_expired(1500));
|
||||||
assert!(locked.is_expired(2000));
|
assert!(locked.is_expired(2000));
|
||||||
|
|
@ -624,11 +617,7 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_lock_unlock() {
|
fn test_lock_unlock() {
|
||||||
let mut vault = Vault::new(
|
let mut vault = Vault::new(VaultId::new("test"), ChainType::Ethereum, AssetId::eth());
|
||||||
VaultId::new("test"),
|
|
||||||
ChainType::Ethereum,
|
|
||||||
AssetId::eth(),
|
|
||||||
);
|
|
||||||
|
|
||||||
let current_time = 1700000000;
|
let current_time = 1700000000;
|
||||||
|
|
||||||
|
|
@ -654,20 +643,28 @@ mod tests {
|
||||||
);
|
);
|
||||||
|
|
||||||
let current_time = 1700000000;
|
let current_time = 1700000000;
|
||||||
vault.lock("lock-1", 1000, test_owner(), test_recipient(), current_time).unwrap();
|
vault
|
||||||
vault.lock("lock-2", 2000, test_owner(), test_recipient(), current_time).unwrap();
|
.lock("lock-1", 1000, test_owner(), test_recipient(), current_time)
|
||||||
vault.lock("lock-3", 500, test_owner_alt(), test_recipient(), current_time).unwrap();
|
.unwrap();
|
||||||
|
vault
|
||||||
|
.lock("lock-2", 2000, test_owner(), test_recipient(), current_time)
|
||||||
|
.unwrap();
|
||||||
|
vault
|
||||||
|
.lock(
|
||||||
|
"lock-3",
|
||||||
|
500,
|
||||||
|
test_owner_alt(),
|
||||||
|
test_recipient(),
|
||||||
|
current_time,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
assert_eq!(vault.total_locked, 3500);
|
assert_eq!(vault.total_locked, 3500);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_duplicate_lock() {
|
fn test_duplicate_lock() {
|
||||||
let mut vault = Vault::new(
|
let mut vault = Vault::new(VaultId::new("test"), ChainType::Ethereum, AssetId::eth());
|
||||||
VaultId::new("test"),
|
|
||||||
ChainType::Ethereum,
|
|
||||||
AssetId::eth(),
|
|
||||||
);
|
|
||||||
|
|
||||||
vault
|
vault
|
||||||
.lock("lock1", 1000, test_owner(), test_recipient(), 0)
|
.lock("lock1", 1000, test_owner(), test_recipient(), 0)
|
||||||
|
|
@ -697,20 +694,21 @@ mod tests {
|
||||||
AssetId::eth(),
|
AssetId::eth(),
|
||||||
);
|
);
|
||||||
|
|
||||||
vault.lock("lock-1", 1000, test_owner(), test_recipient(), 0).unwrap();
|
vault
|
||||||
|
.lock("lock-1", 1000, test_owner(), test_recipient(), 0)
|
||||||
|
.unwrap();
|
||||||
vault.unlock("lock-1").unwrap();
|
vault.unlock("lock-1").unwrap();
|
||||||
|
|
||||||
let result = vault.unlock("lock-1");
|
let result = vault.unlock("lock-1");
|
||||||
assert!(matches!(result, Err(BridgeError::TransferAlreadyCompleted(_))));
|
assert!(matches!(
|
||||||
|
result,
|
||||||
|
Err(BridgeError::TransferAlreadyCompleted(_))
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_vault_pause() {
|
fn test_vault_pause() {
|
||||||
let mut vault = Vault::new(
|
let mut vault = Vault::new(VaultId::new("test"), ChainType::Ethereum, AssetId::eth());
|
||||||
VaultId::new("test"),
|
|
||||||
ChainType::Ethereum,
|
|
||||||
AssetId::eth(),
|
|
||||||
);
|
|
||||||
|
|
||||||
vault.pause();
|
vault.pause();
|
||||||
|
|
||||||
|
|
@ -730,7 +728,9 @@ mod tests {
|
||||||
vault.resume();
|
vault.resume();
|
||||||
|
|
||||||
assert_eq!(vault.state, VaultState::Active);
|
assert_eq!(vault.state, VaultState::Active);
|
||||||
vault.lock("lock-1", 1000, test_owner(), test_recipient(), 0).unwrap();
|
vault
|
||||||
|
.lock("lock-1", 1000, test_owner(), test_recipient(), 0)
|
||||||
|
.unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -750,12 +750,8 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_daily_limit() {
|
fn test_daily_limit() {
|
||||||
let mut vault = Vault::new(
|
let mut vault = Vault::new(VaultId::new("test"), ChainType::Ethereum, AssetId::eth())
|
||||||
VaultId::new("test"),
|
.with_daily_limit(1000);
|
||||||
ChainType::Ethereum,
|
|
||||||
AssetId::eth(),
|
|
||||||
)
|
|
||||||
.with_daily_limit(1000);
|
|
||||||
|
|
||||||
let current_time = 86400 * 100;
|
let current_time = 86400 * 100;
|
||||||
|
|
||||||
|
|
@ -781,8 +777,24 @@ mod tests {
|
||||||
);
|
);
|
||||||
|
|
||||||
let current_time = 0;
|
let current_time = 0;
|
||||||
vault.lock("lock-1", 1000000000, test_owner(), test_recipient(), current_time).unwrap();
|
vault
|
||||||
vault.lock("lock-2", 1000000000, test_owner(), test_recipient(), current_time).unwrap();
|
.lock(
|
||||||
|
"lock-1",
|
||||||
|
1000000000,
|
||||||
|
test_owner(),
|
||||||
|
test_recipient(),
|
||||||
|
current_time,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
vault
|
||||||
|
.lock(
|
||||||
|
"lock-2",
|
||||||
|
1000000000,
|
||||||
|
test_owner(),
|
||||||
|
test_recipient(),
|
||||||
|
current_time,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
assert_eq!(vault.total_locked, 2000000000);
|
assert_eq!(vault.total_locked, 2000000000);
|
||||||
}
|
}
|
||||||
|
|
@ -795,7 +807,9 @@ mod tests {
|
||||||
AssetId::eth(),
|
AssetId::eth(),
|
||||||
);
|
);
|
||||||
|
|
||||||
vault.lock("lock-1", 1000, test_owner(), test_recipient(), 0).unwrap();
|
vault
|
||||||
|
.lock("lock-1", 1000, test_owner(), test_recipient(), 0)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
assert!(vault.get_locked("lock-1").is_some());
|
assert!(vault.get_locked("lock-1").is_some());
|
||||||
assert!(vault.get_locked("nonexistent").is_none());
|
assert!(vault.get_locked("nonexistent").is_none());
|
||||||
|
|
@ -809,8 +823,12 @@ mod tests {
|
||||||
AssetId::eth(),
|
AssetId::eth(),
|
||||||
);
|
);
|
||||||
|
|
||||||
vault.lock("lock-1", 1000, test_owner(), test_recipient(), 0).unwrap();
|
vault
|
||||||
vault.lock("lock-2", 2000, test_owner(), test_recipient(), 0).unwrap();
|
.lock("lock-1", 1000, test_owner(), test_recipient(), 0)
|
||||||
|
.unwrap();
|
||||||
|
vault
|
||||||
|
.lock("lock-2", 2000, test_owner(), test_recipient(), 0)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let all: Vec<_> = vault.all_locked().collect();
|
let all: Vec<_> = vault.all_locked().collect();
|
||||||
assert_eq!(all.len(), 2);
|
assert_eq!(all.len(), 2);
|
||||||
|
|
@ -824,8 +842,12 @@ mod tests {
|
||||||
AssetId::eth(),
|
AssetId::eth(),
|
||||||
);
|
);
|
||||||
|
|
||||||
vault.lock("lock-1", 1000, test_owner(), test_recipient(), 0).unwrap();
|
vault
|
||||||
vault.lock("lock-2", 2000, test_owner(), test_recipient(), 0).unwrap();
|
.lock("lock-1", 1000, test_owner(), test_recipient(), 0)
|
||||||
|
.unwrap();
|
||||||
|
vault
|
||||||
|
.lock("lock-2", 2000, test_owner(), test_recipient(), 0)
|
||||||
|
.unwrap();
|
||||||
vault.unlock("lock-1").unwrap();
|
vault.unlock("lock-1").unwrap();
|
||||||
|
|
||||||
let active: Vec<_> = vault.active_locked().collect();
|
let active: Vec<_> = vault.active_locked().collect();
|
||||||
|
|
@ -858,7 +880,9 @@ mod tests {
|
||||||
assert!(manager.find_vault(&ChainType::Ethereum, ð).is_some());
|
assert!(manager.find_vault(&ChainType::Ethereum, ð).is_some());
|
||||||
|
|
||||||
let vault = manager.get_or_create_vault(ChainType::Ethereum, eth.clone());
|
let vault = manager.get_or_create_vault(ChainType::Ethereum, eth.clone());
|
||||||
vault.lock("lock1", 100, test_owner(), test_recipient(), 0).unwrap();
|
vault
|
||||||
|
.lock("lock1", 100, test_owner(), test_recipient(), 0)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
assert_eq!(manager.total_locked(), 100);
|
assert_eq!(manager.total_locked(), 100);
|
||||||
}
|
}
|
||||||
|
|
@ -881,7 +905,9 @@ mod tests {
|
||||||
|
|
||||||
{
|
{
|
||||||
let vault = manager.get_vault_mut(&vault_id).unwrap();
|
let vault = manager.get_vault_mut(&vault_id).unwrap();
|
||||||
vault.lock("lock-1", 1000, test_owner(), test_recipient(), 0).unwrap();
|
vault
|
||||||
|
.lock("lock-1", 1000, test_owner(), test_recipient(), 0)
|
||||||
|
.unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
let vault = manager.get_vault(&vault_id).unwrap();
|
let vault = manager.get_vault(&vault_id).unwrap();
|
||||||
|
|
@ -902,7 +928,9 @@ mod tests {
|
||||||
manager.create_vault(ChainType::Ethereum, eth.clone());
|
manager.create_vault(ChainType::Ethereum, eth.clone());
|
||||||
|
|
||||||
let vault = manager.find_vault_mut(&ChainType::Ethereum, ð).unwrap();
|
let vault = manager.find_vault_mut(&ChainType::Ethereum, ð).unwrap();
|
||||||
vault.lock("lock-1", 1000, test_owner(), test_recipient(), 0).unwrap();
|
vault
|
||||||
|
.lock("lock-1", 1000, test_owner(), test_recipient(), 0)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
assert_eq!(manager.total_locked(), 1000);
|
assert_eq!(manager.total_locked(), 1000);
|
||||||
}
|
}
|
||||||
|
|
@ -913,7 +941,9 @@ mod tests {
|
||||||
let eth = AssetId::eth();
|
let eth = AssetId::eth();
|
||||||
|
|
||||||
let vault = manager.get_or_create_vault(ChainType::Ethereum, eth.clone());
|
let vault = manager.get_or_create_vault(ChainType::Ethereum, eth.clone());
|
||||||
vault.lock("lock-1", 1000, test_owner(), test_recipient(), 0).unwrap();
|
vault
|
||||||
|
.lock("lock-1", 1000, test_owner(), test_recipient(), 0)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
assert_eq!(manager.vault_ids().len(), 1);
|
assert_eq!(manager.vault_ids().len(), 1);
|
||||||
assert_eq!(manager.total_locked(), 1000);
|
assert_eq!(manager.total_locked(), 1000);
|
||||||
|
|
|
||||||
|
|
@ -241,7 +241,10 @@ impl DeviceRegistry {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Gets a processor by ID.
|
/// Gets a processor by ID.
|
||||||
pub fn get_processor(&self, processor_id: ProcessorId) -> Result<Arc<dyn Processor>, ComputeError> {
|
pub fn get_processor(
|
||||||
|
&self,
|
||||||
|
processor_id: ProcessorId,
|
||||||
|
) -> Result<Arc<dyn Processor>, ComputeError> {
|
||||||
self.processors
|
self.processors
|
||||||
.read()
|
.read()
|
||||||
.get(&processor_id)
|
.get(&processor_id)
|
||||||
|
|
@ -266,7 +269,10 @@ impl DeviceRegistry {
|
||||||
|
|
||||||
/// Gets the next processor ID.
|
/// Gets the next processor ID.
|
||||||
pub fn next_processor_id(&self) -> ProcessorId {
|
pub fn next_processor_id(&self) -> ProcessorId {
|
||||||
ProcessorId(self.next_processor_id.fetch_add(1, std::sync::atomic::Ordering::SeqCst))
|
ProcessorId(
|
||||||
|
self.next_processor_id
|
||||||
|
.fetch_add(1, std::sync::atomic::Ordering::SeqCst),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Gets total number of devices.
|
/// Gets total number of devices.
|
||||||
|
|
@ -309,7 +315,10 @@ impl DeviceRegistry {
|
||||||
device.status = status;
|
device.status = status;
|
||||||
Ok(())
|
Ok(())
|
||||||
} else {
|
} else {
|
||||||
Err(ComputeError::Internal(format!("Device not found: {}", device_id)))
|
Err(ComputeError::Internal(format!(
|
||||||
|
"Device not found: {}",
|
||||||
|
device_id
|
||||||
|
)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -323,7 +332,7 @@ impl Default for DeviceRegistry {
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::processor::{CpuVariant, AvxSupport};
|
use crate::processor::{AvxSupport, CpuVariant};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_device_id() {
|
fn test_device_id() {
|
||||||
|
|
|
||||||
|
|
@ -67,6 +67,10 @@ pub use market::{
|
||||||
ResourceType, SpotMarket, Trade,
|
ResourceType, SpotMarket, Trade,
|
||||||
};
|
};
|
||||||
pub use memory::{MemoryManager, TensorHandle, TransferPath, UnifiedMemory};
|
pub use memory::{MemoryManager, TensorHandle, TransferPath, UnifiedMemory};
|
||||||
|
pub use model::{
|
||||||
|
ModelCategory, ModelFormat, ModelId, ModelInfo, ModelRegistry, ModelUploadRequest,
|
||||||
|
ModelUploadResponse,
|
||||||
|
};
|
||||||
pub use processor::{
|
pub use processor::{
|
||||||
ComputeThroughput, CpuVariant, GpuVariant, NpuVariant, Operation, OperationType, Processor,
|
ComputeThroughput, CpuVariant, GpuVariant, NpuVariant, Operation, OperationType, Processor,
|
||||||
ProcessorCapabilities, ProcessorId, ProcessorType, TpuVersion,
|
ProcessorCapabilities, ProcessorId, ProcessorType, TpuVersion,
|
||||||
|
|
@ -78,10 +82,6 @@ pub use task::{
|
||||||
ComputeTask, DecomposedWorkload, Task, TaskDecomposer, TaskId, TaskPriority, TaskResult,
|
ComputeTask, DecomposedWorkload, Task, TaskDecomposer, TaskId, TaskPriority, TaskResult,
|
||||||
TaskStatus,
|
TaskStatus,
|
||||||
};
|
};
|
||||||
pub use model::{
|
|
||||||
ModelCategory, ModelFormat, ModelId, ModelInfo, ModelRegistry, ModelUploadRequest,
|
|
||||||
ModelUploadResponse,
|
|
||||||
};
|
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
@ -434,7 +434,10 @@ impl ComputeCluster {
|
||||||
let jobs = self.jobs.read();
|
let jobs = self.jobs.read();
|
||||||
|
|
||||||
let total_nodes = nodes.len();
|
let total_nodes = nodes.len();
|
||||||
let online_nodes = nodes.values().filter(|n| n.status == NodeStatus::Online).count();
|
let online_nodes = nodes
|
||||||
|
.values()
|
||||||
|
.filter(|n| n.status == NodeStatus::Online)
|
||||||
|
.count();
|
||||||
|
|
||||||
let total_gpus: usize = nodes
|
let total_gpus: usize = nodes
|
||||||
.values()
|
.values()
|
||||||
|
|
@ -515,16 +518,16 @@ pub enum GpuTier {
|
||||||
impl Default for ComputePricing {
|
impl Default for ComputePricing {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
let mut gpu_hourly = HashMap::new();
|
let mut gpu_hourly = HashMap::new();
|
||||||
gpu_hourly.insert(GpuTier::Consumer, 100_000_000); // 0.10 SYNOR
|
gpu_hourly.insert(GpuTier::Consumer, 100_000_000); // 0.10 SYNOR
|
||||||
gpu_hourly.insert(GpuTier::Professional, 300_000_000); // 0.30 SYNOR
|
gpu_hourly.insert(GpuTier::Professional, 300_000_000); // 0.30 SYNOR
|
||||||
gpu_hourly.insert(GpuTier::DataCenter, 2_000_000_000); // 2.00 SYNOR
|
gpu_hourly.insert(GpuTier::DataCenter, 2_000_000_000); // 2.00 SYNOR
|
||||||
gpu_hourly.insert(GpuTier::Premium, 4_000_000_000); // 4.00 SYNOR
|
gpu_hourly.insert(GpuTier::Premium, 4_000_000_000); // 4.00 SYNOR
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
gpu_hourly,
|
gpu_hourly,
|
||||||
cpu_core_hour: 20_000_000, // 0.02 SYNOR
|
cpu_core_hour: 20_000_000, // 0.02 SYNOR
|
||||||
memory_gb_hour: 5_000_000, // 0.005 SYNOR
|
memory_gb_hour: 5_000_000, // 0.005 SYNOR
|
||||||
network_egress_gb: 50_000_000, // 0.05 SYNOR
|
network_egress_gb: 50_000_000, // 0.05 SYNOR
|
||||||
inference_per_million_tokens: 100_000_000, // 0.10 SYNOR
|
inference_per_million_tokens: 100_000_000, // 0.10 SYNOR
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -686,24 +686,24 @@ impl PricingEngine {
|
||||||
pub fn greenest_region(&self) -> &str {
|
pub fn greenest_region(&self) -> &str {
|
||||||
self.regions
|
self.regions
|
||||||
.iter()
|
.iter()
|
||||||
.max_by(|a, b| {
|
.max_by(|a, b| a.renewable_pct.partial_cmp(&b.renewable_pct).unwrap())
|
||||||
a.renewable_pct
|
|
||||||
.partial_cmp(&b.renewable_pct)
|
|
||||||
.unwrap()
|
|
||||||
})
|
|
||||||
.map(|r| r.region.as_str())
|
.map(|r| r.region.as_str())
|
||||||
.unwrap_or("eu-north")
|
.unwrap_or("eu-north")
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Compares price to cloud providers.
|
/// Compares price to cloud providers.
|
||||||
pub fn compare_to_cloud(&self, resource: &ResourceType, region: Option<&str>) -> CloudComparison {
|
pub fn compare_to_cloud(
|
||||||
|
&self,
|
||||||
|
resource: &ResourceType,
|
||||||
|
region: Option<&str>,
|
||||||
|
) -> CloudComparison {
|
||||||
let our_price = self.spot_price(resource, region);
|
let our_price = self.spot_price(resource, region);
|
||||||
|
|
||||||
// Approximate cloud provider prices (USD/hour for GPU)
|
// Approximate cloud provider prices (USD/hour for GPU)
|
||||||
let (aws_price, gcp_price, azure_price) = match resource {
|
let (aws_price, gcp_price, azure_price) = match resource {
|
||||||
ResourceType::GpuHours(GpuTier::DataCenter) => (3.06, 2.95, 3.10), // A100 equivalents
|
ResourceType::GpuHours(GpuTier::DataCenter) => (3.06, 2.95, 3.10), // A100 equivalents
|
||||||
ResourceType::GpuHours(GpuTier::Ultra) => (5.00, 4.50, 5.20), // H100 equivalents
|
ResourceType::GpuHours(GpuTier::Ultra) => (5.00, 4.50, 5.20), // H100 equivalents
|
||||||
ResourceType::GpuHours(GpuTier::High) => (1.50, 1.40, 1.60), // T4/A10 equivalents
|
ResourceType::GpuHours(GpuTier::High) => (1.50, 1.40, 1.60), // T4/A10 equivalents
|
||||||
ResourceType::CpuHours(CpuTier::Server) => (0.40, 0.35, 0.42),
|
ResourceType::CpuHours(CpuTier::Server) => (0.40, 0.35, 0.42),
|
||||||
_ => (1.0, 1.0, 1.0),
|
_ => (1.0, 1.0, 1.0),
|
||||||
};
|
};
|
||||||
|
|
@ -888,9 +888,18 @@ impl SpotMarket {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
order_books.insert(ResourceType::TpuHours, OrderBook::new(ResourceType::TpuHours));
|
order_books.insert(
|
||||||
order_books.insert(ResourceType::NpuHours, OrderBook::new(ResourceType::NpuHours));
|
ResourceType::TpuHours,
|
||||||
order_books.insert(ResourceType::LpuCredits, OrderBook::new(ResourceType::LpuCredits));
|
OrderBook::new(ResourceType::TpuHours),
|
||||||
|
);
|
||||||
|
order_books.insert(
|
||||||
|
ResourceType::NpuHours,
|
||||||
|
OrderBook::new(ResourceType::NpuHours),
|
||||||
|
);
|
||||||
|
order_books.insert(
|
||||||
|
ResourceType::LpuCredits,
|
||||||
|
OrderBook::new(ResourceType::LpuCredits),
|
||||||
|
);
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
order_books,
|
order_books,
|
||||||
|
|
@ -1074,12 +1083,21 @@ mod tests {
|
||||||
fn test_pricing_engine() {
|
fn test_pricing_engine() {
|
||||||
let engine = PricingEngine::new();
|
let engine = PricingEngine::new();
|
||||||
|
|
||||||
let price = engine.spot_price(&ResourceType::GpuHours(GpuTier::DataCenter), Some("eu-north"));
|
let price = engine.spot_price(
|
||||||
|
&ResourceType::GpuHours(GpuTier::DataCenter),
|
||||||
|
Some("eu-north"),
|
||||||
|
);
|
||||||
assert!(price > 0.0);
|
assert!(price > 0.0);
|
||||||
|
|
||||||
// eu-north should be cheaper (low electricity cost)
|
// eu-north should be cheaper (low electricity cost)
|
||||||
let eu_price = engine.spot_price(&ResourceType::GpuHours(GpuTier::DataCenter), Some("eu-north"));
|
let eu_price = engine.spot_price(
|
||||||
let eu_west_price = engine.spot_price(&ResourceType::GpuHours(GpuTier::DataCenter), Some("eu-west"));
|
&ResourceType::GpuHours(GpuTier::DataCenter),
|
||||||
|
Some("eu-north"),
|
||||||
|
);
|
||||||
|
let eu_west_price = engine.spot_price(
|
||||||
|
&ResourceType::GpuHours(GpuTier::DataCenter),
|
||||||
|
Some("eu-west"),
|
||||||
|
);
|
||||||
|
|
||||||
// eu-north has cheaper electricity
|
// eu-north has cheaper electricity
|
||||||
assert!(eu_price < eu_west_price);
|
assert!(eu_price < eu_west_price);
|
||||||
|
|
@ -1089,7 +1107,8 @@ mod tests {
|
||||||
fn test_cloud_comparison() {
|
fn test_cloud_comparison() {
|
||||||
let engine = PricingEngine::new();
|
let engine = PricingEngine::new();
|
||||||
|
|
||||||
let comparison = engine.compare_to_cloud(&ResourceType::GpuHours(GpuTier::DataCenter), None);
|
let comparison =
|
||||||
|
engine.compare_to_cloud(&ResourceType::GpuHours(GpuTier::DataCenter), None);
|
||||||
|
|
||||||
// Should show significant savings
|
// Should show significant savings
|
||||||
assert!(comparison.aws_savings > 50.0);
|
assert!(comparison.aws_savings > 50.0);
|
||||||
|
|
|
||||||
|
|
@ -106,11 +106,11 @@ impl TransferPath {
|
||||||
/// Returns approximate bandwidth in GB/s.
|
/// Returns approximate bandwidth in GB/s.
|
||||||
pub fn bandwidth_gbps(&self) -> f64 {
|
pub fn bandwidth_gbps(&self) -> f64 {
|
||||||
match self {
|
match self {
|
||||||
TransferPath::NvLink => 900.0, // NVLink 4.0
|
TransferPath::NvLink => 900.0, // NVLink 4.0
|
||||||
TransferPath::PciePeerToPeer => 64.0, // PCIe 5.0 x16
|
TransferPath::PciePeerToPeer => 64.0, // PCIe 5.0 x16
|
||||||
TransferPath::CpuMediated => 50.0, // DDR5
|
TransferPath::CpuMediated => 50.0, // DDR5
|
||||||
TransferPath::UnifiedMemory => 400.0, // Apple unified
|
TransferPath::UnifiedMemory => 400.0, // Apple unified
|
||||||
TransferPath::Network => 10.0, // 100Gbps network
|
TransferPath::Network => 10.0, // 100Gbps network
|
||||||
TransferPath::SameMemory => f64::INFINITY,
|
TransferPath::SameMemory => f64::INFINITY,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -154,7 +154,11 @@ impl MemoryManager {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Allocates a tensor.
|
/// Allocates a tensor.
|
||||||
pub fn allocate(&self, shape: Vec<usize>, dtype: DataType) -> Result<TensorHandle, ComputeError> {
|
pub fn allocate(
|
||||||
|
&self,
|
||||||
|
shape: Vec<usize>,
|
||||||
|
dtype: DataType,
|
||||||
|
) -> Result<TensorHandle, ComputeError> {
|
||||||
let handle = TensorHandle::new(shape, dtype);
|
let handle = TensorHandle::new(shape, dtype);
|
||||||
self.tensors.write().insert(handle.id, handle.clone());
|
self.tensors.write().insert(handle.id, handle.clone());
|
||||||
Ok(handle)
|
Ok(handle)
|
||||||
|
|
@ -223,9 +227,13 @@ impl MemoryManager {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for NVLink between NVIDIA GPUs
|
// Check for NVLink between NVIDIA GPUs
|
||||||
if matches!(from, ProcessorType::Gpu(crate::processor::GpuVariant::NvidiaCuda { .. }))
|
if matches!(
|
||||||
&& matches!(to, ProcessorType::Gpu(crate::processor::GpuVariant::NvidiaCuda { .. }))
|
from,
|
||||||
{
|
ProcessorType::Gpu(crate::processor::GpuVariant::NvidiaCuda { .. })
|
||||||
|
) && matches!(
|
||||||
|
to,
|
||||||
|
ProcessorType::Gpu(crate::processor::GpuVariant::NvidiaCuda { .. })
|
||||||
|
) {
|
||||||
return TransferPath::NvLink;
|
return TransferPath::NvLink;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -244,10 +252,22 @@ impl MemoryManager {
|
||||||
|
|
||||||
match (a, b) {
|
match (a, b) {
|
||||||
// Apple Silicon unified memory
|
// Apple Silicon unified memory
|
||||||
(ProcessorType::Cpu(CpuVariant::Arm64 { .. }), ProcessorType::Gpu(GpuVariant::AppleMetal))
|
(
|
||||||
| (ProcessorType::Gpu(GpuVariant::AppleMetal), ProcessorType::Cpu(CpuVariant::Arm64 { .. }))
|
ProcessorType::Cpu(CpuVariant::Arm64 { .. }),
|
||||||
| (ProcessorType::Npu(NpuVariant::AppleNeuralEngine { .. }), ProcessorType::Cpu(CpuVariant::Arm64 { .. }))
|
ProcessorType::Gpu(GpuVariant::AppleMetal),
|
||||||
| (ProcessorType::Npu(NpuVariant::AppleNeuralEngine { .. }), ProcessorType::Gpu(GpuVariant::AppleMetal)) => true,
|
)
|
||||||
|
| (
|
||||||
|
ProcessorType::Gpu(GpuVariant::AppleMetal),
|
||||||
|
ProcessorType::Cpu(CpuVariant::Arm64 { .. }),
|
||||||
|
)
|
||||||
|
| (
|
||||||
|
ProcessorType::Npu(NpuVariant::AppleNeuralEngine { .. }),
|
||||||
|
ProcessorType::Cpu(CpuVariant::Arm64 { .. }),
|
||||||
|
)
|
||||||
|
| (
|
||||||
|
ProcessorType::Npu(NpuVariant::AppleNeuralEngine { .. }),
|
||||||
|
ProcessorType::Gpu(GpuVariant::AppleMetal),
|
||||||
|
) => true,
|
||||||
// Same type
|
// Same type
|
||||||
_ if a == b => true,
|
_ if a == b => true,
|
||||||
_ => false,
|
_ => false,
|
||||||
|
|
@ -325,7 +345,9 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_transfer_path_bandwidth() {
|
fn test_transfer_path_bandwidth() {
|
||||||
assert!(TransferPath::NvLink.bandwidth_gbps() > TransferPath::PciePeerToPeer.bandwidth_gbps());
|
assert!(
|
||||||
|
TransferPath::NvLink.bandwidth_gbps() > TransferPath::PciePeerToPeer.bandwidth_gbps()
|
||||||
|
);
|
||||||
assert!(TransferPath::SameMemory.bandwidth_gbps().is_infinite());
|
assert!(TransferPath::SameMemory.bandwidth_gbps().is_infinite());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -333,7 +355,9 @@ mod tests {
|
||||||
fn test_memory_manager() {
|
fn test_memory_manager() {
|
||||||
let manager = MemoryManager::new();
|
let manager = MemoryManager::new();
|
||||||
|
|
||||||
let handle = manager.allocate(vec![1024, 1024], DataType::Float32).unwrap();
|
let handle = manager
|
||||||
|
.allocate(vec![1024, 1024], DataType::Float32)
|
||||||
|
.unwrap();
|
||||||
assert_eq!(manager.tensor_count(), 1);
|
assert_eq!(manager.tensor_count(), 1);
|
||||||
|
|
||||||
manager.free(handle.id).unwrap();
|
manager.free(handle.id).unwrap();
|
||||||
|
|
@ -347,22 +371,26 @@ mod tests {
|
||||||
let handle = manager.allocate(vec![1024], DataType::Float32).unwrap();
|
let handle = manager.allocate(vec![1024], DataType::Float32).unwrap();
|
||||||
|
|
||||||
// First ensure should allocate
|
// First ensure should allocate
|
||||||
let path = manager.ensure_on(
|
let path = manager
|
||||||
handle.id,
|
.ensure_on(
|
||||||
ProcessorType::Gpu(crate::processor::GpuVariant::NvidiaCuda {
|
handle.id,
|
||||||
compute_capability: (8, 0),
|
ProcessorType::Gpu(crate::processor::GpuVariant::NvidiaCuda {
|
||||||
}),
|
compute_capability: (8, 0),
|
||||||
).unwrap();
|
}),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
assert_eq!(path, TransferPath::SameMemory);
|
assert_eq!(path, TransferPath::SameMemory);
|
||||||
|
|
||||||
// Second ensure to same location should be same memory
|
// Second ensure to same location should be same memory
|
||||||
let path = manager.ensure_on(
|
let path = manager
|
||||||
handle.id,
|
.ensure_on(
|
||||||
ProcessorType::Gpu(crate::processor::GpuVariant::NvidiaCuda {
|
handle.id,
|
||||||
compute_capability: (8, 0),
|
ProcessorType::Gpu(crate::processor::GpuVariant::NvidiaCuda {
|
||||||
}),
|
compute_capability: (8, 0),
|
||||||
).unwrap();
|
}),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
assert_eq!(path, TransferPath::SameMemory);
|
assert_eq!(path, TransferPath::SameMemory);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -140,13 +140,7 @@ pub struct ModelInfo {
|
||||||
|
|
||||||
impl ModelInfo {
|
impl ModelInfo {
|
||||||
/// Creates a new LLM model info.
|
/// Creates a new LLM model info.
|
||||||
pub fn llm(
|
pub fn llm(alias: &str, name: &str, cid: &str, parameters: u64, context_length: u32) -> Self {
|
||||||
alias: &str,
|
|
||||||
name: &str,
|
|
||||||
cid: &str,
|
|
||||||
parameters: u64,
|
|
||||||
context_length: u32,
|
|
||||||
) -> Self {
|
|
||||||
Self {
|
Self {
|
||||||
id: ModelId::from_alias(alias),
|
id: ModelId::from_alias(alias),
|
||||||
name: name.to_string(),
|
name: name.to_string(),
|
||||||
|
|
@ -156,7 +150,12 @@ impl ModelInfo {
|
||||||
format: ModelFormat::SafeTensors,
|
format: ModelFormat::SafeTensors,
|
||||||
size_bytes: parameters * 2, // ~2 bytes per param in fp16
|
size_bytes: parameters * 2, // ~2 bytes per param in fp16
|
||||||
parameters,
|
parameters,
|
||||||
supported_precisions: vec![Precision::Fp16, Precision::Bf16, Precision::Int8, Precision::Int4],
|
supported_precisions: vec![
|
||||||
|
Precision::Fp16,
|
||||||
|
Precision::Bf16,
|
||||||
|
Precision::Int8,
|
||||||
|
Precision::Int4,
|
||||||
|
],
|
||||||
recommended_processor: ProcessorType::Lpu,
|
recommended_processor: ProcessorType::Lpu,
|
||||||
context_length: Some(context_length),
|
context_length: Some(context_length),
|
||||||
input_schema: None,
|
input_schema: None,
|
||||||
|
|
@ -238,33 +237,123 @@ impl ModelRegistry {
|
||||||
let default_models = vec![
|
let default_models = vec![
|
||||||
// ===== LLMs =====
|
// ===== LLMs =====
|
||||||
// Llama 3 family
|
// Llama 3 family
|
||||||
ModelInfo::llm("llama-3-8b", "Llama 3 8B", "QmLlama3_8B_placeholder", 8_000_000_000, 8192),
|
ModelInfo::llm(
|
||||||
ModelInfo::llm("llama-3-70b", "Llama 3 70B", "QmLlama3_70B_placeholder", 70_000_000_000, 8192),
|
"llama-3-8b",
|
||||||
ModelInfo::llm("llama-3.1-8b", "Llama 3.1 8B", "QmLlama31_8B_placeholder", 8_000_000_000, 128000),
|
"Llama 3 8B",
|
||||||
ModelInfo::llm("llama-3.1-70b", "Llama 3.1 70B", "QmLlama31_70B_placeholder", 70_000_000_000, 128000),
|
"QmLlama3_8B_placeholder",
|
||||||
ModelInfo::llm("llama-3.1-405b", "Llama 3.1 405B", "QmLlama31_405B_placeholder", 405_000_000_000, 128000),
|
8_000_000_000,
|
||||||
|
8192,
|
||||||
|
),
|
||||||
|
ModelInfo::llm(
|
||||||
|
"llama-3-70b",
|
||||||
|
"Llama 3 70B",
|
||||||
|
"QmLlama3_70B_placeholder",
|
||||||
|
70_000_000_000,
|
||||||
|
8192,
|
||||||
|
),
|
||||||
|
ModelInfo::llm(
|
||||||
|
"llama-3.1-8b",
|
||||||
|
"Llama 3.1 8B",
|
||||||
|
"QmLlama31_8B_placeholder",
|
||||||
|
8_000_000_000,
|
||||||
|
128000,
|
||||||
|
),
|
||||||
|
ModelInfo::llm(
|
||||||
|
"llama-3.1-70b",
|
||||||
|
"Llama 3.1 70B",
|
||||||
|
"QmLlama31_70B_placeholder",
|
||||||
|
70_000_000_000,
|
||||||
|
128000,
|
||||||
|
),
|
||||||
|
ModelInfo::llm(
|
||||||
|
"llama-3.1-405b",
|
||||||
|
"Llama 3.1 405B",
|
||||||
|
"QmLlama31_405B_placeholder",
|
||||||
|
405_000_000_000,
|
||||||
|
128000,
|
||||||
|
),
|
||||||
// Mistral family
|
// Mistral family
|
||||||
ModelInfo::llm("mistral-7b", "Mistral 7B", "QmMistral7B_placeholder", 7_000_000_000, 32768),
|
ModelInfo::llm(
|
||||||
ModelInfo::llm("mixtral-8x7b", "Mixtral 8x7B", "QmMixtral8x7B_placeholder", 46_000_000_000, 32768),
|
"mistral-7b",
|
||||||
ModelInfo::llm("mixtral-8x22b", "Mixtral 8x22B", "QmMixtral8x22B_placeholder", 176_000_000_000, 65536),
|
"Mistral 7B",
|
||||||
|
"QmMistral7B_placeholder",
|
||||||
|
7_000_000_000,
|
||||||
|
32768,
|
||||||
|
),
|
||||||
|
ModelInfo::llm(
|
||||||
|
"mixtral-8x7b",
|
||||||
|
"Mixtral 8x7B",
|
||||||
|
"QmMixtral8x7B_placeholder",
|
||||||
|
46_000_000_000,
|
||||||
|
32768,
|
||||||
|
),
|
||||||
|
ModelInfo::llm(
|
||||||
|
"mixtral-8x22b",
|
||||||
|
"Mixtral 8x22B",
|
||||||
|
"QmMixtral8x22B_placeholder",
|
||||||
|
176_000_000_000,
|
||||||
|
65536,
|
||||||
|
),
|
||||||
// Qwen family
|
// Qwen family
|
||||||
ModelInfo::llm("qwen-2.5-7b", "Qwen 2.5 7B", "QmQwen25_7B_placeholder", 7_000_000_000, 128000),
|
ModelInfo::llm(
|
||||||
ModelInfo::llm("qwen-2.5-72b", "Qwen 2.5 72B", "QmQwen25_72B_placeholder", 72_000_000_000, 128000),
|
"qwen-2.5-7b",
|
||||||
|
"Qwen 2.5 7B",
|
||||||
|
"QmQwen25_7B_placeholder",
|
||||||
|
7_000_000_000,
|
||||||
|
128000,
|
||||||
|
),
|
||||||
|
ModelInfo::llm(
|
||||||
|
"qwen-2.5-72b",
|
||||||
|
"Qwen 2.5 72B",
|
||||||
|
"QmQwen25_72B_placeholder",
|
||||||
|
72_000_000_000,
|
||||||
|
128000,
|
||||||
|
),
|
||||||
// DeepSeek family
|
// DeepSeek family
|
||||||
ModelInfo::llm("deepseek-v2", "DeepSeek V2", "QmDeepSeekV2_placeholder", 236_000_000_000, 128000),
|
ModelInfo::llm(
|
||||||
ModelInfo::llm("deepseek-coder-33b", "DeepSeek Coder 33B", "QmDeepSeekCoder33B_placeholder", 33_000_000_000, 16384),
|
"deepseek-v2",
|
||||||
|
"DeepSeek V2",
|
||||||
|
"QmDeepSeekV2_placeholder",
|
||||||
|
236_000_000_000,
|
||||||
|
128000,
|
||||||
|
),
|
||||||
|
ModelInfo::llm(
|
||||||
|
"deepseek-coder-33b",
|
||||||
|
"DeepSeek Coder 33B",
|
||||||
|
"QmDeepSeekCoder33B_placeholder",
|
||||||
|
33_000_000_000,
|
||||||
|
16384,
|
||||||
|
),
|
||||||
// Phi family (small/efficient)
|
// Phi family (small/efficient)
|
||||||
ModelInfo::llm("phi-3-mini", "Phi 3 Mini", "QmPhi3Mini_placeholder", 3_800_000_000, 128000),
|
ModelInfo::llm(
|
||||||
ModelInfo::llm("phi-3-medium", "Phi 3 Medium", "QmPhi3Medium_placeholder", 14_000_000_000, 128000),
|
"phi-3-mini",
|
||||||
|
"Phi 3 Mini",
|
||||||
|
"QmPhi3Mini_placeholder",
|
||||||
|
3_800_000_000,
|
||||||
|
128000,
|
||||||
|
),
|
||||||
|
ModelInfo::llm(
|
||||||
|
"phi-3-medium",
|
||||||
|
"Phi 3 Medium",
|
||||||
|
"QmPhi3Medium_placeholder",
|
||||||
|
14_000_000_000,
|
||||||
|
128000,
|
||||||
|
),
|
||||||
// Code models
|
// Code models
|
||||||
ModelInfo::llm("codellama-34b", "Code Llama 34B", "QmCodeLlama34B_placeholder", 34_000_000_000, 16384),
|
ModelInfo::llm(
|
||||||
ModelInfo::llm("starcoder2-15b", "StarCoder2 15B", "QmStarCoder2_15B_placeholder", 15_000_000_000, 16384),
|
"codellama-34b",
|
||||||
|
"Code Llama 34B",
|
||||||
|
"QmCodeLlama34B_placeholder",
|
||||||
|
34_000_000_000,
|
||||||
|
16384,
|
||||||
|
),
|
||||||
|
ModelInfo::llm(
|
||||||
|
"starcoder2-15b",
|
||||||
|
"StarCoder2 15B",
|
||||||
|
"QmStarCoder2_15B_placeholder",
|
||||||
|
15_000_000_000,
|
||||||
|
16384,
|
||||||
|
),
|
||||||
// ===== Embedding Models =====
|
// ===== Embedding Models =====
|
||||||
ModelInfo {
|
ModelInfo {
|
||||||
id: ModelId::from_alias("bge-large"),
|
id: ModelId::from_alias("bge-large"),
|
||||||
|
|
@ -306,7 +395,6 @@ impl ModelRegistry {
|
||||||
is_public: true,
|
is_public: true,
|
||||||
owner: None,
|
owner: None,
|
||||||
},
|
},
|
||||||
|
|
||||||
// ===== Vision Models =====
|
// ===== Vision Models =====
|
||||||
ModelInfo {
|
ModelInfo {
|
||||||
id: ModelId::from_alias("stable-diffusion-xl"),
|
id: ModelId::from_alias("stable-diffusion-xl"),
|
||||||
|
|
@ -348,7 +436,6 @@ impl ModelRegistry {
|
||||||
is_public: true,
|
is_public: true,
|
||||||
owner: None,
|
owner: None,
|
||||||
},
|
},
|
||||||
|
|
||||||
// ===== Speech Models =====
|
// ===== Speech Models =====
|
||||||
ModelInfo {
|
ModelInfo {
|
||||||
id: ModelId::from_alias("whisper-large-v3"),
|
id: ModelId::from_alias("whisper-large-v3"),
|
||||||
|
|
@ -370,7 +457,6 @@ impl ModelRegistry {
|
||||||
is_public: true,
|
is_public: true,
|
||||||
owner: None,
|
owner: None,
|
||||||
},
|
},
|
||||||
|
|
||||||
// ===== Multi-Modal Models =====
|
// ===== Multi-Modal Models =====
|
||||||
ModelInfo {
|
ModelInfo {
|
||||||
id: ModelId::from_alias("llava-1.5-13b"),
|
id: ModelId::from_alias("llava-1.5-13b"),
|
||||||
|
|
@ -555,7 +641,9 @@ mod tests {
|
||||||
let registry = ModelRegistry::new();
|
let registry = ModelRegistry::new();
|
||||||
let results = registry.search("llama");
|
let results = registry.search("llama");
|
||||||
assert!(!results.is_empty());
|
assert!(!results.is_empty());
|
||||||
assert!(results.iter().all(|m| m.name.to_lowercase().contains("llama")));
|
assert!(results
|
||||||
|
.iter()
|
||||||
|
.all(|m| m.name.to_lowercase().contains("llama")));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
|
||||||
|
|
@ -305,7 +305,7 @@ impl ProcessorCapabilities {
|
||||||
},
|
},
|
||||||
memory: MemorySpecs {
|
memory: MemorySpecs {
|
||||||
capacity_bytes: 230 * 1024 * 1024 * 1024, // 230 GB SRAM!
|
capacity_bytes: 230 * 1024 * 1024 * 1024, // 230 GB SRAM!
|
||||||
bandwidth_gbps: 80_000, // 80 TB/s internal
|
bandwidth_gbps: 80_000, // 80 TB/s internal
|
||||||
type_: MemoryType::Sram,
|
type_: MemoryType::Sram,
|
||||||
},
|
},
|
||||||
operations: Self::lpu_operations(),
|
operations: Self::lpu_operations(),
|
||||||
|
|
@ -349,8 +349,8 @@ impl ProcessorCapabilities {
|
||||||
/// Creates Apple Neural Engine capabilities.
|
/// Creates Apple Neural Engine capabilities.
|
||||||
pub fn apple_neural_engine(cores: u32) -> Self {
|
pub fn apple_neural_engine(cores: u32) -> Self {
|
||||||
let int8_tops = match cores {
|
let int8_tops = match cores {
|
||||||
16 => 18.0, // M3
|
16 => 18.0, // M3
|
||||||
32 => 35.0, // M3 Max
|
32 => 35.0, // M3 Max
|
||||||
_ => cores as f64 * 1.1,
|
_ => cores as f64 * 1.1,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -542,6 +542,8 @@ mod tests {
|
||||||
fn test_lpu_capabilities() {
|
fn test_lpu_capabilities() {
|
||||||
let caps = ProcessorCapabilities::lpu();
|
let caps = ProcessorCapabilities::lpu();
|
||||||
assert!(caps.memory.bandwidth_gbps > 10000); // Very high internal bandwidth
|
assert!(caps.memory.bandwidth_gbps > 10000); // Very high internal bandwidth
|
||||||
assert!(caps.optimal_for.contains(&WorkloadCharacteristic::Sequential));
|
assert!(caps
|
||||||
|
.optimal_for
|
||||||
|
.contains(&WorkloadCharacteristic::Sequential));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -253,10 +253,22 @@ impl Processor for GenericProcessor {
|
||||||
fn shares_memory_with(&self, other: &ProcessorType) -> bool {
|
fn shares_memory_with(&self, other: &ProcessorType) -> bool {
|
||||||
match (&self.processor_type, other) {
|
match (&self.processor_type, other) {
|
||||||
// Apple Silicon has unified memory
|
// Apple Silicon has unified memory
|
||||||
(ProcessorType::Cpu(CpuVariant::Arm64 { .. }), ProcessorType::Gpu(GpuVariant::AppleMetal))
|
(
|
||||||
| (ProcessorType::Gpu(GpuVariant::AppleMetal), ProcessorType::Cpu(CpuVariant::Arm64 { .. }))
|
ProcessorType::Cpu(CpuVariant::Arm64 { .. }),
|
||||||
| (ProcessorType::Npu(NpuVariant::AppleNeuralEngine { .. }), ProcessorType::Cpu(CpuVariant::Arm64 { .. }))
|
ProcessorType::Gpu(GpuVariant::AppleMetal),
|
||||||
| (ProcessorType::Npu(NpuVariant::AppleNeuralEngine { .. }), ProcessorType::Gpu(GpuVariant::AppleMetal)) => true,
|
)
|
||||||
|
| (
|
||||||
|
ProcessorType::Gpu(GpuVariant::AppleMetal),
|
||||||
|
ProcessorType::Cpu(CpuVariant::Arm64 { .. }),
|
||||||
|
)
|
||||||
|
| (
|
||||||
|
ProcessorType::Npu(NpuVariant::AppleNeuralEngine { .. }),
|
||||||
|
ProcessorType::Cpu(CpuVariant::Arm64 { .. }),
|
||||||
|
)
|
||||||
|
| (
|
||||||
|
ProcessorType::Npu(NpuVariant::AppleNeuralEngine { .. }),
|
||||||
|
ProcessorType::Gpu(GpuVariant::AppleMetal),
|
||||||
|
) => true,
|
||||||
// Same type always shares
|
// Same type always shares
|
||||||
(a, b) if a == b => true,
|
(a, b) if a == b => true,
|
||||||
_ => false,
|
_ => false,
|
||||||
|
|
|
||||||
|
|
@ -191,10 +191,7 @@ pub enum Operation {
|
||||||
},
|
},
|
||||||
|
|
||||||
/// Data loading from storage.
|
/// Data loading from storage.
|
||||||
DataLoad {
|
DataLoad { bytes: usize, async_: bool },
|
||||||
bytes: usize,
|
|
||||||
async_: bool,
|
|
||||||
},
|
|
||||||
|
|
||||||
/// Data preprocessing.
|
/// Data preprocessing.
|
||||||
DataPreprocess {
|
DataPreprocess {
|
||||||
|
|
@ -209,16 +206,10 @@ pub enum Operation {
|
||||||
},
|
},
|
||||||
|
|
||||||
/// Detokenization.
|
/// Detokenization.
|
||||||
Detokenization {
|
Detokenization { tokens: usize, vocab_size: usize },
|
||||||
tokens: usize,
|
|
||||||
vocab_size: usize,
|
|
||||||
},
|
|
||||||
|
|
||||||
/// Checkpoint save.
|
/// Checkpoint save.
|
||||||
Checkpoint {
|
Checkpoint { bytes: usize, async_: bool },
|
||||||
bytes: usize,
|
|
||||||
async_: bool,
|
|
||||||
},
|
|
||||||
|
|
||||||
/// All-reduce across devices.
|
/// All-reduce across devices.
|
||||||
AllReduce {
|
AllReduce {
|
||||||
|
|
@ -228,9 +219,7 @@ pub enum Operation {
|
||||||
},
|
},
|
||||||
|
|
||||||
/// Backward pass for a layer.
|
/// Backward pass for a layer.
|
||||||
Backward {
|
Backward { forward_op: Box<Operation> },
|
||||||
forward_op: Box<Operation>,
|
|
||||||
},
|
|
||||||
|
|
||||||
/// Optimizer step.
|
/// Optimizer step.
|
||||||
OptimizerStep {
|
OptimizerStep {
|
||||||
|
|
@ -240,16 +229,10 @@ pub enum Operation {
|
||||||
},
|
},
|
||||||
|
|
||||||
/// Transpose.
|
/// Transpose.
|
||||||
Transpose {
|
Transpose { shape: Vec<usize>, axes: Vec<usize> },
|
||||||
shape: Vec<usize>,
|
|
||||||
axes: Vec<usize>,
|
|
||||||
},
|
|
||||||
|
|
||||||
/// Reshape.
|
/// Reshape.
|
||||||
Reshape {
|
Reshape { from: Vec<usize>, to: Vec<usize> },
|
||||||
from: Vec<usize>,
|
|
||||||
to: Vec<usize>,
|
|
||||||
},
|
|
||||||
|
|
||||||
/// Concatenate tensors.
|
/// Concatenate tensors.
|
||||||
Concat {
|
Concat {
|
||||||
|
|
@ -378,9 +361,7 @@ impl Operation {
|
||||||
| Operation::SiLU { elements } => *elements as f64,
|
| Operation::SiLU { elements } => *elements as f64,
|
||||||
|
|
||||||
// Softmax: ~5 ops per element (exp, sum, div)
|
// Softmax: ~5 ops per element (exp, sum, div)
|
||||||
Operation::Softmax {
|
Operation::Softmax { batch, seq_len, .. } => 5.0 * (*batch as f64) * (*seq_len as f64),
|
||||||
batch, seq_len, ..
|
|
||||||
} => 5.0 * (*batch as f64) * (*seq_len as f64),
|
|
||||||
|
|
||||||
// Embedding: just lookup, minimal FLOPS
|
// Embedding: just lookup, minimal FLOPS
|
||||||
Operation::Embedding {
|
Operation::Embedding {
|
||||||
|
|
|
||||||
|
|
@ -39,8 +39,7 @@ impl ProcessorProfiles {
|
||||||
bandwidth_gbps: 460,
|
bandwidth_gbps: 460,
|
||||||
type_: MemoryType::Ddr5,
|
type_: MemoryType::Ddr5,
|
||||||
},
|
},
|
||||||
operations: ProcessorCapabilities::cpu(96, 2.4, false)
|
operations: ProcessorCapabilities::cpu(96, 2.4, false).operations,
|
||||||
.operations,
|
|
||||||
power: PowerCharacteristics {
|
power: PowerCharacteristics {
|
||||||
tdp_watts: 360,
|
tdp_watts: 360,
|
||||||
efficiency: 0.85,
|
efficiency: 0.85,
|
||||||
|
|
@ -70,8 +69,7 @@ impl ProcessorProfiles {
|
||||||
bandwidth_gbps: 307,
|
bandwidth_gbps: 307,
|
||||||
type_: MemoryType::Ddr5,
|
type_: MemoryType::Ddr5,
|
||||||
},
|
},
|
||||||
operations: ProcessorCapabilities::cpu(56, 2.9, true)
|
operations: ProcessorCapabilities::cpu(56, 2.9, true).operations,
|
||||||
.operations,
|
|
||||||
power: PowerCharacteristics {
|
power: PowerCharacteristics {
|
||||||
tdp_watts: 350,
|
tdp_watts: 350,
|
||||||
efficiency: 0.80,
|
efficiency: 0.80,
|
||||||
|
|
@ -101,8 +99,7 @@ impl ProcessorProfiles {
|
||||||
bandwidth_gbps: 400,
|
bandwidth_gbps: 400,
|
||||||
type_: MemoryType::Unified,
|
type_: MemoryType::Unified,
|
||||||
},
|
},
|
||||||
operations: ProcessorCapabilities::cpu(16, 4.0, false)
|
operations: ProcessorCapabilities::cpu(16, 4.0, false).operations,
|
||||||
.operations,
|
|
||||||
power: PowerCharacteristics {
|
power: PowerCharacteristics {
|
||||||
tdp_watts: 40,
|
tdp_watts: 40,
|
||||||
efficiency: 0.95,
|
efficiency: 0.95,
|
||||||
|
|
@ -141,8 +138,7 @@ impl ProcessorProfiles {
|
||||||
bandwidth_gbps: 3350,
|
bandwidth_gbps: 3350,
|
||||||
type_: MemoryType::Hbm3,
|
type_: MemoryType::Hbm3,
|
||||||
},
|
},
|
||||||
operations: ProcessorCapabilities::nvidia_gpu(16896, 528, 80, 3350, (9, 0))
|
operations: ProcessorCapabilities::nvidia_gpu(16896, 528, 80, 3350, (9, 0)).operations,
|
||||||
.operations,
|
|
||||||
power: PowerCharacteristics {
|
power: PowerCharacteristics {
|
||||||
tdp_watts: 700,
|
tdp_watts: 700,
|
||||||
efficiency: 0.90,
|
efficiency: 0.90,
|
||||||
|
|
@ -173,8 +169,7 @@ impl ProcessorProfiles {
|
||||||
bandwidth_gbps: 2039,
|
bandwidth_gbps: 2039,
|
||||||
type_: MemoryType::Hbm2e,
|
type_: MemoryType::Hbm2e,
|
||||||
},
|
},
|
||||||
operations: ProcessorCapabilities::nvidia_gpu(6912, 432, 80, 2039, (8, 0))
|
operations: ProcessorCapabilities::nvidia_gpu(6912, 432, 80, 2039, (8, 0)).operations,
|
||||||
.operations,
|
|
||||||
power: PowerCharacteristics {
|
power: PowerCharacteristics {
|
||||||
tdp_watts: 400,
|
tdp_watts: 400,
|
||||||
efficiency: 0.88,
|
efficiency: 0.88,
|
||||||
|
|
@ -205,8 +200,7 @@ impl ProcessorProfiles {
|
||||||
bandwidth_gbps: 1008,
|
bandwidth_gbps: 1008,
|
||||||
type_: MemoryType::Gddr6,
|
type_: MemoryType::Gddr6,
|
||||||
},
|
},
|
||||||
operations: ProcessorCapabilities::nvidia_gpu(16384, 512, 24, 1008, (8, 9))
|
operations: ProcessorCapabilities::nvidia_gpu(16384, 512, 24, 1008, (8, 9)).operations,
|
||||||
.operations,
|
|
||||||
power: PowerCharacteristics {
|
power: PowerCharacteristics {
|
||||||
tdp_watts: 450,
|
tdp_watts: 450,
|
||||||
efficiency: 0.85,
|
efficiency: 0.85,
|
||||||
|
|
@ -236,8 +230,7 @@ impl ProcessorProfiles {
|
||||||
bandwidth_gbps: 936,
|
bandwidth_gbps: 936,
|
||||||
type_: MemoryType::Gddr6,
|
type_: MemoryType::Gddr6,
|
||||||
},
|
},
|
||||||
operations: ProcessorCapabilities::nvidia_gpu(10496, 328, 24, 936, (8, 6))
|
operations: ProcessorCapabilities::nvidia_gpu(10496, 328, 24, 936, (8, 6)).operations,
|
||||||
.operations,
|
|
||||||
power: PowerCharacteristics {
|
power: PowerCharacteristics {
|
||||||
tdp_watts: 350,
|
tdp_watts: 350,
|
||||||
efficiency: 0.82,
|
efficiency: 0.82,
|
||||||
|
|
@ -272,8 +265,8 @@ impl ProcessorProfiles {
|
||||||
type_: MemoryType::Hbm3,
|
type_: MemoryType::Hbm3,
|
||||||
},
|
},
|
||||||
operations: {
|
operations: {
|
||||||
let mut ops = ProcessorCapabilities::nvidia_gpu(16384, 512, 80, 5300, (9, 0))
|
let mut ops =
|
||||||
.operations;
|
ProcessorCapabilities::nvidia_gpu(16384, 512, 80, 5300, (9, 0)).operations;
|
||||||
ops.remove(&OperationType::FlashAttention); // Different implementation
|
ops.remove(&OperationType::FlashAttention); // Different implementation
|
||||||
ops
|
ops
|
||||||
},
|
},
|
||||||
|
|
@ -308,8 +301,8 @@ impl ProcessorProfiles {
|
||||||
type_: MemoryType::Gddr6,
|
type_: MemoryType::Gddr6,
|
||||||
},
|
},
|
||||||
operations: {
|
operations: {
|
||||||
let mut ops = ProcessorCapabilities::nvidia_gpu(6144, 0, 24, 960, (8, 0))
|
let mut ops =
|
||||||
.operations;
|
ProcessorCapabilities::nvidia_gpu(6144, 0, 24, 960, (8, 0)).operations;
|
||||||
ops.remove(&OperationType::FlashAttention);
|
ops.remove(&OperationType::FlashAttention);
|
||||||
ops
|
ops
|
||||||
},
|
},
|
||||||
|
|
@ -318,9 +311,7 @@ impl ProcessorProfiles {
|
||||||
efficiency: 0.80,
|
efficiency: 0.80,
|
||||||
power_tier: PowerTier::High,
|
power_tier: PowerTier::High,
|
||||||
},
|
},
|
||||||
optimal_for: vec![
|
optimal_for: vec![WorkloadCharacteristic::HighlyParallel],
|
||||||
WorkloadCharacteristic::HighlyParallel,
|
|
||||||
],
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -429,8 +420,7 @@ impl ProcessorProfiles {
|
||||||
bandwidth_gbps: 200,
|
bandwidth_gbps: 200,
|
||||||
type_: MemoryType::Unified,
|
type_: MemoryType::Unified,
|
||||||
},
|
},
|
||||||
operations: ProcessorCapabilities::apple_neural_engine(16)
|
operations: ProcessorCapabilities::apple_neural_engine(16).operations,
|
||||||
.operations,
|
|
||||||
power: PowerCharacteristics {
|
power: PowerCharacteristics {
|
||||||
tdp_watts: 8,
|
tdp_watts: 8,
|
||||||
efficiency: 0.98,
|
efficiency: 0.98,
|
||||||
|
|
@ -465,8 +455,7 @@ impl ProcessorProfiles {
|
||||||
bandwidth_gbps: 77,
|
bandwidth_gbps: 77,
|
||||||
type_: MemoryType::Lpddr,
|
type_: MemoryType::Lpddr,
|
||||||
},
|
},
|
||||||
operations: ProcessorCapabilities::apple_neural_engine(16)
|
operations: ProcessorCapabilities::apple_neural_engine(16).operations,
|
||||||
.operations,
|
|
||||||
power: PowerCharacteristics {
|
power: PowerCharacteristics {
|
||||||
tdp_watts: 10,
|
tdp_watts: 10,
|
||||||
efficiency: 0.95,
|
efficiency: 0.95,
|
||||||
|
|
|
||||||
|
|
@ -24,10 +24,7 @@ pub enum ProcessorType {
|
||||||
/// WebAssembly runtime.
|
/// WebAssembly runtime.
|
||||||
Wasm,
|
Wasm,
|
||||||
/// Custom/Unknown accelerator.
|
/// Custom/Unknown accelerator.
|
||||||
Custom {
|
Custom { vendor: String, model: String },
|
||||||
vendor: String,
|
|
||||||
model: String,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for ProcessorType {
|
impl Default for ProcessorType {
|
||||||
|
|
|
||||||
|
|
@ -6,10 +6,10 @@
|
||||||
//! - Latency-aware scheduling
|
//! - Latency-aware scheduling
|
||||||
//! - Real-time utilization metrics
|
//! - Real-time utilization metrics
|
||||||
|
|
||||||
|
use super::TaskAssignment;
|
||||||
use crate::device::DeviceRegistry;
|
use crate::device::DeviceRegistry;
|
||||||
use crate::processor::{Operation, OperationType, ProcessorId, ProcessorType};
|
use crate::processor::{Operation, OperationType, ProcessorId, ProcessorType};
|
||||||
use crate::task::{Task, TaskId, TaskPriority};
|
use crate::task::{Task, TaskId, TaskPriority};
|
||||||
use super::TaskAssignment;
|
|
||||||
use parking_lot::RwLock;
|
use parking_lot::RwLock;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::atomic::{AtomicU64, Ordering};
|
use std::sync::atomic::{AtomicU64, Ordering};
|
||||||
|
|
@ -127,8 +127,12 @@ impl LoadBalancer {
|
||||||
/// Register a processor with its type.
|
/// Register a processor with its type.
|
||||||
pub fn register_processor(&self, processor_id: ProcessorId, processor_type: ProcessorType) {
|
pub fn register_processor(&self, processor_id: ProcessorId, processor_type: ProcessorType) {
|
||||||
self.loads.write().insert(processor_id, AtomicU64::new(0));
|
self.loads.write().insert(processor_id, AtomicU64::new(0));
|
||||||
self.metrics.write().insert(processor_id, ProcessorMetrics::default());
|
self.metrics
|
||||||
self.processor_types.write().insert(processor_id, processor_type);
|
.write()
|
||||||
|
.insert(processor_id, ProcessorMetrics::default());
|
||||||
|
self.processor_types
|
||||||
|
.write()
|
||||||
|
.insert(processor_id, processor_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Unregister a processor.
|
/// Unregister a processor.
|
||||||
|
|
@ -150,7 +154,8 @@ impl LoadBalancer {
|
||||||
|
|
||||||
/// Get current load for a processor.
|
/// Get current load for a processor.
|
||||||
pub fn get_load(&self, processor_id: ProcessorId) -> u64 {
|
pub fn get_load(&self, processor_id: ProcessorId) -> u64 {
|
||||||
self.loads.read()
|
self.loads
|
||||||
|
.read()
|
||||||
.get(&processor_id)
|
.get(&processor_id)
|
||||||
.map(|l| l.load(Ordering::Relaxed))
|
.map(|l| l.load(Ordering::Relaxed))
|
||||||
.unwrap_or(0)
|
.unwrap_or(0)
|
||||||
|
|
@ -179,140 +184,140 @@ impl LoadBalancer {
|
||||||
ProcessorType::Cpu(_) => matches!(
|
ProcessorType::Cpu(_) => matches!(
|
||||||
op_type,
|
op_type,
|
||||||
OperationType::MatMul
|
OperationType::MatMul
|
||||||
| OperationType::Conv2d
|
| OperationType::Conv2d
|
||||||
| OperationType::Conv3d
|
| OperationType::Conv3d
|
||||||
| OperationType::DepthwiseConv
|
| OperationType::DepthwiseConv
|
||||||
| OperationType::BatchNorm
|
| OperationType::BatchNorm
|
||||||
| OperationType::LayerNorm
|
| OperationType::LayerNorm
|
||||||
| OperationType::Add
|
| OperationType::Add
|
||||||
| OperationType::Mul
|
| OperationType::Mul
|
||||||
| OperationType::ReLU
|
| OperationType::ReLU
|
||||||
| OperationType::GeLU
|
| OperationType::GeLU
|
||||||
| OperationType::SiLU
|
| OperationType::SiLU
|
||||||
| OperationType::Softmax
|
| OperationType::Softmax
|
||||||
| OperationType::Sum
|
| OperationType::Sum
|
||||||
| OperationType::Mean
|
| OperationType::Mean
|
||||||
| OperationType::Max
|
| OperationType::Max
|
||||||
| OperationType::ArgMax
|
| OperationType::ArgMax
|
||||||
| OperationType::Embedding
|
| OperationType::Embedding
|
||||||
| OperationType::TopK
|
| OperationType::TopK
|
||||||
| OperationType::Sampling
|
| OperationType::Sampling
|
||||||
| OperationType::Tokenization
|
| OperationType::Tokenization
|
||||||
| OperationType::Detokenization
|
| OperationType::Detokenization
|
||||||
| OperationType::DataLoad
|
| OperationType::DataLoad
|
||||||
| OperationType::DataPreprocess
|
| OperationType::DataPreprocess
|
||||||
| OperationType::Transpose
|
| OperationType::Transpose
|
||||||
| OperationType::Reshape
|
| OperationType::Reshape
|
||||||
| OperationType::Concat
|
| OperationType::Concat
|
||||||
| OperationType::Split
|
| OperationType::Split
|
||||||
),
|
),
|
||||||
|
|
||||||
// GPUs excel at parallel operations
|
// GPUs excel at parallel operations
|
||||||
ProcessorType::Gpu(_) => matches!(
|
ProcessorType::Gpu(_) => matches!(
|
||||||
op_type,
|
op_type,
|
||||||
OperationType::MatMul
|
OperationType::MatMul
|
||||||
| OperationType::Conv2d
|
| OperationType::Conv2d
|
||||||
| OperationType::Conv3d
|
| OperationType::Conv3d
|
||||||
| OperationType::DepthwiseConv
|
| OperationType::DepthwiseConv
|
||||||
| OperationType::BatchNorm
|
| OperationType::BatchNorm
|
||||||
| OperationType::LayerNorm
|
| OperationType::LayerNorm
|
||||||
| OperationType::SelfAttention
|
| OperationType::SelfAttention
|
||||||
| OperationType::CrossAttention
|
| OperationType::CrossAttention
|
||||||
| OperationType::FlashAttention
|
| OperationType::FlashAttention
|
||||||
| OperationType::Add
|
| OperationType::Add
|
||||||
| OperationType::Mul
|
| OperationType::Mul
|
||||||
| OperationType::ReLU
|
| OperationType::ReLU
|
||||||
| OperationType::GeLU
|
| OperationType::GeLU
|
||||||
| OperationType::SiLU
|
| OperationType::SiLU
|
||||||
| OperationType::Softmax
|
| OperationType::Softmax
|
||||||
| OperationType::Sum
|
| OperationType::Sum
|
||||||
| OperationType::Mean
|
| OperationType::Mean
|
||||||
| OperationType::Max
|
| OperationType::Max
|
||||||
| OperationType::ArgMax
|
| OperationType::ArgMax
|
||||||
| OperationType::Embedding
|
| OperationType::Embedding
|
||||||
| OperationType::RoPE
|
| OperationType::RoPE
|
||||||
| OperationType::KVCache
|
| OperationType::KVCache
|
||||||
| OperationType::TopK
|
| OperationType::TopK
|
||||||
| OperationType::Sampling
|
| OperationType::Sampling
|
||||||
| OperationType::Transpose
|
| OperationType::Transpose
|
||||||
| OperationType::Reshape
|
| OperationType::Reshape
|
||||||
| OperationType::Concat
|
| OperationType::Concat
|
||||||
| OperationType::Split
|
| OperationType::Split
|
||||||
| OperationType::Gather
|
| OperationType::Gather
|
||||||
| OperationType::Scatter
|
| OperationType::Scatter
|
||||||
| OperationType::AllReduce
|
| OperationType::AllReduce
|
||||||
| OperationType::AllGather
|
| OperationType::AllGather
|
||||||
| OperationType::ReduceScatter
|
| OperationType::ReduceScatter
|
||||||
| OperationType::Backward
|
| OperationType::Backward
|
||||||
| OperationType::OptimizerStep
|
| OperationType::OptimizerStep
|
||||||
| OperationType::GradientClip
|
| OperationType::GradientClip
|
||||||
),
|
),
|
||||||
|
|
||||||
// TPUs optimized for ML
|
// TPUs optimized for ML
|
||||||
ProcessorType::Tpu(_) => matches!(
|
ProcessorType::Tpu(_) => matches!(
|
||||||
op_type,
|
op_type,
|
||||||
OperationType::MatMul
|
OperationType::MatMul
|
||||||
| OperationType::Conv2d
|
| OperationType::Conv2d
|
||||||
| OperationType::BatchNorm
|
| OperationType::BatchNorm
|
||||||
| OperationType::LayerNorm
|
| OperationType::LayerNorm
|
||||||
| OperationType::SelfAttention
|
| OperationType::SelfAttention
|
||||||
| OperationType::CrossAttention
|
| OperationType::CrossAttention
|
||||||
| OperationType::FlashAttention
|
| OperationType::FlashAttention
|
||||||
| OperationType::Add
|
| OperationType::Add
|
||||||
| OperationType::Mul
|
| OperationType::Mul
|
||||||
| OperationType::ReLU
|
| OperationType::ReLU
|
||||||
| OperationType::GeLU
|
| OperationType::GeLU
|
||||||
| OperationType::SiLU
|
| OperationType::SiLU
|
||||||
| OperationType::Softmax
|
| OperationType::Softmax
|
||||||
| OperationType::Sum
|
| OperationType::Sum
|
||||||
| OperationType::Mean
|
| OperationType::Mean
|
||||||
| OperationType::Embedding
|
| OperationType::Embedding
|
||||||
| OperationType::RoPE
|
| OperationType::RoPE
|
||||||
| OperationType::KVCache
|
| OperationType::KVCache
|
||||||
| OperationType::AllReduce
|
| OperationType::AllReduce
|
||||||
| OperationType::AllGather
|
| OperationType::AllGather
|
||||||
| OperationType::ReduceScatter
|
| OperationType::ReduceScatter
|
||||||
| OperationType::Backward
|
| OperationType::Backward
|
||||||
| OperationType::OptimizerStep
|
| OperationType::OptimizerStep
|
||||||
),
|
),
|
||||||
|
|
||||||
// NPUs for neural network inference
|
// NPUs for neural network inference
|
||||||
ProcessorType::Npu(_) => matches!(
|
ProcessorType::Npu(_) => matches!(
|
||||||
op_type,
|
op_type,
|
||||||
OperationType::MatMul
|
OperationType::MatMul
|
||||||
| OperationType::Conv2d
|
| OperationType::Conv2d
|
||||||
| OperationType::DepthwiseConv
|
| OperationType::DepthwiseConv
|
||||||
| OperationType::BatchNorm
|
| OperationType::BatchNorm
|
||||||
| OperationType::LayerNorm
|
| OperationType::LayerNorm
|
||||||
| OperationType::SelfAttention
|
| OperationType::SelfAttention
|
||||||
| OperationType::Add
|
| OperationType::Add
|
||||||
| OperationType::Mul
|
| OperationType::Mul
|
||||||
| OperationType::ReLU
|
| OperationType::ReLU
|
||||||
| OperationType::GeLU
|
| OperationType::GeLU
|
||||||
| OperationType::SiLU
|
| OperationType::SiLU
|
||||||
| OperationType::Softmax
|
| OperationType::Softmax
|
||||||
| OperationType::Sum
|
| OperationType::Sum
|
||||||
| OperationType::Mean
|
| OperationType::Mean
|
||||||
),
|
),
|
||||||
|
|
||||||
// LPUs for sequential inference (optimized for LLMs)
|
// LPUs for sequential inference (optimized for LLMs)
|
||||||
ProcessorType::Lpu => matches!(
|
ProcessorType::Lpu => matches!(
|
||||||
op_type,
|
op_type,
|
||||||
OperationType::MatMul
|
OperationType::MatMul
|
||||||
| OperationType::LayerNorm
|
| OperationType::LayerNorm
|
||||||
| OperationType::SelfAttention
|
| OperationType::SelfAttention
|
||||||
| OperationType::FlashAttention
|
| OperationType::FlashAttention
|
||||||
| OperationType::Add
|
| OperationType::Add
|
||||||
| OperationType::Mul
|
| OperationType::Mul
|
||||||
| OperationType::ReLU
|
| OperationType::ReLU
|
||||||
| OperationType::GeLU
|
| OperationType::GeLU
|
||||||
| OperationType::SiLU
|
| OperationType::SiLU
|
||||||
| OperationType::Softmax
|
| OperationType::Softmax
|
||||||
| OperationType::Embedding
|
| OperationType::Embedding
|
||||||
| OperationType::RoPE
|
| OperationType::RoPE
|
||||||
| OperationType::KVCache
|
| OperationType::KVCache
|
||||||
| OperationType::TopK
|
| OperationType::TopK
|
||||||
| OperationType::Sampling
|
| OperationType::Sampling
|
||||||
),
|
),
|
||||||
|
|
||||||
// FPGAs can be programmed for anything
|
// FPGAs can be programmed for anything
|
||||||
|
|
@ -322,40 +327,40 @@ impl LoadBalancer {
|
||||||
ProcessorType::Dsp(_) => matches!(
|
ProcessorType::Dsp(_) => matches!(
|
||||||
op_type,
|
op_type,
|
||||||
OperationType::Conv2d
|
OperationType::Conv2d
|
||||||
| OperationType::DepthwiseConv
|
| OperationType::DepthwiseConv
|
||||||
| OperationType::Add
|
| OperationType::Add
|
||||||
| OperationType::Mul
|
| OperationType::Mul
|
||||||
| OperationType::Sum
|
| OperationType::Sum
|
||||||
| OperationType::Mean
|
| OperationType::Mean
|
||||||
| OperationType::Max
|
| OperationType::Max
|
||||||
),
|
),
|
||||||
|
|
||||||
// WebGPU has limited operations
|
// WebGPU has limited operations
|
||||||
ProcessorType::WebGpu => matches!(
|
ProcessorType::WebGpu => matches!(
|
||||||
op_type,
|
op_type,
|
||||||
OperationType::MatMul
|
OperationType::MatMul
|
||||||
| OperationType::Conv2d
|
| OperationType::Conv2d
|
||||||
| OperationType::Add
|
| OperationType::Add
|
||||||
| OperationType::Mul
|
| OperationType::Mul
|
||||||
| OperationType::ReLU
|
| OperationType::ReLU
|
||||||
| OperationType::Softmax
|
| OperationType::Softmax
|
||||||
| OperationType::Sum
|
| OperationType::Sum
|
||||||
| OperationType::Transpose
|
| OperationType::Transpose
|
||||||
| OperationType::Reshape
|
| OperationType::Reshape
|
||||||
),
|
),
|
||||||
|
|
||||||
// WASM for portable compute
|
// WASM for portable compute
|
||||||
ProcessorType::Wasm => matches!(
|
ProcessorType::Wasm => matches!(
|
||||||
op_type,
|
op_type,
|
||||||
OperationType::MatMul
|
OperationType::MatMul
|
||||||
| OperationType::Add
|
| OperationType::Add
|
||||||
| OperationType::Mul
|
| OperationType::Mul
|
||||||
| OperationType::ReLU
|
| OperationType::ReLU
|
||||||
| OperationType::Softmax
|
| OperationType::Softmax
|
||||||
| OperationType::Sum
|
| OperationType::Sum
|
||||||
| OperationType::Mean
|
| OperationType::Mean
|
||||||
| OperationType::Tokenization
|
| OperationType::Tokenization
|
||||||
| OperationType::Detokenization
|
| OperationType::Detokenization
|
||||||
),
|
),
|
||||||
|
|
||||||
// Custom processors - assume they can handle anything
|
// Custom processors - assume they can handle anything
|
||||||
|
|
@ -381,7 +386,9 @@ impl LoadBalancer {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get utilization and metrics
|
// Get utilization and metrics
|
||||||
let utilization = proc_metrics.map(|m| m.utilization).unwrap_or(load as f64 / 100.0);
|
let utilization = proc_metrics
|
||||||
|
.map(|m| m.utilization)
|
||||||
|
.unwrap_or(load as f64 / 100.0);
|
||||||
let power = proc_metrics.map(|m| m.power_watts).unwrap_or(100.0);
|
let power = proc_metrics.map(|m| m.power_watts).unwrap_or(100.0);
|
||||||
let avg_completion = proc_metrics.map(|m| m.avg_completion_ms).unwrap_or(100.0);
|
let avg_completion = proc_metrics.map(|m| m.avg_completion_ms).unwrap_or(100.0);
|
||||||
|
|
||||||
|
|
@ -431,13 +438,13 @@ impl LoadBalancer {
|
||||||
BalancingStrategy::Cost => {
|
BalancingStrategy::Cost => {
|
||||||
// Prioritize cheaper resources (consumer devices)
|
// Prioritize cheaper resources (consumer devices)
|
||||||
let cost_factor = match processor_type {
|
let cost_factor = match processor_type {
|
||||||
ProcessorType::Wasm => 0.1, // Cheapest (browser)
|
ProcessorType::Wasm => 0.1, // Cheapest (browser)
|
||||||
ProcessorType::WebGpu => 0.15,
|
ProcessorType::WebGpu => 0.15,
|
||||||
ProcessorType::Cpu(_) => 0.2,
|
ProcessorType::Cpu(_) => 0.2,
|
||||||
ProcessorType::Npu(_) => 0.3, // Mobile NPUs
|
ProcessorType::Npu(_) => 0.3, // Mobile NPUs
|
||||||
ProcessorType::Gpu(_) => 0.5,
|
ProcessorType::Gpu(_) => 0.5,
|
||||||
ProcessorType::Lpu => 0.8,
|
ProcessorType::Lpu => 0.8,
|
||||||
ProcessorType::Tpu(_) => 1.0, // Most expensive
|
ProcessorType::Tpu(_) => 1.0, // Most expensive
|
||||||
_ => 0.5,
|
_ => 0.5,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -450,7 +457,7 @@ impl LoadBalancer {
|
||||||
|
|
||||||
// Bonus for low-latency processors
|
// Bonus for low-latency processors
|
||||||
let latency_bonus = match processor_type {
|
let latency_bonus = match processor_type {
|
||||||
ProcessorType::Lpu => 5.0, // Designed for low latency
|
ProcessorType::Lpu => 5.0, // Designed for low latency
|
||||||
ProcessorType::Npu(_) => 3.0,
|
ProcessorType::Npu(_) => 3.0,
|
||||||
ProcessorType::Gpu(_) => 2.0,
|
ProcessorType::Gpu(_) => 2.0,
|
||||||
ProcessorType::Tpu(_) => 1.5,
|
ProcessorType::Tpu(_) => 1.5,
|
||||||
|
|
@ -550,7 +557,8 @@ impl LoadBalancer {
|
||||||
let mut suggestions = Vec::new();
|
let mut suggestions = Vec::new();
|
||||||
let loads = self.loads.read();
|
let loads = self.loads.read();
|
||||||
|
|
||||||
let load_values: Vec<_> = loads.iter()
|
let load_values: Vec<_> = loads
|
||||||
|
.iter()
|
||||||
.map(|(id, load)| (*id, load.load(Ordering::Relaxed)))
|
.map(|(id, load)| (*id, load.load(Ordering::Relaxed)))
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
|
|
@ -558,16 +566,18 @@ impl LoadBalancer {
|
||||||
return suggestions;
|
return suggestions;
|
||||||
}
|
}
|
||||||
|
|
||||||
let avg_load: f64 = load_values.iter().map(|(_, l)| *l as f64).sum::<f64>()
|
let avg_load: f64 =
|
||||||
/ load_values.len() as f64;
|
load_values.iter().map(|(_, l)| *l as f64).sum::<f64>() / load_values.len() as f64;
|
||||||
|
|
||||||
let processor_types = self.processor_types.read();
|
let processor_types = self.processor_types.read();
|
||||||
|
|
||||||
let overloaded: Vec<_> = load_values.iter()
|
let overloaded: Vec<_> = load_values
|
||||||
|
.iter()
|
||||||
.filter(|(_, l)| *l as f64 > avg_load * (1.0 + self.rebalance_threshold))
|
.filter(|(_, l)| *l as f64 > avg_load * (1.0 + self.rebalance_threshold))
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let underloaded: Vec<_> = load_values.iter()
|
let underloaded: Vec<_> = load_values
|
||||||
|
.iter()
|
||||||
.filter(|(_, l)| (*l as f64) < avg_load * (1.0 - self.rebalance_threshold))
|
.filter(|(_, l)| (*l as f64) < avg_load * (1.0 - self.rebalance_threshold))
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
|
|
@ -627,7 +637,9 @@ impl LoadBalancer {
|
||||||
/// Clean up old migration history.
|
/// Clean up old migration history.
|
||||||
pub fn cleanup_history(&self, max_age: Duration) {
|
pub fn cleanup_history(&self, max_age: Duration) {
|
||||||
let cutoff = Instant::now() - max_age;
|
let cutoff = Instant::now() - max_age;
|
||||||
self.migration_history.write().retain(|r| r.timestamp > cutoff);
|
self.migration_history
|
||||||
|
.write()
|
||||||
|
.retain(|r| r.timestamp > cutoff);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -725,7 +737,9 @@ mod tests {
|
||||||
balancer.register_processor(ProcessorId(0), ProcessorType::Cpu(CpuVariant::default()));
|
balancer.register_processor(ProcessorId(0), ProcessorType::Cpu(CpuVariant::default()));
|
||||||
balancer.register_processor(
|
balancer.register_processor(
|
||||||
ProcessorId(1),
|
ProcessorId(1),
|
||||||
ProcessorType::Gpu(GpuVariant::NvidiaCuda { compute_capability: (8, 9) }),
|
ProcessorType::Gpu(GpuVariant::NvidiaCuda {
|
||||||
|
compute_capability: (8, 9),
|
||||||
|
}),
|
||||||
);
|
);
|
||||||
|
|
||||||
// Give CPU high load
|
// Give CPU high load
|
||||||
|
|
@ -757,7 +771,9 @@ mod tests {
|
||||||
};
|
};
|
||||||
|
|
||||||
let cpu = ProcessorType::Cpu(CpuVariant::default());
|
let cpu = ProcessorType::Cpu(CpuVariant::default());
|
||||||
let gpu = ProcessorType::Gpu(GpuVariant::NvidiaCuda { compute_capability: (8, 9) });
|
let gpu = ProcessorType::Gpu(GpuVariant::NvidiaCuda {
|
||||||
|
compute_capability: (8, 9),
|
||||||
|
});
|
||||||
let lpu = ProcessorType::Lpu;
|
let lpu = ProcessorType::Lpu;
|
||||||
|
|
||||||
// MatMul can run on all
|
// MatMul can run on all
|
||||||
|
|
@ -778,7 +794,10 @@ mod tests {
|
||||||
let npu_id = ProcessorId(1);
|
let npu_id = ProcessorId(1);
|
||||||
|
|
||||||
balancer.register_processor(cpu_id, ProcessorType::Cpu(CpuVariant::default()));
|
balancer.register_processor(cpu_id, ProcessorType::Cpu(CpuVariant::default()));
|
||||||
balancer.register_processor(npu_id, ProcessorType::Npu(crate::processor::NpuVariant::AppleNeuralEngine { cores: 16 }));
|
balancer.register_processor(
|
||||||
|
npu_id,
|
||||||
|
ProcessorType::Npu(crate::processor::NpuVariant::AppleNeuralEngine { cores: 16 }),
|
||||||
|
);
|
||||||
|
|
||||||
let task = create_test_task(TaskPriority::Normal);
|
let task = create_test_task(TaskPriority::Normal);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -69,7 +69,9 @@ impl HeterogeneousScheduler {
|
||||||
let utilization = self.estimate_utilization(&schedule);
|
let utilization = self.estimate_utilization(&schedule);
|
||||||
|
|
||||||
// 5. Store active schedule
|
// 5. Store active schedule
|
||||||
self.active_schedules.write().insert(schedule.id, schedule.clone());
|
self.active_schedules
|
||||||
|
.write()
|
||||||
|
.insert(schedule.id, schedule.clone());
|
||||||
|
|
||||||
Ok(ScheduleResult {
|
Ok(ScheduleResult {
|
||||||
schedule,
|
schedule,
|
||||||
|
|
@ -89,10 +91,12 @@ impl HeterogeneousScheduler {
|
||||||
let mut handles = Vec::new();
|
let mut handles = Vec::new();
|
||||||
|
|
||||||
for task_id in &stage.tasks {
|
for task_id in &stage.tasks {
|
||||||
let task = schedule.tasks.get(task_id)
|
let task = schedule.tasks.get(task_id).ok_or_else(|| {
|
||||||
.ok_or_else(|| ComputeError::Internal(format!("Task not found: {:?}", task_id)))?;
|
ComputeError::Internal(format!("Task not found: {:?}", task_id))
|
||||||
let processor_id = schedule.assignment.get(task_id)
|
})?;
|
||||||
.ok_or_else(|| ComputeError::Internal(format!("No assignment for task: {:?}", task_id)))?;
|
let processor_id = schedule.assignment.get(task_id).ok_or_else(|| {
|
||||||
|
ComputeError::Internal(format!("No assignment for task: {:?}", task_id))
|
||||||
|
})?;
|
||||||
|
|
||||||
let processor = self.device_registry.get_processor(processor_id)?;
|
let processor = self.device_registry.get_processor(processor_id)?;
|
||||||
let task_clone = task.clone();
|
let task_clone = task.clone();
|
||||||
|
|
@ -144,8 +148,9 @@ impl HeterogeneousScheduler {
|
||||||
let best_processor = self.find_best_processor(&task).await?;
|
let best_processor = self.find_best_processor(&task).await?;
|
||||||
|
|
||||||
// Check if we should rebalance
|
// Check if we should rebalance
|
||||||
let final_processor = self.load_balancer
|
let final_processor =
|
||||||
.maybe_rebalance(&task, best_processor, &assignment);
|
self.load_balancer
|
||||||
|
.maybe_rebalance(&task, best_processor, &assignment);
|
||||||
|
|
||||||
assignment.assign(task.id, final_processor);
|
assignment.assign(task.id, final_processor);
|
||||||
}
|
}
|
||||||
|
|
@ -207,9 +212,7 @@ impl HeterogeneousScheduler {
|
||||||
fn topological_sort(&self, tasks: &[Task], deps: &DependencyGraph) -> Vec<Task> {
|
fn topological_sort(&self, tasks: &[Task], deps: &DependencyGraph) -> Vec<Task> {
|
||||||
let mut sorted = Vec::new();
|
let mut sorted = Vec::new();
|
||||||
let mut visited = std::collections::HashSet::new();
|
let mut visited = std::collections::HashSet::new();
|
||||||
let task_map: HashMap<TaskId, Task> = tasks.iter()
|
let task_map: HashMap<TaskId, Task> = tasks.iter().map(|t| (t.id, t.clone())).collect();
|
||||||
.map(|t| (t.id, t.clone()))
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
fn visit(
|
fn visit(
|
||||||
task_id: TaskId,
|
task_id: TaskId,
|
||||||
|
|
@ -254,9 +257,7 @@ impl HeterogeneousScheduler {
|
||||||
) -> Result<Schedule, ComputeError> {
|
) -> Result<Schedule, ComputeError> {
|
||||||
let mut stages = Vec::new();
|
let mut stages = Vec::new();
|
||||||
let mut scheduled = std::collections::HashSet::new();
|
let mut scheduled = std::collections::HashSet::new();
|
||||||
let task_map: HashMap<TaskId, Task> = tasks.iter()
|
let task_map: HashMap<TaskId, Task> = tasks.iter().map(|t| (t.id, t.clone())).collect();
|
||||||
.map(|t| (t.id, t.clone()))
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
while scheduled.len() < tasks.len() {
|
while scheduled.len() < tasks.len() {
|
||||||
let mut stage_tasks = Vec::new();
|
let mut stage_tasks = Vec::new();
|
||||||
|
|
@ -267,8 +268,7 @@ impl HeterogeneousScheduler {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if all dependencies are satisfied
|
// Check if all dependencies are satisfied
|
||||||
let deps_satisfied = task.dependencies.iter()
|
let deps_satisfied = task.dependencies.iter().all(|dep| scheduled.contains(dep));
|
||||||
.all(|dep| scheduled.contains(dep));
|
|
||||||
|
|
||||||
if deps_satisfied {
|
if deps_satisfied {
|
||||||
stage_tasks.push(task.id);
|
stage_tasks.push(task.id);
|
||||||
|
|
@ -277,7 +277,7 @@ impl HeterogeneousScheduler {
|
||||||
|
|
||||||
if stage_tasks.is_empty() {
|
if stage_tasks.is_empty() {
|
||||||
return Err(ComputeError::SchedulingFailed(
|
return Err(ComputeError::SchedulingFailed(
|
||||||
"Circular dependency detected".to_string()
|
"Circular dependency detected".to_string(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -153,7 +153,10 @@ impl PriorityWorkQueue {
|
||||||
TaskPriority::Normal,
|
TaskPriority::Normal,
|
||||||
TaskPriority::Background,
|
TaskPriority::Background,
|
||||||
] {
|
] {
|
||||||
queues.insert(priority, WorkQueue::new(processor_type.clone(), capacity_per_priority));
|
queues.insert(
|
||||||
|
priority,
|
||||||
|
WorkQueue::new(processor_type.clone(), capacity_per_priority),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
|
|
@ -223,10 +226,7 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_work_queue_basic() {
|
fn test_work_queue_basic() {
|
||||||
let queue = WorkQueue::new(
|
let queue = WorkQueue::new(ProcessorType::Cpu(CpuVariant::default()), 100);
|
||||||
ProcessorType::Cpu(CpuVariant::default()),
|
|
||||||
100,
|
|
||||||
);
|
|
||||||
|
|
||||||
assert!(queue.is_empty());
|
assert!(queue.is_empty());
|
||||||
|
|
||||||
|
|
@ -246,10 +246,7 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_priority_queue() {
|
fn test_priority_queue() {
|
||||||
let queue = PriorityWorkQueue::new(
|
let queue = PriorityWorkQueue::new(ProcessorType::Cpu(CpuVariant::default()), 100);
|
||||||
ProcessorType::Cpu(CpuVariant::default()),
|
|
||||||
100,
|
|
||||||
);
|
|
||||||
|
|
||||||
queue.push(create_test_task(1, TaskPriority::Background));
|
queue.push(create_test_task(1, TaskPriority::Background));
|
||||||
queue.push(create_test_task(2, TaskPriority::Critical));
|
queue.push(create_test_task(2, TaskPriority::Critical));
|
||||||
|
|
|
||||||
|
|
@ -495,9 +495,9 @@ mod tests {
|
||||||
compute_capability: (8, 0)
|
compute_capability: (8, 0)
|
||||||
}
|
}
|
||||||
)));
|
)));
|
||||||
assert!(matmul_task.is_compatible_with(ProcessorType::Tpu(
|
assert!(
|
||||||
crate::processor::TpuVersion::V5p
|
matmul_task.is_compatible_with(ProcessorType::Tpu(crate::processor::TpuVersion::V5p))
|
||||||
)));
|
);
|
||||||
|
|
||||||
let data_load_task = Task::new(Operation::DataLoad {
|
let data_load_task = Task::new(Operation::DataLoad {
|
||||||
bytes: 1000,
|
bytes: 1000,
|
||||||
|
|
@ -505,9 +505,8 @@ mod tests {
|
||||||
});
|
});
|
||||||
|
|
||||||
// DataLoad should be compatible with CPU
|
// DataLoad should be compatible with CPU
|
||||||
assert!(data_load_task.is_compatible_with(ProcessorType::Cpu(
|
assert!(data_load_task
|
||||||
crate::processor::CpuVariant::default()
|
.is_compatible_with(ProcessorType::Cpu(crate::processor::CpuVariant::default())));
|
||||||
)));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
|
||||||
|
|
@ -25,8 +25,7 @@
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
/// Blocks per second mode.
|
/// Blocks per second mode.
|
||||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Default)]
|
||||||
#[derive(Default)]
|
|
||||||
pub enum BpsMode {
|
pub enum BpsMode {
|
||||||
/// Standard mode: 10 blocks per second (100ms block time)
|
/// Standard mode: 10 blocks per second (100ms block time)
|
||||||
/// - Suitable for most network conditions
|
/// - Suitable for most network conditions
|
||||||
|
|
@ -75,7 +74,6 @@ impl BpsMode {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
impl std::fmt::Display for BpsMode {
|
impl std::fmt::Display for BpsMode {
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
match self {
|
match self {
|
||||||
|
|
@ -148,39 +146,39 @@ impl NetworkConfig {
|
||||||
bps_mode: mode,
|
bps_mode: mode,
|
||||||
blocks_per_second: 10,
|
blocks_per_second: 10,
|
||||||
target_block_time_ms: 100,
|
target_block_time_ms: 100,
|
||||||
daa_window_size: 2641, // ~264s window
|
daa_window_size: 2641, // ~264s window
|
||||||
ghostdag_k: 18, // For 10 BPS
|
ghostdag_k: 18, // For 10 BPS
|
||||||
dagknight_k_min: 8,
|
dagknight_k_min: 8,
|
||||||
dagknight_k_max: 64,
|
dagknight_k_max: 64,
|
||||||
finality_depth: 864, // ~86 seconds
|
finality_depth: 864, // ~86 seconds
|
||||||
pruning_depth: 864_000, // ~24 hours
|
pruning_depth: 864_000, // ~24 hours
|
||||||
merge_set_size_limit: 180,
|
merge_set_size_limit: 180,
|
||||||
expected_delay_ms: 100,
|
expected_delay_ms: 100,
|
||||||
},
|
},
|
||||||
BpsMode::Fast32 => Self {
|
BpsMode::Fast32 => Self {
|
||||||
bps_mode: mode,
|
bps_mode: mode,
|
||||||
blocks_per_second: 32,
|
blocks_per_second: 32,
|
||||||
target_block_time_ms: 31, // ~31.25ms
|
target_block_time_ms: 31, // ~31.25ms
|
||||||
daa_window_size: 8461, // ~264s window at 32 BPS
|
daa_window_size: 8461, // ~264s window at 32 BPS
|
||||||
ghostdag_k: 58, // Scaled for 32 BPS
|
ghostdag_k: 58, // Scaled for 32 BPS
|
||||||
dagknight_k_min: 16, // Higher min for faster blocks
|
dagknight_k_min: 16, // Higher min for faster blocks
|
||||||
dagknight_k_max: 128, // Higher max for adaptation
|
dagknight_k_max: 128, // Higher max for adaptation
|
||||||
finality_depth: 2765, // ~86 seconds at 32 BPS
|
finality_depth: 2765, // ~86 seconds at 32 BPS
|
||||||
pruning_depth: 2_764_800, // ~24 hours at 32 BPS
|
pruning_depth: 2_764_800, // ~24 hours at 32 BPS
|
||||||
merge_set_size_limit: 576, // 32/10 * 180
|
merge_set_size_limit: 576, // 32/10 * 180
|
||||||
expected_delay_ms: 50,
|
expected_delay_ms: 50,
|
||||||
},
|
},
|
||||||
BpsMode::Ultra100 => Self {
|
BpsMode::Ultra100 => Self {
|
||||||
bps_mode: mode,
|
bps_mode: mode,
|
||||||
blocks_per_second: 100,
|
blocks_per_second: 100,
|
||||||
target_block_time_ms: 10,
|
target_block_time_ms: 10,
|
||||||
daa_window_size: 26410, // ~264s window at 100 BPS
|
daa_window_size: 26410, // ~264s window at 100 BPS
|
||||||
ghostdag_k: 180, // Scaled for 100 BPS
|
ghostdag_k: 180, // Scaled for 100 BPS
|
||||||
dagknight_k_min: 50, // Higher min for very fast blocks
|
dagknight_k_min: 50, // Higher min for very fast blocks
|
||||||
dagknight_k_max: 255, // u8 max - very high for adaptation
|
dagknight_k_max: 255, // u8 max - very high for adaptation
|
||||||
finality_depth: 8640, // ~86 seconds at 100 BPS
|
finality_depth: 8640, // ~86 seconds at 100 BPS
|
||||||
pruning_depth: 8_640_000, // ~24 hours at 100 BPS
|
pruning_depth: 8_640_000, // ~24 hours at 100 BPS
|
||||||
merge_set_size_limit: 1800, // 100/10 * 180
|
merge_set_size_limit: 1800, // 100/10 * 180
|
||||||
expected_delay_ms: 20,
|
expected_delay_ms: 20,
|
||||||
},
|
},
|
||||||
BpsMode::Custom(bps) => {
|
BpsMode::Custom(bps) => {
|
||||||
|
|
@ -269,7 +267,7 @@ pub fn bps_comparison_table() -> String {
|
||||||
|
|
||||||
let mut table = String::from(
|
let mut table = String::from(
|
||||||
"| Property | Standard (10 BPS) | Fast (32 BPS) | Ultra (100 BPS) |\n\
|
"| Property | Standard (10 BPS) | Fast (32 BPS) | Ultra (100 BPS) |\n\
|
||||||
|----------|-------------------|---------------|------------------|\n"
|
|----------|-------------------|---------------|------------------|\n",
|
||||||
);
|
);
|
||||||
|
|
||||||
// Block Time
|
// Block Time
|
||||||
|
|
@ -314,7 +312,9 @@ pub fn bps_comparison_table() -> String {
|
||||||
// Estimated TPS
|
// Estimated TPS
|
||||||
table.push_str(&format!(
|
table.push_str(&format!(
|
||||||
"| Est. TPS @1000tx/block | {:.0} | {:.0} | {:.0} |\n",
|
"| Est. TPS @1000tx/block | {:.0} | {:.0} | {:.0} |\n",
|
||||||
standard.estimate_tps(1000), fast.estimate_tps(1000), ultra.estimate_tps(1000)
|
standard.estimate_tps(1000),
|
||||||
|
fast.estimate_tps(1000),
|
||||||
|
ultra.estimate_tps(1000)
|
||||||
));
|
));
|
||||||
|
|
||||||
table
|
table
|
||||||
|
|
@ -401,9 +401,9 @@ mod tests {
|
||||||
fn test_latency_acceptable() {
|
fn test_latency_acceptable() {
|
||||||
let config = NetworkConfig::standard(); // expects 100ms
|
let config = NetworkConfig::standard(); // expects 100ms
|
||||||
|
|
||||||
assert!(config.is_latency_acceptable(50)); // Good
|
assert!(config.is_latency_acceptable(50)); // Good
|
||||||
assert!(config.is_latency_acceptable(100)); // OK
|
assert!(config.is_latency_acceptable(100)); // OK
|
||||||
assert!(config.is_latency_acceptable(200)); // Still OK (2x limit)
|
assert!(config.is_latency_acceptable(200)); // Still OK (2x limit)
|
||||||
assert!(!config.is_latency_acceptable(300)); // Too high
|
assert!(!config.is_latency_acceptable(300)); // Too high
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -55,8 +55,8 @@
|
||||||
//! | Layer 2 transactions | FALCON-512 (batch efficiency) |
|
//! | Layer 2 transactions | FALCON-512 (batch efficiency) |
|
||||||
//! | High-value transactions | Dilithium3 (conservative choice) |
|
//! | High-value transactions | Dilithium3 (conservative choice) |
|
||||||
|
|
||||||
use pqcrypto_falcon::falcon512;
|
|
||||||
use pqcrypto_falcon::falcon1024;
|
use pqcrypto_falcon::falcon1024;
|
||||||
|
use pqcrypto_falcon::falcon512;
|
||||||
use pqcrypto_traits::sign::{
|
use pqcrypto_traits::sign::{
|
||||||
DetachedSignature, PublicKey as PqPublicKey, SecretKey as PqSecretKey,
|
DetachedSignature, PublicKey as PqPublicKey, SecretKey as PqSecretKey,
|
||||||
};
|
};
|
||||||
|
|
@ -64,8 +64,7 @@ use thiserror::Error;
|
||||||
use zeroize::{Zeroize, ZeroizeOnDrop};
|
use zeroize::{Zeroize, ZeroizeOnDrop};
|
||||||
|
|
||||||
/// FALCON variant selection.
|
/// FALCON variant selection.
|
||||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
|
||||||
#[derive(Default)]
|
|
||||||
pub enum FalconVariant {
|
pub enum FalconVariant {
|
||||||
/// 128-bit security, ~690 byte signatures
|
/// 128-bit security, ~690 byte signatures
|
||||||
#[default]
|
#[default]
|
||||||
|
|
@ -124,7 +123,6 @@ impl FalconVariant {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/// FALCON public key.
|
/// FALCON public key.
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct FalconPublicKey {
|
pub struct FalconPublicKey {
|
||||||
|
|
@ -188,7 +186,10 @@ impl std::fmt::Debug for FalconPublicKey {
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
f.debug_struct("FalconPublicKey")
|
f.debug_struct("FalconPublicKey")
|
||||||
.field("variant", &self.variant)
|
.field("variant", &self.variant)
|
||||||
.field("bytes", &hex::encode(&self.bytes[..8.min(self.bytes.len())]))
|
.field(
|
||||||
|
"bytes",
|
||||||
|
&hex::encode(&self.bytes[..8.min(self.bytes.len())]),
|
||||||
|
)
|
||||||
.finish()
|
.finish()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -492,7 +493,10 @@ mod tests {
|
||||||
|
|
||||||
// Verify with wrong message should fail
|
// Verify with wrong message should fail
|
||||||
let wrong_message = b"Wrong message";
|
let wrong_message = b"Wrong message";
|
||||||
assert!(keypair.public_key().verify(wrong_message, &signature).is_err());
|
assert!(keypair
|
||||||
|
.public_key()
|
||||||
|
.verify(wrong_message, &signature)
|
||||||
|
.is_err());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
|
||||||
|
|
@ -150,9 +150,9 @@ impl PqAlgorithm {
|
||||||
/// Default priority order (higher = more preferred)
|
/// Default priority order (higher = more preferred)
|
||||||
fn default_priority(&self) -> u8 {
|
fn default_priority(&self) -> u8 {
|
||||||
match self {
|
match self {
|
||||||
Self::Dilithium3 => 100, // Default, well-balanced
|
Self::Dilithium3 => 100, // Default, well-balanced
|
||||||
Self::Falcon1024 => 90, // High security, compact
|
Self::Falcon1024 => 90, // High security, compact
|
||||||
Self::Falcon512 => 85, // Compact, mobile-friendly
|
Self::Falcon512 => 85, // Compact, mobile-friendly
|
||||||
Self::SphincsShake192s => 70, // Conservative backup
|
Self::SphincsShake192s => 70, // Conservative backup
|
||||||
Self::SphincsShake256s => 60, // Maximum security
|
Self::SphincsShake256s => 60, // Maximum security
|
||||||
Self::SphincsShake128s => 50, // Basic SPHINCS+
|
Self::SphincsShake128s => 50, // Basic SPHINCS+
|
||||||
|
|
@ -270,7 +270,8 @@ impl AlgorithmCapabilities {
|
||||||
|
|
||||||
/// Decode capabilities from bytes
|
/// Decode capabilities from bytes
|
||||||
pub fn decode(data: &[u8]) -> Result<Self, NegotiationError> {
|
pub fn decode(data: &[u8]) -> Result<Self, NegotiationError> {
|
||||||
serde_json::from_slice(data).map_err(|e| NegotiationError::InvalidCapabilities(e.to_string()))
|
serde_json::from_slice(data)
|
||||||
|
.map_err(|e| NegotiationError::InvalidCapabilities(e.to_string()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -384,8 +385,7 @@ impl AlgorithmNegotiator {
|
||||||
// Check security level
|
// Check security level
|
||||||
let meets_local_security =
|
let meets_local_security =
|
||||||
algo.security_level() >= self.local_caps.min_security_level;
|
algo.security_level() >= self.local_caps.min_security_level;
|
||||||
let meets_remote_security =
|
let meets_remote_security = algo.security_level() >= remote_caps.min_security_level;
|
||||||
algo.security_level() >= remote_caps.min_security_level;
|
|
||||||
|
|
||||||
// Check signature size
|
// Check signature size
|
||||||
let local_size_ok = self.local_caps.max_signature_size == 0
|
let local_size_ok = self.local_caps.max_signature_size == 0
|
||||||
|
|
@ -513,10 +513,7 @@ impl AlgorithmNegotiator {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Quick negotiation using just algorithm names
|
/// Quick negotiation using just algorithm names
|
||||||
pub fn quick_negotiate(
|
pub fn quick_negotiate(local: &[PqAlgorithm], remote: &[PqAlgorithm]) -> Option<PqAlgorithm> {
|
||||||
local: &[PqAlgorithm],
|
|
||||||
remote: &[PqAlgorithm],
|
|
||||||
) -> Option<PqAlgorithm> {
|
|
||||||
// Find common algorithms and return the one with highest default priority
|
// Find common algorithms and return the one with highest default priority
|
||||||
let local_set: HashSet<_> = local.iter().collect();
|
let local_set: HashSet<_> = local.iter().collect();
|
||||||
let remote_set: HashSet<_> = remote.iter().collect();
|
let remote_set: HashSet<_> = remote.iter().collect();
|
||||||
|
|
@ -604,7 +601,10 @@ pub enum NegotiationMessage {
|
||||||
},
|
},
|
||||||
|
|
||||||
/// Acknowledge selection
|
/// Acknowledge selection
|
||||||
Acknowledgment { session_id: [u8; 32], accepted: bool },
|
Acknowledgment {
|
||||||
|
session_id: [u8; 32],
|
||||||
|
accepted: bool,
|
||||||
|
},
|
||||||
|
|
||||||
/// Request renegotiation
|
/// Request renegotiation
|
||||||
Renegotiate { reason: String },
|
Renegotiate { reason: String },
|
||||||
|
|
@ -691,8 +691,10 @@ mod tests {
|
||||||
let result = negotiator.negotiate(&remote_caps).unwrap();
|
let result = negotiator.negotiate(&remote_caps).unwrap();
|
||||||
|
|
||||||
// Should prefer FALCON for bandwidth-constrained scenarios
|
// Should prefer FALCON for bandwidth-constrained scenarios
|
||||||
assert!(result.algorithm == PqAlgorithm::Falcon512 ||
|
assert!(
|
||||||
result.algorithm == PqAlgorithm::Falcon1024);
|
result.algorithm == PqAlgorithm::Falcon512
|
||||||
|
|| result.algorithm == PqAlgorithm::Falcon1024
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
|
||||||
|
|
@ -60,8 +60,7 @@ use zeroize::{Zeroize, ZeroizeOnDrop};
|
||||||
/// All variants use SHAKE (SHA3-based) for hashing.
|
/// All variants use SHAKE (SHA3-based) for hashing.
|
||||||
/// 's' variants have smaller signatures but are slower.
|
/// 's' variants have smaller signatures but are slower.
|
||||||
/// 'f' variants are faster but have larger signatures.
|
/// 'f' variants are faster but have larger signatures.
|
||||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
|
||||||
#[derive(Default)]
|
|
||||||
pub enum SphincsVariant {
|
pub enum SphincsVariant {
|
||||||
/// 128-bit security, small signatures (~7.8KB)
|
/// 128-bit security, small signatures (~7.8KB)
|
||||||
#[default]
|
#[default]
|
||||||
|
|
@ -119,7 +118,6 @@ impl SphincsVariant {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/// SPHINCS+ public key.
|
/// SPHINCS+ public key.
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct SphincsPublicKey {
|
pub struct SphincsPublicKey {
|
||||||
|
|
@ -191,7 +189,10 @@ impl std::fmt::Debug for SphincsPublicKey {
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
f.debug_struct("SphincsPublicKey")
|
f.debug_struct("SphincsPublicKey")
|
||||||
.field("variant", &self.variant)
|
.field("variant", &self.variant)
|
||||||
.field("bytes", &hex::encode(&self.bytes[..8.min(self.bytes.len())]))
|
.field(
|
||||||
|
"bytes",
|
||||||
|
&hex::encode(&self.bytes[..8.min(self.bytes.len())]),
|
||||||
|
)
|
||||||
.finish()
|
.finish()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -500,7 +501,10 @@ mod tests {
|
||||||
|
|
||||||
// Verify with wrong message should fail
|
// Verify with wrong message should fail
|
||||||
let wrong_message = b"Wrong message";
|
let wrong_message = b"Wrong message";
|
||||||
assert!(keypair.public_key().verify(wrong_message, &signature).is_err());
|
assert!(keypair
|
||||||
|
.public_key()
|
||||||
|
.verify(wrong_message, &signature)
|
||||||
|
.is_err());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
|
||||||
|
|
@ -155,10 +155,7 @@ pub struct DagKnightManager {
|
||||||
|
|
||||||
impl DagKnightManager {
|
impl DagKnightManager {
|
||||||
/// Creates a new DAGKnight manager with standard 10 BPS configuration.
|
/// Creates a new DAGKnight manager with standard 10 BPS configuration.
|
||||||
pub fn new(
|
pub fn new(dag: Arc<BlockDag>, reachability: Arc<ReachabilityStore>) -> Self {
|
||||||
dag: Arc<BlockDag>,
|
|
||||||
reachability: Arc<ReachabilityStore>,
|
|
||||||
) -> Self {
|
|
||||||
Self::with_config(dag, reachability, BlockRateConfig::Standard)
|
Self::with_config(dag, reachability, BlockRateConfig::Standard)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -269,7 +266,8 @@ impl DagKnightManager {
|
||||||
let anticone_size = self.calculate_anticone_size(&block_id, parents);
|
let anticone_size = self.calculate_anticone_size(&block_id, parents);
|
||||||
|
|
||||||
// Record observation in latency tracker
|
// Record observation in latency tracker
|
||||||
self.latency_tracker.record_block(block_id, block_time_ms, anticone_size);
|
self.latency_tracker
|
||||||
|
.record_block(block_id, block_time_ms, anticone_size);
|
||||||
|
|
||||||
// Process with underlying GHOSTDAG
|
// Process with underlying GHOSTDAG
|
||||||
let data = self.ghostdag.add_block(block_id, parents)?;
|
let data = self.ghostdag.add_block(block_id, parents)?;
|
||||||
|
|
@ -292,11 +290,9 @@ impl DagKnightManager {
|
||||||
for tip in tips {
|
for tip in tips {
|
||||||
if tip != *block_id && !parents.contains(&tip) {
|
if tip != *block_id && !parents.contains(&tip) {
|
||||||
// Check if tip is in the past of any parent
|
// Check if tip is in the past of any parent
|
||||||
let in_past = parents.iter().any(|p| {
|
let in_past = parents
|
||||||
self.reachability
|
.iter()
|
||||||
.is_ancestor(p, &tip)
|
.any(|p| self.reachability.is_ancestor(p, &tip).unwrap_or(false));
|
||||||
.unwrap_or(false)
|
|
||||||
});
|
|
||||||
|
|
||||||
if !in_past {
|
if !in_past {
|
||||||
anticone_count += 1;
|
anticone_count += 1;
|
||||||
|
|
@ -375,7 +371,8 @@ impl DagKnightManager {
|
||||||
let sigma_multiplier = confidence.sigma_multiplier();
|
let sigma_multiplier = confidence.sigma_multiplier();
|
||||||
|
|
||||||
// Required depth scales with variance and confidence level
|
// Required depth scales with variance and confidence level
|
||||||
let required_depth = (self.block_rate_bps * (mean_delay + sigma * sigma_multiplier)).ceil() as u64;
|
let required_depth =
|
||||||
|
(self.block_rate_bps * (mean_delay + sigma * sigma_multiplier)).ceil() as u64;
|
||||||
|
|
||||||
// Current confidence based on actual depth
|
// Current confidence based on actual depth
|
||||||
let current_confidence = if depth >= required_depth {
|
let current_confidence = if depth >= required_depth {
|
||||||
|
|
@ -388,7 +385,8 @@ impl DagKnightManager {
|
||||||
// Time to reach required depth
|
// Time to reach required depth
|
||||||
let blocks_needed = required_depth.saturating_sub(depth);
|
let blocks_needed = required_depth.saturating_sub(depth);
|
||||||
let time_per_block_ms = 1000.0 / self.block_rate_bps;
|
let time_per_block_ms = 1000.0 / self.block_rate_bps;
|
||||||
let estimated_time = Duration::from_millis((blocks_needed as f64 * time_per_block_ms) as u64);
|
let estimated_time =
|
||||||
|
Duration::from_millis((blocks_needed as f64 * time_per_block_ms) as u64);
|
||||||
|
|
||||||
// Block is final if depth exceeds finality threshold for this block rate
|
// Block is final if depth exceeds finality threshold for this block rate
|
||||||
let is_final = depth >= self.finality_depth();
|
let is_final = depth >= self.finality_depth();
|
||||||
|
|
@ -506,7 +504,10 @@ impl std::fmt::Debug for DagKnightManager {
|
||||||
.field("block_rate_config", &self.block_rate_config)
|
.field("block_rate_config", &self.block_rate_config)
|
||||||
.field("block_rate_bps", &self.block_rate_bps)
|
.field("block_rate_bps", &self.block_rate_bps)
|
||||||
.field("adaptive_k", &*self.adaptive_k.read())
|
.field("adaptive_k", &*self.adaptive_k.read())
|
||||||
.field("k_bounds", &format!("{}-{}", self.k_bounds.min_k, self.k_bounds.max_k))
|
.field(
|
||||||
|
"k_bounds",
|
||||||
|
&format!("{}-{}", self.k_bounds.min_k, self.k_bounds.max_k),
|
||||||
|
)
|
||||||
.field("mean_delay_ms", &stats.mean_delay_ms)
|
.field("mean_delay_ms", &stats.mean_delay_ms)
|
||||||
.field("sample_count", &stats.sample_count)
|
.field("sample_count", &stats.sample_count)
|
||||||
.finish()
|
.finish()
|
||||||
|
|
@ -534,10 +535,7 @@ pub fn calculate_optimal_k(network_delay_ms: f64, block_rate_bps: f64) -> u8 {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Calculates the optimal k for a specific block rate configuration.
|
/// Calculates the optimal k for a specific block rate configuration.
|
||||||
pub fn calculate_optimal_k_for_config(
|
pub fn calculate_optimal_k_for_config(network_delay_ms: f64, config: BlockRateConfig) -> u8 {
|
||||||
network_delay_ms: f64,
|
|
||||||
config: BlockRateConfig,
|
|
||||||
) -> u8 {
|
|
||||||
let bounds = AdaptiveKBounds::for_block_rate(config);
|
let bounds = AdaptiveKBounds::for_block_rate(config);
|
||||||
let delay_secs = network_delay_ms / 1000.0;
|
let delay_secs = network_delay_ms / 1000.0;
|
||||||
let k = (config.bps() * delay_secs * SAFETY_MARGIN).ceil() as u16;
|
let k = (config.bps() * delay_secs * SAFETY_MARGIN).ceil() as u16;
|
||||||
|
|
@ -578,7 +576,9 @@ mod tests {
|
||||||
(dag, reachability, dagknight)
|
(dag, reachability, dagknight)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn setup_test_dag_with_config(config: BlockRateConfig) -> (Arc<BlockDag>, Arc<ReachabilityStore>, DagKnightManager) {
|
fn setup_test_dag_with_config(
|
||||||
|
config: BlockRateConfig,
|
||||||
|
) -> (Arc<BlockDag>, Arc<ReachabilityStore>, DagKnightManager) {
|
||||||
let genesis = make_block_id(0);
|
let genesis = make_block_id(0);
|
||||||
let dag = Arc::new(BlockDag::new(genesis, 0));
|
let dag = Arc::new(BlockDag::new(genesis, 0));
|
||||||
let reachability = Arc::new(ReachabilityStore::new(genesis));
|
let reachability = Arc::new(ReachabilityStore::new(genesis));
|
||||||
|
|
@ -671,14 +671,19 @@ mod tests {
|
||||||
let tps_poor = estimate_throughput(10.0, 100, 40.0);
|
let tps_poor = estimate_throughput(10.0, 100, 40.0);
|
||||||
|
|
||||||
// Good network should have higher throughput
|
// Good network should have higher throughput
|
||||||
assert!(tps_good > tps_poor, "tps_good={} should be > tps_poor={}", tps_good, tps_poor);
|
assert!(
|
||||||
|
tps_good > tps_poor,
|
||||||
|
"tps_good={} should be > tps_poor={}",
|
||||||
|
tps_good,
|
||||||
|
tps_poor
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_throughput_by_config() {
|
fn test_throughput_by_config() {
|
||||||
// At same network conditions, higher BPS = higher theoretical TPS
|
// At same network conditions, higher BPS = higher theoretical TPS
|
||||||
let tps_10 = estimate_throughput(10.0, 100, 20.0); // 10 BPS
|
let tps_10 = estimate_throughput(10.0, 100, 20.0); // 10 BPS
|
||||||
let tps_32 = estimate_throughput(32.0, 100, 20.0); // 32 BPS
|
let tps_32 = estimate_throughput(32.0, 100, 20.0); // 32 BPS
|
||||||
let tps_100 = estimate_throughput(100.0, 100, 20.0); // 100 BPS
|
let tps_100 = estimate_throughput(100.0, 100, 20.0); // 100 BPS
|
||||||
|
|
||||||
// Higher block rates give higher TPS (with network overhead)
|
// Higher block rates give higher TPS (with network overhead)
|
||||||
|
|
@ -698,19 +703,37 @@ mod tests {
|
||||||
let maximum_time_hrs = maximum.finality_depth() as f64 / 100.0 / 3600.0;
|
let maximum_time_hrs = maximum.finality_depth() as f64 / 100.0 / 3600.0;
|
||||||
|
|
||||||
// Should all be approximately 2.4 hours (allow some variance)
|
// Should all be approximately 2.4 hours (allow some variance)
|
||||||
assert!((standard_time_hrs - 2.4).abs() < 0.1, "standard: {}", standard_time_hrs);
|
assert!(
|
||||||
assert!((enhanced_time_hrs - 2.4).abs() < 0.1, "enhanced: {}", enhanced_time_hrs);
|
(standard_time_hrs - 2.4).abs() < 0.1,
|
||||||
assert!((maximum_time_hrs - 2.4).abs() < 0.1, "maximum: {}", maximum_time_hrs);
|
"standard: {}",
|
||||||
|
standard_time_hrs
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
(enhanced_time_hrs - 2.4).abs() < 0.1,
|
||||||
|
"enhanced: {}",
|
||||||
|
enhanced_time_hrs
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
(maximum_time_hrs - 2.4).abs() < 0.1,
|
||||||
|
"maximum: {}",
|
||||||
|
maximum_time_hrs
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_confidence_levels() {
|
fn test_confidence_levels() {
|
||||||
assert!(ConfirmationConfidence::VeryHigh.sigma_multiplier()
|
assert!(
|
||||||
> ConfirmationConfidence::High.sigma_multiplier());
|
ConfirmationConfidence::VeryHigh.sigma_multiplier()
|
||||||
assert!(ConfirmationConfidence::High.sigma_multiplier()
|
> ConfirmationConfidence::High.sigma_multiplier()
|
||||||
> ConfirmationConfidence::Medium.sigma_multiplier());
|
);
|
||||||
assert!(ConfirmationConfidence::Medium.sigma_multiplier()
|
assert!(
|
||||||
> ConfirmationConfidence::Low.sigma_multiplier());
|
ConfirmationConfidence::High.sigma_multiplier()
|
||||||
|
> ConfirmationConfidence::Medium.sigma_multiplier()
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
ConfirmationConfidence::Medium.sigma_multiplier()
|
||||||
|
> ConfirmationConfidence::Low.sigma_multiplier()
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
|
||||||
|
|
@ -98,12 +98,7 @@ impl LatencyTracker {
|
||||||
/// * `block_id` - Hash of the observed block
|
/// * `block_id` - Hash of the observed block
|
||||||
/// * `block_time_ms` - Timestamp from block header (Unix ms)
|
/// * `block_time_ms` - Timestamp from block header (Unix ms)
|
||||||
/// * `anticone_size` - Number of blocks in the anticone at observation time
|
/// * `anticone_size` - Number of blocks in the anticone at observation time
|
||||||
pub fn record_block(
|
pub fn record_block(&self, block_id: BlockId, block_time_ms: u64, anticone_size: usize) {
|
||||||
&self,
|
|
||||||
block_id: BlockId,
|
|
||||||
block_time_ms: u64,
|
|
||||||
anticone_size: usize,
|
|
||||||
) {
|
|
||||||
let local_time = Instant::now();
|
let local_time = Instant::now();
|
||||||
let now_ms = std::time::SystemTime::now()
|
let now_ms = std::time::SystemTime::now()
|
||||||
.duration_since(std::time::UNIX_EPOCH)
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
|
@ -208,7 +203,10 @@ impl LatencyTracker {
|
||||||
let anticone_growth_rate = if n > 1 {
|
let anticone_growth_rate = if n > 1 {
|
||||||
let first = samples.front().unwrap();
|
let first = samples.front().unwrap();
|
||||||
let last = samples.back().unwrap();
|
let last = samples.back().unwrap();
|
||||||
let time_span_secs = last.local_time.duration_since(first.local_time).as_secs_f64();
|
let time_span_secs = last
|
||||||
|
.local_time
|
||||||
|
.duration_since(first.local_time)
|
||||||
|
.as_secs_f64();
|
||||||
|
|
||||||
if time_span_secs > 0.0 {
|
if time_span_secs > 0.0 {
|
||||||
let total_anticone_growth: usize = samples.iter().map(|s| s.anticone_size).sum();
|
let total_anticone_growth: usize = samples.iter().map(|s| s.anticone_size).sum();
|
||||||
|
|
|
||||||
|
|
@ -32,8 +32,8 @@ pub mod reachability;
|
||||||
|
|
||||||
pub use dag::{BlockDag, BlockRelations, DagError};
|
pub use dag::{BlockDag, BlockRelations, DagError};
|
||||||
pub use dagknight::{
|
pub use dagknight::{
|
||||||
calculate_optimal_k, calculate_optimal_k_for_config, estimate_throughput,
|
calculate_optimal_k, calculate_optimal_k_for_config, estimate_throughput, AdaptiveKBounds,
|
||||||
AdaptiveKBounds, ConfirmationConfidence, ConfirmationStatus, DagKnightManager,
|
ConfirmationConfidence, ConfirmationStatus, DagKnightManager,
|
||||||
};
|
};
|
||||||
pub use ghostdag::{GhostdagData, GhostdagError, GhostdagManager};
|
pub use ghostdag::{GhostdagData, GhostdagError, GhostdagManager};
|
||||||
pub use latency::{LatencySample, LatencyStats, LatencyTracker};
|
pub use latency::{LatencySample, LatencyStats, LatencyTracker};
|
||||||
|
|
@ -116,27 +116,27 @@ impl BlockRateConfig {
|
||||||
/// Returns the merge depth adjusted for block rate.
|
/// Returns the merge depth adjusted for block rate.
|
||||||
pub const fn merge_depth(&self) -> u64 {
|
pub const fn merge_depth(&self) -> u64 {
|
||||||
match self {
|
match self {
|
||||||
BlockRateConfig::Standard => 3600, // ~6 min at 10 bps
|
BlockRateConfig::Standard => 3600, // ~6 min at 10 bps
|
||||||
BlockRateConfig::Enhanced => 11520, // ~6 min at 32 bps
|
BlockRateConfig::Enhanced => 11520, // ~6 min at 32 bps
|
||||||
BlockRateConfig::Maximum => 36000, // ~6 min at 100 bps
|
BlockRateConfig::Maximum => 36000, // ~6 min at 100 bps
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the finality depth adjusted for block rate.
|
/// Returns the finality depth adjusted for block rate.
|
||||||
pub const fn finality_depth(&self) -> u64 {
|
pub const fn finality_depth(&self) -> u64 {
|
||||||
match self {
|
match self {
|
||||||
BlockRateConfig::Standard => 86400, // ~2.4 hours at 10 bps
|
BlockRateConfig::Standard => 86400, // ~2.4 hours at 10 bps
|
||||||
BlockRateConfig::Enhanced => 276480, // ~2.4 hours at 32 bps
|
BlockRateConfig::Enhanced => 276480, // ~2.4 hours at 32 bps
|
||||||
BlockRateConfig::Maximum => 864000, // ~2.4 hours at 100 bps
|
BlockRateConfig::Maximum => 864000, // ~2.4 hours at 100 bps
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the pruning depth adjusted for block rate.
|
/// Returns the pruning depth adjusted for block rate.
|
||||||
pub const fn pruning_depth(&self) -> u64 {
|
pub const fn pruning_depth(&self) -> u64 {
|
||||||
match self {
|
match self {
|
||||||
BlockRateConfig::Standard => 288_000, // ~8 hours at 10 bps
|
BlockRateConfig::Standard => 288_000, // ~8 hours at 10 bps
|
||||||
BlockRateConfig::Enhanced => 921_600, // ~8 hours at 32 bps
|
BlockRateConfig::Enhanced => 921_600, // ~8 hours at 32 bps
|
||||||
BlockRateConfig::Maximum => 2_880_000, // ~8 hours at 100 bps
|
BlockRateConfig::Maximum => 2_880_000, // ~8 hours at 100 bps
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -42,7 +42,9 @@ impl DocumentId {
|
||||||
let bytes = hex::decode(s)
|
let bytes = hex::decode(s)
|
||||||
.map_err(|_| DatabaseError::InvalidOperation("Invalid hex string".into()))?;
|
.map_err(|_| DatabaseError::InvalidOperation("Invalid hex string".into()))?;
|
||||||
if bytes.len() != 32 {
|
if bytes.len() != 32 {
|
||||||
return Err(DatabaseError::InvalidOperation("Invalid document ID length".into()));
|
return Err(DatabaseError::InvalidOperation(
|
||||||
|
"Invalid document ID length".into(),
|
||||||
|
));
|
||||||
}
|
}
|
||||||
let mut arr = [0u8; 32];
|
let mut arr = [0u8; 32];
|
||||||
arr.copy_from_slice(&bytes);
|
arr.copy_from_slice(&bytes);
|
||||||
|
|
@ -249,7 +251,11 @@ impl Collection {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Updates documents matching a filter.
|
/// Updates documents matching a filter.
|
||||||
pub fn update_many(&self, filter: &DocumentFilter, update: JsonValue) -> Result<u64, DatabaseError> {
|
pub fn update_many(
|
||||||
|
&self,
|
||||||
|
filter: &DocumentFilter,
|
||||||
|
update: JsonValue,
|
||||||
|
) -> Result<u64, DatabaseError> {
|
||||||
let mut docs = self.documents.write();
|
let mut docs = self.documents.write();
|
||||||
let mut count = 0;
|
let mut count = 0;
|
||||||
for doc in docs.values_mut() {
|
for doc in docs.values_mut() {
|
||||||
|
|
@ -321,60 +327,71 @@ enum FilterCondition {
|
||||||
impl DocumentFilter {
|
impl DocumentFilter {
|
||||||
/// Creates a new empty filter (matches all).
|
/// Creates a new empty filter (matches all).
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
Self { conditions: Vec::new() }
|
Self {
|
||||||
|
conditions: Vec::new(),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Equality condition.
|
/// Equality condition.
|
||||||
pub fn eq(mut self, field: impl Into<String>, value: JsonValue) -> Self {
|
pub fn eq(mut self, field: impl Into<String>, value: JsonValue) -> Self {
|
||||||
self.conditions.push(FilterCondition::Eq(field.into(), value));
|
self.conditions
|
||||||
|
.push(FilterCondition::Eq(field.into(), value));
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Not equal condition.
|
/// Not equal condition.
|
||||||
pub fn ne(mut self, field: impl Into<String>, value: JsonValue) -> Self {
|
pub fn ne(mut self, field: impl Into<String>, value: JsonValue) -> Self {
|
||||||
self.conditions.push(FilterCondition::Ne(field.into(), value));
|
self.conditions
|
||||||
|
.push(FilterCondition::Ne(field.into(), value));
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Greater than.
|
/// Greater than.
|
||||||
pub fn gt(mut self, field: impl Into<String>, value: JsonValue) -> Self {
|
pub fn gt(mut self, field: impl Into<String>, value: JsonValue) -> Self {
|
||||||
self.conditions.push(FilterCondition::Gt(field.into(), value));
|
self.conditions
|
||||||
|
.push(FilterCondition::Gt(field.into(), value));
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Greater than or equal.
|
/// Greater than or equal.
|
||||||
pub fn gte(mut self, field: impl Into<String>, value: JsonValue) -> Self {
|
pub fn gte(mut self, field: impl Into<String>, value: JsonValue) -> Self {
|
||||||
self.conditions.push(FilterCondition::Gte(field.into(), value));
|
self.conditions
|
||||||
|
.push(FilterCondition::Gte(field.into(), value));
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Less than.
|
/// Less than.
|
||||||
pub fn lt(mut self, field: impl Into<String>, value: JsonValue) -> Self {
|
pub fn lt(mut self, field: impl Into<String>, value: JsonValue) -> Self {
|
||||||
self.conditions.push(FilterCondition::Lt(field.into(), value));
|
self.conditions
|
||||||
|
.push(FilterCondition::Lt(field.into(), value));
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Less than or equal.
|
/// Less than or equal.
|
||||||
pub fn lte(mut self, field: impl Into<String>, value: JsonValue) -> Self {
|
pub fn lte(mut self, field: impl Into<String>, value: JsonValue) -> Self {
|
||||||
self.conditions.push(FilterCondition::Lte(field.into(), value));
|
self.conditions
|
||||||
|
.push(FilterCondition::Lte(field.into(), value));
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// In array.
|
/// In array.
|
||||||
pub fn in_array(mut self, field: impl Into<String>, values: Vec<JsonValue>) -> Self {
|
pub fn in_array(mut self, field: impl Into<String>, values: Vec<JsonValue>) -> Self {
|
||||||
self.conditions.push(FilterCondition::In(field.into(), values));
|
self.conditions
|
||||||
|
.push(FilterCondition::In(field.into(), values));
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// String contains.
|
/// String contains.
|
||||||
pub fn contains(mut self, field: impl Into<String>, substring: impl Into<String>) -> Self {
|
pub fn contains(mut self, field: impl Into<String>, substring: impl Into<String>) -> Self {
|
||||||
self.conditions.push(FilterCondition::Contains(field.into(), substring.into()));
|
self.conditions
|
||||||
|
.push(FilterCondition::Contains(field.into(), substring.into()));
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Field exists.
|
/// Field exists.
|
||||||
pub fn exists(mut self, field: impl Into<String>, exists: bool) -> Self {
|
pub fn exists(mut self, field: impl Into<String>, exists: bool) -> Self {
|
||||||
self.conditions.push(FilterCondition::Exists(field.into(), exists));
|
self.conditions
|
||||||
|
.push(FilterCondition::Exists(field.into(), exists));
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -396,7 +413,9 @@ impl DocumentFilter {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
self.conditions.iter().all(|cond| self.eval_condition(cond, doc))
|
self.conditions
|
||||||
|
.iter()
|
||||||
|
.all(|cond| self.eval_condition(cond, doc))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn eval_condition(&self, cond: &FilterCondition, doc: &Document) -> bool {
|
fn eval_condition(&self, cond: &FilterCondition, doc: &Document) -> bool {
|
||||||
|
|
@ -419,27 +438,21 @@ impl DocumentFilter {
|
||||||
FilterCondition::Lte(field, value) => {
|
FilterCondition::Lte(field, value) => {
|
||||||
self.compare_values(doc.get_nested(field), value, |a, b| a <= b)
|
self.compare_values(doc.get_nested(field), value, |a, b| a <= b)
|
||||||
}
|
}
|
||||||
FilterCondition::In(field, values) => {
|
FilterCondition::In(field, values) => doc
|
||||||
doc.get_nested(field)
|
.get_nested(field)
|
||||||
.map(|v| values.contains(v))
|
.map(|v| values.contains(v))
|
||||||
.unwrap_or(false)
|
.unwrap_or(false),
|
||||||
}
|
FilterCondition::Contains(field, substring) => doc
|
||||||
FilterCondition::Contains(field, substring) => {
|
.get_nested(field)
|
||||||
doc.get_nested(field)
|
.and_then(|v| v.as_str())
|
||||||
.and_then(|v| v.as_str())
|
.map(|s| s.contains(substring))
|
||||||
.map(|s| s.contains(substring))
|
.unwrap_or(false),
|
||||||
.unwrap_or(false)
|
|
||||||
}
|
|
||||||
FilterCondition::Exists(field, should_exist) => {
|
FilterCondition::Exists(field, should_exist) => {
|
||||||
let exists = doc.get_nested(field).is_some();
|
let exists = doc.get_nested(field).is_some();
|
||||||
exists == *should_exist
|
exists == *should_exist
|
||||||
}
|
}
|
||||||
FilterCondition::And(filters) => {
|
FilterCondition::And(filters) => filters.iter().all(|f| f.matches(doc)),
|
||||||
filters.iter().all(|f| f.matches(doc))
|
FilterCondition::Or(filters) => filters.iter().any(|f| f.matches(doc)),
|
||||||
}
|
|
||||||
FilterCondition::Or(filters) => {
|
|
||||||
filters.iter().any(|f| f.matches(doc))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -448,12 +461,10 @@ impl DocumentFilter {
|
||||||
F: Fn(f64, f64) -> bool,
|
F: Fn(f64, f64) -> bool,
|
||||||
{
|
{
|
||||||
match (a, b) {
|
match (a, b) {
|
||||||
(Some(JsonValue::Number(a)), JsonValue::Number(b)) => {
|
(Some(JsonValue::Number(a)), JsonValue::Number(b)) => match (a.as_f64(), b.as_f64()) {
|
||||||
match (a.as_f64(), b.as_f64()) {
|
(Some(a), Some(b)) => cmp(a, b),
|
||||||
(Some(a), Some(b)) => cmp(a, b),
|
_ => false,
|
||||||
_ => false,
|
},
|
||||||
}
|
|
||||||
}
|
|
||||||
_ => false,
|
_ => false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -512,7 +523,11 @@ impl DocumentStore {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Finds documents in a collection.
|
/// Finds documents in a collection.
|
||||||
pub fn find(&self, collection: &str, filter: &DocumentFilter) -> Result<Vec<Document>, DatabaseError> {
|
pub fn find(
|
||||||
|
&self,
|
||||||
|
collection: &str,
|
||||||
|
filter: &DocumentFilter,
|
||||||
|
) -> Result<Vec<Document>, DatabaseError> {
|
||||||
let collections = self.collections.read();
|
let collections = self.collections.read();
|
||||||
let coll = collections
|
let coll = collections
|
||||||
.get(collection)
|
.get(collection)
|
||||||
|
|
@ -521,7 +536,11 @@ impl DocumentStore {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Finds one document.
|
/// Finds one document.
|
||||||
pub fn find_one(&self, collection: &str, filter: &DocumentFilter) -> Result<Option<Document>, DatabaseError> {
|
pub fn find_one(
|
||||||
|
&self,
|
||||||
|
collection: &str,
|
||||||
|
filter: &DocumentFilter,
|
||||||
|
) -> Result<Option<Document>, DatabaseError> {
|
||||||
let collections = self.collections.read();
|
let collections = self.collections.read();
|
||||||
let coll = collections
|
let coll = collections
|
||||||
.get(collection)
|
.get(collection)
|
||||||
|
|
@ -530,7 +549,11 @@ impl DocumentStore {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Finds a document by ID.
|
/// Finds a document by ID.
|
||||||
pub fn find_by_id(&self, collection: &str, id: &DocumentId) -> Result<Option<Document>, DatabaseError> {
|
pub fn find_by_id(
|
||||||
|
&self,
|
||||||
|
collection: &str,
|
||||||
|
id: &DocumentId,
|
||||||
|
) -> Result<Option<Document>, DatabaseError> {
|
||||||
let collections = self.collections.read();
|
let collections = self.collections.read();
|
||||||
let coll = collections
|
let coll = collections
|
||||||
.get(collection)
|
.get(collection)
|
||||||
|
|
@ -539,7 +562,12 @@ impl DocumentStore {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Updates a document by ID.
|
/// Updates a document by ID.
|
||||||
pub fn update_by_id(&self, collection: &str, id: &DocumentId, update: JsonValue) -> Result<bool, DatabaseError> {
|
pub fn update_by_id(
|
||||||
|
&self,
|
||||||
|
collection: &str,
|
||||||
|
id: &DocumentId,
|
||||||
|
update: JsonValue,
|
||||||
|
) -> Result<bool, DatabaseError> {
|
||||||
let collections = self.collections.read();
|
let collections = self.collections.read();
|
||||||
let coll = collections
|
let coll = collections
|
||||||
.get(collection)
|
.get(collection)
|
||||||
|
|
@ -584,7 +612,8 @@ mod tests {
|
||||||
fn test_collection_insert_find() {
|
fn test_collection_insert_find() {
|
||||||
let coll = Collection::new("users");
|
let coll = Collection::new("users");
|
||||||
|
|
||||||
coll.insert_one(json!({"name": "Alice", "age": 30})).unwrap();
|
coll.insert_one(json!({"name": "Alice", "age": 30}))
|
||||||
|
.unwrap();
|
||||||
coll.insert_one(json!({"name": "Bob", "age": 25})).unwrap();
|
coll.insert_one(json!({"name": "Bob", "age": 25})).unwrap();
|
||||||
|
|
||||||
let filter = DocumentFilter::new().eq("name", json!("Alice"));
|
let filter = DocumentFilter::new().eq("name", json!("Alice"));
|
||||||
|
|
@ -597,9 +626,11 @@ mod tests {
|
||||||
fn test_filter_comparison() {
|
fn test_filter_comparison() {
|
||||||
let coll = Collection::new("users");
|
let coll = Collection::new("users");
|
||||||
|
|
||||||
coll.insert_one(json!({"name": "Alice", "age": 30})).unwrap();
|
coll.insert_one(json!({"name": "Alice", "age": 30}))
|
||||||
|
.unwrap();
|
||||||
coll.insert_one(json!({"name": "Bob", "age": 25})).unwrap();
|
coll.insert_one(json!({"name": "Bob", "age": 25})).unwrap();
|
||||||
coll.insert_one(json!({"name": "Charlie", "age": 35})).unwrap();
|
coll.insert_one(json!({"name": "Charlie", "age": 35}))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let filter = DocumentFilter::new().gte("age", json!(30));
|
let filter = DocumentFilter::new().gte("age", json!(30));
|
||||||
let results = coll.find(&filter);
|
let results = coll.find(&filter);
|
||||||
|
|
@ -622,7 +653,9 @@ mod tests {
|
||||||
#[test]
|
#[test]
|
||||||
fn test_update_document() {
|
fn test_update_document() {
|
||||||
let coll = Collection::new("users");
|
let coll = Collection::new("users");
|
||||||
let id = coll.insert_one(json!({"name": "Alice", "age": 30})).unwrap();
|
let id = coll
|
||||||
|
.insert_one(json!({"name": "Alice", "age": 30}))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
coll.update_by_id(&id, json!({"age": 31})).unwrap();
|
coll.update_by_id(&id, json!({"age": 31})).unwrap();
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,9 @@
|
||||||
//! Authentication and authorization for Database Gateway.
|
//! Authentication and authorization for Database Gateway.
|
||||||
|
|
||||||
|
use parking_lot::RwLock;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use parking_lot::RwLock;
|
|
||||||
|
|
||||||
/// API key for authentication.
|
/// API key for authentication.
|
||||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
|
|
||||||
|
|
@ -272,19 +272,13 @@ pub fn json_to_filter(json: &JsonValue) -> Option<Filter> {
|
||||||
|
|
||||||
// Handle $and
|
// Handle $and
|
||||||
if let Some(and_arr) = obj.get("$and").and_then(|v| v.as_array()) {
|
if let Some(and_arr) = obj.get("$and").and_then(|v| v.as_array()) {
|
||||||
let filters: Vec<Filter> = and_arr
|
let filters: Vec<Filter> = and_arr.iter().filter_map(json_to_filter).collect();
|
||||||
.iter()
|
|
||||||
.filter_map(json_to_filter)
|
|
||||||
.collect();
|
|
||||||
return Some(Filter::And(filters));
|
return Some(Filter::And(filters));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle $or
|
// Handle $or
|
||||||
if let Some(or_arr) = obj.get("$or").and_then(|v| v.as_array()) {
|
if let Some(or_arr) = obj.get("$or").and_then(|v| v.as_array()) {
|
||||||
let filters: Vec<Filter> = or_arr
|
let filters: Vec<Filter> = or_arr.iter().filter_map(json_to_filter).collect();
|
||||||
.iter()
|
|
||||||
.filter_map(json_to_filter)
|
|
||||||
.collect();
|
|
||||||
return Some(Filter::Or(filters));
|
return Some(Filter::Or(filters));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -137,7 +137,12 @@ async fn kv_get(
|
||||||
// For demo, use a default database
|
// For demo, use a default database
|
||||||
let db = match get_default_database(&state) {
|
let db = match get_default_database(&state) {
|
||||||
Some(db) => db,
|
Some(db) => db,
|
||||||
None => return (StatusCode::NOT_FOUND, Json(ApiResponse::<KvGetResponse>::error("No database"))),
|
None => {
|
||||||
|
return (
|
||||||
|
StatusCode::NOT_FOUND,
|
||||||
|
Json(ApiResponse::<KvGetResponse>::error("No database")),
|
||||||
|
)
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
state.record_read();
|
state.record_read();
|
||||||
|
|
@ -153,7 +158,12 @@ async fn kv_set(
|
||||||
) -> impl IntoResponse {
|
) -> impl IntoResponse {
|
||||||
let db = match get_default_database(&state) {
|
let db = match get_default_database(&state) {
|
||||||
Some(db) => db,
|
Some(db) => db,
|
||||||
None => return (StatusCode::NOT_FOUND, Json(ApiResponse::<KvGetResponse>::error("No database"))),
|
None => {
|
||||||
|
return (
|
||||||
|
StatusCode::NOT_FOUND,
|
||||||
|
Json(ApiResponse::<KvGetResponse>::error("No database")),
|
||||||
|
)
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
state.record_write(req.value.len() as u64);
|
state.record_write(req.value.len() as u64);
|
||||||
|
|
@ -168,7 +178,12 @@ async fn kv_delete(
|
||||||
) -> impl IntoResponse {
|
) -> impl IntoResponse {
|
||||||
let db = match get_default_database(&state) {
|
let db = match get_default_database(&state) {
|
||||||
Some(db) => db,
|
Some(db) => db,
|
||||||
None => return (StatusCode::NOT_FOUND, Json(ApiResponse::<bool>::error("No database"))),
|
None => {
|
||||||
|
return (
|
||||||
|
StatusCode::NOT_FOUND,
|
||||||
|
Json(ApiResponse::<bool>::error("No database")),
|
||||||
|
)
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let response = handle_kv_delete(db.kv(), &key);
|
let response = handle_kv_delete(db.kv(), &key);
|
||||||
|
|
@ -182,7 +197,12 @@ async fn kv_batch(
|
||||||
) -> impl IntoResponse {
|
) -> impl IntoResponse {
|
||||||
let db = match get_default_database(&state) {
|
let db = match get_default_database(&state) {
|
||||||
Some(db) => db,
|
Some(db) => db,
|
||||||
None => return (StatusCode::NOT_FOUND, Json(ApiResponse::<KvBatchResponse>::error("No database"))),
|
None => {
|
||||||
|
return (
|
||||||
|
StatusCode::NOT_FOUND,
|
||||||
|
Json(ApiResponse::<KvBatchResponse>::error("No database")),
|
||||||
|
)
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let response = handle_kv_batch(db.kv(), req);
|
let response = handle_kv_batch(db.kv(), req);
|
||||||
|
|
@ -217,7 +237,9 @@ async fn list_databases(
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
Json(ApiResponse::ok(ListDatabasesResponse { databases: response }))
|
Json(ApiResponse::ok(ListDatabasesResponse {
|
||||||
|
databases: response,
|
||||||
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn create_database(
|
async fn create_database(
|
||||||
|
|
@ -250,7 +272,10 @@ async fn create_database(
|
||||||
})),
|
})),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
Err(e) => (StatusCode::BAD_REQUEST, Json(ApiResponse::error(e.to_string()))),
|
Err(e) => (
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
Json(ApiResponse::error(e.to_string())),
|
||||||
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -302,12 +327,20 @@ async fn create_collection(
|
||||||
) -> impl IntoResponse {
|
) -> impl IntoResponse {
|
||||||
let db = match get_database(&state, &db_name) {
|
let db = match get_database(&state, &db_name) {
|
||||||
Some(db) => db,
|
Some(db) => db,
|
||||||
None => return (StatusCode::NOT_FOUND, Json(ApiResponse::<bool>::error("Database not found"))),
|
None => {
|
||||||
|
return (
|
||||||
|
StatusCode::NOT_FOUND,
|
||||||
|
Json(ApiResponse::<bool>::error("Database not found")),
|
||||||
|
)
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
match db.documents().create_collection(&req.name) {
|
match db.documents().create_collection(&req.name) {
|
||||||
Ok(_) => (StatusCode::CREATED, Json(ApiResponse::ok(true))),
|
Ok(_) => (StatusCode::CREATED, Json(ApiResponse::ok(true))),
|
||||||
Err(e) => (StatusCode::BAD_REQUEST, Json(ApiResponse::<bool>::error(e.to_string()))),
|
Err(e) => (
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
Json(ApiResponse::<bool>::error(e.to_string())),
|
||||||
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -349,15 +382,20 @@ async fn query_documents(
|
||||||
) -> impl IntoResponse {
|
) -> impl IntoResponse {
|
||||||
let db = match get_database(&state, &db_name) {
|
let db = match get_database(&state, &db_name) {
|
||||||
Some(db) => db,
|
Some(db) => db,
|
||||||
None => return (StatusCode::NOT_FOUND, Json(ApiResponse::<QueryDocumentsResponse>::error("Database not found"))),
|
None => {
|
||||||
|
return (
|
||||||
|
StatusCode::NOT_FOUND,
|
||||||
|
Json(ApiResponse::<QueryDocumentsResponse>::error(
|
||||||
|
"Database not found",
|
||||||
|
)),
|
||||||
|
)
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
state.record_read();
|
state.record_read();
|
||||||
|
|
||||||
// Build query
|
// Build query
|
||||||
let mut query = Query::new(&coll_name)
|
let mut query = Query::new(&coll_name).skip(req.skip).limit(req.limit);
|
||||||
.skip(req.skip)
|
|
||||||
.limit(req.limit);
|
|
||||||
|
|
||||||
// Add filter
|
// Add filter
|
||||||
if let Some(filter_json) = &req.filter {
|
if let Some(filter_json) = &req.filter {
|
||||||
|
|
@ -384,8 +422,14 @@ async fn query_documents(
|
||||||
}
|
}
|
||||||
|
|
||||||
match db.query().execute(&query) {
|
match db.query().execute(&query) {
|
||||||
Ok(result) => (StatusCode::OK, Json(ApiResponse::ok(QueryDocumentsResponse::from(result)))),
|
Ok(result) => (
|
||||||
Err(e) => (StatusCode::BAD_REQUEST, Json(ApiResponse::error(e.to_string()))),
|
StatusCode::OK,
|
||||||
|
Json(ApiResponse::ok(QueryDocumentsResponse::from(result))),
|
||||||
|
),
|
||||||
|
Err(e) => (
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
Json(ApiResponse::error(e.to_string())),
|
||||||
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -396,15 +440,25 @@ async fn insert_document(
|
||||||
) -> impl IntoResponse {
|
) -> impl IntoResponse {
|
||||||
let db = match get_database(&state, &db_name) {
|
let db = match get_database(&state, &db_name) {
|
||||||
Some(db) => db,
|
Some(db) => db,
|
||||||
None => return (StatusCode::NOT_FOUND, Json(ApiResponse::<String>::error("Database not found"))),
|
None => {
|
||||||
|
return (
|
||||||
|
StatusCode::NOT_FOUND,
|
||||||
|
Json(ApiResponse::<String>::error("Database not found")),
|
||||||
|
)
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let size = serde_json::to_vec(&req.document).map(|v| v.len()).unwrap_or(0);
|
let size = serde_json::to_vec(&req.document)
|
||||||
|
.map(|v| v.len())
|
||||||
|
.unwrap_or(0);
|
||||||
state.record_write(size as u64);
|
state.record_write(size as u64);
|
||||||
|
|
||||||
match db.documents().insert(&coll_name, req.document) {
|
match db.documents().insert(&coll_name, req.document) {
|
||||||
Ok(id) => (StatusCode::CREATED, Json(ApiResponse::ok(id.to_hex()))),
|
Ok(id) => (StatusCode::CREATED, Json(ApiResponse::ok(id.to_hex()))),
|
||||||
Err(e) => (StatusCode::BAD_REQUEST, Json(ApiResponse::error(e.to_string()))),
|
Err(e) => (
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
Json(ApiResponse::error(e.to_string())),
|
||||||
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -415,7 +469,12 @@ async fn insert_many_documents(
|
||||||
) -> impl IntoResponse {
|
) -> impl IntoResponse {
|
||||||
let db = match get_database(&state, &db_name) {
|
let db = match get_database(&state, &db_name) {
|
||||||
Some(db) => db,
|
Some(db) => db,
|
||||||
None => return (StatusCode::NOT_FOUND, Json(ApiResponse::<Vec<String>>::error("Database not found"))),
|
None => {
|
||||||
|
return (
|
||||||
|
StatusCode::NOT_FOUND,
|
||||||
|
Json(ApiResponse::<Vec<String>>::error("Database not found")),
|
||||||
|
)
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut ids = Vec::with_capacity(req.documents.len());
|
let mut ids = Vec::with_capacity(req.documents.len());
|
||||||
|
|
@ -425,7 +484,12 @@ async fn insert_many_documents(
|
||||||
total_size += serde_json::to_vec(&doc).map(|v| v.len()).unwrap_or(0) as u64;
|
total_size += serde_json::to_vec(&doc).map(|v| v.len()).unwrap_or(0) as u64;
|
||||||
match db.documents().insert(&coll_name, doc) {
|
match db.documents().insert(&coll_name, doc) {
|
||||||
Ok(id) => ids.push(id.to_hex()),
|
Ok(id) => ids.push(id.to_hex()),
|
||||||
Err(e) => return (StatusCode::BAD_REQUEST, Json(ApiResponse::error(e.to_string()))),
|
Err(e) => {
|
||||||
|
return (
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
Json(ApiResponse::error(e.to_string())),
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -477,7 +541,9 @@ async fn update_document(
|
||||||
Err(e) => return Json(ApiResponse::error(e.to_string())),
|
Err(e) => return Json(ApiResponse::error(e.to_string())),
|
||||||
};
|
};
|
||||||
|
|
||||||
let update_size = serde_json::to_vec(&req.update).map(|v| v.len()).unwrap_or(0);
|
let update_size = serde_json::to_vec(&req.update)
|
||||||
|
.map(|v| v.len())
|
||||||
|
.unwrap_or(0);
|
||||||
state.record_write(update_size as u64);
|
state.record_write(update_size as u64);
|
||||||
|
|
||||||
match db.documents().update_by_id(&coll_name, &id, req.update) {
|
match db.documents().update_by_id(&coll_name, &id, req.update) {
|
||||||
|
|
@ -519,7 +585,12 @@ async fn insert_embeddings(
|
||||||
) -> impl IntoResponse {
|
) -> impl IntoResponse {
|
||||||
let db = match get_database(&state, &db_name) {
|
let db = match get_database(&state, &db_name) {
|
||||||
Some(db) => db,
|
Some(db) => db,
|
||||||
None => return (StatusCode::NOT_FOUND, Json(ApiResponse::<usize>::error("Database not found"))),
|
None => {
|
||||||
|
return (
|
||||||
|
StatusCode::NOT_FOUND,
|
||||||
|
Json(ApiResponse::<usize>::error("Database not found")),
|
||||||
|
)
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut count = 0;
|
let mut count = 0;
|
||||||
|
|
@ -533,7 +604,10 @@ async fn insert_embeddings(
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Err(e) = db.vectors().insert(embedding) {
|
if let Err(e) = db.vectors().insert(embedding) {
|
||||||
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error(e.to_string())));
|
return (
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
Json(ApiResponse::error(e.to_string())),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
count += 1;
|
count += 1;
|
||||||
}
|
}
|
||||||
|
|
@ -549,7 +623,14 @@ async fn vector_search(
|
||||||
) -> impl IntoResponse {
|
) -> impl IntoResponse {
|
||||||
let db = match get_database(&state, &db_name) {
|
let db = match get_database(&state, &db_name) {
|
||||||
Some(db) => db,
|
Some(db) => db,
|
||||||
None => return (StatusCode::NOT_FOUND, Json(ApiResponse::<VectorSearchResponse>::error("Database not found"))),
|
None => {
|
||||||
|
return (
|
||||||
|
StatusCode::NOT_FOUND,
|
||||||
|
Json(ApiResponse::<VectorSearchResponse>::error(
|
||||||
|
"Database not found",
|
||||||
|
)),
|
||||||
|
)
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
state.record_vector_search();
|
state.record_vector_search();
|
||||||
|
|
@ -563,9 +644,18 @@ async fn vector_search(
|
||||||
Ok(results) => {
|
Ok(results) => {
|
||||||
let count = results.len();
|
let count = results.len();
|
||||||
let matches: Vec<VectorMatch> = results.into_iter().map(Into::into).collect();
|
let matches: Vec<VectorMatch> = results.into_iter().map(Into::into).collect();
|
||||||
(StatusCode::OK, Json(ApiResponse::ok(VectorSearchResponse { results: matches, count })))
|
(
|
||||||
|
StatusCode::OK,
|
||||||
|
Json(ApiResponse::ok(VectorSearchResponse {
|
||||||
|
results: matches,
|
||||||
|
count,
|
||||||
|
})),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
Err(e) => (StatusCode::BAD_REQUEST, Json(ApiResponse::error(e.to_string()))),
|
Err(e) => (
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
Json(ApiResponse::error(e.to_string())),
|
||||||
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -114,10 +114,7 @@ impl GatewayServer {
|
||||||
tracing::error!("Failed to create default database: {}", e);
|
tracing::error!("Failed to create default database: {}", e);
|
||||||
}
|
}
|
||||||
|
|
||||||
let state = Arc::new(AppState::new(
|
let state = Arc::new(AppState::new(self.db_manager.clone(), self.auth.clone()));
|
||||||
self.db_manager.clone(),
|
|
||||||
self.auth.clone(),
|
|
||||||
));
|
|
||||||
|
|
||||||
let app = create_router(state);
|
let app = create_router(state);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -86,7 +86,12 @@ pub struct Edge {
|
||||||
|
|
||||||
impl Edge {
|
impl Edge {
|
||||||
/// Creates a new directed edge.
|
/// Creates a new directed edge.
|
||||||
pub fn new(source: NodeId, target: NodeId, edge_type: impl Into<String>, properties: JsonValue) -> Self {
|
pub fn new(
|
||||||
|
source: NodeId,
|
||||||
|
target: NodeId,
|
||||||
|
edge_type: impl Into<String>,
|
||||||
|
properties: JsonValue,
|
||||||
|
) -> Self {
|
||||||
let now = std::time::SystemTime::now()
|
let now = std::time::SystemTime::now()
|
||||||
.duration_since(std::time::UNIX_EPOCH)
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
|
|
@ -105,7 +110,12 @@ impl Edge {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates an undirected edge.
|
/// Creates an undirected edge.
|
||||||
pub fn undirected(source: NodeId, target: NodeId, edge_type: impl Into<String>, properties: JsonValue) -> Self {
|
pub fn undirected(
|
||||||
|
source: NodeId,
|
||||||
|
target: NodeId,
|
||||||
|
edge_type: impl Into<String>,
|
||||||
|
properties: JsonValue,
|
||||||
|
) -> Self {
|
||||||
let mut edge = Self::new(source, target, edge_type, properties);
|
let mut edge = Self::new(source, target, edge_type, properties);
|
||||||
edge.directed = false;
|
edge.directed = false;
|
||||||
edge
|
edge
|
||||||
|
|
@ -138,8 +148,8 @@ impl Edge {
|
||||||
|
|
||||||
/// Checks if this edge connects two specific nodes.
|
/// Checks if this edge connects two specific nodes.
|
||||||
pub fn connects_pair(&self, a: &NodeId, b: &NodeId) -> bool {
|
pub fn connects_pair(&self, a: &NodeId, b: &NodeId) -> bool {
|
||||||
(&self.source == a && &self.target == b) ||
|
(&self.source == a && &self.target == b)
|
||||||
(!self.directed && &self.source == b && &self.target == a)
|
|| (!self.directed && &self.source == b && &self.target == a)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Gets a property value.
|
/// Gets a property value.
|
||||||
|
|
@ -156,7 +166,9 @@ impl Edge {
|
||||||
|
|
||||||
/// Checks if the edge matches a property filter.
|
/// Checks if the edge matches a property filter.
|
||||||
pub fn matches_properties(&self, filter: &JsonValue) -> bool {
|
pub fn matches_properties(&self, filter: &JsonValue) -> bool {
|
||||||
if let (Some(filter_obj), Some(props_obj)) = (filter.as_object(), self.properties.as_object()) {
|
if let (Some(filter_obj), Some(props_obj)) =
|
||||||
|
(filter.as_object(), self.properties.as_object())
|
||||||
|
{
|
||||||
for (key, expected) in filter_obj {
|
for (key, expected) in filter_obj {
|
||||||
if let Some(actual) = props_obj.get(key) {
|
if let Some(actual) = props_obj.get(key) {
|
||||||
if actual != expected {
|
if actual != expected {
|
||||||
|
|
@ -216,7 +228,12 @@ impl EdgeBuilder {
|
||||||
|
|
||||||
/// Builds the edge.
|
/// Builds the edge.
|
||||||
pub fn build(self) -> Edge {
|
pub fn build(self) -> Edge {
|
||||||
let mut edge = Edge::new(self.source, self.target, self.edge_type, JsonValue::Object(self.properties));
|
let mut edge = Edge::new(
|
||||||
|
self.source,
|
||||||
|
self.target,
|
||||||
|
self.edge_type,
|
||||||
|
JsonValue::Object(self.properties),
|
||||||
|
);
|
||||||
edge.directed = self.directed;
|
edge.directed = self.directed;
|
||||||
edge.weight = self.weight;
|
edge.weight = self.weight;
|
||||||
edge
|
edge
|
||||||
|
|
@ -264,7 +281,10 @@ mod tests {
|
||||||
|
|
||||||
assert!(!edge.directed);
|
assert!(!edge.directed);
|
||||||
assert_eq!(edge.weight, 2.5);
|
assert_eq!(edge.weight, 2.5);
|
||||||
assert_eq!(edge.get_property("percentage"), Some(&serde_json::json!(50)));
|
assert_eq!(
|
||||||
|
edge.get_property("percentage"),
|
||||||
|
Some(&serde_json::json!(50))
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
|
||||||
|
|
@ -172,7 +172,9 @@ impl Node {
|
||||||
|
|
||||||
/// Checks if the node matches a property filter.
|
/// Checks if the node matches a property filter.
|
||||||
pub fn matches_properties(&self, filter: &JsonValue) -> bool {
|
pub fn matches_properties(&self, filter: &JsonValue) -> bool {
|
||||||
if let (Some(filter_obj), Some(props_obj)) = (filter.as_object(), self.properties.as_object()) {
|
if let (Some(filter_obj), Some(props_obj)) =
|
||||||
|
(filter.as_object(), self.properties.as_object())
|
||||||
|
{
|
||||||
for (key, expected) in filter_obj {
|
for (key, expected) in filter_obj {
|
||||||
if let Some(actual) = props_obj.get(key) {
|
if let Some(actual) = props_obj.get(key) {
|
||||||
if actual != expected {
|
if actual != expected {
|
||||||
|
|
@ -258,10 +260,7 @@ mod tests {
|
||||||
|
|
||||||
assert!(node.has_label("User"));
|
assert!(node.has_label("User"));
|
||||||
assert!(!node.has_label("Admin"));
|
assert!(!node.has_label("Admin"));
|
||||||
assert_eq!(
|
assert_eq!(node.get_property("name"), Some(&serde_json::json!("Alice")));
|
||||||
node.get_property("name"),
|
|
||||||
Some(&serde_json::json!("Alice"))
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
|
||||||
|
|
@ -60,7 +60,10 @@ impl Eq for DijkstraState {}
|
||||||
impl Ord for DijkstraState {
|
impl Ord for DijkstraState {
|
||||||
fn cmp(&self, other: &Self) -> Ordering {
|
fn cmp(&self, other: &Self) -> Ordering {
|
||||||
// Reverse ordering for min-heap
|
// Reverse ordering for min-heap
|
||||||
other.distance.partial_cmp(&self.distance).unwrap_or(Ordering::Equal)
|
other
|
||||||
|
.distance
|
||||||
|
.partial_cmp(&self.distance)
|
||||||
|
.unwrap_or(Ordering::Equal)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -140,9 +143,16 @@ impl<'a> PathFinder<'a> {
|
||||||
let mut visited = HashSet::new();
|
let mut visited = HashSet::new();
|
||||||
|
|
||||||
distances.insert(*from, 0.0);
|
distances.insert(*from, 0.0);
|
||||||
heap.push(DijkstraState { node: *from, distance: 0.0 });
|
heap.push(DijkstraState {
|
||||||
|
node: *from,
|
||||||
|
distance: 0.0,
|
||||||
|
});
|
||||||
|
|
||||||
while let Some(DijkstraState { node: current, distance: dist }) = heap.pop() {
|
while let Some(DijkstraState {
|
||||||
|
node: current,
|
||||||
|
distance: dist,
|
||||||
|
}) = heap.pop()
|
||||||
|
{
|
||||||
if ¤t == to {
|
if ¤t == to {
|
||||||
// Reconstruct path
|
// Reconstruct path
|
||||||
let mut path = vec![current];
|
let mut path = vec![current];
|
||||||
|
|
@ -181,12 +191,18 @@ impl<'a> PathFinder<'a> {
|
||||||
}
|
}
|
||||||
|
|
||||||
let new_dist = dist + edge.weight;
|
let new_dist = dist + edge.weight;
|
||||||
let is_shorter = distances.get(&neighbor).map(|&d| new_dist < d).unwrap_or(true);
|
let is_shorter = distances
|
||||||
|
.get(&neighbor)
|
||||||
|
.map(|&d| new_dist < d)
|
||||||
|
.unwrap_or(true);
|
||||||
|
|
||||||
if is_shorter {
|
if is_shorter {
|
||||||
distances.insert(neighbor, new_dist);
|
distances.insert(neighbor, new_dist);
|
||||||
previous.insert(neighbor, (current, edge.clone()));
|
previous.insert(neighbor, (current, edge.clone()));
|
||||||
heap.push(DijkstraState { node: neighbor, distance: new_dist });
|
heap.push(DijkstraState {
|
||||||
|
node: neighbor,
|
||||||
|
distance: new_dist,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -251,7 +267,9 @@ impl<'a> PathFinder<'a> {
|
||||||
path.push(neighbor);
|
path.push(neighbor);
|
||||||
edges.push(edge.clone());
|
edges.push(edge.clone());
|
||||||
|
|
||||||
self.find_all_paths_dfs(&neighbor, target, max_length, path, edges, visited, results);
|
self.find_all_paths_dfs(
|
||||||
|
&neighbor, target, max_length, path, edges, visited, results,
|
||||||
|
);
|
||||||
|
|
||||||
path.pop();
|
path.pop();
|
||||||
edges.pop();
|
edges.pop();
|
||||||
|
|
@ -261,7 +279,12 @@ impl<'a> PathFinder<'a> {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Finds the shortest path considering only specific edge types.
|
/// Finds the shortest path considering only specific edge types.
|
||||||
pub fn shortest_path_by_type(&self, from: &NodeId, to: &NodeId, edge_types: &[String]) -> PathResult {
|
pub fn shortest_path_by_type(
|
||||||
|
&self,
|
||||||
|
from: &NodeId,
|
||||||
|
to: &NodeId,
|
||||||
|
edge_types: &[String],
|
||||||
|
) -> PathResult {
|
||||||
if from == to {
|
if from == to {
|
||||||
return PathResult::found(vec![*from], Vec::new(), 0.0);
|
return PathResult::found(vec![*from], Vec::new(), 0.0);
|
||||||
}
|
}
|
||||||
|
|
@ -377,11 +400,21 @@ mod tests {
|
||||||
let d = store.create_node(vec![], serde_json::json!({"name": "D"}));
|
let d = store.create_node(vec![], serde_json::json!({"name": "D"}));
|
||||||
let e = store.create_node(vec![], serde_json::json!({"name": "E"}));
|
let e = store.create_node(vec![], serde_json::json!({"name": "E"}));
|
||||||
|
|
||||||
store.create_edge(a, b, "LINK", serde_json::json!({})).unwrap();
|
store
|
||||||
store.create_edge(b, c, "LINK", serde_json::json!({})).unwrap();
|
.create_edge(a, b, "LINK", serde_json::json!({}))
|
||||||
store.create_edge(c, d, "LINK", serde_json::json!({})).unwrap();
|
.unwrap();
|
||||||
store.create_edge(a, e, "LINK", serde_json::json!({})).unwrap();
|
store
|
||||||
store.create_edge(e, d, "LINK", serde_json::json!({})).unwrap();
|
.create_edge(b, c, "LINK", serde_json::json!({}))
|
||||||
|
.unwrap();
|
||||||
|
store
|
||||||
|
.create_edge(c, d, "LINK", serde_json::json!({}))
|
||||||
|
.unwrap();
|
||||||
|
store
|
||||||
|
.create_edge(a, e, "LINK", serde_json::json!({}))
|
||||||
|
.unwrap();
|
||||||
|
store
|
||||||
|
.create_edge(e, d, "LINK", serde_json::json!({}))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
store
|
store
|
||||||
}
|
}
|
||||||
|
|
@ -392,8 +425,14 @@ mod tests {
|
||||||
let finder = PathFinder::new(&store);
|
let finder = PathFinder::new(&store);
|
||||||
|
|
||||||
let nodes: Vec<_> = store.find_nodes(None, &serde_json::json!({}));
|
let nodes: Vec<_> = store.find_nodes(None, &serde_json::json!({}));
|
||||||
let a = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("A"))).unwrap();
|
let a = nodes
|
||||||
let d = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("D"))).unwrap();
|
.iter()
|
||||||
|
.find(|n| n.get_property("name") == Some(&serde_json::json!("A")))
|
||||||
|
.unwrap();
|
||||||
|
let d = nodes
|
||||||
|
.iter()
|
||||||
|
.find(|n| n.get_property("name") == Some(&serde_json::json!("D")))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let result = finder.shortest_path_bfs(&a.id, &d.id);
|
let result = finder.shortest_path_bfs(&a.id, &d.id);
|
||||||
|
|
||||||
|
|
@ -413,13 +452,19 @@ mod tests {
|
||||||
// A --(3.0)--> C
|
// A --(3.0)--> C
|
||||||
let mut edge1 = super::super::edge::Edge::new(a, b, "LINK", serde_json::json!({}));
|
let mut edge1 = super::super::edge::Edge::new(a, b, "LINK", serde_json::json!({}));
|
||||||
edge1.weight = 1.0;
|
edge1.weight = 1.0;
|
||||||
store.create_edge(a, b, "LINK", serde_json::json!({})).unwrap();
|
store
|
||||||
|
.create_edge(a, b, "LINK", serde_json::json!({}))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let mut edge2 = super::super::edge::Edge::new(b, c, "LINK", serde_json::json!({}));
|
let mut edge2 = super::super::edge::Edge::new(b, c, "LINK", serde_json::json!({}));
|
||||||
edge2.weight = 1.0;
|
edge2.weight = 1.0;
|
||||||
store.create_edge(b, c, "LINK", serde_json::json!({})).unwrap();
|
store
|
||||||
|
.create_edge(b, c, "LINK", serde_json::json!({}))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
store.create_edge(a, c, "DIRECT", serde_json::json!({})).unwrap();
|
store
|
||||||
|
.create_edge(a, c, "DIRECT", serde_json::json!({}))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let finder = PathFinder::new(&store);
|
let finder = PathFinder::new(&store);
|
||||||
let result = finder.shortest_path_dijkstra(&a, &c);
|
let result = finder.shortest_path_dijkstra(&a, &c);
|
||||||
|
|
@ -449,8 +494,14 @@ mod tests {
|
||||||
let finder = PathFinder::new(&store);
|
let finder = PathFinder::new(&store);
|
||||||
|
|
||||||
let nodes: Vec<_> = store.find_nodes(None, &serde_json::json!({}));
|
let nodes: Vec<_> = store.find_nodes(None, &serde_json::json!({}));
|
||||||
let a = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("A"))).unwrap();
|
let a = nodes
|
||||||
let d = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("D"))).unwrap();
|
.iter()
|
||||||
|
.find(|n| n.get_property("name") == Some(&serde_json::json!("A")))
|
||||||
|
.unwrap();
|
||||||
|
let d = nodes
|
||||||
|
.iter()
|
||||||
|
.find(|n| n.get_property("name") == Some(&serde_json::json!("D")))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let paths = finder.all_paths(&a.id, &d.id, 5);
|
let paths = finder.all_paths(&a.id, &d.id, 5);
|
||||||
|
|
||||||
|
|
@ -463,8 +514,14 @@ mod tests {
|
||||||
let finder = PathFinder::new(&store);
|
let finder = PathFinder::new(&store);
|
||||||
|
|
||||||
let nodes: Vec<_> = store.find_nodes(None, &serde_json::json!({}));
|
let nodes: Vec<_> = store.find_nodes(None, &serde_json::json!({}));
|
||||||
let a = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("A"))).unwrap();
|
let a = nodes
|
||||||
let d = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("D"))).unwrap();
|
.iter()
|
||||||
|
.find(|n| n.get_property("name") == Some(&serde_json::json!("A")))
|
||||||
|
.unwrap();
|
||||||
|
let d = nodes
|
||||||
|
.iter()
|
||||||
|
.find(|n| n.get_property("name") == Some(&serde_json::json!("D")))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
assert!(finder.path_exists(&a.id, &d.id));
|
assert!(finder.path_exists(&a.id, &d.id));
|
||||||
}
|
}
|
||||||
|
|
@ -475,9 +532,18 @@ mod tests {
|
||||||
let finder = PathFinder::new(&store);
|
let finder = PathFinder::new(&store);
|
||||||
|
|
||||||
let nodes: Vec<_> = store.find_nodes(None, &serde_json::json!({}));
|
let nodes: Vec<_> = store.find_nodes(None, &serde_json::json!({}));
|
||||||
let a = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("A"))).unwrap();
|
let a = nodes
|
||||||
let d = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("D"))).unwrap();
|
.iter()
|
||||||
let b = nodes.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("B"))).unwrap();
|
.find(|n| n.get_property("name") == Some(&serde_json::json!("A")))
|
||||||
|
.unwrap();
|
||||||
|
let d = nodes
|
||||||
|
.iter()
|
||||||
|
.find(|n| n.get_property("name") == Some(&serde_json::json!("D")))
|
||||||
|
.unwrap();
|
||||||
|
let b = nodes
|
||||||
|
.iter()
|
||||||
|
.find(|n| n.get_property("name") == Some(&serde_json::json!("B")))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
assert_eq!(finder.distance(&a.id, &b.id), Some(1));
|
assert_eq!(finder.distance(&a.id, &b.id), Some(1));
|
||||||
assert_eq!(finder.distance(&a.id, &d.id), Some(2)); // A -> E -> D
|
assert_eq!(finder.distance(&a.id, &d.id), Some(2)); // A -> E -> D
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,10 @@ pub enum GraphQuery {
|
||||||
/// DELETE query for removing nodes/edges.
|
/// DELETE query for removing nodes/edges.
|
||||||
Delete { variable: String, detach: bool },
|
Delete { variable: String, detach: bool },
|
||||||
/// SET query for updating properties.
|
/// SET query for updating properties.
|
||||||
Set { variable: String, properties: JsonValue },
|
Set {
|
||||||
|
variable: String,
|
||||||
|
properties: JsonValue,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Pattern to match in the graph.
|
/// Pattern to match in the graph.
|
||||||
|
|
@ -78,15 +81,35 @@ pub enum RelationshipDirection {
|
||||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
pub enum WhereClause {
|
pub enum WhereClause {
|
||||||
/// Property comparison.
|
/// Property comparison.
|
||||||
PropertyEquals { variable: String, property: String, value: JsonValue },
|
PropertyEquals {
|
||||||
|
variable: String,
|
||||||
|
property: String,
|
||||||
|
value: JsonValue,
|
||||||
|
},
|
||||||
/// Property comparison (not equals).
|
/// Property comparison (not equals).
|
||||||
PropertyNotEquals { variable: String, property: String, value: JsonValue },
|
PropertyNotEquals {
|
||||||
|
variable: String,
|
||||||
|
property: String,
|
||||||
|
value: JsonValue,
|
||||||
|
},
|
||||||
/// Property greater than.
|
/// Property greater than.
|
||||||
PropertyGt { variable: String, property: String, value: JsonValue },
|
PropertyGt {
|
||||||
|
variable: String,
|
||||||
|
property: String,
|
||||||
|
value: JsonValue,
|
||||||
|
},
|
||||||
/// Property less than.
|
/// Property less than.
|
||||||
PropertyLt { variable: String, property: String, value: JsonValue },
|
PropertyLt {
|
||||||
|
variable: String,
|
||||||
|
property: String,
|
||||||
|
value: JsonValue,
|
||||||
|
},
|
||||||
/// Property contains (for text).
|
/// Property contains (for text).
|
||||||
PropertyContains { variable: String, property: String, value: String },
|
PropertyContains {
|
||||||
|
variable: String,
|
||||||
|
property: String,
|
||||||
|
value: String,
|
||||||
|
},
|
||||||
/// AND condition.
|
/// AND condition.
|
||||||
And(Box<WhereClause>, Box<WhereClause>),
|
And(Box<WhereClause>, Box<WhereClause>),
|
||||||
/// OR condition.
|
/// OR condition.
|
||||||
|
|
@ -105,7 +128,10 @@ pub enum ReturnItem {
|
||||||
/// Return a property of a variable.
|
/// Return a property of a variable.
|
||||||
Property { variable: String, property: String },
|
Property { variable: String, property: String },
|
||||||
/// Return with an alias.
|
/// Return with an alias.
|
||||||
Alias { item: Box<ReturnItem>, alias: String },
|
Alias {
|
||||||
|
item: Box<ReturnItem>,
|
||||||
|
alias: String,
|
||||||
|
},
|
||||||
/// Count aggregation.
|
/// Count aggregation.
|
||||||
Count(Option<String>),
|
Count(Option<String>),
|
||||||
}
|
}
|
||||||
|
|
@ -114,7 +140,11 @@ pub enum ReturnItem {
|
||||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
pub enum CreateElement {
|
pub enum CreateElement {
|
||||||
/// Create a node.
|
/// Create a node.
|
||||||
Node { variable: Option<String>, labels: Vec<String>, properties: JsonValue },
|
Node {
|
||||||
|
variable: Option<String>,
|
||||||
|
labels: Vec<String>,
|
||||||
|
properties: JsonValue,
|
||||||
|
},
|
||||||
/// Create a relationship.
|
/// Create a relationship.
|
||||||
Relationship {
|
Relationship {
|
||||||
from_var: String,
|
from_var: String,
|
||||||
|
|
@ -176,7 +206,10 @@ impl GraphQueryParser {
|
||||||
} else if upper.starts_with("SET") {
|
} else if upper.starts_with("SET") {
|
||||||
Self::parse_set(query)
|
Self::parse_set(query)
|
||||||
} else {
|
} else {
|
||||||
Err(GraphError::InvalidOperation(format!("Unknown query type: {}", query)))
|
Err(GraphError::InvalidOperation(format!(
|
||||||
|
"Unknown query type: {}",
|
||||||
|
query
|
||||||
|
)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -185,7 +218,10 @@ impl GraphQueryParser {
|
||||||
let upper = query.to_uppercase();
|
let upper = query.to_uppercase();
|
||||||
|
|
||||||
// Find MATCH, WHERE, RETURN, LIMIT positions
|
// Find MATCH, WHERE, RETURN, LIMIT positions
|
||||||
let match_end = upper.find("WHERE").or_else(|| upper.find("RETURN")).unwrap_or(query.len());
|
let match_end = upper
|
||||||
|
.find("WHERE")
|
||||||
|
.or_else(|| upper.find("RETURN"))
|
||||||
|
.unwrap_or(query.len());
|
||||||
let where_start = upper.find("WHERE");
|
let where_start = upper.find("WHERE");
|
||||||
let return_start = upper.find("RETURN");
|
let return_start = upper.find("RETURN");
|
||||||
let limit_start = upper.find("LIMIT");
|
let limit_start = upper.find("LIMIT");
|
||||||
|
|
@ -253,7 +289,9 @@ impl GraphQueryParser {
|
||||||
}
|
}
|
||||||
|
|
||||||
if nodes.is_empty() {
|
if nodes.is_empty() {
|
||||||
return Err(GraphError::InvalidOperation("No node pattern found".to_string()));
|
return Err(GraphError::InvalidOperation(
|
||||||
|
"No node pattern found".to_string(),
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Combine nodes with relationships
|
// Combine nodes with relationships
|
||||||
|
|
@ -264,10 +302,15 @@ impl GraphQueryParser {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(MatchPattern { start, relationships })
|
Ok(MatchPattern {
|
||||||
|
start,
|
||||||
|
relationships,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parse_node_pattern(chars: &mut std::iter::Peekable<std::str::Chars>) -> Result<NodePattern, GraphError> {
|
fn parse_node_pattern(
|
||||||
|
chars: &mut std::iter::Peekable<std::str::Chars>,
|
||||||
|
) -> Result<NodePattern, GraphError> {
|
||||||
// Consume '('
|
// Consume '('
|
||||||
chars.next();
|
chars.next();
|
||||||
|
|
||||||
|
|
@ -335,10 +378,16 @@ impl GraphQueryParser {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(NodePattern { variable, labels, properties })
|
Ok(NodePattern {
|
||||||
|
variable,
|
||||||
|
labels,
|
||||||
|
properties,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parse_relationship_pattern(chars: &mut std::iter::Peekable<std::str::Chars>) -> Result<RelationshipPattern, GraphError> {
|
fn parse_relationship_pattern(
|
||||||
|
chars: &mut std::iter::Peekable<std::str::Chars>,
|
||||||
|
) -> Result<RelationshipPattern, GraphError> {
|
||||||
let mut direction = RelationshipDirection::Undirected;
|
let mut direction = RelationshipDirection::Undirected;
|
||||||
let mut edge_type = None;
|
let mut edge_type = None;
|
||||||
let mut variable = None;
|
let mut variable = None;
|
||||||
|
|
@ -408,7 +457,11 @@ impl GraphQueryParser {
|
||||||
variable,
|
variable,
|
||||||
edge_type,
|
edge_type,
|
||||||
direction,
|
direction,
|
||||||
target: NodePattern { variable: None, labels: Vec::new(), properties: None },
|
target: NodePattern {
|
||||||
|
variable: None,
|
||||||
|
labels: Vec::new(),
|
||||||
|
properties: None,
|
||||||
|
},
|
||||||
min_hops,
|
min_hops,
|
||||||
max_hops,
|
max_hops,
|
||||||
})
|
})
|
||||||
|
|
@ -476,7 +529,9 @@ impl GraphQueryParser {
|
||||||
elements.push(CreateElement::Node {
|
elements.push(CreateElement::Node {
|
||||||
variable: node.variable,
|
variable: node.variable,
|
||||||
labels: node.labels,
|
labels: node.labels,
|
||||||
properties: node.properties.unwrap_or(JsonValue::Object(serde_json::Map::new())),
|
properties: node
|
||||||
|
.properties
|
||||||
|
.unwrap_or(JsonValue::Object(serde_json::Map::new())),
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
break;
|
break;
|
||||||
|
|
@ -488,7 +543,11 @@ impl GraphQueryParser {
|
||||||
|
|
||||||
fn parse_delete(query: &str) -> Result<GraphQuery, GraphError> {
|
fn parse_delete(query: &str) -> Result<GraphQuery, GraphError> {
|
||||||
let detach = query.to_uppercase().starts_with("DETACH");
|
let detach = query.to_uppercase().starts_with("DETACH");
|
||||||
let start = if detach { "DETACH DELETE".len() } else { "DELETE".len() };
|
let start = if detach {
|
||||||
|
"DETACH DELETE".len()
|
||||||
|
} else {
|
||||||
|
"DELETE".len()
|
||||||
|
};
|
||||||
let variable = query[start..].trim().to_string();
|
let variable = query[start..].trim().to_string();
|
||||||
|
|
||||||
Ok(GraphQuery::Delete { variable, detach })
|
Ok(GraphQuery::Delete { variable, detach })
|
||||||
|
|
@ -500,19 +559,24 @@ impl GraphQueryParser {
|
||||||
let parts: Vec<_> = content.split('=').collect();
|
let parts: Vec<_> = content.split('=').collect();
|
||||||
|
|
||||||
if parts.len() != 2 {
|
if parts.len() != 2 {
|
||||||
return Err(GraphError::InvalidOperation("Invalid SET syntax".to_string()));
|
return Err(GraphError::InvalidOperation(
|
||||||
|
"Invalid SET syntax".to_string(),
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
let var_prop: Vec<_> = parts[0].trim().split('.').collect();
|
let var_prop: Vec<_> = parts[0].trim().split('.').collect();
|
||||||
if var_prop.len() != 2 {
|
if var_prop.len() != 2 {
|
||||||
return Err(GraphError::InvalidOperation("Invalid SET variable".to_string()));
|
return Err(GraphError::InvalidOperation(
|
||||||
|
"Invalid SET variable".to_string(),
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
let variable = var_prop[0].to_string();
|
let variable = var_prop[0].to_string();
|
||||||
let property = var_prop[1].to_string();
|
let property = var_prop[1].to_string();
|
||||||
let value_str = parts[1].trim();
|
let value_str = parts[1].trim();
|
||||||
|
|
||||||
let value: JsonValue = serde_json::from_str(value_str).unwrap_or(JsonValue::String(value_str.to_string()));
|
let value: JsonValue =
|
||||||
|
serde_json::from_str(value_str).unwrap_or(JsonValue::String(value_str.to_string()));
|
||||||
|
|
||||||
Ok(GraphQuery::Set {
|
Ok(GraphQuery::Set {
|
||||||
variable,
|
variable,
|
||||||
|
|
@ -535,18 +599,21 @@ impl<'a> GraphQueryExecutor<'a> {
|
||||||
/// Executes a graph query.
|
/// Executes a graph query.
|
||||||
pub fn execute(&self, query: &GraphQuery) -> Result<QueryResult, GraphError> {
|
pub fn execute(&self, query: &GraphQuery) -> Result<QueryResult, GraphError> {
|
||||||
match query {
|
match query {
|
||||||
GraphQuery::Match { pattern, where_clause, return_items, limit } => {
|
GraphQuery::Match {
|
||||||
self.execute_match(pattern, where_clause.as_ref(), return_items, *limit)
|
pattern,
|
||||||
}
|
where_clause,
|
||||||
GraphQuery::Create { .. } => {
|
return_items,
|
||||||
Err(GraphError::InvalidOperation("CREATE requires mutable access".to_string()))
|
limit,
|
||||||
}
|
} => self.execute_match(pattern, where_clause.as_ref(), return_items, *limit),
|
||||||
GraphQuery::Delete { .. } => {
|
GraphQuery::Create { .. } => Err(GraphError::InvalidOperation(
|
||||||
Err(GraphError::InvalidOperation("DELETE requires mutable access".to_string()))
|
"CREATE requires mutable access".to_string(),
|
||||||
}
|
)),
|
||||||
GraphQuery::Set { .. } => {
|
GraphQuery::Delete { .. } => Err(GraphError::InvalidOperation(
|
||||||
Err(GraphError::InvalidOperation("SET requires mutable access".to_string()))
|
"DELETE requires mutable access".to_string(),
|
||||||
}
|
)),
|
||||||
|
GraphQuery::Set { .. } => Err(GraphError::InvalidOperation(
|
||||||
|
"SET requires mutable access".to_string(),
|
||||||
|
)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -586,7 +653,11 @@ impl<'a> GraphQueryExecutor<'a> {
|
||||||
.depth(rel_pattern.max_hops)
|
.depth(rel_pattern.max_hops)
|
||||||
.direction(direction)
|
.direction(direction)
|
||||||
.edge_types(
|
.edge_types(
|
||||||
rel_pattern.edge_type.clone().map(|t| vec![t]).unwrap_or_default(),
|
rel_pattern
|
||||||
|
.edge_type
|
||||||
|
.clone()
|
||||||
|
.map(|t| vec![t])
|
||||||
|
.unwrap_or_default(),
|
||||||
)
|
)
|
||||||
.labels(rel_pattern.target.labels.clone());
|
.labels(rel_pattern.target.labels.clone());
|
||||||
|
|
||||||
|
|
@ -635,7 +706,10 @@ impl<'a> GraphQueryExecutor<'a> {
|
||||||
|
|
||||||
fn find_matching_nodes(&self, pattern: &NodePattern) -> Vec<Node> {
|
fn find_matching_nodes(&self, pattern: &NodePattern) -> Vec<Node> {
|
||||||
let label = pattern.labels.first().map(|s| s.as_str());
|
let label = pattern.labels.first().map(|s| s.as_str());
|
||||||
let filter = pattern.properties.clone().unwrap_or(JsonValue::Object(serde_json::Map::new()));
|
let filter = pattern
|
||||||
|
.properties
|
||||||
|
.clone()
|
||||||
|
.unwrap_or(JsonValue::Object(serde_json::Map::new()));
|
||||||
self.store.find_nodes(label, &filter)
|
self.store.find_nodes(label, &filter)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -657,7 +731,11 @@ impl<'a> GraphQueryExecutor<'a> {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_column_names(&self, return_items: &[ReturnItem], bindings: &[HashMap<String, JsonValue>]) -> Vec<String> {
|
fn get_column_names(
|
||||||
|
&self,
|
||||||
|
return_items: &[ReturnItem],
|
||||||
|
bindings: &[HashMap<String, JsonValue>],
|
||||||
|
) -> Vec<String> {
|
||||||
let mut columns = Vec::new();
|
let mut columns = Vec::new();
|
||||||
|
|
||||||
for item in return_items {
|
for item in return_items {
|
||||||
|
|
@ -673,7 +751,10 @@ impl<'a> GraphQueryExecutor<'a> {
|
||||||
}
|
}
|
||||||
ReturnItem::Alias { alias, .. } => columns.push(alias.clone()),
|
ReturnItem::Alias { alias, .. } => columns.push(alias.clone()),
|
||||||
ReturnItem::Count(var) => {
|
ReturnItem::Count(var) => {
|
||||||
columns.push(format!("count({})", var.as_ref().map(|s| s.as_str()).unwrap_or("*")));
|
columns.push(format!(
|
||||||
|
"count({})",
|
||||||
|
var.as_ref().map(|s| s.as_str()).unwrap_or("*")
|
||||||
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -681,11 +762,18 @@ impl<'a> GraphQueryExecutor<'a> {
|
||||||
columns
|
columns
|
||||||
}
|
}
|
||||||
|
|
||||||
fn extract_rows(&self, return_items: &[ReturnItem], bindings: &[HashMap<String, JsonValue>]) -> Vec<Vec<JsonValue>> {
|
fn extract_rows(
|
||||||
|
&self,
|
||||||
|
return_items: &[ReturnItem],
|
||||||
|
bindings: &[HashMap<String, JsonValue>],
|
||||||
|
) -> Vec<Vec<JsonValue>> {
|
||||||
let mut rows = Vec::new();
|
let mut rows = Vec::new();
|
||||||
|
|
||||||
// Handle COUNT specially
|
// Handle COUNT specially
|
||||||
if return_items.iter().any(|i| matches!(i, ReturnItem::Count(_))) {
|
if return_items
|
||||||
|
.iter()
|
||||||
|
.any(|i| matches!(i, ReturnItem::Count(_)))
|
||||||
|
{
|
||||||
rows.push(vec![JsonValue::Number(bindings.len().into())]);
|
rows.push(vec![JsonValue::Number(bindings.len().into())]);
|
||||||
return rows;
|
return rows;
|
||||||
}
|
}
|
||||||
|
|
@ -760,8 +848,14 @@ mod tests {
|
||||||
if let GraphQuery::Match { pattern, .. } = parsed {
|
if let GraphQuery::Match { pattern, .. } = parsed {
|
||||||
assert_eq!(pattern.start.labels, vec!["User".to_string()]);
|
assert_eq!(pattern.start.labels, vec!["User".to_string()]);
|
||||||
assert_eq!(pattern.relationships.len(), 1);
|
assert_eq!(pattern.relationships.len(), 1);
|
||||||
assert_eq!(pattern.relationships[0].edge_type, Some("FRIEND".to_string()));
|
assert_eq!(
|
||||||
assert_eq!(pattern.relationships[0].direction, RelationshipDirection::Outgoing);
|
pattern.relationships[0].edge_type,
|
||||||
|
Some("FRIEND".to_string())
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
pattern.relationships[0].direction,
|
||||||
|
RelationshipDirection::Outgoing
|
||||||
|
);
|
||||||
} else {
|
} else {
|
||||||
panic!("Expected Match query");
|
panic!("Expected Match query");
|
||||||
}
|
}
|
||||||
|
|
@ -771,9 +865,14 @@ mod tests {
|
||||||
fn test_execute_match() {
|
fn test_execute_match() {
|
||||||
let store = GraphStore::new();
|
let store = GraphStore::new();
|
||||||
|
|
||||||
let alice = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Alice"}));
|
let alice = store.create_node(
|
||||||
|
vec!["User".to_string()],
|
||||||
|
serde_json::json!({"name": "Alice"}),
|
||||||
|
);
|
||||||
let bob = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Bob"}));
|
let bob = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Bob"}));
|
||||||
store.create_edge(alice, bob, "FRIEND", serde_json::json!({})).unwrap();
|
store
|
||||||
|
.create_edge(alice, bob, "FRIEND", serde_json::json!({}))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let query = GraphQueryParser::parse("MATCH (n:User) RETURN n").unwrap();
|
let query = GraphQueryParser::parse("MATCH (n:User) RETURN n").unwrap();
|
||||||
let executor = GraphQueryExecutor::new(&store);
|
let executor = GraphQueryExecutor::new(&store);
|
||||||
|
|
|
||||||
|
|
@ -177,18 +177,8 @@ impl GraphStore {
|
||||||
/// Deletes a node and all its connected edges.
|
/// Deletes a node and all its connected edges.
|
||||||
pub fn delete_node(&self, id: &NodeId) -> Result<(), GraphError> {
|
pub fn delete_node(&self, id: &NodeId) -> Result<(), GraphError> {
|
||||||
// Get connected edges
|
// Get connected edges
|
||||||
let outgoing: Vec<EdgeId> = self
|
let outgoing: Vec<EdgeId> = self.adjacency.read().get(id).cloned().unwrap_or_default();
|
||||||
.adjacency
|
let incoming: Vec<EdgeId> = self.reverse_adj.read().get(id).cloned().unwrap_or_default();
|
||||||
.read()
|
|
||||||
.get(id)
|
|
||||||
.cloned()
|
|
||||||
.unwrap_or_default();
|
|
||||||
let incoming: Vec<EdgeId> = self
|
|
||||||
.reverse_adj
|
|
||||||
.read()
|
|
||||||
.get(id)
|
|
||||||
.cloned()
|
|
||||||
.unwrap_or_default();
|
|
||||||
|
|
||||||
// Delete all connected edges
|
// Delete all connected edges
|
||||||
for edge_id in outgoing.iter().chain(incoming.iter()) {
|
for edge_id in outgoing.iter().chain(incoming.iter()) {
|
||||||
|
|
@ -457,7 +447,12 @@ impl GraphStore {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Gets the neighbor node from an edge.
|
/// Gets the neighbor node from an edge.
|
||||||
fn get_neighbor_from_edge(&self, edge: &Edge, from: &NodeId, direction: Direction) -> Option<NodeId> {
|
fn get_neighbor_from_edge(
|
||||||
|
&self,
|
||||||
|
edge: &Edge,
|
||||||
|
from: &NodeId,
|
||||||
|
direction: Direction,
|
||||||
|
) -> Option<NodeId> {
|
||||||
match direction {
|
match direction {
|
||||||
Direction::Outgoing => {
|
Direction::Outgoing => {
|
||||||
if &edge.source == from {
|
if &edge.source == from {
|
||||||
|
|
@ -491,7 +486,12 @@ impl GraphStore {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Gets neighbors connected by a specific edge type.
|
/// Gets neighbors connected by a specific edge type.
|
||||||
pub fn neighbors_by_type(&self, id: &NodeId, edge_type: &str, direction: Direction) -> Vec<Node> {
|
pub fn neighbors_by_type(
|
||||||
|
&self,
|
||||||
|
id: &NodeId,
|
||||||
|
edge_type: &str,
|
||||||
|
direction: Direction,
|
||||||
|
) -> Vec<Node> {
|
||||||
let edges = self.edges_of(id, direction);
|
let edges = self.edges_of(id, direction);
|
||||||
let nodes = self.nodes.read();
|
let nodes = self.nodes.read();
|
||||||
|
|
||||||
|
|
@ -565,7 +565,10 @@ mod tests {
|
||||||
fn test_create_edge() {
|
fn test_create_edge() {
|
||||||
let store = GraphStore::new();
|
let store = GraphStore::new();
|
||||||
|
|
||||||
let alice = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Alice"}));
|
let alice = store.create_node(
|
||||||
|
vec!["User".to_string()],
|
||||||
|
serde_json::json!({"name": "Alice"}),
|
||||||
|
);
|
||||||
let bob = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Bob"}));
|
let bob = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Bob"}));
|
||||||
|
|
||||||
let edge_id = store
|
let edge_id = store
|
||||||
|
|
@ -582,12 +585,22 @@ mod tests {
|
||||||
fn test_neighbors() {
|
fn test_neighbors() {
|
||||||
let store = GraphStore::new();
|
let store = GraphStore::new();
|
||||||
|
|
||||||
let alice = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Alice"}));
|
let alice = store.create_node(
|
||||||
|
vec!["User".to_string()],
|
||||||
|
serde_json::json!({"name": "Alice"}),
|
||||||
|
);
|
||||||
let bob = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Bob"}));
|
let bob = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Bob"}));
|
||||||
let charlie = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Charlie"}));
|
let charlie = store.create_node(
|
||||||
|
vec!["User".to_string()],
|
||||||
|
serde_json::json!({"name": "Charlie"}),
|
||||||
|
);
|
||||||
|
|
||||||
store.create_edge(alice, bob, "FRIEND", serde_json::json!({})).unwrap();
|
store
|
||||||
store.create_edge(alice, charlie, "FRIEND", serde_json::json!({})).unwrap();
|
.create_edge(alice, bob, "FRIEND", serde_json::json!({}))
|
||||||
|
.unwrap();
|
||||||
|
store
|
||||||
|
.create_edge(alice, charlie, "FRIEND", serde_json::json!({}))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let neighbors = store.neighbors(&alice, Direction::Outgoing);
|
let neighbors = store.neighbors(&alice, Direction::Outgoing);
|
||||||
assert_eq!(neighbors.len(), 2);
|
assert_eq!(neighbors.len(), 2);
|
||||||
|
|
@ -597,9 +610,15 @@ mod tests {
|
||||||
fn test_find_by_label() {
|
fn test_find_by_label() {
|
||||||
let store = GraphStore::new();
|
let store = GraphStore::new();
|
||||||
|
|
||||||
store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Alice"}));
|
store.create_node(
|
||||||
|
vec!["User".to_string()],
|
||||||
|
serde_json::json!({"name": "Alice"}),
|
||||||
|
);
|
||||||
store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Bob"}));
|
store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Bob"}));
|
||||||
store.create_node(vec!["Product".to_string()], serde_json::json!({"name": "Widget"}));
|
store.create_node(
|
||||||
|
vec!["Product".to_string()],
|
||||||
|
serde_json::json!({"name": "Widget"}),
|
||||||
|
);
|
||||||
|
|
||||||
let users = store.find_nodes_by_label("User");
|
let users = store.find_nodes_by_label("User");
|
||||||
assert_eq!(users.len(), 2);
|
assert_eq!(users.len(), 2);
|
||||||
|
|
@ -615,7 +634,9 @@ mod tests {
|
||||||
let alice = store.create_node(vec!["User".to_string()], serde_json::json!({}));
|
let alice = store.create_node(vec!["User".to_string()], serde_json::json!({}));
|
||||||
let bob = store.create_node(vec!["User".to_string()], serde_json::json!({}));
|
let bob = store.create_node(vec!["User".to_string()], serde_json::json!({}));
|
||||||
|
|
||||||
store.create_edge(alice, bob, "FRIEND", serde_json::json!({})).unwrap();
|
store
|
||||||
|
.create_edge(alice, bob, "FRIEND", serde_json::json!({}))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
// Delete Alice - should also delete the edge
|
// Delete Alice - should also delete the edge
|
||||||
store.delete_node(&alice).unwrap();
|
store.delete_node(&alice).unwrap();
|
||||||
|
|
@ -631,7 +652,9 @@ mod tests {
|
||||||
let a = store.create_node(vec![], serde_json::json!({}));
|
let a = store.create_node(vec![], serde_json::json!({}));
|
||||||
let b = store.create_node(vec![], serde_json::json!({}));
|
let b = store.create_node(vec![], serde_json::json!({}));
|
||||||
|
|
||||||
store.create_undirected_edge(a, b, "LINK", serde_json::json!({})).unwrap();
|
store
|
||||||
|
.create_undirected_edge(a, b, "LINK", serde_json::json!({}))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
// Both directions should work
|
// Both directions should work
|
||||||
let a_neighbors = store.neighbors(&a, Direction::Outgoing);
|
let a_neighbors = store.neighbors(&a, Direction::Outgoing);
|
||||||
|
|
@ -648,8 +671,12 @@ mod tests {
|
||||||
let a = store.create_node(vec![], serde_json::json!({}));
|
let a = store.create_node(vec![], serde_json::json!({}));
|
||||||
let b = store.create_node(vec![], serde_json::json!({}));
|
let b = store.create_node(vec![], serde_json::json!({}));
|
||||||
|
|
||||||
store.create_edge(a, b, "TYPE_A", serde_json::json!({})).unwrap();
|
store
|
||||||
store.create_edge(a, b, "TYPE_B", serde_json::json!({})).unwrap();
|
.create_edge(a, b, "TYPE_A", serde_json::json!({}))
|
||||||
|
.unwrap();
|
||||||
|
store
|
||||||
|
.create_edge(a, b, "TYPE_B", serde_json::json!({}))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let edges = store.edges_between(&a, &b);
|
let edges = store.edges_between(&a, &b);
|
||||||
assert_eq!(edges.len(), 2);
|
assert_eq!(edges.len(), 2);
|
||||||
|
|
|
||||||
|
|
@ -171,7 +171,9 @@ impl<'a> Traverser<'a> {
|
||||||
|
|
||||||
for edge in edges {
|
for edge in edges {
|
||||||
// Check edge type filter
|
// Check edge type filter
|
||||||
if !query.edge_types.is_empty() && !query.edge_types.contains(&edge.edge_type) {
|
if !query.edge_types.is_empty()
|
||||||
|
&& !query.edge_types.contains(&edge.edge_type)
|
||||||
|
{
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -395,18 +397,35 @@ mod tests {
|
||||||
fn setup_social_graph() -> GraphStore {
|
fn setup_social_graph() -> GraphStore {
|
||||||
let store = GraphStore::new();
|
let store = GraphStore::new();
|
||||||
|
|
||||||
let alice = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Alice"}));
|
let alice = store.create_node(
|
||||||
|
vec!["User".to_string()],
|
||||||
|
serde_json::json!({"name": "Alice"}),
|
||||||
|
);
|
||||||
let bob = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Bob"}));
|
let bob = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Bob"}));
|
||||||
let charlie = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Charlie"}));
|
let charlie = store.create_node(
|
||||||
let dave = store.create_node(vec!["User".to_string()], serde_json::json!({"name": "Dave"}));
|
vec!["User".to_string()],
|
||||||
|
serde_json::json!({"name": "Charlie"}),
|
||||||
|
);
|
||||||
|
let dave = store.create_node(
|
||||||
|
vec!["User".to_string()],
|
||||||
|
serde_json::json!({"name": "Dave"}),
|
||||||
|
);
|
||||||
|
|
||||||
// Alice -> Bob -> Charlie -> Dave
|
// Alice -> Bob -> Charlie -> Dave
|
||||||
store.create_edge(alice, bob, "FRIEND", serde_json::json!({})).unwrap();
|
store
|
||||||
store.create_edge(bob, charlie, "FRIEND", serde_json::json!({})).unwrap();
|
.create_edge(alice, bob, "FRIEND", serde_json::json!({}))
|
||||||
store.create_edge(charlie, dave, "FRIEND", serde_json::json!({})).unwrap();
|
.unwrap();
|
||||||
|
store
|
||||||
|
.create_edge(bob, charlie, "FRIEND", serde_json::json!({}))
|
||||||
|
.unwrap();
|
||||||
|
store
|
||||||
|
.create_edge(charlie, dave, "FRIEND", serde_json::json!({}))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
// Alice -> Charlie (shortcut)
|
// Alice -> Charlie (shortcut)
|
||||||
store.create_edge(alice, charlie, "KNOWS", serde_json::json!({})).unwrap();
|
store
|
||||||
|
.create_edge(alice, charlie, "KNOWS", serde_json::json!({}))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
store
|
store
|
||||||
}
|
}
|
||||||
|
|
@ -417,7 +436,10 @@ mod tests {
|
||||||
let traverser = Traverser::new(&store);
|
let traverser = Traverser::new(&store);
|
||||||
|
|
||||||
let users = store.find_nodes_by_label("User");
|
let users = store.find_nodes_by_label("User");
|
||||||
let alice = users.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("Alice"))).unwrap();
|
let alice = users
|
||||||
|
.iter()
|
||||||
|
.find(|n| n.get_property("name") == Some(&serde_json::json!("Alice")))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let query = TraversalQuery::new().depth(2);
|
let query = TraversalQuery::new().depth(2);
|
||||||
let results = traverser.traverse(&alice.id, &query);
|
let results = traverser.traverse(&alice.id, &query);
|
||||||
|
|
@ -432,7 +454,10 @@ mod tests {
|
||||||
let traverser = Traverser::new(&store);
|
let traverser = Traverser::new(&store);
|
||||||
|
|
||||||
let users = store.find_nodes_by_label("User");
|
let users = store.find_nodes_by_label("User");
|
||||||
let alice = users.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("Alice"))).unwrap();
|
let alice = users
|
||||||
|
.iter()
|
||||||
|
.find(|n| n.get_property("name") == Some(&serde_json::json!("Alice")))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let query = TraversalQuery::new()
|
let query = TraversalQuery::new()
|
||||||
.depth(2)
|
.depth(2)
|
||||||
|
|
@ -440,7 +465,10 @@ mod tests {
|
||||||
let results = traverser.traverse(&alice.id, &query);
|
let results = traverser.traverse(&alice.id, &query);
|
||||||
|
|
||||||
// Following only FRIEND edges: Alice -> Bob -> Charlie
|
// Following only FRIEND edges: Alice -> Bob -> Charlie
|
||||||
let names: Vec<_> = results.iter().filter_map(|r| r.node.get_property("name")).collect();
|
let names: Vec<_> = results
|
||||||
|
.iter()
|
||||||
|
.filter_map(|r| r.node.get_property("name"))
|
||||||
|
.collect();
|
||||||
assert!(names.contains(&&serde_json::json!("Bob")));
|
assert!(names.contains(&&serde_json::json!("Bob")));
|
||||||
assert!(names.contains(&&serde_json::json!("Charlie")));
|
assert!(names.contains(&&serde_json::json!("Charlie")));
|
||||||
}
|
}
|
||||||
|
|
@ -451,7 +479,10 @@ mod tests {
|
||||||
let traverser = Traverser::new(&store);
|
let traverser = Traverser::new(&store);
|
||||||
|
|
||||||
let users = store.find_nodes_by_label("User");
|
let users = store.find_nodes_by_label("User");
|
||||||
let alice = users.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("Alice"))).unwrap();
|
let alice = users
|
||||||
|
.iter()
|
||||||
|
.find(|n| n.get_property("name") == Some(&serde_json::json!("Alice")))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let query = TraversalQuery::new().depth(1);
|
let query = TraversalQuery::new().depth(1);
|
||||||
let results = traverser.traverse(&alice.id, &query);
|
let results = traverser.traverse(&alice.id, &query);
|
||||||
|
|
@ -468,7 +499,10 @@ mod tests {
|
||||||
let traverser = Traverser::new(&store);
|
let traverser = Traverser::new(&store);
|
||||||
|
|
||||||
let users = store.find_nodes_by_label("User");
|
let users = store.find_nodes_by_label("User");
|
||||||
let alice = users.iter().find(|n| n.get_property("name") == Some(&serde_json::json!("Alice"))).unwrap();
|
let alice = users
|
||||||
|
.iter()
|
||||||
|
.find(|n| n.get_property("name") == Some(&serde_json::json!("Alice")))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let query = TraversalQuery::new().depth(10).limit(2);
|
let query = TraversalQuery::new().depth(10).limit(2);
|
||||||
let results = traverser.traverse(&alice.id, &query);
|
let results = traverser.traverse(&alice.id, &query);
|
||||||
|
|
@ -486,11 +520,21 @@ mod tests {
|
||||||
let mutual2 = store.create_node(vec![], serde_json::json!({"name": "Mutual2"}));
|
let mutual2 = store.create_node(vec![], serde_json::json!({"name": "Mutual2"}));
|
||||||
let only_alice = store.create_node(vec![], serde_json::json!({"name": "OnlyAlice"}));
|
let only_alice = store.create_node(vec![], serde_json::json!({"name": "OnlyAlice"}));
|
||||||
|
|
||||||
store.create_edge(alice, mutual1, "FRIEND", serde_json::json!({})).unwrap();
|
store
|
||||||
store.create_edge(alice, mutual2, "FRIEND", serde_json::json!({})).unwrap();
|
.create_edge(alice, mutual1, "FRIEND", serde_json::json!({}))
|
||||||
store.create_edge(alice, only_alice, "FRIEND", serde_json::json!({})).unwrap();
|
.unwrap();
|
||||||
store.create_edge(bob, mutual1, "FRIEND", serde_json::json!({})).unwrap();
|
store
|
||||||
store.create_edge(bob, mutual2, "FRIEND", serde_json::json!({})).unwrap();
|
.create_edge(alice, mutual2, "FRIEND", serde_json::json!({}))
|
||||||
|
.unwrap();
|
||||||
|
store
|
||||||
|
.create_edge(alice, only_alice, "FRIEND", serde_json::json!({}))
|
||||||
|
.unwrap();
|
||||||
|
store
|
||||||
|
.create_edge(bob, mutual1, "FRIEND", serde_json::json!({}))
|
||||||
|
.unwrap();
|
||||||
|
store
|
||||||
|
.create_edge(bob, mutual2, "FRIEND", serde_json::json!({}))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let traverser = Traverser::new(&store);
|
let traverser = Traverser::new(&store);
|
||||||
let mutual = traverser.mutual_connections(&alice, &bob, Some("FRIEND"));
|
let mutual = traverser.mutual_connections(&alice, &bob, Some("FRIEND"));
|
||||||
|
|
|
||||||
|
|
@ -174,17 +174,24 @@ impl Index {
|
||||||
// Check uniqueness if required
|
// Check uniqueness if required
|
||||||
if self.config.unique {
|
if self.config.unique {
|
||||||
let exists = match self.config.index_type {
|
let exists = match self.config.index_type {
|
||||||
IndexType::Hash | IndexType::Unique => {
|
IndexType::Hash | IndexType::Unique => self
|
||||||
self.hash.read().get(&key).map(|s| !s.is_empty()).unwrap_or(false)
|
.hash
|
||||||
}
|
.read()
|
||||||
_ => {
|
.get(&key)
|
||||||
self.btree.read().get(&key).map(|s| !s.is_empty()).unwrap_or(false)
|
.map(|s| !s.is_empty())
|
||||||
}
|
.unwrap_or(false),
|
||||||
|
_ => self
|
||||||
|
.btree
|
||||||
|
.read()
|
||||||
|
.get(&key)
|
||||||
|
.map(|s| !s.is_empty())
|
||||||
|
.unwrap_or(false),
|
||||||
};
|
};
|
||||||
if exists {
|
if exists {
|
||||||
return Err(DatabaseError::AlreadyExists(
|
return Err(DatabaseError::AlreadyExists(format!(
|
||||||
format!("Unique constraint violation on index '{}'", self.config.name)
|
"Unique constraint violation on index '{}'",
|
||||||
));
|
self.config.name
|
||||||
|
)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -239,20 +246,18 @@ impl Index {
|
||||||
self.stats.write().lookups += 1;
|
self.stats.write().lookups += 1;
|
||||||
|
|
||||||
let result: Vec<DocumentId> = match self.config.index_type {
|
let result: Vec<DocumentId> = match self.config.index_type {
|
||||||
IndexType::Hash | IndexType::Unique => {
|
IndexType::Hash | IndexType::Unique => self
|
||||||
self.hash
|
.hash
|
||||||
.read()
|
.read()
|
||||||
.get(&key)
|
.get(&key)
|
||||||
.map(|s| s.iter().cloned().collect())
|
.map(|s| s.iter().cloned().collect())
|
||||||
.unwrap_or_default()
|
.unwrap_or_default(),
|
||||||
}
|
_ => self
|
||||||
_ => {
|
.btree
|
||||||
self.btree
|
.read()
|
||||||
.read()
|
.get(&key)
|
||||||
.get(&key)
|
.map(|s| s.iter().cloned().collect())
|
||||||
.map(|s| s.iter().cloned().collect())
|
.unwrap_or_default(),
|
||||||
.unwrap_or_default()
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
if !result.is_empty() {
|
if !result.is_empty() {
|
||||||
|
|
@ -407,12 +412,7 @@ impl IndexManager {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Removes a document from indexes.
|
/// Removes a document from indexes.
|
||||||
pub fn unindex_document(
|
pub fn unindex_document(&self, collection: &str, doc_id: &DocumentId, document: &JsonValue) {
|
||||||
&self,
|
|
||||||
collection: &str,
|
|
||||||
doc_id: &DocumentId,
|
|
||||||
document: &JsonValue,
|
|
||||||
) {
|
|
||||||
let index_names = self.get_collection_indexes(collection);
|
let index_names = self.get_collection_indexes(collection);
|
||||||
let indexes = self.indexes.read();
|
let indexes = self.indexes.read();
|
||||||
|
|
||||||
|
|
@ -483,7 +483,9 @@ mod tests {
|
||||||
let index = Index::new(config);
|
let index = Index::new(config);
|
||||||
|
|
||||||
let doc1 = DocumentId::new();
|
let doc1 = DocumentId::new();
|
||||||
index.insert(doc1.clone(), &json!("alice@example.com")).unwrap();
|
index
|
||||||
|
.insert(doc1.clone(), &json!("alice@example.com"))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let results = index.lookup(&json!("alice@example.com"));
|
let results = index.lookup(&json!("alice@example.com"));
|
||||||
assert_eq!(results.len(), 1);
|
assert_eq!(results.len(), 1);
|
||||||
|
|
@ -521,7 +523,9 @@ mod tests {
|
||||||
let doc_id = DocumentId::new();
|
let doc_id = DocumentId::new();
|
||||||
let doc = json!({"name": "Alice", "age": 30});
|
let doc = json!({"name": "Alice", "age": 30});
|
||||||
|
|
||||||
manager.index_document("users", doc_id.clone(), &doc).unwrap();
|
manager
|
||||||
|
.index_document("users", doc_id.clone(), &doc)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let indexes = manager.list_indexes();
|
let indexes = manager.list_indexes();
|
||||||
assert_eq!(indexes.len(), 1);
|
assert_eq!(indexes.len(), 1);
|
||||||
|
|
|
||||||
|
|
@ -126,8 +126,7 @@ impl KeyValueStore {
|
||||||
|
|
||||||
/// Gets a value as string.
|
/// Gets a value as string.
|
||||||
pub fn get_string(&self, key: &str) -> Option<String> {
|
pub fn get_string(&self, key: &str) -> Option<String> {
|
||||||
self.get(key)
|
self.get(key).and_then(|v| String::from_utf8(v).ok())
|
||||||
.and_then(|v| String::from_utf8(v).ok())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Sets a value with optional TTL.
|
/// Sets a value with optional TTL.
|
||||||
|
|
@ -224,8 +223,9 @@ impl KeyValueStore {
|
||||||
} else {
|
} else {
|
||||||
let s = String::from_utf8(entry.value.clone())
|
let s = String::from_utf8(entry.value.clone())
|
||||||
.map_err(|_| DatabaseError::InvalidOperation("Value is not a string".into()))?;
|
.map_err(|_| DatabaseError::InvalidOperation("Value is not a string".into()))?;
|
||||||
s.parse::<i64>()
|
s.parse::<i64>().map_err(|_| {
|
||||||
.map_err(|_| DatabaseError::InvalidOperation("Value is not an integer".into()))?
|
DatabaseError::InvalidOperation("Value is not an integer".into())
|
||||||
|
})?
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
0
|
0
|
||||||
|
|
@ -243,9 +243,9 @@ impl KeyValueStore {
|
||||||
pub fn append(&self, key: &str, value: &[u8]) -> Result<usize, DatabaseError> {
|
pub fn append(&self, key: &str, value: &[u8]) -> Result<usize, DatabaseError> {
|
||||||
let mut data = self.data.write();
|
let mut data = self.data.write();
|
||||||
|
|
||||||
let entry = data.entry(key.to_string()).or_insert_with(|| {
|
let entry = data
|
||||||
KvEntry::new(Vec::new(), 0)
|
.entry(key.to_string())
|
||||||
});
|
.or_insert_with(|| KvEntry::new(Vec::new(), 0));
|
||||||
|
|
||||||
if entry.is_expired() {
|
if entry.is_expired() {
|
||||||
entry.value.clear();
|
entry.value.clear();
|
||||||
|
|
@ -393,11 +393,16 @@ mod tests {
|
||||||
fn test_mget_mset() {
|
fn test_mget_mset() {
|
||||||
let store = KeyValueStore::new();
|
let store = KeyValueStore::new();
|
||||||
|
|
||||||
store.mset(&[
|
store
|
||||||
("k1", b"v1".to_vec()),
|
.mset(
|
||||||
("k2", b"v2".to_vec()),
|
&[
|
||||||
("k3", b"v3".to_vec()),
|
("k1", b"v1".to_vec()),
|
||||||
], 0).unwrap();
|
("k2", b"v2".to_vec()),
|
||||||
|
("k3", b"v3".to_vec()),
|
||||||
|
],
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let results = store.mget(&["k1", "k2", "k4"]);
|
let results = store.mget(&["k1", "k2", "k4"]);
|
||||||
assert_eq!(results.len(), 3);
|
assert_eq!(results.len(), 3);
|
||||||
|
|
|
||||||
|
|
@ -65,12 +65,14 @@ pub use graph::{
|
||||||
pub use index::{Index, IndexConfig, IndexManager, IndexType};
|
pub use index::{Index, IndexConfig, IndexManager, IndexType};
|
||||||
pub use keyvalue::{KeyValue, KeyValueStore, KvEntry};
|
pub use keyvalue::{KeyValue, KeyValueStore, KvEntry};
|
||||||
pub use query::{Filter, Query, QueryEngine, QueryResult, SortOrder};
|
pub use query::{Filter, Query, QueryEngine, QueryResult, SortOrder};
|
||||||
pub use schema::{Field, FieldType, Schema, SchemaValidator};
|
|
||||||
pub use replication::{
|
pub use replication::{
|
||||||
ClusterConfig, Command as RaftCommand, NodeRole, RaftConfig, RaftEvent, RaftNode, RaftState,
|
ClusterConfig, Command as RaftCommand, NodeRole, RaftConfig, RaftEvent, RaftNode, RaftState,
|
||||||
ReplicatedLog,
|
ReplicatedLog,
|
||||||
};
|
};
|
||||||
pub use sql::{QueryResult as SqlQueryResult, SqlEngine, SqlParser, SqlType, SqlValue, Table, TableDef};
|
pub use schema::{Field, FieldType, Schema, SchemaValidator};
|
||||||
|
pub use sql::{
|
||||||
|
QueryResult as SqlQueryResult, SqlEngine, SqlParser, SqlType, SqlValue, Table, TableDef,
|
||||||
|
};
|
||||||
pub use timeseries::{DataPoint, Metric, TimeSeries, TimeSeriesStore};
|
pub use timeseries::{DataPoint, Metric, TimeSeries, TimeSeriesStore};
|
||||||
pub use vector::{Embedding, SimilarityMetric, VectorIndex, VectorStore};
|
pub use vector::{Embedding, SimilarityMetric, VectorIndex, VectorStore};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -419,10 +419,7 @@ impl QueryEngine {
|
||||||
|
|
||||||
let values: Vec<f64> = docs
|
let values: Vec<f64> = docs
|
||||||
.iter()
|
.iter()
|
||||||
.filter_map(|doc| {
|
.filter_map(|doc| doc.get(field).and_then(|v| v.as_f64()))
|
||||||
doc.get(field)
|
|
||||||
.and_then(|v| v.as_f64())
|
|
||||||
})
|
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let result = match op {
|
let result = match op {
|
||||||
|
|
@ -439,22 +436,18 @@ impl QueryEngine {
|
||||||
serde_json::to_value(avg).unwrap_or(JsonValue::Null)
|
serde_json::to_value(avg).unwrap_or(JsonValue::Null)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
AggregateOp::Min => {
|
AggregateOp::Min => values
|
||||||
values
|
.iter()
|
||||||
.iter()
|
.copied()
|
||||||
.copied()
|
.min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
|
||||||
.min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
|
.map(|v| serde_json::to_value(v).unwrap_or(JsonValue::Null))
|
||||||
.map(|v| serde_json::to_value(v).unwrap_or(JsonValue::Null))
|
.unwrap_or(JsonValue::Null),
|
||||||
.unwrap_or(JsonValue::Null)
|
AggregateOp::Max => values
|
||||||
}
|
.iter()
|
||||||
AggregateOp::Max => {
|
.copied()
|
||||||
values
|
.max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
|
||||||
.iter()
|
.map(|v| serde_json::to_value(v).unwrap_or(JsonValue::Null))
|
||||||
.copied()
|
.unwrap_or(JsonValue::Null),
|
||||||
.max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
|
|
||||||
.map(|v| serde_json::to_value(v).unwrap_or(JsonValue::Null))
|
|
||||||
.unwrap_or(JsonValue::Null)
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(result)
|
Ok(result)
|
||||||
|
|
@ -507,8 +500,10 @@ mod tests {
|
||||||
fn test_simple_query() {
|
fn test_simple_query() {
|
||||||
let docs = Arc::new(DocumentStore::new());
|
let docs = Arc::new(DocumentStore::new());
|
||||||
docs.create_collection("users").unwrap();
|
docs.create_collection("users").unwrap();
|
||||||
docs.insert("users", json!({"name": "Alice", "age": 30})).unwrap();
|
docs.insert("users", json!({"name": "Alice", "age": 30}))
|
||||||
docs.insert("users", json!({"name": "Bob", "age": 25})).unwrap();
|
.unwrap();
|
||||||
|
docs.insert("users", json!({"name": "Bob", "age": 25}))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let vectors = Arc::new(VectorStore::new(3));
|
let vectors = Arc::new(VectorStore::new(3));
|
||||||
let indexes = Arc::new(IndexManager::new());
|
let indexes = Arc::new(IndexManager::new());
|
||||||
|
|
@ -525,8 +520,10 @@ mod tests {
|
||||||
fn test_filter_query() {
|
fn test_filter_query() {
|
||||||
let docs = Arc::new(DocumentStore::new());
|
let docs = Arc::new(DocumentStore::new());
|
||||||
docs.create_collection("users").unwrap();
|
docs.create_collection("users").unwrap();
|
||||||
docs.insert("users", json!({"name": "Alice", "age": 30})).unwrap();
|
docs.insert("users", json!({"name": "Alice", "age": 30}))
|
||||||
docs.insert("users", json!({"name": "Bob", "age": 25})).unwrap();
|
.unwrap();
|
||||||
|
docs.insert("users", json!({"name": "Bob", "age": 25}))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let vectors = Arc::new(VectorStore::new(3));
|
let vectors = Arc::new(VectorStore::new(3));
|
||||||
let indexes = Arc::new(IndexManager::new());
|
let indexes = Arc::new(IndexManager::new());
|
||||||
|
|
@ -543,9 +540,12 @@ mod tests {
|
||||||
fn test_sorted_query() {
|
fn test_sorted_query() {
|
||||||
let docs = Arc::new(DocumentStore::new());
|
let docs = Arc::new(DocumentStore::new());
|
||||||
docs.create_collection("users").unwrap();
|
docs.create_collection("users").unwrap();
|
||||||
docs.insert("users", json!({"name": "Alice", "age": 30})).unwrap();
|
docs.insert("users", json!({"name": "Alice", "age": 30}))
|
||||||
docs.insert("users", json!({"name": "Bob", "age": 25})).unwrap();
|
.unwrap();
|
||||||
docs.insert("users", json!({"name": "Charlie", "age": 35})).unwrap();
|
docs.insert("users", json!({"name": "Bob", "age": 25}))
|
||||||
|
.unwrap();
|
||||||
|
docs.insert("users", json!({"name": "Charlie", "age": 35}))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let vectors = Arc::new(VectorStore::new(3));
|
let vectors = Arc::new(VectorStore::new(3));
|
||||||
let indexes = Arc::new(IndexManager::new());
|
let indexes = Arc::new(IndexManager::new());
|
||||||
|
|
|
||||||
|
|
@ -92,12 +92,7 @@ impl Election {
|
||||||
|
|
||||||
/// Creates a RequestVote message for this election.
|
/// Creates a RequestVote message for this election.
|
||||||
pub fn create_request(&self, log: &ReplicatedLog) -> RequestVote {
|
pub fn create_request(&self, log: &ReplicatedLog) -> RequestVote {
|
||||||
RequestVote::new(
|
RequestVote::new(self.term, self.node_id, log.last_index(), log.last_term())
|
||||||
self.term,
|
|
||||||
self.node_id,
|
|
||||||
log.last_index(),
|
|
||||||
log.last_term(),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Checks the current result of the election.
|
/// Checks the current result of the election.
|
||||||
|
|
@ -217,8 +212,8 @@ impl Default for ElectionTimeout {
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::replication::state::Command;
|
|
||||||
use crate::replication::log::LogEntry;
|
use crate::replication::log::LogEntry;
|
||||||
|
use crate::replication::state::Command;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_election_basic() {
|
fn test_election_basic() {
|
||||||
|
|
|
||||||
|
|
@ -139,7 +139,10 @@ impl ReplicatedLog {
|
||||||
|
|
||||||
let entries = self.entries.read();
|
let entries = self.entries.read();
|
||||||
let offset = (from_index - start) as usize;
|
let offset = (from_index - start) as usize;
|
||||||
entries.get(offset..).map(|s| s.to_vec()).unwrap_or_default()
|
entries
|
||||||
|
.get(offset..)
|
||||||
|
.map(|s| s.to_vec())
|
||||||
|
.unwrap_or_default()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Appends an entry to the log.
|
/// Appends an entry to the log.
|
||||||
|
|
@ -151,7 +154,12 @@ impl ReplicatedLog {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Appends multiple entries, potentially overwriting conflicting entries.
|
/// Appends multiple entries, potentially overwriting conflicting entries.
|
||||||
pub fn append_entries(&self, prev_index: u64, prev_term: u64, new_entries: Vec<LogEntry>) -> bool {
|
pub fn append_entries(
|
||||||
|
&self,
|
||||||
|
prev_index: u64,
|
||||||
|
prev_term: u64,
|
||||||
|
new_entries: Vec<LogEntry>,
|
||||||
|
) -> bool {
|
||||||
// Check that prev entry matches
|
// Check that prev entry matches
|
||||||
if prev_index > 0 {
|
if prev_index > 0 {
|
||||||
if let Some(prev_entry_term) = self.term_at(prev_index) {
|
if let Some(prev_entry_term) = self.term_at(prev_index) {
|
||||||
|
|
@ -245,7 +253,11 @@ impl ReplicatedLog {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates entries for replication starting from a given index.
|
/// Creates entries for replication starting from a given index.
|
||||||
pub fn entries_for_replication(&self, from_index: u64, max_entries: usize) -> (u64, u64, Vec<LogEntry>) {
|
pub fn entries_for_replication(
|
||||||
|
&self,
|
||||||
|
from_index: u64,
|
||||||
|
max_entries: usize,
|
||||||
|
) -> (u64, u64, Vec<LogEntry>) {
|
||||||
let prev_index = from_index.saturating_sub(1);
|
let prev_index = from_index.saturating_sub(1);
|
||||||
let prev_term = self.term_at(prev_index).unwrap_or(0);
|
let prev_term = self.term_at(prev_index).unwrap_or(0);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -222,7 +222,11 @@ impl RaftNode {
|
||||||
|
|
||||||
// Create new election
|
// Create new election
|
||||||
let cluster_size = self.cluster.voting_members();
|
let cluster_size = self.cluster.voting_members();
|
||||||
self.election = Some(Election::new(self.id, self.state.current_term, cluster_size));
|
self.election = Some(Election::new(
|
||||||
|
self.id,
|
||||||
|
self.state.current_term,
|
||||||
|
cluster_size,
|
||||||
|
));
|
||||||
|
|
||||||
// Create RequestVote message
|
// Create RequestVote message
|
||||||
let request = RequestVote::new(
|
let request = RequestVote::new(
|
||||||
|
|
@ -295,9 +299,9 @@ impl RaftNode {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
let (prev_log_index, prev_log_term, entries) =
|
let (prev_log_index, prev_log_term, entries) = self
|
||||||
self.log
|
.log
|
||||||
.entries_for_replication(next_index, self.config.max_entries_per_rpc);
|
.entries_for_replication(next_index, self.config.max_entries_per_rpc);
|
||||||
|
|
||||||
let request = AppendEntries::with_entries(
|
let request = AppendEntries::with_entries(
|
||||||
self.state.current_term,
|
self.state.current_term,
|
||||||
|
|
@ -308,8 +312,10 @@ impl RaftNode {
|
||||||
self.state.commit_index,
|
self.state.commit_index,
|
||||||
);
|
);
|
||||||
|
|
||||||
self.events
|
self.events.push(RaftEvent::SendRpc(
|
||||||
.push(RaftEvent::SendRpc(peer_id, RpcMessage::AppendEntries(request)));
|
peer_id,
|
||||||
|
RpcMessage::AppendEntries(request),
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
fn send_install_snapshot(&mut self, peer_id: NodeId) {
|
fn send_install_snapshot(&mut self, peer_id: NodeId) {
|
||||||
|
|
@ -332,8 +338,10 @@ impl RaftNode {
|
||||||
done,
|
done,
|
||||||
);
|
);
|
||||||
|
|
||||||
self.events
|
self.events.push(RaftEvent::SendRpc(
|
||||||
.push(RaftEvent::SendRpc(peer_id, RpcMessage::InstallSnapshot(request)));
|
peer_id,
|
||||||
|
RpcMessage::InstallSnapshot(request),
|
||||||
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -395,7 +403,11 @@ impl RaftNode {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn handle_append_entries(&mut self, _from: NodeId, req: AppendEntries) -> AppendEntriesResponse {
|
fn handle_append_entries(
|
||||||
|
&mut self,
|
||||||
|
_from: NodeId,
|
||||||
|
req: AppendEntries,
|
||||||
|
) -> AppendEntriesResponse {
|
||||||
// Rule: If term > currentTerm, become follower
|
// Rule: If term > currentTerm, become follower
|
||||||
if req.term > self.state.current_term {
|
if req.term > self.state.current_term {
|
||||||
self.become_follower(req.term, Some(req.leader_id));
|
self.become_follower(req.term, Some(req.leader_id));
|
||||||
|
|
@ -416,9 +428,9 @@ impl RaftNode {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try to append entries
|
// Try to append entries
|
||||||
let success =
|
let success = self
|
||||||
self.log
|
.log
|
||||||
.append_entries(req.prev_log_index, req.prev_log_term, req.entries);
|
.append_entries(req.prev_log_index, req.prev_log_term, req.entries);
|
||||||
|
|
||||||
if success {
|
if success {
|
||||||
// Update commit index
|
// Update commit index
|
||||||
|
|
@ -443,7 +455,11 @@ impl RaftNode {
|
||||||
}
|
}
|
||||||
conflict_index -= 1;
|
conflict_index -= 1;
|
||||||
}
|
}
|
||||||
AppendEntriesResponse::conflict(self.state.current_term, conflict_term, conflict_index)
|
AppendEntriesResponse::conflict(
|
||||||
|
self.state.current_term,
|
||||||
|
conflict_term,
|
||||||
|
conflict_index,
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
AppendEntriesResponse::failure(self.state.current_term)
|
AppendEntriesResponse::failure(self.state.current_term)
|
||||||
}
|
}
|
||||||
|
|
@ -502,7 +518,11 @@ impl RaftNode {
|
||||||
self.cluster.update_peer_state(from, PeerState::Reachable);
|
self.cluster.update_peer_state(from, PeerState::Reachable);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn handle_install_snapshot(&mut self, _from: NodeId, req: InstallSnapshot) -> InstallSnapshotResponse {
|
fn handle_install_snapshot(
|
||||||
|
&mut self,
|
||||||
|
_from: NodeId,
|
||||||
|
req: InstallSnapshot,
|
||||||
|
) -> InstallSnapshotResponse {
|
||||||
// Rule: If term > currentTerm, become follower
|
// Rule: If term > currentTerm, become follower
|
||||||
if req.term > self.state.current_term {
|
if req.term > self.state.current_term {
|
||||||
self.become_follower(req.term, Some(req.leader_id));
|
self.become_follower(req.term, Some(req.leader_id));
|
||||||
|
|
@ -692,12 +712,14 @@ impl RaftNode {
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
|
||||||
use super::super::cluster::PeerAddress;
|
use super::super::cluster::PeerAddress;
|
||||||
|
use super::*;
|
||||||
|
|
||||||
fn create_test_cluster(node_id: NodeId, peers: &[NodeId]) -> ClusterConfig {
|
fn create_test_cluster(node_id: NodeId, peers: &[NodeId]) -> ClusterConfig {
|
||||||
let mut cluster =
|
let mut cluster = ClusterConfig::new(
|
||||||
ClusterConfig::new(node_id, PeerAddress::new("127.0.0.1", 9000 + node_id as u16));
|
node_id,
|
||||||
|
PeerAddress::new("127.0.0.1", 9000 + node_id as u16),
|
||||||
|
);
|
||||||
for &peer in peers {
|
for &peer in peers {
|
||||||
cluster.add_peer(super::super::cluster::PeerInfo::new(
|
cluster.add_peer(super::super::cluster::PeerInfo::new(
|
||||||
peer,
|
peer,
|
||||||
|
|
|
||||||
|
|
@ -176,7 +176,10 @@ impl SnapshotManager {
|
||||||
/// Adds a chunk to the pending snapshot.
|
/// Adds a chunk to the pending snapshot.
|
||||||
pub fn add_chunk(&mut self, offset: u64, data: Vec<u8>) -> bool {
|
pub fn add_chunk(&mut self, offset: u64, data: Vec<u8>) -> bool {
|
||||||
if let Some(ref mut pending) = self.pending_snapshot {
|
if let Some(ref mut pending) = self.pending_snapshot {
|
||||||
if offset == pending.expected_offset + pending.chunks.iter().map(|c| c.len() as u64).sum::<u64>() {
|
if offset
|
||||||
|
== pending.expected_offset
|
||||||
|
+ pending.chunks.iter().map(|c| c.len() as u64).sum::<u64>()
|
||||||
|
{
|
||||||
pending.chunks.push(data);
|
pending.chunks.push(data);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -118,31 +118,55 @@ pub enum Command {
|
||||||
|
|
||||||
// Key-Value operations
|
// Key-Value operations
|
||||||
/// Set a key-value pair.
|
/// Set a key-value pair.
|
||||||
KvSet { key: String, value: Vec<u8>, ttl: Option<u64> },
|
KvSet {
|
||||||
|
key: String,
|
||||||
|
value: Vec<u8>,
|
||||||
|
ttl: Option<u64>,
|
||||||
|
},
|
||||||
/// Delete a key.
|
/// Delete a key.
|
||||||
KvDelete { key: String },
|
KvDelete { key: String },
|
||||||
|
|
||||||
// Document operations
|
// Document operations
|
||||||
/// Insert a document.
|
/// Insert a document.
|
||||||
DocInsert { collection: String, document: JsonValue },
|
DocInsert {
|
||||||
|
collection: String,
|
||||||
|
document: JsonValue,
|
||||||
|
},
|
||||||
/// Update a document.
|
/// Update a document.
|
||||||
DocUpdate { collection: String, id: String, update: JsonValue },
|
DocUpdate {
|
||||||
|
collection: String,
|
||||||
|
id: String,
|
||||||
|
update: JsonValue,
|
||||||
|
},
|
||||||
/// Delete a document.
|
/// Delete a document.
|
||||||
DocDelete { collection: String, id: String },
|
DocDelete { collection: String, id: String },
|
||||||
|
|
||||||
// Vector operations
|
// Vector operations
|
||||||
/// Insert a vector.
|
/// Insert a vector.
|
||||||
VectorInsert { namespace: String, id: String, vector: Vec<f32>, metadata: JsonValue },
|
VectorInsert {
|
||||||
|
namespace: String,
|
||||||
|
id: String,
|
||||||
|
vector: Vec<f32>,
|
||||||
|
metadata: JsonValue,
|
||||||
|
},
|
||||||
/// Delete a vector.
|
/// Delete a vector.
|
||||||
VectorDelete { namespace: String, id: String },
|
VectorDelete { namespace: String, id: String },
|
||||||
|
|
||||||
// Time-series operations
|
// Time-series operations
|
||||||
/// Record a metric data point.
|
/// Record a metric data point.
|
||||||
TimeSeriesRecord { metric: String, value: f64, timestamp: u64, tags: JsonValue },
|
TimeSeriesRecord {
|
||||||
|
metric: String,
|
||||||
|
value: f64,
|
||||||
|
timestamp: u64,
|
||||||
|
tags: JsonValue,
|
||||||
|
},
|
||||||
|
|
||||||
// Graph operations
|
// Graph operations
|
||||||
/// Create a graph node.
|
/// Create a graph node.
|
||||||
GraphNodeCreate { labels: Vec<String>, properties: JsonValue },
|
GraphNodeCreate {
|
||||||
|
labels: Vec<String>,
|
||||||
|
properties: JsonValue,
|
||||||
|
},
|
||||||
/// Delete a graph node.
|
/// Delete a graph node.
|
||||||
GraphNodeDelete { id: String },
|
GraphNodeDelete { id: String },
|
||||||
/// Create a graph edge.
|
/// Create a graph edge.
|
||||||
|
|
@ -161,13 +185,20 @@ pub enum Command {
|
||||||
|
|
||||||
// Schema operations
|
// Schema operations
|
||||||
/// Create a collection/table.
|
/// Create a collection/table.
|
||||||
CreateCollection { name: String, schema: Option<JsonValue> },
|
CreateCollection {
|
||||||
|
name: String,
|
||||||
|
schema: Option<JsonValue>,
|
||||||
|
},
|
||||||
/// Drop a collection/table.
|
/// Drop a collection/table.
|
||||||
DropCollection { name: String },
|
DropCollection { name: String },
|
||||||
|
|
||||||
// Index operations
|
// Index operations
|
||||||
/// Create an index.
|
/// Create an index.
|
||||||
CreateIndex { collection: String, field: String, index_type: String },
|
CreateIndex {
|
||||||
|
collection: String,
|
||||||
|
field: String,
|
||||||
|
index_type: String,
|
||||||
|
},
|
||||||
/// Drop an index.
|
/// Drop an index.
|
||||||
DropIndex { name: String },
|
DropIndex { name: String },
|
||||||
|
|
||||||
|
|
@ -265,7 +296,12 @@ impl LeaderState {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Calculates the new commit index based on majority replication.
|
/// Calculates the new commit index based on majority replication.
|
||||||
pub fn calculate_commit_index(&self, current_commit: u64, current_term: u64, log_term_at: impl Fn(u64) -> Option<u64>) -> u64 {
|
pub fn calculate_commit_index(
|
||||||
|
&self,
|
||||||
|
current_commit: u64,
|
||||||
|
current_term: u64,
|
||||||
|
log_term_at: impl Fn(u64) -> Option<u64>,
|
||||||
|
) -> u64 {
|
||||||
// Find the highest index that a majority have replicated
|
// Find the highest index that a majority have replicated
|
||||||
let mut indices: Vec<u64> = self.match_index.values().cloned().collect();
|
let mut indices: Vec<u64> = self.match_index.values().cloned().collect();
|
||||||
indices.sort_unstable();
|
indices.sort_unstable();
|
||||||
|
|
|
||||||
|
|
@ -315,8 +315,8 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_vector_field() {
|
fn test_vector_field() {
|
||||||
let schema = Schema::new("embedding")
|
let schema =
|
||||||
.field(Field::required("vector", FieldType::Vector(3)));
|
Schema::new("embedding").field(Field::required("vector", FieldType::Vector(3)));
|
||||||
|
|
||||||
let mut validator = SchemaValidator::new();
|
let mut validator = SchemaValidator::new();
|
||||||
validator.register(schema);
|
validator.register(schema);
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,7 @@
|
||||||
//! SQL query executor.
|
//! SQL query executor.
|
||||||
|
|
||||||
use super::parser::{
|
use super::parser::{
|
||||||
BinaryOp, ParsedExpr, ParsedSelect, ParsedSelectItem,
|
BinaryOp, ParsedExpr, ParsedSelect, ParsedSelectItem, ParsedStatement, SqlParser,
|
||||||
ParsedStatement, SqlParser,
|
|
||||||
};
|
};
|
||||||
use super::row::{Row, RowId};
|
use super::row::{Row, RowId};
|
||||||
use super::table::{ColumnDef, Table, TableDef};
|
use super::table::{ColumnDef, Table, TableDef};
|
||||||
|
|
@ -192,11 +191,7 @@ impl SqlEngine {
|
||||||
match a_val.partial_cmp(&b_val) {
|
match a_val.partial_cmp(&b_val) {
|
||||||
Some(std::cmp::Ordering::Equal) => continue,
|
Some(std::cmp::Ordering::Equal) => continue,
|
||||||
Some(ord) => {
|
Some(ord) => {
|
||||||
return if ob.ascending {
|
return if ob.ascending { ord } else { ord.reverse() };
|
||||||
ord
|
|
||||||
} else {
|
|
||||||
ord.reverse()
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
None => continue,
|
None => continue,
|
||||||
}
|
}
|
||||||
|
|
@ -216,7 +211,11 @@ impl SqlEngine {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle aggregates
|
// Handle aggregates
|
||||||
if select.columns.iter().any(|c| matches!(c, ParsedSelectItem::Aggregate { .. })) {
|
if select
|
||||||
|
.columns
|
||||||
|
.iter()
|
||||||
|
.any(|c| matches!(c, ParsedSelectItem::Aggregate { .. }))
|
||||||
|
{
|
||||||
return self.execute_aggregate(select, &rows, table);
|
return self.execute_aggregate(select, &rows, table);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -244,7 +243,9 @@ impl SqlEngine {
|
||||||
ParsedSelectItem::Wildcard => table.def.column_names(),
|
ParsedSelectItem::Wildcard => table.def.column_names(),
|
||||||
ParsedSelectItem::Column(name) => vec![name.clone()],
|
ParsedSelectItem::Column(name) => vec![name.clone()],
|
||||||
ParsedSelectItem::ColumnAlias { alias, .. } => vec![alias.clone()],
|
ParsedSelectItem::ColumnAlias { alias, .. } => vec![alias.clone()],
|
||||||
ParsedSelectItem::Aggregate { function, alias, .. } => {
|
ParsedSelectItem::Aggregate {
|
||||||
|
function, alias, ..
|
||||||
|
} => {
|
||||||
vec![alias.clone().unwrap_or_else(|| function.clone())]
|
vec![alias.clone().unwrap_or_else(|| function.clone())]
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
@ -328,7 +329,9 @@ impl SqlEngine {
|
||||||
rows.iter()
|
rows.iter()
|
||||||
.map(|r| r.get_or_null(col))
|
.map(|r| r.get_or_null(col))
|
||||||
.filter(|v| !v.is_null())
|
.filter(|v| !v.is_null())
|
||||||
.min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
|
.min_by(|a, b| {
|
||||||
|
a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
|
||||||
|
})
|
||||||
.unwrap_or(SqlValue::Null)
|
.unwrap_or(SqlValue::Null)
|
||||||
}
|
}
|
||||||
"MAX" => {
|
"MAX" => {
|
||||||
|
|
@ -338,12 +341,12 @@ impl SqlEngine {
|
||||||
rows.iter()
|
rows.iter()
|
||||||
.map(|r| r.get_or_null(col))
|
.map(|r| r.get_or_null(col))
|
||||||
.filter(|v| !v.is_null())
|
.filter(|v| !v.is_null())
|
||||||
.max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
|
.max_by(|a, b| {
|
||||||
|
a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
|
||||||
|
})
|
||||||
.unwrap_or(SqlValue::Null)
|
.unwrap_or(SqlValue::Null)
|
||||||
}
|
}
|
||||||
_ => {
|
_ => return Err(SqlError::Unsupported(format!("Function: {}", function))),
|
||||||
return Err(SqlError::Unsupported(format!("Function: {}", function)))
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
result_values.push(value);
|
result_values.push(value);
|
||||||
}
|
}
|
||||||
|
|
@ -404,7 +407,11 @@ impl SqlEngine {
|
||||||
ParsedExpr::IsNotNull(inner) => {
|
ParsedExpr::IsNotNull(inner) => {
|
||||||
SqlValue::Boolean(!self.evaluate_expr(row, inner).is_null())
|
SqlValue::Boolean(!self.evaluate_expr(row, inner).is_null())
|
||||||
}
|
}
|
||||||
ParsedExpr::InList { expr, list, negated } => {
|
ParsedExpr::InList {
|
||||||
|
expr,
|
||||||
|
list,
|
||||||
|
negated,
|
||||||
|
} => {
|
||||||
let val = self.evaluate_expr(row, expr);
|
let val = self.evaluate_expr(row, expr);
|
||||||
let in_list = list.iter().any(|item| {
|
let in_list = list.iter().any(|item| {
|
||||||
let item_val = self.evaluate_expr(row, item);
|
let item_val = self.evaluate_expr(row, item);
|
||||||
|
|
@ -424,9 +431,7 @@ impl SqlEngine {
|
||||||
let between = val >= low_val && val <= high_val;
|
let between = val >= low_val && val <= high_val;
|
||||||
SqlValue::Boolean(if *negated { !between } else { between })
|
SqlValue::Boolean(if *negated { !between } else { between })
|
||||||
}
|
}
|
||||||
ParsedExpr::Function { name, args } => {
|
ParsedExpr::Function { name, args } => self.evaluate_function(row, name, args),
|
||||||
self.evaluate_function(row, name, args)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -474,9 +479,7 @@ impl SqlEngine {
|
||||||
_ => SqlValue::Null,
|
_ => SqlValue::Null,
|
||||||
},
|
},
|
||||||
BinaryOp::Divide => match (left, right) {
|
BinaryOp::Divide => match (left, right) {
|
||||||
(SqlValue::Integer(a), SqlValue::Integer(b)) if *b != 0 => {
|
(SqlValue::Integer(a), SqlValue::Integer(b)) if *b != 0 => SqlValue::Integer(a / b),
|
||||||
SqlValue::Integer(a / b)
|
|
||||||
}
|
|
||||||
(SqlValue::Real(a), SqlValue::Real(b)) if *b != 0.0 => SqlValue::Real(a / b),
|
(SqlValue::Real(a), SqlValue::Real(b)) if *b != 0.0 => SqlValue::Real(a / b),
|
||||||
_ => SqlValue::Null,
|
_ => SqlValue::Null,
|
||||||
},
|
},
|
||||||
|
|
@ -536,9 +539,7 @@ impl SqlEngine {
|
||||||
/// Matches a LIKE pattern.
|
/// Matches a LIKE pattern.
|
||||||
fn match_like(&self, text: &str, pattern: &str) -> bool {
|
fn match_like(&self, text: &str, pattern: &str) -> bool {
|
||||||
// Simple LIKE implementation: % = any chars, _ = single char
|
// Simple LIKE implementation: % = any chars, _ = single char
|
||||||
let _regex_pattern = pattern
|
let _regex_pattern = pattern.replace('%', ".*").replace('_', ".");
|
||||||
.replace('%', ".*")
|
|
||||||
.replace('_', ".");
|
|
||||||
// For simplicity, just do case-insensitive contains for now
|
// For simplicity, just do case-insensitive contains for now
|
||||||
if pattern.starts_with('%') && pattern.ends_with('%') {
|
if pattern.starts_with('%') && pattern.ends_with('%') {
|
||||||
let inner = &pattern[1..pattern.len() - 1];
|
let inner = &pattern[1..pattern.len() - 1];
|
||||||
|
|
@ -615,8 +616,7 @@ impl SqlEngine {
|
||||||
.unwrap_or(true);
|
.unwrap_or(true);
|
||||||
|
|
||||||
if matches {
|
if matches {
|
||||||
let updates: HashMap<String, SqlValue> =
|
let updates: HashMap<String, SqlValue> = assignments.iter().cloned().collect();
|
||||||
assignments.iter().cloned().collect();
|
|
||||||
table.update(row.id, updates)?;
|
table.update(row.id, updates)?;
|
||||||
count += 1;
|
count += 1;
|
||||||
}
|
}
|
||||||
|
|
@ -672,9 +672,9 @@ impl SqlEngine {
|
||||||
.ok_or_else(|| SqlError::TableNotFound(table_name.to_string()))?;
|
.ok_or_else(|| SqlError::TableNotFound(table_name.to_string()))?;
|
||||||
|
|
||||||
// For simplicity, only support single-column indexes
|
// For simplicity, only support single-column indexes
|
||||||
let column = columns
|
let column = columns.first().ok_or_else(|| {
|
||||||
.first()
|
SqlError::InvalidOperation("Index requires at least one column".to_string())
|
||||||
.ok_or_else(|| SqlError::InvalidOperation("Index requires at least one column".to_string()))?;
|
})?;
|
||||||
|
|
||||||
table.create_index(name, column, unique)?;
|
table.create_index(name, column, unique)?;
|
||||||
Ok(QueryResult::empty())
|
Ok(QueryResult::empty())
|
||||||
|
|
@ -775,7 +775,9 @@ mod tests {
|
||||||
.execute("INSERT INTO users (id, name, age) VALUES (2, 'Bob', 25)")
|
.execute("INSERT INTO users (id, name, age) VALUES (2, 'Bob', 25)")
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let result = engine.execute("SELECT name FROM users WHERE age > 26").unwrap();
|
let result = engine
|
||||||
|
.execute("SELECT name FROM users WHERE age > 26")
|
||||||
|
.unwrap();
|
||||||
assert_eq!(result.rows.len(), 1);
|
assert_eq!(result.rows.len(), 1);
|
||||||
assert_eq!(result.rows[0][0], SqlValue::Text("Alice".to_string()));
|
assert_eq!(result.rows[0][0], SqlValue::Text("Alice".to_string()));
|
||||||
}
|
}
|
||||||
|
|
@ -806,7 +808,9 @@ mod tests {
|
||||||
engine
|
engine
|
||||||
.execute(&format!(
|
.execute(&format!(
|
||||||
"INSERT INTO users (id, name, age) VALUES ({}, 'User{}', {})",
|
"INSERT INTO users (id, name, age) VALUES ({}, 'User{}', {})",
|
||||||
i, i, 20 + i
|
i,
|
||||||
|
i,
|
||||||
|
20 + i
|
||||||
))
|
))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -18,10 +18,7 @@ pub enum ParsedStatement {
|
||||||
if_not_exists: bool,
|
if_not_exists: bool,
|
||||||
},
|
},
|
||||||
/// DROP TABLE statement.
|
/// DROP TABLE statement.
|
||||||
DropTable {
|
DropTable { name: String, if_exists: bool },
|
||||||
name: String,
|
|
||||||
if_exists: bool,
|
|
||||||
},
|
|
||||||
/// SELECT statement.
|
/// SELECT statement.
|
||||||
Select(ParsedSelect),
|
Select(ParsedSelect),
|
||||||
/// INSERT statement.
|
/// INSERT statement.
|
||||||
|
|
@ -179,15 +176,17 @@ impl SqlParser {
|
||||||
/// Parses a SQL statement.
|
/// Parses a SQL statement.
|
||||||
pub fn parse(sql: &str) -> Result<ParsedStatement, SqlError> {
|
pub fn parse(sql: &str) -> Result<ParsedStatement, SqlError> {
|
||||||
let dialect = SQLiteDialect {};
|
let dialect = SQLiteDialect {};
|
||||||
let statements = Parser::parse_sql(&dialect, sql)
|
let statements =
|
||||||
.map_err(|e| SqlError::Parse(e.to_string()))?;
|
Parser::parse_sql(&dialect, sql).map_err(|e| SqlError::Parse(e.to_string()))?;
|
||||||
|
|
||||||
if statements.is_empty() {
|
if statements.is_empty() {
|
||||||
return Err(SqlError::Parse("Empty SQL statement".to_string()));
|
return Err(SqlError::Parse("Empty SQL statement".to_string()));
|
||||||
}
|
}
|
||||||
|
|
||||||
if statements.len() > 1 {
|
if statements.len() > 1 {
|
||||||
return Err(SqlError::Parse("Multiple statements not supported".to_string()));
|
return Err(SqlError::Parse(
|
||||||
|
"Multiple statements not supported".to_string(),
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
Self::convert_statement(&statements[0])
|
Self::convert_statement(&statements[0])
|
||||||
|
|
@ -195,25 +194,42 @@ impl SqlParser {
|
||||||
|
|
||||||
fn convert_statement(stmt: &Statement) -> Result<ParsedStatement, SqlError> {
|
fn convert_statement(stmt: &Statement) -> Result<ParsedStatement, SqlError> {
|
||||||
match stmt {
|
match stmt {
|
||||||
Statement::CreateTable { name, columns, if_not_exists, constraints, .. } => {
|
Statement::CreateTable {
|
||||||
Self::convert_create_table(name, columns, constraints, *if_not_exists)
|
name,
|
||||||
}
|
columns,
|
||||||
Statement::Drop { object_type, names, if_exists, .. } => {
|
if_not_exists,
|
||||||
Self::convert_drop(object_type, names, *if_exists)
|
constraints,
|
||||||
}
|
..
|
||||||
|
} => Self::convert_create_table(name, columns, constraints, *if_not_exists),
|
||||||
|
Statement::Drop {
|
||||||
|
object_type,
|
||||||
|
names,
|
||||||
|
if_exists,
|
||||||
|
..
|
||||||
|
} => Self::convert_drop(object_type, names, *if_exists),
|
||||||
Statement::Query(query) => Self::convert_query(query),
|
Statement::Query(query) => Self::convert_query(query),
|
||||||
Statement::Insert { table_name, columns, source, .. } => {
|
Statement::Insert {
|
||||||
Self::convert_insert(table_name, columns, source)
|
table_name,
|
||||||
}
|
columns,
|
||||||
Statement::Update { table, assignments, selection, .. } => {
|
source,
|
||||||
Self::convert_update(table, assignments, selection)
|
..
|
||||||
}
|
} => Self::convert_insert(table_name, columns, source),
|
||||||
Statement::Delete { from, selection, .. } => {
|
Statement::Update {
|
||||||
Self::convert_delete(from, selection)
|
table,
|
||||||
}
|
assignments,
|
||||||
Statement::CreateIndex { name, table_name, columns, unique, .. } => {
|
selection,
|
||||||
Self::convert_create_index(name, table_name, columns, *unique)
|
..
|
||||||
}
|
} => Self::convert_update(table, assignments, selection),
|
||||||
|
Statement::Delete {
|
||||||
|
from, selection, ..
|
||||||
|
} => Self::convert_delete(from, selection),
|
||||||
|
Statement::CreateIndex {
|
||||||
|
name,
|
||||||
|
table_name,
|
||||||
|
columns,
|
||||||
|
unique,
|
||||||
|
..
|
||||||
|
} => Self::convert_create_index(name, table_name, columns, *unique),
|
||||||
_ => Err(SqlError::Unsupported(format!("Statement not supported"))),
|
_ => Err(SqlError::Unsupported(format!("Statement not supported"))),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -230,7 +246,12 @@ impl SqlParser {
|
||||||
|
|
||||||
// Extract primary keys from table constraints
|
// Extract primary keys from table constraints
|
||||||
for constraint in constraints {
|
for constraint in constraints {
|
||||||
if let sqlparser::ast::TableConstraint::Unique { columns: pk_cols, is_primary: true, .. } = constraint {
|
if let sqlparser::ast::TableConstraint::Unique {
|
||||||
|
columns: pk_cols,
|
||||||
|
is_primary: true,
|
||||||
|
..
|
||||||
|
} = constraint
|
||||||
|
{
|
||||||
for col in pk_cols {
|
for col in pk_cols {
|
||||||
primary_keys.push(col.value.clone());
|
primary_keys.push(col.value.clone());
|
||||||
}
|
}
|
||||||
|
|
@ -296,10 +317,9 @@ impl SqlParser {
|
||||||
DataType::Real | DataType::Float(_) | DataType::Double | DataType::DoublePrecision => {
|
DataType::Real | DataType::Float(_) | DataType::Double | DataType::DoublePrecision => {
|
||||||
Ok(SqlType::Real)
|
Ok(SqlType::Real)
|
||||||
}
|
}
|
||||||
DataType::Varchar(_)
|
DataType::Varchar(_) | DataType::Char(_) | DataType::Text | DataType::String(_) => {
|
||||||
| DataType::Char(_)
|
Ok(SqlType::Text)
|
||||||
| DataType::Text
|
}
|
||||||
| DataType::String(_) => Ok(SqlType::Text),
|
|
||||||
DataType::Binary(_) | DataType::Varbinary(_) | DataType::Blob(_) => Ok(SqlType::Blob),
|
DataType::Binary(_) | DataType::Varbinary(_) | DataType::Blob(_) => Ok(SqlType::Blob),
|
||||||
DataType::Boolean => Ok(SqlType::Boolean),
|
DataType::Boolean => Ok(SqlType::Boolean),
|
||||||
DataType::Timestamp(_, _) | DataType::Date | DataType::Datetime(_) => {
|
DataType::Timestamp(_, _) | DataType::Date | DataType::Datetime(_) => {
|
||||||
|
|
@ -367,10 +387,7 @@ impl SqlParser {
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
// Parse LIMIT/OFFSET
|
// Parse LIMIT/OFFSET
|
||||||
let limit = query
|
let limit = query.limit.as_ref().and_then(|l| Self::expr_to_usize(l));
|
||||||
.limit
|
|
||||||
.as_ref()
|
|
||||||
.and_then(|l| Self::expr_to_usize(l));
|
|
||||||
let offset = query
|
let offset = query
|
||||||
.offset
|
.offset
|
||||||
.as_ref()
|
.as_ref()
|
||||||
|
|
@ -403,16 +420,18 @@ impl SqlParser {
|
||||||
Self::convert_select_expr(expr)
|
Self::convert_select_expr(expr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ => Err(SqlError::Unsupported("Select item not supported".to_string())),
|
_ => Err(SqlError::Unsupported(
|
||||||
|
"Select item not supported".to_string(),
|
||||||
|
)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn convert_select_expr(expr: &Expr) -> Result<ParsedSelectItem, SqlError> {
|
fn convert_select_expr(expr: &Expr) -> Result<ParsedSelectItem, SqlError> {
|
||||||
match expr {
|
match expr {
|
||||||
Expr::Identifier(id) => Ok(ParsedSelectItem::Column(id.value.clone())),
|
Expr::Identifier(id) => Ok(ParsedSelectItem::Column(id.value.clone())),
|
||||||
Expr::CompoundIdentifier(ids) => {
|
Expr::CompoundIdentifier(ids) => Ok(ParsedSelectItem::Column(
|
||||||
Ok(ParsedSelectItem::Column(ids.last().map(|i| i.value.clone()).unwrap_or_default()))
|
ids.last().map(|i| i.value.clone()).unwrap_or_default(),
|
||||||
}
|
)),
|
||||||
Expr::Function(func) => {
|
Expr::Function(func) => {
|
||||||
let name = func.name.to_string().to_uppercase();
|
let name = func.name.to_string().to_uppercase();
|
||||||
// Try to extract column from first arg - simplified for compatibility
|
// Try to extract column from first arg - simplified for compatibility
|
||||||
|
|
@ -423,14 +442,18 @@ impl SqlParser {
|
||||||
alias: None,
|
alias: None,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
_ => Err(SqlError::Unsupported("Select expression not supported".to_string())),
|
_ => Err(SqlError::Unsupported(
|
||||||
|
"Select expression not supported".to_string(),
|
||||||
|
)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn convert_table_factor(factor: &TableFactor) -> Result<String, SqlError> {
|
fn convert_table_factor(factor: &TableFactor) -> Result<String, SqlError> {
|
||||||
match factor {
|
match factor {
|
||||||
TableFactor::Table { name, .. } => Ok(name.to_string()),
|
TableFactor::Table { name, .. } => Ok(name.to_string()),
|
||||||
_ => Err(SqlError::Unsupported("Table factor not supported".to_string())),
|
_ => Err(SqlError::Unsupported(
|
||||||
|
"Table factor not supported".to_string(),
|
||||||
|
)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -461,9 +484,9 @@ impl SqlParser {
|
||||||
fn convert_expr(expr: &Expr) -> Result<ParsedExpr, SqlError> {
|
fn convert_expr(expr: &Expr) -> Result<ParsedExpr, SqlError> {
|
||||||
match expr {
|
match expr {
|
||||||
Expr::Identifier(id) => Ok(ParsedExpr::Column(id.value.clone())),
|
Expr::Identifier(id) => Ok(ParsedExpr::Column(id.value.clone())),
|
||||||
Expr::CompoundIdentifier(ids) => {
|
Expr::CompoundIdentifier(ids) => Ok(ParsedExpr::Column(
|
||||||
Ok(ParsedExpr::Column(ids.last().map(|i| i.value.clone()).unwrap_or_default()))
|
ids.last().map(|i| i.value.clone()).unwrap_or_default(),
|
||||||
}
|
)),
|
||||||
Expr::Value(v) => Ok(ParsedExpr::Literal(Self::convert_value(v)?)),
|
Expr::Value(v) => Ok(ParsedExpr::Literal(Self::convert_value(v)?)),
|
||||||
Expr::BinaryOp { left, op, right } => {
|
Expr::BinaryOp { left, op, right } => {
|
||||||
let left = Box::new(Self::convert_expr(left)?);
|
let left = Box::new(Self::convert_expr(left)?);
|
||||||
|
|
@ -471,17 +494,30 @@ impl SqlParser {
|
||||||
let op = Self::convert_binary_op(op)?;
|
let op = Self::convert_binary_op(op)?;
|
||||||
Ok(ParsedExpr::BinaryOp { left, op, right })
|
Ok(ParsedExpr::BinaryOp { left, op, right })
|
||||||
}
|
}
|
||||||
Expr::UnaryOp { op: sqlparser::ast::UnaryOperator::Not, expr } => {
|
Expr::UnaryOp {
|
||||||
Ok(ParsedExpr::Not(Box::new(Self::convert_expr(expr)?)))
|
op: sqlparser::ast::UnaryOperator::Not,
|
||||||
}
|
expr,
|
||||||
|
} => Ok(ParsedExpr::Not(Box::new(Self::convert_expr(expr)?))),
|
||||||
Expr::IsNull(expr) => Ok(ParsedExpr::IsNull(Box::new(Self::convert_expr(expr)?))),
|
Expr::IsNull(expr) => Ok(ParsedExpr::IsNull(Box::new(Self::convert_expr(expr)?))),
|
||||||
Expr::IsNotNull(expr) => Ok(ParsedExpr::IsNotNull(Box::new(Self::convert_expr(expr)?))),
|
Expr::IsNotNull(expr) => Ok(ParsedExpr::IsNotNull(Box::new(Self::convert_expr(expr)?))),
|
||||||
Expr::InList { expr, list, negated } => Ok(ParsedExpr::InList {
|
Expr::InList {
|
||||||
|
expr,
|
||||||
|
list,
|
||||||
|
negated,
|
||||||
|
} => Ok(ParsedExpr::InList {
|
||||||
expr: Box::new(Self::convert_expr(expr)?),
|
expr: Box::new(Self::convert_expr(expr)?),
|
||||||
list: list.iter().map(Self::convert_expr).collect::<Result<_, _>>()?,
|
list: list
|
||||||
|
.iter()
|
||||||
|
.map(Self::convert_expr)
|
||||||
|
.collect::<Result<_, _>>()?,
|
||||||
negated: *negated,
|
negated: *negated,
|
||||||
}),
|
}),
|
||||||
Expr::Between { expr, low, high, negated } => Ok(ParsedExpr::Between {
|
Expr::Between {
|
||||||
|
expr,
|
||||||
|
low,
|
||||||
|
high,
|
||||||
|
negated,
|
||||||
|
} => Ok(ParsedExpr::Between {
|
||||||
expr: Box::new(Self::convert_expr(expr)?),
|
expr: Box::new(Self::convert_expr(expr)?),
|
||||||
low: Box::new(Self::convert_expr(low)?),
|
low: Box::new(Self::convert_expr(low)?),
|
||||||
high: Box::new(Self::convert_expr(high)?),
|
high: Box::new(Self::convert_expr(high)?),
|
||||||
|
|
@ -490,10 +526,16 @@ impl SqlParser {
|
||||||
Expr::Like { expr, pattern, .. } => {
|
Expr::Like { expr, pattern, .. } => {
|
||||||
let left = Box::new(Self::convert_expr(expr)?);
|
let left = Box::new(Self::convert_expr(expr)?);
|
||||||
let right = Box::new(Self::convert_expr(pattern)?);
|
let right = Box::new(Self::convert_expr(pattern)?);
|
||||||
Ok(ParsedExpr::BinaryOp { left, op: BinaryOp::Like, right })
|
Ok(ParsedExpr::BinaryOp {
|
||||||
|
left,
|
||||||
|
op: BinaryOp::Like,
|
||||||
|
right,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
Expr::Nested(inner) => Self::convert_expr(inner),
|
Expr::Nested(inner) => Self::convert_expr(inner),
|
||||||
_ => Err(SqlError::Unsupported("Expression not supported".to_string())),
|
_ => Err(SqlError::Unsupported(
|
||||||
|
"Expression not supported".to_string(),
|
||||||
|
)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -587,7 +629,11 @@ impl SqlParser {
|
||||||
let parsed_assignments: Vec<(String, SqlValue)> = assignments
|
let parsed_assignments: Vec<(String, SqlValue)> = assignments
|
||||||
.iter()
|
.iter()
|
||||||
.map(|a| {
|
.map(|a| {
|
||||||
let col = a.id.iter().map(|i| i.value.clone()).collect::<Vec<_>>().join(".");
|
let col =
|
||||||
|
a.id.iter()
|
||||||
|
.map(|i| i.value.clone())
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join(".");
|
||||||
let val = Self::convert_value_expr(&a.value)?;
|
let val = Self::convert_value_expr(&a.value)?;
|
||||||
Ok((col, val))
|
Ok((col, val))
|
||||||
})
|
})
|
||||||
|
|
@ -633,10 +679,7 @@ impl SqlParser {
|
||||||
|
|
||||||
let table = table_name.to_string();
|
let table = table_name.to_string();
|
||||||
|
|
||||||
let cols: Vec<String> = columns
|
let cols: Vec<String> = columns.iter().map(|c| c.expr.to_string()).collect();
|
||||||
.iter()
|
|
||||||
.map(|c| c.expr.to_string())
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
Ok(ParsedStatement::CreateIndex {
|
Ok(ParsedStatement::CreateIndex {
|
||||||
name: index_name,
|
name: index_name,
|
||||||
|
|
@ -694,7 +737,12 @@ mod tests {
|
||||||
let sql = "INSERT INTO users (name, age) VALUES ('Alice', 30), ('Bob', 25)";
|
let sql = "INSERT INTO users (name, age) VALUES ('Alice', 30), ('Bob', 25)";
|
||||||
let stmt = SqlParser::parse(sql).unwrap();
|
let stmt = SqlParser::parse(sql).unwrap();
|
||||||
|
|
||||||
if let ParsedStatement::Insert { table, columns, values } = stmt {
|
if let ParsedStatement::Insert {
|
||||||
|
table,
|
||||||
|
columns,
|
||||||
|
values,
|
||||||
|
} = stmt
|
||||||
|
{
|
||||||
assert_eq!(table, "users");
|
assert_eq!(table, "users");
|
||||||
assert_eq!(columns, vec!["name", "age"]);
|
assert_eq!(columns, vec!["name", "age"]);
|
||||||
assert_eq!(values.len(), 2);
|
assert_eq!(values.len(), 2);
|
||||||
|
|
@ -708,7 +756,12 @@ mod tests {
|
||||||
let sql = "UPDATE users SET age = 31 WHERE name = 'Alice'";
|
let sql = "UPDATE users SET age = 31 WHERE name = 'Alice'";
|
||||||
let stmt = SqlParser::parse(sql).unwrap();
|
let stmt = SqlParser::parse(sql).unwrap();
|
||||||
|
|
||||||
if let ParsedStatement::Update { table, assignments, where_clause } = stmt {
|
if let ParsedStatement::Update {
|
||||||
|
table,
|
||||||
|
assignments,
|
||||||
|
where_clause,
|
||||||
|
} = stmt
|
||||||
|
{
|
||||||
assert_eq!(table, "users");
|
assert_eq!(table, "users");
|
||||||
assert_eq!(assignments.len(), 1);
|
assert_eq!(assignments.len(), 1);
|
||||||
assert!(where_clause.is_some());
|
assert!(where_clause.is_some());
|
||||||
|
|
@ -722,7 +775,11 @@ mod tests {
|
||||||
let sql = "DELETE FROM users WHERE age < 18";
|
let sql = "DELETE FROM users WHERE age < 18";
|
||||||
let stmt = SqlParser::parse(sql).unwrap();
|
let stmt = SqlParser::parse(sql).unwrap();
|
||||||
|
|
||||||
if let ParsedStatement::Delete { table, where_clause } = stmt {
|
if let ParsedStatement::Delete {
|
||||||
|
table,
|
||||||
|
where_clause,
|
||||||
|
} = stmt
|
||||||
|
{
|
||||||
assert_eq!(table, "users");
|
assert_eq!(table, "users");
|
||||||
assert!(where_clause.is_some());
|
assert!(where_clause.is_some());
|
||||||
} else {
|
} else {
|
||||||
|
|
|
||||||
|
|
@ -85,7 +85,10 @@ impl Row {
|
||||||
|
|
||||||
/// Returns all values in column order.
|
/// Returns all values in column order.
|
||||||
pub fn values(&self) -> Vec<&SqlValue> {
|
pub fn values(&self) -> Vec<&SqlValue> {
|
||||||
self.columns.iter().map(|c| self.values.get(c).unwrap()).collect()
|
self.columns
|
||||||
|
.iter()
|
||||||
|
.map(|c| self.values.get(c).unwrap())
|
||||||
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the number of columns.
|
/// Returns the number of columns.
|
||||||
|
|
|
||||||
|
|
@ -299,7 +299,10 @@ impl Table {
|
||||||
|
|
||||||
let mut indexes = self.indexes.write();
|
let mut indexes = self.indexes.write();
|
||||||
if indexes.contains_key(&name) {
|
if indexes.contains_key(&name) {
|
||||||
return Err(SqlError::InvalidOperation(format!("Index '{}' already exists", name)));
|
return Err(SqlError::InvalidOperation(format!(
|
||||||
|
"Index '{}' already exists",
|
||||||
|
name
|
||||||
|
)));
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut index = TableIndex::new(&name, &column, unique);
|
let mut index = TableIndex::new(&name, &column, unique);
|
||||||
|
|
@ -319,7 +322,10 @@ impl Table {
|
||||||
pub fn drop_index(&self, name: &str) -> Result<(), SqlError> {
|
pub fn drop_index(&self, name: &str) -> Result<(), SqlError> {
|
||||||
let mut indexes = self.indexes.write();
|
let mut indexes = self.indexes.write();
|
||||||
if indexes.remove(name).is_none() {
|
if indexes.remove(name).is_none() {
|
||||||
return Err(SqlError::InvalidOperation(format!("Index '{}' not found", name)));
|
return Err(SqlError::InvalidOperation(format!(
|
||||||
|
"Index '{}' not found",
|
||||||
|
name
|
||||||
|
)));
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
@ -371,9 +377,9 @@ impl Table {
|
||||||
/// Updates a row.
|
/// Updates a row.
|
||||||
pub fn update(&self, id: RowId, updates: HashMap<String, SqlValue>) -> Result<(), SqlError> {
|
pub fn update(&self, id: RowId, updates: HashMap<String, SqlValue>) -> Result<(), SqlError> {
|
||||||
let mut rows = self.rows.write();
|
let mut rows = self.rows.write();
|
||||||
let row = rows.get_mut(&id).ok_or_else(|| {
|
let row = rows
|
||||||
SqlError::InvalidOperation(format!("Row {} not found", id))
|
.get_mut(&id)
|
||||||
})?;
|
.ok_or_else(|| SqlError::InvalidOperation(format!("Row {} not found", id)))?;
|
||||||
|
|
||||||
let old_values: HashMap<String, SqlValue> = updates
|
let old_values: HashMap<String, SqlValue> = updates
|
||||||
.keys()
|
.keys()
|
||||||
|
|
@ -392,7 +398,10 @@ impl Table {
|
||||||
let mut indexes = self.indexes.write();
|
let mut indexes = self.indexes.write();
|
||||||
for (_, index) in indexes.iter_mut() {
|
for (_, index) in indexes.iter_mut() {
|
||||||
if let Some(new_value) = updates.get(&index.column) {
|
if let Some(new_value) = updates.get(&index.column) {
|
||||||
let old_value = old_values.get(&index.column).cloned().unwrap_or(SqlValue::Null);
|
let old_value = old_values
|
||||||
|
.get(&index.column)
|
||||||
|
.cloned()
|
||||||
|
.unwrap_or(SqlValue::Null);
|
||||||
index.remove(&old_value, &id);
|
index.remove(&old_value, &id);
|
||||||
index.insert(new_value.clone(), id)?;
|
index.insert(new_value.clone(), id)?;
|
||||||
}
|
}
|
||||||
|
|
@ -480,7 +489,10 @@ mod tests {
|
||||||
values.insert("id".to_string(), SqlValue::Integer(1));
|
values.insert("id".to_string(), SqlValue::Integer(1));
|
||||||
values.insert("name".to_string(), SqlValue::Text("Alice".to_string()));
|
values.insert("name".to_string(), SqlValue::Text("Alice".to_string()));
|
||||||
values.insert("age".to_string(), SqlValue::Integer(30));
|
values.insert("age".to_string(), SqlValue::Integer(30));
|
||||||
values.insert("email".to_string(), SqlValue::Text("alice@example.com".to_string()));
|
values.insert(
|
||||||
|
"email".to_string(),
|
||||||
|
SqlValue::Text("alice@example.com".to_string()),
|
||||||
|
);
|
||||||
|
|
||||||
let row_id = table.insert(values).unwrap();
|
let row_id = table.insert(values).unwrap();
|
||||||
assert_eq!(table.count(), 1);
|
assert_eq!(table.count(), 1);
|
||||||
|
|
@ -508,13 +520,19 @@ mod tests {
|
||||||
let mut values1 = HashMap::new();
|
let mut values1 = HashMap::new();
|
||||||
values1.insert("id".to_string(), SqlValue::Integer(1));
|
values1.insert("id".to_string(), SqlValue::Integer(1));
|
||||||
values1.insert("name".to_string(), SqlValue::Text("Alice".to_string()));
|
values1.insert("name".to_string(), SqlValue::Text("Alice".to_string()));
|
||||||
values1.insert("email".to_string(), SqlValue::Text("test@example.com".to_string()));
|
values1.insert(
|
||||||
|
"email".to_string(),
|
||||||
|
SqlValue::Text("test@example.com".to_string()),
|
||||||
|
);
|
||||||
table.insert(values1).unwrap();
|
table.insert(values1).unwrap();
|
||||||
|
|
||||||
let mut values2 = HashMap::new();
|
let mut values2 = HashMap::new();
|
||||||
values2.insert("id".to_string(), SqlValue::Integer(2));
|
values2.insert("id".to_string(), SqlValue::Integer(2));
|
||||||
values2.insert("name".to_string(), SqlValue::Text("Bob".to_string()));
|
values2.insert("name".to_string(), SqlValue::Text("Bob".to_string()));
|
||||||
values2.insert("email".to_string(), SqlValue::Text("test@example.com".to_string()));
|
values2.insert(
|
||||||
|
"email".to_string(),
|
||||||
|
SqlValue::Text("test@example.com".to_string()),
|
||||||
|
);
|
||||||
|
|
||||||
let result = table.insert(values2);
|
let result = table.insert(values2);
|
||||||
assert!(result.is_err()); // Duplicate email
|
assert!(result.is_err()); // Duplicate email
|
||||||
|
|
|
||||||
|
|
@ -124,7 +124,12 @@ impl Transaction {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Records an insert operation.
|
/// Records an insert operation.
|
||||||
pub fn record_insert(&mut self, table: String, row_id: RowId, values: HashMap<String, SqlValue>) {
|
pub fn record_insert(
|
||||||
|
&mut self,
|
||||||
|
table: String,
|
||||||
|
row_id: RowId,
|
||||||
|
values: HashMap<String, SqlValue>,
|
||||||
|
) {
|
||||||
self.operations.push(TransactionOp::Insert {
|
self.operations.push(TransactionOp::Insert {
|
||||||
table,
|
table,
|
||||||
row_id,
|
row_id,
|
||||||
|
|
@ -149,7 +154,12 @@ impl Transaction {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Records a delete operation.
|
/// Records a delete operation.
|
||||||
pub fn record_delete(&mut self, table: String, row_id: RowId, old_values: HashMap<String, SqlValue>) {
|
pub fn record_delete(
|
||||||
|
&mut self,
|
||||||
|
table: String,
|
||||||
|
row_id: RowId,
|
||||||
|
old_values: HashMap<String, SqlValue>,
|
||||||
|
) {
|
||||||
self.operations.push(TransactionOp::Delete {
|
self.operations.push(TransactionOp::Delete {
|
||||||
table,
|
table,
|
||||||
row_id,
|
row_id,
|
||||||
|
|
@ -213,7 +223,10 @@ impl TransactionManager {
|
||||||
.ok_or_else(|| SqlError::Transaction(format!("Transaction {} not found", id)))?;
|
.ok_or_else(|| SqlError::Transaction(format!("Transaction {} not found", id)))?;
|
||||||
|
|
||||||
if !txn.is_active() {
|
if !txn.is_active() {
|
||||||
return Err(SqlError::Transaction(format!("Transaction {} is not active", id)));
|
return Err(SqlError::Transaction(format!(
|
||||||
|
"Transaction {} is not active",
|
||||||
|
id
|
||||||
|
)));
|
||||||
}
|
}
|
||||||
|
|
||||||
txn.operations.push(op);
|
txn.operations.push(op);
|
||||||
|
|
@ -228,7 +241,10 @@ impl TransactionManager {
|
||||||
.ok_or_else(|| SqlError::Transaction(format!("Transaction {} not found", id)))?;
|
.ok_or_else(|| SqlError::Transaction(format!("Transaction {} not found", id)))?;
|
||||||
|
|
||||||
if !txn.is_active() {
|
if !txn.is_active() {
|
||||||
return Err(SqlError::Transaction(format!("Transaction {} is not active", id)));
|
return Err(SqlError::Transaction(format!(
|
||||||
|
"Transaction {} is not active",
|
||||||
|
id
|
||||||
|
)));
|
||||||
}
|
}
|
||||||
|
|
||||||
txn.mark_committed();
|
txn.mark_committed();
|
||||||
|
|
@ -245,7 +261,10 @@ impl TransactionManager {
|
||||||
.ok_or_else(|| SqlError::Transaction(format!("Transaction {} not found", id)))?;
|
.ok_or_else(|| SqlError::Transaction(format!("Transaction {} not found", id)))?;
|
||||||
|
|
||||||
if !txn.is_active() {
|
if !txn.is_active() {
|
||||||
return Err(SqlError::Transaction(format!("Transaction {} is not active", id)));
|
return Err(SqlError::Transaction(format!(
|
||||||
|
"Transaction {} is not active",
|
||||||
|
id
|
||||||
|
)));
|
||||||
}
|
}
|
||||||
|
|
||||||
txn.mark_rolled_back();
|
txn.mark_rolled_back();
|
||||||
|
|
|
||||||
|
|
@ -233,12 +233,8 @@ impl Ord for SqlValue {
|
||||||
(SqlValue::Blob(a), SqlValue::Blob(b)) => a.cmp(b),
|
(SqlValue::Blob(a), SqlValue::Blob(b)) => a.cmp(b),
|
||||||
(SqlValue::Boolean(a), SqlValue::Boolean(b)) => a.cmp(b),
|
(SqlValue::Boolean(a), SqlValue::Boolean(b)) => a.cmp(b),
|
||||||
(SqlValue::Timestamp(a), SqlValue::Timestamp(b)) => a.cmp(b),
|
(SqlValue::Timestamp(a), SqlValue::Timestamp(b)) => a.cmp(b),
|
||||||
(SqlValue::Integer(a), SqlValue::Real(b)) => {
|
(SqlValue::Integer(a), SqlValue::Real(b)) => (*a as f64).to_bits().cmp(&b.to_bits()),
|
||||||
(*a as f64).to_bits().cmp(&b.to_bits())
|
(SqlValue::Real(a), SqlValue::Integer(b)) => a.to_bits().cmp(&(*b as f64).to_bits()),
|
||||||
}
|
|
||||||
(SqlValue::Real(a), SqlValue::Integer(b)) => {
|
|
||||||
a.to_bits().cmp(&(*b as f64).to_bits())
|
|
||||||
}
|
|
||||||
// Different types: order by type discriminant
|
// Different types: order by type discriminant
|
||||||
_ => self.type_order().cmp(&other.type_order()),
|
_ => self.type_order().cmp(&other.type_order()),
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -158,11 +158,7 @@ impl Metric {
|
||||||
|
|
||||||
/// Calculates sum in a time range.
|
/// Calculates sum in a time range.
|
||||||
pub fn sum(&self, start: u64, end: u64) -> f64 {
|
pub fn sum(&self, start: u64, end: u64) -> f64 {
|
||||||
self.data
|
self.data.read().range(start..=end).map(|(_, &v)| v).sum()
|
||||||
.read()
|
|
||||||
.range(start..=end)
|
|
||||||
.map(|(_, &v)| v)
|
|
||||||
.sum()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Counts data points in a time range.
|
/// Counts data points in a time range.
|
||||||
|
|
|
||||||
|
|
@ -207,9 +207,7 @@ impl VectorIndex {
|
||||||
let embeddings = self.embeddings.read();
|
let embeddings = self.embeddings.read();
|
||||||
let mut results: Vec<VectorSearchResult> = embeddings
|
let mut results: Vec<VectorSearchResult> = embeddings
|
||||||
.values()
|
.values()
|
||||||
.filter(|e| {
|
.filter(|e| namespace.map(|ns| e.namespace == ns).unwrap_or(true))
|
||||||
namespace.map(|ns| e.namespace == ns).unwrap_or(true)
|
|
||||||
})
|
|
||||||
.map(|e| {
|
.map(|e| {
|
||||||
let score = self.calculate_similarity(&e.vector, query);
|
let score = self.calculate_similarity(&e.vector, query);
|
||||||
VectorSearchResult {
|
VectorSearchResult {
|
||||||
|
|
@ -217,14 +215,14 @@ impl VectorIndex {
|
||||||
score,
|
score,
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.filter(|r| {
|
.filter(|r| threshold.map(|t| r.score >= t).unwrap_or(true))
|
||||||
threshold.map(|t| r.score >= t).unwrap_or(true)
|
|
||||||
})
|
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
// Sort by score descending
|
// Sort by score descending
|
||||||
results.sort_by(|a, b| {
|
results.sort_by(|a, b| {
|
||||||
b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal)
|
b.score
|
||||||
|
.partial_cmp(&a.score)
|
||||||
|
.unwrap_or(std::cmp::Ordering::Equal)
|
||||||
});
|
});
|
||||||
|
|
||||||
// Apply limit
|
// Apply limit
|
||||||
|
|
@ -234,8 +232,9 @@ impl VectorIndex {
|
||||||
let elapsed = start.elapsed().as_millis() as f64;
|
let elapsed = start.elapsed().as_millis() as f64;
|
||||||
let mut stats = self.stats.write();
|
let mut stats = self.stats.write();
|
||||||
stats.searches += 1;
|
stats.searches += 1;
|
||||||
stats.avg_search_time_ms =
|
stats.avg_search_time_ms = (stats.avg_search_time_ms * (stats.searches - 1) as f64
|
||||||
(stats.avg_search_time_ms * (stats.searches - 1) as f64 + elapsed) / stats.searches as f64;
|
+ elapsed)
|
||||||
|
/ stats.searches as f64;
|
||||||
|
|
||||||
Ok(results)
|
Ok(results)
|
||||||
}
|
}
|
||||||
|
|
@ -329,7 +328,8 @@ impl VectorStore {
|
||||||
namespace: Option<&str>,
|
namespace: Option<&str>,
|
||||||
threshold: Option<f32>,
|
threshold: Option<f32>,
|
||||||
) -> Result<Vec<VectorSearchResult>, DatabaseError> {
|
) -> Result<Vec<VectorSearchResult>, DatabaseError> {
|
||||||
self.default_index.search(query, limit, namespace, threshold)
|
self.default_index
|
||||||
|
.search(query, limit, namespace, threshold)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Gets an embedding by ID.
|
/// Gets an embedding by ID.
|
||||||
|
|
@ -388,10 +388,7 @@ pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
|
||||||
|
|
||||||
/// Manhattan distance (L1) between two vectors.
|
/// Manhattan distance (L1) between two vectors.
|
||||||
pub fn manhattan_distance(a: &[f32], b: &[f32]) -> f32 {
|
pub fn manhattan_distance(a: &[f32], b: &[f32]) -> f32 {
|
||||||
a.iter()
|
a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum()
|
||||||
.zip(b.iter())
|
|
||||||
.map(|(x, y)| (x - y).abs())
|
|
||||||
.sum()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|
@ -414,9 +411,15 @@ mod tests {
|
||||||
fn test_vector_insert_search() {
|
fn test_vector_insert_search() {
|
||||||
let store = VectorStore::new(3);
|
let store = VectorStore::new(3);
|
||||||
|
|
||||||
store.insert(Embedding::new("a", vec![1.0, 0.0, 0.0])).unwrap();
|
store
|
||||||
store.insert(Embedding::new("b", vec![0.9, 0.1, 0.0])).unwrap();
|
.insert(Embedding::new("a", vec![1.0, 0.0, 0.0]))
|
||||||
store.insert(Embedding::new("c", vec![0.0, 1.0, 0.0])).unwrap();
|
.unwrap();
|
||||||
|
store
|
||||||
|
.insert(Embedding::new("b", vec![0.9, 0.1, 0.0]))
|
||||||
|
.unwrap();
|
||||||
|
store
|
||||||
|
.insert(Embedding::new("c", vec![0.0, 1.0, 0.0]))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let results = store.search(&[1.0, 0.0, 0.0], 2, None, None).unwrap();
|
let results = store.search(&[1.0, 0.0, 0.0], 2, None, None).unwrap();
|
||||||
|
|
||||||
|
|
@ -429,8 +432,12 @@ mod tests {
|
||||||
fn test_similarity_threshold() {
|
fn test_similarity_threshold() {
|
||||||
let store = VectorStore::new(3);
|
let store = VectorStore::new(3);
|
||||||
|
|
||||||
store.insert(Embedding::new("a", vec![1.0, 0.0, 0.0])).unwrap();
|
store
|
||||||
store.insert(Embedding::new("b", vec![0.0, 1.0, 0.0])).unwrap();
|
.insert(Embedding::new("a", vec![1.0, 0.0, 0.0]))
|
||||||
|
.unwrap();
|
||||||
|
store
|
||||||
|
.insert(Embedding::new("b", vec![0.0, 1.0, 0.0]))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let results = store.search(&[1.0, 0.0, 0.0], 10, None, Some(0.5)).unwrap();
|
let results = store.search(&[1.0, 0.0, 0.0], 10, None, Some(0.5)).unwrap();
|
||||||
|
|
||||||
|
|
@ -443,14 +450,16 @@ mod tests {
|
||||||
fn test_namespace_filter() {
|
fn test_namespace_filter() {
|
||||||
let store = VectorStore::new(3);
|
let store = VectorStore::new(3);
|
||||||
|
|
||||||
store.insert(
|
store
|
||||||
Embedding::new("a", vec![1.0, 0.0, 0.0]).with_namespace("ns1")
|
.insert(Embedding::new("a", vec![1.0, 0.0, 0.0]).with_namespace("ns1"))
|
||||||
).unwrap();
|
.unwrap();
|
||||||
store.insert(
|
store
|
||||||
Embedding::new("b", vec![1.0, 0.0, 0.0]).with_namespace("ns2")
|
.insert(Embedding::new("b", vec![1.0, 0.0, 0.0]).with_namespace("ns2"))
|
||||||
).unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let results = store.search(&[1.0, 0.0, 0.0], 10, Some("ns1"), None).unwrap();
|
let results = store
|
||||||
|
.search(&[1.0, 0.0, 0.0], 10, Some("ns1"), None)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
assert_eq!(results.len(), 1);
|
assert_eq!(results.len(), 1);
|
||||||
assert_eq!(results[0].embedding.id, "a");
|
assert_eq!(results[0].embedding.id, "a");
|
||||||
|
|
|
||||||
|
|
@ -184,9 +184,7 @@ impl Credit {
|
||||||
|
|
||||||
/// Check if credit is expired
|
/// Check if credit is expired
|
||||||
pub fn is_expired(&self) -> bool {
|
pub fn is_expired(&self) -> bool {
|
||||||
self.expires_at
|
self.expires_at.map(|exp| Utc::now() > exp).unwrap_or(false)
|
||||||
.map(|exp| Utc::now() > exp)
|
|
||||||
.unwrap_or(false)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get remaining amount
|
/// Get remaining amount
|
||||||
|
|
@ -241,14 +239,21 @@ impl std::fmt::Display for CreditError {
|
||||||
match self {
|
match self {
|
||||||
CreditError::CreditInactive => write!(f, "Credit is no longer active"),
|
CreditError::CreditInactive => write!(f, "Credit is no longer active"),
|
||||||
CreditError::CreditExpired => write!(f, "Credit has expired"),
|
CreditError::CreditExpired => write!(f, "Credit has expired"),
|
||||||
CreditError::InsufficientCredit { requested, available } => {
|
CreditError::InsufficientCredit {
|
||||||
|
requested,
|
||||||
|
available,
|
||||||
|
} => {
|
||||||
write!(
|
write!(
|
||||||
f,
|
f,
|
||||||
"Insufficient credit: requested {}, available {}",
|
"Insufficient credit: requested {}, available {}",
|
||||||
requested, available
|
requested, available
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
CreditError::ExceedsMaxCredit { current, requested, maximum } => {
|
CreditError::ExceedsMaxCredit {
|
||||||
|
current,
|
||||||
|
requested,
|
||||||
|
maximum,
|
||||||
|
} => {
|
||||||
write!(
|
write!(
|
||||||
f,
|
f,
|
||||||
"Credit exceeds maximum: current {}, requested {}, maximum {}",
|
"Credit exceeds maximum: current {}, requested {}, maximum {}",
|
||||||
|
|
@ -279,9 +284,9 @@ pub struct CreditPolicy {
|
||||||
impl Default for CreditPolicy {
|
impl Default for CreditPolicy {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
welcome_amount: Decimal::new(10, 0), // 10 SYNOR
|
welcome_amount: Decimal::new(10, 0), // 10 SYNOR
|
||||||
referral_referrer_amount: Decimal::new(25, 0), // 25 SYNOR
|
referral_referrer_amount: Decimal::new(25, 0), // 25 SYNOR
|
||||||
referral_referee_amount: Decimal::new(10, 0), // 10 SYNOR
|
referral_referee_amount: Decimal::new(10, 0), // 10 SYNOR
|
||||||
max_credit_per_account: Decimal::new(1000, 0), // 1000 SYNOR
|
max_credit_per_account: Decimal::new(1000, 0), // 1000 SYNOR
|
||||||
default_expiry_days: 365,
|
default_expiry_days: 365,
|
||||||
}
|
}
|
||||||
|
|
@ -334,12 +339,20 @@ impl CreditManager {
|
||||||
let referee_id = referee_id.into();
|
let referee_id = referee_id.into();
|
||||||
|
|
||||||
// Credit for the referrer
|
// Credit for the referrer
|
||||||
let referrer_credit = Credit::referral(&referrer_id, self.policy.referral_referrer_amount, &referee_id)
|
let referrer_credit = Credit::referral(
|
||||||
.with_expiry_days(self.policy.default_expiry_days);
|
&referrer_id,
|
||||||
|
self.policy.referral_referrer_amount,
|
||||||
|
&referee_id,
|
||||||
|
)
|
||||||
|
.with_expiry_days(self.policy.default_expiry_days);
|
||||||
|
|
||||||
// Credit for the referee
|
// Credit for the referee
|
||||||
let referee_credit = Credit::referral(&referee_id, self.policy.referral_referee_amount, &referrer_id)
|
let referee_credit = Credit::referral(
|
||||||
.with_expiry_days(self.policy.default_expiry_days);
|
&referee_id,
|
||||||
|
self.policy.referral_referee_amount,
|
||||||
|
&referrer_id,
|
||||||
|
)
|
||||||
|
.with_expiry_days(self.policy.default_expiry_days);
|
||||||
|
|
||||||
self.credits
|
self.credits
|
||||||
.entry(referrer_id)
|
.entry(referrer_id)
|
||||||
|
|
@ -448,13 +461,11 @@ impl CreditManager {
|
||||||
let mut remaining = amount;
|
let mut remaining = amount;
|
||||||
|
|
||||||
// Sort by expiry date (soonest first) for FIFO
|
// Sort by expiry date (soonest first) for FIFO
|
||||||
credits.sort_by(|a, b| {
|
credits.sort_by(|a, b| match (&a.expires_at, &b.expires_at) {
|
||||||
match (&a.expires_at, &b.expires_at) {
|
(Some(a_exp), Some(b_exp)) => a_exp.cmp(b_exp),
|
||||||
(Some(a_exp), Some(b_exp)) => a_exp.cmp(b_exp),
|
(Some(_), None) => std::cmp::Ordering::Less,
|
||||||
(Some(_), None) => std::cmp::Ordering::Less,
|
(None, Some(_)) => std::cmp::Ordering::Greater,
|
||||||
(None, Some(_)) => std::cmp::Ordering::Greater,
|
(None, None) => a.created_at.cmp(&b.created_at),
|
||||||
(None, None) => a.created_at.cmp(&b.created_at),
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
|
|
||||||
for credit in credits.iter_mut() {
|
for credit in credits.iter_mut() {
|
||||||
|
|
|
||||||
|
|
@ -319,12 +319,7 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_line_item() {
|
fn test_line_item() {
|
||||||
let item = InvoiceLineItem::new(
|
let item = InvoiceLineItem::new("Storage L2", ServiceType::Storage, dec!(10), dec!(0.02));
|
||||||
"Storage L2",
|
|
||||||
ServiceType::Storage,
|
|
||||||
dec!(10),
|
|
||||||
dec!(0.02),
|
|
||||||
);
|
|
||||||
|
|
||||||
assert_eq!(item.amount, dec!(0.20));
|
assert_eq!(item.amount, dec!(0.20));
|
||||||
}
|
}
|
||||||
|
|
@ -332,8 +327,18 @@ mod tests {
|
||||||
#[test]
|
#[test]
|
||||||
fn test_invoice_calculate() {
|
fn test_invoice_calculate() {
|
||||||
let mut invoice = Invoice::new("test")
|
let mut invoice = Invoice::new("test")
|
||||||
.add_line_item(InvoiceLineItem::new("Storage", ServiceType::Storage, dec!(100), dec!(0.02)))
|
.add_line_item(InvoiceLineItem::new(
|
||||||
.add_line_item(InvoiceLineItem::new("Compute", ServiceType::Compute, dec!(10), dec!(0.50)));
|
"Storage",
|
||||||
|
ServiceType::Storage,
|
||||||
|
dec!(100),
|
||||||
|
dec!(0.02),
|
||||||
|
))
|
||||||
|
.add_line_item(InvoiceLineItem::new(
|
||||||
|
"Compute",
|
||||||
|
ServiceType::Compute,
|
||||||
|
dec!(10),
|
||||||
|
dec!(0.50),
|
||||||
|
));
|
||||||
|
|
||||||
invoice.discount = dec!(1);
|
invoice.discount = dec!(1);
|
||||||
invoice.calculate();
|
invoice.calculate();
|
||||||
|
|
|
||||||
|
|
@ -164,17 +164,12 @@ impl BillingEngine {
|
||||||
let outstanding: Vec<_> = account
|
let outstanding: Vec<_> = account
|
||||||
.invoice_ids
|
.invoice_ids
|
||||||
.iter()
|
.iter()
|
||||||
.filter(|id| {
|
.filter(|id| invoices.get(*id).map(|inv| !inv.is_paid()).unwrap_or(false))
|
||||||
invoices
|
|
||||||
.get(*id)
|
|
||||||
.map(|inv| !inv.is_paid())
|
|
||||||
.unwrap_or(false)
|
|
||||||
})
|
|
||||||
.cloned()
|
.cloned()
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let next_invoice = account.billing_cycle_start
|
let next_invoice =
|
||||||
+ Duration::days(self.config.billing_cycle_days as i64);
|
account.billing_cycle_start + Duration::days(self.config.billing_cycle_days as i64);
|
||||||
|
|
||||||
Ok(AccountBillingInfo {
|
Ok(AccountBillingInfo {
|
||||||
account_id: account_id.to_string(),
|
account_id: account_id.to_string(),
|
||||||
|
|
@ -198,11 +193,7 @@ impl BillingEngine {
|
||||||
|
|
||||||
account.prepaid_balance += amount;
|
account.prepaid_balance += amount;
|
||||||
|
|
||||||
tracing::info!(
|
tracing::info!("Added {} SYNOR prepaid to account {}", amount, account_id);
|
||||||
"Added {} SYNOR prepaid to account {}",
|
|
||||||
amount,
|
|
||||||
account_id
|
|
||||||
);
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
@ -378,7 +369,9 @@ impl BillingEngine {
|
||||||
PaymentMethod::CreditBalance => {
|
PaymentMethod::CreditBalance => {
|
||||||
// Deduct from credit balance
|
// Deduct from credit balance
|
||||||
if account.credit_balance < payment.amount {
|
if account.credit_balance < payment.amount {
|
||||||
return Err(EconomicsError::PaymentFailed("Insufficient credit balance".to_string()));
|
return Err(EconomicsError::PaymentFailed(
|
||||||
|
"Insufficient credit balance".to_string(),
|
||||||
|
));
|
||||||
}
|
}
|
||||||
account.credit_balance -= payment.amount;
|
account.credit_balance -= payment.amount;
|
||||||
payment.mark_completed();
|
payment.mark_completed();
|
||||||
|
|
@ -484,7 +477,10 @@ impl BillingEngine {
|
||||||
/// Get unpaid invoices for an account
|
/// Get unpaid invoices for an account
|
||||||
pub async fn get_unpaid_invoices(&self, account_id: &str) -> Result<Vec<Invoice>> {
|
pub async fn get_unpaid_invoices(&self, account_id: &str) -> Result<Vec<Invoice>> {
|
||||||
let all_invoices = self.get_account_invoices(account_id).await?;
|
let all_invoices = self.get_account_invoices(account_id).await?;
|
||||||
Ok(all_invoices.into_iter().filter(|inv| !inv.is_paid()).collect())
|
Ok(all_invoices
|
||||||
|
.into_iter()
|
||||||
|
.filter(|inv| !inv.is_paid())
|
||||||
|
.collect())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get detailed account information including creation date
|
/// Get detailed account information including creation date
|
||||||
|
|
@ -617,7 +613,10 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_register_account() {
|
async fn test_register_account() {
|
||||||
let engine = setup_engine().await;
|
let engine = setup_engine().await;
|
||||||
engine.register_account("test_account", "standard").await.unwrap();
|
engine
|
||||||
|
.register_account("test_account", "standard")
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let info = engine.get_account_details("test_account").await.unwrap();
|
let info = engine.get_account_details("test_account").await.unwrap();
|
||||||
assert_eq!(info.account_id, "test_account");
|
assert_eq!(info.account_id, "test_account");
|
||||||
|
|
@ -627,7 +626,10 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_add_prepaid() {
|
async fn test_add_prepaid() {
|
||||||
let engine = setup_engine().await;
|
let engine = setup_engine().await;
|
||||||
engine.register_account("prepaid_test", "standard").await.unwrap();
|
engine
|
||||||
|
.register_account("prepaid_test", "standard")
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
engine.add_prepaid("prepaid_test", dec!(100)).await.unwrap();
|
engine.add_prepaid("prepaid_test", dec!(100)).await.unwrap();
|
||||||
|
|
||||||
let info = engine.get_account_info("prepaid_test").await.unwrap();
|
let info = engine.get_account_info("prepaid_test").await.unwrap();
|
||||||
|
|
@ -637,7 +639,10 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_add_credit() {
|
async fn test_add_credit() {
|
||||||
let engine = setup_engine().await;
|
let engine = setup_engine().await;
|
||||||
engine.register_account("credit_test", "standard").await.unwrap();
|
engine
|
||||||
|
.register_account("credit_test", "standard")
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let credit = Credit::new("credit_test", dec!(50), "Welcome bonus");
|
let credit = Credit::new("credit_test", dec!(50), "Welcome bonus");
|
||||||
engine.add_credit("credit_test", credit).await.unwrap();
|
engine.add_credit("credit_test", credit).await.unwrap();
|
||||||
|
|
|
||||||
|
|
@ -210,17 +210,23 @@ impl PaymentProcessor {
|
||||||
payment.mark_processing();
|
payment.mark_processing();
|
||||||
|
|
||||||
// Simulate transaction
|
// Simulate transaction
|
||||||
let tx_hash = format!("0x{:x}000000000000000000000000000000000000000000000000000000000000",
|
let tx_hash = format!(
|
||||||
|
"0x{:x}000000000000000000000000000000000000000000000000000000000000",
|
||||||
std::time::SystemTime::now()
|
std::time::SystemTime::now()
|
||||||
.duration_since(UNIX_EPOCH)
|
.duration_since(UNIX_EPOCH)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.as_secs());
|
.as_secs()
|
||||||
|
);
|
||||||
|
|
||||||
payment.mark_confirmed(tx_hash);
|
payment.mark_confirmed(tx_hash);
|
||||||
|
|
||||||
// Add addresses to metadata
|
// Add addresses to metadata
|
||||||
payment.metadata.insert("from".to_string(), from_address.to_string());
|
payment
|
||||||
payment.metadata.insert("to".to_string(), to_address.to_string());
|
.metadata
|
||||||
|
.insert("from".to_string(), from_address.to_string());
|
||||||
|
payment
|
||||||
|
.metadata
|
||||||
|
.insert("to".to_string(), to_address.to_string());
|
||||||
|
|
||||||
payment.mark_completed();
|
payment.mark_completed();
|
||||||
|
|
||||||
|
|
@ -325,7 +331,10 @@ mod tests {
|
||||||
|
|
||||||
payment.mark_failed("Insufficient funds");
|
payment.mark_failed("Insufficient funds");
|
||||||
assert!(!payment.is_complete());
|
assert!(!payment.is_complete());
|
||||||
assert_eq!(payment.failure_reason, Some("Insufficient funds".to_string()));
|
assert_eq!(
|
||||||
|
payment.failure_reason,
|
||||||
|
Some("Insufficient funds".to_string())
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
|
|
|
||||||
|
|
@ -177,7 +177,10 @@ impl CostEstimator {
|
||||||
|
|
||||||
/// Estimate cost for a usage projection
|
/// Estimate cost for a usage projection
|
||||||
pub async fn estimate(&self, projection: UsageProjection) -> Result<CostEstimate> {
|
pub async fn estimate(&self, projection: UsageProjection) -> Result<CostEstimate> {
|
||||||
let tier_name = projection.tier.clone().unwrap_or_else(|| "free".to_string());
|
let tier_name = projection
|
||||||
|
.tier
|
||||||
|
.clone()
|
||||||
|
.unwrap_or_else(|| "free".to_string());
|
||||||
let months = projection.duration_months.max(1);
|
let months = projection.duration_months.max(1);
|
||||||
|
|
||||||
let mut by_service = HashMap::new();
|
let mut by_service = HashMap::new();
|
||||||
|
|
|
||||||
|
|
@ -22,10 +22,7 @@ pub enum EconomicsError {
|
||||||
|
|
||||||
/// Insufficient balance
|
/// Insufficient balance
|
||||||
#[error("Insufficient balance: required {required}, available {available}")]
|
#[error("Insufficient balance: required {required}, available {available}")]
|
||||||
InsufficientBalance {
|
InsufficientBalance { required: String, available: String },
|
||||||
required: String,
|
|
||||||
available: String,
|
|
||||||
},
|
|
||||||
|
|
||||||
/// Insufficient funds (with Decimal values)
|
/// Insufficient funds (with Decimal values)
|
||||||
#[error("Insufficient funds: required {required}, available {available}")]
|
#[error("Insufficient funds: required {required}, available {available}")]
|
||||||
|
|
@ -36,10 +33,7 @@ pub enum EconomicsError {
|
||||||
|
|
||||||
/// Stale price with specific asset
|
/// Stale price with specific asset
|
||||||
#[error("Price stale for {asset}: {age_seconds} seconds old")]
|
#[error("Price stale for {asset}: {age_seconds} seconds old")]
|
||||||
StalePrice {
|
StalePrice { asset: String, age_seconds: i64 },
|
||||||
asset: String,
|
|
||||||
age_seconds: i64,
|
|
||||||
},
|
|
||||||
|
|
||||||
/// Account not found
|
/// Account not found
|
||||||
#[error("Account not found: {0}")]
|
#[error("Account not found: {0}")]
|
||||||
|
|
|
||||||
|
|
@ -251,9 +251,7 @@ impl EconomicsManager {
|
||||||
use rust_decimal_macros::dec;
|
use rust_decimal_macros::dec;
|
||||||
|
|
||||||
// Default to development oracle with mock feeds at $1.50 base price
|
// Default to development oracle with mock feeds at $1.50 base price
|
||||||
let oracle = Arc::new(RwLock::new(
|
let oracle = Arc::new(RwLock::new(oracle::OracleFactory::development(dec!(1.50))));
|
||||||
oracle::OracleFactory::development(dec!(1.50))
|
|
||||||
));
|
|
||||||
let pricing = Arc::new(PricingEngine::new());
|
let pricing = Arc::new(PricingEngine::new());
|
||||||
let metering = Arc::new(MeteringService::new(pricing.clone()));
|
let metering = Arc::new(MeteringService::new(pricing.clone()));
|
||||||
let billing = Arc::new(BillingEngine::new(metering.clone(), pricing.clone()));
|
let billing = Arc::new(BillingEngine::new(metering.clone(), pricing.clone()));
|
||||||
|
|
@ -270,9 +268,7 @@ impl EconomicsManager {
|
||||||
|
|
||||||
/// Create an economics manager with production oracle configuration
|
/// Create an economics manager with production oracle configuration
|
||||||
pub fn with_production_oracle(config: oracle::ProductionOracleConfig) -> Self {
|
pub fn with_production_oracle(config: oracle::ProductionOracleConfig) -> Self {
|
||||||
let oracle = Arc::new(RwLock::new(
|
let oracle = Arc::new(RwLock::new(oracle::OracleFactory::production(config)));
|
||||||
oracle::OracleFactory::production(config)
|
|
||||||
));
|
|
||||||
let pricing = Arc::new(PricingEngine::new());
|
let pricing = Arc::new(PricingEngine::new());
|
||||||
let metering = Arc::new(MeteringService::new(pricing.clone()));
|
let metering = Arc::new(MeteringService::new(pricing.clone()));
|
||||||
let billing = Arc::new(BillingEngine::new(metering.clone(), pricing.clone()));
|
let billing = Arc::new(BillingEngine::new(metering.clone(), pricing.clone()));
|
||||||
|
|
|
||||||
|
|
@ -209,23 +209,22 @@ impl MeteringService {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Calculate cost for this event
|
// Calculate cost for this event
|
||||||
let cost = self.pricing.calculate_cost(
|
let cost =
|
||||||
event.service_type,
|
self.pricing
|
||||||
event.resource_unit,
|
.calculate_cost(event.service_type, event.resource_unit, event.amount)?;
|
||||||
event.amount,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
// Update current usage
|
// Update current usage
|
||||||
{
|
{
|
||||||
let mut usage = self.current_usage.write().await;
|
let mut usage = self.current_usage.write().await;
|
||||||
let account_usage = usage.entry(event.account_id.clone()).or_insert_with(|| {
|
let account_usage =
|
||||||
AccountUsage {
|
usage
|
||||||
account_id: event.account_id.clone(),
|
.entry(event.account_id.clone())
|
||||||
by_service: HashMap::new(),
|
.or_insert_with(|| AccountUsage {
|
||||||
current_period_start: Utc::now(),
|
account_id: event.account_id.clone(),
|
||||||
last_event: None,
|
by_service: HashMap::new(),
|
||||||
}
|
current_period_start: Utc::now(),
|
||||||
});
|
last_event: None,
|
||||||
|
});
|
||||||
|
|
||||||
*account_usage
|
*account_usage
|
||||||
.by_service
|
.by_service
|
||||||
|
|
@ -263,7 +262,8 @@ impl MeteringService {
|
||||||
ServiceType::Storage,
|
ServiceType::Storage,
|
||||||
ResourceUnit::Bytes,
|
ResourceUnit::Bytes,
|
||||||
Decimal::from(usage.bytes_stored),
|
Decimal::from(usage.bytes_stored),
|
||||||
)).await?;
|
))
|
||||||
|
.await?;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Storage: bytes retrieved
|
// Storage: bytes retrieved
|
||||||
|
|
@ -273,7 +273,8 @@ impl MeteringService {
|
||||||
ServiceType::Storage,
|
ServiceType::Storage,
|
||||||
ResourceUnit::BandwidthGb,
|
ResourceUnit::BandwidthGb,
|
||||||
Decimal::from(usage.bytes_retrieved) / Decimal::from(1_073_741_824u64), // to GB
|
Decimal::from(usage.bytes_retrieved) / Decimal::from(1_073_741_824u64), // to GB
|
||||||
)).await?;
|
))
|
||||||
|
.await?;
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|
@ -288,7 +289,8 @@ impl MeteringService {
|
||||||
ServiceType::Hosting,
|
ServiceType::Hosting,
|
||||||
ResourceUnit::BandwidthGb,
|
ResourceUnit::BandwidthGb,
|
||||||
Decimal::from(usage.bandwidth_bytes) / Decimal::from(1_073_741_824u64),
|
Decimal::from(usage.bandwidth_bytes) / Decimal::from(1_073_741_824u64),
|
||||||
)).await?;
|
))
|
||||||
|
.await?;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Custom domains
|
// Custom domains
|
||||||
|
|
@ -298,7 +300,8 @@ impl MeteringService {
|
||||||
ServiceType::Hosting,
|
ServiceType::Hosting,
|
||||||
ResourceUnit::Domains,
|
ResourceUnit::Domains,
|
||||||
Decimal::from(usage.custom_domains),
|
Decimal::from(usage.custom_domains),
|
||||||
)).await?;
|
))
|
||||||
|
.await?;
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|
@ -313,7 +316,8 @@ impl MeteringService {
|
||||||
ServiceType::Database,
|
ServiceType::Database,
|
||||||
ResourceUnit::Queries,
|
ResourceUnit::Queries,
|
||||||
Decimal::from(usage.queries),
|
Decimal::from(usage.queries),
|
||||||
)).await?;
|
))
|
||||||
|
.await?;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Vector searches
|
// Vector searches
|
||||||
|
|
@ -323,7 +327,8 @@ impl MeteringService {
|
||||||
ServiceType::Database,
|
ServiceType::Database,
|
||||||
ResourceUnit::VectorSearches,
|
ResourceUnit::VectorSearches,
|
||||||
Decimal::from(usage.vector_searches),
|
Decimal::from(usage.vector_searches),
|
||||||
)).await?;
|
))
|
||||||
|
.await?;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Storage
|
// Storage
|
||||||
|
|
@ -333,7 +338,8 @@ impl MeteringService {
|
||||||
ServiceType::Database,
|
ServiceType::Database,
|
||||||
ResourceUnit::GbMonth,
|
ResourceUnit::GbMonth,
|
||||||
Decimal::from(usage.storage_bytes) / Decimal::from(1_073_741_824u64),
|
Decimal::from(usage.storage_bytes) / Decimal::from(1_073_741_824u64),
|
||||||
)).await?;
|
))
|
||||||
|
.await?;
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|
@ -348,7 +354,8 @@ impl MeteringService {
|
||||||
ServiceType::Compute,
|
ServiceType::Compute,
|
||||||
ResourceUnit::CpuCoreHours,
|
ResourceUnit::CpuCoreHours,
|
||||||
Decimal::from(usage.cpu_core_seconds) / Decimal::from(3600),
|
Decimal::from(usage.cpu_core_seconds) / Decimal::from(3600),
|
||||||
)).await?;
|
))
|
||||||
|
.await?;
|
||||||
}
|
}
|
||||||
|
|
||||||
// GPU hours
|
// GPU hours
|
||||||
|
|
@ -358,7 +365,8 @@ impl MeteringService {
|
||||||
ServiceType::Compute,
|
ServiceType::Compute,
|
||||||
ResourceUnit::GpuHours,
|
ResourceUnit::GpuHours,
|
||||||
Decimal::from(usage.gpu_seconds) / Decimal::from(3600),
|
Decimal::from(usage.gpu_seconds) / Decimal::from(3600),
|
||||||
)).await?;
|
))
|
||||||
|
.await?;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Memory GB hours
|
// Memory GB hours
|
||||||
|
|
@ -368,7 +376,8 @@ impl MeteringService {
|
||||||
ServiceType::Compute,
|
ServiceType::Compute,
|
||||||
ResourceUnit::MemoryGbHours,
|
ResourceUnit::MemoryGbHours,
|
||||||
Decimal::from(usage.memory_gb_seconds) / Decimal::from(3600),
|
Decimal::from(usage.memory_gb_seconds) / Decimal::from(3600),
|
||||||
)).await?;
|
))
|
||||||
|
.await?;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Invocations (serverless)
|
// Invocations (serverless)
|
||||||
|
|
@ -378,7 +387,8 @@ impl MeteringService {
|
||||||
ServiceType::Compute,
|
ServiceType::Compute,
|
||||||
ResourceUnit::Invocations,
|
ResourceUnit::Invocations,
|
||||||
Decimal::from(usage.invocations),
|
Decimal::from(usage.invocations),
|
||||||
)).await?;
|
))
|
||||||
|
.await?;
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|
@ -393,7 +403,8 @@ impl MeteringService {
|
||||||
ServiceType::Network,
|
ServiceType::Network,
|
||||||
ResourceUnit::BandwidthGb,
|
ResourceUnit::BandwidthGb,
|
||||||
Decimal::from(total_bytes) / Decimal::from(1_073_741_824u64),
|
Decimal::from(total_bytes) / Decimal::from(1_073_741_824u64),
|
||||||
)).await?;
|
))
|
||||||
|
.await?;
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|
@ -421,10 +432,7 @@ impl MeteringService {
|
||||||
// Check buffered events
|
// Check buffered events
|
||||||
let buffer = self.event_buffer.read().await;
|
let buffer = self.event_buffer.read().await;
|
||||||
for event in buffer.iter() {
|
for event in buffer.iter() {
|
||||||
if event.account_id == account_id
|
if event.account_id == account_id && event.timestamp >= start && event.timestamp < end {
|
||||||
&& event.timestamp >= start
|
|
||||||
&& event.timestamp < end
|
|
||||||
{
|
|
||||||
let cost = self.pricing.calculate_cost(
|
let cost = self.pricing.calculate_cost(
|
||||||
event.service_type,
|
event.service_type,
|
||||||
event.resource_unit,
|
event.resource_unit,
|
||||||
|
|
|
||||||
|
|
@ -224,9 +224,8 @@ impl IsolationTree {
|
||||||
// Random split point
|
// Random split point
|
||||||
let split = min_val + (max_val - min_val) * 0.5;
|
let split = min_val + (max_val - min_val) * 0.5;
|
||||||
|
|
||||||
let (left_data, right_data): (Vec<_>, Vec<_>) = data.iter()
|
let (left_data, right_data): (Vec<_>, Vec<_>) =
|
||||||
.cloned()
|
data.iter().cloned().partition(|row| row[feature] < split);
|
||||||
.partition(|row| row[feature] < split);
|
|
||||||
|
|
||||||
Some(Self {
|
Some(Self {
|
||||||
split_feature: feature,
|
split_feature: feature,
|
||||||
|
|
@ -280,7 +279,8 @@ impl IsolationForest {
|
||||||
let trees: Vec<_> = (0..n_trees)
|
let trees: Vec<_> = (0..n_trees)
|
||||||
.filter_map(|i| {
|
.filter_map(|i| {
|
||||||
// Subsample with deterministic "randomness" based on tree index
|
// Subsample with deterministic "randomness" based on tree index
|
||||||
let sample: Vec<_> = data.iter()
|
let sample: Vec<_> = data
|
||||||
|
.iter()
|
||||||
.enumerate()
|
.enumerate()
|
||||||
.filter(|(j, _)| (i + j) % 3 != 0)
|
.filter(|(j, _)| (i + j) % 3 != 0)
|
||||||
.take(sample_size)
|
.take(sample_size)
|
||||||
|
|
@ -299,9 +299,12 @@ impl IsolationForest {
|
||||||
return 0.5;
|
return 0.5;
|
||||||
}
|
}
|
||||||
|
|
||||||
let avg_path: f64 = self.trees.iter()
|
let avg_path: f64 = self
|
||||||
|
.trees
|
||||||
|
.iter()
|
||||||
.map(|tree| tree.path_length(point, 0.0))
|
.map(|tree| tree.path_length(point, 0.0))
|
||||||
.sum::<f64>() / self.trees.len() as f64;
|
.sum::<f64>()
|
||||||
|
/ self.trees.len() as f64;
|
||||||
|
|
||||||
let c = c_factor(self.sample_size);
|
let c = c_factor(self.sample_size);
|
||||||
if c < f64::EPSILON {
|
if c < f64::EPSILON {
|
||||||
|
|
@ -365,17 +368,28 @@ impl PairDetector {
|
||||||
|
|
||||||
// Track addresses
|
// Track addresses
|
||||||
if !point.addresses.is_empty() {
|
if !point.addresses.is_empty() {
|
||||||
self.recent_addresses.push_back((point.timestamp, point.addresses.clone()));
|
self.recent_addresses
|
||||||
|
.push_back((point.timestamp, point.addresses.clone()));
|
||||||
}
|
}
|
||||||
|
|
||||||
self.price_history.push_back(point);
|
self.price_history.push_back(point);
|
||||||
|
|
||||||
// Cleanup old data
|
// Cleanup old data
|
||||||
let cutoff = Utc::now() - Duration::hours(24);
|
let cutoff = Utc::now() - Duration::hours(24);
|
||||||
while self.price_history.front().map(|p| p.timestamp < cutoff).unwrap_or(false) {
|
while self
|
||||||
|
.price_history
|
||||||
|
.front()
|
||||||
|
.map(|p| p.timestamp < cutoff)
|
||||||
|
.unwrap_or(false)
|
||||||
|
{
|
||||||
self.price_history.pop_front();
|
self.price_history.pop_front();
|
||||||
}
|
}
|
||||||
while self.recent_addresses.front().map(|(t, _)| *t < cutoff).unwrap_or(false) {
|
while self
|
||||||
|
.recent_addresses
|
||||||
|
.front()
|
||||||
|
.map(|(t, _)| *t < cutoff)
|
||||||
|
.unwrap_or(false)
|
||||||
|
{
|
||||||
self.recent_addresses.pop_front();
|
self.recent_addresses.pop_front();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -386,7 +400,9 @@ impl PairDetector {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build feature vectors: [price, volume, return, bid/ask ratio]
|
// Build feature vectors: [price, volume, return, bid/ask ratio]
|
||||||
let data: Vec<Vec<f64>> = self.price_history.iter()
|
let data: Vec<Vec<f64>> = self
|
||||||
|
.price_history
|
||||||
|
.iter()
|
||||||
.skip(1)
|
.skip(1)
|
||||||
.zip(self.price_history.iter())
|
.zip(self.price_history.iter())
|
||||||
.map(|(curr, prev)| {
|
.map(|(curr, prev)| {
|
||||||
|
|
@ -403,7 +419,11 @@ impl PairDetector {
|
||||||
(Some(bid), Some(ask)) => {
|
(Some(bid), Some(ask)) => {
|
||||||
let bid_f = bid.to_string().parse::<f64>().unwrap_or(0.0);
|
let bid_f = bid.to_string().parse::<f64>().unwrap_or(0.0);
|
||||||
let ask_f = ask.to_string().parse::<f64>().unwrap_or(1.0);
|
let ask_f = ask.to_string().parse::<f64>().unwrap_or(1.0);
|
||||||
if ask_f > 0.0 { bid_f / ask_f } else { 1.0 }
|
if ask_f > 0.0 {
|
||||||
|
bid_f / ask_f
|
||||||
|
} else {
|
||||||
|
1.0
|
||||||
|
}
|
||||||
}
|
}
|
||||||
_ => 1.0,
|
_ => 1.0,
|
||||||
};
|
};
|
||||||
|
|
@ -462,19 +482,34 @@ impl AnomalyDetector {
|
||||||
|
|
||||||
// Run all detectors using the immutable reference first
|
// Run all detectors using the immutable reference first
|
||||||
if let Some(detector) = self.detectors.get(pair) {
|
if let Some(detector) = self.detectors.get(pair) {
|
||||||
if let Some(a) = Self::detect_price_outlier_impl(pair, &data, detector, min_data_points, z_score_threshold) {
|
if let Some(a) = Self::detect_price_outlier_impl(
|
||||||
|
pair,
|
||||||
|
&data,
|
||||||
|
detector,
|
||||||
|
min_data_points,
|
||||||
|
z_score_threshold,
|
||||||
|
) {
|
||||||
anomalies.push(a);
|
anomalies.push(a);
|
||||||
}
|
}
|
||||||
if let Some(a) = Self::detect_volume_spike_impl(pair, &data, detector, min_data_points, volume_spike_multiplier) {
|
if let Some(a) = Self::detect_volume_spike_impl(
|
||||||
|
pair,
|
||||||
|
&data,
|
||||||
|
detector,
|
||||||
|
min_data_points,
|
||||||
|
volume_spike_multiplier,
|
||||||
|
) {
|
||||||
anomalies.push(a);
|
anomalies.push(a);
|
||||||
}
|
}
|
||||||
if let Some(a) = Self::detect_wash_trading_impl(pair, &data, detector, wash_trading_window) {
|
if let Some(a) =
|
||||||
|
Self::detect_wash_trading_impl(pair, &data, detector, wash_trading_window)
|
||||||
|
{
|
||||||
anomalies.push(a);
|
anomalies.push(a);
|
||||||
}
|
}
|
||||||
if let Some(a) = Self::detect_pump_dump_impl(pair, detector, pump_dump_window) {
|
if let Some(a) = Self::detect_pump_dump_impl(pair, detector, pump_dump_window) {
|
||||||
anomalies.push(a);
|
anomalies.push(a);
|
||||||
}
|
}
|
||||||
if let Some(a) = Self::detect_flash_loan_impl(pair, &data, detector, flash_loan_window) {
|
if let Some(a) = Self::detect_flash_loan_impl(pair, &data, detector, flash_loan_window)
|
||||||
|
{
|
||||||
anomalies.push(a);
|
anomalies.push(a);
|
||||||
}
|
}
|
||||||
if ml_enabled {
|
if ml_enabled {
|
||||||
|
|
@ -493,7 +528,13 @@ impl AnomalyDetector {
|
||||||
anomalies
|
anomalies
|
||||||
}
|
}
|
||||||
|
|
||||||
fn detect_price_outlier_impl(pair: &str, data: &MarketDataPoint, detector: &PairDetector, min_data_points: usize, z_score_threshold: f64) -> Option<Anomaly> {
|
fn detect_price_outlier_impl(
|
||||||
|
pair: &str,
|
||||||
|
data: &MarketDataPoint,
|
||||||
|
detector: &PairDetector,
|
||||||
|
min_data_points: usize,
|
||||||
|
z_score_threshold: f64,
|
||||||
|
) -> Option<Anomaly> {
|
||||||
if detector.price_stats.count < min_data_points {
|
if detector.price_stats.count < min_data_points {
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
|
|
@ -531,7 +572,13 @@ impl AnomalyDetector {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn detect_volume_spike_impl(pair: &str, data: &MarketDataPoint, detector: &PairDetector, min_data_points: usize, volume_spike_multiplier: f64) -> Option<Anomaly> {
|
fn detect_volume_spike_impl(
|
||||||
|
pair: &str,
|
||||||
|
data: &MarketDataPoint,
|
||||||
|
detector: &PairDetector,
|
||||||
|
min_data_points: usize,
|
||||||
|
volume_spike_multiplier: f64,
|
||||||
|
) -> Option<Anomaly> {
|
||||||
if detector.volume_stats.count < min_data_points {
|
if detector.volume_stats.count < min_data_points {
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
|
|
@ -550,7 +597,9 @@ impl AnomalyDetector {
|
||||||
confidence: 0.75,
|
confidence: 0.75,
|
||||||
description: format!(
|
description: format!(
|
||||||
"Volume {} is {:.1}x the average {:.2}",
|
"Volume {} is {:.1}x the average {:.2}",
|
||||||
data.volume, volume_f64 / mean, mean
|
data.volume,
|
||||||
|
volume_f64 / mean,
|
||||||
|
mean
|
||||||
),
|
),
|
||||||
data: AnomalyData {
|
data: AnomalyData {
|
||||||
current_value: data.volume,
|
current_value: data.volume,
|
||||||
|
|
@ -566,7 +615,12 @@ impl AnomalyDetector {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn detect_wash_trading_impl(pair: &str, data: &MarketDataPoint, detector: &PairDetector, wash_trading_window: i64) -> Option<Anomaly> {
|
fn detect_wash_trading_impl(
|
||||||
|
pair: &str,
|
||||||
|
data: &MarketDataPoint,
|
||||||
|
detector: &PairDetector,
|
||||||
|
wash_trading_window: i64,
|
||||||
|
) -> Option<Anomaly> {
|
||||||
if data.addresses.is_empty() {
|
if data.addresses.is_empty() {
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
|
|
@ -617,14 +671,20 @@ impl AnomalyDetector {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
fn detect_pump_dump_impl(pair: &str, detector: &PairDetector, pump_dump_window: i64) -> Option<Anomaly> {
|
fn detect_pump_dump_impl(
|
||||||
|
pair: &str,
|
||||||
|
detector: &PairDetector,
|
||||||
|
pump_dump_window: i64,
|
||||||
|
) -> Option<Anomaly> {
|
||||||
// Need enough history
|
// Need enough history
|
||||||
if detector.price_history.len() < 10 {
|
if detector.price_history.len() < 10 {
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
|
|
||||||
let window_start = Utc::now() - Duration::minutes(pump_dump_window);
|
let window_start = Utc::now() - Duration::minutes(pump_dump_window);
|
||||||
let prices: Vec<_> = detector.price_history.iter()
|
let prices: Vec<_> = detector
|
||||||
|
.price_history
|
||||||
|
.iter()
|
||||||
.filter(|p| p.timestamp >= window_start)
|
.filter(|p| p.timestamp >= window_start)
|
||||||
.map(|p| p.price.to_string().parse::<f64>().unwrap_or(0.0))
|
.map(|p| p.price.to_string().parse::<f64>().unwrap_or(0.0))
|
||||||
.collect();
|
.collect();
|
||||||
|
|
@ -635,7 +695,9 @@ impl AnomalyDetector {
|
||||||
|
|
||||||
// Find max and check for reversal
|
// Find max and check for reversal
|
||||||
let max_price = prices.iter().copied().fold(f64::MIN, f64::max);
|
let max_price = prices.iter().copied().fold(f64::MIN, f64::max);
|
||||||
let max_idx = prices.iter().position(|&p| (p - max_price).abs() < f64::EPSILON)?;
|
let max_idx = prices
|
||||||
|
.iter()
|
||||||
|
.position(|&p| (p - max_price).abs() < f64::EPSILON)?;
|
||||||
let first_price = prices.first()?;
|
let first_price = prices.first()?;
|
||||||
let last_price = prices.last()?;
|
let last_price = prices.last()?;
|
||||||
|
|
||||||
|
|
@ -672,10 +734,17 @@ impl AnomalyDetector {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
fn detect_flash_loan_impl(pair: &str, data: &MarketDataPoint, detector: &PairDetector, flash_loan_window: i64) -> Option<Anomaly> {
|
fn detect_flash_loan_impl(
|
||||||
|
pair: &str,
|
||||||
|
data: &MarketDataPoint,
|
||||||
|
detector: &PairDetector,
|
||||||
|
flash_loan_window: i64,
|
||||||
|
) -> Option<Anomaly> {
|
||||||
// Flash loan signature: huge volume spike + quick price movement + reversal
|
// Flash loan signature: huge volume spike + quick price movement + reversal
|
||||||
let window_start = Utc::now() - Duration::seconds(flash_loan_window);
|
let window_start = Utc::now() - Duration::seconds(flash_loan_window);
|
||||||
let recent: Vec<_> = detector.price_history.iter()
|
let recent: Vec<_> = detector
|
||||||
|
.price_history
|
||||||
|
.iter()
|
||||||
.filter(|p| p.timestamp >= window_start)
|
.filter(|p| p.timestamp >= window_start)
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
|
|
@ -683,11 +752,13 @@ impl AnomalyDetector {
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
|
|
||||||
let volumes: Vec<f64> = recent.iter()
|
let volumes: Vec<f64> = recent
|
||||||
|
.iter()
|
||||||
.map(|p| p.volume.to_string().parse::<f64>().unwrap_or(0.0))
|
.map(|p| p.volume.to_string().parse::<f64>().unwrap_or(0.0))
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let prices: Vec<f64> = recent.iter()
|
let prices: Vec<f64> = recent
|
||||||
|
.iter()
|
||||||
.map(|p| p.price.to_string().parse::<f64>().unwrap_or(0.0))
|
.map(|p| p.price.to_string().parse::<f64>().unwrap_or(0.0))
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
|
|
@ -706,7 +777,10 @@ impl AnomalyDetector {
|
||||||
// Big spike and quick reversal
|
// Big spike and quick reversal
|
||||||
if spike > 10.0 && reversal > 8.0 {
|
if spike > 10.0 && reversal > 8.0 {
|
||||||
let mut context = HashMap::new();
|
let mut context = HashMap::new();
|
||||||
context.insert("volume_spike".to_string(), format!("{:.0}x", max_volume / avg_volume));
|
context.insert(
|
||||||
|
"volume_spike".to_string(),
|
||||||
|
format!("{:.0}x", max_volume / avg_volume),
|
||||||
|
);
|
||||||
context.insert("price_spike".to_string(), format!("{:.1}%", spike));
|
context.insert("price_spike".to_string(), format!("{:.1}%", spike));
|
||||||
|
|
||||||
return Some(Anomaly {
|
return Some(Anomaly {
|
||||||
|
|
@ -717,7 +791,8 @@ impl AnomalyDetector {
|
||||||
confidence: 0.65,
|
confidence: 0.65,
|
||||||
description: format!(
|
description: format!(
|
||||||
"Suspected flash loan attack: {}x volume, {:.1}% price spike",
|
"Suspected flash loan attack: {}x volume, {:.1}% price spike",
|
||||||
max_volume / avg_volume, spike
|
max_volume / avg_volume,
|
||||||
|
spike
|
||||||
),
|
),
|
||||||
data: AnomalyData {
|
data: AnomalyData {
|
||||||
current_value: data.volume,
|
current_value: data.volume,
|
||||||
|
|
@ -734,7 +809,11 @@ impl AnomalyDetector {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
fn detect_ml_anomaly_impl(pair: &str, data: &MarketDataPoint, detector: &PairDetector) -> Option<Anomaly> {
|
fn detect_ml_anomaly_impl(
|
||||||
|
pair: &str,
|
||||||
|
data: &MarketDataPoint,
|
||||||
|
detector: &PairDetector,
|
||||||
|
) -> Option<Anomaly> {
|
||||||
let forest = detector.isolation_forest.as_ref()?;
|
let forest = detector.isolation_forest.as_ref()?;
|
||||||
|
|
||||||
if detector.price_history.is_empty() {
|
if detector.price_history.is_empty() {
|
||||||
|
|
@ -756,7 +835,11 @@ impl AnomalyDetector {
|
||||||
(Some(bid), Some(ask)) => {
|
(Some(bid), Some(ask)) => {
|
||||||
let bid_f = bid.to_string().parse::<f64>().ok()?;
|
let bid_f = bid.to_string().parse::<f64>().ok()?;
|
||||||
let ask_f = ask.to_string().parse::<f64>().ok()?;
|
let ask_f = ask.to_string().parse::<f64>().ok()?;
|
||||||
if ask_f > 0.0 { bid_f / ask_f } else { 1.0 }
|
if ask_f > 0.0 {
|
||||||
|
bid_f / ask_f
|
||||||
|
} else {
|
||||||
|
1.0
|
||||||
|
}
|
||||||
}
|
}
|
||||||
_ => 1.0,
|
_ => 1.0,
|
||||||
};
|
};
|
||||||
|
|
@ -774,10 +857,7 @@ impl AnomalyDetector {
|
||||||
detected_at: Utc::now(),
|
detected_at: Utc::now(),
|
||||||
severity: score,
|
severity: score,
|
||||||
confidence: 0.6,
|
confidence: 0.6,
|
||||||
description: format!(
|
description: format!("ML model detected anomaly with score {:.3}", score),
|
||||||
"ML model detected anomaly with score {:.3}",
|
|
||||||
score
|
|
||||||
),
|
|
||||||
data: AnomalyData {
|
data: AnomalyData {
|
||||||
current_value: data.price,
|
current_value: data.price,
|
||||||
expected_value: Decimal::from_f64_retain(detector.price_stats.mean)?,
|
expected_value: Decimal::from_f64_retain(detector.price_stats.mean)?,
|
||||||
|
|
@ -798,7 +878,8 @@ impl AnomalyDetector {
|
||||||
|
|
||||||
/// Get recent anomalies for a pair
|
/// Get recent anomalies for a pair
|
||||||
pub fn get_anomalies(&self, pair: &str) -> Vec<Anomaly> {
|
pub fn get_anomalies(&self, pair: &str) -> Vec<Anomaly> {
|
||||||
self.detectors.get(pair)
|
self.detectors
|
||||||
|
.get(pair)
|
||||||
.map(|d| d.anomalies.clone())
|
.map(|d| d.anomalies.clone())
|
||||||
.unwrap_or_default()
|
.unwrap_or_default()
|
||||||
}
|
}
|
||||||
|
|
@ -807,11 +888,14 @@ impl AnomalyDetector {
|
||||||
pub fn get_stats(&self, pair: &str) -> Option<AnomalyStats> {
|
pub fn get_stats(&self, pair: &str) -> Option<AnomalyStats> {
|
||||||
let detector = self.detectors.get(pair)?;
|
let detector = self.detectors.get(pair)?;
|
||||||
|
|
||||||
let by_type: HashMap<AnomalyType, usize> = detector.anomalies.iter()
|
let by_type: HashMap<AnomalyType, usize> =
|
||||||
.fold(HashMap::new(), |mut acc, a| {
|
detector
|
||||||
*acc.entry(a.anomaly_type.clone()).or_insert(0) += 1;
|
.anomalies
|
||||||
acc
|
.iter()
|
||||||
});
|
.fold(HashMap::new(), |mut acc, a| {
|
||||||
|
*acc.entry(a.anomaly_type.clone()).or_insert(0) += 1;
|
||||||
|
acc
|
||||||
|
});
|
||||||
|
|
||||||
Some(AnomalyStats {
|
Some(AnomalyStats {
|
||||||
total_anomalies: detector.anomalies.len(),
|
total_anomalies: detector.anomalies.len(),
|
||||||
|
|
@ -913,7 +997,9 @@ mod tests {
|
||||||
|
|
||||||
let anomalies = detector.process("SYNOR/USD", outlier);
|
let anomalies = detector.process("SYNOR/USD", outlier);
|
||||||
assert!(!anomalies.is_empty());
|
assert!(!anomalies.is_empty());
|
||||||
assert!(anomalies.iter().any(|a| a.anomaly_type == AnomalyType::PriceOutlier));
|
assert!(anomalies
|
||||||
|
.iter()
|
||||||
|
.any(|a| a.anomaly_type == AnomalyType::PriceOutlier));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -946,6 +1032,8 @@ mod tests {
|
||||||
};
|
};
|
||||||
|
|
||||||
let anomalies = detector.process("SYNOR/USD", spike);
|
let anomalies = detector.process("SYNOR/USD", spike);
|
||||||
assert!(anomalies.iter().any(|a| a.anomaly_type == AnomalyType::VolumeSpike));
|
assert!(anomalies
|
||||||
|
.iter()
|
||||||
|
.any(|a| a.anomaly_type == AnomalyType::VolumeSpike));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -40,17 +40,11 @@ pub enum TriggerReason {
|
||||||
threshold: SynorDecimal,
|
threshold: SynorDecimal,
|
||||||
},
|
},
|
||||||
/// Multiple oracle sources disagree
|
/// Multiple oracle sources disagree
|
||||||
OracleDisagreement {
|
OracleDisagreement { spread_percent: Decimal },
|
||||||
spread_percent: Decimal,
|
|
||||||
},
|
|
||||||
/// Manual trigger by admin
|
/// Manual trigger by admin
|
||||||
ManualHalt {
|
ManualHalt { reason: String },
|
||||||
reason: String,
|
|
||||||
},
|
|
||||||
/// Cascade from related market
|
/// Cascade from related market
|
||||||
CascadeTrigger {
|
CascadeTrigger { source_pair: String },
|
||||||
source_pair: String,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Circuit breaker event
|
/// Circuit breaker event
|
||||||
|
|
@ -98,12 +92,12 @@ pub struct CircuitBreakerConfig {
|
||||||
impl Default for CircuitBreakerConfig {
|
impl Default for CircuitBreakerConfig {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
max_1m_change: Decimal::new(10, 2), // 10%
|
max_1m_change: Decimal::new(10, 2), // 10%
|
||||||
max_5m_change: Decimal::new(20, 2), // 20%
|
max_5m_change: Decimal::new(20, 2), // 20%
|
||||||
max_1h_change: Decimal::new(50, 2), // 50%
|
max_1h_change: Decimal::new(50, 2), // 50%
|
||||||
max_twap_deviation: Decimal::new(30, 2), // 30%
|
max_twap_deviation: Decimal::new(30, 2), // 30%
|
||||||
min_liquidity: Decimal::new(10000, 0), // $10k
|
min_liquidity: Decimal::new(10000, 0), // $10k
|
||||||
max_oracle_spread: Decimal::new(5, 2), // 5%
|
max_oracle_spread: Decimal::new(5, 2), // 5%
|
||||||
cooldown_duration: Duration::minutes(5),
|
cooldown_duration: Duration::minutes(5),
|
||||||
recovery_checks: 3,
|
recovery_checks: 3,
|
||||||
cascade_enabled: true,
|
cascade_enabled: true,
|
||||||
|
|
@ -173,7 +167,12 @@ impl PairCircuitBreaker {
|
||||||
|
|
||||||
// Keep only last 24 hours
|
// Keep only last 24 hours
|
||||||
let cutoff = Utc::now() - Duration::hours(24);
|
let cutoff = Utc::now() - Duration::hours(24);
|
||||||
while self.price_history.front().map(|s| s.timestamp < cutoff).unwrap_or(false) {
|
while self
|
||||||
|
.price_history
|
||||||
|
.front()
|
||||||
|
.map(|s| s.timestamp < cutoff)
|
||||||
|
.unwrap_or(false)
|
||||||
|
{
|
||||||
self.price_history.pop_front();
|
self.price_history.pop_front();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -192,7 +191,8 @@ impl PairCircuitBreaker {
|
||||||
|
|
||||||
fn get_price_at(&self, seconds_ago: i64) -> Option<SynorDecimal> {
|
fn get_price_at(&self, seconds_ago: i64) -> Option<SynorDecimal> {
|
||||||
let target = Utc::now() - Duration::seconds(seconds_ago);
|
let target = Utc::now() - Duration::seconds(seconds_ago);
|
||||||
self.price_history.iter()
|
self.price_history
|
||||||
|
.iter()
|
||||||
.rev()
|
.rev()
|
||||||
.find(|s| s.timestamp <= target)
|
.find(|s| s.timestamp <= target)
|
||||||
.map(|s| s.price)
|
.map(|s| s.price)
|
||||||
|
|
@ -232,7 +232,9 @@ impl CircuitBreakerManager {
|
||||||
price: SynorDecimal,
|
price: SynorDecimal,
|
||||||
liquidity: Option<SynorDecimal>,
|
liquidity: Option<SynorDecimal>,
|
||||||
) -> Result<CircuitState> {
|
) -> Result<CircuitState> {
|
||||||
let breaker = self.breakers.entry(pair.to_string())
|
let breaker = self
|
||||||
|
.breakers
|
||||||
|
.entry(pair.to_string())
|
||||||
.or_insert_with(PairCircuitBreaker::new);
|
.or_insert_with(PairCircuitBreaker::new);
|
||||||
|
|
||||||
// Use the convenience method for real-time price recording
|
// Use the convenience method for real-time price recording
|
||||||
|
|
@ -256,7 +258,9 @@ impl CircuitBreakerManager {
|
||||||
liquidity: Option<SynorDecimal>,
|
liquidity: Option<SynorDecimal>,
|
||||||
timestamp: DateTime<Utc>,
|
timestamp: DateTime<Utc>,
|
||||||
) -> Result<CircuitState> {
|
) -> Result<CircuitState> {
|
||||||
let breaker = self.breakers.entry(pair.to_string())
|
let breaker = self
|
||||||
|
.breakers
|
||||||
|
.entry(pair.to_string())
|
||||||
.or_insert_with(PairCircuitBreaker::new);
|
.or_insert_with(PairCircuitBreaker::new);
|
||||||
|
|
||||||
breaker.record_price_at(price, liquidity, timestamp);
|
breaker.record_price_at(price, liquidity, timestamp);
|
||||||
|
|
@ -273,22 +277,26 @@ impl CircuitBreakerManager {
|
||||||
|
|
||||||
/// Check all trigger conditions
|
/// Check all trigger conditions
|
||||||
fn check_triggers(&mut self, pair: &str) -> Result<()> {
|
fn check_triggers(&mut self, pair: &str) -> Result<()> {
|
||||||
let breaker = self.breakers.get(pair).ok_or_else(||
|
let breaker = self
|
||||||
EconomicsError::PriceFeedUnavailable(pair.to_string())
|
.breakers
|
||||||
)?;
|
.get(pair)
|
||||||
|
.ok_or_else(|| EconomicsError::PriceFeedUnavailable(pair.to_string()))?;
|
||||||
|
|
||||||
let current = breaker.current_price().ok_or_else(||
|
let current = breaker
|
||||||
EconomicsError::PriceFeedUnavailable(pair.to_string())
|
.current_price()
|
||||||
)?;
|
.ok_or_else(|| EconomicsError::PriceFeedUnavailable(pair.to_string()))?;
|
||||||
|
|
||||||
// Check 1-minute change
|
// Check 1-minute change
|
||||||
if let Some(price_1m) = breaker.get_price_at(60) {
|
if let Some(price_1m) = breaker.get_price_at(60) {
|
||||||
let change = ((current - price_1m) / price_1m).abs();
|
let change = ((current - price_1m) / price_1m).abs();
|
||||||
if change > self.config.max_1m_change {
|
if change > self.config.max_1m_change {
|
||||||
return self.trigger_breaker(pair, TriggerReason::RapidPriceChange {
|
return self.trigger_breaker(
|
||||||
change_percent: change * Decimal::ONE_HUNDRED,
|
pair,
|
||||||
window_seconds: 60,
|
TriggerReason::RapidPriceChange {
|
||||||
});
|
change_percent: change * Decimal::ONE_HUNDRED,
|
||||||
|
window_seconds: 60,
|
||||||
|
},
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -296,10 +304,13 @@ impl CircuitBreakerManager {
|
||||||
if let Some(price_5m) = breaker.get_price_at(300) {
|
if let Some(price_5m) = breaker.get_price_at(300) {
|
||||||
let change = ((current - price_5m) / price_5m).abs();
|
let change = ((current - price_5m) / price_5m).abs();
|
||||||
if change > self.config.max_5m_change {
|
if change > self.config.max_5m_change {
|
||||||
return self.trigger_breaker(pair, TriggerReason::RapidPriceChange {
|
return self.trigger_breaker(
|
||||||
change_percent: change * Decimal::ONE_HUNDRED,
|
pair,
|
||||||
window_seconds: 300,
|
TriggerReason::RapidPriceChange {
|
||||||
});
|
change_percent: change * Decimal::ONE_HUNDRED,
|
||||||
|
window_seconds: 300,
|
||||||
|
},
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -307,10 +318,13 @@ impl CircuitBreakerManager {
|
||||||
if let Some(price_1h) = breaker.get_price_at(3600) {
|
if let Some(price_1h) = breaker.get_price_at(3600) {
|
||||||
let change = ((current - price_1h) / price_1h).abs();
|
let change = ((current - price_1h) / price_1h).abs();
|
||||||
if change > self.config.max_1h_change {
|
if change > self.config.max_1h_change {
|
||||||
return self.trigger_breaker(pair, TriggerReason::RapidPriceChange {
|
return self.trigger_breaker(
|
||||||
change_percent: change * Decimal::ONE_HUNDRED,
|
pair,
|
||||||
window_seconds: 3600,
|
TriggerReason::RapidPriceChange {
|
||||||
});
|
change_percent: change * Decimal::ONE_HUNDRED,
|
||||||
|
window_seconds: 3600,
|
||||||
|
},
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -318,20 +332,26 @@ impl CircuitBreakerManager {
|
||||||
if let Some(twap) = breaker.twap_24h {
|
if let Some(twap) = breaker.twap_24h {
|
||||||
let deviation = ((current - twap) / twap).abs();
|
let deviation = ((current - twap) / twap).abs();
|
||||||
if deviation > self.config.max_twap_deviation {
|
if deviation > self.config.max_twap_deviation {
|
||||||
return self.trigger_breaker(pair, TriggerReason::ExcessiveDeviation {
|
return self.trigger_breaker(
|
||||||
deviation_percent: deviation * Decimal::ONE_HUNDRED,
|
pair,
|
||||||
reference_price: twap,
|
TriggerReason::ExcessiveDeviation {
|
||||||
});
|
deviation_percent: deviation * Decimal::ONE_HUNDRED,
|
||||||
|
reference_price: twap,
|
||||||
|
},
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check liquidity
|
// Check liquidity
|
||||||
if let Some(liquidity) = breaker.current_liquidity() {
|
if let Some(liquidity) = breaker.current_liquidity() {
|
||||||
if liquidity < self.config.min_liquidity {
|
if liquidity < self.config.min_liquidity {
|
||||||
return self.trigger_breaker(pair, TriggerReason::LowLiquidity {
|
return self.trigger_breaker(
|
||||||
current: liquidity,
|
pair,
|
||||||
threshold: self.config.min_liquidity,
|
TriggerReason::LowLiquidity {
|
||||||
});
|
current: liquidity,
|
||||||
|
threshold: self.config.min_liquidity,
|
||||||
|
},
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -340,9 +360,10 @@ impl CircuitBreakerManager {
|
||||||
|
|
||||||
/// Trigger the circuit breaker
|
/// Trigger the circuit breaker
|
||||||
fn trigger_breaker(&mut self, pair: &str, reason: TriggerReason) -> Result<()> {
|
fn trigger_breaker(&mut self, pair: &str, reason: TriggerReason) -> Result<()> {
|
||||||
let breaker = self.breakers.get_mut(pair).ok_or_else(||
|
let breaker = self
|
||||||
EconomicsError::PriceFeedUnavailable(pair.to_string())
|
.breakers
|
||||||
)?;
|
.get_mut(pair)
|
||||||
|
.ok_or_else(|| EconomicsError::PriceFeedUnavailable(pair.to_string()))?;
|
||||||
|
|
||||||
let event = CircuitEvent {
|
let event = CircuitEvent {
|
||||||
pair: pair.to_string(),
|
pair: pair.to_string(),
|
||||||
|
|
@ -361,7 +382,10 @@ impl CircuitBreakerManager {
|
||||||
|
|
||||||
// Check cascade triggers
|
// Check cascade triggers
|
||||||
if self.config.cascade_enabled {
|
if self.config.cascade_enabled {
|
||||||
let cascades: Vec<_> = self.config.cascade_pairs.iter()
|
let cascades: Vec<_> = self
|
||||||
|
.config
|
||||||
|
.cascade_pairs
|
||||||
|
.iter()
|
||||||
.filter(|(source, _)| source == pair)
|
.filter(|(source, _)| source == pair)
|
||||||
.map(|(_, target)| target.clone())
|
.map(|(_, target)| target.clone())
|
||||||
.collect();
|
.collect();
|
||||||
|
|
@ -407,10 +431,15 @@ impl CircuitBreakerManager {
|
||||||
|
|
||||||
// Get current state first (immutable borrow)
|
// Get current state first (immutable borrow)
|
||||||
let (current_state, triggered_at, trigger_reason) = {
|
let (current_state, triggered_at, trigger_reason) = {
|
||||||
let breaker = self.breakers.get(pair).ok_or_else(||
|
let breaker = self
|
||||||
EconomicsError::PriceFeedUnavailable(pair.to_string())
|
.breakers
|
||||||
)?;
|
.get(pair)
|
||||||
(breaker.state, breaker.triggered_at, breaker.trigger_reason.clone())
|
.ok_or_else(|| EconomicsError::PriceFeedUnavailable(pair.to_string()))?;
|
||||||
|
(
|
||||||
|
breaker.state,
|
||||||
|
breaker.triggered_at,
|
||||||
|
breaker.trigger_reason.clone(),
|
||||||
|
)
|
||||||
};
|
};
|
||||||
|
|
||||||
// Check stability for half-open state (immutable borrow)
|
// Check stability for half-open state (immutable borrow)
|
||||||
|
|
@ -421,9 +450,10 @@ impl CircuitBreakerManager {
|
||||||
};
|
};
|
||||||
|
|
||||||
// Now get mutable reference for updates
|
// Now get mutable reference for updates
|
||||||
let breaker = self.breakers.get_mut(pair).ok_or_else(||
|
let breaker = self
|
||||||
EconomicsError::PriceFeedUnavailable(pair.to_string())
|
.breakers
|
||||||
)?;
|
.get_mut(pair)
|
||||||
|
.ok_or_else(|| EconomicsError::PriceFeedUnavailable(pair.to_string()))?;
|
||||||
|
|
||||||
match current_state {
|
match current_state {
|
||||||
CircuitState::Open => {
|
CircuitState::Open => {
|
||||||
|
|
@ -435,9 +465,9 @@ impl CircuitBreakerManager {
|
||||||
pair: pair.to_string(),
|
pair: pair.to_string(),
|
||||||
from_state: CircuitState::Open,
|
from_state: CircuitState::Open,
|
||||||
to_state: CircuitState::HalfOpen,
|
to_state: CircuitState::HalfOpen,
|
||||||
reason: trigger_reason.clone().unwrap_or(
|
reason: trigger_reason.clone().unwrap_or(TriggerReason::ManualHalt {
|
||||||
TriggerReason::ManualHalt { reason: "Unknown".into() }
|
reason: "Unknown".into(),
|
||||||
),
|
}),
|
||||||
timestamp: Utc::now(),
|
timestamp: Utc::now(),
|
||||||
cooldown: None,
|
cooldown: None,
|
||||||
};
|
};
|
||||||
|
|
@ -457,9 +487,9 @@ impl CircuitBreakerManager {
|
||||||
pair: pair.to_string(),
|
pair: pair.to_string(),
|
||||||
from_state: CircuitState::HalfOpen,
|
from_state: CircuitState::HalfOpen,
|
||||||
to_state: CircuitState::Closed,
|
to_state: CircuitState::Closed,
|
||||||
reason: trigger_reason.unwrap_or(
|
reason: trigger_reason.unwrap_or(TriggerReason::ManualHalt {
|
||||||
TriggerReason::ManualHalt { reason: "Recovery".into() }
|
reason: "Recovery".into(),
|
||||||
),
|
}),
|
||||||
timestamp: Utc::now(),
|
timestamp: Utc::now(),
|
||||||
cooldown: None,
|
cooldown: None,
|
||||||
};
|
};
|
||||||
|
|
@ -482,9 +512,10 @@ impl CircuitBreakerManager {
|
||||||
|
|
||||||
/// Check if market conditions are stable
|
/// Check if market conditions are stable
|
||||||
fn is_stable(&self, pair: &str) -> Result<bool> {
|
fn is_stable(&self, pair: &str) -> Result<bool> {
|
||||||
let breaker = self.breakers.get(pair).ok_or_else(||
|
let breaker = self
|
||||||
EconomicsError::PriceFeedUnavailable(pair.to_string())
|
.breakers
|
||||||
)?;
|
.get(pair)
|
||||||
|
.ok_or_else(|| EconomicsError::PriceFeedUnavailable(pair.to_string()))?;
|
||||||
|
|
||||||
let current = match breaker.current_price() {
|
let current = match breaker.current_price() {
|
||||||
Some(p) => p,
|
Some(p) => p,
|
||||||
|
|
@ -511,7 +542,8 @@ impl CircuitBreakerManager {
|
||||||
|
|
||||||
/// Get current state for a pair
|
/// Get current state for a pair
|
||||||
pub fn get_state(&self, pair: &str) -> CircuitState {
|
pub fn get_state(&self, pair: &str) -> CircuitState {
|
||||||
self.breakers.get(pair)
|
self.breakers
|
||||||
|
.get(pair)
|
||||||
.map(|b| b.state)
|
.map(|b| b.state)
|
||||||
.unwrap_or(CircuitState::Closed)
|
.unwrap_or(CircuitState::Closed)
|
||||||
}
|
}
|
||||||
|
|
@ -523,25 +555,32 @@ impl CircuitBreakerManager {
|
||||||
|
|
||||||
/// Manually trigger circuit breaker
|
/// Manually trigger circuit breaker
|
||||||
pub fn manual_halt(&mut self, pair: &str, reason: impl Into<String>) -> Result<()> {
|
pub fn manual_halt(&mut self, pair: &str, reason: impl Into<String>) -> Result<()> {
|
||||||
self.breakers.entry(pair.to_string())
|
self.breakers
|
||||||
|
.entry(pair.to_string())
|
||||||
.or_insert_with(PairCircuitBreaker::new);
|
.or_insert_with(PairCircuitBreaker::new);
|
||||||
|
|
||||||
self.trigger_breaker(pair, TriggerReason::ManualHalt {
|
self.trigger_breaker(
|
||||||
reason: reason.into(),
|
pair,
|
||||||
})
|
TriggerReason::ManualHalt {
|
||||||
|
reason: reason.into(),
|
||||||
|
},
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Manually reset circuit breaker
|
/// Manually reset circuit breaker
|
||||||
pub fn manual_reset(&mut self, pair: &str) -> Result<()> {
|
pub fn manual_reset(&mut self, pair: &str) -> Result<()> {
|
||||||
let breaker = self.breakers.get_mut(pair).ok_or_else(||
|
let breaker = self
|
||||||
EconomicsError::PriceFeedUnavailable(pair.to_string())
|
.breakers
|
||||||
)?;
|
.get_mut(pair)
|
||||||
|
.ok_or_else(|| EconomicsError::PriceFeedUnavailable(pair.to_string()))?;
|
||||||
|
|
||||||
let event = CircuitEvent {
|
let event = CircuitEvent {
|
||||||
pair: pair.to_string(),
|
pair: pair.to_string(),
|
||||||
from_state: breaker.state,
|
from_state: breaker.state,
|
||||||
to_state: CircuitState::Closed,
|
to_state: CircuitState::Closed,
|
||||||
reason: TriggerReason::ManualHalt { reason: "Manual reset".into() },
|
reason: TriggerReason::ManualHalt {
|
||||||
|
reason: "Manual reset".into(),
|
||||||
|
},
|
||||||
timestamp: Utc::now(),
|
timestamp: Utc::now(),
|
||||||
cooldown: None,
|
cooldown: None,
|
||||||
};
|
};
|
||||||
|
|
@ -558,26 +597,32 @@ impl CircuitBreakerManager {
|
||||||
/// Record oracle disagreement
|
/// Record oracle disagreement
|
||||||
pub fn record_oracle_spread(&mut self, pair: &str, spread: Decimal) -> Result<()> {
|
pub fn record_oracle_spread(&mut self, pair: &str, spread: Decimal) -> Result<()> {
|
||||||
if spread > self.config.max_oracle_spread {
|
if spread > self.config.max_oracle_spread {
|
||||||
self.breakers.entry(pair.to_string())
|
self.breakers
|
||||||
|
.entry(pair.to_string())
|
||||||
.or_insert_with(PairCircuitBreaker::new);
|
.or_insert_with(PairCircuitBreaker::new);
|
||||||
|
|
||||||
self.trigger_breaker(pair, TriggerReason::OracleDisagreement {
|
self.trigger_breaker(
|
||||||
spread_percent: spread * Decimal::ONE_HUNDRED,
|
pair,
|
||||||
})?;
|
TriggerReason::OracleDisagreement {
|
||||||
|
spread_percent: spread * Decimal::ONE_HUNDRED,
|
||||||
|
},
|
||||||
|
)?;
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get event history for a pair
|
/// Get event history for a pair
|
||||||
pub fn get_events(&self, pair: &str) -> Vec<CircuitEvent> {
|
pub fn get_events(&self, pair: &str) -> Vec<CircuitEvent> {
|
||||||
self.breakers.get(pair)
|
self.breakers
|
||||||
|
.get(pair)
|
||||||
.map(|b| b.events.clone())
|
.map(|b| b.events.clone())
|
||||||
.unwrap_or_default()
|
.unwrap_or_default()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get all currently halted pairs
|
/// Get all currently halted pairs
|
||||||
pub fn get_halted_pairs(&self) -> Vec<(String, CircuitState, Option<TriggerReason>)> {
|
pub fn get_halted_pairs(&self) -> Vec<(String, CircuitState, Option<TriggerReason>)> {
|
||||||
self.breakers.iter()
|
self.breakers
|
||||||
|
.iter()
|
||||||
.filter(|(_, b)| b.state != CircuitState::Closed)
|
.filter(|(_, b)| b.state != CircuitState::Closed)
|
||||||
.map(|(pair, b)| (pair.clone(), b.state, b.trigger_reason.clone()))
|
.map(|(pair, b)| (pair.clone(), b.state, b.trigger_reason.clone()))
|
||||||
.collect()
|
.collect()
|
||||||
|
|
@ -586,8 +631,16 @@ impl CircuitBreakerManager {
|
||||||
/// Get summary statistics
|
/// Get summary statistics
|
||||||
pub fn get_stats(&self) -> CircuitBreakerStats {
|
pub fn get_stats(&self) -> CircuitBreakerStats {
|
||||||
let total = self.breakers.len();
|
let total = self.breakers.len();
|
||||||
let open = self.breakers.values().filter(|b| b.state == CircuitState::Open).count();
|
let open = self
|
||||||
let half_open = self.breakers.values().filter(|b| b.state == CircuitState::HalfOpen).count();
|
.breakers
|
||||||
|
.values()
|
||||||
|
.filter(|b| b.state == CircuitState::Open)
|
||||||
|
.count();
|
||||||
|
let half_open = self
|
||||||
|
.breakers
|
||||||
|
.values()
|
||||||
|
.filter(|b| b.state == CircuitState::HalfOpen)
|
||||||
|
.count();
|
||||||
let total_events: usize = self.breakers.values().map(|b| b.events.len()).sum();
|
let total_events: usize = self.breakers.values().map(|b| b.events.len()).sum();
|
||||||
|
|
||||||
CircuitBreakerStats {
|
CircuitBreakerStats {
|
||||||
|
|
@ -628,7 +681,9 @@ mod tests {
|
||||||
// Normal price movements should not trigger
|
// Normal price movements should not trigger
|
||||||
for i in 0..10 {
|
for i in 0..10 {
|
||||||
let price = dec!(100) + Decimal::from(i);
|
let price = dec!(100) + Decimal::from(i);
|
||||||
let state = manager.record_price("SYNOR/USD", price, Some(dec!(100000))).unwrap();
|
let state = manager
|
||||||
|
.record_price("SYNOR/USD", price, Some(dec!(100000)))
|
||||||
|
.unwrap();
|
||||||
assert_eq!(state, CircuitState::Closed);
|
assert_eq!(state, CircuitState::Closed);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -641,10 +696,19 @@ mod tests {
|
||||||
let now = Utc::now();
|
let now = Utc::now();
|
||||||
|
|
||||||
// Record baseline 2 minutes ago
|
// Record baseline 2 minutes ago
|
||||||
manager.record_price_at("SYNOR/USD", dec!(100), Some(dec!(100000)), now - Duration::minutes(2)).unwrap();
|
manager
|
||||||
|
.record_price_at(
|
||||||
|
"SYNOR/USD",
|
||||||
|
dec!(100),
|
||||||
|
Some(dec!(100000)),
|
||||||
|
now - Duration::minutes(2),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
// Simulate 15% drop (exceeds 10% 1-minute threshold)
|
// Simulate 15% drop (exceeds 10% 1-minute threshold)
|
||||||
let state = manager.record_price_at("SYNOR/USD", dec!(85), Some(dec!(100000)), now).unwrap();
|
let state = manager
|
||||||
|
.record_price_at("SYNOR/USD", dec!(85), Some(dec!(100000)), now)
|
||||||
|
.unwrap();
|
||||||
assert_eq!(state, CircuitState::Open);
|
assert_eq!(state, CircuitState::Open);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -653,7 +717,9 @@ mod tests {
|
||||||
let mut manager = CircuitBreakerManager::new();
|
let mut manager = CircuitBreakerManager::new();
|
||||||
|
|
||||||
// Record with very low liquidity
|
// Record with very low liquidity
|
||||||
let state = manager.record_price("SYNOR/USD", dec!(100), Some(dec!(100))).unwrap();
|
let state = manager
|
||||||
|
.record_price("SYNOR/USD", dec!(100), Some(dec!(100)))
|
||||||
|
.unwrap();
|
||||||
assert_eq!(state, CircuitState::Open);
|
assert_eq!(state, CircuitState::Open);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -661,11 +727,15 @@ mod tests {
|
||||||
fn test_manual_halt_and_reset() {
|
fn test_manual_halt_and_reset() {
|
||||||
let mut manager = CircuitBreakerManager::new();
|
let mut manager = CircuitBreakerManager::new();
|
||||||
|
|
||||||
manager.record_price("SYNOR/USD", dec!(100), Some(dec!(100000))).unwrap();
|
manager
|
||||||
|
.record_price("SYNOR/USD", dec!(100), Some(dec!(100000)))
|
||||||
|
.unwrap();
|
||||||
assert!(manager.is_trading_allowed("SYNOR/USD"));
|
assert!(manager.is_trading_allowed("SYNOR/USD"));
|
||||||
|
|
||||||
// Manual halt
|
// Manual halt
|
||||||
manager.manual_halt("SYNOR/USD", "Scheduled maintenance").unwrap();
|
manager
|
||||||
|
.manual_halt("SYNOR/USD", "Scheduled maintenance")
|
||||||
|
.unwrap();
|
||||||
assert!(!manager.is_trading_allowed("SYNOR/USD"));
|
assert!(!manager.is_trading_allowed("SYNOR/USD"));
|
||||||
|
|
||||||
// Manual reset
|
// Manual reset
|
||||||
|
|
@ -678,10 +748,14 @@ mod tests {
|
||||||
let mut manager = CircuitBreakerManager::new();
|
let mut manager = CircuitBreakerManager::new();
|
||||||
|
|
||||||
// Initialize
|
// Initialize
|
||||||
manager.record_price("SYNOR/USD", dec!(100), Some(dec!(100000))).unwrap();
|
manager
|
||||||
|
.record_price("SYNOR/USD", dec!(100), Some(dec!(100000)))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
// Record 10% spread (exceeds 5% threshold)
|
// Record 10% spread (exceeds 5% threshold)
|
||||||
manager.record_oracle_spread("SYNOR/USD", dec!(0.10)).unwrap();
|
manager
|
||||||
|
.record_oracle_spread("SYNOR/USD", dec!(0.10))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
assert_eq!(manager.get_state("SYNOR/USD"), CircuitState::Open);
|
assert_eq!(manager.get_state("SYNOR/USD"), CircuitState::Open);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -200,7 +200,10 @@ pub struct CrossChainConfig {
|
||||||
impl Default for CrossChainConfig {
|
impl Default for CrossChainConfig {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
let mut tracked = HashMap::new();
|
let mut tracked = HashMap::new();
|
||||||
tracked.insert(ChainNetwork::Ethereum, vec!["ETH".to_string(), "USDC".to_string(), "USDT".to_string()]);
|
tracked.insert(
|
||||||
|
ChainNetwork::Ethereum,
|
||||||
|
vec!["ETH".to_string(), "USDC".to_string(), "USDT".to_string()],
|
||||||
|
);
|
||||||
tracked.insert(ChainNetwork::Bitcoin, vec!["BTC".to_string()]);
|
tracked.insert(ChainNetwork::Bitcoin, vec!["BTC".to_string()]);
|
||||||
tracked.insert(ChainNetwork::Cosmos, vec!["ATOM".to_string()]);
|
tracked.insert(ChainNetwork::Cosmos, vec!["ATOM".to_string()]);
|
||||||
tracked.insert(ChainNetwork::Osmosis, vec!["OSMO".to_string()]);
|
tracked.insert(ChainNetwork::Osmosis, vec!["OSMO".to_string()]);
|
||||||
|
|
@ -305,16 +308,21 @@ impl CrossChainOracle {
|
||||||
};
|
};
|
||||||
|
|
||||||
if !verified {
|
if !verified {
|
||||||
return Err(EconomicsError::InvalidPrice("Packet verification failed".into()));
|
return Err(EconomicsError::InvalidPrice(
|
||||||
|
"Packet verification failed".into(),
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cache the price
|
// Cache the price
|
||||||
let pair_key = format!("{}/{}", packet.token, packet.quote);
|
let pair_key = format!("{}/{}", packet.token, packet.quote);
|
||||||
self.cache.insert(pair_key, CrossChainPrice {
|
self.cache.insert(
|
||||||
packet,
|
pair_key,
|
||||||
received_at: Utc::now(),
|
CrossChainPrice {
|
||||||
verified,
|
packet,
|
||||||
});
|
received_at: Utc::now(),
|
||||||
|
verified,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
@ -326,7 +334,9 @@ impl CrossChainOracle {
|
||||||
token: &str,
|
token: &str,
|
||||||
quote: &str,
|
quote: &str,
|
||||||
) -> Result<CrossChainPricePacket> {
|
) -> Result<CrossChainPricePacket> {
|
||||||
let fetcher = self.fetchers.get(&chain)
|
let fetcher = self
|
||||||
|
.fetchers
|
||||||
|
.get(&chain)
|
||||||
.ok_or_else(|| EconomicsError::PriceFeedUnavailable(format!("{:?}", chain)))?;
|
.ok_or_else(|| EconomicsError::PriceFeedUnavailable(format!("{:?}", chain)))?;
|
||||||
|
|
||||||
let packet = fetcher.fetch_price(token, quote).await?;
|
let packet = fetcher.fetch_price(token, quote).await?;
|
||||||
|
|
@ -334,11 +344,14 @@ impl CrossChainOracle {
|
||||||
// Verify and cache
|
// Verify and cache
|
||||||
if fetcher.verify_packet(&packet) {
|
if fetcher.verify_packet(&packet) {
|
||||||
let pair_key = format!("{}/{}", token, quote);
|
let pair_key = format!("{}/{}", token, quote);
|
||||||
self.cache.insert(pair_key.clone(), CrossChainPrice {
|
self.cache.insert(
|
||||||
packet: packet.clone(),
|
pair_key.clone(),
|
||||||
received_at: Utc::now(),
|
CrossChainPrice {
|
||||||
verified: true,
|
packet: packet.clone(),
|
||||||
});
|
received_at: Utc::now(),
|
||||||
|
verified: true,
|
||||||
|
},
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(packet)
|
Ok(packet)
|
||||||
|
|
@ -347,7 +360,8 @@ impl CrossChainOracle {
|
||||||
/// Get cached price for a token pair
|
/// Get cached price for a token pair
|
||||||
pub fn get_price(&self, token: &str, quote: &str) -> Option<SynorDecimal> {
|
pub fn get_price(&self, token: &str, quote: &str) -> Option<SynorDecimal> {
|
||||||
let pair_key = format!("{}/{}", token, quote);
|
let pair_key = format!("{}/{}", token, quote);
|
||||||
self.cache.get(&pair_key)
|
self.cache
|
||||||
|
.get(&pair_key)
|
||||||
.filter(|c| c.verified)
|
.filter(|c| c.verified)
|
||||||
.filter(|c| (Utc::now() - c.received_at).num_seconds() < self.config.max_packet_age)
|
.filter(|c| (Utc::now() - c.received_at).num_seconds() < self.config.max_packet_age)
|
||||||
.map(|c| c.packet.price)
|
.map(|c| c.packet.price)
|
||||||
|
|
@ -356,7 +370,8 @@ impl CrossChainOracle {
|
||||||
/// Get price with full packet info
|
/// Get price with full packet info
|
||||||
pub fn get_price_with_info(&self, token: &str, quote: &str) -> Option<&CrossChainPricePacket> {
|
pub fn get_price_with_info(&self, token: &str, quote: &str) -> Option<&CrossChainPricePacket> {
|
||||||
let pair_key = format!("{}/{}", token, quote);
|
let pair_key = format!("{}/{}", token, quote);
|
||||||
self.cache.get(&pair_key)
|
self.cache
|
||||||
|
.get(&pair_key)
|
||||||
.filter(|c| c.verified)
|
.filter(|c| c.verified)
|
||||||
.map(|c| &c.packet)
|
.map(|c| &c.packet)
|
||||||
}
|
}
|
||||||
|
|
@ -408,7 +423,8 @@ impl CrossChainOracle {
|
||||||
|
|
||||||
/// Get all cached prices
|
/// Get all cached prices
|
||||||
pub fn get_all_prices(&self) -> Vec<TokenPrice> {
|
pub fn get_all_prices(&self) -> Vec<TokenPrice> {
|
||||||
self.cache.values()
|
self.cache
|
||||||
|
.values()
|
||||||
.filter(|c| c.verified)
|
.filter(|c| c.verified)
|
||||||
.map(|c| c.packet.to_token_price())
|
.map(|c| c.packet.to_token_price())
|
||||||
.collect()
|
.collect()
|
||||||
|
|
@ -417,9 +433,8 @@ impl CrossChainOracle {
|
||||||
/// Clear stale cache entries
|
/// Clear stale cache entries
|
||||||
pub fn cleanup_cache(&mut self) {
|
pub fn cleanup_cache(&mut self) {
|
||||||
let max_age = self.config.max_packet_age;
|
let max_age = self.config.max_packet_age;
|
||||||
self.cache.retain(|_, v| {
|
self.cache
|
||||||
(Utc::now() - v.received_at).num_seconds() < max_age
|
.retain(|_, v| (Utc::now() - v.received_at).num_seconds() < max_age);
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Send an IBC price request and track pending packet
|
/// Send an IBC price request and track pending packet
|
||||||
|
|
@ -450,7 +465,11 @@ impl CrossChainOracle {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Confirm a pending packet was received
|
/// Confirm a pending packet was received
|
||||||
pub fn confirm_pending_packet(&mut self, channel: &str, sequence: u64) -> Option<PendingPacket> {
|
pub fn confirm_pending_packet(
|
||||||
|
&mut self,
|
||||||
|
channel: &str,
|
||||||
|
sequence: u64,
|
||||||
|
) -> Option<PendingPacket> {
|
||||||
if let Some(idx) = self
|
if let Some(idx) = self
|
||||||
.pending_packets
|
.pending_packets
|
||||||
.iter()
|
.iter()
|
||||||
|
|
@ -469,9 +488,8 @@ impl CrossChainOracle {
|
||||||
|
|
||||||
/// Cleanup timed out pending packets
|
/// Cleanup timed out pending packets
|
||||||
pub fn cleanup_pending(&mut self, timeout_secs: i64) {
|
pub fn cleanup_pending(&mut self, timeout_secs: i64) {
|
||||||
self.pending_packets.retain(|p| {
|
self.pending_packets
|
||||||
(Utc::now() - p.sent_at).num_seconds() < timeout_secs
|
.retain(|p| (Utc::now() - p.sent_at).num_seconds() < timeout_secs);
|
||||||
});
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -502,9 +520,18 @@ pub struct EthereumPriceFetcher {
|
||||||
impl EthereumPriceFetcher {
|
impl EthereumPriceFetcher {
|
||||||
pub fn new(rpc_url: impl Into<String>) -> Self {
|
pub fn new(rpc_url: impl Into<String>) -> Self {
|
||||||
let mut feeds = HashMap::new();
|
let mut feeds = HashMap::new();
|
||||||
feeds.insert("ETH/USD".to_string(), "0x5f4eC3Df9cbd43714FE2740f5E3616155c5b8419".to_string());
|
feeds.insert(
|
||||||
feeds.insert("BTC/USD".to_string(), "0xF4030086522a5bEEa4988F8cA5B36dbC97BeE88c".to_string());
|
"ETH/USD".to_string(),
|
||||||
feeds.insert("USDC/USD".to_string(), "0x8fFfFfd4AfB6115b954Bd326cbe7B4BA576818f6".to_string());
|
"0x5f4eC3Df9cbd43714FE2740f5E3616155c5b8419".to_string(),
|
||||||
|
);
|
||||||
|
feeds.insert(
|
||||||
|
"BTC/USD".to_string(),
|
||||||
|
"0xF4030086522a5bEEa4988F8cA5B36dbC97BeE88c".to_string(),
|
||||||
|
);
|
||||||
|
feeds.insert(
|
||||||
|
"USDC/USD".to_string(),
|
||||||
|
"0x8fFfFfd4AfB6115b954Bd326cbe7B4BA576818f6".to_string(),
|
||||||
|
);
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
rpc_url: rpc_url.into(),
|
rpc_url: rpc_url.into(),
|
||||||
|
|
@ -526,7 +553,9 @@ impl ChainPriceFetcher for EthereumPriceFetcher {
|
||||||
|
|
||||||
async fn fetch_price(&self, token: &str, quote: &str) -> Result<CrossChainPricePacket> {
|
async fn fetch_price(&self, token: &str, quote: &str) -> Result<CrossChainPricePacket> {
|
||||||
let pair = format!("{}/{}", token, quote);
|
let pair = format!("{}/{}", token, quote);
|
||||||
let _feed_addr = self.chainlink_feeds.get(&pair)
|
let _feed_addr = self
|
||||||
|
.chainlink_feeds
|
||||||
|
.get(&pair)
|
||||||
.ok_or_else(|| EconomicsError::PriceFeedUnavailable(pair.clone()))?;
|
.ok_or_else(|| EconomicsError::PriceFeedUnavailable(pair.clone()))?;
|
||||||
|
|
||||||
// In production: Call Chainlink aggregator via ethers-rs
|
// In production: Call Chainlink aggregator via ethers-rs
|
||||||
|
|
@ -539,19 +568,16 @@ impl ChainPriceFetcher for EthereumPriceFetcher {
|
||||||
source_block: 19000000,
|
source_block: 19000000,
|
||||||
source_timestamp: Utc::now(),
|
source_timestamp: Utc::now(),
|
||||||
proof: None,
|
proof: None,
|
||||||
signatures: vec![
|
signatures: vec![OracleSignature {
|
||||||
OracleSignature {
|
signer: "chainlink".to_string(),
|
||||||
signer: "chainlink".to_string(),
|
signature: vec![0; 65],
|
||||||
signature: vec![0; 65],
|
timestamp: Utc::now(),
|
||||||
timestamp: Utc::now(),
|
}],
|
||||||
},
|
|
||||||
],
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn verify_packet(&self, packet: &CrossChainPricePacket) -> bool {
|
fn verify_packet(&self, packet: &CrossChainPricePacket) -> bool {
|
||||||
packet.source_chain == ChainNetwork::Ethereum
|
packet.source_chain == ChainNetwork::Ethereum && !packet.signatures.is_empty()
|
||||||
&& !packet.signatures.is_empty()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn supported_tokens(&self) -> Vec<String> {
|
fn supported_tokens(&self) -> Vec<String> {
|
||||||
|
|
@ -611,8 +637,7 @@ impl ChainPriceFetcher for CosmosPriceFetcher {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn verify_packet(&self, packet: &CrossChainPricePacket) -> bool {
|
fn verify_packet(&self, packet: &CrossChainPricePacket) -> bool {
|
||||||
packet.source_chain == ChainNetwork::Cosmos
|
packet.source_chain == ChainNetwork::Cosmos && packet.proof.is_some()
|
||||||
&& packet.proof.is_some()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn supported_tokens(&self) -> Vec<String> {
|
fn supported_tokens(&self) -> Vec<String> {
|
||||||
|
|
@ -650,7 +675,11 @@ impl CrossChainOracleBuilder {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Add Cosmos/IBC price fetcher
|
/// Add Cosmos/IBC price fetcher
|
||||||
pub fn with_cosmos(mut self, light_client_id: impl Into<String>, chain_id: impl Into<String>) -> Self {
|
pub fn with_cosmos(
|
||||||
|
mut self,
|
||||||
|
light_client_id: impl Into<String>,
|
||||||
|
chain_id: impl Into<String>,
|
||||||
|
) -> Self {
|
||||||
self.cosmos_light_client = Some((light_client_id.into(), chain_id.into()));
|
self.cosmos_light_client = Some((light_client_id.into(), chain_id.into()));
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
@ -693,8 +722,7 @@ impl CrossChainOracleFactory {
|
||||||
|
|
||||||
/// Create a production oracle with real endpoints
|
/// Create a production oracle with real endpoints
|
||||||
pub fn production(config: CrossChainProductionConfig) -> CrossChainOracle {
|
pub fn production(config: CrossChainProductionConfig) -> CrossChainOracle {
|
||||||
let mut builder = CrossChainOracleBuilder::new()
|
let mut builder = CrossChainOracleBuilder::new().with_config(config.cross_chain_config);
|
||||||
.with_config(config.cross_chain_config);
|
|
||||||
|
|
||||||
if let Some(eth_rpc) = config.ethereum_rpc_url {
|
if let Some(eth_rpc) = config.ethereum_rpc_url {
|
||||||
builder = builder.with_ethereum(eth_rpc);
|
builder = builder.with_ethereum(eth_rpc);
|
||||||
|
|
@ -715,7 +743,10 @@ impl CrossChainOracleFactory {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create an oracle with only Cosmos/IBC support
|
/// Create an oracle with only Cosmos/IBC support
|
||||||
pub fn cosmos_only(light_client_id: impl Into<String>, chain_id: impl Into<String>) -> CrossChainOracle {
|
pub fn cosmos_only(
|
||||||
|
light_client_id: impl Into<String>,
|
||||||
|
chain_id: impl Into<String>,
|
||||||
|
) -> CrossChainOracle {
|
||||||
CrossChainOracleBuilder::new()
|
CrossChainOracleBuilder::new()
|
||||||
.with_cosmos(light_client_id, chain_id)
|
.with_cosmos(light_client_id, chain_id)
|
||||||
.build()
|
.build()
|
||||||
|
|
|
||||||
|
|
@ -112,7 +112,9 @@ impl AggregationRound {
|
||||||
/// Add a submission to this round
|
/// Add a submission to this round
|
||||||
pub fn add_submission(&mut self, submission: PriceSubmission) -> Result<()> {
|
pub fn add_submission(&mut self, submission: PriceSubmission) -> Result<()> {
|
||||||
if self.finalized {
|
if self.finalized {
|
||||||
return Err(EconomicsError::InvalidPrice("Round already finalized".into()));
|
return Err(EconomicsError::InvalidPrice(
|
||||||
|
"Round already finalized".into(),
|
||||||
|
));
|
||||||
}
|
}
|
||||||
if Utc::now() >= self.deadline {
|
if Utc::now() >= self.deadline {
|
||||||
return Err(EconomicsError::InvalidPrice("Round deadline passed".into()));
|
return Err(EconomicsError::InvalidPrice("Round deadline passed".into()));
|
||||||
|
|
@ -122,7 +124,11 @@ impl AggregationRound {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for duplicate submission from same node
|
// Check for duplicate submission from same node
|
||||||
if self.submissions.iter().any(|s| s.node_id == submission.node_id) {
|
if self
|
||||||
|
.submissions
|
||||||
|
.iter()
|
||||||
|
.any(|s| s.node_id == submission.node_id)
|
||||||
|
{
|
||||||
return Err(EconomicsError::InvalidPrice("Duplicate submission".into()));
|
return Err(EconomicsError::InvalidPrice("Duplicate submission".into()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -231,7 +237,9 @@ impl DecentralizedOracle {
|
||||||
|
|
||||||
/// Update node heartbeat
|
/// Update node heartbeat
|
||||||
pub fn heartbeat(&mut self, node_id: &str) -> Result<()> {
|
pub fn heartbeat(&mut self, node_id: &str) -> Result<()> {
|
||||||
let node = self.nodes.get_mut(node_id)
|
let node = self
|
||||||
|
.nodes
|
||||||
|
.get_mut(node_id)
|
||||||
.ok_or_else(|| EconomicsError::InvalidPrice(format!("Unknown node: {}", node_id)))?;
|
.ok_or_else(|| EconomicsError::InvalidPrice(format!("Unknown node: {}", node_id)))?;
|
||||||
node.last_heartbeat = Utc::now();
|
node.last_heartbeat = Utc::now();
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|
@ -241,11 +249,8 @@ impl DecentralizedOracle {
|
||||||
pub fn start_round(&mut self, pair: impl Into<String>) -> u64 {
|
pub fn start_round(&mut self, pair: impl Into<String>) -> u64 {
|
||||||
let pair = pair.into();
|
let pair = pair.into();
|
||||||
self.round_counter += 1;
|
self.round_counter += 1;
|
||||||
let round = AggregationRound::new(
|
let round =
|
||||||
self.round_counter,
|
AggregationRound::new(self.round_counter, pair.clone(), self.config.round_duration);
|
||||||
pair.clone(),
|
|
||||||
self.config.round_duration
|
|
||||||
);
|
|
||||||
self.current_rounds.insert(pair, round);
|
self.current_rounds.insert(pair, round);
|
||||||
self.round_counter
|
self.round_counter
|
||||||
}
|
}
|
||||||
|
|
@ -253,7 +258,9 @@ impl DecentralizedOracle {
|
||||||
/// Submit a price for the current round
|
/// Submit a price for the current round
|
||||||
pub fn submit_price(&mut self, submission: PriceSubmission) -> Result<()> {
|
pub fn submit_price(&mut self, submission: PriceSubmission) -> Result<()> {
|
||||||
// Verify node exists and is eligible
|
// Verify node exists and is eligible
|
||||||
let node = self.nodes.get(&submission.node_id)
|
let node = self
|
||||||
|
.nodes
|
||||||
|
.get(&submission.node_id)
|
||||||
.ok_or_else(|| EconomicsError::InvalidPrice("Unknown node".into()))?;
|
.ok_or_else(|| EconomicsError::InvalidPrice("Unknown node".into()))?;
|
||||||
|
|
||||||
if !node.is_eligible(self.config.min_stake, self.config.min_reputation) {
|
if !node.is_eligible(self.config.min_stake, self.config.min_reputation) {
|
||||||
|
|
@ -264,7 +271,9 @@ impl DecentralizedOracle {
|
||||||
// For now, we trust the submission
|
// For now, we trust the submission
|
||||||
|
|
||||||
// Add to current round
|
// Add to current round
|
||||||
let round = self.current_rounds.get_mut(&submission.pair)
|
let round = self
|
||||||
|
.current_rounds
|
||||||
|
.get_mut(&submission.pair)
|
||||||
.ok_or_else(|| EconomicsError::InvalidPrice("No active round for pair".into()))?;
|
.ok_or_else(|| EconomicsError::InvalidPrice("No active round for pair".into()))?;
|
||||||
|
|
||||||
round.add_submission(submission)
|
round.add_submission(submission)
|
||||||
|
|
@ -274,13 +283,15 @@ impl DecentralizedOracle {
|
||||||
pub fn finalize_round(&mut self, pair: &str) -> Result<SynorDecimal> {
|
pub fn finalize_round(&mut self, pair: &str) -> Result<SynorDecimal> {
|
||||||
// First check state and get submissions (immutable borrow)
|
// First check state and get submissions (immutable borrow)
|
||||||
let (is_finalized, existing_price, submissions) = {
|
let (is_finalized, existing_price, submissions) = {
|
||||||
let round = self.current_rounds.get(pair)
|
let round = self
|
||||||
|
.current_rounds
|
||||||
|
.get(pair)
|
||||||
.ok_or_else(|| EconomicsError::PriceFeedUnavailable(pair.to_string()))?;
|
.ok_or_else(|| EconomicsError::PriceFeedUnavailable(pair.to_string()))?;
|
||||||
|
|
||||||
if round.finalized {
|
if round.finalized {
|
||||||
return round.final_price.ok_or_else(||
|
return round
|
||||||
EconomicsError::InvalidPrice("Round has no price".into())
|
.final_price
|
||||||
);
|
.ok_or_else(|| EconomicsError::InvalidPrice("Round has no price".into()));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check minimum submissions
|
// Check minimum submissions
|
||||||
|
|
@ -292,13 +303,16 @@ impl DecentralizedOracle {
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
|
|
||||||
(round.finalized, round.final_price, round.submissions.clone())
|
(
|
||||||
|
round.finalized,
|
||||||
|
round.final_price,
|
||||||
|
round.submissions.clone(),
|
||||||
|
)
|
||||||
};
|
};
|
||||||
|
|
||||||
if is_finalized {
|
if is_finalized {
|
||||||
return existing_price.ok_or_else(||
|
return existing_price
|
||||||
EconomicsError::InvalidPrice("Round has no price".into())
|
.ok_or_else(|| EconomicsError::InvalidPrice("Round has no price".into()));
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Filter outliers and aggregate (using cloned submissions)
|
// Filter outliers and aggregate (using cloned submissions)
|
||||||
|
|
@ -326,7 +340,11 @@ impl DecentralizedOracle {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Aggregate prices from a vector of submissions (owned)
|
/// Aggregate prices from a vector of submissions (owned)
|
||||||
fn aggregate_prices_from_vec(&self, pair: &str, submissions: &[PriceSubmission]) -> Result<SynorDecimal> {
|
fn aggregate_prices_from_vec(
|
||||||
|
&self,
|
||||||
|
pair: &str,
|
||||||
|
submissions: &[PriceSubmission],
|
||||||
|
) -> Result<SynorDecimal> {
|
||||||
if submissions.is_empty() {
|
if submissions.is_empty() {
|
||||||
return Err(EconomicsError::PriceFeedUnavailable(pair.to_string()));
|
return Err(EconomicsError::PriceFeedUnavailable(pair.to_string()));
|
||||||
}
|
}
|
||||||
|
|
@ -334,15 +352,21 @@ impl DecentralizedOracle {
|
||||||
// Filter outliers first
|
// Filter outliers first
|
||||||
let filtered = self.filter_outliers_vec(submissions);
|
let filtered = self.filter_outliers_vec(submissions);
|
||||||
if filtered.is_empty() {
|
if filtered.is_empty() {
|
||||||
return Err(EconomicsError::InvalidPrice("All submissions were outliers".into()));
|
return Err(EconomicsError::InvalidPrice(
|
||||||
|
"All submissions were outliers".into(),
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
let filtered_refs: Vec<_> = filtered.iter().collect();
|
let filtered_refs: Vec<_> = filtered.iter().collect();
|
||||||
match self.strategy {
|
match self.strategy {
|
||||||
AggregationStrategy::Median => self.calculate_median(&filtered_refs),
|
AggregationStrategy::Median => self.calculate_median(&filtered_refs),
|
||||||
AggregationStrategy::StakeWeightedMedian => self.calculate_stake_weighted_median(&filtered_refs),
|
AggregationStrategy::StakeWeightedMedian => {
|
||||||
|
self.calculate_stake_weighted_median(&filtered_refs)
|
||||||
|
}
|
||||||
AggregationStrategy::TrimmedMean => self.calculate_trimmed_mean(&filtered_refs),
|
AggregationStrategy::TrimmedMean => self.calculate_trimmed_mean(&filtered_refs),
|
||||||
AggregationStrategy::ReputationWeighted => self.calculate_reputation_weighted(&filtered_refs),
|
AggregationStrategy::ReputationWeighted => {
|
||||||
|
self.calculate_reputation_weighted(&filtered_refs)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -376,13 +400,14 @@ impl DecentralizedOracle {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Calculate stake-weighted median
|
/// Calculate stake-weighted median
|
||||||
fn calculate_stake_weighted_median(&self, submissions: &[&PriceSubmission]) -> Result<SynorDecimal> {
|
fn calculate_stake_weighted_median(
|
||||||
|
&self,
|
||||||
|
submissions: &[&PriceSubmission],
|
||||||
|
) -> Result<SynorDecimal> {
|
||||||
// Get stake for each submission
|
// Get stake for each submission
|
||||||
let mut weighted: Vec<(SynorDecimal, SynorDecimal)> = submissions
|
let mut weighted: Vec<(SynorDecimal, SynorDecimal)> = submissions
|
||||||
.iter()
|
.iter()
|
||||||
.filter_map(|s| {
|
.filter_map(|s| self.nodes.get(&s.node_id).map(|n| (s.price, n.stake)))
|
||||||
self.nodes.get(&s.node_id).map(|n| (s.price, n.stake))
|
|
||||||
})
|
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
if weighted.is_empty() {
|
if weighted.is_empty() {
|
||||||
|
|
@ -424,12 +449,17 @@ impl DecentralizedOracle {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Calculate reputation-weighted average
|
/// Calculate reputation-weighted average
|
||||||
fn calculate_reputation_weighted(&self, submissions: &[&PriceSubmission]) -> Result<SynorDecimal> {
|
fn calculate_reputation_weighted(
|
||||||
|
&self,
|
||||||
|
submissions: &[&PriceSubmission],
|
||||||
|
) -> Result<SynorDecimal> {
|
||||||
let mut weighted_sum = Decimal::ZERO;
|
let mut weighted_sum = Decimal::ZERO;
|
||||||
let mut total_weight = Decimal::ZERO;
|
let mut total_weight = Decimal::ZERO;
|
||||||
|
|
||||||
for sub in submissions {
|
for sub in submissions {
|
||||||
let reputation = self.nodes.get(&sub.node_id)
|
let reputation = self
|
||||||
|
.nodes
|
||||||
|
.get(&sub.node_id)
|
||||||
.map(|n| n.reputation)
|
.map(|n| n.reputation)
|
||||||
.unwrap_or(0.5);
|
.unwrap_or(0.5);
|
||||||
|
|
||||||
|
|
@ -448,7 +478,9 @@ impl DecentralizedOracle {
|
||||||
/// Update node reputations based on submission accuracy
|
/// Update node reputations based on submission accuracy
|
||||||
fn update_reputations(&mut self, _pair: &str, final_price: SynorDecimal) {
|
fn update_reputations(&mut self, _pair: &str, final_price: SynorDecimal) {
|
||||||
// Get submissions from current round before it was moved
|
// Get submissions from current round before it was moved
|
||||||
let submissions: Vec<_> = self.history.last()
|
let submissions: Vec<_> = self
|
||||||
|
.history
|
||||||
|
.last()
|
||||||
.map(|r| r.submissions.clone())
|
.map(|r| r.submissions.clone())
|
||||||
.unwrap_or_default();
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
|
@ -457,7 +489,8 @@ impl DecentralizedOracle {
|
||||||
let deviation = (sub.price - final_price).abs() / final_price;
|
let deviation = (sub.price - final_price).abs() / final_price;
|
||||||
|
|
||||||
// Increase reputation for accurate submissions, decrease for inaccurate
|
// Increase reputation for accurate submissions, decrease for inaccurate
|
||||||
if deviation <= Decimal::new(1, 2) { // Within 1%
|
if deviation <= Decimal::new(1, 2) {
|
||||||
|
// Within 1%
|
||||||
node.reputation = (node.reputation + 0.01).min(1.0);
|
node.reputation = (node.reputation + 0.01).min(1.0);
|
||||||
} else if deviation > self.config.max_deviation {
|
} else if deviation > self.config.max_deviation {
|
||||||
node.reputation = (node.reputation - 0.05).max(0.0);
|
node.reputation = (node.reputation - 0.05).max(0.0);
|
||||||
|
|
@ -473,7 +506,8 @@ impl DecentralizedOracle {
|
||||||
|
|
||||||
/// Get number of active nodes
|
/// Get number of active nodes
|
||||||
pub fn active_node_count(&self) -> usize {
|
pub fn active_node_count(&self) -> usize {
|
||||||
self.nodes.values()
|
self.nodes
|
||||||
|
.values()
|
||||||
.filter(|n| n.is_eligible(self.config.min_stake, self.config.min_reputation))
|
.filter(|n| n.is_eligible(self.config.min_stake, self.config.min_reputation))
|
||||||
.count()
|
.count()
|
||||||
}
|
}
|
||||||
|
|
@ -492,20 +526,23 @@ impl DecentralizedOracle {
|
||||||
|
|
||||||
/// Convert finalized price to TokenPrice
|
/// Convert finalized price to TokenPrice
|
||||||
pub fn to_token_price(&self, pair: &str) -> Option<TokenPrice> {
|
pub fn to_token_price(&self, pair: &str) -> Option<TokenPrice> {
|
||||||
self.history.iter()
|
self.history
|
||||||
|
.iter()
|
||||||
.rev()
|
.rev()
|
||||||
.find(|r| r.pair == pair && r.finalized)
|
.find(|r| r.pair == pair && r.finalized)
|
||||||
.and_then(|r| r.final_price.map(|price| {
|
.and_then(|r| {
|
||||||
let parts: Vec<_> = pair.split('/').collect();
|
r.final_price.map(|price| {
|
||||||
TokenPrice {
|
let parts: Vec<_> = pair.split('/').collect();
|
||||||
token: parts.get(0).unwrap_or(&"").to_string(),
|
TokenPrice {
|
||||||
quote: parts.get(1).unwrap_or(&"").to_string(),
|
token: parts.get(0).unwrap_or(&"").to_string(),
|
||||||
price,
|
quote: parts.get(1).unwrap_or(&"").to_string(),
|
||||||
timestamp: r.deadline,
|
price,
|
||||||
source: PriceSource::Aggregated,
|
timestamp: r.deadline,
|
||||||
confidence: 1.0,
|
source: PriceSource::Aggregated,
|
||||||
}
|
confidence: 1.0,
|
||||||
}))
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -521,13 +558,15 @@ mod tests {
|
||||||
use rust_decimal_macros::dec;
|
use rust_decimal_macros::dec;
|
||||||
|
|
||||||
fn create_test_nodes() -> Vec<OracleNode> {
|
fn create_test_nodes() -> Vec<OracleNode> {
|
||||||
(0..5).map(|i| {
|
(0..5)
|
||||||
OracleNode::new(
|
.map(|i| {
|
||||||
format!("node_{}", i),
|
OracleNode::new(
|
||||||
vec![i as u8; 32],
|
format!("node_{}", i),
|
||||||
dec!(10000), // 10k stake
|
vec![i as u8; 32],
|
||||||
)
|
dec!(10000), // 10k stake
|
||||||
}).collect()
|
)
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
|
||||||
|
|
@ -6,8 +6,8 @@
|
||||||
use crate::error::{EconomicsError, Result};
|
use crate::error::{EconomicsError, Result};
|
||||||
use crate::SynorDecimal;
|
use crate::SynorDecimal;
|
||||||
use chrono::{DateTime, Duration, Timelike, Utc};
|
use chrono::{DateTime, Duration, Timelike, Utc};
|
||||||
use rust_decimal::Decimal;
|
|
||||||
use rust_decimal::prelude::ToPrimitive;
|
use rust_decimal::prelude::ToPrimitive;
|
||||||
|
use rust_decimal::Decimal;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::f64::consts::PI;
|
use std::f64::consts::PI;
|
||||||
|
|
@ -178,7 +178,12 @@ impl BlackScholes {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Price a European option
|
/// Price a European option
|
||||||
pub fn price(&self, contract: &OptionContract, spot: SynorDecimal, vol: f64) -> Result<OptionPricing> {
|
pub fn price(
|
||||||
|
&self,
|
||||||
|
contract: &OptionContract,
|
||||||
|
spot: SynorDecimal,
|
||||||
|
vol: f64,
|
||||||
|
) -> Result<OptionPricing> {
|
||||||
if contract.is_expired() {
|
if contract.is_expired() {
|
||||||
// At expiration, option is worth intrinsic value
|
// At expiration, option is worth intrinsic value
|
||||||
let intrinsic = contract.intrinsic_value(spot);
|
let intrinsic = contract.intrinsic_value(spot);
|
||||||
|
|
@ -208,12 +213,13 @@ impl BlackScholes {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
let s = spot.to_f64().ok_or_else(||
|
let s = spot
|
||||||
EconomicsError::InvalidPrice("Invalid spot price".into())
|
.to_f64()
|
||||||
)?;
|
.ok_or_else(|| EconomicsError::InvalidPrice("Invalid spot price".into()))?;
|
||||||
let k = contract.strike.to_f64().ok_or_else(||
|
let k = contract
|
||||||
EconomicsError::InvalidPrice("Invalid strike price".into())
|
.strike
|
||||||
)?;
|
.to_f64()
|
||||||
|
.ok_or_else(|| EconomicsError::InvalidPrice("Invalid strike price".into()))?;
|
||||||
let t = contract.time_to_expiry();
|
let t = contract.time_to_expiry();
|
||||||
|
|
||||||
if t <= 0.0 || vol <= 0.0 {
|
if t <= 0.0 || vol <= 0.0 {
|
||||||
|
|
@ -286,13 +292,10 @@ impl BlackScholes {
|
||||||
let theta_common = -(s * vol * (-q * t).exp() * n_prime_d1) / (2.0 * sqrt_t);
|
let theta_common = -(s * vol * (-q * t).exp() * n_prime_d1) / (2.0 * sqrt_t);
|
||||||
let theta = match contract.option_type {
|
let theta = match contract.option_type {
|
||||||
OptionType::Call => {
|
OptionType::Call => {
|
||||||
theta_common
|
theta_common + q * s * (-q * t).exp() * n_d1 - r * k * (-r * t).exp() * n_d2
|
||||||
+ q * s * (-q * t).exp() * n_d1
|
|
||||||
- r * k * (-r * t).exp() * n_d2
|
|
||||||
}
|
}
|
||||||
OptionType::Put => {
|
OptionType::Put => {
|
||||||
theta_common
|
theta_common - q * s * (-q * t).exp() * (1.0 - n_d1)
|
||||||
- q * s * (-q * t).exp() * (1.0 - n_d1)
|
|
||||||
+ r * k * (-r * t).exp() * (1.0 - n_d2)
|
+ r * k * (-r * t).exp() * (1.0 - n_d2)
|
||||||
}
|
}
|
||||||
} / 365.0; // Per day
|
} / 365.0; // Per day
|
||||||
|
|
@ -327,16 +330,16 @@ impl BlackScholes {
|
||||||
spot: SynorDecimal,
|
spot: SynorDecimal,
|
||||||
market_price: SynorDecimal,
|
market_price: SynorDecimal,
|
||||||
) -> Result<f64> {
|
) -> Result<f64> {
|
||||||
let target = market_price.to_f64().ok_or_else(||
|
let target = market_price
|
||||||
EconomicsError::InvalidPrice("Invalid market price".into())
|
.to_f64()
|
||||||
)?;
|
.ok_or_else(|| EconomicsError::InvalidPrice("Invalid market price".into()))?;
|
||||||
|
|
||||||
// Initial guess based on time value
|
// Initial guess based on time value
|
||||||
let intrinsic = contract.intrinsic_value(spot).to_f64().unwrap_or(0.0);
|
let intrinsic = contract.intrinsic_value(spot).to_f64().unwrap_or(0.0);
|
||||||
let time_value = (target - intrinsic).max(0.0);
|
let time_value = (target - intrinsic).max(0.0);
|
||||||
let s = spot.to_f64().ok_or_else(||
|
let s = spot
|
||||||
EconomicsError::InvalidPrice("Invalid spot".into())
|
.to_f64()
|
||||||
)?;
|
.ok_or_else(|| EconomicsError::InvalidPrice("Invalid spot".into()))?;
|
||||||
let t = contract.time_to_expiry();
|
let t = contract.time_to_expiry();
|
||||||
|
|
||||||
// Brenner-Subrahmanyam approximation for initial guess
|
// Brenner-Subrahmanyam approximation for initial guess
|
||||||
|
|
@ -381,7 +384,11 @@ impl BlackScholes {
|
||||||
|
|
||||||
for _ in 0..100 {
|
for _ in 0..100 {
|
||||||
let mid = (low + high) / 2.0;
|
let mid = (low + high) / 2.0;
|
||||||
let price = self.price(contract, spot, mid)?.price.to_f64().unwrap_or(0.0);
|
let price = self
|
||||||
|
.price(contract, spot, mid)?
|
||||||
|
.price
|
||||||
|
.to_f64()
|
||||||
|
.unwrap_or(0.0);
|
||||||
|
|
||||||
if (price - target).abs() < 0.0001 {
|
if (price - target).abs() < 0.0001 {
|
||||||
return Ok(mid);
|
return Ok(mid);
|
||||||
|
|
@ -493,9 +500,9 @@ impl FuturesModel {
|
||||||
/// F = S * e^((r + u - y) * T)
|
/// F = S * e^((r + u - y) * T)
|
||||||
/// where r = risk-free rate, u = storage cost, y = convenience yield
|
/// where r = risk-free rate, u = storage cost, y = convenience yield
|
||||||
pub fn price(&self, contract: &FuturesContract, spot: SynorDecimal) -> Result<FuturesPricing> {
|
pub fn price(&self, contract: &FuturesContract, spot: SynorDecimal) -> Result<FuturesPricing> {
|
||||||
let s = spot.to_f64().ok_or_else(||
|
let s = spot
|
||||||
EconomicsError::InvalidPrice("Invalid spot price".into())
|
.to_f64()
|
||||||
)?;
|
.ok_or_else(|| EconomicsError::InvalidPrice("Invalid spot price".into()))?;
|
||||||
|
|
||||||
let t = contract.time_to_expiry();
|
let t = contract.time_to_expiry();
|
||||||
if t < 0.0 {
|
if t < 0.0 {
|
||||||
|
|
@ -515,8 +522,7 @@ impl FuturesModel {
|
||||||
0.0
|
0.0
|
||||||
};
|
};
|
||||||
|
|
||||||
let coc = Decimal::from_f64_retain(cost_of_carry * t * s)
|
let coc = Decimal::from_f64_retain(cost_of_carry * t * s).unwrap_or(Decimal::ZERO);
|
||||||
.unwrap_or(Decimal::ZERO);
|
|
||||||
|
|
||||||
Ok(FuturesPricing {
|
Ok(FuturesPricing {
|
||||||
fair_value,
|
fair_value,
|
||||||
|
|
@ -529,13 +535,18 @@ impl FuturesModel {
|
||||||
|
|
||||||
/// Calculate implied repo rate from futures price
|
/// Calculate implied repo rate from futures price
|
||||||
/// R = (F/S - 1) / T
|
/// R = (F/S - 1) / T
|
||||||
pub fn implied_repo_rate(&self, contract: &FuturesContract, spot: SynorDecimal, futures_price: SynorDecimal) -> Result<f64> {
|
pub fn implied_repo_rate(
|
||||||
let s = spot.to_f64().ok_or_else(||
|
&self,
|
||||||
EconomicsError::InvalidPrice("Invalid spot".into())
|
contract: &FuturesContract,
|
||||||
)?;
|
spot: SynorDecimal,
|
||||||
let f = futures_price.to_f64().ok_or_else(||
|
futures_price: SynorDecimal,
|
||||||
EconomicsError::InvalidPrice("Invalid futures price".into())
|
) -> Result<f64> {
|
||||||
)?;
|
let s = spot
|
||||||
|
.to_f64()
|
||||||
|
.ok_or_else(|| EconomicsError::InvalidPrice("Invalid spot".into()))?;
|
||||||
|
let f = futures_price
|
||||||
|
.to_f64()
|
||||||
|
.ok_or_else(|| EconomicsError::InvalidPrice("Invalid futures price".into()))?;
|
||||||
let t = contract.time_to_expiry();
|
let t = contract.time_to_expiry();
|
||||||
|
|
||||||
if t <= 0.0 || s <= 0.0 {
|
if t <= 0.0 || s <= 0.0 {
|
||||||
|
|
@ -595,12 +606,12 @@ impl PerpetualModel {
|
||||||
mark_price: SynorDecimal,
|
mark_price: SynorDecimal,
|
||||||
index_price: SynorDecimal,
|
index_price: SynorDecimal,
|
||||||
) -> Result<f64> {
|
) -> Result<f64> {
|
||||||
let mark = mark_price.to_f64().ok_or_else(||
|
let mark = mark_price
|
||||||
EconomicsError::InvalidPrice("Invalid mark price".into())
|
.to_f64()
|
||||||
)?;
|
.ok_or_else(|| EconomicsError::InvalidPrice("Invalid mark price".into()))?;
|
||||||
let index = index_price.to_f64().ok_or_else(||
|
let index = index_price
|
||||||
EconomicsError::InvalidPrice("Invalid index price".into())
|
.to_f64()
|
||||||
)?;
|
.ok_or_else(|| EconomicsError::InvalidPrice("Invalid index price".into()))?;
|
||||||
|
|
||||||
if index <= 0.0 {
|
if index <= 0.0 {
|
||||||
return Err(EconomicsError::InvalidPrice("Invalid index".into()));
|
return Err(EconomicsError::InvalidPrice("Invalid index".into()));
|
||||||
|
|
@ -640,7 +651,8 @@ impl PerpetualModel {
|
||||||
let hours_since_midnight = now.time().hour();
|
let hours_since_midnight = now.time().hour();
|
||||||
let next_funding_hour = ((hours_since_midnight / self.funding_interval_hours) + 1)
|
let next_funding_hour = ((hours_since_midnight / self.funding_interval_hours) + 1)
|
||||||
* self.funding_interval_hours;
|
* self.funding_interval_hours;
|
||||||
let next_funding = now.date_naive()
|
let next_funding = now
|
||||||
|
.date_naive()
|
||||||
.and_hms_opt(next_funding_hour % 24, 0, 0)
|
.and_hms_opt(next_funding_hour % 24, 0, 0)
|
||||||
.map(|dt| DateTime::from_naive_utc_and_offset(dt, Utc))
|
.map(|dt| DateTime::from_naive_utc_and_offset(dt, Utc))
|
||||||
.unwrap_or(now + Duration::hours(self.funding_interval_hours as i64));
|
.unwrap_or(now + Duration::hours(self.funding_interval_hours as i64));
|
||||||
|
|
@ -721,7 +733,8 @@ impl DerivativesOracle {
|
||||||
|
|
||||||
/// Set volatility surface for an underlying
|
/// Set volatility surface for an underlying
|
||||||
pub fn set_vol_surface(&mut self, surface: VolatilitySurface) {
|
pub fn set_vol_surface(&mut self, surface: VolatilitySurface) {
|
||||||
self.vol_surfaces.insert(surface.underlying.clone(), surface);
|
self.vol_surfaces
|
||||||
|
.insert(surface.underlying.clone(), surface);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Price an option using the volatility surface
|
/// Price an option using the volatility surface
|
||||||
|
|
@ -730,12 +743,13 @@ impl DerivativesOracle {
|
||||||
contract: &OptionContract,
|
contract: &OptionContract,
|
||||||
spot: SynorDecimal,
|
spot: SynorDecimal,
|
||||||
) -> Result<OptionPricing> {
|
) -> Result<OptionPricing> {
|
||||||
let s = spot.to_f64().ok_or_else(||
|
let s = spot
|
||||||
EconomicsError::InvalidPrice("Invalid spot".into())
|
.to_f64()
|
||||||
)?;
|
.ok_or_else(|| EconomicsError::InvalidPrice("Invalid spot".into()))?;
|
||||||
let k = contract.strike.to_f64().ok_or_else(||
|
let k = contract
|
||||||
EconomicsError::InvalidPrice("Invalid strike".into())
|
.strike
|
||||||
)?;
|
.to_f64()
|
||||||
|
.ok_or_else(|| EconomicsError::InvalidPrice("Invalid strike".into()))?;
|
||||||
|
|
||||||
// Get volatility from surface or use default
|
// Get volatility from surface or use default
|
||||||
let vol = if let Some(surface) = self.vol_surfaces.get(&contract.underlying) {
|
let vol = if let Some(surface) = self.vol_surfaces.get(&contract.underlying) {
|
||||||
|
|
@ -765,7 +779,8 @@ impl DerivativesOracle {
|
||||||
mark_price: SynorDecimal,
|
mark_price: SynorDecimal,
|
||||||
open_interest: SynorDecimal,
|
open_interest: SynorDecimal,
|
||||||
) -> Result<PerpetualPricing> {
|
) -> Result<PerpetualPricing> {
|
||||||
self.perpetual_model.price(index_price, mark_price, open_interest)
|
self.perpetual_model
|
||||||
|
.price(index_price, mark_price, open_interest)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Calculate implied vol from market price
|
/// Calculate implied vol from market price
|
||||||
|
|
@ -775,7 +790,8 @@ impl DerivativesOracle {
|
||||||
spot: SynorDecimal,
|
spot: SynorDecimal,
|
||||||
market_price: SynorDecimal,
|
market_price: SynorDecimal,
|
||||||
) -> Result<f64> {
|
) -> Result<f64> {
|
||||||
self.options_model.implied_volatility(contract, spot, market_price)
|
self.options_model
|
||||||
|
.implied_volatility(contract, spot, market_price)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -792,8 +808,18 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_option_intrinsic_value() {
|
fn test_option_intrinsic_value() {
|
||||||
let call = OptionContract::new("ETH", dec!(2000), Utc::now() + Duration::days(30), OptionType::Call);
|
let call = OptionContract::new(
|
||||||
let put = OptionContract::new("ETH", dec!(2000), Utc::now() + Duration::days(30), OptionType::Put);
|
"ETH",
|
||||||
|
dec!(2000),
|
||||||
|
Utc::now() + Duration::days(30),
|
||||||
|
OptionType::Call,
|
||||||
|
);
|
||||||
|
let put = OptionContract::new(
|
||||||
|
"ETH",
|
||||||
|
dec!(2000),
|
||||||
|
Utc::now() + Duration::days(30),
|
||||||
|
OptionType::Put,
|
||||||
|
);
|
||||||
|
|
||||||
// ITM call
|
// ITM call
|
||||||
assert_eq!(call.intrinsic_value(dec!(2100)), dec!(100));
|
assert_eq!(call.intrinsic_value(dec!(2100)), dec!(100));
|
||||||
|
|
@ -880,7 +906,9 @@ mod tests {
|
||||||
let pricing = model.price(&contract, dec!(2000), vol).unwrap();
|
let pricing = model.price(&contract, dec!(2000), vol).unwrap();
|
||||||
|
|
||||||
// Calculate IV from that price
|
// Calculate IV from that price
|
||||||
let iv = model.implied_volatility(&contract, dec!(2000), pricing.price).unwrap();
|
let iv = model
|
||||||
|
.implied_volatility(&contract, dec!(2000), pricing.price)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
// Should match original vol
|
// Should match original vol
|
||||||
assert!((iv - vol).abs() < 0.01);
|
assert!((iv - vol).abs() < 0.01);
|
||||||
|
|
|
||||||
|
|
@ -40,12 +40,12 @@ impl CollateralAsset {
|
||||||
pub fn standard(symbol: impl Into<String>) -> Self {
|
pub fn standard(symbol: impl Into<String>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
symbol: symbol.into(),
|
symbol: symbol.into(),
|
||||||
collateral_factor: Decimal::new(75, 2), // 75%
|
collateral_factor: Decimal::new(75, 2), // 75%
|
||||||
liquidation_threshold: Decimal::new(80, 2), // 80%
|
liquidation_threshold: Decimal::new(80, 2), // 80%
|
||||||
liquidation_bonus: Decimal::new(5, 2), // 5%
|
liquidation_bonus: Decimal::new(5, 2), // 5%
|
||||||
supply_cap: None,
|
supply_cap: None,
|
||||||
borrow_enabled: true,
|
borrow_enabled: true,
|
||||||
reserve_factor: Decimal::new(10, 2), // 10%
|
reserve_factor: Decimal::new(10, 2), // 10%
|
||||||
volatility_multiplier: Decimal::ONE,
|
volatility_multiplier: Decimal::ONE,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -54,12 +54,12 @@ impl CollateralAsset {
|
||||||
pub fn stablecoin(symbol: impl Into<String>) -> Self {
|
pub fn stablecoin(symbol: impl Into<String>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
symbol: symbol.into(),
|
symbol: symbol.into(),
|
||||||
collateral_factor: Decimal::new(90, 2), // 90%
|
collateral_factor: Decimal::new(90, 2), // 90%
|
||||||
liquidation_threshold: Decimal::new(95, 2), // 95%
|
liquidation_threshold: Decimal::new(95, 2), // 95%
|
||||||
liquidation_bonus: Decimal::new(2, 2), // 2%
|
liquidation_bonus: Decimal::new(2, 2), // 2%
|
||||||
supply_cap: None,
|
supply_cap: None,
|
||||||
borrow_enabled: true,
|
borrow_enabled: true,
|
||||||
reserve_factor: Decimal::new(5, 2), // 5%
|
reserve_factor: Decimal::new(5, 2), // 5%
|
||||||
volatility_multiplier: Decimal::ONE,
|
volatility_multiplier: Decimal::ONE,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -68,13 +68,13 @@ impl CollateralAsset {
|
||||||
pub fn volatile(symbol: impl Into<String>) -> Self {
|
pub fn volatile(symbol: impl Into<String>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
symbol: symbol.into(),
|
symbol: symbol.into(),
|
||||||
collateral_factor: Decimal::new(50, 2), // 50%
|
collateral_factor: Decimal::new(50, 2), // 50%
|
||||||
liquidation_threshold: Decimal::new(65, 2), // 65%
|
liquidation_threshold: Decimal::new(65, 2), // 65%
|
||||||
liquidation_bonus: Decimal::new(10, 2), // 10%
|
liquidation_bonus: Decimal::new(10, 2), // 10%
|
||||||
supply_cap: None,
|
supply_cap: None,
|
||||||
borrow_enabled: true,
|
borrow_enabled: true,
|
||||||
reserve_factor: Decimal::new(20, 2), // 20%
|
reserve_factor: Decimal::new(20, 2), // 20%
|
||||||
volatility_multiplier: Decimal::new(12, 1), // 1.2x
|
volatility_multiplier: Decimal::new(12, 1), // 1.2x
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -131,7 +131,11 @@ impl LendingPosition {
|
||||||
/// Withdraw collateral
|
/// Withdraw collateral
|
||||||
pub fn withdraw(&mut self, asset: impl Into<String>, amount: SynorDecimal) -> Result<()> {
|
pub fn withdraw(&mut self, asset: impl Into<String>, amount: SynorDecimal) -> Result<()> {
|
||||||
let asset = asset.into();
|
let asset = asset.into();
|
||||||
let current = self.collateral.get(&asset).copied().unwrap_or(Decimal::ZERO);
|
let current = self
|
||||||
|
.collateral
|
||||||
|
.get(&asset)
|
||||||
|
.copied()
|
||||||
|
.unwrap_or(Decimal::ZERO);
|
||||||
if amount > current {
|
if amount > current {
|
||||||
return Err(EconomicsError::InsufficientFunds {
|
return Err(EconomicsError::InsufficientFunds {
|
||||||
required: amount,
|
required: amount,
|
||||||
|
|
@ -233,10 +237,10 @@ pub struct LiquidationOracleConfig {
|
||||||
impl Default for LiquidationOracleConfig {
|
impl Default for LiquidationOracleConfig {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
max_price_age: 60, // 1 minute (stricter than general oracle)
|
max_price_age: 60, // 1 minute (stricter than general oracle)
|
||||||
min_confidence: 0.9,
|
min_confidence: 0.9,
|
||||||
min_sources: 2,
|
min_sources: 2,
|
||||||
liquidation_grace_period: 300, // 5 minutes
|
liquidation_grace_period: 300, // 5 minutes
|
||||||
min_liquidation_amount: Decimal::new(10, 0), // $10
|
min_liquidation_amount: Decimal::new(10, 0), // $10
|
||||||
max_liquidation_pct: Decimal::new(50, 2), // 50% at a time
|
max_liquidation_pct: Decimal::new(50, 2), // 50% at a time
|
||||||
partial_liquidation: true,
|
partial_liquidation: true,
|
||||||
|
|
@ -289,7 +293,8 @@ impl LiquidationOracle {
|
||||||
/// Create a new position
|
/// Create a new position
|
||||||
pub fn create_position(&mut self, account_id: impl Into<String>) -> &mut LendingPosition {
|
pub fn create_position(&mut self, account_id: impl Into<String>) -> &mut LendingPosition {
|
||||||
let account_id = account_id.into();
|
let account_id = account_id.into();
|
||||||
self.positions.entry(account_id.clone())
|
self.positions
|
||||||
|
.entry(account_id.clone())
|
||||||
.or_insert_with(|| LendingPosition::new(account_id))
|
.or_insert_with(|| LendingPosition::new(account_id))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -320,7 +325,8 @@ impl LiquidationOracle {
|
||||||
freshness_remaining: self.config.max_price_age - age,
|
freshness_remaining: self.config.max_price_age - age,
|
||||||
};
|
};
|
||||||
|
|
||||||
self.price_cache.insert(asset.to_string(), liq_price.clone());
|
self.price_cache
|
||||||
|
.insert(asset.to_string(), liq_price.clone());
|
||||||
Ok(liq_price)
|
Ok(liq_price)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -328,7 +334,9 @@ impl LiquidationOracle {
|
||||||
pub fn calculate_health(&mut self, account_id: &str) -> Result<HealthStatus> {
|
pub fn calculate_health(&mut self, account_id: &str) -> Result<HealthStatus> {
|
||||||
// Clone position data to avoid borrow conflicts with get_liquidation_price
|
// Clone position data to avoid borrow conflicts with get_liquidation_price
|
||||||
let (collateral, borrows, interest_owed) = {
|
let (collateral, borrows, interest_owed) = {
|
||||||
let position = self.positions.get(account_id)
|
let position = self
|
||||||
|
.positions
|
||||||
|
.get(account_id)
|
||||||
.ok_or_else(|| EconomicsError::AccountNotFound(account_id.to_string()))?;
|
.ok_or_else(|| EconomicsError::AccountNotFound(account_id.to_string()))?;
|
||||||
(
|
(
|
||||||
position.collateral.clone(),
|
position.collateral.clone(),
|
||||||
|
|
@ -352,7 +360,9 @@ impl LiquidationOracle {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
let asset_config = self.assets.get(asset)
|
let asset_config = self
|
||||||
|
.assets
|
||||||
|
.get(asset)
|
||||||
.ok_or_else(|| EconomicsError::InvalidPrice(format!("Unknown asset: {}", asset)))?;
|
.ok_or_else(|| EconomicsError::InvalidPrice(format!("Unknown asset: {}", asset)))?;
|
||||||
|
|
||||||
let value = *amount * price.price;
|
let value = *amount * price.price;
|
||||||
|
|
@ -426,24 +436,37 @@ impl LiquidationOracle {
|
||||||
return Err(EconomicsError::InvalidPrice("Position is healthy".into()));
|
return Err(EconomicsError::InvalidPrice("Position is healthy".into()));
|
||||||
}
|
}
|
||||||
|
|
||||||
let position = self.positions.get(account_id)
|
let position = self
|
||||||
|
.positions
|
||||||
|
.get(account_id)
|
||||||
.ok_or_else(|| EconomicsError::AccountNotFound(account_id.to_string()))?;
|
.ok_or_else(|| EconomicsError::AccountNotFound(account_id.to_string()))?;
|
||||||
|
|
||||||
let debt_amount = position.borrows.get(debt_asset).copied().unwrap_or(Decimal::ZERO);
|
let debt_amount = position
|
||||||
let collateral_amount = position.collateral.get(collateral_asset).copied().unwrap_or(Decimal::ZERO);
|
.borrows
|
||||||
|
.get(debt_asset)
|
||||||
|
.copied()
|
||||||
|
.unwrap_or(Decimal::ZERO);
|
||||||
|
let collateral_amount = position
|
||||||
|
.collateral
|
||||||
|
.get(collateral_asset)
|
||||||
|
.copied()
|
||||||
|
.unwrap_or(Decimal::ZERO);
|
||||||
|
|
||||||
if debt_amount == Decimal::ZERO {
|
if debt_amount == Decimal::ZERO {
|
||||||
return Err(EconomicsError::InvalidPrice("No debt to repay".into()));
|
return Err(EconomicsError::InvalidPrice("No debt to repay".into()));
|
||||||
}
|
}
|
||||||
if collateral_amount == Decimal::ZERO {
|
if collateral_amount == Decimal::ZERO {
|
||||||
return Err(EconomicsError::InvalidPrice("No collateral to seize".into()));
|
return Err(EconomicsError::InvalidPrice(
|
||||||
|
"No collateral to seize".into(),
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
let debt_price = self.get_liquidation_price(debt_asset)?;
|
let debt_price = self.get_liquidation_price(debt_asset)?;
|
||||||
let collateral_price = self.get_liquidation_price(collateral_asset)?;
|
let collateral_price = self.get_liquidation_price(collateral_asset)?;
|
||||||
|
|
||||||
let collateral_config = self.assets.get(collateral_asset)
|
let collateral_config = self.assets.get(collateral_asset).ok_or_else(|| {
|
||||||
.ok_or_else(|| EconomicsError::InvalidPrice(format!("Unknown asset: {}", collateral_asset)))?;
|
EconomicsError::InvalidPrice(format!("Unknown asset: {}", collateral_asset))
|
||||||
|
})?;
|
||||||
|
|
||||||
// Max debt repayable = close_factor * total_debt
|
// Max debt repayable = close_factor * total_debt
|
||||||
let max_debt_repay = debt_amount * self.config.close_factor;
|
let max_debt_repay = debt_amount * self.config.close_factor;
|
||||||
|
|
@ -458,12 +481,14 @@ impl LiquidationOracle {
|
||||||
let actual_collateral_seized = collateral_to_seize.min(collateral_amount);
|
let actual_collateral_seized = collateral_to_seize.min(collateral_amount);
|
||||||
let actual_debt_repaid = if actual_collateral_seized < collateral_to_seize {
|
let actual_debt_repaid = if actual_collateral_seized < collateral_to_seize {
|
||||||
// Partial liquidation
|
// Partial liquidation
|
||||||
(actual_collateral_seized * collateral_price.price) / (bonus_multiplier * debt_price.price)
|
(actual_collateral_seized * collateral_price.price)
|
||||||
|
/ (bonus_multiplier * debt_price.price)
|
||||||
} else {
|
} else {
|
||||||
max_debt_repay
|
max_debt_repay
|
||||||
};
|
};
|
||||||
|
|
||||||
let bonus_amount = actual_collateral_seized * collateral_config.liquidation_bonus / bonus_multiplier;
|
let bonus_amount =
|
||||||
|
actual_collateral_seized * collateral_config.liquidation_bonus / bonus_multiplier;
|
||||||
|
|
||||||
Ok(LiquidationCalculation {
|
Ok(LiquidationCalculation {
|
||||||
account_id: account_id.to_string(),
|
account_id: account_id.to_string(),
|
||||||
|
|
@ -489,7 +514,9 @@ impl LiquidationOracle {
|
||||||
let calc = self.calculate_liquidation(account_id, debt_asset, collateral_asset)?;
|
let calc = self.calculate_liquidation(account_id, debt_asset, collateral_asset)?;
|
||||||
|
|
||||||
// Update position
|
// Update position
|
||||||
let position = self.positions.get_mut(account_id)
|
let position = self
|
||||||
|
.positions
|
||||||
|
.get_mut(account_id)
|
||||||
.ok_or_else(|| EconomicsError::AccountNotFound(account_id.to_string()))?;
|
.ok_or_else(|| EconomicsError::AccountNotFound(account_id.to_string()))?;
|
||||||
|
|
||||||
// Reduce debt
|
// Reduce debt
|
||||||
|
|
@ -549,7 +576,9 @@ impl LiquidationOracle {
|
||||||
// Protocol gets a portion of the liquidation bonus
|
// Protocol gets a portion of the liquidation bonus
|
||||||
if let Some(asset_config) = self.assets.get(&event.collateral_asset) {
|
if let Some(asset_config) = self.assets.get(&event.collateral_asset) {
|
||||||
let protocol_share = event.bonus_amount * asset_config.reserve_factor;
|
let protocol_share = event.bonus_amount * asset_config.reserve_factor;
|
||||||
*reserves.entry(event.collateral_asset.clone()).or_insert(Decimal::ZERO) += protocol_share;
|
*reserves
|
||||||
|
.entry(event.collateral_asset.clone())
|
||||||
|
.or_insert(Decimal::ZERO) += protocol_share;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -561,15 +590,18 @@ impl LiquidationOracle {
|
||||||
let total_positions = self.positions.len();
|
let total_positions = self.positions.len();
|
||||||
let total_liquidations = self.liquidation_history.len();
|
let total_liquidations = self.liquidation_history.len();
|
||||||
|
|
||||||
let total_debt_liquidated: SynorDecimal = self.liquidation_history.iter()
|
let total_debt_liquidated: SynorDecimal =
|
||||||
.map(|e| e.debt_amount)
|
self.liquidation_history.iter().map(|e| e.debt_amount).sum();
|
||||||
.sum();
|
|
||||||
|
|
||||||
let total_collateral_seized: SynorDecimal = self.liquidation_history.iter()
|
let total_collateral_seized: SynorDecimal = self
|
||||||
|
.liquidation_history
|
||||||
|
.iter()
|
||||||
.map(|e| e.collateral_amount)
|
.map(|e| e.collateral_amount)
|
||||||
.sum();
|
.sum();
|
||||||
|
|
||||||
let unique_liquidated: std::collections::HashSet<_> = self.liquidation_history.iter()
|
let unique_liquidated: std::collections::HashSet<_> = self
|
||||||
|
.liquidation_history
|
||||||
|
.iter()
|
||||||
.map(|e| &e.account_id)
|
.map(|e| &e.account_id)
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
|
|
@ -617,18 +649,60 @@ mod tests {
|
||||||
let mut price_oracle = PriceOracle::with_config(OracleConfig::default());
|
let mut price_oracle = PriceOracle::with_config(OracleConfig::default());
|
||||||
|
|
||||||
// Add prices from multiple sources for test validity
|
// Add prices from multiple sources for test validity
|
||||||
price_oracle.update_price(TokenPrice::new("ETH", "USD", dec!(2000), PriceSource::Internal)).unwrap();
|
price_oracle
|
||||||
price_oracle.update_price(TokenPrice::new("ETH", "USD", dec!(2000), PriceSource::Aggregated)).unwrap();
|
.update_price(TokenPrice::new(
|
||||||
price_oracle.update_price(TokenPrice::new("SYNOR", "USD", dec!(1), PriceSource::Internal)).unwrap();
|
"ETH",
|
||||||
price_oracle.update_price(TokenPrice::new("SYNOR", "USD", dec!(1), PriceSource::Aggregated)).unwrap();
|
"USD",
|
||||||
price_oracle.update_price(TokenPrice::new("USDC", "USD", dec!(1), PriceSource::Internal)).unwrap();
|
dec!(2000),
|
||||||
price_oracle.update_price(TokenPrice::new("USDC", "USD", dec!(1), PriceSource::Aggregated)).unwrap();
|
PriceSource::Internal,
|
||||||
|
))
|
||||||
|
.unwrap();
|
||||||
|
price_oracle
|
||||||
|
.update_price(TokenPrice::new(
|
||||||
|
"ETH",
|
||||||
|
"USD",
|
||||||
|
dec!(2000),
|
||||||
|
PriceSource::Aggregated,
|
||||||
|
))
|
||||||
|
.unwrap();
|
||||||
|
price_oracle
|
||||||
|
.update_price(TokenPrice::new(
|
||||||
|
"SYNOR",
|
||||||
|
"USD",
|
||||||
|
dec!(1),
|
||||||
|
PriceSource::Internal,
|
||||||
|
))
|
||||||
|
.unwrap();
|
||||||
|
price_oracle
|
||||||
|
.update_price(TokenPrice::new(
|
||||||
|
"SYNOR",
|
||||||
|
"USD",
|
||||||
|
dec!(1),
|
||||||
|
PriceSource::Aggregated,
|
||||||
|
))
|
||||||
|
.unwrap();
|
||||||
|
price_oracle
|
||||||
|
.update_price(TokenPrice::new(
|
||||||
|
"USDC",
|
||||||
|
"USD",
|
||||||
|
dec!(1),
|
||||||
|
PriceSource::Internal,
|
||||||
|
))
|
||||||
|
.unwrap();
|
||||||
|
price_oracle
|
||||||
|
.update_price(TokenPrice::new(
|
||||||
|
"USDC",
|
||||||
|
"USD",
|
||||||
|
dec!(1),
|
||||||
|
PriceSource::Aggregated,
|
||||||
|
))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
// Use test config with relaxed freshness requirements
|
// Use test config with relaxed freshness requirements
|
||||||
let test_config = LiquidationOracleConfig {
|
let test_config = LiquidationOracleConfig {
|
||||||
max_price_age: 3600, // 1 hour for tests
|
max_price_age: 3600, // 1 hour for tests
|
||||||
min_confidence: 0.5, // Lower confidence threshold
|
min_confidence: 0.5, // Lower confidence threshold
|
||||||
min_sources: 1, // Single source OK for tests
|
min_sources: 1, // Single source OK for tests
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -660,8 +734,8 @@ mod tests {
|
||||||
|
|
||||||
// Create position with good health
|
// Create position with good health
|
||||||
let pos = oracle.create_position("user1");
|
let pos = oracle.create_position("user1");
|
||||||
pos.deposit("ETH", dec!(1)); // $2000 worth
|
pos.deposit("ETH", dec!(1)); // $2000 worth
|
||||||
pos.borrow("USDC", dec!(500)); // Borrow $500
|
pos.borrow("USDC", dec!(500)); // Borrow $500
|
||||||
|
|
||||||
let health = oracle.calculate_health("user1").unwrap();
|
let health = oracle.calculate_health("user1").unwrap();
|
||||||
|
|
||||||
|
|
@ -678,8 +752,8 @@ mod tests {
|
||||||
|
|
||||||
// Create position close to liquidation
|
// Create position close to liquidation
|
||||||
let pos = oracle.create_position("user2");
|
let pos = oracle.create_position("user2");
|
||||||
pos.deposit("ETH", dec!(1)); // $2000 worth
|
pos.deposit("ETH", dec!(1)); // $2000 worth
|
||||||
pos.borrow("USDC", dec!(1500)); // Borrow $1500
|
pos.borrow("USDC", dec!(1500)); // Borrow $1500
|
||||||
|
|
||||||
let health = oracle.calculate_health("user2").unwrap();
|
let health = oracle.calculate_health("user2").unwrap();
|
||||||
|
|
||||||
|
|
@ -699,7 +773,9 @@ mod tests {
|
||||||
pos.deposit("ETH", dec!(1));
|
pos.deposit("ETH", dec!(1));
|
||||||
pos.borrow("USDC", dec!(1500));
|
pos.borrow("USDC", dec!(1500));
|
||||||
|
|
||||||
let calc = oracle.calculate_liquidation("user3", "USDC", "ETH").unwrap();
|
let calc = oracle
|
||||||
|
.calculate_liquidation("user3", "USDC", "ETH")
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
// Should be able to liquidate
|
// Should be able to liquidate
|
||||||
assert!(calc.debt_to_repay > Decimal::ZERO);
|
assert!(calc.debt_to_repay > Decimal::ZERO);
|
||||||
|
|
|
||||||
|
|
@ -49,8 +49,7 @@ pub use price_feed::{
|
||||||
PriceSource,
|
PriceSource,
|
||||||
};
|
};
|
||||||
pub use twap::{
|
pub use twap::{
|
||||||
OnChainTwap, OnChainTwapFactory, TwapCalculator, TwapConfig, TwapObservation,
|
OnChainTwap, OnChainTwapFactory, TwapCalculator, TwapConfig, TwapObservation, TwapOracleBuilder,
|
||||||
TwapOracleBuilder,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::error::{EconomicsError, Result};
|
use crate::error::{EconomicsError, Result};
|
||||||
|
|
@ -241,7 +240,10 @@ impl PriceOracle {
|
||||||
/// Get price history for a pair
|
/// Get price history for a pair
|
||||||
pub fn get_price_history(&self, token: &str, quote: &str) -> Vec<TokenPrice> {
|
pub fn get_price_history(&self, token: &str, quote: &str) -> Vec<TokenPrice> {
|
||||||
let pair_key = format!("{}/{}", token, quote);
|
let pair_key = format!("{}/{}", token, quote);
|
||||||
self.price_history.get(&pair_key).cloned().unwrap_or_default()
|
self.price_history
|
||||||
|
.get(&pair_key)
|
||||||
|
.cloned()
|
||||||
|
.unwrap_or_default()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Fetch prices from all configured feeds
|
/// Fetch prices from all configured feeds
|
||||||
|
|
@ -403,7 +405,9 @@ impl PriceOracle {
|
||||||
}
|
}
|
||||||
|
|
||||||
let healthy = !pairs_status.is_empty()
|
let healthy = !pairs_status.is_empty()
|
||||||
&& pairs_status.values().all(|s| !s.is_stale && s.price_count > 0);
|
&& pairs_status
|
||||||
|
.values()
|
||||||
|
.all(|s| !s.is_stale && s.price_count > 0);
|
||||||
|
|
||||||
OracleHealthStatus {
|
OracleHealthStatus {
|
||||||
healthy,
|
healthy,
|
||||||
|
|
@ -443,7 +447,10 @@ impl PriceOracleBuilder {
|
||||||
/// Add a mock price feed (for testing)
|
/// Add a mock price feed (for testing)
|
||||||
pub fn with_mock_feed(mut self, base_price: SynorDecimal) -> Self {
|
pub fn with_mock_feed(mut self, base_price: SynorDecimal) -> Self {
|
||||||
use price_feed::MockPriceFeed;
|
use price_feed::MockPriceFeed;
|
||||||
self.feeds.push(Box::new(MockPriceFeed::new(PriceSource::Internal, base_price)));
|
self.feeds.push(Box::new(MockPriceFeed::new(
|
||||||
|
PriceSource::Internal,
|
||||||
|
base_price,
|
||||||
|
)));
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -455,9 +462,14 @@ impl PriceOracleBuilder {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Add Chainlink oracle feed
|
/// Add Chainlink oracle feed
|
||||||
pub fn with_chainlink(mut self, contract_address: impl Into<String>, rpc_url: impl Into<String>) -> Self {
|
pub fn with_chainlink(
|
||||||
|
mut self,
|
||||||
|
contract_address: impl Into<String>,
|
||||||
|
rpc_url: impl Into<String>,
|
||||||
|
) -> Self {
|
||||||
use price_feed::ChainlinkFeed;
|
use price_feed::ChainlinkFeed;
|
||||||
self.feeds.push(Box::new(ChainlinkFeed::new(contract_address, rpc_url)));
|
self.feeds
|
||||||
|
.push(Box::new(ChainlinkFeed::new(contract_address, rpc_url)));
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -508,8 +520,14 @@ impl OracleFactory {
|
||||||
let mut oracle = PriceOracle::new();
|
let mut oracle = PriceOracle::new();
|
||||||
|
|
||||||
// Add multiple mock feeds with slight variations for testing
|
// Add multiple mock feeds with slight variations for testing
|
||||||
oracle.add_feed(Box::new(MockPriceFeed::new(PriceSource::Internal, base_price)));
|
oracle.add_feed(Box::new(MockPriceFeed::new(
|
||||||
oracle.add_feed(Box::new(MockPriceFeed::new(PriceSource::SynorDex, base_price)));
|
PriceSource::Internal,
|
||||||
|
base_price,
|
||||||
|
)));
|
||||||
|
oracle.add_feed(Box::new(MockPriceFeed::new(
|
||||||
|
PriceSource::SynorDex,
|
||||||
|
base_price,
|
||||||
|
)));
|
||||||
|
|
||||||
oracle
|
oracle
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -110,7 +110,8 @@ impl PriceFeed for MockPriceFeed {
|
||||||
async fn fetch_price(&self, token: &str, quote: &str) -> Result<TokenPrice> {
|
async fn fetch_price(&self, token: &str, quote: &str) -> Result<TokenPrice> {
|
||||||
// Add small random variance
|
// Add small random variance
|
||||||
let variance = (rand_simple() * 2.0 - 1.0) * self.volatility;
|
let variance = (rand_simple() * 2.0 - 1.0) * self.volatility;
|
||||||
let price = self.base_price * (Decimal::ONE + Decimal::from_f64_retain(variance).unwrap_or_default());
|
let price = self.base_price
|
||||||
|
* (Decimal::ONE + Decimal::from_f64_retain(variance).unwrap_or_default());
|
||||||
|
|
||||||
Ok(TokenPrice {
|
Ok(TokenPrice {
|
||||||
token: token.to_string(),
|
token: token.to_string(),
|
||||||
|
|
@ -258,9 +259,12 @@ impl PriceFeed for CoinGeckoFeed {
|
||||||
"SYNOR" => "synor", // Would need actual CoinGecko ID
|
"SYNOR" => "synor", // Would need actual CoinGecko ID
|
||||||
"BTC" => "bitcoin",
|
"BTC" => "bitcoin",
|
||||||
"ETH" => "ethereum",
|
"ETH" => "ethereum",
|
||||||
_ => return Err(EconomicsError::PriceFeedUnavailable(
|
_ => {
|
||||||
format!("Token {} not supported on CoinGecko", token)
|
return Err(EconomicsError::PriceFeedUnavailable(format!(
|
||||||
)),
|
"Token {} not supported on CoinGecko",
|
||||||
|
token
|
||||||
|
)))
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let quote_currency = quote.to_lowercase();
|
let quote_currency = quote.to_lowercase();
|
||||||
|
|
@ -294,8 +298,7 @@ impl PriceFeed for CoinGeckoFeed {
|
||||||
Ok(TokenPrice {
|
Ok(TokenPrice {
|
||||||
token: token.to_string(),
|
token: token.to_string(),
|
||||||
quote: quote.to_string(),
|
quote: quote.to_string(),
|
||||||
price: Decimal::from_f64_retain(price)
|
price: Decimal::from_f64_retain(price).unwrap_or_default(),
|
||||||
.unwrap_or_default(),
|
|
||||||
timestamp: Utc::now(),
|
timestamp: Utc::now(),
|
||||||
source: PriceSource::CoinGecko,
|
source: PriceSource::CoinGecko,
|
||||||
confidence: 0.90,
|
confidence: 0.90,
|
||||||
|
|
|
||||||
|
|
@ -116,8 +116,8 @@ impl TwapCalculator {
|
||||||
|
|
||||||
let duration = (interval_end - interval_start).num_seconds() as f64;
|
let duration = (interval_end - interval_start).num_seconds() as f64;
|
||||||
if duration > 0.0 {
|
if duration > 0.0 {
|
||||||
let weight = Decimal::from_f64_retain(duration / total_duration)
|
let weight =
|
||||||
.unwrap_or(Decimal::ZERO);
|
Decimal::from_f64_retain(duration / total_duration).unwrap_or(Decimal::ZERO);
|
||||||
|
|
||||||
weighted_sum += price.price * weight;
|
weighted_sum += price.price * weight;
|
||||||
total_weight += weight;
|
total_weight += weight;
|
||||||
|
|
@ -343,7 +343,8 @@ impl OnChainTwap {
|
||||||
/// Apply the pending cardinality increase (called during next observation)
|
/// Apply the pending cardinality increase (called during next observation)
|
||||||
pub fn apply_cardinality_growth(&mut self) {
|
pub fn apply_cardinality_growth(&mut self) {
|
||||||
if self.cardinality_next > self.cardinality {
|
if self.cardinality_next > self.cardinality {
|
||||||
self.observations.reserve(self.cardinality_next - self.cardinality);
|
self.observations
|
||||||
|
.reserve(self.cardinality_next - self.cardinality);
|
||||||
self.cardinality = self.cardinality_next;
|
self.cardinality = self.cardinality_next;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -396,7 +397,12 @@ impl TwapOracleBuilder {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Add an initial observation
|
/// Add an initial observation
|
||||||
pub fn with_observation(mut self, timestamp: DateTime<Utc>, price_cumulative: SynorDecimal, spl_cumulative: SynorDecimal) -> Self {
|
pub fn with_observation(
|
||||||
|
mut self,
|
||||||
|
timestamp: DateTime<Utc>,
|
||||||
|
price_cumulative: SynorDecimal,
|
||||||
|
spl_cumulative: SynorDecimal,
|
||||||
|
) -> Self {
|
||||||
self.initial_observations.push(TwapObservation {
|
self.initial_observations.push(TwapObservation {
|
||||||
timestamp,
|
timestamp,
|
||||||
price_cumulative,
|
price_cumulative,
|
||||||
|
|
|
||||||
|
|
@ -76,7 +76,11 @@ pub struct Discount {
|
||||||
|
|
||||||
impl Discount {
|
impl Discount {
|
||||||
/// Create a new percentage discount
|
/// Create a new percentage discount
|
||||||
pub fn percentage(code: impl Into<String>, name: impl Into<String>, percentage: SynorDecimal) -> Self {
|
pub fn percentage(
|
||||||
|
code: impl Into<String>,
|
||||||
|
name: impl Into<String>,
|
||||||
|
percentage: SynorDecimal,
|
||||||
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
code: code.into(),
|
code: code.into(),
|
||||||
name: name.into(),
|
name: name.into(),
|
||||||
|
|
@ -96,7 +100,11 @@ impl Discount {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create a fixed amount discount
|
/// Create a fixed amount discount
|
||||||
pub fn fixed_amount(code: impl Into<String>, name: impl Into<String>, amount: SynorDecimal) -> Self {
|
pub fn fixed_amount(
|
||||||
|
code: impl Into<String>,
|
||||||
|
name: impl Into<String>,
|
||||||
|
amount: SynorDecimal,
|
||||||
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
code: code.into(),
|
code: code.into(),
|
||||||
name: name.into(),
|
name: name.into(),
|
||||||
|
|
@ -120,7 +128,10 @@ impl Discount {
|
||||||
Self {
|
Self {
|
||||||
code: format!("VOLUME_{}", min_spend),
|
code: format!("VOLUME_{}", min_spend),
|
||||||
name: format!("Volume Discount ({} SYNOR+)", min_spend),
|
name: format!("Volume Discount ({} SYNOR+)", min_spend),
|
||||||
description: format!("{}% off when spending {} SYNOR or more", percentage, min_spend),
|
description: format!(
|
||||||
|
"{}% off when spending {} SYNOR or more",
|
||||||
|
percentage, min_spend
|
||||||
|
),
|
||||||
discount_type: DiscountType::Volume,
|
discount_type: DiscountType::Volume,
|
||||||
value: percentage,
|
value: percentage,
|
||||||
min_spend: Some(min_spend),
|
min_spend: Some(min_spend),
|
||||||
|
|
@ -260,9 +271,11 @@ impl Discount {
|
||||||
}
|
}
|
||||||
|
|
||||||
let discount = match self.discount_type {
|
let discount = match self.discount_type {
|
||||||
DiscountType::Percentage | DiscountType::Volume | DiscountType::Loyalty | DiscountType::Referral | DiscountType::Partner => {
|
DiscountType::Percentage
|
||||||
amount * (self.value / Decimal::ONE_HUNDRED)
|
| DiscountType::Volume
|
||||||
}
|
| DiscountType::Loyalty
|
||||||
|
| DiscountType::Referral
|
||||||
|
| DiscountType::Partner => amount * (self.value / Decimal::ONE_HUNDRED),
|
||||||
DiscountType::FixedAmount | DiscountType::Promotional => {
|
DiscountType::FixedAmount | DiscountType::Promotional => {
|
||||||
self.value.min(amount) // Can't discount more than amount
|
self.value.min(amount) // Can't discount more than amount
|
||||||
}
|
}
|
||||||
|
|
@ -298,7 +311,7 @@ impl Discount {
|
||||||
/// Volume discount tiers
|
/// Volume discount tiers
|
||||||
pub fn standard_volume_discounts() -> Vec<Discount> {
|
pub fn standard_volume_discounts() -> Vec<Discount> {
|
||||||
vec![
|
vec![
|
||||||
Discount::volume(Decimal::new(100, 0), Decimal::new(5, 0)), // 5% at 100+ SYNOR
|
Discount::volume(Decimal::new(100, 0), Decimal::new(5, 0)), // 5% at 100+ SYNOR
|
||||||
Discount::volume(Decimal::new(500, 0), Decimal::new(10, 0)), // 10% at 500+ SYNOR
|
Discount::volume(Decimal::new(500, 0), Decimal::new(10, 0)), // 10% at 500+ SYNOR
|
||||||
Discount::volume(Decimal::new(1000, 0), Decimal::new(15, 0)), // 15% at 1000+ SYNOR
|
Discount::volume(Decimal::new(1000, 0), Decimal::new(15, 0)), // 15% at 1000+ SYNOR
|
||||||
Discount::volume(Decimal::new(5000, 0), Decimal::new(20, 0)), // 20% at 5000+ SYNOR
|
Discount::volume(Decimal::new(5000, 0), Decimal::new(20, 0)), // 20% at 5000+ SYNOR
|
||||||
|
|
@ -309,11 +322,7 @@ pub fn standard_volume_discounts() -> Vec<Discount> {
|
||||||
pub fn find_best_volume_discount(amount: SynorDecimal) -> Option<Discount> {
|
pub fn find_best_volume_discount(amount: SynorDecimal) -> Option<Discount> {
|
||||||
standard_volume_discounts()
|
standard_volume_discounts()
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.filter(|d| {
|
.filter(|d| d.min_spend.map(|min| amount >= min).unwrap_or(false))
|
||||||
d.min_spend
|
|
||||||
.map(|min| amount >= min)
|
|
||||||
.unwrap_or(false)
|
|
||||||
})
|
|
||||||
.max_by(|a, b| a.value.cmp(&b.value))
|
.max_by(|a, b| a.value.cmp(&b.value))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -378,8 +387,8 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_discount_usage_limit() {
|
fn test_discount_usage_limit() {
|
||||||
let mut discount = Discount::percentage("LIMITED", "Limited Use", dec!(10))
|
let mut discount =
|
||||||
.with_max_uses(2);
|
Discount::percentage("LIMITED", "Limited Use", dec!(10)).with_max_uses(2);
|
||||||
|
|
||||||
assert!(discount.use_discount());
|
assert!(discount.use_discount());
|
||||||
assert!(discount.use_discount());
|
assert!(discount.use_discount());
|
||||||
|
|
|
||||||
|
|
@ -198,15 +198,9 @@ impl PricingEngine {
|
||||||
.get(&service_type)
|
.get(&service_type)
|
||||||
.ok_or_else(|| EconomicsError::ServiceNotConfigured(service_type.to_string()))?;
|
.ok_or_else(|| EconomicsError::ServiceNotConfigured(service_type.to_string()))?;
|
||||||
|
|
||||||
let unit_price = pricing
|
let unit_price = pricing.base_prices.get(&resource_unit).ok_or_else(|| {
|
||||||
.base_prices
|
EconomicsError::ServiceNotConfigured(format!("{} - {}", service_type, resource_unit))
|
||||||
.get(&resource_unit)
|
})?;
|
||||||
.ok_or_else(|| {
|
|
||||||
EconomicsError::ServiceNotConfigured(format!(
|
|
||||||
"{} - {}",
|
|
||||||
service_type, resource_unit
|
|
||||||
))
|
|
||||||
})?;
|
|
||||||
|
|
||||||
let cost = amount * unit_price;
|
let cost = amount * unit_price;
|
||||||
|
|
||||||
|
|
@ -357,7 +351,8 @@ impl PricingEngine {
|
||||||
storage: ServicePricingSummary {
|
storage: ServicePricingSummary {
|
||||||
gb_month: self.get_base_price(ServiceType::Storage, ResourceUnit::GbMonth),
|
gb_month: self.get_base_price(ServiceType::Storage, ResourceUnit::GbMonth),
|
||||||
retrieval_gb: self.get_base_price(ServiceType::Storage, ResourceUnit::BandwidthGb),
|
retrieval_gb: self.get_base_price(ServiceType::Storage, ResourceUnit::BandwidthGb),
|
||||||
free_storage_gb: self.get_free_allocation(ServiceType::Storage, ResourceUnit::GbMonth),
|
free_storage_gb: self
|
||||||
|
.get_free_allocation(ServiceType::Storage, ResourceUnit::GbMonth),
|
||||||
},
|
},
|
||||||
hosting: HostingPricingSummary {
|
hosting: HostingPricingSummary {
|
||||||
bandwidth_gb: self.get_base_price(ServiceType::Hosting, ResourceUnit::BandwidthGb),
|
bandwidth_gb: self.get_base_price(ServiceType::Hosting, ResourceUnit::BandwidthGb),
|
||||||
|
|
@ -377,7 +372,8 @@ impl PricingEngine {
|
||||||
.get_free_allocation(ServiceType::Database, ResourceUnit::Queries),
|
.get_free_allocation(ServiceType::Database, ResourceUnit::Queries),
|
||||||
},
|
},
|
||||||
compute: ComputePricingSummary {
|
compute: ComputePricingSummary {
|
||||||
cpu_core_hour: self.get_base_price(ServiceType::Compute, ResourceUnit::CpuCoreHours),
|
cpu_core_hour: self
|
||||||
|
.get_base_price(ServiceType::Compute, ResourceUnit::CpuCoreHours),
|
||||||
gpu_hour: self.get_base_price(ServiceType::Compute, ResourceUnit::GpuHours),
|
gpu_hour: self.get_base_price(ServiceType::Compute, ResourceUnit::GpuHours),
|
||||||
memory_gb_hour: self
|
memory_gb_hour: self
|
||||||
.get_base_price(ServiceType::Compute, ResourceUnit::MemoryGbHours),
|
.get_base_price(ServiceType::Compute, ResourceUnit::MemoryGbHours),
|
||||||
|
|
@ -472,7 +468,11 @@ mod tests {
|
||||||
|
|
||||||
// 10 million queries
|
// 10 million queries
|
||||||
let cost = engine
|
let cost = engine
|
||||||
.calculate_cost(ServiceType::Database, ResourceUnit::Queries, dec!(10_000_000))
|
.calculate_cost(
|
||||||
|
ServiceType::Database,
|
||||||
|
ResourceUnit::Queries,
|
||||||
|
dec!(10_000_000),
|
||||||
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
assert_eq!(cost, dec!(0.10)); // 10M * 0.00000001
|
assert_eq!(cost, dec!(0.10)); // 10M * 0.00000001
|
||||||
|
|
@ -484,7 +484,9 @@ mod tests {
|
||||||
|
|
||||||
// Premium tier gets 20% discount
|
// Premium tier gets 20% discount
|
||||||
let base_cost = dec!(100);
|
let base_cost = dec!(100);
|
||||||
let discount = engine.calculate_tier_discount("premium", base_cost).unwrap();
|
let discount = engine
|
||||||
|
.calculate_tier_discount("premium", base_cost)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
assert_eq!(discount, dec!(20)); // 20%
|
assert_eq!(discount, dec!(20)); // 20%
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -108,8 +108,8 @@ impl PricingTier {
|
||||||
discount_percentage: Decimal::new(30, 0), // 30% discount
|
discount_percentage: Decimal::new(30, 0), // 30% discount
|
||||||
priority_support: true,
|
priority_support: true,
|
||||||
sla_percentage: Decimal::new(9999, 2), // 99.99% SLA
|
sla_percentage: Decimal::new(9999, 2), // 99.99% SLA
|
||||||
custom_domain_limit: 0, // Unlimited
|
custom_domain_limit: 0, // Unlimited
|
||||||
api_rate_limit: 0, // Unlimited
|
api_rate_limit: 0, // Unlimited
|
||||||
features: vec![
|
features: vec![
|
||||||
"Everything in Premium".to_string(),
|
"Everything in Premium".to_string(),
|
||||||
"30%+ Usage Discount".to_string(),
|
"30%+ Usage Discount".to_string(),
|
||||||
|
|
@ -147,7 +147,9 @@ impl PricingTier {
|
||||||
|
|
||||||
/// Check if this tier has a feature
|
/// Check if this tier has a feature
|
||||||
pub fn has_feature(&self, feature: &str) -> bool {
|
pub fn has_feature(&self, feature: &str) -> bool {
|
||||||
self.features.iter().any(|f| f.to_lowercase().contains(&feature.to_lowercase()))
|
self.features
|
||||||
|
.iter()
|
||||||
|
.any(|f| f.to_lowercase().contains(&feature.to_lowercase()))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Calculate effective monthly cost including usage discount
|
/// Calculate effective monthly cost including usage discount
|
||||||
|
|
@ -163,7 +165,9 @@ impl PricingTier {
|
||||||
let other_cost = other.effective_cost(monthly_usage);
|
let other_cost = other.effective_cost(monthly_usage);
|
||||||
|
|
||||||
// Upgrade if other tier is cheaper or offers significant benefits
|
// Upgrade if other tier is cheaper or offers significant benefits
|
||||||
other_cost < current_cost || (other.sla_percentage > self.sla_percentage && other_cost <= current_cost * Decimal::new(12, 1))
|
other_cost < current_cost
|
||||||
|
|| (other.sla_percentage > self.sla_percentage
|
||||||
|
&& other_cost <= current_cost * Decimal::new(12, 1))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -181,7 +181,12 @@ impl AuthService {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate a JWT token.
|
/// Generate a JWT token.
|
||||||
pub fn generate_token(&self, user_id: &str, tier: ApiKeyTier, permissions: Permissions) -> Result<String, ApiError> {
|
pub fn generate_token(
|
||||||
|
&self,
|
||||||
|
user_id: &str,
|
||||||
|
tier: ApiKeyTier,
|
||||||
|
permissions: Permissions,
|
||||||
|
) -> Result<String, ApiError> {
|
||||||
let now = Utc::now();
|
let now = Utc::now();
|
||||||
let exp = now + self.jwt_expiration;
|
let exp = now + self.jwt_expiration;
|
||||||
|
|
||||||
|
|
@ -215,8 +220,7 @@ impl AuthService {
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let claims = token_data.claims;
|
let claims = token_data.claims;
|
||||||
let expires_at = DateTime::from_timestamp(claims.exp, 0)
|
let expires_at = DateTime::from_timestamp(claims.exp, 0).map(|dt| dt.with_timezone(&Utc));
|
||||||
.map(|dt| dt.with_timezone(&Utc));
|
|
||||||
|
|
||||||
Ok(AuthContext {
|
Ok(AuthContext {
|
||||||
user_id: claims.sub,
|
user_id: claims.sub,
|
||||||
|
|
@ -278,9 +282,7 @@ impl AuthService {
|
||||||
|
|
||||||
// Try API key header
|
// Try API key header
|
||||||
if let Some(api_key) = headers.get("X-API-Key") {
|
if let Some(api_key) = headers.get("X-API-Key") {
|
||||||
let key = api_key
|
let key = api_key.to_str().map_err(|_| ApiError::InvalidApiKey)?;
|
||||||
.to_str()
|
|
||||||
.map_err(|_| ApiError::InvalidApiKey)?;
|
|
||||||
return self.validate_api_key(key).await;
|
return self.validate_api_key(key).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -295,8 +297,7 @@ impl AuthService {
|
||||||
let decoded = BASE64
|
let decoded = BASE64
|
||||||
.decode(encoded)
|
.decode(encoded)
|
||||||
.map_err(|_| ApiError::InvalidApiKey)?;
|
.map_err(|_| ApiError::InvalidApiKey)?;
|
||||||
let key = String::from_utf8(decoded)
|
let key = String::from_utf8(decoded).map_err(|_| ApiError::InvalidApiKey)?;
|
||||||
.map_err(|_| ApiError::InvalidApiKey)?;
|
|
||||||
return self.validate_api_key(&key).await;
|
return self.validate_api_key(&key).await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -318,7 +319,9 @@ where
|
||||||
fn from_request_parts<'life0, 'life1, 'async_trait>(
|
fn from_request_parts<'life0, 'life1, 'async_trait>(
|
||||||
parts: &'life0 mut Parts,
|
parts: &'life0 mut Parts,
|
||||||
_state: &'life1 S,
|
_state: &'life1 S,
|
||||||
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Self, Self::Rejection>> + Send + 'async_trait>>
|
) -> std::pin::Pin<
|
||||||
|
Box<dyn std::future::Future<Output = Result<Self, Self::Rejection>> + Send + 'async_trait>,
|
||||||
|
>
|
||||||
where
|
where
|
||||||
'life0: 'async_trait,
|
'life0: 'async_trait,
|
||||||
'life1: 'async_trait,
|
'life1: 'async_trait,
|
||||||
|
|
@ -351,7 +354,9 @@ where
|
||||||
fn from_request_parts<'life0, 'life1, 'async_trait>(
|
fn from_request_parts<'life0, 'life1, 'async_trait>(
|
||||||
parts: &'life0 mut Parts,
|
parts: &'life0 mut Parts,
|
||||||
_state: &'life1 S,
|
_state: &'life1 S,
|
||||||
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Self, Self::Rejection>> + Send + 'async_trait>>
|
) -> std::pin::Pin<
|
||||||
|
Box<dyn std::future::Future<Output = Result<Self, Self::Rejection>> + Send + 'async_trait>,
|
||||||
|
>
|
||||||
where
|
where
|
||||||
'life0: 'async_trait,
|
'life0: 'async_trait,
|
||||||
'life1: 'async_trait,
|
'life1: 'async_trait,
|
||||||
|
|
@ -359,10 +364,7 @@ where
|
||||||
{
|
{
|
||||||
Box::pin(async move {
|
Box::pin(async move {
|
||||||
// Get auth service from extensions
|
// Get auth service from extensions
|
||||||
let auth_service = parts
|
let auth_service = parts.extensions.get::<AuthService>().cloned();
|
||||||
.extensions
|
|
||||||
.get::<AuthService>()
|
|
||||||
.cloned();
|
|
||||||
|
|
||||||
if let Some(auth_service) = auth_service {
|
if let Some(auth_service) = auth_service {
|
||||||
match auth_service.authenticate(&parts.headers).await {
|
match auth_service.authenticate(&parts.headers).await {
|
||||||
|
|
@ -377,10 +379,7 @@ where
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Require specific permissions.
|
/// Require specific permissions.
|
||||||
pub fn require_permission(
|
pub fn require_permission(context: &AuthContext, permission: &str) -> Result<(), ApiError> {
|
||||||
context: &AuthContext,
|
|
||||||
permission: &str,
|
|
||||||
) -> Result<(), ApiError> {
|
|
||||||
let has_permission = match permission {
|
let has_permission = match permission {
|
||||||
"read" => context.can_read(),
|
"read" => context.can_read(),
|
||||||
"write" => context.can_write(),
|
"write" => context.can_write(),
|
||||||
|
|
@ -397,10 +396,7 @@ pub fn require_permission(
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Require access to a specific service.
|
/// Require access to a specific service.
|
||||||
pub fn require_service_access(
|
pub fn require_service_access(context: &AuthContext, service: &str) -> Result<(), ApiError> {
|
||||||
context: &AuthContext,
|
|
||||||
service: &str,
|
|
||||||
) -> Result<(), ApiError> {
|
|
||||||
if context.can_access_service(service) {
|
if context.can_access_service(service) {
|
||||||
Ok(())
|
Ok(())
|
||||||
} else {
|
} else {
|
||||||
|
|
|
||||||
|
|
@ -160,10 +160,26 @@ pub struct RateLimitTiers {
|
||||||
impl Default for RateLimitTiers {
|
impl Default for RateLimitTiers {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
free: TierConfig { rpm: 60, burst: 10, concurrent: 5 },
|
free: TierConfig {
|
||||||
developer: TierConfig { rpm: 600, burst: 100, concurrent: 20 },
|
rpm: 60,
|
||||||
pro: TierConfig { rpm: 6000, burst: 1000, concurrent: 100 },
|
burst: 10,
|
||||||
enterprise: TierConfig { rpm: 0, burst: 0, concurrent: 0 }, // Unlimited
|
concurrent: 5,
|
||||||
|
},
|
||||||
|
developer: TierConfig {
|
||||||
|
rpm: 600,
|
||||||
|
burst: 100,
|
||||||
|
concurrent: 20,
|
||||||
|
},
|
||||||
|
pro: TierConfig {
|
||||||
|
rpm: 6000,
|
||||||
|
burst: 1000,
|
||||||
|
concurrent: 100,
|
||||||
|
},
|
||||||
|
enterprise: TierConfig {
|
||||||
|
rpm: 0,
|
||||||
|
burst: 0,
|
||||||
|
concurrent: 0,
|
||||||
|
}, // Unlimited
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -170,9 +170,7 @@ impl ApiError {
|
||||||
| Self::ContractError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
| Self::ContractError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
|
|
||||||
// 429 Too Many Requests
|
// 429 Too Many Requests
|
||||||
Self::RateLimitExceeded | Self::TooManyRequests { .. } => {
|
Self::RateLimitExceeded | Self::TooManyRequests { .. } => StatusCode::TOO_MANY_REQUESTS,
|
||||||
StatusCode::TOO_MANY_REQUESTS
|
|
||||||
}
|
|
||||||
|
|
||||||
// 500 Internal Server Error
|
// 500 Internal Server Error
|
||||||
Self::InternalError | Self::Custom(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
Self::InternalError | Self::Custom(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
|
@ -222,7 +220,10 @@ impl ApiError {
|
||||||
/// Build error details with optional extra information.
|
/// Build error details with optional extra information.
|
||||||
pub fn to_details(&self) -> ErrorDetails {
|
pub fn to_details(&self) -> ErrorDetails {
|
||||||
let details = match self {
|
let details = match self {
|
||||||
Self::InsufficientBalance { required, available } => Some(serde_json::json!({
|
Self::InsufficientBalance {
|
||||||
|
required,
|
||||||
|
available,
|
||||||
|
} => Some(serde_json::json!({
|
||||||
"required": required,
|
"required": required,
|
||||||
"available": available
|
"available": available
|
||||||
})),
|
})),
|
||||||
|
|
@ -257,10 +258,9 @@ impl IntoResponse for ApiError {
|
||||||
|
|
||||||
// Add rate limit headers for 429 errors
|
// Add rate limit headers for 429 errors
|
||||||
if let Self::TooManyRequests { retry_after } = &self {
|
if let Self::TooManyRequests { retry_after } = &self {
|
||||||
response.headers_mut().insert(
|
response
|
||||||
"Retry-After",
|
.headers_mut()
|
||||||
retry_after.to_string().parse().unwrap(),
|
.insert("Retry-After", retry_after.to_string().parse().unwrap());
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
response
|
response
|
||||||
|
|
|
||||||
|
|
@ -144,7 +144,8 @@ pub async fn timing_middleware(request: Request, next: Next) -> Response {
|
||||||
|
|
||||||
// Update metrics
|
// Update metrics
|
||||||
metrics::counter!("http_requests_total", "method" => method.to_string(), "status" => status.as_u16().to_string()).increment(1);
|
metrics::counter!("http_requests_total", "method" => method.to_string(), "status" => status.as_u16().to_string()).increment(1);
|
||||||
metrics::histogram!("http_request_duration_seconds", "method" => method.to_string()).record(duration.as_secs_f64());
|
metrics::histogram!("http_request_duration_seconds", "method" => method.to_string())
|
||||||
|
.record(duration.as_secs_f64());
|
||||||
|
|
||||||
response
|
response
|
||||||
}
|
}
|
||||||
|
|
@ -169,7 +170,10 @@ impl RateLimiterState {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get or create a rate limiter for an IP.
|
/// Get or create a rate limiter for an IP.
|
||||||
pub async fn get_ip_limiter(&self, ip: &str) -> Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>> {
|
pub async fn get_ip_limiter(
|
||||||
|
&self,
|
||||||
|
ip: &str,
|
||||||
|
) -> Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>> {
|
||||||
{
|
{
|
||||||
let limiters = self.ip_limiters.read().await;
|
let limiters = self.ip_limiters.read().await;
|
||||||
if let Some(limiter) = limiters.get(ip) {
|
if let Some(limiter) = limiters.get(ip) {
|
||||||
|
|
@ -189,7 +193,11 @@ impl RateLimiterState {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get or create a rate limiter for an API key.
|
/// Get or create a rate limiter for an API key.
|
||||||
pub async fn get_key_limiter(&self, key_id: &str, tier: ApiKeyTier) -> Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>> {
|
pub async fn get_key_limiter(
|
||||||
|
&self,
|
||||||
|
key_id: &str,
|
||||||
|
tier: ApiKeyTier,
|
||||||
|
) -> Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>> {
|
||||||
{
|
{
|
||||||
let limiters = self.key_limiters.read().await;
|
let limiters = self.key_limiters.read().await;
|
||||||
if let Some(limiter) = limiters.get(key_id) {
|
if let Some(limiter) = limiters.get(key_id) {
|
||||||
|
|
@ -255,9 +263,7 @@ pub async fn rate_limit_middleware(
|
||||||
// Use a fixed retry time since we can't easily convert to quanta's instant
|
// Use a fixed retry time since we can't easily convert to quanta's instant
|
||||||
let retry_after = 60; // Default to 60 seconds
|
let retry_after = 60; // Default to 60 seconds
|
||||||
|
|
||||||
Err(ApiError::TooManyRequests {
|
Err(ApiError::TooManyRequests { retry_after })
|
||||||
retry_after,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -274,10 +280,7 @@ pub async fn auth_middleware(
|
||||||
}
|
}
|
||||||
|
|
||||||
/// API version middleware - validates version prefix.
|
/// API version middleware - validates version prefix.
|
||||||
pub async fn version_middleware(
|
pub async fn version_middleware(request: Request, next: Next) -> Result<Response, ApiError> {
|
||||||
request: Request,
|
|
||||||
next: Next,
|
|
||||||
) -> Result<Response, ApiError> {
|
|
||||||
let path = request.uri().path();
|
let path = request.uri().path();
|
||||||
|
|
||||||
// Skip version check for health, metrics, and docs
|
// Skip version check for health, metrics, and docs
|
||||||
|
|
@ -307,22 +310,13 @@ pub async fn security_headers_middleware(request: Request, next: Next) -> Respon
|
||||||
let headers = response.headers_mut();
|
let headers = response.headers_mut();
|
||||||
|
|
||||||
// Prevent XSS
|
// Prevent XSS
|
||||||
headers.insert(
|
headers.insert("X-Content-Type-Options", "nosniff".parse().unwrap());
|
||||||
"X-Content-Type-Options",
|
|
||||||
"nosniff".parse().unwrap(),
|
|
||||||
);
|
|
||||||
|
|
||||||
// Prevent clickjacking
|
// Prevent clickjacking
|
||||||
headers.insert(
|
headers.insert("X-Frame-Options", "DENY".parse().unwrap());
|
||||||
"X-Frame-Options",
|
|
||||||
"DENY".parse().unwrap(),
|
|
||||||
);
|
|
||||||
|
|
||||||
// Enable XSS filter
|
// Enable XSS filter
|
||||||
headers.insert(
|
headers.insert("X-XSS-Protection", "1; mode=block".parse().unwrap());
|
||||||
"X-XSS-Protection",
|
|
||||||
"1; mode=block".parse().unwrap(),
|
|
||||||
);
|
|
||||||
|
|
||||||
// Strict transport security (HTTPS)
|
// Strict transport security (HTTPS)
|
||||||
headers.insert(
|
headers.insert(
|
||||||
|
|
|
||||||
|
|
@ -6,11 +6,7 @@
|
||||||
//! - Contract analysis and validation
|
//! - Contract analysis and validation
|
||||||
//! - Security scanning
|
//! - Security scanning
|
||||||
|
|
||||||
use axum::{
|
use axum::{extract::State, routing::post, Json, Router};
|
||||||
extract::State,
|
|
||||||
routing::post,
|
|
||||||
Json, Router,
|
|
||||||
};
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
|
|
@ -44,7 +40,7 @@ pub fn router() -> Router<AppState> {
|
||||||
// Types
|
// Types
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
pub struct CompileRequest {
|
pub struct CompileRequest {
|
||||||
pub wasm: String, // base64 encoded WASM
|
pub wasm: String, // base64 encoded WASM
|
||||||
pub optimization_level: Option<String>, // none, basic, size, aggressive
|
pub optimization_level: Option<String>, // none, basic, size, aggressive
|
||||||
pub strip_debug: Option<bool>,
|
pub strip_debug: Option<bool>,
|
||||||
pub strip_names: Option<bool>,
|
pub strip_names: Option<bool>,
|
||||||
|
|
@ -216,16 +212,14 @@ async fn compile_contract(
|
||||||
abi: Some(ContractAbi {
|
abi: Some(ContractAbi {
|
||||||
name: "MyContract".to_string(),
|
name: "MyContract".to_string(),
|
||||||
version: "1.0.0".to_string(),
|
version: "1.0.0".to_string(),
|
||||||
functions: vec![
|
functions: vec![AbiFunction {
|
||||||
AbiFunction {
|
name: "init".to_string(),
|
||||||
name: "init".to_string(),
|
selector: "0x12345678".to_string(),
|
||||||
selector: "0x12345678".to_string(),
|
inputs: vec![],
|
||||||
inputs: vec![],
|
outputs: vec![],
|
||||||
outputs: vec![],
|
view: false,
|
||||||
view: false,
|
payable: false,
|
||||||
payable: false,
|
}],
|
||||||
},
|
|
||||||
],
|
|
||||||
events: vec![],
|
events: vec![],
|
||||||
errors: vec![],
|
errors: vec![],
|
||||||
}),
|
}),
|
||||||
|
|
@ -342,23 +336,19 @@ async fn analyze_contract(
|
||||||
imports: 100,
|
imports: 100,
|
||||||
total: 5000,
|
total: 5000,
|
||||||
},
|
},
|
||||||
functions: vec![
|
functions: vec![FunctionAnalysis {
|
||||||
FunctionAnalysis {
|
name: "init".to_string(),
|
||||||
name: "init".to_string(),
|
size: 500,
|
||||||
size: 500,
|
instruction_count: 50,
|
||||||
instruction_count: 50,
|
local_count: 3,
|
||||||
local_count: 3,
|
exported: true,
|
||||||
exported: true,
|
estimated_gas: 10000,
|
||||||
estimated_gas: 10000,
|
}],
|
||||||
},
|
imports: vec![ImportInfo {
|
||||||
],
|
module: "env".to_string(),
|
||||||
imports: vec![
|
name: "memory".to_string(),
|
||||||
ImportInfo {
|
kind: "memory".to_string(),
|
||||||
module: "env".to_string(),
|
}],
|
||||||
name: "memory".to_string(),
|
|
||||||
kind: "memory".to_string(),
|
|
||||||
},
|
|
||||||
],
|
|
||||||
gas_analysis: GasAnalysis {
|
gas_analysis: GasAnalysis {
|
||||||
deployment_gas: 100000,
|
deployment_gas: 100000,
|
||||||
memory_init_gas: 5000,
|
memory_init_gas: 5000,
|
||||||
|
|
@ -378,14 +368,12 @@ async fn security_scan(
|
||||||
|
|
||||||
let result = SecurityScanResult {
|
let result = SecurityScanResult {
|
||||||
score: 85,
|
score: 85,
|
||||||
issues: vec![
|
issues: vec![SecurityIssue {
|
||||||
SecurityIssue {
|
severity: "low".to_string(),
|
||||||
severity: "low".to_string(),
|
issue_type: "unbounded_loop".to_string(),
|
||||||
issue_type: "unbounded_loop".to_string(),
|
description: "Potential unbounded loop detected".to_string(),
|
||||||
description: "Potential unbounded loop detected".to_string(),
|
location: Some("function:process".to_string()),
|
||||||
location: Some("function:process".to_string()),
|
}],
|
||||||
},
|
|
||||||
],
|
|
||||||
recommendations: vec![
|
recommendations: vec![
|
||||||
"Add loop iteration limits".to_string(),
|
"Add loop iteration limits".to_string(),
|
||||||
"Consider using checked arithmetic".to_string(),
|
"Consider using checked arithmetic".to_string(),
|
||||||
|
|
|
||||||
|
|
@ -122,17 +122,15 @@ async fn list_markets(
|
||||||
) -> ApiResult<Json<ApiResponse<Vec<Market>>>> {
|
) -> ApiResult<Json<ApiResponse<Vec<Market>>>> {
|
||||||
require_permission(&auth, "read")?;
|
require_permission(&auth, "read")?;
|
||||||
|
|
||||||
let markets = vec![
|
let markets = vec![Market {
|
||||||
Market {
|
symbol: "ETH-USDC".to_string(),
|
||||||
symbol: "ETH-USDC".to_string(),
|
base_asset: "ETH".to_string(),
|
||||||
base_asset: "ETH".to_string(),
|
quote_asset: "USDC".to_string(),
|
||||||
quote_asset: "USDC".to_string(),
|
last_price: "3000.00".to_string(),
|
||||||
last_price: "3000.00".to_string(),
|
change_24h: "2.5".to_string(),
|
||||||
change_24h: "2.5".to_string(),
|
volume_24h: "10000000".to_string(),
|
||||||
volume_24h: "10000000".to_string(),
|
status: "active".to_string(),
|
||||||
status: "active".to_string(),
|
}];
|
||||||
},
|
|
||||||
];
|
|
||||||
|
|
||||||
Ok(Json(ApiResponse::success(markets)))
|
Ok(Json(ApiResponse::success(markets)))
|
||||||
}
|
}
|
||||||
|
|
@ -165,8 +163,14 @@ async fn get_orderbook(
|
||||||
require_permission(&auth, "read")?;
|
require_permission(&auth, "read")?;
|
||||||
|
|
||||||
let orderbook = Orderbook {
|
let orderbook = Orderbook {
|
||||||
bids: vec![OrderbookEntry { price: "2999.00".to_string(), quantity: "1.5".to_string() }],
|
bids: vec![OrderbookEntry {
|
||||||
asks: vec![OrderbookEntry { price: "3001.00".to_string(), quantity: "2.0".to_string() }],
|
price: "2999.00".to_string(),
|
||||||
|
quantity: "1.5".to_string(),
|
||||||
|
}],
|
||||||
|
asks: vec![OrderbookEntry {
|
||||||
|
price: "3001.00".to_string(),
|
||||||
|
quantity: "2.0".to_string(),
|
||||||
|
}],
|
||||||
spread: "2.00".to_string(),
|
spread: "2.00".to_string(),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -286,7 +290,9 @@ async fn place_perp_order(
|
||||||
Json(req): Json<serde_json::Value>,
|
Json(req): Json<serde_json::Value>,
|
||||||
) -> ApiResult<Json<ApiResponse<serde_json::Value>>> {
|
) -> ApiResult<Json<ApiResponse<serde_json::Value>>> {
|
||||||
require_permission(&auth, "write")?;
|
require_permission(&auth, "write")?;
|
||||||
Ok(Json(ApiResponse::success(serde_json::json!({"order_id": "perp_123"}))))
|
Ok(Json(ApiResponse::success(
|
||||||
|
serde_json::json!({"order_id": "perp_123"}),
|
||||||
|
)))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn list_pools(
|
async fn list_pools(
|
||||||
|
|
@ -295,18 +301,16 @@ async fn list_pools(
|
||||||
) -> ApiResult<Json<ApiResponse<Vec<Pool>>>> {
|
) -> ApiResult<Json<ApiResponse<Vec<Pool>>>> {
|
||||||
require_permission(&auth, "read")?;
|
require_permission(&auth, "read")?;
|
||||||
|
|
||||||
let pools = vec![
|
let pools = vec![Pool {
|
||||||
Pool {
|
pool_id: "ETH-USDC".to_string(),
|
||||||
pool_id: "ETH-USDC".to_string(),
|
name: "ETH/USDC".to_string(),
|
||||||
name: "ETH/USDC".to_string(),
|
token_a: "ETH".to_string(),
|
||||||
token_a: "ETH".to_string(),
|
token_b: "USDC".to_string(),
|
||||||
token_b: "USDC".to_string(),
|
reserve_a: "1000".to_string(),
|
||||||
reserve_a: "1000".to_string(),
|
reserve_b: "3000000".to_string(),
|
||||||
reserve_b: "3000000".to_string(),
|
tvl: "6000000".to_string(),
|
||||||
tvl: "6000000".to_string(),
|
apr: "15.5".to_string(),
|
||||||
apr: "15.5".to_string(),
|
}];
|
||||||
},
|
|
||||||
];
|
|
||||||
|
|
||||||
Ok(Json(ApiResponse::success(pools)))
|
Ok(Json(ApiResponse::success(pools)))
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -128,16 +128,14 @@ async fn list_chains(
|
||||||
) -> ApiResult<Json<ApiResponse<Vec<Chain>>>> {
|
) -> ApiResult<Json<ApiResponse<Vec<Chain>>>> {
|
||||||
require_permission(&auth, "read")?;
|
require_permission(&auth, "read")?;
|
||||||
|
|
||||||
let chains = vec![
|
let chains = vec![Chain {
|
||||||
Chain {
|
chain_id: "cosmoshub-4".to_string(),
|
||||||
chain_id: "cosmoshub-4".to_string(),
|
name: "Cosmos Hub".to_string(),
|
||||||
name: "Cosmos Hub".to_string(),
|
status: "active".to_string(),
|
||||||
status: "active".to_string(),
|
rpc_endpoint: "https://rpc.cosmos.network".to_string(),
|
||||||
rpc_endpoint: "https://rpc.cosmos.network".to_string(),
|
latest_height: 18000000,
|
||||||
latest_height: 18000000,
|
active_channels: 50,
|
||||||
active_channels: 50,
|
}];
|
||||||
},
|
|
||||||
];
|
|
||||||
|
|
||||||
Ok(Json(ApiResponse::success(chains)))
|
Ok(Json(ApiResponse::success(chains)))
|
||||||
}
|
}
|
||||||
|
|
@ -280,15 +278,13 @@ async fn get_routes(
|
||||||
) -> ApiResult<Json<ApiResponse<Vec<TransferRoute>>>> {
|
) -> ApiResult<Json<ApiResponse<Vec<TransferRoute>>>> {
|
||||||
require_permission(&auth, "read")?;
|
require_permission(&auth, "read")?;
|
||||||
|
|
||||||
let routes = vec![
|
let routes = vec![TransferRoute {
|
||||||
TransferRoute {
|
source_chain: "cosmoshub-4".to_string(),
|
||||||
source_chain: "cosmoshub-4".to_string(),
|
dest_chain: "synor-mainnet".to_string(),
|
||||||
dest_chain: "synor-mainnet".to_string(),
|
channel_id: "channel-0".to_string(),
|
||||||
channel_id: "channel-0".to_string(),
|
estimated_time: "30s".to_string(),
|
||||||
estimated_time: "30s".to_string(),
|
fee: "0.001 ATOM".to_string(),
|
||||||
fee: "0.001 ATOM".to_string(),
|
}];
|
||||||
},
|
|
||||||
];
|
|
||||||
|
|
||||||
Ok(Json(ApiResponse::success(routes)))
|
Ok(Json(ApiResponse::success(routes)))
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -303,13 +303,11 @@ async fn get_peers(
|
||||||
) -> ApiResult<Json<ApiResponse<Vec<serde_json::Value>>>> {
|
) -> ApiResult<Json<ApiResponse<Vec<serde_json::Value>>>> {
|
||||||
require_permission(&auth, "read")?;
|
require_permission(&auth, "read")?;
|
||||||
|
|
||||||
let peers = vec![
|
let peers = vec![serde_json::json!({
|
||||||
serde_json::json!({
|
"id": "peer1",
|
||||||
"id": "peer1",
|
"address": "192.168.1.1:16100",
|
||||||
"address": "192.168.1.1:16100",
|
"connected_since": 1705312200
|
||||||
"connected_since": 1705312200
|
})];
|
||||||
})
|
|
||||||
];
|
|
||||||
|
|
||||||
Ok(Json(ApiResponse::success(peers)))
|
Ok(Json(ApiResponse::success(peers)))
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -247,14 +247,12 @@ async fn list_directory(
|
||||||
) -> ApiResult<Json<ApiResponse<Vec<DirectoryEntry>>>> {
|
) -> ApiResult<Json<ApiResponse<Vec<DirectoryEntry>>>> {
|
||||||
require_permission(&auth, "read")?;
|
require_permission(&auth, "read")?;
|
||||||
|
|
||||||
let entries = vec![
|
let entries = vec![DirectoryEntry {
|
||||||
DirectoryEntry {
|
name: "file1.txt".to_string(),
|
||||||
name: "file1.txt".to_string(),
|
cid: "bafyfile1...".to_string(),
|
||||||
cid: "bafyfile1...".to_string(),
|
size: 1024,
|
||||||
size: 1024,
|
is_directory: false,
|
||||||
is_directory: false,
|
}];
|
||||||
},
|
|
||||||
];
|
|
||||||
|
|
||||||
Ok(Json(ApiResponse::success(entries)))
|
Ok(Json(ApiResponse::success(entries)))
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -409,7 +409,10 @@ async fn list_addresses(
|
||||||
];
|
];
|
||||||
|
|
||||||
let pagination_meta = pagination.to_meta(addresses.len() as u64);
|
let pagination_meta = pagination.to_meta(addresses.len() as u64);
|
||||||
Ok(Json(ApiResponse::success_paginated(addresses, pagination_meta)))
|
Ok(Json(ApiResponse::success_paginated(
|
||||||
|
addresses,
|
||||||
|
pagination_meta,
|
||||||
|
)))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate a stealth address.
|
/// Generate a stealth address.
|
||||||
|
|
@ -477,7 +480,9 @@ async fn get_balances(
|
||||||
require_permission(&auth, "read")?;
|
require_permission(&auth, "read")?;
|
||||||
|
|
||||||
if req.addresses.is_empty() {
|
if req.addresses.is_empty() {
|
||||||
return Err(ApiError::ValidationError("addresses cannot be empty".to_string()));
|
return Err(ApiError::ValidationError(
|
||||||
|
"addresses cannot be empty".to_string(),
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.addresses.len() > 100 {
|
if req.addresses.len() > 100 {
|
||||||
|
|
@ -585,12 +590,15 @@ async fn send_transaction(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate amount
|
// Validate amount
|
||||||
let amount: f64 = req.amount.parse().map_err(|_| {
|
let amount: f64 = req
|
||||||
ApiError::ValidationError("Invalid amount format".to_string())
|
.amount
|
||||||
})?;
|
.parse()
|
||||||
|
.map_err(|_| ApiError::ValidationError("Invalid amount format".to_string()))?;
|
||||||
|
|
||||||
if amount <= 0.0 {
|
if amount <= 0.0 {
|
||||||
return Err(ApiError::ValidationError("Amount must be positive".to_string()));
|
return Err(ApiError::ValidationError(
|
||||||
|
"Amount must be positive".to_string(),
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
// In production, build, sign, and broadcast the transaction
|
// In production, build, sign, and broadcast the transaction
|
||||||
|
|
|
||||||
|
|
@ -139,16 +139,14 @@ async fn list_circuits(
|
||||||
) -> ApiResult<Json<ApiResponse<Vec<Circuit>>>> {
|
) -> ApiResult<Json<ApiResponse<Vec<Circuit>>>> {
|
||||||
require_permission(&auth, "read")?;
|
require_permission(&auth, "read")?;
|
||||||
|
|
||||||
let circuits = vec![
|
let circuits = vec![Circuit {
|
||||||
Circuit {
|
circuit_id: "multiplier-v1".to_string(),
|
||||||
circuit_id: "multiplier-v1".to_string(),
|
name: "Multiplier".to_string(),
|
||||||
name: "Multiplier".to_string(),
|
constraints: 1,
|
||||||
constraints: 1,
|
public_inputs: 1,
|
||||||
public_inputs: 1,
|
private_inputs: 2,
|
||||||
private_inputs: 2,
|
outputs: 1,
|
||||||
outputs: 1,
|
}];
|
||||||
},
|
|
||||||
];
|
|
||||||
|
|
||||||
let meta = pagination.to_meta(circuits.len() as u64);
|
let meta = pagination.to_meta(circuits.len() as u64);
|
||||||
Ok(Json(ApiResponse::success_paginated(circuits, meta)))
|
Ok(Json(ApiResponse::success_paginated(circuits, meta)))
|
||||||
|
|
|
||||||
|
|
@ -17,15 +17,9 @@ use axum::{
|
||||||
Router,
|
Router,
|
||||||
};
|
};
|
||||||
use std::{net::SocketAddr, sync::Arc};
|
use std::{net::SocketAddr, sync::Arc};
|
||||||
use tokio::{
|
use tokio::{net::TcpListener, signal, sync::oneshot};
|
||||||
net::TcpListener,
|
|
||||||
signal,
|
|
||||||
sync::oneshot,
|
|
||||||
};
|
|
||||||
use tower_http::{
|
use tower_http::{
|
||||||
compression::CompressionLayer,
|
compression::CompressionLayer, limit::RequestBodyLimitLayer, timeout::TimeoutLayer,
|
||||||
limit::RequestBodyLimitLayer,
|
|
||||||
timeout::TimeoutLayer,
|
|
||||||
trace::TraceLayer,
|
trace::TraceLayer,
|
||||||
};
|
};
|
||||||
use tracing::info;
|
use tracing::info;
|
||||||
|
|
|
||||||
|
|
@ -164,7 +164,8 @@ impl VersionRegistry {
|
||||||
|
|
||||||
/// Register a version.
|
/// Register a version.
|
||||||
pub fn register(&mut self, info: VersionInfo) {
|
pub fn register(&mut self, info: VersionInfo) {
|
||||||
self.versions.insert((info.version.major, info.version.minor), info);
|
self.versions
|
||||||
|
.insert((info.version.major, info.version.minor), info);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get version info.
|
/// Get version info.
|
||||||
|
|
@ -330,10 +331,7 @@ pub async fn version_middleware(req: Request, next: Next) -> Response {
|
||||||
// Add deprecation headers if needed
|
// Add deprecation headers if needed
|
||||||
if let Some(info) = registry.get(&extracted.version) {
|
if let Some(info) = registry.get(&extracted.version) {
|
||||||
if info.is_deprecated {
|
if info.is_deprecated {
|
||||||
headers.insert(
|
headers.insert(X_API_DEPRECATED.clone(), HeaderValue::from_static("true"));
|
||||||
X_API_DEPRECATED.clone(),
|
|
||||||
HeaderValue::from_static("true"),
|
|
||||||
);
|
|
||||||
|
|
||||||
if let Some(deprecated_at) = &info.deprecated_at {
|
if let Some(deprecated_at) = &info.deprecated_at {
|
||||||
if let Ok(v) = HeaderValue::from_str(&deprecated_at.to_rfc3339()) {
|
if let Ok(v) = HeaderValue::from_str(&deprecated_at.to_rfc3339()) {
|
||||||
|
|
@ -427,8 +425,8 @@ impl VersionsResponse {
|
||||||
// Routes
|
// Routes
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
|
|
||||||
use axum::{routing::get, Json, Router};
|
|
||||||
use crate::routes::AppState;
|
use crate::routes::AppState;
|
||||||
|
use axum::{routing::get, Json, Router};
|
||||||
|
|
||||||
/// Build version routes.
|
/// Build version routes.
|
||||||
pub fn router() -> Router<AppState> {
|
pub fn router() -> Router<AppState> {
|
||||||
|
|
|
||||||
|
|
@ -593,11 +593,7 @@ async fn ws_blocks_handler(
|
||||||
ws.on_upgrade(move |socket| handle_blocks_socket(socket, state, auth))
|
ws.on_upgrade(move |socket| handle_blocks_socket(socket, state, auth))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_blocks_socket(
|
async fn handle_blocks_socket(socket: WebSocket, state: AppState, _auth: Option<AuthContext>) {
|
||||||
socket: WebSocket,
|
|
||||||
state: AppState,
|
|
||||||
_auth: Option<AuthContext>,
|
|
||||||
) {
|
|
||||||
let ws_state = &state.websocket;
|
let ws_state = &state.websocket;
|
||||||
ws_state.broadcaster.add_connection().await;
|
ws_state.broadcaster.add_connection().await;
|
||||||
|
|
||||||
|
|
@ -834,11 +830,7 @@ async fn ws_markets_handler(
|
||||||
ws.on_upgrade(move |socket| handle_markets_socket(socket, state, auth))
|
ws.on_upgrade(move |socket| handle_markets_socket(socket, state, auth))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_markets_socket(
|
async fn handle_markets_socket(socket: WebSocket, state: AppState, _auth: Option<AuthContext>) {
|
||||||
socket: WebSocket,
|
|
||||||
state: AppState,
|
|
||||||
_auth: Option<AuthContext>,
|
|
||||||
) {
|
|
||||||
let ws_state = &state.websocket;
|
let ws_state = &state.websocket;
|
||||||
ws_state.broadcaster.add_connection().await;
|
ws_state.broadcaster.add_connection().await;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,8 +7,8 @@
|
||||||
//! hosting-gateway --domain synor.cc --storage-url http://localhost:8180
|
//! hosting-gateway --domain synor.cc --storage-url http://localhost:8180
|
||||||
//! hosting-gateway --config /path/to/config.toml
|
//! hosting-gateway --config /path/to/config.toml
|
||||||
|
|
||||||
use synor_hosting::{HostingGateway, GatewayConfig};
|
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
|
use synor_hosting::{GatewayConfig, HostingGateway};
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
|
@ -19,28 +19,32 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
// Parse command line arguments
|
// Parse command line arguments
|
||||||
let args: Vec<String> = std::env::args().collect();
|
let args: Vec<String> = std::env::args().collect();
|
||||||
|
|
||||||
let listen_addr: SocketAddr = args.iter()
|
let listen_addr: SocketAddr = args
|
||||||
|
.iter()
|
||||||
.position(|a| a == "--listen" || a == "-l")
|
.position(|a| a == "--listen" || a == "-l")
|
||||||
.and_then(|i| args.get(i + 1))
|
.and_then(|i| args.get(i + 1))
|
||||||
.and_then(|s| s.parse().ok())
|
.and_then(|s| s.parse().ok())
|
||||||
.or_else(|| std::env::var("LISTEN_ADDR").ok()?.parse().ok())
|
.or_else(|| std::env::var("LISTEN_ADDR").ok()?.parse().ok())
|
||||||
.unwrap_or_else(|| "0.0.0.0:8080".parse().unwrap());
|
.unwrap_or_else(|| "0.0.0.0:8080".parse().unwrap());
|
||||||
|
|
||||||
let hosting_domain = args.iter()
|
let hosting_domain = args
|
||||||
|
.iter()
|
||||||
.position(|a| a == "--domain" || a == "-d")
|
.position(|a| a == "--domain" || a == "-d")
|
||||||
.and_then(|i| args.get(i + 1))
|
.and_then(|i| args.get(i + 1))
|
||||||
.cloned()
|
.cloned()
|
||||||
.or_else(|| std::env::var("HOSTING_DOMAIN").ok())
|
.or_else(|| std::env::var("HOSTING_DOMAIN").ok())
|
||||||
.unwrap_or_else(|| "synor.cc".to_string());
|
.unwrap_or_else(|| "synor.cc".to_string());
|
||||||
|
|
||||||
let storage_url = args.iter()
|
let storage_url = args
|
||||||
|
.iter()
|
||||||
.position(|a| a == "--storage-url" || a == "-s")
|
.position(|a| a == "--storage-url" || a == "-s")
|
||||||
.and_then(|i| args.get(i + 1))
|
.and_then(|i| args.get(i + 1))
|
||||||
.cloned()
|
.cloned()
|
||||||
.or_else(|| std::env::var("STORAGE_GATEWAY_URL").ok())
|
.or_else(|| std::env::var("STORAGE_GATEWAY_URL").ok())
|
||||||
.unwrap_or_else(|| "http://localhost:8180".to_string());
|
.unwrap_or_else(|| "http://localhost:8180".to_string());
|
||||||
|
|
||||||
let rate_limit: u32 = args.iter()
|
let rate_limit: u32 = args
|
||||||
|
.iter()
|
||||||
.position(|a| a == "--rate-limit")
|
.position(|a| a == "--rate-limit")
|
||||||
.and_then(|i| args.get(i + 1))
|
.and_then(|i| args.get(i + 1))
|
||||||
.and_then(|s| s.parse().ok())
|
.and_then(|s| s.parse().ok())
|
||||||
|
|
|
||||||
|
|
@ -233,11 +233,7 @@ impl EdgeCompute {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Run AI inference at the edge.
|
/// Run AI inference at the edge.
|
||||||
pub async fn inference(
|
pub async fn inference(&self, _model: &str, _input: &[u8]) -> Result<Vec<u8>, EdgeError> {
|
||||||
&self,
|
|
||||||
_model: &str,
|
|
||||||
_input: &[u8],
|
|
||||||
) -> Result<Vec<u8>, EdgeError> {
|
|
||||||
if !self.enabled {
|
if !self.enabled {
|
||||||
return Err(EdgeError::NotEnabled);
|
return Err(EdgeError::NotEnabled);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Reference in a new issue