"""
句子生成器管理器模块
负责发现、加载和管理引擎特定的生成器
"""
import importlib
import pkgutil
import inspect
from typing import List, Dict, Type
from core.base_sentence_generator import BaseSentenceGenerator
from core.param_translator import ParamTranslator
from core.config_manager import EngineConfig
from core.logger import get_logger
from core.exceptions import GeneratorError
logger = get_logger()
[文档]
class SentenceGeneratorManager:
"""句子生成器管理器"""
def __init__(self, engine_type: str):
self.engine_type = engine_type
self.generator_classes: List[Type[BaseSentenceGenerator]] = []
self.param_configs: Dict = {}
self._loaded = False
[文档]
def load(self):
"""加载所有生成器类和参数配置"""
if self._loaded:
return
logger.info(f"开始加载 {self.engine_type} 引擎的生成器")
self._discover_generator_classes()
self._collect_param_configs()
self._loaded = True
def _discover_generator_classes(self):
"""发现指定引擎的所有生成器类"""
generators_package = f"engines.{self.engine_type}.sentence_generators"
try:
package = importlib.import_module(generators_package)
for _, module_name, is_pkg in pkgutil.iter_modules(package.__path__):
if is_pkg or not module_name.endswith('_generator'):
continue
try:
full_module_name = f"{generators_package}.{module_name}"
module = importlib.import_module(full_module_name)
for name, obj in inspect.getmembers(module, inspect.isclass):
if (self._is_generator_class(obj) and
obj.__module__ == module.__name__):
self.generator_classes.append(obj)
logger.debug(f"发现生成器: {obj.__name__}")
except Exception as e:
logger.error(f"导入模块 {module_name} 时出错: {e}")
except ImportError as e:
logger.error(f"导入生成器包 {generators_package} 时出错: {e}")
raise GeneratorError(f"无法加载引擎 {self.engine_type} 的生成器") from e
def _is_generator_class(self, obj) -> bool:
"""检查是否为有效的生成器类"""
try:
return (inspect.isclass(obj) and
issubclass(obj, BaseSentenceGenerator) and
obj != BaseSentenceGenerator)
except TypeError:
return False
def _collect_param_configs(self):
"""收集所有生成器的参数配置"""
total_params = 0
for generator_class in self.generator_classes:
param_config = getattr(generator_class, 'param_config', {})
if param_config and isinstance(param_config, dict):
self.param_configs.update(param_config)
total_params += len(param_config)
logger.info(f"从 {len(self.generator_classes)} 个生成器收集了 {total_params} 个参数配置")
[文档]
def create_generator_instances(
self,
translator: ParamTranslator,
engine_config: EngineConfig
) -> List[BaseSentenceGenerator]:
"""创建生成器实例"""
self.load()
instances = []
for generator_class in self.generator_classes:
try:
instance = generator_class(translator, engine_config)
instances.append(instance)
except Exception as e:
logger.error(f"创建 {generator_class.__name__} 实例失败: {e}")
instances.sort(key=lambda g: g.priority)
logger.info(f"共创建 {len(instances)} 个生成器")
for i, generator in enumerate(instances, 1):
logger.info(f" {i}. {generator.__class__.__name__} (优先级: {generator.priority})")
return instances
[文档]
def get_all_param_names(self) -> List[str]:
"""获取所有参数名称"""
self.load()
return sorted(list(self.param_configs.keys()))
[文档]
def get_validate_params(self) -> Dict[str, list[str]]:
"""
收集所有参数配置中的translate_type值和validate_type值
Returns:
Dict[str, list[str]]: 去重后的数据验证参数列表词典
"""
translate_types = set()
validate_types = set()
result = {}
self.load()
for generator_class in self.generator_classes:
param_config = getattr(generator_class, 'param_config', {})
if param_config and isinstance(param_config, dict):
for param_name, config in param_config.items():
if isinstance(config, dict):
if "translate_type" in config:
translate_type = config["translate_type"]
if translate_type:
translate_types.add(translate_type)
elif "validate_type" in config:
validate_type = config["validate_type"]
if validate_type:
validate_types.add(validate_type)
# 转换为排序后的列表
result["translate_types"] = sorted(list(translate_types))
result["validate_types"] = sorted(list(validate_types))
logger.info(f"收集到 {len(translate_types)} 个翻译参数: {result['translate_types']}")
logger.info(f"收集到 {len(validate_types)} 个验证参数: {result['validate_types']}")
return result