common.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. # coding=utf-8
  2. """
  3. @project: maxkb
  4. @Author:虎
  5. @file: common_serializers.py
  6. @date:2023/11/17 11:00
  7. @desc:
  8. """
  9. import os
  10. import re
  11. import zipfile
  12. from typing import List
  13. import uuid_utils.compat as uuid
  14. from django.db.models import QuerySet
  15. from django.utils.translation import gettext_lazy as _
  16. from rest_framework import serializers
  17. from application.flow.tools import save_workflow_mapping, get_instance_resource, knowledge_instance_field_call_dict
  18. from common.config.embedding_config import ModelManage
  19. from common.db.search import native_search
  20. from common.db.sql_execute import sql_execute, update_execute
  21. from common.exception.app_exception import AppApiException
  22. from common.utils.common import get_file_content
  23. from common.utils.fork import Fork
  24. from common.utils.logger import maxkb_logger
  25. from knowledge.models import Document, KnowledgeWorkflow, KnowledgeWorkflowVersion, KnowledgeType
  26. from knowledge.models import Paragraph, Problem, ProblemParagraphMapping, Knowledge, File
  27. from maxkb.conf import PROJECT_DIR
  28. from models_provider.tools import get_model, get_model_default_params
  29. from system_manage.models.resource_mapping import ResourceMapping, ResourceType
  30. class MetaSerializer(serializers.Serializer):
  31. class WebMeta(serializers.Serializer):
  32. source_url = serializers.CharField(required=True, label=_('source url'))
  33. selector = serializers.CharField(required=False, allow_null=True, allow_blank=True, label=_('selector'))
  34. def is_valid(self, *, raise_exception=False):
  35. super().is_valid(raise_exception=True)
  36. source_url = self.data.get('source_url')
  37. response = Fork(source_url, []).fork()
  38. if response.status == 500:
  39. raise AppApiException(500, _('URL error, cannot parse [{source_url}]').format(source_url=source_url))
  40. class BaseMeta(serializers.Serializer):
  41. def is_valid(self, *, raise_exception=False):
  42. super().is_valid(raise_exception=True)
  43. class BatchSerializer(serializers.Serializer):
  44. id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True), label=_('id list'))
  45. def is_valid(self, *, model=None, raise_exception=False):
  46. super().is_valid(raise_exception=True)
  47. if model is not None:
  48. id_list = self.data.get('id_list')
  49. model_list = QuerySet(model).filter(id__in=id_list)
  50. if len(model_list) != len(id_list):
  51. model_id_list = [str(m.id) for m in model_list]
  52. error_id_list = list(filter(lambda row_id: not model_id_list.__contains__(row_id), id_list))
  53. raise AppApiException(500, _('The following id does not exist: {error_id_list}').format(
  54. error_id_list=error_id_list))
  55. class BatchMoveSerializer(BatchSerializer):
  56. folder_id = serializers.CharField(required=True, label=_('folder id'))
  57. class ProblemParagraphObject:
  58. def __init__(self, knowledge_id: str, document_id: str, paragraph_id: str, problem_content: str):
  59. self.knowledge_id = knowledge_id
  60. self.document_id = document_id
  61. self.paragraph_id = paragraph_id
  62. self.problem_content = problem_content
  63. class GenerateRelatedSerializer(serializers.Serializer):
  64. model_id = serializers.UUIDField(required=True, label=_('Model id'))
  65. prompt = serializers.CharField(required=True, label=_('Prompt word'))
  66. state_list = serializers.ListField(required=False, child=serializers.CharField(required=True),
  67. label=_("state list"))
  68. class ProblemParagraphManage:
  69. def __init__(self, problem_paragraph_object_list: List[ProblemParagraphObject], knowledge_id):
  70. self.knowledge_id = knowledge_id
  71. self.problem_paragraph_object_list = problem_paragraph_object_list
  72. def to_problem_model_list(self):
  73. problem_list = [item.problem_content for item in self.problem_paragraph_object_list]
  74. exists_problem_list = []
  75. if len(self.problem_paragraph_object_list) > 0:
  76. # 查询到已存在的问题列表
  77. exists_problem_list = QuerySet(Problem).filter(knowledge_id=self.knowledge_id,
  78. content__in=problem_list).all()
  79. problem_content_dict = {}
  80. problem_model_list = [
  81. or_get(
  82. exists_problem_list,
  83. problemParagraphObject.problem_content,
  84. problemParagraphObject.knowledge_id,
  85. problemParagraphObject.document_id,
  86. problemParagraphObject.paragraph_id, problem_content_dict
  87. ) for problemParagraphObject in self.problem_paragraph_object_list]
  88. problem_paragraph_mapping_list = [
  89. ProblemParagraphMapping(
  90. id=uuid.uuid7(),
  91. document_id=document_id,
  92. problem_id=problem_model.id,
  93. paragraph_id=paragraph_id,
  94. knowledge_id=self.knowledge_id
  95. ) for problem_model, document_id, paragraph_id in problem_model_list]
  96. result = [
  97. problem_model for problem_model, is_create in problem_content_dict.values() if is_create
  98. ], problem_paragraph_mapping_list
  99. return result
  100. def get_embedding_model_by_knowledge_id_list(knowledge_id_list: List):
  101. knowledge_list = QuerySet(Knowledge).filter(id__in=knowledge_id_list)
  102. if len(set([knowledge.embedding_model_id for knowledge in knowledge_list])) > 1:
  103. raise Exception(_('The knowledge base is inconsistent with the vector model'))
  104. if len(knowledge_list) == 0:
  105. raise Exception(_('Knowledge base setting error, please reset the knowledge base'))
  106. default_params = get_model_default_params(knowledge_list[0].embedding_model)
  107. return ModelManage.get_model(
  108. str(knowledge_list[0].embedding_model_id),
  109. lambda _id: get_model(knowledge_list[0].embedding_model, **{**default_params})
  110. )
  111. def get_embedding_model_by_knowledge_id(knowledge_id: str):
  112. knowledge = QuerySet(Knowledge).select_related('embedding_model').filter(id=knowledge_id).first()
  113. default_params = get_model_default_params(knowledge.embedding_model)
  114. return ModelManage.get_model(str(knowledge.embedding_model_id),
  115. lambda _id: get_model(knowledge.embedding_model, **{**default_params}))
  116. def get_embedding_model_by_knowledge(knowledge):
  117. default_params = get_model_default_params(knowledge.embedding_model)
  118. return ModelManage.get_model(str(knowledge.embedding_model_id),
  119. lambda _id: get_model(knowledge.embedding_model, **{**default_params}))
  120. def get_embedding_model_id_by_knowledge_id(knowledge_id):
  121. knowledge = QuerySet(Knowledge).select_related('embedding_model').filter(id=knowledge_id).first()
  122. return str(knowledge.embedding_model_id)
  123. def get_embedding_model_id_by_knowledge_id_list(knowledge_id_list: List):
  124. knowledge_list = QuerySet(Knowledge).filter(id__in=knowledge_id_list)
  125. if len(set([knowledge.embedding_model_id for knowledge in knowledge_list])) > 1:
  126. raise Exception(_('The knowledge base is inconsistent with the vector model'))
  127. if len(knowledge_list) == 0:
  128. raise Exception(_('Knowledge base setting error, please reset the knowledge base'))
  129. return str(knowledge_list[0].embedding_model_id)
  130. def zip_dir(zip_path, output=None):
  131. output = output or os.path.basename(zip_path) + '.zip'
  132. zip = zipfile.ZipFile(output, 'w', zipfile.ZIP_DEFLATED)
  133. for root, dirs, files in os.walk(zip_path):
  134. relative_root = '' if root == zip_path else root.replace(zip_path, '') + os.sep
  135. for filename in files:
  136. zip.write(os.path.join(root, filename), relative_root + filename)
  137. zip.close()
  138. def is_valid_uuid(s):
  139. try:
  140. uuid.UUID(s)
  141. return True
  142. except ValueError:
  143. return False
  144. def write_image(zip_path: str, image_list: List[str]):
  145. for image in image_list:
  146. search = re.search("\(.*\)", image)
  147. if search:
  148. text = search.group()
  149. if text.startswith('(./oss/file/'):
  150. r = text.replace('(./oss/file/', '').replace(')', '')
  151. r = r.strip().split(" ")[0]
  152. if not is_valid_uuid(r):
  153. break
  154. file = QuerySet(File).filter(id=r).first()
  155. if file is None:
  156. break
  157. zip_inner_path = os.path.join('oss', 'file', r)
  158. file_path = os.path.join(zip_path, zip_inner_path)
  159. if not os.path.exists(os.path.dirname(file_path)):
  160. os.makedirs(os.path.dirname(file_path))
  161. with open(os.path.join(zip_path, file_path), 'wb') as f:
  162. f.write(file.get_bytes())
  163. def update_document_char_length(document_id: str):
  164. update_execute(get_file_content(
  165. os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'update_document_char_length.sql')),
  166. (document_id, document_id))
  167. def list_paragraph(paragraph_list: List[str]):
  168. if paragraph_list is None or len(paragraph_list) == 0:
  169. return []
  170. return native_search(QuerySet(Paragraph).filter(id__in=paragraph_list), get_file_content(
  171. os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'list_paragraph.sql')))
  172. def or_get(exists_problem_list, content, knowledge_id, document_id, paragraph_id, problem_content_dict):
  173. if content in problem_content_dict:
  174. return problem_content_dict.get(content)[0], document_id, paragraph_id
  175. exists = [row for row in exists_problem_list if row.content == content]
  176. if len(exists) > 0:
  177. problem_content_dict[content] = exists[0], False
  178. return exists[0], document_id, paragraph_id
  179. else:
  180. problem = Problem(id=uuid.uuid7(), content=content, knowledge_id=knowledge_id)
  181. problem_content_dict[content] = problem, True
  182. return problem, document_id, paragraph_id
  183. def get_knowledge_operation_object(knowledge_id: str):
  184. knowledge_model = QuerySet(model=Knowledge).filter(id=knowledge_id).first()
  185. if knowledge_model is not None:
  186. return {
  187. "name": knowledge_model.name,
  188. "desc": knowledge_model.desc,
  189. "type": knowledge_model.type,
  190. "create_time": knowledge_model.create_time,
  191. "update_time": knowledge_model.update_time
  192. }
  193. return {}
  194. def create_knowledge_index(knowledge_id=None, document_id=None):
  195. if knowledge_id is None and document_id is None:
  196. raise AppApiException(500, _('Knowledge ID or Document ID must be provided'))
  197. if knowledge_id is not None:
  198. k_id = knowledge_id
  199. else:
  200. document = QuerySet(Document).filter(id=document_id).first()
  201. k_id = document.knowledge_id
  202. sql = f"SELECT indexname, indexdef FROM pg_indexes WHERE tablename = 'embedding' AND indexname = 'embedding_hnsw_idx_{k_id}'"
  203. index = sql_execute(sql, [])
  204. if not index:
  205. sql = f"SELECT vector_dims(embedding) AS dims FROM embedding WHERE knowledge_id = '{k_id}' LIMIT 1"
  206. result = sql_execute(sql, [])
  207. if len(result) == 0:
  208. return
  209. dims = result[0]['dims']
  210. # 超过2000维度不创建索引,pgvector hnsw索引不支持超过2000维度
  211. if dims < 2000:
  212. sql = f"""CREATE INDEX "embedding_hnsw_idx_{k_id}" ON embedding USING hnsw ((embedding::vector({dims})) vector_cosine_ops) WHERE knowledge_id = '{k_id}'"""
  213. update_execute(sql, [])
  214. maxkb_logger.info(f'Created index for knowledge ID: {k_id}')
  215. def drop_knowledge_index(knowledge_id=None, document_id=None):
  216. if knowledge_id is None and document_id is None:
  217. raise AppApiException(500, _('Knowledge ID or Document ID must be provided'))
  218. if knowledge_id is not None:
  219. k_id = knowledge_id
  220. else:
  221. document = QuerySet(Document).filter(id=document_id).first()
  222. k_id = document.knowledge_id
  223. sql = f"SELECT indexname, indexdef FROM pg_indexes WHERE tablename = 'embedding' AND indexname = 'embedding_hnsw_idx_{k_id}'"
  224. index = sql_execute(sql, [])
  225. if index:
  226. sql = f'DROP INDEX "embedding_hnsw_idx_{k_id}"'
  227. update_execute(sql, [])
  228. maxkb_logger.info(f'Dropped index for knowledge ID: {k_id}')
  229. def update_resource_mapping_by_knowledge(knowledge_id: str):
  230. knowledge = QuerySet(Knowledge).filter(id=knowledge_id).first()
  231. instance_mapping = get_instance_resource(knowledge, ResourceType.KNOWLEDGE, str(knowledge.id),
  232. knowledge_instance_field_call_dict)
  233. if knowledge.type == KnowledgeType.WORKFLOW:
  234. knowledge_workflow = QuerySet(KnowledgeWorkflow).filter(
  235. knowledge_id=knowledge_id).order_by(
  236. '-create_time')[0:1].first()
  237. if knowledge_workflow:
  238. save_workflow_mapping(knowledge_workflow.work_flow, ResourceType.KNOWLEDGE,
  239. str(knowledge_id), instance_mapping)
  240. return
  241. else:
  242. save_workflow_mapping({}, ResourceType.KNOWLEDGE,
  243. str(knowledge_id), instance_mapping)