"""WebSocket endpoint for real-time task progress.""" from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query from ..auth.jwt_handler import decode_token from jose import JWTError router = APIRouter(tags=["websocket"]) @router.websocket("/ws/task/{task_id}") async def task_websocket( websocket: WebSocket, task_id: str, token: str = Query(...), ): """WebSocket for real-time task progress updates.""" try: payload = decode_token(token) username = payload.get("sub") if not username: await websocket.close(code=4001, reason="Invalid token") return except (JWTError, Exception): await websocket.close(code=4001, reason="Invalid token") return await websocket.accept() tm = websocket.app.state.task_manager task = tm.get_task(task_id) if not task: await websocket.send_json({"error": "任务不存在"}) await websocket.close() return tm.subscribe(task_id, websocket) await websocket.send_json(task.to_dict()) try: while True: data = await websocket.receive_text() if data == "ping": await websocket.send_text("pong") except WebSocketDisconnect: tm.unsubscribe(task_id, websocket) except Exception: tm.unsubscribe(task_id, websocket)