| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214 |
- from datetime import datetime
- import re
- from enum import Enum
- from sqlalchemy import Enum as SQLEnum, Text
- from sqlalchemy.orm import selectinload
- from sqlmodel.ext.asyncio.session import AsyncSession
- from typing import ClassVar, List, Optional, TYPE_CHECKING
- from pydantic import field_validator
- from sqlmodel import (
- Field,
- Relationship,
- Column,
- SQLModel,
- Integer,
- ForeignKey,
- )
- from gpustack.schemas.common import ListParams
- from .common import PaginatedList
- from ..mixins import BaseModelMixin
- from .clusters import Cluster
- from .workers import Worker
- if TYPE_CHECKING:
- from .api_keys import ApiKey
- from .principals import Principal
- system_name_prefix = "system/cluster"
- default_cluster_user_name = f"{system_name_prefix}-1"
- class UserRole(Enum):
- Worker = "Worker"
- Cluster = "Cluster"
- class AuthProviderEnum(str, Enum):
- Local = "Local"
- OIDC = "OIDC"
- SAML = "SAML"
- class UserBase(SQLModel):
- username: str
- is_admin: bool = False
- is_active: bool = True
- full_name: Optional[str] = None
- avatar_url: Optional[str] = Field(
- default=None, sa_column=Column(Text, nullable=True)
- )
- source: Optional[str] = Field(
- default=AuthProviderEnum.Local, sa_type=SQLEnum(AuthProviderEnum)
- )
- require_password_change: bool = Field(default=False)
- is_system: bool = False
- role: Optional[UserRole] = Field(
- default=None, description="Role of the user, e.g., worker or cluster"
- )
- cluster_id: Optional[int] = Field(
- default=None,
- sa_column=Column(Integer, ForeignKey("clusters.id", ondelete="CASCADE")),
- )
- worker_id: Optional[int] = Field(
- default=None,
- sa_column=Column(Integer, ForeignKey("workers.id", ondelete="CASCADE")),
- )
- # 1:1 link to the user's Principal row. NOT NULL by construction —
- # every user has a principal, and that principal is the canonical
- # owner identity for resources the user creates in their personal
- # scope. RESTRICT prevents the principal from being deleted while
- # the user row still references it; ``users`` is supposed to be the
- # source of truth for user existence, so the principal goes away
- # only as part of user deletion.
- principal_id: Optional[int] = Field(
- default=None,
- sa_column=Column(
- Integer,
- ForeignKey("principals.id", ondelete="RESTRICT"),
- nullable=False,
- unique=True,
- ),
- )
- class UserCreate(UserBase):
- password: str
- @field_validator('password')
- def validate_password(cls, value):
- if not re.search(r'[A-Z]', value):
- raise ValueError('Password must contain at least one uppercase letter')
- if not re.search(r'[a-z]', value):
- raise ValueError('Password must contain at least one lowercase letter')
- if not re.search(r'[0-9]', value):
- raise ValueError('Password must contain at least one digit')
- if not re.search(r'[!@#$%^&*_+]', value):
- raise ValueError('Password must contain at least one special character')
- return value
- class UserUpdate(UserBase):
- password: Optional[str] = None
- class UserSelfUpdate(SQLModel):
- """Schema for users updating their own profile - excludes privileged fields"""
- full_name: Optional[str] = None
- avatar_url: Optional[str] = Field(
- default=None, sa_column=Column(Text, nullable=True)
- )
- password: Optional[str] = None
- @field_validator('password')
- def validate_password(cls, value):
- if value is None:
- return value
- if not re.search(r'[A-Z]', value):
- raise ValueError('Password must contain at least one uppercase letter')
- if not re.search(r'[a-z]', value):
- raise ValueError('Password must contain at least one lowercase letter')
- if not re.search(r'[0-9]', value):
- raise ValueError('Password must contain at least one digit')
- if not re.search(r'[!@#$%^&*_+]', value):
- raise ValueError('Password must contain at least one special character')
- return value
- class UpdatePassword(SQLModel):
- current_password: str
- new_password: str
- @field_validator('new_password')
- def validate_password(cls, value):
- if not re.search(r'[A-Z]', value):
- raise ValueError('Password must contain at least one uppercase letter')
- if not re.search(r'[a-z]', value):
- raise ValueError('Password must contain at least one lowercase letter')
- if not re.search(r'[0-9]', value):
- raise ValueError('Password must contain at least one digit')
- if not re.search(r'[!@#$%^&*_+]', value):
- raise ValueError('Password must contain at least one special character')
- return value
- class User(UserBase, BaseModelMixin, table=True):
- __tablename__ = 'users'
- id: Optional[int] = Field(default=None, primary_key=True)
- hashed_password: Optional[str] = None
- cluster: Optional[Cluster] = Relationship(
- back_populates="cluster_users", sa_relationship_kwargs={"lazy": "noload"}
- )
- worker: Optional[Worker] = Relationship(sa_relationship_kwargs={"lazy": "noload"})
- # 1:1 link to the user's USER-principal. Setting ``user.principal``
- # (instead of ``user.principal_id``) at construction time lets
- # SQLAlchemy's unit of work insert the principal first and
- # auto-populate ``principal_id`` during a combined flush — the
- # standard idiom for satisfying a NOT NULL FK without a separate
- # round trip.
- principal: Optional["Principal"] = Relationship(
- sa_relationship_kwargs={"lazy": "noload"},
- )
- api_keys: List["ApiKey"] = Relationship(
- back_populates='user',
- sa_relationship_kwargs={"cascade": "delete", "lazy": "noload"},
- )
- class UserActivationUpdate(SQLModel):
- is_active: bool
- class UserListParams(ListParams):
- sortable_fields: ClassVar[List[str]] = [
- "username",
- "is_admin",
- "full_name",
- "source",
- "is_active",
- "created_at",
- "updated_at",
- ]
- class UserPublic(UserBase):
- id: int
- created_at: datetime
- updated_at: datetime
- UsersPublic = PaginatedList[UserPublic]
- def is_default_cluster_user(cluster_user: User) -> bool:
- return (
- cluster_user.is_system
- and cluster_user.cluster_id is not None
- and cluster_user.username == default_cluster_user_name
- )
- async def get_default_cluster_user(session: AsyncSession) -> Optional[User]:
- return await User.one_by_field(
- session=session,
- field="username",
- value=default_cluster_user_name,
- options=[selectinload(User.cluster)],
- )
|