import os import requests from dotenv import load_dotenv from langchain_core.documents import Document from typing import Sequence, List # 加载环境变量 load_dotenv() # 读取配置 CUSTOM_API_URL = os.getenv("CUSTOM_COHERE_BASE_URL", "http://192.168.91.253:9005/v1/rerank") RERANK_MODEL = os.getenv("RERANK_MODEL", "bge-reranker-v2-m3") TOP_N = 3 # 校验配置 if not CUSTOM_API_URL: raise ValueError("请在.env文件中配置 CUSTOM_COHERE_BASE_URL(你的 API 地址)") class CustomReranker: """自定义重排器,适配 API 格式:candidates 参数 + 响应返回 text+score""" def __init__(self, api_url: str, model: str, top_n: int = 3): self.api_url = api_url self.model = model self.top_n = top_n def compress_documents( self, documents: Sequence[Document], query: str ) -> Sequence[Document]: """重写压缩逻辑,适配自定义 API 格式""" # 1. 提取文档内容(转为字符串数组,适配 candidates 参数) candidates = [doc.page_content for doc in documents] if not candidates: return [] # 2. 构造 API 请求体(用 candidates 而非 documents) request_body = { "model": self.model, "query": query, "candidates": candidates, "top_n": self.top_n } # 3. 发送请求到自定义 API headers = {"Content-Type": "application/json"} try: response = requests.post( self.api_url, json=request_body, headers=headers, timeout=30 ) response.raise_for_status() result = response.json() except requests.exceptions.RequestException as e: print(f"❌ API 请求失败:{str(e)}") if hasattr(e, 'response') and e.response is not None: print(f"响应内容:{e.response.text}") return [] # 4. 解析响应(核心修复:按 text 匹配原始 Document) reranked_docs = [] # 从响应中获取排序后的 text 和 score sorted_results = result.get("results", [])[:self.top_n] # 取 Top-N 结果 for item in sorted_results: sorted_text = item.get("text") # 排序后的文档内容 relevance_score = item.get("score", 0.0) # 相关性分数 # 找到原始 Document(通过内容完全匹配,保留元数据) # 注意:若文档内容有重复,会匹配第一个,建议原始文档避免重复 matched_doc = next( (doc for doc in documents if doc.page_content == sorted_text), None # 未匹配到返回 None ) if matched_doc: # 将分数存入元数据,方便后续查看 matched_doc.metadata["relevance_score"] = float(relevance_score) reranked_docs.append(matched_doc) else: print(f"⚠️ 未找到匹配的原始文档:{sorted_text}") return reranked_docs def rerank_bridge_construction_docs(): """中文交通路桥施工文档重排示例""" # 1. 初始化自定义重排器 reranker = CustomReranker( api_url=CUSTOM_API_URL, model=RERANK_MODEL, top_n=TOP_N ) # 2. 路桥施工相关文档(模拟检索结果) construction_docs = [ Document( page_content="大跨度桥梁挂篮施工时,需严格控制挂篮的变形量,变形值不得超过设计允许的5mm。挂篮前移应匀速缓慢,速度控制在0.5m/h以内。", metadata={"source": "桥梁施工规范2025", "section": "挂篮施工章节"} ), Document( page_content="高速公路路基填筑应分层碾压,每层厚度不超过30cm,压实度需达到96%以上。碾压设备优先选用重型振动压路机。", metadata={"source": "路基工程施工手册", "section": "路基碾压工艺"} ), Document( page_content="挂篮施工的关键工序包括:底模安装、钢筋绑扎、预应力筋布设、混凝土浇筑、预应力张拉。其中预应力张拉需在混凝土强度达到设计值的90%后进行。", metadata={"source": "大跨度桥梁施工技术指南", "section": "挂篮施工关键工序"} ), Document( page_content="桥梁桩基钻孔施工中,泥浆比重应根据地质情况调整,粉质黏土地层泥浆比重控制在1.1~1.2之间。", metadata={"source": "桩基施工操作规程", "section": "泥浆指标控制"} ), Document( page_content="挂篮的承重结构采用高强度钢材制作,需进行严格的荷载试验,试验荷载为设计荷载的1.2倍,持荷时间不少于24小时。", metadata={"source": "桥梁施工安全技术规范", "section": "挂篮荷载试验"} ), Document( page_content="隧道施工中应加强通风,保证洞内氧气浓度不低于19.5%,有害气体浓度符合国家标准。", metadata={"source": "隧道工程施工规范", "section": "洞内通风要求"} ) ] # 3. 施工相关查询词 query = "大跨度桥梁挂篮施工的关键工序和质量控制要点" # 4. 执行重排 reranked_docs = reranker.compress_documents(documents=construction_docs, query=query) # 5. 输出结果 print("\n" + "="*50) print(f"查询词:{query}") print("="*50) print(f"重排后 Top-{len(reranked_docs)} 相关文档:\n") for idx, doc in enumerate(reranked_docs, 1): print(f"【第{idx}篇】") print(f"来源:{doc.metadata['source']} - {doc.metadata['section']}") print(f"相关性分数:{doc.metadata.get('relevance_score', '未返回'):.6f}") print(f"内容:{doc.page_content}\n") if __name__ == "__main__": rerank_bridge_construction_docs()