mirror of
https://gitee.com/houhuan/TrendRadar.git
synced 2025-12-21 12:47:16 +08:00
469 lines
17 KiB
Python
469 lines
17 KiB
Python
# coding=utf-8
|
||
"""
|
||
存储同步工具
|
||
|
||
实现从远程存储拉取数据到本地、获取存储状态、列出可用日期等功能。
|
||
"""
|
||
|
||
import os
|
||
import re
|
||
from pathlib import Path
|
||
from datetime import datetime, timedelta
|
||
from typing import Dict, List, Optional
|
||
|
||
import yaml
|
||
|
||
from ..utils.errors import MCPError
|
||
|
||
|
||
class StorageSyncTools:
|
||
"""存储同步工具类"""
|
||
|
||
def __init__(self, project_root: str = None):
|
||
"""
|
||
初始化存储同步工具
|
||
|
||
Args:
|
||
project_root: 项目根目录
|
||
"""
|
||
if project_root:
|
||
self.project_root = Path(project_root)
|
||
else:
|
||
current_file = Path(__file__)
|
||
self.project_root = current_file.parent.parent.parent
|
||
|
||
self._config = None
|
||
self._remote_backend = None
|
||
|
||
def _load_config(self) -> dict:
|
||
"""加载配置文件"""
|
||
if self._config is None:
|
||
config_path = self.project_root / "config" / "config.yaml"
|
||
if config_path.exists():
|
||
with open(config_path, "r", encoding="utf-8") as f:
|
||
self._config = yaml.safe_load(f)
|
||
else:
|
||
self._config = {}
|
||
return self._config
|
||
|
||
def _get_storage_config(self) -> dict:
|
||
"""获取存储配置"""
|
||
config = self._load_config()
|
||
return config.get("storage", {})
|
||
|
||
def _get_remote_config(self) -> dict:
|
||
"""
|
||
获取远程存储配置(合并配置文件和环境变量)
|
||
"""
|
||
storage_config = self._get_storage_config()
|
||
remote_config = storage_config.get("remote", {})
|
||
|
||
return {
|
||
"endpoint_url": remote_config.get("endpoint_url") or os.environ.get("S3_ENDPOINT_URL", ""),
|
||
"bucket_name": remote_config.get("bucket_name") or os.environ.get("S3_BUCKET_NAME", ""),
|
||
"access_key_id": remote_config.get("access_key_id") or os.environ.get("S3_ACCESS_KEY_ID", ""),
|
||
"secret_access_key": remote_config.get("secret_access_key") or os.environ.get("S3_SECRET_ACCESS_KEY", ""),
|
||
"region": remote_config.get("region") or os.environ.get("S3_REGION", ""),
|
||
}
|
||
|
||
def _has_remote_config(self) -> bool:
|
||
"""检查是否有有效的远程存储配置"""
|
||
config = self._get_remote_config()
|
||
return bool(
|
||
config.get("bucket_name") and
|
||
config.get("access_key_id") and
|
||
config.get("secret_access_key") and
|
||
config.get("endpoint_url")
|
||
)
|
||
|
||
def _get_remote_backend(self):
|
||
"""获取远程存储后端实例"""
|
||
if self._remote_backend is not None:
|
||
return self._remote_backend
|
||
|
||
if not self._has_remote_config():
|
||
return None
|
||
|
||
try:
|
||
from trendradar.storage.remote import RemoteStorageBackend
|
||
|
||
remote_config = self._get_remote_config()
|
||
config = self._load_config()
|
||
timezone = config.get("app", {}).get("timezone", "Asia/Shanghai")
|
||
|
||
self._remote_backend = RemoteStorageBackend(
|
||
bucket_name=remote_config["bucket_name"],
|
||
access_key_id=remote_config["access_key_id"],
|
||
secret_access_key=remote_config["secret_access_key"],
|
||
endpoint_url=remote_config["endpoint_url"],
|
||
region=remote_config.get("region", ""),
|
||
timezone=timezone,
|
||
)
|
||
return self._remote_backend
|
||
except ImportError:
|
||
print("[存储同步] 远程存储后端需要安装 boto3: pip install boto3")
|
||
return None
|
||
except Exception as e:
|
||
print(f"[存储同步] 创建远程后端失败: {e}")
|
||
return None
|
||
|
||
def _get_local_data_dir(self) -> Path:
|
||
"""获取本地数据目录"""
|
||
storage_config = self._get_storage_config()
|
||
local_config = storage_config.get("local", {})
|
||
data_dir = local_config.get("data_dir", "output")
|
||
return self.project_root / data_dir
|
||
|
||
def _parse_date_folder_name(self, folder_name: str) -> Optional[datetime]:
|
||
"""
|
||
解析日期文件夹名称(兼容中文和 ISO 格式)
|
||
|
||
支持两种格式:
|
||
- 中文格式:YYYY年MM月DD日
|
||
- ISO 格式:YYYY-MM-DD
|
||
"""
|
||
# 尝试 ISO 格式
|
||
iso_match = re.match(r'(\d{4})-(\d{2})-(\d{2})', folder_name)
|
||
if iso_match:
|
||
try:
|
||
return datetime(
|
||
int(iso_match.group(1)),
|
||
int(iso_match.group(2)),
|
||
int(iso_match.group(3))
|
||
)
|
||
except ValueError:
|
||
pass
|
||
|
||
# 尝试中文格式
|
||
chinese_match = re.match(r'(\d{4})年(\d{2})月(\d{2})日', folder_name)
|
||
if chinese_match:
|
||
try:
|
||
return datetime(
|
||
int(chinese_match.group(1)),
|
||
int(chinese_match.group(2)),
|
||
int(chinese_match.group(3))
|
||
)
|
||
except ValueError:
|
||
pass
|
||
|
||
return None
|
||
|
||
def _get_local_dates(self) -> List[str]:
|
||
"""获取本地可用的日期列表"""
|
||
local_dir = self._get_local_data_dir()
|
||
dates = []
|
||
|
||
if not local_dir.exists():
|
||
return dates
|
||
|
||
for item in local_dir.iterdir():
|
||
if item.is_dir() and not item.name.startswith('.'):
|
||
folder_date = self._parse_date_folder_name(item.name)
|
||
if folder_date:
|
||
dates.append(folder_date.strftime("%Y-%m-%d"))
|
||
|
||
return sorted(dates, reverse=True)
|
||
|
||
def _calculate_dir_size(self, path: Path) -> int:
|
||
"""计算目录大小(字节)"""
|
||
total_size = 0
|
||
if path.exists():
|
||
for item in path.rglob("*"):
|
||
if item.is_file():
|
||
total_size += item.stat().st_size
|
||
return total_size
|
||
|
||
def sync_from_remote(self, days: int = 7) -> Dict:
|
||
"""
|
||
从远程存储拉取数据到本地
|
||
|
||
Args:
|
||
days: 拉取最近 N 天的数据,默认 7 天
|
||
|
||
Returns:
|
||
同步结果字典
|
||
"""
|
||
try:
|
||
# 检查远程配置
|
||
if not self._has_remote_config():
|
||
return {
|
||
"success": False,
|
||
"error": {
|
||
"code": "REMOTE_NOT_CONFIGURED",
|
||
"message": "未配置远程存储",
|
||
"suggestion": "请在 config/config.yaml 中配置 storage.remote 或设置环境变量"
|
||
}
|
||
}
|
||
|
||
# 获取远程后端
|
||
remote_backend = self._get_remote_backend()
|
||
if remote_backend is None:
|
||
return {
|
||
"success": False,
|
||
"error": {
|
||
"code": "REMOTE_BACKEND_FAILED",
|
||
"message": "无法创建远程存储后端",
|
||
"suggestion": "请检查远程存储配置和 boto3 是否已安装"
|
||
}
|
||
}
|
||
|
||
# 获取本地数据目录
|
||
local_dir = self._get_local_data_dir()
|
||
local_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
# 获取远程可用日期
|
||
remote_dates = remote_backend.list_remote_dates()
|
||
|
||
# 获取本地已有日期
|
||
local_dates = set(self._get_local_dates())
|
||
|
||
# 计算需要拉取的日期(最近 N 天)
|
||
from trendradar.utils.time import get_configured_time
|
||
config = self._load_config()
|
||
timezone = config.get("app", {}).get("timezone", "Asia/Shanghai")
|
||
now = get_configured_time(timezone)
|
||
|
||
target_dates = []
|
||
for i in range(days):
|
||
date = now - timedelta(days=i)
|
||
date_str = date.strftime("%Y-%m-%d")
|
||
if date_str in remote_dates:
|
||
target_dates.append(date_str)
|
||
|
||
# 执行拉取
|
||
synced_dates = []
|
||
skipped_dates = []
|
||
failed_dates = []
|
||
|
||
for date_str in target_dates:
|
||
# 检查本地是否已存在
|
||
if date_str in local_dates:
|
||
skipped_dates.append(date_str)
|
||
continue
|
||
|
||
# 拉取单个日期
|
||
try:
|
||
local_date_dir = local_dir / date_str
|
||
local_db_path = local_date_dir / "news.db"
|
||
remote_key = f"news/{date_str}.db"
|
||
|
||
local_date_dir.mkdir(parents=True, exist_ok=True)
|
||
remote_backend.s3_client.download_file(
|
||
remote_backend.bucket_name,
|
||
remote_key,
|
||
str(local_db_path)
|
||
)
|
||
synced_dates.append(date_str)
|
||
print(f"[存储同步] 已拉取: {date_str}")
|
||
except Exception as e:
|
||
failed_dates.append({"date": date_str, "error": str(e)})
|
||
print(f"[存储同步] 拉取失败 ({date_str}): {e}")
|
||
|
||
return {
|
||
"success": True,
|
||
"synced_files": len(synced_dates),
|
||
"synced_dates": synced_dates,
|
||
"skipped_dates": skipped_dates,
|
||
"failed_dates": failed_dates,
|
||
"message": f"成功同步 {len(synced_dates)} 天数据" + (
|
||
f",跳过 {len(skipped_dates)} 天(本地已存在)" if skipped_dates else ""
|
||
) + (
|
||
f",失败 {len(failed_dates)} 天" if failed_dates else ""
|
||
)
|
||
}
|
||
|
||
except MCPError as e:
|
||
return {
|
||
"success": False,
|
||
"error": e.to_dict()
|
||
}
|
||
except Exception as e:
|
||
return {
|
||
"success": False,
|
||
"error": {
|
||
"code": "INTERNAL_ERROR",
|
||
"message": str(e)
|
||
}
|
||
}
|
||
|
||
def get_storage_status(self) -> Dict:
|
||
"""
|
||
获取存储配置和状态
|
||
|
||
Returns:
|
||
存储状态字典
|
||
"""
|
||
try:
|
||
storage_config = self._get_storage_config()
|
||
config = self._load_config()
|
||
|
||
# 本地存储状态
|
||
local_config = storage_config.get("local", {})
|
||
local_dir = self._get_local_data_dir()
|
||
local_size = self._calculate_dir_size(local_dir)
|
||
local_dates = self._get_local_dates()
|
||
|
||
local_status = {
|
||
"data_dir": local_config.get("data_dir", "output"),
|
||
"retention_days": local_config.get("retention_days", 0),
|
||
"total_size": f"{local_size / 1024 / 1024:.2f} MB",
|
||
"total_size_bytes": local_size,
|
||
"date_count": len(local_dates),
|
||
"earliest_date": local_dates[-1] if local_dates else None,
|
||
"latest_date": local_dates[0] if local_dates else None,
|
||
}
|
||
|
||
# 远程存储状态
|
||
remote_config = storage_config.get("remote", {})
|
||
has_remote = self._has_remote_config()
|
||
|
||
remote_status = {
|
||
"configured": has_remote,
|
||
"retention_days": remote_config.get("retention_days", 0),
|
||
}
|
||
|
||
if has_remote:
|
||
merged_config = self._get_remote_config()
|
||
# 脱敏显示
|
||
endpoint = merged_config.get("endpoint_url", "")
|
||
bucket = merged_config.get("bucket_name", "")
|
||
remote_status["endpoint_url"] = endpoint
|
||
remote_status["bucket_name"] = bucket
|
||
|
||
# 尝试获取远程日期列表
|
||
remote_backend = self._get_remote_backend()
|
||
if remote_backend:
|
||
try:
|
||
remote_dates = remote_backend.list_remote_dates()
|
||
remote_status["date_count"] = len(remote_dates)
|
||
remote_status["earliest_date"] = remote_dates[-1] if remote_dates else None
|
||
remote_status["latest_date"] = remote_dates[0] if remote_dates else None
|
||
except Exception as e:
|
||
remote_status["error"] = str(e)
|
||
|
||
# 拉取配置状态
|
||
pull_config = storage_config.get("pull", {})
|
||
pull_status = {
|
||
"enabled": pull_config.get("enabled", False),
|
||
"days": pull_config.get("days", 7),
|
||
}
|
||
|
||
return {
|
||
"success": True,
|
||
"backend": storage_config.get("backend", "auto"),
|
||
"local": local_status,
|
||
"remote": remote_status,
|
||
"pull": pull_status,
|
||
}
|
||
|
||
except MCPError as e:
|
||
return {
|
||
"success": False,
|
||
"error": e.to_dict()
|
||
}
|
||
except Exception as e:
|
||
return {
|
||
"success": False,
|
||
"error": {
|
||
"code": "INTERNAL_ERROR",
|
||
"message": str(e)
|
||
}
|
||
}
|
||
|
||
def list_available_dates(self, source: str = "both") -> Dict:
|
||
"""
|
||
列出可用的日期范围
|
||
|
||
Args:
|
||
source: 数据来源
|
||
- "local": 仅本地
|
||
- "remote": 仅远程
|
||
- "both": 两者都列出(默认)
|
||
|
||
Returns:
|
||
日期列表字典
|
||
"""
|
||
try:
|
||
result = {
|
||
"success": True,
|
||
}
|
||
|
||
# 本地日期
|
||
if source in ("local", "both"):
|
||
local_dates = self._get_local_dates()
|
||
result["local"] = {
|
||
"dates": local_dates,
|
||
"count": len(local_dates),
|
||
"earliest": local_dates[-1] if local_dates else None,
|
||
"latest": local_dates[0] if local_dates else None,
|
||
}
|
||
|
||
# 远程日期
|
||
if source in ("remote", "both"):
|
||
if not self._has_remote_config():
|
||
result["remote"] = {
|
||
"configured": False,
|
||
"dates": [],
|
||
"count": 0,
|
||
"earliest": None,
|
||
"latest": None,
|
||
"error": "未配置远程存储"
|
||
}
|
||
else:
|
||
remote_backend = self._get_remote_backend()
|
||
if remote_backend:
|
||
try:
|
||
remote_dates = remote_backend.list_remote_dates()
|
||
result["remote"] = {
|
||
"configured": True,
|
||
"dates": remote_dates,
|
||
"count": len(remote_dates),
|
||
"earliest": remote_dates[-1] if remote_dates else None,
|
||
"latest": remote_dates[0] if remote_dates else None,
|
||
}
|
||
except Exception as e:
|
||
result["remote"] = {
|
||
"configured": True,
|
||
"dates": [],
|
||
"count": 0,
|
||
"earliest": None,
|
||
"latest": None,
|
||
"error": str(e)
|
||
}
|
||
else:
|
||
result["remote"] = {
|
||
"configured": True,
|
||
"dates": [],
|
||
"count": 0,
|
||
"earliest": None,
|
||
"latest": None,
|
||
"error": "无法创建远程存储后端"
|
||
}
|
||
|
||
# 如果同时查询两者,计算差异
|
||
if source == "both" and "local" in result and "remote" in result:
|
||
local_set = set(result["local"]["dates"])
|
||
remote_set = set(result["remote"].get("dates", []))
|
||
|
||
result["comparison"] = {
|
||
"only_local": sorted(list(local_set - remote_set), reverse=True),
|
||
"only_remote": sorted(list(remote_set - local_set), reverse=True),
|
||
"both": sorted(list(local_set & remote_set), reverse=True),
|
||
}
|
||
|
||
return result
|
||
|
||
except MCPError as e:
|
||
return {
|
||
"success": False,
|
||
"error": e.to_dict()
|
||
}
|
||
except Exception as e:
|
||
return {
|
||
"success": False,
|
||
"error": {
|
||
"code": "INTERNAL_ERROR",
|
||
"message": str(e)
|
||
}
|
||
}
|