package synor import ( "context" "encoding/json" "net/http" "net/http/httptest" "testing" "time" ) // TestClientCreation tests client initialization. func TestClientCreation(t *testing.T) { t.Run("creates client with API key", func(t *testing.T) { client := NewClient("test-api-key") if client == nil { t.Fatal("expected non-nil client") } }) t.Run("creates client with config", func(t *testing.T) { config := Config{ APIKey: "test-api-key", Endpoint: "https://custom.api.com", Strategy: Cost, Precision: FP16, Timeout: 60 * time.Second, } client := NewClientWithConfig(config) if client == nil { t.Fatal("expected non-nil client") } if client.config.Endpoint != "https://custom.api.com" { t.Errorf("expected custom endpoint, got %s", client.config.Endpoint) } }) } // TestDefaultConfig tests default configuration. func TestDefaultConfig(t *testing.T) { config := DefaultConfig("my-key") if config.APIKey != "my-key" { t.Errorf("expected API key 'my-key', got %s", config.APIKey) } if config.Endpoint != DefaultEndpoint { t.Errorf("expected default endpoint, got %s", config.Endpoint) } if config.Strategy != Balanced { t.Errorf("expected Balanced strategy, got %s", config.Strategy) } if config.Precision != FP32 { t.Errorf("expected FP32 precision, got %s", config.Precision) } if config.Timeout != 30*time.Second { t.Errorf("expected 30s timeout, got %v", config.Timeout) } } // TestTensorCreation tests tensor creation. func TestTensorCreation(t *testing.T) { t.Run("creates tensor from data", func(t *testing.T) { data := []float32{1.0, 2.0, 3.0, 4.0, 5.0, 6.0} shape := []int{2, 3} tensor := NewTensor(data, shape, FP32) if tensor == nil { t.Fatal("expected non-nil tensor") } if len(tensor.Data) != 6 { t.Errorf("expected 6 elements, got %d", len(tensor.Data)) } if len(tensor.Shape) != 2 { t.Errorf("expected 2 dimensions, got %d", len(tensor.Shape)) } if tensor.DType != FP32 { t.Errorf("expected FP32 dtype, got %s", tensor.DType) } }) t.Run("creates zeros tensor", func(t *testing.T) { tensor := Zeros([]int{3, 4}, FP32) if tensor == nil { t.Fatal("expected non-nil tensor") } if len(tensor.Data) != 12 { t.Errorf("expected 12 elements, got %d", len(tensor.Data)) } for i, v := range tensor.Data { if v != 0.0 { t.Errorf("expected 0 at index %d, got %f", i, v) } } }) } // TestTensorSerialize tests tensor serialization. func TestTensorSerialize(t *testing.T) { data := []float32{1.0, 2.0, 3.0, 4.0} shape := []int{2, 2} tensor := NewTensor(data, shape, FP32) serialized := tensor.Serialize() if _, ok := serialized["data"]; !ok { t.Error("expected 'data' field in serialized tensor") } if _, ok := serialized["shape"]; !ok { t.Error("expected 'shape' field in serialized tensor") } if _, ok := serialized["dtype"]; !ok { t.Error("expected 'dtype' field in serialized tensor") } if serialized["dtype"] != FP32 { t.Errorf("expected FP32 dtype, got %v", serialized["dtype"]) } } // TestProcessorTypes tests processor type constants. func TestProcessorTypes(t *testing.T) { types := []ProcessorType{CPU, GPU, TPU, NPU, LPU, FPGA, WASM, WebGPU} expected := []string{"cpu", "gpu", "tpu", "npu", "lpu", "fpga", "wasm", "webgpu"} for i, pt := range types { if string(pt) != expected[i] { t.Errorf("expected %s, got %s", expected[i], pt) } } } // TestPrecisionTypes tests precision type constants. func TestPrecisionTypes(t *testing.T) { precisions := []Precision{FP64, FP32, FP16, BF16, INT8, INT4} expected := []string{"fp64", "fp32", "fp16", "bf16", "int8", "int4"} for i, p := range precisions { if string(p) != expected[i] { t.Errorf("expected %s, got %s", expected[i], p) } } } // TestJobStatus tests job status constants. func TestJobStatus(t *testing.T) { statuses := []JobStatus{Pending, Queued, Running, Completed, Failed, Cancelled} expected := []string{"pending", "queued", "running", "completed", "failed", "cancelled"} for i, s := range statuses { if string(s) != expected[i] { t.Errorf("expected %s, got %s", expected[i], s) } } } // TestBalancingStrategies tests balancing strategy constants. func TestBalancingStrategies(t *testing.T) { strategies := []BalancingStrategy{Speed, Cost, Energy, Latency, Balanced} expected := []string{"speed", "cost", "energy", "latency", "balanced"} for i, s := range strategies { if string(s) != expected[i] { t.Errorf("expected %s, got %s", expected[i], s) } } } // TestSubmitJob tests job submission with mock server. func TestSubmitJob(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/jobs" && r.Method == "POST" { w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(map[string]interface{}{ "job_id": "job-123", }) return } w.WriteHeader(http.StatusNotFound) })) defer server.Close() config := Config{ APIKey: "test-key", Endpoint: server.URL, Timeout: 5 * time.Second, } client := NewClientWithConfig(config) ctx := context.Background() job, err := client.SubmitJob(ctx, "matmul", map[string]interface{}{ "a": map[string]interface{}{"data": "base64", "shape": []int{2, 2}}, "b": map[string]interface{}{"data": "base64", "shape": []int{2, 2}}, }) if err != nil { t.Fatalf("unexpected error: %v", err) } if job.ID != "job-123" { t.Errorf("expected job ID 'job-123', got %s", job.ID) } } // TestGetJobStatus tests getting job status with mock server. func TestGetJobStatus(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/jobs/job-123" && r.Method == "GET" { w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(map[string]interface{}{ "job_id": "job-123", "status": "completed", "data": map[string]interface{}{"result": "success"}, }) return } w.WriteHeader(http.StatusNotFound) })) defer server.Close() config := Config{ APIKey: "test-key", Endpoint: server.URL, Timeout: 5 * time.Second, } client := NewClientWithConfig(config) ctx := context.Background() result, err := client.GetJobStatus(ctx, "job-123") if err != nil { t.Fatalf("unexpected error: %v", err) } if result.Status != Completed { t.Errorf("expected Completed status, got %s", result.Status) } } // TestGetPricing tests getting pricing information with mock server. func TestGetPricing(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/pricing" && r.Method == "GET" { w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(map[string]interface{}{ "pricing": []map[string]interface{}{ { "processor_type": "gpu", "spot_price": 0.0001, "avg_price_24h": 0.00012, "aws_equivalent": 0.001, "savings_percent": 90.0, }, { "processor_type": "tpu", "spot_price": 0.0002, "avg_price_24h": 0.00022, "aws_equivalent": 0.002, "savings_percent": 89.0, }, }, }) return } w.WriteHeader(http.StatusNotFound) })) defer server.Close() config := Config{ APIKey: "test-key", Endpoint: server.URL, Timeout: 5 * time.Second, } client := NewClientWithConfig(config) ctx := context.Background() pricing, err := client.GetPricing(ctx) if err != nil { t.Fatalf("unexpected error: %v", err) } if len(pricing) != 2 { t.Errorf("expected 2 pricing entries, got %d", len(pricing)) } } // TestSynorError tests error formatting. func TestSynorError(t *testing.T) { err := &SynorError{ Message: "Invalid API key", StatusCode: 401, } expected := "synor: Invalid API key (status 401)" if err.Error() != expected { t.Errorf("expected error message '%s', got '%s'", expected, err.Error()) } } // TestAPIError tests handling of API errors. func TestAPIError(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusUnauthorized) json.NewEncoder(w).Encode(map[string]interface{}{ "message": "Invalid API key", }) })) defer server.Close() config := Config{ APIKey: "bad-key", Endpoint: server.URL, Timeout: 5 * time.Second, } client := NewClientWithConfig(config) ctx := context.Background() _, err := client.SubmitJob(ctx, "test", nil) if err == nil { t.Fatal("expected error, got nil") } synorErr, ok := err.(*SynorError) if !ok { t.Fatalf("expected SynorError, got %T", err) } if synorErr.StatusCode != 401 { t.Errorf("expected status 401, got %d", synorErr.StatusCode) } } // TestCancelJob tests job cancellation with mock server. func TestCancelJob(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/jobs/job-123" && r.Method == "DELETE" { w.WriteHeader(http.StatusOK) return } w.WriteHeader(http.StatusNotFound) })) defer server.Close() config := Config{ APIKey: "test-key", Endpoint: server.URL, Timeout: 5 * time.Second, } client := NewClientWithConfig(config) ctx := context.Background() err := client.CancelJob(ctx, "job-123") if err != nil { t.Fatalf("unexpected error: %v", err) } } // TestJobMetrics tests job metrics structure. func TestJobMetrics(t *testing.T) { metrics := JobMetrics{ ExecutionTimeMs: 150.5, QueueTimeMs: 10.2, ProcessorType: GPU, ProcessorID: "gpu-001", FLOPS: 1e12, MemoryBytes: 1073741824, CostMicro: 100, EnergyMJ: 5.5, } if metrics.ExecutionTimeMs != 150.5 { t.Errorf("expected execution time 150.5, got %f", metrics.ExecutionTimeMs) } if metrics.ProcessorType != GPU { t.Errorf("expected GPU processor, got %s", metrics.ProcessorType) } } // TestPricingInfo tests pricing info structure. func TestPricingInfo(t *testing.T) { pricing := PricingInfo{ ProcessorType: GPU, SpotPrice: 0.0001, AvgPrice24h: 0.00012, AWSEquivalent: 0.001, SavingsPercent: 90.0, } if pricing.ProcessorType != GPU { t.Errorf("expected GPU processor, got %s", pricing.ProcessorType) } if pricing.SavingsPercent != 90.0 { t.Errorf("expected 90%% savings, got %f%%", pricing.SavingsPercent) } } // TestJobResult tests job result structure. func TestJobResult(t *testing.T) { result := JobResult{ JobID: "job-123", Status: Completed, Data: map[string]interface{}{"output": "success"}, Metrics: &JobMetrics{ ExecutionTimeMs: 100, ProcessorType: GPU, }, } if result.JobID != "job-123" { t.Errorf("expected job ID 'job-123', got %s", result.JobID) } if result.Status != Completed { t.Errorf("expected Completed status, got %s", result.Status) } if result.Metrics == nil { t.Error("expected non-nil metrics") } }