users.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. from datetime import datetime
  2. import re
  3. from enum import Enum
  4. from sqlalchemy import Enum as SQLEnum, Text
  5. from sqlalchemy.orm import selectinload
  6. from sqlmodel.ext.asyncio.session import AsyncSession
  7. from typing import ClassVar, List, Optional, TYPE_CHECKING
  8. from pydantic import field_validator
  9. from sqlmodel import (
  10. Field,
  11. Relationship,
  12. Column,
  13. SQLModel,
  14. Integer,
  15. ForeignKey,
  16. )
  17. from gpustack.schemas.common import ListParams
  18. from .common import PaginatedList
  19. from ..mixins import BaseModelMixin
  20. from .clusters import Cluster
  21. from .workers import Worker
  22. if TYPE_CHECKING:
  23. from .api_keys import ApiKey
  24. from .principals import Principal
  25. system_name_prefix = "system/cluster"
  26. default_cluster_user_name = f"{system_name_prefix}-1"
  27. class UserRole(Enum):
  28. Worker = "Worker"
  29. Cluster = "Cluster"
  30. class AuthProviderEnum(str, Enum):
  31. Local = "Local"
  32. OIDC = "OIDC"
  33. SAML = "SAML"
  34. class UserBase(SQLModel):
  35. username: str
  36. is_admin: bool = False
  37. is_active: bool = True
  38. full_name: Optional[str] = None
  39. avatar_url: Optional[str] = Field(
  40. default=None, sa_column=Column(Text, nullable=True)
  41. )
  42. source: Optional[str] = Field(
  43. default=AuthProviderEnum.Local, sa_type=SQLEnum(AuthProviderEnum)
  44. )
  45. require_password_change: bool = Field(default=False)
  46. is_system: bool = False
  47. role: Optional[UserRole] = Field(
  48. default=None, description="Role of the user, e.g., worker or cluster"
  49. )
  50. cluster_id: Optional[int] = Field(
  51. default=None,
  52. sa_column=Column(Integer, ForeignKey("clusters.id", ondelete="CASCADE")),
  53. )
  54. worker_id: Optional[int] = Field(
  55. default=None,
  56. sa_column=Column(Integer, ForeignKey("workers.id", ondelete="CASCADE")),
  57. )
  58. # 1:1 link to the user's Principal row. NOT NULL by construction —
  59. # every user has a principal, and that principal is the canonical
  60. # owner identity for resources the user creates in their personal
  61. # scope. RESTRICT prevents the principal from being deleted while
  62. # the user row still references it; ``users`` is supposed to be the
  63. # source of truth for user existence, so the principal goes away
  64. # only as part of user deletion.
  65. principal_id: Optional[int] = Field(
  66. default=None,
  67. sa_column=Column(
  68. Integer,
  69. ForeignKey("principals.id", ondelete="RESTRICT"),
  70. nullable=False,
  71. unique=True,
  72. ),
  73. )
  74. class UserCreate(UserBase):
  75. password: str
  76. @field_validator('password')
  77. def validate_password(cls, value):
  78. if not re.search(r'[A-Z]', value):
  79. raise ValueError('Password must contain at least one uppercase letter')
  80. if not re.search(r'[a-z]', value):
  81. raise ValueError('Password must contain at least one lowercase letter')
  82. if not re.search(r'[0-9]', value):
  83. raise ValueError('Password must contain at least one digit')
  84. if not re.search(r'[!@#$%^&*_+]', value):
  85. raise ValueError('Password must contain at least one special character')
  86. return value
  87. class UserUpdate(UserBase):
  88. password: Optional[str] = None
  89. class UserSelfUpdate(SQLModel):
  90. """Schema for users updating their own profile - excludes privileged fields"""
  91. full_name: Optional[str] = None
  92. avatar_url: Optional[str] = Field(
  93. default=None, sa_column=Column(Text, nullable=True)
  94. )
  95. password: Optional[str] = None
  96. @field_validator('password')
  97. def validate_password(cls, value):
  98. if value is None:
  99. return value
  100. if not re.search(r'[A-Z]', value):
  101. raise ValueError('Password must contain at least one uppercase letter')
  102. if not re.search(r'[a-z]', value):
  103. raise ValueError('Password must contain at least one lowercase letter')
  104. if not re.search(r'[0-9]', value):
  105. raise ValueError('Password must contain at least one digit')
  106. if not re.search(r'[!@#$%^&*_+]', value):
  107. raise ValueError('Password must contain at least one special character')
  108. return value
  109. class UpdatePassword(SQLModel):
  110. current_password: str
  111. new_password: str
  112. @field_validator('new_password')
  113. def validate_password(cls, value):
  114. if not re.search(r'[A-Z]', value):
  115. raise ValueError('Password must contain at least one uppercase letter')
  116. if not re.search(r'[a-z]', value):
  117. raise ValueError('Password must contain at least one lowercase letter')
  118. if not re.search(r'[0-9]', value):
  119. raise ValueError('Password must contain at least one digit')
  120. if not re.search(r'[!@#$%^&*_+]', value):
  121. raise ValueError('Password must contain at least one special character')
  122. return value
  123. class User(UserBase, BaseModelMixin, table=True):
  124. __tablename__ = 'users'
  125. id: Optional[int] = Field(default=None, primary_key=True)
  126. hashed_password: Optional[str] = None
  127. cluster: Optional[Cluster] = Relationship(
  128. back_populates="cluster_users", sa_relationship_kwargs={"lazy": "noload"}
  129. )
  130. worker: Optional[Worker] = Relationship(sa_relationship_kwargs={"lazy": "noload"})
  131. # 1:1 link to the user's USER-principal. Setting ``user.principal``
  132. # (instead of ``user.principal_id``) at construction time lets
  133. # SQLAlchemy's unit of work insert the principal first and
  134. # auto-populate ``principal_id`` during a combined flush — the
  135. # standard idiom for satisfying a NOT NULL FK without a separate
  136. # round trip.
  137. principal: Optional["Principal"] = Relationship(
  138. sa_relationship_kwargs={"lazy": "noload"},
  139. )
  140. api_keys: List["ApiKey"] = Relationship(
  141. back_populates='user',
  142. sa_relationship_kwargs={"cascade": "delete", "lazy": "noload"},
  143. )
  144. class UserActivationUpdate(SQLModel):
  145. is_active: bool
  146. class UserListParams(ListParams):
  147. sortable_fields: ClassVar[List[str]] = [
  148. "username",
  149. "is_admin",
  150. "full_name",
  151. "source",
  152. "is_active",
  153. "created_at",
  154. "updated_at",
  155. ]
  156. class UserPublic(UserBase):
  157. id: int
  158. created_at: datetime
  159. updated_at: datetime
  160. UsersPublic = PaginatedList[UserPublic]
  161. def is_default_cluster_user(cluster_user: User) -> bool:
  162. return (
  163. cluster_user.is_system
  164. and cluster_user.cluster_id is not None
  165. and cluster_user.username == default_cluster_user_name
  166. )
  167. async def get_default_cluster_user(session: AsyncSession) -> Optional[User]:
  168. return await User.one_by_field(
  169. session=session,
  170. field="username",
  171. value=default_cluster_user_name,
  172. options=[selectinload(User.cluster)],
  173. )