| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144 |
- import os
- import requests
- from dotenv import load_dotenv
- from langchain_core.documents import Document
- from typing import Sequence, List
- # 加载环境变量
- load_dotenv()
- # 读取配置
- CUSTOM_API_URL = os.getenv("CUSTOM_COHERE_BASE_URL", "http://192.168.91.253:9005/v1/rerank")
- RERANK_MODEL = os.getenv("RERANK_MODEL", "bge-reranker-v2-m3")
- TOP_N = 3
- # 校验配置
- if not CUSTOM_API_URL:
- raise ValueError("请在.env文件中配置 CUSTOM_COHERE_BASE_URL(你的 API 地址)")
- class CustomReranker:
- """自定义重排器,适配 API 格式:candidates 参数 + 响应返回 text+score"""
- def __init__(self, api_url: str, model: str, top_n: int = 3):
- self.api_url = api_url
- self.model = model
- self.top_n = top_n
- def compress_documents(
- self, documents: Sequence[Document], query: str
- ) -> Sequence[Document]:
- """重写压缩逻辑,适配自定义 API 格式"""
- # 1. 提取文档内容(转为字符串数组,适配 candidates 参数)
- candidates = [doc.page_content for doc in documents]
- if not candidates:
- return []
- # 2. 构造 API 请求体(用 candidates 而非 documents)
- request_body = {
- "model": self.model,
- "query": query,
- "candidates": candidates,
- "top_n": self.top_n
- }
- # 3. 发送请求到自定义 API
- headers = {"Content-Type": "application/json"}
- try:
- response = requests.post(
- self.api_url,
- json=request_body,
- headers=headers,
- timeout=30
- )
- response.raise_for_status()
- result = response.json()
- except requests.exceptions.RequestException as e:
- print(f"❌ API 请求失败:{str(e)}")
- if hasattr(e, 'response') and e.response is not None:
- print(f"响应内容:{e.response.text}")
- return []
- # 4. 解析响应(核心修复:按 text 匹配原始 Document)
- reranked_docs = []
- # 从响应中获取排序后的 text 和 score
- sorted_results = result.get("results", [])[:self.top_n] # 取 Top-N 结果
- for item in sorted_results:
- sorted_text = item.get("text") # 排序后的文档内容
- relevance_score = item.get("score", 0.0) # 相关性分数
- # 找到原始 Document(通过内容完全匹配,保留元数据)
- # 注意:若文档内容有重复,会匹配第一个,建议原始文档避免重复
- matched_doc = next(
- (doc for doc in documents if doc.page_content == sorted_text),
- None # 未匹配到返回 None
- )
- if matched_doc:
- # 将分数存入元数据,方便后续查看
- matched_doc.metadata["relevance_score"] = float(relevance_score)
- reranked_docs.append(matched_doc)
- else:
- print(f"⚠️ 未找到匹配的原始文档:{sorted_text}")
- return reranked_docs
- def rerank_bridge_construction_docs():
- """中文交通路桥施工文档重排示例"""
- # 1. 初始化自定义重排器
- reranker = CustomReranker(
- api_url=CUSTOM_API_URL,
- model=RERANK_MODEL,
- top_n=TOP_N
- )
- # 2. 路桥施工相关文档(模拟检索结果)
- construction_docs = [
- Document(
- page_content="大跨度桥梁挂篮施工时,需严格控制挂篮的变形量,变形值不得超过设计允许的5mm。挂篮前移应匀速缓慢,速度控制在0.5m/h以内。",
- metadata={"source": "桥梁施工规范2025", "section": "挂篮施工章节"}
- ),
- Document(
- page_content="高速公路路基填筑应分层碾压,每层厚度不超过30cm,压实度需达到96%以上。碾压设备优先选用重型振动压路机。",
- metadata={"source": "路基工程施工手册", "section": "路基碾压工艺"}
- ),
- Document(
- page_content="挂篮施工的关键工序包括:底模安装、钢筋绑扎、预应力筋布设、混凝土浇筑、预应力张拉。其中预应力张拉需在混凝土强度达到设计值的90%后进行。",
- metadata={"source": "大跨度桥梁施工技术指南", "section": "挂篮施工关键工序"}
- ),
- Document(
- page_content="桥梁桩基钻孔施工中,泥浆比重应根据地质情况调整,粉质黏土地层泥浆比重控制在1.1~1.2之间。",
- metadata={"source": "桩基施工操作规程", "section": "泥浆指标控制"}
- ),
- Document(
- page_content="挂篮的承重结构采用高强度钢材制作,需进行严格的荷载试验,试验荷载为设计荷载的1.2倍,持荷时间不少于24小时。",
- metadata={"source": "桥梁施工安全技术规范", "section": "挂篮荷载试验"}
- ),
- Document(
- page_content="隧道施工中应加强通风,保证洞内氧气浓度不低于19.5%,有害气体浓度符合国家标准。",
- metadata={"source": "隧道工程施工规范", "section": "洞内通风要求"}
- )
- ]
- # 3. 施工相关查询词
- query = "大跨度桥梁挂篮施工的关键工序和质量控制要点"
- # 4. 执行重排
- reranked_docs = reranker.compress_documents(documents=construction_docs, query=query)
- # 5. 输出结果
- print("\n" + "="*50)
- print(f"查询词:{query}")
- print("="*50)
- print(f"重排后 Top-{len(reranked_docs)} 相关文档:\n")
- for idx, doc in enumerate(reranked_docs, 1):
- print(f"【第{idx}篇】")
- print(f"来源:{doc.metadata['source']} - {doc.metadata['section']}")
- print(f"相关性分数:{doc.metadata.get('relevance_score', '未返回'):.6f}")
- print(f"内容:{doc.page_content}\n")
- if __name__ == "__main__":
- rerank_bridge_construction_docs()
|