Adds unit tests covering tensor operations, type enums, client functionality, and serialization for all 12 SDK implementations: - JavaScript (Vitest): tensor, types, client tests - Python (pytest): tensor, types, client tests - Go: standard library tests with httptest - Flutter (flutter_test): tensor, types tests - Java (JUnit 5): tensor, types tests - Rust: embedded module tests - Ruby (minitest): tensor, types tests - C# (xUnit): tensor, types tests Tests cover: - Tensor creation (zeros, ones, random, randn, eye, arange, linspace) - Tensor operations (reshape, transpose, indexing) - Reductions (sum, mean, std, min, max) - Activations (relu, sigmoid, softmax) - Serialization/deserialization - Type enums and configuration - Client request building - Error handling
382 lines
13 KiB
Python
382 lines
13 KiB
Python
"""Unit tests for SynorCompute client."""
|
|
|
|
import pytest
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
import numpy as np
|
|
|
|
from synor_compute import SynorCompute, SynorError, Tensor
|
|
from synor_compute.types import (
|
|
ProcessorType,
|
|
Precision,
|
|
JobStatus,
|
|
SynorConfig,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_response():
|
|
"""Create a mock HTTP response."""
|
|
def _create_response(data, status=200):
|
|
response = AsyncMock()
|
|
response.status = status
|
|
response.json = AsyncMock(return_value=data)
|
|
response.text = AsyncMock(return_value=str(data))
|
|
return response
|
|
return _create_response
|
|
|
|
|
|
@pytest.fixture
|
|
def client():
|
|
"""Create a test client."""
|
|
return SynorCompute(api_key="test-api-key")
|
|
|
|
|
|
class TestClientInitialization:
|
|
"""Tests for client initialization."""
|
|
|
|
def test_create_with_api_key(self):
|
|
"""Should create client with API key."""
|
|
client = SynorCompute(api_key="my-key")
|
|
assert client is not None
|
|
|
|
def test_create_with_config(self):
|
|
"""Should create client with config."""
|
|
config = SynorConfig(
|
|
api_key="my-key",
|
|
base_url="https://custom.api.com",
|
|
timeout=60,
|
|
)
|
|
client = SynorCompute(config=config)
|
|
assert client is not None
|
|
|
|
def test_raise_without_api_key(self):
|
|
"""Should raise error without API key."""
|
|
with pytest.raises((ValueError, TypeError)):
|
|
SynorCompute(api_key="")
|
|
|
|
|
|
class TestMatMul:
|
|
"""Tests for matrix multiplication."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_matmul_request(self, client, mock_response):
|
|
"""Should send matmul request."""
|
|
response_data = {
|
|
"job_id": "job-123",
|
|
"status": "completed",
|
|
"result": {"data": "base64", "shape": [2, 4], "dtype": "fp32"},
|
|
"execution_time_ms": 15,
|
|
"processor": "gpu",
|
|
}
|
|
|
|
with patch.object(client, "_request", new_callable=AsyncMock) as mock_req:
|
|
mock_req.return_value = response_data
|
|
|
|
a = Tensor.random((2, 3))
|
|
b = Tensor.random((3, 4))
|
|
result = await client.matmul(a, b)
|
|
|
|
mock_req.assert_called_once()
|
|
assert result.status == JobStatus.COMPLETED
|
|
assert result.job_id == "job-123"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_matmul_with_precision(self, client):
|
|
"""Should support precision option."""
|
|
with patch.object(client, "_request", new_callable=AsyncMock) as mock_req:
|
|
mock_req.return_value = {
|
|
"job_id": "job-456",
|
|
"status": "completed",
|
|
"result": {"data": "base64", "shape": [2, 2], "dtype": "fp16"},
|
|
}
|
|
|
|
a = Tensor.random((2, 2))
|
|
b = Tensor.random((2, 2))
|
|
await client.matmul(a, b, precision=Precision.FP16)
|
|
|
|
call_args = mock_req.call_args
|
|
assert call_args[1].get("precision") == "fp16" or \
|
|
call_args[0][1].get("precision") == "fp16"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_matmul_with_processor(self, client):
|
|
"""Should support processor selection."""
|
|
with patch.object(client, "_request", new_callable=AsyncMock) as mock_req:
|
|
mock_req.return_value = {
|
|
"job_id": "job-789",
|
|
"status": "completed",
|
|
"result": {"data": "base64", "shape": [2, 2], "dtype": "fp32"},
|
|
}
|
|
|
|
a = Tensor.random((2, 2))
|
|
b = Tensor.random((2, 2))
|
|
await client.matmul(a, b, processor=ProcessorType.TPU)
|
|
|
|
mock_req.assert_called_once()
|
|
|
|
|
|
class TestConv2d:
|
|
"""Tests for 2D convolution."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_conv2d_request(self, client):
|
|
"""Should send conv2d request."""
|
|
with patch.object(client, "_request", new_callable=AsyncMock) as mock_req:
|
|
mock_req.return_value = {
|
|
"job_id": "conv-123",
|
|
"status": "completed",
|
|
"result": {"data": "base64", "shape": [1, 32, 30, 30], "dtype": "fp32"},
|
|
}
|
|
|
|
input_tensor = Tensor.random((1, 3, 32, 32))
|
|
kernel = Tensor.random((32, 3, 3, 3))
|
|
result = await client.conv2d(input_tensor, kernel)
|
|
|
|
assert result.status == JobStatus.COMPLETED
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_conv2d_with_options(self, client):
|
|
"""Should support stride and padding options."""
|
|
with patch.object(client, "_request", new_callable=AsyncMock) as mock_req:
|
|
mock_req.return_value = {
|
|
"job_id": "conv-456",
|
|
"status": "completed",
|
|
"result": {"data": "base64", "shape": [1, 32, 16, 16], "dtype": "fp32"},
|
|
}
|
|
|
|
input_tensor = Tensor.random((1, 3, 32, 32))
|
|
kernel = Tensor.random((32, 3, 3, 3))
|
|
await client.conv2d(
|
|
input_tensor,
|
|
kernel,
|
|
stride=(2, 2),
|
|
padding=(1, 1),
|
|
)
|
|
|
|
mock_req.assert_called_once()
|
|
|
|
|
|
class TestAttention:
|
|
"""Tests for attention operation."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_attention_request(self, client):
|
|
"""Should send attention request."""
|
|
with patch.object(client, "_request", new_callable=AsyncMock) as mock_req:
|
|
mock_req.return_value = {
|
|
"job_id": "attn-123",
|
|
"status": "completed",
|
|
"result": {"data": "base64", "shape": [1, 8, 64, 64], "dtype": "fp16"},
|
|
}
|
|
|
|
query = Tensor.random((1, 8, 64, 64))
|
|
key = Tensor.random((1, 8, 64, 64))
|
|
value = Tensor.random((1, 8, 64, 64))
|
|
result = await client.attention(query, key, value)
|
|
|
|
assert result.status == JobStatus.COMPLETED
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_attention_flash(self, client):
|
|
"""Should support flash attention option."""
|
|
with patch.object(client, "_request", new_callable=AsyncMock) as mock_req:
|
|
mock_req.return_value = {
|
|
"job_id": "attn-456",
|
|
"status": "completed",
|
|
"result": {"data": "base64", "shape": [1, 8, 64, 64], "dtype": "fp16"},
|
|
}
|
|
|
|
query = Tensor.random((1, 8, 64, 64))
|
|
key = Tensor.random((1, 8, 64, 64))
|
|
value = Tensor.random((1, 8, 64, 64))
|
|
await client.attention(query, key, value, flash=True)
|
|
|
|
mock_req.assert_called_once()
|
|
|
|
|
|
class TestInference:
|
|
"""Tests for LLM inference."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_inference_request(self, client):
|
|
"""Should send inference request."""
|
|
with patch.object(client, "_request", new_callable=AsyncMock) as mock_req:
|
|
mock_req.return_value = {
|
|
"job_id": "inf-123",
|
|
"status": "completed",
|
|
"result": "The answer to your question is...",
|
|
"execution_time_ms": 2500,
|
|
"processor": "gpu",
|
|
}
|
|
|
|
result = await client.inference(
|
|
model="llama-3-70b",
|
|
prompt="What is the capital of France?",
|
|
)
|
|
|
|
assert result.status == JobStatus.COMPLETED
|
|
assert "answer" in result.result
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_inference_with_options(self, client):
|
|
"""Should support inference options."""
|
|
with patch.object(client, "_request", new_callable=AsyncMock) as mock_req:
|
|
mock_req.return_value = {
|
|
"job_id": "inf-456",
|
|
"status": "completed",
|
|
"result": "Generated text...",
|
|
}
|
|
|
|
await client.inference(
|
|
model="llama-3-70b",
|
|
prompt="Hello",
|
|
max_tokens=512,
|
|
temperature=0.8,
|
|
top_p=0.95,
|
|
top_k=40,
|
|
)
|
|
|
|
mock_req.assert_called_once()
|
|
|
|
|
|
class TestModelRegistry:
|
|
"""Tests for model registry operations."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_models(self, client):
|
|
"""Should list available models."""
|
|
with patch.object(client, "_request", new_callable=AsyncMock) as mock_req:
|
|
mock_req.return_value = {
|
|
"models": [
|
|
{"id": "llama-3-70b", "name": "Llama 3 70B", "category": "llm"},
|
|
{"id": "mistral-7b", "name": "Mistral 7B", "category": "llm"},
|
|
]
|
|
}
|
|
|
|
models = await client.list_models()
|
|
|
|
assert len(models) == 2
|
|
assert models[0]["id"] == "llama-3-70b"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_models_by_category(self, client):
|
|
"""Should filter models by category."""
|
|
with patch.object(client, "_request", new_callable=AsyncMock) as mock_req:
|
|
mock_req.return_value = {
|
|
"models": [
|
|
{"id": "sd-xl", "name": "Stable Diffusion XL", "category": "image_generation"},
|
|
]
|
|
}
|
|
|
|
models = await client.list_models(category="image_generation")
|
|
|
|
assert len(models) == 1
|
|
assert models[0]["category"] == "image_generation"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_model(self, client):
|
|
"""Should get model by ID."""
|
|
with patch.object(client, "_request", new_callable=AsyncMock) as mock_req:
|
|
mock_req.return_value = {
|
|
"id": "llama-3-70b",
|
|
"name": "Llama 3 70B",
|
|
"category": "llm",
|
|
"parameters": 70000000000,
|
|
"context_length": 8192,
|
|
}
|
|
|
|
model = await client.get_model("llama-3-70b")
|
|
|
|
assert model["id"] == "llama-3-70b"
|
|
assert model["parameters"] == 70000000000
|
|
|
|
|
|
class TestPricingAndUsage:
|
|
"""Tests for pricing and usage APIs."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_pricing(self, client):
|
|
"""Should get pricing information."""
|
|
with patch.object(client, "_request", new_callable=AsyncMock) as mock_req:
|
|
mock_req.return_value = {
|
|
"pricing": [
|
|
{"processor": "gpu", "price_per_second": 0.0001, "available_units": 100},
|
|
{"processor": "tpu", "price_per_second": 0.0002, "available_units": 50},
|
|
]
|
|
}
|
|
|
|
pricing = await client.get_pricing()
|
|
|
|
assert len(pricing) == 2
|
|
assert pricing[0]["processor"] == "gpu"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_usage(self, client):
|
|
"""Should get usage statistics."""
|
|
with patch.object(client, "_request", new_callable=AsyncMock) as mock_req:
|
|
mock_req.return_value = {
|
|
"total_jobs": 1000,
|
|
"completed_jobs": 950,
|
|
"failed_jobs": 50,
|
|
"total_compute_seconds": 3600,
|
|
"total_cost": 0.36,
|
|
}
|
|
|
|
usage = await client.get_usage()
|
|
|
|
assert usage["total_jobs"] == 1000
|
|
assert usage["completed_jobs"] == 950
|
|
|
|
|
|
class TestErrorHandling:
|
|
"""Tests for error handling."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_api_error(self, client):
|
|
"""Should handle API errors."""
|
|
with patch.object(client, "_request", new_callable=AsyncMock) as mock_req:
|
|
mock_req.side_effect = SynorError("Invalid API key", status_code=401)
|
|
|
|
a = Tensor.random((2, 2))
|
|
b = Tensor.random((2, 2))
|
|
|
|
with pytest.raises(SynorError) as exc_info:
|
|
await client.matmul(a, b)
|
|
|
|
assert exc_info.value.status_code == 401
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_network_error(self, client):
|
|
"""Should handle network errors."""
|
|
with patch.object(client, "_request", new_callable=AsyncMock) as mock_req:
|
|
mock_req.side_effect = Exception("Network error")
|
|
|
|
a = Tensor.random((2, 2))
|
|
b = Tensor.random((2, 2))
|
|
|
|
with pytest.raises(Exception, match="Network error"):
|
|
await client.matmul(a, b)
|
|
|
|
|
|
class TestHealthCheck:
|
|
"""Tests for health check."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_healthy_service(self, client):
|
|
"""Should return True for healthy service."""
|
|
with patch.object(client, "_request", new_callable=AsyncMock) as mock_req:
|
|
mock_req.return_value = {"status": "healthy"}
|
|
|
|
healthy = await client.health_check()
|
|
|
|
assert healthy is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_unhealthy_service(self, client):
|
|
"""Should return False for unhealthy service."""
|
|
with patch.object(client, "_request", new_callable=AsyncMock) as mock_req:
|
|
mock_req.side_effect = Exception("Service unavailable")
|
|
|
|
healthy = await client.health_check()
|
|
|
|
assert healthy is False
|