app.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. from functools import partial
  2. from contextlib import asynccontextmanager
  3. import logging
  4. from pathlib import Path
  5. import aiohttp
  6. from fastapi import FastAPI
  7. from fastapi_cdn_host import patch_docs
  8. from gpustack import __version__
  9. from fastapi.middleware.cors import CORSMiddleware
  10. from gpustack.api import exceptions, middlewares
  11. from gpustack.api.auth import BearerTokenAuthenticator
  12. from gpustack.config.config import Config
  13. from gpustack import envs
  14. from gpustack.routes import ui
  15. from gpustack.routes.routes import api_router
  16. from gpustack.utils.forwarded import ForwardedHostPortMiddleware
  17. from gpustack.security import JWTManager
  18. from gpustack.gateway.utils import worker_websocket_connect_callback
  19. from gpustack.websocket_proxy.message_server import MessageServerHandler
  20. from gpustack.extension import Plugin, iter_plugin_classes
  21. logger = logging.getLogger(__name__)
  22. def create_app(cfg: Config) -> FastAPI:
  23. @asynccontextmanager
  24. async def lifespan(app: FastAPI):
  25. app.state.server_config = cfg
  26. connector = aiohttp.TCPConnector(
  27. limit=envs.TCP_CONNECTOR_LIMIT,
  28. force_close=True,
  29. )
  30. app.state.http_client = aiohttp.ClientSession(
  31. connector=connector, trust_env=True
  32. )
  33. app.state.http_client_no_proxy = aiohttp.ClientSession(connector=connector)
  34. yield
  35. await app.state.http_client.close()
  36. await app.state.http_client_no_proxy.close()
  37. app = FastAPI(
  38. title="GPUStack",
  39. lifespan=lifespan,
  40. response_model_exclude_unset=True,
  41. version=__version__,
  42. docs_url=None if (cfg and cfg.disable_openapi_docs) else "/docs",
  43. redoc_url=None if (cfg and cfg.disable_openapi_docs) else "/redoc",
  44. openapi_url=None if (cfg and cfg.disable_openapi_docs) else "/openapi.json",
  45. )
  46. patch_docs(app, Path(__file__).parents[1] / "ui" / "static")
  47. app.add_middleware(ForwardedHostPortMiddleware)
  48. app.add_middleware(middlewares.RequestTimeMiddleware)
  49. app.add_middleware(middlewares.ModelUsageMiddleware)
  50. app.add_middleware(middlewares.RefreshTokenMiddleware)
  51. if cfg.enable_cors:
  52. app.add_middleware(
  53. CORSMiddleware,
  54. allow_origins=cfg.allow_origins,
  55. allow_credentials=cfg.allow_credentials,
  56. allow_methods=cfg.allow_methods,
  57. allow_headers=cfg.allow_headers,
  58. )
  59. app.include_router(api_router)
  60. ui.register(app)
  61. _load_extension_plugins(app, cfg)
  62. exceptions.register_handlers(app)
  63. app.state.jwt_manager = JWTManager(cfg.jwt_secret_key)
  64. app.state.websocket_authenticator = BearerTokenAuthenticator()
  65. app.state.message_server_handler = MessageServerHandler(
  66. listen_address=cfg.get_proxy_listen_address(cfg.get_advertise_address()),
  67. listen_port=cfg.api_port,
  68. proxy_port=cfg.get_proxy_port(),
  69. authenticator=app.state.websocket_authenticator,
  70. callback_on_connect=partial(
  71. worker_websocket_connect_callback,
  72. proxy_address=cfg.get_proxy_url(),
  73. ),
  74. callback_on_disconnect=worker_websocket_connect_callback,
  75. )
  76. return app
  77. def _load_extension_plugins(app: FastAPI, cfg: Config):
  78. """Load extension plugins registered via entry points.
  79. Each entry point is expected to resolve to a ``Plugin`` subclass
  80. whose ``__init__(app, cfg)`` performs the full registration.
  81. """
  82. app.state.extension_plugins = []
  83. for name, plugin_class in iter_plugin_classes():
  84. try:
  85. if not (
  86. isinstance(plugin_class, type) and issubclass(plugin_class, Plugin)
  87. ):
  88. logger.warning(
  89. f"Extension plugin {name} does not implement the Plugin interface."
  90. )
  91. continue
  92. plugin = plugin_class(app, cfg)
  93. app.state.extension_plugins.append(plugin)
  94. logger.info(f"Loaded extension plugin: {name}")
  95. except Exception as e:
  96. raise RuntimeError(f"Failed to load extension plugin '{name}': {e}") from e