|
@@ -6,6 +6,7 @@ Embedding 客户端
|
|
|
统一通过 model_handler 获取 Embedding 模型,配置从 config.ini 读取
|
|
统一通过 model_handler 获取 Embedding 模型,配置从 config.ini 读取
|
|
|
"""
|
|
"""
|
|
|
|
|
|
|
|
|
|
+import asyncio
|
|
|
import math
|
|
import math
|
|
|
import re
|
|
import re
|
|
|
from typing import List, Optional, Tuple
|
|
from typing import List, Optional, Tuple
|
|
@@ -18,10 +19,15 @@ from foundation.observability.logger.loggering import review_logger as logger
|
|
|
class EmbeddingClient:
|
|
class EmbeddingClient:
|
|
|
"""Embedding模型客户端,用于计算文本相似度"""
|
|
"""Embedding模型客户端,用于计算文本相似度"""
|
|
|
|
|
|
|
|
|
|
+ # 连续失败次数阈值,超过后清除缓存触发降级
|
|
|
|
|
+ _FAILURE_THRESHOLD = 3
|
|
|
|
|
+ # 重试次数
|
|
|
|
|
+ _MAX_RETRIES = 2
|
|
|
|
|
+
|
|
|
def __init__(self):
|
|
def __init__(self):
|
|
|
"""初始化 Embedding 客户端,通过 model_handler 获取模型"""
|
|
"""初始化 Embedding 客户端,通过 model_handler 获取模型"""
|
|
|
- # 统一通过 model_handler 获取 Embedding 模型
|
|
|
|
|
self._embedding_model = None
|
|
self._embedding_model = None
|
|
|
|
|
+ self._consecutive_failures = 0
|
|
|
|
|
|
|
|
@property
|
|
@property
|
|
|
def embedding_model(self):
|
|
def embedding_model(self):
|
|
@@ -30,25 +36,51 @@ class EmbeddingClient:
|
|
|
self._embedding_model = model_handler.get_embedding_model()
|
|
self._embedding_model = model_handler.get_embedding_model()
|
|
|
return self._embedding_model
|
|
return self._embedding_model
|
|
|
|
|
|
|
|
|
|
+ def _invalidate_cache(self):
|
|
|
|
|
+ """清除本地和 model_handler 的 embedding 缓存,触发降级重新初始化"""
|
|
|
|
|
+ self._embedding_model = None
|
|
|
|
|
+ self._consecutive_failures = 0
|
|
|
|
|
+ # 清除 model_handler 中的 embedding 缓存,使下次 get_embedding_model 重新走初始化+降级逻辑
|
|
|
|
|
+ for key in list(model_handler._model_cache.keys()):
|
|
|
|
|
+ if "embed" in key.lower():
|
|
|
|
|
+ del model_handler._model_cache[key]
|
|
|
|
|
+ logger.info(f"已清除 model_handler embedding 缓存: {key}")
|
|
|
|
|
+
|
|
|
async def get_embedding(self, text: str) -> Optional[List[float]]:
|
|
async def get_embedding(self, text: str) -> Optional[List[float]]:
|
|
|
- """获取文本的embedding向量"""
|
|
|
|
|
- try:
|
|
|
|
|
- # 使用 model_handler 提供的 embedding 模型
|
|
|
|
|
- embedding = self.embedding_model.embed_query(text)
|
|
|
|
|
- return embedding
|
|
|
|
|
- except Exception as e:
|
|
|
|
|
- logger.error(f"Embedding API调用失败: {e}")
|
|
|
|
|
- return None
|
|
|
|
|
|
|
+ """获取文本的embedding向量,带重试和缓存失效机制"""
|
|
|
|
|
+ for attempt in range(self._MAX_RETRIES + 1):
|
|
|
|
|
+ try:
|
|
|
|
|
+ embedding = self.embedding_model.embed_query(text)
|
|
|
|
|
+ self._consecutive_failures = 0
|
|
|
|
|
+ return embedding
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ if attempt < self._MAX_RETRIES:
|
|
|
|
|
+ await asyncio.sleep(1 * (attempt + 1))
|
|
|
|
|
+ continue
|
|
|
|
|
+ self._consecutive_failures += 1
|
|
|
|
|
+ logger.error(f"Embedding API调用失败 (连续第{self._consecutive_failures}次): {e}")
|
|
|
|
|
+ if self._consecutive_failures >= self._FAILURE_THRESHOLD:
|
|
|
|
|
+ logger.warning("Embedding连续失败超过阈值,清除缓存触发降级")
|
|
|
|
|
+ self._invalidate_cache()
|
|
|
|
|
+ return None
|
|
|
|
|
|
|
|
async def get_embeddings_batch(self, texts: List[str]) -> List[Optional[List[float]]]:
|
|
async def get_embeddings_batch(self, texts: List[str]) -> List[Optional[List[float]]]:
|
|
|
- """批量获取文本的embedding向量"""
|
|
|
|
|
- try:
|
|
|
|
|
- # 使用 model_handler 提供的 embedding 模型
|
|
|
|
|
- embeddings = self.embedding_model.embed_documents(texts)
|
|
|
|
|
- return embeddings
|
|
|
|
|
- except Exception as e:
|
|
|
|
|
- logger.error(f"Embedding API批量调用失败: {e}")
|
|
|
|
|
- return [None] * len(texts)
|
|
|
|
|
|
|
+ """批量获取文本的embedding向量,带重试和缓存失效机制"""
|
|
|
|
|
+ for attempt in range(self._MAX_RETRIES + 1):
|
|
|
|
|
+ try:
|
|
|
|
|
+ embeddings = self.embedding_model.embed_documents(texts)
|
|
|
|
|
+ self._consecutive_failures = 0
|
|
|
|
|
+ return embeddings
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ if attempt < self._MAX_RETRIES:
|
|
|
|
|
+ await asyncio.sleep(1 * (attempt + 1))
|
|
|
|
|
+ continue
|
|
|
|
|
+ self._consecutive_failures += 1
|
|
|
|
|
+ logger.error(f"Embedding API批量调用失败 (连续第{self._consecutive_failures}次): {e}")
|
|
|
|
|
+ if self._consecutive_failures >= self._FAILURE_THRESHOLD:
|
|
|
|
|
+ logger.warning("Embedding连续失败超过阈值,清除缓存触发降级")
|
|
|
|
|
+ self._invalidate_cache()
|
|
|
|
|
+ return [None] * len(texts)
|
|
|
|
|
|
|
|
def cosine_similarity(self, vec1: List[float], vec2: List[float]) -> float:
|
|
def cosine_similarity(self, vec1: List[float], vec2: List[float]) -> float:
|
|
|
"""计算两个向量的余弦相似度"""
|
|
"""计算两个向量的余弦相似度"""
|