paragraph.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711
  1. # coding=utf-8
  2. from typing import Dict
  3. import uuid_utils.compat as uuid
  4. from celery_once import AlreadyQueued
  5. from django.db import transaction
  6. from django.db.models import QuerySet, Count, F
  7. from django.utils.translation import gettext_lazy as _
  8. from rest_framework import serializers
  9. from common.db.search import page_search
  10. from common.event.listener_manage import ListenerManagement
  11. from common.exception.app_exception import AppApiException
  12. from common.utils.common import post
  13. from knowledge.models import Paragraph, Problem, Document, ProblemParagraphMapping, SourceType, TaskType, State, \
  14. Knowledge
  15. from knowledge.serializers.common import ProblemParagraphObject, ProblemParagraphManage, \
  16. get_embedding_model_id_by_knowledge_id, update_document_char_length, BatchSerializer
  17. from knowledge.serializers.problem import ProblemInstanceSerializer, ProblemSerializer, ProblemSerializers
  18. from knowledge.task.embedding import embedding_by_paragraph, enable_embedding_by_paragraph, \
  19. disable_embedding_by_paragraph, \
  20. delete_embedding_by_paragraph, embedding_by_problem as embedding_by_problem_task, delete_embedding_by_paragraph_ids, \
  21. embedding_by_problem, delete_embedding_by_source, update_embedding_document_id
  22. from knowledge.task.generate import generate_related_by_paragraph_id_list
  23. class ParagraphSerializer(serializers.ModelSerializer):
  24. class Meta:
  25. model = Paragraph
  26. fields = ['id', 'content', 'is_active', 'document_id', 'title', 'create_time', 'update_time', 'position']
  27. class ParagraphInstanceSerializer(serializers.Serializer):
  28. """
  29. 段落实例对象
  30. """
  31. content = serializers.CharField(required=True, label=_('content'), max_length=102400, min_length=1, allow_null=True,
  32. allow_blank=True)
  33. title = serializers.CharField(required=False, max_length=256, label=_('section title'), allow_null=True,
  34. allow_blank=True)
  35. problem_list = ProblemInstanceSerializer(required=False, many=True)
  36. is_active = serializers.BooleanField(required=False, label=_('Is active'))
  37. class EditParagraphSerializers(serializers.Serializer):
  38. title = serializers.CharField(required=False, max_length=256, label=_('section title'), allow_null=True,
  39. allow_blank=True)
  40. content = serializers.CharField(required=False, max_length=102400, allow_null=True, allow_blank=True,
  41. label=_('section title'))
  42. problem_list = ProblemInstanceSerializer(required=False, many=True)
  43. class ParagraphBatchGenerateRelatedSerializer(serializers.Serializer):
  44. paragraph_id_list = serializers.ListField(required=True, label=_('paragraph id list'),
  45. child=serializers.UUIDField(required=True, label=_('paragraph id')))
  46. model_id = serializers.UUIDField(required=True, label=_('model id'))
  47. prompt = serializers.CharField(required=True, label=_('prompt'), max_length=102400, allow_null=True,
  48. allow_blank=True)
  49. document_id = serializers.UUIDField(required=True, label=_('document id'))
  50. class ParagraphSerializers(serializers.Serializer):
  51. title = serializers.CharField(required=False, max_length=256, label=_('section title'), allow_null=True,
  52. allow_blank=True)
  53. content = serializers.CharField(required=True, max_length=102400, label=_('section title'))
  54. class Problem(serializers.Serializer):
  55. workspace_id = serializers.CharField(required=True, label=_('workspace id'))
  56. knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id'))
  57. document_id = serializers.UUIDField(required=True, label=_('document id'))
  58. paragraph_id = serializers.UUIDField(required=True, label=_('paragraph id'))
  59. def is_valid(self, *, raise_exception=False):
  60. super().is_valid(raise_exception=True)
  61. workspace_id = self.data.get('workspace_id')
  62. query_set = QuerySet(Knowledge).filter(id=self.data.get('knowledge_id'))
  63. if workspace_id:
  64. query_set = query_set.filter(workspace_id=workspace_id)
  65. if not query_set.exists():
  66. raise AppApiException(500, _('Knowledge id does not exist'))
  67. if not QuerySet(Paragraph).filter(id=self.data.get('paragraph_id')).exists():
  68. raise AppApiException(500, _('Paragraph id does not exist'))
  69. def list(self, with_valid=False):
  70. """
  71. 获取问题列表
  72. :param with_valid: 是否校验
  73. :return: 问题列表
  74. """
  75. if with_valid:
  76. self.is_valid(raise_exception=True)
  77. problem_paragraph_mapping = QuerySet(ProblemParagraphMapping).filter(
  78. knowledge_id=self.data.get("knowledge_id"),
  79. paragraph_id=self.data.get(
  80. 'paragraph_id'))
  81. return [ProblemSerializer(row).data for row in
  82. QuerySet(Problem).filter(id__in=[row.problem_id for row in problem_paragraph_mapping])]
  83. @transaction.atomic
  84. def save(self, instance: Dict, with_valid=True, with_embedding=True, embedding_by_problem=None):
  85. if with_valid:
  86. self.is_valid()
  87. ProblemInstanceSerializer(data=instance).is_valid(raise_exception=True)
  88. problem = QuerySet(Problem).filter(knowledge_id=self.data.get('knowledge_id'),
  89. content=instance.get('content')).first()
  90. if problem is None:
  91. problem = Problem(id=uuid.uuid7(), knowledge_id=self.data.get('knowledge_id'),
  92. content=instance.get('content'))
  93. problem.save()
  94. if QuerySet(ProblemParagraphMapping).filter(knowledge_id=self.data.get('knowledge_id'),
  95. problem_id=problem.id,
  96. paragraph_id=self.data.get('paragraph_id')).exists():
  97. raise AppApiException(500, _('Already associated, please do not associate again'))
  98. problem_paragraph_mapping = ProblemParagraphMapping(
  99. id=uuid.uuid7(),
  100. problem_id=problem.id,
  101. document_id=self.data.get('document_id'),
  102. paragraph_id=self.data.get('paragraph_id'),
  103. knowledge_id=self.data.get('knowledge_id')
  104. )
  105. problem_paragraph_mapping.save()
  106. model_id = get_embedding_model_id_by_knowledge_id(self.data.get('knowledge_id'))
  107. if with_embedding:
  108. embedding_by_problem_task({
  109. 'text': problem.content,
  110. 'is_active': True,
  111. 'source_type': SourceType.PROBLEM,
  112. 'source_id': problem_paragraph_mapping.id,
  113. 'document_id': self.data.get('document_id'),
  114. 'paragraph_id': self.data.get('paragraph_id'),
  115. 'knowledge_id': self.data.get('knowledge_id'),
  116. }, model_id)
  117. return ProblemSerializers.Operate(
  118. data={
  119. 'workspace_id': self.data.get('workspace_id'),
  120. 'knowledge_id': self.data.get('knowledge_id'),
  121. 'problem_id': problem.id
  122. }
  123. ).one(with_valid=True)
  124. class Operate(serializers.Serializer):
  125. workspace_id = serializers.CharField(required=True, label=_('workspace id'))
  126. # 段落id
  127. paragraph_id = serializers.UUIDField(required=True, label=_('paragraph id'))
  128. # 知识库id
  129. knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id'))
  130. # 文档id
  131. document_id = serializers.UUIDField(required=True, label=_('document id'))
  132. def is_valid(self, *, raise_exception=True):
  133. super().is_valid(raise_exception=True)
  134. workspace_id = self.data.get('workspace_id')
  135. query_set = QuerySet(Knowledge).filter(id=self.data.get('knowledge_id'))
  136. if workspace_id:
  137. query_set = query_set.filter(workspace_id=workspace_id)
  138. if not query_set.exists():
  139. raise AppApiException(500, _('Knowledge id does not exist'))
  140. if not QuerySet(Paragraph).filter(id=self.data.get('paragraph_id')).exists():
  141. raise AppApiException(500, _('Paragraph id does not exist'))
  142. @staticmethod
  143. def post_embedding(paragraph, instance, knowledge_id):
  144. if 'is_active' in instance and instance.get('is_active') is not None:
  145. (enable_embedding_by_paragraph if instance.get(
  146. 'is_active') else disable_embedding_by_paragraph)(paragraph.get('id'))
  147. else:
  148. model_id = get_embedding_model_id_by_knowledge_id(knowledge_id)
  149. embedding_by_paragraph(paragraph.get('id'), model_id)
  150. return paragraph
  151. @post(post_embedding)
  152. @transaction.atomic
  153. def edit(self, instance: Dict):
  154. self.is_valid()
  155. EditParagraphSerializers(data=instance).is_valid(raise_exception=True)
  156. _paragraph = QuerySet(Paragraph).get(id=self.data.get("paragraph_id"))
  157. update_keys = ['title', 'content', 'is_active']
  158. for update_key in update_keys:
  159. if update_key in instance and instance.get(update_key) is not None:
  160. _paragraph.__setattr__(update_key, instance.get(update_key))
  161. if 'problem_list' in instance:
  162. update_problem_list = list(
  163. filter(lambda row: 'id' in row and row.get('id') is not None, instance.get('problem_list')))
  164. create_problem_list = list(filter(lambda row: row.get('id') is None, instance.get('problem_list')))
  165. # 问题集合
  166. problem_list = QuerySet(Problem).filter(paragraph_id=self.data.get("paragraph_id"))
  167. # 校验前端 携带过来的id
  168. for update_problem in update_problem_list:
  169. if not set([str(row.id) for row in problem_list]).__contains__(update_problem.get('id')):
  170. raise AppApiException(500, _('Problem id does not exist'))
  171. # 对比需要删除的问题
  172. delete_problem_list = list(filter(
  173. lambda row: not [str(update_row.get('id')) for update_row in update_problem_list].__contains__(
  174. str(row.id)), problem_list)) if len(update_problem_list) > 0 else []
  175. # 删除问题
  176. QuerySet(Problem).filter(id__in=[row.id for row in delete_problem_list]).delete() if len(
  177. delete_problem_list) > 0 else None
  178. # 插入新的问题
  179. QuerySet(Problem).bulk_create([
  180. Problem(
  181. id=uuid.uuid7(),
  182. content=p.get('content'),
  183. paragraph_id=self.data.get('paragraph_id'),
  184. knowledge_id=self.data.get('knowledge_id'),
  185. document_id=self.data.get('document_id')
  186. ) for p in create_problem_list
  187. ]) if len(create_problem_list) else None
  188. # 修改问题集合
  189. QuerySet(Problem).bulk_update([
  190. Problem(
  191. id=row.get('id'),
  192. content=row.get('content')
  193. ) for row in update_problem_list], ['content']
  194. ) if len(update_problem_list) > 0 else None
  195. _paragraph.save()
  196. update_document_char_length(self.data.get('document_id'))
  197. return self.one(), instance, self.data.get('knowledge_id')
  198. def get_problem_list(self):
  199. ProblemParagraphMapping(ProblemParagraphMapping)
  200. problem_paragraph_mapping = QuerySet(ProblemParagraphMapping).filter(
  201. paragraph_id=self.data.get("paragraph_id"))
  202. if len(problem_paragraph_mapping) > 0:
  203. return [ProblemSerializer(problem).data for problem in
  204. QuerySet(Problem).filter(id__in=[ppm.problem_id for ppm in problem_paragraph_mapping])]
  205. return []
  206. def one(self, with_valid=False):
  207. if with_valid:
  208. self.is_valid(raise_exception=True)
  209. return {**ParagraphSerializer(QuerySet(model=Paragraph).get(id=self.data.get('paragraph_id'))).data,
  210. 'problem_list': self.get_problem_list()}
  211. def delete(self, with_valid=False):
  212. if with_valid:
  213. self.is_valid(raise_exception=True)
  214. paragraph_id = self.data.get('paragraph_id')
  215. Paragraph.objects.filter(id=paragraph_id).delete()
  216. delete_problems_and_mappings([paragraph_id])
  217. update_document_char_length(self.data.get('document_id'))
  218. delete_embedding_by_paragraph(paragraph_id)
  219. class Create(serializers.Serializer):
  220. workspace_id = serializers.CharField(required=True, label='Workspace ID')
  221. knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id'))
  222. document_id = serializers.UUIDField(required=True, label=_('document id'))
  223. def is_valid(self, *, raise_exception=False):
  224. super().is_valid(raise_exception=True)
  225. if not QuerySet(Document).filter(id=self.data.get('document_id'),
  226. knowledge_id=self.data.get('knowledge_id')).exists():
  227. raise AppApiException(500, _('The document id is incorrect'))
  228. @transaction.atomic
  229. def save(self, instance: Dict, with_valid=True, with_embedding=True):
  230. if with_valid:
  231. ParagraphSerializers(data=instance).is_valid(raise_exception=True)
  232. self.is_valid()
  233. knowledge_id = self.data.get("knowledge_id")
  234. document_id = self.data.get('document_id')
  235. # 先将同一文档中的所有段落位置向下移动一位
  236. Paragraph.objects.filter(document_id=document_id).update(position=F('position') + 1)
  237. paragraph_problem_model = self.get_paragraph_problem_model(knowledge_id, document_id, instance)
  238. paragraph = paragraph_problem_model.get('paragraph')
  239. problem_paragraph_object_list = paragraph_problem_model.get('problem_paragraph_object_list')
  240. problem_model_list, problem_paragraph_mapping_list = (
  241. ProblemParagraphManage(problem_paragraph_object_list, knowledge_id)
  242. .to_problem_model_list())
  243. # 新加的在最上面
  244. paragraph.position = 1
  245. paragraph.save()
  246. # 调整位置
  247. if 'position' in instance:
  248. if type(instance['position']) is not int:
  249. instance['position'] = 1
  250. else:
  251. instance['position'] = 1
  252. ParagraphSerializers.AdjustPosition(data={
  253. 'paragraph_id': str(paragraph.id),
  254. 'knowledge_id': knowledge_id,
  255. 'document_id': document_id,
  256. 'workspace_id': self.data.get('workspace_id')
  257. }).adjust_position(instance.get('position'))
  258. # 插入問題
  259. QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None
  260. # 插入问题关联关系
  261. QuerySet(ProblemParagraphMapping).bulk_create(
  262. problem_paragraph_mapping_list
  263. ) if len(problem_paragraph_mapping_list) > 0 else None
  264. # 修改长度
  265. update_document_char_length(document_id)
  266. if with_embedding:
  267. model_id = get_embedding_model_id_by_knowledge_id(knowledge_id)
  268. embedding_by_paragraph(str(paragraph.id), model_id)
  269. ListenerManagement.update_status(
  270. QuerySet(Document).filter(id=document_id), TaskType.EMBEDDING, State.SUCCESS
  271. )
  272. ListenerManagement.get_aggregation_document_status(document_id)()
  273. return ParagraphSerializers.Operate(
  274. data={
  275. 'paragraph_id': str(paragraph.id),
  276. 'knowledge_id': knowledge_id,
  277. 'document_id': document_id,
  278. 'workspace_id': self.data.get('workspace_id')
  279. }
  280. ).one(with_valid=True)
  281. @staticmethod
  282. def get_paragraph_problem_model(knowledge_id: str, document_id: str, instance: Dict):
  283. paragraph = Paragraph(
  284. id=uuid.uuid7(),
  285. document_id=document_id,
  286. content=instance.get("content"),
  287. knowledge_id=knowledge_id,
  288. title=instance.get("title") if 'title' in instance else ''
  289. )
  290. problem_paragraph_object_list = [ProblemParagraphObject(
  291. knowledge_id, document_id, str(paragraph.id), problem.get('content')
  292. ) for problem in (instance.get('problem_list') if 'problem_list' in instance else [])]
  293. return {
  294. 'paragraph': paragraph,
  295. 'problem_paragraph_object_list': problem_paragraph_object_list
  296. }
  297. @staticmethod
  298. def or_get(exists_problem_list, content, knowledge_id):
  299. exists = [row for row in exists_problem_list if row.content == content]
  300. if len(exists) > 0:
  301. return exists[0]
  302. else:
  303. return Problem(id=uuid.uuid7(), content=content, knowledge_id=knowledge_id)
  304. class Query(serializers.Serializer):
  305. workspace_id = serializers.CharField(required=True, label=_('workspace id'))
  306. knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id'))
  307. document_id = serializers.UUIDField(required=True, label=_('document id'))
  308. title = serializers.CharField(required=False, label=_('section title'))
  309. content = serializers.CharField(required=False)
  310. def is_valid(self, *, raise_exception=False):
  311. super().is_valid(raise_exception=True)
  312. workspace_id = self.data.get('workspace_id')
  313. query_set = QuerySet(Knowledge).filter(id=self.data.get('knowledge_id'))
  314. if workspace_id:
  315. query_set = query_set.filter(workspace_id=workspace_id)
  316. if not query_set.exists():
  317. raise AppApiException(500, _('Knowledge id does not exist'))
  318. def get_query_set(self):
  319. self.is_valid()
  320. query_set = QuerySet(model=Paragraph)
  321. query_set = query_set.filter(
  322. **{'knowledge_id': self.data.get('knowledge_id'), 'document_id': self.data.get("document_id")})
  323. if 'title' in self.data:
  324. query_set = query_set.filter(
  325. **{'title__icontains': self.data.get('title')})
  326. if 'content' in self.data:
  327. query_set = query_set.filter(**{'content__icontains': self.data.get('content')})
  328. query_set = query_set.order_by('position', 'create_time')
  329. return query_set
  330. def list(self):
  331. return list(map(lambda row: ParagraphSerializer(row).data, self.get_query_set()))
  332. def page(self, current_page, page_size):
  333. query_set = self.get_query_set()
  334. return page_search(current_page, page_size, query_set, lambda row: ParagraphSerializer(row).data)
  335. class Association(serializers.Serializer):
  336. workspace_id = serializers.CharField(required=True, label=_('workspace id'))
  337. knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id'))
  338. problem_id = serializers.UUIDField(required=True, label=_('problem id'))
  339. document_id = serializers.UUIDField(required=True, label=_('document id'))
  340. paragraph_id = serializers.UUIDField(required=True, label=_('paragraph id'))
  341. def is_valid(self, *, raise_exception=True):
  342. super().is_valid(raise_exception=True)
  343. knowledge_id = self.data.get('knowledge_id')
  344. paragraph_id = self.data.get('paragraph_id')
  345. problem_id = self.data.get("problem_id")
  346. workspace_id = self.data.get('workspace_id')
  347. query_set = QuerySet(Knowledge).filter(id=self.data.get('knowledge_id'))
  348. if workspace_id:
  349. query_set = query_set.filter(workspace_id=workspace_id)
  350. if not query_set.exists():
  351. raise AppApiException(500, _('Knowledge id does not exist'))
  352. if not QuerySet(Paragraph).filter(knowledge_id=knowledge_id, id=paragraph_id).exists():
  353. raise AppApiException(500, _('Paragraph does not exist'))
  354. if not QuerySet(Problem).filter(knowledge_id=knowledge_id, id=problem_id).exists():
  355. raise AppApiException(500, _('Problem does not exist'))
  356. def association(self, with_valid=True, with_embedding=True):
  357. if with_valid:
  358. self.is_valid(raise_exception=True)
  359. # 已关联则直接返回
  360. if QuerySet(ProblemParagraphMapping).filter(
  361. knowledge_id=self.data.get('knowledge_id'),
  362. document_id=self.data.get('document_id'),
  363. paragraph_id=self.data.get('paragraph_id'),
  364. problem_id=self.data.get('problem_id')
  365. ).exists():
  366. return True
  367. problem = QuerySet(Problem).filter(id=self.data.get("problem_id")).first()
  368. problem_paragraph_mapping = ProblemParagraphMapping(id=uuid.uuid7(),
  369. document_id=self.data.get('document_id'),
  370. paragraph_id=self.data.get('paragraph_id'),
  371. knowledge_id=self.data.get('knowledge_id'),
  372. problem_id=problem.id)
  373. problem_paragraph_mapping.save()
  374. if with_embedding:
  375. model_id = get_embedding_model_id_by_knowledge_id(self.data.get('knowledge_id'))
  376. embedding_by_problem({
  377. 'text': problem.content,
  378. 'is_active': True,
  379. 'source_type': SourceType.PROBLEM,
  380. 'source_id': problem_paragraph_mapping.id,
  381. 'document_id': self.data.get('document_id'),
  382. 'paragraph_id': self.data.get('paragraph_id'),
  383. 'knowledge_id': self.data.get('knowledge_id'),
  384. }, model_id)
  385. def un_association(self, with_valid=True):
  386. if with_valid:
  387. self.is_valid(raise_exception=True)
  388. problem_paragraph_mapping = QuerySet(ProblemParagraphMapping).filter(
  389. paragraph_id=self.data.get('paragraph_id'),
  390. knowledge_id=self.data.get('knowledge_id'),
  391. problem_id=self.data.get(
  392. 'problem_id')).first()
  393. problem_paragraph_mapping_id = problem_paragraph_mapping.id
  394. problem_paragraph_mapping.delete()
  395. delete_embedding_by_source(problem_paragraph_mapping_id)
  396. return True
  397. class Batch(serializers.Serializer):
  398. workspace_id = serializers.CharField(required=False, label=_('workspace id'))
  399. knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id'))
  400. document_id = serializers.UUIDField(required=True, label=_('document id'))
  401. def is_valid(self, *, raise_exception=False):
  402. super().is_valid(raise_exception=True)
  403. workspace_id = self.data.get('workspace_id')
  404. query_set = QuerySet(Knowledge).filter(id=self.data.get('knowledge_id'))
  405. if workspace_id:
  406. query_set = query_set.filter(workspace_id=workspace_id)
  407. if not query_set.exists():
  408. raise AppApiException(500, _('Knowledge id does not exist'))
  409. @transaction.atomic
  410. def batch_delete(self, instance: Dict, with_valid=True):
  411. if with_valid:
  412. BatchSerializer(data=instance).is_valid(model=Paragraph, raise_exception=True)
  413. self.is_valid(raise_exception=True)
  414. paragraph_id_list = instance.get("id_list")
  415. QuerySet(Paragraph).filter(id__in=paragraph_id_list).delete()
  416. delete_problems_and_mappings(paragraph_id_list)
  417. update_document_char_length(self.data.get('document_id'))
  418. # 删除向量库
  419. delete_embedding_by_paragraph_ids(paragraph_id_list)
  420. return True
  421. def batch_generate_related(self, instance: Dict, with_valid=True):
  422. if with_valid:
  423. self.is_valid(raise_exception=True)
  424. paragraph_id_list = instance.get("paragraph_id_list")
  425. model_id = instance.get("model_id")
  426. prompt = instance.get("prompt")
  427. model_params_setting = instance.get("model_params_setting")
  428. document_id = self.data.get('document_id')
  429. ListenerManagement.update_status(
  430. QuerySet(Document).filter(id=document_id),
  431. TaskType.GENERATE_PROBLEM,
  432. State.PENDING
  433. )
  434. ListenerManagement.update_status(
  435. QuerySet(Paragraph).filter(id__in=paragraph_id_list),
  436. TaskType.GENERATE_PROBLEM,
  437. State.PENDING
  438. )
  439. ListenerManagement.get_aggregation_document_status(document_id)()
  440. try:
  441. generate_related_by_paragraph_id_list.delay(document_id, paragraph_id_list, model_id, model_params_setting, prompt)
  442. except AlreadyQueued as e:
  443. raise AppApiException(500, _('The task is being executed, please do not send it again.'))
  444. class Migrate(serializers.Serializer):
  445. workspace_id = serializers.CharField(required=True, label=_('workspace id'))
  446. knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id'))
  447. document_id = serializers.UUIDField(required=True, label=_('document id'))
  448. target_knowledge_id = serializers.UUIDField(required=True, label=_('target knowledge id'))
  449. target_document_id = serializers.UUIDField(required=True, label=_('target document id'))
  450. paragraph_id_list = serializers.ListField(required=True, label=_('paragraph id list'),
  451. child=serializers.UUIDField(required=True, label=_('paragraph id')))
  452. def is_valid(self, *, raise_exception=False):
  453. super().is_valid(raise_exception=True)
  454. workspace_id = self.data.get('workspace_id')
  455. query_set = QuerySet(Knowledge).filter(id=self.data.get('knowledge_id'))
  456. if workspace_id:
  457. query_set = query_set.filter(workspace_id=workspace_id)
  458. if not query_set.exists():
  459. raise AppApiException(500, _('Knowledge id does not exist'))
  460. document_list = QuerySet(Document).filter(
  461. id__in=[self.data.get('document_id'), self.data.get('target_document_id')])
  462. document_id = self.data.get('document_id')
  463. target_document_id = self.data.get('target_document_id')
  464. if document_id == target_document_id:
  465. raise AppApiException(5000, _('The document to be migrated is consistent with the target document'))
  466. if len([document for document in document_list if str(document.id) == self.data.get('document_id')]) < 1:
  467. raise AppApiException(5000, _('The document id does not exist [{document_id}]').format(
  468. document_id=self.data.get('document_id')))
  469. if len([document for document in document_list if
  470. str(document.id) == self.data.get('target_document_id')]) < 1:
  471. raise AppApiException(5000, _('The target document id does not exist [{document_id}]').format(
  472. document_id=self.data.get('target_document_id')))
  473. @transaction.atomic
  474. def migrate(self, with_valid=True):
  475. if with_valid:
  476. self.is_valid(raise_exception=True)
  477. knowledge_id = self.data.get('knowledge_id')
  478. target_knowledge_id = self.data.get('target_knowledge_id')
  479. document_id = self.data.get('document_id')
  480. target_document_id = self.data.get('target_document_id')
  481. paragraph_id_list = self.data.get('paragraph_id_list')
  482. paragraph_list = QuerySet(Paragraph).filter(knowledge_id=knowledge_id, document_id=document_id,
  483. id__in=paragraph_id_list)
  484. problem_paragraph_mapping_list = QuerySet(ProblemParagraphMapping).filter(paragraph__in=paragraph_list)
  485. # 同数据集迁移
  486. if target_knowledge_id == knowledge_id:
  487. if len(problem_paragraph_mapping_list):
  488. problem_paragraph_mapping_list = [
  489. self.update_problem_paragraph_mapping(target_document_id,
  490. problem_paragraph_mapping) for problem_paragraph_mapping
  491. in
  492. problem_paragraph_mapping_list]
  493. # 修改mapping
  494. QuerySet(ProblemParagraphMapping).bulk_update(problem_paragraph_mapping_list,
  495. ['document_id'])
  496. update_embedding_document_id([paragraph.id for paragraph in paragraph_list],
  497. target_document_id, target_knowledge_id, None)
  498. # 修改段落信息
  499. paragraph_list.update(document_id=target_document_id)
  500. # 将当前文档中所有段落的位置向下移动,为新段落腾出空间
  501. Paragraph.objects.filter(document_id=target_document_id).exclude(
  502. id__in=paragraph_id_list
  503. ).update(position=F('position') + len(paragraph_id_list))
  504. # 重新查询迁移的段落
  505. paragraph_list = Paragraph.objects.filter(
  506. id__in=paragraph_id_list, document_id=target_document_id
  507. )
  508. # 将迁移的段落位置设置为从1开始的序号
  509. for i, paragraph in enumerate(paragraph_list):
  510. paragraph.position = i + 1
  511. paragraph.save()
  512. # 不同数据集迁移
  513. else:
  514. problem_list = QuerySet(Problem).filter(
  515. id__in=[problem_paragraph_mapping.problem_id for problem_paragraph_mapping in
  516. problem_paragraph_mapping_list])
  517. # 目标数据集问题
  518. target_problem_list = list(
  519. QuerySet(Problem).filter(content__in=[problem.content for problem in problem_list],
  520. knowledge_id=target_knowledge_id))
  521. target_handle_problem_list = [
  522. self.get_target_knowledge_problem(target_knowledge_id, target_document_id,
  523. problem_paragraph_mapping,
  524. problem_list, target_problem_list) for
  525. problem_paragraph_mapping
  526. in
  527. problem_paragraph_mapping_list]
  528. create_problem_list = [problem for problem, is_create in target_handle_problem_list if
  529. is_create is not None and is_create]
  530. # 插入问题
  531. QuerySet(Problem).bulk_create(create_problem_list)
  532. # 修改mapping
  533. QuerySet(ProblemParagraphMapping).bulk_update(problem_paragraph_mapping_list,
  534. ['problem_id', 'knowledge_id', 'document_id'])
  535. target_knowledge = QuerySet(Knowledge).filter(id=target_knowledge_id).first()
  536. knowledge = QuerySet(Knowledge).filter(id=knowledge_id).first()
  537. embedding_model_id = None
  538. if target_knowledge.embedding_model_id != knowledge.embedding_model_id:
  539. embedding_model_id = str(target_knowledge.embedding_model_id)
  540. pid_list = [paragraph.id for paragraph in paragraph_list]
  541. # 修改段落信息
  542. paragraph_list.update(knowledge_id=target_knowledge_id, document_id=target_document_id)
  543. # 将当前文档中所有段落的位置向下移动,为新段落腾出空间
  544. Paragraph.objects.filter(document_id=target_document_id).exclude(
  545. id__in=pid_list
  546. ).update(position=F('position') + len(pid_list))
  547. # 重新查询迁移的段落
  548. paragraph_list = Paragraph.objects.filter(
  549. id__in=pid_list, document_id=target_document_id
  550. )
  551. # 将迁移的段落位置设置为从1开始的序号
  552. for i, paragraph in enumerate(paragraph_list):
  553. paragraph.position = i + 1
  554. paragraph.save()
  555. # 修改向量段落信息
  556. update_embedding_document_id(pid_list, target_document_id, target_knowledge_id, embedding_model_id)
  557. update_document_char_length(document_id)
  558. update_document_char_length(target_document_id)
  559. @staticmethod
  560. def update_problem_paragraph_mapping(target_document_id: str, problem_paragraph_mapping):
  561. problem_paragraph_mapping.document_id = target_document_id
  562. return problem_paragraph_mapping
  563. @staticmethod
  564. def get_target_knowledge_problem(target_knowledge_id: str,
  565. target_document_id: str,
  566. problem_paragraph_mapping,
  567. source_problem_list,
  568. target_problem_list):
  569. source_problem_list = [source_problem for source_problem in source_problem_list if
  570. source_problem.id == problem_paragraph_mapping.problem_id]
  571. problem_paragraph_mapping.knowledge_id = target_knowledge_id
  572. problem_paragraph_mapping.document_id = target_document_id
  573. if len(source_problem_list) > 0:
  574. problem_content = source_problem_list[-1].content
  575. problem_list = [problem for problem in target_problem_list if problem.content == problem_content]
  576. if len(problem_list) > 0:
  577. problem = problem_list[-1]
  578. problem_paragraph_mapping.problem_id = problem.id
  579. return problem, False
  580. else:
  581. problem = Problem(id=uuid.uuid7(), knowledge_id=target_knowledge_id, content=problem_content)
  582. target_problem_list.append(problem)
  583. problem_paragraph_mapping.problem_id = problem.id
  584. return problem, True
  585. return None
  586. class AdjustPosition(serializers.Serializer):
  587. workspace_id = serializers.CharField(required=True, label=_('workspace id'))
  588. knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id'))
  589. document_id = serializers.UUIDField(required=True, label=_('document id'))
  590. paragraph_id = serializers.UUIDField(required=True, label=_('paragraph id'))
  591. def is_valid(self, *, raise_exception=False):
  592. super().is_valid(raise_exception=True)
  593. workspace_id = self.data.get('workspace_id')
  594. query_set = QuerySet(Knowledge).filter(id=self.data.get('knowledge_id'))
  595. if workspace_id:
  596. query_set = query_set.filter(workspace_id=workspace_id)
  597. if not query_set.exists():
  598. raise AppApiException(500, _('Knowledge id does not exist'))
  599. @transaction.atomic
  600. def adjust_position(self, new_position):
  601. """
  602. 调整段落顺序
  603. :param new_position: 新的顺序值
  604. """
  605. self.is_valid(raise_exception=True)
  606. try:
  607. new_position = int(new_position)
  608. except (TypeError, ValueError):
  609. raise serializers.ValidationError(_('new_position must be an integer'))
  610. # 获取当前段落
  611. paragraph = Paragraph.objects.get(id=self.data.get('paragraph_id'))
  612. old_position = paragraph.position
  613. if old_position < new_position:
  614. # 如果新顺序在当前顺序之后,更新受影响段落的顺序
  615. Paragraph.objects.filter(
  616. position__gt=old_position, position__lte=new_position
  617. ).update(position=F('position') - 1)
  618. elif old_position > new_position:
  619. # 如果新顺序在当前顺序之前,更新受影响段落的顺序
  620. Paragraph.objects.filter(
  621. position__lt=old_position, position__gte=new_position
  622. ).update(position=F('position') + 1)
  623. # 更新当前段落的顺序
  624. paragraph.position = new_position
  625. paragraph.save()
  626. def delete_problems_and_mappings(paragraph_ids):
  627. problem_paragraph_mappings = ProblemParagraphMapping.objects.filter(paragraph_id__in=paragraph_ids)
  628. problem_ids = set(problem_paragraph_mappings.values_list('problem_id', flat=True))
  629. if problem_ids:
  630. problem_paragraph_mappings.delete()
  631. remaining_problem_counts = ProblemParagraphMapping.objects.filter(problem_id__in=problem_ids).values(
  632. 'problem_id').annotate(count=Count('problem_id'))
  633. remaining_problem_ids = {pc['problem_id'] for pc in remaining_problem_counts}
  634. problem_ids_to_delete = problem_ids - remaining_problem_ids
  635. Problem.objects.filter(id__in=problem_ids_to_delete).delete()
  636. else:
  637. problem_paragraph_mappings.delete()