TrendRadar/mcp_server/tools/storage_sync.py
2025-12-13 13:44:35 +08:00

469 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# coding=utf-8
"""
存储同步工具
实现从远程存储拉取数据到本地、获取存储状态、列出可用日期等功能。
"""
import os
import re
from pathlib import Path
from datetime import datetime, timedelta
from typing import Dict, List, Optional
import yaml
from ..utils.errors import MCPError
class StorageSyncTools:
"""存储同步工具类"""
def __init__(self, project_root: str = None):
"""
初始化存储同步工具
Args:
project_root: 项目根目录
"""
if project_root:
self.project_root = Path(project_root)
else:
current_file = Path(__file__)
self.project_root = current_file.parent.parent.parent
self._config = None
self._remote_backend = None
def _load_config(self) -> dict:
"""加载配置文件"""
if self._config is None:
config_path = self.project_root / "config" / "config.yaml"
if config_path.exists():
with open(config_path, "r", encoding="utf-8") as f:
self._config = yaml.safe_load(f)
else:
self._config = {}
return self._config
def _get_storage_config(self) -> dict:
"""获取存储配置"""
config = self._load_config()
return config.get("storage", {})
def _get_remote_config(self) -> dict:
"""
获取远程存储配置(合并配置文件和环境变量)
"""
storage_config = self._get_storage_config()
remote_config = storage_config.get("remote", {})
return {
"endpoint_url": remote_config.get("endpoint_url") or os.environ.get("S3_ENDPOINT_URL", ""),
"bucket_name": remote_config.get("bucket_name") or os.environ.get("S3_BUCKET_NAME", ""),
"access_key_id": remote_config.get("access_key_id") or os.environ.get("S3_ACCESS_KEY_ID", ""),
"secret_access_key": remote_config.get("secret_access_key") or os.environ.get("S3_SECRET_ACCESS_KEY", ""),
"region": remote_config.get("region") or os.environ.get("S3_REGION", ""),
}
def _has_remote_config(self) -> bool:
"""检查是否有有效的远程存储配置"""
config = self._get_remote_config()
return bool(
config.get("bucket_name") and
config.get("access_key_id") and
config.get("secret_access_key") and
config.get("endpoint_url")
)
def _get_remote_backend(self):
"""获取远程存储后端实例"""
if self._remote_backend is not None:
return self._remote_backend
if not self._has_remote_config():
return None
try:
from trendradar.storage.remote import RemoteStorageBackend
remote_config = self._get_remote_config()
config = self._load_config()
timezone = config.get("app", {}).get("timezone", "Asia/Shanghai")
self._remote_backend = RemoteStorageBackend(
bucket_name=remote_config["bucket_name"],
access_key_id=remote_config["access_key_id"],
secret_access_key=remote_config["secret_access_key"],
endpoint_url=remote_config["endpoint_url"],
region=remote_config.get("region", ""),
timezone=timezone,
)
return self._remote_backend
except ImportError:
print("[存储同步] 远程存储后端需要安装 boto3: pip install boto3")
return None
except Exception as e:
print(f"[存储同步] 创建远程后端失败: {e}")
return None
def _get_local_data_dir(self) -> Path:
"""获取本地数据目录"""
storage_config = self._get_storage_config()
local_config = storage_config.get("local", {})
data_dir = local_config.get("data_dir", "output")
return self.project_root / data_dir
def _parse_date_folder_name(self, folder_name: str) -> Optional[datetime]:
"""
解析日期文件夹名称(兼容中文和 ISO 格式)
支持两种格式:
- 中文格式YYYY年MM月DD日
- ISO 格式YYYY-MM-DD
"""
# 尝试 ISO 格式
iso_match = re.match(r'(\d{4})-(\d{2})-(\d{2})', folder_name)
if iso_match:
try:
return datetime(
int(iso_match.group(1)),
int(iso_match.group(2)),
int(iso_match.group(3))
)
except ValueError:
pass
# 尝试中文格式
chinese_match = re.match(r'(\d{4})年(\d{2})月(\d{2})日', folder_name)
if chinese_match:
try:
return datetime(
int(chinese_match.group(1)),
int(chinese_match.group(2)),
int(chinese_match.group(3))
)
except ValueError:
pass
return None
def _get_local_dates(self) -> List[str]:
"""获取本地可用的日期列表"""
local_dir = self._get_local_data_dir()
dates = []
if not local_dir.exists():
return dates
for item in local_dir.iterdir():
if item.is_dir() and not item.name.startswith('.'):
folder_date = self._parse_date_folder_name(item.name)
if folder_date:
dates.append(folder_date.strftime("%Y-%m-%d"))
return sorted(dates, reverse=True)
def _calculate_dir_size(self, path: Path) -> int:
"""计算目录大小(字节)"""
total_size = 0
if path.exists():
for item in path.rglob("*"):
if item.is_file():
total_size += item.stat().st_size
return total_size
def sync_from_remote(self, days: int = 7) -> Dict:
"""
从远程存储拉取数据到本地
Args:
days: 拉取最近 N 天的数据,默认 7 天
Returns:
同步结果字典
"""
try:
# 检查远程配置
if not self._has_remote_config():
return {
"success": False,
"error": {
"code": "REMOTE_NOT_CONFIGURED",
"message": "未配置远程存储",
"suggestion": "请在 config/config.yaml 中配置 storage.remote 或设置环境变量"
}
}
# 获取远程后端
remote_backend = self._get_remote_backend()
if remote_backend is None:
return {
"success": False,
"error": {
"code": "REMOTE_BACKEND_FAILED",
"message": "无法创建远程存储后端",
"suggestion": "请检查远程存储配置和 boto3 是否已安装"
}
}
# 获取本地数据目录
local_dir = self._get_local_data_dir()
local_dir.mkdir(parents=True, exist_ok=True)
# 获取远程可用日期
remote_dates = remote_backend.list_remote_dates()
# 获取本地已有日期
local_dates = set(self._get_local_dates())
# 计算需要拉取的日期(最近 N 天)
from trendradar.utils.time import get_configured_time
config = self._load_config()
timezone = config.get("app", {}).get("timezone", "Asia/Shanghai")
now = get_configured_time(timezone)
target_dates = []
for i in range(days):
date = now - timedelta(days=i)
date_str = date.strftime("%Y-%m-%d")
if date_str in remote_dates:
target_dates.append(date_str)
# 执行拉取
synced_dates = []
skipped_dates = []
failed_dates = []
for date_str in target_dates:
# 检查本地是否已存在
if date_str in local_dates:
skipped_dates.append(date_str)
continue
# 拉取单个日期
try:
local_date_dir = local_dir / date_str
local_db_path = local_date_dir / "news.db"
remote_key = f"news/{date_str}.db"
local_date_dir.mkdir(parents=True, exist_ok=True)
remote_backend.s3_client.download_file(
remote_backend.bucket_name,
remote_key,
str(local_db_path)
)
synced_dates.append(date_str)
print(f"[存储同步] 已拉取: {date_str}")
except Exception as e:
failed_dates.append({"date": date_str, "error": str(e)})
print(f"[存储同步] 拉取失败 ({date_str}): {e}")
return {
"success": True,
"synced_files": len(synced_dates),
"synced_dates": synced_dates,
"skipped_dates": skipped_dates,
"failed_dates": failed_dates,
"message": f"成功同步 {len(synced_dates)} 天数据" + (
f",跳过 {len(skipped_dates)} 天(本地已存在)" if skipped_dates else ""
) + (
f",失败 {len(failed_dates)}" if failed_dates else ""
)
}
except MCPError as e:
return {
"success": False,
"error": e.to_dict()
}
except Exception as e:
return {
"success": False,
"error": {
"code": "INTERNAL_ERROR",
"message": str(e)
}
}
def get_storage_status(self) -> Dict:
"""
获取存储配置和状态
Returns:
存储状态字典
"""
try:
storage_config = self._get_storage_config()
config = self._load_config()
# 本地存储状态
local_config = storage_config.get("local", {})
local_dir = self._get_local_data_dir()
local_size = self._calculate_dir_size(local_dir)
local_dates = self._get_local_dates()
local_status = {
"data_dir": local_config.get("data_dir", "output"),
"retention_days": local_config.get("retention_days", 0),
"total_size": f"{local_size / 1024 / 1024:.2f} MB",
"total_size_bytes": local_size,
"date_count": len(local_dates),
"earliest_date": local_dates[-1] if local_dates else None,
"latest_date": local_dates[0] if local_dates else None,
}
# 远程存储状态
remote_config = storage_config.get("remote", {})
has_remote = self._has_remote_config()
remote_status = {
"configured": has_remote,
"retention_days": remote_config.get("retention_days", 0),
}
if has_remote:
merged_config = self._get_remote_config()
# 脱敏显示
endpoint = merged_config.get("endpoint_url", "")
bucket = merged_config.get("bucket_name", "")
remote_status["endpoint_url"] = endpoint
remote_status["bucket_name"] = bucket
# 尝试获取远程日期列表
remote_backend = self._get_remote_backend()
if remote_backend:
try:
remote_dates = remote_backend.list_remote_dates()
remote_status["date_count"] = len(remote_dates)
remote_status["earliest_date"] = remote_dates[-1] if remote_dates else None
remote_status["latest_date"] = remote_dates[0] if remote_dates else None
except Exception as e:
remote_status["error"] = str(e)
# 拉取配置状态
pull_config = storage_config.get("pull", {})
pull_status = {
"enabled": pull_config.get("enabled", False),
"days": pull_config.get("days", 7),
}
return {
"success": True,
"backend": storage_config.get("backend", "auto"),
"local": local_status,
"remote": remote_status,
"pull": pull_status,
}
except MCPError as e:
return {
"success": False,
"error": e.to_dict()
}
except Exception as e:
return {
"success": False,
"error": {
"code": "INTERNAL_ERROR",
"message": str(e)
}
}
def list_available_dates(self, source: str = "both") -> Dict:
"""
列出可用的日期范围
Args:
source: 数据来源
- "local": 仅本地
- "remote": 仅远程
- "both": 两者都列出(默认)
Returns:
日期列表字典
"""
try:
result = {
"success": True,
}
# 本地日期
if source in ("local", "both"):
local_dates = self._get_local_dates()
result["local"] = {
"dates": local_dates,
"count": len(local_dates),
"earliest": local_dates[-1] if local_dates else None,
"latest": local_dates[0] if local_dates else None,
}
# 远程日期
if source in ("remote", "both"):
if not self._has_remote_config():
result["remote"] = {
"configured": False,
"dates": [],
"count": 0,
"earliest": None,
"latest": None,
"error": "未配置远程存储"
}
else:
remote_backend = self._get_remote_backend()
if remote_backend:
try:
remote_dates = remote_backend.list_remote_dates()
result["remote"] = {
"configured": True,
"dates": remote_dates,
"count": len(remote_dates),
"earliest": remote_dates[-1] if remote_dates else None,
"latest": remote_dates[0] if remote_dates else None,
}
except Exception as e:
result["remote"] = {
"configured": True,
"dates": [],
"count": 0,
"earliest": None,
"latest": None,
"error": str(e)
}
else:
result["remote"] = {
"configured": True,
"dates": [],
"count": 0,
"earliest": None,
"latest": None,
"error": "无法创建远程存储后端"
}
# 如果同时查询两者,计算差异
if source == "both" and "local" in result and "remote" in result:
local_set = set(result["local"]["dates"])
remote_set = set(result["remote"].get("dates", []))
result["comparison"] = {
"only_local": sorted(list(local_set - remote_set), reverse=True),
"only_remote": sorted(list(remote_set - local_set), reverse=True),
"both": sorted(list(local_set & remote_set), reverse=True),
}
return result
except MCPError as e:
return {
"success": False,
"error": e.to_dict()
}
except Exception as e:
return {
"success": False,
"error": {
"code": "INTERNAL_ERROR",
"message": str(e)
}
}