| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187 |
- import uuid
- import asyncio
- from foundation.observability.logger.loggering import server_logger
- from foundation.ai.agent.generate.model_generate import generate_model_client
- class QueryRewriteManager():
- """
- 召回管理器,实现多路召回功能
- """
- def __init__(self):
- # 获取部署的模型列表
- self.generate_model_client = generate_model_client
- @property
- def prompt_loader(self):
- """延迟加载 prompt_loader,避免循环导入"""
- from core.construction_review.component.reviewers.utils import prompt_loader
- return prompt_loader
- def query_extract(self, review_content):
- """
- 从审查条文中提取query
- Args:
- review_content: 审查内容文本
- Returns:
- list: 标准格式的查询列表
- [
- {
- "entity": str, # 实体名称
- "search_keywords": list, # 搜索关键词列表
- "background": str, # 背景信息
- "parameter": str # 技术参数
- }
- ]
- 或 None(提取失败时)
- """
- try:
- # 获取提示词模板并组装
- task_prompt = self.prompt_loader.get_prompt_template(
- reviewer_type="query_extract", # 审查器类型
- review_content=review_content # 传入审查内容作为参数
- )
- # 构建任务提示信息 - 参考标准模式
- task_prompt_info = {
- "task_prompt": task_prompt, # 使用组装好的提示词
- "task_name": "query_extract"
- }
- # 生成唯一的trace_id用于追踪
- trace_id = str(uuid.uuid4())
- # 调用模型生成接口(处理异步调用)
- try:
- loop = asyncio.get_running_loop()
- # 如果已有运行中的事件循环,使用create_task
- import concurrent.futures
- with concurrent.futures.ThreadPoolExecutor() as executor:
- future = executor.submit(
- asyncio.run,
- self.generate_model_client.get_model_generate_invoke(
- trace_id=trace_id,
- task_prompt_info=task_prompt_info,
- timeout=60,
- model_name="qwen3_30b" # 修复: 使用正确的模型名称(下划线)
- )
- )
- model_response = future.result()
- except RuntimeError:
- # 没有运行中的事件循环,直接使用asyncio.run
- model_response = asyncio.run(self.generate_model_client.get_model_generate_invoke(
- trace_id=trace_id,
- task_prompt_info=task_prompt_info,
- timeout=60,
- model_name="qwen3_30b" # 修复: 使用正确的模型名称(下划线)
- ))
- # 格式化模型响应
- formatted_response = self.ai_respose_format(model_response)
- # 检查 formatted_response 是否为 None
- if formatted_response is not None:
- server_logger.info(f"查询对构建完成,构建 {len(formatted_response)}条。")
- else:
- server_logger.warning("查询对构建失败,formatted_response 为 None")
- # 记录日志
- if formatted_response:
- server_logger.info(f"Query 提取成功, 提取到 {len(formatted_response)} 个实体")
- else:
- server_logger.warning(f"Query 提取失败, 格式化后为空")
- return formatted_response
- except Exception as e:
- server_logger.error(f"Query 提取失败: {str(e)}")
- return None
-
- def ai_respose_format(self, model_response):
- """
- 将模型返回的响应格式化为标准格式
- Args:
- model_response: AI模型返回的原始响应(可能是字符串或已解析的JSON)
- Returns:
- list: 标准格式的查询列表
- [
- {
- "entity": str, # 实体名称
- "search_keywords": list, # 搜索关键词列表
- "background": str, # 背景信息
- "parameter": str # 技术参数
- }
- ]
- 或 None(解析失败时)
- """
- import re
- import json
- try:
- # 1. 如果model_response已经是list,直接返回
- if isinstance(model_response, list):
- server_logger.info(f"模型响应已是list格式, 包含 {len(model_response)} 个实体")
- return model_response
- # 2. 如果是dict,包装成list返回
- if isinstance(model_response, dict):
- server_logger.info("模型响应是dict格式, 包装为list")
- return [model_response]
- # 3. 如果是字符串,需要解析
- if isinstance(model_response, str):
- response_text = model_response.strip()
- server_logger.debug(f"原始响应字符串长度: {len(response_text)}")
- # 3.1 尝试去除 ```json 和 ``` 标记
- # 匹配 ```json ... ``` 或 ``` ... ```
- json_pattern = r'```(?:json)?\s*\n?(.*?)\n?```'
- json_match = re.search(json_pattern, response_text, re.DOTALL | re.IGNORECASE)
- if json_match:
- # 提取代码块中的JSON内容
- json_str = json_match.group(1).strip()
- server_logger.debug("检测到markdown代码块, 已提取纯JSON内容")
- else:
- # 如果没有代码块标记,尝试直接解析整个字符串
- json_str = response_text
- server_logger.debug("未检测到markdown代码块, 尝试直接解析")
- # 3.2 去除可能的Markdown注释或多余空白
- json_str = re.sub(r'\n+', '\n', json_str) # 多个换行压缩为一个
- json_str = json_str.strip()
- server_logger.debug(f"待解析的JSON字符串: {json_str[:200]}...")
- # 3.3 解析JSON
- parsed_data = json.loads(json_str)
- # 3.4 确保返回list格式
- if isinstance(parsed_data, list):
- server_logger.info(f"JSON解析成功, 提取到 {len(parsed_data)} 个实体")
- return parsed_data
- elif isinstance(parsed_data, dict):
- server_logger.info("JSON解析成功, 单个实体包装为list")
- return [parsed_data]
- server_logger.warning(f"无法识别的JSON格式: {type(parsed_data)}")
- return None
- server_logger.warning(f"无法识别的响应类型: {type(model_response)}")
- return None
- except json.JSONDecodeError as e:
- server_logger.error(f"JSON解析失败: {e}")
- server_logger.error(f"原始响应: {str(model_response)[:500]}")
- return None
- except Exception as e:
- server_logger.error(f"响应格式化异常: {e}")
- server_logger.error(f"原始响应: {str(model_response)[:500]}")
- return None
- query_rewrite_manager = QueryRewriteManager()
|