problem.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284
  1. import os
  2. from functools import reduce
  3. from typing import Dict, List
  4. import uuid_utils.compat as uuid
  5. from django.db import transaction
  6. from django.db.models import QuerySet
  7. from django.utils.translation import gettext_lazy as _
  8. from rest_framework import serializers
  9. from common.db.search import native_search, native_page_search
  10. from common.exception.app_exception import AppApiException
  11. from common.utils.common import get_file_content
  12. from knowledge.models import Problem, ProblemParagraphMapping, Paragraph, Knowledge, SourceType
  13. from knowledge.serializers.common import get_embedding_model_id_by_knowledge_id
  14. from knowledge.task.embedding import delete_embedding_by_source_ids, update_problem_embedding, embedding_by_data_list
  15. from maxkb.const import PROJECT_DIR
  16. class ProblemSerializer(serializers.ModelSerializer):
  17. class Meta:
  18. model = Problem
  19. fields = ['id', 'content', 'knowledge_id', 'create_time', 'update_time']
  20. class ProblemInstanceSerializer(serializers.Serializer):
  21. id = serializers.CharField(required=False, label=_('problem id'))
  22. content = serializers.CharField(required=True, max_length=256, label=_('content'))
  23. class ProblemEditSerializer(serializers.Serializer):
  24. content = serializers.CharField(required=True, max_length=256, label=_('content'))
  25. class ProblemMappingSerializer(serializers.Serializer):
  26. paragraph_id = serializers.UUIDField(required=True, label=_('paragraph id'))
  27. document_id = serializers.UUIDField(required=True, label=_('document id'))
  28. class ProblemBatchSerializer(serializers.Serializer):
  29. problem_list = serializers.ListField(required=True, label=_('problem list'),
  30. child=serializers.CharField(required=True, max_length=256, label=_('problem')))
  31. class ProblemBatchDeleteSerializer(serializers.Serializer):
  32. problem_id_list = serializers.ListField(required=True, label=_('problem id list'),
  33. child=serializers.UUIDField(required=True, label=_('problem id')))
  34. class AssociationParagraph(serializers.Serializer):
  35. paragraph_id = serializers.UUIDField(required=True, label=_('paragraph id'))
  36. document_id = serializers.UUIDField(required=True, label=_('document id'))
  37. class BatchAssociation(serializers.Serializer):
  38. problem_id_list = serializers.ListField(required=True, label=_('problem id list'),
  39. child=serializers.UUIDField(required=True, label=_('problem id')))
  40. paragraph_list = AssociationParagraph(many=True)
  41. class ProblemSerializers(serializers.Serializer):
  42. class BatchOperate(serializers.Serializer):
  43. workspace_id = serializers.CharField(required=True, label=_('workspace id'))
  44. knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id'))
  45. def is_valid(self, *, raise_exception=False):
  46. super().is_valid(raise_exception=True)
  47. workspace_id = self.data.get('workspace_id')
  48. query_set = QuerySet(Knowledge).filter(id=self.data.get('knowledge_id'))
  49. if workspace_id:
  50. query_set = query_set.filter(workspace_id=workspace_id)
  51. if not query_set.exists():
  52. raise AppApiException(500, _('Knowledge id does not exist'))
  53. def delete(self, problem_id_list: List, with_valid=True):
  54. if with_valid:
  55. self.is_valid(raise_exception=True)
  56. knowledge_id = self.data.get('knowledge_id')
  57. problem_paragraph_mapping_list = QuerySet(ProblemParagraphMapping).filter(
  58. knowledge_id=knowledge_id,
  59. problem_id__in=problem_id_list)
  60. source_ids = [row.id for row in problem_paragraph_mapping_list]
  61. problem_paragraph_mapping_list.delete()
  62. QuerySet(Problem).filter(id__in=problem_id_list).delete()
  63. delete_embedding_by_source_ids(source_ids)
  64. return True
  65. def association(self, instance: Dict, with_valid=True):
  66. if with_valid:
  67. self.is_valid(raise_exception=True)
  68. BatchAssociation(data=instance).is_valid(raise_exception=True)
  69. knowledge_id = self.data.get('knowledge_id')
  70. paragraph_list = instance.get('paragraph_list')
  71. problem_id_list = instance.get('problem_id_list')
  72. problem_list = QuerySet(Problem).filter(id__in=problem_id_list)
  73. exits_problem_paragraph_mapping = QuerySet(
  74. ProblemParagraphMapping
  75. ).filter(problem_id__in=problem_id_list, paragraph_id__in=[p.get('paragraph_id') for p in paragraph_list])
  76. problem_paragraph_mapping_list = [
  77. (problem_paragraph_mapping, problem) for problem_paragraph_mapping, problem in
  78. reduce(
  79. lambda x, y: [*x, *y],
  80. [
  81. [
  82. to_problem_paragraph_mapping(
  83. problem, paragraph.get('document_id'),
  84. paragraph.get('paragraph_id'),
  85. knowledge_id
  86. ) for paragraph in paragraph_list
  87. ] for problem in problem_list
  88. ],
  89. []
  90. ) if not is_exits(exits_problem_paragraph_mapping, problem_paragraph_mapping)
  91. ]
  92. QuerySet(ProblemParagraphMapping).bulk_create(
  93. [problem_paragraph_mapping for problem_paragraph_mapping, problem in problem_paragraph_mapping_list]
  94. )
  95. data_list = [
  96. {
  97. 'text': problem.content,
  98. 'is_active': True,
  99. 'source_type': SourceType.PROBLEM,
  100. 'source_id': str(problem_paragraph_mapping.id),
  101. 'document_id': str(problem_paragraph_mapping.document_id),
  102. 'paragraph_id': str(problem_paragraph_mapping.paragraph_id),
  103. 'knowledge_id': knowledge_id,
  104. } for problem_paragraph_mapping, problem in problem_paragraph_mapping_list
  105. ]
  106. model_id = get_embedding_model_id_by_knowledge_id(self.data.get('knowledge_id'))
  107. embedding_by_data_list(data_list, model_id=model_id)
  108. class Operate(serializers.Serializer):
  109. workspace_id = serializers.CharField(required=True, label=_('workspace id'))
  110. knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id'))
  111. problem_id = serializers.UUIDField(required=True, label=_('problem id'))
  112. def is_valid(self, *, raise_exception=False):
  113. super().is_valid(raise_exception=True)
  114. workspace_id = self.data.get('workspace_id')
  115. query_set = QuerySet(Knowledge).filter(id=self.data.get('knowledge_id'))
  116. if workspace_id:
  117. query_set = query_set.filter(workspace_id=workspace_id)
  118. if not query_set.exists():
  119. raise AppApiException(500, _('Knowledge id does not exist'))
  120. def list_paragraph(self, with_valid=True):
  121. if with_valid:
  122. self.is_valid(raise_exception=True)
  123. problem_paragraph_mapping = QuerySet(ProblemParagraphMapping).filter(
  124. knowledge_id=self.data.get("knowledge_id"),
  125. problem_id=self.data.get("problem_id")
  126. )
  127. if problem_paragraph_mapping is None or len(problem_paragraph_mapping) == 0:
  128. return []
  129. return native_search(
  130. QuerySet(Paragraph).filter(id__in=[row.paragraph_id for row in problem_paragraph_mapping]),
  131. select_string=get_file_content(
  132. os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'list_paragraph.sql')))
  133. def one(self, with_valid=True):
  134. if with_valid:
  135. self.is_valid(raise_exception=True)
  136. return ProblemInstanceSerializer(QuerySet(Problem).get(**{'id': self.data.get('problem_id')})).data
  137. @transaction.atomic
  138. def delete(self, with_valid=True):
  139. if with_valid:
  140. self.is_valid(raise_exception=True)
  141. problem_paragraph_mapping_list = QuerySet(ProblemParagraphMapping).filter(
  142. knowledge_id=self.data.get('knowledge_id'),
  143. problem_id=self.data.get('problem_id'))
  144. source_ids = [row.id for row in problem_paragraph_mapping_list]
  145. problem_paragraph_mapping_list.delete()
  146. QuerySet(Problem).filter(id=self.data.get('problem_id')).delete()
  147. delete_embedding_by_source_ids(source_ids)
  148. return True
  149. @transaction.atomic
  150. def edit(self, instance: Dict, with_valid=True):
  151. if with_valid:
  152. self.is_valid(raise_exception=True)
  153. problem_id = self.data.get('problem_id')
  154. knowledge_id = self.data.get('knowledge_id')
  155. content = instance.get('content')
  156. problem = QuerySet(Problem).filter(id=problem_id, knowledge_id=knowledge_id).first()
  157. QuerySet(Knowledge).filter(id=knowledge_id)
  158. problem.content = content
  159. problem.save()
  160. model_id = get_embedding_model_id_by_knowledge_id(knowledge_id)
  161. update_problem_embedding(problem_id, content, model_id)
  162. class Create(serializers.Serializer):
  163. workspace_id = serializers.CharField(required=True, label=_('workspace id'))
  164. knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id'))
  165. def is_valid(self, *, raise_exception=False):
  166. super().is_valid(raise_exception=True)
  167. workspace_id = self.data.get('workspace_id')
  168. query_set = QuerySet(Knowledge).filter(id=self.data.get('knowledge_id'))
  169. if workspace_id:
  170. query_set = query_set.filter(workspace_id=workspace_id)
  171. if not query_set.exists():
  172. raise AppApiException(500, _('Knowledge id does not exist'))
  173. def batch(self, problem_list, with_valid=True):
  174. if with_valid:
  175. self.is_valid(raise_exception=True)
  176. ProblemBatchSerializer(data={'problem_list': problem_list}).is_valid(raise_exception=True)
  177. problem_list = list(set(problem_list))
  178. knowledge_id = self.data.get('knowledge_id')
  179. exists_problem_content_list = [
  180. problem.content for problem in QuerySet(
  181. Problem
  182. ).filter(knowledge_id=knowledge_id, content__in=problem_list)
  183. ]
  184. problem_instance_list = [
  185. Problem(
  186. id=uuid.uuid7(), knowledge_id=knowledge_id, content=problem_content
  187. ) for problem_content in problem_list if (
  188. not exists_problem_content_list.__contains__(
  189. problem_content
  190. ) if len(exists_problem_content_list) > 0 else True
  191. )
  192. ]
  193. QuerySet(Problem).bulk_create(problem_instance_list) if len(problem_instance_list) > 0 else None
  194. return [ProblemSerializer(problem_instance).data for problem_instance in problem_instance_list]
  195. class Query(serializers.Serializer):
  196. workspace_id = serializers.CharField(required=True, label=_('workspace id'))
  197. knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id'))
  198. content = serializers.CharField(required=False, label=_('content'))
  199. def is_valid(self, *, raise_exception=False):
  200. super().is_valid(raise_exception=True)
  201. workspace_id = self.data.get('workspace_id')
  202. query_set = QuerySet(Knowledge).filter(id=self.data.get('knowledge_id'))
  203. if workspace_id:
  204. query_set = query_set.filter(workspace_id=workspace_id)
  205. if not query_set.exists():
  206. raise AppApiException(500, _('Knowledge id does not exist'))
  207. def get_query_set(self):
  208. self.is_valid()
  209. query_set = QuerySet(model=Problem)
  210. query_set = query_set.filter(
  211. **{'knowledge_id': self.data.get('knowledge_id')})
  212. if 'content' in self.data:
  213. query_set = query_set.filter(**{'content__icontains': self.data.get('content')})
  214. query_set = query_set.order_by("-create_time")
  215. return query_set
  216. def list(self):
  217. query_set = self.get_query_set()
  218. return native_search(query_set, select_string=get_file_content(
  219. os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'list_problem.sql')))
  220. def page(self, current_page, page_size):
  221. query_set = self.get_query_set()
  222. return native_page_search(current_page, page_size, query_set, select_string=get_file_content(
  223. os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'list_problem.sql')))
  224. def is_exits(exits_problem_paragraph_mapping_list, new_paragraph_mapping):
  225. filter_list = [exits_problem_paragraph_mapping for exits_problem_paragraph_mapping in
  226. exits_problem_paragraph_mapping_list if
  227. str(exits_problem_paragraph_mapping.paragraph_id) == new_paragraph_mapping.paragraph_id
  228. and str(exits_problem_paragraph_mapping.problem_id) == new_paragraph_mapping.problem_id
  229. and str(exits_problem_paragraph_mapping.knowledge_id) == new_paragraph_mapping.knowledge_id]
  230. return len(filter_list) > 0
  231. def to_problem_paragraph_mapping(problem, document_id: str, paragraph_id: str, knowledge_id: str):
  232. return ProblemParagraphMapping(
  233. id=uuid.uuid7(),
  234. document_id=document_id,
  235. paragraph_id=paragraph_id,
  236. knowledge_id=knowledge_id,
  237. problem_id=str(problem.id)
  238. ), problem