rerank_model.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. """
  4. 重排序执行模块
  5. 用于调用重排序模型进行文档重排序
  6. 支持的重排序模型:
  7. - BGE Reranker (本地部署)
  8. - Qwen3-Reranker-8B (硅基流动API)
  9. """
  10. import json
  11. import requests
  12. from typing import List, Dict, Any
  13. from foundation.infrastructure.config.config import config_handler
  14. from foundation.observability.logger.loggering import server_logger
  15. class LqReranker:
  16. """
  17. 重排序执行器
  18. """
  19. def __init__(self):
  20. self.api_url = config_handler.get('rerank_model', 'BGE_RERANKER_SERVER_RUL')
  21. self.model = config_handler.get('rerank_model', 'BGE_RERANKER_MODEL_ID')
  22. # 确保top_k是整数类型,避免切片错误
  23. self.top_k = int(config_handler.get('rerank_model', 'BGE_RERANKER_TOP_N', 3))
  24. # Qwen3-Reranker-8B 配置
  25. self.qwen_api_url = config_handler.get('rerank_model_qwen', 'QWEN_RERANKER_API_URL', 'https://api.siliconflow.cn/v1/rerank')
  26. self.qwen_api_key = config_handler.get('rerank_model_qwen', 'QWEN_RERANKER_API_KEY')
  27. self.qwen_model = config_handler.get('rerank_model_qwen', 'QWEN_RERANKER_MODEL', 'Qwen/Qwen3-Reranker-8B')
  28. def bge_rerank(self,query: str, candidates: List[str],top_k :int = None) -> List[Dict[str, Any]]:
  29. """
  30. 执行重排序的全局函数
  31. Args:
  32. query: 查询文本
  33. candidates: 候选文档列表
  34. top_k: 调用时chaurnum参数,默认为None
  35. Returns:
  36. List[Dict]: 重排序后的结果列表
  37. """
  38. try:
  39. # self.top_k 是config.ini生产环境中实际使用的重排序数量,bge_rerank中的top_k,用于开发环境中快速效果调试
  40. if not top_k:# 如果开发top_k未指定,则使用配置文件中的top_k
  41. top_k = self.top_k
  42. server_logger.info(f"开始执行重排序,查询: '{query}', 候选文档数量: {len(candidates)}")
  43. # 构建重排序请求
  44. rerank_request = {
  45. "model": "bge-reranker-v2-m3",
  46. "query": query,
  47. "candidates": candidates
  48. }
  49. # 直接调用重排序API
  50. url = self.api_url
  51. headers = {
  52. "Content-Type": "application/json"
  53. }
  54. server_logger.debug(f"调用重排序API: {url}")
  55. server_logger.debug(f"请求数据: {json.dumps(rerank_request, ensure_ascii=False)}")
  56. response = requests.post(url, headers=headers, json=rerank_request, timeout=30)
  57. if response.status_code == 200:
  58. result = response.json()
  59. server_logger.debug(f"API响应: {json.dumps(result, ensure_ascii=False)}")
  60. if "results" in result:
  61. return result["results"][:top_k]
  62. else:
  63. server_logger.warning(f"API响应格式异常: {result}")
  64. return []
  65. else:
  66. server_logger.error(f"API调用失败,状态码: {response.status_code}, 响应: {response.text}")
  67. return []
  68. except Exception as e:
  69. server_logger.error(f"执行重排序失败: {str(e)}")
  70. # 返回原始顺序作为fallback
  71. return [{"text": doc, "score": "0.0"} for doc in candidates[:top_k]]
  72. def qwen3_rerank(self, query: str, documents: List[str], top_k: int = None,
  73. instruction: str = "请根据桥梁施工建设相关的查询内容,对文档进行重新排序,优先返回与桥梁施工、建设标准、技术规范、质量控制、安全管理等高度相关的文档。") -> List[Dict[str, Any]]:
  74. """
  75. 使用 Qwen3-Reranker-8B 进行重排序
  76. Args:
  77. query: 查询文本
  78. documents: 文档列表
  79. top_k: 返回前k个结果,默认使用配置文件的top_k
  80. instruction: 重排序指令
  81. Returns:
  82. List[Dict]: 重排序后的结果列表,包含 text 和 score
  83. """
  84. try:
  85. if not top_k:
  86. top_k = self.top_k
  87. if not self.qwen_api_key:
  88. server_logger.error("Qwen Reranker API Key 未配置")
  89. return []
  90. server_logger.info(f"开始执行Qwen3重排序,查询: '{query}', 文档数量: {len(documents)}")
  91. # 构建请求数据
  92. request_data = {
  93. "model": self.qwen_model,
  94. "query": query,
  95. "documents": documents,
  96. "instruction": instruction,
  97. "top_n": top_k,
  98. "return_documents": True,
  99. "max_chunks_per_doc": 123,
  100. "overlap_tokens": 79
  101. }
  102. headers = {
  103. "Authorization": f"Bearer {self.qwen_api_key}",
  104. "Content-Type": "application/json"
  105. }
  106. server_logger.debug(f"调用Qwen3 Reranker API: {self.qwen_api_url}")
  107. server_logger.debug(f"请求数据: {json.dumps(request_data, ensure_ascii=False)}")
  108. response = requests.post(
  109. self.qwen_api_url,
  110. headers=headers,
  111. json=request_data,
  112. timeout=30
  113. )
  114. if response.status_code == 200:
  115. result = response.json()
  116. server_logger.debug(f"Qwen3 API响应: {json.dumps(result, ensure_ascii=False)}")
  117. if "results" in result:
  118. # 格式化结果为统一格式
  119. formatted_results = []
  120. for item in result["results"]:
  121. formatted_results.append({
  122. "text": item.get("document", {}).get("text", ""),
  123. "score": float(item.get("relevance_score", 0.0)),
  124. "index": item.get("index", 0)
  125. })
  126. return formatted_results[:top_k]
  127. else:
  128. server_logger.warning(f"Qwen3 API响应格式异常: {result}")
  129. return []
  130. else:
  131. server_logger.error(f"Qwen3 API调用失败,状态码: {response.status_code}, 响应: {response.text}")
  132. return []
  133. except Exception as e:
  134. server_logger.error(f"执行Qwen3重排序失败: {str(e)}")
  135. # 返回原始顺序作为fallback
  136. return [{"text": doc, "score": 0.0, "index": i} for i, doc in enumerate(documents[:top_k])]
  137. rerank_model = LqReranker()