synor/crates/synor-compute/src/scheduler/work_queue.rs
2026-01-26 22:03:40 +05:30

271 lines
7.4 KiB
Rust

//! Work queue with thread-safe task management.
use crate::processor::ProcessorType;
use crate::task::{Task, TaskId, TaskPriority};
use crossbeam_channel::{bounded, Receiver, Sender, TryRecvError};
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
/// Work queue for a specific processor type.
pub struct WorkQueue {
/// Task sender (for producers).
sender: Sender<Task>,
/// Task receiver (for consumers).
receiver: Receiver<Task>,
/// Processor type this queue is for.
processor_type: ProcessorType,
/// Current queue size.
size: AtomicU64,
/// Total tasks processed.
processed: AtomicU64,
}
impl WorkQueue {
/// Creates a new work queue for a processor type.
pub fn new(processor_type: ProcessorType, capacity: usize) -> Self {
let (sender, receiver) = bounded(capacity.max(1024));
Self {
sender,
receiver,
processor_type,
size: AtomicU64::new(0),
processed: AtomicU64::new(0),
}
}
/// Push a task to the queue.
pub fn push(&self, task: Task) {
if self.sender.try_send(task).is_ok() {
self.size.fetch_add(1, Ordering::Relaxed);
}
}
/// Pop a task from the queue (ignores worker_id for compatibility).
pub fn pop(&self, _worker_id: usize) -> Option<Task> {
self.pop_any()
}
/// Pop any task from the queue.
pub fn pop_any(&self) -> Option<Task> {
match self.receiver.try_recv() {
Ok(task) => {
self.size.fetch_sub(1, Ordering::Relaxed);
self.processed.fetch_add(1, Ordering::Relaxed);
Some(task)
}
Err(TryRecvError::Empty) | Err(TryRecvError::Disconnected) => None,
}
}
/// Pop from global queue (alias for pop_any).
pub fn pop_global(&self) -> Option<Task> {
self.pop_any()
}
/// Steal a batch of tasks from another queue.
pub fn steal_batch_from(&self, other: &WorkQueue, max_tasks: usize) -> Vec<Task> {
let mut stolen = Vec::new();
while stolen.len() < max_tasks {
if let Some(task) = other.pop_any() {
stolen.push(task);
} else {
break;
}
}
// Push stolen tasks to this queue
for task in &stolen {
// Tasks are already accounted for in `other`, just push to self
if self.sender.try_send(task.clone()).is_ok() {
self.size.fetch_add(1, Ordering::Relaxed);
}
}
stolen
}
/// Get current queue size.
pub fn len(&self) -> usize {
self.size.load(Ordering::Relaxed) as usize
}
/// Check if queue is empty.
pub fn is_empty(&self) -> bool {
self.len() == 0
}
/// Get number of tasks processed.
pub fn processed_count(&self) -> u64 {
self.processed.load(Ordering::Relaxed)
}
/// Get processor type for this queue.
pub fn processor_type(&self) -> ProcessorType {
self.processor_type.clone()
}
/// Get utilization estimate (0.0 - 1.0).
pub fn utilization(&self) -> f64 {
let size = self.size.load(Ordering::Relaxed) as f64;
let capacity = self.sender.capacity().unwrap_or(1024) as f64;
(size / capacity).min(1.0)
}
/// Get a stealer for cross-queue work stealing.
pub fn get_stealer(&self) -> QueueStealer {
QueueStealer {
receiver: self.receiver.clone(),
}
}
}
/// Stealer handle for cross-queue work stealing.
#[derive(Clone)]
pub struct QueueStealer {
receiver: Receiver<Task>,
}
impl QueueStealer {
/// Try to steal a task.
pub fn steal(&self) -> Option<Task> {
self.receiver.try_recv().ok()
}
}
/// Priority queue wrapper for tasks.
pub struct PriorityWorkQueue {
/// Queues by priority level.
queues: HashMap<TaskPriority, WorkQueue>,
/// Processor type.
processor_type: ProcessorType,
}
impl PriorityWorkQueue {
/// Creates a new priority work queue.
pub fn new(processor_type: ProcessorType, capacity_per_priority: usize) -> Self {
let mut queues = HashMap::new();
for priority in [
TaskPriority::Critical,
TaskPriority::High,
TaskPriority::Normal,
TaskPriority::Background,
] {
queues.insert(priority, WorkQueue::new(processor_type.clone(), capacity_per_priority));
}
Self {
queues,
processor_type,
}
}
/// Push a task with its priority.
pub fn push(&self, task: Task) {
let priority = task.priority;
if let Some(queue) = self.queues.get(&priority) {
queue.push(task);
}
}
/// Pop highest priority task available.
pub fn pop(&self, worker_id: usize) -> Option<Task> {
// Try priorities in order: Critical > High > Normal > Background
for priority in [
TaskPriority::Critical,
TaskPriority::High,
TaskPriority::Normal,
TaskPriority::Background,
] {
if let Some(queue) = self.queues.get(&priority) {
if let Some(task) = queue.pop(worker_id) {
return Some(task);
}
}
}
None
}
/// Get total queue size.
pub fn len(&self) -> usize {
self.queues.values().map(|q| q.len()).sum()
}
/// Check if all queues are empty.
pub fn is_empty(&self) -> bool {
self.queues.values().all(|q| q.is_empty())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::processor::{CpuVariant, Operation, Precision};
use crate::task::TaskStatus;
fn create_test_task(id: u64, priority: TaskPriority) -> Task {
Task {
id: TaskId(id),
operation: Operation::MatMul {
m: 1024,
n: 1024,
k: 1024,
precision: Precision::Fp32,
},
priority,
dependencies: vec![],
status: TaskStatus::Pending,
deadline: None,
}
}
#[test]
fn test_work_queue_basic() {
let queue = WorkQueue::new(
ProcessorType::Cpu(CpuVariant::default()),
100,
);
assert!(queue.is_empty());
queue.push(create_test_task(1, TaskPriority::Normal));
queue.push(create_test_task(2, TaskPriority::Normal));
assert_eq!(queue.len(), 2);
let task1 = queue.pop(0);
assert!(task1.is_some());
assert_eq!(queue.len(), 1);
let task2 = queue.pop(0);
assert!(task2.is_some());
assert!(queue.is_empty());
}
#[test]
fn test_priority_queue() {
let queue = PriorityWorkQueue::new(
ProcessorType::Cpu(CpuVariant::default()),
100,
);
queue.push(create_test_task(1, TaskPriority::Background));
queue.push(create_test_task(2, TaskPriority::Critical));
queue.push(create_test_task(3, TaskPriority::Normal));
// Should get Critical first
let task = queue.pop(0).unwrap();
assert_eq!(task.id, TaskId(2));
assert_eq!(task.priority, TaskPriority::Critical);
// Then Normal
let task = queue.pop(0).unwrap();
assert_eq!(task.id, TaskId(3));
// Then Background
let task = queue.pop(0).unwrap();
assert_eq!(task.id, TaskId(1));
}
}