"""app.core.db.product_db 单元测试""" import os import tempfile import pytest import pandas as pd from app.core.db.product_db import ProductDatabase @pytest.fixture def db_dir(): """临时目录""" with tempfile.TemporaryDirectory() as d: yield d @pytest.fixture def sample_excel(db_dir): """创建测试用 Excel 文件""" path = os.path.join(db_dir, '商品资料.xlsx') df = pd.DataFrame({ '商品条码': ['6920584471055', '6901028001133', '6925303800013'], '商品名称': ['农夫山泉550ml', '蒙牛纯牛奶', '可口可乐330ml'], '进货价': [1.2, 3.5, 1.8], '单位': ['瓶', '盒', '罐'], }) df.to_excel(path, index=False) return path @pytest.fixture def db_with_data(db_dir, sample_excel): """已导入数据的数据库""" db_path = os.path.join(db_dir, 'product_cache.db') db = ProductDatabase(db_path, sample_excel) return db class TestProductDatabaseInit: """数据库初始化测试""" def test_auto_import_on_first_run(self, db_dir, sample_excel): """首次运行自动从 Excel 导入""" db_path = os.path.join(db_dir, 'product_cache.db') assert not os.path.exists(db_path) db = ProductDatabase(db_path, sample_excel) assert os.path.exists(db_path) assert db.count() == 3 def test_no_reimport_on_existing_db(self, db_dir, sample_excel): """数据库已存在时不重新导入""" db_path = os.path.join(db_dir, 'product_cache.db') db1 = ProductDatabase(db_path, sample_excel) assert db1.count() == 3 # 删除 Excel 后仍能打开已有数据库 os.remove(sample_excel) db2 = ProductDatabase(db_path, sample_excel) assert db2.count() == 3 def test_missing_excel_creates_empty_db(self, db_dir): """Excel 不存在时创建空数据库""" db_path = os.path.join(db_dir, 'product_cache.db') fake_excel = os.path.join(db_dir, '不存在.xlsx') db = ProductDatabase(db_path, fake_excel) assert os.path.exists(db_path) assert db.count() == 0 def test_missing_dir_created(self, db_dir, sample_excel): """数据库目录不存在时自动创建""" db_path = os.path.join(db_dir, 'subdir', 'product_cache.db') db = ProductDatabase(db_path, sample_excel) assert os.path.exists(db_path) assert db.count() == 3 class TestGetPrice: """单条查询测试""" def test_existing_barcode(self, db_with_data): price = db_with_data.get_price('6920584471055') assert price == pytest.approx(1.2) def test_nonexistent_barcode(self, db_with_data): price = db_with_data.get_price('0000000000000') assert price is None def test_empty_barcode(self, db_with_data): price = db_with_data.get_price('') assert price is None def test_barcode_with_spaces(self, db_with_data): """条码前后空格应能匹配""" price = db_with_data.get_price(' 6920584471055 ') assert price == pytest.approx(1.2) class TestGetPrices: """批量查询测试""" def test_multiple_barcodes(self, db_with_data): result = db_with_data.get_prices(['6920584471055', '6901028001133']) assert len(result) == 2 assert result['6920584471055'] == pytest.approx(1.2) assert result['6901028001133'] == pytest.approx(3.5) def test_partial_match(self, db_with_data): """部分条码存在,部分不存在""" result = db_with_data.get_prices(['6920584471055', '0000000000000']) assert len(result) == 1 assert '6920584471055' in result def test_empty_list(self, db_with_data): result = db_with_data.get_prices([]) assert result == {} def test_all_nonexistent(self, db_with_data): result = db_with_data.get_prices(['0000000000000', '1111111111111']) assert result == {} class TestReimport: """重新导入测试""" def test_reimport_clears_and_reloads(self, db_dir, sample_excel): db_path = os.path.join(db_dir, 'product_cache.db') db = ProductDatabase(db_path, sample_excel) assert db.count() == 3 # 修改 Excel,添加一行 df = pd.read_excel(sample_excel) df = pd.concat([df, pd.DataFrame({ '商品条码': ['6954365200123'], '商品名称': ['测试商品'], '进货价': [5.0], '单位': ['个'], })]) df.to_excel(sample_excel, index=False) count = db.reimport() assert count == 4 assert db.count() == 4 assert db.get_price('6954365200123') == pytest.approx(5.0) class TestEdgeCases: """边界条件测试""" def test_excel_with_missing_price(self, db_dir): """Excel 中价格列为空的行""" path = os.path.join(db_dir, '商品资料.xlsx') df = pd.DataFrame({ '商品条码': ['6920584471055', '6901028001133'], '商品名称': ['商品A', '商品B'], '进货价': [1.5, None], }) df.to_excel(path, index=False) db_path = os.path.join(db_dir, 'product_cache.db') db = ProductDatabase(db_path, path) assert db.count() == 2 assert db.get_price('6920584471055') == pytest.approx(1.5) assert db.get_price('6901028001133') == pytest.approx(0.0) def test_excel_with_duplicate_barcodes(self, db_dir): """重复条码取最后一条 (INSERT OR REPLACE)""" path = os.path.join(db_dir, '商品资料.xlsx') df = pd.DataFrame({ '商品条码': ['6920584471055', '6920584471055'], '商品名称': ['商品A', '商品A-新'], '进货价': [1.0, 2.0], }) df.to_excel(path, index=False) db_path = os.path.join(db_dir, 'product_cache.db') db = ProductDatabase(db_path, path) assert db.count() == 1 assert db.get_price('6920584471055') == pytest.approx(2.0)