| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142 |
- """
- 知识片段视图路由
- """
- from fastapi import APIRouter, Depends, Query, Path, Body
- from fastapi.responses import StreamingResponse
- from typing import Optional, Dict, Any
- from datetime import datetime
- import urllib.parse
- from sqlalchemy.ext.asyncio import AsyncSession
- from app.base.async_mysql_connection import get_db
- from app.services.snippet_service import snippet_service
- from app.services.knowledge_base_service import knowledge_base_service
- from app.schemas.base import ResponseSchema, PaginatedResponseSchema
- from app.services.jwt_token import verify_token
- from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
- from pydantic import BaseModel
- router = APIRouter(prefix="/document/snippet", tags=["样本中心-知识片段"])
- security = HTTPBearer()
- # Schemas
- class SnippetCreate(BaseModel):
- collection_name: str
- doc_name: str = "手动添加"
- content: str
- meta_info: Optional[str] = None
- custom_fields: Optional[Dict[str, Any]] = None
- class SnippetUpdate(BaseModel):
- collection_name: str
- doc_name: Optional[str] = None
- content: str
- custom_fields: Optional[Dict[str, Any]] = None
- @router.get("", response_model=PaginatedResponseSchema)
- async def get_snippets(
- page: int = Query(1, ge=1),
- page_size: int = Query(10, ge=1),
- kb: Optional[str] = Query(None, description="知识库集合名称"),
- keyword: Optional[str] = Query(None),
- status: Optional[str] = Query(None),
- credentials: HTTPAuthorizationCredentials = Depends(security)
- ):
- """获取知识片段列表 (跨集合查询)"""
- items, meta = await snippet_service.get_list(page, page_size, kb, keyword, status)
-
- return PaginatedResponseSchema(
- code=0,
- message="获取成功",
- data=items,
- meta=meta
- )
- @router.get("/detail", response_model=ResponseSchema)
- async def get_snippet_detail(
- kb: str = Query(..., description="知识库名称"),
- id: str = Query(..., description="片段ID (document_id 或 pk)"),
- credentials: HTTPAuthorizationCredentials = Depends(security),
- db: AsyncSession = Depends(get_db)
- ):
- """获取知识片段详情"""
- payload_token = verify_token(credentials.credentials)
- if not payload_token:
- return ResponseSchema(code=401, message="无效的访问令牌")
-
- data = await snippet_service.get_by_id(db, kb, id)
- if not data:
- return ResponseSchema(code=404, message="未找到该片段")
-
- return ResponseSchema(code=0, message="获取成功", data=data)
- @router.get("/export")
- async def export_snippets(
- kb: Optional[str] = Query(None, description="知识库集合名称"),
- keyword: Optional[str] = Query(None),
- status: Optional[str] = Query(None),
- credentials: HTTPAuthorizationCredentials = Depends(security)
- ):
- """导出知识片段"""
- payload_token = verify_token(credentials.credentials)
- if not payload_token:
- return ResponseSchema(code=401, message="无效的访问令牌")
- filename = f"snippets_export_{datetime.now().strftime('%Y%m%d%H%M%S')}.csv"
- encoded_filename = urllib.parse.quote(filename)
-
- return StreamingResponse(
- snippet_service.generate_csv_stream(kb, keyword),
- media_type="text/csv",
- headers={
- "Content-Disposition": f"attachment; filename={filename}; filename*=utf-8''{encoded_filename}"
- }
- )
- @router.post("", response_model=ResponseSchema)
- async def create_snippet(
- payload: SnippetCreate,
- credentials: HTTPAuthorizationCredentials = Depends(security),
- db: AsyncSession = Depends(get_db)
- ):
- """创建知识片段"""
- payload_token = verify_token(credentials.credentials)
- if not payload_token:
- return ResponseSchema(code=401, message="无效的访问令牌")
- data = await snippet_service.create(db, payload)
- return ResponseSchema(code=0, message="创建成功", data=data)
- @router.post("/{id}", response_model=ResponseSchema)
- async def update_snippet(
- id: str,
- payload: SnippetUpdate,
- credentials: HTTPAuthorizationCredentials = Depends(security),
- db: AsyncSession = Depends(get_db)
- ):
- """更新知识片段"""
- payload_token = verify_token(credentials.credentials)
- if not payload_token:
- return ResponseSchema(code=401, message="无效的访问令牌")
- msg = await snippet_service.update(db, id, payload)
- return ResponseSchema(code=0, message=msg)
- @router.post("/{id}/delete", response_model=ResponseSchema)
- async def delete_snippet(
- id: str,
- kb: str = Query(..., description="知识库名称"),
- credentials: HTTPAuthorizationCredentials = Depends(security),
- db: AsyncSession = Depends(get_db)
- ):
- """删除知识片段"""
- payload_token = verify_token(credentials.credentials)
- if not payload_token:
- return ResponseSchema(code=401, message="无效的访问令牌")
-
- snippet_service.delete(id, kb)
- # 更新知识库文档数量并返回最新计数,便于前端立即刷新展示
- new_count = await knowledge_base_service.update_doc_count(db, kb)
- return ResponseSchema(code=0, message="删除成功", data={"document_count": new_count})
|