main.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. import os
  2. # 禁用 FlashAttention,解决沐曦显卡共享内存不足问题
  3. # 必须放在最开头,在任何库导入之前设置
  4. os.environ["PYTORCH_NO_FLASH"] = "1"
  5. os.environ["FLASH_ATTENTION_ENABLED"] = "0"
  6. os.environ["USE_FLASH_ATTENTION"] = "0"
  7. os.environ["TORCH_FLASH_ATTN"] = "0"
  8. from contextlib import asynccontextmanager
  9. from fastapi import FastAPI
  10. from fastapi.middleware.cors import CORSMiddleware
  11. from app.config import get_settings
  12. settings = get_settings()
  13. @asynccontextmanager
  14. async def lifespan(app: FastAPI):
  15. # 启动时:确保数据目录存在 + 初始化数据库 + 启动 JobQueue
  16. settings.ensure_dirs()
  17. from app.core.db import init_db
  18. await init_db()
  19. from app.core.job_queue import job_queue
  20. from app.services.training_service import update_job_in_db
  21. job_queue.register_callback(update_job_in_db)
  22. await job_queue.start()
  23. yield
  24. # 关闭时:停止 JobQueue
  25. await job_queue.stop()
  26. def create_app() -> FastAPI:
  27. app = FastAPI(
  28. title="PEFT Fine-Tuning Platform",
  29. version="0.1.0",
  30. lifespan=lifespan,
  31. )
  32. # CORS 中间件
  33. app.add_middleware(
  34. CORSMiddleware,
  35. allow_origins=settings.backend_cors_origins,
  36. allow_credentials=True,
  37. allow_methods=["*"],
  38. allow_headers=["*"],
  39. )
  40. # 挂载路由
  41. from app.api import models as models_api
  42. from app.api import datasets as datasets_api
  43. from app.api import training as training_api
  44. from app.api import evaluation as evaluation_api
  45. from app.api import deployment as deployment_api
  46. from app.api import inference as inference_api
  47. app.include_router(models_api.router, prefix="/api/v1/models", tags=["models"])
  48. app.include_router(datasets_api.router, prefix="/api/v1/datasets", tags=["datasets"])
  49. app.include_router(training_api.router, prefix="/api/v1/training", tags=["training"])
  50. app.include_router(evaluation_api.router, prefix="/api/v1/evaluation", tags=["evaluation"])
  51. app.include_router(deployment_api.router, prefix="/api/v1/deployment", tags=["deployment"])
  52. app.include_router(inference_api.router, prefix="/api/v1/inference", tags=["inference"])
  53. # WebSocket
  54. from app.core.websocket import router as ws_router
  55. app.include_router(ws_router)
  56. @app.get("/health")
  57. async def health_check():
  58. return {"status": "ok", "env": settings.backend_env}
  59. return app
  60. app = create_app()