feat: complete web application — FastAPI backend + Vue 3 SPA frontend
- Full FastAPI backend with JWT auth, file management, processing pipeline, memory CRUD, barcode mappings, config management, cloud sync - Vue 3 + Element Plus frontend with dashboard, task history, HTTP logs, memory editor, barcode editor, config editor, sync page - HTTP request logging middleware with SQLite persistence - Task history tracking with progress and retry support - File metadata recording for upload/download operations - WebAuth section in config.ini for bcrypt password storage - Bug fix: logs.py count query returns tuple not dict Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,58 @@
|
||||
"""FastAPI auth dependencies"""
|
||||
|
||||
from fastapi import Depends, HTTPException, status, Query, Request
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
|
||||
from .jwt_handler import decode_token
|
||||
|
||||
security = HTTPBearer()
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||
) -> dict:
|
||||
try:
|
||||
payload = decode_token(credentials.credentials)
|
||||
username = payload.get("sub")
|
||||
if username is None:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
|
||||
return {"username": username}
|
||||
except Exception:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的认证凭据")
|
||||
|
||||
|
||||
async def get_current_user_ws(token: str = Query(...)) -> dict:
|
||||
"""WebSocket auth via query parameter"""
|
||||
try:
|
||||
payload = decode_token(token)
|
||||
username = payload.get("sub")
|
||||
if username is None:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
|
||||
return {"username": username}
|
||||
except Exception:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的认证凭据")
|
||||
|
||||
|
||||
async def get_current_user_flexible(
|
||||
request: Request,
|
||||
credentials: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)),
|
||||
token: str = Query(None),
|
||||
) -> dict:
|
||||
"""Auth from header OR query param (for file downloads in browser)."""
|
||||
token_str = None
|
||||
if credentials:
|
||||
token_str = credentials.credentials
|
||||
elif token:
|
||||
token_str = token
|
||||
|
||||
if not token_str:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="未提供认证凭据")
|
||||
|
||||
try:
|
||||
payload = decode_token(token_str)
|
||||
username = payload.get("sub")
|
||||
if username is None:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
|
||||
return {"username": username}
|
||||
except Exception:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的认证凭据")
|
||||
@@ -0,0 +1,19 @@
|
||||
"""JWT token creation and validation"""
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional
|
||||
|
||||
from jose import jwt, JWTError
|
||||
|
||||
from ..config import get_or_generate_secret, JWT_ALGORITHM, JWT_EXPIRE_HOURS
|
||||
|
||||
|
||||
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
||||
to_encode = data.copy()
|
||||
expire = datetime.now(timezone.utc) + (expires_delta or timedelta(hours=JWT_EXPIRE_HOURS))
|
||||
to_encode.update({"exp": expire})
|
||||
return jwt.encode(to_encode, get_or_generate_secret(), algorithm=JWT_ALGORITHM)
|
||||
|
||||
|
||||
def decode_token(token: str) -> dict:
|
||||
return jwt.decode(token, get_or_generate_secret(), algorithms=[JWT_ALGORITHM])
|
||||
@@ -0,0 +1,89 @@
|
||||
"""Auth API endpoints"""
|
||||
|
||||
import os
|
||||
import bcrypt
|
||||
from fastapi import APIRouter, HTTPException, Depends, status
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .jwt_handler import create_access_token
|
||||
from .dependencies import get_current_user
|
||||
|
||||
router = APIRouter(prefix="/api/auth", tags=["auth"])
|
||||
|
||||
# Default credentials (should be changed on first login)
|
||||
DEFAULT_USERNAME = "admin"
|
||||
DEFAULT_PASSWORD = "admin123"
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
|
||||
|
||||
class LoginResponse(BaseModel):
|
||||
access_token: str
|
||||
token_type: str = "bearer"
|
||||
|
||||
|
||||
def _get_credentials() -> tuple[str, bytes]:
|
||||
"""Get username and password hash from config or defaults"""
|
||||
try:
|
||||
from app.config.settings import ConfigManager
|
||||
cfg = ConfigManager()
|
||||
username = cfg.get('WebAuth', 'username', fallback=DEFAULT_USERNAME)
|
||||
pw_hash = cfg.get('WebAuth', 'password_hash', fallback='')
|
||||
if not pw_hash:
|
||||
# First run: store default password hash
|
||||
pw_hash = bcrypt.hashpw(DEFAULT_PASSWORD.encode(), bcrypt.gensalt()).decode()
|
||||
try:
|
||||
cfg.update('WebAuth', 'username', DEFAULT_USERNAME)
|
||||
cfg.update('WebAuth', 'password_hash', pw_hash)
|
||||
cfg.save_config()
|
||||
except Exception:
|
||||
pass
|
||||
return username, pw_hash.encode()
|
||||
except Exception:
|
||||
return DEFAULT_USERNAME, bcrypt.hashpw(DEFAULT_PASSWORD.encode(), bcrypt.gensalt())
|
||||
|
||||
|
||||
@router.post("/login", response_model=LoginResponse)
|
||||
async def login(req: LoginRequest):
|
||||
stored_username, stored_hash = _get_credentials()
|
||||
|
||||
if req.username != stored_username:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户名或密码错误")
|
||||
|
||||
if not bcrypt.checkpw(req.password.encode(), stored_hash):
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户名或密码错误")
|
||||
|
||||
token = create_access_token({"sub": req.username})
|
||||
return LoginResponse(access_token=token)
|
||||
|
||||
|
||||
@router.get("/me")
|
||||
async def me(current_user: dict = Depends(get_current_user)):
|
||||
return current_user
|
||||
|
||||
|
||||
class ChangePasswordRequest(BaseModel):
|
||||
old_password: str
|
||||
new_password: str
|
||||
|
||||
|
||||
@router.post("/change-password")
|
||||
async def change_password(req: ChangePasswordRequest, current_user: dict = Depends(get_current_user)):
|
||||
_, stored_hash = _get_credentials()
|
||||
|
||||
if not bcrypt.checkpw(req.old_password.encode(), stored_hash):
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="旧密码错误")
|
||||
|
||||
new_hash = bcrypt.hashpw(req.new_password.encode(), bcrypt.gensalt()).decode()
|
||||
try:
|
||||
from app.config.settings import ConfigManager
|
||||
cfg = ConfigManager()
|
||||
cfg.update('WebAuth', 'password_hash', new_hash)
|
||||
cfg.save_config()
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"保存密码失败: {e}")
|
||||
|
||||
return {"message": "密码修改成功"}
|
||||
@@ -0,0 +1,40 @@
|
||||
"""Web-specific configuration"""
|
||||
|
||||
import os
|
||||
import secrets
|
||||
|
||||
# JWT
|
||||
JWT_SECRET_KEY = os.getenv("JWT_SECRET_KEY", "")
|
||||
JWT_ALGORITHM = "HS256"
|
||||
JWT_EXPIRE_HOURS = 24
|
||||
|
||||
# File upload
|
||||
MAX_UPLOAD_SIZE = 50 * 1024 * 1024 # 50MB
|
||||
ALLOWED_IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp'}
|
||||
ALLOWED_EXCEL_EXTENSIONS = {'.xlsx', '.xls'}
|
||||
ALLOWED_EXTENSIONS = ALLOWED_IMAGE_EXTENSIONS | ALLOWED_EXCEL_EXTENSIONS
|
||||
|
||||
# CORS
|
||||
CORS_ORIGINS = os.getenv("CORS_ORIGINS", "*").split(",")
|
||||
|
||||
# Auth rate limit
|
||||
LOGIN_RATE_LIMIT = 5 # per minute
|
||||
|
||||
|
||||
def get_or_generate_secret() -> str:
|
||||
"""Get JWT secret from env or auto-generate on first run"""
|
||||
global JWT_SECRET_KEY
|
||||
if not JWT_SECRET_KEY:
|
||||
secret_file = os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))),
|
||||
'data', '.jwt_secret'
|
||||
)
|
||||
if os.path.exists(secret_file):
|
||||
with open(secret_file, 'r') as f:
|
||||
JWT_SECRET_KEY = f.read().strip()
|
||||
if not JWT_SECRET_KEY:
|
||||
JWT_SECRET_KEY = secrets.token_urlsafe(48)
|
||||
os.makedirs(os.path.dirname(secret_file), exist_ok=True)
|
||||
with open(secret_file, 'w') as f:
|
||||
f.write(JWT_SECRET_KEY)
|
||||
return JWT_SECRET_KEY
|
||||
@@ -0,0 +1,109 @@
|
||||
"""FastAPI application entry point for the web-based OCR order processing system."""
|
||||
|
||||
import sys
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
|
||||
# Ensure app/ is importable
|
||||
_web_dir = Path(__file__).resolve().parent.parent # web/
|
||||
_project_root = _web_dir.parent # project root
|
||||
if str(_project_root) not in sys.path:
|
||||
sys.path.insert(0, str(_project_root))
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
from .config import get_or_generate_secret # noqa: trigger secret generation
|
||||
from .services.task_manager import TaskManager
|
||||
from .services.db_pool import DBPool
|
||||
from .auth.router import router as auth_router
|
||||
from .routers.files import router as files_router
|
||||
from .routers.processing import router as processing_router
|
||||
from .routers.memory import router as memory_router
|
||||
from .routers.config_api import router as config_router
|
||||
from .routers.barcodes import router as barcodes_router
|
||||
from .routers.sync import router as sync_router
|
||||
from .routers.websocket import router as ws_router
|
||||
from .routers.logs import router as logs_router
|
||||
from .routers.tasks import router as tasks_router
|
||||
from .middleware.logging import LoggingMiddleware
|
||||
|
||||
# Shared singletons
|
||||
task_manager = TaskManager()
|
||||
db_pool = DBPool()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Initialize shared resources on startup."""
|
||||
from app.config.settings import ConfigManager
|
||||
ConfigManager()
|
||||
|
||||
# Initialize DB and cleanup old records
|
||||
from .services.db_schema import init_db, cleanup_old_records
|
||||
init_db()
|
||||
cleanup_old_records()
|
||||
|
||||
# Wire up DB pool to task manager
|
||||
task_manager.set_db_pool(db_pool)
|
||||
|
||||
app.state.task_manager = task_manager
|
||||
app.state.db_pool = db_pool
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="益选 OCR 订单处理系统",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# CORS
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["http://localhost:5173", "http://127.0.0.1:5173", "http://localhost:8000"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# HTTP logging middleware (after CORS, before routes)
|
||||
app.add_middleware(LoggingMiddleware)
|
||||
|
||||
# Make task_manager and db_pool accessible via request.state
|
||||
@app.middleware("http")
|
||||
async def inject_services(request, call_next):
|
||||
request.state.task_manager = task_manager
|
||||
request.state.db_pool = db_pool
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
# Mount routers
|
||||
app.include_router(auth_router)
|
||||
app.include_router(files_router)
|
||||
app.include_router(processing_router)
|
||||
app.include_router(memory_router)
|
||||
app.include_router(config_router)
|
||||
app.include_router(barcodes_router)
|
||||
app.include_router(sync_router)
|
||||
app.include_router(ws_router)
|
||||
app.include_router(logs_router)
|
||||
app.include_router(tasks_router)
|
||||
|
||||
|
||||
# Serve Vue SPA static files
|
||||
_static_dir = Path(__file__).resolve().parent / "static"
|
||||
if _static_dir.is_dir():
|
||||
app.mount("/assets", StaticFiles(directory=str(_static_dir / "assets")), name="assets")
|
||||
|
||||
@app.get("/{full_path:path}")
|
||||
async def serve_spa(full_path: str):
|
||||
"""Catch-all: serve index.html for Vue Router history mode."""
|
||||
file_path = _static_dir / full_path
|
||||
if file_path.is_file():
|
||||
return FileResponse(str(file_path))
|
||||
return FileResponse(str(_static_dir / "index.html"))
|
||||
@@ -0,0 +1,124 @@
|
||||
"""Barcode mapping CRUD endpoints."""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, List
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..auth.dependencies import get_current_user
|
||||
|
||||
router = APIRouter(prefix="/api/barcodes", tags=["barcodes"])
|
||||
|
||||
_project_root = Path(__file__).resolve().parent.parent.parent.parent
|
||||
_mappings_file = _project_root / "config" / "barcode_mappings.json"
|
||||
|
||||
|
||||
class BarcodeMapping(BaseModel):
|
||||
barcode: str
|
||||
target: str
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
class BarcodeMappingUpdate(BaseModel):
|
||||
target: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
def _load_mappings() -> Dict:
|
||||
if not _mappings_file.is_file():
|
||||
return {}
|
||||
try:
|
||||
return json.loads(_mappings_file.read_text(encoding="utf-8"))
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
|
||||
def _save_mappings(data: Dict):
|
||||
_mappings_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
_mappings_file.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_barcodes(
|
||||
search: str = "",
|
||||
current_user: dict = Depends(get_current_user),
|
||||
):
|
||||
mappings = _load_mappings()
|
||||
items = []
|
||||
for barcode, info in mappings.items():
|
||||
if isinstance(info, dict):
|
||||
target = info.get("map_to", info.get("target", ""))
|
||||
desc = info.get("description", "")
|
||||
else:
|
||||
target = str(info)
|
||||
desc = ""
|
||||
if search and search not in barcode and search not in target and search not in desc:
|
||||
continue
|
||||
items.append({"barcode": barcode, "target": target, "description": desc})
|
||||
return {"items": items, "total": len(items)}
|
||||
|
||||
|
||||
@router.get("/{barcode}")
|
||||
async def get_barcode(
|
||||
barcode: str,
|
||||
current_user: dict = Depends(get_current_user),
|
||||
):
|
||||
mappings = _load_mappings()
|
||||
if barcode not in mappings:
|
||||
raise HTTPException(404, f"未找到条码映射 {barcode}")
|
||||
info = mappings[barcode]
|
||||
if isinstance(info, dict):
|
||||
return {"barcode": barcode, "target": info.get("map_to", info.get("target", "")), "description": info.get("description", "")}
|
||||
return {"barcode": barcode, "target": str(info), "description": ""}
|
||||
|
||||
|
||||
@router.post("")
|
||||
async def create_barcode(
|
||||
body: BarcodeMapping,
|
||||
current_user: dict = Depends(get_current_user),
|
||||
):
|
||||
mappings = _load_mappings()
|
||||
if body.barcode in mappings:
|
||||
raise HTTPException(409, f"条码 {body.barcode} 已存在")
|
||||
mappings[body.barcode] = {"map_to": body.target, "description": body.description or ""}
|
||||
_save_mappings(mappings)
|
||||
return {"message": f"已创建映射 {body.barcode} → {body.target}"}
|
||||
|
||||
|
||||
@router.put("/{barcode}")
|
||||
async def update_barcode(
|
||||
barcode: str,
|
||||
body: BarcodeMappingUpdate,
|
||||
current_user: dict = Depends(get_current_user),
|
||||
):
|
||||
mappings = _load_mappings()
|
||||
if barcode not in mappings:
|
||||
raise HTTPException(404, f"未找到条码映射 {barcode}")
|
||||
|
||||
existing = mappings[barcode]
|
||||
if not isinstance(existing, dict):
|
||||
existing = {"map_to": str(existing), "description": ""}
|
||||
|
||||
if body.target is not None:
|
||||
existing["map_to"] = body.target
|
||||
if body.description is not None:
|
||||
existing["description"] = body.description
|
||||
|
||||
mappings[barcode] = existing
|
||||
_save_mappings(mappings)
|
||||
return {"message": f"已更新映射 {barcode}"}
|
||||
|
||||
|
||||
@router.delete("/{barcode}")
|
||||
async def delete_barcode(
|
||||
barcode: str,
|
||||
current_user: dict = Depends(get_current_user),
|
||||
):
|
||||
mappings = _load_mappings()
|
||||
if barcode not in mappings:
|
||||
raise HTTPException(404, f"未找到条码映射 {barcode}")
|
||||
del mappings[barcode]
|
||||
_save_mappings(mappings)
|
||||
return {"message": f"已删除映射 {barcode}"}
|
||||
@@ -0,0 +1,98 @@
|
||||
"""Configuration read/write endpoints."""
|
||||
|
||||
from typing import Dict, Optional, Any
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..auth.dependencies import get_current_user
|
||||
|
||||
router = APIRouter(prefix="/api/config", tags=["config"])
|
||||
|
||||
# Keys that should be masked in GET responses
|
||||
_SENSITIVE_KEYS = {"api_key", "secret_key", "token", "password", "api_secret", "access_key"}
|
||||
|
||||
# Sections to expose (match actual config.ini)
|
||||
_ALLOWED_SECTIONS = {"API", "Paths", "Performance", "File", "Templates", "Gitea", "WebAuth"}
|
||||
|
||||
|
||||
class ConfigUpdate(BaseModel):
|
||||
section: str
|
||||
key: str
|
||||
value: str
|
||||
|
||||
|
||||
class ConfigBulkUpdate(BaseModel):
|
||||
updates: list[ConfigUpdate]
|
||||
|
||||
|
||||
def _get_config():
|
||||
from app.config.settings import ConfigManager
|
||||
return ConfigManager()
|
||||
|
||||
|
||||
def _mask_value(key: str, value: str) -> str:
|
||||
if any(s in key.lower() for s in _SENSITIVE_KEYS):
|
||||
if len(value) > 4:
|
||||
return value[:2] + "*" * (len(value) - 4) + value[-2:]
|
||||
return "****"
|
||||
return value
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def get_config(
|
||||
section: Optional[str] = None,
|
||||
current_user: dict = Depends(get_current_user),
|
||||
):
|
||||
cfg = _get_config()
|
||||
if section:
|
||||
if section not in _ALLOWED_SECTIONS and section != "DEFAULT":
|
||||
raise HTTPException(403, f"不允许访问配置节: {section}")
|
||||
items = {}
|
||||
for key, value in cfg.config.items(section):
|
||||
items[key] = _mask_value(key, value)
|
||||
return {"section": section, "items": items}
|
||||
|
||||
result = {}
|
||||
for sec in _ALLOWED_SECTIONS:
|
||||
try:
|
||||
items = {}
|
||||
for key, value in cfg.config.items(sec):
|
||||
items[key] = _mask_value(key, value)
|
||||
result[sec] = items
|
||||
except Exception:
|
||||
pass
|
||||
return result
|
||||
|
||||
|
||||
@router.put("")
|
||||
async def update_config(
|
||||
body: ConfigUpdate,
|
||||
current_user: dict = Depends(get_current_user),
|
||||
):
|
||||
if body.section not in _ALLOWED_SECTIONS:
|
||||
raise HTTPException(403, f"不允许修改配置节: {body.section}")
|
||||
|
||||
cfg = _get_config()
|
||||
try:
|
||||
cfg.update(body.section, body.key, body.value)
|
||||
cfg.save_config()
|
||||
return {"message": f"已更新 [{body.section}] {body.key}"}
|
||||
except Exception as e:
|
||||
raise HTTPException(500, f"保存失败: {e}")
|
||||
|
||||
|
||||
@router.put("/bulk")
|
||||
async def bulk_update_config(
|
||||
body: ConfigBulkUpdate,
|
||||
current_user: dict = Depends(get_current_user),
|
||||
):
|
||||
cfg = _get_config()
|
||||
updated = []
|
||||
for item in body.updates:
|
||||
if item.section not in _ALLOWED_SECTIONS:
|
||||
continue
|
||||
cfg.update(item.section, item.key, item.value)
|
||||
updated.append(f"[{item.section}] {item.key}")
|
||||
|
||||
cfg.save_config()
|
||||
return {"message": f"已更新 {len(updated)} 项", "updated": updated}
|
||||
@@ -49,7 +49,7 @@ def _count_http_logs(
|
||||
row = conn.execute(
|
||||
f"SELECT COUNT(*) as cnt FROM http_logs{where}", params
|
||||
).fetchone()
|
||||
return row["cnt"] if row else 0
|
||||
return row[0] if row else 0
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
@@ -0,0 +1,165 @@
|
||||
"""Product memory CRUD endpoints."""
|
||||
|
||||
from typing import Optional, List, Dict
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends, Query
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..auth.dependencies import get_current_user
|
||||
|
||||
router = APIRouter(prefix="/api/memory", tags=["memory"])
|
||||
|
||||
_project_root = Path(__file__).resolve().parent.parent.parent.parent
|
||||
_db_path = str(_project_root / "data" / "product_cache.db")
|
||||
_excel_source = str(_project_root / "templates" / "商品资料.xlsx")
|
||||
|
||||
|
||||
class MemoryItem(BaseModel):
|
||||
barcode: str
|
||||
name: str
|
||||
spec: Optional[str] = None
|
||||
unit: Optional[str] = None
|
||||
price: Optional[float] = None
|
||||
confidence: int = 0
|
||||
source: str = "ocr"
|
||||
last_used: Optional[str] = None
|
||||
use_count: int = 0
|
||||
|
||||
|
||||
class MemoryUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
spec: Optional[str] = None
|
||||
unit: Optional[str] = None
|
||||
price: Optional[float] = None
|
||||
confidence: Optional[int] = None
|
||||
|
||||
|
||||
class MemoryListResponse(BaseModel):
|
||||
items: List[MemoryItem]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
|
||||
|
||||
def _get_db():
|
||||
from app.core.db.product_db import ProductDatabase
|
||||
return ProductDatabase(_db_path, _excel_source)
|
||||
|
||||
|
||||
def _row_to_item(row: Dict) -> MemoryItem:
|
||||
return MemoryItem(
|
||||
barcode=row.get("barcode", ""),
|
||||
name=row.get("name", ""),
|
||||
spec=row.get("spec"),
|
||||
unit=row.get("unit"),
|
||||
price=row.get("price"),
|
||||
confidence=row.get("confidence", 0),
|
||||
source=row.get("source", "ocr"),
|
||||
last_used=row.get("last_used"),
|
||||
use_count=row.get("use_count", 0),
|
||||
)
|
||||
|
||||
|
||||
@router.get("", response_model=MemoryListResponse)
|
||||
async def list_memory(
|
||||
search: str = "",
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(50, ge=1, le=200),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
):
|
||||
db = _get_db()
|
||||
results = db.get_all_memories()
|
||||
|
||||
if search:
|
||||
s = search.lower()
|
||||
results = [r for r in results if s in r.get("barcode", "").lower() or s in r.get("name", "").lower()]
|
||||
|
||||
total = len(results)
|
||||
start = (page - 1) * page_size
|
||||
page_items = results[start:start + page_size]
|
||||
|
||||
return MemoryListResponse(
|
||||
items=[_row_to_item(r) for r in page_items],
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{barcode}")
|
||||
async def get_memory(
|
||||
barcode: str,
|
||||
current_user: dict = Depends(get_current_user),
|
||||
):
|
||||
db = _get_db()
|
||||
product = db.get_memory(barcode)
|
||||
if not product:
|
||||
raise HTTPException(404, f"未找到条码 {barcode} 的记忆记录")
|
||||
return product
|
||||
|
||||
|
||||
@router.put("/{barcode}")
|
||||
async def update_memory(
|
||||
barcode: str,
|
||||
body: MemoryUpdate,
|
||||
current_user: dict = Depends(get_current_user),
|
||||
):
|
||||
db = _get_db()
|
||||
existing = db.get_memory(barcode)
|
||||
if not existing:
|
||||
raise HTTPException(404, f"未找到条码 {barcode}")
|
||||
|
||||
update_data = body.model_dump(exclude_none=True)
|
||||
if not update_data:
|
||||
raise HTTPException(400, "没有提供更新数据")
|
||||
|
||||
db.update_memory(barcode, update_data)
|
||||
return {"message": f"已更新 {barcode}", "updated_fields": list(update_data.keys())}
|
||||
|
||||
|
||||
@router.delete("/{barcode}")
|
||||
async def delete_memory(
|
||||
barcode: str,
|
||||
current_user: dict = Depends(get_current_user),
|
||||
):
|
||||
db = _get_db()
|
||||
existing = db.get_memory(barcode)
|
||||
if not existing:
|
||||
raise HTTPException(404, f"未找到条码 {barcode}")
|
||||
db.delete_memory(barcode)
|
||||
return {"message": f"已删除 {barcode}"}
|
||||
|
||||
|
||||
@router.post("/reimport")
|
||||
async def reimport_memory(
|
||||
current_user: dict = Depends(get_current_user),
|
||||
):
|
||||
db = _get_db()
|
||||
try:
|
||||
count = db.reimport()
|
||||
return {"message": f"重新导入完成,共导入 {count} 条记录", "count": count}
|
||||
except Exception as e:
|
||||
raise HTTPException(500, f"导入失败: {e}")
|
||||
|
||||
|
||||
@router.get("/export/sync")
|
||||
async def export_memory(
|
||||
current_user: dict = Depends(get_current_user),
|
||||
):
|
||||
db = _get_db()
|
||||
data = db.export_for_sync()
|
||||
return {"data": data, "count": len(data)}
|
||||
|
||||
|
||||
@router.post("/import/sync")
|
||||
async def import_memory(
|
||||
data: dict,
|
||||
current_user: dict = Depends(get_current_user),
|
||||
):
|
||||
db = _get_db()
|
||||
try:
|
||||
count = db.import_from_sync(data.get("data", []))
|
||||
return {"message": f"导入完成,共 {count} 条", "count": count}
|
||||
except Exception as e:
|
||||
raise HTTPException(500, f"导入失败: {e}")
|
||||
@@ -0,0 +1,250 @@
|
||||
"""Processing endpoints: OCR, Excel conversion, merge, and full pipeline."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends, Request
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..auth.dependencies import get_current_user
|
||||
from ..services.service_wrapper import ServiceWrapper
|
||||
|
||||
router = APIRouter(prefix="/api/processing", tags=["processing"])
|
||||
|
||||
_wrapper = ServiceWrapper(max_workers=3)
|
||||
|
||||
_project_root = Path(__file__).resolve().parent.parent.parent.parent
|
||||
_input_dir = _project_root / "data" / "input"
|
||||
_output_dir = _project_root / "data" / "output"
|
||||
_result_dir = _project_root / "data" / "result"
|
||||
|
||||
|
||||
class PipelineRequest(BaseModel):
|
||||
files: Optional[List[str]] = None # specific files, or None = all in input/
|
||||
supplier: Optional[str] = None # force supplier type
|
||||
|
||||
|
||||
class TaskResponse(BaseModel):
|
||||
task_id: str
|
||||
status: str
|
||||
message: str
|
||||
|
||||
|
||||
def _get_task_manager(request: Request):
|
||||
return request.state.task_manager
|
||||
|
||||
|
||||
def _list_input_files(filter_ext: Optional[List[str]] = None) -> List[Path]:
|
||||
if not _input_dir.is_dir():
|
||||
return []
|
||||
files = []
|
||||
for f in sorted(_input_dir.iterdir()):
|
||||
if f.is_file():
|
||||
if filter_ext is None or f.suffix.lower() in filter_ext:
|
||||
files.append(f)
|
||||
return files
|
||||
|
||||
|
||||
@router.post("/ocr-batch", response_model=TaskResponse)
|
||||
async def ocr_batch(
|
||||
request: Request,
|
||||
current_user: dict = Depends(get_current_user),
|
||||
):
|
||||
"""Run OCR on all images in input/."""
|
||||
tm = _get_task_manager(request)
|
||||
task = tm.create_task("批量OCR识别")
|
||||
|
||||
image_exts = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif'}
|
||||
files = _list_input_files(filter_ext=list(image_exts))
|
||||
if not files:
|
||||
raise HTTPException(400, "input/ 目录中没有图片文件")
|
||||
|
||||
async def _run():
|
||||
try:
|
||||
from app.services.ocr_service import OCRService
|
||||
svc = OCRService()
|
||||
total = len(files)
|
||||
for i, f in enumerate(files):
|
||||
tm.update_progress(task.id, int((i / total) * 100), f"正在识别: {f.name}")
|
||||
tm.add_log(task.id, f"[OCR] 处理 {f.name}")
|
||||
try:
|
||||
svc.process_single(str(f), str(_output_dir))
|
||||
tm.add_log(task.id, f"[OCR] 完成: {f.name}")
|
||||
except Exception as e:
|
||||
tm.add_log(task.id, f"[OCR] 失败: {f.name} - {e}")
|
||||
result_files = [f.name for f in _output_dir.iterdir() if f.is_file()]
|
||||
tm.set_completed(task.id, result_files=result_files, message=f"OCR完成,共处理 {total} 个文件")
|
||||
except Exception as e:
|
||||
tm.set_failed(task.id, str(e))
|
||||
|
||||
import asyncio
|
||||
asyncio.create_task(_run())
|
||||
|
||||
return TaskResponse(task_id=task.id, status="accepted", message="OCR任务已创建")
|
||||
|
||||
|
||||
@router.post("/excel", response_model=TaskResponse)
|
||||
async def process_excel(
|
||||
request: Request,
|
||||
body: PipelineRequest = PipelineRequest(),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
):
|
||||
"""Convert OCR output Excel files to standardized format."""
|
||||
tm = _get_task_manager(request)
|
||||
task = tm.create_task("Excel标准化处理")
|
||||
|
||||
excel_exts = {'.xls', '.xlsx'}
|
||||
if body.files:
|
||||
files = [_output_dir / f for f in body.files if (_output_dir / f).is_file()]
|
||||
else:
|
||||
files = _list_input_files(filter_ext=list(excel_exts))
|
||||
if not files:
|
||||
files = _list_input_files_from(_output_dir, filter_ext=list(excel_exts))
|
||||
|
||||
if not files:
|
||||
raise HTTPException(400, "没有找到Excel文件")
|
||||
|
||||
async def _run():
|
||||
try:
|
||||
from app.services.order_service import OrderService
|
||||
svc = OrderService()
|
||||
total = len(files)
|
||||
for i, f in enumerate(files):
|
||||
tm.update_progress(task.id, int((i / total) * 100), f"正在处理: {f.name}")
|
||||
tm.add_log(task.id, f"[Excel] 处理 {f.name}")
|
||||
try:
|
||||
svc.process_excel(str(f), str(_result_dir))
|
||||
tm.add_log(task.id, f"[Excel] 完成: {f.name}")
|
||||
except Exception as e:
|
||||
tm.add_log(task.id, f"[Excel] 失败: {f.name} - {e}")
|
||||
result_files = [f.name for f in _result_dir.iterdir() if f.is_file()]
|
||||
tm.set_completed(task.id, result_files=result_files, message=f"Excel处理完成,共 {total} 个文件")
|
||||
except Exception as e:
|
||||
tm.set_failed(task.id, str(e))
|
||||
|
||||
import asyncio
|
||||
asyncio.create_task(_run())
|
||||
|
||||
return TaskResponse(task_id=task.id, status="accepted", message="Excel处理任务已创建")
|
||||
|
||||
|
||||
@router.post("/merge", response_model=TaskResponse)
|
||||
async def merge_orders(
|
||||
request: Request,
|
||||
current_user: dict = Depends(get_current_user),
|
||||
):
|
||||
"""Merge all processed Excel files into a single purchase order."""
|
||||
tm = _get_task_manager(request)
|
||||
task = tm.create_task("合并采购单")
|
||||
|
||||
async def _run():
|
||||
try:
|
||||
from app.services.order_service import OrderService
|
||||
svc = OrderService()
|
||||
tm.update_progress(task.id, 20, "正在合并采购单...")
|
||||
tm.add_log(task.id, "[合并] 开始合并")
|
||||
result = svc.merge_orders(str(_result_dir))
|
||||
tm.add_log(task.id, f"[合并] 完成: {result}")
|
||||
tm.set_completed(task.id, result_files=[result] if result else [], message="合并完成")
|
||||
except Exception as e:
|
||||
tm.set_failed(task.id, str(e))
|
||||
|
||||
import asyncio
|
||||
asyncio.create_task(_run())
|
||||
|
||||
return TaskResponse(task_id=task.id, status="accepted", message="合并任务已创建")
|
||||
|
||||
|
||||
@router.post("/pipeline", response_model=TaskResponse)
|
||||
async def full_pipeline(
|
||||
request: Request,
|
||||
body: PipelineRequest = PipelineRequest(),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
):
|
||||
"""Run the full pipeline: OCR → Excel → Merge."""
|
||||
tm = _get_task_manager(request)
|
||||
task = tm.create_task("一键全流程处理")
|
||||
|
||||
async def _run():
|
||||
try:
|
||||
# Step 1: OCR
|
||||
tm.update_progress(task.id, 0, "步骤 1/3: OCR识别")
|
||||
tm.add_log(task.id, "[Pipeline] 开始OCR识别")
|
||||
from app.services.ocr_service import OCRService
|
||||
ocr_svc = OCRService()
|
||||
|
||||
image_exts = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif'}
|
||||
images = _list_input_files(filter_ext=list(image_exts))
|
||||
for i, f in enumerate(images):
|
||||
pct = int((i / max(len(images), 1)) * 30)
|
||||
tm.update_progress(task.id, pct, f"OCR: {f.name}")
|
||||
try:
|
||||
ocr_svc.process_single(str(f), str(_output_dir))
|
||||
tm.add_log(task.id, f"[OCR] 完成: {f.name}")
|
||||
except Exception as e:
|
||||
tm.add_log(task.id, f"[OCR] 失败: {f.name} - {e}")
|
||||
|
||||
# Step 2: Excel conversion
|
||||
tm.update_progress(task.id, 35, "步骤 2/3: Excel标准化")
|
||||
tm.add_log(task.id, "[Pipeline] 开始Excel处理")
|
||||
from app.services.order_service import OrderService
|
||||
order_svc = OrderService()
|
||||
|
||||
excel_files = list(_output_dir.glob("*.xls")) + list(_output_dir.glob("*.xlsx"))
|
||||
for i, f in enumerate(excel_files):
|
||||
pct = 35 + int((i / max(len(excel_files), 1)) * 35)
|
||||
tm.update_progress(task.id, pct, f"Excel: {f.name}")
|
||||
try:
|
||||
order_svc.process_excel(str(f), str(_result_dir))
|
||||
tm.add_log(task.id, f"[Excel] 完成: {f.name}")
|
||||
except Exception as e:
|
||||
tm.add_log(task.id, f"[Excel] 失败: {f.name} - {e}")
|
||||
|
||||
# Step 3: Merge
|
||||
tm.update_progress(task.id, 75, "步骤 3/3: 合并采购单")
|
||||
tm.add_log(task.id, "[Pipeline] 开始合并")
|
||||
try:
|
||||
result = order_svc.merge_orders(str(_result_dir))
|
||||
tm.add_log(task.id, f"[合并] 完成: {result}")
|
||||
except Exception as e:
|
||||
tm.add_log(task.id, f"[合并] 失败: {e}")
|
||||
result = None
|
||||
|
||||
result_files = [f.name for f in _result_dir.iterdir() if f.is_file()]
|
||||
tm.set_completed(task.id, result_files=result_files, message="全流程处理完成")
|
||||
except Exception as e:
|
||||
tb = traceback.format_exc()
|
||||
tm.add_log(task.id, f"[错误] {tb}")
|
||||
tm.set_failed(task.id, str(e))
|
||||
|
||||
import asyncio
|
||||
asyncio.create_task(_run())
|
||||
|
||||
return TaskResponse(task_id=task.id, status="accepted", message="全流程任务已创建")
|
||||
|
||||
|
||||
@router.get("/status/{task_id}")
|
||||
async def get_task_status(
|
||||
task_id: str,
|
||||
request: Request,
|
||||
current_user: dict = Depends(get_current_user),
|
||||
):
|
||||
tm = _get_task_manager(request)
|
||||
task = tm.get_task(task_id)
|
||||
if not task:
|
||||
raise HTTPException(404, "任务不存在")
|
||||
return task.to_dict()
|
||||
|
||||
|
||||
def _list_input_files_from(directory: Path, filter_ext: List[str] = None) -> List[Path]:
|
||||
if not directory.is_dir():
|
||||
return []
|
||||
files = []
|
||||
for f in sorted(directory.iterdir()):
|
||||
if f.is_file():
|
||||
if filter_ext is None or f.suffix.lower() in filter_ext:
|
||||
files.append(f)
|
||||
return files
|
||||
@@ -0,0 +1,93 @@
|
||||
"""Cloud sync endpoints (Gitea-based)."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends, Request
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..auth.dependencies import get_current_user
|
||||
from ..services.task_manager import TaskManager
|
||||
|
||||
router = APIRouter(prefix="/api/sync", tags=["sync"])
|
||||
|
||||
_project_root = Path(__file__).resolve().parent.parent.parent.parent
|
||||
|
||||
|
||||
class SyncResponse(BaseModel):
|
||||
task_id: str
|
||||
status: str
|
||||
message: str
|
||||
|
||||
|
||||
def _get_sync():
|
||||
from app.core.utils.cloud_sync import GiteaSync
|
||||
from app.config.settings import ConfigManager
|
||||
cfg = ConfigManager()
|
||||
return GiteaSync(cfg)
|
||||
|
||||
|
||||
@router.post("/push", response_model=SyncResponse)
|
||||
async def sync_push(
|
||||
request: Request,
|
||||
current_user: dict = Depends(get_current_user),
|
||||
):
|
||||
tm = request.state.task_manager
|
||||
task = tm.create_task("推送到云端")
|
||||
|
||||
async def _run():
|
||||
try:
|
||||
tm.update_progress(task.id, 10, "正在初始化同步...")
|
||||
sync = _get_sync()
|
||||
tm.update_progress(task.id, 30, "正在推送文件...")
|
||||
tm.add_log(task.id, "[Push] 开始推送")
|
||||
result = sync.push()
|
||||
tm.add_log(task.id, f"[Push] 完成: {result}")
|
||||
tm.set_completed(task.id, message="推送完成")
|
||||
except Exception as e:
|
||||
tm.set_failed(task.id, str(e))
|
||||
|
||||
import asyncio
|
||||
asyncio.create_task(_run())
|
||||
return SyncResponse(task_id=task.id, status="accepted", message="推送任务已创建")
|
||||
|
||||
|
||||
@router.post("/pull", response_model=SyncResponse)
|
||||
async def sync_pull(
|
||||
request: Request,
|
||||
current_user: dict = Depends(get_current_user),
|
||||
):
|
||||
tm = request.state.task_manager
|
||||
task = tm.create_task("从云端拉取")
|
||||
|
||||
async def _run():
|
||||
try:
|
||||
tm.update_progress(task.id, 10, "正在初始化同步...")
|
||||
sync = _get_sync()
|
||||
tm.update_progress(task.id, 30, "正在拉取文件...")
|
||||
tm.add_log(task.id, "[Pull] 开始拉取")
|
||||
result = sync.pull()
|
||||
tm.add_log(task.id, f"[Pull] 完成: {result}")
|
||||
tm.set_completed(task.id, message="拉取完成")
|
||||
except Exception as e:
|
||||
tm.set_failed(task.id, str(e))
|
||||
|
||||
import asyncio
|
||||
asyncio.create_task(_run())
|
||||
return SyncResponse(task_id=task.id, status="accepted", message="拉取任务已创建")
|
||||
|
||||
|
||||
@router.get("/status")
|
||||
async def sync_status(
|
||||
current_user: dict = Depends(get_current_user),
|
||||
):
|
||||
try:
|
||||
from app.config.settings import ConfigManager
|
||||
cfg = ConfigManager()
|
||||
base_url = cfg.get("Gitea", "base_url", fallback="")
|
||||
owner = cfg.get("Gitea", "owner", fallback="")
|
||||
repo = cfg.get("Gitea", "repo", fallback="")
|
||||
enabled = bool(base_url and owner and repo)
|
||||
repo_url = f"{base_url}/{owner}/{repo}" if enabled else ""
|
||||
return {"enabled": enabled, "repo_url": repo_url}
|
||||
except Exception:
|
||||
return {"enabled": False, "repo_url": ""}
|
||||
@@ -0,0 +1,47 @@
|
||||
"""WebSocket endpoint for real-time task progress."""
|
||||
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
|
||||
from ..auth.jwt_handler import decode_token
|
||||
from jose import JWTError
|
||||
|
||||
router = APIRouter(tags=["websocket"])
|
||||
|
||||
|
||||
@router.websocket("/ws/task/{task_id}")
|
||||
async def task_websocket(
|
||||
websocket: WebSocket,
|
||||
task_id: str,
|
||||
token: str = Query(...),
|
||||
):
|
||||
"""WebSocket for real-time task progress updates."""
|
||||
try:
|
||||
payload = decode_token(token)
|
||||
username = payload.get("sub")
|
||||
if not username:
|
||||
await websocket.close(code=4001, reason="Invalid token")
|
||||
return
|
||||
except (JWTError, Exception):
|
||||
await websocket.close(code=4001, reason="Invalid token")
|
||||
return
|
||||
|
||||
await websocket.accept()
|
||||
|
||||
tm = websocket.app.state.task_manager
|
||||
task = tm.get_task(task_id)
|
||||
if not task:
|
||||
await websocket.send_json({"error": "任务不存在"})
|
||||
await websocket.close()
|
||||
return
|
||||
|
||||
tm.subscribe(task_id, websocket)
|
||||
await websocket.send_json(task.to_dict())
|
||||
|
||||
try:
|
||||
while True:
|
||||
data = await websocket.receive_text()
|
||||
if data == "ping":
|
||||
await websocket.send_text("pong")
|
||||
except WebSocketDisconnect:
|
||||
tm.unsubscribe(task_id, websocket)
|
||||
except Exception:
|
||||
tm.unsubscribe(task_id, websocket)
|
||||
@@ -0,0 +1,20 @@
|
||||
"""SQLite write serialization for async context"""
|
||||
|
||||
import asyncio
|
||||
from typing import Callable, Any
|
||||
|
||||
|
||||
class DBPool:
|
||||
"""Serializes SQLite writes via asyncio.Lock. Reads are concurrent."""
|
||||
|
||||
def __init__(self):
|
||||
self._write_lock = asyncio.Lock()
|
||||
|
||||
async def execute_write(self, fn: Callable, *args, **kwargs) -> Any:
|
||||
async with self._write_lock:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, lambda: fn(*args, **kwargs))
|
||||
|
||||
async def execute_read(self, fn: Callable, *args, **kwargs) -> Any:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, lambda: fn(*args, **kwargs))
|
||||
@@ -0,0 +1,19 @@
|
||||
"""Async wrapper for synchronous app/ services"""
|
||||
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Callable, Any
|
||||
|
||||
|
||||
class ServiceWrapper:
|
||||
"""Wraps synchronous services for async FastAPI endpoints."""
|
||||
|
||||
def __init__(self, max_workers: int = 3):
|
||||
self._executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||
|
||||
async def run_sync(self, fn: Callable, *args, **kwargs) -> Any:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
self._executor,
|
||||
lambda: fn(*args, **kwargs)
|
||||
)
|
||||
Reference in New Issue
Block a user