| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106 |
- from functools import partial
- from contextlib import asynccontextmanager
- import logging
- from pathlib import Path
- import aiohttp
- from fastapi import FastAPI
- from fastapi_cdn_host import patch_docs
- from gpustack import __version__
- from fastapi.middleware.cors import CORSMiddleware
- from gpustack.api import exceptions, middlewares
- from gpustack.api.auth import BearerTokenAuthenticator
- from gpustack.config.config import Config
- from gpustack import envs
- from gpustack.routes import ui
- from gpustack.routes.routes import api_router
- from gpustack.utils.forwarded import ForwardedHostPortMiddleware
- from gpustack.security import JWTManager
- from gpustack.gateway.utils import worker_websocket_connect_callback
- from gpustack.websocket_proxy.message_server import MessageServerHandler
- from gpustack.extension import Plugin, iter_plugin_classes
- logger = logging.getLogger(__name__)
- def create_app(cfg: Config) -> FastAPI:
- @asynccontextmanager
- async def lifespan(app: FastAPI):
- app.state.server_config = cfg
- connector = aiohttp.TCPConnector(
- limit=envs.TCP_CONNECTOR_LIMIT,
- force_close=True,
- )
- app.state.http_client = aiohttp.ClientSession(
- connector=connector, trust_env=True
- )
- app.state.http_client_no_proxy = aiohttp.ClientSession(connector=connector)
- yield
- await app.state.http_client.close()
- await app.state.http_client_no_proxy.close()
- app = FastAPI(
- title="GPUStack",
- lifespan=lifespan,
- response_model_exclude_unset=True,
- version=__version__,
- docs_url=None if (cfg and cfg.disable_openapi_docs) else "/docs",
- redoc_url=None if (cfg and cfg.disable_openapi_docs) else "/redoc",
- openapi_url=None if (cfg and cfg.disable_openapi_docs) else "/openapi.json",
- )
- patch_docs(app, Path(__file__).parents[1] / "ui" / "static")
- app.add_middleware(ForwardedHostPortMiddleware)
- app.add_middleware(middlewares.RequestTimeMiddleware)
- app.add_middleware(middlewares.ModelUsageMiddleware)
- app.add_middleware(middlewares.RefreshTokenMiddleware)
- if cfg.enable_cors:
- app.add_middleware(
- CORSMiddleware,
- allow_origins=cfg.allow_origins,
- allow_credentials=cfg.allow_credentials,
- allow_methods=cfg.allow_methods,
- allow_headers=cfg.allow_headers,
- )
- app.include_router(api_router)
- ui.register(app)
- _load_extension_plugins(app, cfg)
- exceptions.register_handlers(app)
- app.state.jwt_manager = JWTManager(cfg.jwt_secret_key)
- app.state.websocket_authenticator = BearerTokenAuthenticator()
- app.state.message_server_handler = MessageServerHandler(
- listen_address=cfg.get_proxy_listen_address(cfg.get_advertise_address()),
- listen_port=cfg.api_port,
- proxy_port=cfg.get_proxy_port(),
- authenticator=app.state.websocket_authenticator,
- callback_on_connect=partial(
- worker_websocket_connect_callback,
- proxy_address=cfg.get_proxy_url(),
- ),
- callback_on_disconnect=worker_websocket_connect_callback,
- )
- return app
- def _load_extension_plugins(app: FastAPI, cfg: Config):
- """Load extension plugins registered via entry points.
- Each entry point is expected to resolve to a ``Plugin`` subclass
- whose ``__init__(app, cfg)`` performs the full registration.
- """
- app.state.extension_plugins = []
- for name, plugin_class in iter_plugin_classes():
- try:
- if not (
- isinstance(plugin_class, type) and issubclass(plugin_class, Plugin)
- ):
- logger.warning(
- f"Extension plugin {name} does not implement the Plugin interface."
- )
- continue
- plugin = plugin_class(app, cfg)
- app.state.extension_plugins.append(plugin)
- logger.info(f"Loaded extension plugin: {name}")
- except Exception as e:
- raise RuntimeError(f"Failed to load extension plugin '{name}': {e}") from e
|