from fastapi import APIRouter from app.schemas.model import ModelDownloadRequest, ModelDownloadResponse, ModelInfo from app.services import model_service router = APIRouter() @router.get("/", response_model=list[ModelInfo]) async def list_models(): """列出所有本地缓存的模型。""" models = model_service.list_cached_models() return [ ModelInfo( id=m["id"], name=m.get("name", m["id"]), model_type=m.get("model_type", "text"), path=m.get("path"), is_downloaded=m.get("is_downloaded", True), context_length=m.get("context_length"), supported_peft_methods=m.get("supported_peft_methods", []), ) for m in models ] @router.post("/download", response_model=ModelDownloadResponse) async def download_model(req: ModelDownloadRequest): """从 HuggingFace 或 ModelScope 下载模型。""" result = await model_service.download_model(req.model_id, req.use_modelscope) return ModelDownloadResponse( model_id=result["model_id"], status=result["status"], path=result.get("path"), error=result.get("error"), ) @router.get("/{model_id}", response_model=ModelInfo) async def get_model_info(model_id: str): """获取已缓存模型的详细信息。""" info = await model_service.get_model_info(model_id) if info: return ModelInfo( id=info["id"], name=info.get("name", model_id.split("/")[-1]), model_type=info.get("model_type", "text"), path=info.get("path"), is_downloaded=info.get("is_downloaded", True), context_length=info.get("context_length"), supported_peft_methods=info.get("supported_peft_methods", []), ) return ModelInfo( id=model_id, name=model_id.split("/")[-1], model_type="text", is_downloaded=False, )