citation_formatter.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. """
  2. 引用格式化器
  3. 实现搜索结果的引用标注格式化功能
  4. 需求: 7.1, 7.2, 7.3, 7.4
  5. """
  6. import re
  7. import logging
  8. from typing import Dict, List, Optional, Tuple
  9. from app.schemas.llm_schema import SearchResult
  10. logger = logging.getLogger(__name__)
  11. class CitationFormatter:
  12. """引用格式化器"""
  13. @staticmethod
  14. def extract_search_results(search_info: Dict) -> List[SearchResult]:
  15. """
  16. 从搜索信息中提取结果列表
  17. Args:
  18. search_info: 搜索信息字典
  19. Returns:
  20. 搜索结果列表
  21. """
  22. if not search_info or 'search_results' not in search_info:
  23. return []
  24. results = []
  25. for item in search_info['search_results']:
  26. try:
  27. result = SearchResult(
  28. index=item.get('index', len(results) + 1), # 如果没有index,使用序号
  29. title=item.get('title', ''),
  30. url=item.get('url', ''),
  31. snippet=item.get('snippet')
  32. )
  33. results.append(result)
  34. except Exception as e:
  35. logger.warning(f"解析搜索结果项时出错: {e}, 项目: {item}")
  36. continue
  37. logger.info(f"从搜索信息中提取到 {len(results)} 个搜索结果")
  38. return results
  39. @staticmethod
  40. def format_citations(
  41. content: str,
  42. search_results: List[SearchResult],
  43. format_type: str = "[<number>]"
  44. ) -> str:
  45. """
  46. 格式化引用标注
  47. Args:
  48. content: 原始内容
  49. search_results: 搜索结果列表
  50. format_type: 引用格式类型,支持 "[<number>]" 和 "[ref_<number>]"
  51. Returns:
  52. 格式化后的内容
  53. """
  54. if not content or not search_results:
  55. return content
  56. # 创建索引到搜索结果的映射
  57. index_to_result = {result.index: result for result in search_results}
  58. # 根据格式类型定义正则表达式和替换模式
  59. if format_type == "[<number>]":
  60. # 匹配 [1], [2] 等格式
  61. pattern = r'\[(\d+)\]'
  62. replacement_func = lambda match: CitationFormatter._format_single_citation(
  63. match, index_to_result, "[{}]"
  64. )
  65. elif format_type == "[ref_<number>]":
  66. # 匹配 [ref_1], [ref_2] 等格式
  67. pattern = r'\[ref_(\d+)\]'
  68. replacement_func = lambda match: CitationFormatter._format_single_citation(
  69. match, index_to_result, "[ref_{}]"
  70. )
  71. else:
  72. logger.warning(f"不支持的引用格式类型: {format_type}")
  73. return content
  74. # 执行替换
  75. try:
  76. formatted_content = re.sub(pattern, replacement_func, content)
  77. logger.info(f"完成引用格式化,格式类型: {format_type}")
  78. return formatted_content
  79. except Exception as e:
  80. logger.error(f"引用格式化时出错: {e}")
  81. return content
  82. @staticmethod
  83. def _format_single_citation(
  84. match,
  85. index_to_result: Dict[int, SearchResult],
  86. template: str
  87. ) -> str:
  88. """
  89. 格式化单个引用
  90. Args:
  91. match: 正则匹配对象
  92. index_to_result: 索引到搜索结果的映射
  93. template: 引用模板,如 "[{}]" 或 "[ref_{}]"
  94. Returns:
  95. 格式化后的引用字符串
  96. """
  97. try:
  98. index = int(match.group(1))
  99. if index in index_to_result:
  100. # 引用存在对应的搜索结果,保持原样
  101. return match.group(0)
  102. else:
  103. # 引用不存在对应的搜索结果,保持原样但记录警告
  104. logger.warning(f"引用索引 {index} 没有对应的搜索结果")
  105. return match.group(0)
  106. except (ValueError, IndexError) as e:
  107. logger.warning(f"解析引用索引时出错: {e}")
  108. return match.group(0)
  109. @staticmethod
  110. def append_source_list(content: str, search_results: List[SearchResult]) -> str:
  111. """
  112. 在内容末尾添加搜索来源列表
  113. Args:
  114. content: 原始内容
  115. search_results: 搜索结果列表
  116. Returns:
  117. 添加来源列表后的内容
  118. """
  119. if not search_results:
  120. return content
  121. # 构建来源列表
  122. source_lines = ["\n\n**参考来源:**"]
  123. for result in search_results:
  124. source_line = f"{result.index}. [{result.title}]({result.url})"
  125. if result.snippet:
  126. source_line += f" - {result.snippet}"
  127. source_lines.append(source_line)
  128. source_text = "\n".join(source_lines)
  129. logger.info(f"添加了 {len(search_results)} 个搜索来源")
  130. return content + source_text
  131. @staticmethod
  132. def format_content_with_citations_and_sources(
  133. content: str,
  134. search_info: Optional[Dict],
  135. enable_citation: bool = False,
  136. citation_format: str = "[<number>]",
  137. enable_source: bool = False
  138. ) -> Tuple[str, List[SearchResult]]:
  139. """
  140. 完整的内容格式化:引用标注 + 来源列表
  141. Args:
  142. content: 原始内容
  143. search_info: 搜索信息字典
  144. enable_citation: 是否启用引用标注
  145. citation_format: 引用格式类型
  146. enable_source: 是否启用来源列表
  147. Returns:
  148. 格式化后的内容和搜索结果列表的元组
  149. """
  150. if not search_info:
  151. return content, []
  152. # 提取搜索结果
  153. search_results = CitationFormatter.extract_search_results(search_info)
  154. if not search_results:
  155. return content, []
  156. formatted_content = content
  157. # 格式化引用标注
  158. if enable_citation:
  159. formatted_content = CitationFormatter.format_citations(
  160. formatted_content, search_results, citation_format
  161. )
  162. # 添加来源列表
  163. if enable_source:
  164. formatted_content = CitationFormatter.append_source_list(
  165. formatted_content, search_results
  166. )
  167. return formatted_content, search_results
  168. @staticmethod
  169. def validate_citation_format(format_type: str) -> bool:
  170. """
  171. 验证引用格式类型是否支持
  172. Args:
  173. format_type: 引用格式类型
  174. Returns:
  175. 是否支持该格式
  176. """
  177. supported_formats = ["[<number>]", "[ref_<number>]"]
  178. return format_type in supported_formats
  179. @staticmethod
  180. def extract_citation_indices(content: str, format_type: str = "[<number>]") -> List[int]:
  181. """
  182. 从内容中提取所有引用索引
  183. Args:
  184. content: 内容文本
  185. format_type: 引用格式类型
  186. Returns:
  187. 引用索引列表
  188. """
  189. if not content:
  190. return []
  191. indices = []
  192. try:
  193. if format_type == "[<number>]":
  194. pattern = r'\[(\d+)\]'
  195. elif format_type == "[ref_<number>]":
  196. pattern = r'\[ref_(\d+)\]'
  197. else:
  198. logger.warning(f"不支持的引用格式类型: {format_type}")
  199. return []
  200. matches = re.findall(pattern, content)
  201. indices = [int(match) for match in matches]
  202. indices = sorted(list(set(indices))) # 去重并排序
  203. logger.info(f"从内容中提取到 {len(indices)} 个引用索引: {indices}")
  204. except Exception as e:
  205. logger.error(f"提取引用索引时出错: {e}")
  206. return indices
  207. @staticmethod
  208. def validate_citations_completeness(
  209. content: str,
  210. search_results: List[SearchResult],
  211. format_type: str = "[<number>]"
  212. ) -> Dict[str, List[int]]:
  213. """
  214. 验证引用的完整性
  215. Args:
  216. content: 内容文本
  217. search_results: 搜索结果列表
  218. format_type: 引用格式类型
  219. Returns:
  220. 包含缺失引用和多余引用的字典
  221. """
  222. # 提取内容中的引用索引
  223. content_indices = set(CitationFormatter.extract_citation_indices(content, format_type))
  224. # 提取搜索结果中的索引
  225. result_indices = set(result.index for result in search_results)
  226. # 找出缺失的引用(搜索结果有但内容中没有引用)
  227. missing_citations = sorted(list(result_indices - content_indices))
  228. # 找出多余的引用(内容中有引用但搜索结果中没有)
  229. extra_citations = sorted(list(content_indices - result_indices))
  230. validation_result = {
  231. "missing_citations": missing_citations,
  232. "extra_citations": extra_citations
  233. }
  234. if missing_citations or extra_citations:
  235. logger.warning(f"引用完整性检查: 缺失引用 {missing_citations}, 多余引用 {extra_citations}")
  236. else:
  237. logger.info("引用完整性检查通过")
  238. return validation_result