"""Background task tracking + WebSocket broadcast""" import json import uuid import asyncio from datetime import datetime from enum import Enum from typing import Dict, List, Optional, Set from dataclasses import dataclass, field from web.backend.services.db_schema import insert_task, update_task class TaskStatus(str, Enum): PENDING = "pending" RUNNING = "running" COMPLETED = "completed" FAILED = "failed" @dataclass class Task: id: str name: str status: TaskStatus = TaskStatus.PENDING progress: int = 0 message: str = "" result_files: List[str] = field(default_factory=list) error: Optional[str] = None log_lines: List[str] = field(default_factory=list) metadata: Optional[dict] = None def to_dict(self) -> dict: d = { "task_id": self.id, "name": self.name, "status": self.status.value, "progress": self.progress, "message": self.message, "result_files": self.result_files, "error": self.error, "log_lines": self.log_lines[-100:], } if self.metadata: d["metadata"] = self.metadata return d class TaskManager: def __init__(self): self._tasks: Dict[str, Task] = {} self._connections: Dict[str, Set] = {} self._db = None # type: ignore self._loop = None # captured event loop def set_db_pool(self, db_pool): """Set the DBPool reference for database persistence.""" self._db = db_pool try: self._loop = asyncio.get_running_loop() except RuntimeError: pass def _schedule(self, coro): """Schedule a coroutine from either async or thread context.""" try: loop = asyncio.get_running_loop() asyncio.ensure_future(coro, loop=loop) except RuntimeError: # No running loop — we're in a thread; schedule onto the main loop if self._loop and self._loop.is_running(): asyncio.run_coroutine_threadsafe(coro, self._loop) def create_task(self, name: str) -> Task: task_id = str(uuid.uuid4())[:8] task = Task(id=task_id, name=name) self._tasks[task_id] = task self._connections[task_id] = set() if self._db: self._schedule( self._db.execute_write(insert_task, task_id, name, TaskStatus.PENDING.value) ) return task def get_task(self, task_id: str) -> Optional[Task]: return self._tasks.get(task_id) def update_progress(self, task_id: str, progress: int, message: str = ""): task = self._tasks.get(task_id) if not task: return # Auto-transition from PENDING to RUNNING on first progress update if task.status == TaskStatus.PENDING: task.status = TaskStatus.RUNNING task.progress = progress task.message = message if self._db: self._schedule( self._db.execute_write( update_task, task_id, status=task.status.value, progress=progress, message=message, ) ) self._schedule(self._broadcast(task_id)) def add_log(self, task_id: str, line: str): task = self._tasks.get(task_id) if not task: return task.log_lines.append(line) if self._db: self._schedule( self._db.execute_write( update_task, task_id, log_lines=json.dumps(task.log_lines[-200:]), ) ) self._schedule(self._broadcast(task_id)) def set_completed(self, task_id: str, result_files: List[str] = None, message: str = ""): task = self._tasks.get(task_id) if not task: return task.status = TaskStatus.COMPLETED task.progress = 100 task.message = message or "处理完成" if result_files: task.result_files = result_files now = datetime.now().isoformat() if self._db: self._schedule( self._db.execute_write( update_task, task_id, status=TaskStatus.COMPLETED.value, progress=100, message=task.message, result_files=json.dumps(task.result_files), completed_at=now, ) ) self._schedule(self._broadcast(task_id)) def retry_task(self, task_id: str) -> Optional[Task]: """Create a new task to retry a failed task with its original parameters. Returns the new task if the original was failed and retryable, else None. The caller is responsible for dispatching the actual work based on ``new_task.metadata``. """ original = self._tasks.get(task_id) if not original or original.status != TaskStatus.FAILED: return None new_task = self.create_task(original.name) if original.metadata: new_task.metadata = dict(original.metadata) return new_task def set_failed(self, task_id: str, error: str): task = self._tasks.get(task_id) if not task: return task.status = TaskStatus.FAILED task.error = error task.message = f"处理失败: {error}" now = datetime.now().isoformat() if self._db: self._schedule( self._db.execute_write( update_task, task_id, status=TaskStatus.FAILED.value, error=error, message=task.message, completed_at=now, ) ) self._schedule(self._broadcast(task_id)) def subscribe(self, task_id: str, websocket): if task_id in self._connections: self._connections[task_id].add(websocket) def unsubscribe(self, task_id: str, websocket): if task_id in self._connections: self._connections[task_id].discard(websocket) async def _broadcast(self, task_id: str): task = self._tasks.get(task_id) if not task: return data = task.to_dict() dead = set() for ws in self._connections.get(task_id, set()): try: await ws.send_json(data) except Exception: dead.add(ws) for ws in dead: self._connections[task_id].discard(ws)