| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108 |
- from foundation.logger.loggering import server_logger as logger
- import os
- import time
- from tqdm import tqdm
- from typing import List, Dict, Any
- from foundation.models.base_online_platform import BaseApiPlatform
- class BaseVectorDB:
- """
- 向量数据库操作基类
- """
-
- def __init__(self , base_api_platform :BaseApiPlatform):
- self.base_api_platform = base_api_platform
- def text_to_vector(self, text: str) -> List[float]:
- """
- 将文本转换为向量
- """
- return self.base_api_platform.get_embeddings([text])[0]
-
- def document_standard(self, documents: List[Dict[str, Any]]):
- """
- 文档标准处理
- """
- raise NotImplementedError
-
- def add_document(self , param: Dict[str, Any] , document: Dict[str, Any]):
- """
- 单条添加文档
- param: 扩展参数信息,如:表名称等
- documents: 文档列表,包括元数据信息
- # 返回: 添加的文档ID列表
- """
- raise NotImplementedError
- def add_batch_documents(self , param: Dict[str, Any] , documents: List[Dict[str, Any]]):
- """
- 批量添加文档
- param: 扩展参数信息,如:表名称等
- documents: 文档列表,包括元数据信息
- # 返回: 添加的文档ID列表
- """
- raise NotImplementedError
- def add_tqdm_batch_documents(self , param: Dict[str, Any] , documents: List[Dict[str, Any]] , batch_size=10):
- """
- 批量添加文档(带进度条)
- param: 扩展参数信息,如:表名称等
- documents: 文档列表,包括元数据信息
- # 返回: 添加的文档ID列表
- """
-
- logger.info(f"Inserting {len(documents)} documents.")
- start_time = time.time()
- total_docs_inserted = 0
- total_batches = (len(documents) + batch_size - 1) // batch_size
- with tqdm(total=total_batches, desc="Inserting batches", unit="batch") as pbar:
- for i in range(0, len(documents), batch_size):
- batch = documents[i:i + batch_size]
- # 调用传入的插入函数
- self.add_batch_documents(param, batch)
- total_docs_inserted += len(batch)
- # 计算并显示当前的TPM
- elapsed_time = time.time() - start_time
- if elapsed_time > 0:
- tpm = (total_docs_inserted / elapsed_time) * 60
- pbar.set_postfix({"TPM": f"{tpm:.2f}"})
- pbar.update(1)
-
- def retriever(self, input_query):
- """
- 根据用户问题查询文档
- """
- raise NotImplementedError
- def similarity_search(self, param: Dict[str, Any], query_text: str , min_score=0.5 ,
- top_k=10, filters: Dict[str, Any] = None):
- """
- 根据用户问题查询文档
- """
- raise NotImplementedError
- def retriever(self, param: Dict[str, Any], query_text: str,
- top_k: int = 5, filters: Dict[str, Any] = None):
- """
- 根据用户问题查询文档
- """
- raise NotImplementedError
-
|