silicon_flow.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. import os
  2. import sys
  3. sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
  4. import requests
  5. from dotenv import load_dotenv
  6. from foundation.models.base_online_platform import BaseApiPlatform
  7. from foundation.base.config import config_handler
  8. from foundation.logger.loggering import server_logger
  9. from foundation.utils.common import handler_err
  10. from openai import OpenAI
  11. from langchain_core.embeddings import Embeddings
  12. #from chromadb.utils.embedding_functions import EmbeddingFunction
  13. from typing import List
  14. import numpy as np
  15. class SiliconFlowEmbeddings(Embeddings):
  16. """
  17. LangChain 兼容的硅基流动嵌入模型客户端
  18. 使用方式:
  19. embeddings = SiliconFlowEmbeddings(
  20. model="netease-youdao/bce-embedding-base_v1",
  21. api_key="sk-..."
  22. )
  23. vectors = embeddings.embed_documents(["文本1", "文本2"])
  24. """
  25. def __init__(self, base_url: str, api_key: str, embed_model_id: str):
  26. self.model = embed_model_id
  27. self.api_key = api_key
  28. if not self.api_key:
  29. raise ValueError("必须提供 api_key 或设置环境变量 SILICONFLOW_API_KEY")
  30. self.client = OpenAI(
  31. api_key=self.api_key,
  32. base_url=base_url
  33. )
  34. def embed_documents(self, texts: List[str]) -> List[List[float]]:
  35. """对文档列表进行向量化"""
  36. if not texts:
  37. return []
  38. response = self.client.embeddings.create(
  39. model=self.model,
  40. input=texts
  41. )
  42. return [item.embedding for item in response.data]
  43. def embed_query(self, text: str) -> List[float]:
  44. """对查询文本进行向量化"""
  45. return self.embed_documents([text])[0]
  46. class SiliconFlowAPI(BaseApiPlatform):
  47. def __init__(self , trace_id=""):
  48. self.trace_id = trace_id
  49. self.config_prefix = "siliconflow"
  50. self.model_server_url = config_handler.get(self.config_prefix, "SLCF_MODEL_SERVER_URL")
  51. self.api_key = config_handler.get(self.config_prefix, "SLCF_API_KEY")
  52. self.embed_url = self.model_server_url +"/embeddings" #/embeddings
  53. self.rerank_url = self.model_server_url +"/rerank" #/rerank
  54. self.embed_model_id = config_handler.get(self.config_prefix, "SLCF_EMBED_MODEL_ID")
  55. self.rerank_model_id = config_handler.get(self.config_prefix, "SLCF_REANKER_MODEL_ID")
  56. server_logger.info(f"SiliconFlowAPI -> embed_url:{self.embed_url},rerank_url:{self.rerank_url}")
  57. server_logger.info(f"SiliconFlowAPI -> embed_model_id:{self.embed_model_id},rerank_model_id:{self.rerank_model_id}")
  58. self.client = self.get_openai_client(self.model_server_url, self.api_key)
  59. # 创建LangChain兼容的嵌入对象
  60. langchain_embeddings = SiliconFlowEmbeddings(base_url = self.model_server_url , api_key=self.api_key , embed_model_id=self.embed_model_id)
  61. #self.embed_model = ChromaSiliconFlowEmbedding(embeddings=langchain_embeddings)
  62. def get_embed_model(self):
  63. """
  64. 获取嵌入模型
  65. """
  66. return self.embed_model
  67. def get_embeddings(self, texts: list[str]):
  68. """获取文本向量(embedding)"""
  69. try:
  70. response = self.client.embeddings.create(
  71. model=self.embed_model_id, # 指定向量模型
  72. input=texts if isinstance(texts, list) else [texts]
  73. )
  74. # 返回 embeddings 列表
  75. return [data.embedding for data in response.data]
  76. except Exception as e:
  77. handler_err(server_logger, trace_id=self.trace_id, err=e, err_name='Embedding 调用失败')
  78. raise
  79. def rerank(self, input_query: str, documents: list, top_n: int = 5, return_documents: bool = True):
  80. """
  81. 使用 BGE 重排序模型进行相关性打分
  82. 使用重排序模型对候选文档进行排序
  83. :param query: 用户查询语句
  84. :param documents: 候选文本列表
  85. :param top_n: 返回前 N 个结果
  86. :return: 排序后的结果列表,包含文本和相似度分数
  87. """
  88. try:
  89. headers = {
  90. "Authorization": f"Bearer {self.api_key}",
  91. "Content-Type": "application/json"
  92. }
  93. payload = {
  94. "model": self.rerank_model_id,
  95. "query": input_query,
  96. "documents": documents,
  97. "top_n": top_n,
  98. "return_documents": return_documents
  99. }
  100. response = requests.post(self.rerank_url, json=payload, headers=headers)
  101. response.raise_for_status()
  102. data = response.json()
  103. results = []
  104. for item in data['results']:
  105. results.append({
  106. "index": item['index'],
  107. "relevance_score": item['relevance_score'],
  108. "document": item.get('document', {}).get('text', None)
  109. })
  110. return results
  111. except Exception as e:
  112. handler_err(server_logger, trace_id=self.trace_id, err=e, err_name='重排序调用失败')
  113. raise
  114. # 使用示例
  115. if __name__ == "__main__":
  116. # 初始化客户端(需提前设置环境变量 SILICONFLOW_API_KEY)
  117. client = SiliconFlowAPI()
  118. # 示例1:向量化文本
  119. texts = ["奶牛养殖技术", "牛肉市场价格分析"]
  120. embeddings = client.get_embeddings(texts)
  121. print(f"向量维度:{len(embeddings[0])}") # 输出向量维度
  122. # 示例2:重排序文档
  123. query = "如何提高牛奶产量?"
  124. documents = [
  125. "奶牛饲料配比指南",
  126. "牧场管理规范",
  127. "牛奶加工工艺流程",
  128. "提高产奶量的10个技巧"
  129. ]
  130. rerank_results = client.rerank(query, documents)
  131. print("\n重排序结果:")
  132. for result in sorted(rerank_results, key=lambda x: x['relevance_score'], reverse=True):
  133. print(f"{result['index']} (得分: {result['relevance_score']:.2f}): {documents[result['index']]}")