query_rewrite.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. import uuid
  2. import asyncio
  3. from foundation.infrastructure.config.config import config_handler
  4. from foundation.observability.logger.loggering import server_logger
  5. from foundation.ai.agent.generate.model_generate import generate_model_client
  6. from foundation.database.base.vector.milvus_vector import MilvusVectorManager
  7. from core.construction_review.component.reviewers.utils import prompt_loader
  8. class QueryRewriteManager():
  9. """
  10. 召回管理器,实现多路召回功能
  11. """
  12. def __init__(self):
  13. # 获取部署的模型列表
  14. self.generate_model_client = generate_model_client
  15. self.prompt_loader = prompt_loader
  16. def query_extract(self, review_content):
  17. """
  18. 从审查条文中提取query
  19. return:
  20. query: str
  21. background: str
  22. parameters: str
  23. """
  24. try:
  25. # 获取提示词模板并组装
  26. task_prompt = self.prompt_loader.get_prompt_template(
  27. reviewer_type="query_extract", # 审查器类型
  28. review_content=review_content # 传入审查内容作为参数
  29. )
  30. # 构建任务提示信息 - 参考标准模式
  31. task_prompt_info = {
  32. "task_prompt": task_prompt, # 使用组装好的提示词
  33. "task_name": "query_extract"
  34. }
  35. # 生成唯一的trace_id用于追踪
  36. trace_id = str(uuid.uuid4())
  37. # 调用模型生成接口(使用异步运行)
  38. model_response = asyncio.run(self.generate_model_client.get_model_generate_invoke(
  39. trace_id=trace_id,
  40. task_prompt_info=task_prompt_info
  41. ))
  42. # 记录日志
  43. server_logger.info(f"Query 提取完成长度: {len(review_content)}")
  44. return model_response
  45. except Exception as e:
  46. server_logger.error(f"Query 提取失败: {str(e)}")
  47. return None