rerank_model.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. """
  4. 重排序执行模块
  5. 用于调用重排序模型进行文档重排序
  6. 支持的重排序模型:
  7. - BGE Reranker (本地部署)
  8. - Qwen3-Reranker-8B (本地部署)
  9. - Qwen3-Reranker-8B (硅基流动API)
  10. """
  11. import json
  12. import requests
  13. from typing import List, Dict, Any
  14. from foundation.infrastructure.config.config import config_handler
  15. from foundation.observability.logger.loggering import review_logger as server_logger
  16. class LqReranker:
  17. """
  18. 重排序执行器
  19. """
  20. def __init__(self):
  21. # BGE Reranker 配置
  22. self.bge_api_url = config_handler.get('bge_rerank_model', 'BGE_RERANKER_SERVER_URL')
  23. self.bge_model = config_handler.get('bge_rerank_model', 'BGE_RERANKER_MODEL')
  24. self.bge_top_k = int(config_handler.get('bge_rerank_model', 'BGE_RERANKER_TOP_N', 10))
  25. # 本地Qwen3-Reranker-8B配置
  26. self.lq_rerank_api_url = config_handler.get('lq_rerank_model', 'LQ_RERANKER_SERVER_URL')
  27. self.lq_rerank_model = config_handler.get('lq_rerank_model', 'LQ_RERANKER_MODEL')
  28. self.lq_rerank_top_k = int(config_handler.get('lq_rerank_model', 'LQ_RERANKER_TOP_N', 10))
  29. # SHUTIAN Qwen3-Reranker-8B配置(蜀天云算力 25426端口)
  30. self.shutian_rerank_api_url = config_handler.get('shutian', 'SHUTIAN_RERANK_SERVER_URL')
  31. self.shutian_rerank_model = config_handler.get('shutian', 'SHUTIAN_RERANK_MODEL_ID')
  32. self.shutian_rerank_api_key = config_handler.get('shutian', 'SHUTIAN_RERANK_API_KEY')
  33. # 硅基流动Qwen3-Reranker-8B配置
  34. self.silicoflow_rerank_api_url = config_handler.get('silicoflow_rerank_model', 'SILICOFLOW_RERANKER_API_URL', 'https://api.siliconflow.cn/v1/rerank')
  35. self.silicoflow_rerank_api_key = config_handler.get('silicoflow_rerank_model', 'SILICOFLOW_RERANKER_API_KEY')
  36. self.silicoflow_rerank_model = config_handler.get('silicoflow_rerank_model', 'SILICOFLOW_RERANKER_MODEL', 'Qwen/Qwen3-Reranker-8B')
  37. def bge_rerank(self,query: str, candidates: List[str],top_k :int = None) -> List[Dict[str, Any]]:
  38. """
  39. 执行重排序的全局函数
  40. Args:
  41. query: 查询文本
  42. candidates: 候选文档列表
  43. top_k: 调用时chaurnum参数,默认为None
  44. Returns:
  45. List[Dict]: 重排序后的结果列表
  46. """
  47. try:
  48. # self.top_k 是config.ini生产环境中实际使用的重排序数量,bge_rerank中的top_k,用于开发环境中快速效果调试
  49. if not top_k:# 如果开发top_k未指定,则使用配置文件中的top_k
  50. top_k = self.bge_top_k
  51. server_logger.info(f"开始执行重排序,查询: '{query}', 候选文档数量: {len(candidates)}")
  52. # 构建重排序请求
  53. rerank_request = {
  54. "model": self.bge_model,
  55. "query": query,
  56. "candidates": candidates
  57. }
  58. # 直接调用重排序API
  59. url = self.bge_api_url
  60. headers = {
  61. "Content-Type": "application/json"
  62. }
  63. server_logger.debug(f"调用重排序API: {url}")
  64. server_logger.debug(f"请求数据: {json.dumps(rerank_request, ensure_ascii=False)}")
  65. response = requests.post(url, headers=headers, json=rerank_request, timeout=30)
  66. if response.status_code == 200:
  67. result = response.json()
  68. server_logger.debug(f"API响应: {json.dumps(result, ensure_ascii=False)}")
  69. if "results" in result:
  70. return result["results"][:top_k]
  71. else:
  72. server_logger.warning(f"API响应格式异常: {result}")
  73. return []
  74. else:
  75. server_logger.error(f"API调用失败,状态码: {response.status_code}, 响应: {response.text}")
  76. return []
  77. except Exception as e:
  78. server_logger.error(f"执行重排序失败: {str(e)}")
  79. # 返回原始顺序作为fallback
  80. return [{"text": doc, "score": "0.0"} for doc in candidates[:top_k]]
  81. def lq_rerank(self, query: str, candidates: List[str], top_k: int = None) -> List[Dict[str, Any]]:
  82. """
  83. 使用本地部署的 Qwen3-Reranker-8B 进行重排序
  84. Args:
  85. query: 查询文本
  86. candidates: 候选文档列表
  87. top_k: 返回前k个结果,默认使用配置文件的top_k
  88. Returns:
  89. List[Dict[str, Any]]: 重排序后的结果列表
  90. [
  91. {
  92. "text": str, # 文档文本内容
  93. "score": float, # 相关性得分
  94. "index": int # 原始索引
  95. },
  96. ...
  97. ]
  98. """
  99. try:
  100. if not top_k:
  101. top_k = self.lq_rerank_top_k
  102. # 检查query是否为空
  103. if not query or not query.strip():
  104. server_logger.warning(f"本地Qwen3重排序跳过:query为空")
  105. return [{"text": doc, "score": 0.0} for doc in candidates[:top_k]]
  106. server_logger.info(f"开始执行本地Qwen3重排序,查询: '{query}', 候选文档数量: {len(candidates)}")
  107. # 定义变量(与测试脚本完全一致)
  108. url = self.lq_rerank_api_url
  109. prefix = '<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|>\n<|im_start|>user\n'
  110. suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
  111. query_template = "{prefix}<Instruct>: {instruction}\n<Query>: {query}\n"
  112. document_template = "<Document>: {doc}{suffix}"
  113. instruction = (
  114. "请根据桥梁施工建设相关的查询内容,对文档进行重新排序,优先返回与桥梁施工、建设标准、技术规范、质量控制、安全管理等高度相关的文档。"
  115. )
  116. query = query_template.format(prefix=prefix, instruction=instruction, query=query)
  117. documents = [document_template.format(doc=doc, suffix=suffix) for doc in candidates]
  118. data = {
  119. "model": self.lq_rerank_model,
  120. "query": query,
  121. "documents": documents
  122. }
  123. headers = {"Content-Type": "application/json"}
  124. response = requests.post(url, headers=headers, json=data, timeout=30)
  125. if response.status_code == 200:
  126. result = response.json()
  127. if "results" in result:
  128. # 格式化结果:将嵌套的 document.text 提取到外层,并清理模板标记
  129. formatted_results = []
  130. for item in result["results"]:
  131. # 获取包含模板的原始文本
  132. raw_text = item.get("document", {}).get("text", "")
  133. # 清理模板标记:去除 <Document>: 和 <|im_end|>...assistant 之后的内容
  134. # 文本格式: <Document>: 原始内容<|im_end|>\n<|im_start|>assistant\n...
  135. if "<Document>:" in raw_text:
  136. # 提取 <Document>: 和 <|im_end|> 之间的内容
  137. start = raw_text.find("<Document>:") + len("<Document>:")
  138. end = raw_text.find("<|im_end|>")
  139. if end > start:
  140. cleaned_text = raw_text[start:end].strip()
  141. else:
  142. cleaned_text = raw_text[start:].strip()
  143. else:
  144. cleaned_text = raw_text
  145. formatted_results.append({
  146. "text": cleaned_text,
  147. "score": float(item.get("relevance_score", 0.0)),
  148. "index": item.get("index", 0)
  149. })
  150. server_logger.info(f"本地Qwen3 API响应: {formatted_results[:top_k]}")
  151. return formatted_results[:top_k]
  152. else:
  153. server_logger.warning(f"API响应格式异常: {result}")
  154. return []
  155. else:
  156. server_logger.error(f"API调用失败,状态码: {response.status_code}, 响应: {response.text}")
  157. return []
  158. except Exception as e:
  159. server_logger.error(f"执行本地Qwen3重排序失败: {str(e)}")
  160. return [{"text": doc, "score": 0.0} for doc in candidates[:top_k]]
  161. def shutian_rerank(self, query: str, candidates: List[str], top_k: int = None) -> List[Dict[str, Any]]:
  162. """
  163. 使用蜀天云算力部署的 Qwen3-Reranker-8B (端口25426) 进行重排序
  164. 接口为标准 OpenAI 兼容 rerank API,无需模板包装,直接传原始 query/documents
  165. """
  166. try:
  167. if not top_k:
  168. top_k = self.lq_rerank_top_k
  169. if not query or not query.strip():
  170. server_logger.warning("SHUTIAN重排序跳过:query为空")
  171. return [{"text": doc, "score": 0.0} for doc in candidates[:top_k]]
  172. server_logger.info(f"开始执行SHUTIAN Qwen3重排序,查询: '{query}', 候选文档数量: {len(candidates)}")
  173. data = {
  174. "model": self.shutian_rerank_model,
  175. "query": query,
  176. "documents": candidates,
  177. "top_n": top_k
  178. }
  179. headers = {
  180. "Content-Type": "application/json",
  181. "Authorization": f"Bearer {self.shutian_rerank_api_key}"
  182. }
  183. response = requests.post(self.shutian_rerank_api_url, headers=headers, json=data, timeout=30)
  184. if response.status_code == 200:
  185. result = response.json()
  186. # SHUTIAN API直接返回列表: [{"score": x, "document": "文本", "index": 0}, ...]
  187. results_list = result.get("results", result) if isinstance(result, dict) else result
  188. if isinstance(results_list, list) and results_list:
  189. formatted_results = []
  190. for item in results_list:
  191. doc = item.get("document", "")
  192. # document 可能是字符串或 {"text": "..."} 对象
  193. text = doc if isinstance(doc, str) else doc.get("text", "")
  194. formatted_results.append({
  195. "text": text,
  196. "score": float(item.get("relevance_score", item.get("score", 0.0))),
  197. "index": item.get("index", 0)
  198. })
  199. server_logger.info(f"SHUTIAN Qwen3重排序完成,返回 {len(formatted_results)} 个结果")
  200. return formatted_results[:top_k]
  201. else:
  202. server_logger.warning(f"SHUTIAN API响应格式异常: {result}")
  203. return []
  204. else:
  205. server_logger.error(f"SHUTIAN API调用失败,状态码: {response.status_code}, 响应: {response.text}")
  206. return []
  207. except Exception as e:
  208. server_logger.error(f"执行SHUTIAN Qwen3重排序失败: {str(e)}")
  209. return [{"text": doc, "score": 0.0} for doc in candidates[:top_k]]
  210. def qwen3_rerank(self, query: str, documents: List[str], top_k: int = None,
  211. instruction: str = "请根据桥梁施工建设相关的查询内容,对文档进行重新排序,优先返回与桥梁施工、建设标准、技术规范、质量控制、安全管理等高度相关的文档。") -> List[Dict[str, Any]]:
  212. """
  213. 使用硅基流动 Qwen3-Reranker-8B API 进行重排序
  214. Args:
  215. query: 查询文本
  216. documents: 文档列表
  217. top_k: 返回前k个结果,默认使用配置文件的top_k
  218. instruction: 重排序指令
  219. Returns:
  220. List[Dict]: 重排序后的结果列表,包含 text 和 score
  221. """
  222. try:
  223. if not top_k:
  224. top_k = 10 # 默认值
  225. if not self.silicoflow_rerank_api_key:
  226. server_logger.error("硅基流动 Reranker API Key 未配置")
  227. return []
  228. server_logger.info(f"开始执行硅基流动Qwen3重排序,查询: '{query}', 文档数量: {len(documents)}")
  229. # 构建请求数据
  230. request_data = {
  231. "model": self.silicoflow_rerank_model,
  232. "query": query,
  233. "documents": documents,
  234. "instruction": instruction,
  235. "top_n": top_k,
  236. "return_documents": True,
  237. # "max_chunks_per_doc": 123,
  238. # "overlap_tokens": 79
  239. }
  240. headers = {
  241. "Authorization": f"Bearer {self.silicoflow_rerank_api_key}",
  242. "Content-Type": "application/json"
  243. }
  244. server_logger.debug(f"调用硅基流动Qwen3 Reranker API: {self.silicoflow_rerank_api_url}")
  245. server_logger.debug(f"请求数据: {json.dumps(request_data, ensure_ascii=False)}")
  246. response = requests.post(
  247. self.silicoflow_rerank_api_url,
  248. headers=headers,
  249. json=request_data,
  250. timeout=30
  251. )
  252. if response.status_code == 200:
  253. result = response.json()
  254. server_logger.debug(f"硅基流动Qwen3 API响应: {json.dumps(result, ensure_ascii=False)}")
  255. if "results" in result:
  256. # 格式化结果为统一格式
  257. formatted_results = []
  258. for item in result["results"]:
  259. formatted_results.append({
  260. "text": item.get("document", {}).get("text", ""),
  261. "score": float(item.get("relevance_score", 0.0)),
  262. "index": item.get("index", 0)
  263. })
  264. return formatted_results[:top_k]
  265. else:
  266. server_logger.warning(f"API响应格式异常: {result}")
  267. return []
  268. else:
  269. server_logger.error(f"API调用失败,状态码: {response.status_code}, 响应: {response.text}")
  270. return []
  271. except Exception as e:
  272. server_logger.error(f"执行硅基流动Qwen3重排序失败: {str(e)}")
  273. # 返回原始顺序作为fallback
  274. return [{"text": doc, "score": 0.0} for doc in documents[:top_k]]
  275. rerank_model = LqReranker()