增强版v2-初始化仓库,验证好了ocr部分,先备份一次
This commit is contained in:
@@ -0,0 +1,5 @@
|
||||
"""
|
||||
OCR订单处理系统 - OCR核心模块
|
||||
---------------------------
|
||||
提供OCR识别相关功能,包括图片预处理、文字识别和表格识别。
|
||||
"""
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,344 @@
|
||||
"""
|
||||
百度OCR客户端模块
|
||||
---------------
|
||||
提供百度OCR API的访问和调用功能。
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import base64
|
||||
import requests
|
||||
import logging
|
||||
from typing import Dict, Optional, Any, Union
|
||||
|
||||
from ...config.settings import ConfigManager
|
||||
from ..utils.log_utils import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
class TokenManager:
|
||||
"""
|
||||
令牌管理类,负责获取和刷新百度API访问令牌
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: str, secret_key: str, max_retries: int = 3, retry_delay: int = 2):
|
||||
"""
|
||||
初始化令牌管理器
|
||||
|
||||
Args:
|
||||
api_key: 百度API Key
|
||||
secret_key: 百度Secret Key
|
||||
max_retries: 最大重试次数
|
||||
retry_delay: 重试延迟(秒)
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.secret_key = secret_key
|
||||
self.max_retries = max_retries
|
||||
self.retry_delay = retry_delay
|
||||
self.access_token = None
|
||||
self.token_expiry = 0
|
||||
|
||||
def get_token(self) -> Optional[str]:
|
||||
"""
|
||||
获取访问令牌,如果令牌已过期则刷新
|
||||
|
||||
Returns:
|
||||
访问令牌,如果获取失败则返回None
|
||||
"""
|
||||
if self.is_token_valid():
|
||||
return self.access_token
|
||||
|
||||
return self.refresh_token()
|
||||
|
||||
def is_token_valid(self) -> bool:
|
||||
"""
|
||||
检查令牌是否有效
|
||||
|
||||
Returns:
|
||||
令牌是否有效
|
||||
"""
|
||||
return (
|
||||
self.access_token is not None and
|
||||
self.token_expiry > time.time() + 60 # 提前1分钟刷新
|
||||
)
|
||||
|
||||
def refresh_token(self) -> Optional[str]:
|
||||
"""
|
||||
刷新访问令牌
|
||||
|
||||
Returns:
|
||||
新的访问令牌,如果获取失败则返回None
|
||||
"""
|
||||
url = "https://aip.baidubce.com/oauth/2.0/token"
|
||||
params = {
|
||||
"grant_type": "client_credentials",
|
||||
"client_id": self.api_key,
|
||||
"client_secret": self.secret_key
|
||||
}
|
||||
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
response = requests.post(url, params=params, timeout=10)
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
if "access_token" in result:
|
||||
self.access_token = result["access_token"]
|
||||
# 设置令牌过期时间(默认30天,提前1小时过期以确保安全)
|
||||
self.token_expiry = time.time() + result.get("expires_in", 2592000) - 3600
|
||||
logger.info("成功获取访问令牌")
|
||||
return self.access_token
|
||||
|
||||
logger.warning(f"获取访问令牌失败 (尝试 {attempt+1}/{self.max_retries}): {response.text}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"获取访问令牌时发生错误 (尝试 {attempt+1}/{self.max_retries}): {e}")
|
||||
|
||||
# 如果不是最后一次尝试,则等待后重试
|
||||
if attempt < self.max_retries - 1:
|
||||
time.sleep(self.retry_delay * (attempt + 1)) # 指数退避
|
||||
|
||||
logger.error("无法获取访问令牌")
|
||||
return None
|
||||
|
||||
class BaiduOCRClient:
|
||||
"""
|
||||
百度OCR API客户端
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[ConfigManager] = None):
|
||||
"""
|
||||
初始化百度OCR客户端
|
||||
|
||||
Args:
|
||||
config: 配置管理器,如果为None则创建新的
|
||||
"""
|
||||
self.config = config or ConfigManager()
|
||||
|
||||
# 获取配置
|
||||
self.api_key = self.config.get('API', 'api_key')
|
||||
self.secret_key = self.config.get('API', 'secret_key')
|
||||
self.timeout = self.config.getint('API', 'timeout', 30)
|
||||
self.max_retries = self.config.getint('API', 'max_retries', 3)
|
||||
self.retry_delay = self.config.getint('API', 'retry_delay', 2)
|
||||
self.api_url = self.config.get('API', 'api_url', 'https://aip.baidubce.com/rest/2.0/ocr/v1/table')
|
||||
|
||||
# 创建令牌管理器
|
||||
self.token_manager = TokenManager(
|
||||
self.api_key,
|
||||
self.secret_key,
|
||||
self.max_retries,
|
||||
self.retry_delay
|
||||
)
|
||||
|
||||
# 验证API配置
|
||||
if not self.api_key or not self.secret_key:
|
||||
logger.warning("API密钥未设置,请在配置文件中设置API密钥")
|
||||
|
||||
def read_image(self, image_path: str) -> Optional[bytes]:
|
||||
"""
|
||||
读取图片文件为二进制数据
|
||||
|
||||
Args:
|
||||
image_path: 图片文件路径
|
||||
|
||||
Returns:
|
||||
图片二进制数据,如果读取失败则返回None
|
||||
"""
|
||||
try:
|
||||
with open(image_path, 'rb') as f:
|
||||
return f.read()
|
||||
except Exception as e:
|
||||
logger.error(f"读取图片文件失败: {image_path}, 错误: {e}")
|
||||
return None
|
||||
|
||||
def recognize_table(self, image_data: Union[str, bytes]) -> Optional[Dict]:
|
||||
"""
|
||||
识别表格
|
||||
|
||||
Args:
|
||||
image_data: 图片数据,可以是文件路径或二进制数据
|
||||
|
||||
Returns:
|
||||
识别结果字典,如果识别失败则返回None
|
||||
"""
|
||||
# 获取访问令牌
|
||||
access_token = self.token_manager.get_token()
|
||||
if not access_token:
|
||||
logger.error("无法获取访问令牌,无法进行表格识别")
|
||||
return None
|
||||
|
||||
# 如果是文件路径,读取图片数据
|
||||
if isinstance(image_data, str):
|
||||
image_data = self.read_image(image_data)
|
||||
if image_data is None:
|
||||
return None
|
||||
|
||||
# 准备请求参数
|
||||
url = f"{self.api_url}?access_token={access_token}"
|
||||
image_base64 = base64.b64encode(image_data).decode('utf-8')
|
||||
|
||||
# 请求参数 - 添加return_excel参数,与v1版本保持一致
|
||||
payload = {
|
||||
'image': image_base64,
|
||||
'is_sync': 'true', # 同步请求
|
||||
'request_type': 'excel', # 输出为Excel
|
||||
'return_excel': 'true' # 直接返回Excel数据
|
||||
}
|
||||
|
||||
headers = {
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
'Accept': 'application/json'
|
||||
}
|
||||
|
||||
# 发送请求
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
response = requests.post(
|
||||
url,
|
||||
data=payload,
|
||||
headers=headers,
|
||||
timeout=self.timeout
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
# 打印返回结果以便调试
|
||||
logger.debug(f"百度OCR API返回结果: {result}")
|
||||
|
||||
if 'error_code' in result:
|
||||
error_msg = result.get('error_msg', '未知错误')
|
||||
logger.error(f"百度OCR API错误: {error_msg}")
|
||||
# 如果是授权错误,尝试刷新令牌
|
||||
if result.get('error_code') in [110, 111]: # 授权相关错误码
|
||||
logger.info("尝试刷新访问令牌...")
|
||||
self.token_manager.refresh_token()
|
||||
return None
|
||||
|
||||
# 兼容不同的返回结构
|
||||
# 这是最关键的修改部分: 直接返回整个结果,不强制要求特定结构
|
||||
return result
|
||||
else:
|
||||
logger.warning(f"表格识别请求失败 (尝试 {attempt+1}/{self.max_retries}): {response.text}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"表格识别时发生错误 (尝试 {attempt+1}/{self.max_retries}): {e}")
|
||||
|
||||
# 如果不是最后一次尝试,则等待后重试
|
||||
if attempt < self.max_retries - 1:
|
||||
wait_time = self.retry_delay * (2 ** attempt) # 指数退避
|
||||
logger.info(f"将在 {wait_time} 秒后重试...")
|
||||
time.sleep(wait_time)
|
||||
|
||||
logger.error("表格识别失败")
|
||||
return None
|
||||
|
||||
def get_excel_result(self, request_id_or_result: Union[str, Dict]) -> Optional[bytes]:
|
||||
"""
|
||||
获取Excel结果
|
||||
|
||||
Args:
|
||||
request_id_or_result: 请求ID或完整的识别结果
|
||||
|
||||
Returns:
|
||||
Excel二进制数据,如果获取失败则返回None
|
||||
"""
|
||||
# 获取访问令牌
|
||||
access_token = self.token_manager.get_token()
|
||||
if not access_token:
|
||||
logger.error("无法获取访问令牌,无法获取Excel结果")
|
||||
return None
|
||||
|
||||
# 处理直接传入结果对象的情况
|
||||
request_id = request_id_or_result
|
||||
if isinstance(request_id_or_result, dict):
|
||||
# v1版本兼容处理:如果结果中直接包含Excel数据
|
||||
if 'result' in request_id_or_result:
|
||||
# 如果是同步返回的Excel结果(某些API版本会直接返回)
|
||||
if 'result_data' in request_id_or_result['result']:
|
||||
excel_content = request_id_or_result['result']['result_data']
|
||||
if excel_content:
|
||||
try:
|
||||
return base64.b64decode(excel_content)
|
||||
except Exception as e:
|
||||
logger.error(f"解析Excel数据失败: {e}")
|
||||
|
||||
# 提取request_id
|
||||
if 'request_id' in request_id_or_result['result']:
|
||||
request_id = request_id_or_result['result']['request_id']
|
||||
logger.debug(f"从result子对象中提取request_id: {request_id}")
|
||||
elif 'tables_result' in request_id_or_result['result'] and len(request_id_or_result['result']['tables_result']) > 0:
|
||||
# 某些版本API可能直接返回表格内容,此时可能没有request_id
|
||||
logger.info("检测到API直接返回了表格内容,但没有request_id")
|
||||
return None
|
||||
# 有些版本可能request_id在顶层
|
||||
elif 'request_id' in request_id_or_result:
|
||||
request_id = request_id_or_result['request_id']
|
||||
logger.debug(f"从顶层对象中提取request_id: {request_id}")
|
||||
|
||||
# 如果没有有效的request_id,无法获取结果
|
||||
if not isinstance(request_id, str):
|
||||
logger.error(f"无法从结果中提取有效的request_id: {request_id_or_result}")
|
||||
return None
|
||||
|
||||
url = f"https://aip.baidubce.com/rest/2.0/solution/v1/form_ocr/get_request_result?access_token={access_token}"
|
||||
|
||||
payload = {
|
||||
'request_id': request_id,
|
||||
'result_type': 'excel'
|
||||
}
|
||||
|
||||
headers = {
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
'Accept': 'application/json'
|
||||
}
|
||||
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
response = requests.post(
|
||||
url,
|
||||
data=payload,
|
||||
headers=headers,
|
||||
timeout=self.timeout
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
try:
|
||||
result = response.json()
|
||||
logger.debug(f"获取Excel结果返回: {result}")
|
||||
|
||||
# 检查是否还在处理中
|
||||
if result.get('result', {}).get('ret_code') == 3:
|
||||
logger.info(f"Excel结果正在处理中,等待后重试 (尝试 {attempt+1}/{self.max_retries})")
|
||||
time.sleep(2)
|
||||
continue
|
||||
|
||||
# 检查是否有错误
|
||||
if 'error_code' in result or result.get('result', {}).get('ret_code') != 0:
|
||||
error_msg = result.get('error_msg') or result.get('result', {}).get('ret_msg', '未知错误')
|
||||
logger.error(f"获取Excel结果失败: {error_msg}")
|
||||
return None
|
||||
|
||||
# 获取Excel内容
|
||||
excel_content = result.get('result', {}).get('result_data')
|
||||
if excel_content:
|
||||
return base64.b64decode(excel_content)
|
||||
else:
|
||||
logger.error("Excel结果为空")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解析Excel结果时出错: {e}")
|
||||
return None
|
||||
|
||||
else:
|
||||
logger.warning(f"获取Excel结果请求失败 (尝试 {attempt+1}/{self.max_retries}): {response.text}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"获取Excel结果时发生错误 (尝试 {attempt+1}/{self.max_retries}): {e}")
|
||||
|
||||
# 如果不是最后一次尝试,则等待后重试
|
||||
if attempt < self.max_retries - 1:
|
||||
time.sleep(self.retry_delay * (attempt + 1))
|
||||
|
||||
logger.error("获取Excel结果失败")
|
||||
return None
|
||||
@@ -0,0 +1,334 @@
|
||||
"""
|
||||
表格OCR处理模块
|
||||
-------------
|
||||
处理图片并提取表格内容,保存为Excel文件。
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import json
|
||||
import base64
|
||||
from datetime import datetime
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Dict, List, Optional, Tuple, Union, Any
|
||||
|
||||
from ...config.settings import ConfigManager
|
||||
from ..utils.log_utils import get_logger
|
||||
from ..utils.file_utils import (
|
||||
ensure_dir,
|
||||
get_file_extension,
|
||||
get_files_by_extensions,
|
||||
generate_timestamp_filename,
|
||||
is_file_size_valid,
|
||||
load_json,
|
||||
save_json
|
||||
)
|
||||
from .baidu_ocr import BaiduOCRClient
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
class ProcessedRecordManager:
|
||||
"""处理记录管理器,用于跟踪已处理的文件"""
|
||||
|
||||
def __init__(self, record_file: str):
|
||||
"""
|
||||
初始化处理记录管理器
|
||||
|
||||
Args:
|
||||
record_file: 记录文件路径
|
||||
"""
|
||||
self.record_file = record_file
|
||||
self.processed_files = self._load_record()
|
||||
|
||||
def _load_record(self) -> Dict[str, str]:
|
||||
"""
|
||||
加载处理记录
|
||||
|
||||
Returns:
|
||||
处理记录字典,键为输入文件路径,值为输出文件路径
|
||||
"""
|
||||
return load_json(self.record_file, {})
|
||||
|
||||
def save_record(self) -> None:
|
||||
"""保存处理记录"""
|
||||
save_json(self.processed_files, self.record_file)
|
||||
|
||||
def is_processed(self, image_file: str) -> bool:
|
||||
"""
|
||||
检查图片是否已处理
|
||||
|
||||
Args:
|
||||
image_file: 图片文件路径
|
||||
|
||||
Returns:
|
||||
是否已处理
|
||||
"""
|
||||
return image_file in self.processed_files
|
||||
|
||||
def mark_as_processed(self, image_file: str, output_file: str) -> None:
|
||||
"""
|
||||
标记图片为已处理
|
||||
|
||||
Args:
|
||||
image_file: 图片文件路径
|
||||
output_file: 输出文件路径
|
||||
"""
|
||||
self.processed_files[image_file] = output_file
|
||||
self.save_record()
|
||||
|
||||
def get_output_file(self, image_file: str) -> Optional[str]:
|
||||
"""
|
||||
获取图片的输出文件路径
|
||||
|
||||
Args:
|
||||
image_file: 图片文件路径
|
||||
|
||||
Returns:
|
||||
输出文件路径,如果不存在则返回None
|
||||
"""
|
||||
return self.processed_files.get(image_file)
|
||||
|
||||
def get_unprocessed_files(self, files: List[str]) -> List[str]:
|
||||
"""
|
||||
获取未处理的文件列表
|
||||
|
||||
Args:
|
||||
files: 文件列表
|
||||
|
||||
Returns:
|
||||
未处理的文件列表
|
||||
"""
|
||||
return [file for file in files if not self.is_processed(file)]
|
||||
|
||||
class OCRProcessor:
|
||||
"""
|
||||
OCR处理器,用于表格识别与处理
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[ConfigManager] = None):
|
||||
"""
|
||||
初始化OCR处理器
|
||||
|
||||
Args:
|
||||
config: 配置管理器,如果为None则创建新的
|
||||
"""
|
||||
self.config = config or ConfigManager()
|
||||
|
||||
# 创建百度OCR客户端
|
||||
self.ocr_client = BaiduOCRClient(self.config)
|
||||
|
||||
# 获取配置
|
||||
self.input_folder = self.config.get_path('Paths', 'input_folder', 'data/input', create=True)
|
||||
self.output_folder = self.config.get_path('Paths', 'output_folder', 'data/output', create=True)
|
||||
self.temp_folder = self.config.get_path('Paths', 'temp_folder', 'data/temp', create=True)
|
||||
|
||||
# 确保目录结构正确
|
||||
for folder in [self.input_folder, self.output_folder, self.temp_folder]:
|
||||
if not os.path.exists(folder):
|
||||
os.makedirs(folder, exist_ok=True)
|
||||
logger.info(f"创建目录: {folder}")
|
||||
|
||||
# 记录实际路径
|
||||
logger.info(f"使用输入目录: {os.path.abspath(self.input_folder)}")
|
||||
logger.info(f"使用输出目录: {os.path.abspath(self.output_folder)}")
|
||||
logger.info(f"使用临时目录: {os.path.abspath(self.temp_folder)}")
|
||||
|
||||
self.allowed_extensions = self.config.get_list('File', 'allowed_extensions', '.jpg,.jpeg,.png,.bmp')
|
||||
self.max_file_size_mb = self.config.getfloat('File', 'max_file_size_mb', 4.0)
|
||||
self.excel_extension = self.config.get('File', 'excel_extension', '.xlsx')
|
||||
|
||||
# 处理性能配置
|
||||
self.max_workers = self.config.getint('Performance', 'max_workers', 4)
|
||||
self.batch_size = self.config.getint('Performance', 'batch_size', 5)
|
||||
self.skip_existing = self.config.getboolean('Performance', 'skip_existing', True)
|
||||
|
||||
# 初始化处理记录管理器
|
||||
record_file = self.config.get('Paths', 'processed_record', 'data/processed_files.json')
|
||||
self.record_manager = ProcessedRecordManager(record_file)
|
||||
|
||||
logger.info(f"OCR处理器初始化完成,输入目录: {self.input_folder}, 输出目录: {self.output_folder}")
|
||||
|
||||
def get_unprocessed_images(self) -> List[str]:
|
||||
"""
|
||||
获取未处理的图片列表
|
||||
|
||||
Returns:
|
||||
未处理的图片文件路径列表
|
||||
"""
|
||||
# 获取所有图片文件
|
||||
image_files = get_files_by_extensions(self.input_folder, self.allowed_extensions)
|
||||
|
||||
# 如果需要跳过已存在的文件
|
||||
if self.skip_existing:
|
||||
# 过滤已处理的文件
|
||||
unprocessed_files = self.record_manager.get_unprocessed_files(image_files)
|
||||
logger.info(f"找到 {len(image_files)} 个图片文件,其中 {len(unprocessed_files)} 个未处理")
|
||||
return unprocessed_files
|
||||
|
||||
logger.info(f"找到 {len(image_files)} 个图片文件(不跳过已处理的文件)")
|
||||
return image_files
|
||||
|
||||
def validate_image(self, image_path: str) -> bool:
|
||||
"""
|
||||
验证图片是否有效
|
||||
|
||||
Args:
|
||||
image_path: 图片文件路径
|
||||
|
||||
Returns:
|
||||
图片是否有效
|
||||
"""
|
||||
# 检查文件是否存在
|
||||
if not os.path.exists(image_path):
|
||||
logger.warning(f"图片文件不存在: {image_path}")
|
||||
return False
|
||||
|
||||
# 检查文件扩展名
|
||||
ext = get_file_extension(image_path)
|
||||
if ext not in self.allowed_extensions:
|
||||
logger.warning(f"不支持的文件类型: {ext}, 文件: {image_path}")
|
||||
return False
|
||||
|
||||
# 检查文件大小
|
||||
if not is_file_size_valid(image_path, self.max_file_size_mb):
|
||||
logger.warning(f"文件大小超过限制 ({self.max_file_size_mb}MB): {image_path}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def process_image(self, image_path: str) -> Optional[str]:
|
||||
"""
|
||||
处理单个图片
|
||||
|
||||
Args:
|
||||
image_path: 图片文件路径
|
||||
|
||||
Returns:
|
||||
输出Excel文件路径,如果处理失败则返回None
|
||||
"""
|
||||
# 验证图片
|
||||
if not self.validate_image(image_path):
|
||||
return None
|
||||
|
||||
# 如果需要跳过已处理的文件
|
||||
if self.skip_existing and self.record_manager.is_processed(image_path):
|
||||
output_file = self.record_manager.get_output_file(image_path)
|
||||
logger.info(f"图片已处理,跳过: {image_path}, 输出文件: {output_file}")
|
||||
return output_file
|
||||
|
||||
logger.info(f"开始处理图片: {image_path}")
|
||||
|
||||
try:
|
||||
# 生成输出文件路径
|
||||
file_name = os.path.splitext(os.path.basename(image_path))[0]
|
||||
output_file = os.path.join(self.output_folder, f"{file_name}{self.excel_extension}")
|
||||
|
||||
# 检查是否已存在对应的Excel文件
|
||||
if os.path.exists(output_file) and self.skip_existing:
|
||||
logger.info(f"已存在对应的Excel文件,跳过处理: {os.path.basename(image_path)} -> {os.path.basename(output_file)}")
|
||||
# 记录处理结果
|
||||
self.record_manager.mark_as_processed(image_path, output_file)
|
||||
return output_file
|
||||
|
||||
# 进行OCR识别
|
||||
ocr_result = self.ocr_client.recognize_table(image_path)
|
||||
if not ocr_result:
|
||||
logger.error(f"OCR识别失败: {image_path}")
|
||||
return None
|
||||
|
||||
# 保存Excel文件 - 按照v1版本逻辑提取Excel数据
|
||||
excel_base64 = None
|
||||
|
||||
# 从不同可能的字段中尝试获取Excel数据
|
||||
if 'excel_file' in ocr_result:
|
||||
excel_base64 = ocr_result['excel_file']
|
||||
logger.debug("从excel_file字段获取Excel数据")
|
||||
elif 'result' in ocr_result:
|
||||
if 'result_data' in ocr_result['result']:
|
||||
excel_base64 = ocr_result['result']['result_data']
|
||||
logger.debug("从result.result_data字段获取Excel数据")
|
||||
elif 'excel_file' in ocr_result['result']:
|
||||
excel_base64 = ocr_result['result']['excel_file']
|
||||
logger.debug("从result.excel_file字段获取Excel数据")
|
||||
elif 'tables_result' in ocr_result['result'] and ocr_result['result']['tables_result']:
|
||||
for table in ocr_result['result']['tables_result']:
|
||||
if 'excel_file' in table:
|
||||
excel_base64 = table['excel_file']
|
||||
logger.debug("从tables_result中获取Excel数据")
|
||||
break
|
||||
|
||||
# 如果还是没有找到Excel数据,尝试通过get_excel_result获取
|
||||
if not excel_base64:
|
||||
logger.info("无法从直接返回中获取Excel数据,尝试通过API获取...")
|
||||
excel_data = self.ocr_client.get_excel_result(ocr_result)
|
||||
if not excel_data:
|
||||
logger.error(f"获取Excel结果失败: {image_path}")
|
||||
return None
|
||||
|
||||
# 保存Excel文件
|
||||
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
||||
with open(output_file, 'wb') as f:
|
||||
f.write(excel_data)
|
||||
else:
|
||||
# 解码并保存Excel文件
|
||||
try:
|
||||
excel_data = base64.b64decode(excel_base64)
|
||||
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
||||
with open(output_file, 'wb') as f:
|
||||
f.write(excel_data)
|
||||
except Exception as e:
|
||||
logger.error(f"解码或保存Excel数据时出错: {e}")
|
||||
return None
|
||||
|
||||
logger.info(f"图片处理成功: {image_path}, 输出文件: {output_file}")
|
||||
|
||||
# 标记为已处理
|
||||
self.record_manager.mark_as_processed(image_path, output_file)
|
||||
|
||||
return output_file
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理图片时出错: {image_path}, 错误: {e}")
|
||||
return None
|
||||
|
||||
def process_images_batch(self, batch_size: int = None, max_workers: int = None) -> Tuple[int, int]:
|
||||
"""
|
||||
批量处理图片
|
||||
|
||||
Args:
|
||||
batch_size: 批处理大小,如果为None则使用配置值
|
||||
max_workers: 最大线程数,如果为None则使用配置值
|
||||
|
||||
Returns:
|
||||
(总处理数, 成功处理数)元组
|
||||
"""
|
||||
# 使用配置值或参数值
|
||||
batch_size = batch_size or self.batch_size
|
||||
max_workers = max_workers or self.max_workers
|
||||
|
||||
# 获取未处理的图片
|
||||
unprocessed_images = self.get_unprocessed_images()
|
||||
if not unprocessed_images:
|
||||
logger.warning("没有需要处理的图片")
|
||||
return 0, 0
|
||||
|
||||
total = len(unprocessed_images)
|
||||
success = 0
|
||||
|
||||
# 按批次处理
|
||||
for i in range(0, total, batch_size):
|
||||
batch = unprocessed_images[i:i + batch_size]
|
||||
logger.info(f"处理批次 {i//batch_size + 1}/{(total-1)//batch_size + 1}, 大小: {len(batch)}")
|
||||
|
||||
# 使用线程池并行处理
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
results = list(executor.map(self.process_image, batch))
|
||||
|
||||
# 统计成功数
|
||||
success += sum(1 for result in results if result is not None)
|
||||
|
||||
logger.info(f"批次处理完成, 成功: {sum(1 for result in results if result is not None)}/{len(batch)}")
|
||||
|
||||
logger.info(f"所有图片处理完成, 总计: {total}, 成功: {success}")
|
||||
return total, success
|
||||
Reference in New Issue
Block a user