from fastapi import APIRouter, HTTPException from app.schemas.model import ModelDownloadRequest, ModelDownloadResponse, ModelInfo from app.schemas.model_test import ModelTestRequest, ModelTestResponse from app.services import model_service, model_test_service router = APIRouter() @router.get("/", response_model=list[ModelInfo]) async def list_models(): """列出所有本地缓存的模型(从数据库读取)。""" models = await model_service.list_cached_models() return [ ModelInfo( id=m["id"], name=m["name"], model_type=m["model_type"], path=m["path"], is_downloaded=m["is_downloaded"], context_length=m["context_length"], supported_peft_methods=m["supported_peft_methods"], ) for m in models ] @router.post("/download", response_model=ModelDownloadResponse, status_code=200) async def download_model(req: ModelDownloadRequest): """从 HuggingFace 或 ModelScope 下载模型。""" result = await model_service.download_model(req.model_id, req.use_modelscope) if result["status"] == "failed": raise HTTPException(status_code=400, detail=result.get("error", "Download failed")) return ModelDownloadResponse( model_id=result["model_id"], status=result["status"], path=result.get("path"), error=result.get("error"), ) @router.get("/{model_id}", response_model=ModelInfo) async def get_model_info(model_id: str): """获取已缓存模型的详细信息。""" info = await model_service.get_model_info(model_id) if info: return ModelInfo( id=info["id"], name=info["name"], model_type=info["model_type"], path=info["path"], is_downloaded=info["is_downloaded"], context_length=info["context_length"], supported_peft_methods=info["supported_peft_methods"], ) raise HTTPException(status_code=404, detail=f"Model not found: {model_id}") @router.delete("/{model_id}") async def delete_model(model_id: str): """删除已缓存的模型(数据库记录 + 本地文件)。""" result = await model_service.delete_model(model_id) return result @router.post("/test", response_model=ModelTestResponse) async def test_model(req: ModelTestRequest): """快速测试已缓存模型的生成能力。""" result = await model_test_service.test_model( req.model_id, req.prompt, req.max_new_tokens, req.temperature, req.top_p, ) if "error" in result: raise HTTPException(status_code=400, detail=result["error"]) return ModelTestResponse(**result)