rerank_model.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. """
  4. 重排序执行模块
  5. 用于调用重排序模型进行文档重排序
  6. 支持的重排序模型:
  7. - BGE Reranker (本地部署)
  8. - Qwen3-Reranker-8B (本地部署)
  9. - Qwen3-Reranker-8B (蜀天算力)
  10. - Qwen3-Reranker-8B (硅基流动API)
  11. 配置加载策略: 懒加载(首次调用时从 config.ini 读取该后端的凭证并缓存)
  12. 路由决策: 由 retrieval.py 通过 model_setting.yaml 的 rerank 功能决定使用哪个后端
  13. """
  14. import json
  15. import requests
  16. from typing import List, Dict, Any, Optional
  17. from foundation.infrastructure.config.config import config_handler
  18. from foundation.observability.logger.loggering import review_logger as server_logger
  19. class LqReranker:
  20. """
  21. 重排序执行器
  22. 各后端配置按需加载:首次调用某后端时才从 config.ini 读取其凭证,
  23. 避免初始化时加载所有 4 个后端的配置。
  24. """
  25. def __init__(self):
  26. # 各后端配置缓存(首次调用时加载)
  27. self._bge_config: Optional[Dict[str, Any]] = None
  28. self._lq_config: Optional[Dict[str, Any]] = None
  29. self._shutian_config: Optional[Dict[str, Any]] = None
  30. self._silicoflow_config: Optional[Dict[str, Any]] = None
  31. def _get_bge_config(self) -> Dict[str, Any]:
  32. """懒加载 BGE Reranker 配置"""
  33. if self._bge_config is None:
  34. self._bge_config = {
  35. 'api_url': config_handler.get('bge_rerank_model', 'BGE_RERANKER_SERVER_URL'),
  36. 'model': config_handler.get('bge_rerank_model', 'BGE_RERANKER_MODEL'),
  37. 'top_k': int(config_handler.get('bge_rerank_model', 'BGE_RERANKER_TOP_N', 10)),
  38. }
  39. return self._bge_config
  40. def _get_lq_config(self) -> Dict[str, Any]:
  41. """懒加载本地 Qwen3-Reranker 配置"""
  42. if self._lq_config is None:
  43. self._lq_config = {
  44. 'api_url': config_handler.get('lq_rerank_model', 'LQ_RERANKER_SERVER_URL'),
  45. 'model': config_handler.get('lq_rerank_model', 'LQ_RERANKER_MODEL'),
  46. 'top_k': int(config_handler.get('lq_rerank_model', 'LQ_RERANKER_TOP_N', 10)),
  47. }
  48. return self._lq_config
  49. def _get_shutian_config(self) -> Dict[str, Any]:
  50. """懒加载蜀天 Qwen3-Reranker 配置"""
  51. if self._shutian_config is None:
  52. self._shutian_config = {
  53. 'api_url': config_handler.get('shutian', 'SHUTIAN_RERANK_SERVER_URL'),
  54. 'model': config_handler.get('shutian', 'SHUTIAN_RERANK_MODEL_ID'),
  55. 'api_key': config_handler.get('shutian', 'SHUTIAN_RERANK_API_KEY'),
  56. }
  57. return self._shutian_config
  58. def _get_silicoflow_config(self) -> Dict[str, Any]:
  59. """懒加载硅基流动 Qwen3-Reranker 配置"""
  60. if self._silicoflow_config is None:
  61. self._silicoflow_config = {
  62. 'api_url': config_handler.get('silicoflow_rerank_model', 'SILICOFLOW_RERANKER_API_URL',
  63. 'https://api.siliconflow.cn/v1/rerank'),
  64. 'api_key': config_handler.get('silicoflow_rerank_model', 'SILICOFLOW_RERANKER_API_KEY'),
  65. 'model': config_handler.get('silicoflow_rerank_model', 'SILICOFLOW_RERANKER_MODEL',
  66. 'Qwen/Qwen3-Reranker-8B'),
  67. }
  68. return self._silicoflow_config
  69. def bge_rerank(self, query: str, candidates: List[str], top_k: int = None) -> List[Dict[str, Any]]:
  70. """
  71. 使用本地 BGE-reranker-v2-m3 进行重排序
  72. Args:
  73. query: 查询文本
  74. candidates: 候选文档列表
  75. top_k: 返回前k个结果,默认使用配置文件的top_k
  76. Returns:
  77. List[Dict]: 重排序后的结果列表
  78. """
  79. try:
  80. cfg = self._get_bge_config()
  81. if not top_k:
  82. top_k = cfg['top_k']
  83. server_logger.info(f"开始执行重排序,查询: '{query}', 候选文档数量: {len(candidates)}")
  84. rerank_request = {
  85. "model": cfg['model'],
  86. "query": query,
  87. "documents": candidates
  88. }
  89. headers = {"Content-Type": "application/json"}
  90. server_logger.debug(f"调用重排序API: {cfg['api_url']}")
  91. server_logger.debug(f"请求数据: {json.dumps(rerank_request, ensure_ascii=False)}")
  92. response = requests.post(cfg['api_url'], headers=headers, json=rerank_request, timeout=30)
  93. if response.status_code == 200:
  94. result = response.json()
  95. server_logger.debug(f"API响应: {json.dumps(result, ensure_ascii=False)}")
  96. if "results" in result:
  97. return result["results"][:top_k]
  98. else:
  99. server_logger.warning(f"API响应格式异常: {result}")
  100. return []
  101. else:
  102. server_logger.error(f"API调用失败,状态码: {response.status_code}, 响应: {response.text}")
  103. return []
  104. except Exception as e:
  105. server_logger.error(f"执行重排序失败: {str(e)}")
  106. return [{"text": doc, "score": "0.0"} for doc in candidates[:top_k]]
  107. def lq_rerank(self, query: str, candidates: List[str], top_k: int = None) -> List[Dict[str, Any]]:
  108. """
  109. 使用本地部署的 Qwen3-Reranker-8B 进行重排序
  110. Args:
  111. query: 查询文本
  112. candidates: 候选文档列表
  113. top_k: 返回前k个结果,默认使用配置文件的top_k
  114. Returns:
  115. List[Dict[str, Any]]: 重排序后的结果列表
  116. """
  117. try:
  118. cfg = self._get_lq_config()
  119. if not top_k:
  120. top_k = cfg['top_k']
  121. if not query or not query.strip():
  122. server_logger.warning(f"本地Qwen3重排序跳过:query为空")
  123. return [{"text": doc, "score": 0.0} for doc in candidates[:top_k]]
  124. server_logger.info(f"开始执行本地Qwen3重排序,查询: '{query}', 候选文档数量: {len(candidates)}")
  125. url = cfg['api_url']
  126. prefix = '<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|>\n<|im_start|>user\n'
  127. suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
  128. query_template = "{prefix}<Instruct>: {instruction}\n<Query>: {query}\n"
  129. document_template = "<Document>: {doc}{suffix}"
  130. instruction = (
  131. "请根据桥梁施工建设相关的查询内容,对文档进行重新排序,优先返回与桥梁施工、建设标准、技术规范、质量控制、安全管理等高度相关的文档。"
  132. )
  133. query = query_template.format(prefix=prefix, instruction=instruction, query=query)
  134. documents = [document_template.format(doc=doc, suffix=suffix) for doc in candidates]
  135. data = {
  136. "model": cfg['model'],
  137. "query": query,
  138. "documents": documents
  139. }
  140. headers = {"Content-Type": "application/json"}
  141. response = requests.post(url, headers=headers, json=data, timeout=30)
  142. if response.status_code == 200:
  143. result = response.json()
  144. if "results" in result:
  145. formatted_results = []
  146. for item in result["results"]:
  147. raw_text = item.get("document", {}).get("text", "")
  148. if "<Document>:" in raw_text:
  149. start = raw_text.find("<Document>:") + len("<Document>:")
  150. end = raw_text.find("<|im_end|>")
  151. if end > start:
  152. cleaned_text = raw_text[start:end].strip()
  153. else:
  154. cleaned_text = raw_text[start:].strip()
  155. else:
  156. cleaned_text = raw_text
  157. formatted_results.append({
  158. "text": cleaned_text,
  159. "score": float(item.get("relevance_score", 0.0)),
  160. "index": item.get("index", 0)
  161. })
  162. server_logger.info(f"本地Qwen3 API响应: {formatted_results[:top_k]}")
  163. return formatted_results[:top_k]
  164. else:
  165. server_logger.warning(f"API响应格式异常: {result}")
  166. return []
  167. else:
  168. server_logger.error(f"API调用失败,状态码: {response.status_code}, 响应: {response.text}")
  169. return []
  170. except Exception as e:
  171. server_logger.error(f"执行本地Qwen3重排序失败: {str(e)}")
  172. return [{"text": doc, "score": 0.0} for doc in candidates[:top_k]]
  173. def shutian_rerank(self, query: str, candidates: List[str], top_k: int = None) -> List[Dict[str, Any]]:
  174. """
  175. 使用蜀天云算力部署的 Qwen3-Reranker-8B (端口25426) 进行重排序
  176. 接口为标准 OpenAI 兼容 rerank API,无需模板包装,直接传原始 query/documents
  177. """
  178. try:
  179. cfg = self._get_shutian_config()
  180. if not top_k:
  181. top_k = self._get_lq_config()['top_k']
  182. if not query or not query.strip():
  183. server_logger.warning("SHUTIAN重排序跳过:query为空")
  184. return [{"text": doc, "score": 0.0} for doc in candidates[:top_k]]
  185. server_logger.info(f"开始执行SHUTIAN Qwen3重排序,查询: '{query}', 候选文档数量: {len(candidates)}")
  186. data = {
  187. "model": cfg['model'],
  188. "query": query,
  189. "documents": candidates,
  190. "top_n": top_k
  191. }
  192. headers = {
  193. "Content-Type": "application/json",
  194. "Authorization": f"Bearer {cfg['api_key']}"
  195. }
  196. response = requests.post(cfg['api_url'], headers=headers, json=data, timeout=30)
  197. if response.status_code == 200:
  198. result = response.json()
  199. results_list = result.get("results", result) if isinstance(result, dict) else result
  200. if isinstance(results_list, list) and results_list:
  201. formatted_results = []
  202. for item in results_list:
  203. doc = item.get("document", "")
  204. text = doc if isinstance(doc, str) else doc.get("text", "")
  205. formatted_results.append({
  206. "text": text,
  207. "score": float(item.get("relevance_score", item.get("score", 0.0))),
  208. "index": item.get("index", 0)
  209. })
  210. server_logger.info(f"SHUTIAN Qwen3重排序完成,返回 {len(formatted_results)} 个结果")
  211. return formatted_results[:top_k]
  212. else:
  213. server_logger.warning(f"SHUTIAN API响应格式异常: {result}")
  214. return []
  215. else:
  216. server_logger.error(f"SHUTIAN API调用失败,状态码: {response.status_code}, 响应: {response.text}")
  217. return []
  218. except Exception as e:
  219. server_logger.error(f"执行SHUTIAN Qwen3重排序失败: {str(e)}")
  220. return [{"text": doc, "score": 0.0} for doc in candidates[:top_k]]
  221. def qwen3_rerank(self, query: str, documents: List[str], top_k: int = None,
  222. instruction: str = "请根据桥梁施工建设相关的查询内容,对文档进行重新排序,优先返回与桥梁施工、建设标准、技术规范、质量控制、安全管理等高度相关的文档。") -> List[Dict[str, Any]]:
  223. """
  224. 使用硅基流动 Qwen3-Reranker-8B API 进行重排序
  225. Args:
  226. query: 查询文本
  227. documents: 文档列表
  228. top_k: 返回前k个结果,默认10
  229. instruction: 重排序指令
  230. Returns:
  231. List[Dict]: 重排序后的结果列表,包含 text 和 score
  232. """
  233. try:
  234. cfg = self._get_silicoflow_config()
  235. if not top_k:
  236. top_k = 10
  237. if not cfg['api_key']:
  238. server_logger.error("硅基流动 Reranker API Key 未配置")
  239. return []
  240. server_logger.info(f"开始执行硅基流动Qwen3重排序,查询: '{query}', 文档数量: {len(documents)}")
  241. request_data = {
  242. "model": cfg['model'],
  243. "query": query,
  244. "documents": documents,
  245. "instruction": instruction,
  246. "top_n": top_k,
  247. "return_documents": True,
  248. }
  249. headers = {
  250. "Authorization": f"Bearer {cfg['api_key']}",
  251. "Content-Type": "application/json"
  252. }
  253. server_logger.debug(f"调用硅基流动Qwen3 Reranker API: {cfg['api_url']}")
  254. server_logger.debug(f"请求数据: {json.dumps(request_data, ensure_ascii=False)}")
  255. response = requests.post(
  256. cfg['api_url'],
  257. headers=headers,
  258. json=request_data,
  259. timeout=30
  260. )
  261. if response.status_code == 200:
  262. result = response.json()
  263. server_logger.debug(f"硅基流动Qwen3 API响应: {json.dumps(result, ensure_ascii=False)}")
  264. if "results" in result:
  265. formatted_results = []
  266. for item in result["results"]:
  267. formatted_results.append({
  268. "text": item.get("document", {}).get("text", ""),
  269. "score": float(item.get("relevance_score", 0.0)),
  270. "index": item.get("index", 0)
  271. })
  272. return formatted_results[:top_k]
  273. else:
  274. server_logger.warning(f"API响应格式异常: {result}")
  275. return []
  276. else:
  277. server_logger.error(f"API调用失败,状态码: {response.status_code}, 响应: {response.text}")
  278. return []
  279. except Exception as e:
  280. server_logger.error(f"执行硅基流动Qwen3重排序失败: {str(e)}")
  281. return [{"text": doc, "score": 0.0} for doc in documents[:top_k]]
  282. rerank_model = LqReranker()