main.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. import os
  2. # 禁用 FlashAttention,解决沐曦显卡共享内存不足问题
  3. # 必须放在最开头,在任何库导入之前设置
  4. os.environ["PYTORCH_NO_FLASH"] = "1"
  5. os.environ["FLASH_ATTENTION_ENABLED"] = "0"
  6. os.environ["USE_FLASH_ATTENTION"] = "0"
  7. os.environ["TORCH_FLASH_ATTN"] = "0"
  8. from contextlib import asynccontextmanager
  9. from fastapi import Depends, FastAPI
  10. from fastapi.middleware.cors import CORSMiddleware
  11. from app.config import get_settings
  12. settings = get_settings()
  13. @asynccontextmanager
  14. async def lifespan(app: FastAPI):
  15. # 启动时:确保数据目录存在 + 初始化数据库 + 启动 JobQueue
  16. settings.ensure_dirs()
  17. from app.core.db import init_db
  18. await init_db()
  19. from app.core.job_queue import job_queue
  20. from app.services.training_service import update_job_in_db
  21. job_queue.register_callback(update_job_in_db)
  22. await job_queue.start()
  23. # 初始化后台任务管理器
  24. from app.core.background_tasks import background_task_manager
  25. background_task_manager.set_concurrency("model_download", 5)
  26. background_task_manager.set_concurrency("dataset_download", 5)
  27. background_task_manager.set_concurrency("evaluation", 1)
  28. background_task_manager.set_concurrency("deployment", 1)
  29. # 恢复因重启中断的任务
  30. from app.services import model_service, dataset_service, eval_service, deploy_service
  31. await model_service.recover_stale_downloads()
  32. await dataset_service.recover_stale_downloads()
  33. await eval_service.recover_stale_evaluations()
  34. await deploy_service.recover_stale_deploys()
  35. yield
  36. # 关闭时:停止 JobQueue
  37. await job_queue.stop()
  38. def create_app() -> FastAPI:
  39. app = FastAPI(
  40. title="四川路桥模型微调平台",
  41. version="0.1.0",
  42. lifespan=lifespan,
  43. )
  44. # CORS 中间件
  45. app.add_middleware(
  46. CORSMiddleware,
  47. allow_origins=settings.backend_cors_origins,
  48. allow_credentials=True,
  49. allow_methods=["*"],
  50. allow_headers=["*"],
  51. )
  52. # 挂载路由
  53. from app.api import models as models_api
  54. from app.api import datasets as datasets_api
  55. from app.api import training as training_api
  56. from app.api import evaluation as evaluation_api
  57. from app.api import deployment as deployment_api
  58. from app.api import inference as inference_api
  59. from app.api import auth as auth_api
  60. from app.api import sample_center as sample_center_api
  61. from app.api import api_keys as api_keys_api
  62. from app.core.auth import get_current_active_user
  63. # 认证路由(无 prefix,端点自带完整路径)
  64. app.include_router(auth_api.router)
  65. # API Key 管理路由
  66. app.include_router(
  67. api_keys_api.router, prefix="/api/v1/api-keys", tags=["api-keys"],
  68. dependencies=[Depends(get_current_active_user)],
  69. )
  70. # 已有路由:添加认证依赖保护
  71. app.include_router(
  72. models_api.router, prefix="/api/v1/models", tags=["models"],
  73. dependencies=[Depends(get_current_active_user)],
  74. )
  75. app.include_router(
  76. datasets_api.router, prefix="/api/v1/datasets", tags=["datasets"],
  77. dependencies=[Depends(get_current_active_user)],
  78. )
  79. app.include_router(
  80. training_api.router, prefix="/api/v1/training", tags=["training"],
  81. dependencies=[Depends(get_current_active_user)],
  82. )
  83. app.include_router(
  84. evaluation_api.router, prefix="/api/v1/evaluation", tags=["evaluation"],
  85. dependencies=[Depends(get_current_active_user)],
  86. )
  87. app.include_router(
  88. deployment_api.router, prefix="/api/v1/deployment", tags=["deployment"],
  89. dependencies=[Depends(get_current_active_user)],
  90. )
  91. # 代理端点不需要 JWT,使用 API Key 认证
  92. app.include_router(
  93. deployment_api.proxy_router, prefix="/api/v1/deployment", tags=["deployment-proxy"],
  94. )
  95. app.include_router(
  96. inference_api.router, prefix="/api/v1/inference", tags=["inference"],
  97. dependencies=[Depends(get_current_active_user)],
  98. )
  99. app.include_router(
  100. sample_center_api.router, prefix="/api/v1/sample-center", tags=["sample-center"],
  101. dependencies=[Depends(get_current_active_user)],
  102. )
  103. # WebSocket
  104. from app.core.websocket import router as ws_router
  105. app.include_router(ws_router)
  106. @app.get("/health")
  107. async def health_check():
  108. return {"status": "ok", "env": settings.backend_env}
  109. return app
  110. app = create_app()