models.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. from fastapi import APIRouter, HTTPException
  2. from app.schemas.model import ModelDownloadRequest, ModelDownloadResponse, ModelInfo
  3. from app.services import model_service
  4. router = APIRouter()
  5. @router.get("/", response_model=list[ModelInfo])
  6. async def list_models():
  7. """列出所有本地缓存的模型(从数据库读取)。"""
  8. models = await model_service.list_cached_models()
  9. return [
  10. ModelInfo(
  11. id=m["id"],
  12. name=m["name"],
  13. model_type=m["model_type"],
  14. path=m["path"],
  15. is_downloaded=m["is_downloaded"],
  16. context_length=m["context_length"],
  17. supported_peft_methods=m["supported_peft_methods"],
  18. )
  19. for m in models
  20. ]
  21. @router.post("/download", response_model=ModelDownloadResponse, status_code=200)
  22. async def download_model(req: ModelDownloadRequest):
  23. """从 HuggingFace 或 ModelScope 下载模型。"""
  24. result = await model_service.download_model(req.model_id, req.use_modelscope)
  25. if result["status"] == "failed":
  26. raise HTTPException(status_code=400, detail=result.get("error", "Download failed"))
  27. return ModelDownloadResponse(
  28. model_id=result["model_id"],
  29. status=result["status"],
  30. path=result.get("path"),
  31. error=result.get("error"),
  32. )
  33. @router.get("/{model_id}", response_model=ModelInfo)
  34. async def get_model_info(model_id: str):
  35. """获取已缓存模型的详细信息。"""
  36. info = await model_service.get_model_info(model_id)
  37. if info:
  38. return ModelInfo(
  39. id=info["id"],
  40. name=info["name"],
  41. model_type=info["model_type"],
  42. path=info["path"],
  43. is_downloaded=info["is_downloaded"],
  44. context_length=info["context_length"],
  45. supported_peft_methods=info["supported_peft_methods"],
  46. )
  47. raise HTTPException(status_code=404, detail=f"Model not found: {model_id}")
  48. @router.delete("/{model_id}")
  49. async def delete_model(model_id: str):
  50. """删除已缓存的模型(数据库记录 + 本地文件)。"""
  51. result = await model_service.delete_model(model_id)
  52. return result