|
|
@@ -1,26 +1,102 @@
|
|
|
-"""OpenAI Embeddings 配置和客户端获取。"""
|
|
|
+"""MetaX Embeddings 配置和客户端获取。"""
|
|
|
from __future__ import annotations
|
|
|
|
|
|
from functools import lru_cache
|
|
|
+import time
|
|
|
+from typing import List
|
|
|
|
|
|
-from langchain_openai import OpenAIEmbeddings
|
|
|
+import requests
|
|
|
+from requests import HTTPError
|
|
|
+from requests import RequestException
|
|
|
|
|
|
from .setting import settings
|
|
|
|
|
|
|
|
|
+class MetaXEmbeddings:
|
|
|
+ """基于 requests 的嵌入客户端,兼容项目中现有调用方式。"""
|
|
|
+
|
|
|
+ def __init__(self, base_url: str, model: str, api_key: str, timeout: int = 60) -> None:
|
|
|
+ cleaned_base_url = base_url.rstrip("/")
|
|
|
+ # 允许配置为 .../v1/embeddings,内部统一到 base_url
|
|
|
+ if cleaned_base_url.endswith("/embeddings"):
|
|
|
+ cleaned_base_url = cleaned_base_url[: -len("/embeddings")]
|
|
|
+ self.base_url = cleaned_base_url
|
|
|
+ self.model = model
|
|
|
+ self.api_key = api_key or "dummy"
|
|
|
+ self.timeout = timeout
|
|
|
+ self.endpoint = f"{self.base_url}/embeddings"
|
|
|
+ self.session = requests.Session()
|
|
|
+ self.max_retries = 3
|
|
|
+ self.retry_delay_seconds = 1.5
|
|
|
+ self.max_batch_size = 16
|
|
|
+
|
|
|
+ def _post_embeddings(self, texts: List[str]) -> List[List[float]]:
|
|
|
+ headers = {
|
|
|
+ "Content-Type": "application/json",
|
|
|
+ "Authorization": f"Bearer {self.api_key}",
|
|
|
+ }
|
|
|
+ payload = {"input": texts, "model": self.model}
|
|
|
+ resp = self.session.post(self.endpoint, headers=headers, json=payload, timeout=self.timeout)
|
|
|
+ resp.raise_for_status()
|
|
|
+ data = resp.json().get("data", [])
|
|
|
+ if len(data) != len(texts):
|
|
|
+ raise RuntimeError(f"Expected {len(texts)} embeddings but got {len(data)}")
|
|
|
+ return [item["embedding"] for item in data]
|
|
|
+
|
|
|
+ def _request_embeddings(self, texts: List[str]) -> List[List[float]]:
|
|
|
+ if not texts:
|
|
|
+ return []
|
|
|
+ vectors: List[List[float]] = []
|
|
|
+ for i in range(0, len(texts), self.max_batch_size):
|
|
|
+ chunk = texts[i : i + self.max_batch_size]
|
|
|
+ last_exc: Exception | None = None
|
|
|
+ for attempt in range(1, self.max_retries + 1):
|
|
|
+ try:
|
|
|
+ vectors.extend(self._post_embeddings(chunk))
|
|
|
+ last_exc = None
|
|
|
+ break
|
|
|
+ except HTTPError as exc:
|
|
|
+ response = exc.response
|
|
|
+ status = response.status_code if response is not None else "unknown"
|
|
|
+ body = ""
|
|
|
+ if response is not None and response.text:
|
|
|
+ # 保留后端关键报错信息,避免日志过长
|
|
|
+ body = response.text.strip().replace("\n", " ")[:500]
|
|
|
+ last_exc = RuntimeError(
|
|
|
+ "Embedding request failed: "
|
|
|
+ f"endpoint={self.endpoint}, model={self.model}, status={status}, body={body!r}"
|
|
|
+ )
|
|
|
+ except RequestException as exc:
|
|
|
+ last_exc = RuntimeError(
|
|
|
+ "Embedding request failed: "
|
|
|
+ f"endpoint={self.endpoint}, model={self.model}, error={exc!r}"
|
|
|
+ )
|
|
|
+ if attempt < self.max_retries:
|
|
|
+ time.sleep(self.retry_delay_seconds * attempt)
|
|
|
+ if last_exc is not None:
|
|
|
+ raise last_exc
|
|
|
+ return vectors
|
|
|
+
|
|
|
+ def _embed(self, text: str) -> List[float]:
|
|
|
+ if not text:
|
|
|
+ return []
|
|
|
+ return self._request_embeddings([text])[0]
|
|
|
+
|
|
|
+ def embed_query(self, text: str) -> List[float]:
|
|
|
+ return self._embed(text)
|
|
|
+
|
|
|
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
|
+ return self._request_embeddings(texts)
|
|
|
+
|
|
|
+
|
|
|
@lru_cache(maxsize=1)
|
|
|
-def get_embeddings() -> OpenAIEmbeddings:
|
|
|
- """
|
|
|
- 创建或返回已缓存的 OpenAI Embeddings 客户端。
|
|
|
-
|
|
|
- Returns:
|
|
|
- OpenAIEmbeddings 实例
|
|
|
- """
|
|
|
+def get_embeddings() -> MetaXEmbeddings:
|
|
|
+ """创建或返回已缓存的 MetaX Embeddings 客户端。"""
|
|
|
if not settings.EMBEDDING_BASE_URL or not settings.EMBEDDING_MODEL:
|
|
|
raise ValueError("Embedding configuration is incomplete")
|
|
|
-
|
|
|
- return OpenAIEmbeddings(
|
|
|
+
|
|
|
+ return MetaXEmbeddings(
|
|
|
base_url=settings.EMBEDDING_BASE_URL,
|
|
|
model=settings.EMBEDDING_MODEL,
|
|
|
api_key=settings.EMBEDDING_API_KEY,
|
|
|
- )
|
|
|
+ )
|