feat: 商品记忆库 — 从OCR结果学习,逐步替代OCR识别

- 扩展 product_db.py: schema迁移(specification/source/confidence/usage_count/last_seen)
  + 学习逻辑(learn_from_product)、置信度系统、批量查询、导入导出、云端同步
- 注入处理管线: processor.py 在提取产品后调用 _apply_memory() 用记忆补全OCR
  + _is_spec_suspicious() 检测OCR规格质量,处理完后自动学习
- order_service.py 创建共享 ProductDatabase 实例
- dialog_utils.py 新增商品记忆库云端同步条目
- 新建 memory_editor.py: Treeview查看/编辑/搜索/删除/重新导入
- main_window.py 系统设置区新增"商品记忆库"按钮
- build_exe.py 添加 memory_editor 到 hidden_imports
@
This commit is contained in:
2026-05-05 02:40:48 +08:00
parent 5cf9a98d9a
commit d267a1d1fa
8 changed files with 656 additions and 44 deletions
+342 -26
View File
@@ -1,11 +1,18 @@
"""
商品资料 SQLite 数据库
商品资料 SQLite 数据库 + 商品记忆库
将商品资料 (条码/名称/进货价/单位) 存储在 SQLite 中,
支持从 Excel 自动导入按条码快速查询。
将商品资料 (条码/名称/进货价/单位/规格) 存储在 SQLite 中,
支持从 Excel 自动导入按条码快速查询、以及从 OCR 处理结果中学习
记忆库功能:
- 处理完每单后自动学习商品数据
- 下次处理时用记忆库补全 OCR 缺失/错误的字段
- 通过置信度系统控制数据质量
- 支持云端同步
"""
import os
import json
import sqlite3
from datetime import datetime
from typing import Dict, List, Optional
@@ -20,7 +27,7 @@ logger = get_logger(__name__)
class ProductDatabase:
"""商品资料 SQLite 数据库"""
"""商品资料 SQLite 数据库 + 商品记忆库"""
SCHEMA = """
CREATE TABLE IF NOT EXISTS products (
@@ -28,10 +35,24 @@ class ProductDatabase:
name TEXT DEFAULT '',
price REAL DEFAULT 0.0,
unit TEXT DEFAULT '',
updated_at TEXT
updated_at TEXT,
specification TEXT DEFAULT '',
source TEXT DEFAULT 'template',
confidence INTEGER DEFAULT 0,
usage_count INTEGER DEFAULT 0,
last_seen TEXT
);
"""
# 新增列定义(用于迁移)
_NEW_COLUMNS = {
'specification': "TEXT DEFAULT ''",
'source': "TEXT DEFAULT 'template'",
'confidence': 'INTEGER DEFAULT 0',
'usage_count': 'INTEGER DEFAULT 0',
'last_seen': 'TEXT',
}
def __init__(self, db_path: str, excel_source: str):
"""初始化数据库,如果 SQLite 不存在则自动从 Excel 导入
@@ -49,6 +70,7 @@ class ProductDatabase:
def _ensure_db(self):
"""确保数据库存在,不存在则从 Excel 导入"""
if os.path.exists(self.db_path):
self._migrate_schema()
return
if not os.path.exists(self.excel_source):
@@ -71,8 +93,24 @@ class ProductDatabase:
finally:
conn.close()
def _migrate_schema(self):
"""幂等迁移:为已有数据库添加新列"""
conn = self._connect()
try:
cursor = conn.execute("PRAGMA table_info(products)")
existing_cols = {row[1] for row in cursor.fetchall()}
for col_name, col_type in self._NEW_COLUMNS.items():
if col_name not in existing_cols:
conn.execute(f"ALTER TABLE products ADD COLUMN {col_name} {col_type}")
logger.info(f"数据库迁移: 添加列 {col_name}")
conn.commit()
finally:
conn.close()
def import_from_excel(self, excel_path: str) -> int:
"""从 Excel 导入商品资料
"""从 Excel 导入商品资料source=template, confidence=100
Args:
excel_path: Excel 文件路径
@@ -101,9 +139,10 @@ class ProductDatabase:
price_col = col
break
# 查找名称列单位列 (可选)
# 查找名称列单位列、规格列 (可选)
name_col = ColumnMapper.find_column(list(df.columns), 'name')
unit_col = ColumnMapper.find_column(list(df.columns), 'unit')
spec_col = ColumnMapper.find_column(list(df.columns), 'specification')
now = datetime.now().isoformat()
rows = []
@@ -127,8 +166,11 @@ class ProductDatabase:
unit = str(row.get(unit_col, '')).strip() if unit_col else ''
if unit == 'nan':
unit = ''
spec = str(row.get(spec_col, '')).strip() if spec_col else ''
if spec == 'nan':
spec = ''
rows.append((barcode, name, price, unit, now))
rows.append((barcode, name, price, unit, now, spec, 'template', 100, 0, now))
if not rows:
logger.warning(f"Excel 中未解析出有效记录: {excel_path}")
@@ -137,8 +179,9 @@ class ProductDatabase:
conn = self._connect()
try:
conn.executemany(
"INSERT OR REPLACE INTO products (barcode, name, price, unit, updated_at) "
"VALUES (?, ?, ?, ?, ?)",
"INSERT OR REPLACE INTO products "
"(barcode, name, price, unit, updated_at, specification, source, confidence, usage_count, last_seen) "
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
rows
)
conn.commit()
@@ -161,15 +204,10 @@ class ProductDatabase:
conn.close()
return self.import_from_excel(self.excel_source)
# ── 基础查询(保持兼容) ──────────────────────────────────
def get_price(self, barcode: str) -> Optional[float]:
"""按条码查询进货价
Args:
barcode: 商品条码
Returns:
进货价,未找到返回 None
"""
"""按条码查询进货价"""
conn = self._connect()
try:
cursor = conn.execute(
@@ -182,14 +220,7 @@ class ProductDatabase:
conn.close()
def get_prices(self, barcodes: List[str]) -> Dict[str, float]:
"""批量查询进货价
Args:
barcodes: 条码列表
Returns:
{条码: 进货价} 字典,未找到的不包含
"""
"""批量查询进货价"""
if not barcodes:
return {}
@@ -212,3 +243,288 @@ class ProductDatabase:
return cursor.fetchone()[0]
finally:
conn.close()
# ── 记忆库查询 ────────────────────────────────────────────
def get_memory(self, barcode: str) -> Optional[Dict]:
"""查询单条商品记忆"""
conn = self._connect()
conn.row_factory = sqlite3.Row
try:
cursor = conn.execute(
"SELECT * FROM products WHERE barcode = ?",
(str(barcode).strip(),)
)
row = cursor.fetchone()
if row:
return dict(row)
return None
finally:
conn.close()
def get_memories(self, barcodes: List[str]) -> Dict[str, Dict]:
"""批量查询商品记忆"""
if not barcodes:
return {}
conn = self._connect()
conn.row_factory = sqlite3.Row
try:
placeholders = ','.join('?' * len(barcodes))
cursor = conn.execute(
f"SELECT * FROM products WHERE barcode IN ({placeholders})",
[str(b).strip() for b in barcodes]
)
return {row['barcode']: dict(row) for row in cursor.fetchall()}
finally:
conn.close()
def get_all_memories(self) -> List[Dict]:
"""返回全部记录(UI 用)"""
conn = self._connect()
conn.row_factory = sqlite3.Row
try:
cursor = conn.execute(
"SELECT * FROM products ORDER BY usage_count DESC, barcode"
)
return [dict(row) for row in cursor.fetchall()]
finally:
conn.close()
# ── 学习逻辑 ──────────────────────────────────────────────
def learn_from_product(self, product: Dict, source: str = 'ocr') -> None:
"""从处理结果中学习单条商品数据
Args:
product: 商品字典 (barcode, name, specification, unit, price, ...)
source: 数据来源 ('template', 'ocr', 'user_confirmed')
"""
barcode = str(product.get('barcode', '')).strip()
if not barcode:
return
now = datetime.now().isoformat()
name = str(product.get('name', ''))
spec = str(product.get('specification', ''))
unit = str(product.get('unit', ''))
price = float(product.get('price', 0))
conn = self._connect()
try:
cursor = conn.execute(
"SELECT confidence, usage_count FROM products WHERE barcode = ?",
(barcode,)
)
row = cursor.fetchone()
if row is None:
# 新记录
conf = {'template': 100, 'user_confirmed': 90}.get(source, 50)
conn.execute(
"INSERT INTO products "
"(barcode, name, specification, unit, price, source, confidence, usage_count, last_seen, updated_at) "
"VALUES (?, ?, ?, ?, ?, ?, ?, 1, ?, ?)",
(barcode, name, spec, unit, price, source, conf, now, now)
)
else:
old_conf, old_count = row
new_count = old_count + 1
if source == 'template':
new_conf = 100
elif source == 'user_confirmed':
new_conf = 90
else: # ocr
new_conf = min(80, old_conf + 10) if old_conf < 80 else old_conf
if source in ('template', 'user_confirmed'):
# 高权威来源:全字段覆盖
conn.execute(
"UPDATE products SET name=?, specification=?, unit=?, price=?, "
"source=?, confidence=?, usage_count=?, last_seen=?, updated_at=? "
"WHERE barcode=?",
(name, spec, unit, price, source, new_conf, new_count, now, now, barcode)
)
else:
# OCR:仅填充空字段,不更新 price
conn.execute(
"UPDATE products SET "
"name = CASE WHEN name='' THEN ? ELSE name END, "
"specification = CASE WHEN specification='' THEN ? ELSE specification END, "
"unit = CASE WHEN unit='' THEN ? ELSE unit END, "
"source=?, confidence=?, usage_count=?, last_seen=?, updated_at=? "
"WHERE barcode=?",
(name, spec, unit, source, new_conf, new_count, now, now, barcode)
)
conn.commit()
finally:
conn.close()
def learn_from_products(self, products: List[Dict], source: str = 'ocr') -> int:
"""批量学习,返回更新条数"""
count = 0
for p in products:
try:
self.learn_from_product(p, source)
count += 1
except Exception as e:
logger.warning(f"学习商品记忆失败: {e}")
return count
def update_memory(self, barcode: str, fields: Dict) -> bool:
"""手动编辑记录(UI 用,source→user_confirmed, confidence→90"""
barcode = str(barcode).strip()
if not barcode:
return False
allowed = {'name', 'specification', 'unit', 'price'}
updates = {k: v for k, v in fields.items() if k in allowed}
if not updates:
return False
now = datetime.now().isoformat()
set_clause = ', '.join(f"{k}=?" for k in updates)
values = list(updates.values())
conn = self._connect()
try:
conn.execute(
f"UPDATE products SET {set_clause}, source='user_confirmed', confidence=90, "
"updated_at=? WHERE barcode=?",
values + [now, barcode]
)
conn.commit()
return conn.total_changes > 0
finally:
conn.close()
def delete_memory(self, barcode: str) -> bool:
"""删除记录"""
conn = self._connect()
try:
conn.execute("DELETE FROM products WHERE barcode=?", (str(barcode).strip(),))
conn.commit()
return conn.total_changes > 0
finally:
conn.close()
# ── 云端同步 ──────────────────────────────────────────────
def export_for_sync(self) -> Dict:
"""导出全部记录为 JSON-serializable dict(按条码索引)"""
conn = self._connect()
try:
cursor = conn.execute(
"SELECT barcode, name, specification, unit, price, source, "
"confidence, usage_count, last_seen FROM products"
)
result = {}
for row in cursor.fetchall():
result[row[0]] = {
'name': row[1],
'specification': row[2],
'unit': row[3],
'price': row[4],
'source': row[5],
'confidence': row[6],
'usage_count': row[7],
'last_seen': row[8],
}
return result
finally:
conn.close()
def import_from_sync(self, data: Dict) -> int:
"""从云端 JSON 导入,高置信度优先合并
Args:
data: {barcode: {name, specification, unit, price, source, confidence, ...}}
Returns:
导入/更新的记录数
"""
now = datetime.now().isoformat()
count = 0
conn = self._connect()
try:
for barcode, info in data.items():
barcode = str(barcode).strip()
if not barcode:
continue
name = str(info.get('name', ''))
spec = str(info.get('specification', ''))
unit = str(info.get('unit', ''))
price = float(info.get('price', 0))
remote_source = str(info.get('source', 'ocr'))
remote_conf = int(info.get('confidence', 50))
remote_count = int(info.get('usage_count', 1))
remote_seen = str(info.get('last_seen', now))
cursor = conn.execute(
"SELECT confidence FROM products WHERE barcode = ?",
(barcode,)
)
row = cursor.fetchone()
if row is None:
# 新记录,直接插入
conn.execute(
"INSERT INTO products "
"(barcode, name, specification, unit, price, source, confidence, usage_count, last_seen, updated_at) "
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
(barcode, name, spec, unit, price, remote_source, remote_conf, remote_count, remote_seen, now)
)
count += 1
else:
local_conf = row[0]
if remote_conf > local_conf:
# 云端置信度更高,覆盖
conn.execute(
"UPDATE products SET name=?, specification=?, unit=?, price=?, "
"source=?, confidence=?, usage_count=?, last_seen=?, updated_at=? "
"WHERE barcode=?",
(name, spec, unit, price, remote_source, remote_conf, remote_count, remote_seen, now, barcode)
)
count += 1
elif remote_conf == local_conf:
# 置信度相同,填充空字段
conn.execute(
"UPDATE products SET "
"name = CASE WHEN name='' THEN ? ELSE name END, "
"specification = CASE WHEN specification='' THEN ? ELSE specification END, "
"unit = CASE WHEN unit='' THEN ? ELSE unit END, "
"usage_count = MAX(usage_count, ?), "
"updated_at=? WHERE barcode=?",
(name, spec, unit, remote_count, now, barcode)
)
count += 1
conn.commit()
finally:
conn.close()
return count
def _export_memory_json(self, json_path: str = None) -> str:
"""导出记忆库为本地 JSON 文件
Args:
json_path: 输出路径,默认 data/product_memory.json
Returns:
写入的文件路径
"""
if json_path is None:
json_path = os.path.join(os.path.dirname(self.db_path), 'product_memory.json')
data = self.export_for_sync()
os.makedirs(os.path.dirname(json_path), exist_ok=True)
with open(json_path, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
logger.debug(f"商品记忆库已导出: {json_path} ({len(data)} 条)")
return json_path
+81 -4
View File
@@ -40,12 +40,13 @@ class ExcelProcessor:
提取条码、单价和数量,并按照采购单模板的格式填充
"""
def __init__(self, config):
def __init__(self, config, product_db=None):
"""
初始化Excel处理器
Args:
config: 配置信息
product_db: 商品数据库实例(可选,由外部传入以共享)
"""
self.config = config
@@ -74,6 +75,18 @@ class ExcelProcessor:
# 加载单位转换器和配置
self.unit_converter = UnitConverter()
# 商品记忆库
if product_db is not None:
self.product_db = product_db
else:
from ..db.product_db import ProductDatabase
db_path = config.get_path('Paths', 'product_db', fallback='data/product_cache.db') if hasattr(config, 'get_path') else 'data/product_cache.db'
tpl_folder = config.get('Paths', 'template_folder', fallback='templates')
item_data = config.get('Templates', 'item_data', fallback='商品资料.xlsx')
tpl_path = os.path.join(tpl_folder, item_data)
self.product_db = ProductDatabase(db_path, tpl_path)
logger.info(f"初始化ExcelProcessor完成,模板文件: {self.template_path}")
except Exception as e:
logger.error(f"初始化ExcelProcessor失败: {e}")
@@ -371,14 +384,70 @@ class ExcelProcessor:
except Exception as e:
logger.warning(f"通过金额和单价计算数量失败: {e}")
# 应用记忆库补全
product = self._apply_memory(product)
products.append(product)
except Exception as e:
logger.error(f"提取第{idx+1}行商品信息时出错: {e}", exc_info=True)
continue
logger.info(f"提取到 {len(products)} 个商品信息")
return products
def _apply_memory(self, product: Dict) -> Dict:
"""查记忆库,补全 OCR 缺失/错误的字段"""
barcode = product.get('barcode', '')
if not barcode:
return product
try:
memory = self.product_db.get_memory(barcode)
except Exception:
return product
if memory is None or memory.get('confidence', 0) < 80:
return product
# 补全规格
ocr_spec = product.get('specification', '')
mem_spec = memory.get('specification', '') or ''
if mem_spec and (not ocr_spec or self._is_spec_suspicious(ocr_spec)):
product['specification'] = mem_spec
logger.info(f"记忆修正规格: {barcode} '{ocr_spec}' -> '{mem_spec}'")
# 补全名称
ocr_name = product.get('name', '')
mem_name = memory.get('name', '') or ''
if mem_name and not ocr_name:
product['name'] = mem_name
logger.info(f"记忆修正名称: {barcode} -> '{mem_name}'")
# 补全单位
ocr_unit = product.get('unit', '')
mem_unit = memory.get('unit', '') or ''
if mem_unit and not ocr_unit:
product['unit'] = mem_unit
logger.info(f"记忆修正单位: {barcode} -> '{mem_unit}'")
# 不改数量和单价(每单不同)
return product
def _is_spec_suspicious(self, spec: str) -> bool:
"""检测规格是否像 OCR 垃圾"""
if not spec:
return True
# IL*12I 和 1 混淆)
if re.search(r'^[Ii][Ll*]', spec):
return True
# 4.51*4L 被识别为 1
if re.search(r'\d+\.\d+1\*\d+', spec):
return True
# 包含非常规字符(排除常见规格字符)
if re.search(r'[^\d.*xX\-LlKkGgMm升毫瓶桶盒箱件提\s]', spec):
return True
return False
def fill_template(self, products: List[Dict], output_file_path: str) -> bool:
"""
填充采购单模板
@@ -599,6 +668,14 @@ class ExcelProcessor:
# 填充模板并保存
if self.fill_template(products, output_file):
# 从处理结果中学习商品记忆
try:
self.product_db.learn_from_products(products, source='ocr')
self.product_db._export_memory_json()
logger.info(f"已从处理结果学习 {len(products)} 条商品记忆")
except Exception as e:
logger.warning(f"学习商品记忆失败: {e}")
# 记录已处理文件
self.processed_files[file_path] = output_file
self._save_processed_files()
+19
View File
@@ -830,6 +830,12 @@ SYNC_FILES = [
"local": "templates/银豹-采购单模板.xls",
"type": "binary",
},
{
"name": "商品记忆库",
"remote": "product_memory.json",
"local": "data/product_memory.json",
"type": "json",
},
]
@@ -1068,6 +1074,19 @@ def show_cloud_sync_dialog(parent=None):
ProcessorService(ConfigManager()).reload_processors()
except Exception:
pass
elif entry["remote"] == "product_memory.json":
try:
from app.core.db.product_db import ProductDatabase
cfg = ConfigManager()
db_path = cfg.get_path('Paths', 'product_db', fallback='data/product_cache.db') if hasattr(cfg, 'get_path') else 'data/product_cache.db'
tpl_folder = cfg.get('Paths', 'template_folder', fallback='templates')
item_data = cfg.get('Templates', 'item_data', fallback='商品资料.xlsx')
tpl_path = os.path.join(tpl_folder, item_data)
db = ProductDatabase(db_path, tpl_path)
count = db.import_from_sync(data)
logger.info(f"从云端导入商品记忆: {count}")
except Exception:
pass
def push_all():
ok, fail = 0, 0