sample_center.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. # coding=utf-8
  2. """
  3. 样本中心知识库对接视图
  4. 提供样本中心知识库查询、批量入库等功能
  5. """
  6. import uuid_utils.compat as uuid
  7. from django.utils.translation import gettext as _
  8. from rest_framework.request import Request
  9. from rest_framework.views import APIView
  10. from common.auth import TokenAuth
  11. from common.exception.app_exception import AppApiException
  12. from common.result import result
  13. from knowledge.services.sample_center_client import get_sample_center_client
  14. class SampleCenterView(APIView):
  15. """样本中心知识库对接"""
  16. authentication_classes = [TokenAuth]
  17. class ListKnowledgeBases(APIView):
  18. """查询样本中心知识库列表"""
  19. authentication_classes = [TokenAuth]
  20. def get(self, request: Request):
  21. page = int(request.query_params.get('page', 1))
  22. page_size = int(request.query_params.get('page_size', 20))
  23. base_url = request.query_params.get('base_url', '')
  24. app_id = request.query_params.get('app_id', '')
  25. app_secret = request.query_params.get('app_secret', '')
  26. if not all([base_url, app_id, app_secret]):
  27. raise AppApiException(400, _('base_url, app_id, app_secret are required'))
  28. client = get_sample_center_client(
  29. base_url=base_url,
  30. app_id=app_id,
  31. app_secret=app_secret,
  32. )
  33. data = client.list_knowledge_bases(page=page, page_size=page_size)
  34. return result.success(data)
  35. class GetKnowledgeBase(APIView):
  36. """查询样本中心知识库详情"""
  37. authentication_classes = [TokenAuth]
  38. def get(self, request: Request):
  39. kb_id = request.query_params.get('kb_id', '')
  40. base_url = request.query_params.get('base_url', '')
  41. app_id = request.query_params.get('app_id', '')
  42. app_secret = request.query_params.get('app_secret', '')
  43. if not all([kb_id, base_url, app_id, app_secret]):
  44. raise AppApiException(400, _('kb_id, base_url, app_id, app_secret are required'))
  45. client = get_sample_center_client(
  46. base_url=base_url,
  47. app_id=app_id,
  48. app_secret=app_secret,
  49. )
  50. data = client.get_knowledge_base(kb_id)
  51. return result.success(data)
  52. class BatchImport(APIView):
  53. """提交批量入库任务(直接转发到样本中心)"""
  54. authentication_classes = [TokenAuth]
  55. def post(self, request: Request):
  56. kb_id = request.data.get('kb_id', '')
  57. base_url = request.data.get('base_url', '')
  58. app_id = request.data.get('app_id', '')
  59. app_secret = request.data.get('app_secret', '')
  60. parents = request.data.get('parents', [])
  61. children = request.data.get('children', [])
  62. callback_url = request.data.get('callback_url', '')
  63. if not all([kb_id, base_url, app_id, app_secret, parents]):
  64. raise AppApiException(400, _('kb_id, base_url, app_id, app_secret, parents are required'))
  65. task_no = f"IMP{uuid.uuid4().hex[:16].upper()}"
  66. client = get_sample_center_client(
  67. base_url=base_url,
  68. app_id=app_id,
  69. app_secret=app_secret,
  70. )
  71. data = client.batch_import(
  72. kb_id=kb_id,
  73. task_no=task_no,
  74. parents=parents,
  75. children=children,
  76. callback_url=callback_url,
  77. )
  78. return result.success(data)
  79. class GetImportTask(APIView):
  80. """查询批量入库任务状态"""
  81. authentication_classes = [TokenAuth]
  82. def get(self, request: Request):
  83. task_id = request.query_params.get('task_id', '')
  84. base_url = request.query_params.get('base_url', '')
  85. app_id = request.query_params.get('app_id', '')
  86. app_secret = request.query_params.get('app_secret', '')
  87. if not all([task_id, base_url, app_id, app_secret]):
  88. raise AppApiException(400, _('task_id, base_url, app_id, app_secret are required'))
  89. client = get_sample_center_client(
  90. base_url=base_url,
  91. app_id=app_id,
  92. app_secret=app_secret,
  93. )
  94. data = client.get_import_task(task_id)
  95. return result.success(data)
  96. class SyncDocuments(APIView):
  97. """将本地知识库文档段落推送到样本中心"""
  98. authentication_classes = [TokenAuth]
  99. def post(self, request: Request, workspace_id: str, knowledge_id: str):
  100. from knowledge.models import Knowledge, Document, Paragraph, DocumentTag, Tag
  101. document_ids = request.data.get('document_ids', [])
  102. if not document_ids:
  103. raise AppApiException(400, _('document_ids is required'))
  104. # 获取知识库信息
  105. try:
  106. knowledge = Knowledge.objects.get(id=knowledge_id, workspace_id=workspace_id)
  107. except Knowledge.DoesNotExist:
  108. raise AppApiException(404, _('Knowledge base not found'))
  109. # 获取样本中心配置
  110. sample_center_config = knowledge.meta.get('sample_center', {})
  111. base_url = sample_center_config.get('base_url', '')
  112. app_id = sample_center_config.get('app_id', '')
  113. app_secret = sample_center_config.get('app_secret', '')
  114. kb_id = sample_center_config.get('kb_id', '')
  115. if not all([base_url, app_id, app_secret, kb_id]):
  116. raise AppApiException(400, _('Knowledge base is not linked to a sample center'))
  117. # 查询文档
  118. documents = Document.objects.filter(
  119. id__in=document_ids,
  120. knowledge_id=knowledge_id
  121. )
  122. if not documents.exists():
  123. raise AppApiException(400, _('No valid documents found'))
  124. # 查询段落
  125. paragraphs = Paragraph.objects.filter(
  126. document_id__in=document_ids,
  127. is_active=True
  128. ).order_by('document_id', 'position')
  129. if not paragraphs.exists():
  130. raise AppApiException(400, _('No paragraphs found in selected documents'))
  131. # 查询文档标签
  132. doc_tags = {}
  133. tag_mappings = DocumentTag.objects.filter(
  134. document_id__in=document_ids
  135. ).select_related('tag')
  136. for mapping in tag_mappings:
  137. doc_id = str(mapping.document_id)
  138. if doc_id not in doc_tags:
  139. doc_tags[doc_id] = []
  140. doc_tags[doc_id].append(f"{mapping.tag.key}:{mapping.tag.value}")
  141. # 构建文档名称映射
  142. doc_name_map = {str(doc.id): doc.name for doc in documents}
  143. # 将段落转换为样本中心的 parents 格式
  144. # 每个段落作为一个 parent 条目
  145. parents = []
  146. for idx, paragraph in enumerate(paragraphs):
  147. doc_id = str(paragraph.document_id)
  148. parent_item = {
  149. 'index': idx,
  150. 'parent_id': kb_id,
  151. 'hierarchy': doc_name_map.get(doc_id, ''),
  152. 'text': paragraph.content,
  153. 'metadata': {
  154. 'source_knowledge_id': str(knowledge_id),
  155. 'source_document_id': doc_id,
  156. 'source_paragraph_id': str(paragraph.id),
  157. 'title': paragraph.title or '',
  158. },
  159. 'doc_id': doc_id,
  160. 'tag_list': doc_tags.get(doc_id, []),
  161. }
  162. parents.append(parent_item)
  163. # 生成任务号
  164. task_no = f"IMP{uuid.uuid4().hex[:16].upper()}"
  165. # 调用样本中心批量入库接口
  166. client = get_sample_center_client(
  167. base_url=base_url,
  168. app_id=app_id,
  169. app_secret=app_secret,
  170. )
  171. data = client.batch_import(
  172. kb_id=kb_id,
  173. task_no=task_no,
  174. parents=parents,
  175. )
  176. return result.success({
  177. 'task_id': data.get('task_id', ''),
  178. 'task_no': task_no,
  179. 'status': data.get('status', 'pending'),
  180. 'total_paragraphs': len(parents),
  181. 'total_documents': documents.count(),
  182. })