diff --git a/crates/synor-compute/src/error.rs b/crates/synor-compute/src/error.rs index 33a34ee..60f3605 100644 --- a/crates/synor-compute/src/error.rs +++ b/crates/synor-compute/src/error.rs @@ -77,6 +77,18 @@ pub enum ComputeError { /// Internal error. #[error("Internal error: {0}")] Internal(String), + + /// Model not found. + #[error("Model not found: {0}")] + ModelNotFound(String), + + /// Model upload failed. + #[error("Model upload failed: {0}")] + ModelUploadFailed(String), + + /// Invalid model format. + #[error("Invalid model format: {0}")] + InvalidModelFormat(String), } impl From for ComputeError { diff --git a/crates/synor-compute/src/lib.rs b/crates/synor-compute/src/lib.rs index 6baef4c..4e9c4a2 100644 --- a/crates/synor-compute/src/lib.rs +++ b/crates/synor-compute/src/lib.rs @@ -52,6 +52,7 @@ pub mod device; pub mod error; pub mod market; pub mod memory; +pub mod model; pub mod processor; pub mod scheduler; pub mod task; @@ -77,6 +78,10 @@ pub use task::{ ComputeTask, DecomposedWorkload, Task, TaskDecomposer, TaskId, TaskPriority, TaskResult, TaskStatus, }; +pub use model::{ + ModelCategory, ModelFormat, ModelId, ModelInfo, ModelRegistry, ModelUploadRequest, + ModelUploadResponse, +}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; diff --git a/crates/synor-compute/src/model/mod.rs b/crates/synor-compute/src/model/mod.rs new file mode 100644 index 0000000..4c1b9cd --- /dev/null +++ b/crates/synor-compute/src/model/mod.rs @@ -0,0 +1,588 @@ +//! Model registry and management for Synor Compute. +//! +//! Provides: +//! - Pre-registered popular models (LLMs, vision, audio) +//! - Custom model uploads +//! - Model format conversion +//! - Model caching and warm-up + +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::sync::Arc; + +use parking_lot::RwLock; + +use crate::error::ComputeError; +use crate::processor::{Precision, ProcessorType}; + +/// Model identifier (CID or alias). +#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct ModelId(pub String); + +impl ModelId { + /// Creates from a CID. + pub fn from_cid(cid: &str) -> Self { + Self(cid.to_string()) + } + + /// Creates from an alias. + pub fn from_alias(alias: &str) -> Self { + Self(alias.to_string()) + } + + /// Checks if this is a CID (starts with Qm or bafy). + pub fn is_cid(&self) -> bool { + self.0.starts_with("Qm") || self.0.starts_with("bafy") + } +} + +impl std::fmt::Display for ModelId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +impl From<&str> for ModelId { + fn from(s: &str) -> Self { + Self(s.to_string()) + } +} + +/// Model format specification. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub enum ModelFormat { + /// ONNX Runtime format. + Onnx, + /// PyTorch checkpoint (.pt, .pth). + PyTorch, + /// PyTorch TorchScript (.pt). + TorchScript, + /// TensorFlow SavedModel. + TensorFlow, + /// TensorFlow Lite. + TFLite, + /// SafeTensors format. + SafeTensors, + /// GGUF format (llama.cpp). + Gguf, + /// GGML format (legacy). + Ggml, + /// Custom binary format. + Custom, +} + +/// Model category. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub enum ModelCategory { + /// Large language models. + Llm, + /// Text embedding models. + Embedding, + /// Image classification. + ImageClassification, + /// Object detection. + ObjectDetection, + /// Image segmentation. + Segmentation, + /// Image generation (diffusion). + ImageGeneration, + /// Speech-to-text. + SpeechToText, + /// Text-to-speech. + TextToSpeech, + /// Video generation. + VideoGeneration, + /// Multi-modal (vision-language). + MultiModal, + /// Custom/other. + Custom, +} + +/// Model metadata. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ModelInfo { + /// Model ID (alias or CID). + pub id: ModelId, + /// Human-readable name. + pub name: String, + /// Description. + pub description: String, + /// Model category. + pub category: ModelCategory, + /// Storage CID (actual location). + pub cid: String, + /// Model format. + pub format: ModelFormat, + /// Model size in bytes. + pub size_bytes: u64, + /// Parameter count. + pub parameters: u64, + /// Supported precisions. + pub supported_precisions: Vec, + /// Recommended processor type. + pub recommended_processor: ProcessorType, + /// Context length (for LLMs). + pub context_length: Option, + /// Input schema (JSON Schema). + pub input_schema: Option, + /// Output schema (JSON Schema). + pub output_schema: Option, + /// License. + pub license: String, + /// Provider/author. + pub provider: String, + /// Version. + pub version: String, + /// Is public (anyone can use). + pub is_public: bool, + /// Owner address (for private models). + pub owner: Option<[u8; 32]>, +} + +impl ModelInfo { + /// Creates a new LLM model info. + pub fn llm( + alias: &str, + name: &str, + cid: &str, + parameters: u64, + context_length: u32, + ) -> Self { + Self { + id: ModelId::from_alias(alias), + name: name.to_string(), + description: format!("{} - {} parameter LLM", name, format_params(parameters)), + category: ModelCategory::Llm, + cid: cid.to_string(), + format: ModelFormat::SafeTensors, + size_bytes: parameters * 2, // ~2 bytes per param in fp16 + parameters, + supported_precisions: vec![Precision::Fp16, Precision::Bf16, Precision::Int8, Precision::Int4], + recommended_processor: ProcessorType::Lpu, + context_length: Some(context_length), + input_schema: None, + output_schema: None, + license: "Apache-2.0".to_string(), + provider: "Synor".to_string(), + version: "1.0".to_string(), + is_public: true, + owner: None, + } + } + + /// Creates a custom model info. + pub fn custom( + cid: &str, + name: &str, + category: ModelCategory, + format: ModelFormat, + size_bytes: u64, + owner: [u8; 32], + ) -> Self { + Self { + id: ModelId::from_cid(cid), + name: name.to_string(), + description: format!("Custom model: {}", name), + category, + cid: cid.to_string(), + format, + size_bytes, + parameters: 0, + supported_precisions: vec![Precision::Fp32, Precision::Fp16], + recommended_processor: ProcessorType::Cpu(crate::processor::CpuVariant::default()), + context_length: None, + input_schema: None, + output_schema: None, + license: "Custom".to_string(), + provider: "User".to_string(), + version: "1.0".to_string(), + is_public: false, + owner: Some(owner), + } + } +} + +/// Format parameter count for display. +fn format_params(params: u64) -> String { + if params >= 1_000_000_000_000 { + format!("{:.1}T", params as f64 / 1e12) + } else if params >= 1_000_000_000 { + format!("{:.1}B", params as f64 / 1e9) + } else if params >= 1_000_000 { + format!("{:.1}M", params as f64 / 1e6) + } else { + format!("{}", params) + } +} + +/// Model registry. +pub struct ModelRegistry { + /// Registered models by ID/alias. + models: RwLock>, + /// Alias to CID mapping. + aliases: RwLock>, +} + +impl ModelRegistry { + /// Creates a new model registry with default models. + pub fn new() -> Self { + let registry = Self { + models: RwLock::new(HashMap::new()), + aliases: RwLock::new(HashMap::new()), + }; + registry.register_default_models(); + registry + } + + /// Registers default/popular models. + fn register_default_models(&self) { + let default_models = vec![ + // ===== LLMs ===== + // Llama 3 family + ModelInfo::llm("llama-3-8b", "Llama 3 8B", "QmLlama3_8B_placeholder", 8_000_000_000, 8192), + ModelInfo::llm("llama-3-70b", "Llama 3 70B", "QmLlama3_70B_placeholder", 70_000_000_000, 8192), + ModelInfo::llm("llama-3.1-8b", "Llama 3.1 8B", "QmLlama31_8B_placeholder", 8_000_000_000, 128000), + ModelInfo::llm("llama-3.1-70b", "Llama 3.1 70B", "QmLlama31_70B_placeholder", 70_000_000_000, 128000), + ModelInfo::llm("llama-3.1-405b", "Llama 3.1 405B", "QmLlama31_405B_placeholder", 405_000_000_000, 128000), + + // Mistral family + ModelInfo::llm("mistral-7b", "Mistral 7B", "QmMistral7B_placeholder", 7_000_000_000, 32768), + ModelInfo::llm("mixtral-8x7b", "Mixtral 8x7B", "QmMixtral8x7B_placeholder", 46_000_000_000, 32768), + ModelInfo::llm("mixtral-8x22b", "Mixtral 8x22B", "QmMixtral8x22B_placeholder", 176_000_000_000, 65536), + + // Qwen family + ModelInfo::llm("qwen-2.5-7b", "Qwen 2.5 7B", "QmQwen25_7B_placeholder", 7_000_000_000, 128000), + ModelInfo::llm("qwen-2.5-72b", "Qwen 2.5 72B", "QmQwen25_72B_placeholder", 72_000_000_000, 128000), + + // DeepSeek family + ModelInfo::llm("deepseek-v2", "DeepSeek V2", "QmDeepSeekV2_placeholder", 236_000_000_000, 128000), + ModelInfo::llm("deepseek-coder-33b", "DeepSeek Coder 33B", "QmDeepSeekCoder33B_placeholder", 33_000_000_000, 16384), + + // Phi family (small/efficient) + ModelInfo::llm("phi-3-mini", "Phi 3 Mini", "QmPhi3Mini_placeholder", 3_800_000_000, 128000), + ModelInfo::llm("phi-3-medium", "Phi 3 Medium", "QmPhi3Medium_placeholder", 14_000_000_000, 128000), + + // Code models + ModelInfo::llm("codellama-34b", "Code Llama 34B", "QmCodeLlama34B_placeholder", 34_000_000_000, 16384), + ModelInfo::llm("starcoder2-15b", "StarCoder2 15B", "QmStarCoder2_15B_placeholder", 15_000_000_000, 16384), + + // ===== Embedding Models ===== + ModelInfo { + id: ModelId::from_alias("bge-large"), + name: "BGE Large".to_string(), + description: "BAAI General Embedding - Large".to_string(), + category: ModelCategory::Embedding, + cid: "QmBGELarge_placeholder".to_string(), + format: ModelFormat::SafeTensors, + size_bytes: 1_300_000_000, + parameters: 335_000_000, + supported_precisions: vec![Precision::Fp32, Precision::Fp16], + recommended_processor: ProcessorType::Gpu(crate::processor::GpuVariant::default()), + context_length: Some(512), + input_schema: None, + output_schema: None, + license: "MIT".to_string(), + provider: "BAAI".to_string(), + version: "1.5".to_string(), + is_public: true, + owner: None, + }, + ModelInfo { + id: ModelId::from_alias("e5-large-v2"), + name: "E5 Large v2".to_string(), + description: "Microsoft E5 Embedding - Large".to_string(), + category: ModelCategory::Embedding, + cid: "QmE5LargeV2_placeholder".to_string(), + format: ModelFormat::SafeTensors, + size_bytes: 1_300_000_000, + parameters: 335_000_000, + supported_precisions: vec![Precision::Fp32, Precision::Fp16], + recommended_processor: ProcessorType::Gpu(crate::processor::GpuVariant::default()), + context_length: Some(512), + input_schema: None, + output_schema: None, + license: "MIT".to_string(), + provider: "Microsoft".to_string(), + version: "2.0".to_string(), + is_public: true, + owner: None, + }, + + // ===== Vision Models ===== + ModelInfo { + id: ModelId::from_alias("stable-diffusion-xl"), + name: "Stable Diffusion XL".to_string(), + description: "SDXL 1.0 - High quality image generation".to_string(), + category: ModelCategory::ImageGeneration, + cid: "QmSDXL_placeholder".to_string(), + format: ModelFormat::SafeTensors, + size_bytes: 6_900_000_000, + parameters: 2_600_000_000, + supported_precisions: vec![Precision::Fp16, Precision::Bf16], + recommended_processor: ProcessorType::Gpu(crate::processor::GpuVariant::default()), + context_length: None, + input_schema: None, + output_schema: None, + license: "CreativeML Open RAIL++-M".to_string(), + provider: "Stability AI".to_string(), + version: "1.0".to_string(), + is_public: true, + owner: None, + }, + ModelInfo { + id: ModelId::from_alias("flux-1-dev"), + name: "FLUX.1 Dev".to_string(), + description: "FLUX.1 Development - State of the art image generation".to_string(), + category: ModelCategory::ImageGeneration, + cid: "QmFLUX1Dev_placeholder".to_string(), + format: ModelFormat::SafeTensors, + size_bytes: 23_000_000_000, + parameters: 12_000_000_000, + supported_precisions: vec![Precision::Fp16, Precision::Bf16], + recommended_processor: ProcessorType::Gpu(crate::processor::GpuVariant::default()), + context_length: None, + input_schema: None, + output_schema: None, + license: "FLUX.1-dev Non-Commercial".to_string(), + provider: "Black Forest Labs".to_string(), + version: "1.0".to_string(), + is_public: true, + owner: None, + }, + + // ===== Speech Models ===== + ModelInfo { + id: ModelId::from_alias("whisper-large-v3"), + name: "Whisper Large v3".to_string(), + description: "OpenAI Whisper - Speech recognition".to_string(), + category: ModelCategory::SpeechToText, + cid: "QmWhisperLargeV3_placeholder".to_string(), + format: ModelFormat::SafeTensors, + size_bytes: 3_100_000_000, + parameters: 1_550_000_000, + supported_precisions: vec![Precision::Fp32, Precision::Fp16], + recommended_processor: ProcessorType::Gpu(crate::processor::GpuVariant::default()), + context_length: None, + input_schema: None, + output_schema: None, + license: "MIT".to_string(), + provider: "OpenAI".to_string(), + version: "3.0".to_string(), + is_public: true, + owner: None, + }, + + // ===== Multi-Modal Models ===== + ModelInfo { + id: ModelId::from_alias("llava-1.5-13b"), + name: "LLaVA 1.5 13B".to_string(), + description: "Large Language and Vision Assistant".to_string(), + category: ModelCategory::MultiModal, + cid: "QmLLaVA15_13B_placeholder".to_string(), + format: ModelFormat::SafeTensors, + size_bytes: 26_000_000_000, + parameters: 13_000_000_000, + supported_precisions: vec![Precision::Fp16, Precision::Int8], + recommended_processor: ProcessorType::Gpu(crate::processor::GpuVariant::default()), + context_length: Some(4096), + input_schema: None, + output_schema: None, + license: "Llama 2".to_string(), + provider: "LLaVA".to_string(), + version: "1.5".to_string(), + is_public: true, + owner: None, + }, + ]; + + let mut models = self.models.write(); + let mut aliases = self.aliases.write(); + + for model in default_models { + let alias = model.id.0.clone(); + let cid = model.cid.clone(); + + models.insert(alias.clone(), model.clone()); + models.insert(cid.clone(), model); + aliases.insert(alias, cid); + } + } + + /// Resolves a model ID (alias or CID) to model info. + pub fn resolve(&self, id: &str) -> Result { + let models = self.models.read(); + + // Try direct lookup + if let Some(model) = models.get(id) { + return Ok(model.clone()); + } + + // Try alias lookup + let aliases = self.aliases.read(); + if let Some(cid) = aliases.get(id) { + if let Some(model) = models.get(cid) { + return Ok(model.clone()); + } + } + + Err(ComputeError::ModelNotFound(id.to_string())) + } + + /// Registers a custom model. + pub fn register(&self, model: ModelInfo) -> Result<(), ComputeError> { + let mut models = self.models.write(); + let mut aliases = self.aliases.write(); + + let id = model.id.0.clone(); + let cid = model.cid.clone(); + + // Register by both ID and CID + models.insert(id.clone(), model.clone()); + models.insert(cid.clone(), model); + + if !id.starts_with("Qm") && !id.starts_with("bafy") { + aliases.insert(id, cid); + } + + Ok(()) + } + + /// Lists all available models. + pub fn list(&self) -> Vec { + let models = self.models.read(); + let aliases = self.aliases.read(); + + // Return only alias entries to avoid duplicates + aliases + .keys() + .filter_map(|alias| models.get(alias).cloned()) + .collect() + } + + /// Lists models by category. + pub fn list_by_category(&self, category: ModelCategory) -> Vec { + self.list() + .into_iter() + .filter(|m| m.category == category) + .collect() + } + + /// Searches models by name/description. + pub fn search(&self, query: &str) -> Vec { + let query_lower = query.to_lowercase(); + self.list() + .into_iter() + .filter(|m| { + m.name.to_lowercase().contains(&query_lower) + || m.description.to_lowercase().contains(&query_lower) + || m.id.0.to_lowercase().contains(&query_lower) + }) + .collect() + } + + /// Gets model by alias. + pub fn get_by_alias(&self, alias: &str) -> Option { + self.resolve(alias).ok() + } + + /// Checks if a model exists. + pub fn exists(&self, id: &str) -> bool { + self.resolve(id).is_ok() + } +} + +impl Default for ModelRegistry { + fn default() -> Self { + Self::new() + } +} + +/// Model upload request. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ModelUploadRequest { + /// Model name. + pub name: String, + /// Model description. + pub description: Option, + /// Model category. + pub category: ModelCategory, + /// Model format. + pub format: ModelFormat, + /// Model file data (bytes). + #[serde(skip)] + pub data: Vec, + /// Optional alias (must be unique). + pub alias: Option, + /// Is public. + pub is_public: bool, + /// License. + pub license: Option, +} + +/// Model upload response. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ModelUploadResponse { + /// Assigned model ID. + pub model_id: ModelId, + /// Storage CID. + pub cid: String, + /// Size in bytes. + pub size_bytes: u64, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_model_registry() { + let registry = ModelRegistry::new(); + + // Test default models exist + let llama = registry.resolve("llama-3-70b").unwrap(); + assert_eq!(llama.parameters, 70_000_000_000); + assert_eq!(llama.category, ModelCategory::Llm); + } + + #[test] + fn test_model_list() { + let registry = ModelRegistry::new(); + let models = registry.list(); + assert!(!models.is_empty()); + } + + #[test] + fn test_model_search() { + let registry = ModelRegistry::new(); + let results = registry.search("llama"); + assert!(!results.is_empty()); + assert!(results.iter().all(|m| m.name.to_lowercase().contains("llama"))); + } + + #[test] + fn test_model_by_category() { + let registry = ModelRegistry::new(); + let llms = registry.list_by_category(ModelCategory::Llm); + assert!(!llms.is_empty()); + assert!(llms.iter().all(|m| m.category == ModelCategory::Llm)); + } + + #[test] + fn test_custom_model() { + let registry = ModelRegistry::new(); + + let custom = ModelInfo::custom( + "QmCustomModel123", + "My Custom Model", + ModelCategory::Custom, + ModelFormat::Onnx, + 1_000_000, + [1u8; 32], + ); + + registry.register(custom).unwrap(); + + let resolved = registry.resolve("QmCustomModel123").unwrap(); + assert_eq!(resolved.name, "My Custom Model"); + } +} diff --git a/sdk/flutter/example/example.dart b/sdk/flutter/example/example.dart index 524e757..a349933 100644 --- a/sdk/flutter/example/example.dart +++ b/sdk/flutter/example/example.dart @@ -32,6 +32,15 @@ void main() async { // Example 5: Pricing and usage await pricingExample(client); + + // Example 6: List available models + await modelRegistryExample(client); + + // Example 7: Training a model + await trainingExample(client); + + // Example 8: Custom model upload + await customModelExample(client); } finally { // Always dispose client to release resources client.dispose(); @@ -176,3 +185,178 @@ Future pricingExample(SynorCompute client) async { print(' Total cost: \$${usage.totalCost.toStringAsFixed(4)}'); print(''); } + +/// Model registry example - list available models +Future modelRegistryExample(SynorCompute client) async { + print('=== Model Registry ==='); + + // List all available models + final allModels = await client.listModels(); + print('Total available models: ${allModels.length}'); + + // List only LLMs + final llms = await client.listModels(category: ModelCategory.llm); + print('\nAvailable LLMs:'); + for (final model in llms.take(5)) { + print(' ${model.id.padRight(20)} ${model.formattedParameters.padRight(8)} ' + '${model.name}'); + } + + // Search for a specific model + final searchResults = await client.searchModels('llama'); + print('\nSearch "llama": ${searchResults.length} results'); + + // Get specific model info + final modelInfo = await client.getModel('llama-3-70b'); + print('\nModel details for ${modelInfo.name}:'); + print(' Parameters: ${modelInfo.formattedParameters}'); + print(' Context length: ${modelInfo.contextLength}'); + print(' Format: ${modelInfo.format.value}'); + print(' Recommended processor: ${modelInfo.recommendedProcessor.value}'); + print(' License: ${modelInfo.license}'); + + // List embedding models + final embeddings = await client.listModels(category: ModelCategory.embedding); + print('\nAvailable embedding models:'); + for (final model in embeddings) { + print(' ${model.id} - ${model.name}'); + } + + // List image generation models + final imageGen = + await client.listModels(category: ModelCategory.imageGeneration); + print('\nAvailable image generation models:'); + for (final model in imageGen) { + print(' ${model.id} - ${model.name}'); + } + + print(''); +} + +/// Training example - train/fine-tune a model +Future trainingExample(SynorCompute client) async { + print('=== Model Training ==='); + + // Example: Fine-tune Llama 3 8B on custom dataset + print('Fine-tuning llama-3-8b on custom dataset...'); + + // Note: In practice, you'd upload your dataset first: + // final datasetCid = await client.uploadTensor(datasetTensor); + + final result = await client.fineTune( + baseModel: 'llama-3-8b', // Use model alias + datasetCid: 'QmYourDatasetCID', // Your uploaded dataset + outputAlias: 'my-custom-llama', // Optional: alias for trained model + options: TrainingOptions( + framework: MlFramework.pytorch, + epochs: 3, + batchSize: 8, + learningRate: 0.00002, + optimizer: 'adamw', + hyperparameters: { + 'weight_decay': 0.01, + 'warmup_steps': 100, + 'gradient_accumulation_steps': 4, + }, + checkpointEvery: 500, // Save checkpoint every 500 steps + processor: ProcessorType.gpu, + priority: Priority.high, + ), + ); + + if (result.isSuccess) { + final training = result.result!; + print('Training completed!'); + print(' New model CID: ${training.modelCid}'); + print(' Final loss: ${training.finalLoss.toStringAsFixed(4)}'); + print(' Duration: ${training.durationMs / 1000}s'); + print(' Cost: \$${training.cost.toStringAsFixed(4)}'); + print(' Metrics: ${training.metrics}'); + + // Now use your trained model for inference + print('\nUsing trained model for inference:'); + final inference = await client.inference( + training.modelCid, // Use the CID of your trained model + 'Hello, how are you?', + options: InferenceOptions(maxTokens: 50), + ); + print('Response: ${inference.result}'); + } else { + print('Training failed: ${result.error}'); + } + + print(''); + + // Example: Streaming training progress + print('Training with streaming progress...'); + await for (final progress in client.trainStream( + modelCid: 'llama-3-8b', + datasetCid: 'QmYourDatasetCID', + options: TrainingOptions(epochs: 1, batchSize: 16), + )) { + // Update UI with progress + stdout.write('\r${progress.progressText} - ' + '${progress.samplesPerSecond} samples/s'); + } + print('\nTraining complete!'); + + print(''); +} + +/// Custom model upload example +Future customModelExample(SynorCompute client) async { + print('=== Custom Model Upload ==='); + + // Example: Upload a custom ONNX model + // In practice, you'd read this from a file: + // final modelBytes = await File('my_model.onnx').readAsBytes(); + + // For demonstration, we'll show the API structure + print('To upload your own Python-trained model:'); + print(''' + 1. Train your model in Python: + + import torch + model = MyModel() + # ... train model ... + torch.onnx.export(model, dummy_input, "my_model.onnx") + + 2. Upload to Synor Compute: + + final modelBytes = await File('my_model.onnx').readAsBytes(); + final result = await client.uploadModel( + modelBytes, + ModelUploadOptions( + name: 'my-custom-model', + description: 'My custom trained model', + category: ModelCategory.custom, + format: ModelFormat.onnx, + alias: 'my-model', // Optional shortcut name + isPublic: false, // Keep private + license: 'Proprietary', + ), + ); + print('Uploaded! CID: \${result.cid}'); + + 3. Use for inference: + + final result = await client.inference( + result.cid, // or 'my-model' if you set an alias + 'Your input data', + ); +'''); + + // Supported model formats + print('Supported model formats:'); + for (final format in ModelFormat.values) { + print(' - ${format.value}'); + } + + // Supported categories + print('\nSupported model categories:'); + for (final category in ModelCategory.values) { + print(' - ${category.value}'); + } + + print(''); +} diff --git a/sdk/flutter/lib/src/client.dart b/sdk/flutter/lib/src/client.dart index 5e401a5..1ba8563 100644 --- a/sdk/flutter/lib/src/client.dart +++ b/sdk/flutter/lib/src/client.dart @@ -426,6 +426,192 @@ class SynorCompute { } } + // ==================== Model Registry ==================== + + /// List all available models. + Future> listModels({ModelCategory? category}) async { + _checkDisposed(); + + final params = { + if (category != null) 'category': category.value, + }; + + final response = await _get('/models', params); + final models = response['models'] as List; + return models + .map((m) => ModelInfo.fromJson(m as Map)) + .toList(); + } + + /// Get model info by ID or alias. + Future getModel(String modelId) async { + _checkDisposed(); + + final response = await _get('/models/$modelId'); + return ModelInfo.fromJson(response); + } + + /// Search models by query. + Future> searchModels(String query) async { + _checkDisposed(); + + final response = await _get('/models/search', {'q': query}); + final models = response['models'] as List; + return models + .map((m) => ModelInfo.fromJson(m as Map)) + .toList(); + } + + /// Upload a custom model. + Future uploadModel( + List modelData, + ModelUploadOptions options, + ) async { + _checkDisposed(); + + // For large files, use multipart upload + final uri = Uri.parse('${_config.baseUrl}/models/upload'); + final request = http.MultipartRequest('POST', uri) + ..headers.addAll(_headers) + ..fields.addAll(options.toJson().map((k, v) => MapEntry(k, v.toString()))) + ..files.add(http.MultipartFile.fromBytes( + 'model', + modelData, + filename: '${options.name}.${options.format.value}', + )); + + final streamedResponse = await _httpClient.send(request); + final response = await http.Response.fromStream(streamedResponse); + + if (response.statusCode != 200) { + throw SynorException( + 'Model upload failed', + statusCode: response.statusCode, + ); + } + + final json = jsonDecode(response.body) as Map; + return ModelUploadResult.fromJson(json); + } + + /// Delete a custom model (only owner can delete). + Future deleteModel(String modelId) async { + _checkDisposed(); + + await _delete('/models/$modelId'); + } + + // ==================== Training ==================== + + /// Train a model on a dataset. + /// + /// Example: + /// ```dart + /// final result = await client.train( + /// modelCid: 'QmBaseModelCID', // Base model to fine-tune + /// datasetCid: 'QmDatasetCID', // Training dataset CID + /// options: TrainingOptions( + /// framework: MlFramework.pytorch, + /// epochs: 10, + /// batchSize: 32, + /// learningRate: 0.0001, + /// ), + /// ); + /// print('Trained model CID: ${result.modelCid}'); + /// ``` + Future> train({ + required String modelCid, + required String datasetCid, + TrainingOptions? options, + }) async { + _checkDisposed(); + + final opts = options ?? const TrainingOptions(); + final body = { + 'operation': 'training', + 'model_cid': modelCid, + 'dataset_cid': datasetCid, + 'options': opts.toJson(), + }; + + return _submitAndWait( + body, + (result) => TrainingResult.fromJson(result as Map), + ); + } + + /// Stream training progress updates. + Stream trainStream({ + required String modelCid, + required String datasetCid, + TrainingOptions? options, + }) async* { + _checkDisposed(); + + final opts = options ?? const TrainingOptions(); + final body = { + 'operation': 'training', + 'model_cid': modelCid, + 'dataset_cid': datasetCid, + 'options': { + ...opts.toJson(), + 'stream': true, + }, + }; + + final request = http.Request('POST', Uri.parse('${_config.baseUrl}/train/stream')) + ..headers.addAll(_headers) + ..body = jsonEncode(body); + + final streamedResponse = await _httpClient.send(request); + + if (streamedResponse.statusCode != 200) { + throw SynorException( + 'Training stream failed', + statusCode: streamedResponse.statusCode, + ); + } + + await for (final chunk in streamedResponse.stream.transform(utf8.decoder)) { + for (final line in chunk.split('\n')) { + if (line.startsWith('data: ')) { + final data = line.substring(6); + if (data == '[DONE]') return; + try { + final json = jsonDecode(data) as Map; + yield TrainingProgress.fromJson(json); + } catch (e) { + // Skip malformed JSON + } + } + } + } + } + + /// Fine-tune a pre-trained model. + Future> fineTune({ + required String baseModel, + required String datasetCid, + String? outputAlias, + TrainingOptions? options, + }) async { + _checkDisposed(); + + final opts = options ?? const TrainingOptions(); + final body = { + 'operation': 'fine_tune', + 'base_model': baseModel, + 'dataset_cid': datasetCid, + if (outputAlias != null) 'output_alias': outputAlias, + 'options': opts.toJson(), + }; + + return _submitAndWait( + body, + (result) => TrainingResult.fromJson(result as Map), + ); + } + // Internal HTTP methods Future> _submitAndWait( diff --git a/sdk/flutter/lib/src/types.dart b/sdk/flutter/lib/src/types.dart index f4c934d..c692b91 100644 --- a/sdk/flutter/lib/src/types.dart +++ b/sdk/flutter/lib/src/types.dart @@ -345,6 +345,312 @@ class UsageStats { } } +/// Model category. +enum ModelCategory { + llm('llm'), + embedding('embedding'), + imageClassification('image_classification'), + objectDetection('object_detection'), + segmentation('segmentation'), + imageGeneration('image_generation'), + speechToText('speech_to_text'), + textToSpeech('text_to_speech'), + videoGeneration('video_generation'), + multiModal('multi_modal'), + custom('custom'); + + const ModelCategory(this.value); + final String value; + + static ModelCategory fromString(String s) => + ModelCategory.values.firstWhere((c) => c.value == s, orElse: () => custom); +} + +/// Model format. +enum ModelFormat { + onnx('onnx'), + pytorch('pytorch'), + torchScript('torchscript'), + tensorflow('tensorflow'), + tfLite('tflite'), + safeTensors('safetensors'), + gguf('gguf'), + ggml('ggml'), + custom('custom'); + + const ModelFormat(this.value); + final String value; + + static ModelFormat fromString(String s) => + ModelFormat.values.firstWhere((f) => f.value == s, orElse: () => custom); +} + +/// ML framework for training. +enum MlFramework { + pytorch('pytorch'), + tensorflow('tensorflow'), + jax('jax'), + onnx('onnx'); + + const MlFramework(this.value); + final String value; + + static MlFramework fromString(String s) => + MlFramework.values.firstWhere((f) => f.value == s, orElse: () => pytorch); +} + +/// Model information. +class ModelInfo { + final String id; + final String name; + final String description; + final ModelCategory category; + final String cid; + final ModelFormat format; + final int sizeBytes; + final int parameters; + final List supportedPrecisions; + final ProcessorType recommendedProcessor; + final int? contextLength; + final String license; + final String provider; + final String version; + final bool isPublic; + + const ModelInfo({ + required this.id, + required this.name, + required this.description, + required this.category, + required this.cid, + required this.format, + required this.sizeBytes, + required this.parameters, + required this.supportedPrecisions, + required this.recommendedProcessor, + this.contextLength, + required this.license, + required this.provider, + required this.version, + required this.isPublic, + }); + + factory ModelInfo.fromJson(Map json) => ModelInfo( + id: json['id'] as String, + name: json['name'] as String, + description: json['description'] as String, + category: ModelCategory.fromString(json['category'] as String), + cid: json['cid'] as String, + format: ModelFormat.fromString(json['format'] as String), + sizeBytes: json['size_bytes'] as int, + parameters: json['parameters'] as int, + supportedPrecisions: (json['supported_precisions'] as List) + .map((p) => Precision.fromString(p as String)) + .toList(), + recommendedProcessor: + ProcessorType.fromString(json['recommended_processor'] as String), + contextLength: json['context_length'] as int?, + license: json['license'] as String, + provider: json['provider'] as String, + version: json['version'] as String, + isPublic: json['is_public'] as bool? ?? true, + ); + + /// Format parameters for display (e.g., "70B", "7B"). + String get formattedParameters { + if (parameters >= 1000000000000) { + return '${(parameters / 1e12).toStringAsFixed(1)}T'; + } else if (parameters >= 1000000000) { + return '${(parameters / 1e9).toStringAsFixed(1)}B'; + } else if (parameters >= 1000000) { + return '${(parameters / 1e6).toStringAsFixed(1)}M'; + } + return '$parameters'; + } +} + +/// Training options. +class TrainingOptions { + final MlFramework framework; + final int epochs; + final int batchSize; + final double learningRate; + final String? optimizer; + final Map? hyperparameters; + final bool distributed; + final int? checkpointEvery; + final ProcessorType? processor; + final Priority? priority; + + const TrainingOptions({ + this.framework = MlFramework.pytorch, + this.epochs = 1, + this.batchSize = 32, + this.learningRate = 0.001, + this.optimizer, + this.hyperparameters, + this.distributed = false, + this.checkpointEvery, + this.processor, + this.priority, + }); + + Map toJson() => { + 'framework': framework.value, + 'epochs': epochs, + 'batch_size': batchSize, + 'learning_rate': learningRate, + if (optimizer != null) 'optimizer': optimizer, + if (hyperparameters != null) 'hyperparameters': hyperparameters, + 'distributed': distributed, + if (checkpointEvery != null) 'checkpoint_every': checkpointEvery, + if (processor != null) 'processor': processor!.value, + if (priority != null) 'priority': priority!.value, + }; +} + +/// Training job result. +class TrainingResult { + final String jobId; + final String modelCid; + final int epochs; + final double finalLoss; + final Map metrics; + final int durationMs; + final double cost; + final List? checkpointCids; + + const TrainingResult({ + required this.jobId, + required this.modelCid, + required this.epochs, + required this.finalLoss, + required this.metrics, + required this.durationMs, + required this.cost, + this.checkpointCids, + }); + + factory TrainingResult.fromJson(Map json) => TrainingResult( + jobId: json['job_id'] as String, + modelCid: json['model_cid'] as String, + epochs: json['epochs'] as int, + finalLoss: (json['final_loss'] as num).toDouble(), + metrics: (json['metrics'] as Map) + .map((k, v) => MapEntry(k, (v as num).toDouble())), + durationMs: json['duration_ms'] as int, + cost: (json['cost'] as num).toDouble(), + checkpointCids: (json['checkpoint_cids'] as List?)?.cast(), + ); +} + +/// Model upload options. +class ModelUploadOptions { + final String name; + final String? description; + final ModelCategory category; + final ModelFormat format; + final String? alias; + final bool isPublic; + final String? license; + + const ModelUploadOptions({ + required this.name, + this.description, + this.category = ModelCategory.custom, + this.format = ModelFormat.onnx, + this.alias, + this.isPublic = false, + this.license, + }); + + Map toJson() => { + 'name': name, + if (description != null) 'description': description, + 'category': category.value, + 'format': format.value, + if (alias != null) 'alias': alias, + 'is_public': isPublic, + if (license != null) 'license': license, + }; +} + +/// Model upload result. +class ModelUploadResult { + final String modelId; + final String cid; + final int sizeBytes; + + const ModelUploadResult({ + required this.modelId, + required this.cid, + required this.sizeBytes, + }); + + factory ModelUploadResult.fromJson(Map json) => + ModelUploadResult( + modelId: json['model_id'] as String, + cid: json['cid'] as String, + sizeBytes: json['size_bytes'] as int, + ); +} + +/// Training progress update. +class TrainingProgress { + final String jobId; + final int epoch; + final int totalEpochs; + final int step; + final int totalSteps; + final double loss; + final Map metrics; + final double? learningRate; + final double? gradientNorm; + final int samplesPerSecond; + final int? estimatedRemainingMs; + + const TrainingProgress({ + required this.jobId, + required this.epoch, + required this.totalEpochs, + required this.step, + required this.totalSteps, + required this.loss, + required this.metrics, + this.learningRate, + this.gradientNorm, + required this.samplesPerSecond, + this.estimatedRemainingMs, + }); + + factory TrainingProgress.fromJson(Map json) => + TrainingProgress( + jobId: json['job_id'] as String, + epoch: json['epoch'] as int, + totalEpochs: json['total_epochs'] as int, + step: json['step'] as int, + totalSteps: json['total_steps'] as int, + loss: (json['loss'] as num).toDouble(), + metrics: (json['metrics'] as Map?) + ?.map((k, v) => MapEntry(k, (v as num).toDouble())) ?? + {}, + learningRate: (json['learning_rate'] as num?)?.toDouble(), + gradientNorm: (json['gradient_norm'] as num?)?.toDouble(), + samplesPerSecond: json['samples_per_second'] as int? ?? 0, + estimatedRemainingMs: json['estimated_remaining_ms'] as int?, + ); + + /// Progress percentage (0.0 - 1.0). + double get progress { + if (totalSteps == 0) return 0.0; + return step / totalSteps; + } + + /// Formatted progress string. + String get progressText => + 'Epoch ${epoch + 1}/$totalEpochs - Step $step/$totalSteps - Loss: ${loss.toStringAsFixed(4)}'; +} + /// Exception thrown by Synor Compute operations class SynorException implements Exception { final String message; diff --git a/sdk/flutter/lib/synor_compute.dart b/sdk/flutter/lib/synor_compute.dart index 81f5228..dafc7b3 100644 --- a/sdk/flutter/lib/synor_compute.dart +++ b/sdk/flutter/lib/synor_compute.dart @@ -83,7 +83,18 @@ export 'src/types.dart' InferenceOptions, PricingInfo, UsageStats, - SynorException; + SynorException, + // Model types + ModelCategory, + ModelFormat, + MlFramework, + ModelInfo, + ModelUploadOptions, + ModelUploadResult, + // Training types + TrainingOptions, + TrainingResult, + TrainingProgress; export 'src/tensor.dart' show Tensor;