local_model_router.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. """
  2. 本地模型API路由
  3. 提供本地模型管理的RESTful API端点
  4. 需求: 2.1, 2.8, 4.5, 4.6, 4.7
  5. """
  6. from typing import List
  7. from fastapi import APIRouter, Depends, HTTPException, status
  8. from sqlalchemy.orm import Session
  9. from app.database import get_db
  10. from app.dependencies.auth import get_current_user
  11. from app.dependencies.admin_auth import get_current_admin
  12. from app.models.user import User
  13. from app.models.admin import AdminUser
  14. from app.services.local_model_service import LocalModelService
  15. from app.services.system_config_manager import get_config_bool
  16. from app.schemas.local_model import (
  17. LocalModelCreate,
  18. LocalModelUpdate,
  19. LocalModelResponse,
  20. ConnectionTestRequest,
  21. ConnectionTestResponse
  22. )
  23. from app.schemas.model_schema import ApiResponse
  24. router = APIRouter(prefix="/api/models/local", tags=["本地模型"])
  25. @router.post("", response_model=ApiResponse[LocalModelResponse])
  26. def create_local_model(
  27. request: LocalModelCreate,
  28. current_admin: AdminUser = Depends(get_current_admin),
  29. db: Session = Depends(get_db)
  30. ):
  31. """
  32. 添加本地模型
  33. 需求 2.1: 管理员可以添加本地部署的OpenAI API兼容模型
  34. """
  35. try:
  36. service = LocalModelService(db)
  37. model = service.create_local_model(
  38. user_id=None,
  39. name=request.name,
  40. supplier=request.supplier or 'Custom',
  41. base_url=request.base_url,
  42. api_key=request.api_key,
  43. visibility=request.visibility or 'global',
  44. categories=request.categories or [0]
  45. )
  46. return ApiResponse(code=200, message="success", data=LocalModelResponse.from_model(model))
  47. except HTTPException:
  48. raise
  49. except Exception as e:
  50. raise HTTPException(status_code=500, detail=f"Database error: {str(e)}")
  51. @router.get("", response_model=ApiResponse[List[LocalModelResponse]])
  52. async def get_local_models(
  53. current_user: User = Depends(get_current_user),
  54. db: Session = Depends(get_db)
  55. ):
  56. """
  57. 获取本地模型列表
  58. 需求 4.1: 用户可以查看自己添加的本地模型列表
  59. """
  60. service = LocalModelService(db)
  61. # 检查本地模型是否启用
  62. if get_config_bool("enable_local_models", True):
  63. # 如果本地模型启用,获取所有模型(无论visibility是什么)
  64. models = await service.get_all_local_models()
  65. else:
  66. # 如果本地模型未启用,获取用户有权限的模型
  67. models = await service.get_user_local_models(current_user.id)
  68. return ApiResponse(
  69. code=200,
  70. message="success",
  71. data=[LocalModelResponse.from_model(m) for m in models]
  72. )
  73. @router.put("/{model_id}", response_model=ApiResponse[LocalModelResponse])
  74. async def update_local_model(
  75. model_id: int,
  76. request: LocalModelUpdate,
  77. current_admin: AdminUser = Depends(get_current_admin),
  78. db: Session = Depends(get_db)
  79. ):
  80. """
  81. 更新本地模型
  82. 需求 4.5: 管理员可以编辑本地模型的配置
  83. """
  84. try:
  85. service = LocalModelService(db)
  86. model = service.update_local_model(
  87. model_id=model_id,
  88. user_id=None,
  89. name=request.name,
  90. supplier=request.supplier,
  91. base_url=request.base_url,
  92. api_key=request.api_key,
  93. visibility=request.visibility,
  94. categories=request.categories
  95. )
  96. if not model:
  97. raise HTTPException(status_code=404, detail="模型不存在或无权限")
  98. # 异步清理缓存
  99. from app.services.cache_service import CacheService
  100. await CacheService.delete_model(model_id)
  101. return ApiResponse(code=200, message="success", data=LocalModelResponse.from_model(model))
  102. except HTTPException:
  103. raise
  104. except Exception as e:
  105. raise HTTPException(status_code=500, detail=f"Database error: {str(e)}")
  106. @router.delete("/{model_id}", response_model=ApiResponse[dict])
  107. async def delete_local_model(
  108. model_id: int,
  109. current_admin: AdminUser = Depends(get_current_admin),
  110. db: Session = Depends(get_db)
  111. ):
  112. """
  113. 删除本地模型
  114. 需求 4.6: 管理员可以删除本地模型
  115. 需求 4.7: 删除前需要确认
  116. """
  117. service = LocalModelService(db)
  118. success = service.delete_local_model(model_id, None) # 本地模型不关联用户
  119. if not success:
  120. raise HTTPException(status_code=404, detail="模型不存在或无权限")
  121. # 异步清理缓存
  122. from app.services.cache_service import CacheService
  123. await CacheService.delete_model(model_id)
  124. return ApiResponse(code=200, message="success", data={"success": True})
  125. @router.post("/test", response_model=ApiResponse[ConnectionTestResponse])
  126. async def test_connection(
  127. request: ConnectionTestRequest,
  128. db: Session = Depends(get_db),
  129. current_admin: AdminUser = Depends(get_current_admin)
  130. ):
  131. """
  132. 测试本地模型连接(管理员)
  133. 需求 2.5: 提供"测试连接"按钮
  134. 需求 2.6: 测试成功显示绿色提示
  135. 需求 2.7: 测试失败显示红色错误信息
  136. """
  137. service = LocalModelService(db)
  138. result = await service.test_connection(request.base_url, request.api_key, request.model_name, (request.categories or [0])[0])
  139. return ApiResponse(code=200, message="success", data=ConnectionTestResponse(**result))
  140. @router.post("/test-public", response_model=ApiResponse[ConnectionTestResponse])
  141. async def test_connection_public(
  142. request: ConnectionTestRequest,
  143. db: Session = Depends(get_db),
  144. current_user: User = Depends(get_current_user)
  145. ):
  146. """
  147. 测试本地模型连接(普通用户)
  148. 需求 2.5: 提供"测试连接"按钮
  149. 需求 2.6: 测试成功显示绿色提示
  150. 需求 2.7: 测试失败显示红色错误信息
  151. """
  152. # 检查本地模型是否启用
  153. if not get_config_bool("enable_local_models", False):
  154. raise HTTPException(
  155. status_code=status.HTTP_403_FORBIDDEN,
  156. detail="本地模型功能已关闭"
  157. )
  158. service = LocalModelService(db)
  159. result = await service.test_connection(request.base_url, request.api_key, request.model_name, (request.categories or [0])[0])
  160. return ApiResponse(code=200, message="success", data=ConnectionTestResponse(**result))