prompt_loader.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. """
  2. Prompt 加载器
  3. 用于加载和管理 prompt 模板,实现 prompts 与代码分离
  4. """
  5. import os
  6. import yaml
  7. from pathlib import Path
  8. from typing import Dict, Optional
  9. from utils.logger import logger
  10. class PromptLoader:
  11. """Prompt 模板加载器"""
  12. def __init__(self, config_path: str = "config/prompt_config.yaml", base_dir: str = None):
  13. """
  14. 初始化 Prompt 加载器
  15. Args:
  16. config_path: prompt配置文件路径(相对于base_dir)
  17. base_dir: 基础目录,默认为项目根目录
  18. """
  19. self.base_dir = base_dir or os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
  20. self.config_path = os.path.join(self.base_dir, config_path)
  21. self.config = {}
  22. self.cache = {} # prompt内容缓存
  23. self.cache_enabled = True
  24. # 加载配置
  25. self._load_config()
  26. def _load_config(self):
  27. """加载配置文件"""
  28. try:
  29. if not os.path.exists(self.config_path):
  30. logger.warning(f"Prompt配置文件不存在: {self.config_path}")
  31. return
  32. with open(self.config_path, 'r', encoding='utf-8') as f:
  33. self.config = yaml.safe_load(f) or {}
  34. # 读取默认配置
  35. defaults = self.config.get('defaults', {})
  36. self.cache_enabled = defaults.get('cache_enabled', True)
  37. logger.info(f"Prompt配置加载成功,共 {len(self.config.get('prompts', {}))} 个模板")
  38. except Exception as e:
  39. logger.error(f"加载Prompt配置失败: {e}")
  40. self.config = {}
  41. def _read_prompt_file(self, file_path: str, encoding: str = 'utf-8') -> str:
  42. """
  43. 读取prompt文件内容
  44. Args:
  45. file_path: 文件路径(相对于base_dir)
  46. encoding: 文件编码
  47. Returns:
  48. 文件内容
  49. """
  50. full_path = os.path.join(self.base_dir, file_path)
  51. try:
  52. if not os.path.exists(full_path):
  53. logger.error(f"Prompt文件不存在: {full_path}")
  54. return ""
  55. with open(full_path, 'r', encoding=encoding) as f:
  56. content = f.read()
  57. return content
  58. except Exception as e:
  59. logger.error(f"读取Prompt文件失败 {full_path}: {e}")
  60. return ""
  61. def get_prompt(self, prompt_key: str, **variables) -> str:
  62. """
  63. 获取prompt模板并替换变量
  64. Args:
  65. prompt_key: prompt配置中的key
  66. **variables: 要替换的变量,如 context="xxx", question="xxx"
  67. Returns:
  68. 处理后的prompt内容
  69. """
  70. # 检查配置
  71. prompts_config = self.config.get('prompts', {})
  72. if prompt_key not in prompts_config:
  73. logger.error(f"未找到prompt配置: {prompt_key}")
  74. return ""
  75. prompt_info = prompts_config[prompt_key]
  76. file_path = prompt_info.get('file', '')
  77. encoding = prompt_info.get('encoding', 'utf-8')
  78. # 从缓存或文件读取内容
  79. if self.cache_enabled and prompt_key in self.cache:
  80. content = self.cache[prompt_key]
  81. else:
  82. content = self._read_prompt_file(file_path, encoding)
  83. if self.cache_enabled:
  84. self.cache[prompt_key] = content
  85. # 替换变量
  86. if variables:
  87. content = self._replace_variables(content, variables)
  88. return content
  89. def _replace_variables(self, content: str, variables: Dict) -> str:
  90. """
  91. 替换prompt中的变量
  92. 支持格式:
  93. - {variable_name}
  94. - ${variable_name}
  95. Args:
  96. content: prompt内容
  97. variables: 变量字典
  98. Returns:
  99. 替换后的内容
  100. """
  101. result = content
  102. for key, value in variables.items():
  103. # 转换为字符串
  104. value_str = str(value) if value is not None else ""
  105. # 替换 {key} 格式
  106. result = result.replace(f"{{{key}}}", value_str)
  107. # 替换 ${key} 格式(可选)
  108. result = result.replace(f"${{{key}}}", value_str)
  109. return result
  110. def reload_prompt(self, prompt_key: str):
  111. """
  112. 重新加载指定prompt(清除缓存)
  113. Args:
  114. prompt_key: prompt配置中的key
  115. """
  116. if prompt_key in self.cache:
  117. del self.cache[prompt_key]
  118. logger.info(f"已清除prompt缓存: {prompt_key}")
  119. def reload_all(self):
  120. """重新加载所有prompts(清除所有缓存)"""
  121. self.cache.clear()
  122. self._load_config()
  123. logger.info("已重新加载所有prompt配置和缓存")
  124. def get_prompt_info(self, prompt_key: str) -> Optional[Dict]:
  125. """
  126. 获取prompt的配置信息
  127. Args:
  128. prompt_key: prompt配置中的key
  129. Returns:
  130. 配置信息字典
  131. """
  132. prompts_config = self.config.get('prompts', {})
  133. return prompts_config.get(prompt_key)
  134. def list_prompts(self) -> Dict:
  135. """
  136. 列出所有可用的prompt
  137. Returns:
  138. prompt配置字典
  139. """
  140. return self.config.get('prompts', {})
  141. # 全局单例
  142. _prompt_loader = None
  143. def get_prompt_loader(config_path: str = "config/prompt_config.yaml") -> PromptLoader:
  144. """
  145. 获取全局PromptLoader实例(单例模式)
  146. Args:
  147. config_path: 配置文件路径
  148. Returns:
  149. PromptLoader实例
  150. """
  151. global _prompt_loader
  152. if _prompt_loader is None:
  153. _prompt_loader = PromptLoader(config_path=config_path)
  154. return _prompt_loader
  155. # 便捷函数
  156. def load_prompt(prompt_key: str, **variables) -> str:
  157. """
  158. 便捷函数:加载prompt
  159. Args:
  160. prompt_key: prompt配置中的key
  161. **variables: 变量
  162. Returns:
  163. 处理后的prompt内容
  164. """
  165. loader = get_prompt_loader()
  166. return loader.get_prompt(prompt_key, **variables)