| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215 |
- """
- Prompt 加载器
- 用于加载和管理 prompt 模板,实现 prompts 与代码分离
- """
- import os
- import yaml
- from pathlib import Path
- from typing import Dict, Optional
- from utils.logger import logger
- class PromptLoader:
- """Prompt 模板加载器"""
-
- def __init__(self, config_path: str = "config/prompt_config.yaml", base_dir: str = None):
- """
- 初始化 Prompt 加载器
-
- Args:
- config_path: prompt配置文件路径(相对于base_dir)
- base_dir: 基础目录,默认为项目根目录
- """
- self.base_dir = base_dir or os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
- self.config_path = os.path.join(self.base_dir, config_path)
- self.config = {}
- self.cache = {} # prompt内容缓存
- self.cache_enabled = True
-
- # 加载配置
- self._load_config()
-
- def _load_config(self):
- """加载配置文件"""
- try:
- if not os.path.exists(self.config_path):
- logger.warning(f"Prompt配置文件不存在: {self.config_path}")
- return
-
- with open(self.config_path, 'r', encoding='utf-8') as f:
- self.config = yaml.safe_load(f) or {}
-
- # 读取默认配置
- defaults = self.config.get('defaults', {})
- self.cache_enabled = defaults.get('cache_enabled', True)
-
- logger.info(f"Prompt配置加载成功,共 {len(self.config.get('prompts', {}))} 个模板")
-
- except Exception as e:
- logger.error(f"加载Prompt配置失败: {e}")
- self.config = {}
-
- def _read_prompt_file(self, file_path: str, encoding: str = 'utf-8') -> str:
- """
- 读取prompt文件内容
-
- Args:
- file_path: 文件路径(相对于base_dir)
- encoding: 文件编码
-
- Returns:
- 文件内容
- """
- full_path = os.path.join(self.base_dir, file_path)
-
- try:
- if not os.path.exists(full_path):
- logger.error(f"Prompt文件不存在: {full_path}")
- return ""
-
- with open(full_path, 'r', encoding=encoding) as f:
- content = f.read()
-
- return content
-
- except Exception as e:
- logger.error(f"读取Prompt文件失败 {full_path}: {e}")
- return ""
-
- def get_prompt(self, prompt_key: str, **variables) -> str:
- """
- 获取prompt模板并替换变量
-
- Args:
- prompt_key: prompt配置中的key
- **variables: 要替换的变量,如 context="xxx", question="xxx"
-
- Returns:
- 处理后的prompt内容
- """
- # 检查配置
- prompts_config = self.config.get('prompts', {})
- if prompt_key not in prompts_config:
- logger.error(f"未找到prompt配置: {prompt_key}")
- return ""
-
- prompt_info = prompts_config[prompt_key]
- file_path = prompt_info.get('file', '')
- encoding = prompt_info.get('encoding', 'utf-8')
-
- # 从缓存或文件读取内容
- if self.cache_enabled and prompt_key in self.cache:
- content = self.cache[prompt_key]
- else:
- content = self._read_prompt_file(file_path, encoding)
- if self.cache_enabled:
- self.cache[prompt_key] = content
-
- # 替换变量
- if variables:
- content = self._replace_variables(content, variables)
-
- return content
-
- def _replace_variables(self, content: str, variables: Dict) -> str:
- """
- 替换prompt中的变量
-
- 支持格式:
- - {variable_name}
- - ${variable_name}
-
- Args:
- content: prompt内容
- variables: 变量字典
-
- Returns:
- 替换后的内容
- """
- result = content
-
- for key, value in variables.items():
- # 转换为字符串
- value_str = str(value) if value is not None else ""
-
- # 替换 {key} 格式
- result = result.replace(f"{{{key}}}", value_str)
- # 替换 ${key} 格式(可选)
- result = result.replace(f"${{{key}}}", value_str)
-
- return result
-
- def reload_prompt(self, prompt_key: str):
- """
- 重新加载指定prompt(清除缓存)
-
- Args:
- prompt_key: prompt配置中的key
- """
- if prompt_key in self.cache:
- del self.cache[prompt_key]
- logger.info(f"已清除prompt缓存: {prompt_key}")
-
- def reload_all(self):
- """重新加载所有prompts(清除所有缓存)"""
- self.cache.clear()
- self._load_config()
- logger.info("已重新加载所有prompt配置和缓存")
-
- def get_prompt_info(self, prompt_key: str) -> Optional[Dict]:
- """
- 获取prompt的配置信息
-
- Args:
- prompt_key: prompt配置中的key
-
- Returns:
- 配置信息字典
- """
- prompts_config = self.config.get('prompts', {})
- return prompts_config.get(prompt_key)
-
- def list_prompts(self) -> Dict:
- """
- 列出所有可用的prompt
-
- Returns:
- prompt配置字典
- """
- return self.config.get('prompts', {})
- # 全局单例
- _prompt_loader = None
- def get_prompt_loader(config_path: str = "config/prompt_config.yaml") -> PromptLoader:
- """
- 获取全局PromptLoader实例(单例模式)
-
- Args:
- config_path: 配置文件路径
-
- Returns:
- PromptLoader实例
- """
- global _prompt_loader
- if _prompt_loader is None:
- _prompt_loader = PromptLoader(config_path=config_path)
- return _prompt_loader
- # 便捷函数
- def load_prompt(prompt_key: str, **variables) -> str:
- """
- 便捷函数:加载prompt
-
- Args:
- prompt_key: prompt配置中的key
- **variables: 变量
-
- Returns:
- 处理后的prompt内容
- """
- loader = get_prompt_loader()
- return loader.get_prompt(prompt_key, **variables)
|