Add comprehensive dataset management to the Flutter SDK including: - Dataset formats: JSONL, CSV, Parquet, Arrow, HuggingFace, TFRecord, WebDataset, Text, ImageFolder, Custom - Dataset types: text completion, instruction tuning, chat, Q&A, classification, NER, vision, audio - Upload methods: uploadDataset, uploadDatasetFromFile, createDatasetFromRecords - Management APIs: listDatasets, getDataset, deleteDataset - Dataset preprocessing: splitting, shuffling, deduplication, tokenization - Complete examples showing all formats and use cases
885 lines
24 KiB
Dart
885 lines
24 KiB
Dart
/// Main client for Synor Compute SDK
|
|
library synor_compute.client;
|
|
|
|
import 'dart:async';
|
|
import 'dart:convert';
|
|
|
|
import 'package:http/http.dart' as http;
|
|
|
|
import 'job.dart';
|
|
import 'tensor.dart';
|
|
import 'types.dart';
|
|
|
|
/// Main client for interacting with Synor Compute
|
|
class SynorCompute {
|
|
final SynorConfig _config;
|
|
final http.Client _httpClient;
|
|
bool _isDisposed = false;
|
|
|
|
/// Creates a new Synor Compute client
|
|
SynorCompute({
|
|
required String apiKey,
|
|
String baseUrl = 'https://compute.synor.io',
|
|
Duration timeout = const Duration(seconds: 30),
|
|
int maxRetries = 3,
|
|
ProcessorType defaultProcessor = ProcessorType.auto,
|
|
Precision defaultPrecision = Precision.fp32,
|
|
Priority defaultPriority = Priority.normal,
|
|
http.Client? httpClient,
|
|
}) : _config = SynorConfig(
|
|
apiKey: apiKey,
|
|
baseUrl: baseUrl,
|
|
timeout: timeout,
|
|
maxRetries: maxRetries,
|
|
defaultProcessor: defaultProcessor,
|
|
defaultPrecision: defaultPrecision,
|
|
defaultPriority: defaultPriority,
|
|
),
|
|
_httpClient = httpClient ?? http.Client();
|
|
|
|
/// Creates a client from configuration
|
|
SynorCompute.fromConfig(SynorConfig config, {http.Client? httpClient})
|
|
: _config = config,
|
|
_httpClient = httpClient ?? http.Client();
|
|
|
|
Map<String, String> get _headers => {
|
|
'Authorization': 'Bearer ${_config.apiKey}',
|
|
'Content-Type': 'application/json',
|
|
'Accept': 'application/json',
|
|
'X-SDK-Version': 'flutter/0.1.0',
|
|
};
|
|
|
|
void _checkDisposed() {
|
|
if (_isDisposed) {
|
|
throw StateError('Client has been disposed');
|
|
}
|
|
}
|
|
|
|
/// Perform matrix multiplication
|
|
Future<JobResult<Tensor>> matmul(
|
|
Tensor a,
|
|
Tensor b, {
|
|
MatMulOptions? options,
|
|
}) async {
|
|
_checkDisposed();
|
|
|
|
final opts = options ?? const MatMulOptions();
|
|
final body = {
|
|
'operation': 'matmul',
|
|
'inputs': {
|
|
'a': a.toJson(),
|
|
'b': b.toJson(),
|
|
},
|
|
'options': {
|
|
'precision': (opts.precision ?? _config.defaultPrecision).value,
|
|
'processor': (opts.processor ?? _config.defaultProcessor).value,
|
|
'priority': (opts.priority ?? _config.defaultPriority).value,
|
|
'transpose_a': opts.transposeA,
|
|
'transpose_b': opts.transposeB,
|
|
},
|
|
};
|
|
|
|
return _submitAndWait<Tensor>(
|
|
body,
|
|
(result) => Tensor.fromJson(result as Map<String, dynamic>),
|
|
);
|
|
}
|
|
|
|
/// Perform 2D convolution
|
|
Future<JobResult<Tensor>> conv2d(
|
|
Tensor input,
|
|
Tensor kernel, {
|
|
Conv2dOptions? options,
|
|
}) async {
|
|
_checkDisposed();
|
|
|
|
final opts = options ?? const Conv2dOptions();
|
|
final body = {
|
|
'operation': 'conv2d',
|
|
'inputs': {
|
|
'input': input.toJson(),
|
|
'kernel': kernel.toJson(),
|
|
},
|
|
'options': {
|
|
...opts.toJson(),
|
|
'precision': (opts.precision ?? _config.defaultPrecision).value,
|
|
'processor': (opts.processor ?? _config.defaultProcessor).value,
|
|
'priority': (opts.priority ?? _config.defaultPriority).value,
|
|
},
|
|
};
|
|
|
|
return _submitAndWait<Tensor>(
|
|
body,
|
|
(result) => Tensor.fromJson(result as Map<String, dynamic>),
|
|
);
|
|
}
|
|
|
|
/// Perform flash attention
|
|
Future<JobResult<Tensor>> attention(
|
|
Tensor query,
|
|
Tensor key,
|
|
Tensor value, {
|
|
required AttentionOptions options,
|
|
}) async {
|
|
_checkDisposed();
|
|
|
|
final body = {
|
|
'operation': 'flash_attention',
|
|
'inputs': {
|
|
'query': query.toJson(),
|
|
'key': key.toJson(),
|
|
'value': value.toJson(),
|
|
},
|
|
'options': {
|
|
...options.toJson(),
|
|
'precision': (options.precision ?? _config.defaultPrecision).value,
|
|
'processor': (options.processor ?? _config.defaultProcessor).value,
|
|
'priority': (options.priority ?? _config.defaultPriority).value,
|
|
},
|
|
};
|
|
|
|
return _submitAndWait<Tensor>(
|
|
body,
|
|
(result) => Tensor.fromJson(result as Map<String, dynamic>),
|
|
);
|
|
}
|
|
|
|
/// Run LLM inference
|
|
Future<JobResult<String>> inference(
|
|
String model,
|
|
String input, {
|
|
InferenceOptions? options,
|
|
}) async {
|
|
_checkDisposed();
|
|
|
|
final opts = options ?? const InferenceOptions();
|
|
final body = {
|
|
'operation': 'inference',
|
|
'model': model,
|
|
'input': input,
|
|
'options': {
|
|
...opts.toJson(),
|
|
'processor': (opts.processor ?? _config.defaultProcessor).value,
|
|
'priority': (opts.priority ?? _config.defaultPriority).value,
|
|
},
|
|
};
|
|
|
|
return _submitAndWait<String>(
|
|
body,
|
|
(result) => result['text'] as String,
|
|
);
|
|
}
|
|
|
|
/// Stream LLM inference with real-time token output
|
|
Stream<String> inferenceStream(
|
|
String model,
|
|
String input, {
|
|
InferenceOptions? options,
|
|
}) async* {
|
|
_checkDisposed();
|
|
|
|
final opts = options ?? const InferenceOptions();
|
|
final body = {
|
|
'operation': 'inference',
|
|
'model': model,
|
|
'input': input,
|
|
'options': {
|
|
...opts.toJson(),
|
|
'stream': true,
|
|
'processor': (opts.processor ?? _config.defaultProcessor).value,
|
|
'priority': (opts.priority ?? _config.defaultPriority).value,
|
|
},
|
|
};
|
|
|
|
final request = http.Request('POST', Uri.parse('${_config.baseUrl}/stream'))
|
|
..headers.addAll(_headers)
|
|
..body = jsonEncode(body);
|
|
|
|
final streamedResponse = await _httpClient.send(request);
|
|
|
|
if (streamedResponse.statusCode != 200) {
|
|
final responseBody = await streamedResponse.stream.bytesToString();
|
|
throw SynorException(
|
|
'Streaming request failed',
|
|
statusCode: streamedResponse.statusCode,
|
|
details: {'response': responseBody},
|
|
);
|
|
}
|
|
|
|
await for (final chunk in streamedResponse.stream.transform(utf8.decoder)) {
|
|
// Parse SSE format
|
|
for (final line in chunk.split('\n')) {
|
|
if (line.startsWith('data: ')) {
|
|
final data = line.substring(6);
|
|
if (data == '[DONE]') return;
|
|
try {
|
|
final json = jsonDecode(data) as Map<String, dynamic>;
|
|
if (json['token'] != null) {
|
|
yield json['token'] as String;
|
|
}
|
|
} catch (e) {
|
|
// Skip malformed JSON
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Apply element-wise operation
|
|
Future<JobResult<Tensor>> elementwise(
|
|
String operation,
|
|
Tensor input, {
|
|
Tensor? other,
|
|
double? scalar,
|
|
Precision? precision,
|
|
ProcessorType? processor,
|
|
Priority? priority,
|
|
}) async {
|
|
_checkDisposed();
|
|
|
|
final body = {
|
|
'operation': 'elementwise',
|
|
'op': operation,
|
|
'inputs': {
|
|
'input': input.toJson(),
|
|
if (other != null) 'other': other.toJson(),
|
|
if (scalar != null) 'scalar': scalar,
|
|
},
|
|
'options': {
|
|
'precision': (precision ?? _config.defaultPrecision).value,
|
|
'processor': (processor ?? _config.defaultProcessor).value,
|
|
'priority': (priority ?? _config.defaultPriority).value,
|
|
},
|
|
};
|
|
|
|
return _submitAndWait<Tensor>(
|
|
body,
|
|
(result) => Tensor.fromJson(result as Map<String, dynamic>),
|
|
);
|
|
}
|
|
|
|
/// Reduce operation (sum, mean, max, min, etc.)
|
|
Future<JobResult<Tensor>> reduce(
|
|
String operation,
|
|
Tensor input, {
|
|
List<int>? axes,
|
|
bool keepDims = false,
|
|
Precision? precision,
|
|
ProcessorType? processor,
|
|
Priority? priority,
|
|
}) async {
|
|
_checkDisposed();
|
|
|
|
final body = {
|
|
'operation': 'reduce',
|
|
'op': operation,
|
|
'inputs': {
|
|
'input': input.toJson(),
|
|
},
|
|
'options': {
|
|
if (axes != null) 'axes': axes,
|
|
'keep_dims': keepDims,
|
|
'precision': (precision ?? _config.defaultPrecision).value,
|
|
'processor': (processor ?? _config.defaultProcessor).value,
|
|
'priority': (priority ?? _config.defaultPriority).value,
|
|
},
|
|
};
|
|
|
|
return _submitAndWait<Tensor>(
|
|
body,
|
|
(result) => Tensor.fromJson(result as Map<String, dynamic>),
|
|
);
|
|
}
|
|
|
|
/// Submit a custom operation
|
|
Future<Job<T>> submit<T>(
|
|
Map<String, dynamic> operation,
|
|
T Function(dynamic) resultParser,
|
|
) async {
|
|
_checkDisposed();
|
|
|
|
final response = await _post('/jobs', operation);
|
|
final jobId = response['job_id'] as String;
|
|
|
|
return Job<T>(
|
|
jobId: jobId,
|
|
baseUrl: _config.baseUrl,
|
|
headers: _headers,
|
|
resultParser: resultParser,
|
|
);
|
|
}
|
|
|
|
/// Get job by ID
|
|
Future<JobResult<T>> getJob<T>(
|
|
String jobId, {
|
|
T Function(dynamic)? resultParser,
|
|
}) async {
|
|
_checkDisposed();
|
|
|
|
final response = await _get('/jobs/$jobId');
|
|
return JobResult<T>.fromJson(response, resultParser);
|
|
}
|
|
|
|
/// Cancel a job
|
|
Future<bool> cancelJob(String jobId) async {
|
|
_checkDisposed();
|
|
|
|
try {
|
|
await _post('/jobs/$jobId/cancel', {});
|
|
return true;
|
|
} catch (e) {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
/// List active jobs
|
|
Future<List<JobResult<dynamic>>> listJobs({
|
|
JobStatus? status,
|
|
int limit = 20,
|
|
int offset = 0,
|
|
}) async {
|
|
_checkDisposed();
|
|
|
|
final params = <String, String>{
|
|
'limit': limit.toString(),
|
|
'offset': offset.toString(),
|
|
if (status != null) 'status': status.value,
|
|
};
|
|
|
|
final response = await _get('/jobs', params);
|
|
final jobs = response['jobs'] as List;
|
|
return jobs
|
|
.map((j) => JobResult<dynamic>.fromJson(j as Map<String, dynamic>, null))
|
|
.toList();
|
|
}
|
|
|
|
/// Get current pricing information
|
|
Future<List<PricingInfo>> getPricing() async {
|
|
_checkDisposed();
|
|
|
|
final response = await _get('/pricing');
|
|
final pricing = response['pricing'] as List;
|
|
return pricing
|
|
.map((p) => PricingInfo.fromJson(p as Map<String, dynamic>))
|
|
.toList();
|
|
}
|
|
|
|
/// Get pricing for specific processor
|
|
Future<PricingInfo> getPricingFor(ProcessorType processor) async {
|
|
final allPricing = await getPricing();
|
|
return allPricing.firstWhere(
|
|
(p) => p.processor == processor,
|
|
orElse: () => throw SynorException(
|
|
'No pricing available for processor ${processor.value}',
|
|
),
|
|
);
|
|
}
|
|
|
|
/// Get account usage statistics
|
|
Future<UsageStats> getUsage({DateTime? from, DateTime? to}) async {
|
|
_checkDisposed();
|
|
|
|
final params = <String, String>{
|
|
if (from != null) 'from': from.toIso8601String(),
|
|
if (to != null) 'to': to.toIso8601String(),
|
|
};
|
|
|
|
final response = await _get('/usage', params);
|
|
return UsageStats.fromJson(response);
|
|
}
|
|
|
|
/// Upload a tensor for reuse
|
|
Future<String> uploadTensor(Tensor tensor, {String? name}) async {
|
|
_checkDisposed();
|
|
|
|
final body = {
|
|
'tensor': tensor.toJson(),
|
|
if (name != null) 'name': name,
|
|
};
|
|
|
|
final response = await _post('/tensors', body);
|
|
return response['tensor_id'] as String;
|
|
}
|
|
|
|
/// Download a tensor by ID
|
|
Future<Tensor> downloadTensor(String tensorId) async {
|
|
_checkDisposed();
|
|
|
|
final response = await _get('/tensors/$tensorId');
|
|
return Tensor.fromJson(response);
|
|
}
|
|
|
|
/// Delete a tensor by ID
|
|
Future<void> deleteTensor(String tensorId) async {
|
|
_checkDisposed();
|
|
|
|
await _delete('/tensors/$tensorId');
|
|
}
|
|
|
|
/// Health check
|
|
Future<bool> healthCheck() async {
|
|
try {
|
|
final response = await _get('/health');
|
|
return response['status'] == 'healthy';
|
|
} catch (e) {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
// ==================== Model Registry ====================
|
|
|
|
/// List all available models.
|
|
Future<List<ModelInfo>> listModels({ModelCategory? category}) async {
|
|
_checkDisposed();
|
|
|
|
final params = <String, String>{
|
|
if (category != null) 'category': category.value,
|
|
};
|
|
|
|
final response = await _get('/models', params);
|
|
final models = response['models'] as List;
|
|
return models
|
|
.map((m) => ModelInfo.fromJson(m as Map<String, dynamic>))
|
|
.toList();
|
|
}
|
|
|
|
/// Get model info by ID or alias.
|
|
Future<ModelInfo> getModel(String modelId) async {
|
|
_checkDisposed();
|
|
|
|
final response = await _get('/models/$modelId');
|
|
return ModelInfo.fromJson(response);
|
|
}
|
|
|
|
/// Search models by query.
|
|
Future<List<ModelInfo>> searchModels(String query) async {
|
|
_checkDisposed();
|
|
|
|
final response = await _get('/models/search', {'q': query});
|
|
final models = response['models'] as List;
|
|
return models
|
|
.map((m) => ModelInfo.fromJson(m as Map<String, dynamic>))
|
|
.toList();
|
|
}
|
|
|
|
/// Upload a custom model.
|
|
Future<ModelUploadResult> uploadModel(
|
|
List<int> modelData,
|
|
ModelUploadOptions options,
|
|
) async {
|
|
_checkDisposed();
|
|
|
|
// For large files, use multipart upload
|
|
final uri = Uri.parse('${_config.baseUrl}/models/upload');
|
|
final request = http.MultipartRequest('POST', uri)
|
|
..headers.addAll(_headers)
|
|
..fields.addAll(options.toJson().map((k, v) => MapEntry(k, v.toString())))
|
|
..files.add(http.MultipartFile.fromBytes(
|
|
'model',
|
|
modelData,
|
|
filename: '${options.name}.${options.format.value}',
|
|
));
|
|
|
|
final streamedResponse = await _httpClient.send(request);
|
|
final response = await http.Response.fromStream(streamedResponse);
|
|
|
|
if (response.statusCode != 200) {
|
|
throw SynorException(
|
|
'Model upload failed',
|
|
statusCode: response.statusCode,
|
|
);
|
|
}
|
|
|
|
final json = jsonDecode(response.body) as Map<String, dynamic>;
|
|
return ModelUploadResult.fromJson(json);
|
|
}
|
|
|
|
/// Delete a custom model (only owner can delete).
|
|
Future<void> deleteModel(String modelId) async {
|
|
_checkDisposed();
|
|
|
|
await _delete('/models/$modelId');
|
|
}
|
|
|
|
// ==================== Dataset Management ====================
|
|
|
|
/// Upload a dataset for training.
|
|
///
|
|
/// Supports multiple formats: JSONL, CSV, Parquet, Arrow, HuggingFace, etc.
|
|
///
|
|
/// Example (JSONL format):
|
|
/// ```dart
|
|
/// // Create JSONL dataset
|
|
/// final jsonlData = '''
|
|
/// {"prompt": "What is 2+2?", "completion": "4"}
|
|
/// {"prompt": "Capital of France?", "completion": "Paris"}
|
|
/// {"prompt": "Hello", "completion": "Hi there!"}
|
|
/// ''';
|
|
///
|
|
/// final dataset = await client.uploadDataset(
|
|
/// utf8.encode(jsonlData),
|
|
/// DatasetUploadOptions(
|
|
/// name: 'my-qa-dataset',
|
|
/// format: DatasetFormat.jsonl,
|
|
/// type: DatasetType.textCompletion,
|
|
/// split: DatasetSplit(train: 0.8, validation: 0.1, test: 0.1),
|
|
/// ),
|
|
/// );
|
|
/// print('Dataset CID: ${dataset.cid}');
|
|
/// print('Total samples: ${dataset.totalSamples}');
|
|
/// ```
|
|
Future<DatasetUploadResult> uploadDataset(
|
|
List<int> data,
|
|
DatasetUploadOptions options,
|
|
) async {
|
|
_checkDisposed();
|
|
|
|
final uri = Uri.parse('${_config.baseUrl}/datasets/upload');
|
|
final request = http.MultipartRequest('POST', uri)
|
|
..headers.addAll(_headers)
|
|
..fields.addAll(options.toJson().map((k, v) {
|
|
if (v is Map || v is List) {
|
|
return MapEntry(k, jsonEncode(v));
|
|
}
|
|
return MapEntry(k, v.toString());
|
|
}))
|
|
..files.add(http.MultipartFile.fromBytes(
|
|
'dataset',
|
|
data,
|
|
filename: '${options.name}.${options.format.value}',
|
|
));
|
|
|
|
final streamedResponse = await _httpClient.send(request);
|
|
final response = await http.Response.fromStream(streamedResponse);
|
|
|
|
if (response.statusCode != 200) {
|
|
throw SynorException(
|
|
'Dataset upload failed',
|
|
statusCode: response.statusCode,
|
|
);
|
|
}
|
|
|
|
final json = jsonDecode(response.body) as Map<String, dynamic>;
|
|
return DatasetUploadResult.fromJson(json);
|
|
}
|
|
|
|
/// Upload a dataset from a file path.
|
|
Future<DatasetUploadResult> uploadDatasetFromFile(
|
|
String filePath,
|
|
DatasetUploadOptions options,
|
|
) async {
|
|
_checkDisposed();
|
|
|
|
final uri = Uri.parse('${_config.baseUrl}/datasets/upload');
|
|
final request = http.MultipartRequest('POST', uri)
|
|
..headers.addAll(_headers)
|
|
..fields.addAll(options.toJson().map((k, v) {
|
|
if (v is Map || v is List) {
|
|
return MapEntry(k, jsonEncode(v));
|
|
}
|
|
return MapEntry(k, v.toString());
|
|
}))
|
|
..files.add(await http.MultipartFile.fromPath('dataset', filePath));
|
|
|
|
final streamedResponse = await _httpClient.send(request);
|
|
final response = await http.Response.fromStream(streamedResponse);
|
|
|
|
if (response.statusCode != 200) {
|
|
throw SynorException(
|
|
'Dataset upload failed',
|
|
statusCode: response.statusCode,
|
|
);
|
|
}
|
|
|
|
final json = jsonDecode(response.body) as Map<String, dynamic>;
|
|
return DatasetUploadResult.fromJson(json);
|
|
}
|
|
|
|
/// List uploaded datasets.
|
|
Future<List<DatasetInfo>> listDatasets({DatasetType? type}) async {
|
|
_checkDisposed();
|
|
|
|
final params = <String, String>{
|
|
if (type != null) 'type': type.value,
|
|
};
|
|
|
|
final response = await _get('/datasets', params);
|
|
final datasets = response['datasets'] as List;
|
|
return datasets
|
|
.map((d) => DatasetInfo.fromJson(d as Map<String, dynamic>))
|
|
.toList();
|
|
}
|
|
|
|
/// Get dataset info by ID or CID.
|
|
Future<DatasetInfo> getDataset(String datasetId) async {
|
|
_checkDisposed();
|
|
|
|
final response = await _get('/datasets/$datasetId');
|
|
return DatasetInfo.fromJson(response);
|
|
}
|
|
|
|
/// Delete a dataset.
|
|
Future<void> deleteDataset(String datasetId) async {
|
|
_checkDisposed();
|
|
|
|
await _delete('/datasets/$datasetId');
|
|
}
|
|
|
|
/// Create a dataset from inline data (convenience method).
|
|
///
|
|
/// Example (instruction tuning):
|
|
/// ```dart
|
|
/// final dataset = await client.createDatasetFromRecords(
|
|
/// name: 'instruction-dataset',
|
|
/// records: [
|
|
/// {'instruction': 'Summarize:', 'input': 'Long text...', 'output': 'Summary'},
|
|
/// {'instruction': 'Translate:', 'input': 'Hello', 'output': 'Hola'},
|
|
/// ],
|
|
/// type: DatasetType.instructionTuning,
|
|
/// );
|
|
/// ```
|
|
Future<DatasetUploadResult> createDatasetFromRecords({
|
|
required String name,
|
|
required List<Map<String, dynamic>> records,
|
|
DatasetType type = DatasetType.textCompletion,
|
|
DatasetSplit? split,
|
|
}) async {
|
|
// Convert to JSONL format
|
|
final jsonlLines = records.map((r) => jsonEncode(r)).join('\n');
|
|
final data = utf8.encode(jsonlLines);
|
|
|
|
return uploadDataset(
|
|
data,
|
|
DatasetUploadOptions(
|
|
name: name,
|
|
format: DatasetFormat.jsonl,
|
|
type: type,
|
|
split: split,
|
|
),
|
|
);
|
|
}
|
|
|
|
// ==================== Training ====================
|
|
|
|
/// Train a model on a dataset.
|
|
///
|
|
/// Example:
|
|
/// ```dart
|
|
/// final result = await client.train(
|
|
/// modelCid: 'QmBaseModelCID', // Base model to fine-tune
|
|
/// datasetCid: 'QmDatasetCID', // Training dataset CID
|
|
/// options: TrainingOptions(
|
|
/// framework: MlFramework.pytorch,
|
|
/// epochs: 10,
|
|
/// batchSize: 32,
|
|
/// learningRate: 0.0001,
|
|
/// ),
|
|
/// );
|
|
/// print('Trained model CID: ${result.modelCid}');
|
|
/// ```
|
|
Future<JobResult<TrainingResult>> train({
|
|
required String modelCid,
|
|
required String datasetCid,
|
|
TrainingOptions? options,
|
|
}) async {
|
|
_checkDisposed();
|
|
|
|
final opts = options ?? const TrainingOptions();
|
|
final body = {
|
|
'operation': 'training',
|
|
'model_cid': modelCid,
|
|
'dataset_cid': datasetCid,
|
|
'options': opts.toJson(),
|
|
};
|
|
|
|
return _submitAndWait<TrainingResult>(
|
|
body,
|
|
(result) => TrainingResult.fromJson(result as Map<String, dynamic>),
|
|
);
|
|
}
|
|
|
|
/// Stream training progress updates.
|
|
Stream<TrainingProgress> trainStream({
|
|
required String modelCid,
|
|
required String datasetCid,
|
|
TrainingOptions? options,
|
|
}) async* {
|
|
_checkDisposed();
|
|
|
|
final opts = options ?? const TrainingOptions();
|
|
final body = {
|
|
'operation': 'training',
|
|
'model_cid': modelCid,
|
|
'dataset_cid': datasetCid,
|
|
'options': {
|
|
...opts.toJson(),
|
|
'stream': true,
|
|
},
|
|
};
|
|
|
|
final request = http.Request('POST', Uri.parse('${_config.baseUrl}/train/stream'))
|
|
..headers.addAll(_headers)
|
|
..body = jsonEncode(body);
|
|
|
|
final streamedResponse = await _httpClient.send(request);
|
|
|
|
if (streamedResponse.statusCode != 200) {
|
|
throw SynorException(
|
|
'Training stream failed',
|
|
statusCode: streamedResponse.statusCode,
|
|
);
|
|
}
|
|
|
|
await for (final chunk in streamedResponse.stream.transform(utf8.decoder)) {
|
|
for (final line in chunk.split('\n')) {
|
|
if (line.startsWith('data: ')) {
|
|
final data = line.substring(6);
|
|
if (data == '[DONE]') return;
|
|
try {
|
|
final json = jsonDecode(data) as Map<String, dynamic>;
|
|
yield TrainingProgress.fromJson(json);
|
|
} catch (e) {
|
|
// Skip malformed JSON
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Fine-tune a pre-trained model.
|
|
Future<JobResult<TrainingResult>> fineTune({
|
|
required String baseModel,
|
|
required String datasetCid,
|
|
String? outputAlias,
|
|
TrainingOptions? options,
|
|
}) async {
|
|
_checkDisposed();
|
|
|
|
final opts = options ?? const TrainingOptions();
|
|
final body = {
|
|
'operation': 'fine_tune',
|
|
'base_model': baseModel,
|
|
'dataset_cid': datasetCid,
|
|
if (outputAlias != null) 'output_alias': outputAlias,
|
|
'options': opts.toJson(),
|
|
};
|
|
|
|
return _submitAndWait<TrainingResult>(
|
|
body,
|
|
(result) => TrainingResult.fromJson(result as Map<String, dynamic>),
|
|
);
|
|
}
|
|
|
|
// Internal HTTP methods
|
|
|
|
Future<JobResult<T>> _submitAndWait<T>(
|
|
Map<String, dynamic> body,
|
|
T Function(dynamic) resultParser,
|
|
) async {
|
|
final response = await _post('/jobs', body);
|
|
final jobId = response['job_id'] as String;
|
|
|
|
final job = Job<T>(
|
|
jobId: jobId,
|
|
baseUrl: _config.baseUrl,
|
|
headers: _headers,
|
|
resultParser: resultParser,
|
|
);
|
|
|
|
try {
|
|
return await _pollJob<T>(jobId, resultParser);
|
|
} finally {
|
|
job.dispose();
|
|
}
|
|
}
|
|
|
|
Future<JobResult<T>> _pollJob<T>(
|
|
String jobId,
|
|
T Function(dynamic) resultParser, {
|
|
Duration interval = const Duration(milliseconds: 500),
|
|
Duration timeout = const Duration(minutes: 5),
|
|
}) async {
|
|
final endTime = DateTime.now().add(timeout);
|
|
|
|
while (DateTime.now().isBefore(endTime)) {
|
|
final response = await _get('/jobs/$jobId');
|
|
final result = JobResult<T>.fromJson(response, resultParser);
|
|
|
|
if (result.status.isTerminal) {
|
|
return result;
|
|
}
|
|
|
|
await Future.delayed(interval);
|
|
}
|
|
|
|
throw SynorException('Job polling timed out after $timeout');
|
|
}
|
|
|
|
Future<Map<String, dynamic>> _get(
|
|
String path, [
|
|
Map<String, String>? queryParams,
|
|
]) async {
|
|
var uri = Uri.parse('${_config.baseUrl}$path');
|
|
if (queryParams != null && queryParams.isNotEmpty) {
|
|
uri = uri.replace(queryParameters: queryParams);
|
|
}
|
|
|
|
final response = await _httpClient
|
|
.get(uri, headers: _headers)
|
|
.timeout(_config.timeout);
|
|
|
|
return _handleResponse(response);
|
|
}
|
|
|
|
Future<Map<String, dynamic>> _post(
|
|
String path,
|
|
Map<String, dynamic> body,
|
|
) async {
|
|
final uri = Uri.parse('${_config.baseUrl}$path');
|
|
final response = await _httpClient
|
|
.post(uri, headers: _headers, body: jsonEncode(body))
|
|
.timeout(_config.timeout);
|
|
|
|
return _handleResponse(response);
|
|
}
|
|
|
|
Future<void> _delete(String path) async {
|
|
final uri = Uri.parse('${_config.baseUrl}$path');
|
|
final response = await _httpClient
|
|
.delete(uri, headers: _headers)
|
|
.timeout(_config.timeout);
|
|
|
|
if (response.statusCode != 200 && response.statusCode != 204) {
|
|
_handleResponse(response);
|
|
}
|
|
}
|
|
|
|
Map<String, dynamic> _handleResponse(http.Response response) {
|
|
if (response.statusCode >= 200 && response.statusCode < 300) {
|
|
if (response.body.isEmpty) {
|
|
return {};
|
|
}
|
|
return jsonDecode(response.body) as Map<String, dynamic>;
|
|
}
|
|
|
|
Map<String, dynamic>? errorBody;
|
|
try {
|
|
errorBody = jsonDecode(response.body) as Map<String, dynamic>;
|
|
} catch (e) {
|
|
// Body is not JSON
|
|
}
|
|
|
|
throw SynorException(
|
|
errorBody?['message'] as String? ?? 'Request failed',
|
|
code: errorBody?['code'] as String?,
|
|
statusCode: response.statusCode,
|
|
details: errorBody,
|
|
);
|
|
}
|
|
|
|
/// Dispose the client and release resources
|
|
void dispose() {
|
|
_isDisposed = true;
|
|
_httpClient.close();
|
|
}
|
|
}
|