auth.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451
  1. import re
  2. import uuid
  3. from datetime import datetime, timezone
  4. import logging
  5. import aiohttp
  6. from aiocache import cached
  7. from fastapi import Depends, Request, WebSocket
  8. from starlette.datastructures import Headers
  9. from gpustack.config.config import Config
  10. from gpustack.schemas.config import GatewayModeEnum
  11. from gpustack.server.db import get_session, async_session
  12. from typing import Annotated, Optional, Tuple, List, Dict
  13. from fastapi.security import (
  14. APIKeyCookie,
  15. APIKeyHeader,
  16. HTTPAuthorizationCredentials,
  17. HTTPBasic,
  18. HTTPBasicCredentials,
  19. HTTPBearer,
  20. )
  21. from fastapi.security.utils import get_authorization_scheme_param
  22. from sqlmodel.ext.asyncio.session import AsyncSession
  23. from gpustack.api.exceptions import (
  24. ForbiddenException,
  25. InternalServerErrorException,
  26. UnauthorizedException,
  27. HTTPException,
  28. )
  29. from gpustack.schemas.api_keys import ApiKey, PermissionScope
  30. from gpustack.schemas.users import User, UserRole
  31. from gpustack.security import (
  32. JWTManager,
  33. verify_hashed_secret,
  34. get_key_pair,
  35. )
  36. from gpustack.server.services import APIKeyService, UserService, WorkerService
  37. from gpustack.websocket_proxy.authenticator import (
  38. Authenticator as WebsocketAuthenticator,
  39. )
  40. logger = logging.getLogger(__name__)
  41. SESSION_COOKIE_NAME = "gpustack_session"
  42. OIDC_ID_TOKEN_COOKIE_NAME = "gpustack_oidc_id_token"
  43. OIDC_STATE_COOKIE_NAME = "gpustack_oidc_state"
  44. SSO_LOGIN_COOKIE_NAME = "gpustack_sso_login"
  45. SYSTEM_USER_PREFIX = "system/"
  46. SYSTEM_WORKER_USER_PREFIX = "system/worker/"
  47. GATEWAY_AUTH_TOKEN_HEADER = "X-GPUStack-Auth-Token"
  48. basic_auth = HTTPBasic(auto_error=False)
  49. bearer_auth = HTTPBearer(auto_error=False)
  50. api_key_header_auth = APIKeyHeader(name="X-API-Key", auto_error=False)
  51. cookie_auth = APIKeyCookie(name=SESSION_COOKIE_NAME, auto_error=False)
  52. _gateway_auth_header = APIKeyHeader(name=GATEWAY_AUTH_TOKEN_HEADER, auto_error=False)
  53. credentials_exception = UnauthorizedException(
  54. message="Invalid authentication credentials"
  55. )
  56. def gateway_token_auth(
  57. request: Request,
  58. token: Annotated[Optional[str], Depends(_gateway_auth_header)] = None,
  59. ):
  60. if not token:
  61. raise UnauthorizedException(message="Missing authentication token")
  62. cfg: Config = request.app.state.server_config
  63. if token != cfg.get_derived_gateway_token():
  64. raise UnauthorizedException(message="Invalid gateway token")
  65. def client_ip_getter(request: Request) -> str:
  66. if request.app.state.server_config.gateway_mode == GatewayModeEnum.embedded:
  67. return request.headers.get("X-GPUStack-Real-IP", "")
  68. else:
  69. return request.client.host
  70. async def get_current_user(
  71. request: Request,
  72. session: Annotated[AsyncSession, Depends(get_session)],
  73. basic_credentials: Annotated[
  74. Optional[HTTPBasicCredentials], Depends(basic_auth)
  75. ] = None,
  76. bearer_token: Annotated[
  77. Optional[HTTPAuthorizationCredentials], Depends(bearer_auth)
  78. ] = None,
  79. x_api_key: Annotated[Optional[str], Depends(api_key_header_auth)] = None,
  80. cookie_token: Annotated[Optional[str], Depends(cookie_auth)] = None,
  81. ) -> User:
  82. if hasattr(request.state, "user"):
  83. user: User = getattr(request.state, "user")
  84. return user
  85. api_key: Optional[ApiKey] = None
  86. user = None
  87. try:
  88. server_config: Config = request.app.state.server_config
  89. if basic_credentials and is_system_user(basic_credentials.username):
  90. user = await authenticate_system_user(server_config, basic_credentials)
  91. elif basic_credentials:
  92. user = await authenticate_basic_user(session, basic_credentials)
  93. elif cookie_token:
  94. jwt_manager: JWTManager = request.app.state.jwt_manager
  95. user = await get_user_from_jwt_token(session, jwt_manager, cookie_token)
  96. elif bearer_token or x_api_key:
  97. token = (bearer_token.credentials if bearer_token else None) or x_api_key
  98. if token is not None:
  99. # Try API key first, fall back to JWT token validation
  100. user, api_key = await get_user_from_api_token(session, token)
  101. if user is None:
  102. # Bearer token might be a JWT (e.g. from SSO callback)
  103. jwt_manager: JWTManager = request.app.state.jwt_manager
  104. user = await get_user_from_jwt_token(session, jwt_manager, token)
  105. if user is None and client_ip_getter(request=request) == "127.0.0.1":
  106. if not server_config.force_auth_localhost:
  107. user = await User.first_by_field(session, "is_admin", True)
  108. if user:
  109. if not user.is_active:
  110. raise UnauthorizedException(message="User account is deactivated")
  111. request.state.user = user
  112. if api_key is not None:
  113. request.state.api_key = api_key
  114. return user
  115. except HTTPException:
  116. raise
  117. except Exception as e:
  118. raise InternalServerErrorException(message=f"Failed to authenticate user: {e}")
  119. raise credentials_exception
  120. async def get_admin_user(
  121. current_user: Annotated[User, Depends(get_current_user)],
  122. ) -> User:
  123. if not current_user.is_admin:
  124. raise ForbiddenException(message="No permission to access")
  125. return current_user
  126. async def get_cluster_user(
  127. current_user: Annotated[User, Depends(get_current_user)],
  128. ) -> User:
  129. if (
  130. current_user.is_system
  131. and current_user.role == UserRole.Cluster
  132. and current_user.cluster_id is not None
  133. ):
  134. return current_user
  135. return await get_admin_user(current_user)
  136. async def get_worker_user(
  137. current_user: Annotated[User, Depends(get_current_user)],
  138. ) -> User:
  139. if (
  140. current_user.is_system
  141. and current_user.role == UserRole.Worker
  142. and current_user.worker is not None
  143. ):
  144. return current_user
  145. return await get_admin_user(current_user)
  146. def is_system_user(username: str) -> bool:
  147. return username.startswith(SYSTEM_USER_PREFIX)
  148. async def authenticate_system_user(
  149. config: Config,
  150. credentials: HTTPBasicCredentials,
  151. ) -> Optional[User]:
  152. if credentials.username.startswith(SYSTEM_WORKER_USER_PREFIX):
  153. if credentials.password == config.token:
  154. return User(username=credentials.username, is_admin=True)
  155. return None
  156. async def authenticate_basic_user(
  157. session: AsyncSession,
  158. basic_credentials: HTTPBasicCredentials,
  159. ) -> Optional[User]:
  160. try:
  161. user = await authenticate_user(
  162. session, basic_credentials.username, basic_credentials.password
  163. )
  164. return user
  165. except Exception:
  166. return None
  167. def get_access_token(
  168. bearer_token: Optional[HTTPAuthorizationCredentials],
  169. oauth2_bearer_token: Optional[str],
  170. cookie_token: Optional[str],
  171. ) -> str:
  172. if bearer_token:
  173. return bearer_token.credentials
  174. elif oauth2_bearer_token:
  175. return oauth2_bearer_token
  176. elif cookie_token:
  177. return cookie_token
  178. else:
  179. raise credentials_exception
  180. async def get_user_from_jwt_token(
  181. session: AsyncSession, jwt_manager: JWTManager, access_token: str
  182. ) -> Optional[User]:
  183. try:
  184. payload = jwt_manager.decode_jwt_token(access_token)
  185. username = payload.get("sub")
  186. except Exception:
  187. logger.debug("Failed to decode JWT token")
  188. return None
  189. if username is None:
  190. return None
  191. try:
  192. user = await UserService(session).get_by_username(username)
  193. except Exception as e:
  194. raise InternalServerErrorException(message=f"Failed to get user: {e}")
  195. return user
  196. def parse_hyphen_uuid(value: str) -> Optional[str]:
  197. if not re.match(
  198. r'^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$', value, re.I
  199. ):
  200. return None
  201. try:
  202. uuid.UUID(value)
  203. return value
  204. except ValueError:
  205. return None
  206. async def get_user_from_api_token(
  207. session: AsyncSession, token: str
  208. ) -> Tuple[Optional[User], Optional[ApiKey]]:
  209. try:
  210. access_key, secret_key = get_key_pair(token)
  211. worker_uuid = parse_hyphen_uuid(access_key)
  212. if worker_uuid is not None:
  213. # if access_key is a valid uuid, it's legacy worker re-registering with legacy token
  214. access_key = ""
  215. access_keys = [access_key]
  216. # the custom key should have 32 chars access key as it is generated by security.custom_key_hash which will return 32 chars hex string.
  217. if len(access_key) == 32 and "" not in access_keys:
  218. # this means it is custom key or legacy worker token, we should also try to find api key with empty access key for backward compatibility
  219. access_keys.append("")
  220. for access_key in access_keys:
  221. api_key: ApiKey = await APIKeyService(session).get_by_access_key(access_key)
  222. if api_key:
  223. logger.trace(f"Found API key for access key: {access_key}")
  224. break
  225. if (
  226. api_key is not None
  227. and verify_hashed_secret(api_key.hashed_secret_key, secret_key)
  228. and (
  229. api_key.expires_at is None
  230. or api_key.expires_at > datetime.now(timezone.utc)
  231. )
  232. ):
  233. user: Optional[User] = await UserService(session).get_by_id(
  234. user_id=api_key.user_id,
  235. )
  236. if user is not None:
  237. return user, api_key
  238. except Exception as e:
  239. raise InternalServerErrorException(message=f"Failed to get user: {e}")
  240. return None, None
  241. async def authenticate_user(
  242. session: AsyncSession, username: str, password: str
  243. ) -> User:
  244. user = await UserService(session).get_by_username(username)
  245. if not user:
  246. raise UnauthorizedException(message="Incorrect username or password")
  247. if not verify_hashed_secret(user.hashed_password, password):
  248. raise UnauthorizedException(message="Incorrect username or password")
  249. if not user.is_active:
  250. raise UnauthorizedException(message="User account is deactivated")
  251. return user
  252. async def worker_auth(
  253. request: Request,
  254. bearer_token: Annotated[
  255. Optional[HTTPAuthorizationCredentials], Depends(bearer_auth)
  256. ] = None,
  257. x_api_key: Annotated[Optional[str], Depends(api_key_header_auth)] = None,
  258. ):
  259. token_value = (bearer_token.credentials if bearer_token else None) or x_api_key
  260. if not token_value:
  261. raise UnauthorizedException(message="Invalid authentication credentials")
  262. token = request.app.state.token
  263. config: Config = request.app.state.config
  264. registration_token = config.token
  265. server_url = config.get_server_url()
  266. if token_value in [token, registration_token]:
  267. return
  268. model_name = request.headers.get("X-Higress-Llm-Model")
  269. if model_name is not None:
  270. cred = token_value
  271. show_len = max(1, min(6, len(cred)))
  272. masked_token = f"{'*' * (len(cred) - show_len)}{cred[-show_len:]}"
  273. logger.debug(f"Verifying worker token {masked_token} via server authentication")
  274. cached_auth = make_auth_token_via_server(request.app.state.http_client_no_proxy)
  275. is_valid = await cached_auth(server_url, token_value, model_name)
  276. if is_valid:
  277. return
  278. raise UnauthorizedException(message="Invalid authentication credentials")
  279. def make_auth_token_via_server(client: aiohttp.ClientSession):
  280. @cached(ttl=60)
  281. async def inner(server_url: str, token: str, model_name: str) -> bool:
  282. auth_url = f"{server_url.rstrip('/')}/token-auth"
  283. headers = {
  284. "Authorization": f"Bearer {token}",
  285. "X-Higress-Llm-Model": model_name,
  286. }
  287. try:
  288. async with client.get(auth_url, headers=headers) as resp:
  289. return resp.status == 200
  290. except aiohttp.ClientError as e:
  291. logger.error(f"Error verifying token via server: {e}")
  292. return False
  293. return inner
  294. def get_scopes(
  295. request: Request, _current_user: Annotated[User, Depends(get_current_user)]
  296. ) -> List[str]:
  297. api_key: ApiKey = getattr(request.state, "api_key", None)
  298. if api_key is not None:
  299. return api_key.scope
  300. return [PermissionScope.ALL]
  301. def inference_scope(
  302. request: Request, _current_user: Annotated[User, Depends(get_current_user)]
  303. ):
  304. scopes = get_scopes(request, _current_user)
  305. if PermissionScope.ALL not in scopes and PermissionScope.INFERENCE not in scopes:
  306. raise ForbiddenException(
  307. message="API key does not have permission to access inference features"
  308. )
  309. def management_scope(
  310. request: Request, _current_user: Annotated[User, Depends(get_current_user)]
  311. ):
  312. scopes = get_scopes(request, _current_user)
  313. if PermissionScope.ALL not in scopes and PermissionScope.MANAGEMENT not in scopes:
  314. raise ForbiddenException(
  315. message="API key does not have permission to access management features"
  316. )
  317. async def authenticate_worker_by_request_headers(
  318. header_dict: Dict[str, str],
  319. validate_proxy: Optional[bool] = None,
  320. ) -> Optional[User]:
  321. """
  322. Authenticate a worker based on request headers, used for both WebSocket and non-WebSocket requests.
  323. For WebSocket requests, the Bearer token is expected in the "Authorization" header.
  324. For non-WebSocket requests (e.g. HTTP requests to the proxy), the Bearer token can be in either "Authorization"
  325. or "Proxy-Authorization" header, with "Proxy-Authorization" taking precedence if both are present.
  326. """
  327. headers = Headers(header_dict)
  328. authorization: Optional[str] = None
  329. if validate_proxy:
  330. authorization = headers.get("Proxy-Authorization")
  331. elif validate_proxy is not None:
  332. authorization = headers.get("Authorization")
  333. else:
  334. # if validate_proxy is None, it means we are in a context where both headers could be used (e.g. WebSocket connection from the proxy)
  335. # in this case we give precedence to Proxy-Authorization if it exists, otherwise fall back to Authorization
  336. authorization = headers.get("Proxy-Authorization") or headers.get(
  337. "Authorization"
  338. )
  339. async with async_session() as session:
  340. scheme, credentials = get_authorization_scheme_param(authorization)
  341. if not (authorization and scheme and credentials) or scheme.lower() != "bearer":
  342. return None
  343. bearer_token = HTTPAuthorizationCredentials(
  344. scheme=scheme, credentials=credentials
  345. )
  346. user, _ = await get_user_from_api_token(session, bearer_token.credentials)
  347. if user is None:
  348. return None
  349. if user.worker_id is not None:
  350. user.worker = await WorkerService(session).get_by_id(user.worker_id)
  351. return user
  352. class BearerTokenAuthenticator(WebsocketAuthenticator):
  353. """Websocket authenticator that verifies bearer tokens via the main server."""
  354. token: Optional[str]
  355. def __init__(
  356. self,
  357. token: Optional[str] = None,
  358. headers: Optional[Dict[str, str]] = None,
  359. ):
  360. self.token = token
  361. if not self.token and headers:
  362. parsed_headers = Headers(headers)
  363. self.token = parsed_headers.get("Authorization", "").replace("Bearer ", "")
  364. def inject_headers(
  365. self,
  366. headers: Dict[str, str],
  367. ) -> None:
  368. # No need to inject headers for outgoing connections from the proxy
  369. for key in list(headers.keys()):
  370. if key.lower() == "authorization":
  371. headers.pop(key)
  372. if self.token:
  373. headers.setdefault("Authorization", f"Bearer {self.token}")
  374. async def authenticate(self, websocket: WebSocket) -> bool:
  375. user = await authenticate_worker_by_request_headers(
  376. websocket.headers, validate_proxy=False
  377. )
  378. if user is None:
  379. return False
  380. if user.worker is None:
  381. logger.debug(
  382. f"Authenticated user {user.id} with bearer token but it is not associated with any worker"
  383. )
  384. return False
  385. if websocket.headers.get("x-client-id") != user.worker.worker_uuid:
  386. logger.debug(
  387. f"Authenticated worker {user.worker_id} with bearer token but client_id {websocket.headers.get('x-client-id')} does not match worker_uuid {user.worker.worker_uuid}"
  388. )
  389. return False
  390. websocket.scope["user"] = user
  391. return True