增强版v2-初始化仓库,验证好了ocr部分,先备份一次

This commit is contained in:
2025-05-02 17:25:47 +08:00
commit 0035cd1893
88 changed files with 9031 additions and 0 deletions
+5
View File
@@ -0,0 +1,5 @@
"""
OCR订单处理系统 - OCR核心模块
---------------------------
提供OCR识别相关功能,包括图片预处理、文字识别和表格识别。
"""
Binary file not shown.
Binary file not shown.
Binary file not shown.
+344
View File
@@ -0,0 +1,344 @@
"""
百度OCR客户端模块
---------------
提供百度OCR API的访问和调用功能。
"""
import os
import time
import base64
import requests
import logging
from typing import Dict, Optional, Any, Union
from ...config.settings import ConfigManager
from ..utils.log_utils import get_logger
logger = get_logger(__name__)
class TokenManager:
"""
令牌管理类,负责获取和刷新百度API访问令牌
"""
def __init__(self, api_key: str, secret_key: str, max_retries: int = 3, retry_delay: int = 2):
"""
初始化令牌管理器
Args:
api_key: 百度API Key
secret_key: 百度Secret Key
max_retries: 最大重试次数
retry_delay: 重试延迟(秒)
"""
self.api_key = api_key
self.secret_key = secret_key
self.max_retries = max_retries
self.retry_delay = retry_delay
self.access_token = None
self.token_expiry = 0
def get_token(self) -> Optional[str]:
"""
获取访问令牌,如果令牌已过期则刷新
Returns:
访问令牌,如果获取失败则返回None
"""
if self.is_token_valid():
return self.access_token
return self.refresh_token()
def is_token_valid(self) -> bool:
"""
检查令牌是否有效
Returns:
令牌是否有效
"""
return (
self.access_token is not None and
self.token_expiry > time.time() + 60 # 提前1分钟刷新
)
def refresh_token(self) -> Optional[str]:
"""
刷新访问令牌
Returns:
新的访问令牌,如果获取失败则返回None
"""
url = "https://aip.baidubce.com/oauth/2.0/token"
params = {
"grant_type": "client_credentials",
"client_id": self.api_key,
"client_secret": self.secret_key
}
for attempt in range(self.max_retries):
try:
response = requests.post(url, params=params, timeout=10)
if response.status_code == 200:
result = response.json()
if "access_token" in result:
self.access_token = result["access_token"]
# 设置令牌过期时间(默认30天,提前1小时过期以确保安全)
self.token_expiry = time.time() + result.get("expires_in", 2592000) - 3600
logger.info("成功获取访问令牌")
return self.access_token
logger.warning(f"获取访问令牌失败 (尝试 {attempt+1}/{self.max_retries}): {response.text}")
except Exception as e:
logger.warning(f"获取访问令牌时发生错误 (尝试 {attempt+1}/{self.max_retries}): {e}")
# 如果不是最后一次尝试,则等待后重试
if attempt < self.max_retries - 1:
time.sleep(self.retry_delay * (attempt + 1)) # 指数退避
logger.error("无法获取访问令牌")
return None
class BaiduOCRClient:
"""
百度OCR API客户端
"""
def __init__(self, config: Optional[ConfigManager] = None):
"""
初始化百度OCR客户端
Args:
config: 配置管理器,如果为None则创建新的
"""
self.config = config or ConfigManager()
# 获取配置
self.api_key = self.config.get('API', 'api_key')
self.secret_key = self.config.get('API', 'secret_key')
self.timeout = self.config.getint('API', 'timeout', 30)
self.max_retries = self.config.getint('API', 'max_retries', 3)
self.retry_delay = self.config.getint('API', 'retry_delay', 2)
self.api_url = self.config.get('API', 'api_url', 'https://aip.baidubce.com/rest/2.0/ocr/v1/table')
# 创建令牌管理器
self.token_manager = TokenManager(
self.api_key,
self.secret_key,
self.max_retries,
self.retry_delay
)
# 验证API配置
if not self.api_key or not self.secret_key:
logger.warning("API密钥未设置,请在配置文件中设置API密钥")
def read_image(self, image_path: str) -> Optional[bytes]:
"""
读取图片文件为二进制数据
Args:
image_path: 图片文件路径
Returns:
图片二进制数据,如果读取失败则返回None
"""
try:
with open(image_path, 'rb') as f:
return f.read()
except Exception as e:
logger.error(f"读取图片文件失败: {image_path}, 错误: {e}")
return None
def recognize_table(self, image_data: Union[str, bytes]) -> Optional[Dict]:
"""
识别表格
Args:
image_data: 图片数据,可以是文件路径或二进制数据
Returns:
识别结果字典,如果识别失败则返回None
"""
# 获取访问令牌
access_token = self.token_manager.get_token()
if not access_token:
logger.error("无法获取访问令牌,无法进行表格识别")
return None
# 如果是文件路径,读取图片数据
if isinstance(image_data, str):
image_data = self.read_image(image_data)
if image_data is None:
return None
# 准备请求参数
url = f"{self.api_url}?access_token={access_token}"
image_base64 = base64.b64encode(image_data).decode('utf-8')
# 请求参数 - 添加return_excel参数,与v1版本保持一致
payload = {
'image': image_base64,
'is_sync': 'true', # 同步请求
'request_type': 'excel', # 输出为Excel
'return_excel': 'true' # 直接返回Excel数据
}
headers = {
'Content-Type': 'application/x-www-form-urlencoded',
'Accept': 'application/json'
}
# 发送请求
for attempt in range(self.max_retries):
try:
response = requests.post(
url,
data=payload,
headers=headers,
timeout=self.timeout
)
if response.status_code == 200:
result = response.json()
# 打印返回结果以便调试
logger.debug(f"百度OCR API返回结果: {result}")
if 'error_code' in result:
error_msg = result.get('error_msg', '未知错误')
logger.error(f"百度OCR API错误: {error_msg}")
# 如果是授权错误,尝试刷新令牌
if result.get('error_code') in [110, 111]: # 授权相关错误码
logger.info("尝试刷新访问令牌...")
self.token_manager.refresh_token()
return None
# 兼容不同的返回结构
# 这是最关键的修改部分: 直接返回整个结果,不强制要求特定结构
return result
else:
logger.warning(f"表格识别请求失败 (尝试 {attempt+1}/{self.max_retries}): {response.text}")
except Exception as e:
logger.warning(f"表格识别时发生错误 (尝试 {attempt+1}/{self.max_retries}): {e}")
# 如果不是最后一次尝试,则等待后重试
if attempt < self.max_retries - 1:
wait_time = self.retry_delay * (2 ** attempt) # 指数退避
logger.info(f"将在 {wait_time} 秒后重试...")
time.sleep(wait_time)
logger.error("表格识别失败")
return None
def get_excel_result(self, request_id_or_result: Union[str, Dict]) -> Optional[bytes]:
"""
获取Excel结果
Args:
request_id_or_result: 请求ID或完整的识别结果
Returns:
Excel二进制数据,如果获取失败则返回None
"""
# 获取访问令牌
access_token = self.token_manager.get_token()
if not access_token:
logger.error("无法获取访问令牌,无法获取Excel结果")
return None
# 处理直接传入结果对象的情况
request_id = request_id_or_result
if isinstance(request_id_or_result, dict):
# v1版本兼容处理:如果结果中直接包含Excel数据
if 'result' in request_id_or_result:
# 如果是同步返回的Excel结果(某些API版本会直接返回)
if 'result_data' in request_id_or_result['result']:
excel_content = request_id_or_result['result']['result_data']
if excel_content:
try:
return base64.b64decode(excel_content)
except Exception as e:
logger.error(f"解析Excel数据失败: {e}")
# 提取request_id
if 'request_id' in request_id_or_result['result']:
request_id = request_id_or_result['result']['request_id']
logger.debug(f"从result子对象中提取request_id: {request_id}")
elif 'tables_result' in request_id_or_result['result'] and len(request_id_or_result['result']['tables_result']) > 0:
# 某些版本API可能直接返回表格内容,此时可能没有request_id
logger.info("检测到API直接返回了表格内容,但没有request_id")
return None
# 有些版本可能request_id在顶层
elif 'request_id' in request_id_or_result:
request_id = request_id_or_result['request_id']
logger.debug(f"从顶层对象中提取request_id: {request_id}")
# 如果没有有效的request_id,无法获取结果
if not isinstance(request_id, str):
logger.error(f"无法从结果中提取有效的request_id: {request_id_or_result}")
return None
url = f"https://aip.baidubce.com/rest/2.0/solution/v1/form_ocr/get_request_result?access_token={access_token}"
payload = {
'request_id': request_id,
'result_type': 'excel'
}
headers = {
'Content-Type': 'application/x-www-form-urlencoded',
'Accept': 'application/json'
}
for attempt in range(self.max_retries):
try:
response = requests.post(
url,
data=payload,
headers=headers,
timeout=self.timeout
)
if response.status_code == 200:
try:
result = response.json()
logger.debug(f"获取Excel结果返回: {result}")
# 检查是否还在处理中
if result.get('result', {}).get('ret_code') == 3:
logger.info(f"Excel结果正在处理中,等待后重试 (尝试 {attempt+1}/{self.max_retries})")
time.sleep(2)
continue
# 检查是否有错误
if 'error_code' in result or result.get('result', {}).get('ret_code') != 0:
error_msg = result.get('error_msg') or result.get('result', {}).get('ret_msg', '未知错误')
logger.error(f"获取Excel结果失败: {error_msg}")
return None
# 获取Excel内容
excel_content = result.get('result', {}).get('result_data')
if excel_content:
return base64.b64decode(excel_content)
else:
logger.error("Excel结果为空")
return None
except Exception as e:
logger.error(f"解析Excel结果时出错: {e}")
return None
else:
logger.warning(f"获取Excel结果请求失败 (尝试 {attempt+1}/{self.max_retries}): {response.text}")
except Exception as e:
logger.warning(f"获取Excel结果时发生错误 (尝试 {attempt+1}/{self.max_retries}): {e}")
# 如果不是最后一次尝试,则等待后重试
if attempt < self.max_retries - 1:
time.sleep(self.retry_delay * (attempt + 1))
logger.error("获取Excel结果失败")
return None
+334
View File
@@ -0,0 +1,334 @@
"""
表格OCR处理模块
-------------
处理图片并提取表格内容,保存为Excel文件。
"""
import os
import sys
import time
import json
import base64
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Optional, Tuple, Union, Any
from ...config.settings import ConfigManager
from ..utils.log_utils import get_logger
from ..utils.file_utils import (
ensure_dir,
get_file_extension,
get_files_by_extensions,
generate_timestamp_filename,
is_file_size_valid,
load_json,
save_json
)
from .baidu_ocr import BaiduOCRClient
logger = get_logger(__name__)
class ProcessedRecordManager:
"""处理记录管理器,用于跟踪已处理的文件"""
def __init__(self, record_file: str):
"""
初始化处理记录管理器
Args:
record_file: 记录文件路径
"""
self.record_file = record_file
self.processed_files = self._load_record()
def _load_record(self) -> Dict[str, str]:
"""
加载处理记录
Returns:
处理记录字典,键为输入文件路径,值为输出文件路径
"""
return load_json(self.record_file, {})
def save_record(self) -> None:
"""保存处理记录"""
save_json(self.processed_files, self.record_file)
def is_processed(self, image_file: str) -> bool:
"""
检查图片是否已处理
Args:
image_file: 图片文件路径
Returns:
是否已处理
"""
return image_file in self.processed_files
def mark_as_processed(self, image_file: str, output_file: str) -> None:
"""
标记图片为已处理
Args:
image_file: 图片文件路径
output_file: 输出文件路径
"""
self.processed_files[image_file] = output_file
self.save_record()
def get_output_file(self, image_file: str) -> Optional[str]:
"""
获取图片的输出文件路径
Args:
image_file: 图片文件路径
Returns:
输出文件路径,如果不存在则返回None
"""
return self.processed_files.get(image_file)
def get_unprocessed_files(self, files: List[str]) -> List[str]:
"""
获取未处理的文件列表
Args:
files: 文件列表
Returns:
未处理的文件列表
"""
return [file for file in files if not self.is_processed(file)]
class OCRProcessor:
"""
OCR处理器,用于表格识别与处理
"""
def __init__(self, config: Optional[ConfigManager] = None):
"""
初始化OCR处理器
Args:
config: 配置管理器,如果为None则创建新的
"""
self.config = config or ConfigManager()
# 创建百度OCR客户端
self.ocr_client = BaiduOCRClient(self.config)
# 获取配置
self.input_folder = self.config.get_path('Paths', 'input_folder', 'data/input', create=True)
self.output_folder = self.config.get_path('Paths', 'output_folder', 'data/output', create=True)
self.temp_folder = self.config.get_path('Paths', 'temp_folder', 'data/temp', create=True)
# 确保目录结构正确
for folder in [self.input_folder, self.output_folder, self.temp_folder]:
if not os.path.exists(folder):
os.makedirs(folder, exist_ok=True)
logger.info(f"创建目录: {folder}")
# 记录实际路径
logger.info(f"使用输入目录: {os.path.abspath(self.input_folder)}")
logger.info(f"使用输出目录: {os.path.abspath(self.output_folder)}")
logger.info(f"使用临时目录: {os.path.abspath(self.temp_folder)}")
self.allowed_extensions = self.config.get_list('File', 'allowed_extensions', '.jpg,.jpeg,.png,.bmp')
self.max_file_size_mb = self.config.getfloat('File', 'max_file_size_mb', 4.0)
self.excel_extension = self.config.get('File', 'excel_extension', '.xlsx')
# 处理性能配置
self.max_workers = self.config.getint('Performance', 'max_workers', 4)
self.batch_size = self.config.getint('Performance', 'batch_size', 5)
self.skip_existing = self.config.getboolean('Performance', 'skip_existing', True)
# 初始化处理记录管理器
record_file = self.config.get('Paths', 'processed_record', 'data/processed_files.json')
self.record_manager = ProcessedRecordManager(record_file)
logger.info(f"OCR处理器初始化完成,输入目录: {self.input_folder}, 输出目录: {self.output_folder}")
def get_unprocessed_images(self) -> List[str]:
"""
获取未处理的图片列表
Returns:
未处理的图片文件路径列表
"""
# 获取所有图片文件
image_files = get_files_by_extensions(self.input_folder, self.allowed_extensions)
# 如果需要跳过已存在的文件
if self.skip_existing:
# 过滤已处理的文件
unprocessed_files = self.record_manager.get_unprocessed_files(image_files)
logger.info(f"找到 {len(image_files)} 个图片文件,其中 {len(unprocessed_files)} 个未处理")
return unprocessed_files
logger.info(f"找到 {len(image_files)} 个图片文件(不跳过已处理的文件)")
return image_files
def validate_image(self, image_path: str) -> bool:
"""
验证图片是否有效
Args:
image_path: 图片文件路径
Returns:
图片是否有效
"""
# 检查文件是否存在
if not os.path.exists(image_path):
logger.warning(f"图片文件不存在: {image_path}")
return False
# 检查文件扩展名
ext = get_file_extension(image_path)
if ext not in self.allowed_extensions:
logger.warning(f"不支持的文件类型: {ext}, 文件: {image_path}")
return False
# 检查文件大小
if not is_file_size_valid(image_path, self.max_file_size_mb):
logger.warning(f"文件大小超过限制 ({self.max_file_size_mb}MB): {image_path}")
return False
return True
def process_image(self, image_path: str) -> Optional[str]:
"""
处理单个图片
Args:
image_path: 图片文件路径
Returns:
输出Excel文件路径,如果处理失败则返回None
"""
# 验证图片
if not self.validate_image(image_path):
return None
# 如果需要跳过已处理的文件
if self.skip_existing and self.record_manager.is_processed(image_path):
output_file = self.record_manager.get_output_file(image_path)
logger.info(f"图片已处理,跳过: {image_path}, 输出文件: {output_file}")
return output_file
logger.info(f"开始处理图片: {image_path}")
try:
# 生成输出文件路径
file_name = os.path.splitext(os.path.basename(image_path))[0]
output_file = os.path.join(self.output_folder, f"{file_name}{self.excel_extension}")
# 检查是否已存在对应的Excel文件
if os.path.exists(output_file) and self.skip_existing:
logger.info(f"已存在对应的Excel文件,跳过处理: {os.path.basename(image_path)} -> {os.path.basename(output_file)}")
# 记录处理结果
self.record_manager.mark_as_processed(image_path, output_file)
return output_file
# 进行OCR识别
ocr_result = self.ocr_client.recognize_table(image_path)
if not ocr_result:
logger.error(f"OCR识别失败: {image_path}")
return None
# 保存Excel文件 - 按照v1版本逻辑提取Excel数据
excel_base64 = None
# 从不同可能的字段中尝试获取Excel数据
if 'excel_file' in ocr_result:
excel_base64 = ocr_result['excel_file']
logger.debug("从excel_file字段获取Excel数据")
elif 'result' in ocr_result:
if 'result_data' in ocr_result['result']:
excel_base64 = ocr_result['result']['result_data']
logger.debug("从result.result_data字段获取Excel数据")
elif 'excel_file' in ocr_result['result']:
excel_base64 = ocr_result['result']['excel_file']
logger.debug("从result.excel_file字段获取Excel数据")
elif 'tables_result' in ocr_result['result'] and ocr_result['result']['tables_result']:
for table in ocr_result['result']['tables_result']:
if 'excel_file' in table:
excel_base64 = table['excel_file']
logger.debug("从tables_result中获取Excel数据")
break
# 如果还是没有找到Excel数据,尝试通过get_excel_result获取
if not excel_base64:
logger.info("无法从直接返回中获取Excel数据,尝试通过API获取...")
excel_data = self.ocr_client.get_excel_result(ocr_result)
if not excel_data:
logger.error(f"获取Excel结果失败: {image_path}")
return None
# 保存Excel文件
os.makedirs(os.path.dirname(output_file), exist_ok=True)
with open(output_file, 'wb') as f:
f.write(excel_data)
else:
# 解码并保存Excel文件
try:
excel_data = base64.b64decode(excel_base64)
os.makedirs(os.path.dirname(output_file), exist_ok=True)
with open(output_file, 'wb') as f:
f.write(excel_data)
except Exception as e:
logger.error(f"解码或保存Excel数据时出错: {e}")
return None
logger.info(f"图片处理成功: {image_path}, 输出文件: {output_file}")
# 标记为已处理
self.record_manager.mark_as_processed(image_path, output_file)
return output_file
except Exception as e:
logger.error(f"处理图片时出错: {image_path}, 错误: {e}")
return None
def process_images_batch(self, batch_size: int = None, max_workers: int = None) -> Tuple[int, int]:
"""
批量处理图片
Args:
batch_size: 批处理大小,如果为None则使用配置值
max_workers: 最大线程数,如果为None则使用配置值
Returns:
(总处理数, 成功处理数)元组
"""
# 使用配置值或参数值
batch_size = batch_size or self.batch_size
max_workers = max_workers or self.max_workers
# 获取未处理的图片
unprocessed_images = self.get_unprocessed_images()
if not unprocessed_images:
logger.warning("没有需要处理的图片")
return 0, 0
total = len(unprocessed_images)
success = 0
# 按批次处理
for i in range(0, total, batch_size):
batch = unprocessed_images[i:i + batch_size]
logger.info(f"处理批次 {i//batch_size + 1}/{(total-1)//batch_size + 1}, 大小: {len(batch)}")
# 使用线程池并行处理
with ThreadPoolExecutor(max_workers=max_workers) as executor:
results = list(executor.map(self.process_image, batch))
# 统计成功数
success += sum(1 for result in results if result is not None)
logger.info(f"批次处理完成, 成功: {sum(1 for result in results if result is not None)}/{len(batch)}")
logger.info(f"所有图片处理完成, 总计: {total}, 成功: {success}")
return total, success