diff --git a/web/backend/routers/tasks.py b/web/backend/routers/tasks.py new file mode 100644 index 0000000..19b151a --- /dev/null +++ b/web/backend/routers/tasks.py @@ -0,0 +1,159 @@ +"""Tasks API router: history query, stats, detail, and retry.""" + +import asyncio +from typing import Optional + +import httpx +from fastapi import APIRouter, Depends, HTTPException, Request + +from ..auth.dependencies import get_current_user +from ..services import db_schema + +router = APIRouter(prefix="/api/tasks", tags=["tasks"]) + +# Mapping from task name to the processing endpoint that retries it. +_RETRY_ROUTE_MAP = { + "批量OCR识别": "/api/processing/ocr-batch", + "Excel标准化处理": "/api/processing/excel", + "合并采购单": "/api/processing/merge", + "一键全流程处理": "/api/processing/pipeline", +} + + +@router.get("/stats") +async def task_stats( + current_user: dict = Depends(get_current_user), +): + """Return aggregate task statistics.""" + loop = asyncio.get_event_loop() + stats = await loop.run_in_executor(None, db_schema.query_task_stats) + # Ensure all expected keys are present. + return { + "total": stats.get("total", 0), + "completed": stats.get("completed", 0), + "failed": stats.get("failed", 0), + "running": stats.get("running", 0), + } + + +@router.get("") +async def list_tasks( + page: int = 1, + page_size: int = 50, + status: Optional[str] = None, + name: Optional[str] = None, + search: Optional[str] = None, + current_user: dict = Depends(get_current_user), +): + """List tasks with optional filters and pagination. + + ``search`` is applied as a general text filter (matches name). + """ + page_size = min(page_size, 200) + page = max(page, 1) + offset = (page - 1) * page_size + + # ``search`` maps to the ``name`` filter in the DB layer. + effective_name = search or name + + loop = asyncio.get_event_loop() + items = await loop.run_in_executor( + None, + lambda: db_schema.query_task_history( + status=status, + name=effective_name, + limit=page_size, + offset=offset, + ), + ) + + # Obtain total count for pagination. Re-run a lightweight count query. + def _count(): + import sqlite3 + from pathlib import Path + + db_path = Path(__file__).resolve().parent.parent.parent.parent / "data" / "web_data.db" + conn = sqlite3.connect(db_path) + try: + clauses: list[str] = [] + params: list = [] + if status: + clauses.append("status = ?") + params.append(status) + if effective_name: + clauses.append("name LIKE ?") + params.append(f"%{effective_name}%") + where = (" WHERE " + " AND ".join(clauses)) if clauses else "" + row = conn.execute( + f"SELECT COUNT(*) as cnt FROM task_history{where}", + params, + ).fetchone() + return row[0] if row else 0 + finally: + conn.close() + + total = await loop.run_in_executor(None, _count) + + return {"items": items, "total": total} + + +@router.get("/{task_id}") +async def get_task( + task_id: str, + current_user: dict = Depends(get_current_user), +): + """Get full task detail including log_lines and result_files.""" + loop = asyncio.get_event_loop() + task = await loop.run_in_executor( + None, lambda: db_schema.query_task_by_id(task_id), + ) + if task is None: + raise HTTPException(status_code=404, detail="任务不存在") + return task + + +@router.post("/{task_id}/retry") +async def retry_task( + task_id: str, + request: Request, + current_user: dict = Depends(get_current_user), +): + """Retry a failed task by re-invoking its processing endpoint. + + Only tasks with status ``failed`` may be retried. + """ + loop = asyncio.get_event_loop() + task = await loop.run_in_executor( + None, lambda: db_schema.query_task_by_id(task_id), + ) + if task is None: + raise HTTPException(status_code=404, detail="任务不存在") + if task.get("status") != "failed": + raise HTTPException( + status_code=400, + detail="只有失败的任务才能重试", + ) + + task_name = task.get("name", "") + endpoint = _RETRY_ROUTE_MAP.get(task_name) + if endpoint is None: + raise HTTPException( + status_code=400, + detail=f"未知的任务类型: {task_name}", + ) + + # Build the internal URL to the processing endpoint. + base_url = f"http://{request.url.hostname}:{request.url.port}" + url = f"{base_url}{endpoint}" + + # Forward the Authorization header so the processing endpoint can + # authenticate the request. + auth_header = request.headers.get("authorization") + headers: dict[str, str] = {} + if auth_header: + headers["authorization"] = auth_header + + async with httpx.AsyncClient() as client: + resp = await client.post(url, headers=headers) + + return resp.json()