auth.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445
  1. import re
  2. import uuid
  3. from datetime import datetime, timezone
  4. import logging
  5. import aiohttp
  6. from aiocache import cached
  7. from fastapi import Depends, Request, WebSocket
  8. from starlette.datastructures import Headers
  9. from gpustack.config.config import Config
  10. from gpustack.schemas.config import GatewayModeEnum
  11. from gpustack.server.db import get_session, async_session
  12. from typing import Annotated, Optional, Tuple, List, Dict
  13. from fastapi.security import (
  14. APIKeyCookie,
  15. APIKeyHeader,
  16. HTTPAuthorizationCredentials,
  17. HTTPBasic,
  18. HTTPBasicCredentials,
  19. HTTPBearer,
  20. )
  21. from fastapi.security.utils import get_authorization_scheme_param
  22. from sqlmodel.ext.asyncio.session import AsyncSession
  23. from gpustack.api.exceptions import (
  24. ForbiddenException,
  25. InternalServerErrorException,
  26. UnauthorizedException,
  27. HTTPException,
  28. )
  29. from gpustack.schemas.api_keys import ApiKey, PermissionScope
  30. from gpustack.schemas.users import User, UserRole
  31. from gpustack.security import (
  32. JWTManager,
  33. verify_hashed_secret,
  34. get_key_pair,
  35. )
  36. from gpustack.server.services import APIKeyService, UserService, WorkerService
  37. from gpustack.websocket_proxy.authenticator import (
  38. Authenticator as WebsocketAuthenticator,
  39. )
  40. logger = logging.getLogger(__name__)
  41. SESSION_COOKIE_NAME = "gpustack_session"
  42. OIDC_ID_TOKEN_COOKIE_NAME = "gpustack_oidc_id_token"
  43. SSO_LOGIN_COOKIE_NAME = "gpustack_sso_login"
  44. SYSTEM_USER_PREFIX = "system/"
  45. SYSTEM_WORKER_USER_PREFIX = "system/worker/"
  46. GATEWAY_AUTH_TOKEN_HEADER = "X-GPUStack-Auth-Token"
  47. basic_auth = HTTPBasic(auto_error=False)
  48. bearer_auth = HTTPBearer(auto_error=False)
  49. api_key_header_auth = APIKeyHeader(name="X-API-Key", auto_error=False)
  50. cookie_auth = APIKeyCookie(name=SESSION_COOKIE_NAME, auto_error=False)
  51. _gateway_auth_header = APIKeyHeader(name=GATEWAY_AUTH_TOKEN_HEADER, auto_error=False)
  52. credentials_exception = UnauthorizedException(
  53. message="Invalid authentication credentials"
  54. )
  55. def gateway_token_auth(
  56. request: Request,
  57. token: Annotated[Optional[str], Depends(_gateway_auth_header)] = None,
  58. ):
  59. if not token:
  60. raise UnauthorizedException(message="Missing authentication token")
  61. cfg: Config = request.app.state.server_config
  62. if token != cfg.get_derived_gateway_token():
  63. raise UnauthorizedException(message="Invalid gateway token")
  64. def client_ip_getter(request: Request) -> str:
  65. if request.app.state.server_config.gateway_mode == GatewayModeEnum.embedded:
  66. return request.headers.get("X-GPUStack-Real-IP", "")
  67. else:
  68. return request.client.host
  69. async def get_current_user(
  70. request: Request,
  71. session: Annotated[AsyncSession, Depends(get_session)],
  72. basic_credentials: Annotated[
  73. Optional[HTTPBasicCredentials], Depends(basic_auth)
  74. ] = None,
  75. bearer_token: Annotated[
  76. Optional[HTTPAuthorizationCredentials], Depends(bearer_auth)
  77. ] = None,
  78. x_api_key: Annotated[Optional[str], Depends(api_key_header_auth)] = None,
  79. cookie_token: Annotated[Optional[str], Depends(cookie_auth)] = None,
  80. ) -> User:
  81. if hasattr(request.state, "user"):
  82. user: User = getattr(request.state, "user")
  83. return user
  84. api_key: Optional[ApiKey] = None
  85. user = None
  86. try:
  87. server_config: Config = request.app.state.server_config
  88. if basic_credentials and is_system_user(basic_credentials.username):
  89. user = await authenticate_system_user(server_config, basic_credentials)
  90. elif basic_credentials:
  91. user = await authenticate_basic_user(session, basic_credentials)
  92. elif cookie_token:
  93. jwt_manager: JWTManager = request.app.state.jwt_manager
  94. user = await get_user_from_jwt_token(session, jwt_manager, cookie_token)
  95. elif bearer_token or x_api_key:
  96. token = (bearer_token.credentials if bearer_token else None) or x_api_key
  97. if token is not None:
  98. user, api_key = await get_user_from_api_token(session, token)
  99. if user is None and client_ip_getter(request=request) == "127.0.0.1":
  100. if not server_config.force_auth_localhost:
  101. user = await User.first_by_field(session, "is_admin", True)
  102. if user:
  103. if not user.is_active:
  104. raise UnauthorizedException(message="User account is deactivated")
  105. request.state.user = user
  106. if api_key is not None:
  107. request.state.api_key = api_key
  108. return user
  109. except HTTPException:
  110. raise
  111. except Exception as e:
  112. raise InternalServerErrorException(message=f"Failed to authenticate user: {e}")
  113. raise credentials_exception
  114. async def get_admin_user(
  115. current_user: Annotated[User, Depends(get_current_user)],
  116. ) -> User:
  117. if not current_user.is_admin:
  118. raise ForbiddenException(message="No permission to access")
  119. return current_user
  120. async def get_cluster_user(
  121. current_user: Annotated[User, Depends(get_current_user)],
  122. ) -> User:
  123. if (
  124. current_user.is_system
  125. and current_user.role == UserRole.Cluster
  126. and current_user.cluster_id is not None
  127. ):
  128. return current_user
  129. return await get_admin_user(current_user)
  130. async def get_worker_user(
  131. current_user: Annotated[User, Depends(get_current_user)],
  132. ) -> User:
  133. if (
  134. current_user.is_system
  135. and current_user.role == UserRole.Worker
  136. and current_user.worker is not None
  137. ):
  138. return current_user
  139. return await get_admin_user(current_user)
  140. def is_system_user(username: str) -> bool:
  141. return username.startswith(SYSTEM_USER_PREFIX)
  142. async def authenticate_system_user(
  143. config: Config,
  144. credentials: HTTPBasicCredentials,
  145. ) -> Optional[User]:
  146. if credentials.username.startswith(SYSTEM_WORKER_USER_PREFIX):
  147. if credentials.password == config.token:
  148. return User(username=credentials.username, is_admin=True)
  149. return None
  150. async def authenticate_basic_user(
  151. session: AsyncSession,
  152. basic_credentials: HTTPBasicCredentials,
  153. ) -> Optional[User]:
  154. try:
  155. user = await authenticate_user(
  156. session, basic_credentials.username, basic_credentials.password
  157. )
  158. return user
  159. except Exception:
  160. return None
  161. def get_access_token(
  162. bearer_token: Optional[HTTPAuthorizationCredentials],
  163. oauth2_bearer_token: Optional[str],
  164. cookie_token: Optional[str],
  165. ) -> str:
  166. if bearer_token:
  167. return bearer_token.credentials
  168. elif oauth2_bearer_token:
  169. return oauth2_bearer_token
  170. elif cookie_token:
  171. return cookie_token
  172. else:
  173. raise credentials_exception
  174. async def get_user_from_jwt_token(
  175. session: AsyncSession, jwt_manager: JWTManager, access_token: str
  176. ) -> Optional[User]:
  177. try:
  178. payload = jwt_manager.decode_jwt_token(access_token)
  179. username = payload.get("sub")
  180. except Exception:
  181. logger.debug("Failed to decode JWT token")
  182. return None
  183. if username is None:
  184. return None
  185. try:
  186. user = await UserService(session).get_by_username(username)
  187. except Exception as e:
  188. raise InternalServerErrorException(message=f"Failed to get user: {e}")
  189. return user
  190. def parse_hyphen_uuid(value: str) -> Optional[str]:
  191. if not re.match(
  192. r'^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$', value, re.I
  193. ):
  194. return None
  195. try:
  196. uuid.UUID(value)
  197. return value
  198. except ValueError:
  199. return None
  200. async def get_user_from_api_token(
  201. session: AsyncSession, token: str
  202. ) -> Tuple[Optional[User], Optional[ApiKey]]:
  203. try:
  204. access_key, secret_key = get_key_pair(token)
  205. worker_uuid = parse_hyphen_uuid(access_key)
  206. if worker_uuid is not None:
  207. # if access_key is a valid uuid, it's legacy worker re-registering with legacy token
  208. access_key = ""
  209. access_keys = [access_key]
  210. # the custom key should have 32 chars access key as it is generated by security.custom_key_hash which will return 32 chars hex string.
  211. if len(access_key) == 32 and "" not in access_keys:
  212. # this means it is custom key or legacy worker token, we should also try to find api key with empty access key for backward compatibility
  213. access_keys.append("")
  214. for access_key in access_keys:
  215. api_key: ApiKey = await APIKeyService(session).get_by_access_key(access_key)
  216. if api_key:
  217. logger.trace(f"Found API key for access key: {access_key}")
  218. break
  219. if (
  220. api_key is not None
  221. and verify_hashed_secret(api_key.hashed_secret_key, secret_key)
  222. and (
  223. api_key.expires_at is None
  224. or api_key.expires_at > datetime.now(timezone.utc)
  225. )
  226. ):
  227. user: Optional[User] = await UserService(session).get_by_id(
  228. user_id=api_key.user_id,
  229. )
  230. if user is not None:
  231. return user, api_key
  232. except Exception as e:
  233. raise InternalServerErrorException(message=f"Failed to get user: {e}")
  234. return None, None
  235. async def authenticate_user(
  236. session: AsyncSession, username: str, password: str
  237. ) -> User:
  238. user = await UserService(session).get_by_username(username)
  239. if not user:
  240. raise UnauthorizedException(message="Incorrect username or password")
  241. if not verify_hashed_secret(user.hashed_password, password):
  242. raise UnauthorizedException(message="Incorrect username or password")
  243. if not user.is_active:
  244. raise UnauthorizedException(message="User account is deactivated")
  245. return user
  246. async def worker_auth(
  247. request: Request,
  248. bearer_token: Annotated[
  249. Optional[HTTPAuthorizationCredentials], Depends(bearer_auth)
  250. ] = None,
  251. x_api_key: Annotated[Optional[str], Depends(api_key_header_auth)] = None,
  252. ):
  253. token_value = (bearer_token.credentials if bearer_token else None) or x_api_key
  254. if not token_value:
  255. raise UnauthorizedException(message="Invalid authentication credentials")
  256. token = request.app.state.token
  257. config: Config = request.app.state.config
  258. registration_token = config.token
  259. server_url = config.get_server_url()
  260. if token_value in [token, registration_token]:
  261. return
  262. model_name = request.headers.get("X-Higress-Llm-Model")
  263. if model_name is not None:
  264. cred = token_value
  265. show_len = max(1, min(6, len(cred)))
  266. masked_token = f"{'*' * (len(cred) - show_len)}{cred[-show_len:]}"
  267. logger.debug(f"Verifying worker token {masked_token} via server authentication")
  268. cached_auth = make_auth_token_via_server(request.app.state.http_client_no_proxy)
  269. is_valid = await cached_auth(server_url, token_value, model_name)
  270. if is_valid:
  271. return
  272. raise UnauthorizedException(message="Invalid authentication credentials")
  273. def make_auth_token_via_server(client: aiohttp.ClientSession):
  274. @cached(ttl=60)
  275. async def inner(server_url: str, token: str, model_name: str) -> bool:
  276. auth_url = f"{server_url.rstrip('/')}/token-auth"
  277. headers = {
  278. "Authorization": f"Bearer {token}",
  279. "X-Higress-Llm-Model": model_name,
  280. }
  281. try:
  282. async with client.get(auth_url, headers=headers) as resp:
  283. return resp.status == 200
  284. except aiohttp.ClientError as e:
  285. logger.error(f"Error verifying token via server: {e}")
  286. return False
  287. return inner
  288. def get_scopes(
  289. request: Request, _current_user: Annotated[User, Depends(get_current_user)]
  290. ) -> List[str]:
  291. api_key: ApiKey = getattr(request.state, "api_key", None)
  292. if api_key is not None:
  293. return api_key.scope
  294. return [PermissionScope.ALL]
  295. def inference_scope(
  296. request: Request, _current_user: Annotated[User, Depends(get_current_user)]
  297. ):
  298. scopes = get_scopes(request, _current_user)
  299. if PermissionScope.ALL not in scopes and PermissionScope.INFERENCE not in scopes:
  300. raise ForbiddenException(
  301. message="API key does not have permission to access inference features"
  302. )
  303. def management_scope(
  304. request: Request, _current_user: Annotated[User, Depends(get_current_user)]
  305. ):
  306. scopes = get_scopes(request, _current_user)
  307. if PermissionScope.ALL not in scopes and PermissionScope.MANAGEMENT not in scopes:
  308. raise ForbiddenException(
  309. message="API key does not have permission to access management features"
  310. )
  311. async def authenticate_worker_by_request_headers(
  312. header_dict: Dict[str, str],
  313. validate_proxy: Optional[bool] = None,
  314. ) -> Optional[User]:
  315. """
  316. Authenticate a worker based on request headers, used for both WebSocket and non-WebSocket requests.
  317. For WebSocket requests, the Bearer token is expected in the "Authorization" header.
  318. For non-WebSocket requests (e.g. HTTP requests to the proxy), the Bearer token can be in either "Authorization"
  319. or "Proxy-Authorization" header, with "Proxy-Authorization" taking precedence if both are present.
  320. """
  321. headers = Headers(header_dict)
  322. authorization: Optional[str] = None
  323. if validate_proxy:
  324. authorization = headers.get("Proxy-Authorization")
  325. elif validate_proxy is not None:
  326. authorization = headers.get("Authorization")
  327. else:
  328. # if validate_proxy is None, it means we are in a context where both headers could be used (e.g. WebSocket connection from the proxy)
  329. # in this case we give precedence to Proxy-Authorization if it exists, otherwise fall back to Authorization
  330. authorization = headers.get("Proxy-Authorization") or headers.get(
  331. "Authorization"
  332. )
  333. async with async_session() as session:
  334. scheme, credentials = get_authorization_scheme_param(authorization)
  335. if not (authorization and scheme and credentials) or scheme.lower() != "bearer":
  336. return None
  337. bearer_token = HTTPAuthorizationCredentials(
  338. scheme=scheme, credentials=credentials
  339. )
  340. user, _ = await get_user_from_api_token(session, bearer_token.credentials)
  341. if user is None:
  342. return None
  343. if user.worker_id is not None:
  344. user.worker = await WorkerService(session).get_by_id(user.worker_id)
  345. return user
  346. class BearerTokenAuthenticator(WebsocketAuthenticator):
  347. """Websocket authenticator that verifies bearer tokens via the main server."""
  348. token: Optional[str]
  349. def __init__(
  350. self,
  351. token: Optional[str] = None,
  352. headers: Optional[Dict[str, str]] = None,
  353. ):
  354. self.token = token
  355. if not self.token and headers:
  356. parsed_headers = Headers(headers)
  357. self.token = parsed_headers.get("Authorization", "").replace("Bearer ", "")
  358. def inject_headers(
  359. self,
  360. headers: Dict[str, str],
  361. ) -> None:
  362. # No need to inject headers for outgoing connections from the proxy
  363. for key in list(headers.keys()):
  364. if key.lower() == "authorization":
  365. headers.pop(key)
  366. if self.token:
  367. headers.setdefault("Authorization", f"Bearer {self.token}")
  368. async def authenticate(self, websocket: WebSocket) -> bool:
  369. user = await authenticate_worker_by_request_headers(
  370. websocket.headers, validate_proxy=False
  371. )
  372. if user is None:
  373. return False
  374. if user.worker is None:
  375. logger.debug(
  376. f"Authenticated user {user.id} with bearer token but it is not associated with any worker"
  377. )
  378. return False
  379. if websocket.headers.get("x-client-id") != user.worker.worker_uuid:
  380. logger.debug(
  381. f"Authenticated worker {user.worker_id} with bearer token but client_id {websocket.headers.get('x-client-id')} does not match worker_uuid {user.worker.worker_uuid}"
  382. )
  383. return False
  384. websocket.scope["user"] = user
  385. return True