base_vector.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. from foundation.observability.logger.loggering import server_logger as logger
  2. import os
  3. import time
  4. from tqdm import tqdm
  5. from typing import List, Dict, Any
  6. class BaseVectorDB:
  7. """
  8. 向量数据库操作基类
  9. """
  10. def text_to_vector(self, text: str) -> List[float]:
  11. """
  12. 将文本转换为向量
  13. """
  14. return self.base_api_platform.get_embeddings([text])[0]
  15. def document_standard(self, documents: List[Dict[str, Any]]):
  16. """
  17. 文档标准处理
  18. """
  19. raise NotImplementedError
  20. def add_document(self , param: Dict[str, Any] , document: Dict[str, Any]):
  21. """
  22. 单条添加文档
  23. param: 扩展参数信息,如:表名称等
  24. documents: 文档列表,包括元数据信息
  25. # 返回: 添加的文档ID列表
  26. """
  27. raise NotImplementedError
  28. def add_batch_documents(self , param: Dict[str, Any] , documents: List[Dict[str, Any]]):
  29. """
  30. 批量添加文档
  31. param: 扩展参数信息,如:表名称等
  32. documents: 文档列表,包括元数据信息
  33. # 返回: 添加的文档ID列表
  34. """
  35. raise NotImplementedError
  36. def add_tqdm_batch_documents(self , param: Dict[str, Any] , documents: List[Dict[str, Any]] , batch_size=10):
  37. """
  38. 批量添加文档(带进度条)
  39. param: 扩展参数信息,如:表名称等
  40. documents: 文档列表,包括元数据信息
  41. # 返回: 添加的文档ID列表
  42. """
  43. logger.info(f"Inserting {len(documents)} documents.")
  44. start_time = time.time()
  45. total_docs_inserted = 0
  46. total_batches = (len(documents) + batch_size - 1) // batch_size
  47. with tqdm(total=total_batches, desc="Inserting batches", unit="batch") as pbar:
  48. for i in range(0, len(documents), batch_size):
  49. batch = documents[i:i + batch_size]
  50. # 调用传入的插入函数
  51. self.add_batch_documents(param, batch)
  52. total_docs_inserted += len(batch)
  53. # 计算并显示当前的TPM
  54. elapsed_time = time.time() - start_time
  55. if elapsed_time > 0:
  56. tpm = (total_docs_inserted / elapsed_time) * 60
  57. pbar.set_postfix({"TPM": f"{tpm:.2f}"})
  58. pbar.update(1)
  59. def retriever(self, input_query):
  60. """
  61. 根据用户问题查询文档
  62. """
  63. raise NotImplementedError
  64. def similarity_search(self, param: Dict[str, Any], query_text: str , min_score=0.5 ,
  65. top_k=10, filters: Dict[str, Any] = None):
  66. """
  67. 根据用户问题查询文档
  68. """
  69. raise NotImplementedError
  70. def retriever(self, param: Dict[str, Any], query_text: str,
  71. top_k: int = 5, filters: Dict[str, Any] = None):
  72. """
  73. 根据用户问题查询文档
  74. """
  75. raise NotImplementedError