common.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. from datetime import timezone, datetime
  2. import json
  3. from typing import ClassVar, Generic, List, Optional, Tuple, Type, TypeVar
  4. from fastapi import Query
  5. from fastapi.encoders import jsonable_encoder
  6. from pydantic import BaseModel, TypeAdapter, computed_field, field_validator
  7. import sqlalchemy as sa
  8. from sqlalchemy import JSON as SQLAlchemyJSON, TypeDecorator
  9. from gpustack.api.exceptions import InvalidException
  10. T = TypeVar("T", bound=BaseModel)
  11. class PublicFields:
  12. id: int
  13. created_at: datetime
  14. updated_at: datetime
  15. deleted_at: Optional[datetime] = None
  16. class Pagination(BaseModel):
  17. page: int
  18. perPage: int
  19. total: int
  20. totalPage: int
  21. class ListParams(BaseModel):
  22. page: int = Query(default=1)
  23. # FIXME It uses camelCase but most APIs use snake_case. We might want to migrate to snake_case later.
  24. perPage: int = Query(default=100)
  25. watch: bool = Query(default=False)
  26. sort_by: Optional[str] = Query(
  27. default=None,
  28. description="Sorting in the format: field1,-field2,field3. A leading '-' indicates descending order.",
  29. )
  30. sortable_fields: ClassVar[List[str]] = []
  31. @field_validator('sort_by')
  32. def validate_sort_by(cls, v: Optional[str]) -> Optional[str]:
  33. """Validates the sort_by string format."""
  34. if not v:
  35. return v
  36. if not cls.sortable_fields:
  37. return v
  38. for field in v.split(','):
  39. field = field.strip()
  40. if not field:
  41. continue
  42. field_name = field[1:] if field.startswith('-') else field
  43. # Verify if the field is in the allowed sortable fields
  44. if field_name not in cls.sortable_fields:
  45. raise InvalidException(
  46. f"Field '{field_name}' is not sortable. "
  47. f"Allowed fields: {', '.join(cls.sortable_fields)}"
  48. )
  49. return v
  50. @computed_field
  51. @property
  52. def order_by(self) -> Optional[List[Tuple[str, str]]]:
  53. """
  54. Parses the sort_by string into a list of (field, direction) tuples.
  55. For example, "name,-created_at,status" will be parsed to:
  56. [("name", "asc"), ("created_at", "desc"), ("status", "asc")]
  57. Returns None if sort_by is not set.
  58. """
  59. if self.sort_by is None:
  60. return None
  61. order_by = []
  62. for field in self.sort_by.split(','):
  63. field = field.strip()
  64. if not field:
  65. continue
  66. if field.startswith('-'):
  67. direction = "desc"
  68. field_name = field[1:]
  69. else:
  70. direction = "asc"
  71. field_name = field
  72. order_by.append((field_name, direction))
  73. return order_by
  74. class ItemList(BaseModel, Generic[T]):
  75. items: list[T]
  76. class PaginatedList(ItemList[T]):
  77. pagination: Pagination
  78. class JSON(SQLAlchemyJSON):
  79. pass
  80. class UTCDateTime(sa.TypeDecorator):
  81. impl = sa.TIMESTAMP(timezone=False)
  82. cache_ok = True
  83. def process_bind_param(self, value, dialect):
  84. if value is not None and value.tzinfo is not None:
  85. # Ensure the datetime is in UTC and clear tzinfo before storing
  86. value = value.astimezone(timezone.utc).replace(tzinfo=None)
  87. return value
  88. def process_result_value(self, value, dialect):
  89. if value is not None:
  90. # Assume stored datetime is in UTC and attach tzinfo
  91. value = value.replace(tzinfo=timezone.utc)
  92. return value
  93. def pydantic_column_type(
  94. pydantic_type: Type[T],
  95. exclude_defaults: bool = False,
  96. exclude_none: bool = False,
  97. exclude_unset: bool = False,
  98. ): # noqa: C901
  99. class PydanticJSONType(TypeDecorator, Generic[T]):
  100. impl = JSON()
  101. # https://docs.sqlalchemy.org/en/20/core/type_api.html#sqlalchemy.types.ExternalType.cache_ok
  102. cache_ok = True
  103. def __init__(self, json_encoder=json):
  104. self.json_encoder = json_encoder
  105. super(PydanticJSONType, self).__init__()
  106. def bind_processor(self, dialect):
  107. impl_processor = self.impl.bind_processor(dialect)
  108. dumps = self.json_encoder.dumps
  109. def process(value: T):
  110. if value is not None:
  111. value_to_dump = self._prepare_value_for_dump(value)
  112. value = jsonable_encoder(
  113. value_to_dump,
  114. exclude_defaults=exclude_defaults,
  115. exclude_none=exclude_none,
  116. exclude_unset=exclude_unset,
  117. )
  118. return (
  119. impl_processor(value)
  120. if impl_processor
  121. else dumps(
  122. jsonable_encoder(
  123. value_to_dump,
  124. exclude_defaults=exclude_defaults,
  125. exclude_none=exclude_none,
  126. exclude_unset=exclude_unset,
  127. )
  128. )
  129. )
  130. return process
  131. def result_processor(self, dialect, coltype) -> T:
  132. impl_processor = self.impl.result_processor(dialect, coltype)
  133. def process(value):
  134. if impl_processor:
  135. value = impl_processor(value)
  136. if value is None:
  137. return None
  138. return TypeAdapter(pydantic_type).validate_python(value)
  139. return process
  140. def compare_values(self, x, y):
  141. return x == y
  142. def _prepare_value_for_dump(self, value):
  143. return TypeAdapter(pydantic_type).validate_python(value)
  144. def __repr__(self):
  145. return "JSON()"
  146. def __str__(self):
  147. return "JSON()"
  148. return PydanticJSONType