test_rerank_model.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. import os
  2. import requests
  3. from dotenv import load_dotenv
  4. from langchain_core.documents import Document
  5. from typing import Sequence, List
  6. # 加载环境变量
  7. load_dotenv()
  8. # 读取配置
  9. CUSTOM_API_URL = os.getenv("CUSTOM_COHERE_BASE_URL", "http://192.168.91.253:9005/v1/rerank")
  10. RERANK_MODEL = os.getenv("RERANK_MODEL", "bge-reranker-v2-m3")
  11. TOP_N = 3
  12. # 校验配置
  13. if not CUSTOM_API_URL:
  14. raise ValueError("请在.env文件中配置 CUSTOM_COHERE_BASE_URL(你的 API 地址)")
  15. class CustomReranker:
  16. """自定义重排器,适配 API 格式:candidates 参数 + 响应返回 text+score"""
  17. def __init__(self, api_url: str, model: str, top_n: int = 3):
  18. self.api_url = api_url
  19. self.model = model
  20. self.top_n = top_n
  21. def compress_documents(
  22. self, documents: Sequence[Document], query: str
  23. ) -> Sequence[Document]:
  24. """重写压缩逻辑,适配自定义 API 格式"""
  25. # 1. 提取文档内容(转为字符串数组,适配 candidates 参数)
  26. candidates = [doc.page_content for doc in documents]
  27. if not candidates:
  28. return []
  29. # 2. 构造 API 请求体(用 candidates 而非 documents)
  30. request_body = {
  31. "model": self.model,
  32. "query": query,
  33. "candidates": candidates,
  34. "top_n": self.top_n
  35. }
  36. # 3. 发送请求到自定义 API
  37. headers = {"Content-Type": "application/json"}
  38. try:
  39. response = requests.post(
  40. self.api_url,
  41. json=request_body,
  42. headers=headers,
  43. timeout=30
  44. )
  45. response.raise_for_status()
  46. result = response.json()
  47. except requests.exceptions.RequestException as e:
  48. print(f"❌ API 请求失败:{str(e)}")
  49. if hasattr(e, 'response') and e.response is not None:
  50. print(f"响应内容:{e.response.text}")
  51. return []
  52. # 4. 解析响应(核心修复:按 text 匹配原始 Document)
  53. reranked_docs = []
  54. # 从响应中获取排序后的 text 和 score
  55. sorted_results = result.get("results", [])[:self.top_n] # 取 Top-N 结果
  56. for item in sorted_results:
  57. sorted_text = item.get("text") # 排序后的文档内容
  58. relevance_score = item.get("score", 0.0) # 相关性分数
  59. # 找到原始 Document(通过内容完全匹配,保留元数据)
  60. # 注意:若文档内容有重复,会匹配第一个,建议原始文档避免重复
  61. matched_doc = next(
  62. (doc for doc in documents if doc.page_content == sorted_text),
  63. None # 未匹配到返回 None
  64. )
  65. if matched_doc:
  66. # 将分数存入元数据,方便后续查看
  67. matched_doc.metadata["relevance_score"] = float(relevance_score)
  68. reranked_docs.append(matched_doc)
  69. else:
  70. print(f"⚠️ 未找到匹配的原始文档:{sorted_text}")
  71. return reranked_docs
  72. def rerank_bridge_construction_docs():
  73. """中文交通路桥施工文档重排示例"""
  74. # 1. 初始化自定义重排器
  75. reranker = CustomReranker(
  76. api_url=CUSTOM_API_URL,
  77. model=RERANK_MODEL,
  78. top_n=TOP_N
  79. )
  80. # 2. 路桥施工相关文档(模拟检索结果)
  81. construction_docs = [
  82. Document(
  83. page_content="大跨度桥梁挂篮施工时,需严格控制挂篮的变形量,变形值不得超过设计允许的5mm。挂篮前移应匀速缓慢,速度控制在0.5m/h以内。",
  84. metadata={"source": "桥梁施工规范2025", "section": "挂篮施工章节"}
  85. ),
  86. Document(
  87. page_content="高速公路路基填筑应分层碾压,每层厚度不超过30cm,压实度需达到96%以上。碾压设备优先选用重型振动压路机。",
  88. metadata={"source": "路基工程施工手册", "section": "路基碾压工艺"}
  89. ),
  90. Document(
  91. page_content="挂篮施工的关键工序包括:底模安装、钢筋绑扎、预应力筋布设、混凝土浇筑、预应力张拉。其中预应力张拉需在混凝土强度达到设计值的90%后进行。",
  92. metadata={"source": "大跨度桥梁施工技术指南", "section": "挂篮施工关键工序"}
  93. ),
  94. Document(
  95. page_content="桥梁桩基钻孔施工中,泥浆比重应根据地质情况调整,粉质黏土地层泥浆比重控制在1.1~1.2之间。",
  96. metadata={"source": "桩基施工操作规程", "section": "泥浆指标控制"}
  97. ),
  98. Document(
  99. page_content="挂篮的承重结构采用高强度钢材制作,需进行严格的荷载试验,试验荷载为设计荷载的1.2倍,持荷时间不少于24小时。",
  100. metadata={"source": "桥梁施工安全技术规范", "section": "挂篮荷载试验"}
  101. ),
  102. Document(
  103. page_content="隧道施工中应加强通风,保证洞内氧气浓度不低于19.5%,有害气体浓度符合国家标准。",
  104. metadata={"source": "隧道工程施工规范", "section": "洞内通风要求"}
  105. )
  106. ]
  107. # 3. 施工相关查询词
  108. query = "大跨度桥梁挂篮施工的关键工序和质量控制要点"
  109. # 4. 执行重排
  110. reranked_docs = reranker.compress_documents(documents=construction_docs, query=query)
  111. # 5. 输出结果
  112. print("\n" + "="*50)
  113. print(f"查询词:{query}")
  114. print("="*50)
  115. print(f"重排后 Top-{len(reranked_docs)} 相关文档:\n")
  116. for idx, doc in enumerate(reranked_docs, 1):
  117. print(f"【第{idx}篇】")
  118. print(f"来源:{doc.metadata['source']} - {doc.metadata['section']}")
  119. print(f"相关性分数:{doc.metadata.get('relevance_score', '未返回'):.6f}")
  120. print(f"内容:{doc.page_content}\n")
  121. if __name__ == "__main__":
  122. rerank_bridge_construction_docs()