models.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. from fastapi import APIRouter
  2. from app.schemas.model import ModelDownloadRequest, ModelDownloadResponse, ModelInfo
  3. from app.services import model_service
  4. router = APIRouter()
  5. @router.get("/", response_model=list[ModelInfo])
  6. async def list_models():
  7. """列出所有本地缓存的模型。"""
  8. models = model_service.list_cached_models()
  9. return [
  10. ModelInfo(
  11. id=m["id"],
  12. name=m.get("name", m["id"]),
  13. model_type=m.get("model_type", "text"),
  14. path=m.get("path"),
  15. is_downloaded=m.get("is_downloaded", True),
  16. context_length=m.get("context_length"),
  17. supported_peft_methods=m.get("supported_peft_methods", []),
  18. )
  19. for m in models
  20. ]
  21. @router.post("/download", response_model=ModelDownloadResponse)
  22. async def download_model(req: ModelDownloadRequest):
  23. """从 HuggingFace 或 ModelScope 下载模型。"""
  24. result = await model_service.download_model(req.model_id, req.use_modelscope)
  25. return ModelDownloadResponse(
  26. model_id=result["model_id"],
  27. status=result["status"],
  28. path=result.get("path"),
  29. error=result.get("error"),
  30. )
  31. @router.get("/{model_id}", response_model=ModelInfo)
  32. async def get_model_info(model_id: str):
  33. """获取已缓存模型的详细信息。"""
  34. info = await model_service.get_model_info(model_id)
  35. if info:
  36. return ModelInfo(
  37. id=info["id"],
  38. name=info.get("name", model_id.split("/")[-1]),
  39. model_type=info.get("model_type", "text"),
  40. path=info.get("path"),
  41. is_downloaded=info.get("is_downloaded", True),
  42. context_length=info.get("context_length"),
  43. supported_peft_methods=info.get("supported_peft_methods", []),
  44. )
  45. return ModelInfo(
  46. id=model_id,
  47. name=model_id.split("/")[-1],
  48. model_type="text",
  49. is_downloaded=False,
  50. )