document.py 88 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728
  1. import io
  2. import json
  3. import os
  4. import re
  5. import traceback
  6. from collections import defaultdict
  7. from functools import reduce
  8. from tempfile import TemporaryDirectory
  9. from typing import Dict, List
  10. import openpyxl
  11. import uuid_utils.compat as uuid
  12. from celery_once import AlreadyQueued
  13. from django.contrib.postgres.fields import JSONField
  14. from django.core import validators
  15. from django.db import transaction, models
  16. from django.db.models import QuerySet, Func, F, Value
  17. from django.db.models.aggregates import Max
  18. from django.db.models.functions import Substr, Reverse
  19. from django.db.models.query_utils import Q
  20. from django.http import HttpResponse
  21. from django.utils.translation import gettext_lazy as _, gettext, get_language, to_locale
  22. from openpyxl.cell.cell import ILLEGAL_CHARACTERS_RE
  23. from rest_framework import serializers
  24. from xlwt import Utils
  25. from common.db.search import native_search, get_dynamics_model, native_page_search
  26. from common.event.common import work_thread_pool
  27. from common.event.listener_manage import ListenerManagement
  28. from common.exception.app_exception import AppApiException
  29. from common.field.common import UploadedFileField
  30. from common.handle.impl.qa.csv_parse_qa_handle import CsvParseQAHandle
  31. from common.handle.impl.qa.md_parse_qa_handle import MarkdownParseQAHandle
  32. from common.handle.impl.qa.xls_parse_qa_handle import XlsParseQAHandle
  33. from common.handle.impl.qa.xlsx_parse_qa_handle import XlsxParseQAHandle
  34. from common.handle.impl.qa.zip_parse_qa_handle import ZipParseQAHandle
  35. from common.handle.impl.table.csv_parse_table_handle import CsvParseTableHandle
  36. from common.handle.impl.table.xls_parse_table_handle import XlsParseTableHandle
  37. from common.handle.impl.table.xlsx_parse_table_handle import XlsxParseTableHandle
  38. from common.handle.impl.text.csv_split_handle import CsvSplitHandle
  39. from common.handle.impl.text.doc_split_handle import DocSplitHandle
  40. from common.handle.impl.text.html_split_handle import HTMLSplitHandle
  41. from common.handle.impl.text.pdf_split_handle import PdfSplitHandle
  42. from common.handle.impl.text.text_split_handle import TextSplitHandle
  43. from common.handle.impl.text.xls_split_handle import XlsSplitHandle
  44. from common.handle.impl.text.xlsx_split_handle import XlsxSplitHandle
  45. from common.handle.impl.text.zip_split_handle import ZipSplitHandle
  46. from common.utils.common import post, get_file_content, bulk_create_in_batches, parse_image
  47. from common.utils.fork import Fork
  48. from common.utils.logger import maxkb_logger
  49. from common.utils.split_model import get_split_model, flat_map
  50. from knowledge.models import Knowledge, Paragraph, Problem, Document, KnowledgeType, ProblemParagraphMapping, State, \
  51. TaskType, File, FileSourceType, Tag, DocumentTag
  52. from knowledge.serializers.common import ProblemParagraphManage, BatchSerializer, \
  53. get_embedding_model_id_by_knowledge_id, MetaSerializer, write_image, zip_dir
  54. from knowledge.serializers.paragraph import ParagraphSerializers, ParagraphInstanceSerializer, \
  55. delete_problems_and_mappings
  56. from knowledge.task.embedding import embedding_by_document, delete_embedding_by_document_list, \
  57. delete_embedding_by_document, delete_embedding_by_paragraph_ids, embedding_by_document_list, \
  58. update_embedding_knowledge_id
  59. from knowledge.task.generate import generate_related_by_document_id
  60. from knowledge.task.sync import sync_web_document
  61. from maxkb.const import PROJECT_DIR
  62. from models_provider.models import Model
  63. from oss.serializers.file import FileSerializer
  64. default_split_handle = TextSplitHandle()
  65. split_handles = [
  66. HTMLSplitHandle(),
  67. DocSplitHandle(),
  68. PdfSplitHandle(),
  69. XlsxSplitHandle(),
  70. XlsSplitHandle(),
  71. CsvSplitHandle(),
  72. ZipSplitHandle(),
  73. default_split_handle
  74. ]
  75. md_qa_split_handle = MarkdownParseQAHandle()
  76. parse_qa_handle_list = [XlsParseQAHandle(), CsvParseQAHandle(), XlsxParseQAHandle(), ZipParseQAHandle()]
  77. parse_table_handle_list = [CsvParseTableHandle(), XlsParseTableHandle(), XlsxParseTableHandle()]
  78. def convert_uuid_to_str(obj):
  79. if isinstance(obj, dict):
  80. return {k: convert_uuid_to_str(v) for k, v in obj.items()}
  81. elif isinstance(obj, list):
  82. return [convert_uuid_to_str(i) for i in obj]
  83. elif isinstance(obj, uuid.UUID):
  84. return str(obj)
  85. else:
  86. return obj
  87. class BatchCancelInstanceSerializer(serializers.Serializer):
  88. id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True), label=_('id list'))
  89. type = serializers.IntegerField(required=True, label=_('task type'))
  90. def is_valid(self, *, raise_exception=False):
  91. super().is_valid(raise_exception=True)
  92. _type = self.data.get('type')
  93. try:
  94. TaskType(_type)
  95. except Exception as e:
  96. raise AppApiException(500, _('task type not support'))
  97. class DocumentInstanceSerializer(serializers.Serializer):
  98. name = serializers.CharField(required=True, label=_('document name'), max_length=128, min_length=1,
  99. source=_('document name'))
  100. paragraphs = ParagraphInstanceSerializer(required=False, many=True, allow_null=True)
  101. source_file_id = serializers.UUIDField(required=False, allow_null=True, label=_('source file id'))
  102. class CancelInstanceSerializer(serializers.Serializer):
  103. type = serializers.IntegerField(required=True, label=_('task type'))
  104. def is_valid(self, *, raise_exception=False):
  105. super().is_valid(raise_exception=True)
  106. _type = self.data.get('type')
  107. try:
  108. TaskType(_type)
  109. except Exception as e:
  110. raise AppApiException(500, _('task type not support'))
  111. class DocumentEditInstanceSerializer(serializers.Serializer):
  112. meta = serializers.DictField(required=False)
  113. name = serializers.CharField(required=False, max_length=128, min_length=1, label=_('document name'),
  114. source=_('document name'))
  115. hit_handling_method = serializers.CharField(required=False, validators=[
  116. validators.RegexValidator(regex=re.compile("^optimization|directly_return$"),
  117. message=_('The type only supports optimization|directly_return'),
  118. code=500)
  119. ], label=_('hit handling method'))
  120. directly_return_similarity = serializers.FloatField(required=False, max_value=2, min_value=0,
  121. label=_('directly return similarity'))
  122. is_active = serializers.BooleanField(required=False, label=_('document is active'))
  123. @staticmethod
  124. def get_meta_valid_map():
  125. knowledge_meta_valid_map = {
  126. KnowledgeType.BASE: MetaSerializer.BaseMeta,
  127. KnowledgeType.WEB: MetaSerializer.WebMeta
  128. }
  129. return knowledge_meta_valid_map
  130. def is_valid(self, *, document: Document = None):
  131. super().is_valid(raise_exception=True)
  132. if 'meta' in self.data and self.data.get('meta') is not None and self.data.get('meta') != {}:
  133. knowledge_meta_valid_map = self.get_meta_valid_map()
  134. valid_class = knowledge_meta_valid_map.get(document.type)
  135. if valid_class is not None:
  136. valid_class(data=self.data.get('meta')).is_valid(raise_exception=True)
  137. class DocumentSplitRequest(serializers.Serializer):
  138. file = serializers.ListField(required=True, label=_('file list'))
  139. limit = serializers.IntegerField(required=False, label=_('limit'))
  140. patterns = serializers.ListField(
  141. required=False,
  142. child=serializers.CharField(required=True, label=_('patterns')),
  143. label=_('patterns')
  144. )
  145. with_filter = serializers.BooleanField(required=False, label=_('Auto Clean'))
  146. class DocumentWebInstanceSerializer(serializers.Serializer):
  147. source_url_list = serializers.ListField(required=True, label=_('document url list'),
  148. child=serializers.CharField(required=True, label=_('document url list')))
  149. selector = serializers.CharField(required=False, allow_null=True, allow_blank=True, label=_('selector'))
  150. class DocumentInstanceQASerializer(serializers.Serializer):
  151. file_list = serializers.ListSerializer(required=True, label=_('file list'),
  152. child=serializers.FileField(required=True, label=_('file')))
  153. class DocumentInstanceTableSerializer(serializers.Serializer):
  154. file_list = serializers.ListSerializer(required=True, label=_('file list'),
  155. child=serializers.FileField(required=True, label=_('file')))
  156. class DocumentRefreshSerializer(serializers.Serializer):
  157. state_list = serializers.ListField(required=True, label=_('state list'))
  158. class DocumentBatchRefreshSerializer(serializers.Serializer):
  159. id_list = serializers.ListField(required=True, label=_('id list'))
  160. state_list = serializers.ListField(required=True, label=_('state list'))
  161. class DocumentBatchGenerateRelatedSerializer(serializers.Serializer):
  162. document_id_list = serializers.ListField(required=True, label=_('document id list'))
  163. model_id = serializers.UUIDField(required=True, label=_('model id'))
  164. prompt = serializers.CharField(required=True, label=_('prompt'))
  165. state_list = serializers.ListField(required=True, label=_('state list'))
  166. class DocumentMigrateSerializer(serializers.Serializer):
  167. document_id_list = serializers.ListField(required=True, label=_('document id list'))
  168. class BatchEditHitHandlingSerializer(serializers.Serializer):
  169. id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True), label=_('id list'))
  170. hit_handling_method = serializers.CharField(required=True, label=_('hit handling method'))
  171. directly_return_similarity = serializers.FloatField(required=False, max_value=2, min_value=0,
  172. label=_('directly return similarity'))
  173. def is_valid(self, *, raise_exception=False):
  174. super().is_valid(raise_exception=True)
  175. if self.data.get('hit_handling_method') not in ['optimization', 'directly_return']:
  176. raise AppApiException(500, _('The type only supports optimization|directly_return'))
  177. class DocumentSerializers(serializers.Serializer):
  178. class Export(serializers.Serializer):
  179. type = serializers.CharField(required=True, validators=[
  180. validators.RegexValidator(regex=re.compile("^csv|excel$"),
  181. message=_('The template type only supports excel|csv'),
  182. code=500)
  183. ], label=_('type'))
  184. def export(self, with_valid=True):
  185. if with_valid:
  186. self.is_valid(raise_exception=True)
  187. language = get_language()
  188. if self.data.get('type') == 'csv':
  189. file = open(
  190. os.path.join(PROJECT_DIR, "apps", "knowledge", 'template',
  191. f'csv_template_{to_locale(language)}.csv'),
  192. "rb")
  193. content = file.read()
  194. file.close()
  195. return HttpResponse(content, status=200, headers={'Content-Type': 'text/csv',
  196. 'Content-Disposition': 'attachment; filename="csv_template.csv"'})
  197. elif self.data.get('type') == 'excel':
  198. file = open(os.path.join(PROJECT_DIR, "apps", "knowledge", 'template',
  199. f'excel_template_{to_locale(language)}.xlsx'), "rb")
  200. content = file.read()
  201. file.close()
  202. return HttpResponse(content, status=200, headers={'Content-Type': 'application/vnd.ms-excel',
  203. 'Content-Disposition': 'attachment; filename="excel_template.xlsx"'})
  204. else:
  205. return None
  206. def table_export(self, with_valid=True):
  207. if with_valid:
  208. self.is_valid(raise_exception=True)
  209. language = get_language()
  210. if self.data.get('type') == 'csv':
  211. file = open(
  212. os.path.join(PROJECT_DIR, "apps", "knowledge", 'template',
  213. f'table_template_{to_locale(language)}.csv'),
  214. "rb")
  215. content = file.read()
  216. file.close()
  217. return HttpResponse(content, status=200, headers={'Content-Type': 'text/csv',
  218. 'Content-Disposition': 'attachment; filename="csv_template.csv"'})
  219. elif self.data.get('type') == 'excel':
  220. file = open(os.path.join(PROJECT_DIR, "apps", "knowledge", 'template',
  221. f'table_template_{to_locale(language)}.xlsx'),
  222. "rb")
  223. content = file.read()
  224. file.close()
  225. return HttpResponse(content, status=200, headers={'Content-Type': 'application/vnd.ms-excel',
  226. 'Content-Disposition': 'attachment; filename="excel_template.xlsx"'})
  227. else:
  228. return None
  229. class Migrate(serializers.Serializer):
  230. workspace_id = serializers.CharField(required=True, label=_('workspace id'))
  231. knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id'))
  232. target_knowledge_id = serializers.UUIDField(required=True, label=_('target knowledge id'))
  233. document_id_list = serializers.ListField(required=True, label=_('document list'),
  234. child=serializers.UUIDField(required=True, label=_('document id')))
  235. def is_valid(self, *, raise_exception=False):
  236. super().is_valid(raise_exception=True)
  237. workspace_id = self.data.get('workspace_id')
  238. query_set = QuerySet(Knowledge).filter(id=self.data.get('knowledge_id'))
  239. if workspace_id:
  240. query_set = query_set.filter(workspace_id=workspace_id)
  241. if not query_set.exists():
  242. raise AppApiException(500, _('Knowledge id does not exist'))
  243. query_set = QuerySet(Knowledge).filter(id=self.data.get('target_knowledge_id'))
  244. if workspace_id:
  245. query_set = query_set.filter(workspace_id=workspace_id)
  246. if not query_set.exists():
  247. raise AppApiException(500, _('Knowledge id does not exist'))
  248. @transaction.atomic
  249. def migrate(self, with_valid=True):
  250. if with_valid:
  251. self.is_valid(raise_exception=True)
  252. knowledge_id = self.data.get('knowledge_id')
  253. target_knowledge_id = self.data.get('target_knowledge_id')
  254. knowledge = QuerySet(Knowledge).filter(id=knowledge_id).first()
  255. target_knowledge = QuerySet(Knowledge).filter(id=target_knowledge_id).first()
  256. document_id_list = self.data.get('document_id_list')
  257. document_list = QuerySet(Document).filter(knowledge_id=knowledge_id, id__in=document_id_list)
  258. paragraph_list = QuerySet(Paragraph).filter(knowledge_id=knowledge_id, document_id__in=document_id_list)
  259. problem_paragraph_mapping_list = QuerySet(ProblemParagraphMapping).filter(paragraph__in=paragraph_list)
  260. problem_list = QuerySet(Problem).filter(
  261. id__in=[problem_paragraph_mapping.problem_id for problem_paragraph_mapping in
  262. problem_paragraph_mapping_list])
  263. target_problem_list = list(
  264. QuerySet(Problem).filter(content__in=[problem.content for problem in problem_list],
  265. knowledge_id=target_knowledge_id))
  266. target_handle_problem_list = [
  267. self.get_target_knowledge_problem(target_knowledge_id, problem_paragraph_mapping,
  268. problem_list, target_problem_list) for
  269. problem_paragraph_mapping
  270. in
  271. problem_paragraph_mapping_list]
  272. create_problem_list = [problem for problem, is_create in target_handle_problem_list if
  273. is_create is not None and is_create]
  274. # 插入问题
  275. QuerySet(Problem).bulk_create(create_problem_list)
  276. # 修改mapping
  277. QuerySet(ProblemParagraphMapping).bulk_update(problem_paragraph_mapping_list,
  278. ['problem_id', 'knowledge_id'])
  279. # 修改文档
  280. if knowledge.type == KnowledgeType.BASE.value and target_knowledge.type == KnowledgeType.WEB.value:
  281. document_list.update(knowledge_id=target_knowledge_id, type=KnowledgeType.WEB,
  282. meta={'source_url': '', 'selector': ''})
  283. elif target_knowledge.type == KnowledgeType.BASE.value and knowledge.type == KnowledgeType.WEB.value:
  284. document_list.update(knowledge_id=target_knowledge_id, type=KnowledgeType.BASE,
  285. meta={})
  286. else:
  287. document_list.update(knowledge_id=target_knowledge_id)
  288. model_id = None
  289. if knowledge.embedding_model_id != target_knowledge.embedding_model_id:
  290. model_id = get_embedding_model_id_by_knowledge_id(target_knowledge_id)
  291. pid_list = [paragraph.id for paragraph in paragraph_list]
  292. # 修改段落信息
  293. paragraph_list.update(knowledge_id=target_knowledge_id)
  294. # 修改向量信息
  295. if model_id:
  296. delete_embedding_by_paragraph_ids(pid_list)
  297. ListenerManagement.update_status(QuerySet(Document).filter(id__in=document_id_list),
  298. TaskType.EMBEDDING,
  299. State.PENDING)
  300. ListenerManagement.update_status(QuerySet(Paragraph).filter(document_id__in=document_id_list),
  301. TaskType.EMBEDDING,
  302. State.PENDING)
  303. ListenerManagement.get_aggregation_document_status_by_query_set(
  304. QuerySet(Document).filter(id__in=document_id_list))()
  305. embedding_by_document_list.delay(document_id_list, model_id)
  306. else:
  307. update_embedding_knowledge_id(pid_list, target_knowledge_id)
  308. @staticmethod
  309. def get_target_knowledge_problem(target_knowledge_id: str,
  310. problem_paragraph_mapping,
  311. source_problem_list,
  312. target_problem_list):
  313. source_problem_list = [source_problem for source_problem in source_problem_list if
  314. source_problem.id == problem_paragraph_mapping.problem_id]
  315. problem_paragraph_mapping.knowledge_id = target_knowledge_id
  316. if len(source_problem_list) > 0:
  317. problem_content = source_problem_list[-1].content
  318. problem_list = [problem for problem in target_problem_list if problem.content == problem_content]
  319. if len(problem_list) > 0:
  320. problem = problem_list[-1]
  321. problem_paragraph_mapping.problem_id = problem.id
  322. return problem, False
  323. else:
  324. problem = Problem(id=uuid.uuid7(), knowledge_id=target_knowledge_id, content=problem_content)
  325. target_problem_list.append(problem)
  326. problem_paragraph_mapping.problem_id = problem.id
  327. return problem, True
  328. return None
  329. class Query(serializers.Serializer):
  330. # 知识库id
  331. workspace_id = serializers.CharField(required=True, label=_('workspace id'))
  332. knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id'))
  333. name = serializers.CharField(
  334. required=False, max_length=128, min_length=1, allow_null=True, allow_blank=True, label=_('document name')
  335. )
  336. hit_handling_method = serializers.CharField(
  337. required=False, label=_('hit handling method'), allow_null=True, allow_blank=True
  338. )
  339. is_active = serializers.BooleanField(required=False, label=_('document is active'), allow_null=True)
  340. task_type = serializers.IntegerField(required=False, label=_('task type'))
  341. status = serializers.CharField(required=False, label=_('status'), allow_null=True, allow_blank=True)
  342. order_by = serializers.CharField(required=False, label=_('order by'), allow_null=True, allow_blank=True)
  343. tag = serializers.CharField(required=False, label=_('tag'), allow_null=True, allow_blank=True)
  344. tag_ids = serializers.ListField(child=serializers.UUIDField(), allow_null=True, required=False,
  345. allow_empty=True)
  346. no_tag = serializers.BooleanField(required=False, default=False, allow_null=True)
  347. tag_exclude = serializers.BooleanField(required=False, default=False, allow_null=True)
  348. def get_query_set(self):
  349. query_set = QuerySet(model=Document)
  350. query_set = query_set.filter(**{'knowledge_id': self.data.get("knowledge_id")})
  351. tag_ids = self.data.get('tag_ids')
  352. no_tag = self.data.get('no_tag')
  353. tag_exclude = self.data.get('tag_exclude')
  354. if 'name' in self.data and self.data.get('name') is not None:
  355. query_set = query_set.filter(**{'name__icontains': self.data.get('name')})
  356. if 'hit_handling_method' in self.data and self.data.get('hit_handling_method') not in [None, '']:
  357. query_set = query_set.filter(**{'hit_handling_method': self.data.get('hit_handling_method')})
  358. if 'is_active' in self.data and self.data.get('is_active') is not None:
  359. query_set = query_set.filter(**{'is_active': self.data.get('is_active')})
  360. if no_tag and tag_ids:
  361. matched_doc_ids = QuerySet(DocumentTag).filter(tag_id__in=tag_ids).values_list('document_id', flat=True)
  362. tagged_doc_ids = QuerySet(DocumentTag).values_list('document_id', flat=True)
  363. query_set = query_set.filter(
  364. Q(id__in=matched_doc_ids) | ~Q(id__in=tagged_doc_ids)
  365. )
  366. elif no_tag:
  367. tagged_doc_ids = QuerySet(DocumentTag).values_list('document_id', flat=True)
  368. query_set = query_set.exclude(id__in=tagged_doc_ids)
  369. elif tag_ids:
  370. matched_doc_ids = QuerySet(DocumentTag).filter(tag_id__in=tag_ids).values_list('document_id', flat=True)
  371. if tag_exclude:
  372. query_set = query_set.exclude(id__in=matched_doc_ids)
  373. else:
  374. query_set = query_set.filter(id__in=matched_doc_ids)
  375. if 'status' in self.data and self.data.get('status') is not None:
  376. task_type = self.data.get('task_type')
  377. status = self.data.get('status')
  378. if task_type is not None:
  379. query_set = query_set.annotate(
  380. reversed_status=Reverse('status'),
  381. task_type_status=Substr('reversed_status', TaskType(task_type).value, 1),
  382. ).filter(
  383. task_type_status=State(status).value
  384. ).values('id')
  385. else:
  386. if status != State.SUCCESS.value:
  387. query_set = query_set.filter(status__icontains=status)
  388. else:
  389. query_set = query_set.filter(status__iregex='^[2n]*$')
  390. if 'tag' in self.data and self.data.get('tag') not in [None, '']:
  391. tag_name = self.data.get('tag')
  392. document_id_list = QuerySet(DocumentTag).filter(
  393. Q(tag__key__icontains=tag_name) | Q(tag__value__icontains=tag_name)
  394. ).values_list('document_id', flat=True)
  395. query_set = query_set.filter(id__in=document_id_list)
  396. order_by = self.data.get('order_by', '')
  397. order_by_query_set = QuerySet(model=get_dynamics_model(
  398. {'char_length': models.CharField(), 'paragraph_count': models.IntegerField(),
  399. "update_time": models.IntegerField(), 'create_time': models.DateTimeField()}))
  400. if order_by:
  401. order_by_query_set = order_by_query_set.order_by(order_by)
  402. else:
  403. order_by_query_set = order_by_query_set.order_by('-create_time', 'id')
  404. return {
  405. 'document_custom_sql': query_set,
  406. 'order_by_query': order_by_query_set
  407. }
  408. def list(self):
  409. self.is_valid(raise_exception=True)
  410. query_set = self.get_query_set()
  411. return native_search(query_set, select_string=get_file_content(
  412. os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'list_document.sql')))
  413. def page(self, current_page, page_size):
  414. self.is_valid(raise_exception=True)
  415. query_set = self.get_query_set()
  416. return native_page_search(current_page, page_size, query_set, select_string=get_file_content(
  417. os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'list_document.sql')))
  418. class Sync(serializers.Serializer):
  419. workspace_id = serializers.CharField(required=False, label=_('workspace id'))
  420. knowledge_id = serializers.UUIDField(required=False, label=_('knowledge id'))
  421. document_id = serializers.UUIDField(required=True, label=_('document id'))
  422. def is_valid(self, *, raise_exception=False):
  423. super().is_valid(raise_exception=True)
  424. workspace_id = self.data.get('workspace_id')
  425. query_set = QuerySet(Knowledge).filter(id=self.data.get('knowledge_id'))
  426. if workspace_id:
  427. query_set = query_set.filter(workspace_id=workspace_id)
  428. if not query_set.exists():
  429. raise AppApiException(500, _('Knowledge id does not exist'))
  430. document_id = self.data.get('document_id')
  431. first = QuerySet(Document).filter(id=document_id).first()
  432. if first is None:
  433. raise AppApiException(500, _('document id not exist'))
  434. if first.type != KnowledgeType.WEB:
  435. raise AppApiException(500, _('Synchronization is only supported for web site types'))
  436. @transaction.atomic
  437. def sync(self, with_valid=True, with_embedding=True):
  438. if with_valid:
  439. self.is_valid(raise_exception=True)
  440. document_id = self.data.get('document_id')
  441. document = QuerySet(Document).filter(id=document_id).first()
  442. state = State.SUCCESS
  443. if document.type != KnowledgeType.WEB:
  444. return True
  445. try:
  446. ListenerManagement.update_status(QuerySet(Document).filter(id=document_id),
  447. TaskType.SYNC,
  448. State.PENDING)
  449. ListenerManagement.get_aggregation_document_status(document_id)()
  450. source_url = document.meta.get('source_url')
  451. selector_list = document.meta.get('selector').split(
  452. " ") if 'selector' in document.meta and document.meta.get('selector') is not None else []
  453. result = Fork(source_url, selector_list).fork()
  454. if result.status == 200:
  455. # 删除段落
  456. QuerySet(model=Paragraph).filter(document_id=document_id).delete()
  457. # 删除问题
  458. QuerySet(model=ProblemParagraphMapping).filter(document_id=document_id).delete()
  459. delete_problems_and_mappings([document_id])
  460. # 删除向量库
  461. delete_embedding_by_document(document_id)
  462. paragraphs = get_split_model('web.md').parse(result.content)
  463. char_length = reduce(lambda x, y: x + y,
  464. [len(p.get('content')) for p in paragraphs],
  465. 0)
  466. QuerySet(Document).filter(id=document_id).update(char_length=char_length)
  467. document_paragraph_model = DocumentSerializers.Create.get_paragraph_model(document, paragraphs)
  468. paragraph_model_list = document_paragraph_model.get('paragraph_model_list')
  469. problem_paragraph_object_list = document_paragraph_model.get('problem_paragraph_object_list')
  470. problem_model_list, problem_paragraph_mapping_list = ProblemParagraphManage(
  471. problem_paragraph_object_list, document.knowledge_id).to_problem_model_list()
  472. # 批量插入段落
  473. if len(paragraph_model_list) > 0:
  474. max_position = Paragraph.objects.filter(document_id=document_id).aggregate(
  475. max_position=Max('position')
  476. )['max_position'] or 0
  477. for i, paragraph in enumerate(paragraph_model_list):
  478. paragraph.position = max_position + i + 1
  479. QuerySet(Paragraph).bulk_create(paragraph_model_list)
  480. # 批量插入问题
  481. QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None
  482. # 插入关联问题
  483. QuerySet(ProblemParagraphMapping).bulk_create(problem_paragraph_mapping_list) if len(
  484. problem_paragraph_mapping_list) > 0 else None
  485. # 向量化
  486. if with_embedding:
  487. embedding_model_id = get_embedding_model_id_by_knowledge_id(document.knowledge_id)
  488. ListenerManagement.update_status(QuerySet(Document).filter(id=document_id),
  489. TaskType.EMBEDDING,
  490. State.PENDING)
  491. ListenerManagement.update_status(QuerySet(Paragraph).filter(document_id=document_id),
  492. TaskType.EMBEDDING,
  493. State.PENDING)
  494. ListenerManagement.get_aggregation_document_status(document_id)()
  495. embedding_by_document.delay(document_id, embedding_model_id)
  496. else:
  497. state = State.FAILURE
  498. except Exception as e:
  499. maxkb_logger.error(f'{str(e)}:{traceback.format_exc()}')
  500. state = State.FAILURE
  501. ListenerManagement.update_status(
  502. QuerySet(Document).filter(id=document_id),
  503. TaskType.SYNC,
  504. state
  505. )
  506. ListenerManagement.update_status(
  507. QuerySet(Paragraph).filter(document_id=document_id),
  508. TaskType.SYNC,
  509. state
  510. )
  511. ListenerManagement.get_aggregation_document_status(document_id)()
  512. return True
  513. class Operate(serializers.Serializer):
  514. workspace_id = serializers.CharField(required=False, label=_('workspace id'), allow_blank=True)
  515. document_id = serializers.UUIDField(required=True, label=_('document id'))
  516. knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id'))
  517. def is_valid(self, *, raise_exception=False):
  518. super().is_valid(raise_exception=True)
  519. workspace_id = self.data.get('workspace_id')
  520. query_set = QuerySet(Knowledge).filter(id=self.data.get('knowledge_id'))
  521. if workspace_id and workspace_id != 'None':
  522. query_set = query_set.filter(workspace_id=workspace_id)
  523. if not query_set.exists():
  524. raise AppApiException(500, _('Knowledge id does not exist'))
  525. document_id = self.data.get('document_id')
  526. if not QuerySet(Document).filter(id=document_id).exists():
  527. raise AppApiException(500, _('document id not exist'))
  528. def export(self, with_valid=True):
  529. if with_valid:
  530. self.is_valid(raise_exception=True)
  531. document = QuerySet(Document).filter(id=self.data.get("document_id")).first()
  532. paragraph_query_set = QuerySet(Paragraph).filter(
  533. document_id=self.data.get("document_id")
  534. ).order_by('position')
  535. paragraph_list = native_search(
  536. paragraph_query_set,
  537. get_file_content(
  538. os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'list_paragraph_document_name.sql'))
  539. )
  540. problem_mapping_list = native_search(
  541. QuerySet(ProblemParagraphMapping).filter(document_id=self.data.get("document_id")), get_file_content(
  542. os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'list_problem_mapping.sql')),
  543. with_table_name=True)
  544. data_dict, document_dict = self.merge_problem(paragraph_list, problem_mapping_list, [document])
  545. workbook = self.get_workbook(data_dict, document_dict)
  546. response = HttpResponse(content_type='application/vnd.ms-excel')
  547. response['Content-Disposition'] = f'attachment; filename="data.xlsx"'
  548. workbook.save(response)
  549. return response
  550. def export_zip(self, with_valid=True):
  551. if with_valid:
  552. self.is_valid(raise_exception=True)
  553. document = QuerySet(Document).filter(id=self.data.get("document_id")).first()
  554. paragraph_query_set = QuerySet(Paragraph).filter(
  555. document_id=self.data.get("document_id")
  556. ).order_by('position')
  557. paragraph_list = native_search(
  558. paragraph_query_set,
  559. get_file_content(
  560. os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'list_paragraph_document_name.sql')
  561. )
  562. )
  563. problem_mapping_list = native_search(
  564. QuerySet(ProblemParagraphMapping).filter(document_id=self.data.get("document_id")), get_file_content(
  565. os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'list_problem_mapping.sql')),
  566. with_table_name=True)
  567. data_dict, document_dict = self.merge_problem(paragraph_list, problem_mapping_list, [document])
  568. res = [parse_image(paragraph.get('content')) for paragraph in paragraph_list]
  569. workbook = DocumentSerializers.Operate.get_workbook(data_dict, document_dict)
  570. response = HttpResponse(content_type='application/zip')
  571. response['Content-Disposition'] = f'attachment; filename="{document.name.strip()}.zip"'
  572. zip_buffer = io.BytesIO()
  573. with TemporaryDirectory() as tempdir:
  574. knowledge_file = os.path.join(tempdir, 'document.xlsx')
  575. workbook.save(knowledge_file)
  576. for r in res:
  577. write_image(tempdir, r)
  578. zip_dir(tempdir, zip_buffer)
  579. response.write(zip_buffer.getvalue())
  580. return response
  581. def download_source_file(self):
  582. self.is_valid(raise_exception=True)
  583. file = QuerySet(File).filter(source_id=self.data.get('document_id')).first()
  584. if not file:
  585. raise AppApiException(500, _('File not exist. Only manually uploaded documents are supported'))
  586. return FileSerializer.Operate(data={'id': file.id}).get(with_valid=True)
  587. def one(self, with_valid=False):
  588. self.is_valid(raise_exception=True)
  589. query_set = QuerySet(model=Document)
  590. query_set = query_set.filter(**{'id': self.data.get("document_id")})
  591. return native_search({
  592. 'document_custom_sql': query_set,
  593. 'order_by_query': QuerySet(Document).order_by('-create_time', 'id')
  594. }, select_string=get_file_content(
  595. os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'list_document.sql')), with_search_one=True)
  596. def edit(self, instance: Dict, with_valid=False):
  597. if with_valid:
  598. self.is_valid(raise_exception=True)
  599. _document = QuerySet(Document).get(id=self.data.get("document_id"))
  600. if with_valid:
  601. DocumentEditInstanceSerializer(data=instance).is_valid(document=_document)
  602. update_keys = ['name', 'is_active', 'hit_handling_method', 'directly_return_similarity', 'meta']
  603. for update_key in update_keys:
  604. if update_key in instance and instance.get(update_key) is not None:
  605. _document.__setattr__(update_key, instance.get(update_key))
  606. _document.save()
  607. return self.one()
  608. def cancel(self, instance, with_valid=True):
  609. if with_valid:
  610. self.is_valid(raise_exception=True)
  611. CancelInstanceSerializer(data=instance).is_valid()
  612. document_id = self.data.get("document_id")
  613. ListenerManagement.update_status(
  614. QuerySet(Paragraph).annotate(
  615. reversed_status=Reverse('status'),
  616. task_type_status=Substr('reversed_status', TaskType(instance.get('type')).value, 1),
  617. ).filter(
  618. task_type_status__in=[State.PENDING.value, State.STARTED.value]
  619. ).filter(
  620. document_id=document_id
  621. ).values('id'),
  622. TaskType(instance.get('type')),
  623. State.REVOKE
  624. )
  625. ListenerManagement.update_status(
  626. QuerySet(Document).annotate(
  627. reversed_status=Reverse('status'),
  628. task_type_status=Substr('reversed_status', TaskType(instance.get('type')).value,
  629. 1),
  630. ).filter(
  631. task_type_status__in=[State.PENDING.value, State.STARTED.value]
  632. ).filter(
  633. id=document_id
  634. ).values('id'),
  635. TaskType(instance.get('type')),
  636. State.REVOKE
  637. )
  638. return True
  639. @transaction.atomic
  640. def delete(self):
  641. self.is_valid(raise_exception=True)
  642. document_id = self.data.get("document_id")
  643. source_file_ids = [
  644. doc['meta'].get(
  645. 'source_file_id'
  646. ) for doc in Document.objects.filter(id=document_id).values("meta")
  647. ]
  648. QuerySet(File).filter(id__in=source_file_ids).delete()
  649. QuerySet(File).filter(source_id=document_id, source_type=FileSourceType.DOCUMENT).delete()
  650. paragraph_ids = QuerySet(model=Paragraph).filter(document_id=document_id).values_list("id", flat=True)
  651. # 删除问题
  652. delete_problems_and_mappings(paragraph_ids)
  653. # 删除段落
  654. QuerySet(model=Paragraph).filter(document_id=document_id).delete()
  655. # 删除向量库
  656. delete_embedding_by_document(document_id)
  657. QuerySet(model=DocumentTag).filter(document_id=document_id).delete()
  658. QuerySet(model=Document).filter(id=document_id).delete()
  659. return True
  660. def refresh(self, state_list=None, with_valid=True):
  661. if state_list is None:
  662. state_list = [State.PENDING.value, State.STARTED.value, State.SUCCESS.value, State.FAILURE.value,
  663. State.REVOKE.value,
  664. State.REVOKED.value, State.IGNORED.value]
  665. if with_valid:
  666. self.is_valid(raise_exception=True)
  667. knowledge = QuerySet(Knowledge).filter(id=self.data.get('knowledge_id')).first()
  668. embedding_model_id = knowledge.embedding_model_id
  669. knowledge_user_id = knowledge.user_id
  670. embedding_model = QuerySet(Model).filter(id=embedding_model_id).first()
  671. if embedding_model is None:
  672. raise AppApiException(500, _('Model does not exist'))
  673. document_id = self.data.get("document_id")
  674. ListenerManagement.update_status(
  675. QuerySet(Document).filter(id=document_id), TaskType.EMBEDDING, State.PENDING
  676. )
  677. ListenerManagement.update_status(
  678. QuerySet(Paragraph).annotate(
  679. reversed_status=Reverse('status'),
  680. task_type_status=Substr('reversed_status', TaskType.EMBEDDING.value, 1),
  681. ).filter(task_type_status__in=state_list, document_id=document_id).values('id'),
  682. TaskType.EMBEDDING,
  683. State.PENDING
  684. )
  685. ListenerManagement.get_aggregation_document_status(document_id)()
  686. try:
  687. embedding_by_document.delay(document_id, embedding_model_id, state_list)
  688. except AlreadyQueued as e:
  689. raise AppApiException(500, _('The task is being executed, please do not send it repeatedly.'))
  690. @staticmethod
  691. def get_workbook(data_dict, document_dict):
  692. # 创建工作簿对象
  693. workbook = openpyxl.Workbook()
  694. workbook.remove(workbook.active)
  695. if len(data_dict.keys()) == 0:
  696. data_dict['sheet'] = []
  697. for sheet_id in data_dict:
  698. # 添加工作表
  699. worksheet = workbook.create_sheet(document_dict.get(sheet_id))
  700. data = [
  701. [gettext('Section title (optional)'),
  702. gettext('Section content (required, question answer, no more than 4096 characters)'),
  703. gettext('Question (optional, one per line in the cell)')],
  704. *data_dict.get(sheet_id, [])
  705. ]
  706. # 写入数据到工作表
  707. for row_idx, row in enumerate(data):
  708. for col_idx, col in enumerate(row):
  709. cell = worksheet.cell(row=row_idx + 1, column=col_idx + 1)
  710. if isinstance(col, str):
  711. col = re.sub(ILLEGAL_CHARACTERS_RE, '', col)
  712. if col.startswith(('=', '+', '-', '@')):
  713. col = '\ufeff' + col
  714. cell.value = col
  715. # 创建HttpResponse对象返回Excel文件
  716. return workbook
  717. @staticmethod
  718. def merge_problem(paragraph_list: List[Dict], problem_mapping_list: List[Dict], document_list):
  719. result = {}
  720. document_dict = {}
  721. for paragraph in paragraph_list:
  722. problem_list = [problem_mapping.get('content') for problem_mapping in problem_mapping_list if
  723. problem_mapping.get('paragraph_id') == paragraph.get('id')]
  724. document_sheet = result.get(paragraph.get('document_id'))
  725. document_name = DocumentSerializers.Operate.reset_document_name(paragraph.get('document_name'))
  726. d = document_dict.get(document_name)
  727. if d is None:
  728. document_dict[document_name] = {paragraph.get('document_id')}
  729. else:
  730. d.add(paragraph.get('document_id'))
  731. if document_sheet is None:
  732. result[paragraph.get('document_id')] = [[paragraph.get('title'), paragraph.get('content'),
  733. '\n'.join(problem_list)]]
  734. else:
  735. document_sheet.append([paragraph.get('title'), paragraph.get('content'), '\n'.join(problem_list)])
  736. for document in document_list:
  737. if document.id not in result:
  738. document_name = DocumentSerializers.Operate.reset_document_name(document.name)
  739. result[document.id] = [[]]
  740. d = document_dict.get(document_name)
  741. if d is None:
  742. document_dict[document_name] = {document.id}
  743. else:
  744. d.add(document.id)
  745. result_document_dict = {}
  746. for d_name in document_dict:
  747. for index, d_id in enumerate(document_dict.get(d_name)):
  748. result_document_dict[d_id] = d_name if index == 0 else d_name + str(index)
  749. return result, result_document_dict
  750. @staticmethod
  751. def reset_document_name(document_name):
  752. if document_name is not None:
  753. document_name = document_name.strip()[0:29]
  754. if document_name is None or not Utils.valid_sheet_name(document_name):
  755. return "Sheet"
  756. return document_name.strip()
  757. class Create(serializers.Serializer):
  758. workspace_id = serializers.CharField(required=False, label=_('workspace id'), allow_null=True)
  759. knowledge_id = serializers.UUIDField(required=True, label=_('document id'))
  760. def is_valid(self, *, raise_exception=False):
  761. super().is_valid(raise_exception=True)
  762. if not QuerySet(Knowledge).filter(id=self.data.get('knowledge_id')).exists():
  763. raise AppApiException(10000, _('knowledge id not exist'))
  764. return True
  765. @staticmethod
  766. def post_embedding(result, document_id, knowledge_id):
  767. DocumentSerializers.Operate(
  768. data={'knowledge_id': knowledge_id, 'document_id': document_id}).refresh()
  769. return result
  770. @post(post_function=post_embedding)
  771. @transaction.atomic
  772. def save(self, instance: Dict, with_valid=True, **kwargs):
  773. if with_valid:
  774. DocumentInstanceSerializer(data=instance).is_valid(raise_exception=True)
  775. self.is_valid(raise_exception=True)
  776. knowledge_id = self.data.get('knowledge_id')
  777. document_paragraph_model = self.get_document_paragraph_model(knowledge_id, instance)
  778. document_model = document_paragraph_model.get('document')
  779. paragraph_model_list = document_paragraph_model.get('paragraph_model_list')
  780. problem_paragraph_object_list = document_paragraph_model.get('problem_paragraph_object_list')
  781. problem_model_list, problem_paragraph_mapping_list = (
  782. ProblemParagraphManage(problem_paragraph_object_list, knowledge_id).to_problem_model_list())
  783. # 插入文档
  784. document_model.save()
  785. # 批量插入段落
  786. if len(paragraph_model_list) > 0:
  787. max_position = Paragraph.objects.filter(document_id=document_model.id).aggregate(
  788. max_position=Max('position')
  789. )['max_position'] or 0
  790. for i, paragraph in enumerate(paragraph_model_list):
  791. paragraph.position = max_position + i + 1
  792. QuerySet(Paragraph).bulk_create(paragraph_model_list)
  793. # 批量插入问题
  794. QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None
  795. # 批量插入关联问题
  796. QuerySet(ProblemParagraphMapping).bulk_create(
  797. problem_paragraph_mapping_list
  798. ) if len(problem_paragraph_mapping_list) > 0 else None
  799. document_id = str(document_model.id)
  800. return (DocumentSerializers.Operate(
  801. data={'knowledge_id': knowledge_id, 'document_id': document_id}
  802. ).one(with_valid=True), document_id, knowledge_id)
  803. @staticmethod
  804. def get_paragraph_model(document_model, paragraph_list: List):
  805. knowledge_id = document_model.knowledge_id
  806. paragraph_model_dict_list = [
  807. ParagraphSerializers.Create(
  808. data={
  809. 'knowledge_id': knowledge_id, 'document_id': str(document_model.id)
  810. }).get_paragraph_problem_model(knowledge_id, document_model.id, paragraph)
  811. for paragraph in paragraph_list]
  812. paragraph_model_list = []
  813. problem_paragraph_object_list = []
  814. for paragraphs in paragraph_model_dict_list:
  815. paragraph = paragraphs.get('paragraph')
  816. for problem_model in paragraphs.get('problem_paragraph_object_list'):
  817. problem_paragraph_object_list.append(problem_model)
  818. paragraph_model_list.append(paragraph)
  819. return {
  820. 'document': document_model,
  821. 'paragraph_model_list': paragraph_model_list,
  822. 'problem_paragraph_object_list': problem_paragraph_object_list
  823. }
  824. @staticmethod
  825. def get_document_paragraph_model(knowledge_id, instance: Dict):
  826. source_meta = {'source_file_id': instance.get('source_file_id')} if instance.get('source_file_id') else {}
  827. meta = {**instance.get('meta'), **source_meta} if instance.get('meta') is not None else source_meta
  828. meta = {**convert_uuid_to_str(meta), 'allow_download': True}
  829. document_model = Document(
  830. **{
  831. 'knowledge_id': knowledge_id,
  832. 'id': uuid.uuid7(),
  833. 'name': instance.get('name'),
  834. 'char_length': reduce(
  835. lambda x, y: x + y,
  836. [len(p.get('content')) for p in instance.get('paragraphs', [])],
  837. 0),
  838. 'meta': meta,
  839. 'type': instance.get('type') if instance.get('type') is not None else KnowledgeType.BASE
  840. })
  841. return DocumentSerializers.Create.get_paragraph_model(
  842. document_model,
  843. instance.get('paragraphs') if 'paragraphs' in instance else []
  844. )
  845. def save_web(self, instance: Dict, with_valid=True):
  846. if with_valid:
  847. DocumentWebInstanceSerializer(data=instance).is_valid(raise_exception=True)
  848. self.is_valid(raise_exception=True)
  849. knowledge_id = self.data.get('knowledge_id')
  850. source_url_list = instance.get('source_url_list')
  851. selector = instance.get('selector')
  852. sync_web_document.delay(knowledge_id, source_url_list, selector)
  853. def save_qa(self, instance: Dict, with_valid=True):
  854. if with_valid:
  855. DocumentInstanceQASerializer(data=instance).is_valid(raise_exception=True)
  856. self.is_valid(raise_exception=True)
  857. file_list = instance.get('file_list')
  858. document_list = flat_map([self.parse_qa_file(file) for file in file_list])
  859. return DocumentSerializers.Batch(data={
  860. 'knowledge_id': self.data.get('knowledge_id'), 'workspace_id': self.data.get('workspace_id')
  861. }).batch_save(document_list)
  862. def save_table(self, instance: Dict, with_valid=True):
  863. if with_valid:
  864. DocumentInstanceTableSerializer(data=instance).is_valid(raise_exception=True)
  865. self.is_valid(raise_exception=True)
  866. file_list = instance.get('file_list')
  867. document_list = flat_map([self.parse_table_file(file) for file in file_list])
  868. return DocumentSerializers.Batch(data={
  869. 'knowledge_id': self.data.get('knowledge_id'), 'workspace_id': self.data.get('workspace_id')
  870. }).batch_save(document_list)
  871. def parse_qa_file(self, file):
  872. # 保存源文件
  873. source_file_id = uuid.uuid7()
  874. source_file = File(
  875. id=source_file_id,
  876. file_name=file.name,
  877. source_type=FileSourceType.KNOWLEDGE,
  878. source_id=self.data.get('knowledge_id'),
  879. meta={}
  880. )
  881. source_file.save(file.read())
  882. file.seek(0)
  883. get_buffer = FileBufferHandle().get_buffer
  884. for parse_qa_handle in parse_qa_handle_list:
  885. if parse_qa_handle.support(file, get_buffer):
  886. documents = parse_qa_handle.handle(file, get_buffer, self.save_image)
  887. for doc in documents:
  888. doc['source_file_id'] = source_file_id
  889. return documents
  890. raise AppApiException(500, _('Unsupported file format'))
  891. def parse_table_file(self, file):
  892. # 保存源文件
  893. source_file_id = uuid.uuid7()
  894. source_file = File(
  895. id=source_file_id,
  896. file_name=file.name,
  897. source_type=FileSourceType.KNOWLEDGE,
  898. source_id=self.data.get('knowledge_id'),
  899. meta={}
  900. )
  901. source_file.save(file.read())
  902. file.seek(0)
  903. get_buffer = FileBufferHandle().get_buffer
  904. for parse_table_handle in parse_table_handle_list:
  905. if parse_table_handle.support(file, get_buffer):
  906. documents = parse_table_handle.handle(file, get_buffer, self.save_image)
  907. for doc in documents:
  908. doc['source_file_id'] = source_file_id
  909. return documents
  910. raise AppApiException(500, _('Unsupported file format'))
  911. def save_image(self, image_list):
  912. if image_list is not None and len(image_list) > 0:
  913. exist_image_list = [str(i.get('id')) for i in
  914. QuerySet(File).filter(id__in=[i.id for i in image_list]).values('id')]
  915. save_image_list = [image for image in image_list if not exist_image_list.__contains__(str(image.id))]
  916. save_image_list = list({img.id: img for img in save_image_list}.values())
  917. # save image
  918. for file in save_image_list:
  919. file_bytes = file.meta.pop('content')
  920. file.meta['knowledge_id'] = self.data.get('knowledge_id')
  921. file.source_type = FileSourceType.KNOWLEDGE
  922. file.source_id = self.data.get('knowledge_id')
  923. file.save(file_bytes)
  924. class Split(serializers.Serializer):
  925. workspace_id = serializers.CharField(required=False, label=_('workspace id'), allow_null=True)
  926. knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id'))
  927. def is_valid(self, *, instance=None, raise_exception=True):
  928. super().is_valid(raise_exception=True)
  929. workspace_id = self.data.get('workspace_id')
  930. query_set = QuerySet(Knowledge).filter(id=self.data.get('knowledge_id'))
  931. if workspace_id:
  932. query_set = query_set.filter(workspace_id=workspace_id)
  933. if not query_set.exists():
  934. raise AppApiException(500, _('Knowledge id does not exist'))
  935. files = instance.get('file')
  936. knowledge = Knowledge.objects.filter(id=self.data.get('knowledge_id')).first()
  937. for f in files:
  938. if f.size > 1024 * 1024 * knowledge.file_size_limit:
  939. raise AppApiException(500, _(
  940. 'The maximum size of the uploaded file cannot exceed {}MB'
  941. ).format(knowledge.file_size_limit))
  942. def parse(self, instance):
  943. self.is_valid(instance=instance, raise_exception=True)
  944. DocumentSplitRequest(data=instance).is_valid(raise_exception=True)
  945. file_list = instance.get("file")
  946. return reduce(
  947. lambda x, y: [*x, *y],
  948. [self.file_to_paragraph(
  949. f,
  950. instance.get("patterns", None),
  951. instance.get("with_filter", None),
  952. instance.get("limit", 4096)
  953. ) for f in file_list],
  954. []
  955. )
  956. def save_image(self, image_list):
  957. if image_list is not None and len(image_list) > 0:
  958. exist_image_list = [str(i.get('id')) for i in
  959. QuerySet(File).filter(id__in=[i.id for i in image_list]).values('id')]
  960. save_image_list = [image for image in image_list if not exist_image_list.__contains__(str(image.id))]
  961. save_image_list = list({img.id: img for img in save_image_list}.values())
  962. # save image
  963. for file in save_image_list:
  964. file_bytes = file.meta.pop('content')
  965. file.meta['knowledge_id'] = self.data.get('knowledge_id')
  966. file.source_type = FileSourceType.KNOWLEDGE
  967. file.source_id = self.data.get('knowledge_id')
  968. file.save(file_bytes)
  969. def file_to_paragraph(self, file, pattern_list: List, with_filter: bool, limit: int):
  970. # 保存源文件
  971. file_id = uuid.uuid7()
  972. raw_file = File(
  973. id=file_id,
  974. file_name=file.name,
  975. file_size=file.size,
  976. source_type=FileSourceType.KNOWLEDGE,
  977. source_id=self.data.get('knowledge_id'),
  978. )
  979. raw_file.save(file.read())
  980. file.seek(0)
  981. get_buffer = FileBufferHandle().get_buffer
  982. for split_handle in split_handles:
  983. if split_handle.support(file, get_buffer):
  984. result = split_handle.handle(file, pattern_list, with_filter, limit, get_buffer, self.save_image)
  985. if isinstance(result, list):
  986. for item in result:
  987. item['source_file_id'] = file_id
  988. return result
  989. result['source_file_id'] = file_id
  990. return [result]
  991. result = default_split_handle.handle(file, pattern_list, with_filter, limit, get_buffer, self.save_image)
  992. if isinstance(result, list):
  993. for item in result:
  994. item['source_file_id'] = file_id
  995. return result
  996. result['source_file_id'] = file_id
  997. return [result]
  998. class SplitPattern(serializers.Serializer):
  999. workspace_id = serializers.CharField(required=False, label=_('workspace id'), allow_null=True)
  1000. knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id'))
  1001. @staticmethod
  1002. def list():
  1003. return [
  1004. {'key': "#", 'value': '(?<=^)# .*|(?<=\\n)# .*'},
  1005. {'key': '##', 'value': '(?<=\\n)(?<!#)## (?!#).*|(?<=^)(?<!#)## (?!#).*'},
  1006. {'key': '###', 'value': "(?<=\\n)(?<!#)### (?!#).*|(?<=^)(?<!#)### (?!#).*"},
  1007. {'key': '####', 'value': "(?<=\\n)(?<!#)#### (?!#).*|(?<=^)(?<!#)#### (?!#).*"},
  1008. {'key': '#####', 'value': "(?<=\\n)(?<!#)##### (?!#).*|(?<=^)(?<!#)##### (?!#).*"},
  1009. {'key': '######', 'value': "(?<=\\n)(?<!#)###### (?!#).*|(?<=^)(?<!#)###### (?!#).*"},
  1010. {'key': '-', 'value': '(?<! )- .*'},
  1011. {'key': _('space'), 'value': '(?<! ) (?! )'},
  1012. {'key': _('semicolon'), 'value': '(?<!;);(?!;)'}, {'key': _('comma'), 'value': '(?<!,),(?!,)'},
  1013. {'key': _('period'), 'value': '(?<!。)。(?!。)'}, {'key': _('enter'), 'value': '(?<!\\n)\\n(?!\\n)'},
  1014. {'key': _('blank line'), 'value': '(?<!\\n)\\n\\n(?!\\n)'}
  1015. ]
  1016. class Batch(serializers.Serializer):
  1017. workspace_id = serializers.CharField(required=True, label=_('workspace id'))
  1018. knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id'))
  1019. def is_valid(self, *, raise_exception=False):
  1020. super().is_valid(raise_exception=True)
  1021. workspace_id = self.data.get('workspace_id')
  1022. query_set = QuerySet(Knowledge).filter(id=self.data.get('knowledge_id'))
  1023. if workspace_id:
  1024. query_set = query_set.filter(workspace_id=workspace_id)
  1025. if not query_set.exists():
  1026. raise AppApiException(500, _('Knowledge id does not exist'))
  1027. @staticmethod
  1028. def link_file(source_file_id, document_id):
  1029. if source_file_id is None:
  1030. return
  1031. source_file = QuerySet(File).filter(id=source_file_id).first()
  1032. if source_file:
  1033. # 获取原始文件内容
  1034. file_content = source_file.get_bytes()
  1035. # 创建新文件对象,复制原始文件的重要属性
  1036. new_file = File(
  1037. id=uuid.uuid7(),
  1038. file_name=source_file.file_name,
  1039. file_size=source_file.file_size,
  1040. source_type=FileSourceType.DOCUMENT,
  1041. source_id=document_id, # 更新为当前知识库ID
  1042. meta=source_file.meta.copy() if source_file.meta else {}
  1043. )
  1044. # 保存文件内容和元数据
  1045. new_file.save(file_content)
  1046. @staticmethod
  1047. def post_embedding(document_list, knowledge_id, workspace_id):
  1048. for document_dict in document_list:
  1049. DocumentSerializers.Operate(data={
  1050. 'knowledge_id': knowledge_id,
  1051. 'document_id': document_dict.get('id'),
  1052. 'workspace_id': workspace_id
  1053. }).refresh()
  1054. return document_list
  1055. @post(post_function=post_embedding)
  1056. @transaction.atomic
  1057. def batch_save(self, instance_list: List[Dict], with_valid=True):
  1058. if with_valid:
  1059. self.is_valid(raise_exception=True)
  1060. DocumentInstanceSerializer(many=True, data=instance_list).is_valid(raise_exception=True)
  1061. workspace_id = self.data.get("workspace_id")
  1062. knowledge_id = self.data.get("knowledge_id")
  1063. document_model_list = []
  1064. paragraph_model_list = []
  1065. problem_paragraph_object_list = []
  1066. # 插入文档
  1067. for document in instance_list:
  1068. document_paragraph_dict_model = DocumentSerializers.Create.get_document_paragraph_model(
  1069. knowledge_id,
  1070. document
  1071. )
  1072. # 保存文档和文件的关系
  1073. document_instance = document_paragraph_dict_model.get('document')
  1074. self.link_file(document.get('source_file_id'), document_instance.id)
  1075. document_model_list.append(document_instance)
  1076. for paragraph in document_paragraph_dict_model.get('paragraph_model_list'):
  1077. paragraph_model_list.append(paragraph)
  1078. for problem_paragraph_object in document_paragraph_dict_model.get('problem_paragraph_object_list'):
  1079. problem_paragraph_object_list.append(problem_paragraph_object)
  1080. problem_model_list, problem_paragraph_mapping_list = (
  1081. ProblemParagraphManage(problem_paragraph_object_list, knowledge_id).to_problem_model_list()
  1082. )
  1083. # 插入文档
  1084. QuerySet(Document).bulk_create(document_model_list) if len(document_model_list) > 0 else None
  1085. # 批量插入段落
  1086. if len(paragraph_model_list) > 0:
  1087. for document in document_model_list:
  1088. max_position = Paragraph.objects.filter(document_id=document.id).aggregate(
  1089. max_position=Max('position')
  1090. )['max_position'] or 0
  1091. sub_list = [p for p in paragraph_model_list if p.document_id == document.id]
  1092. for i, paragraph in enumerate(sub_list):
  1093. paragraph.position = max_position + i + 1
  1094. QuerySet(Paragraph).bulk_create(sub_list if len(sub_list) > 0 else [])
  1095. # 批量插入问题
  1096. bulk_create_in_batches(Problem, problem_model_list, batch_size=1000)
  1097. # 批量插入关联问题
  1098. bulk_create_in_batches(ProblemParagraphMapping, problem_paragraph_mapping_list, batch_size=1000)
  1099. # 查询文档
  1100. query_set = QuerySet(model=Document)
  1101. if len(document_model_list) == 0:
  1102. return [], knowledge_id, workspace_id
  1103. query_set = query_set.filter(**{'id__in': [d.id for d in document_model_list]})
  1104. return native_search(
  1105. {
  1106. 'document_custom_sql': query_set,
  1107. 'order_by_query': QuerySet(Document).order_by('-create_time', 'id')
  1108. },
  1109. select_string=get_file_content(
  1110. os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'list_document.sql')
  1111. ),
  1112. with_search_one=False
  1113. ), knowledge_id, workspace_id
  1114. def batch_sync(self, instance: Dict, with_valid=True):
  1115. if with_valid:
  1116. BatchSerializer(data=instance).is_valid(model=Document, raise_exception=True)
  1117. self.is_valid(raise_exception=True)
  1118. # 异步同步
  1119. work_thread_pool.submit(
  1120. lambda doc_ids: [
  1121. DocumentSerializers.Sync(data={
  1122. 'document_id': doc_id,
  1123. 'knowledge_id': self.data.get('knowledge_id'),
  1124. 'workspace_id': self.data.get('workspace_id')
  1125. }).sync() for doc_id in doc_ids
  1126. ],
  1127. instance.get('id_list')
  1128. )
  1129. return True
  1130. @transaction.atomic
  1131. def batch_delete(self, instance: Dict, with_valid=True):
  1132. if with_valid:
  1133. BatchSerializer(data=instance).is_valid(model=Document, raise_exception=True)
  1134. self.is_valid(raise_exception=True)
  1135. document_id_list = instance.get("id_list")
  1136. source_file_ids = [doc['meta'].get('source_file_id') for doc in
  1137. Document.objects.filter(id__in=document_id_list).values("meta")]
  1138. QuerySet(File).filter(id__in=source_file_ids).delete()
  1139. QuerySet(Document).filter(id__in=document_id_list).delete()
  1140. QuerySet(DocumentTag).filter(document_id__in=document_id_list).delete()
  1141. paragraph_ids = QuerySet(Paragraph).filter(document_id__in=document_id_list).values_list("id", flat=True)
  1142. # 删除问题关系
  1143. delete_problems_and_mappings(paragraph_ids)
  1144. # 删除段落
  1145. QuerySet(Paragraph).filter(document_id__in=document_id_list).delete()
  1146. # 删除向量库
  1147. delete_embedding_by_document_list(document_id_list)
  1148. return True
  1149. def batch_cancel(self, instance: Dict, with_valid=True):
  1150. if with_valid:
  1151. self.is_valid(raise_exception=True)
  1152. BatchCancelInstanceSerializer(data=instance).is_valid(raise_exception=True)
  1153. document_id_list = instance.get("id_list")
  1154. ListenerManagement.update_status(
  1155. QuerySet(Paragraph).annotate(
  1156. reversed_status=Reverse('status'),
  1157. task_type_status=Substr('reversed_status', TaskType(instance.get('type')).value, 1),
  1158. ).filter(
  1159. task_type_status__in=[State.PENDING.value, State.STARTED.value]
  1160. ).filter(
  1161. document_id__in=document_id_list
  1162. ).values('id'),
  1163. TaskType(instance.get('type')),
  1164. State.REVOKE
  1165. )
  1166. ListenerManagement.update_status(
  1167. QuerySet(Document).annotate(
  1168. reversed_status=Reverse('status'),
  1169. task_type_status=Substr('reversed_status', TaskType(instance.get('type')).value, 1),
  1170. ).filter(
  1171. task_type_status__in=[State.PENDING.value, State.STARTED.value]
  1172. ).filter(
  1173. id__in=document_id_list
  1174. ).values('id'),
  1175. TaskType(instance.get('type')),
  1176. State.REVOKE
  1177. )
  1178. def batch_edit_hit_handling(self, instance: Dict, with_valid=True):
  1179. if with_valid:
  1180. BatchSerializer(data=instance).is_valid(model=Document, raise_exception=True)
  1181. hit_handling_method = instance.get('hit_handling_method')
  1182. if hit_handling_method is None:
  1183. raise AppApiException(500, _('Hit handling method is required'))
  1184. if hit_handling_method != 'optimization' and hit_handling_method != 'directly_return':
  1185. raise AppApiException(500, _('The hit processing method must be directly_return|optimization'))
  1186. self.is_valid(raise_exception=True)
  1187. document_id_list = instance.get("id_list")
  1188. hit_handling_method = instance.get('hit_handling_method')
  1189. directly_return_similarity = instance.get('directly_return_similarity')
  1190. update_dict = {'hit_handling_method': hit_handling_method}
  1191. if directly_return_similarity is not None:
  1192. update_dict['directly_return_similarity'] = directly_return_similarity
  1193. QuerySet(Document).filter(id__in=document_id_list).update(**update_dict)
  1194. allow_download = instance.get('allow_download')
  1195. if allow_download is not None:
  1196. # 我需要修改meta meta是存在Document的字段 是一个json字段 但是allow_download可能不存在
  1197. Document.objects.filter(id__in=document_id_list).update(
  1198. meta=Func(
  1199. F("meta"),
  1200. Value(["allow_download"]),
  1201. Value(json.dumps(allow_download)), # 转成 "true"/"false"
  1202. Value(True), # create_missing = true
  1203. function="jsonb_set",
  1204. output_field=JSONField(),
  1205. )
  1206. )
  1207. def batch_refresh(self, instance: Dict, with_valid=True):
  1208. if with_valid:
  1209. self.is_valid(raise_exception=True)
  1210. document_id_list = instance.get("id_list")
  1211. state_list = instance.get("state_list")
  1212. knowledge_id = self.data.get('knowledge_id')
  1213. for document_id in document_id_list:
  1214. try:
  1215. DocumentSerializers.Operate(
  1216. data={'knowledge_id': knowledge_id, 'document_id': document_id}).refresh(state_list)
  1217. except AlreadyQueued as e:
  1218. pass
  1219. def batch_add_tag(self, instance: Dict, with_valid=True):
  1220. if with_valid:
  1221. self.is_valid(raise_exception=True)
  1222. document_id_list = instance.get("document_ids")
  1223. tag_id_list = instance.get("tag_ids")
  1224. # 批量查询已存在的标签关联关系
  1225. existing_relations = {
  1226. (str(doc_id), str(tag_id))
  1227. for doc_id, tag_id in QuerySet(DocumentTag).filter(
  1228. document_id__in=document_id_list,
  1229. tag_id__in=tag_id_list
  1230. ).values_list('document_id', 'tag_id')
  1231. }
  1232. new_relations = []
  1233. for doc_id in document_id_list:
  1234. for tag_id in tag_id_list:
  1235. relation_key = (str(doc_id), str(tag_id))
  1236. # 既检查数据库中已存在的,也检查本次即将创建的
  1237. if relation_key not in existing_relations:
  1238. new_relations.append(DocumentTag(
  1239. id=uuid.uuid7(),
  1240. document_id=doc_id,
  1241. tag_id=tag_id,
  1242. ))
  1243. existing_relations.add(relation_key)
  1244. if new_relations:
  1245. QuerySet(DocumentTag).bulk_create(new_relations)
  1246. def batch_export(self, instance: Dict, with_valid=True):
  1247. if with_valid:
  1248. BatchSerializer(data=instance).is_valid(model=Document, raise_exception=True)
  1249. self.is_valid(raise_exception=True)
  1250. document_ids = instance.get("id_list")
  1251. document_list = QuerySet(Document).filter(id__in=document_ids)
  1252. paragraph_list = native_search(
  1253. QuerySet(Paragraph).filter(document_id__in=document_ids),
  1254. get_file_content(
  1255. os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'list_paragraph_document_name.sql')
  1256. )
  1257. )
  1258. problem_mapping_list = native_search(
  1259. QuerySet(ProblemParagraphMapping).filter(document_id__in=document_ids),
  1260. get_file_content(os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'list_problem_mapping.sql')),
  1261. with_table_name=True
  1262. )
  1263. data_dict, document_dict = DocumentSerializers.Operate.merge_problem(
  1264. paragraph_list, problem_mapping_list, document_list
  1265. )
  1266. workbook = DocumentSerializers.Operate.get_workbook(data_dict, document_dict)
  1267. response = HttpResponse(content_type='application/vnd.ms-excel')
  1268. workbook.save(response)
  1269. return response
  1270. def batch_export_zip(self, instance: Dict, with_valid=True):
  1271. if with_valid:
  1272. BatchSerializer(data=instance).is_valid(model=Document, raise_exception=True)
  1273. self.is_valid(raise_exception=True)
  1274. knowledge = QuerySet(Knowledge).filter(id=self.data.get('knowledge_id')).first()
  1275. document_ids = instance.get("id_list")
  1276. document_list = QuerySet(Document).filter(id__in=document_ids)
  1277. paragraph_list = native_search(
  1278. QuerySet(Paragraph).filter(document_id__in=document_ids),
  1279. get_file_content(
  1280. os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'list_paragraph_document_name.sql')
  1281. )
  1282. )
  1283. problem_mapping_list = native_search(
  1284. QuerySet(ProblemParagraphMapping).filter(document_id__in=document_ids),
  1285. get_file_content(os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'list_problem_mapping.sql')),
  1286. with_table_name=True
  1287. )
  1288. data_dict, document_dict = DocumentSerializers.Operate.merge_problem(
  1289. paragraph_list, problem_mapping_list, document_list
  1290. )
  1291. res = [parse_image(paragraph.get('content')) for paragraph in paragraph_list]
  1292. workbook = DocumentSerializers.Operate.get_workbook(data_dict, document_dict)
  1293. response = HttpResponse(content_type='application/zip')
  1294. zip_buffer = io.BytesIO()
  1295. with TemporaryDirectory() as tempdir:
  1296. knowledge_file = os.path.join(tempdir, f'{knowledge.name}.xlsx')
  1297. workbook.save(knowledge_file)
  1298. for r in res:
  1299. write_image(tempdir, r)
  1300. zip_dir(tempdir, zip_buffer)
  1301. response.write(zip_buffer.getvalue())
  1302. return response
  1303. class BatchGenerateRelated(serializers.Serializer):
  1304. workspace_id = serializers.CharField(required=True, label=_('workspace id'))
  1305. knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id'))
  1306. def is_valid(self, *, raise_exception=False):
  1307. super().is_valid(raise_exception=True)
  1308. workspace_id = self.data.get('workspace_id')
  1309. query_set = QuerySet(Knowledge).filter(id=self.data.get('knowledge_id'))
  1310. if workspace_id:
  1311. query_set = query_set.filter(workspace_id=workspace_id)
  1312. if not query_set.exists():
  1313. raise AppApiException(500, _('Knowledge id does not exist'))
  1314. def batch_generate_related(self, instance: Dict, with_valid=True):
  1315. if with_valid:
  1316. self.is_valid(raise_exception=True)
  1317. document_id_list = instance.get("document_id_list")
  1318. model_id = instance.get("model_id")
  1319. prompt = instance.get("prompt")
  1320. model_params_setting = instance.get("model_params_setting")
  1321. state_list = instance.get('state_list')
  1322. ListenerManagement.update_status(
  1323. QuerySet(Document).filter(id__in=document_id_list),
  1324. TaskType.GENERATE_PROBLEM,
  1325. State.PENDING
  1326. )
  1327. ListenerManagement.update_status(
  1328. QuerySet(Paragraph).annotate(
  1329. reversed_status=Reverse('status'),
  1330. task_type_status=Substr('reversed_status', TaskType.GENERATE_PROBLEM.value,
  1331. 1),
  1332. ).filter(
  1333. task_type_status__in=state_list, document_id__in=document_id_list
  1334. )
  1335. .values('id'),
  1336. TaskType.GENERATE_PROBLEM,
  1337. State.PENDING
  1338. )
  1339. ListenerManagement.get_aggregation_document_status_by_query_set(
  1340. QuerySet(Document).filter(id__in=document_id_list))()
  1341. try:
  1342. for document_id in document_id_list:
  1343. generate_related_by_document_id.delay(
  1344. document_id, model_id, model_params_setting, prompt, state_list
  1345. )
  1346. except AlreadyQueued as e:
  1347. pass
  1348. class Tags(serializers.Serializer):
  1349. workspace_id = serializers.CharField(required=True, label=_('workspace id'))
  1350. knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id'))
  1351. document_id = serializers.UUIDField(required=True, label=_('document id'))
  1352. name = serializers.CharField(required=False, allow_null=True, allow_blank=True, label=_('search value'))
  1353. def is_valid(self, *, raise_exception=False):
  1354. super().is_valid(raise_exception=True)
  1355. workspace_id = self.data.get('workspace_id')
  1356. query_set = QuerySet(Knowledge).filter(id=self.data.get('knowledge_id'))
  1357. if workspace_id and workspace_id != 'None':
  1358. query_set = query_set.filter(workspace_id=workspace_id)
  1359. if not query_set.exists():
  1360. raise AppApiException(500, _('Knowledge id does not exist'))
  1361. if not QuerySet(Document).filter(
  1362. id=self.data.get('document_id'),
  1363. knowledge_id=self.data.get('knowledge_id')
  1364. ).exists():
  1365. raise AppApiException(500, _('Document id does not exist'))
  1366. def list(self):
  1367. self.is_valid(raise_exception=True)
  1368. tag_ids = QuerySet(DocumentTag).filter(
  1369. document_id=self.data.get('document_id')
  1370. ).values_list('tag_id', flat=True)
  1371. if self.data.get('name'):
  1372. tag_ids = QuerySet(Tag).filter(
  1373. knowledge_id=self.data.get('knowledge_id'),
  1374. id__in=tag_ids,
  1375. ).filter(
  1376. Q(key__icontains=self.data.get('name')) | Q(value__icontains=self.data.get('name'))
  1377. ).values_list('id', flat=True)
  1378. # 获取所有标签,按创建时间排序保持稳定顺序
  1379. tags = QuerySet(Tag).filter(
  1380. knowledge_id=self.data.get('knowledge_id'),
  1381. id__in=tag_ids
  1382. ).values('key', 'value', 'id', 'create_time', 'update_time').order_by('create_time', 'key', 'value')
  1383. # 按key分组
  1384. grouped_tags = defaultdict(list)
  1385. for tag in tags:
  1386. grouped_tags[tag['key']].append({
  1387. 'id': tag['id'],
  1388. 'value': tag['value'],
  1389. 'create_time': tag['create_time'],
  1390. 'update_time': tag['update_time']
  1391. })
  1392. # 转换为期望的格式,保持key的顺序
  1393. result = []
  1394. # 按key排序以确保结果顺序一致
  1395. for key in sorted(grouped_tags.keys()):
  1396. values = grouped_tags[key]
  1397. # 按创建时间对values进行排序
  1398. values.sort(key=lambda x: x['create_time'])
  1399. result.append({
  1400. 'key': key,
  1401. 'values': values,
  1402. })
  1403. return result
  1404. class AddTags(serializers.Serializer):
  1405. workspace_id = serializers.CharField(required=True, label=_('workspace id'))
  1406. knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id'))
  1407. document_id = serializers.UUIDField(required=True, label=_('document id'))
  1408. tag_ids = serializers.ListField(
  1409. required=True, label=_('tag ids'), child=serializers.UUIDField(required=True, label=_('tag id'))
  1410. )
  1411. def is_valid(self, *, raise_exception=False):
  1412. super().is_valid(raise_exception=True)
  1413. workspace_id = self.data.get('workspace_id')
  1414. query_set = QuerySet(Knowledge).filter(id=self.data.get('knowledge_id'))
  1415. if workspace_id and workspace_id != 'None':
  1416. query_set = query_set.filter(workspace_id=workspace_id)
  1417. if not query_set.exists():
  1418. raise AppApiException(500, _('Knowledge id does not exist'))
  1419. if not QuerySet(Document).filter(
  1420. id=self.data.get('document_id'),
  1421. knowledge_id=self.data.get('knowledge_id')
  1422. ).exists():
  1423. raise AppApiException(500, _('Document id does not exist'))
  1424. def add_tags(self):
  1425. self.is_valid(raise_exception=True)
  1426. document_id = self.data.get('document_id')
  1427. tag_ids = self.data.get('tag_ids')
  1428. existing_tag_ids = set(
  1429. str(tag_id) for tag_id in QuerySet(DocumentTag).filter(
  1430. document_id=document_id, tag_id__in=tag_ids
  1431. ).values_list('tag_id', flat=True)
  1432. )
  1433. new_tags = [
  1434. DocumentTag(
  1435. id=uuid.uuid7(),
  1436. document_id=document_id,
  1437. tag_id=tag_id
  1438. ) for tag_id in set(tag_ids) if tag_id not in existing_tag_ids
  1439. ]
  1440. if new_tags:
  1441. QuerySet(DocumentTag).bulk_create(new_tags)
  1442. class DeleteTags(serializers.Serializer):
  1443. workspace_id = serializers.CharField(required=True, label=_('workspace id'))
  1444. knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id'))
  1445. document_id = serializers.UUIDField(required=True, label=_('document id'))
  1446. tag_ids = serializers.ListField(
  1447. required=True, label=_('tag ids'), child=serializers.UUIDField(required=True, label=_('tag id'))
  1448. )
  1449. def is_valid(self, *, raise_exception=False):
  1450. super().is_valid(raise_exception=True)
  1451. workspace_id = self.data.get('workspace_id')
  1452. query_set = QuerySet(Knowledge).filter(id=self.data.get('knowledge_id'))
  1453. if workspace_id and workspace_id != 'None':
  1454. query_set = query_set.filter(workspace_id=workspace_id)
  1455. if not query_set.exists():
  1456. raise AppApiException(500, _('Knowledge id does not exist'))
  1457. if not QuerySet(Document).filter(
  1458. id=self.data.get('document_id'),
  1459. knowledge_id=self.data.get('knowledge_id')
  1460. ).exists():
  1461. raise AppApiException(500, _('Document id does not exist'))
  1462. def delete_tags(self):
  1463. self.is_valid(raise_exception=True)
  1464. document_id = self.data.get('document_id')
  1465. tag_ids = self.data.get('tag_ids')
  1466. QuerySet(DocumentTag).filter(
  1467. document_id=document_id,
  1468. tag_id__in=tag_ids
  1469. ).delete()
  1470. class DeleteDocsTag(serializers.Serializer):
  1471. workspace_id = serializers.CharField(required=True, label=_('workspace id'))
  1472. knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id'))
  1473. tag_id = serializers.UUIDField(required=True, label=_('tag id'))
  1474. def is_valid(self, *, raise_exception=False):
  1475. super().is_valid(raise_exception=True)
  1476. workspace_id = self.data.get('workspace_id')
  1477. query_set = QuerySet(Knowledge).filter(id=self.data.get('knowledge_id'))
  1478. if workspace_id and workspace_id != 'None':
  1479. query_set = query_set.filter(workspace_id=workspace_id)
  1480. if not query_set.exists():
  1481. raise AppApiException(500, _('Knowledge id does not exist'))
  1482. if not QuerySet(Tag).filter(
  1483. id=self.data.get('tag_id'),
  1484. knowledge_id=self.data.get('knowledge_id')
  1485. ).exists():
  1486. raise AppApiException(500, _('Tag id does not exist'))
  1487. def batch_delete_docs_tag(self, instance, with_valid=True):
  1488. if with_valid:
  1489. BatchSerializer(data=instance).is_valid(model=Document, raise_exception=True)
  1490. self.is_valid(raise_exception=True)
  1491. knowledge_id = self.data.get('knowledge_id')
  1492. tag_id = self.data.get('tag_id')
  1493. doc_id_list = instance.get("id_list")
  1494. valid_doc_count = Document.objects.filter(id__in=doc_id_list, knowledge_id=knowledge_id).count()
  1495. if valid_doc_count != len(doc_id_list):
  1496. raise AppApiException(500, _('Document id does not belong to current knowledge'))
  1497. DocumentTag.objects.filter(document_id__in=doc_id_list, tag_id=tag_id).delete()
  1498. return True
  1499. class ReplaceSourceFile(serializers.Serializer):
  1500. workspace_id = serializers.CharField(required=True, label=_('workspace id'))
  1501. knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id'))
  1502. document_id = serializers.UUIDField(required=True, label=_('document id'))
  1503. file = UploadedFileField(required=True, label=_("file"))
  1504. def is_valid(self, *, raise_exception=False):
  1505. super().is_valid(raise_exception=True)
  1506. workspace_id = self.data.get('workspace_id')
  1507. query_set = QuerySet(Knowledge).filter(id=self.data.get('knowledge_id'))
  1508. if workspace_id and workspace_id != 'None':
  1509. query_set = query_set.filter(workspace_id=workspace_id)
  1510. if not query_set.exists():
  1511. raise AppApiException(500, _('Knowledge id does not exist'))
  1512. if not QuerySet(Document).filter(
  1513. id=self.data.get('document_id'),
  1514. knowledge_id=self.data.get('knowledge_id')
  1515. ).exists():
  1516. raise AppApiException(500, _('Document id does not exist'))
  1517. @transaction.atomic
  1518. def replace(self):
  1519. self.is_valid(raise_exception=True)
  1520. file = self.data.get('file')
  1521. source_file = QuerySet(File).filter(source_id=self.data.get('document_id')).first()
  1522. if not source_file:
  1523. # 不存在手动关联一个文档
  1524. new_source_file_id = uuid.uuid7()
  1525. new_source_file = File(
  1526. id=new_source_file_id,
  1527. file_name=file.name,
  1528. source_type=FileSourceType.DOCUMENT,
  1529. source_id=self.data.get('document_id'),
  1530. )
  1531. new_source_file.save(file.read())
  1532. # 更新Document的meta字段
  1533. QuerySet(Document).filter(id=self.data.get('document_id')).update(
  1534. meta=Func(
  1535. F("meta"),
  1536. Value(["source_file_id"]),
  1537. Value(json.dumps(str(new_source_file_id))),
  1538. Value(True), # create_missing = true
  1539. function="jsonb_set",
  1540. output_field=JSONField(),
  1541. )
  1542. )
  1543. else:
  1544. # 获取原文件的sha256_hash
  1545. original_hash = source_file.sha256_hash
  1546. # 读取新文件内容
  1547. file_content = file.read()
  1548. QuerySet(File).filter(
  1549. sha256_hash=original_hash,
  1550. source_id__in=[self.data.get('knowledge_id'), self.data.get('document_id')]
  1551. ).update(file_name=file.name)
  1552. # 查找所有具有相同sha256_hash的文件
  1553. files_to_update = QuerySet(File).filter(
  1554. sha256_hash=original_hash,
  1555. source_id__in=[self.data.get('knowledge_id'), self.data.get('document_id')]
  1556. )
  1557. # 更新所有相同hash的文件
  1558. for file_obj in files_to_update:
  1559. file_obj.save(file_content)
  1560. return True
  1561. class FileBufferHandle:
  1562. buffer = None
  1563. def get_buffer(self, file):
  1564. if self.buffer is None:
  1565. self.buffer = file.read()
  1566. return self.buffer