""" Prompt 加载器 用于加载和管理 prompt 模板,实现 prompts 与代码分离 """ import os import yaml from pathlib import Path from typing import Dict, Optional from utils.logger import logger class PromptLoader: """Prompt 模板加载器""" def __init__(self, config_path: str = "config/prompt_config.yaml", base_dir: str = None): """ 初始化 Prompt 加载器 Args: config_path: prompt配置文件路径(相对于base_dir) base_dir: 基础目录,默认为项目根目录 """ self.base_dir = base_dir or os.path.dirname(os.path.dirname(os.path.abspath(__file__))) self.config_path = os.path.join(self.base_dir, config_path) self.config = {} self.cache = {} # prompt内容缓存 self.cache_enabled = True # 加载配置 self._load_config() def _load_config(self): """加载配置文件""" try: if not os.path.exists(self.config_path): logger.warning(f"Prompt配置文件不存在: {self.config_path}") return with open(self.config_path, 'r', encoding='utf-8') as f: self.config = yaml.safe_load(f) or {} # 读取默认配置 defaults = self.config.get('defaults', {}) self.cache_enabled = defaults.get('cache_enabled', True) logger.info(f"Prompt配置加载成功,共 {len(self.config.get('prompts', {}))} 个模板") except Exception as e: logger.error(f"加载Prompt配置失败: {e}") self.config = {} def _read_prompt_file(self, file_path: str, encoding: str = 'utf-8') -> str: """ 读取prompt文件内容 Args: file_path: 文件路径(相对于base_dir) encoding: 文件编码 Returns: 文件内容 """ full_path = os.path.join(self.base_dir, file_path) try: if not os.path.exists(full_path): logger.error(f"Prompt文件不存在: {full_path}") return "" with open(full_path, 'r', encoding=encoding) as f: content = f.read() return content except Exception as e: logger.error(f"读取Prompt文件失败 {full_path}: {e}") return "" def get_prompt(self, prompt_key: str, **variables) -> str: """ 获取prompt模板并替换变量 Args: prompt_key: prompt配置中的key **variables: 要替换的变量,如 context="xxx", question="xxx" Returns: 处理后的prompt内容 """ # 检查配置 prompts_config = self.config.get('prompts', {}) if prompt_key not in prompts_config: logger.error(f"未找到prompt配置: {prompt_key}") return "" prompt_info = prompts_config[prompt_key] file_path = prompt_info.get('file', '') encoding = prompt_info.get('encoding', 'utf-8') # 从缓存或文件读取内容 if self.cache_enabled and prompt_key in self.cache: content = self.cache[prompt_key] else: content = self._read_prompt_file(file_path, encoding) if self.cache_enabled: self.cache[prompt_key] = content # 替换变量 if variables: content = self._replace_variables(content, variables) return content def _replace_variables(self, content: str, variables: Dict) -> str: """ 替换prompt中的变量 支持格式: - {variable_name} - ${variable_name} Args: content: prompt内容 variables: 变量字典 Returns: 替换后的内容 """ result = content for key, value in variables.items(): # 转换为字符串 value_str = str(value) if value is not None else "" # 替换 {key} 格式 result = result.replace(f"{{{key}}}", value_str) # 替换 ${key} 格式(可选) result = result.replace(f"${{{key}}}", value_str) return result def reload_prompt(self, prompt_key: str): """ 重新加载指定prompt(清除缓存) Args: prompt_key: prompt配置中的key """ if prompt_key in self.cache: del self.cache[prompt_key] logger.info(f"已清除prompt缓存: {prompt_key}") def reload_all(self): """重新加载所有prompts(清除所有缓存)""" self.cache.clear() self._load_config() logger.info("已重新加载所有prompt配置和缓存") def get_prompt_info(self, prompt_key: str) -> Optional[Dict]: """ 获取prompt的配置信息 Args: prompt_key: prompt配置中的key Returns: 配置信息字典 """ prompts_config = self.config.get('prompts', {}) return prompts_config.get(prompt_key) def list_prompts(self) -> Dict: """ 列出所有可用的prompt Returns: prompt配置字典 """ return self.config.get('prompts', {}) # 全局单例 _prompt_loader = None def get_prompt_loader(config_path: str = "config/prompt_config.yaml") -> PromptLoader: """ 获取全局PromptLoader实例(单例模式) Args: config_path: 配置文件路径 Returns: PromptLoader实例 """ global _prompt_loader if _prompt_loader is None: _prompt_loader = PromptLoader(config_path=config_path) return _prompt_loader # 便捷函数 def load_prompt(prompt_key: str, **variables) -> str: """ 便捷函数:加载prompt Args: prompt_key: prompt配置中的key **variables: 变量 Returns: 处理后的prompt内容 """ loader = get_prompt_loader() return loader.get_prompt(prompt_key, **variables)