models.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. from fastapi import APIRouter, HTTPException
  2. from app.schemas.model import ModelDownloadRequest, ModelDownloadResponse, ModelInfo
  3. from app.schemas.model_test import ModelTestRequest, ModelTestResponse
  4. from app.services import model_service, model_test_service
  5. router = APIRouter()
  6. @router.get("/", response_model=list[ModelInfo])
  7. async def list_models():
  8. """列出所有本地缓存的模型(从数据库读取)。"""
  9. models = await model_service.list_cached_models()
  10. return [
  11. ModelInfo(
  12. id=m["id"],
  13. name=m["name"],
  14. model_type=m["model_type"],
  15. path=m["path"],
  16. is_downloaded=m["is_downloaded"],
  17. context_length=m["context_length"],
  18. supported_peft_methods=m["supported_peft_methods"],
  19. )
  20. for m in models
  21. ]
  22. @router.post("/download", response_model=ModelDownloadResponse, status_code=200)
  23. async def download_model(req: ModelDownloadRequest):
  24. """从 HuggingFace 或 ModelScope 下载模型。"""
  25. from app.core.logging import logger
  26. logger.info(f"[DEBUG] download request: model_id={req.model_id}, use_modelscope={req.use_modelscope}, type={type(req.use_modelscope)}")
  27. result = await model_service.download_model(req.model_id, req.use_modelscope)
  28. logger.info(f"[DEBUG] download result: {result}")
  29. if result["status"] == "failed":
  30. raise HTTPException(status_code=400, detail=result.get("error", "Download failed"))
  31. return ModelDownloadResponse(
  32. model_id=result["model_id"],
  33. status=result["status"],
  34. path=result.get("path"),
  35. error=result.get("error"),
  36. )
  37. @router.get("/{model_id}", response_model=ModelInfo)
  38. async def get_model_info(model_id: str):
  39. """获取已缓存模型的详细信息。"""
  40. info = await model_service.get_model_info(model_id)
  41. if info:
  42. return ModelInfo(
  43. id=info["id"],
  44. name=info["name"],
  45. model_type=info["model_type"],
  46. path=info["path"],
  47. is_downloaded=info["is_downloaded"],
  48. context_length=info["context_length"],
  49. supported_peft_methods=info["supported_peft_methods"],
  50. )
  51. raise HTTPException(status_code=404, detail=f"Model not found: {model_id}")
  52. @router.delete("/{model_id}")
  53. async def delete_model(model_id: str):
  54. """删除已缓存的模型(数据库记录 + 本地文件)。"""
  55. result = await model_service.delete_model(model_id)
  56. return result
  57. @router.post("/test", response_model=ModelTestResponse)
  58. async def test_model(req: ModelTestRequest):
  59. """快速测试已缓存模型的生成能力。"""
  60. result = await model_test_service.test_model(
  61. req.model_id, req.prompt, req.max_new_tokens, req.temperature, req.top_p,
  62. )
  63. if "error" in result:
  64. raise HTTPException(status_code=400, detail=result["error"])
  65. return ModelTestResponse(**result)