""" 知识片段视图路由 """ from fastapi import APIRouter, Depends, Query, Path, Body from fastapi.responses import StreamingResponse from typing import Optional, Dict, Any from datetime import datetime import urllib.parse from sqlalchemy.ext.asyncio import AsyncSession from app.base.async_mysql_connection import get_db from app.services.snippet_service import snippet_service from app.services.knowledge_base_service import knowledge_base_service from app.schemas.base import ResponseSchema, PaginatedResponseSchema from app.services.jwt_token import verify_token from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from pydantic import BaseModel router = APIRouter(prefix="/document/snippet", tags=["样本中心-知识片段"]) security = HTTPBearer() # Schemas class SnippetCreate(BaseModel): collection_name: str doc_name: str = "手动添加" content: str meta_info: Optional[str] = None custom_fields: Optional[Dict[str, Any]] = None class SnippetUpdate(BaseModel): collection_name: str doc_name: Optional[str] = None content: str custom_fields: Optional[Dict[str, Any]] = None @router.get("", response_model=PaginatedResponseSchema) async def get_snippets( page: int = Query(1, ge=1), page_size: int = Query(10, ge=1), kb: Optional[str] = Query(None, description="知识库集合名称"), keyword: Optional[str] = Query(None), status: Optional[str] = Query(None), credentials: HTTPAuthorizationCredentials = Depends(security) ): """获取知识片段列表 (跨集合查询)""" items, meta = await snippet_service.get_list(page, page_size, kb, keyword, status) return PaginatedResponseSchema( code=0, message="获取成功", data=items, meta=meta ) @router.get("/detail", response_model=ResponseSchema) async def get_snippet_detail( kb: str = Query(..., description="知识库名称"), id: str = Query(..., description="片段ID (document_id 或 pk)"), credentials: HTTPAuthorizationCredentials = Depends(security), db: AsyncSession = Depends(get_db) ): """获取知识片段详情""" payload_token = verify_token(credentials.credentials) if not payload_token: return ResponseSchema(code=401, message="无效的访问令牌") data = await snippet_service.get_by_id(db, kb, id) if not data: return ResponseSchema(code=404, message="未找到该片段") return ResponseSchema(code=0, message="获取成功", data=data) @router.get("/export") async def export_snippets( kb: Optional[str] = Query(None, description="知识库集合名称"), keyword: Optional[str] = Query(None), status: Optional[str] = Query(None), credentials: HTTPAuthorizationCredentials = Depends(security) ): """导出知识片段""" payload_token = verify_token(credentials.credentials) if not payload_token: return ResponseSchema(code=401, message="无效的访问令牌") filename = f"snippets_export_{datetime.now().strftime('%Y%m%d%H%M%S')}.csv" encoded_filename = urllib.parse.quote(filename) return StreamingResponse( snippet_service.generate_csv_stream(kb, keyword), media_type="text/csv", headers={ "Content-Disposition": f"attachment; filename={filename}; filename*=utf-8''{encoded_filename}" } ) @router.post("", response_model=ResponseSchema) async def create_snippet( payload: SnippetCreate, credentials: HTTPAuthorizationCredentials = Depends(security), db: AsyncSession = Depends(get_db) ): """创建知识片段""" payload_token = verify_token(credentials.credentials) if not payload_token: return ResponseSchema(code=401, message="无效的访问令牌") data = await snippet_service.create(db, payload) return ResponseSchema(code=0, message="创建成功", data=data) @router.post("/{id}", response_model=ResponseSchema) async def update_snippet( id: str, payload: SnippetUpdate, credentials: HTTPAuthorizationCredentials = Depends(security), db: AsyncSession = Depends(get_db) ): """更新知识片段""" payload_token = verify_token(credentials.credentials) if not payload_token: return ResponseSchema(code=401, message="无效的访问令牌") msg = await snippet_service.update(db, id, payload) return ResponseSchema(code=0, message=msg) @router.post("/{id}/delete", response_model=ResponseSchema) async def delete_snippet( id: str, kb: str = Query(..., description="知识库名称"), credentials: HTTPAuthorizationCredentials = Depends(security), db: AsyncSession = Depends(get_db) ): """删除知识片段""" payload_token = verify_token(credentials.credentials) if not payload_token: return ResponseSchema(code=401, message="无效的访问令牌") snippet_service.delete(id, kb) # 更新知识库文档数量并返回最新计数,便于前端立即刷新展示 new_count = await knowledge_base_service.update_doc_count(db, kb) return ResponseSchema(code=0, message="删除成功", data={"document_count": new_count})