| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217 |
- # coding=utf-8
- """
- 样本中心知识库对接视图
- 提供样本中心知识库查询、批量入库等功能
- """
- import uuid_utils.compat as uuid
- from django.utils.translation import gettext as _
- from rest_framework.request import Request
- from rest_framework.views import APIView
- from common.auth import TokenAuth
- from common.exception.app_exception import AppApiException
- from common.result import result
- from knowledge.services.sample_center_client import get_sample_center_client
- class SampleCenterView(APIView):
- """样本中心知识库对接"""
- authentication_classes = [TokenAuth]
- class ListKnowledgeBases(APIView):
- """查询样本中心知识库列表"""
- authentication_classes = [TokenAuth]
- def get(self, request: Request):
- page = int(request.query_params.get('page', 1))
- page_size = int(request.query_params.get('page_size', 20))
- base_url = request.query_params.get('base_url', '')
- app_id = request.query_params.get('app_id', '')
- app_secret = request.query_params.get('app_secret', '')
- if not all([base_url, app_id, app_secret]):
- raise AppApiException(400, _('base_url, app_id, app_secret are required'))
- client = get_sample_center_client(
- base_url=base_url,
- app_id=app_id,
- app_secret=app_secret,
- )
- data = client.list_knowledge_bases(page=page, page_size=page_size)
- return result.success(data)
- class GetKnowledgeBase(APIView):
- """查询样本中心知识库详情"""
- authentication_classes = [TokenAuth]
- def get(self, request: Request):
- kb_id = request.query_params.get('kb_id', '')
- base_url = request.query_params.get('base_url', '')
- app_id = request.query_params.get('app_id', '')
- app_secret = request.query_params.get('app_secret', '')
- if not all([kb_id, base_url, app_id, app_secret]):
- raise AppApiException(400, _('kb_id, base_url, app_id, app_secret are required'))
- client = get_sample_center_client(
- base_url=base_url,
- app_id=app_id,
- app_secret=app_secret,
- )
- data = client.get_knowledge_base(kb_id)
- return result.success(data)
- class BatchImport(APIView):
- """提交批量入库任务(直接转发到样本中心)"""
- authentication_classes = [TokenAuth]
- def post(self, request: Request):
- kb_id = request.data.get('kb_id', '')
- base_url = request.data.get('base_url', '')
- app_id = request.data.get('app_id', '')
- app_secret = request.data.get('app_secret', '')
- parents = request.data.get('parents', [])
- children = request.data.get('children', [])
- callback_url = request.data.get('callback_url', '')
- if not all([kb_id, base_url, app_id, app_secret, parents]):
- raise AppApiException(400, _('kb_id, base_url, app_id, app_secret, parents are required'))
- task_no = f"IMP{uuid.uuid4().hex[:16].upper()}"
- client = get_sample_center_client(
- base_url=base_url,
- app_id=app_id,
- app_secret=app_secret,
- )
- data = client.batch_import(
- kb_id=kb_id,
- task_no=task_no,
- parents=parents,
- children=children,
- callback_url=callback_url,
- )
- return result.success(data)
- class GetImportTask(APIView):
- """查询批量入库任务状态"""
- authentication_classes = [TokenAuth]
- def get(self, request: Request):
- task_id = request.query_params.get('task_id', '')
- base_url = request.query_params.get('base_url', '')
- app_id = request.query_params.get('app_id', '')
- app_secret = request.query_params.get('app_secret', '')
- if not all([task_id, base_url, app_id, app_secret]):
- raise AppApiException(400, _('task_id, base_url, app_id, app_secret are required'))
- client = get_sample_center_client(
- base_url=base_url,
- app_id=app_id,
- app_secret=app_secret,
- )
- data = client.get_import_task(task_id)
- return result.success(data)
- class SyncDocuments(APIView):
- """将本地知识库文档段落推送到样本中心"""
- authentication_classes = [TokenAuth]
- def post(self, request: Request, workspace_id: str, knowledge_id: str):
- from knowledge.models import Knowledge, Document, Paragraph, DocumentTag, Tag
- document_ids = request.data.get('document_ids', [])
- if not document_ids:
- raise AppApiException(400, _('document_ids is required'))
- # 获取知识库信息
- try:
- knowledge = Knowledge.objects.get(id=knowledge_id, workspace_id=workspace_id)
- except Knowledge.DoesNotExist:
- raise AppApiException(404, _('Knowledge base not found'))
- # 获取样本中心配置
- sample_center_config = knowledge.meta.get('sample_center', {})
- base_url = sample_center_config.get('base_url', '')
- app_id = sample_center_config.get('app_id', '')
- app_secret = sample_center_config.get('app_secret', '')
- kb_id = sample_center_config.get('kb_id', '')
- if not all([base_url, app_id, app_secret, kb_id]):
- raise AppApiException(400, _('Knowledge base is not linked to a sample center'))
- # 查询文档
- documents = Document.objects.filter(
- id__in=document_ids,
- knowledge_id=knowledge_id
- )
- if not documents.exists():
- raise AppApiException(400, _('No valid documents found'))
- # 查询段落
- paragraphs = Paragraph.objects.filter(
- document_id__in=document_ids,
- is_active=True
- ).order_by('document_id', 'position')
- if not paragraphs.exists():
- raise AppApiException(400, _('No paragraphs found in selected documents'))
- # 查询文档标签
- doc_tags = {}
- tag_mappings = DocumentTag.objects.filter(
- document_id__in=document_ids
- ).select_related('tag')
- for mapping in tag_mappings:
- doc_id = str(mapping.document_id)
- if doc_id not in doc_tags:
- doc_tags[doc_id] = []
- doc_tags[doc_id].append(f"{mapping.tag.key}:{mapping.tag.value}")
- # 构建文档名称映射
- doc_name_map = {str(doc.id): doc.name for doc in documents}
- # 将段落转换为样本中心的 parents 格式
- # 每个段落作为一个 parent 条目
- parents = []
- for idx, paragraph in enumerate(paragraphs):
- doc_id = str(paragraph.document_id)
- parent_item = {
- 'index': idx,
- 'parent_id': kb_id,
- 'hierarchy': doc_name_map.get(doc_id, ''),
- 'text': paragraph.content,
- 'metadata': {
- 'source_knowledge_id': str(knowledge_id),
- 'source_document_id': doc_id,
- 'source_paragraph_id': str(paragraph.id),
- 'title': paragraph.title or '',
- },
- 'doc_id': doc_id,
- 'tag_list': doc_tags.get(doc_id, []),
- }
- parents.append(parent_item)
- # 生成任务号
- task_no = f"IMP{uuid.uuid4().hex[:16].upper()}"
- # 调用样本中心批量入库接口
- client = get_sample_center_client(
- base_url=base_url,
- app_id=app_id,
- app_secret=app_secret,
- )
- data = client.batch_import(
- kb_id=kb_id,
- task_no=task_no,
- parents=parents,
- )
- return result.success({
- 'task_id': data.get('task_id', ''),
- 'task_no': task_no,
- 'status': data.get('status', 'pending'),
- 'total_paragraphs': len(parents),
- 'total_documents': documents.count(),
- })
|