query_rewrite.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. import uuid
  2. import asyncio
  3. from foundation.observability.logger.loggering import server_logger
  4. from foundation.ai.agent.generate.model_generate import generate_model_client
  5. class QueryRewriteManager():
  6. """
  7. 召回管理器,实现多路召回功能
  8. """
  9. def __init__(self):
  10. # 获取部署的模型列表
  11. self.generate_model_client = generate_model_client
  12. @property
  13. def prompt_loader(self):
  14. """延迟加载 prompt_loader,避免循环导入"""
  15. from core.construction_review.component.reviewers.utils import prompt_loader
  16. return prompt_loader
  17. def query_extract(self, review_content):
  18. """
  19. 从审查条文中提取query
  20. Args:
  21. review_content: 审查内容文本
  22. Returns:
  23. list: 标准格式的查询列表
  24. [
  25. {
  26. "entity": str, # 实体名称
  27. "search_keywords": list, # 搜索关键词列表
  28. "background": str, # 背景信息
  29. "parameter": str # 技术参数
  30. }
  31. ]
  32. 或 None(提取失败时)
  33. """
  34. try:
  35. # 获取提示词模板并组装
  36. task_prompt = self.prompt_loader.get_prompt_template(
  37. reviewer_type="query_extract", # 审查器类型
  38. review_content=review_content # 传入审查内容作为参数
  39. )
  40. # 构建任务提示信息 - 参考标准模式
  41. task_prompt_info = {
  42. "task_prompt": task_prompt, # 使用组装好的提示词
  43. "task_name": "query_extract"
  44. }
  45. # 生成唯一的trace_id用于追踪
  46. trace_id = str(uuid.uuid4())
  47. # 调用模型生成接口(处理异步调用)
  48. try:
  49. loop = asyncio.get_running_loop()
  50. # 如果已有运行中的事件循环,使用create_task
  51. import concurrent.futures
  52. with concurrent.futures.ThreadPoolExecutor() as executor:
  53. future = executor.submit(
  54. asyncio.run,
  55. self.generate_model_client.get_model_generate_invoke(
  56. trace_id=trace_id,
  57. task_prompt_info=task_prompt_info,
  58. timeout=60
  59. )
  60. )
  61. model_response = future.result()
  62. except RuntimeError:
  63. # 没有运行中的事件循环,直接使用asyncio.run
  64. model_response = asyncio.run(self.generate_model_client.get_model_generate_invoke(
  65. trace_id=trace_id,
  66. task_prompt_info=task_prompt_info,
  67. timeout=60
  68. ))
  69. # 格式化模型响应
  70. formatted_response = self.ai_respose_format(model_response)
  71. # 检查 formatted_response 是否为 None
  72. if formatted_response is not None:
  73. server_logger.info(f"查询对构建完成,构建 {len(formatted_response)}条。")
  74. else:
  75. server_logger.warning("查询对构建失败,formatted_response 为 None")
  76. # 记录日志
  77. if formatted_response:
  78. server_logger.info(f"Query 提取成功, 提取到 {len(formatted_response)} 个实体")
  79. else:
  80. server_logger.warning(f"Query 提取失败, 格式化后为空")
  81. return formatted_response
  82. except Exception as e:
  83. server_logger.error(f"Query 提取失败: {str(e)}")
  84. return None
  85. def ai_respose_format(self, model_response):
  86. """
  87. 将模型返回的响应格式化为标准格式
  88. Args:
  89. model_response: AI模型返回的原始响应(可能是字符串或已解析的JSON)
  90. Returns:
  91. list: 标准格式的查询列表
  92. [
  93. {
  94. "entity": str, # 实体名称
  95. "search_keywords": list, # 搜索关键词列表
  96. "background": str, # 背景信息
  97. "parameter": str # 技术参数
  98. }
  99. ]
  100. 或 None(解析失败时)
  101. """
  102. import re
  103. import json
  104. try:
  105. # 1. 如果model_response已经是list,直接返回
  106. if isinstance(model_response, list):
  107. server_logger.info(f"模型响应已是list格式, 包含 {len(model_response)} 个实体")
  108. return model_response
  109. # 2. 如果是dict,包装成list返回
  110. if isinstance(model_response, dict):
  111. server_logger.info("模型响应是dict格式, 包装为list")
  112. return [model_response]
  113. # 3. 如果是字符串,需要解析
  114. if isinstance(model_response, str):
  115. response_text = model_response.strip()
  116. server_logger.debug(f"原始响应字符串长度: {len(response_text)}")
  117. # 3.1 尝试去除 ```json 和 ``` 标记
  118. # 匹配 ```json ... ``` 或 ``` ... ```
  119. json_pattern = r'```(?:json)?\s*\n?(.*?)\n?```'
  120. json_match = re.search(json_pattern, response_text, re.DOTALL | re.IGNORECASE)
  121. if json_match:
  122. # 提取代码块中的JSON内容
  123. json_str = json_match.group(1).strip()
  124. server_logger.debug("检测到markdown代码块, 已提取纯JSON内容")
  125. else:
  126. # 如果没有代码块标记,尝试直接解析整个字符串
  127. json_str = response_text
  128. server_logger.debug("未检测到markdown代码块, 尝试直接解析")
  129. # 3.2 去除可能的Markdown注释或多余空白
  130. json_str = re.sub(r'\n+', '\n', json_str) # 多个换行压缩为一个
  131. json_str = json_str.strip()
  132. server_logger.debug(f"待解析的JSON字符串: {json_str[:200]}...")
  133. # 3.3 解析JSON
  134. parsed_data = json.loads(json_str)
  135. # 3.4 确保返回list格式
  136. if isinstance(parsed_data, list):
  137. server_logger.info(f"JSON解析成功, 提取到 {len(parsed_data)} 个实体")
  138. return parsed_data
  139. elif isinstance(parsed_data, dict):
  140. server_logger.info("JSON解析成功, 单个实体包装为list")
  141. return [parsed_data]
  142. server_logger.warning(f"无法识别的JSON格式: {type(parsed_data)}")
  143. return None
  144. server_logger.warning(f"无法识别的响应类型: {type(model_response)}")
  145. return None
  146. except json.JSONDecodeError as e:
  147. server_logger.error(f"JSON解析失败: {e}")
  148. server_logger.error(f"原始响应: {str(model_response)[:500]}")
  149. return None
  150. except Exception as e:
  151. server_logger.error(f"响应格式化异常: {e}")
  152. server_logger.error(f"原始响应: {str(model_response)[:500]}")
  153. return None
  154. query_rewrite_manager = QueryRewriteManager()