feat: add task history query and retry API endpoints

This commit is contained in:
2026-05-05 11:33:40 +08:00
parent c49105a678
commit 79522d8356
+159
View File
@@ -0,0 +1,159 @@
"""Tasks API router: history query, stats, detail, and retry."""
import asyncio
from typing import Optional
import httpx
from fastapi import APIRouter, Depends, HTTPException, Request
from ..auth.dependencies import get_current_user
from ..services import db_schema
router = APIRouter(prefix="/api/tasks", tags=["tasks"])
# Mapping from task name to the processing endpoint that retries it.
_RETRY_ROUTE_MAP = {
"批量OCR识别": "/api/processing/ocr-batch",
"Excel标准化处理": "/api/processing/excel",
"合并采购单": "/api/processing/merge",
"一键全流程处理": "/api/processing/pipeline",
}
@router.get("/stats")
async def task_stats(
current_user: dict = Depends(get_current_user),
):
"""Return aggregate task statistics."""
loop = asyncio.get_event_loop()
stats = await loop.run_in_executor(None, db_schema.query_task_stats)
# Ensure all expected keys are present.
return {
"total": stats.get("total", 0),
"completed": stats.get("completed", 0),
"failed": stats.get("failed", 0),
"running": stats.get("running", 0),
}
@router.get("")
async def list_tasks(
page: int = 1,
page_size: int = 50,
status: Optional[str] = None,
name: Optional[str] = None,
search: Optional[str] = None,
current_user: dict = Depends(get_current_user),
):
"""List tasks with optional filters and pagination.
``search`` is applied as a general text filter (matches name).
"""
page_size = min(page_size, 200)
page = max(page, 1)
offset = (page - 1) * page_size
# ``search`` maps to the ``name`` filter in the DB layer.
effective_name = search or name
loop = asyncio.get_event_loop()
items = await loop.run_in_executor(
None,
lambda: db_schema.query_task_history(
status=status,
name=effective_name,
limit=page_size,
offset=offset,
),
)
# Obtain total count for pagination. Re-run a lightweight count query.
def _count():
import sqlite3
from pathlib import Path
db_path = Path(__file__).resolve().parent.parent.parent.parent / "data" / "web_data.db"
conn = sqlite3.connect(db_path)
try:
clauses: list[str] = []
params: list = []
if status:
clauses.append("status = ?")
params.append(status)
if effective_name:
clauses.append("name LIKE ?")
params.append(f"%{effective_name}%")
where = (" WHERE " + " AND ".join(clauses)) if clauses else ""
row = conn.execute(
f"SELECT COUNT(*) as cnt FROM task_history{where}",
params,
).fetchone()
return row[0] if row else 0
finally:
conn.close()
total = await loop.run_in_executor(None, _count)
return {"items": items, "total": total}
@router.get("/{task_id}")
async def get_task(
task_id: str,
current_user: dict = Depends(get_current_user),
):
"""Get full task detail including log_lines and result_files."""
loop = asyncio.get_event_loop()
task = await loop.run_in_executor(
None, lambda: db_schema.query_task_by_id(task_id),
)
if task is None:
raise HTTPException(status_code=404, detail="任务不存在")
return task
@router.post("/{task_id}/retry")
async def retry_task(
task_id: str,
request: Request,
current_user: dict = Depends(get_current_user),
):
"""Retry a failed task by re-invoking its processing endpoint.
Only tasks with status ``failed`` may be retried.
"""
loop = asyncio.get_event_loop()
task = await loop.run_in_executor(
None, lambda: db_schema.query_task_by_id(task_id),
)
if task is None:
raise HTTPException(status_code=404, detail="任务不存在")
if task.get("status") != "failed":
raise HTTPException(
status_code=400,
detail="只有失败的任务才能重试",
)
task_name = task.get("name", "")
endpoint = _RETRY_ROUTE_MAP.get(task_name)
if endpoint is None:
raise HTTPException(
status_code=400,
detail=f"未知的任务类型: {task_name}",
)
# Build the internal URL to the processing endpoint.
base_url = f"http://{request.url.hostname}:{request.url.port}"
url = f"{base_url}{endpoint}"
# Forward the Authorization header so the processing endpoint can
# authenticate the request.
auth_header = request.headers.get("authorization")
headers: dict[str, str] = {}
if auth_header:
headers["authorization"] = auth_header
async with httpx.AsyncClient() as client:
resp = await client.post(url, headers=headers)
return resp.json()