embedding.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. # coding=utf-8
  2. import traceback
  3. from typing import List
  4. from celery_once import QueueOnce
  5. from django.db.models import QuerySet
  6. from django.utils.translation import gettext_lazy as _
  7. from common.config.embedding_config import ModelManage
  8. from common.event.listener_manage import ListenerManagement, UpdateProblemArgs, UpdateEmbeddingKnowledgeIdArgs, \
  9. UpdateEmbeddingDocumentIdArgs
  10. from common.utils.logger import maxkb_logger
  11. from knowledge.models import Document, TaskType, State
  12. from knowledge.serializers.common import drop_knowledge_index
  13. from models_provider.models import Model
  14. from models_provider.tools import get_model, get_model_default_params
  15. from ops import celery_app
  16. def get_embedding_model(model_id, exception_handler=lambda e: maxkb_logger.error(
  17. _('Failed to obtain vector model: {error} {traceback}').format(
  18. error=str(e),
  19. traceback=traceback.format_exc()
  20. ))):
  21. try:
  22. model = QuerySet(Model).filter(id=model_id).first()
  23. default_params = get_model_default_params(model)
  24. embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model, **{**default_params}))
  25. except Exception as e:
  26. exception_handler(e)
  27. raise e
  28. return embedding_model
  29. @celery_app.task(base=QueueOnce, once={'keys': ['paragraph_id']}, name='celery:embedding_by_paragraph')
  30. def embedding_by_paragraph(paragraph_id, model_id):
  31. embedding_model = get_embedding_model(model_id)
  32. ListenerManagement.embedding_by_paragraph(paragraph_id, embedding_model)
  33. @celery_app.task(base=QueueOnce, once={'keys': ['paragraph_id_list']}, name='celery:embedding_by_paragraph_data_list')
  34. def embedding_by_paragraph_data_list(data_list, paragraph_id_list, model_id):
  35. embedding_model = get_embedding_model(model_id)
  36. ListenerManagement.embedding_by_paragraph_data_list(data_list, paragraph_id_list, embedding_model)
  37. @celery_app.task(base=QueueOnce, once={'keys': ['paragraph_id_list']}, name='celery:embedding_by_paragraph_list')
  38. def embedding_by_paragraph_list(paragraph_id_list, model_id):
  39. embedding_model = get_embedding_model(model_id)
  40. ListenerManagement.embedding_by_paragraph_list(paragraph_id_list, embedding_model)
  41. @celery_app.task(base=QueueOnce, once={'keys': ['document_id']}, name='celery:embedding_by_document')
  42. def embedding_by_document(document_id, model_id, state_list=None):
  43. """
  44. 向量化文档
  45. @param state_list:
  46. @param document_id: 文档id
  47. @param model_id 向量模型
  48. :return: None
  49. """
  50. if state_list is None:
  51. state_list = [State.PENDING.value, State.STARTED.value, State.SUCCESS.value, State.FAILURE.value,
  52. State.REVOKE.value,
  53. State.REVOKED.value, State.IGNORED.value]
  54. def exception_handler(e):
  55. ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), TaskType.EMBEDDING,
  56. State.FAILURE)
  57. maxkb_logger.error(
  58. _('Failed to obtain vector model: {error} {traceback}').format(
  59. error=str(e),
  60. traceback=traceback.format_exc()
  61. ))
  62. embedding_model = get_embedding_model(model_id, exception_handler)
  63. #
  64. ListenerManagement.embedding_by_document(document_id, embedding_model, state_list)
  65. @celery_app.task(name='celery:embedding_by_document_list')
  66. def embedding_by_document_list(document_id_list, model_id):
  67. """
  68. 向量化文档
  69. @param document_id_list: 文档id列表
  70. @param model_id 向量模型
  71. :return: None
  72. """
  73. for document_id in document_id_list:
  74. embedding_by_document.delay(document_id, model_id)
  75. @celery_app.task(base=QueueOnce, once={'keys': ['knowledge_id']}, name='celery:embedding_by_knowledge')
  76. def embedding_by_knowledge(knowledge_id, model_id):
  77. """
  78. 向量化知识库
  79. @param knowledge_id: 知识库id
  80. @param model_id 向量模型
  81. :return: None
  82. """
  83. maxkb_logger.info(_('Start--->Vectorized knowledge: {knowledge_id}').format(knowledge_id=knowledge_id))
  84. try:
  85. ListenerManagement.delete_embedding_by_knowledge(knowledge_id)
  86. drop_knowledge_index(knowledge_id=knowledge_id)
  87. document_list = QuerySet(Document).filter(knowledge_id=knowledge_id)
  88. maxkb_logger.info(_('Knowledge documentation: {document_names}').format(
  89. document_names=", ".join([d.name for d in document_list])))
  90. for document in document_list:
  91. try:
  92. embedding_by_document.delay(document.id, model_id)
  93. except Exception as e:
  94. pass
  95. except Exception as e:
  96. maxkb_logger.error(
  97. _('Vectorized knowledge: {knowledge_id} error {error} {traceback}').format(knowledge_id=knowledge_id,
  98. error=str(e),
  99. traceback=traceback.format_exc()))
  100. finally:
  101. maxkb_logger.info(_('End--->Vectorized knowledge: {knowledge_id}').format(knowledge_id=knowledge_id))
  102. def embedding_by_problem(args, model_id):
  103. """
  104. 向量话问题
  105. @param args: 问题对象
  106. @param model_id: 模型id
  107. @return:
  108. """
  109. embedding_model = get_embedding_model(model_id)
  110. ListenerManagement.embedding_by_problem(args, embedding_model)
  111. def embedding_by_data_list(args: List, model_id):
  112. embedding_model = get_embedding_model(model_id)
  113. ListenerManagement.embedding_by_data_list(args, embedding_model)
  114. def delete_embedding_by_document(document_id):
  115. """
  116. 删除指定文档id的向量
  117. @param document_id: 文档id
  118. @return: None
  119. """
  120. ListenerManagement.delete_embedding_by_document(document_id)
  121. def delete_embedding_by_document_list(document_id_list: List[str]):
  122. """
  123. 删除指定文档列表的向量数据
  124. @param document_id_list: 文档id列表
  125. @return: None
  126. """
  127. ListenerManagement.delete_embedding_by_document_list(document_id_list)
  128. def delete_embedding_by_knowledge(knowledge_id):
  129. """
  130. 删除指定数据集向量数据
  131. @param knowledge_id: 数据集id
  132. @return: None
  133. """
  134. ListenerManagement.delete_embedding_by_knowledge(knowledge_id)
  135. def delete_embedding_by_paragraph(paragraph_id):
  136. """
  137. 删除指定段落的向量数据
  138. @param paragraph_id: 段落id
  139. @return: None
  140. """
  141. ListenerManagement.delete_embedding_by_paragraph(paragraph_id)
  142. def delete_embedding_by_source(source_id):
  143. """
  144. 删除指定资源id的向量数据
  145. @param source_id: 资源id
  146. @return: None
  147. """
  148. ListenerManagement.delete_embedding_by_source(source_id)
  149. def disable_embedding_by_paragraph(paragraph_id):
  150. """
  151. 禁用某个段落id的向量
  152. @param paragraph_id: 段落id
  153. @return: None
  154. """
  155. ListenerManagement.disable_embedding_by_paragraph(paragraph_id)
  156. def enable_embedding_by_paragraph(paragraph_id):
  157. """
  158. 开启某个段落id的向量数据
  159. @param paragraph_id: 段落id
  160. @return: None
  161. """
  162. ListenerManagement.enable_embedding_by_paragraph(paragraph_id)
  163. def delete_embedding_by_source_ids(source_ids: List[str]):
  164. """
  165. 删除向量根据source_id_list
  166. @param source_ids:
  167. @return:
  168. """
  169. ListenerManagement.delete_embedding_by_source_ids(source_ids)
  170. def update_problem_embedding(problem_id: str, problem_content: str, model_id):
  171. """
  172. 更新问题
  173. @param problem_id:
  174. @param problem_content:
  175. @param model_id:
  176. @return:
  177. """
  178. model = get_embedding_model(model_id)
  179. ListenerManagement.update_problem(UpdateProblemArgs(problem_id, problem_content, model))
  180. def update_embedding_knowledge_id(paragraph_id_list, target_knowledge_id):
  181. """
  182. 修改向量数据到指定知识库
  183. @param paragraph_id_list: 指定段落的向量数据
  184. @param target_knowledge_id: 知识库id
  185. @return:
  186. """
  187. ListenerManagement.update_embedding_knowledge_id(
  188. UpdateEmbeddingKnowledgeIdArgs(paragraph_id_list, target_knowledge_id))
  189. def delete_embedding_by_paragraph_ids(paragraph_ids: List[str]):
  190. """
  191. 删除指定段落列表的向量数据
  192. @param paragraph_ids: 段落列表
  193. @return: None
  194. """
  195. ListenerManagement.delete_embedding_by_paragraph_ids(paragraph_ids)
  196. def update_embedding_document_id(paragraph_id_list, target_document_id, target_knowledge_id,
  197. target_embedding_model_id=None):
  198. target_embedding_model = get_embedding_model(
  199. target_embedding_model_id) if target_embedding_model_id is not None else None
  200. ListenerManagement.update_embedding_document_id(
  201. UpdateEmbeddingDocumentIdArgs(paragraph_id_list, target_document_id, target_knowledge_id,
  202. target_embedding_model))
  203. def delete_embedding_by_knowledge_id_list(knowledge_id_list):
  204. ListenerManagement.delete_embedding_by_knowledge_id_list(knowledge_id_list)