main.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. from contextlib import asynccontextmanager
  2. from fastapi import FastAPI
  3. from fastapi.middleware.cors import CORSMiddleware
  4. from app.config import get_settings
  5. settings = get_settings()
  6. @asynccontextmanager
  7. async def lifespan(app: FastAPI):
  8. # 启动时:确保数据目录存在 + 初始化数据库 + 启动 JobQueue
  9. settings.ensure_dirs()
  10. from app.core.db import init_db
  11. await init_db()
  12. from app.core.job_queue import job_queue
  13. from app.services.training_service import update_job_in_db
  14. job_queue.register_callback(update_job_in_db)
  15. await job_queue.start()
  16. yield
  17. # 关闭时:停止 JobQueue
  18. await job_queue.stop()
  19. def create_app() -> FastAPI:
  20. app = FastAPI(
  21. title="PEFT Fine-Tuning Platform",
  22. version="0.1.0",
  23. lifespan=lifespan,
  24. )
  25. # CORS 中间件
  26. app.add_middleware(
  27. CORSMiddleware,
  28. allow_origins=settings.backend_cors_origins,
  29. allow_credentials=True,
  30. allow_methods=["*"],
  31. allow_headers=["*"],
  32. )
  33. # 挂载路由
  34. from app.api import models as models_api
  35. from app.api import datasets as datasets_api
  36. from app.api import training as training_api
  37. from app.api import evaluation as evaluation_api
  38. from app.api import deployment as deployment_api
  39. from app.api import inference as inference_api
  40. app.include_router(models_api.router, prefix="/api/v1/models", tags=["models"])
  41. app.include_router(datasets_api.router, prefix="/api/v1/datasets", tags=["datasets"])
  42. app.include_router(training_api.router, prefix="/api/v1/training", tags=["training"])
  43. app.include_router(evaluation_api.router, prefix="/api/v1/evaluation", tags=["evaluation"])
  44. app.include_router(deployment_api.router, prefix="/api/v1/deployment", tags=["deployment"])
  45. app.include_router(inference_api.router, prefix="/api/v1/inference", tags=["inference"])
  46. # WebSocket
  47. from app.core.websocket import router as ws_router
  48. app.include_router(ws_router)
  49. @app.get("/health")
  50. async def health_check():
  51. return {"status": "ok", "env": settings.backend_env}
  52. return app
  53. app = create_app()