search_engine_view.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. """
  2. 检索引擎视图路由
  3. """
  4. from fastapi import APIRouter, Depends, Query, Path, Body, Request
  5. from sqlalchemy.ext.asyncio import AsyncSession
  6. from app.base.async_mysql_connection import get_db
  7. from app.services.search_engine_service import search_engine_service
  8. from app.sample.schemas.search_engine import (
  9. SearchEngineCreate,
  10. SearchEngineUpdate,
  11. SearchEngineResponse,
  12. KBSearchRequest,
  13. KBSearchResponse
  14. )
  15. from app.schemas.base import PaginatedResponseSchema, ResponseSchema
  16. from app.services.jwt_token import verify_token
  17. from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
  18. from app.utils.auth_dependency import get_current_user_with_refresh
  19. router = APIRouter(prefix="/sample/search-engine", tags=["样本中心-检索引擎"])
  20. security = HTTPBearer()
  21. # --- 新增:知识库搜索接口 ---
  22. @router.post("/search", response_model=ResponseSchema)
  23. async def search_knowledge_base(
  24. payload: KBSearchRequest,
  25. db: AsyncSession = Depends(get_db),
  26. current_user: dict = Depends(get_current_user_with_refresh)
  27. ):
  28. """知识库语义搜索"""
  29. try:
  30. result = await search_engine_service.search_kb(db, payload)
  31. return ResponseSchema(code=0, message="搜索成功", data=result)
  32. except ValueError as e:
  33. return ResponseSchema(code=400, message=str(e), data=KBSearchResponse(results=[], total=0))
  34. except Exception:
  35. return ResponseSchema(code=500, message="搜索失败", data=KBSearchResponse(results=[], total=0))
  36. # --- 原有接口 ---
  37. @router.get("", response_model=PaginatedResponseSchema)
  38. async def get_search_engines(
  39. page: int = Query(1, ge=1, description="页码"),
  40. page_size: int = Query(10, ge=1, le=1000, description="每页数量"),
  41. keyword: str = Query(None, description="搜索关键词"),
  42. status: str = Query(None, description="状态筛选"),
  43. db: AsyncSession = Depends(get_db),
  44. current_user: dict = Depends(get_current_user_with_refresh)
  45. ):
  46. """获取检索引擎列表"""
  47. items, meta = await search_engine_service.get_list(
  48. db, page=page, page_size=page_size, keyword=keyword, status=status
  49. )
  50. return PaginatedResponseSchema(
  51. code=0,
  52. message="获取检索引擎列表成功",
  53. data=[SearchEngineResponse.model_validate(item) for item in items],
  54. meta=meta,
  55. )
  56. @router.post("", response_model=ResponseSchema)
  57. async def create_search_engine(
  58. payload: SearchEngineCreate,
  59. db: AsyncSession = Depends(get_db),
  60. current_user: dict = Depends(get_current_user_with_refresh)
  61. ):
  62. """创建新检索引擎"""
  63. new_engine = await search_engine_service.create(db, payload)
  64. return ResponseSchema(code=0, message="创建成功", data=SearchEngineResponse.model_validate(new_engine))
  65. @router.post("/{id}", response_model=ResponseSchema)
  66. async def update_search_engine(
  67. payload: SearchEngineUpdate,
  68. id: str = Path(..., description="检索引擎ID"),
  69. db: AsyncSession = Depends(get_db),
  70. current_user: dict = Depends(get_current_user_with_refresh)
  71. ):
  72. """更新检索引擎信息"""
  73. engine = await search_engine_service.update(db, id, payload)
  74. return ResponseSchema(code=0, message="更新成功", data=SearchEngineResponse.model_validate(engine))
  75. @router.post("/{id}/status", response_model=ResponseSchema)
  76. async def update_search_engine_status(
  77. id: str = Path(..., description="检索引擎ID"),
  78. status: str = Query(..., description="状态: normal/disabled"),
  79. db: AsyncSession = Depends(get_db),
  80. current_user: dict = Depends(get_current_user_with_refresh)
  81. ):
  82. """更新检索引擎状态"""
  83. await search_engine_service.update_status(db, id, status)
  84. return ResponseSchema(code=0, message=f"状态已更新为 {status}")
  85. @router.post("/{id}/delete", response_model=ResponseSchema)
  86. async def delete_search_engine(
  87. id: str = Path(..., description="检索引擎ID"),
  88. db: AsyncSession = Depends(get_db),
  89. current_user: dict = Depends(get_current_user_with_refresh)
  90. ):
  91. """删除检索引擎"""
  92. await search_engine_service.delete(db, id)
  93. return ResponseSchema(code=0, message="删除成功")