base_vector.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. from foundation.logger.loggering import server_logger as logger
  2. import os
  3. import time
  4. from tqdm import tqdm
  5. from typing import List, Dict, Any
  6. from foundation.models.base_online_platform import BaseApiPlatform
  7. class BaseVectorDB:
  8. """
  9. 向量数据库操作基类
  10. """
  11. def __init__(self , base_api_platform :BaseApiPlatform):
  12. self.base_api_platform = base_api_platform
  13. def text_to_vector(self, text: str) -> List[float]:
  14. """
  15. 将文本转换为向量
  16. """
  17. return self.base_api_platform.get_embeddings([text])[0]
  18. def document_standard(self, documents: List[Dict[str, Any]]):
  19. """
  20. 文档标准处理
  21. """
  22. raise NotImplementedError
  23. def add_document(self , param: Dict[str, Any] , document: Dict[str, Any]):
  24. """
  25. 单条添加文档
  26. param: 扩展参数信息,如:表名称等
  27. documents: 文档列表,包括元数据信息
  28. # 返回: 添加的文档ID列表
  29. """
  30. raise NotImplementedError
  31. def add_batch_documents(self , param: Dict[str, Any] , documents: List[Dict[str, Any]]):
  32. """
  33. 批量添加文档
  34. param: 扩展参数信息,如:表名称等
  35. documents: 文档列表,包括元数据信息
  36. # 返回: 添加的文档ID列表
  37. """
  38. raise NotImplementedError
  39. def add_tqdm_batch_documents(self , param: Dict[str, Any] , documents: List[Dict[str, Any]] , batch_size=10):
  40. """
  41. 批量添加文档(带进度条)
  42. param: 扩展参数信息,如:表名称等
  43. documents: 文档列表,包括元数据信息
  44. # 返回: 添加的文档ID列表
  45. """
  46. logger.info(f"Inserting {len(documents)} documents.")
  47. start_time = time.time()
  48. total_docs_inserted = 0
  49. total_batches = (len(documents) + batch_size - 1) // batch_size
  50. with tqdm(total=total_batches, desc="Inserting batches", unit="batch") as pbar:
  51. for i in range(0, len(documents), batch_size):
  52. batch = documents[i:i + batch_size]
  53. # 调用传入的插入函数
  54. self.add_batch_documents(param, batch)
  55. total_docs_inserted += len(batch)
  56. # 计算并显示当前的TPM
  57. elapsed_time = time.time() - start_time
  58. if elapsed_time > 0:
  59. tpm = (total_docs_inserted / elapsed_time) * 60
  60. pbar.set_postfix({"TPM": f"{tpm:.2f}"})
  61. pbar.update(1)
  62. def retriever(self, input_query):
  63. """
  64. 根据用户问题查询文档
  65. """
  66. raise NotImplementedError
  67. def similarity_search(self, param: Dict[str, Any], query_text: str , min_score=0.5 ,
  68. top_k=10, filters: Dict[str, Any] = None):
  69. """
  70. 根据用户问题查询文档
  71. """
  72. raise NotImplementedError
  73. def retriever(self, param: Dict[str, Any], query_text: str,
  74. top_k: int = 5, filters: Dict[str, Any] = None):
  75. """
  76. 根据用户问题查询文档
  77. """
  78. raise NotImplementedError