synor/sdk/rust/src/tests.rs
Gulshan Yadav e2a3b66123 test(sdk): add comprehensive unit tests for all SDKs
Adds unit tests covering tensor operations, type enums, client
functionality, and serialization for all 12 SDK implementations:

- JavaScript (Vitest): tensor, types, client tests
- Python (pytest): tensor, types, client tests
- Go: standard library tests with httptest
- Flutter (flutter_test): tensor, types tests
- Java (JUnit 5): tensor, types tests
- Rust: embedded module tests
- Ruby (minitest): tensor, types tests
- C# (xUnit): tensor, types tests

Tests cover:
- Tensor creation (zeros, ones, random, randn, eye, arange, linspace)
- Tensor operations (reshape, transpose, indexing)
- Reductions (sum, mean, std, min, max)
- Activations (relu, sigmoid, softmax)
- Serialization/deserialization
- Type enums and configuration
- Client request building
- Error handling
2026-01-11 17:56:11 +05:30

334 lines
9.5 KiB
Rust

//! Comprehensive tests for Synor Compute Rust SDK.
#[cfg(test)]
mod tensor_tests {
use crate::tensor::Tensor;
use crate::types::Precision;
#[test]
fn test_tensor_creation() {
let t = Tensor::new(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
assert_eq!(t.shape(), &[2, 3]);
assert_eq!(t.size(), 6);
assert_eq!(t.ndim(), 2);
}
#[test]
fn test_tensor_zeros() {
let t = Tensor::zeros(&[3, 3]);
assert_eq!(t.shape(), &[3, 3]);
assert!(t.data().iter().all(|&x| x == 0.0));
}
#[test]
fn test_tensor_ones() {
let t = Tensor::ones(&[2, 2]);
assert!(t.data().iter().all(|&x| x == 1.0));
}
#[test]
fn test_tensor_rand() {
let t = Tensor::rand(&[10, 10]);
assert_eq!(t.size(), 100);
assert!(t.data().iter().all(|&x| x >= 0.0 && x < 1.0));
}
#[test]
fn test_tensor_randn() {
let t = Tensor::randn(&[1000]);
let mean = t.mean();
let std = t.std();
assert!(mean.abs() < 0.2, "Mean should be close to 0");
assert!(std > 0.8 && std < 1.2, "Std should be close to 1");
}
#[test]
fn test_tensor_eye() {
let t = Tensor::eye(3);
assert_eq!(t.shape(), &[3, 3]);
assert_eq!(t.get(&[0, 0]), 1.0);
assert_eq!(t.get(&[1, 1]), 1.0);
assert_eq!(t.get(&[2, 2]), 1.0);
assert_eq!(t.get(&[0, 1]), 0.0);
}
#[test]
fn test_tensor_arange() {
let t = Tensor::arange(0.0, 5.0, 1.0);
assert_eq!(t.shape(), &[5]);
assert_eq!(t.data()[0], 0.0);
assert_eq!(t.data()[4], 4.0);
}
#[test]
fn test_tensor_linspace() {
let t = Tensor::linspace(0.0, 10.0, 11);
assert_eq!(t.shape(), &[11]);
assert!((t.data()[0] - 0.0).abs() < 1e-10);
assert!((t.data()[10] - 10.0).abs() < 1e-10);
}
#[test]
fn test_tensor_reshape() {
let t = Tensor::new(&[6], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let reshaped = t.reshape(&[2, 3]);
assert_eq!(reshaped.shape(), &[2, 3]);
assert_eq!(reshaped.size(), 6);
}
#[test]
#[should_panic]
fn test_tensor_invalid_reshape() {
let t = Tensor::zeros(&[4]);
t.reshape(&[2, 3]); // Should panic
}
#[test]
fn test_tensor_transpose() {
let t = Tensor::new(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let transposed = t.transpose();
assert_eq!(transposed.shape(), &[3, 2]);
}
#[test]
fn test_tensor_mean() {
let t = Tensor::new(&[4], vec![1.0, 2.0, 3.0, 4.0]);
assert!((t.mean() - 2.5).abs() < 1e-10);
}
#[test]
fn test_tensor_sum() {
let t = Tensor::new(&[4], vec![1.0, 2.0, 3.0, 4.0]);
assert!((t.sum() - 10.0).abs() < 1e-10);
}
#[test]
fn test_tensor_std() {
let t = Tensor::new(&[4], vec![1.0, 2.0, 3.0, 4.0]);
assert!((t.std() - 1.118).abs() < 0.01);
}
#[test]
fn test_tensor_min_max() {
let t = Tensor::new(&[4], vec![3.0, 1.0, 4.0, 2.0]);
assert_eq!(t.min(), 1.0);
assert_eq!(t.max(), 4.0);
}
#[test]
fn test_tensor_relu() {
let t = Tensor::new(&[5], vec![-2.0, -1.0, 0.0, 1.0, 2.0]);
let result = t.relu();
assert_eq!(result.data(), &[0.0, 0.0, 0.0, 1.0, 2.0]);
}
#[test]
fn test_tensor_sigmoid() {
let t = Tensor::new(&[1], vec![0.0]);
let result = t.sigmoid();
assert!((result.data()[0] - 0.5).abs() < 1e-10);
}
#[test]
fn test_tensor_softmax() {
let t = Tensor::new(&[3], vec![1.0, 2.0, 3.0]);
let result = t.softmax();
assert!((result.sum() - 1.0).abs() < 1e-10);
assert!(result.data()[2] > result.data()[1]);
}
#[test]
fn test_tensor_dtype() {
let t = Tensor::zeros(&[2, 2]).with_dtype(Precision::FP16);
assert_eq!(t.dtype(), Precision::FP16);
}
}
#[cfg(test)]
mod types_tests {
use crate::types::*;
#[test]
fn test_processor_type_values() {
assert_eq!(ProcessorType::Cpu.as_str(), "cpu");
assert_eq!(ProcessorType::Gpu.as_str(), "gpu");
assert_eq!(ProcessorType::Tpu.as_str(), "tpu");
assert_eq!(ProcessorType::Npu.as_str(), "npu");
assert_eq!(ProcessorType::Lpu.as_str(), "lpu");
assert_eq!(ProcessorType::Auto.as_str(), "auto");
}
#[test]
fn test_precision_values() {
assert_eq!(Precision::FP64.as_str(), "fp64");
assert_eq!(Precision::FP32.as_str(), "fp32");
assert_eq!(Precision::FP16.as_str(), "fp16");
assert_eq!(Precision::BF16.as_str(), "bf16");
assert_eq!(Precision::INT8.as_str(), "int8");
assert_eq!(Precision::INT4.as_str(), "int4");
}
#[test]
fn test_priority_default() {
let priority = Priority::default();
assert_eq!(priority, Priority::Normal);
}
#[test]
fn test_job_status_default() {
let status = JobStatus::default();
assert_eq!(status, JobStatus::Pending);
}
#[test]
fn test_config_builder() {
let config = Config::new("test-api-key");
assert_eq!(config.api_key, "test-api-key");
assert_eq!(config.base_url, "https://api.synor.io/compute/v1");
assert_eq!(config.timeout_secs, 30);
assert!(!config.debug);
}
#[test]
fn test_config_custom() {
let config = Config::new("test-key")
.base_url("https://custom.api.com")
.default_processor(ProcessorType::Gpu)
.default_precision(Precision::FP16)
.timeout_secs(60)
.debug(true);
assert_eq!(config.base_url, "https://custom.api.com");
assert_eq!(config.default_processor, ProcessorType::Gpu);
assert_eq!(config.default_precision, Precision::FP16);
assert_eq!(config.timeout_secs, 60);
assert!(config.debug);
}
#[test]
fn test_matmul_options_default() {
let options = MatMulOptions::default();
assert_eq!(options.precision, Precision::FP32);
assert_eq!(options.processor, ProcessorType::Auto);
assert_eq!(options.priority, Priority::Normal);
}
#[test]
fn test_conv2d_options_default() {
let options = Conv2dOptions::default();
assert_eq!(options.stride, (1, 1));
assert_eq!(options.padding, (0, 0));
}
#[test]
fn test_attention_options_default() {
let options = AttentionOptions::default();
assert_eq!(options.num_heads, 8);
assert!(options.flash);
assert_eq!(options.precision, Precision::FP16);
}
#[test]
fn test_inference_options_default() {
let options = InferenceOptions::default();
assert_eq!(options.max_tokens, 256);
assert!((options.temperature - 0.7).abs() < 0.001);
assert!((options.top_p - 0.9).abs() < 0.001);
assert_eq!(options.top_k, 50);
}
#[test]
fn test_job_result_success() {
let result: JobResult<String> = JobResult {
job_id: Some("job-123".to_string()),
status: JobStatus::Completed,
result: Some("output".to_string()),
error: None,
execution_time_ms: Some(100),
processor: Some(ProcessorType::Gpu),
cost: Some(0.01),
};
assert!(result.is_success());
assert!(!result.is_failed());
}
#[test]
fn test_job_result_failure() {
let result: JobResult<String> = JobResult {
job_id: Some("job-456".to_string()),
status: JobStatus::Failed,
result: None,
error: Some("Error message".to_string()),
execution_time_ms: None,
processor: None,
cost: None,
};
assert!(!result.is_success());
assert!(result.is_failed());
}
#[test]
fn test_model_info_formatted_parameters() {
let model = ModelInfo {
id: "test".to_string(),
name: "Test Model".to_string(),
description: None,
category: "llm".to_string(),
parameters: Some(70_000_000_000),
context_length: Some(8192),
format: None,
recommended_processor: None,
license: None,
cid: None,
};
assert_eq!(model.formatted_parameters(), "70B");
}
#[test]
fn test_model_info_formatted_parameters_millions() {
let model = ModelInfo {
id: "test".to_string(),
name: "Test Model".to_string(),
description: None,
category: "embedding".to_string(),
parameters: Some(350_000_000),
context_length: None,
format: None,
recommended_processor: None,
license: None,
cid: None,
};
assert_eq!(model.formatted_parameters(), "350M");
}
}
#[cfg(test)]
mod error_tests {
use crate::error::Error;
#[test]
fn test_client_closed_error() {
let err = Error::ClientClosed;
assert_eq!(format!("{}", err), "Client has been closed");
}
#[test]
fn test_api_error() {
let err = Error::Api {
status_code: 401,
message: "Invalid API key".to_string(),
};
assert!(format!("{}", err).contains("401"));
assert!(format!("{}", err).contains("Invalid API key"));
}
#[test]
fn test_invalid_argument_error() {
let err = Error::InvalidArgument("Bad parameter".to_string());
assert!(format!("{}", err).contains("Bad parameter"));
}
}