from foundation.logger.loggering import server_logger as logger import os import time from tqdm import tqdm from typing import List, Dict, Any from foundation.models.base_online_platform import BaseApiPlatform class BaseVectorDB: """ 向量数据库操作基类 """ def __init__(self , base_api_platform :BaseApiPlatform): self.base_api_platform = base_api_platform def text_to_vector(self, text: str) -> List[float]: """ 将文本转换为向量 """ return self.base_api_platform.get_embeddings([text])[0] def document_standard(self, documents: List[Dict[str, Any]]): """ 文档标准处理 """ raise NotImplementedError def add_document(self , param: Dict[str, Any] , document: Dict[str, Any]): """ 单条添加文档 param: 扩展参数信息,如:表名称等 documents: 文档列表,包括元数据信息 # 返回: 添加的文档ID列表 """ raise NotImplementedError def add_batch_documents(self , param: Dict[str, Any] , documents: List[Dict[str, Any]]): """ 批量添加文档 param: 扩展参数信息,如:表名称等 documents: 文档列表,包括元数据信息 # 返回: 添加的文档ID列表 """ raise NotImplementedError def add_tqdm_batch_documents(self , param: Dict[str, Any] , documents: List[Dict[str, Any]] , batch_size=10): """ 批量添加文档(带进度条) param: 扩展参数信息,如:表名称等 documents: 文档列表,包括元数据信息 # 返回: 添加的文档ID列表 """ logger.info(f"Inserting {len(documents)} documents.") start_time = time.time() total_docs_inserted = 0 total_batches = (len(documents) + batch_size - 1) // batch_size with tqdm(total=total_batches, desc="Inserting batches", unit="batch") as pbar: for i in range(0, len(documents), batch_size): batch = documents[i:i + batch_size] # 调用传入的插入函数 self.add_batch_documents(param, batch) total_docs_inserted += len(batch) # 计算并显示当前的TPM elapsed_time = time.time() - start_time if elapsed_time > 0: tpm = (total_docs_inserted / elapsed_time) * 60 pbar.set_postfix({"TPM": f"{tpm:.2f}"}) pbar.update(1) def retriever(self, input_query): """ 根据用户问题查询文档 """ raise NotImplementedError def similarity_search(self, param: Dict[str, Any], query_text: str , min_score=0.5 , top_k=10, filters: Dict[str, Any] = None): """ 根据用户问题查询文档 """ raise NotImplementedError def retriever(self, param: Dict[str, Any], query_text: str, top_k: int = 5, filters: Dict[str, Any] = None): """ 根据用户问题查询文档 """ raise NotImplementedError