models.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. from fastapi import APIRouter, HTTPException
  2. from app.schemas.model import ModelDownloadRequest, ModelDownloadResponse, ModelInfo
  3. from app.schemas.model_test import ModelTestRequest, ModelTestResponse
  4. from app.schemas.background_task import ModelDownloadTaskResponse
  5. from app.services import model_service, model_test_service
  6. router = APIRouter()
  7. @router.get("/", response_model=list[ModelInfo])
  8. async def list_models():
  9. """列出所有本地缓存的模型(从数据库读取)。"""
  10. models = await model_service.list_cached_models()
  11. return [
  12. ModelInfo(
  13. id=m["id"],
  14. name=m["name"],
  15. model_type=m["model_type"],
  16. path=m["path"],
  17. is_downloaded=m["is_downloaded"],
  18. context_length=m["context_length"],
  19. supported_peft_methods=m["supported_peft_methods"],
  20. )
  21. for m in models
  22. ]
  23. @router.post("/download", response_model=ModelDownloadTaskResponse, status_code=200)
  24. async def download_model(req: ModelDownloadRequest):
  25. """启动模型下载后台任务,立即返回 task_id。"""
  26. result = await model_service.download_model(req.model_id, req.use_modelscope)
  27. return ModelDownloadTaskResponse(
  28. task_id=result["task_id"],
  29. model_id=result["model_id"],
  30. status=result["status"],
  31. )
  32. @router.get("/download/{task_id}", response_model=ModelDownloadTaskResponse)
  33. async def get_model_download_status(task_id: str):
  34. """查询模型下载任务状态。"""
  35. result = await model_service.get_model_download_status(task_id)
  36. if result.get("status") == "not_found":
  37. raise HTTPException(status_code=404, detail="Download task not found")
  38. return ModelDownloadTaskResponse(**result)
  39. @router.get("/downloads")
  40. async def list_model_downloads():
  41. """列出所有模型下载任务。"""
  42. return await model_service.list_model_downloads()
  43. @router.post("/download/{task_id}/cancel")
  44. async def cancel_model_download(task_id: str):
  45. """取消模型下载任务。"""
  46. return await model_service.cancel_model_download(task_id)
  47. @router.get("/{model_id}", response_model=ModelInfo)
  48. async def get_model_info(model_id: str):
  49. """获取已缓存模型的详细信息。"""
  50. info = await model_service.get_model_info(model_id)
  51. if info:
  52. return ModelInfo(
  53. id=info["id"],
  54. name=info["name"],
  55. model_type=info["model_type"],
  56. path=info["path"],
  57. is_downloaded=info["is_downloaded"],
  58. context_length=info["context_length"],
  59. supported_peft_methods=info["supported_peft_methods"],
  60. )
  61. raise HTTPException(status_code=404, detail=f"Model not found: {model_id}")
  62. @router.delete("/{model_id}")
  63. async def delete_model(model_id: str):
  64. """删除已缓存的模型(数据库记录 + 本地文件)。"""
  65. result = await model_service.delete_model(model_id)
  66. return result
  67. @router.post("/test", response_model=ModelTestResponse)
  68. async def test_model(req: ModelTestRequest):
  69. """快速测试已缓存模型的生成能力。"""
  70. result = await model_test_service.test_model(
  71. req.model_id, req.prompt, req.max_new_tokens, req.temperature, req.top_p,
  72. )
  73. if "error" in result:
  74. raise HTTPException(status_code=400, detail=result["error"])
  75. return ModelTestResponse(**result)