Adds unit tests covering tensor operations, type enums, client functionality, and serialization for all 12 SDK implementations: - JavaScript (Vitest): tensor, types, client tests - Python (pytest): tensor, types, client tests - Go: standard library tests with httptest - Flutter (flutter_test): tensor, types tests - Java (JUnit 5): tensor, types tests - Rust: embedded module tests - Ruby (minitest): tensor, types tests - C# (xUnit): tensor, types tests Tests cover: - Tensor creation (zeros, ones, random, randn, eye, arange, linspace) - Tensor operations (reshape, transpose, indexing) - Reductions (sum, mean, std, min, max) - Activations (relu, sigmoid, softmax) - Serialization/deserialization - Type enums and configuration - Client request building - Error handling
416 lines
11 KiB
Go
416 lines
11 KiB
Go
package synor
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
// TestClientCreation tests client initialization.
|
|
func TestClientCreation(t *testing.T) {
|
|
t.Run("creates client with API key", func(t *testing.T) {
|
|
client := NewClient("test-api-key")
|
|
if client == nil {
|
|
t.Fatal("expected non-nil client")
|
|
}
|
|
})
|
|
|
|
t.Run("creates client with config", func(t *testing.T) {
|
|
config := Config{
|
|
APIKey: "test-api-key",
|
|
Endpoint: "https://custom.api.com",
|
|
Strategy: Cost,
|
|
Precision: FP16,
|
|
Timeout: 60 * time.Second,
|
|
}
|
|
client := NewClientWithConfig(config)
|
|
if client == nil {
|
|
t.Fatal("expected non-nil client")
|
|
}
|
|
if client.config.Endpoint != "https://custom.api.com" {
|
|
t.Errorf("expected custom endpoint, got %s", client.config.Endpoint)
|
|
}
|
|
})
|
|
}
|
|
|
|
// TestDefaultConfig tests default configuration.
|
|
func TestDefaultConfig(t *testing.T) {
|
|
config := DefaultConfig("my-key")
|
|
|
|
if config.APIKey != "my-key" {
|
|
t.Errorf("expected API key 'my-key', got %s", config.APIKey)
|
|
}
|
|
if config.Endpoint != DefaultEndpoint {
|
|
t.Errorf("expected default endpoint, got %s", config.Endpoint)
|
|
}
|
|
if config.Strategy != Balanced {
|
|
t.Errorf("expected Balanced strategy, got %s", config.Strategy)
|
|
}
|
|
if config.Precision != FP32 {
|
|
t.Errorf("expected FP32 precision, got %s", config.Precision)
|
|
}
|
|
if config.Timeout != 30*time.Second {
|
|
t.Errorf("expected 30s timeout, got %v", config.Timeout)
|
|
}
|
|
}
|
|
|
|
// TestTensorCreation tests tensor creation.
|
|
func TestTensorCreation(t *testing.T) {
|
|
t.Run("creates tensor from data", func(t *testing.T) {
|
|
data := []float32{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}
|
|
shape := []int{2, 3}
|
|
tensor := NewTensor(data, shape, FP32)
|
|
|
|
if tensor == nil {
|
|
t.Fatal("expected non-nil tensor")
|
|
}
|
|
if len(tensor.Data) != 6 {
|
|
t.Errorf("expected 6 elements, got %d", len(tensor.Data))
|
|
}
|
|
if len(tensor.Shape) != 2 {
|
|
t.Errorf("expected 2 dimensions, got %d", len(tensor.Shape))
|
|
}
|
|
if tensor.DType != FP32 {
|
|
t.Errorf("expected FP32 dtype, got %s", tensor.DType)
|
|
}
|
|
})
|
|
|
|
t.Run("creates zeros tensor", func(t *testing.T) {
|
|
tensor := Zeros([]int{3, 4}, FP32)
|
|
|
|
if tensor == nil {
|
|
t.Fatal("expected non-nil tensor")
|
|
}
|
|
if len(tensor.Data) != 12 {
|
|
t.Errorf("expected 12 elements, got %d", len(tensor.Data))
|
|
}
|
|
for i, v := range tensor.Data {
|
|
if v != 0.0 {
|
|
t.Errorf("expected 0 at index %d, got %f", i, v)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
|
|
// TestTensorSerialize tests tensor serialization.
|
|
func TestTensorSerialize(t *testing.T) {
|
|
data := []float32{1.0, 2.0, 3.0, 4.0}
|
|
shape := []int{2, 2}
|
|
tensor := NewTensor(data, shape, FP32)
|
|
|
|
serialized := tensor.Serialize()
|
|
|
|
if _, ok := serialized["data"]; !ok {
|
|
t.Error("expected 'data' field in serialized tensor")
|
|
}
|
|
if _, ok := serialized["shape"]; !ok {
|
|
t.Error("expected 'shape' field in serialized tensor")
|
|
}
|
|
if _, ok := serialized["dtype"]; !ok {
|
|
t.Error("expected 'dtype' field in serialized tensor")
|
|
}
|
|
if serialized["dtype"] != FP32 {
|
|
t.Errorf("expected FP32 dtype, got %v", serialized["dtype"])
|
|
}
|
|
}
|
|
|
|
// TestProcessorTypes tests processor type constants.
|
|
func TestProcessorTypes(t *testing.T) {
|
|
types := []ProcessorType{CPU, GPU, TPU, NPU, LPU, FPGA, WASM, WebGPU}
|
|
expected := []string{"cpu", "gpu", "tpu", "npu", "lpu", "fpga", "wasm", "webgpu"}
|
|
|
|
for i, pt := range types {
|
|
if string(pt) != expected[i] {
|
|
t.Errorf("expected %s, got %s", expected[i], pt)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestPrecisionTypes tests precision type constants.
|
|
func TestPrecisionTypes(t *testing.T) {
|
|
precisions := []Precision{FP64, FP32, FP16, BF16, INT8, INT4}
|
|
expected := []string{"fp64", "fp32", "fp16", "bf16", "int8", "int4"}
|
|
|
|
for i, p := range precisions {
|
|
if string(p) != expected[i] {
|
|
t.Errorf("expected %s, got %s", expected[i], p)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestJobStatus tests job status constants.
|
|
func TestJobStatus(t *testing.T) {
|
|
statuses := []JobStatus{Pending, Queued, Running, Completed, Failed, Cancelled}
|
|
expected := []string{"pending", "queued", "running", "completed", "failed", "cancelled"}
|
|
|
|
for i, s := range statuses {
|
|
if string(s) != expected[i] {
|
|
t.Errorf("expected %s, got %s", expected[i], s)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestBalancingStrategies tests balancing strategy constants.
|
|
func TestBalancingStrategies(t *testing.T) {
|
|
strategies := []BalancingStrategy{Speed, Cost, Energy, Latency, Balanced}
|
|
expected := []string{"speed", "cost", "energy", "latency", "balanced"}
|
|
|
|
for i, s := range strategies {
|
|
if string(s) != expected[i] {
|
|
t.Errorf("expected %s, got %s", expected[i], s)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestSubmitJob tests job submission with mock server.
|
|
func TestSubmitJob(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path == "/jobs" && r.Method == "POST" {
|
|
w.WriteHeader(http.StatusOK)
|
|
json.NewEncoder(w).Encode(map[string]interface{}{
|
|
"job_id": "job-123",
|
|
})
|
|
return
|
|
}
|
|
w.WriteHeader(http.StatusNotFound)
|
|
}))
|
|
defer server.Close()
|
|
|
|
config := Config{
|
|
APIKey: "test-key",
|
|
Endpoint: server.URL,
|
|
Timeout: 5 * time.Second,
|
|
}
|
|
client := NewClientWithConfig(config)
|
|
|
|
ctx := context.Background()
|
|
job, err := client.SubmitJob(ctx, "matmul", map[string]interface{}{
|
|
"a": map[string]interface{}{"data": "base64", "shape": []int{2, 2}},
|
|
"b": map[string]interface{}{"data": "base64", "shape": []int{2, 2}},
|
|
})
|
|
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if job.ID != "job-123" {
|
|
t.Errorf("expected job ID 'job-123', got %s", job.ID)
|
|
}
|
|
}
|
|
|
|
// TestGetJobStatus tests getting job status with mock server.
|
|
func TestGetJobStatus(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path == "/jobs/job-123" && r.Method == "GET" {
|
|
w.WriteHeader(http.StatusOK)
|
|
json.NewEncoder(w).Encode(map[string]interface{}{
|
|
"job_id": "job-123",
|
|
"status": "completed",
|
|
"data": map[string]interface{}{"result": "success"},
|
|
})
|
|
return
|
|
}
|
|
w.WriteHeader(http.StatusNotFound)
|
|
}))
|
|
defer server.Close()
|
|
|
|
config := Config{
|
|
APIKey: "test-key",
|
|
Endpoint: server.URL,
|
|
Timeout: 5 * time.Second,
|
|
}
|
|
client := NewClientWithConfig(config)
|
|
|
|
ctx := context.Background()
|
|
result, err := client.GetJobStatus(ctx, "job-123")
|
|
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if result.Status != Completed {
|
|
t.Errorf("expected Completed status, got %s", result.Status)
|
|
}
|
|
}
|
|
|
|
// TestGetPricing tests getting pricing information with mock server.
|
|
func TestGetPricing(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path == "/pricing" && r.Method == "GET" {
|
|
w.WriteHeader(http.StatusOK)
|
|
json.NewEncoder(w).Encode(map[string]interface{}{
|
|
"pricing": []map[string]interface{}{
|
|
{
|
|
"processor_type": "gpu",
|
|
"spot_price": 0.0001,
|
|
"avg_price_24h": 0.00012,
|
|
"aws_equivalent": 0.001,
|
|
"savings_percent": 90.0,
|
|
},
|
|
{
|
|
"processor_type": "tpu",
|
|
"spot_price": 0.0002,
|
|
"avg_price_24h": 0.00022,
|
|
"aws_equivalent": 0.002,
|
|
"savings_percent": 89.0,
|
|
},
|
|
},
|
|
})
|
|
return
|
|
}
|
|
w.WriteHeader(http.StatusNotFound)
|
|
}))
|
|
defer server.Close()
|
|
|
|
config := Config{
|
|
APIKey: "test-key",
|
|
Endpoint: server.URL,
|
|
Timeout: 5 * time.Second,
|
|
}
|
|
client := NewClientWithConfig(config)
|
|
|
|
ctx := context.Background()
|
|
pricing, err := client.GetPricing(ctx)
|
|
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if len(pricing) != 2 {
|
|
t.Errorf("expected 2 pricing entries, got %d", len(pricing))
|
|
}
|
|
}
|
|
|
|
// TestSynorError tests error formatting.
|
|
func TestSynorError(t *testing.T) {
|
|
err := &SynorError{
|
|
Message: "Invalid API key",
|
|
StatusCode: 401,
|
|
}
|
|
|
|
expected := "synor: Invalid API key (status 401)"
|
|
if err.Error() != expected {
|
|
t.Errorf("expected error message '%s', got '%s'", expected, err.Error())
|
|
}
|
|
}
|
|
|
|
// TestAPIError tests handling of API errors.
|
|
func TestAPIError(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusUnauthorized)
|
|
json.NewEncoder(w).Encode(map[string]interface{}{
|
|
"message": "Invalid API key",
|
|
})
|
|
}))
|
|
defer server.Close()
|
|
|
|
config := Config{
|
|
APIKey: "bad-key",
|
|
Endpoint: server.URL,
|
|
Timeout: 5 * time.Second,
|
|
}
|
|
client := NewClientWithConfig(config)
|
|
|
|
ctx := context.Background()
|
|
_, err := client.SubmitJob(ctx, "test", nil)
|
|
|
|
if err == nil {
|
|
t.Fatal("expected error, got nil")
|
|
}
|
|
|
|
synorErr, ok := err.(*SynorError)
|
|
if !ok {
|
|
t.Fatalf("expected SynorError, got %T", err)
|
|
}
|
|
if synorErr.StatusCode != 401 {
|
|
t.Errorf("expected status 401, got %d", synorErr.StatusCode)
|
|
}
|
|
}
|
|
|
|
// TestCancelJob tests job cancellation with mock server.
|
|
func TestCancelJob(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path == "/jobs/job-123" && r.Method == "DELETE" {
|
|
w.WriteHeader(http.StatusOK)
|
|
return
|
|
}
|
|
w.WriteHeader(http.StatusNotFound)
|
|
}))
|
|
defer server.Close()
|
|
|
|
config := Config{
|
|
APIKey: "test-key",
|
|
Endpoint: server.URL,
|
|
Timeout: 5 * time.Second,
|
|
}
|
|
client := NewClientWithConfig(config)
|
|
|
|
ctx := context.Background()
|
|
err := client.CancelJob(ctx, "job-123")
|
|
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
}
|
|
|
|
// TestJobMetrics tests job metrics structure.
|
|
func TestJobMetrics(t *testing.T) {
|
|
metrics := JobMetrics{
|
|
ExecutionTimeMs: 150.5,
|
|
QueueTimeMs: 10.2,
|
|
ProcessorType: GPU,
|
|
ProcessorID: "gpu-001",
|
|
FLOPS: 1e12,
|
|
MemoryBytes: 1073741824,
|
|
CostMicro: 100,
|
|
EnergyMJ: 5.5,
|
|
}
|
|
|
|
if metrics.ExecutionTimeMs != 150.5 {
|
|
t.Errorf("expected execution time 150.5, got %f", metrics.ExecutionTimeMs)
|
|
}
|
|
if metrics.ProcessorType != GPU {
|
|
t.Errorf("expected GPU processor, got %s", metrics.ProcessorType)
|
|
}
|
|
}
|
|
|
|
// TestPricingInfo tests pricing info structure.
|
|
func TestPricingInfo(t *testing.T) {
|
|
pricing := PricingInfo{
|
|
ProcessorType: GPU,
|
|
SpotPrice: 0.0001,
|
|
AvgPrice24h: 0.00012,
|
|
AWSEquivalent: 0.001,
|
|
SavingsPercent: 90.0,
|
|
}
|
|
|
|
if pricing.ProcessorType != GPU {
|
|
t.Errorf("expected GPU processor, got %s", pricing.ProcessorType)
|
|
}
|
|
if pricing.SavingsPercent != 90.0 {
|
|
t.Errorf("expected 90%% savings, got %f%%", pricing.SavingsPercent)
|
|
}
|
|
}
|
|
|
|
// TestJobResult tests job result structure.
|
|
func TestJobResult(t *testing.T) {
|
|
result := JobResult{
|
|
JobID: "job-123",
|
|
Status: Completed,
|
|
Data: map[string]interface{}{"output": "success"},
|
|
Metrics: &JobMetrics{
|
|
ExecutionTimeMs: 100,
|
|
ProcessorType: GPU,
|
|
},
|
|
}
|
|
|
|
if result.JobID != "job-123" {
|
|
t.Errorf("expected job ID 'job-123', got %s", result.JobID)
|
|
}
|
|
if result.Status != Completed {
|
|
t.Errorf("expected Completed status, got %s", result.Status)
|
|
}
|
|
if result.Metrics == nil {
|
|
t.Error("expected non-nil metrics")
|
|
}
|
|
}
|