| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124 |
- import secrets
- import string
- from datetime import datetime, timedelta, timezone
- from typing import Optional, Union, Tuple
- from functools import lru_cache
- import jwt
- from argon2 import PasswordHasher
- import hashlib
- from gpustack import envs
- ph = PasswordHasher()
- API_KEY_PREFIX = "gpustack"
- @lru_cache(maxsize=2048)
- def verify_hashed_secret(hashed: Union[str, bytes], plain: Union[str, bytes]) -> bool:
- try:
- return ph.verify(hashed, plain)
- except Exception:
- return False
- def get_secret_hash(plain: Union[str, bytes]):
- return ph.hash(plain)
- def generate_secure_password(length=12):
- if length < 8:
- raise ValueError("Password length should be at least 8 characters")
- special_characters = "!@#$%^&*_+"
- characters = string.ascii_letters + string.digits + special_characters
- while True:
- password = ''.join(secrets.choice(characters) for i in range(length))
- if (
- any(c.islower() for c in password)
- and any(c.isupper() for c in password)
- and any(c.isdigit() for c in password)
- and any(c in special_characters for c in password)
- ):
- return password
- def custom_key_hash(secret_key: str) -> str:
- return hashlib.blake2b(secret_key.encode(), digest_size=16).hexdigest()
- def is_valid_format(key: str) -> Tuple[bool, str, str]:
- if not key.startswith(f"{API_KEY_PREFIX}_"):
- return False, "", ""
- parts = key.split("_", 2)
- if len(parts) != 3:
- return False, "", ""
- access_key, secret_key = parts[1], parts[2]
- return True, access_key, secret_key
- def get_key_pair(key: str) -> Tuple[str, str]:
- """
- Parse and validate an API key.
- Scenarios:
- 1. Standard format key: "gpustack_{access_key}_{secret_key}"
- - access_key: 8 characters (hex string, e.g. "3192253c")
- - secret_key: 16 characters (hex string, e.g. "c11c75ed6334ea9505da4ad9")
- - Used for normal API authentication via /v2/* routes
- 2. Legacy UUID format key: standard UUID format with dashes
- - Example: access_key: "3192253c-c11c-75ed-6334-ea9505da4ad9", the secret_key can be any string
- - Used by legacy worker tokens that use UUID as identifier
- - Falls back to custom_key_hash for backward compatibility
- 3. Custom/unrecognized format key:
- - Example: "any_random_string_here", "sk-xxx"
- - Any other string format that doesn't match standard format
- - Returns hashed value for storage, original value for lookup
- - Used for backward compatibility with non-standard API keys
- Returns:
- Tuple of (access_key, secret_key):
- - For standard format: returns the parsed access_key and secret_key
- - For non-standard format: returns (custom_key_hash(key), key)
- """
- valid, access_key, secret_key = is_valid_format(key)
- if not valid:
- return custom_key_hash(key), key
- return access_key, secret_key
- AUTH_CACHE_HEADER = "x-gpustack-auth-cache"
- class JWTManager:
- def __init__(
- self,
- secret_key: str,
- algorithm: str = "HS256",
- expires_delta: Optional[timedelta] = None,
- ):
- if expires_delta is None:
- expires_delta = timedelta(minutes=envs.JWT_TOKEN_EXPIRE_MINUTES)
- self.secret_key = secret_key
- self.algorithm = algorithm
- self.expires_delta = expires_delta
- def create_jwt_token(self, username: str):
- to_encode = {"sub": username}
- expire = datetime.now(timezone.utc) + self.expires_delta
- to_encode.update({"exp": expire})
- encoded_jwt = jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm)
- return encoded_jwt
- def create_token(self, payload: dict, expires_delta: Optional[timedelta] = None):
- delta = expires_delta if expires_delta is not None else self.expires_delta
- to_encode = {"data": payload, "exp": datetime.now(timezone.utc) + delta}
- return jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm)
- def decode_jwt_token(self, token: str):
- return jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
- def decode_jwt_data(self, token: str) -> dict:
- return self.decode_jwt_token(token)["data"]
|