| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758 |
- from fastapi import APIRouter
- from app.schemas.model import ModelDownloadRequest, ModelDownloadResponse, ModelInfo
- from app.services import model_service
- router = APIRouter()
- @router.get("/", response_model=list[ModelInfo])
- async def list_models():
- """列出所有本地缓存的模型。"""
- models = model_service.list_cached_models()
- return [
- ModelInfo(
- id=m["id"],
- name=m.get("name", m["id"]),
- model_type=m.get("model_type", "text"),
- path=m.get("path"),
- is_downloaded=m.get("is_downloaded", True),
- context_length=m.get("context_length"),
- supported_peft_methods=m.get("supported_peft_methods", []),
- )
- for m in models
- ]
- @router.post("/download", response_model=ModelDownloadResponse)
- async def download_model(req: ModelDownloadRequest):
- """从 HuggingFace 或 ModelScope 下载模型。"""
- result = await model_service.download_model(req.model_id, req.use_modelscope)
- return ModelDownloadResponse(
- model_id=result["model_id"],
- status=result["status"],
- path=result.get("path"),
- error=result.get("error"),
- )
- @router.get("/{model_id}", response_model=ModelInfo)
- async def get_model_info(model_id: str):
- """获取已缓存模型的详细信息。"""
- info = await model_service.get_model_info(model_id)
- if info:
- return ModelInfo(
- id=info["id"],
- name=info.get("name", model_id.split("/")[-1]),
- model_type=info.get("model_type", "text"),
- path=info.get("path"),
- is_downloaded=info.get("is_downloaded", True),
- context_length=info.get("context_length"),
- supported_peft_methods=info.get("supported_peft_methods", []),
- )
- return ModelInfo(
- id=model_id,
- name=model_id.split("/")[-1],
- model_type="text",
- is_downloaded=False,
- )
|