knowledge_base_view.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. """
  2. 知识库视图路由
  3. """
  4. from fastapi import APIRouter, Depends, Query, Path
  5. from sqlalchemy.ext.asyncio import AsyncSession
  6. from typing import Optional
  7. from app.base.async_mysql_connection import get_db
  8. from app.services.knowledge_base_service import knowledge_base_service
  9. from app.sample.schemas.knowledge_base import (
  10. KnowledgeBaseCreate,
  11. KnowledgeBaseUpdate,
  12. KnowledgeBaseResponse
  13. )
  14. from app.schemas.base import PaginatedResponseSchema, ResponseSchema
  15. from app.services.jwt_token import verify_token
  16. from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
  17. router = APIRouter(prefix="/sample/knowledge-base", tags=["样本中心-知识库"])
  18. security = HTTPBearer()
  19. @router.get("", response_model=PaginatedResponseSchema)
  20. async def get_knowledge_bases(
  21. page: int = Query(1, ge=1, description="页码"),
  22. page_size: int = Query(10, ge=1, le=1000, description="每页数量"),
  23. keyword: str = Query(None, description="搜索关键词"),
  24. status: str = Query(None, description="状态筛选"),
  25. db: AsyncSession = Depends(get_db),
  26. credentials: HTTPAuthorizationCredentials = Depends(security)
  27. ):
  28. """获取知识库列表"""
  29. # 鉴权
  30. payload = verify_token(credentials.credentials)
  31. if not payload:
  32. return PaginatedResponseSchema(code=401, message="无效的访问令牌", data=[], meta=None)
  33. items, meta = await knowledge_base_service.get_list(
  34. db, page=page, page_size=page_size, keyword=keyword, status=status
  35. )
  36. return PaginatedResponseSchema(
  37. code=0,
  38. message="获取知识库列表成功",
  39. data=[KnowledgeBaseResponse.model_validate(item) for item in items],
  40. meta=meta,
  41. )
  42. @router.get("/list", response_model=ResponseSchema)
  43. async def get_knowledge_base_simple_list(
  44. db: AsyncSession = Depends(get_db),
  45. credentials: HTTPAuthorizationCredentials = Depends(security)
  46. ):
  47. """获取知识库简单列表 (用于下拉选择)"""
  48. payload = verify_token(credentials.credentials)
  49. if not payload:
  50. return ResponseSchema(code=401, message="无效的访问令牌")
  51. # 只获取状态正常的知识库
  52. items, _ = await knowledge_base_service.get_list(db, page=1, page_size=1000, status="normal")
  53. return ResponseSchema(
  54. code=0,
  55. message="获取成功",
  56. data=[{"id": item.id, "name": item.name, "collection_name": item.collection_name_children or item.collection_name_parent} for item in items]
  57. )
  58. @router.post("", response_model=ResponseSchema)
  59. async def create_knowledge_base(
  60. payload: KnowledgeBaseCreate,
  61. db: AsyncSession = Depends(get_db),
  62. credentials: HTTPAuthorizationCredentials = Depends(security)
  63. ):
  64. """创建新知识库"""
  65. payload_token = verify_token(credentials.credentials)
  66. if not payload_token:
  67. return ResponseSchema(code=401, message="无效的访问令牌")
  68. try:
  69. new_kb = await knowledge_base_service.create(db, payload)
  70. return ResponseSchema(code=0, message="创建成功", data=KnowledgeBaseResponse.model_validate(new_kb))
  71. except ValueError as e:
  72. return ResponseSchema(code=400, message=str(e))
  73. except Exception:
  74. return ResponseSchema(code=500, message="创建失败")
  75. @router.post("/{id}", response_model=ResponseSchema)
  76. async def update_knowledge_base(
  77. payload: KnowledgeBaseUpdate,
  78. id: str = Path(..., description="知识库ID"),
  79. db: AsyncSession = Depends(get_db),
  80. credentials: HTTPAuthorizationCredentials = Depends(security)
  81. ):
  82. """更新知识库信息"""
  83. payload_token = verify_token(credentials.credentials)
  84. if not payload_token:
  85. return ResponseSchema(code=401, message="无效的访问令牌")
  86. try:
  87. kb = await knowledge_base_service.update(db, id, payload)
  88. return ResponseSchema(code=0, message="更新成功", data=KnowledgeBaseResponse.model_validate(kb))
  89. except ValueError as e:
  90. return ResponseSchema(code=400, message=str(e))
  91. except Exception:
  92. return ResponseSchema(code=500, message="更新失败")
  93. @router.post("/{id}/status", response_model=ResponseSchema)
  94. async def update_knowledge_base_status(
  95. id: str = Path(..., description="知识库ID"),
  96. status: str = Query(..., description="状态: normal/test/disabled"),
  97. db: AsyncSession = Depends(get_db),
  98. credentials: HTTPAuthorizationCredentials = Depends(security)
  99. ):
  100. """更新知识库状态"""
  101. payload_token = verify_token(credentials.credentials)
  102. if not payload_token:
  103. return ResponseSchema(code=401, message="无效的访问令牌")
  104. await knowledge_base_service.update_status(db, id, status)
  105. return ResponseSchema(code=0, message=f"状态已更新为 {status}")
  106. @router.post("/{id}/delete", response_model=ResponseSchema)
  107. async def delete_knowledge_base(
  108. id: str = Path(..., description="知识库ID"),
  109. db: AsyncSession = Depends(get_db),
  110. credentials: HTTPAuthorizationCredentials = Depends(security)
  111. ):
  112. """删除知识库"""
  113. payload_token = verify_token(credentials.credentials)
  114. if not payload_token:
  115. return ResponseSchema(code=401, message="无效的访问令牌")
  116. await knowledge_base_service.delete(db, id)
  117. return ResponseSchema(code=0, message="删除成功")
  118. @router.get("/{id}/metadata", response_model=ResponseSchema)
  119. async def get_knowledge_base_metadata(
  120. id: str = Path(..., description="知识库ID"),
  121. db: AsyncSession = Depends(get_db),
  122. credentials: HTTPAuthorizationCredentials = Depends(security)
  123. ):
  124. """获取知识库的元数据字段定义和自定义Schema"""
  125. payload_token = verify_token(credentials.credentials)
  126. if not payload_token:
  127. return ResponseSchema(code=401, message="无效的访问令牌")
  128. try:
  129. data = await knowledge_base_service.get_metadata_and_schema(db, id)
  130. return ResponseSchema(code=0, message="获取成功", data=data)
  131. except ValueError as e:
  132. return ResponseSchema(code=400, message=str(e))
  133. except Exception as e:
  134. return ResponseSchema(code=500, message=f"获取失败: {str(e)}")
  135. @router.post("/{id}/sync", response_model=ResponseSchema)
  136. async def sync_knowledge_base(
  137. id: str = Path(..., description="知识库ID"),
  138. db: AsyncSession = Depends(get_db),
  139. credentials: HTTPAuthorizationCredentials = Depends(security)
  140. ):
  141. """同步创建Milvus集合"""
  142. payload_token = verify_token(credentials.credentials)
  143. if not payload_token:
  144. return ResponseSchema(code=401, message="无效的访问令牌")
  145. try:
  146. await knowledge_base_service.sync_to_milvus(db, id)
  147. return ResponseSchema(code=0, message="同步成功")
  148. except ValueError as e:
  149. return ResponseSchema(code=400, message=str(e))
  150. except Exception as e:
  151. return ResponseSchema(code=500, message=f"同步失败: {str(e)}")