models.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. from fastapi import APIRouter, HTTPException
  2. from app.schemas.model import ModelDownloadRequest, ModelDownloadResponse, ModelInfo
  3. from app.schemas.model_test import ModelTestRequest, ModelTestResponse
  4. from app.services import model_service, model_test_service
  5. router = APIRouter()
  6. @router.get("/", response_model=list[ModelInfo])
  7. async def list_models():
  8. """列出所有本地缓存的模型(从数据库读取)。"""
  9. models = await model_service.list_cached_models()
  10. return [
  11. ModelInfo(
  12. id=m["id"],
  13. name=m["name"],
  14. model_type=m["model_type"],
  15. path=m["path"],
  16. is_downloaded=m["is_downloaded"],
  17. context_length=m["context_length"],
  18. supported_peft_methods=m["supported_peft_methods"],
  19. )
  20. for m in models
  21. ]
  22. @router.post("/download", response_model=ModelDownloadResponse, status_code=200)
  23. async def download_model(req: ModelDownloadRequest):
  24. """从 HuggingFace 或 ModelScope 下载模型。"""
  25. result = await model_service.download_model(req.model_id, req.use_modelscope)
  26. if result["status"] == "failed":
  27. raise HTTPException(status_code=400, detail=result.get("error", "Download failed"))
  28. return ModelDownloadResponse(
  29. model_id=result["model_id"],
  30. status=result["status"],
  31. path=result.get("path"),
  32. error=result.get("error"),
  33. )
  34. @router.get("/{model_id}", response_model=ModelInfo)
  35. async def get_model_info(model_id: str):
  36. """获取已缓存模型的详细信息。"""
  37. info = await model_service.get_model_info(model_id)
  38. if info:
  39. return ModelInfo(
  40. id=info["id"],
  41. name=info["name"],
  42. model_type=info["model_type"],
  43. path=info["path"],
  44. is_downloaded=info["is_downloaded"],
  45. context_length=info["context_length"],
  46. supported_peft_methods=info["supported_peft_methods"],
  47. )
  48. raise HTTPException(status_code=404, detail=f"Model not found: {model_id}")
  49. @router.delete("/{model_id}")
  50. async def delete_model(model_id: str):
  51. """删除已缓存的模型(数据库记录 + 本地文件)。"""
  52. result = await model_service.delete_model(model_id)
  53. return result
  54. @router.post("/test", response_model=ModelTestResponse)
  55. async def test_model(req: ModelTestRequest):
  56. """快速测试已缓存模型的生成能力。"""
  57. result = await model_test_service.test_model(
  58. req.model_id, req.prompt, req.max_new_tokens, req.temperature, req.top_p,
  59. )
  60. if "error" in result:
  61. raise HTTPException(status_code=400, detail=result["error"])
  62. return ModelTestResponse(**result)