Adds comprehensive model management and training capabilities: synor-compute (Rust): - ModelRegistry with pre-registered popular models - LLMs: Llama 3/3.1, Mistral, Mixtral, Qwen, DeepSeek, Phi, CodeLlama - Embedding: BGE, E5 - Image: Stable Diffusion XL, FLUX.1 - Speech: Whisper - Multi-modal: LLaVA - ModelInfo with parameters, format, precision, context length - Custom model upload and registration - Model search by name/category Flutter SDK: - Model registry APIs: listModels, getModel, searchModels - Custom model upload with multipart upload - Training APIs: train(), fineTune(), trainStream() - TrainingOptions: framework, epochs, batch_size, learning_rate - TrainingProgress for real-time updates - ModelUploadOptions and ModelUploadResult Example code for: - Listing available models by category - Fine-tuning pre-trained models - Uploading custom Python/ONNX models - Streaming training progress This enables users to: 1. Use pre-registered models like 'llama-3-70b' 2. Upload their own custom models 3. Fine-tune models on custom datasets 4. Track training progress in real-time
362 lines
10 KiB
Dart
362 lines
10 KiB
Dart
import 'dart:io';
|
|
|
|
import 'package:synor_compute/synor_compute.dart';
|
|
|
|
/// Example usage of Synor Compute SDK for Flutter/Dart
|
|
void main() async {
|
|
// Initialize client with API key
|
|
final client = SynorCompute(
|
|
apiKey: Platform.environment['SYNOR_API_KEY'] ?? 'your-api-key',
|
|
// Optional: customize defaults
|
|
defaultProcessor: ProcessorType.auto,
|
|
defaultPrecision: Precision.fp32,
|
|
defaultPriority: Priority.normal,
|
|
);
|
|
|
|
try {
|
|
// Check service health
|
|
final isHealthy = await client.healthCheck();
|
|
print('Service healthy: $isHealthy\n');
|
|
|
|
// Example 1: Matrix multiplication
|
|
await matrixMultiplicationExample(client);
|
|
|
|
// Example 2: Tensor operations
|
|
await tensorOperationsExample(client);
|
|
|
|
// Example 3: LLM inference
|
|
await llmInferenceExample(client);
|
|
|
|
// Example 4: Streaming inference
|
|
await streamingInferenceExample(client);
|
|
|
|
// Example 5: Pricing and usage
|
|
await pricingExample(client);
|
|
|
|
// Example 6: List available models
|
|
await modelRegistryExample(client);
|
|
|
|
// Example 7: Training a model
|
|
await trainingExample(client);
|
|
|
|
// Example 8: Custom model upload
|
|
await customModelExample(client);
|
|
} finally {
|
|
// Always dispose client to release resources
|
|
client.dispose();
|
|
}
|
|
}
|
|
|
|
/// Matrix multiplication example
|
|
Future<void> matrixMultiplicationExample(SynorCompute client) async {
|
|
print('=== Matrix Multiplication ===');
|
|
|
|
// Create random matrices
|
|
final a = Tensor.rand([256, 512]);
|
|
final b = Tensor.rand([512, 256]);
|
|
|
|
print('A: ${a.shape}');
|
|
print('B: ${b.shape}');
|
|
|
|
// Perform multiplication on GPU with FP16 precision
|
|
final result = await client.matmul(
|
|
a,
|
|
b,
|
|
options: MatMulOptions(
|
|
precision: Precision.fp16,
|
|
processor: ProcessorType.gpu,
|
|
priority: Priority.high,
|
|
),
|
|
);
|
|
|
|
if (result.isSuccess) {
|
|
print('Result: ${result.result!.shape}');
|
|
print('Execution time: ${result.executionTimeMs}ms');
|
|
print('Cost: \$${result.cost?.toStringAsFixed(6)}');
|
|
print('Processor: ${result.processor?.value}');
|
|
} else {
|
|
print('Error: ${result.error}');
|
|
}
|
|
print('');
|
|
}
|
|
|
|
/// Local tensor operations example
|
|
Future<void> tensorOperationsExample(SynorCompute client) async {
|
|
print('=== Tensor Operations ===');
|
|
|
|
// Create tensors
|
|
final x = Tensor.randn([100], mean: 0.0, std: 1.0);
|
|
print('Random normal tensor: mean=${x.mean().toStringAsFixed(4)}, '
|
|
'std=${x.std().toStringAsFixed(4)}');
|
|
|
|
// Create identity matrix
|
|
final eye = Tensor.eye(4);
|
|
print('Identity matrix:\n${eye.toNestedList()}');
|
|
|
|
// Create linspace
|
|
final linspace = Tensor.linspace(0, 10, 5);
|
|
print('Linspace [0, 10, 5]: ${linspace.toNestedList()}');
|
|
|
|
// Reshape operations
|
|
final matrix = Tensor.arange(0, 12).reshape([3, 4]);
|
|
print('Reshaped [0..12] to [3,4]:\n${matrix.toNestedList()}');
|
|
|
|
// Transpose
|
|
final transposed = matrix.transpose();
|
|
print('Transposed to ${transposed.shape}');
|
|
|
|
// Activations
|
|
final input = Tensor(shape: [5], data: [-2.0, -1.0, 0.0, 1.0, 2.0]);
|
|
print('ReLU of $input: ${input.relu().toNestedList()}');
|
|
print('Sigmoid of $input: ${input.sigmoid().toNestedList()}');
|
|
|
|
// Softmax
|
|
final logits = Tensor(shape: [4], data: [1.0, 2.0, 3.0, 4.0]);
|
|
print('Softmax of $logits: ${logits.softmax().toNestedList()}');
|
|
|
|
print('');
|
|
}
|
|
|
|
/// LLM inference example
|
|
Future<void> llmInferenceExample(SynorCompute client) async {
|
|
print('=== LLM Inference ===');
|
|
|
|
final result = await client.inference(
|
|
'llama-3-70b',
|
|
'What is the capital of France? Answer in one word.',
|
|
options: InferenceOptions(
|
|
maxTokens: 10,
|
|
temperature: 0.1,
|
|
processor: ProcessorType.lpu, // Use LPU for LLM
|
|
),
|
|
);
|
|
|
|
if (result.isSuccess) {
|
|
print('Response: ${result.result}');
|
|
print('Time: ${result.executionTimeMs}ms');
|
|
} else {
|
|
print('Error: ${result.error}');
|
|
}
|
|
print('');
|
|
}
|
|
|
|
/// Streaming inference example
|
|
Future<void> streamingInferenceExample(SynorCompute client) async {
|
|
print('=== Streaming Inference ===');
|
|
print('Response: ');
|
|
|
|
await for (final token in client.inferenceStream(
|
|
'llama-3-70b',
|
|
'Write a short poem about distributed computing.',
|
|
options: InferenceOptions(
|
|
maxTokens: 100,
|
|
temperature: 0.7,
|
|
),
|
|
)) {
|
|
stdout.write(token);
|
|
}
|
|
|
|
print('\n');
|
|
}
|
|
|
|
/// Pricing and usage example
|
|
Future<void> pricingExample(SynorCompute client) async {
|
|
print('=== Pricing Information ===');
|
|
|
|
final pricing = await client.getPricing();
|
|
|
|
print('Current spot prices:');
|
|
for (final p in pricing) {
|
|
print(' ${p.processor.value.toUpperCase().padRight(8)}: '
|
|
'\$${p.pricePerSecond.toStringAsFixed(6)}/sec, '
|
|
'${p.availableUnits} units available, '
|
|
'${p.utilizationPercent.toStringAsFixed(1)}% utilized');
|
|
}
|
|
|
|
print('');
|
|
|
|
// Get usage stats
|
|
final usage = await client.getUsage();
|
|
print('Usage Statistics:');
|
|
print(' Total jobs: ${usage.totalJobs}');
|
|
print(' Completed: ${usage.completedJobs}');
|
|
print(' Failed: ${usage.failedJobs}');
|
|
print(' Total compute time: ${usage.totalComputeSeconds.toStringAsFixed(2)}s');
|
|
print(' Total cost: \$${usage.totalCost.toStringAsFixed(4)}');
|
|
print('');
|
|
}
|
|
|
|
/// Model registry example - list available models
|
|
Future<void> modelRegistryExample(SynorCompute client) async {
|
|
print('=== Model Registry ===');
|
|
|
|
// List all available models
|
|
final allModels = await client.listModels();
|
|
print('Total available models: ${allModels.length}');
|
|
|
|
// List only LLMs
|
|
final llms = await client.listModels(category: ModelCategory.llm);
|
|
print('\nAvailable LLMs:');
|
|
for (final model in llms.take(5)) {
|
|
print(' ${model.id.padRight(20)} ${model.formattedParameters.padRight(8)} '
|
|
'${model.name}');
|
|
}
|
|
|
|
// Search for a specific model
|
|
final searchResults = await client.searchModels('llama');
|
|
print('\nSearch "llama": ${searchResults.length} results');
|
|
|
|
// Get specific model info
|
|
final modelInfo = await client.getModel('llama-3-70b');
|
|
print('\nModel details for ${modelInfo.name}:');
|
|
print(' Parameters: ${modelInfo.formattedParameters}');
|
|
print(' Context length: ${modelInfo.contextLength}');
|
|
print(' Format: ${modelInfo.format.value}');
|
|
print(' Recommended processor: ${modelInfo.recommendedProcessor.value}');
|
|
print(' License: ${modelInfo.license}');
|
|
|
|
// List embedding models
|
|
final embeddings = await client.listModels(category: ModelCategory.embedding);
|
|
print('\nAvailable embedding models:');
|
|
for (final model in embeddings) {
|
|
print(' ${model.id} - ${model.name}');
|
|
}
|
|
|
|
// List image generation models
|
|
final imageGen =
|
|
await client.listModels(category: ModelCategory.imageGeneration);
|
|
print('\nAvailable image generation models:');
|
|
for (final model in imageGen) {
|
|
print(' ${model.id} - ${model.name}');
|
|
}
|
|
|
|
print('');
|
|
}
|
|
|
|
/// Training example - train/fine-tune a model
|
|
Future<void> trainingExample(SynorCompute client) async {
|
|
print('=== Model Training ===');
|
|
|
|
// Example: Fine-tune Llama 3 8B on custom dataset
|
|
print('Fine-tuning llama-3-8b on custom dataset...');
|
|
|
|
// Note: In practice, you'd upload your dataset first:
|
|
// final datasetCid = await client.uploadTensor(datasetTensor);
|
|
|
|
final result = await client.fineTune(
|
|
baseModel: 'llama-3-8b', // Use model alias
|
|
datasetCid: 'QmYourDatasetCID', // Your uploaded dataset
|
|
outputAlias: 'my-custom-llama', // Optional: alias for trained model
|
|
options: TrainingOptions(
|
|
framework: MlFramework.pytorch,
|
|
epochs: 3,
|
|
batchSize: 8,
|
|
learningRate: 0.00002,
|
|
optimizer: 'adamw',
|
|
hyperparameters: {
|
|
'weight_decay': 0.01,
|
|
'warmup_steps': 100,
|
|
'gradient_accumulation_steps': 4,
|
|
},
|
|
checkpointEvery: 500, // Save checkpoint every 500 steps
|
|
processor: ProcessorType.gpu,
|
|
priority: Priority.high,
|
|
),
|
|
);
|
|
|
|
if (result.isSuccess) {
|
|
final training = result.result!;
|
|
print('Training completed!');
|
|
print(' New model CID: ${training.modelCid}');
|
|
print(' Final loss: ${training.finalLoss.toStringAsFixed(4)}');
|
|
print(' Duration: ${training.durationMs / 1000}s');
|
|
print(' Cost: \$${training.cost.toStringAsFixed(4)}');
|
|
print(' Metrics: ${training.metrics}');
|
|
|
|
// Now use your trained model for inference
|
|
print('\nUsing trained model for inference:');
|
|
final inference = await client.inference(
|
|
training.modelCid, // Use the CID of your trained model
|
|
'Hello, how are you?',
|
|
options: InferenceOptions(maxTokens: 50),
|
|
);
|
|
print('Response: ${inference.result}');
|
|
} else {
|
|
print('Training failed: ${result.error}');
|
|
}
|
|
|
|
print('');
|
|
|
|
// Example: Streaming training progress
|
|
print('Training with streaming progress...');
|
|
await for (final progress in client.trainStream(
|
|
modelCid: 'llama-3-8b',
|
|
datasetCid: 'QmYourDatasetCID',
|
|
options: TrainingOptions(epochs: 1, batchSize: 16),
|
|
)) {
|
|
// Update UI with progress
|
|
stdout.write('\r${progress.progressText} - '
|
|
'${progress.samplesPerSecond} samples/s');
|
|
}
|
|
print('\nTraining complete!');
|
|
|
|
print('');
|
|
}
|
|
|
|
/// Custom model upload example
|
|
Future<void> customModelExample(SynorCompute client) async {
|
|
print('=== Custom Model Upload ===');
|
|
|
|
// Example: Upload a custom ONNX model
|
|
// In practice, you'd read this from a file:
|
|
// final modelBytes = await File('my_model.onnx').readAsBytes();
|
|
|
|
// For demonstration, we'll show the API structure
|
|
print('To upload your own Python-trained model:');
|
|
print('''
|
|
1. Train your model in Python:
|
|
|
|
import torch
|
|
model = MyModel()
|
|
# ... train model ...
|
|
torch.onnx.export(model, dummy_input, "my_model.onnx")
|
|
|
|
2. Upload to Synor Compute:
|
|
|
|
final modelBytes = await File('my_model.onnx').readAsBytes();
|
|
final result = await client.uploadModel(
|
|
modelBytes,
|
|
ModelUploadOptions(
|
|
name: 'my-custom-model',
|
|
description: 'My custom trained model',
|
|
category: ModelCategory.custom,
|
|
format: ModelFormat.onnx,
|
|
alias: 'my-model', // Optional shortcut name
|
|
isPublic: false, // Keep private
|
|
license: 'Proprietary',
|
|
),
|
|
);
|
|
print('Uploaded! CID: \${result.cid}');
|
|
|
|
3. Use for inference:
|
|
|
|
final result = await client.inference(
|
|
result.cid, // or 'my-model' if you set an alias
|
|
'Your input data',
|
|
);
|
|
''');
|
|
|
|
// Supported model formats
|
|
print('Supported model formats:');
|
|
for (final format in ModelFormat.values) {
|
|
print(' - ${format.value}');
|
|
}
|
|
|
|
// Supported categories
|
|
print('\nSupported model categories:');
|
|
for (final category in ModelCategory.values) {
|
|
print(' - ${category.value}');
|
|
}
|
|
|
|
print('');
|
|
}
|