| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139 |
- import os
- # 禁用 FlashAttention,解决沐曦显卡共享内存不足问题
- # 必须放在最开头,在任何库导入之前设置
- os.environ["PYTORCH_NO_FLASH"] = "1"
- os.environ["FLASH_ATTENTION_ENABLED"] = "0"
- os.environ["USE_FLASH_ATTENTION"] = "0"
- os.environ["TORCH_FLASH_ATTN"] = "0"
- from contextlib import asynccontextmanager
- from fastapi import Depends, FastAPI
- from fastapi.middleware.cors import CORSMiddleware
- from app.config import get_settings
- settings = get_settings()
- @asynccontextmanager
- async def lifespan(app: FastAPI):
- # 启动时:确保数据目录存在 + 初始化数据库 + 启动 JobQueue
- settings.ensure_dirs()
- from app.core.db import init_db
- await init_db()
- from app.core.job_queue import job_queue
- from app.services.training_service import update_job_in_db
- job_queue.register_callback(update_job_in_db)
- await job_queue.start()
- # 初始化后台任务管理器
- from app.core.background_tasks import background_task_manager
- background_task_manager.set_concurrency("model_download", 5)
- background_task_manager.set_concurrency("dataset_download", 5)
- background_task_manager.set_concurrency("evaluation", 1)
- background_task_manager.set_concurrency("deployment", 1)
- # 恢复因重启中断的任务
- from app.services import model_service, dataset_service, eval_service, deploy_service
- await model_service.recover_stale_downloads()
- await dataset_service.recover_stale_downloads()
- await eval_service.recover_stale_evaluations()
- await deploy_service.recover_stale_deploys()
- yield
- # 关闭时:停止 JobQueue
- await job_queue.stop()
- def create_app() -> FastAPI:
- app = FastAPI(
- title="四川路桥模型微调平台",
- version="0.1.0",
- lifespan=lifespan,
- )
- # CORS 中间件
- app.add_middleware(
- CORSMiddleware,
- allow_origins=settings.backend_cors_origins,
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- )
- # 挂载路由
- from app.api import models as models_api
- from app.api import datasets as datasets_api
- from app.api import training as training_api
- from app.api import evaluation as evaluation_api
- from app.api import deployment as deployment_api
- from app.api import inference as inference_api
- from app.api import auth as auth_api
- from app.api import sample_center as sample_center_api
- from app.api import api_keys as api_keys_api
- from app.core.auth import get_current_active_user
- # 认证路由(无 prefix,端点自带完整路径)
- app.include_router(auth_api.router)
- # API Key 管理路由
- app.include_router(
- api_keys_api.router, prefix="/api/v1/api-keys", tags=["api-keys"],
- dependencies=[Depends(get_current_active_user)],
- )
- # 已有路由:添加认证依赖保护
- app.include_router(
- models_api.router, prefix="/api/v1/models", tags=["models"],
- dependencies=[Depends(get_current_active_user)],
- )
- app.include_router(
- datasets_api.router, prefix="/api/v1/datasets", tags=["datasets"],
- dependencies=[Depends(get_current_active_user)],
- )
- app.include_router(
- training_api.router, prefix="/api/v1/training", tags=["training"],
- dependencies=[Depends(get_current_active_user)],
- )
- app.include_router(
- evaluation_api.router, prefix="/api/v1/evaluation", tags=["evaluation"],
- dependencies=[Depends(get_current_active_user)],
- )
- app.include_router(
- deployment_api.router, prefix="/api/v1/deployment", tags=["deployment"],
- dependencies=[Depends(get_current_active_user)],
- )
- # 代理端点不需要 JWT,使用 API Key 认证
- app.include_router(
- deployment_api.proxy_router, prefix="/api/v1/deployment", tags=["deployment-proxy"],
- )
- app.include_router(
- inference_api.router, prefix="/api/v1/inference", tags=["inference"],
- dependencies=[Depends(get_current_active_user)],
- )
- app.include_router(
- sample_center_api.router, prefix="/api/v1/sample-center", tags=["sample-center"],
- dependencies=[Depends(get_current_active_user)],
- )
- # WebSocket
- from app.core.websocket import router as ws_router
- app.include_router(ws_router)
- @app.get("/health")
- async def health_check():
- return {"status": "ok", "env": settings.backend_env}
- return app
- app = create_app()
|