silicon_flow.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. import os
  2. import sys
  3. sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
  4. import requests
  5. from dotenv import load_dotenv
  6. from foundation.models.base_online_platform import BaseApiPlatform
  7. from foundation.base.config import config_handler
  8. from foundation.logger.loggering import server_logger
  9. from foundation.utils.common import handler_err
  10. from openai import OpenAI
  11. from langchain_core.embeddings import Embeddings
  12. from chromadb.utils.embedding_functions import EmbeddingFunction
  13. from typing import List
  14. import numpy as np
  15. class SiliconFlowEmbeddings(Embeddings):
  16. """
  17. LangChain 兼容的硅基流动嵌入模型客户端
  18. 使用方式:
  19. embeddings = SiliconFlowEmbeddings(
  20. model="netease-youdao/bce-embedding-base_v1",
  21. api_key="sk-..."
  22. )
  23. vectors = embeddings.embed_documents(["文本1", "文本2"])
  24. """
  25. def __init__(self, base_url: str, api_key: str, embed_model_id: str):
  26. self.model = embed_model_id
  27. self.api_key = api_key
  28. if not self.api_key:
  29. raise ValueError("必须提供 api_key 或设置环境变量 SILICONFLOW_API_KEY")
  30. self.client = OpenAI(
  31. api_key=self.api_key,
  32. base_url=base_url
  33. )
  34. def embed_documents(self, texts: List[str]) -> List[List[float]]:
  35. """对文档列表进行向量化"""
  36. if not texts:
  37. return []
  38. response = self.client.embeddings.create(
  39. model=self.model,
  40. input=texts
  41. )
  42. return [item.embedding for item in response.data]
  43. def embed_query(self, text: str) -> List[float]:
  44. """对查询文本进行向量化"""
  45. return self.embed_documents([text])[0]
  46. class ChromaSiliconFlowEmbedding(EmbeddingFunction):
  47. """
  48. 将SiliconFlowEmbeddings适配到ChromaDB的嵌入函数接口
  49. """
  50. def __init__(self, embeddings):
  51. self.embeddings = embeddings
  52. def __call__(self, input: List[str]) -> List[List[float]]:
  53. raw_embeddings = self.embeddings.embed_documents(input) # 关键添加
  54. return self.normalized_embeddings(raw_embeddings)
  55. def embed_documents(self, input: List[str]) -> List[List[float]]:
  56. raw_embeddings = self.embeddings.embed_documents(input) # 关键添加
  57. return self.normalized_embeddings(raw_embeddings)
  58. def embed_query(self, text: str) -> List[float]:
  59. """对查询文本进行向量化"""
  60. raw_embeddings = self.embeddings.embed_documents([text])[0]
  61. return self.normalized_embeddings(raw_embeddings)
  62. def normalized_embeddings(self , raw_embeddings):
  63. # L2归一化处理
  64. normalized = []
  65. for vector in raw_embeddings:
  66. norm = np.linalg.norm(vector)
  67. if norm > 0:
  68. normalized.append(vector / norm)
  69. else:
  70. normalized.append(vector)
  71. return normalized
  72. class SiliconFlowAPI(BaseApiPlatform):
  73. def __init__(self , trace_id=""):
  74. self.trace_id = trace_id
  75. self.config_prefix = "siliconflow"
  76. self.model_server_url = config_handler.get(self.config_prefix, "SLCF_MODEL_SERVER_URL")
  77. self.api_key = config_handler.get(self.config_prefix, "SLCF_API_KEY")
  78. self.embed_url = self.model_server_url +"/embeddings" #/embeddings
  79. self.rerank_url = self.model_server_url +"/rerank" #/rerank
  80. self.embed_model_id = config_handler.get(self.config_prefix, "SLCF_EMBED_MODEL_ID")
  81. self.rerank_model_id = config_handler.get(self.config_prefix, "SLCF_REANKER_MODEL_ID")
  82. server_logger.info(f"SiliconFlowAPI -> embed_url:{self.embed_url},rerank_url:{self.rerank_url}")
  83. server_logger.info(f"SiliconFlowAPI -> embed_model_id:{self.embed_model_id},rerank_model_id:{self.rerank_model_id}")
  84. self.client = self.get_openai_client(self.model_server_url, self.api_key)
  85. # 创建LangChain兼容的嵌入对象
  86. langchain_embeddings = SiliconFlowEmbeddings(base_url = self.model_server_url , api_key=self.api_key , embed_model_id=self.embed_model_id)
  87. self.embed_model = ChromaSiliconFlowEmbedding(embeddings=langchain_embeddings)
  88. def get_embed_model(self):
  89. """
  90. 获取嵌入模型
  91. """
  92. return self.embed_model
  93. def get_embeddings(self, texts: list[str]):
  94. """获取文本向量(embedding)"""
  95. try:
  96. response = self.client.embeddings.create(
  97. model=self.embed_model_id, # 指定向量模型
  98. input=texts if isinstance(texts, list) else [texts]
  99. )
  100. # 返回 embeddings 列表
  101. return [data.embedding for data in response.data]
  102. except Exception as e:
  103. handler_err(server_logger, trace_id=self.trace_id, err=e, err_name='Embedding 调用失败')
  104. raise
  105. def rerank(self, input_query: str, documents: list, top_n: int = 5, return_documents: bool = True):
  106. """
  107. 使用 BGE 重排序模型进行相关性打分
  108. 使用重排序模型对候选文档进行排序
  109. :param query: 用户查询语句
  110. :param documents: 候选文本列表
  111. :param top_n: 返回前 N 个结果
  112. :return: 排序后的结果列表,包含文本和相似度分数
  113. """
  114. try:
  115. headers = {
  116. "Authorization": f"Bearer {self.api_key}",
  117. "Content-Type": "application/json"
  118. }
  119. payload = {
  120. "model": self.rerank_model_id,
  121. "query": input_query,
  122. "documents": documents,
  123. "top_n": top_n,
  124. "return_documents": return_documents
  125. }
  126. response = requests.post(self.rerank_url, json=payload, headers=headers)
  127. response.raise_for_status()
  128. data = response.json()
  129. results = []
  130. for item in data['results']:
  131. results.append({
  132. "index": item['index'],
  133. "relevance_score": item['relevance_score'],
  134. "document": item.get('document', {}).get('text', None)
  135. })
  136. return results
  137. except Exception as e:
  138. handler_err(server_logger, trace_id=self.trace_id, err=e, err_name='重排序调用失败')
  139. raise
  140. # 使用示例
  141. if __name__ == "__main__":
  142. # 初始化客户端(需提前设置环境变量 SILICONFLOW_API_KEY)
  143. client = SiliconFlowAPI()
  144. # 示例1:向量化文本
  145. texts = ["奶牛养殖技术", "牛肉市场价格分析"]
  146. embeddings = client.get_embeddings(texts)
  147. print(f"向量维度:{len(embeddings[0])}") # 输出向量维度
  148. # 示例2:重排序文档
  149. query = "如何提高牛奶产量?"
  150. documents = [
  151. "奶牛饲料配比指南",
  152. "牧场管理规范",
  153. "牛奶加工工艺流程",
  154. "提高产奶量的10个技巧"
  155. ]
  156. rerank_results = client.rerank(query, documents)
  157. print("\n重排序结果:")
  158. for result in sorted(rerank_results, key=lambda x: x['relevance_score'], reverse=True):
  159. print(f"{result['index']} (得分: {result['relevance_score']:.2f}): {documents[result['index']]}")