from fastapi import APIRouter, HTTPException from app.schemas.model import ModelDownloadRequest, ModelDownloadResponse, ModelInfo from app.schemas.model_test import ModelTestRequest, ModelTestResponse from app.schemas.background_task import ModelDownloadTaskResponse from app.services import model_service, model_test_service router = APIRouter() @router.get("/", response_model=list[ModelInfo]) async def list_models(): """列出所有本地缓存的模型(从数据库读取)。""" models = await model_service.list_cached_models() return [ ModelInfo( id=m["id"], name=m["name"], model_type=m["model_type"], path=m["path"], is_downloaded=m["is_downloaded"], context_length=m["context_length"], supported_peft_methods=m["supported_peft_methods"], ) for m in models ] @router.post("/download", response_model=ModelDownloadTaskResponse, status_code=200) async def download_model(req: ModelDownloadRequest): """启动模型下载后台任务,立即返回 task_id。""" result = await model_service.download_model(req.model_id, req.use_modelscope) return ModelDownloadTaskResponse( task_id=result["task_id"], model_id=result["model_id"], status=result["status"], ) @router.get("/download/{task_id}", response_model=ModelDownloadTaskResponse) async def get_model_download_status(task_id: str): """查询模型下载任务状态。""" result = await model_service.get_model_download_status(task_id) if result.get("status") == "not_found": raise HTTPException(status_code=404, detail="Download task not found") return ModelDownloadTaskResponse(**result) @router.get("/downloads") async def list_model_downloads(): """列出所有模型下载任务。""" return await model_service.list_model_downloads() @router.post("/download/{task_id}/cancel") async def cancel_model_download(task_id: str): """取消模型下载任务。""" return await model_service.cancel_model_download(task_id) @router.get("/{model_id}", response_model=ModelInfo) async def get_model_info(model_id: str): """获取已缓存模型的详细信息。""" info = await model_service.get_model_info(model_id) if info: return ModelInfo( id=info["id"], name=info["name"], model_type=info["model_type"], path=info["path"], is_downloaded=info["is_downloaded"], context_length=info["context_length"], supported_peft_methods=info["supported_peft_methods"], ) raise HTTPException(status_code=404, detail=f"Model not found: {model_id}") @router.delete("/{model_id}") async def delete_model(model_id: str): """删除已缓存的模型(数据库记录 + 本地文件)。""" result = await model_service.delete_model(model_id) return result @router.post("/test", response_model=ModelTestResponse) async def test_model(req: ModelTestRequest): """快速测试已缓存模型的生成能力。""" result = await model_test_service.test_model( req.model_id, req.prompt, req.max_new_tokens, req.temperature, req.top_p, ) if "error" in result: raise HTTPException(status_code=400, detail=result["error"]) return ModelTestResponse(**result)