feat: 益选 OCR 订单处理系统初始提交

- 智能供应商识别(蓉城易购/烟草/杨碧月/通用)
- 百度 OCR 表格识别集成
- 规则引擎(列映射/数据清洗/单位转换/规格推断)
- 条码映射管理与云端同步(Gitea REST API)
- 云端同步支持:条码映射、供应商配置、商品资料、采购模板
- 拖拽一键处理(图片→OCR→Excel→合并)
- 191 个单元测试
- 移除无用的模板管理功能
- 清理 IDE 产物目录

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
2026-05-04 19:51:13 +08:00
commit e4d62df7e3
78 changed files with 15257 additions and 0 deletions
View File
+222
View File
@@ -0,0 +1,222 @@
"""app.core.handlers.calculator 单元测试"""
import pytest
import pandas as pd
import numpy as np
from app.core.handlers.calculator import DataCalculator
@pytest.fixture
def sample_df():
return pd.DataFrame({
'price': [10.0, 20.0, 30.0],
'quantity': [2, 5, 10],
'name': ['A', 'B', 'C'],
})
class TestMultiply:
def test_basic_multiply(self, sample_df):
calc = DataCalculator()
calc.add_rule('multiply', source_column='price', target_column='total', factor=2)
result = calc.calculate(sample_df)
assert list(result['total']) == [20.0, 40.0, 60.0]
def test_multiply_missing_source(self, sample_df):
calc = DataCalculator()
calc.add_rule('multiply', source_column='nonexistent', target_column='total', factor=2)
result = calc.calculate(sample_df)
assert 'total' not in result.columns
def test_multiply_default_factor(self, sample_df):
calc = DataCalculator()
calc.add_rule('multiply', source_column='price', target_column='copy', factor=1)
result = calc.calculate(sample_df)
assert list(result['copy']) == [10.0, 20.0, 30.0]
def test_convenience_method(self, sample_df):
calc = DataCalculator()
calc.multiply('price', 'total', 3)
result = calc.calculate(sample_df)
assert list(result['total']) == [30.0, 60.0, 90.0]
class TestDivide:
def test_basic_divide(self, sample_df):
calc = DataCalculator()
calc.add_rule('divide', source_column='price', target_column='half', divisor=2)
result = calc.calculate(sample_df)
assert list(result['half']) == [5.0, 10.0, 15.0]
def test_divide_by_zero(self, sample_df):
calc = DataCalculator()
calc.add_rule('divide', source_column='price', target_column='half', divisor=0)
result = calc.calculate(sample_df)
assert 'half' not in result.columns
def test_divide_missing_source(self, sample_df):
calc = DataCalculator()
calc.add_rule('divide', source_column='nonexistent', target_column='x', divisor=2)
result = calc.calculate(sample_df)
assert 'x' not in result.columns
class TestAdd:
def test_add_columns(self):
df = pd.DataFrame({'a': [1, 2, 3], 'b': [10, 20, 30]})
calc = DataCalculator()
calc.add_rule('add', columns=['a', 'b'], target_column='sum')
result = calc.calculate(df)
assert list(result['sum']) == [11, 22, 33]
def test_add_constant(self):
df = pd.DataFrame({'a': [1, 2, 3]})
calc = DataCalculator()
calc.add_rule('add', target_column='a', constant=100)
result = calc.calculate(df)
assert list(result['a']) == [101, 102, 103]
def test_add_columns_with_constant(self):
df = pd.DataFrame({'a': [1, 2], 'b': [3, 4]})
calc = DataCalculator()
calc.add_rule('add', columns=['a', 'b'], target_column='total', constant=10)
result = calc.calculate(df)
assert list(result['total']) == [14, 16]
def test_add_string_column(self):
df = pd.DataFrame({'a': [1, 2]})
calc = DataCalculator()
calc.add_rule('add', columns='a', target_column='total')
result = calc.calculate(df)
assert list(result['total']) == [1, 2]
class TestSubtract:
def test_subtract_two_columns(self):
df = pd.DataFrame({'income': [100, 200], 'cost': [30, 80]})
calc = DataCalculator()
calc.add_rule('subtract', minuend='income', subtrahend='cost', target_column='profit')
result = calc.calculate(df)
assert list(result['profit']) == [70, 120]
def test_subtract_constant(self):
df = pd.DataFrame({'price': [100, 200]})
calc = DataCalculator()
calc.add_rule('subtract', minuend='price', target_column='discounted', constant=10)
result = calc.calculate(df)
assert list(result['discounted']) == [90, 190]
def test_subtract_missing_minuend(self):
df = pd.DataFrame({'a': [1, 2]})
calc = DataCalculator()
calc.add_rule('subtract', minuend='nonexistent', target_column='x', constant=1)
result = calc.calculate(df)
assert 'x' not in result.columns
class TestFormula:
def test_basic_formula(self, sample_df):
calc = DataCalculator()
calc.add_rule('formula', formula='price * quantity', target_column='total')
result = calc.calculate(sample_df)
assert list(result['total']) == [20.0, 100.0, 300.0]
def test_invalid_formula(self, sample_df):
calc = DataCalculator()
calc.add_rule('formula', formula='nonexistent + 1', target_column='x')
result = calc.calculate(sample_df)
# formula fails, original df returned
assert 'x' not in result.columns
def test_formula_missing_target(self, sample_df):
calc = DataCalculator()
calc.add_rule('formula', formula='price * 2')
result = calc.calculate(sample_df)
# no target_column, nothing happens
assert list(result['price']) == [10.0, 20.0, 30.0]
class TestRound:
def test_round_specific_columns(self):
df = pd.DataFrame({'a': [1.234, 2.567], 'b': [3.1, 4.9]})
calc = DataCalculator()
calc.add_rule('round', columns=['a'], decimals=1)
result = calc.calculate(df)
assert list(result['a']) == [1.2, 2.6]
assert list(result['b']) == [3.1, 4.9] # unchanged
def test_round_all_numeric(self):
df = pd.DataFrame({'a': [1.234, 2.567], 'b': [3.111, 4.999]})
calc = DataCalculator()
calc.add_rule('round', decimals=0)
result = calc.calculate(df)
assert list(result['a']) == [1.0, 3.0]
assert list(result['b']) == [3.0, 5.0]
def test_round_string_column_skipped(self):
df = pd.DataFrame({'name': ['a', 'b'], 'val': [1.5, 2.5]})
calc = DataCalculator()
calc.add_rule('round', columns=['name', 'val'], decimals=0)
result = calc.calculate(df)
assert list(result['val']) == [2.0, 2.0]
class TestSum:
def test_sum_columns_to_target(self):
df = pd.DataFrame({'a': [1, 2], 'b': [3, 4], 'c': [5, 6]})
calc = DataCalculator()
calc.add_rule('sum', columns=['a', 'b'], target_column='total')
result = calc.calculate(df)
assert list(result['total']) == [4, 6]
def test_sum_missing_columns(self):
df = pd.DataFrame({'a': [1, 2]})
calc = DataCalculator()
calc.add_rule('sum', columns=['a', 'missing'], target_column='total')
result = calc.calculate(df)
assert list(result['total']) == [1, 2]
class TestChaining:
def test_multiple_rules(self, sample_df):
calc = DataCalculator()
calc.add_rule('multiply', source_column='price', target_column='total', factor=2)
calc.add_rule('add', columns=['total', 'quantity'], target_column='grand')
result = calc.calculate(sample_df)
assert list(result['total']) == [20.0, 40.0, 60.0]
assert list(result['grand']) == [22.0, 45.0, 70.0]
def test_chaining_convenience(self, sample_df):
calc = DataCalculator()
calc.multiply('price', 'total', 2).round_columns('total', 0)
result = calc.calculate(sample_df)
assert list(result['total']) == [20.0, 40.0, 60.0]
class TestEdgeCases:
def test_empty_dataframe(self):
df = pd.DataFrame({'a': pd.Series([], dtype=float)})
calc = DataCalculator()
calc.add_rule('multiply', source_column='a', target_column='b', factor=2)
result = calc.calculate(df)
assert len(result) == 0
def test_no_rules(self, sample_df):
calc = DataCalculator()
result = calc.calculate(sample_df)
assert list(result['price']) == [10.0, 20.0, 30.0]
def test_unknown_rule_type(self, sample_df):
calc = DataCalculator()
calc.add_rule('unknown_op', source_column='price', target_column='x')
result = calc.calculate(sample_df)
# unknown rule is skipped, df unchanged
assert list(result['price']) == [10.0, 20.0, 30.0]
def test_rule_failure_continues(self, sample_df):
calc = DataCalculator()
calc.add_rule('formula', formula='nonexistent + 1', target_column='x')
calc.add_rule('multiply', source_column='price', target_column='y', factor=2)
result = calc.calculate(sample_df)
assert list(result['y']) == [20.0, 40.0, 60.0]
+154
View File
@@ -0,0 +1,154 @@
"""app.core.handlers.column_mapper 单元测试"""
import pytest
import pandas as pd
from app.core.handlers.column_mapper import ColumnMapper
class TestStandardColumns:
"""STANDARD_COLUMNS 完整性测试"""
def test_has_all_standard_fields(self):
expected = {'barcode', 'name', 'specification', 'quantity', 'unit',
'unit_price', 'total_price', 'gift_quantity',
'category', 'brand', 'supplier'}
assert set(ColumnMapper.STANDARD_COLUMNS.keys()) == expected
def test_no_empty_alias_lists(self):
for field, aliases in ColumnMapper.STANDARD_COLUMNS.items():
assert len(aliases) > 0, f"{field} has no aliases"
def test_barcode_includes_key_names(self):
bc = ColumnMapper.STANDARD_COLUMNS['barcode']
assert '条码' in bc
assert '商品条码' in bc
assert 'barcode' in bc
def test_gift_quantity_includes_common_names(self):
gq = ColumnMapper.STANDARD_COLUMNS['gift_quantity']
assert '赠送量' in gq
assert '赠品数量' in gq
class TestFindColumn:
"""ColumnMapper.find_column 列查找测试"""
def test_exact_match(self):
cols = ['商品条码', '商品名称', '数量', '单价']
assert ColumnMapper.find_column(cols, 'barcode') == '商品条码'
def test_exact_match_standard_english(self):
cols = ['barcode', 'name', 'quantity']
assert ColumnMapper.find_column(cols, 'barcode') == 'barcode'
def test_whitespace_match(self):
"""列名含空格时应匹配"""
cols = ['名 称', '数 量']
assert ColumnMapper.find_column(cols, 'name') == '名 称'
assert ColumnMapper.find_column(cols, 'quantity') == '数 量'
def test_partial_match_substring(self):
"""列名包含候选名时应匹配"""
cols = ['商品条码(小条码)', '商品名称']
assert ColumnMapper.find_column(cols, 'barcode') == '商品条码(小条码)'
def test_not_found_returns_none(self):
cols = ['日期', '备注', '编号']
assert ColumnMapper.find_column(cols, 'barcode') is None
def test_unknown_standard_name_returns_none(self):
cols = ['商品条码']
assert ColumnMapper.find_column(cols, 'nonexistent_field') is None
def test_first_match_wins(self):
"""多个列都能匹配时返回第一个"""
cols = ['条码', '商品条码', 'barcode']
assert ColumnMapper.find_column(cols, 'barcode') == '条码'
def test_case_insensitive(self):
cols = ['Barcode', 'Name']
assert ColumnMapper.find_column(cols, 'barcode') == 'Barcode'
def test_all_fields_matchable(self):
"""每个标准字段都能找到至少一个匹配"""
cols = [
'商品条码', '商品名称', '规格', '数量', '单位',
'单价', '金额', '赠送量', '类别', '品牌', '供应商',
]
for std_name in ColumnMapper.STANDARD_COLUMNS:
result = ColumnMapper.find_column(cols, std_name)
assert result is not None, f"Could not find {std_name} in {cols}"
class TestDetectHeaderRow:
"""ColumnMapper.detect_header_row 表头检测测试"""
def test_header_on_first_row(self):
df = pd.DataFrame({
'A': ['条码', '123456', '789012'],
'B': ['数量', '10', '20'],
'C': ['单价', '5.5', '3.0'],
})
assert ColumnMapper.detect_header_row(df, min_matches=2) == 0
def test_header_on_second_row(self):
df = pd.DataFrame({
'A': ['备注', '条码', '123456'],
'B': ['日期', '数量', '10'],
'C': ['时间', '单价', '5.5'],
})
assert ColumnMapper.detect_header_row(df, min_matches=2) == 1
def test_no_header_returns_minus_one(self):
df = pd.DataFrame({
'A': ['aaa', 'bbb', 'ccc'],
'B': ['ddd', 'eee', 'fff'],
})
assert ColumnMapper.detect_header_row(df, min_matches=3) == -1
def test_empty_dataframe(self):
df = pd.DataFrame()
assert ColumnMapper.detect_header_row(df) == -1
def test_max_rows_limits_scan(self):
"""表头在第 10 行但 max_rows=5 时应返回 -1"""
data = {f'col{i}': ['x'] * 15 for i in range(3)}
data['col0'][10] = '条码'
data['col1'][10] = '数量'
data['col2'][10] = '单价'
df = pd.DataFrame(data)
assert ColumnMapper.detect_header_row(df, max_rows=5, min_matches=2) == -1
class TestColumnMapperInstance:
"""ColumnMapper 实例方法测试"""
def test_init_with_no_config(self):
mapper = ColumnMapper()
assert mapper.mapping_config == {}
def test_init_with_custom_config(self):
mapper = ColumnMapper(mapping_config={'barcode': ['我的条码']})
assert '我的条码' in mapper.custom_mappings
def test_map_columns_renames(self):
mapper = ColumnMapper()
df = pd.DataFrame({'商品条码': ['123'], '商品名称': ['测试'], '数量': [10]})
result = mapper.map_columns(df, target_columns=['barcode', 'name', 'quantity'])
assert 'barcode' in result.columns
assert 'name' in result.columns
assert 'quantity' in result.columns
def test_map_columns_fills_missing(self):
mapper = ColumnMapper()
df = pd.DataFrame({'商品条码': ['123']})
result = mapper.map_columns(df, target_columns=['barcode', 'quantity'])
assert 'barcode' in result.columns
assert 'quantity' in result.columns
assert result['quantity'].iloc[0] == 0 # default value
def test_add_custom_mapping(self):
mapper = ColumnMapper()
mapper.add_custom_mapping('barcode', '自定义条码列')
assert '自定义条码列' in mapper.reverse_mapping
assert mapper.reverse_mapping['自定义条码列'] == 'barcode'
+236
View File
@@ -0,0 +1,236 @@
"""app.core.handlers.data_cleaner 单元测试"""
import pytest
import pandas as pd
from app.core.handlers.data_cleaner import DataCleaner
@pytest.fixture
def sample_df():
return pd.DataFrame({
'name': [' Alice ', 'Bob', 'Charlie', 'Dave'],
'age': [25, 30, None, 40],
'score': [80.5, 90.0, 70.0, 85.0],
'city': ['Beijing', 'Shanghai', 'Beijing', 'Guangzhou'],
})
class TestFillNa:
def test_fill_na_with_value(self, sample_df):
cleaner = DataCleaner()
cleaner.add_rule('fill_na', columns=['age'], value=0)
result = cleaner.clean(sample_df)
assert result['age'].isna().sum() == 0
assert result.loc[2, 'age'] == 0
def test_fill_na_all_columns(self, sample_df):
cleaner = DataCleaner()
cleaner.add_rule('fill_na', value=-1)
result = cleaner.clean(sample_df)
assert result.isna().sum().sum() == 0
def test_fill_na_string_column(self):
df = pd.DataFrame({'a': ['x', None, 'z']})
cleaner = DataCleaner()
cleaner.add_rule('fill_na', columns=['a'], value='unknown')
result = cleaner.clean(df)
assert result.loc[1, 'a'] == 'unknown'
def test_convenience_method(self, sample_df):
cleaner = DataCleaner()
cleaner.fill_na(columns='age', value=99)
result = cleaner.clean(sample_df)
assert result.loc[2, 'age'] == 99
class TestRemoveDuplicates:
def test_remove_by_subset(self):
df = pd.DataFrame({
'name': ['A', 'B', 'A', 'C'],
'val': [1, 2, 3, 4],
})
cleaner = DataCleaner()
cleaner.add_rule('remove_duplicates', subset=['name'], keep='first')
result = cleaner.clean(df)
assert len(result) == 3
assert list(result['name']) == ['A', 'B', 'C']
def test_remove_all_columns(self):
df = pd.DataFrame({
'a': [1, 1, 2],
'b': [10, 10, 20],
})
cleaner = DataCleaner()
cleaner.add_rule('remove_duplicates')
result = cleaner.clean(df)
assert len(result) == 2
def test_no_duplicates(self, sample_df):
cleaner = DataCleaner()
cleaner.add_rule('remove_duplicates', subset=['name'])
result = cleaner.clean(sample_df)
assert len(result) == 4
class TestRemoveRows:
def test_remove_by_condition(self, sample_df):
cleaner = DataCleaner()
cleaner.add_rule('remove_rows', condition='age > 25')
result = cleaner.clean(sample_df)
assert len(result) == 2
def test_remove_by_values(self, sample_df):
cleaner = DataCleaner()
cleaner.add_rule('remove_rows', columns=['city'], values=['Beijing'])
result = cleaner.clean(sample_df)
assert len(result) == 2
assert 'Beijing' not in result['city'].values
def test_remove_no_match(self, sample_df):
cleaner = DataCleaner()
cleaner.add_rule('remove_rows', condition='age > 100')
result = cleaner.clean(sample_df)
assert len(result) == 0 # condition filter: no rows match age > 100
def test_convenience_method(self, sample_df):
cleaner = DataCleaner()
cleaner.remove_rows(condition='score < 75')
result = cleaner.clean(sample_df)
assert len(result) == 1 # condition filter: keeps only Charlie (score=70.0)
class TestConvertType:
def test_to_float(self):
df = pd.DataFrame({'val': ['1.5', '2.7', 'abc']})
cleaner = DataCleaner()
cleaner.add_rule('convert_type', columns=['val'], target_type='float')
result = cleaner.clean(df)
assert result['val'].dtype.kind == 'f'
assert result.loc[0, 'val'] == 1.5
assert pd.isna(result.loc[2, 'val'])
def test_to_int(self):
df = pd.DataFrame({'val': ['1', '2', '3']})
cleaner = DataCleaner()
cleaner.add_rule('convert_type', columns=['val'], target_type='int')
result = cleaner.clean(df)
assert result.loc[0, 'val'] == 1
def test_to_string(self):
df = pd.DataFrame({'val': [1, 2, 3]})
cleaner = DataCleaner()
cleaner.add_rule('convert_type', columns=['val'], target_type='string')
result = cleaner.clean(df)
assert result.loc[0, 'val'] == '1'
def test_missing_column_skipped(self, sample_df):
cleaner = DataCleaner()
cleaner.add_rule('convert_type', columns=['nonexistent'], target_type='float')
result = cleaner.clean(sample_df)
assert len(result) == 4
class TestStripWhitespace:
def test_strip_specific_columns(self, sample_df):
cleaner = DataCleaner()
cleaner.add_rule('strip_whitespace', columns=['name'])
result = cleaner.clean(sample_df)
assert result.loc[0, 'name'] == 'Alice'
def test_strip_all_text(self, sample_df):
cleaner = DataCleaner()
cleaner.add_rule('strip_whitespace')
result = cleaner.clean(sample_df)
assert result.loc[0, 'name'] == 'Alice'
def test_strip_non_text_skipped(self):
df = pd.DataFrame({'val': [1, 2, 3]})
cleaner = DataCleaner()
cleaner.add_rule('strip_whitespace', columns=['val'])
result = cleaner.clean(df)
assert list(result['val']) == [1, 2, 3]
class TestNormalizeText:
def test_lowercase(self):
df = pd.DataFrame({'name': ['ALICE', 'BOB']})
cleaner = DataCleaner()
cleaner.add_rule('normalize_text', columns=['name'], lowercase=True)
result = cleaner.clean(df)
assert list(result['name']) == ['alice', 'bob']
def test_uppercase(self):
df = pd.DataFrame({'name': ['alice', 'bob']})
cleaner = DataCleaner()
cleaner.add_rule('normalize_text', columns=['name'], uppercase=True)
result = cleaner.clean(df)
assert list(result['name']) == ['ALICE', 'BOB']
def test_replace_map(self):
df = pd.DataFrame({'city': ['BJ', 'SH']})
cleaner = DataCleaner()
cleaner.add_rule('normalize_text', columns=['city'], replace_map={'BJ': 'Beijing', 'SH': 'Shanghai'})
result = cleaner.clean(df)
assert list(result['city']) == ['Beijing', 'Shanghai']
class TestValidateData:
def test_validate_logs_but_does_not_modify(self, sample_df):
cleaner = DataCleaner()
cleaner.add_rule('validate_data', columns=['score'], min_value=0, max_value=100)
result = cleaner.clean(sample_df)
assert len(result) == 4
def test_validate_required(self, sample_df):
cleaner = DataCleaner()
cleaner.add_rule('validate_data', columns=['age'], required=True)
result = cleaner.clean(sample_df)
assert len(result) == 4
class TestChaining:
def test_multiple_rules(self, sample_df):
cleaner = DataCleaner()
cleaner.add_rule('strip_whitespace', columns=['name'])
cleaner.add_rule('fill_na', columns=['age'], value=0)
cleaner.add_rule('convert_type', columns=['age'], target_type='int')
result = cleaner.clean(sample_df)
assert result.loc[0, 'name'] == 'Alice'
assert result['age'].isna().sum() == 0
assert result.loc[2, 'age'] == 0
def test_convenience_chaining(self, sample_df):
cleaner = DataCleaner()
cleaner.strip_whitespace('name').fill_na('age', value=0)
result = cleaner.clean(sample_df)
assert result.loc[0, 'name'] == 'Alice'
assert result.loc[2, 'age'] == 0
class TestEdgeCases:
def test_empty_dataframe(self):
df = pd.DataFrame({'a': pd.Series([], dtype=float)})
cleaner = DataCleaner()
cleaner.add_rule('fill_na', value=0)
result = cleaner.clean(df)
assert len(result) == 0
def test_no_rules(self, sample_df):
cleaner = DataCleaner()
result = cleaner.clean(sample_df)
assert len(result) == 4
def test_unknown_rule_type(self, sample_df):
cleaner = DataCleaner()
cleaner.add_rule('unknown_op', columns=['name'])
result = cleaner.clean(sample_df)
assert len(result) == 4
def test_rule_failure_continues(self, sample_df):
"""A failing rule should not block subsequent rules."""
cleaner = DataCleaner()
cleaner.add_rule('convert_type', columns=['nonexistent'], target_type='float')
cleaner.add_rule('fill_na', columns=['age'], value=0)
result = cleaner.clean(sample_df)
assert result.loc[2, 'age'] == 0
+187
View File
@@ -0,0 +1,187 @@
"""app.core.db.product_db 单元测试"""
import os
import tempfile
import pytest
import pandas as pd
from app.core.db.product_db import ProductDatabase
@pytest.fixture
def db_dir():
"""临时目录"""
with tempfile.TemporaryDirectory() as d:
yield d
@pytest.fixture
def sample_excel(db_dir):
"""创建测试用 Excel 文件"""
path = os.path.join(db_dir, '商品资料.xlsx')
df = pd.DataFrame({
'商品条码': ['6920584471055', '6901028001133', '6925303800013'],
'商品名称': ['农夫山泉550ml', '蒙牛纯牛奶', '可口可乐330ml'],
'进货价': [1.2, 3.5, 1.8],
'单位': ['', '', ''],
})
df.to_excel(path, index=False)
return path
@pytest.fixture
def db_with_data(db_dir, sample_excel):
"""已导入数据的数据库"""
db_path = os.path.join(db_dir, 'product_cache.db')
db = ProductDatabase(db_path, sample_excel)
return db
class TestProductDatabaseInit:
"""数据库初始化测试"""
def test_auto_import_on_first_run(self, db_dir, sample_excel):
"""首次运行自动从 Excel 导入"""
db_path = os.path.join(db_dir, 'product_cache.db')
assert not os.path.exists(db_path)
db = ProductDatabase(db_path, sample_excel)
assert os.path.exists(db_path)
assert db.count() == 3
def test_no_reimport_on_existing_db(self, db_dir, sample_excel):
"""数据库已存在时不重新导入"""
db_path = os.path.join(db_dir, 'product_cache.db')
db1 = ProductDatabase(db_path, sample_excel)
assert db1.count() == 3
# 删除 Excel 后仍能打开已有数据库
os.remove(sample_excel)
db2 = ProductDatabase(db_path, sample_excel)
assert db2.count() == 3
def test_missing_excel_creates_empty_db(self, db_dir):
"""Excel 不存在时创建空数据库"""
db_path = os.path.join(db_dir, 'product_cache.db')
fake_excel = os.path.join(db_dir, '不存在.xlsx')
db = ProductDatabase(db_path, fake_excel)
assert os.path.exists(db_path)
assert db.count() == 0
def test_missing_dir_created(self, db_dir, sample_excel):
"""数据库目录不存在时自动创建"""
db_path = os.path.join(db_dir, 'subdir', 'product_cache.db')
db = ProductDatabase(db_path, sample_excel)
assert os.path.exists(db_path)
assert db.count() == 3
class TestGetPrice:
"""单条查询测试"""
def test_existing_barcode(self, db_with_data):
price = db_with_data.get_price('6920584471055')
assert price == pytest.approx(1.2)
def test_nonexistent_barcode(self, db_with_data):
price = db_with_data.get_price('0000000000000')
assert price is None
def test_empty_barcode(self, db_with_data):
price = db_with_data.get_price('')
assert price is None
def test_barcode_with_spaces(self, db_with_data):
"""条码前后空格应能匹配"""
price = db_with_data.get_price(' 6920584471055 ')
assert price == pytest.approx(1.2)
class TestGetPrices:
"""批量查询测试"""
def test_multiple_barcodes(self, db_with_data):
result = db_with_data.get_prices(['6920584471055', '6901028001133'])
assert len(result) == 2
assert result['6920584471055'] == pytest.approx(1.2)
assert result['6901028001133'] == pytest.approx(3.5)
def test_partial_match(self, db_with_data):
"""部分条码存在,部分不存在"""
result = db_with_data.get_prices(['6920584471055', '0000000000000'])
assert len(result) == 1
assert '6920584471055' in result
def test_empty_list(self, db_with_data):
result = db_with_data.get_prices([])
assert result == {}
def test_all_nonexistent(self, db_with_data):
result = db_with_data.get_prices(['0000000000000', '1111111111111'])
assert result == {}
class TestReimport:
"""重新导入测试"""
def test_reimport_clears_and_reloads(self, db_dir, sample_excel):
db_path = os.path.join(db_dir, 'product_cache.db')
db = ProductDatabase(db_path, sample_excel)
assert db.count() == 3
# 修改 Excel,添加一行
df = pd.read_excel(sample_excel)
df = pd.concat([df, pd.DataFrame({
'商品条码': ['6954365200123'],
'商品名称': ['测试商品'],
'进货价': [5.0],
'单位': [''],
})])
df.to_excel(sample_excel, index=False)
count = db.reimport()
assert count == 4
assert db.count() == 4
assert db.get_price('6954365200123') == pytest.approx(5.0)
class TestEdgeCases:
"""边界条件测试"""
def test_excel_with_missing_price(self, db_dir):
"""Excel 中价格列为空的行"""
path = os.path.join(db_dir, '商品资料.xlsx')
df = pd.DataFrame({
'商品条码': ['6920584471055', '6901028001133'],
'商品名称': ['商品A', '商品B'],
'进货价': [1.5, None],
})
df.to_excel(path, index=False)
db_path = os.path.join(db_dir, 'product_cache.db')
db = ProductDatabase(db_path, path)
assert db.count() == 2
assert db.get_price('6920584471055') == pytest.approx(1.5)
assert db.get_price('6901028001133') == pytest.approx(0.0)
def test_excel_with_duplicate_barcodes(self, db_dir):
"""重复条码取最后一条 (INSERT OR REPLACE)"""
path = os.path.join(db_dir, '商品资料.xlsx')
df = pd.DataFrame({
'商品条码': ['6920584471055', '6920584471055'],
'商品名称': ['商品A', '商品A-新'],
'进货价': [1.0, 2.0],
})
df.to_excel(path, index=False)
db_path = os.path.join(db_dir, 'product_cache.db')
db = ProductDatabase(db_path, path)
assert db.count() == 1
assert db.get_price('6920584471055') == pytest.approx(2.0)
+223
View File
@@ -0,0 +1,223 @@
"""app.core.handlers.rule_engine 单元测试"""
import pytest
import pandas as pd
from app.core.handlers.rule_engine import (
apply_rules,
_split_quantity_unit,
_extract_spec_from_name,
_normalize_unit,
_compute_quantity_from_total,
_fill_missing,
_mark_gift,
)
@pytest.fixture
def sample_df():
return pd.DataFrame({
'name': ['农夫山泉550ml*24', '蒙牛纯牛奶', '可口可乐330ml*6'],
'quantity_raw': ['2箱', '5', '3提'],
'unit_price': [28.8, 3.5, 10.8],
'total_price': [57.6, 17.5, 32.4],
})
class TestSplitQuantityUnit:
def test_split_with_unit(self):
df = pd.DataFrame({'quantity_raw': ['2箱', '5瓶', '3提']})
result = _split_quantity_unit(df, 'quantity_raw')
assert list(result['quantity']) == [2.0, 5.0, 3.0]
assert list(result['unit']) == ['', '', '']
def test_split_number_only(self):
df = pd.DataFrame({'quantity_raw': ['10', '20']})
result = _split_quantity_unit(df, 'quantity_raw')
assert list(result['quantity']) == [10.0, 20.0]
def test_split_with_synonyms(self):
df = pd.DataFrame({'quantity_raw': ['2件']})
dictionary = {'unit_synonyms': {'': ''}, 'default_unit': ''}
result = _split_quantity_unit(df, 'quantity_raw', dictionary)
assert result.loc[0, 'unit'] == ''
def test_split_missing_column(self):
df = pd.DataFrame({'other': [1, 2]})
result = _split_quantity_unit(df, 'quantity_raw')
assert 'quantity' not in result.columns
def test_split_invalid_value(self):
df = pd.DataFrame({'quantity_raw': ['abc']})
result = _split_quantity_unit(df, 'quantity_raw')
assert result.loc[0, 'quantity'] == 0.0
class TestExtractSpecFromName:
def test_extract_550ml_24(self):
df = pd.DataFrame({'name': ['农夫山泉550ml*24']})
result = _extract_spec_from_name(df, 'name')
assert result.loc[0, 'package_quantity'] == 24
def test_extract_330ml_6(self):
df = pd.DataFrame({'name': ['可口可乐330ml*6']})
result = _extract_spec_from_name(df, 'name')
assert result.loc[0, 'package_quantity'] == 6
def test_extract_1_star_pattern(self):
df = pd.DataFrame({'name': ['啤酒1*12']})
result = _extract_spec_from_name(df, 'name')
assert result.loc[0, 'package_quantity'] == 12
def test_no_spec(self):
df = pd.DataFrame({'name': ['蒙牛纯牛奶']})
result = _extract_spec_from_name(df, 'name')
assert result.loc[0, 'package_quantity'] is None
def test_missing_column(self):
df = pd.DataFrame({'other': ['test']})
result = _extract_spec_from_name(df, 'name')
assert 'package_quantity' not in result.columns
def test_with_ignore_words(self):
df = pd.DataFrame({'name': ['新品 农夫山泉550ml*24']})
dictionary = {'ignore_words': ['新品'], 'name_patterns': []}
result = _extract_spec_from_name(df, 'name', dictionary)
assert result.loc[0, 'package_quantity'] == 24
class TestNormalizeUnit:
def test_map_units(self):
df = pd.DataFrame({'unit': ['', '', '', ''], 'quantity': [1, 2, 3, 4]})
unit_map = {'': '', '': '', '': ''}
result = _normalize_unit(df, 'unit', unit_map)
# _normalize_unit maps via unit_map, then converts 件→瓶 as packed unit
assert list(result['unit']) == ['', '', '', '']
def test_convert_quantity_for_packed_units(self):
df = pd.DataFrame({
'unit': ['', ''],
'quantity': [2, 5],
'package_quantity': [12, None],
})
unit_map = {'': ''}
result = _normalize_unit(df, 'unit', unit_map)
assert result.loc[0, 'quantity'] == 24 # 2 * 12
assert result.loc[1, 'quantity'] == 5 # unchanged
def test_missing_column(self):
df = pd.DataFrame({'other': [1]})
result = _normalize_unit(df, 'unit', {})
assert 'unit' not in result.columns
class TestComputeQuantityFromTotal:
def test_compute_when_qty_zero(self):
df = pd.DataFrame({
'quantity': [0, 5, 0],
'unit_price': [10.0, 20.0, 0.0],
'total_price': [50.0, 100.0, 30.0],
})
result = _compute_quantity_from_total(df)
assert result.loc[0, 'quantity'] == 5.0 # 50 / 10
assert result.loc[1, 'quantity'] == 5 # unchanged
def test_no_compute_when_qty_positive(self):
df = pd.DataFrame({
'quantity': [3, 5],
'unit_price': [10.0, 20.0],
'total_price': [50.0, 100.0],
})
result = _compute_quantity_from_total(df)
assert list(result['quantity']) == [3, 5]
class TestFillMissing:
def test_fill_existing_column(self):
df = pd.DataFrame({'a': [1, None, 3], 'b': [None, 2, None]})
result = _fill_missing(df, {'a': 0, 'b': 99})
assert result.loc[1, 'a'] == 0
assert result.loc[0, 'b'] == 99
def test_fill_new_column(self):
df = pd.DataFrame({'a': [1, 2]})
result = _fill_missing(df, {'new_col': 'default'})
assert list(result['new_col']) == ['default', 'default']
class TestMarkGift:
def test_gift_by_zero_price(self):
df = pd.DataFrame({
'name': ['商品A', '商品B'],
'unit_price': [10.0, 0.0],
'total_price': [20.0, 0.0],
})
result = _mark_gift(df)
assert result.loc[0, 'is_gift'] == False
assert result.loc[1, 'is_gift'] == True
def test_gift_by_name(self):
df = pd.DataFrame({
'name': ['赠品-杯子', '商品A'],
'unit_price': [0.0, 10.0],
'total_price': [0.0, 20.0],
})
result = _mark_gift(df)
assert result.loc[0, 'is_gift'] == True
assert result.loc[1, 'is_gift'] == False
def test_gift_no_price_columns(self):
df = pd.DataFrame({'name': ['赠品', '正常']})
result = _mark_gift(df)
assert result.loc[0, 'is_gift'] == True
assert result.loc[1, 'is_gift'] == False
class TestApplyRules:
def test_multiple_rules(self, sample_df):
rules = [
{'type': 'split_quantity_unit', 'source': 'quantity_raw'},
{'type': 'extract_spec_from_name', 'source': 'name'},
{'type': 'mark_gift'},
{'type': 'fill_missing', 'fills': {'unit': ''}},
]
result = apply_rules(sample_df, rules)
assert 'quantity' in result.columns
assert 'unit' in result.columns
assert 'package_quantity' in result.columns
assert 'is_gift' in result.columns
def test_empty_rules(self, sample_df):
result = apply_rules(sample_df, [])
assert len(result) == len(sample_df)
def test_none_rules(self, sample_df):
result = apply_rules(sample_df, None)
assert len(result) == len(sample_df)
def test_unknown_rule_type(self, sample_df):
rules = [{'type': 'unknown_operation'}]
result = apply_rules(sample_df, rules)
assert len(result) == len(sample_df)
def test_with_dictionary(self):
df = pd.DataFrame({
'name': ['农夫山泉550ml*24'],
'quantity_raw': ['2箱'],
})
dictionary = {
'unit_synonyms': {'': ''},
'default_unit': '',
'ignore_words': [],
'name_patterns': [],
'pack_multipliers': {'': 12},
'default_package_quantity': 1,
}
rules = [
{'type': 'split_quantity_unit', 'source': 'quantity_raw'},
{'type': 'extract_spec_from_name', 'source': 'name'},
{'type': 'normalize_unit', 'target': 'unit', 'map': {'': ''}},
]
result = apply_rules(df, rules, dictionary)
assert 'quantity' in result.columns
assert 'unit' in result.columns
+124
View File
@@ -0,0 +1,124 @@
"""app.core.utils.string_utils 单元测试"""
import pytest
from app.core.utils.string_utils import parse_monetary_string, format_barcode
class TestParseMonetaryString:
"""parse_monetary_string 金额/数量字符串解析测试"""
# --- 基本类型 ---
def test_none_returns_none(self):
assert parse_monetary_string(None) is None
def test_int_passthrough(self):
assert parse_monetary_string(42) == 42.0
def test_float_passthrough(self):
assert parse_monetary_string(3.14) == 3.14
def test_zero_int(self):
assert parse_monetary_string(0) == 0.0
# --- 正常字符串 ---
def test_plain_number(self):
assert parse_monetary_string("123.45") == 123.45
def test_integer_string(self):
assert parse_monetary_string("100") == 100.0
# --- 货币符号 ---
def test_yen_prefix(self):
assert parse_monetary_string("¥1234.56") == 1234.56
def test_dollar_prefix(self):
assert parse_monetary_string("$99.9") == 99.9
def test_yuan_suffix(self):
assert parse_monetary_string("100元") == 100.0
# --- 逗号处理 ---
def test_comma_as_decimal_point(self):
"""逗号当小数点: "1,5" = 1.5"""
assert parse_monetary_string("1,5") == 1.5
def test_comma_as_thousands_sep(self):
"""逗号当千位分隔符: "1,234.56" = 1234.56"""
assert parse_monetary_string("1,234.56") == 1234.56
def test_multiple_commas_thousands(self):
"""多个逗号: "1,234,567" = 1234567"""
assert parse_monetary_string("1,234,567") == 1234567.0
# --- 空值/无效值 ---
def test_empty_string(self):
assert parse_monetary_string("") is None
def test_whitespace_only(self):
assert parse_monetary_string(" ") is None
def test_o_string(self):
"""OCR 常见误识别: 字母 o 当数字 0"""
assert parse_monetary_string("o") is None
def test_none_string(self):
assert parse_monetary_string("none") is None
def test_null_string(self):
assert parse_monetary_string("null") is None
def test_dash(self):
assert parse_monetary_string("-") is None
def test_double_dash(self):
assert parse_monetary_string("--") is None
def test_no_digits(self):
assert parse_monetary_string("赠品") is None
# --- 负数 ---
def test_negative_number(self):
assert parse_monetary_string("-5.5") == -5.5
# --- 非字符串非数字类型 ---
def test_list_returns_none(self):
assert parse_monetary_string([1, 2]) is None
def test_dict_returns_none(self):
assert parse_monetary_string({"a": 1}) is None
class TestFormatBarcode:
"""format_barcode 条码格式化测试"""
def test_none_returns_empty(self):
assert format_barcode(None) == ""
def test_normal_digit_string(self):
assert format_barcode("6920584471055") == "6920584471055"
def test_integer_input(self):
assert format_barcode(6920584471055) == "6920584471055"
def test_float_with_zero_decimal(self):
assert format_barcode(6920584471055.0) == "6920584471055"
def test_scientific_notation(self):
assert format_barcode("6.920584e+12") == "6920584000000"
def test_trailing_zeros_stripped(self):
assert format_barcode("123456.0") == "123456"
def test_long_barcode_with_trailing_zeros(self):
"""14位条码末尾是0时应截断到13位"""
assert format_barcode("69205844710550") == "6920584471055"
def test_long_barcode_without_trailing_zeros(self):
"""14位条码末尾不是0时不截断"""
assert format_barcode("69205844710551") == "69205844710551"
def test_non_digit_chars_removed(self):
assert format_barcode("692-058-4471055") == "6920584471055"
def test_empty_string(self):
assert format_barcode("") == ""
+251
View File
@@ -0,0 +1,251 @@
"""app.core.excel.validators 单元测试"""
import pytest
from app.core.excel.validators import ProductValidator
@pytest.fixture
def validator():
return ProductValidator()
class TestValidateBarcode:
"""条码验证测试"""
def test_valid_barcode_13_digits(self, validator):
ok, val, err = validator.validate_barcode("6920584471055")
assert ok is True
assert val == "6920584471055"
assert err is None
def test_valid_barcode_8_digits(self, validator):
ok, val, err = validator.validate_barcode("12345678")
assert ok is True
assert val == "12345678"
def test_valid_barcode_12_digits(self, validator):
ok, val, err = validator.validate_barcode("692058447105")
assert ok is True
def test_none_returns_invalid(self, validator):
ok, val, err = validator.validate_barcode(None)
assert ok is False
assert err == "条码为空"
def test_warehouse_identifier(self, validator):
ok, val, err = validator.validate_barcode("仓库")
assert ok is False
assert val == "仓库"
assert err == "条码为仓库标识"
def test_warehouse_full_name(self, validator):
ok, val, err = validator.validate_barcode("仓库全名")
assert ok is False
def test_prefix_5_to_6_correction(self, validator):
"""5开头(非53)的长条码应修正为6开头"""
ok, val, err = validator.validate_barcode("5920584471055")
assert ok is True
assert val.startswith("6")
assert val == "6920584471055"
def test_prefix_53_not_corrected(self, validator):
"""53开头的条码不修正"""
ok, val, err = validator.validate_barcode("5321545613000")
assert ok is True
assert val.startswith("53")
def test_14_digit_trailing_zero_truncated(self, validator):
"""14位条码末尾是0时截断到13位"""
ok, val, err = validator.validate_barcode("69205844710550")
assert ok is True
assert len(val) == 13
def test_14_digit_no_trailing_zero_invalid(self, validator):
"""14位条码末尾不是0时报错"""
ok, val, err = validator.validate_barcode("69205844710551")
assert ok is False
assert "长度异常" in err
def test_too_short_invalid(self, validator):
ok, val, err = validator.validate_barcode("1234567")
assert ok is False
assert "长度异常" in err
def test_too_long_invalid(self, validator):
ok, val, err = validator.validate_barcode("1" * 14)
# 14 digits with trailing 0s gets truncated, but "111...1" has no trailing 0
ok2, val2, err2 = validator.validate_barcode("1" * 15)
assert ok2 is False
def test_no_digits_invalid(self, validator):
ok, val, err = validator.validate_barcode("abc")
assert ok is False
assert err == "条码不包含数字"
def test_float_input_cleaned(self, validator):
"""浮点数输入应清理为整数字符串"""
ok, val, err = validator.validate_barcode(6920584471055.0)
assert ok is True
assert val == "6920584471055"
def test_special_barcode_5321545613(self, validator):
"""特殊条码 5321545613 应通过验证"""
ok, val, err = validator.validate_barcode("5321545613")
assert ok is True
assert val == "5321545613"
class TestValidatePrice:
"""单价验证测试"""
def test_valid_price(self, validator):
ok, val, is_gift, err = validator.validate_price(10.5)
assert ok is True
assert val == 10.5
assert is_gift is False
def test_zero_price_is_gift(self, validator):
ok, val, is_gift, err = validator.validate_price(0)
assert ok is True
assert val == 0.0
assert is_gift is True
def test_none_is_gift(self, validator):
ok, val, is_gift, err = validator.validate_price(None)
assert ok is False
assert is_gift is True
def test_gift_string(self, validator):
ok, val, is_gift, err = validator.validate_price("赠品")
assert ok is True
assert is_gift is True
def test_gift_english(self, validator):
ok, val, is_gift, err = validator.validate_price("gift")
assert ok is True
assert is_gift is True
def test_price_string_with_yen(self, validator):
ok, val, is_gift, err = validator.validate_price("¥123.45")
assert ok is True
assert val == 123.45
assert is_gift is False
def test_price_string_with_comma(self, validator):
ok, val, is_gift, err = validator.validate_price("1,234.56")
assert ok is True
assert val == 1234.56
def test_negative_price_invalid(self, validator):
ok, val, is_gift, err = validator.validate_price(-5)
assert ok is False
assert is_gift is True
def test_empty_string_is_gift(self, validator):
ok, val, is_gift, err = validator.validate_price("")
assert ok is True
assert is_gift is True
class TestValidateQuantity:
"""数量验证测试"""
def test_valid_quantity(self, validator):
ok, val, err = validator.validate_quantity(10)
assert ok is True
assert val == 10.0
def test_float_quantity(self, validator):
ok, val, err = validator.validate_quantity(2.5)
assert ok is True
assert val == 2.5
def test_string_quantity(self, validator):
ok, val, err = validator.validate_quantity("15")
assert ok is True
assert val == 15.0
def test_string_with_unit(self, validator):
ok, val, err = validator.validate_quantity("10瓶")
assert ok is True
assert val == 10.0
def test_none_invalid(self, validator):
ok, val, err = validator.validate_quantity(None)
assert ok is False
assert err == "数量为空"
def test_zero_invalid(self, validator):
ok, val, err = validator.validate_quantity(0)
assert ok is False
assert "必须大于0" in err
def test_negative_invalid(self, validator):
ok, val, err = validator.validate_quantity(-3)
assert ok is False
assert "必须大于0" in err
def test_non_numeric_string_invalid(self, validator):
ok, val, err = validator.validate_quantity("abc")
assert ok is False
assert err == "数量不包含数字"
class TestValidateProduct:
"""商品数据整体验证测试"""
def test_valid_product(self, validator):
product = {
'barcode': '6920584471055',
'price': 10.5,
'quantity': 5,
'amount': 52.5,
}
result = validator.validate_product(product)
assert result['barcode'] == '6920584471055'
assert result['price'] == 10.5
assert result['quantity'] == 5.0
assert result.get('is_gift') is None or result.get('is_gift') is False
def test_gift_product(self, validator):
product = {
'barcode': '6920584471055',
'price': '赠品',
'quantity': 5,
}
result = validator.validate_product(product)
assert result['is_gift'] is True
assert result['price'] == 0.0
def test_quantity_from_amount_and_price(self, validator):
"""数量为空时,通过金额/单价计算"""
product = {
'barcode': '6920584471055',
'price': 10.0,
'amount': 50.0,
'quantity': None,
}
result = validator.validate_product(product)
assert result['quantity'] == 5.0 # 50 / 10
def test_invalid_barcode_still_uses_fixed(self, validator):
"""条码验证失败但有修复值时仍使用修复值"""
product = {
'barcode': '5920584471055', # 5开头, 会被修正为6开头
'price': 10.0,
'quantity': 5,
}
result = validator.validate_product(product)
assert result['barcode'] == '6920584471055'
def test_amount_zero_marks_gift(self, validator):
"""金额为0时标记为赠品"""
product = {
'barcode': '6920584471055',
'price': 10.0,
'quantity': 5,
'amount': 0,
}
result = validator.validate_product(product)
assert result.get('is_gift') is True