feat(compute): add model registry and training APIs
Adds comprehensive model management and training capabilities: synor-compute (Rust): - ModelRegistry with pre-registered popular models - LLMs: Llama 3/3.1, Mistral, Mixtral, Qwen, DeepSeek, Phi, CodeLlama - Embedding: BGE, E5 - Image: Stable Diffusion XL, FLUX.1 - Speech: Whisper - Multi-modal: LLaVA - ModelInfo with parameters, format, precision, context length - Custom model upload and registration - Model search by name/category Flutter SDK: - Model registry APIs: listModels, getModel, searchModels - Custom model upload with multipart upload - Training APIs: train(), fineTune(), trainStream() - TrainingOptions: framework, epochs, batch_size, learning_rate - TrainingProgress for real-time updates - ModelUploadOptions and ModelUploadResult Example code for: - Listing available models by category - Fine-tuning pre-trained models - Uploading custom Python/ONNX models - Streaming training progress This enables users to: 1. Use pre-registered models like 'llama-3-70b' 2. Upload their own custom models 3. Fine-tune models on custom datasets 4. Track training progress in real-time
This commit is contained in:
parent
62ec3c92da
commit
89fc542da4
7 changed files with 1293 additions and 1 deletions
|
|
@ -77,6 +77,18 @@ pub enum ComputeError {
|
|||
/// Internal error.
|
||||
#[error("Internal error: {0}")]
|
||||
Internal(String),
|
||||
|
||||
/// Model not found.
|
||||
#[error("Model not found: {0}")]
|
||||
ModelNotFound(String),
|
||||
|
||||
/// Model upload failed.
|
||||
#[error("Model upload failed: {0}")]
|
||||
ModelUploadFailed(String),
|
||||
|
||||
/// Invalid model format.
|
||||
#[error("Invalid model format: {0}")]
|
||||
InvalidModelFormat(String),
|
||||
}
|
||||
|
||||
impl From<bincode::Error> for ComputeError {
|
||||
|
|
|
|||
|
|
@ -52,6 +52,7 @@ pub mod device;
|
|||
pub mod error;
|
||||
pub mod market;
|
||||
pub mod memory;
|
||||
pub mod model;
|
||||
pub mod processor;
|
||||
pub mod scheduler;
|
||||
pub mod task;
|
||||
|
|
@ -77,6 +78,10 @@ pub use task::{
|
|||
ComputeTask, DecomposedWorkload, Task, TaskDecomposer, TaskId, TaskPriority, TaskResult,
|
||||
TaskStatus,
|
||||
};
|
||||
pub use model::{
|
||||
ModelCategory, ModelFormat, ModelId, ModelInfo, ModelRegistry, ModelUploadRequest,
|
||||
ModelUploadResponse,
|
||||
};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
|
|
|||
588
crates/synor-compute/src/model/mod.rs
Normal file
588
crates/synor-compute/src/model/mod.rs
Normal file
|
|
@ -0,0 +1,588 @@
|
|||
//! Model registry and management for Synor Compute.
|
||||
//!
|
||||
//! Provides:
|
||||
//! - Pre-registered popular models (LLMs, vision, audio)
|
||||
//! - Custom model uploads
|
||||
//! - Model format conversion
|
||||
//! - Model caching and warm-up
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use parking_lot::RwLock;
|
||||
|
||||
use crate::error::ComputeError;
|
||||
use crate::processor::{Precision, ProcessorType};
|
||||
|
||||
/// Model identifier (CID or alias).
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub struct ModelId(pub String);
|
||||
|
||||
impl ModelId {
|
||||
/// Creates from a CID.
|
||||
pub fn from_cid(cid: &str) -> Self {
|
||||
Self(cid.to_string())
|
||||
}
|
||||
|
||||
/// Creates from an alias.
|
||||
pub fn from_alias(alias: &str) -> Self {
|
||||
Self(alias.to_string())
|
||||
}
|
||||
|
||||
/// Checks if this is a CID (starts with Qm or bafy).
|
||||
pub fn is_cid(&self) -> bool {
|
||||
self.0.starts_with("Qm") || self.0.starts_with("bafy")
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ModelId {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&str> for ModelId {
|
||||
fn from(s: &str) -> Self {
|
||||
Self(s.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
/// Model format specification.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum ModelFormat {
|
||||
/// ONNX Runtime format.
|
||||
Onnx,
|
||||
/// PyTorch checkpoint (.pt, .pth).
|
||||
PyTorch,
|
||||
/// PyTorch TorchScript (.pt).
|
||||
TorchScript,
|
||||
/// TensorFlow SavedModel.
|
||||
TensorFlow,
|
||||
/// TensorFlow Lite.
|
||||
TFLite,
|
||||
/// SafeTensors format.
|
||||
SafeTensors,
|
||||
/// GGUF format (llama.cpp).
|
||||
Gguf,
|
||||
/// GGML format (legacy).
|
||||
Ggml,
|
||||
/// Custom binary format.
|
||||
Custom,
|
||||
}
|
||||
|
||||
/// Model category.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum ModelCategory {
|
||||
/// Large language models.
|
||||
Llm,
|
||||
/// Text embedding models.
|
||||
Embedding,
|
||||
/// Image classification.
|
||||
ImageClassification,
|
||||
/// Object detection.
|
||||
ObjectDetection,
|
||||
/// Image segmentation.
|
||||
Segmentation,
|
||||
/// Image generation (diffusion).
|
||||
ImageGeneration,
|
||||
/// Speech-to-text.
|
||||
SpeechToText,
|
||||
/// Text-to-speech.
|
||||
TextToSpeech,
|
||||
/// Video generation.
|
||||
VideoGeneration,
|
||||
/// Multi-modal (vision-language).
|
||||
MultiModal,
|
||||
/// Custom/other.
|
||||
Custom,
|
||||
}
|
||||
|
||||
/// Model metadata.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct ModelInfo {
|
||||
/// Model ID (alias or CID).
|
||||
pub id: ModelId,
|
||||
/// Human-readable name.
|
||||
pub name: String,
|
||||
/// Description.
|
||||
pub description: String,
|
||||
/// Model category.
|
||||
pub category: ModelCategory,
|
||||
/// Storage CID (actual location).
|
||||
pub cid: String,
|
||||
/// Model format.
|
||||
pub format: ModelFormat,
|
||||
/// Model size in bytes.
|
||||
pub size_bytes: u64,
|
||||
/// Parameter count.
|
||||
pub parameters: u64,
|
||||
/// Supported precisions.
|
||||
pub supported_precisions: Vec<Precision>,
|
||||
/// Recommended processor type.
|
||||
pub recommended_processor: ProcessorType,
|
||||
/// Context length (for LLMs).
|
||||
pub context_length: Option<u32>,
|
||||
/// Input schema (JSON Schema).
|
||||
pub input_schema: Option<String>,
|
||||
/// Output schema (JSON Schema).
|
||||
pub output_schema: Option<String>,
|
||||
/// License.
|
||||
pub license: String,
|
||||
/// Provider/author.
|
||||
pub provider: String,
|
||||
/// Version.
|
||||
pub version: String,
|
||||
/// Is public (anyone can use).
|
||||
pub is_public: bool,
|
||||
/// Owner address (for private models).
|
||||
pub owner: Option<[u8; 32]>,
|
||||
}
|
||||
|
||||
impl ModelInfo {
|
||||
/// Creates a new LLM model info.
|
||||
pub fn llm(
|
||||
alias: &str,
|
||||
name: &str,
|
||||
cid: &str,
|
||||
parameters: u64,
|
||||
context_length: u32,
|
||||
) -> Self {
|
||||
Self {
|
||||
id: ModelId::from_alias(alias),
|
||||
name: name.to_string(),
|
||||
description: format!("{} - {} parameter LLM", name, format_params(parameters)),
|
||||
category: ModelCategory::Llm,
|
||||
cid: cid.to_string(),
|
||||
format: ModelFormat::SafeTensors,
|
||||
size_bytes: parameters * 2, // ~2 bytes per param in fp16
|
||||
parameters,
|
||||
supported_precisions: vec![Precision::Fp16, Precision::Bf16, Precision::Int8, Precision::Int4],
|
||||
recommended_processor: ProcessorType::Lpu,
|
||||
context_length: Some(context_length),
|
||||
input_schema: None,
|
||||
output_schema: None,
|
||||
license: "Apache-2.0".to_string(),
|
||||
provider: "Synor".to_string(),
|
||||
version: "1.0".to_string(),
|
||||
is_public: true,
|
||||
owner: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a custom model info.
|
||||
pub fn custom(
|
||||
cid: &str,
|
||||
name: &str,
|
||||
category: ModelCategory,
|
||||
format: ModelFormat,
|
||||
size_bytes: u64,
|
||||
owner: [u8; 32],
|
||||
) -> Self {
|
||||
Self {
|
||||
id: ModelId::from_cid(cid),
|
||||
name: name.to_string(),
|
||||
description: format!("Custom model: {}", name),
|
||||
category,
|
||||
cid: cid.to_string(),
|
||||
format,
|
||||
size_bytes,
|
||||
parameters: 0,
|
||||
supported_precisions: vec![Precision::Fp32, Precision::Fp16],
|
||||
recommended_processor: ProcessorType::Cpu(crate::processor::CpuVariant::default()),
|
||||
context_length: None,
|
||||
input_schema: None,
|
||||
output_schema: None,
|
||||
license: "Custom".to_string(),
|
||||
provider: "User".to_string(),
|
||||
version: "1.0".to_string(),
|
||||
is_public: false,
|
||||
owner: Some(owner),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Format parameter count for display.
|
||||
fn format_params(params: u64) -> String {
|
||||
if params >= 1_000_000_000_000 {
|
||||
format!("{:.1}T", params as f64 / 1e12)
|
||||
} else if params >= 1_000_000_000 {
|
||||
format!("{:.1}B", params as f64 / 1e9)
|
||||
} else if params >= 1_000_000 {
|
||||
format!("{:.1}M", params as f64 / 1e6)
|
||||
} else {
|
||||
format!("{}", params)
|
||||
}
|
||||
}
|
||||
|
||||
/// Model registry.
|
||||
pub struct ModelRegistry {
|
||||
/// Registered models by ID/alias.
|
||||
models: RwLock<HashMap<String, ModelInfo>>,
|
||||
/// Alias to CID mapping.
|
||||
aliases: RwLock<HashMap<String, String>>,
|
||||
}
|
||||
|
||||
impl ModelRegistry {
|
||||
/// Creates a new model registry with default models.
|
||||
pub fn new() -> Self {
|
||||
let registry = Self {
|
||||
models: RwLock::new(HashMap::new()),
|
||||
aliases: RwLock::new(HashMap::new()),
|
||||
};
|
||||
registry.register_default_models();
|
||||
registry
|
||||
}
|
||||
|
||||
/// Registers default/popular models.
|
||||
fn register_default_models(&self) {
|
||||
let default_models = vec![
|
||||
// ===== LLMs =====
|
||||
// Llama 3 family
|
||||
ModelInfo::llm("llama-3-8b", "Llama 3 8B", "QmLlama3_8B_placeholder", 8_000_000_000, 8192),
|
||||
ModelInfo::llm("llama-3-70b", "Llama 3 70B", "QmLlama3_70B_placeholder", 70_000_000_000, 8192),
|
||||
ModelInfo::llm("llama-3.1-8b", "Llama 3.1 8B", "QmLlama31_8B_placeholder", 8_000_000_000, 128000),
|
||||
ModelInfo::llm("llama-3.1-70b", "Llama 3.1 70B", "QmLlama31_70B_placeholder", 70_000_000_000, 128000),
|
||||
ModelInfo::llm("llama-3.1-405b", "Llama 3.1 405B", "QmLlama31_405B_placeholder", 405_000_000_000, 128000),
|
||||
|
||||
// Mistral family
|
||||
ModelInfo::llm("mistral-7b", "Mistral 7B", "QmMistral7B_placeholder", 7_000_000_000, 32768),
|
||||
ModelInfo::llm("mixtral-8x7b", "Mixtral 8x7B", "QmMixtral8x7B_placeholder", 46_000_000_000, 32768),
|
||||
ModelInfo::llm("mixtral-8x22b", "Mixtral 8x22B", "QmMixtral8x22B_placeholder", 176_000_000_000, 65536),
|
||||
|
||||
// Qwen family
|
||||
ModelInfo::llm("qwen-2.5-7b", "Qwen 2.5 7B", "QmQwen25_7B_placeholder", 7_000_000_000, 128000),
|
||||
ModelInfo::llm("qwen-2.5-72b", "Qwen 2.5 72B", "QmQwen25_72B_placeholder", 72_000_000_000, 128000),
|
||||
|
||||
// DeepSeek family
|
||||
ModelInfo::llm("deepseek-v2", "DeepSeek V2", "QmDeepSeekV2_placeholder", 236_000_000_000, 128000),
|
||||
ModelInfo::llm("deepseek-coder-33b", "DeepSeek Coder 33B", "QmDeepSeekCoder33B_placeholder", 33_000_000_000, 16384),
|
||||
|
||||
// Phi family (small/efficient)
|
||||
ModelInfo::llm("phi-3-mini", "Phi 3 Mini", "QmPhi3Mini_placeholder", 3_800_000_000, 128000),
|
||||
ModelInfo::llm("phi-3-medium", "Phi 3 Medium", "QmPhi3Medium_placeholder", 14_000_000_000, 128000),
|
||||
|
||||
// Code models
|
||||
ModelInfo::llm("codellama-34b", "Code Llama 34B", "QmCodeLlama34B_placeholder", 34_000_000_000, 16384),
|
||||
ModelInfo::llm("starcoder2-15b", "StarCoder2 15B", "QmStarCoder2_15B_placeholder", 15_000_000_000, 16384),
|
||||
|
||||
// ===== Embedding Models =====
|
||||
ModelInfo {
|
||||
id: ModelId::from_alias("bge-large"),
|
||||
name: "BGE Large".to_string(),
|
||||
description: "BAAI General Embedding - Large".to_string(),
|
||||
category: ModelCategory::Embedding,
|
||||
cid: "QmBGELarge_placeholder".to_string(),
|
||||
format: ModelFormat::SafeTensors,
|
||||
size_bytes: 1_300_000_000,
|
||||
parameters: 335_000_000,
|
||||
supported_precisions: vec![Precision::Fp32, Precision::Fp16],
|
||||
recommended_processor: ProcessorType::Gpu(crate::processor::GpuVariant::default()),
|
||||
context_length: Some(512),
|
||||
input_schema: None,
|
||||
output_schema: None,
|
||||
license: "MIT".to_string(),
|
||||
provider: "BAAI".to_string(),
|
||||
version: "1.5".to_string(),
|
||||
is_public: true,
|
||||
owner: None,
|
||||
},
|
||||
ModelInfo {
|
||||
id: ModelId::from_alias("e5-large-v2"),
|
||||
name: "E5 Large v2".to_string(),
|
||||
description: "Microsoft E5 Embedding - Large".to_string(),
|
||||
category: ModelCategory::Embedding,
|
||||
cid: "QmE5LargeV2_placeholder".to_string(),
|
||||
format: ModelFormat::SafeTensors,
|
||||
size_bytes: 1_300_000_000,
|
||||
parameters: 335_000_000,
|
||||
supported_precisions: vec![Precision::Fp32, Precision::Fp16],
|
||||
recommended_processor: ProcessorType::Gpu(crate::processor::GpuVariant::default()),
|
||||
context_length: Some(512),
|
||||
input_schema: None,
|
||||
output_schema: None,
|
||||
license: "MIT".to_string(),
|
||||
provider: "Microsoft".to_string(),
|
||||
version: "2.0".to_string(),
|
||||
is_public: true,
|
||||
owner: None,
|
||||
},
|
||||
|
||||
// ===== Vision Models =====
|
||||
ModelInfo {
|
||||
id: ModelId::from_alias("stable-diffusion-xl"),
|
||||
name: "Stable Diffusion XL".to_string(),
|
||||
description: "SDXL 1.0 - High quality image generation".to_string(),
|
||||
category: ModelCategory::ImageGeneration,
|
||||
cid: "QmSDXL_placeholder".to_string(),
|
||||
format: ModelFormat::SafeTensors,
|
||||
size_bytes: 6_900_000_000,
|
||||
parameters: 2_600_000_000,
|
||||
supported_precisions: vec![Precision::Fp16, Precision::Bf16],
|
||||
recommended_processor: ProcessorType::Gpu(crate::processor::GpuVariant::default()),
|
||||
context_length: None,
|
||||
input_schema: None,
|
||||
output_schema: None,
|
||||
license: "CreativeML Open RAIL++-M".to_string(),
|
||||
provider: "Stability AI".to_string(),
|
||||
version: "1.0".to_string(),
|
||||
is_public: true,
|
||||
owner: None,
|
||||
},
|
||||
ModelInfo {
|
||||
id: ModelId::from_alias("flux-1-dev"),
|
||||
name: "FLUX.1 Dev".to_string(),
|
||||
description: "FLUX.1 Development - State of the art image generation".to_string(),
|
||||
category: ModelCategory::ImageGeneration,
|
||||
cid: "QmFLUX1Dev_placeholder".to_string(),
|
||||
format: ModelFormat::SafeTensors,
|
||||
size_bytes: 23_000_000_000,
|
||||
parameters: 12_000_000_000,
|
||||
supported_precisions: vec![Precision::Fp16, Precision::Bf16],
|
||||
recommended_processor: ProcessorType::Gpu(crate::processor::GpuVariant::default()),
|
||||
context_length: None,
|
||||
input_schema: None,
|
||||
output_schema: None,
|
||||
license: "FLUX.1-dev Non-Commercial".to_string(),
|
||||
provider: "Black Forest Labs".to_string(),
|
||||
version: "1.0".to_string(),
|
||||
is_public: true,
|
||||
owner: None,
|
||||
},
|
||||
|
||||
// ===== Speech Models =====
|
||||
ModelInfo {
|
||||
id: ModelId::from_alias("whisper-large-v3"),
|
||||
name: "Whisper Large v3".to_string(),
|
||||
description: "OpenAI Whisper - Speech recognition".to_string(),
|
||||
category: ModelCategory::SpeechToText,
|
||||
cid: "QmWhisperLargeV3_placeholder".to_string(),
|
||||
format: ModelFormat::SafeTensors,
|
||||
size_bytes: 3_100_000_000,
|
||||
parameters: 1_550_000_000,
|
||||
supported_precisions: vec![Precision::Fp32, Precision::Fp16],
|
||||
recommended_processor: ProcessorType::Gpu(crate::processor::GpuVariant::default()),
|
||||
context_length: None,
|
||||
input_schema: None,
|
||||
output_schema: None,
|
||||
license: "MIT".to_string(),
|
||||
provider: "OpenAI".to_string(),
|
||||
version: "3.0".to_string(),
|
||||
is_public: true,
|
||||
owner: None,
|
||||
},
|
||||
|
||||
// ===== Multi-Modal Models =====
|
||||
ModelInfo {
|
||||
id: ModelId::from_alias("llava-1.5-13b"),
|
||||
name: "LLaVA 1.5 13B".to_string(),
|
||||
description: "Large Language and Vision Assistant".to_string(),
|
||||
category: ModelCategory::MultiModal,
|
||||
cid: "QmLLaVA15_13B_placeholder".to_string(),
|
||||
format: ModelFormat::SafeTensors,
|
||||
size_bytes: 26_000_000_000,
|
||||
parameters: 13_000_000_000,
|
||||
supported_precisions: vec![Precision::Fp16, Precision::Int8],
|
||||
recommended_processor: ProcessorType::Gpu(crate::processor::GpuVariant::default()),
|
||||
context_length: Some(4096),
|
||||
input_schema: None,
|
||||
output_schema: None,
|
||||
license: "Llama 2".to_string(),
|
||||
provider: "LLaVA".to_string(),
|
||||
version: "1.5".to_string(),
|
||||
is_public: true,
|
||||
owner: None,
|
||||
},
|
||||
];
|
||||
|
||||
let mut models = self.models.write();
|
||||
let mut aliases = self.aliases.write();
|
||||
|
||||
for model in default_models {
|
||||
let alias = model.id.0.clone();
|
||||
let cid = model.cid.clone();
|
||||
|
||||
models.insert(alias.clone(), model.clone());
|
||||
models.insert(cid.clone(), model);
|
||||
aliases.insert(alias, cid);
|
||||
}
|
||||
}
|
||||
|
||||
/// Resolves a model ID (alias or CID) to model info.
|
||||
pub fn resolve(&self, id: &str) -> Result<ModelInfo, ComputeError> {
|
||||
let models = self.models.read();
|
||||
|
||||
// Try direct lookup
|
||||
if let Some(model) = models.get(id) {
|
||||
return Ok(model.clone());
|
||||
}
|
||||
|
||||
// Try alias lookup
|
||||
let aliases = self.aliases.read();
|
||||
if let Some(cid) = aliases.get(id) {
|
||||
if let Some(model) = models.get(cid) {
|
||||
return Ok(model.clone());
|
||||
}
|
||||
}
|
||||
|
||||
Err(ComputeError::ModelNotFound(id.to_string()))
|
||||
}
|
||||
|
||||
/// Registers a custom model.
|
||||
pub fn register(&self, model: ModelInfo) -> Result<(), ComputeError> {
|
||||
let mut models = self.models.write();
|
||||
let mut aliases = self.aliases.write();
|
||||
|
||||
let id = model.id.0.clone();
|
||||
let cid = model.cid.clone();
|
||||
|
||||
// Register by both ID and CID
|
||||
models.insert(id.clone(), model.clone());
|
||||
models.insert(cid.clone(), model);
|
||||
|
||||
if !id.starts_with("Qm") && !id.starts_with("bafy") {
|
||||
aliases.insert(id, cid);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Lists all available models.
|
||||
pub fn list(&self) -> Vec<ModelInfo> {
|
||||
let models = self.models.read();
|
||||
let aliases = self.aliases.read();
|
||||
|
||||
// Return only alias entries to avoid duplicates
|
||||
aliases
|
||||
.keys()
|
||||
.filter_map(|alias| models.get(alias).cloned())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Lists models by category.
|
||||
pub fn list_by_category(&self, category: ModelCategory) -> Vec<ModelInfo> {
|
||||
self.list()
|
||||
.into_iter()
|
||||
.filter(|m| m.category == category)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Searches models by name/description.
|
||||
pub fn search(&self, query: &str) -> Vec<ModelInfo> {
|
||||
let query_lower = query.to_lowercase();
|
||||
self.list()
|
||||
.into_iter()
|
||||
.filter(|m| {
|
||||
m.name.to_lowercase().contains(&query_lower)
|
||||
|| m.description.to_lowercase().contains(&query_lower)
|
||||
|| m.id.0.to_lowercase().contains(&query_lower)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Gets model by alias.
|
||||
pub fn get_by_alias(&self, alias: &str) -> Option<ModelInfo> {
|
||||
self.resolve(alias).ok()
|
||||
}
|
||||
|
||||
/// Checks if a model exists.
|
||||
pub fn exists(&self, id: &str) -> bool {
|
||||
self.resolve(id).is_ok()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ModelRegistry {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Model upload request.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct ModelUploadRequest {
|
||||
/// Model name.
|
||||
pub name: String,
|
||||
/// Model description.
|
||||
pub description: Option<String>,
|
||||
/// Model category.
|
||||
pub category: ModelCategory,
|
||||
/// Model format.
|
||||
pub format: ModelFormat,
|
||||
/// Model file data (bytes).
|
||||
#[serde(skip)]
|
||||
pub data: Vec<u8>,
|
||||
/// Optional alias (must be unique).
|
||||
pub alias: Option<String>,
|
||||
/// Is public.
|
||||
pub is_public: bool,
|
||||
/// License.
|
||||
pub license: Option<String>,
|
||||
}
|
||||
|
||||
/// Model upload response.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct ModelUploadResponse {
|
||||
/// Assigned model ID.
|
||||
pub model_id: ModelId,
|
||||
/// Storage CID.
|
||||
pub cid: String,
|
||||
/// Size in bytes.
|
||||
pub size_bytes: u64,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_model_registry() {
|
||||
let registry = ModelRegistry::new();
|
||||
|
||||
// Test default models exist
|
||||
let llama = registry.resolve("llama-3-70b").unwrap();
|
||||
assert_eq!(llama.parameters, 70_000_000_000);
|
||||
assert_eq!(llama.category, ModelCategory::Llm);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_list() {
|
||||
let registry = ModelRegistry::new();
|
||||
let models = registry.list();
|
||||
assert!(!models.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_search() {
|
||||
let registry = ModelRegistry::new();
|
||||
let results = registry.search("llama");
|
||||
assert!(!results.is_empty());
|
||||
assert!(results.iter().all(|m| m.name.to_lowercase().contains("llama")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_by_category() {
|
||||
let registry = ModelRegistry::new();
|
||||
let llms = registry.list_by_category(ModelCategory::Llm);
|
||||
assert!(!llms.is_empty());
|
||||
assert!(llms.iter().all(|m| m.category == ModelCategory::Llm));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_custom_model() {
|
||||
let registry = ModelRegistry::new();
|
||||
|
||||
let custom = ModelInfo::custom(
|
||||
"QmCustomModel123",
|
||||
"My Custom Model",
|
||||
ModelCategory::Custom,
|
||||
ModelFormat::Onnx,
|
||||
1_000_000,
|
||||
[1u8; 32],
|
||||
);
|
||||
|
||||
registry.register(custom).unwrap();
|
||||
|
||||
let resolved = registry.resolve("QmCustomModel123").unwrap();
|
||||
assert_eq!(resolved.name, "My Custom Model");
|
||||
}
|
||||
}
|
||||
|
|
@ -32,6 +32,15 @@ void main() async {
|
|||
|
||||
// Example 5: Pricing and usage
|
||||
await pricingExample(client);
|
||||
|
||||
// Example 6: List available models
|
||||
await modelRegistryExample(client);
|
||||
|
||||
// Example 7: Training a model
|
||||
await trainingExample(client);
|
||||
|
||||
// Example 8: Custom model upload
|
||||
await customModelExample(client);
|
||||
} finally {
|
||||
// Always dispose client to release resources
|
||||
client.dispose();
|
||||
|
|
@ -176,3 +185,178 @@ Future<void> pricingExample(SynorCompute client) async {
|
|||
print(' Total cost: \$${usage.totalCost.toStringAsFixed(4)}');
|
||||
print('');
|
||||
}
|
||||
|
||||
/// Model registry example - list available models
|
||||
Future<void> modelRegistryExample(SynorCompute client) async {
|
||||
print('=== Model Registry ===');
|
||||
|
||||
// List all available models
|
||||
final allModels = await client.listModels();
|
||||
print('Total available models: ${allModels.length}');
|
||||
|
||||
// List only LLMs
|
||||
final llms = await client.listModels(category: ModelCategory.llm);
|
||||
print('\nAvailable LLMs:');
|
||||
for (final model in llms.take(5)) {
|
||||
print(' ${model.id.padRight(20)} ${model.formattedParameters.padRight(8)} '
|
||||
'${model.name}');
|
||||
}
|
||||
|
||||
// Search for a specific model
|
||||
final searchResults = await client.searchModels('llama');
|
||||
print('\nSearch "llama": ${searchResults.length} results');
|
||||
|
||||
// Get specific model info
|
||||
final modelInfo = await client.getModel('llama-3-70b');
|
||||
print('\nModel details for ${modelInfo.name}:');
|
||||
print(' Parameters: ${modelInfo.formattedParameters}');
|
||||
print(' Context length: ${modelInfo.contextLength}');
|
||||
print(' Format: ${modelInfo.format.value}');
|
||||
print(' Recommended processor: ${modelInfo.recommendedProcessor.value}');
|
||||
print(' License: ${modelInfo.license}');
|
||||
|
||||
// List embedding models
|
||||
final embeddings = await client.listModels(category: ModelCategory.embedding);
|
||||
print('\nAvailable embedding models:');
|
||||
for (final model in embeddings) {
|
||||
print(' ${model.id} - ${model.name}');
|
||||
}
|
||||
|
||||
// List image generation models
|
||||
final imageGen =
|
||||
await client.listModels(category: ModelCategory.imageGeneration);
|
||||
print('\nAvailable image generation models:');
|
||||
for (final model in imageGen) {
|
||||
print(' ${model.id} - ${model.name}');
|
||||
}
|
||||
|
||||
print('');
|
||||
}
|
||||
|
||||
/// Training example - train/fine-tune a model
|
||||
Future<void> trainingExample(SynorCompute client) async {
|
||||
print('=== Model Training ===');
|
||||
|
||||
// Example: Fine-tune Llama 3 8B on custom dataset
|
||||
print('Fine-tuning llama-3-8b on custom dataset...');
|
||||
|
||||
// Note: In practice, you'd upload your dataset first:
|
||||
// final datasetCid = await client.uploadTensor(datasetTensor);
|
||||
|
||||
final result = await client.fineTune(
|
||||
baseModel: 'llama-3-8b', // Use model alias
|
||||
datasetCid: 'QmYourDatasetCID', // Your uploaded dataset
|
||||
outputAlias: 'my-custom-llama', // Optional: alias for trained model
|
||||
options: TrainingOptions(
|
||||
framework: MlFramework.pytorch,
|
||||
epochs: 3,
|
||||
batchSize: 8,
|
||||
learningRate: 0.00002,
|
||||
optimizer: 'adamw',
|
||||
hyperparameters: {
|
||||
'weight_decay': 0.01,
|
||||
'warmup_steps': 100,
|
||||
'gradient_accumulation_steps': 4,
|
||||
},
|
||||
checkpointEvery: 500, // Save checkpoint every 500 steps
|
||||
processor: ProcessorType.gpu,
|
||||
priority: Priority.high,
|
||||
),
|
||||
);
|
||||
|
||||
if (result.isSuccess) {
|
||||
final training = result.result!;
|
||||
print('Training completed!');
|
||||
print(' New model CID: ${training.modelCid}');
|
||||
print(' Final loss: ${training.finalLoss.toStringAsFixed(4)}');
|
||||
print(' Duration: ${training.durationMs / 1000}s');
|
||||
print(' Cost: \$${training.cost.toStringAsFixed(4)}');
|
||||
print(' Metrics: ${training.metrics}');
|
||||
|
||||
// Now use your trained model for inference
|
||||
print('\nUsing trained model for inference:');
|
||||
final inference = await client.inference(
|
||||
training.modelCid, // Use the CID of your trained model
|
||||
'Hello, how are you?',
|
||||
options: InferenceOptions(maxTokens: 50),
|
||||
);
|
||||
print('Response: ${inference.result}');
|
||||
} else {
|
||||
print('Training failed: ${result.error}');
|
||||
}
|
||||
|
||||
print('');
|
||||
|
||||
// Example: Streaming training progress
|
||||
print('Training with streaming progress...');
|
||||
await for (final progress in client.trainStream(
|
||||
modelCid: 'llama-3-8b',
|
||||
datasetCid: 'QmYourDatasetCID',
|
||||
options: TrainingOptions(epochs: 1, batchSize: 16),
|
||||
)) {
|
||||
// Update UI with progress
|
||||
stdout.write('\r${progress.progressText} - '
|
||||
'${progress.samplesPerSecond} samples/s');
|
||||
}
|
||||
print('\nTraining complete!');
|
||||
|
||||
print('');
|
||||
}
|
||||
|
||||
/// Custom model upload example
|
||||
Future<void> customModelExample(SynorCompute client) async {
|
||||
print('=== Custom Model Upload ===');
|
||||
|
||||
// Example: Upload a custom ONNX model
|
||||
// In practice, you'd read this from a file:
|
||||
// final modelBytes = await File('my_model.onnx').readAsBytes();
|
||||
|
||||
// For demonstration, we'll show the API structure
|
||||
print('To upload your own Python-trained model:');
|
||||
print('''
|
||||
1. Train your model in Python:
|
||||
|
||||
import torch
|
||||
model = MyModel()
|
||||
# ... train model ...
|
||||
torch.onnx.export(model, dummy_input, "my_model.onnx")
|
||||
|
||||
2. Upload to Synor Compute:
|
||||
|
||||
final modelBytes = await File('my_model.onnx').readAsBytes();
|
||||
final result = await client.uploadModel(
|
||||
modelBytes,
|
||||
ModelUploadOptions(
|
||||
name: 'my-custom-model',
|
||||
description: 'My custom trained model',
|
||||
category: ModelCategory.custom,
|
||||
format: ModelFormat.onnx,
|
||||
alias: 'my-model', // Optional shortcut name
|
||||
isPublic: false, // Keep private
|
||||
license: 'Proprietary',
|
||||
),
|
||||
);
|
||||
print('Uploaded! CID: \${result.cid}');
|
||||
|
||||
3. Use for inference:
|
||||
|
||||
final result = await client.inference(
|
||||
result.cid, // or 'my-model' if you set an alias
|
||||
'Your input data',
|
||||
);
|
||||
''');
|
||||
|
||||
// Supported model formats
|
||||
print('Supported model formats:');
|
||||
for (final format in ModelFormat.values) {
|
||||
print(' - ${format.value}');
|
||||
}
|
||||
|
||||
// Supported categories
|
||||
print('\nSupported model categories:');
|
||||
for (final category in ModelCategory.values) {
|
||||
print(' - ${category.value}');
|
||||
}
|
||||
|
||||
print('');
|
||||
}
|
||||
|
|
|
|||
|
|
@ -426,6 +426,192 @@ class SynorCompute {
|
|||
}
|
||||
}
|
||||
|
||||
// ==================== Model Registry ====================
|
||||
|
||||
/// List all available models.
|
||||
Future<List<ModelInfo>> listModels({ModelCategory? category}) async {
|
||||
_checkDisposed();
|
||||
|
||||
final params = <String, String>{
|
||||
if (category != null) 'category': category.value,
|
||||
};
|
||||
|
||||
final response = await _get('/models', params);
|
||||
final models = response['models'] as List;
|
||||
return models
|
||||
.map((m) => ModelInfo.fromJson(m as Map<String, dynamic>))
|
||||
.toList();
|
||||
}
|
||||
|
||||
/// Get model info by ID or alias.
|
||||
Future<ModelInfo> getModel(String modelId) async {
|
||||
_checkDisposed();
|
||||
|
||||
final response = await _get('/models/$modelId');
|
||||
return ModelInfo.fromJson(response);
|
||||
}
|
||||
|
||||
/// Search models by query.
|
||||
Future<List<ModelInfo>> searchModels(String query) async {
|
||||
_checkDisposed();
|
||||
|
||||
final response = await _get('/models/search', {'q': query});
|
||||
final models = response['models'] as List;
|
||||
return models
|
||||
.map((m) => ModelInfo.fromJson(m as Map<String, dynamic>))
|
||||
.toList();
|
||||
}
|
||||
|
||||
/// Upload a custom model.
|
||||
Future<ModelUploadResult> uploadModel(
|
||||
List<int> modelData,
|
||||
ModelUploadOptions options,
|
||||
) async {
|
||||
_checkDisposed();
|
||||
|
||||
// For large files, use multipart upload
|
||||
final uri = Uri.parse('${_config.baseUrl}/models/upload');
|
||||
final request = http.MultipartRequest('POST', uri)
|
||||
..headers.addAll(_headers)
|
||||
..fields.addAll(options.toJson().map((k, v) => MapEntry(k, v.toString())))
|
||||
..files.add(http.MultipartFile.fromBytes(
|
||||
'model',
|
||||
modelData,
|
||||
filename: '${options.name}.${options.format.value}',
|
||||
));
|
||||
|
||||
final streamedResponse = await _httpClient.send(request);
|
||||
final response = await http.Response.fromStream(streamedResponse);
|
||||
|
||||
if (response.statusCode != 200) {
|
||||
throw SynorException(
|
||||
'Model upload failed',
|
||||
statusCode: response.statusCode,
|
||||
);
|
||||
}
|
||||
|
||||
final json = jsonDecode(response.body) as Map<String, dynamic>;
|
||||
return ModelUploadResult.fromJson(json);
|
||||
}
|
||||
|
||||
/// Delete a custom model (only owner can delete).
|
||||
Future<void> deleteModel(String modelId) async {
|
||||
_checkDisposed();
|
||||
|
||||
await _delete('/models/$modelId');
|
||||
}
|
||||
|
||||
// ==================== Training ====================
|
||||
|
||||
/// Train a model on a dataset.
|
||||
///
|
||||
/// Example:
|
||||
/// ```dart
|
||||
/// final result = await client.train(
|
||||
/// modelCid: 'QmBaseModelCID', // Base model to fine-tune
|
||||
/// datasetCid: 'QmDatasetCID', // Training dataset CID
|
||||
/// options: TrainingOptions(
|
||||
/// framework: MlFramework.pytorch,
|
||||
/// epochs: 10,
|
||||
/// batchSize: 32,
|
||||
/// learningRate: 0.0001,
|
||||
/// ),
|
||||
/// );
|
||||
/// print('Trained model CID: ${result.modelCid}');
|
||||
/// ```
|
||||
Future<JobResult<TrainingResult>> train({
|
||||
required String modelCid,
|
||||
required String datasetCid,
|
||||
TrainingOptions? options,
|
||||
}) async {
|
||||
_checkDisposed();
|
||||
|
||||
final opts = options ?? const TrainingOptions();
|
||||
final body = {
|
||||
'operation': 'training',
|
||||
'model_cid': modelCid,
|
||||
'dataset_cid': datasetCid,
|
||||
'options': opts.toJson(),
|
||||
};
|
||||
|
||||
return _submitAndWait<TrainingResult>(
|
||||
body,
|
||||
(result) => TrainingResult.fromJson(result as Map<String, dynamic>),
|
||||
);
|
||||
}
|
||||
|
||||
/// Stream training progress updates.
|
||||
Stream<TrainingProgress> trainStream({
|
||||
required String modelCid,
|
||||
required String datasetCid,
|
||||
TrainingOptions? options,
|
||||
}) async* {
|
||||
_checkDisposed();
|
||||
|
||||
final opts = options ?? const TrainingOptions();
|
||||
final body = {
|
||||
'operation': 'training',
|
||||
'model_cid': modelCid,
|
||||
'dataset_cid': datasetCid,
|
||||
'options': {
|
||||
...opts.toJson(),
|
||||
'stream': true,
|
||||
},
|
||||
};
|
||||
|
||||
final request = http.Request('POST', Uri.parse('${_config.baseUrl}/train/stream'))
|
||||
..headers.addAll(_headers)
|
||||
..body = jsonEncode(body);
|
||||
|
||||
final streamedResponse = await _httpClient.send(request);
|
||||
|
||||
if (streamedResponse.statusCode != 200) {
|
||||
throw SynorException(
|
||||
'Training stream failed',
|
||||
statusCode: streamedResponse.statusCode,
|
||||
);
|
||||
}
|
||||
|
||||
await for (final chunk in streamedResponse.stream.transform(utf8.decoder)) {
|
||||
for (final line in chunk.split('\n')) {
|
||||
if (line.startsWith('data: ')) {
|
||||
final data = line.substring(6);
|
||||
if (data == '[DONE]') return;
|
||||
try {
|
||||
final json = jsonDecode(data) as Map<String, dynamic>;
|
||||
yield TrainingProgress.fromJson(json);
|
||||
} catch (e) {
|
||||
// Skip malformed JSON
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Fine-tune a pre-trained model.
|
||||
Future<JobResult<TrainingResult>> fineTune({
|
||||
required String baseModel,
|
||||
required String datasetCid,
|
||||
String? outputAlias,
|
||||
TrainingOptions? options,
|
||||
}) async {
|
||||
_checkDisposed();
|
||||
|
||||
final opts = options ?? const TrainingOptions();
|
||||
final body = {
|
||||
'operation': 'fine_tune',
|
||||
'base_model': baseModel,
|
||||
'dataset_cid': datasetCid,
|
||||
if (outputAlias != null) 'output_alias': outputAlias,
|
||||
'options': opts.toJson(),
|
||||
};
|
||||
|
||||
return _submitAndWait<TrainingResult>(
|
||||
body,
|
||||
(result) => TrainingResult.fromJson(result as Map<String, dynamic>),
|
||||
);
|
||||
}
|
||||
|
||||
// Internal HTTP methods
|
||||
|
||||
Future<JobResult<T>> _submitAndWait<T>(
|
||||
|
|
|
|||
|
|
@ -345,6 +345,312 @@ class UsageStats {
|
|||
}
|
||||
}
|
||||
|
||||
/// Model category.
|
||||
enum ModelCategory {
|
||||
llm('llm'),
|
||||
embedding('embedding'),
|
||||
imageClassification('image_classification'),
|
||||
objectDetection('object_detection'),
|
||||
segmentation('segmentation'),
|
||||
imageGeneration('image_generation'),
|
||||
speechToText('speech_to_text'),
|
||||
textToSpeech('text_to_speech'),
|
||||
videoGeneration('video_generation'),
|
||||
multiModal('multi_modal'),
|
||||
custom('custom');
|
||||
|
||||
const ModelCategory(this.value);
|
||||
final String value;
|
||||
|
||||
static ModelCategory fromString(String s) =>
|
||||
ModelCategory.values.firstWhere((c) => c.value == s, orElse: () => custom);
|
||||
}
|
||||
|
||||
/// Model format.
|
||||
enum ModelFormat {
|
||||
onnx('onnx'),
|
||||
pytorch('pytorch'),
|
||||
torchScript('torchscript'),
|
||||
tensorflow('tensorflow'),
|
||||
tfLite('tflite'),
|
||||
safeTensors('safetensors'),
|
||||
gguf('gguf'),
|
||||
ggml('ggml'),
|
||||
custom('custom');
|
||||
|
||||
const ModelFormat(this.value);
|
||||
final String value;
|
||||
|
||||
static ModelFormat fromString(String s) =>
|
||||
ModelFormat.values.firstWhere((f) => f.value == s, orElse: () => custom);
|
||||
}
|
||||
|
||||
/// ML framework for training.
|
||||
enum MlFramework {
|
||||
pytorch('pytorch'),
|
||||
tensorflow('tensorflow'),
|
||||
jax('jax'),
|
||||
onnx('onnx');
|
||||
|
||||
const MlFramework(this.value);
|
||||
final String value;
|
||||
|
||||
static MlFramework fromString(String s) =>
|
||||
MlFramework.values.firstWhere((f) => f.value == s, orElse: () => pytorch);
|
||||
}
|
||||
|
||||
/// Model information.
|
||||
class ModelInfo {
|
||||
final String id;
|
||||
final String name;
|
||||
final String description;
|
||||
final ModelCategory category;
|
||||
final String cid;
|
||||
final ModelFormat format;
|
||||
final int sizeBytes;
|
||||
final int parameters;
|
||||
final List<Precision> supportedPrecisions;
|
||||
final ProcessorType recommendedProcessor;
|
||||
final int? contextLength;
|
||||
final String license;
|
||||
final String provider;
|
||||
final String version;
|
||||
final bool isPublic;
|
||||
|
||||
const ModelInfo({
|
||||
required this.id,
|
||||
required this.name,
|
||||
required this.description,
|
||||
required this.category,
|
||||
required this.cid,
|
||||
required this.format,
|
||||
required this.sizeBytes,
|
||||
required this.parameters,
|
||||
required this.supportedPrecisions,
|
||||
required this.recommendedProcessor,
|
||||
this.contextLength,
|
||||
required this.license,
|
||||
required this.provider,
|
||||
required this.version,
|
||||
required this.isPublic,
|
||||
});
|
||||
|
||||
factory ModelInfo.fromJson(Map<String, dynamic> json) => ModelInfo(
|
||||
id: json['id'] as String,
|
||||
name: json['name'] as String,
|
||||
description: json['description'] as String,
|
||||
category: ModelCategory.fromString(json['category'] as String),
|
||||
cid: json['cid'] as String,
|
||||
format: ModelFormat.fromString(json['format'] as String),
|
||||
sizeBytes: json['size_bytes'] as int,
|
||||
parameters: json['parameters'] as int,
|
||||
supportedPrecisions: (json['supported_precisions'] as List)
|
||||
.map((p) => Precision.fromString(p as String))
|
||||
.toList(),
|
||||
recommendedProcessor:
|
||||
ProcessorType.fromString(json['recommended_processor'] as String),
|
||||
contextLength: json['context_length'] as int?,
|
||||
license: json['license'] as String,
|
||||
provider: json['provider'] as String,
|
||||
version: json['version'] as String,
|
||||
isPublic: json['is_public'] as bool? ?? true,
|
||||
);
|
||||
|
||||
/// Format parameters for display (e.g., "70B", "7B").
|
||||
String get formattedParameters {
|
||||
if (parameters >= 1000000000000) {
|
||||
return '${(parameters / 1e12).toStringAsFixed(1)}T';
|
||||
} else if (parameters >= 1000000000) {
|
||||
return '${(parameters / 1e9).toStringAsFixed(1)}B';
|
||||
} else if (parameters >= 1000000) {
|
||||
return '${(parameters / 1e6).toStringAsFixed(1)}M';
|
||||
}
|
||||
return '$parameters';
|
||||
}
|
||||
}
|
||||
|
||||
/// Training options.
|
||||
class TrainingOptions {
|
||||
final MlFramework framework;
|
||||
final int epochs;
|
||||
final int batchSize;
|
||||
final double learningRate;
|
||||
final String? optimizer;
|
||||
final Map<String, dynamic>? hyperparameters;
|
||||
final bool distributed;
|
||||
final int? checkpointEvery;
|
||||
final ProcessorType? processor;
|
||||
final Priority? priority;
|
||||
|
||||
const TrainingOptions({
|
||||
this.framework = MlFramework.pytorch,
|
||||
this.epochs = 1,
|
||||
this.batchSize = 32,
|
||||
this.learningRate = 0.001,
|
||||
this.optimizer,
|
||||
this.hyperparameters,
|
||||
this.distributed = false,
|
||||
this.checkpointEvery,
|
||||
this.processor,
|
||||
this.priority,
|
||||
});
|
||||
|
||||
Map<String, dynamic> toJson() => {
|
||||
'framework': framework.value,
|
||||
'epochs': epochs,
|
||||
'batch_size': batchSize,
|
||||
'learning_rate': learningRate,
|
||||
if (optimizer != null) 'optimizer': optimizer,
|
||||
if (hyperparameters != null) 'hyperparameters': hyperparameters,
|
||||
'distributed': distributed,
|
||||
if (checkpointEvery != null) 'checkpoint_every': checkpointEvery,
|
||||
if (processor != null) 'processor': processor!.value,
|
||||
if (priority != null) 'priority': priority!.value,
|
||||
};
|
||||
}
|
||||
|
||||
/// Training job result.
|
||||
class TrainingResult {
|
||||
final String jobId;
|
||||
final String modelCid;
|
||||
final int epochs;
|
||||
final double finalLoss;
|
||||
final Map<String, double> metrics;
|
||||
final int durationMs;
|
||||
final double cost;
|
||||
final List<String>? checkpointCids;
|
||||
|
||||
const TrainingResult({
|
||||
required this.jobId,
|
||||
required this.modelCid,
|
||||
required this.epochs,
|
||||
required this.finalLoss,
|
||||
required this.metrics,
|
||||
required this.durationMs,
|
||||
required this.cost,
|
||||
this.checkpointCids,
|
||||
});
|
||||
|
||||
factory TrainingResult.fromJson(Map<String, dynamic> json) => TrainingResult(
|
||||
jobId: json['job_id'] as String,
|
||||
modelCid: json['model_cid'] as String,
|
||||
epochs: json['epochs'] as int,
|
||||
finalLoss: (json['final_loss'] as num).toDouble(),
|
||||
metrics: (json['metrics'] as Map<String, dynamic>)
|
||||
.map((k, v) => MapEntry(k, (v as num).toDouble())),
|
||||
durationMs: json['duration_ms'] as int,
|
||||
cost: (json['cost'] as num).toDouble(),
|
||||
checkpointCids: (json['checkpoint_cids'] as List?)?.cast<String>(),
|
||||
);
|
||||
}
|
||||
|
||||
/// Model upload options.
|
||||
class ModelUploadOptions {
|
||||
final String name;
|
||||
final String? description;
|
||||
final ModelCategory category;
|
||||
final ModelFormat format;
|
||||
final String? alias;
|
||||
final bool isPublic;
|
||||
final String? license;
|
||||
|
||||
const ModelUploadOptions({
|
||||
required this.name,
|
||||
this.description,
|
||||
this.category = ModelCategory.custom,
|
||||
this.format = ModelFormat.onnx,
|
||||
this.alias,
|
||||
this.isPublic = false,
|
||||
this.license,
|
||||
});
|
||||
|
||||
Map<String, dynamic> toJson() => {
|
||||
'name': name,
|
||||
if (description != null) 'description': description,
|
||||
'category': category.value,
|
||||
'format': format.value,
|
||||
if (alias != null) 'alias': alias,
|
||||
'is_public': isPublic,
|
||||
if (license != null) 'license': license,
|
||||
};
|
||||
}
|
||||
|
||||
/// Model upload result.
|
||||
class ModelUploadResult {
|
||||
final String modelId;
|
||||
final String cid;
|
||||
final int sizeBytes;
|
||||
|
||||
const ModelUploadResult({
|
||||
required this.modelId,
|
||||
required this.cid,
|
||||
required this.sizeBytes,
|
||||
});
|
||||
|
||||
factory ModelUploadResult.fromJson(Map<String, dynamic> json) =>
|
||||
ModelUploadResult(
|
||||
modelId: json['model_id'] as String,
|
||||
cid: json['cid'] as String,
|
||||
sizeBytes: json['size_bytes'] as int,
|
||||
);
|
||||
}
|
||||
|
||||
/// Training progress update.
|
||||
class TrainingProgress {
|
||||
final String jobId;
|
||||
final int epoch;
|
||||
final int totalEpochs;
|
||||
final int step;
|
||||
final int totalSteps;
|
||||
final double loss;
|
||||
final Map<String, double> metrics;
|
||||
final double? learningRate;
|
||||
final double? gradientNorm;
|
||||
final int samplesPerSecond;
|
||||
final int? estimatedRemainingMs;
|
||||
|
||||
const TrainingProgress({
|
||||
required this.jobId,
|
||||
required this.epoch,
|
||||
required this.totalEpochs,
|
||||
required this.step,
|
||||
required this.totalSteps,
|
||||
required this.loss,
|
||||
required this.metrics,
|
||||
this.learningRate,
|
||||
this.gradientNorm,
|
||||
required this.samplesPerSecond,
|
||||
this.estimatedRemainingMs,
|
||||
});
|
||||
|
||||
factory TrainingProgress.fromJson(Map<String, dynamic> json) =>
|
||||
TrainingProgress(
|
||||
jobId: json['job_id'] as String,
|
||||
epoch: json['epoch'] as int,
|
||||
totalEpochs: json['total_epochs'] as int,
|
||||
step: json['step'] as int,
|
||||
totalSteps: json['total_steps'] as int,
|
||||
loss: (json['loss'] as num).toDouble(),
|
||||
metrics: (json['metrics'] as Map<String, dynamic>?)
|
||||
?.map((k, v) => MapEntry(k, (v as num).toDouble())) ??
|
||||
{},
|
||||
learningRate: (json['learning_rate'] as num?)?.toDouble(),
|
||||
gradientNorm: (json['gradient_norm'] as num?)?.toDouble(),
|
||||
samplesPerSecond: json['samples_per_second'] as int? ?? 0,
|
||||
estimatedRemainingMs: json['estimated_remaining_ms'] as int?,
|
||||
);
|
||||
|
||||
/// Progress percentage (0.0 - 1.0).
|
||||
double get progress {
|
||||
if (totalSteps == 0) return 0.0;
|
||||
return step / totalSteps;
|
||||
}
|
||||
|
||||
/// Formatted progress string.
|
||||
String get progressText =>
|
||||
'Epoch ${epoch + 1}/$totalEpochs - Step $step/$totalSteps - Loss: ${loss.toStringAsFixed(4)}';
|
||||
}
|
||||
|
||||
/// Exception thrown by Synor Compute operations
|
||||
class SynorException implements Exception {
|
||||
final String message;
|
||||
|
|
|
|||
|
|
@ -83,7 +83,18 @@ export 'src/types.dart'
|
|||
InferenceOptions,
|
||||
PricingInfo,
|
||||
UsageStats,
|
||||
SynorException;
|
||||
SynorException,
|
||||
// Model types
|
||||
ModelCategory,
|
||||
ModelFormat,
|
||||
MlFramework,
|
||||
ModelInfo,
|
||||
ModelUploadOptions,
|
||||
ModelUploadResult,
|
||||
// Training types
|
||||
TrainingOptions,
|
||||
TrainingResult,
|
||||
TrainingProgress;
|
||||
|
||||
export 'src/tensor.dart' show Tensor;
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue