"""Compute Job management for Synor Compute SDK.""" from typing import TYPE_CHECKING, Any, Callable, Optional import asyncio from .types import JobStatus, JobMetrics, JobResult if TYPE_CHECKING: from .client import SynorCompute POLL_INTERVAL_SEC = 0.5 MAX_POLL_ATTEMPTS = 120 # 60 seconds max class ComputeJob: """ Represents a compute job on the Synor network. Example: >>> job = await compute.submit_job("matmul", {...}) >>> result = await job.wait() >>> print(result.data) """ def __init__(self, job_id: str, client: "SynorCompute"): self.job_id = job_id self._client = client self._status = JobStatus.PENDING self._result: Optional[Any] = None self._error: Optional[str] = None self._metrics: Optional[JobMetrics] = None self._callbacks: dict[JobStatus, list[Callable[["ComputeJob"], None]]] = {} @property def status(self) -> JobStatus: """Get current job status.""" return self._status def is_complete(self) -> bool: """Check if job is complete.""" return self._status in (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED) async def wait( self, timeout: Optional[float] = None, poll_interval: float = POLL_INTERVAL_SEC, ) -> JobResult: """ Wait for job to complete. Args: timeout: Maximum wait time in seconds poll_interval: Time between status checks Returns: JobResult with data or error """ timeout = timeout or MAX_POLL_ATTEMPTS * POLL_INTERVAL_SEC max_attempts = int(timeout / poll_interval) attempts = 0 while not self.is_complete() and attempts < max_attempts: await self.refresh() if not self.is_complete(): await asyncio.sleep(poll_interval) attempts += 1 if not self.is_complete(): raise JobTimeoutError(self.job_id, timeout) return JobResult( job_id=self.job_id, status=self._status, data=self._result, error=self._error, metrics=self._metrics, ) async def refresh(self) -> None: """Refresh job status from server.""" result = await self._client.get_job_status(self.job_id) previous_status = self._status self._status = result.status self._result = result.data self._error = result.error self._metrics = result.metrics if previous_status != self._status: self._trigger_callbacks(self._status) async def cancel(self) -> None: """Cancel the job.""" await self._client.cancel_job(self.job_id) self._status = JobStatus.CANCELLED self._trigger_callbacks(JobStatus.CANCELLED) def on(self, status: JobStatus, callback: Callable[["ComputeJob"], None]) -> "ComputeJob": """Register a callback for status changes.""" if status not in self._callbacks: self._callbacks[status] = [] self._callbacks[status].append(callback) return self def get_result(self) -> Any: """Get the result (raises if not complete).""" if self._status != JobStatus.COMPLETED: raise ValueError(f"Job {self.job_id} is not completed (status: {self._status})") return self._result def get_error(self) -> Optional[str]: """Get error message (if failed).""" return self._error def get_metrics(self) -> Optional[JobMetrics]: """Get execution metrics (if available).""" return self._metrics def _trigger_callbacks(self, status: JobStatus) -> None: """Trigger callbacks for a status.""" for callback in self._callbacks.get(status, []): try: callback(self) except Exception as e: print(f"Error in job callback for status {status}: {e}") def __repr__(self) -> str: return f"ComputeJob(id={self.job_id}, status={self._status.value})" class JobTimeoutError(Exception): """Error raised when a job times out.""" def __init__(self, job_id: str, timeout: float): super().__init__(f"Job {job_id} timed out after {timeout}s") self.job_id = job_id self.timeout = timeout class JobFailedError(Exception): """Error raised when a job fails.""" def __init__(self, job_id: str, reason: str): super().__init__(f"Job {job_id} failed: {reason}") self.job_id = job_id self.reason = reason class JobBatch: """ Batch multiple jobs for efficient execution. Example: >>> batch = JobBatch() >>> batch.add(job1).add(job2).add(job3) >>> results = await batch.wait_all() """ def __init__(self) -> None: self._jobs: list[ComputeJob] = [] def add(self, job: ComputeJob) -> "JobBatch": """Add a job to the batch.""" self._jobs.append(job) return self @property def jobs(self) -> list[ComputeJob]: """Get all jobs.""" return list(self._jobs) async def wait_all(self, timeout: Optional[float] = None) -> list[JobResult]: """Wait for all jobs to complete.""" return await asyncio.gather( *[job.wait(timeout=timeout) for job in self._jobs] ) async def wait_any(self, timeout: Optional[float] = None) -> JobResult: """Wait for any job to complete.""" done, _ = await asyncio.wait( [asyncio.create_task(job.wait(timeout=timeout)) for job in self._jobs], return_when=asyncio.FIRST_COMPLETED, ) return done.pop().result() async def cancel_all(self) -> None: """Cancel all jobs.""" await asyncio.gather(*[job.cancel() for job in self._jobs]) def get_status_counts(self) -> dict[JobStatus, int]: """Get count of jobs by status.""" counts = {status: 0 for status in JobStatus} for job in self._jobs: counts[job.status] += 1 return counts def __len__(self) -> int: return len(self._jobs)