update_param 源代码

"""
参数映射更新工具
从 Excel 参数文件生成 Python 参数映射模块
"""
import pandas as pd
from pathlib import Path
from typing import Dict, List
from openpyxl import load_workbook
from openpyxl.styles import Alignment
from openpyxl.utils import get_column_letter
from openpyxl.workbook.defined_name import DefinedName
from core.sentence_generator_manager import SentenceGeneratorManager
from core.config_manager import AppConfig
from core.logger import get_logger

logger = get_logger()


[文档] class ParamUpdater: """参数映射更新器""" def __init__(self, config: AppConfig): """ 初始化参数更新器 Args: config: 应用配置 """ self.config = config self.engine_type = config.engine.engine_type
[文档] def read_param_file(self, param_file: Path, skip_template: bool = True) -> Dict[str, Dict[str, str]]: """ 读取参数文件并生成映射 Args: param_file: 参数文件路径 skip_template: 是否跳过模板工作表 Returns: Dict[str, Dict[str, str]]: 参数映射字典 """ if not param_file.exists(): logger.error(f"参数文件不存在: {param_file}") return {} try: # 读取所有工作表 sheets = pd.read_excel(param_file, sheet_name=None, dtype=str, keep_default_na=False, na_values=['']) logger.info(f"读取到 {len(sheets)} 个工作表") mappings = {} for sheet_name, df in sheets.items(): # 根据参数决定是否跳过模板工作表 if skip_template and "模板" in sheet_name: logger.debug(f"跳过模板工作表: {sheet_name}") continue # 检查必需的列 if "ExcelParam" not in df.columns or "ScenarioParam" not in df.columns: logger.warning(f"工作表 {sheet_name} 缺少必需的列,跳过") continue # 构建映射 sheet_mapping = {} for _, row in df.iterrows(): excel_param = row["ExcelParam"] scenario_param = row["ScenarioParam"] if sheet_name == "Transition": print(scenario_param) if pd.notna(excel_param) and pd.notna(scenario_param): sheet_mapping[str(excel_param)] = str(scenario_param) # 对于差分参数文件,保留空映射(包括模板) if not skip_template or sheet_mapping: mappings[sheet_name] = sheet_mapping if sheet_mapping: logger.info(f"工作表 {sheet_name}: {len(sheet_mapping)} 个映射") return mappings except Exception as e: logger.error(f"读取参数文件失败: {e}", exc_info=True) return {}
[文档] def generate_mappings_file(self, mappings: Dict[str, Dict[str, str]], output_file: Path): """ 生成参数映射 Python 文件 Args: mappings: 参数映射字典 output_file: 输出文件路径 """ try: # 根据文件名确定变量名 variable_name = "VARIENT_MAPPINGS" if "varient" in output_file.name else "PARAM_MAPPINGS" with open(output_file, "w", encoding="utf-8") as f: f.write("# 自动生成的参数映射文件\n") f.write("# 请不要手动编辑此文件\n") f.write(f"# 引擎类型: {self.engine_type}\n\n") f.write(f"{variable_name} = ") # 使用 repr 生成格式化的字典 import pprint f.write(pprint.pformat(mappings, width=100, sort_dicts=False)) f.write("\n") logger.info(f"参数映射已保存到: {output_file}") except Exception as e: logger.error(f"保存参数映射文件失败: {e}", exc_info=True)
[文档] def collect_validation_data(self, param_file: Path, varient_file: Path = None) -> Dict[str, List[str]]: """ 收集所有验证数据(基础参数 + 差分参数) Args: param_file: 基础参数文件路径 varient_file: 差分参数文件路径(可选) Returns: Dict[str, List[str]]: 参数名 -> 参数值列表 """ validation_data = {} # 1. 收集基础参数 try: base_sheets = pd.read_excel(param_file, sheet_name=None) for sheet_name, df in base_sheets.items(): # 检查是否有 ExcelParam 列 if "ExcelParam" not in df.columns: continue # 提取参数值 params = [] for _, row in df.iterrows(): excel_param = row["ExcelParam"] if pd.notna(excel_param): param_str = str(excel_param).strip() if param_str: params.append(param_str) # 只保存非空的参数列表 if params: validation_data[sheet_name] = params logger.debug(f"收集参数 {sheet_name}: {len(params)} 个值") except Exception as e: logger.error(f"读取基础参数文件失败: {e}", exc_info=True) return {} # 2. 收集差分参数 if varient_file and varient_file.exists(): try: varient_sheets = pd.read_excel(varient_file, sheet_name=None) # 收集所有差分参数名 all_varient_params = set() for sheet_name, df in varient_sheets.items(): if "ExcelParam" not in df.columns: continue for _, row in df.iterrows(): excel_param = row["ExcelParam"] if pd.notna(excel_param): param_str = str(excel_param).strip() if param_str: all_varient_params.add(param_str) # 将差分参数合并到 Varient 列 if all_varient_params: if "Varient" in validation_data: # 合并去重 existing_varients = set(validation_data["Varient"]) combined_varients = existing_varients.union(all_varient_params) validation_data["Varient"] = sorted(list(combined_varients)) else: validation_data["Varient"] = sorted(list(all_varient_params)) logger.debug(f"收集差分参数: {len(all_varient_params)} 个值") except Exception as e: logger.warning(f"读取差分参数文件失败: {e}") return validation_data
[文档] def get_all_validate_params(self) -> Dict[str, List[str]]: """ 获取所有句子生成器的参数翻译类型 Returns: Dict[str, Dict[str, list[str]]]: 去重后的数据验证参数类型列表 """ try: # 创建管理器实例 manager = SentenceGeneratorManager(self.engine_type) # 调用我们之前写的方法 return manager.get_validate_params() except Exception as e: logger.error(f"获取数据验证参数类型时发生错误: {e}") return {}
[文档] def update_scenario_param_sheets(self, validation_data: Dict[str, List[str]]) -> bool: """ 更新演出表格中的参数表工作表 Args: validation_data: 参数验证数据 Returns: bool: 是否成功 """ if not validation_data: logger.error("没有收集到验证数据") return False # 获取参数翻译类型列表 param_types = self.get_all_validate_params() if not param_types: logger.error("无法获取参数翻译类型词典") return False translate_params = param_types.get("translate_types", []) validate_params = param_types.get("validate_types", []) all_params = sorted(validate_params + translate_params) # 获取 input 目录 input_dir = Path(self.config.paths.input_dir) if not input_dir.exists(): logger.error(f"输入目录不存在: {input_dir}") return False # 查找所有 Excel 文件 excel_files = list(input_dir.glob("*.xlsx")) excel_files = [f for f in excel_files if not f.name.startswith("~")] if not excel_files: logger.warning(f"在 {input_dir} 中没有找到 Excel 文件") return True logger.info(f"找到 {len(excel_files)} 个演出表格文件") success_count = 0 for excel_file in excel_files: try: logger.info(f"处理文件: {excel_file.name}") # 加载工作簿 wb = load_workbook(excel_file) # 检查是否存在"参数表"工作表 if "参数表" not in wb.sheetnames: logger.info(f" 创建新的'参数表'工作表") validation_ws = wb.create_sheet("参数表") # logger.warning(f" {excel_file.name} 中没有'参数表'工作表,跳过") # wb.close() # continue else: validation_ws = wb["参数表"] # 创建居中对齐样式 center_alignment = Alignment(horizontal='center', vertical='center') # # 获取当前参数表的列顺序 # current_headers = [] # for col_idx in range(1, validation_ws.max_column + 1): # header = validation_ws.cell(row=1, column=col_idx).value # if header: # current_headers.append(header) # logger.info(f" 当前参数表列: {current_headers}") # 跟踪是否有任何更新 has_updates = False # 按照当前列顺序更新参数表 for col_idx, param_type in enumerate(all_params, 1): # 获取这个参数的数据 params = validation_data.get(param_type, []) # 检查表头是否匹配 current_header = validation_ws.cell(row=1, column=col_idx).value needs_update = False if current_header != param_type: needs_update = True logger.info(f" 表头不匹配: 当前'{current_header}',期望'{param_type}'") else: # 检查参数内容是否匹配 current_params = [] read_row = 2 while read_row <= validation_ws.max_row: cell_value = validation_ws.cell(row=read_row, column=col_idx).value if cell_value is None: break current_params.append(str(cell_value)) read_row += 1 # 如果参数数量或内容有变化,需要更新 if current_params != [str(p) for p in params]: needs_update = True logger.info(f" 更新 {param_type}: 当前 {len(current_params)} 个,新 {len(params)} 个") # 如果需要更新,重新写入这一列 if needs_update: has_updates = True # 清除这一列的内容(从第1行开始) for clear_row in range(1, validation_ws.max_row + 1): validation_ws.cell(row=clear_row, column=col_idx).value = None # 写入表头 header_cell = validation_ws.cell(row=1, column=col_idx, value=param_type) header_cell.alignment = center_alignment # 写入参数数据 for write_row, param_value in enumerate(params, 2): cell = validation_ws.cell(row=write_row, column=col_idx, value=param_value) cell.alignment = center_alignment logger.info(f" 更新 {param_type} 参数 写入 {len(params)} 个参数") else: logger.debug(f" {param_type} 列无需更新") # 创建或更新命名区域 range_name = f"{param_type}List" col_letter = get_column_letter(col_idx) expected_dynamic_range = f"OFFSET(参数表!${col_letter}$2,0,0,COUNTA(参数表!${col_letter}:${col_letter})-1,1)" # 检查命名区域是否需要更新 needs_update_range = True if range_name in wb.defined_names: existing_named_range = wb.defined_names[range_name] existing_range = existing_named_range.attr_text # 如果现有区域已经是正确的动态公式,则跳过 if existing_range == expected_dynamic_range: needs_update_range = False # 创建或更新命名区域 if needs_update_range: try: # 删除已存在的区域(如果存在) if range_name in wb.defined_names: del wb.defined_names[range_name] # 创建新的命名区域 wb.defined_names[range_name] = DefinedName( name=range_name, attr_text=expected_dynamic_range ) logger.debug(f" 创建命名区域: {range_name}") has_updates = True except Exception as e: logger.warning(f" 处理命名区域 {range_name} 时发生错误: {e}") # 清理多余的列(如果参数翻译类型列表比当前列少) current_column_count = validation_ws.max_column if current_column_count > len(all_params): logger.info(f" 清理多余的列: 当前{current_column_count}列,期望{len(all_params)}列") has_updates = True # 删除多余的列 for col_idx in range(len(all_params) + 1, current_column_count + 1): for row_idx in range(1, validation_ws.max_row + 1): validation_ws.cell(row=row_idx, column=col_idx).value = None # 保存工作簿(仅当有更新时) if has_updates: try: wb.save(excel_file) logger.info(f" 成功保存文件: {excel_file.name}") success_count += 1 except Exception as e: logger.error(f" 保存失败: {e}") finally: wb.close() else: logger.info(f" {excel_file.name} 无需更新") wb.close() except Exception as e: logger.error(f" 处理文件 {excel_file.name} 时发生错误: {e}", exc_info=True) logger.info(f"处理完成,成功更新 {success_count}/{len(excel_files)} 个文件") return success_count > 0
[文档] def update_mappings(self): """更新参数映射""" logger.info("=" * 60) logger.info(f"开始更新参数映射 (引擎: {self.engine_type})") logger.info("=" * 60) # 阶段1: 生成基础参数映射 param_file = Path(self.config.paths.param_config_dir) / f"param_data_{self.engine_type}.xlsx" if not param_file.exists(): logger.error(f"参数文件不存在: {param_file}") logger.info(f"请确保参数文件存在") return False logger.info(f"读取参数文件: {param_file}") mappings = self.read_param_file(param_file) if not mappings: logger.error("未能读取到任何参数映射") return False output_file = self.config.paths.param_config_dir / "param_mappings.py" logger.info(f"生成参数映射文件: {output_file}") self.generate_mappings_file(mappings, output_file) total_mappings = sum(len(m) for m in mappings.values()) logger.info(f"基础参数映射: {len(mappings)} 个工作表, {total_mappings} 个映射") # 阶段2: 生成差分参数映射 varient_file = Path(self.config.paths.param_config_dir) / "varient_data.xlsx" if varient_file.exists(): logger.info(f"读取差分参数文件: {varient_file}") # 差分参数文件不跳过模板工作表,保持与原项目一致 varient_mappings = self.read_param_file(varient_file, skip_template=False) # 生成差分映射文件(保持与原项目一致,包含空映射) varient_output = self.config.paths.param_config_dir / "varient_mappings.py" logger.info(f"生成差分参数映射文件: {varient_output}") self.generate_mappings_file(varient_mappings, varient_output) # 统计有效映射(排除模板) valid_mappings = {k: v for k, v in varient_mappings.items() if v and "模板" not in k} if valid_mappings: total_varient = sum(len(m) for m in valid_mappings.values()) logger.info(f"差分参数映射: {len(valid_mappings)} 个角色, {total_varient} 个映射") else: logger.info("差分参数文件中没有有效的角色映射") else: varient_file = None logger.info("差分参数文件不存在,跳过") # 阶段3: 更新演出表格的参数表 logger.info("=" * 60) logger.info("更新演出表格参数表") logger.info("=" * 60) validation_data = self.collect_validation_data(param_file, varient_file) if validation_data: logger.info(f"收集到 {len(validation_data)} 个参数类型的验证数据") self.update_scenario_param_sheets(validation_data) else: logger.warning("未能收集到验证数据,跳过演出表格更新") logger.info("=" * 60) logger.info(f"参数映射更新完成") logger.info("=" * 60) return True
[文档] def main(): """主函数""" try: # 加载配置 config_path = Path("config.yaml") if not config_path.exists(): logger.error("配置文件不存在: config.yaml") return config = AppConfig.from_file(config_path) # 创建更新器 updater = ParamUpdater(config) # 执行更新 success = updater.update_mappings() if success: logger.info("参数映射更新成功") else: logger.error("参数映射更新失败") except Exception as e: logger.critical(f"参数更新过程失败: {e}", exc_info=True) raise
if __name__ == "__main__": main()