model_files.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. from datetime import datetime
  2. from enum import Enum
  3. from typing import ClassVar, List, Optional
  4. from sqlmodel import (
  5. JSON,
  6. BigInteger,
  7. Column,
  8. Field,
  9. ForeignKey,
  10. Integer,
  11. Relationship,
  12. SQLModel,
  13. Text,
  14. )
  15. from gpustack.mixins import BaseModelMixin
  16. from gpustack.schemas.common import ListParams, PaginatedList
  17. from gpustack.schemas.links import (
  18. ModelInstanceDraftModelFileLink,
  19. ModelInstanceModelFileLink,
  20. )
  21. from gpustack.schemas.models import ModelSource, ModelInstance
  22. class ModelFileStateEnum(str, Enum):
  23. ERROR = "error"
  24. DOWNLOADING = "downloading"
  25. READY = "ready"
  26. class ModelFileBase(SQLModel, ModelSource):
  27. local_dir: Optional[str] = None
  28. worker_id: Optional[int] = None
  29. cleanup_on_delete: Optional[bool] = None
  30. size: Optional[int] = Field(sa_column=Column(BigInteger), default=None)
  31. download_progress: Optional[float] = None
  32. resolved_paths: List[str] = Field(sa_column=Column(JSON), default=[])
  33. state: ModelFileStateEnum = ModelFileStateEnum.DOWNLOADING
  34. state_message: Optional[str] = Field(
  35. default=None, sa_column=Column(Text, nullable=True)
  36. )
  37. class ModelFile(ModelFileBase, BaseModelMixin, table=True):
  38. __tablename__ = 'model_files'
  39. id: Optional[int] = Field(default=None, primary_key=True)
  40. # Unique index of the model source
  41. source_index: Optional[str] = Field(index=True, unique=True, default=None)
  42. # Tenant scope. Server-derived from worker→cluster on creation; not
  43. # exposed on the create payload to avoid clients smuggling overrides.
  44. cluster_id: Optional[int] = Field(default=None)
  45. owner_principal_id: Optional[int] = Field(
  46. default=None,
  47. sa_column=Column(Integer, ForeignKey("principals.id"), nullable=True),
  48. )
  49. instances: list[ModelInstance] = Relationship(
  50. sa_relationship_kwargs={"lazy": "noload"},
  51. back_populates="model_files",
  52. link_model=ModelInstanceModelFileLink,
  53. )
  54. draft_instances: list[ModelInstance] = Relationship(
  55. back_populates="draft_model_files",
  56. link_model=ModelInstanceDraftModelFileLink,
  57. sa_relationship_kwargs={"lazy": "noload"},
  58. )
  59. class ModelFileListParams(ListParams):
  60. sortable_fields: ClassVar[List[str]] = [
  61. "source",
  62. "worker_id",
  63. "state",
  64. "resolved_paths",
  65. "created_at",
  66. "updated_at",
  67. ]
  68. class ModelFileCreate(ModelFileBase):
  69. pass
  70. class ModelFileUpdate(ModelFileBase):
  71. pass
  72. class ModelFilePublic(
  73. ModelFileBase,
  74. ):
  75. id: int
  76. created_at: datetime
  77. updated_at: datetime
  78. ModelFilesPublic = PaginatedList[ModelFilePublic]