snippet_view.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. """
  2. 知识片段视图路由
  3. """
  4. from fastapi import APIRouter, Depends, Query, Path, Body
  5. from fastapi.responses import StreamingResponse
  6. from typing import Optional, Dict, Any
  7. from datetime import datetime
  8. import urllib.parse
  9. from sqlalchemy.ext.asyncio import AsyncSession
  10. from app.base.async_mysql_connection import get_db
  11. from app.services.snippet_service import snippet_service
  12. from app.services.knowledge_base_service import knowledge_base_service
  13. from app.schemas.base import ResponseSchema, PaginatedResponseSchema
  14. from app.services.jwt_token import verify_token
  15. from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
  16. from pydantic import BaseModel
  17. router = APIRouter(prefix="/document/snippet", tags=["样本中心-知识片段"])
  18. security = HTTPBearer()
  19. # Schemas
  20. class SnippetCreate(BaseModel):
  21. collection_name: str
  22. doc_name: str = "手动添加"
  23. content: str
  24. meta_info: Optional[str] = None
  25. custom_fields: Optional[Dict[str, Any]] = None
  26. class SnippetUpdate(BaseModel):
  27. collection_name: str
  28. doc_name: Optional[str] = None
  29. content: str
  30. custom_fields: Optional[Dict[str, Any]] = None
  31. @router.get("", response_model=PaginatedResponseSchema)
  32. async def get_snippets(
  33. page: int = Query(1, ge=1),
  34. page_size: int = Query(10, ge=1),
  35. kb: Optional[str] = Query(None, description="知识库集合名称"),
  36. keyword: Optional[str] = Query(None),
  37. status: Optional[str] = Query(None),
  38. credentials: HTTPAuthorizationCredentials = Depends(security)
  39. ):
  40. """获取知识片段列表 (跨集合查询)"""
  41. items, meta = await snippet_service.get_list(page, page_size, kb, keyword, status)
  42. return PaginatedResponseSchema(
  43. code=0,
  44. message="获取成功",
  45. data=items,
  46. meta=meta
  47. )
  48. @router.get("/detail", response_model=ResponseSchema)
  49. async def get_snippet_detail(
  50. kb: str = Query(..., description="知识库名称"),
  51. id: str = Query(..., description="片段ID (document_id 或 pk)"),
  52. credentials: HTTPAuthorizationCredentials = Depends(security),
  53. db: AsyncSession = Depends(get_db)
  54. ):
  55. """获取知识片段详情"""
  56. payload_token = verify_token(credentials.credentials)
  57. if not payload_token:
  58. return ResponseSchema(code=401, message="无效的访问令牌")
  59. data = await snippet_service.get_by_id(db, kb, id)
  60. if not data:
  61. return ResponseSchema(code=404, message="未找到该片段")
  62. return ResponseSchema(code=0, message="获取成功", data=data)
  63. @router.get("/export")
  64. async def export_snippets(
  65. kb: Optional[str] = Query(None, description="知识库集合名称"),
  66. keyword: Optional[str] = Query(None),
  67. status: Optional[str] = Query(None),
  68. credentials: HTTPAuthorizationCredentials = Depends(security)
  69. ):
  70. """导出知识片段"""
  71. payload_token = verify_token(credentials.credentials)
  72. if not payload_token:
  73. return ResponseSchema(code=401, message="无效的访问令牌")
  74. filename = f"snippets_export_{datetime.now().strftime('%Y%m%d%H%M%S')}.csv"
  75. encoded_filename = urllib.parse.quote(filename)
  76. return StreamingResponse(
  77. snippet_service.generate_csv_stream(kb, keyword),
  78. media_type="text/csv",
  79. headers={
  80. "Content-Disposition": f"attachment; filename={filename}; filename*=utf-8''{encoded_filename}"
  81. }
  82. )
  83. @router.post("", response_model=ResponseSchema)
  84. async def create_snippet(
  85. payload: SnippetCreate,
  86. credentials: HTTPAuthorizationCredentials = Depends(security),
  87. db: AsyncSession = Depends(get_db)
  88. ):
  89. """创建知识片段"""
  90. payload_token = verify_token(credentials.credentials)
  91. if not payload_token:
  92. return ResponseSchema(code=401, message="无效的访问令牌")
  93. data = await snippet_service.create(db, payload)
  94. return ResponseSchema(code=0, message="创建成功", data=data)
  95. @router.post("/{id}", response_model=ResponseSchema)
  96. async def update_snippet(
  97. id: str,
  98. payload: SnippetUpdate,
  99. credentials: HTTPAuthorizationCredentials = Depends(security),
  100. db: AsyncSession = Depends(get_db)
  101. ):
  102. """更新知识片段"""
  103. payload_token = verify_token(credentials.credentials)
  104. if not payload_token:
  105. return ResponseSchema(code=401, message="无效的访问令牌")
  106. msg = await snippet_service.update(db, id, payload)
  107. return ResponseSchema(code=0, message=msg)
  108. @router.post("/{id}/delete", response_model=ResponseSchema)
  109. async def delete_snippet(
  110. id: str,
  111. kb: str = Query(..., description="知识库名称"),
  112. credentials: HTTPAuthorizationCredentials = Depends(security),
  113. db: AsyncSession = Depends(get_db)
  114. ):
  115. """删除知识片段"""
  116. payload_token = verify_token(credentials.credentials)
  117. if not payload_token:
  118. return ResponseSchema(code=401, message="无效的访问令牌")
  119. snippet_service.delete(id, kb)
  120. # 更新知识库文档数量并返回最新计数,便于前端立即刷新展示
  121. new_count = await knowledge_base_service.update_doc_count(db, kb)
  122. return ResponseSchema(code=0, message="删除成功", data={"document_count": new_count})