query_rewrite.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. import uuid
  2. import asyncio
  3. from foundation.observability.logger.loggering import server_logger
  4. from foundation.ai.agent.generate.model_generate import generate_model_client
  5. class QueryRewriteManager():
  6. """
  7. 召回管理器,实现多路召回功能
  8. """
  9. def __init__(self):
  10. # 获取部署的模型列表
  11. self.generate_model_client = generate_model_client
  12. @property
  13. def prompt_loader(self):
  14. """延迟加载 prompt_loader,避免循环导入"""
  15. from core.construction_review.component.reviewers.utils import prompt_loader
  16. return prompt_loader
  17. def query_extract(self, review_content):
  18. """
  19. 从审查条文中提取query
  20. Args:
  21. review_content: 审查内容文本
  22. Returns:
  23. list: 标准格式的查询列表
  24. [
  25. {
  26. "entity": str, # 实体名称
  27. "search_keywords": list, # 搜索关键词列表
  28. "background": str, # 背景信息
  29. "parameter": str # 技术参数
  30. }
  31. ]
  32. 或 None(提取失败时)
  33. """
  34. try:
  35. # 获取提示词模板并组装
  36. task_prompt = self.prompt_loader.get_prompt_template(
  37. reviewer_type="query_extract", # 审查器类型
  38. review_content=review_content # 传入审查内容作为参数
  39. )
  40. # 构建任务提示信息 - 参考标准模式
  41. task_prompt_info = {
  42. "task_prompt": task_prompt, # 使用组装好的提示词
  43. "task_name": "query_extract"
  44. }
  45. # 生成唯一的trace_id用于追踪
  46. trace_id = str(uuid.uuid4())
  47. # 调用模型生成接口(处理异步调用)
  48. try:
  49. loop = asyncio.get_running_loop()
  50. # 如果已有运行中的事件循环,使用create_task
  51. import concurrent.futures
  52. with concurrent.futures.ThreadPoolExecutor() as executor:
  53. future = executor.submit(
  54. asyncio.run,
  55. self.generate_model_client.get_model_generate_invoke(
  56. trace_id=trace_id,
  57. task_prompt_info=task_prompt_info
  58. )
  59. )
  60. model_response = future.result()
  61. except RuntimeError:
  62. # 没有运行中的事件循环,直接使用asyncio.run
  63. model_response = asyncio.run(self.generate_model_client.get_model_generate_invoke(
  64. trace_id=trace_id,
  65. task_prompt_info=task_prompt_info
  66. ))
  67. # 格式化模型响应
  68. formatted_response = self.ai_respose_format(model_response)
  69. server_logger.info(f"查询对构建完成,构建 {len(formatted_response)}条。")
  70. # 记录日志
  71. if formatted_response:
  72. server_logger.info(f"Query 提取成功, 提取到 {len(formatted_response)} 个实体")
  73. else:
  74. server_logger.warning(f"Query 提取失败, 格式化后为空")
  75. return formatted_response
  76. except Exception as e:
  77. server_logger.error(f"Query 提取失败: {str(e)}")
  78. return None
  79. def ai_respose_format(self, model_response):
  80. """
  81. 将模型返回的响应格式化为标准格式
  82. Args:
  83. model_response: AI模型返回的原始响应(可能是字符串或已解析的JSON)
  84. Returns:
  85. list: 标准格式的查询列表
  86. [
  87. {
  88. "entity": str, # 实体名称
  89. "search_keywords": list, # 搜索关键词列表
  90. "background": str, # 背景信息
  91. "parameter": str # 技术参数
  92. }
  93. ]
  94. 或 None(解析失败时)
  95. """
  96. import re
  97. import json
  98. try:
  99. # 1. 如果model_response已经是list,直接返回
  100. if isinstance(model_response, list):
  101. server_logger.info(f"模型响应已是list格式, 包含 {len(model_response)} 个实体")
  102. return model_response
  103. # 2. 如果是dict,包装成list返回
  104. if isinstance(model_response, dict):
  105. server_logger.info("模型响应是dict格式, 包装为list")
  106. return [model_response]
  107. # 3. 如果是字符串,需要解析
  108. if isinstance(model_response, str):
  109. response_text = model_response.strip()
  110. server_logger.debug(f"原始响应字符串长度: {len(response_text)}")
  111. # 3.1 尝试去除 ```json 和 ``` 标记
  112. # 匹配 ```json ... ``` 或 ``` ... ```
  113. json_pattern = r'```(?:json)?\s*\n?(.*?)\n?```'
  114. json_match = re.search(json_pattern, response_text, re.DOTALL | re.IGNORECASE)
  115. if json_match:
  116. # 提取代码块中的JSON内容
  117. json_str = json_match.group(1).strip()
  118. server_logger.debug("检测到markdown代码块, 已提取纯JSON内容")
  119. else:
  120. # 如果没有代码块标记,尝试直接解析整个字符串
  121. json_str = response_text
  122. server_logger.debug("未检测到markdown代码块, 尝试直接解析")
  123. # 3.2 去除可能的Markdown注释或多余空白
  124. json_str = re.sub(r'\n+', '\n', json_str) # 多个换行压缩为一个
  125. json_str = json_str.strip()
  126. server_logger.debug(f"待解析的JSON字符串: {json_str[:200]}...")
  127. # 3.3 解析JSON
  128. parsed_data = json.loads(json_str)
  129. # 3.4 确保返回list格式
  130. if isinstance(parsed_data, list):
  131. server_logger.info(f"JSON解析成功, 提取到 {len(parsed_data)} 个实体")
  132. return parsed_data
  133. elif isinstance(parsed_data, dict):
  134. server_logger.info("JSON解析成功, 单个实体包装为list")
  135. return [parsed_data]
  136. server_logger.warning(f"无法识别的JSON格式: {type(parsed_data)}")
  137. return None
  138. server_logger.warning(f"无法识别的响应类型: {type(model_response)}")
  139. return None
  140. except json.JSONDecodeError as e:
  141. server_logger.error(f"JSON解析失败: {e}")
  142. server_logger.error(f"原始响应: {str(model_response)[:500]}")
  143. return None
  144. except Exception as e:
  145. server_logger.error(f"响应格式化异常: {e}")
  146. server_logger.error(f"原始响应: {str(model_response)[:500]}")
  147. return None
  148. query_rewrite_manager = QueryRewriteManager()