feat: persist task lifecycle to SQLite via TaskManager
This commit is contained in:
@@ -0,0 +1,161 @@
|
|||||||
|
"""Background task tracking + WebSocket broadcast"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
import asyncio
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Dict, List, Optional, Set
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from web.backend.services.db_schema import insert_task, update_task
|
||||||
|
|
||||||
|
|
||||||
|
class TaskStatus(str, Enum):
|
||||||
|
PENDING = "pending"
|
||||||
|
RUNNING = "running"
|
||||||
|
COMPLETED = "completed"
|
||||||
|
FAILED = "failed"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Task:
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
status: TaskStatus = TaskStatus.PENDING
|
||||||
|
progress: int = 0
|
||||||
|
message: str = ""
|
||||||
|
result_files: List[str] = field(default_factory=list)
|
||||||
|
error: Optional[str] = None
|
||||||
|
log_lines: List[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
"task_id": self.id,
|
||||||
|
"name": self.name,
|
||||||
|
"status": self.status.value,
|
||||||
|
"progress": self.progress,
|
||||||
|
"message": self.message,
|
||||||
|
"result_files": self.result_files,
|
||||||
|
"error": self.error,
|
||||||
|
"log_lines": self.log_lines[-100:],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TaskManager:
|
||||||
|
def __init__(self):
|
||||||
|
self._tasks: Dict[str, Task] = {}
|
||||||
|
self._connections: Dict[str, Set] = {}
|
||||||
|
self._db = None # type: ignore
|
||||||
|
|
||||||
|
def set_db_pool(self, db_pool):
|
||||||
|
"""Set the DBPool reference for database persistence."""
|
||||||
|
self._db = db_pool
|
||||||
|
|
||||||
|
def create_task(self, name: str) -> Task:
|
||||||
|
task_id = str(uuid.uuid4())[:8]
|
||||||
|
task = Task(id=task_id, name=name)
|
||||||
|
self._tasks[task_id] = task
|
||||||
|
self._connections[task_id] = set()
|
||||||
|
if self._db:
|
||||||
|
asyncio.create_task(
|
||||||
|
self._db.execute_write(insert_task, task_id, name, TaskStatus.PENDING.value)
|
||||||
|
)
|
||||||
|
return task
|
||||||
|
|
||||||
|
def get_task(self, task_id: str) -> Optional[Task]:
|
||||||
|
return self._tasks.get(task_id)
|
||||||
|
|
||||||
|
def update_progress(self, task_id: str, progress: int, message: str = ""):
|
||||||
|
task = self._tasks.get(task_id)
|
||||||
|
if not task:
|
||||||
|
return
|
||||||
|
# Auto-transition from PENDING to RUNNING on first progress update
|
||||||
|
if task.status == TaskStatus.PENDING:
|
||||||
|
task.status = TaskStatus.RUNNING
|
||||||
|
task.progress = progress
|
||||||
|
task.message = message
|
||||||
|
if self._db:
|
||||||
|
asyncio.create_task(
|
||||||
|
self._db.execute_write(
|
||||||
|
update_task, task_id,
|
||||||
|
status=task.status.value, progress=progress, message=message,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
asyncio.create_task(self._broadcast(task_id))
|
||||||
|
|
||||||
|
def add_log(self, task_id: str, line: str):
|
||||||
|
task = self._tasks.get(task_id)
|
||||||
|
if not task:
|
||||||
|
return
|
||||||
|
task.log_lines.append(line)
|
||||||
|
if self._db:
|
||||||
|
asyncio.create_task(
|
||||||
|
self._db.execute_write(
|
||||||
|
update_task, task_id,
|
||||||
|
log_lines=json.dumps(task.log_lines[-200:]),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
asyncio.create_task(self._broadcast(task_id))
|
||||||
|
|
||||||
|
def set_completed(self, task_id: str, result_files: List[str] = None, message: str = ""):
|
||||||
|
task = self._tasks.get(task_id)
|
||||||
|
if not task:
|
||||||
|
return
|
||||||
|
task.status = TaskStatus.COMPLETED
|
||||||
|
task.progress = 100
|
||||||
|
task.message = message or "处理完成"
|
||||||
|
if result_files:
|
||||||
|
task.result_files = result_files
|
||||||
|
now = datetime.now().isoformat()
|
||||||
|
if self._db:
|
||||||
|
asyncio.create_task(
|
||||||
|
self._db.execute_write(
|
||||||
|
update_task, task_id,
|
||||||
|
status=TaskStatus.COMPLETED.value, progress=100,
|
||||||
|
message=task.message,
|
||||||
|
result_files=json.dumps(task.result_files),
|
||||||
|
completed_at=now,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
asyncio.create_task(self._broadcast(task_id))
|
||||||
|
|
||||||
|
def set_failed(self, task_id: str, error: str):
|
||||||
|
task = self._tasks.get(task_id)
|
||||||
|
if not task:
|
||||||
|
return
|
||||||
|
task.status = TaskStatus.FAILED
|
||||||
|
task.error = error
|
||||||
|
task.message = f"处理失败: {error}"
|
||||||
|
now = datetime.now().isoformat()
|
||||||
|
if self._db:
|
||||||
|
asyncio.create_task(
|
||||||
|
self._db.execute_write(
|
||||||
|
update_task, task_id,
|
||||||
|
status=TaskStatus.FAILED.value, error=error,
|
||||||
|
message=task.message, completed_at=now,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
asyncio.create_task(self._broadcast(task_id))
|
||||||
|
|
||||||
|
def subscribe(self, task_id: str, websocket):
|
||||||
|
if task_id in self._connections:
|
||||||
|
self._connections[task_id].add(websocket)
|
||||||
|
|
||||||
|
def unsubscribe(self, task_id: str, websocket):
|
||||||
|
if task_id in self._connections:
|
||||||
|
self._connections[task_id].discard(websocket)
|
||||||
|
|
||||||
|
async def _broadcast(self, task_id: str):
|
||||||
|
task = self._tasks.get(task_id)
|
||||||
|
if not task:
|
||||||
|
return
|
||||||
|
data = task.to_dict()
|
||||||
|
dead = set()
|
||||||
|
for ws in self._connections.get(task_id, set()):
|
||||||
|
try:
|
||||||
|
await ws.send_json(data)
|
||||||
|
except Exception:
|
||||||
|
dead.add(ws)
|
||||||
|
for ws in dead:
|
||||||
|
self._connections[task_id].discard(ws)
|
||||||
Reference in New Issue
Block a user