diff --git a/web/backend/services/task_manager.py b/web/backend/services/task_manager.py new file mode 100644 index 0000000..05642e7 --- /dev/null +++ b/web/backend/services/task_manager.py @@ -0,0 +1,161 @@ +"""Background task tracking + WebSocket broadcast""" + +import json +import uuid +import asyncio +from datetime import datetime +from enum import Enum +from typing import Dict, List, Optional, Set +from dataclasses import dataclass, field + +from web.backend.services.db_schema import insert_task, update_task + + +class TaskStatus(str, Enum): + PENDING = "pending" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + + +@dataclass +class Task: + id: str + name: str + status: TaskStatus = TaskStatus.PENDING + progress: int = 0 + message: str = "" + result_files: List[str] = field(default_factory=list) + error: Optional[str] = None + log_lines: List[str] = field(default_factory=list) + + def to_dict(self) -> dict: + return { + "task_id": self.id, + "name": self.name, + "status": self.status.value, + "progress": self.progress, + "message": self.message, + "result_files": self.result_files, + "error": self.error, + "log_lines": self.log_lines[-100:], + } + + +class TaskManager: + def __init__(self): + self._tasks: Dict[str, Task] = {} + self._connections: Dict[str, Set] = {} + self._db = None # type: ignore + + def set_db_pool(self, db_pool): + """Set the DBPool reference for database persistence.""" + self._db = db_pool + + def create_task(self, name: str) -> Task: + task_id = str(uuid.uuid4())[:8] + task = Task(id=task_id, name=name) + self._tasks[task_id] = task + self._connections[task_id] = set() + if self._db: + asyncio.create_task( + self._db.execute_write(insert_task, task_id, name, TaskStatus.PENDING.value) + ) + return task + + def get_task(self, task_id: str) -> Optional[Task]: + return self._tasks.get(task_id) + + def update_progress(self, task_id: str, progress: int, message: str = ""): + task = self._tasks.get(task_id) + if not task: + return + # Auto-transition from PENDING to RUNNING on first progress update + if task.status == TaskStatus.PENDING: + task.status = TaskStatus.RUNNING + task.progress = progress + task.message = message + if self._db: + asyncio.create_task( + self._db.execute_write( + update_task, task_id, + status=task.status.value, progress=progress, message=message, + ) + ) + asyncio.create_task(self._broadcast(task_id)) + + def add_log(self, task_id: str, line: str): + task = self._tasks.get(task_id) + if not task: + return + task.log_lines.append(line) + if self._db: + asyncio.create_task( + self._db.execute_write( + update_task, task_id, + log_lines=json.dumps(task.log_lines[-200:]), + ) + ) + asyncio.create_task(self._broadcast(task_id)) + + def set_completed(self, task_id: str, result_files: List[str] = None, message: str = ""): + task = self._tasks.get(task_id) + if not task: + return + task.status = TaskStatus.COMPLETED + task.progress = 100 + task.message = message or "处理完成" + if result_files: + task.result_files = result_files + now = datetime.now().isoformat() + if self._db: + asyncio.create_task( + self._db.execute_write( + update_task, task_id, + status=TaskStatus.COMPLETED.value, progress=100, + message=task.message, + result_files=json.dumps(task.result_files), + completed_at=now, + ) + ) + asyncio.create_task(self._broadcast(task_id)) + + def set_failed(self, task_id: str, error: str): + task = self._tasks.get(task_id) + if not task: + return + task.status = TaskStatus.FAILED + task.error = error + task.message = f"处理失败: {error}" + now = datetime.now().isoformat() + if self._db: + asyncio.create_task( + self._db.execute_write( + update_task, task_id, + status=TaskStatus.FAILED.value, error=error, + message=task.message, completed_at=now, + ) + ) + asyncio.create_task(self._broadcast(task_id)) + + def subscribe(self, task_id: str, websocket): + if task_id in self._connections: + self._connections[task_id].add(websocket) + + def unsubscribe(self, task_id: str, websocket): + if task_id in self._connections: + self._connections[task_id].discard(websocket) + + async def _broadcast(self, task_id: str): + task = self._tasks.get(task_id) + if not task: + return + data = task.to_dict() + dead = set() + for ws in self._connections.get(task_id, set()): + try: + await ws.send_json(data) + except Exception: + dead.add(ws) + for ws in dead: + self._connections[task_id].discard(ws)