knowledge_base_view.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. """
  2. 知识库视图路由
  3. """
  4. from fastapi import APIRouter, Depends, Query, Path
  5. from sqlalchemy.ext.asyncio import AsyncSession
  6. from typing import Optional
  7. from app.base.async_mysql_connection import get_db
  8. from app.services.knowledge_base_service import knowledge_base_service
  9. from app.sample.schemas.knowledge_base import (
  10. KnowledgeBaseCreate,
  11. KnowledgeBaseUpdate,
  12. KnowledgeBaseResponse
  13. )
  14. from app.schemas.base import PaginatedResponseSchema, ResponseSchema
  15. from app.services.jwt_token import verify_token
  16. from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
  17. router = APIRouter(prefix="/sample/knowledge-base", tags=["样本中心-知识库"])
  18. security = HTTPBearer()
  19. @router.get("", response_model=PaginatedResponseSchema)
  20. async def get_knowledge_bases(
  21. page: int = Query(1, ge=1, description="页码"),
  22. page_size: int = Query(10, ge=1, le=100, description="每页数量"),
  23. keyword: str = Query(None, description="搜索关键词"),
  24. status: str = Query(None, description="状态筛选"),
  25. db: AsyncSession = Depends(get_db),
  26. credentials: HTTPAuthorizationCredentials = Depends(security)
  27. ):
  28. """获取知识库列表"""
  29. # 鉴权
  30. payload = verify_token(credentials.credentials)
  31. if not payload:
  32. return PaginatedResponseSchema(code=401, message="无效的访问令牌", data=[], meta=None)
  33. items, meta = await knowledge_base_service.get_list(
  34. db, page=page, page_size=page_size, keyword=keyword, status=status
  35. )
  36. return PaginatedResponseSchema(
  37. code=0,
  38. message="获取知识库列表成功",
  39. data=[KnowledgeBaseResponse.model_validate(item) for item in items],
  40. meta=meta,
  41. )
  42. @router.post("", response_model=ResponseSchema)
  43. async def create_knowledge_base(
  44. payload: KnowledgeBaseCreate,
  45. db: AsyncSession = Depends(get_db),
  46. credentials: HTTPAuthorizationCredentials = Depends(security)
  47. ):
  48. """创建新知识库"""
  49. payload_token = verify_token(credentials.credentials)
  50. if not payload_token:
  51. return ResponseSchema(code=401, message="无效的访问令牌")
  52. new_kb = await knowledge_base_service.create(db, payload)
  53. return ResponseSchema(code=0, message="创建成功", data=KnowledgeBaseResponse.model_validate(new_kb))
  54. @router.post("/{id}", response_model=ResponseSchema)
  55. async def update_knowledge_base(
  56. payload: KnowledgeBaseUpdate,
  57. id: str = Path(..., description="知识库ID"),
  58. db: AsyncSession = Depends(get_db),
  59. credentials: HTTPAuthorizationCredentials = Depends(security)
  60. ):
  61. """更新知识库信息"""
  62. payload_token = verify_token(credentials.credentials)
  63. if not payload_token:
  64. return ResponseSchema(code=401, message="无效的访问令牌")
  65. kb = await knowledge_base_service.update(db, id, payload)
  66. return ResponseSchema(code=0, message="更新成功", data=KnowledgeBaseResponse.model_validate(kb))
  67. @router.post("/{id}/status", response_model=ResponseSchema)
  68. async def update_knowledge_base_status(
  69. id: str = Path(..., description="知识库ID"),
  70. status: str = Query(..., description="状态: normal/test/disabled"),
  71. db: AsyncSession = Depends(get_db),
  72. credentials: HTTPAuthorizationCredentials = Depends(security)
  73. ):
  74. """更新知识库状态"""
  75. payload_token = verify_token(credentials.credentials)
  76. if not payload_token:
  77. return ResponseSchema(code=401, message="无效的访问令牌")
  78. await knowledge_base_service.update_status(db, id, status)
  79. return ResponseSchema(code=0, message=f"状态已更新为 {status}")
  80. @router.post("/{id}/delete", response_model=ResponseSchema)
  81. async def delete_knowledge_base(
  82. id: str = Path(..., description="知识库ID"),
  83. db: AsyncSession = Depends(get_db),
  84. credentials: HTTPAuthorizationCredentials = Depends(security)
  85. ):
  86. """删除知识库"""
  87. payload_token = verify_token(credentials.credentials)
  88. if not payload_token:
  89. return ResponseSchema(code=401, message="无效的访问令牌")
  90. await knowledge_base_service.delete(db, id)
  91. return ResponseSchema(code=0, message="删除成功")
  92. @router.get("/{id}/metadata", response_model=ResponseSchema)
  93. async def get_knowledge_base_metadata(
  94. id: str = Path(..., description="知识库ID"),
  95. db: AsyncSession = Depends(get_db),
  96. credentials: HTTPAuthorizationCredentials = Depends(security)
  97. ):
  98. """获取知识库的元数据字段定义和自定义Schema"""
  99. payload_token = verify_token(credentials.credentials)
  100. if not payload_token:
  101. return ResponseSchema(code=401, message="无效的访问令牌")
  102. try:
  103. data = await knowledge_base_service.get_metadata_and_schema(db, id)
  104. return ResponseSchema(code=0, message="获取成功", data=data)
  105. except ValueError as e:
  106. return ResponseSchema(code=400, message=str(e))
  107. except Exception as e:
  108. return ResponseSchema(code=500, message=f"获取失败: {str(e)}")
  109. @router.post("/{id}/sync", response_model=ResponseSchema)
  110. async def sync_knowledge_base(
  111. id: str = Path(..., description="知识库ID"),
  112. db: AsyncSession = Depends(get_db),
  113. credentials: HTTPAuthorizationCredentials = Depends(security)
  114. ):
  115. """同步创建Milvus集合"""
  116. payload_token = verify_token(credentials.credentials)
  117. if not payload_token:
  118. return ResponseSchema(code=401, message="无效的访问令牌")
  119. try:
  120. await knowledge_base_service.sync_to_milvus(db, id)
  121. return ResponseSchema(code=0, message="同步成功")
  122. except ValueError as e:
  123. return ResponseSchema(code=400, message=str(e))
  124. except Exception as e:
  125. return ResponseSchema(code=500, message=f"同步失败: {str(e)}")