pg_vector.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. # coding=utf-8
  2. """
  3. @project: maxkb
  4. @Author:虎
  5. @file: pg_vector.py
  6. @date:2023/10/19 15:28
  7. @desc:
  8. """
  9. import json
  10. import os
  11. from abc import ABC, abstractmethod
  12. from typing import Dict, List
  13. import uuid_utils.compat as uuid
  14. from django.contrib.postgres.search import SearchVector
  15. from django.db.models import QuerySet, Value
  16. from langchain_core.embeddings import Embeddings
  17. from common.db.search import generate_sql_by_query_dict
  18. from common.db.sql_execute import select_list
  19. from common.utils.common import get_file_content
  20. from common.utils.ts_vecto_util import to_ts_vector, to_query
  21. from knowledge.models import Embedding, SearchMode, SourceType
  22. from knowledge.vector.base_vector import BaseVectorStore, normalize_for_embedding
  23. from maxkb.conf import PROJECT_DIR
  24. class PGVector(BaseVectorStore):
  25. def delete_by_source_ids(self, source_ids: List[str], source_type: str):
  26. if len(source_ids) == 0:
  27. return
  28. QuerySet(Embedding).filter(source_id__in=source_ids, source_type=source_type).delete()
  29. def update_by_source_ids(self, source_ids: List[str], instance: Dict):
  30. QuerySet(Embedding).filter(source_id__in=source_ids).update(**instance)
  31. def vector_is_create(self) -> bool:
  32. # 项目启动默认是创建好的 不需要再创建
  33. return True
  34. def vector_create(self):
  35. return True
  36. def _save(self, text, source_type: SourceType, knowledge_id: str, document_id: str, paragraph_id: str,
  37. source_id: str,
  38. is_active: bool,
  39. embedding: Embeddings):
  40. text = normalize_for_embedding(text)
  41. text_embedding = [float(x) for x in embedding.embed_query(text)]
  42. embedding = Embedding(
  43. id=uuid.uuid7(),
  44. knowledge_id=knowledge_id,
  45. document_id=document_id,
  46. is_active=is_active,
  47. paragraph_id=paragraph_id,
  48. source_id=source_id,
  49. embedding=text_embedding,
  50. source_type=source_type,
  51. search_vector=to_ts_vector(text)
  52. )
  53. embedding.save()
  54. return True
  55. def _batch_save(self, text_list: List[Dict], embedding: Embeddings, is_the_task_interrupted):
  56. texts = [normalize_for_embedding(row.get('text')) for row in text_list]
  57. embeddings = embedding.embed_documents(texts)
  58. embedding_list = [
  59. Embedding(
  60. id=uuid.uuid7(),
  61. document_id=text_list[index].get('document_id'),
  62. paragraph_id=text_list[index].get('paragraph_id'),
  63. knowledge_id=text_list[index].get('knowledge_id'),
  64. is_active=text_list[index].get('is_active', True),
  65. source_id=text_list[index].get('source_id'),
  66. source_type=text_list[index].get('source_type'),
  67. embedding=[float(x) for x in embeddings[index]],
  68. search_vector=SearchVector(Value(to_ts_vector(text_list[index]['text'])))
  69. ) for index in range(0, len(texts))]
  70. if not is_the_task_interrupted():
  71. QuerySet(Embedding).bulk_create(embedding_list) if len(embedding_list) > 0 else None
  72. return True
  73. def hit_test(self, query_text, knowledge_id_list: list[str], exclude_document_id_list: list[str], top_number: int,
  74. similarity: float,
  75. search_mode: SearchMode,
  76. embedding: Embeddings):
  77. if knowledge_id_list is None or len(knowledge_id_list) == 0:
  78. return []
  79. exclude_dict = {}
  80. query_text = normalize_for_embedding(query_text)
  81. embedding_query = embedding.embed_query(query_text)
  82. query_set = QuerySet(Embedding).filter(knowledge_id__in=knowledge_id_list, is_active=True)
  83. if exclude_document_id_list is not None and len(exclude_document_id_list) > 0:
  84. exclude_dict.__setitem__('document_id__in', exclude_document_id_list)
  85. query_set = query_set.exclude(**exclude_dict)
  86. for search_handle in search_handle_list:
  87. if search_handle.support(search_mode):
  88. return search_handle.handle(query_set, query_text, embedding_query, top_number, similarity, search_mode)
  89. def query(self, query_text: str, query_embedding: List[float], knowledge_id_list: list[str],
  90. document_id_list: list[str],
  91. exclude_document_id_list: list[str],
  92. exclude_paragraph_list: list[str], is_active: bool, top_n: int, similarity: float,
  93. search_mode: SearchMode):
  94. exclude_dict = {}
  95. if knowledge_id_list is None or len(knowledge_id_list) == 0:
  96. return []
  97. query_set = QuerySet(Embedding).filter(knowledge_id__in=knowledge_id_list, is_active=is_active)
  98. if document_id_list is not None and len(document_id_list) > 0:
  99. query_set = query_set.filter(document_id__in=document_id_list)
  100. if exclude_document_id_list is not None and len(exclude_document_id_list) > 0:
  101. query_set = query_set.exclude(document_id__in=exclude_document_id_list)
  102. if exclude_paragraph_list is not None and len(exclude_paragraph_list) > 0:
  103. query_set = query_set.exclude(paragraph_id__in=exclude_paragraph_list)
  104. query_set = query_set.exclude(**exclude_dict)
  105. for search_handle in search_handle_list:
  106. if search_handle.support(search_mode):
  107. return search_handle.handle(query_set, query_text, query_embedding, top_n, similarity, search_mode)
  108. def update_by_source_id(self, source_id: str, instance: Dict):
  109. QuerySet(Embedding).filter(source_id=source_id).update(**instance)
  110. def update_by_paragraph_id(self, paragraph_id: str, instance: Dict):
  111. QuerySet(Embedding).filter(paragraph_id=paragraph_id).update(**instance)
  112. def update_by_paragraph_ids(self, paragraph_id: str, instance: Dict):
  113. QuerySet(Embedding).filter(paragraph_id__in=paragraph_id).update(**instance)
  114. def delete_by_knowledge_id(self, knowledge_id: str):
  115. QuerySet(Embedding).filter(knowledge_id=knowledge_id).delete()
  116. def delete_by_knowledge_id_list(self, knowledge_id_list: List[str]):
  117. QuerySet(Embedding).filter(knowledge_id__in=knowledge_id_list).delete()
  118. def delete_by_document_id(self, document_id: str):
  119. QuerySet(Embedding).filter(document_id=document_id).delete()
  120. return True
  121. def delete_by_document_id_list(self, document_id_list: List[str]):
  122. if len(document_id_list) == 0:
  123. return True
  124. return QuerySet(Embedding).filter(document_id__in=document_id_list).delete()
  125. def delete_by_source_id(self, source_id: str, source_type: str):
  126. QuerySet(Embedding).filter(source_id=source_id, source_type=source_type).delete()
  127. return True
  128. def delete_by_paragraph_id(self, paragraph_id: str):
  129. QuerySet(Embedding).filter(paragraph_id=paragraph_id).delete()
  130. def delete_by_paragraph_ids(self, paragraph_ids: List[str]):
  131. QuerySet(Embedding).filter(paragraph_id__in=paragraph_ids).delete()
  132. class ISearch(ABC):
  133. @abstractmethod
  134. def support(self, search_mode: SearchMode):
  135. pass
  136. @abstractmethod
  137. def handle(self, query_set, query_text, query_embedding, top_number: int,
  138. similarity: float, search_mode: SearchMode):
  139. pass
  140. class EmbeddingSearch(ISearch):
  141. def handle(self,
  142. query_set,
  143. query_text,
  144. query_embedding,
  145. top_number: int,
  146. similarity: float,
  147. search_mode: SearchMode):
  148. exec_sql, exec_params = generate_sql_by_query_dict({'embedding_query': query_set},
  149. select_string=get_file_content(
  150. os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql',
  151. 'embedding_search.sql')),
  152. with_table_name=True)
  153. embedding_model = select_list(exec_sql, [
  154. len(query_embedding),
  155. json.dumps(query_embedding),
  156. *exec_params,
  157. similarity,
  158. top_number
  159. ])
  160. return embedding_model
  161. def support(self, search_mode: SearchMode):
  162. return search_mode.value == SearchMode.embedding.value
  163. class KeywordsSearch(ISearch):
  164. def handle(self,
  165. query_set,
  166. query_text,
  167. query_embedding,
  168. top_number: int,
  169. similarity: float,
  170. search_mode: SearchMode):
  171. exec_sql, exec_params = generate_sql_by_query_dict({'keywords_query': query_set},
  172. select_string=get_file_content(
  173. os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql',
  174. 'keywords_search.sql')),
  175. with_table_name=True)
  176. embedding_model = select_list(exec_sql, [
  177. to_query(query_text),
  178. *exec_params,
  179. similarity,
  180. top_number
  181. ])
  182. return embedding_model
  183. def support(self, search_mode: SearchMode):
  184. return search_mode.value == SearchMode.keywords.value
  185. class BlendSearch(ISearch):
  186. def handle(self,
  187. query_set,
  188. query_text,
  189. query_embedding,
  190. top_number: int,
  191. similarity: float,
  192. search_mode: SearchMode):
  193. exec_sql, exec_params = generate_sql_by_query_dict({'embedding_query': query_set},
  194. select_string=get_file_content(
  195. os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql',
  196. 'blend_search.sql')),
  197. with_table_name=True)
  198. embedding_model = select_list(exec_sql, [
  199. len(query_embedding),
  200. json.dumps(query_embedding),
  201. to_query(query_text),
  202. *exec_params, similarity,
  203. top_number
  204. ])
  205. return embedding_model
  206. def support(self, search_mode: SearchMode):
  207. return search_mode.value == SearchMode.blend.value
  208. search_handle_list = [EmbeddingSearch(), KeywordsSearch(), BlendSearch()]