query_rewrite.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. import uuid
  2. import asyncio
  3. from foundation.observability.logger.loggering import server_logger
  4. from foundation.ai.agent.generate.model_generate import generate_model_client
  5. class QueryRewriteManager():
  6. """
  7. 召回管理器,实现多路召回功能
  8. """
  9. def __init__(self):
  10. # 获取部署的模型列表
  11. self.generate_model_client = generate_model_client
  12. @property
  13. def prompt_loader(self):
  14. """延迟加载 prompt_loader,避免循环导入"""
  15. from core.construction_review.component.reviewers.utils import prompt_loader
  16. return prompt_loader
  17. def query_extract(self, review_content):
  18. """
  19. 从审查条文中提取query
  20. Args:
  21. review_content: 审查内容文本
  22. Returns:
  23. list: 标准格式的查询列表
  24. [
  25. {
  26. "entity": str, # 实体名称
  27. "search_keywords": list, # 搜索关键词列表
  28. "background": str, # 背景信息
  29. "parameter": str # 技术参数
  30. }
  31. ]
  32. 或 None(提取失败时)
  33. """
  34. try:
  35. # 获取提示词模板并组装
  36. task_prompt = self.prompt_loader.get_prompt_template(
  37. reviewer_type="query_extract", # 审查器类型
  38. review_content=review_content # 传入审查内容作为参数
  39. )
  40. # 构建任务提示信息 - 参考标准模式
  41. task_prompt_info = {
  42. "task_prompt": task_prompt, # 使用组装好的提示词
  43. "task_name": "query_extract"
  44. }
  45. # 生成唯一的trace_id用于追踪
  46. trace_id = str(uuid.uuid4())
  47. # 调用模型生成接口(处理异步调用)
  48. try:
  49. loop = asyncio.get_running_loop()
  50. # 如果已有运行中的事件循环,使用create_task
  51. import concurrent.futures
  52. with concurrent.futures.ThreadPoolExecutor() as executor:
  53. future = executor.submit(
  54. asyncio.run,
  55. self.generate_model_client.get_model_generate_invoke(
  56. trace_id=trace_id,
  57. task_prompt_info=task_prompt_info,
  58. timeout=60,
  59. model_name="qwen3_30b" # 修复: 使用正确的模型名称(下划线)
  60. )
  61. )
  62. model_response = future.result()
  63. except RuntimeError:
  64. # 没有运行中的事件循环,直接使用asyncio.run
  65. model_response = asyncio.run(self.generate_model_client.get_model_generate_invoke(
  66. trace_id=trace_id,
  67. task_prompt_info=task_prompt_info,
  68. timeout=60,
  69. model_name="qwen3_30b" # 修复: 使用正确的模型名称(下划线)
  70. ))
  71. # 格式化模型响应
  72. formatted_response = self.ai_respose_format(model_response)
  73. # 检查 formatted_response 是否为 None
  74. if formatted_response is not None:
  75. server_logger.info(f"查询对构建完成,构建 {len(formatted_response)}条。")
  76. else:
  77. server_logger.warning("查询对构建失败,formatted_response 为 None")
  78. # 记录日志
  79. if formatted_response:
  80. server_logger.info(f"Query 提取成功, 提取到 {len(formatted_response)} 个实体")
  81. else:
  82. server_logger.warning(f"Query 提取失败, 格式化后为空")
  83. return formatted_response
  84. except Exception as e:
  85. server_logger.error(f"Query 提取失败: {str(e)}")
  86. return None
  87. def ai_respose_format(self, model_response):
  88. """
  89. 将模型返回的响应格式化为标准格式
  90. Args:
  91. model_response: AI模型返回的原始响应(可能是字符串或已解析的JSON)
  92. Returns:
  93. list: 标准格式的查询列表
  94. [
  95. {
  96. "entity": str, # 实体名称
  97. "search_keywords": list, # 搜索关键词列表
  98. "background": str, # 背景信息
  99. "parameter": str # 技术参数
  100. }
  101. ]
  102. 或 None(解析失败时)
  103. """
  104. import re
  105. import json
  106. try:
  107. # 1. 如果model_response已经是list,直接返回
  108. if isinstance(model_response, list):
  109. server_logger.info(f"模型响应已是list格式, 包含 {len(model_response)} 个实体")
  110. return model_response
  111. # 2. 如果是dict,包装成list返回
  112. if isinstance(model_response, dict):
  113. server_logger.info("模型响应是dict格式, 包装为list")
  114. return [model_response]
  115. # 3. 如果是字符串,需要解析
  116. if isinstance(model_response, str):
  117. response_text = model_response.strip()
  118. server_logger.debug(f"原始响应字符串长度: {len(response_text)}")
  119. # 3.1 尝试去除 ```json 和 ``` 标记
  120. # 匹配 ```json ... ``` 或 ``` ... ```
  121. json_pattern = r'```(?:json)?\s*\n?(.*?)\n?```'
  122. json_match = re.search(json_pattern, response_text, re.DOTALL | re.IGNORECASE)
  123. if json_match:
  124. # 提取代码块中的JSON内容
  125. json_str = json_match.group(1).strip()
  126. server_logger.debug("检测到markdown代码块, 已提取纯JSON内容")
  127. else:
  128. # 如果没有代码块标记,尝试直接解析整个字符串
  129. json_str = response_text
  130. server_logger.debug("未检测到markdown代码块, 尝试直接解析")
  131. # 3.2 去除可能的Markdown注释或多余空白
  132. json_str = re.sub(r'\n+', '\n', json_str) # 多个换行压缩为一个
  133. json_str = json_str.strip()
  134. server_logger.debug(f"待解析的JSON字符串: {json_str[:200]}...")
  135. # 3.3 解析JSON
  136. parsed_data = json.loads(json_str)
  137. # 3.4 确保返回list格式
  138. if isinstance(parsed_data, list):
  139. server_logger.info(f"JSON解析成功, 提取到 {len(parsed_data)} 个实体")
  140. return parsed_data
  141. elif isinstance(parsed_data, dict):
  142. server_logger.info("JSON解析成功, 单个实体包装为list")
  143. return [parsed_data]
  144. server_logger.warning(f"无法识别的JSON格式: {type(parsed_data)}")
  145. return None
  146. server_logger.warning(f"无法识别的响应类型: {type(model_response)}")
  147. return None
  148. except json.JSONDecodeError as e:
  149. server_logger.error(f"JSON解析失败: {e}")
  150. server_logger.error(f"原始响应: {str(model_response)[:500]}")
  151. return None
  152. except Exception as e:
  153. server_logger.error(f"响应格式化异常: {e}")
  154. server_logger.error(f"原始响应: {str(model_response)[:500]}")
  155. return None
  156. query_rewrite_manager = QueryRewriteManager()