base_vector.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. # coding=utf-8
  2. """
  3. @project: maxkb
  4. @Author:虎
  5. @file: base_vector.py
  6. @date:2023/10/18 19:16
  7. @desc:
  8. """
  9. import re
  10. import threading
  11. from abc import ABC, abstractmethod
  12. from functools import reduce
  13. from typing import List, Dict
  14. from langchain_core.embeddings import Embeddings
  15. from common.chunk import text_to_chunk
  16. from common.utils.common import sub_array
  17. from knowledge.models import SourceType, SearchMode
  18. lock = threading.Lock()
  19. def chunk_data(data: Dict):
  20. if str(data.get('source_type')) == str(SourceType.PARAGRAPH.value):
  21. text = data.get('text')
  22. chunk_list = data.get('chunks') if data.get('chunks') else text_to_chunk(text)
  23. return [{**data, 'text': chunk} for chunk in chunk_list]
  24. return [data]
  25. def chunk_data_list(data_list: List[Dict]):
  26. result = [chunk_data(data) for data in data_list]
  27. return reduce(lambda x, y: [*x, *y], result, [])
  28. # 预编译正则,性能更好
  29. RE_EMOJI = re.compile(
  30. r"[\U0001F300-\U0001FAFF]" # Emoji
  31. r"|[\u2600-\u27BF]" # Dingbats / Symbols(⚓ 在这)
  32. r"|[\uFE0E\uFE0F]", # Variation Selectors
  33. flags=re.UNICODE
  34. )
  35. RE_WHITESPACE = re.compile(r"\s+")
  36. def normalize_for_embedding(text: str) -> str:
  37. if not text:
  38. return ""
  39. text = RE_EMOJI.sub("", text)
  40. text = RE_WHITESPACE.sub(" ", text)
  41. return text.strip()
  42. class BaseVectorStore(ABC):
  43. vector_exists = False
  44. @abstractmethod
  45. def vector_is_create(self) -> bool:
  46. """
  47. 判断向量库是否创建
  48. :return: 是否创建向量库
  49. """
  50. pass
  51. @abstractmethod
  52. def vector_create(self):
  53. """
  54. 创建 向量库
  55. :return:
  56. """
  57. pass
  58. def save_pre_handler(self):
  59. """
  60. 插入前置处理器 主要是判断向量库是否创建
  61. :return: True
  62. """
  63. if not BaseVectorStore.vector_exists:
  64. if not self.vector_is_create():
  65. self.vector_create()
  66. BaseVectorStore.vector_exists = True
  67. return True
  68. def save(self, text, source_type: SourceType, knowledge_id: str, document_id: str, paragraph_id: str,
  69. source_id: str,
  70. is_active: bool,
  71. embedding: Embeddings):
  72. """
  73. 插入向量数据
  74. :param source_id: 资源id
  75. :param knowledge_id: 知识库id
  76. :param text: 文本
  77. :param source_type: 资源类型
  78. :param document_id: 文档id
  79. :param is_active: 是否禁用
  80. :param embedding: 向量化处理器
  81. :param paragraph_id 段落id
  82. :return: bool
  83. """
  84. self.save_pre_handler()
  85. data = {'document_id': document_id, 'paragraph_id': paragraph_id, 'knowledge_id': knowledge_id,
  86. 'is_active': is_active, 'source_id': source_id, 'source_type': source_type, 'text': text}
  87. chunk_list = chunk_data(data)
  88. result = sub_array(chunk_list)
  89. for child_array in result:
  90. self._batch_save(child_array, embedding, lambda: False)
  91. def batch_save(self, data_list: List[Dict], embedding: Embeddings, is_the_task_interrupted):
  92. """
  93. 批量插入
  94. @param data_list: 数据列表
  95. @param embedding: 向量化处理器
  96. @param is_the_task_interrupted: 判断是否中断任务
  97. :return: bool
  98. """
  99. self.save_pre_handler()
  100. chunk_list = chunk_data_list(data_list)
  101. result = sub_array(chunk_list)
  102. for child_array in result:
  103. if not is_the_task_interrupted():
  104. self._batch_save(child_array, embedding, is_the_task_interrupted)
  105. else:
  106. break
  107. @abstractmethod
  108. def _save(self, text, source_type: SourceType, knowledge_id: str, document_id: str, paragraph_id: str,
  109. source_id: str,
  110. is_active: bool,
  111. embedding: Embeddings):
  112. pass
  113. @abstractmethod
  114. def _batch_save(self, text_list: List[Dict], embedding: Embeddings, is_the_task_interrupted):
  115. pass
  116. def search(self, query_text, knowledge_id_list: list[str], exclude_document_id_list: list[str],
  117. exclude_paragraph_list: list[str],
  118. is_active: bool,
  119. embedding: Embeddings):
  120. if knowledge_id_list is None or len(knowledge_id_list) == 0:
  121. return []
  122. query_text = normalize_for_embedding(query_text)
  123. embedding_query = embedding.embed_query(query_text)
  124. result = self.query(embedding_query, knowledge_id_list, exclude_document_id_list, exclude_paragraph_list,
  125. is_active, 1, 3, 0.65)
  126. return result[0]
  127. @abstractmethod
  128. def query(self, query_text: str, query_embedding: List[float], knowledge_id_list: list[str],
  129. document_id_list: list[str] | None,
  130. exclude_document_id_list: list[str],
  131. exclude_paragraph_list: list[str], is_active: bool, top_n: int, similarity: float,
  132. search_mode: SearchMode):
  133. pass
  134. @abstractmethod
  135. def hit_test(self, query_text, knowledge_id: list[str], exclude_document_id_list: list[str], top_number: int,
  136. similarity: float,
  137. search_mode: SearchMode,
  138. embedding: Embeddings):
  139. pass
  140. @abstractmethod
  141. def update_by_paragraph_id(self, paragraph_id: str, instance: Dict):
  142. pass
  143. @abstractmethod
  144. def update_by_paragraph_ids(self, paragraph_ids: str, instance: Dict):
  145. pass
  146. @abstractmethod
  147. def update_by_source_id(self, source_id: str, instance: Dict):
  148. pass
  149. @abstractmethod
  150. def update_by_source_ids(self, source_ids: List[str], instance: Dict):
  151. pass
  152. @abstractmethod
  153. def delete_by_knowledge_id(self, knowledge_id: str):
  154. pass
  155. @abstractmethod
  156. def delete_by_document_id(self, document_id: str):
  157. pass
  158. @abstractmethod
  159. def delete_by_document_id_list(self, document_id_list: List[str]):
  160. pass
  161. @abstractmethod
  162. def delete_by_knowledge_id_list(self, knowledge_id_list: List[str]):
  163. pass
  164. @abstractmethod
  165. def delete_by_source_id(self, source_id: str, source_type: str):
  166. pass
  167. @abstractmethod
  168. def delete_by_source_ids(self, source_ids: List[str], source_type: str):
  169. pass
  170. @abstractmethod
  171. def delete_by_paragraph_id(self, paragraph_id: str):
  172. pass
  173. @abstractmethod
  174. def delete_by_paragraph_ids(self, paragraph_ids: List[str]):
  175. pass