auth.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746
  1. import base64
  2. import json
  3. import os
  4. from pathlib import Path
  5. import secrets
  6. import httpx
  7. import logging
  8. import jwt
  9. from jwt.algorithms import RSAAlgorithm
  10. from gpustack.config.config import Config
  11. from typing import Annotated, Dict, Optional
  12. from fastapi import APIRouter, Form, Request, Response
  13. from gpustack.api.exceptions import (
  14. InvalidException,
  15. UnauthorizedException,
  16. BadRequestException,
  17. )
  18. from gpustack.schemas.users import UpdatePassword
  19. from gpustack.schemas.users import User, AuthProviderEnum
  20. from gpustack.security import (
  21. JWTManager,
  22. get_secret_hash,
  23. verify_hashed_secret,
  24. )
  25. from gpustack import envs
  26. from gpustack.api.auth import (
  27. SESSION_COOKIE_NAME,
  28. OIDC_ID_TOKEN_COOKIE_NAME,
  29. OIDC_STATE_COOKIE_NAME,
  30. SSO_LOGIN_COOKIE_NAME,
  31. authenticate_user,
  32. )
  33. from gpustack.server.deps import CurrentUserDep, SessionDep
  34. from gpustack.server.services import (
  35. create_user_with_principal,
  36. provision_user_principal,
  37. )
  38. from onelogin.saml2.auth import OneLogin_Saml2_Auth
  39. from fastapi.responses import RedirectResponse
  40. from gpustack.utils.convert import safe_b64decode, inflate_data
  41. from urllib.parse import urlencode
  42. from gpustack.utils.network import (
  43. get_system_trust_store_ssl_context,
  44. use_proxy_env_for_url,
  45. )
  46. router = APIRouter()
  47. timeout = httpx.Timeout(connect=15.0, read=60.0, write=60.0, pool=10.0)
  48. logger = logging.getLogger(__name__)
  49. async def decode_and_validate_token(
  50. client: httpx.AsyncClient, token: str, config: Config
  51. ) -> Dict:
  52. """
  53. Decode the JWT token without verification and check if required fields are present.
  54. Args:
  55. token: token from OIDC provider
  56. config: Application configuration
  57. Returns:
  58. Dictionary containing decoded token data
  59. """
  60. jwks_uri = config.openid_configuration["jwks_uri"]
  61. jwks_res = await client.get(jwks_uri)
  62. jwks = jwks_res.json()
  63. unverified_header = jwt.get_unverified_header(token)
  64. kid = unverified_header.get("kid", None)
  65. public_key = None
  66. if kid:
  67. for key in jwks['keys']:
  68. if key['kid'] == kid:
  69. public_key = RSAAlgorithm.from_jwk(json.dumps(key))
  70. break
  71. else:
  72. public_key = RSAAlgorithm.from_jwk(json.dumps(jwks['keys'][0]))
  73. if public_key is None:
  74. raise UnauthorizedException(message="Public key not found in JWKS")
  75. claims = jwt.decode(
  76. token,
  77. public_key,
  78. algorithms=['RS256'],
  79. options={"verify_aud": False, "verify_iss": False},
  80. )
  81. return claims
  82. async def get_oidc_user_data(
  83. client: httpx.AsyncClient, token_res, config: Config
  84. ) -> Dict:
  85. """
  86. Retrieve user data from OIDC token or userinfo endpoint.
  87. By default, it uses the userinfo endpoint (standard OIDC).
  88. If `oidc_skip_userinfo` is set to True in config, it retrieves data from the ID token.
  89. Args:
  90. client: HTTP client for making requests
  91. token_res: The token response from OIDC provider
  92. config: Application configuration
  93. Returns:
  94. Dictionary containing user data
  95. """
  96. user_data = None
  97. if not isinstance(token_res, Dict):
  98. raise InvalidException(message="Invalid token response")
  99. if config.oidc_skip_userinfo:
  100. tokens = []
  101. if access_token := token_res.get("access_token", None):
  102. tokens.append(access_token)
  103. if id_token := token_res.get("id_token", None):
  104. tokens.append(id_token)
  105. for token in tokens:
  106. try:
  107. user_data = await decode_and_validate_token(client, token, config)
  108. if user_data:
  109. break
  110. except Exception as e:
  111. logger.warning(f"Token decoding/validation failed: {str(e)}")
  112. else:
  113. token = token_res.get("access_token", "")
  114. userinfo_endpoint = config.openid_configuration["userinfo_endpoint"]
  115. headers = {'Authorization': f'Bearer {token}'}
  116. userinfo_res = await client.get(userinfo_endpoint, headers=headers)
  117. if userinfo_res.status_code == 200:
  118. user_data = userinfo_res.json()
  119. else:
  120. raise UnauthorizedException(
  121. message="Failed to fetch user info from userinfo endpoint"
  122. )
  123. if not user_data:
  124. raise UnauthorizedException(message="Failed to retrieve valid user data")
  125. return user_data
  126. async def init_saml_auth(request: Request):
  127. """
  128. Initialize SAML authentication configuration.
  129. """
  130. config: Config = request.app.state.server_config
  131. form_data = await request.form()
  132. form_dict = dict(form_data)
  133. saml_settings = {
  134. "strict": True,
  135. "sp": {
  136. "entityId": config.saml_sp_entity_id, # sp_entityId
  137. "assertionConsumerService": {
  138. "url": config.saml_sp_acs_url, # callback url
  139. "binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect",
  140. },
  141. "singleLogoutService": {
  142. "url": config.saml_sp_slo_url,
  143. "binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect",
  144. },
  145. "x509cert": config.saml_sp_x509_cert, # SP public key
  146. "privateKey": config.saml_sp_private_key, # sp privateKey
  147. },
  148. "idp": {
  149. "entityId": config.saml_idp_entity_id, # idp_entityId
  150. "singleSignOnService": {
  151. "url": config.saml_idp_server_url, # server url
  152. "binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect",
  153. },
  154. "singleLogoutService": {
  155. "url": config.saml_idp_logout_url,
  156. "binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect",
  157. },
  158. "x509cert": config.saml_idp_x509_cert, # idp public key
  159. },
  160. "security": json.loads(config.saml_security),
  161. } # Signature configuration
  162. req = {
  163. "http_host": request.client.host,
  164. "script_name": request.url.path,
  165. "get_data": dict(request.query_params),
  166. "post_data": form_dict,
  167. }
  168. return OneLogin_Saml2_Auth(req, saml_settings)
  169. # SAML login and callback endpoints
  170. @router.get("/saml/login")
  171. async def saml_login(request: Request):
  172. auth = await init_saml_auth(request)
  173. return RedirectResponse(url=auth.login())
  174. @router.api_route("/saml/callback", methods=["GET", "POST"])
  175. async def saml_callback(request: Request, session: SessionDep):
  176. logger.debug("Invoke saml callback.")
  177. try:
  178. auth = await init_saml_auth(request)
  179. if request.method == "GET":
  180. # For HTTP-Redirect binding, decode the deflated SAML response
  181. # and set it in the request data for process_response to handle.
  182. query = dict(request.query_params)
  183. SAMLResponse = query['SAMLResponse']
  184. decoded = safe_b64decode(SAMLResponse)
  185. xml_bytes = inflate_data(decoded)
  186. auth._request_data['post_data']['SAMLResponse'] = base64.b64encode(
  187. xml_bytes
  188. ).decode()
  189. else:
  190. # For HTTP-POST binding, SAMLResponse is base64-encoded XML.
  191. form_data = await request.form()
  192. SAMLResponse = dict(form_data).get('SAMLResponse')
  193. auth._request_data['post_data']['SAMLResponse'] = SAMLResponse
  194. # Use the SAML library to validate signatures, conditions, and prevent XXE.
  195. auth.process_response()
  196. errors = auth.get_errors()
  197. if errors:
  198. error_reason = auth.get_last_error_reason() or "Unknown error"
  199. raise Exception(f"SAML validation failed: {error_reason}")
  200. # Extract validated attributes from the SAML response.
  201. name_id = auth.get_nameid()
  202. saml_attributes = auth.get_attributes()
  203. attributes = {}
  204. attributes['name_id'] = name_id
  205. for key, values in saml_attributes.items():
  206. attributes[key] = values[0] if len(values) == 1 else values
  207. config: Config = request.app.state.server_config
  208. if config.external_auth_name:
  209. # If external_auth_name is set, use it as username.
  210. username = get_saml_attributes(
  211. config, attributes, config.external_auth_name
  212. )
  213. else:
  214. # Try email or name_id for username if external_auth_name is not set.
  215. for key in ["email", "emailaddress", "name_id", "nameidentifier"]:
  216. username = get_saml_attributes(config, attributes, key)
  217. if username:
  218. break
  219. else:
  220. raise Exception(message="No valid username found in saml attributes")
  221. if config.external_auth_full_name and '+' not in config.external_auth_full_name:
  222. # If external_auth_full_name is set, use it as user's full name.
  223. full_name = get_saml_attributes(
  224. config, attributes, config.external_auth_full_name
  225. )
  226. elif config.external_auth_full_name:
  227. # external_auth_full_name is set with concat symbol '+'.
  228. full_name = ' '.join(
  229. [
  230. get_saml_attributes(config, attributes, v.strip())
  231. for v in config.external_auth_full_name.split('+')
  232. ]
  233. )
  234. else:
  235. full_name = ""
  236. # Try common claims. These are not guaranteed to be present.
  237. for key in ["displayName", "name"]:
  238. full_name = get_saml_attributes(config, attributes, key)
  239. if full_name:
  240. break
  241. avatar_url = None
  242. if config.external_auth_avatar_url:
  243. avatar_url = get_saml_attributes(
  244. config, attributes, config.external_auth_avatar_url
  245. )
  246. # determine whether the user already exists
  247. user = await User.first_by_field(
  248. session=session, field="username", value=username
  249. )
  250. # create user
  251. if not user:
  252. user_info = User(
  253. username=username,
  254. full_name=full_name,
  255. avatar_url=avatar_url,
  256. hashed_password="",
  257. is_admin=False,
  258. is_active=not config.external_auth_default_inactive,
  259. source=AuthProviderEnum.SAML,
  260. require_password_change=False,
  261. )
  262. user = await create_user_with_principal(session, user_info)
  263. await session.commit()
  264. elif (
  265. getattr(user, "id", None) is not None
  266. and getattr(user, "principal_id", None) is None
  267. ):
  268. # Backfill for SSO users created before Personal Org
  269. # provisioning was wired in. Idempotent: only fires when the
  270. # user has no Personal Org pointer.
  271. await provision_user_principal(session, user)
  272. await session.commit()
  273. jwt_manager: JWTManager = request.app.state.jwt_manager
  274. access_token = jwt_manager.create_jwt_token(
  275. username=username,
  276. )
  277. response = RedirectResponse(url='/', status_code=303)
  278. response.set_cookie(
  279. key=SESSION_COOKIE_NAME,
  280. value=access_token,
  281. httponly=True,
  282. max_age=envs.JWT_TOKEN_EXPIRE_MINUTES * 60,
  283. expires=envs.JWT_TOKEN_EXPIRE_MINUTES * 60,
  284. )
  285. response.set_cookie(
  286. key=SSO_LOGIN_COOKIE_NAME,
  287. value="true",
  288. httponly=True,
  289. max_age=envs.JWT_TOKEN_EXPIRE_MINUTES * 60,
  290. expires=envs.JWT_TOKEN_EXPIRE_MINUTES * 60,
  291. )
  292. except Exception as e:
  293. logger.error(f"SAML callback error: {str(e)}")
  294. raise UnauthorizedException(message=str(e))
  295. return response
  296. @router.api_route("/saml/logout/callback", methods=["GET", "POST"])
  297. async def saml_logout_callback(request: Request):
  298. try:
  299. auth = await init_saml_auth(request)
  300. auth.process_slo(False)
  301. except Exception:
  302. pass
  303. response = RedirectResponse(url="/")
  304. response.delete_cookie(key=SESSION_COOKIE_NAME)
  305. response.delete_cookie(key=SSO_LOGIN_COOKIE_NAME)
  306. return response
  307. def get_saml_attributes(
  308. config: Config, attributes: Dict[str, str], name: str
  309. ) -> Optional[str]:
  310. search_keys = []
  311. if config.saml_sp_attribute_prefix:
  312. search_keys.append(config.saml_sp_attribute_prefix + name)
  313. search_keys.extend(
  314. [
  315. f"http://schemas.xmlsoap.org/ws/2005/05/identity/claims/{name}",
  316. name,
  317. ]
  318. )
  319. for key in search_keys:
  320. if key in attributes:
  321. return attributes[key]
  322. return None
  323. # OIDC login and callback endpoints
  324. @router.get("/oidc/login")
  325. async def oidc_login(request: Request):
  326. config: Config = request.app.state.server_config
  327. authorization_endpoint = config.openid_configuration["authorization_endpoint"]
  328. state = secrets.token_urlsafe(32)
  329. authUrl = (
  330. f'{authorization_endpoint}?response_type=code&'
  331. f'client_id={config.oidc_client_id}&'
  332. f'redirect_uri={config.oidc_redirect_uri}&'
  333. f'scope=openid profile email&state={state}'
  334. )
  335. response = RedirectResponse(url=authUrl)
  336. response.set_cookie(
  337. key=OIDC_STATE_COOKIE_NAME,
  338. value=state,
  339. httponly=True,
  340. secure=request.url.scheme == "https",
  341. samesite="lax",
  342. max_age=600,
  343. )
  344. return response
  345. @router.get("/oidc/callback")
  346. async def oidc_callback(request: Request, session: SessionDep):
  347. logger.debug("Invoke oidc callback.")
  348. config: Config = request.app.state.server_config
  349. query = dict(request.query_params)
  350. # Verify OIDC state to prevent CSRF attacks
  351. callback_state = query.get('state')
  352. cookie_state = request.cookies.get(OIDC_STATE_COOKIE_NAME)
  353. if not callback_state or not cookie_state:
  354. raise BadRequestException(message="Missing OIDC state parameter")
  355. if not secrets.compare_digest(callback_state, cookie_state):
  356. raise BadRequestException(message="OIDC state mismatch, possible CSRF attack")
  357. code = query['code']
  358. data = {
  359. "grant_type": "authorization_code",
  360. "code": code,
  361. "client_id": config.oidc_client_id,
  362. "client_secret": config.oidc_client_secret,
  363. "redirect_uri": config.oidc_redirect_uri,
  364. }
  365. token_endpoint = config.openid_configuration["token_endpoint"]
  366. use_proxy_env = use_proxy_env_for_url(token_endpoint)
  367. verify = get_system_trust_store_ssl_context()
  368. async with httpx.AsyncClient(
  369. timeout=timeout, verify=verify, trust_env=use_proxy_env
  370. ) as client:
  371. try:
  372. token_res = await client.request("POST", token_endpoint, data=data)
  373. res_data = json.loads(token_res.text)
  374. if token_res.status_code != 200:
  375. raise BadRequestException(
  376. message=f"Failed to get token, {res_data['error_description']}"
  377. )
  378. # Get user data from token or userinfo endpoint
  379. user_data = await get_oidc_user_data(client, res_data, config)
  380. if config.external_auth_name:
  381. # If external_auth_name is set, use it as username.
  382. username = user_data.get(config.external_auth_name)
  383. else:
  384. # Try common OIDC fields for username if external_auth_name is not set.
  385. # Ref: https://openid.net/specs/openid-connect-core-1_0.html#rfc.section.18.1.1
  386. for key in ["email", "sub"]:
  387. if key in user_data:
  388. username = user_data[key]
  389. break
  390. else:
  391. raise UnauthorizedException(
  392. message="No valid username found in user data"
  393. )
  394. if (
  395. config.external_auth_full_name
  396. and '+' not in config.external_auth_full_name
  397. ):
  398. full_name = user_data.get(config.external_auth_full_name)
  399. elif config.external_auth_full_name:
  400. full_name = ' '.join(
  401. [
  402. user_data.get(v.strip())
  403. for v in config.external_auth_full_name.split('+')
  404. ]
  405. )
  406. else:
  407. full_name = user_data.get("name", "")
  408. if config.external_auth_avatar_url:
  409. avatar_url = user_data.get(config.external_auth_avatar_url)
  410. else:
  411. avatar_url = user_data.get("picture", None)
  412. except Exception as e:
  413. logger.error(f"Get OIDC user info error: {str(e)}")
  414. raise UnauthorizedException(message=str(e))
  415. # determine whether the user already exists
  416. user = await User.first_by_field(session=session, field="username", value=username)
  417. # create user
  418. if not user:
  419. user_info = User(
  420. username=username,
  421. full_name=full_name,
  422. avatar_url=avatar_url,
  423. hashed_password="",
  424. is_admin=False,
  425. is_active=not config.external_auth_default_inactive,
  426. source=AuthProviderEnum.OIDC,
  427. require_password_change=False,
  428. )
  429. user = await create_user_with_principal(session, user_info)
  430. await session.commit()
  431. elif (
  432. getattr(user, "id", None) is not None
  433. and getattr(user, "principal_id", None) is None
  434. ):
  435. # Backfill for SSO users created before Personal Org
  436. # provisioning was wired in. Idempotent: only fires when the
  437. # user has no Personal Org pointer.
  438. await provision_user_principal(session, user)
  439. await session.commit()
  440. jwt_manager: JWTManager = request.app.state.jwt_manager
  441. access_token = jwt_manager.create_jwt_token(
  442. username=username,
  443. )
  444. response = RedirectResponse(url='/')
  445. response.set_cookie(
  446. key=SESSION_COOKIE_NAME,
  447. value=access_token,
  448. httponly=True,
  449. max_age=envs.JWT_TOKEN_EXPIRE_MINUTES * 60,
  450. expires=envs.JWT_TOKEN_EXPIRE_MINUTES * 60,
  451. )
  452. try:
  453. id_token = res_data.get("id_token")
  454. if id_token:
  455. response.set_cookie(
  456. key=OIDC_ID_TOKEN_COOKIE_NAME,
  457. value=id_token,
  458. httponly=True,
  459. max_age=envs.JWT_TOKEN_EXPIRE_MINUTES * 60,
  460. expires=envs.JWT_TOKEN_EXPIRE_MINUTES * 60,
  461. )
  462. response.set_cookie(
  463. key=SSO_LOGIN_COOKIE_NAME,
  464. value="true",
  465. httponly=True,
  466. max_age=envs.JWT_TOKEN_EXPIRE_MINUTES * 60,
  467. expires=envs.JWT_TOKEN_EXPIRE_MINUTES * 60,
  468. )
  469. except Exception as e:
  470. logger.warning(f"Failed to set id_token cookie: {str(e)}")
  471. response.delete_cookie(OIDC_STATE_COOKIE_NAME)
  472. return response
  473. # Local authentication endpoints
  474. @router.post("/login")
  475. async def login(
  476. request: Request,
  477. response: Response,
  478. session: SessionDep,
  479. username: Annotated[str, Form()] = "",
  480. password: Annotated[str, Form()] = "",
  481. ):
  482. user = await authenticate_user(session, username, password)
  483. user_name = user.username
  484. jwt_manager: JWTManager = request.app.state.jwt_manager
  485. access_token = jwt_manager.create_jwt_token(
  486. username=user_name,
  487. )
  488. response.set_cookie(
  489. key=SESSION_COOKIE_NAME,
  490. value=access_token,
  491. httponly=True,
  492. max_age=envs.JWT_TOKEN_EXPIRE_MINUTES * 60,
  493. expires=envs.JWT_TOKEN_EXPIRE_MINUTES * 60,
  494. )
  495. @router.post("/logout")
  496. async def logout(request: Request):
  497. config: Config = request.app.state.server_config
  498. external_logout_url = None
  499. sso_login = None # Ensure initialized before any conditional path
  500. if (
  501. config.external_auth_type == AuthProviderEnum.OIDC
  502. and config.openid_configuration
  503. ):
  504. end_session_endpoint = config.openid_configuration.get("end_session_endpoint")
  505. if end_session_endpoint:
  506. redirect_uri = str(config.server_external_url or request.base_url)
  507. params = {
  508. "client_id": config.oidc_client_id,
  509. "post_logout_redirect_uri": redirect_uri,
  510. "id_token_hint": request.cookies.get(OIDC_ID_TOKEN_COOKIE_NAME),
  511. }
  512. if config.external_auth_post_logout_redirect_key:
  513. params[config.external_auth_post_logout_redirect_key] = redirect_uri
  514. query = urlencode({k: v for k, v in params.items() if v})
  515. external_logout_url = (
  516. end_session_endpoint if not query else f"{end_session_endpoint}?{query}"
  517. )
  518. elif config.external_auth_type == AuthProviderEnum.SAML:
  519. try:
  520. auth = await init_saml_auth(request)
  521. redirect_uri = str(config.server_external_url or request.base_url)
  522. params = {}
  523. if config.external_auth_post_logout_redirect_key:
  524. params[config.external_auth_post_logout_redirect_key] = redirect_uri
  525. external_logout_url = auth.logout(return_to=redirect_uri)
  526. query = urlencode({k: v for k, v in params.items() if v})
  527. if query:
  528. external_logout_url += f"&{query}"
  529. except Exception as e:
  530. logger.error(f"Failed to get SAML logout url: {str(e)}")
  531. external_logout_url = None
  532. # SSO logout: return SSO platform logout URL
  533. sso_login = request.cookies.get(SSO_LOGIN_COOKIE_NAME)
  534. sso_logout_url = config.sso_logout_redirect_url
  535. if sso_login and sso_logout_url:
  536. external_logout_url = sso_logout_url
  537. content = json.dumps({"logout_url": external_logout_url}) if sso_login else ""
  538. resp = Response(content=content, media_type="application/json")
  539. resp.delete_cookie(key=SESSION_COOKIE_NAME)
  540. resp.delete_cookie(key=OIDC_ID_TOKEN_COOKIE_NAME)
  541. resp.delete_cookie(key=SSO_LOGIN_COOKIE_NAME)
  542. return resp
  543. @router.post("/update-password")
  544. async def update_password(
  545. request: Request,
  546. session: SessionDep,
  547. user: CurrentUserDep,
  548. update_in: UpdatePassword,
  549. ):
  550. if not verify_hashed_secret(user.hashed_password, update_in.current_password):
  551. raise InvalidException(message="Incorrect current password")
  552. hashed_password = get_secret_hash(update_in.new_password)
  553. patch = {"hashed_password": hashed_password, "require_password_change": False}
  554. await user.update(session, patch)
  555. remove_initial_password_file_if_exists(request.app.state.server_config)
  556. @router.get("/config")
  557. async def get_auth_config(request: Request):
  558. req_dict = {}
  559. config: Config = request.app.state.server_config
  560. auth_type = (config.external_auth_type or "Local").lower()
  561. if auth_type == "oidc":
  562. req_dict = {"is_oidc": True, "is_saml": False}
  563. elif auth_type == "saml":
  564. req_dict = {"is_oidc": False, "is_saml": True}
  565. initial_password_file = Path(config.data_dir) / "initial_admin_password"
  566. if initial_password_file.exists():
  567. req_dict["first_time_setup"] = True
  568. req_dict["get_initial_password_command"] = _get_initial_password_command(
  569. initial_password_file
  570. )
  571. return req_dict
  572. def _get_initial_password_command(initial_password_file: Path) -> str:
  573. """
  574. Get the command to retrieve the initial admin password.
  575. """
  576. if os.getenv("KUBERNETES_SERVICE_HOST") is not None:
  577. # Kubernetes
  578. pod_name = os.getenv("HOSTNAME", "<pod_name>")
  579. namespace_file = Path("/var/run/secrets/kubernetes.io/serviceaccount/namespace")
  580. namespace = (
  581. namespace_file.read_text().strip()
  582. if namespace_file.exists()
  583. else "<namespace>"
  584. )
  585. return f"kubectl exec {pod_name} -n {namespace} -- cat {initial_password_file}"
  586. elif Path("/.dockerenv").exists():
  587. # Docker
  588. return f"docker exec <container_name_or_id> cat {initial_password_file}"
  589. else:
  590. # Non-containerized
  591. return f"cat {initial_password_file}"
  592. def remove_initial_password_file_if_exists(config: Config):
  593. """
  594. Remove the initial admin password file if it exists.
  595. """
  596. initial_password_file = Path(config.data_dir) / "initial_admin_password"
  597. if initial_password_file.exists():
  598. try:
  599. initial_password_file.unlink()
  600. logger.debug(f"Initial password file deleted: {initial_password_file}")
  601. except Exception as e:
  602. logger.warning(f"Failed to delete initial password file: {e}")
  603. # SSO (LQAI-middle-platform) OAuth2 integration endpoints
  604. from gpustack.api.sso import (
  605. build_sso_authorize_url,
  606. handle_sso_exchange_code,
  607. )
  608. from pydantic import BaseModel
  609. class ExchangeCodeRequest(BaseModel):
  610. code: str
  611. @router.get("/sso/authorize")
  612. async def sso_authorize(request: Request, redirect: bool = False):
  613. """
  614. Build SSO OAuth2 authorization URL.
  615. If redirect=True, directly 302 redirect to SSO authorization page.
  616. """
  617. config: Config = request.app.state.server_config
  618. if not config.sso_base_url or not config.sso_client_id:
  619. raise InvalidException(message="SSO 未配置,请先配置 SSO_BASE_URL 和 SSO_CLIENT_ID")
  620. authorize_url = build_sso_authorize_url(config)
  621. if redirect:
  622. return RedirectResponse(url=authorize_url)
  623. return {
  624. "code": "000000",
  625. "message": "获取授权URL成功",
  626. "data": {"authorize_url": authorize_url},
  627. }
  628. @router.post("/oauth/exchange-code")
  629. async def oauth_exchange_code(
  630. request: Request,
  631. session: SessionDep,
  632. body: ExchangeCodeRequest,
  633. ):
  634. """
  635. Exchange SSO authorization code for local JWT.
  636. Core SSO login endpoint.
  637. """
  638. config: Config = request.app.state.server_config
  639. if not config.sso_base_url or not config.sso_client_id:
  640. raise InvalidException(message="SSO 未配置")
  641. if not body.code:
  642. raise BadRequestException(message="缺少授权码")
  643. try:
  644. jwt_manager: JWTManager = request.app.state.jwt_manager
  645. result = await handle_sso_exchange_code(session, config, body.code, jwt_manager)
  646. return {
  647. "code": "000000",
  648. "message": "登录成功",
  649. "data": result,
  650. }
  651. except Exception as e:
  652. logger.error(f"SSO exchange failed: {e}")
  653. error_msg = str(e)
  654. if "invalid_grant" in error_msg or "授权码" in error_msg:
  655. raise BadRequestException(message=f"登录失败: 授权码无效")
  656. raise InvalidException(message=f"登录失败: {error_msg}")