| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364 |
- from fastapi import APIRouter, HTTPException
- from app.schemas.model import ModelDownloadRequest, ModelDownloadResponse, ModelInfo
- from app.services import model_service
- router = APIRouter()
- @router.get("/", response_model=list[ModelInfo])
- async def list_models():
- """列出所有本地缓存的模型(从数据库读取)。"""
- models = await model_service.list_cached_models()
- return [
- ModelInfo(
- id=m["id"],
- name=m["name"],
- model_type=m["model_type"],
- path=m["path"],
- is_downloaded=m["is_downloaded"],
- context_length=m["context_length"],
- supported_peft_methods=m["supported_peft_methods"],
- )
- for m in models
- ]
- @router.post("/download", response_model=ModelDownloadResponse, status_code=200)
- async def download_model(req: ModelDownloadRequest):
- """从 HuggingFace 或 ModelScope 下载模型。"""
- result = await model_service.download_model(req.model_id, req.use_modelscope)
- if result["status"] == "failed":
- raise HTTPException(status_code=400, detail=result.get("error", "Download failed"))
- return ModelDownloadResponse(
- model_id=result["model_id"],
- status=result["status"],
- path=result.get("path"),
- error=result.get("error"),
- )
- @router.get("/{model_id}", response_model=ModelInfo)
- async def get_model_info(model_id: str):
- """获取已缓存模型的详细信息。"""
- info = await model_service.get_model_info(model_id)
- if info:
- return ModelInfo(
- id=info["id"],
- name=info["name"],
- model_type=info["model_type"],
- path=info["path"],
- is_downloaded=info["is_downloaded"],
- context_length=info["context_length"],
- supported_peft_methods=info["supported_peft_methods"],
- )
- raise HTTPException(status_code=404, detail=f"Model not found: {model_id}")
- @router.delete("/{model_id}")
- async def delete_model(model_id: str):
- """删除已缓存的模型(数据库记录 + 本地文件)。"""
- result = await model_service.delete_model(model_id)
- return result
|