160 lines
4.7 KiB
Python
160 lines
4.7 KiB
Python
"""Tasks API router: history query, stats, detail, and retry."""
|
|
|
|
import asyncio
|
|
from typing import Optional
|
|
|
|
import httpx
|
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
|
|
|
from ..auth.dependencies import get_current_user
|
|
from ..services import db_schema
|
|
|
|
router = APIRouter(prefix="/api/tasks", tags=["tasks"])
|
|
|
|
# Mapping from task name to the processing endpoint that retries it.
|
|
_RETRY_ROUTE_MAP = {
|
|
"批量OCR识别": "/api/processing/ocr-batch",
|
|
"Excel标准化处理": "/api/processing/excel",
|
|
"合并采购单": "/api/processing/merge",
|
|
"一键全流程处理": "/api/processing/pipeline",
|
|
}
|
|
|
|
|
|
@router.get("/stats")
|
|
async def task_stats(
|
|
current_user: dict = Depends(get_current_user),
|
|
):
|
|
"""Return aggregate task statistics."""
|
|
loop = asyncio.get_event_loop()
|
|
stats = await loop.run_in_executor(None, db_schema.query_task_stats)
|
|
# Ensure all expected keys are present.
|
|
return {
|
|
"total": stats.get("total", 0),
|
|
"completed": stats.get("completed", 0),
|
|
"failed": stats.get("failed", 0),
|
|
"running": stats.get("running", 0),
|
|
}
|
|
|
|
|
|
@router.get("")
|
|
async def list_tasks(
|
|
page: int = 1,
|
|
page_size: int = 50,
|
|
status: Optional[str] = None,
|
|
name: Optional[str] = None,
|
|
search: Optional[str] = None,
|
|
current_user: dict = Depends(get_current_user),
|
|
):
|
|
"""List tasks with optional filters and pagination.
|
|
|
|
``search`` is applied as a general text filter (matches name).
|
|
"""
|
|
page_size = min(page_size, 200)
|
|
page = max(page, 1)
|
|
offset = (page - 1) * page_size
|
|
|
|
# ``search`` maps to the ``name`` filter in the DB layer.
|
|
effective_name = search or name
|
|
|
|
loop = asyncio.get_event_loop()
|
|
items = await loop.run_in_executor(
|
|
None,
|
|
lambda: db_schema.query_task_history(
|
|
status=status,
|
|
name=effective_name,
|
|
limit=page_size,
|
|
offset=offset,
|
|
),
|
|
)
|
|
|
|
# Obtain total count for pagination. Re-run a lightweight count query.
|
|
def _count():
|
|
import sqlite3
|
|
from pathlib import Path
|
|
|
|
db_path = Path(__file__).resolve().parent.parent.parent.parent / "data" / "web_data.db"
|
|
conn = sqlite3.connect(db_path)
|
|
try:
|
|
clauses: list[str] = []
|
|
params: list = []
|
|
if status:
|
|
clauses.append("status = ?")
|
|
params.append(status)
|
|
if effective_name:
|
|
clauses.append("name LIKE ?")
|
|
params.append(f"%{effective_name}%")
|
|
where = (" WHERE " + " AND ".join(clauses)) if clauses else ""
|
|
row = conn.execute(
|
|
f"SELECT COUNT(*) as cnt FROM task_history{where}",
|
|
params,
|
|
).fetchone()
|
|
return row[0] if row else 0
|
|
finally:
|
|
conn.close()
|
|
|
|
total = await loop.run_in_executor(None, _count)
|
|
|
|
return {"items": items, "total": total}
|
|
|
|
|
|
@router.get("/{task_id}")
|
|
async def get_task(
|
|
task_id: str,
|
|
current_user: dict = Depends(get_current_user),
|
|
):
|
|
"""Get full task detail including log_lines and result_files."""
|
|
loop = asyncio.get_event_loop()
|
|
task = await loop.run_in_executor(
|
|
None, lambda: db_schema.query_task_by_id(task_id),
|
|
)
|
|
if task is None:
|
|
raise HTTPException(status_code=404, detail="任务不存在")
|
|
return task
|
|
|
|
|
|
@router.post("/{task_id}/retry")
|
|
async def retry_task(
|
|
task_id: str,
|
|
request: Request,
|
|
current_user: dict = Depends(get_current_user),
|
|
):
|
|
"""Retry a failed task by re-invoking its processing endpoint.
|
|
|
|
Only tasks with status ``failed`` may be retried.
|
|
"""
|
|
loop = asyncio.get_event_loop()
|
|
task = await loop.run_in_executor(
|
|
None, lambda: db_schema.query_task_by_id(task_id),
|
|
)
|
|
if task is None:
|
|
raise HTTPException(status_code=404, detail="任务不存在")
|
|
if task.get("status") != "failed":
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="只有失败的任务才能重试",
|
|
)
|
|
|
|
task_name = task.get("name", "")
|
|
endpoint = _RETRY_ROUTE_MAP.get(task_name)
|
|
if endpoint is None:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"未知的任务类型: {task_name}",
|
|
)
|
|
|
|
# Build the internal URL to the processing endpoint.
|
|
base_url = f"http://{request.url.hostname}:{request.url.port}"
|
|
url = f"{base_url}{endpoint}"
|
|
|
|
# Forward the Authorization header so the processing endpoint can
|
|
# authenticate the request.
|
|
auth_header = request.headers.get("authorization")
|
|
headers: dict[str, str] = {}
|
|
if auth_header:
|
|
headers["authorization"] = auth_header
|
|
|
|
async with httpx.AsyncClient() as client:
|
|
resp = await client.post(url, headers=headers)
|
|
|
|
return resp.json()
|