tests.core.test_sentence_generator_manager 源代码

"""
测试 sentence_generator_manager 模块
"""
import pytest
from unittest.mock import Mock, patch, MagicMock
from core.sentence_generator_manager import SentenceGeneratorManager
from core.base_sentence_generator import BaseSentenceGenerator
from core.param_translator import ParamTranslator
from core.config_manager import EngineConfig
from core.exceptions import GeneratorError


# 创建测试用的生成器类
[文档] class MockGenerator1(BaseSentenceGenerator): """测试生成器 1""" param_config = { "Music": {"format": "play music {value}"}, "Sound": {"format": "play sound {value}"} } @property def category(self): return "mock1" @property def priority(self): return 10
[文档] def process(self, data): return ["mock1 command"]
[文档] class MockGenerator2(BaseSentenceGenerator): """测试生成器 2""" param_config = { "Background": {"format": "scene {value}"}, "Character": {"format": "show {value}"} } @property def category(self): return "mock2" @property def priority(self): return 5
[文档] def process(self, data): return ["mock2 command"]
[文档] class MockGenerator3(BaseSentenceGenerator): """测试生成器 3(没有 param_config)""" @property def category(self): return "mock3" @property def priority(self): return 20
[文档] def process(self, data): return ["mock3 command"]
[文档] class TestSentenceGeneratorManager: """测试 SentenceGeneratorManager 类"""
[文档] @pytest.fixture def manager(self): """创建管理器实例""" return SentenceGeneratorManager("test_engine")
[文档] def test_init(self, manager): """测试初始化""" assert manager.engine_type == "test_engine" assert manager.generator_classes == [] assert manager.param_configs == {} assert manager._loaded is False
[文档] def test_load_only_once(self, manager): """测试 load 方法只执行一次""" with patch.object(manager, '_discover_generator_classes') as mock_discover: with patch.object(manager, '_collect_param_configs') as mock_collect: # 第一次调用 manager.load() assert manager._loaded is True assert mock_discover.call_count == 1 assert mock_collect.call_count == 1 # 第二次调用(应该被跳过) manager.load() assert mock_discover.call_count == 1 assert mock_collect.call_count == 1
[文档] @pytest.mark.parametrize("obj,expected", [ # 有效的生成器类 (MockGenerator1, True), (MockGenerator2, True), # 基类(应该被过滤) (BaseSentenceGenerator, False), # 非类对象 ("not a class", False), (123, False), (None, False), ]) def test_is_generator_class(self, manager, obj, expected): """测试 _is_generator_class 方法""" assert manager._is_generator_class(obj) is expected
[文档] def test_is_generator_class_not_subclass(self, manager): """测试非生成器子类""" class NotAGenerator: pass assert manager._is_generator_class(NotAGenerator) is False
[文档] def test_collect_param_configs(self, manager): """测试收集参数配置""" manager.generator_classes = [MockGenerator1, MockGenerator2, MockGenerator3] manager._collect_param_configs() # 验证收集到的配置 assert "Music" in manager.param_configs assert "Sound" in manager.param_configs assert "Background" in manager.param_configs assert "Character" in manager.param_configs assert len(manager.param_configs) == 4
[文档] def test_collect_param_configs_empty(self, manager): """测试收集空配置""" manager.generator_classes = [] manager._collect_param_configs() assert manager.param_configs == {}
[文档] def test_collect_param_configs_no_param_config(self, manager): """测试生成器没有 param_config 属性""" manager.generator_classes = [MockGenerator3] manager._collect_param_configs() # MockGenerator3 没有 param_config,应该为空 assert manager.param_configs == {}
[文档] def test_get_all_param_names(self, manager): """测试获取所有参数名称""" manager.generator_classes = [MockGenerator1, MockGenerator2] manager._loaded = True manager._collect_param_configs() param_names = manager.get_all_param_names() # 应该按字母顺序排序 assert param_names == ["Background", "Character", "Music", "Sound"]
[文档] def test_get_all_param_names_empty(self, manager): """测试获取空参数列表""" manager._loaded = True param_names = manager.get_all_param_names() assert param_names == []
[文档] def test_get_all_param_names_calls_load(self, manager): """测试 get_all_param_names 会调用 load""" with patch.object(manager, 'load') as mock_load: manager.get_all_param_names() mock_load.assert_called_once()
[文档] class TestCreateGeneratorInstances: """测试创建生成器实例"""
[文档] @pytest.fixture def manager(self): """创建管理器实例""" return SentenceGeneratorManager("test_engine")
[文档] @pytest.fixture def mock_translator(self): """创建模拟翻译器""" return Mock(spec=ParamTranslator)
[文档] @pytest.fixture def mock_config(self): """创建模拟配置""" return Mock(spec=EngineConfig)
[文档] def test_create_generator_instances(self, manager, mock_translator, mock_config): """测试创建生成器实例""" manager.generator_classes = [MockGenerator1, MockGenerator2, MockGenerator3] manager._loaded = True instances = manager.create_generator_instances(mock_translator, mock_config) # 验证创建了 3 个实例 assert len(instances) == 3 # 验证所有实例都是 BaseSentenceGenerator 的子类 for instance in instances: assert isinstance(instance, BaseSentenceGenerator)
[文档] def test_create_generator_instances_sorted_by_priority(self, manager, mock_translator, mock_config): """测试生成器按优先级排序""" # MockGenerator1: priority=10 # MockGenerator2: priority=5 # MockGenerator3: priority=20 manager.generator_classes = [MockGenerator1, MockGenerator2, MockGenerator3] manager._loaded = True instances = manager.create_generator_instances(mock_translator, mock_config) # 应该按优先级从小到大排序 assert instances[0].priority == 5 # MockGenerator2 assert instances[1].priority == 10 # MockGenerator1 assert instances[2].priority == 20 # MockGenerator3
[文档] def test_create_generator_instances_calls_load(self, manager, mock_translator, mock_config): """测试 create_generator_instances 会调用 load""" with patch.object(manager, 'load') as mock_load: manager._loaded = True manager.generator_classes = [] manager.create_generator_instances(mock_translator, mock_config) mock_load.assert_called_once()
[文档] def test_create_generator_instances_empty(self, manager, mock_translator, mock_config): """测试没有生成器类时""" manager.generator_classes = [] manager._loaded = True instances = manager.create_generator_instances(mock_translator, mock_config) assert instances == []
[文档] def test_create_generator_instances_with_error(self, manager, mock_translator, mock_config): """测试创建实例时出错""" # 创建一个会抛出异常的生成器类 class BrokenGenerator(BaseSentenceGenerator): def __init__(self, translator, config): raise ValueError("Initialization failed") @property def category(self): return "broken" def process(self, data): return [] manager.generator_classes = [MockGenerator1, BrokenGenerator, MockGenerator2] manager._loaded = True instances = manager.create_generator_instances(mock_translator, mock_config) # 应该只创建成功的实例(跳过失败的) assert len(instances) == 2 assert all(isinstance(i, BaseSentenceGenerator) for i in instances)
[文档] class TestDiscoverGeneratorClasses: """测试发现生成器类"""
[文档] @pytest.fixture def manager(self): """创建管理器实例""" return SentenceGeneratorManager("test_engine")
[文档] def test_discover_generator_classes_manual_add(self, manager): """测试手动添加生成器类(模拟发现过程)""" # 直接测试发现后的结果,而不是 mock 整个发现过程 # 这是一个更实用的测试方法 manager.generator_classes = [MockGenerator1, MockGenerator2] # 验证生成器类已添加 assert len(manager.generator_classes) == 2 assert MockGenerator1 in manager.generator_classes assert MockGenerator2 in manager.generator_classes
[文档] def test_discover_generator_classes_skip_packages(self, manager): """测试跳过包(只处理模块)""" mock_package = MagicMock() mock_package.__path__ = [] with patch('importlib.import_module') as mock_import: with patch('pkgutil.iter_modules') as mock_iter: mock_import.return_value = mock_package # is_pkg=True 表示是包,应该被跳过 mock_iter.return_value = [ (None, "subpackage", True), (None, "not_generator", False) # 不以 _generator 结尾 ] manager._discover_generator_classes() # 不应该发现任何生成器 assert manager.generator_classes == []
[文档] def test_discover_generator_classes_skip_non_generator_modules(self, manager): """测试跳过不以 _generator 结尾的模块""" mock_package = MagicMock() mock_package.__path__ = [] with patch('importlib.import_module') as mock_import: with patch('pkgutil.iter_modules') as mock_iter: mock_import.return_value = mock_package mock_iter.return_value = [ (None, "utils", False), (None, "helpers", False), (None, "__init__", False) ] manager._discover_generator_classes() assert manager.generator_classes == []
[文档] def test_discover_generator_classes_import_error(self, manager): """测试导入包失败""" with patch('importlib.import_module') as mock_import: mock_import.side_effect = ImportError("Package not found") with pytest.raises(GeneratorError, match="无法加载引擎 test_engine 的生成器"): manager._discover_generator_classes()
[文档] def test_discover_generator_classes_module_import_error(self, manager): """测试导入模块失败(应该记录错误但继续)""" mock_package = MagicMock() mock_package.__path__ = [] with patch('importlib.import_module') as mock_import: with patch('pkgutil.iter_modules') as mock_iter: # 第一次调用返回包,第二次调用抛出异常 mock_import.side_effect = [ mock_package, ImportError("Module not found") ] mock_iter.return_value = [ (None, "broken_generator", False) ] # 应该不抛出异常,只记录错误 manager._discover_generator_classes() # 不应该发现任何生成器 assert manager.generator_classes == []
[文档] class TestSentenceGeneratorManagerIntegration: """集成测试:测试完整的工作流程"""
[文档] @pytest.fixture def manager(self): """创建管理器实例""" return SentenceGeneratorManager("test_engine")
[文档] @pytest.fixture def mock_translator(self): """创建模拟翻译器""" return Mock(spec=ParamTranslator)
[文档] @pytest.fixture def mock_config(self): """创建模拟配置""" return Mock(spec=EngineConfig)
[文档] def test_full_workflow(self, manager, mock_translator, mock_config): """测试完整工作流程""" # 模拟发现生成器 manager.generator_classes = [MockGenerator1, MockGenerator2, MockGenerator3] manager._loaded = True # 1. 收集参数配置 manager._collect_param_configs() assert len(manager.param_configs) == 4 # 2. 获取所有参数名称 param_names = manager.get_all_param_names() assert len(param_names) == 4 assert param_names == ["Background", "Character", "Music", "Sound"] # 3. 创建生成器实例 instances = manager.create_generator_instances(mock_translator, mock_config) assert len(instances) == 3 # 4. 验证排序 assert instances[0].priority < instances[1].priority < instances[2].priority # 5. 验证实例可以使用 for instance in instances: assert instance.translator == mock_translator assert instance.engine_config == mock_config result = instance.process({"test": "data"}) assert result is not None
[文档] def test_load_idempotent(self, manager): """测试 load 方法的幂等性""" manager.generator_classes = [MockGenerator1] manager._loaded = True # 多次调用 load manager.load() manager.load() manager.load() # 状态应该保持一致 assert manager._loaded is True
[文档] def test_param_configs_merge(self, manager): """测试参数配置合并""" # 创建两个有重叠配置的生成器 class Generator1(BaseSentenceGenerator): param_config = {"Music": {"format": "play {value}"}} @property def category(self): return "gen1" def process(self, data): return [] class Generator2(BaseSentenceGenerator): param_config = { "Music": {"format": "music {value}"}, # 重复的键 "Sound": {"format": "sound {value}"} } @property def category(self): return "gen2" def process(self, data): return [] manager.generator_classes = [Generator1, Generator2] manager._collect_param_configs() # 后面的配置应该覆盖前面的 assert "Music" in manager.param_configs assert "Sound" in manager.param_configs assert manager.param_configs["Music"]["format"] == "music {value}"