# coding=utf-8 """ 样本中心知识库对接视图 提供样本中心知识库查询、批量入库等功能 """ import uuid_utils.compat as uuid from django.utils.translation import gettext as _ from rest_framework.request import Request from rest_framework.views import APIView from common.auth import TokenAuth from common.exception.app_exception import AppApiException from common.result import result from knowledge.services.sample_center_client import get_sample_center_client class SampleCenterView(APIView): """样本中心知识库对接""" authentication_classes = [TokenAuth] class ListKnowledgeBases(APIView): """查询样本中心知识库列表""" authentication_classes = [TokenAuth] def get(self, request: Request): page = int(request.query_params.get('page', 1)) page_size = int(request.query_params.get('page_size', 20)) base_url = request.query_params.get('base_url', '') app_id = request.query_params.get('app_id', '') app_secret = request.query_params.get('app_secret', '') if not all([base_url, app_id, app_secret]): raise AppApiException(400, _('base_url, app_id, app_secret are required')) client = get_sample_center_client( base_url=base_url, app_id=app_id, app_secret=app_secret, ) data = client.list_knowledge_bases(page=page, page_size=page_size) return result.success(data) class GetKnowledgeBase(APIView): """查询样本中心知识库详情""" authentication_classes = [TokenAuth] def get(self, request: Request): kb_id = request.query_params.get('kb_id', '') base_url = request.query_params.get('base_url', '') app_id = request.query_params.get('app_id', '') app_secret = request.query_params.get('app_secret', '') if not all([kb_id, base_url, app_id, app_secret]): raise AppApiException(400, _('kb_id, base_url, app_id, app_secret are required')) client = get_sample_center_client( base_url=base_url, app_id=app_id, app_secret=app_secret, ) data = client.get_knowledge_base(kb_id) return result.success(data) class BatchImport(APIView): """提交批量入库任务(直接转发到样本中心)""" authentication_classes = [TokenAuth] def post(self, request: Request): kb_id = request.data.get('kb_id', '') base_url = request.data.get('base_url', '') app_id = request.data.get('app_id', '') app_secret = request.data.get('app_secret', '') parents = request.data.get('parents', []) children = request.data.get('children', []) callback_url = request.data.get('callback_url', '') if not all([kb_id, base_url, app_id, app_secret, parents]): raise AppApiException(400, _('kb_id, base_url, app_id, app_secret, parents are required')) task_no = f"IMP{uuid.uuid4().hex[:16].upper()}" client = get_sample_center_client( base_url=base_url, app_id=app_id, app_secret=app_secret, ) data = client.batch_import( kb_id=kb_id, task_no=task_no, parents=parents, children=children, callback_url=callback_url, ) return result.success(data) class GetImportTask(APIView): """查询批量入库任务状态""" authentication_classes = [TokenAuth] def get(self, request: Request): task_id = request.query_params.get('task_id', '') base_url = request.query_params.get('base_url', '') app_id = request.query_params.get('app_id', '') app_secret = request.query_params.get('app_secret', '') if not all([task_id, base_url, app_id, app_secret]): raise AppApiException(400, _('task_id, base_url, app_id, app_secret are required')) client = get_sample_center_client( base_url=base_url, app_id=app_id, app_secret=app_secret, ) data = client.get_import_task(task_id) return result.success(data) class SyncDocuments(APIView): """将本地知识库文档段落推送到样本中心""" authentication_classes = [TokenAuth] def post(self, request: Request, workspace_id: str, knowledge_id: str): from knowledge.models import Knowledge, Document, Paragraph, DocumentTag, Tag document_ids = request.data.get('document_ids', []) if not document_ids: raise AppApiException(400, _('document_ids is required')) # 获取知识库信息 try: knowledge = Knowledge.objects.get(id=knowledge_id, workspace_id=workspace_id) except Knowledge.DoesNotExist: raise AppApiException(404, _('Knowledge base not found')) # 获取样本中心配置 sample_center_config = knowledge.meta.get('sample_center', {}) base_url = sample_center_config.get('base_url', '') app_id = sample_center_config.get('app_id', '') app_secret = sample_center_config.get('app_secret', '') kb_id = sample_center_config.get('kb_id', '') if not all([base_url, app_id, app_secret, kb_id]): raise AppApiException(400, _('Knowledge base is not linked to a sample center')) # 查询文档 documents = Document.objects.filter( id__in=document_ids, knowledge_id=knowledge_id ) if not documents.exists(): raise AppApiException(400, _('No valid documents found')) # 查询段落 paragraphs = Paragraph.objects.filter( document_id__in=document_ids, is_active=True ).order_by('document_id', 'position') if not paragraphs.exists(): raise AppApiException(400, _('No paragraphs found in selected documents')) # 查询文档标签 doc_tags = {} tag_mappings = DocumentTag.objects.filter( document_id__in=document_ids ).select_related('tag') for mapping in tag_mappings: doc_id = str(mapping.document_id) if doc_id not in doc_tags: doc_tags[doc_id] = [] doc_tags[doc_id].append(f"{mapping.tag.key}:{mapping.tag.value}") # 构建文档名称映射 doc_name_map = {str(doc.id): doc.name for doc in documents} # 将段落转换为样本中心的 parents 格式 # 每个段落作为一个 parent 条目 parents = [] for idx, paragraph in enumerate(paragraphs): doc_id = str(paragraph.document_id) parent_item = { 'index': idx, 'parent_id': kb_id, 'hierarchy': doc_name_map.get(doc_id, ''), 'text': paragraph.content, 'metadata': { 'source_knowledge_id': str(knowledge_id), 'source_document_id': doc_id, 'source_paragraph_id': str(paragraph.id), 'title': paragraph.title or '', }, 'doc_id': doc_id, 'tag_list': doc_tags.get(doc_id, []), } parents.append(parent_item) # 生成任务号 task_no = f"IMP{uuid.uuid4().hex[:16].upper()}" # 调用样本中心批量入库接口 client = get_sample_center_client( base_url=base_url, app_id=app_id, app_secret=app_secret, ) data = client.batch_import( kb_id=kb_id, task_no=task_no, parents=parents, ) return result.success({ 'task_id': data.get('task_id', ''), 'task_no': task_no, 'status': data.get('status', 'pending'), 'total_paragraphs': len(parents), 'total_documents': documents.count(), })