query_rewrite.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. import uuid
  2. from foundation.observability.logger.loggering import review_logger as server_logger
  3. from foundation.ai.agent.generate.model_generate import generate_model_client
  4. class QueryRewriteManager():
  5. """
  6. 查询改写管理器 — 从施工方案文本中提取审查要点
  7. """
  8. def __init__(self):
  9. self.generate_model_client = generate_model_client
  10. @property
  11. def prompt_loader(self):
  12. """延迟加载 prompt_loader,避免循环导入"""
  13. from core.construction_review.component.reviewers.utils import prompt_loader
  14. return prompt_loader
  15. def query_extract(self, review_content):
  16. """
  17. 从审查条文中提取审查要点 (review points)
  18. Args:
  19. review_content: 审查内容文本
  20. Returns:
  21. list: 审查要点列表
  22. [
  23. {
  24. "label": str, # 审查要点标签
  25. "search_queries": list, # 规范检索语句
  26. "original_text": str, # 原文摘录
  27. "parameter": str, # 技术参数
  28. # --- 向后兼容别名 (由 _add_backward_compat_aliases 自动添加) ---
  29. "entity": str, # = label
  30. "search_keywords": list, # = search_queries
  31. "background": str, # = original_text
  32. }
  33. ]
  34. 或 None(提取失败时)
  35. """
  36. try:
  37. # 获取提示词模板并组装 — 优先使用新 key,回退到旧 key
  38. task_prompt = self.prompt_loader.get_prompt_template(
  39. reviewer_type="review_point_extract",
  40. prompt_name="review_point_extract",
  41. review_content=review_content
  42. )
  43. task_prompt_info = {
  44. "task_prompt": task_prompt,
  45. "task_name": "review_point_extract"
  46. }
  47. trace_id = str(uuid.uuid4())
  48. # 调用模型 — function_name 对应 model_setting.yaml 中的配置
  49. model_response = self.generate_model_client.get_model_generate_invoke_sync(
  50. trace_id=trace_id,
  51. task_prompt_info=task_prompt_info,
  52. timeout=60,
  53. function_name="review_point_extract"
  54. )
  55. # 格式化模型响应
  56. formatted_response = self.ai_respose_format(model_response)
  57. if formatted_response:
  58. # 添加向后兼容字段别名
  59. formatted_response = self._add_backward_compat_aliases(formatted_response)
  60. server_logger.info(f"审查要点提取完成, 提取到 {len(formatted_response)} 个要点")
  61. else:
  62. server_logger.warning("审查要点提取失败, 格式化后为空")
  63. return formatted_response
  64. except Exception as e:
  65. server_logger.error(f"审查要点提取失败: {str(e)}")
  66. return None
  67. def _add_backward_compat_aliases(self, review_points):
  68. """
  69. 为每个审查要点添加双向字段别名,确保新旧格式都能工作
  70. 新字段 → 旧字段: label→entity, search_queries→search_keywords, original_text→background
  71. 旧字段 → 新字段: entity→label, search_keywords→search_queries, background→original_text
  72. """
  73. for point in review_points:
  74. # 新 → 旧(LLM 使用新格式时)
  75. if 'label' in point and 'entity' not in point:
  76. point['entity'] = point['label']
  77. if 'search_queries' in point and 'search_keywords' not in point:
  78. point['search_keywords'] = point['search_queries']
  79. if 'original_text' in point and 'background' not in point:
  80. point['background'] = point['original_text']
  81. # 旧 → 新(LLM 使用旧格式时)
  82. if 'entity' in point and 'label' not in point:
  83. point['label'] = point['entity']
  84. if 'search_keywords' in point and 'search_queries' not in point:
  85. point['search_queries'] = point['search_keywords']
  86. if 'background' in point and 'original_text' not in point:
  87. point['original_text'] = point['background']
  88. return review_points
  89. def ai_respose_format(self, model_response):
  90. """
  91. 将模型返回的响应格式化为标准格式
  92. Args:
  93. model_response: AI模型返回的原始响应(可能是字符串或已解析的JSON)
  94. Returns:
  95. list: 标准格式的审查要点列表, 或 None(解析失败时)
  96. """
  97. import re
  98. import json
  99. try:
  100. # 1. 如果model_response已经是list,直接返回
  101. if isinstance(model_response, list):
  102. server_logger.info(f"模型响应已是list格式, 包含 {len(model_response)} 个要点")
  103. return model_response
  104. # 2. 如果是dict,包装成list返回
  105. if isinstance(model_response, dict):
  106. server_logger.info("模型响应是dict格式, 包装为list")
  107. return [model_response]
  108. # 3. 如果是字符串,需要解析
  109. if isinstance(model_response, str):
  110. response_text = model_response.strip()
  111. # 3.1 尝试去除 ```json 和 ``` 标记
  112. json_pattern = r'```(?:json)?\s*\n?(.*?)\n?```'
  113. json_match = re.search(json_pattern, response_text, re.DOTALL | re.IGNORECASE)
  114. if json_match:
  115. json_str = json_match.group(1).strip()
  116. else:
  117. json_str = response_text
  118. # 3.2 去除可能的Markdown注释或多余空白
  119. json_str = re.sub(r'\n+', '\n', json_str)
  120. json_str = json_str.strip()
  121. # 3.3 解析JSON
  122. parsed_data = json.loads(json_str)
  123. # 3.4 确保返回list格式
  124. if isinstance(parsed_data, list):
  125. server_logger.info(f"JSON解析成功, 提取到 {len(parsed_data)} 个审查要点")
  126. return parsed_data
  127. elif isinstance(parsed_data, dict):
  128. server_logger.info("JSON解析成功, 单个要点包装为list")
  129. return [parsed_data]
  130. server_logger.warning(f"无法识别的JSON格式: {type(parsed_data)}")
  131. return None
  132. server_logger.warning(f"无法识别的响应类型: {type(model_response)}")
  133. return None
  134. except json.JSONDecodeError as e:
  135. server_logger.error(f"JSON解析失败: {e}")
  136. server_logger.error(f"原始响应: {str(model_response)[:500]}")
  137. return None
  138. except Exception as e:
  139. server_logger.error(f"响应格式化异常: {e}")
  140. server_logger.error(f"原始响应: {str(model_response)[:500]}")
  141. return None
  142. query_rewrite_manager = QueryRewriteManager()