rerank_model.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. """
  4. 重排序执行模块
  5. 用于调用BGE重排序模型进行文档重排序
  6. """
  7. import json
  8. import requests
  9. from typing import List, Dict, Any
  10. from foundation.infrastructure.config.config import config_handler
  11. from foundation.observability.logger.loggering import server_logger
  12. class LqReranker:
  13. """
  14. 重排序执行器
  15. """
  16. def __init__(self):
  17. self.api_url = config_handler.get('rerank_model', 'BGE_RERANKER_SERVER_RUL')
  18. self.model = config_handler.get('rerank_model', 'BGE_RERANKER_MODEL_ID')
  19. # 确保top_k是整数类型,避免切片错误
  20. self.top_k = int(config_handler.get('rerank_model', 'BGE_RERANKER_TOP_N', 5))
  21. def bge_rerank(self,query: str, candidates: List[str],top_k :int = None) -> List[Dict[str, Any]]:
  22. """
  23. 执行重排序的全局函数
  24. Args:
  25. query: 查询文本
  26. candidates: 候选文档列表
  27. top_k: 调用时chaurnum参数,默认为None
  28. Returns:
  29. List[Dict]: 重排序后的结果列表
  30. """
  31. try:
  32. # self.top_k 是config.ini生产环境中实际使用的重排序数量,bge_rerank中的top_k,用于开发环境中快速效果调试
  33. if not top_k:# 如果开发top_k未指定,则使用配置文件中的top_k
  34. top_k = self.top_k
  35. server_logger.info(f"开始执行重排序,查询:, 候选文档数量: {len(candidates)}")
  36. # 构建重排序请求
  37. rerank_request = {
  38. "model": "bge-reranker-v2-m3",
  39. "query": query,
  40. "candidates": candidates
  41. }
  42. # 直接调用重排序API
  43. url = self.api_url
  44. headers = {
  45. "Content-Type": "application/json"
  46. }
  47. server_logger.debug(f"调用重排序API: {url}")
  48. server_logger.debug(f"请求数据: {json.dumps(rerank_request, ensure_ascii=False)}")
  49. response = requests.post(url, headers=headers, json=rerank_request, timeout=30)
  50. if response.status_code == 200:
  51. result = response.json()
  52. server_logger.debug(f"API响应: {json.dumps(result, ensure_ascii=False)}")
  53. if "results" in result:
  54. return result["results"][:top_k]
  55. else:
  56. server_logger.warning(f"API响应格式异常: {result}")
  57. return []
  58. else:
  59. server_logger.error(f"API调用失败,状态码: {response.status_code}, 响应: {response.text}")
  60. return []
  61. except Exception as e:
  62. server_logger.error(f"执行重排序失败: {str(e)}")
  63. # 返回原始顺序作为fallback
  64. return [{"text": doc, "score": "0.0"} for doc in candidates[:top_k]]
  65. rerank_model = LqReranker()