//! Work queue with thread-safe task management. use crate::processor::ProcessorType; use crate::task::{Task, TaskId, TaskPriority}; use crossbeam_channel::{bounded, Receiver, Sender, TryRecvError}; use std::collections::HashMap; use std::sync::atomic::{AtomicU64, Ordering}; /// Work queue for a specific processor type. pub struct WorkQueue { /// Task sender (for producers). sender: Sender, /// Task receiver (for consumers). receiver: Receiver, /// Processor type this queue is for. processor_type: ProcessorType, /// Current queue size. size: AtomicU64, /// Total tasks processed. processed: AtomicU64, } impl WorkQueue { /// Creates a new work queue for a processor type. pub fn new(processor_type: ProcessorType, capacity: usize) -> Self { let (sender, receiver) = bounded(capacity.max(1024)); Self { sender, receiver, processor_type, size: AtomicU64::new(0), processed: AtomicU64::new(0), } } /// Push a task to the queue. pub fn push(&self, task: Task) { if self.sender.try_send(task).is_ok() { self.size.fetch_add(1, Ordering::Relaxed); } } /// Pop a task from the queue (ignores worker_id for compatibility). pub fn pop(&self, _worker_id: usize) -> Option { self.pop_any() } /// Pop any task from the queue. pub fn pop_any(&self) -> Option { match self.receiver.try_recv() { Ok(task) => { self.size.fetch_sub(1, Ordering::Relaxed); self.processed.fetch_add(1, Ordering::Relaxed); Some(task) } Err(TryRecvError::Empty) | Err(TryRecvError::Disconnected) => None, } } /// Pop from global queue (alias for pop_any). pub fn pop_global(&self) -> Option { self.pop_any() } /// Steal a batch of tasks from another queue. pub fn steal_batch_from(&self, other: &WorkQueue, max_tasks: usize) -> Vec { let mut stolen = Vec::new(); while stolen.len() < max_tasks { if let Some(task) = other.pop_any() { stolen.push(task); } else { break; } } // Push stolen tasks to this queue for task in &stolen { // Tasks are already accounted for in `other`, just push to self if self.sender.try_send(task.clone()).is_ok() { self.size.fetch_add(1, Ordering::Relaxed); } } stolen } /// Get current queue size. pub fn len(&self) -> usize { self.size.load(Ordering::Relaxed) as usize } /// Check if queue is empty. pub fn is_empty(&self) -> bool { self.len() == 0 } /// Get number of tasks processed. pub fn processed_count(&self) -> u64 { self.processed.load(Ordering::Relaxed) } /// Get processor type for this queue. pub fn processor_type(&self) -> ProcessorType { self.processor_type.clone() } /// Get utilization estimate (0.0 - 1.0). pub fn utilization(&self) -> f64 { let size = self.size.load(Ordering::Relaxed) as f64; let capacity = self.sender.capacity().unwrap_or(1024) as f64; (size / capacity).min(1.0) } /// Get a stealer for cross-queue work stealing. pub fn get_stealer(&self) -> QueueStealer { QueueStealer { receiver: self.receiver.clone(), } } } /// Stealer handle for cross-queue work stealing. #[derive(Clone)] pub struct QueueStealer { receiver: Receiver, } impl QueueStealer { /// Try to steal a task. pub fn steal(&self) -> Option { self.receiver.try_recv().ok() } } /// Priority queue wrapper for tasks. pub struct PriorityWorkQueue { /// Queues by priority level. queues: HashMap, /// Processor type. processor_type: ProcessorType, } impl PriorityWorkQueue { /// Creates a new priority work queue. pub fn new(processor_type: ProcessorType, capacity_per_priority: usize) -> Self { let mut queues = HashMap::new(); for priority in [ TaskPriority::Critical, TaskPriority::High, TaskPriority::Normal, TaskPriority::Background, ] { queues.insert(priority, WorkQueue::new(processor_type.clone(), capacity_per_priority)); } Self { queues, processor_type, } } /// Push a task with its priority. pub fn push(&self, task: Task) { let priority = task.priority; if let Some(queue) = self.queues.get(&priority) { queue.push(task); } } /// Pop highest priority task available. pub fn pop(&self, worker_id: usize) -> Option { // Try priorities in order: Critical > High > Normal > Background for priority in [ TaskPriority::Critical, TaskPriority::High, TaskPriority::Normal, TaskPriority::Background, ] { if let Some(queue) = self.queues.get(&priority) { if let Some(task) = queue.pop(worker_id) { return Some(task); } } } None } /// Get total queue size. pub fn len(&self) -> usize { self.queues.values().map(|q| q.len()).sum() } /// Check if all queues are empty. pub fn is_empty(&self) -> bool { self.queues.values().all(|q| q.is_empty()) } } #[cfg(test)] mod tests { use super::*; use crate::processor::{CpuVariant, Operation, Precision}; use crate::task::TaskStatus; fn create_test_task(id: u64, priority: TaskPriority) -> Task { Task { id: TaskId(id), operation: Operation::MatMul { m: 1024, n: 1024, k: 1024, precision: Precision::Fp32, }, priority, dependencies: vec![], status: TaskStatus::Pending, deadline: None, } } #[test] fn test_work_queue_basic() { let queue = WorkQueue::new( ProcessorType::Cpu(CpuVariant::default()), 100, ); assert!(queue.is_empty()); queue.push(create_test_task(1, TaskPriority::Normal)); queue.push(create_test_task(2, TaskPriority::Normal)); assert_eq!(queue.len(), 2); let task1 = queue.pop(0); assert!(task1.is_some()); assert_eq!(queue.len(), 1); let task2 = queue.pop(0); assert!(task2.is_some()); assert!(queue.is_empty()); } #[test] fn test_priority_queue() { let queue = PriorityWorkQueue::new( ProcessorType::Cpu(CpuVariant::default()), 100, ); queue.push(create_test_task(1, TaskPriority::Background)); queue.push(create_test_task(2, TaskPriority::Critical)); queue.push(create_test_task(3, TaskPriority::Normal)); // Should get Critical first let task = queue.pop(0).unwrap(); assert_eq!(task.id, TaskId(2)); assert_eq!(task.priority, TaskPriority::Critical); // Then Normal let task = queue.pop(0).unwrap(); assert_eq!(task.id, TaskId(3)); // Then Background let task = queue.pop(0).unwrap(); assert_eq!(task.id, TaskId(1)); } }