Kaynağa Gözat

feat: 接入统一认证平台 SSO (LQAI-middle-platform)

后端:
- 新增 SSO 配置项和 CLI 参数 (sso_base_url, sso_client_id 等)
- 新增 SSO 核心服务模块 (api/sso.py): code 换 token、获取用户信息、同步用户
- 新增 GET /auth/sso/authorize 和 POST /auth/oauth/exchange-code 端点
- 修改登出接口返回 SSO 登出 URL
- RefreshTokenMiddleware 增加 Bearer token 滑动过期支持
- 修改默认模型库数据源为 ModelScope

前端:
- 新增 SSO 回调页面 /auth/callback
- 请求拦截器添加 Bearer token,响应拦截器处理 X-New-Token
- 登录页面改为统一认证平台登录,隐藏本地用户名密码登录
- getInitialState 跳过 SSO 回调页面的用户信息获取
- 新增 /auth/callback 路由
kinglee 1 hafta önce
ebeveyn
işleme
06650138e4

+ 22 - 0
gpustack/api/middlewares.py

@@ -357,6 +357,8 @@ class RefreshTokenMiddleware(BaseHTTPMiddleware):
         response = await call_next(request)
 
         jwt_manager: JWTManager = request.app.state.jwt_manager
+
+        # Cookie-based refresh (existing local auth)
         token = request.cookies.get(SESSION_COOKIE_NAME)
 
         if token:
@@ -377,6 +379,26 @@ class RefreshTokenMiddleware(BaseHTTPMiddleware):
                         )
             except (ExpiredSignatureError, DecodeError):
                 pass
+        else:
+            # SSO Bearer token sliding expiration
+            auth_header = request.headers.get("Authorization", "")
+            if auth_header.startswith("Bearer "):
+                bearer_token = auth_header[7:]
+                try:
+                    payload = jwt_manager.decode_jwt_token(bearer_token)
+                    if payload:
+                        exp = payload.get('exp', 0)
+                        iat = payload.get('iat', 0) or (exp - envs.JWT_TOKEN_EXPIRE_MINUTES * 60)
+                        lifetime = exp - iat
+                        remaining = exp - time.time()
+                        # If token has used more than 50% of its lifetime, issue a new one
+                        if remaining < lifetime * 0.5:
+                            new_token = jwt_manager.create_jwt_token(
+                                username=payload['sub']
+                            )
+                            response.headers['X-New-Token'] = new_token
+                except (ExpiredSignatureError, DecodeError):
+                    pass
 
         return response
 

+ 212 - 0
gpustack/api/sso.py

@@ -0,0 +1,212 @@
+"""
+SSO (LQAI-middle-platform) OAuth2 integration.
+Implements the code exchange flow: code -> SSO access_token -> userinfo -> local JWT.
+"""
+
+import logging
+from typing import Optional, Dict, Any
+from urllib.parse import urlencode
+
+import httpx
+
+from gpustack.config.config import Config
+from gpustack.security import JWTManager
+from gpustack.server.services import create_user_with_principal
+
+logger = logging.getLogger(__name__)
+
+SSO_TOKEN_TIMEOUT = httpx.Timeout(connect=15.0, read=30.0, write=30.0, pool=5.0)
+SSO_USERINFO_TIMEOUT = httpx.Timeout(connect=15.0, read=30.0, write=30.0, pool=5.0)
+
+
+def build_sso_authorize_url(config: Config, redirect: bool = False) -> str:
+    """Build the SSO OAuth2 authorization URL."""
+    params = {
+        "response_type": "code",
+        "client_id": config.sso_client_id,
+        "redirect_uri": config.sso_redirect_uri,
+        "scope": config.sso_scope,
+    }
+    authorize_url = f"{config.sso_base_url}/oauth/authorize?{urlencode(params)}"
+    return authorize_url
+
+
+async def exchange_code_for_sso_token(
+    config: Config, code: str
+) -> Dict[str, Any]:
+    """
+    Step 4a: Use authorization code to get SSO access_token.
+    POST {SSO_BASE_URL}/oauth/token
+    """
+    data = {
+        "grant_type": "authorization_code",
+        "code": code,
+        "redirect_uri": config.sso_redirect_uri,
+        "client_id": config.sso_client_id,
+        "client_secret": config.sso_client_secret,
+    }
+
+    async with httpx.AsyncClient(
+        timeout=SSO_TOKEN_TIMEOUT,
+        verify=not config.sso_base_url.startswith("http://"),
+    ) as client:
+        resp = await client.post(
+            f"{config.sso_base_url}/oauth/token",
+            data=data,
+            headers={"Content-Type": "application/x-www-form-urlencoded"},
+        )
+
+    if resp.status_code != 200:
+        logger.error(f"SSO token exchange failed: {resp.status_code} {resp.text}")
+        error_data = resp.json() if resp.text else {}
+        error = error_data.get("error", "unknown_error")
+        error_desc = error_data.get("error_description", "令牌交换失败")
+        raise Exception(f"SSO token exchange failed: {error} - {error_desc}")
+
+    return resp.json()
+
+
+async def get_sso_userinfo(
+    config: Config, access_token: str
+) -> Dict[str, Any]:
+    """
+    Step 4b: Get user info from SSO platform.
+    GET {SSO_BASE_URL}/oauth/userinfo
+    """
+    async with httpx.AsyncClient(
+        timeout=SSO_USERINFO_TIMEOUT,
+        verify=not config.sso_base_url.startswith("http://"),
+    ) as client:
+        resp = await client.get(
+            f"{config.sso_base_url}/oauth/userinfo",
+            headers={"Authorization": f"Bearer {access_token}"},
+        )
+
+    if resp.status_code != 200:
+        logger.error(f"SSO userinfo failed: {resp.status_code} {resp.text}")
+        raise Exception("获取用户信息失败")
+
+    return resp.json()
+
+
+def extract_role_codes(userinfo: Dict[str, Any]) -> list:
+    """Extract role codes from SSO userinfo roles field."""
+    roles = userinfo.get("roles", [])
+    role_codes = []
+    for role in roles:
+        if isinstance(role, dict):
+            code = role.get("code")
+            if code:
+                role_codes.append(code)
+        elif isinstance(role, str):
+            role_codes.append(role)
+    return role_codes
+
+
+async def sync_user_from_sso(
+    session,
+    config: Config,
+    userinfo: Dict[str, Any],
+) -> Any:
+    """
+    Step 5: Sync user from SSO to local database.
+    Find or create user, sync roles.
+    """
+    username = userinfo.get("username") or userinfo.get("sub")
+    if not username:
+        raise Exception("SSO 返回的用户信息中缺少 username")
+
+    email = userinfo.get("email", "")
+    full_name = userinfo.get("real_name", username)
+    avatar_url = userinfo.get("avatar_url", "")
+    role_codes = extract_role_codes(userinfo)
+
+    is_admin = "super_admin" in role_codes
+
+    # Find existing user by username
+    from gpustack.schemas.users import User, AuthProviderEnum
+
+    existing = await User.first_by_field(
+        session, "username", username
+    )
+
+    if existing:
+        # Update user info
+        patch = {
+            "full_name": full_name,
+            "avatar_url": avatar_url,
+            "is_admin": is_admin,
+            "is_active": True,
+            "source": AuthProviderEnum.OIDC,
+        }
+        await existing.update(session, patch)
+        logger.info(f"Updated SSO user: {username}")
+        return existing
+    else:
+        # Create new user
+        # SSO users don't have a local password; generate a random one
+        import secrets
+        random_password = secrets.token_urlsafe(32)
+
+        user = await create_user_with_principal(
+            session=session,
+            username=username,
+            password=random_password,
+            is_admin=is_admin,
+            full_name=full_name,
+            avatar_url=avatar_url,
+            source=AuthProviderEnum.OIDC,
+        )
+        logger.info(f"Created SSO user: {username}")
+        return user
+
+
+async def handle_sso_exchange_code(
+    session,
+    config: Config,
+    code: str,
+    jwt_manager,
+) -> Dict[str, Any]:
+    """
+    Core SSO exchange code flow (Steps 4-6):
+    1. Exchange code for SSO access_token
+    2. Get user info from SSO
+    3. Sync user to local DB
+    4. Issue local JWT
+    """
+    # Step 4a: Get SSO access_token
+    token_data = await exchange_code_for_sso_token(config, code)
+    sso_access_token = token_data.get("access_token")
+    if not sso_access_token:
+        raise Exception("获取 SSO access_token 失败")
+
+    # Step 4b: Get user info
+    userinfo = await get_sso_userinfo(config, sso_access_token)
+    if not userinfo.get("username") and not userinfo.get("sub"):
+        raise Exception("SSO 用户信息格式异常")
+
+    # Step 5: Sync user
+    user = await sync_user_from_sso(session, config, userinfo)
+
+    # Step 6: Issue local JWT
+    local_token = jwt_manager.create_jwt_token(username=user.username)
+
+    # Build user response
+    role_codes = extract_role_codes(userinfo)
+    user_data = {
+        "id": str(user.id),
+        "username": user.username,
+        "email": userinfo.get("email", ""),
+        "phone": userinfo.get("phone", ""),
+        "full_name": user.full_name,
+        "avatar_url": user.avatar_url,
+        "is_superuser": user.is_admin,
+        "is_active": user.is_active,
+        "roles": role_codes,
+    }
+
+    return {
+        "token": local_token,
+        "refresh_token": "",  # SSO flow doesn't need refresh token for now
+        "user": user_data,
+    }

+ 42 - 0
gpustack/cmd/start.py

@@ -498,6 +498,48 @@ def start_cmd_options(parser_server: argparse.ArgumentParser):
         help="Generic key for post-logout redirection across IdPs.",
         default=get_gpustack_env("EXTERNAL_AUTH_POST_LOGOUT_REDIRECT_KEY"),
     )
+    server_group.add_argument(
+        "--sso-base-url",
+        type=str,
+        help="SSO platform base URL (e.g. http://192.168.92.61:8200).",
+        default=get_gpustack_env("SSO_BASE_URL"),
+    )
+    server_group.add_argument(
+        "--sso-client-id",
+        type=str,
+        help="SSO client ID (app_key from t_sys_app).",
+        default=get_gpustack_env("SSO_CLIENT_ID"),
+    )
+    server_group.add_argument(
+        "--sso-client-secret",
+        type=str,
+        help="SSO client secret (app_secret from t_sys_app).",
+        default=get_gpustack_env("SSO_CLIENT_SECRET"),
+    )
+    server_group.add_argument(
+        "--sso-redirect-uri",
+        type=str,
+        help="SSO redirect URI (must be registered in t_sys_app).",
+        default=get_gpustack_env("SSO_REDIRECT_URI"),
+    )
+    server_group.add_argument(
+        "--sso-frontend-url",
+        type=str,
+        help="Frontend base URL for SSO redirect.",
+        default=get_gpustack_env("SSO_FRONTEND_URL"),
+    )
+    server_group.add_argument(
+        "--sso-scope",
+        type=str,
+        help="OAuth2 scope for SSO (default: email).",
+        default=get_gpustack_env("SSO_SCOPE") or "email",
+    )
+    server_group.add_argument(
+        "--sso-logout-redirect-url",
+        type=str,
+        help="SSO logout redirect URL.",
+        default=get_gpustack_env("SSO_LOGOUT_REDIRECT_URL"),
+    )
 
     worker_group = parser_server.add_argument_group("Worker settings")
     worker_group.add_argument(

+ 8 - 0
gpustack/config/config.py

@@ -197,6 +197,14 @@ class Config(WorkerConfig, BaseSettings):
     server_external_url: Optional[str] = None
     # custom post-logout redirection key for compatibility with different IdPs.
     external_auth_post_logout_redirect_key: Optional[str] = None
+    # SSO (LQAI-middle-platform) OAuth2 configuration
+    sso_base_url: Optional[str] = None  # e.g. http://192.168.92.61:8200
+    sso_client_id: Optional[str] = None  # app_key from t_sys_app
+    sso_client_secret: Optional[str] = None  # app_secret from t_sys_app
+    sso_redirect_uri: Optional[str] = None  # e.g. http://localhost:3000/auth/callback
+    sso_frontend_url: Optional[str] = None  # frontend base URL
+    sso_scope: str = "email"  # OAuth2 scope
+    sso_logout_redirect_url: Optional[str] = None  # e.g. http://192.168.92.61:9200/login
     # Number of concurrent connections for the embedded gateway.
     gateway_concurrency: int = 16
     gateway_plugin_server_url: Optional[str] = None

+ 77 - 0
gpustack/routes/auth.py

@@ -552,6 +552,12 @@ async def logout(request: Request):
         except Exception as e:
             logger.error(f"Failed to get SAML logout url: {str(e)}")
             external_logout_url = None
+
+    # SSO logout: return SSO platform logout URL
+    sso_logout_url = config.sso_logout_redirect_url
+    if sso_login and sso_logout_url:
+        external_logout_url = sso_logout_url
+
     sso_login = request.cookies.get(SSO_LOGIN_COOKIE_NAME)
     content = json.dumps({"logout_url": external_logout_url}) if sso_login else ""
     resp = Response(content=content, media_type="application/json")
@@ -632,3 +638,74 @@ def remove_initial_password_file_if_exists(config: Config):
             logger.debug(f"Initial password file deleted: {initial_password_file}")
         except Exception as e:
             logger.warning(f"Failed to delete initial password file: {e}")
+
+
+# SSO (LQAI-middle-platform) OAuth2 integration endpoints
+
+
+from gpustack.api.sso import (
+    build_sso_authorize_url,
+    handle_sso_exchange_code,
+)
+from pydantic import BaseModel
+
+
+class ExchangeCodeRequest(BaseModel):
+    code: str
+
+
+@router.get("/sso/authorize")
+async def sso_authorize(request: Request, redirect: bool = False):
+    """
+    Build SSO OAuth2 authorization URL.
+    If redirect=True, directly 302 redirect to SSO authorization page.
+    """
+    config: Config = request.app.state.server_config
+
+    if not config.sso_base_url or not config.sso_client_id:
+        raise InvalidException(message="SSO 未配置,请先配置 SSO_BASE_URL 和 SSO_CLIENT_ID")
+
+    authorize_url = build_sso_authorize_url(config)
+
+    if redirect:
+        return RedirectResponse(url=authorize_url)
+
+    return {
+        "code": "000000",
+        "message": "获取授权URL成功",
+        "data": {"authorize_url": authorize_url},
+    }
+
+
+@router.post("/oauth/exchange-code")
+async def oauth_exchange_code(
+    request: Request,
+    session: SessionDep,
+    body: ExchangeCodeRequest,
+):
+    """
+    Exchange SSO authorization code for local JWT.
+    Core SSO login endpoint.
+    """
+    config: Config = request.app.state.server_config
+
+    if not config.sso_base_url or not config.sso_client_id:
+        raise InvalidException(message="SSO 未配置")
+
+    if not body.code:
+        raise BadRequestException(message="缺少授权码")
+
+    try:
+        jwt_manager: JWTManager = request.app.state.jwt_manager
+        result = await handle_sso_exchange_code(session, config, body.code, jwt_manager)
+        return {
+            "code": "000000",
+            "message": "登录成功",
+            "data": result,
+        }
+    except Exception as e:
+        logger.error(f"SSO exchange failed: {e}")
+        error_msg = str(e)
+        if "invalid_grant" in error_msg or "授权码" in error_msg:
+            raise BadRequestException(message=f"登录失败: 授权码无效")
+        raise InvalidException(message=f"登录失败: {error_msg}")

+ 4 - 4
gpustack/server/catalog.py

@@ -121,10 +121,10 @@ def get_builtin_model_catalog_file() -> str:
     huggingface_url = "https://huggingface.co"
     modelscope_url = "https://modelscope.cn"
 
-    model_catalog_file_name = "model-catalog.yaml"
-    if not can_access(huggingface_url) and can_access(modelscope_url):
-        model_catalog_file_name = "model-catalog-modelscope.yaml"
-        logger.info(f"Cannot access {huggingface_url}, using ModelScope model catalog.")
+    model_catalog_file_name = "model-catalog-modelscope.yaml"
+    if not can_access(modelscope_url) and can_access(huggingface_url):
+        model_catalog_file_name = "model-catalog.yaml"
+        logger.info(f"Cannot access {modelscope_url}, using HuggingFace model catalog.")
 
     return str(pkg_resources.files("gpustack.assets").joinpath(model_catalog_file_name))