271 lines
7.4 KiB
Rust
271 lines
7.4 KiB
Rust
//! Work queue with thread-safe task management.
|
|
|
|
use crate::processor::ProcessorType;
|
|
use crate::task::{Task, TaskId, TaskPriority};
|
|
use crossbeam_channel::{bounded, Receiver, Sender, TryRecvError};
|
|
use std::collections::HashMap;
|
|
use std::sync::atomic::{AtomicU64, Ordering};
|
|
|
|
/// Work queue for a specific processor type.
|
|
pub struct WorkQueue {
|
|
/// Task sender (for producers).
|
|
sender: Sender<Task>,
|
|
/// Task receiver (for consumers).
|
|
receiver: Receiver<Task>,
|
|
/// Processor type this queue is for.
|
|
processor_type: ProcessorType,
|
|
/// Current queue size.
|
|
size: AtomicU64,
|
|
/// Total tasks processed.
|
|
processed: AtomicU64,
|
|
}
|
|
|
|
impl WorkQueue {
|
|
/// Creates a new work queue for a processor type.
|
|
pub fn new(processor_type: ProcessorType, capacity: usize) -> Self {
|
|
let (sender, receiver) = bounded(capacity.max(1024));
|
|
|
|
Self {
|
|
sender,
|
|
receiver,
|
|
processor_type,
|
|
size: AtomicU64::new(0),
|
|
processed: AtomicU64::new(0),
|
|
}
|
|
}
|
|
|
|
/// Push a task to the queue.
|
|
pub fn push(&self, task: Task) {
|
|
if self.sender.try_send(task).is_ok() {
|
|
self.size.fetch_add(1, Ordering::Relaxed);
|
|
}
|
|
}
|
|
|
|
/// Pop a task from the queue (ignores worker_id for compatibility).
|
|
pub fn pop(&self, _worker_id: usize) -> Option<Task> {
|
|
self.pop_any()
|
|
}
|
|
|
|
/// Pop any task from the queue.
|
|
pub fn pop_any(&self) -> Option<Task> {
|
|
match self.receiver.try_recv() {
|
|
Ok(task) => {
|
|
self.size.fetch_sub(1, Ordering::Relaxed);
|
|
self.processed.fetch_add(1, Ordering::Relaxed);
|
|
Some(task)
|
|
}
|
|
Err(TryRecvError::Empty) | Err(TryRecvError::Disconnected) => None,
|
|
}
|
|
}
|
|
|
|
/// Pop from global queue (alias for pop_any).
|
|
pub fn pop_global(&self) -> Option<Task> {
|
|
self.pop_any()
|
|
}
|
|
|
|
/// Steal a batch of tasks from another queue.
|
|
pub fn steal_batch_from(&self, other: &WorkQueue, max_tasks: usize) -> Vec<Task> {
|
|
let mut stolen = Vec::new();
|
|
|
|
while stolen.len() < max_tasks {
|
|
if let Some(task) = other.pop_any() {
|
|
stolen.push(task);
|
|
} else {
|
|
break;
|
|
}
|
|
}
|
|
|
|
// Push stolen tasks to this queue
|
|
for task in &stolen {
|
|
// Tasks are already accounted for in `other`, just push to self
|
|
if self.sender.try_send(task.clone()).is_ok() {
|
|
self.size.fetch_add(1, Ordering::Relaxed);
|
|
}
|
|
}
|
|
|
|
stolen
|
|
}
|
|
|
|
/// Get current queue size.
|
|
pub fn len(&self) -> usize {
|
|
self.size.load(Ordering::Relaxed) as usize
|
|
}
|
|
|
|
/// Check if queue is empty.
|
|
pub fn is_empty(&self) -> bool {
|
|
self.len() == 0
|
|
}
|
|
|
|
/// Get number of tasks processed.
|
|
pub fn processed_count(&self) -> u64 {
|
|
self.processed.load(Ordering::Relaxed)
|
|
}
|
|
|
|
/// Get processor type for this queue.
|
|
pub fn processor_type(&self) -> ProcessorType {
|
|
self.processor_type.clone()
|
|
}
|
|
|
|
/// Get utilization estimate (0.0 - 1.0).
|
|
pub fn utilization(&self) -> f64 {
|
|
let size = self.size.load(Ordering::Relaxed) as f64;
|
|
let capacity = self.sender.capacity().unwrap_or(1024) as f64;
|
|
(size / capacity).min(1.0)
|
|
}
|
|
|
|
/// Get a stealer for cross-queue work stealing.
|
|
pub fn get_stealer(&self) -> QueueStealer {
|
|
QueueStealer {
|
|
receiver: self.receiver.clone(),
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Stealer handle for cross-queue work stealing.
|
|
#[derive(Clone)]
|
|
pub struct QueueStealer {
|
|
receiver: Receiver<Task>,
|
|
}
|
|
|
|
impl QueueStealer {
|
|
/// Try to steal a task.
|
|
pub fn steal(&self) -> Option<Task> {
|
|
self.receiver.try_recv().ok()
|
|
}
|
|
}
|
|
|
|
/// Priority queue wrapper for tasks.
|
|
pub struct PriorityWorkQueue {
|
|
/// Queues by priority level.
|
|
queues: HashMap<TaskPriority, WorkQueue>,
|
|
/// Processor type.
|
|
processor_type: ProcessorType,
|
|
}
|
|
|
|
impl PriorityWorkQueue {
|
|
/// Creates a new priority work queue.
|
|
pub fn new(processor_type: ProcessorType, capacity_per_priority: usize) -> Self {
|
|
let mut queues = HashMap::new();
|
|
|
|
for priority in [
|
|
TaskPriority::Critical,
|
|
TaskPriority::High,
|
|
TaskPriority::Normal,
|
|
TaskPriority::Background,
|
|
] {
|
|
queues.insert(priority, WorkQueue::new(processor_type.clone(), capacity_per_priority));
|
|
}
|
|
|
|
Self {
|
|
queues,
|
|
processor_type,
|
|
}
|
|
}
|
|
|
|
/// Push a task with its priority.
|
|
pub fn push(&self, task: Task) {
|
|
let priority = task.priority;
|
|
if let Some(queue) = self.queues.get(&priority) {
|
|
queue.push(task);
|
|
}
|
|
}
|
|
|
|
/// Pop highest priority task available.
|
|
pub fn pop(&self, worker_id: usize) -> Option<Task> {
|
|
// Try priorities in order: Critical > High > Normal > Background
|
|
for priority in [
|
|
TaskPriority::Critical,
|
|
TaskPriority::High,
|
|
TaskPriority::Normal,
|
|
TaskPriority::Background,
|
|
] {
|
|
if let Some(queue) = self.queues.get(&priority) {
|
|
if let Some(task) = queue.pop(worker_id) {
|
|
return Some(task);
|
|
}
|
|
}
|
|
}
|
|
None
|
|
}
|
|
|
|
/// Get total queue size.
|
|
pub fn len(&self) -> usize {
|
|
self.queues.values().map(|q| q.len()).sum()
|
|
}
|
|
|
|
/// Check if all queues are empty.
|
|
pub fn is_empty(&self) -> bool {
|
|
self.queues.values().all(|q| q.is_empty())
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use crate::processor::{CpuVariant, Operation, Precision};
|
|
use crate::task::TaskStatus;
|
|
|
|
fn create_test_task(id: u64, priority: TaskPriority) -> Task {
|
|
Task {
|
|
id: TaskId(id),
|
|
operation: Operation::MatMul {
|
|
m: 1024,
|
|
n: 1024,
|
|
k: 1024,
|
|
precision: Precision::Fp32,
|
|
},
|
|
priority,
|
|
dependencies: vec![],
|
|
status: TaskStatus::Pending,
|
|
deadline: None,
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_work_queue_basic() {
|
|
let queue = WorkQueue::new(
|
|
ProcessorType::Cpu(CpuVariant::default()),
|
|
100,
|
|
);
|
|
|
|
assert!(queue.is_empty());
|
|
|
|
queue.push(create_test_task(1, TaskPriority::Normal));
|
|
queue.push(create_test_task(2, TaskPriority::Normal));
|
|
|
|
assert_eq!(queue.len(), 2);
|
|
|
|
let task1 = queue.pop(0);
|
|
assert!(task1.is_some());
|
|
assert_eq!(queue.len(), 1);
|
|
|
|
let task2 = queue.pop(0);
|
|
assert!(task2.is_some());
|
|
assert!(queue.is_empty());
|
|
}
|
|
|
|
#[test]
|
|
fn test_priority_queue() {
|
|
let queue = PriorityWorkQueue::new(
|
|
ProcessorType::Cpu(CpuVariant::default()),
|
|
100,
|
|
);
|
|
|
|
queue.push(create_test_task(1, TaskPriority::Background));
|
|
queue.push(create_test_task(2, TaskPriority::Critical));
|
|
queue.push(create_test_task(3, TaskPriority::Normal));
|
|
|
|
// Should get Critical first
|
|
let task = queue.pop(0).unwrap();
|
|
assert_eq!(task.id, TaskId(2));
|
|
assert_eq!(task.priority, TaskPriority::Critical);
|
|
|
|
// Then Normal
|
|
let task = queue.pop(0).unwrap();
|
|
assert_eq!(task.id, TaskId(3));
|
|
|
|
// Then Background
|
|
let task = queue.pop(0).unwrap();
|
|
assert_eq!(task.id, TaskId(1));
|
|
}
|
|
}
|