models.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822
  1. from dataclasses import dataclass
  2. from datetime import datetime
  3. from enum import Enum
  4. import hashlib
  5. from pathlib import Path
  6. from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Union
  7. from pydantic import BaseModel, ConfigDict, model_validator
  8. from sqlalchemy import JSON, Column, ForeignKey, Integer
  9. from sqlmodel import Field, Relationship, SQLModel, Text
  10. from gpustack.schemas.common import (
  11. ListParams,
  12. PaginatedList,
  13. UTCDateTime,
  14. pydantic_column_type,
  15. )
  16. from gpustack.mixins import BaseModelMixin
  17. from gpustack.schemas.links import (
  18. ModelInstanceDraftModelFileLink,
  19. ModelInstanceModelFileLink,
  20. )
  21. from gpustack.utils.command import find_parameter, find_bool_parameter
  22. from gpustack.schemas.model_routes import (
  23. ModelRoute,
  24. ModelRouteTarget,
  25. AccessPolicyEnum,
  26. )
  27. from gpustack.schemas.principals import PLATFORM_PRINCIPAL_ID
  28. if TYPE_CHECKING:
  29. from gpustack.schemas.model_files import ModelFile
  30. from gpustack.schemas.clusters import Cluster
  31. # Models
  32. class SourceEnum(str, Enum):
  33. HUGGING_FACE = "huggingface"
  34. MODEL_SCOPE = "model_scope"
  35. LOCAL_PATH = "local_path"
  36. class CategoryEnum(str, Enum):
  37. LLM = "llm"
  38. EMBEDDING = "embedding"
  39. IMAGE = "image"
  40. RERANKER = "reranker"
  41. SPEECH_TO_TEXT = "speech_to_text"
  42. TEXT_TO_SPEECH = "text_to_speech"
  43. UNKNOWN = "unknown"
  44. class PlacementStrategyEnum(str, Enum):
  45. SPREAD = "spread"
  46. BINPACK = "binpack"
  47. class BackendEnum(str, Enum):
  48. VLLM = "vLLM"
  49. VOX_BOX = "VoxBox"
  50. ASCEND_MINDIE = "MindIE"
  51. SGLANG = "SGLang"
  52. CUSTOM = "Custom"
  53. class BackendSourceEnum(str, Enum):
  54. CUSTOM = "custom"
  55. BUILT_IN = "built_in"
  56. COMMUNITY = "community"
  57. class SpeculativeAlgorithmEnum(str, Enum):
  58. EAGLE3 = "eagle3"
  59. MTP = "mtp"
  60. NGRAM = "ngram"
  61. class GPUSelector(BaseModel):
  62. # format of each element: "worker_name:device:gpu_index", example: "worker1:cuda:0"
  63. gpu_ids: Optional[List[str]] = None
  64. gpus_per_replica: Optional[int] = None
  65. class ExtendedKVCacheConfig(BaseModel):
  66. enabled: bool = False
  67. """ Enable extended KV cache for the model."""
  68. ram_ratio: Optional[float] = 1.2
  69. """ RAM-to-VRAM ratio for KV cache. For example, 2.0 means the RAM is twice the size of the VRAM. """
  70. ram_size: Optional[int] = None
  71. """ Maximum size of the KV cache to be stored in local CPU memory (unit: GiB). Overrides ram_ratio if both are set. """
  72. chunk_size: Optional[int] = None
  73. """ Size for each KV cache chunk (unit: number of tokens). """
  74. class ModelSource(BaseModel):
  75. source: SourceEnum
  76. huggingface_repo_id: Optional[str] = None
  77. huggingface_filename: Optional[str] = None
  78. model_scope_model_id: Optional[str] = None
  79. model_scope_file_path: Optional[str] = None
  80. local_path: Optional[str] = None
  81. @property
  82. def model_source_key(self) -> str:
  83. """Returns a unique identifier for the model, independent of quantization."""
  84. if self.source == SourceEnum.HUGGING_FACE:
  85. return self.huggingface_repo_id or ""
  86. elif self.source == SourceEnum.MODEL_SCOPE:
  87. return self.model_scope_model_id or ""
  88. elif self.source == SourceEnum.LOCAL_PATH:
  89. return self.local_path or ""
  90. return ""
  91. @property
  92. def readable_source(self) -> str:
  93. values = []
  94. if self.source == SourceEnum.HUGGING_FACE:
  95. values.extend([self.huggingface_repo_id, self.huggingface_filename])
  96. elif self.source == SourceEnum.MODEL_SCOPE:
  97. values.extend([self.model_scope_model_id, self.model_scope_file_path])
  98. elif self.source == SourceEnum.LOCAL_PATH:
  99. values.extend([self.local_path])
  100. return "/".join([value for value in values if value is not None])
  101. @property
  102. def model_source_index(self) -> str:
  103. values = []
  104. if self.source == SourceEnum.HUGGING_FACE:
  105. values.extend([self.huggingface_repo_id, self.huggingface_filename])
  106. elif self.source == SourceEnum.MODEL_SCOPE:
  107. values.extend(
  108. [self.source, self.model_scope_model_id, self.model_scope_file_path]
  109. )
  110. elif self.source == SourceEnum.LOCAL_PATH:
  111. values.extend([self.local_path])
  112. # Filter out None values and join
  113. filtered_values = [v for v in values if v is not None]
  114. source_string = "/".join(filtered_values)
  115. return hashlib.sha256(source_string.encode()).hexdigest()
  116. @model_validator(mode="after")
  117. def check_huggingface_fields(self):
  118. if self.source == SourceEnum.HUGGING_FACE:
  119. if not self.huggingface_repo_id:
  120. raise ValueError(
  121. "huggingface_repo_id must be provided "
  122. "when source is 'huggingface'"
  123. )
  124. if self.source == SourceEnum.MODEL_SCOPE:
  125. if not self.model_scope_model_id:
  126. raise ValueError(
  127. "model_scope_model_id must be provided when source is 'model_scope'"
  128. )
  129. if self.source == SourceEnum.LOCAL_PATH:
  130. if not self.local_path:
  131. raise ValueError(
  132. "local_path must be provided when source is 'local_path'"
  133. )
  134. return self
  135. model_config = ConfigDict(protected_namespaces=())
  136. class SpeculativeConfig(BaseModel):
  137. """Configuration for speculative decoding."""
  138. enabled: bool = False
  139. """Whether speculative decoding is enabled."""
  140. algorithm: Optional[SpeculativeAlgorithmEnum] = None
  141. """The algorithm to use for speculative decoding."""
  142. draft_model: Optional[str] = None
  143. """The draft model to use for speculative decoding.
  144. It can be a draft model name from the model catalog, a local path or a model ID from the main model source."""
  145. num_draft_tokens: Optional[int] = None
  146. """The number of draft tokens."""
  147. # For ngram only
  148. ngram_min_match_length: Optional[int] = None
  149. """Minimum length of the n-gram to match."""
  150. ngram_max_match_length: Optional[int] = None
  151. """Maximum length of the n-gram to match."""
  152. class ModelSpecBase(SQLModel, ModelSource):
  153. name: str = Field(index=True, unique=True)
  154. description: Optional[str] = Field(
  155. sa_type=Text,
  156. nullable=True,
  157. default=None,
  158. )
  159. meta: Optional[Dict[str, Any]] = Field(sa_type=JSON, default={})
  160. replicas: int = Field(default=1, ge=0)
  161. ready_replicas: int = Field(default=0, ge=0)
  162. categories: List[str] = Field(sa_type=JSON, default=[])
  163. placement_strategy: PlacementStrategyEnum = PlacementStrategyEnum.SPREAD
  164. cpu_offloading: Optional[bool] = None
  165. distributed_inference_across_workers: Optional[bool] = None
  166. worker_selector: Optional[Dict[str, str]] = Field(sa_type=JSON, default={})
  167. gpu_selector: Optional[GPUSelector] = Field(
  168. sa_type=pydantic_column_type(GPUSelector), default=None
  169. )
  170. backend: Optional[str] = None
  171. backend_version: Optional[str] = None
  172. backend_parameters: Optional[List[str]] = Field(sa_type=JSON, default=None)
  173. image_name: Optional[str] = None
  174. run_command: Optional[str] = Field(sa_type=Text, default=None)
  175. env: Optional[Dict[str, str]] = Field(sa_type=JSON, default=None)
  176. restart_on_error: Optional[bool] = True
  177. distributable: Optional[bool] = False
  178. # Extended KV Cache configuration. Currently maps to LMCache config in vLLM and SGLang.
  179. extended_kv_cache: Optional[ExtendedKVCacheConfig] = Field(
  180. sa_type=pydantic_column_type(ExtendedKVCacheConfig), default=None
  181. )
  182. speculative_config: Optional[SpeculativeConfig] = Field(
  183. sa_type=pydantic_column_type(SpeculativeConfig), default=None
  184. )
  185. # Enable generic proxy for model, the control of generic proxy
  186. # is migrated to ModelAccess. Keeping this field for backward compatibility
  187. generic_proxy: Optional[bool] = Field(default=False)
  188. @model_validator(mode="after")
  189. def set_defaults(self):
  190. backend = get_backend(self)
  191. if self.distributed_inference_across_workers is None:
  192. self.distributed_inference_across_workers = (
  193. True
  194. if backend
  195. in [BackendEnum.VLLM, BackendEnum.ASCEND_MINDIE, BackendEnum.SGLANG]
  196. else False
  197. )
  198. return self
  199. class ModelBase(ModelSpecBase):
  200. cluster_id: Optional[int] = Field(default=None, foreign_key="clusters.id")
  201. owner_principal_id: int = Field(
  202. default=PLATFORM_PRINCIPAL_ID,
  203. sa_column=Column(
  204. Integer,
  205. ForeignKey("principals.id", ondelete="CASCADE"),
  206. nullable=False,
  207. ),
  208. )
  209. # Deprecated field, kept for backward compatibility
  210. access_policy: AccessPolicyEnum = Field(default=AccessPolicyEnum.AUTHED)
  211. class Model(ModelBase, BaseModelMixin, table=True):
  212. __tablename__ = 'models'
  213. id: Optional[int] = Field(default=None, primary_key=True)
  214. instances: list["ModelInstance"] = Relationship(
  215. sa_relationship_kwargs={"cascade": "delete", "lazy": "noload"},
  216. back_populates="model",
  217. )
  218. cluster: "Cluster" = Relationship(
  219. back_populates="cluster_models",
  220. sa_relationship_kwargs={"lazy": "noload"},
  221. )
  222. model_route_targets: List["ModelRouteTarget"] = Relationship(
  223. back_populates="model",
  224. sa_relationship_kwargs={
  225. "lazy": "noload",
  226. "overlaps": "models",
  227. "cascade": "delete",
  228. },
  229. )
  230. model_routes: List["ModelRoute"] = Relationship(
  231. back_populates="models",
  232. link_model=ModelRouteTarget,
  233. sa_relationship_kwargs={
  234. "lazy": "noload",
  235. "overlaps": "model,model_route_targets,route_targets,model_route",
  236. },
  237. )
  238. class ModelListParams(ListParams):
  239. sortable_fields: ClassVar[List[str]] = [
  240. "name",
  241. "source",
  242. "cluster_id",
  243. "replicas",
  244. "ready_replicas",
  245. "created_at",
  246. "updated_at",
  247. ]
  248. class ModelCreate(ModelBase):
  249. enable_model_route: Optional[bool] = Field(default=None)
  250. overwrite_deleted: bool = Field(
  251. default=False,
  252. description="When true, overwrite soft-deleted model with the same name"
  253. )
  254. class ModelUpdate(ModelBase):
  255. pass
  256. class ModelPublic(
  257. ModelBase,
  258. ):
  259. id: int
  260. created_at: datetime
  261. updated_at: datetime
  262. ModelsPublic = PaginatedList[ModelPublic]
  263. # Model Instances
  264. class ModelInstanceStateEnum(str, Enum):
  265. r"""
  266. Enum for Model Instance State
  267. Transitions:
  268. |- - - - - Scheduler - - - - |- - ServeManager - -|- - - - Controller - - - -|- ServeManager -|
  269. | | | | |
  270. PENDING ---> ANALYZING ---> SCHEDULED ---> INITIALIZING ---> DOWNLOADING ---> STARTING ---> RUNNING
  271. | ^ | | | | ^
  272. | | | | | | |(Worker ready)
  273. |------------|--|---------------|----------------|---------------|----------|
  274. \____________|_____________________________________________________________/|
  275. | ERROR |(Worker unreachable)
  276. └--------------------┘ v
  277. (Restart on Error) UNREACHABLE
  278. """
  279. INITIALIZING = "initializing"
  280. PENDING = "pending"
  281. STARTING = "starting"
  282. RUNNING = "running"
  283. SCHEDULED = "scheduled"
  284. ERROR = "error"
  285. DOWNLOADING = "downloading"
  286. ANALYZING = "analyzing"
  287. UNREACHABLE = "unreachable"
  288. def __str__(self):
  289. return self.value
  290. class ComputedResourceClaim(BaseModel):
  291. is_unified_memory: Optional[bool] = False
  292. offload_layers: Optional[int] = None
  293. total_layers: Optional[int] = None
  294. ram: Optional[int] = Field(default=None) # in bytes
  295. vram: Optional[Dict[int, int]] = Field(default=None) # in bytes
  296. tensor_split: Optional[List[int]] = Field(default=None)
  297. vram_utilization: Optional[float] = Field(default=None)
  298. # estimated_vram is the model's actual estimated VRAM requirement in bytes,
  299. # independent of the target GPU's total memory. This differs from `vram`
  300. # which represents the allocated amount (total_gpu_memory * utilization_rate).
  301. estimated_vram: Optional[int] = Field(default=None) # in bytes
  302. class ModelInstanceSubordinateWorker(BaseModel):
  303. worker_id: Optional[int] = None
  304. worker_name: Optional[str] = None
  305. worker_ip: Optional[str] = None
  306. worker_ifname: Optional[str] = None
  307. total_gpus: Optional[int] = None
  308. gpu_type: Optional[str] = None
  309. gpu_indexes: Optional[List[int]] = Field(sa_column=Column(JSON), default=[])
  310. gpu_addresses: Optional[List[str]] = Field(sa_column=Column(JSON), default=[])
  311. computed_resource_claim: Optional[ComputedResourceClaim] = Field(
  312. sa_column=Column(pydantic_column_type(ComputedResourceClaim)), default=None
  313. )
  314. # - For model file preparation
  315. download_progress: Optional[float] = None
  316. # - For model instance serving preparation
  317. pid: Optional[int] = None
  318. ports: Optional[List[int]] = Field(sa_column=Column(JSON), default=[])
  319. arguments: Optional[List[str]] = Field(sa_column=Column(JSON), default=[])
  320. state: ModelInstanceStateEnum = ModelInstanceStateEnum.PENDING
  321. state_message: Optional[str] = Field(
  322. default=None, sa_column=Column(Text, nullable=True)
  323. )
  324. class DistributedServerCoordinateModeEnum(Enum):
  325. # DELEGATED means that the subordinate workers' coordinate is by-pass to other framework.
  326. DELEGATED = "delegated"
  327. # INITIALIZE_LATER means that the subordinate workers' coordinate is handled by GPUStack,
  328. # all subordinate workers belong to one model instance SHOULD start after the main worker initializes.
  329. # For example, Ascend MindIE/vLLM/SGLang instances need to start their subordinate workers after the main worker initializes.
  330. INITIALIZE_LATER = "initialize_later"
  331. # RUN_FIRST means that the subordinate workers' coordinate is handled by GPUStack,
  332. # all subordinate workers belong to one model instance MUST get ready before the main worker starts.
  333. RUN_FIRST = "run_first"
  334. class DistributedServers(BaseModel):
  335. # Indicates how the distributed servers coordinate with the main worker.
  336. mode: DistributedServerCoordinateModeEnum = (
  337. DistributedServerCoordinateModeEnum.DELEGATED
  338. )
  339. # Indicates if subordinate workers should download model files.
  340. download_model_files: Optional[bool] = True
  341. subordinate_workers: Optional[List[ModelInstanceSubordinateWorker]] = Field(
  342. sa_column=Column(JSON), default=[]
  343. )
  344. model_config = ConfigDict(from_attributes=True)
  345. @dataclass
  346. class ModelInstanceDeploymentMetadata:
  347. """
  348. Metadata for model instance deployment.
  349. """
  350. name: str
  351. """
  352. Name for model instance deployment.
  353. """
  354. distributed: bool = False
  355. """
  356. Whether the model instance is deployed in distributed mode.
  357. """
  358. distributed_leader: bool = False
  359. """
  360. Whether the model instance is the leader in distributed mode.
  361. """
  362. distributed_follower: bool = False
  363. """
  364. Whether the model instance is a follower in distributed mode.
  365. """
  366. distributed_follower_index: Optional[int] = None
  367. """
  368. Index of the follower in distributed mode.
  369. It is None for leader or non-distributed mode.
  370. """
  371. class ModelInstanceBase(SQLModel, ModelSource):
  372. name: str = Field(index=True, unique=True)
  373. worker_id: Optional[int] = None
  374. worker_name: Optional[str] = None
  375. worker_advertise_address: Optional[str] = None
  376. worker_ip: Optional[str] = None
  377. worker_ifname: Optional[str] = None
  378. pid: Optional[int] = None
  379. # FIXME: Migrate to ports.
  380. port: Optional[int] = None
  381. ports: Optional[List[int]] = Field(sa_column=Column(JSON), default=[])
  382. download_progress: Optional[float] = None
  383. resolved_path: Optional[str] = None
  384. draft_model_source: Optional[ModelSource] = Field(
  385. sa_column=Column(pydantic_column_type(ModelSource)), default=None
  386. )
  387. draft_model_download_progress: Optional[float] = None
  388. draft_model_resolved_path: Optional[str] = None
  389. restart_count: Optional[int] = 0
  390. last_restart_time: Optional[datetime] = Field(
  391. sa_column=Column(UTCDateTime), default=None
  392. )
  393. state: ModelInstanceStateEnum = ModelInstanceStateEnum.PENDING
  394. state_message: Optional[str] = Field(
  395. default=None, sa_column=Column(Text, nullable=True)
  396. )
  397. computed_resource_claim: Optional[ComputedResourceClaim] = Field(
  398. sa_column=Column(pydantic_column_type(ComputedResourceClaim)), default=None
  399. )
  400. gpu_type: Optional[str] = None
  401. gpu_indexes: Optional[List[int]] = Field(sa_column=Column(JSON), default=[])
  402. gpu_addresses: Optional[List[str]] = Field(sa_column=Column(JSON), default=[])
  403. model_id: int = Field(default=None, foreign_key="models.id")
  404. model_name: str
  405. backend: Optional[str] = None
  406. backend_version: Optional[str] = None
  407. api_detected_backend_version: Optional[str] = None
  408. injected_backend_parameters: Optional[List[str]] = Field(
  409. sa_column=Column(JSON), default=None
  410. )
  411. distributed_servers: Optional[DistributedServers] = Field(
  412. sa_column=Column(pydantic_column_type(DistributedServers)), default=None
  413. )
  414. # The "model_id" field conflicts with the protected namespace "model_" in Pydantic.
  415. # Disable it given that it's not a real issue for this particular field.
  416. model_config = ConfigDict(protected_namespaces=())
  417. cluster_id: Optional[int] = Field(default=None, foreign_key="clusters.id")
  418. owner_principal_id: int = Field(
  419. default=PLATFORM_PRINCIPAL_ID,
  420. sa_column=Column(
  421. Integer,
  422. ForeignKey("principals.id", ondelete="CASCADE"),
  423. nullable=False,
  424. ),
  425. )
  426. def get_deployment_metadata(
  427. self,
  428. worker_id: int,
  429. ) -> Optional[ModelInstanceDeploymentMetadata]:
  430. """
  431. Get the deployment metadata for the model instance.
  432. Args:
  433. worker_id:
  434. The ID of the worker to get the deployment metadata for.
  435. Returns:
  436. The deployment metadata,
  437. or None if the model instance is not handling by the given `worker_id` worker.
  438. """
  439. dservers = self.distributed_servers
  440. subworkers = (
  441. dservers.subordinate_workers
  442. if dservers and dservers.subordinate_workers
  443. else []
  444. )
  445. name = self.name
  446. distributed = bool(subworkers)
  447. distributed_leader = distributed and self.worker_id == worker_id
  448. distributed_follower = distributed and not distributed_leader
  449. distributed_follower_index = None
  450. if distributed_follower:
  451. for idx, subworker in enumerate(subworkers):
  452. if subworker.worker_id == worker_id:
  453. distributed_follower_index = idx
  454. break
  455. if distributed_follower_index is not None:
  456. # Mutate the name to include the follower index,
  457. # so that each follower has a unique name.
  458. name += f"-f{distributed_follower_index}"
  459. if self.worker_id != worker_id and distributed_follower_index is None:
  460. # This model instance is not handling by the given worker.
  461. return None
  462. return ModelInstanceDeploymentMetadata(
  463. name=name,
  464. distributed=distributed,
  465. distributed_leader=distributed_leader,
  466. distributed_follower=distributed_follower,
  467. distributed_follower_index=distributed_follower_index,
  468. )
  469. class ModelInstance(ModelInstanceBase, BaseModelMixin, table=True):
  470. __tablename__ = 'model_instances'
  471. id: Optional[int] = Field(default=None, primary_key=True)
  472. model: Optional[Model] = Relationship(
  473. back_populates="instances",
  474. sa_relationship_kwargs={"lazy": "noload"},
  475. )
  476. model_files: List["ModelFile"] = Relationship(
  477. back_populates="instances",
  478. link_model=ModelInstanceModelFileLink,
  479. sa_relationship_kwargs={"lazy": "noload"},
  480. )
  481. draft_model_files: List["ModelFile"] = Relationship(
  482. back_populates="draft_instances",
  483. link_model=ModelInstanceDraftModelFileLink,
  484. sa_relationship_kwargs={"lazy": "noload"},
  485. )
  486. cluster: "Cluster" = Relationship(
  487. back_populates="cluster_model_instances",
  488. sa_relationship_kwargs={"lazy": "noload"},
  489. )
  490. # overwrite the hash to use in uniquequeue
  491. def __hash__(self):
  492. return self.id
  493. class ModelInstanceCreate(ModelInstanceBase):
  494. pass
  495. class ModelInstanceUpdate(ModelInstanceBase):
  496. pass
  497. class ModelInstancePublic(
  498. ModelInstanceBase,
  499. ):
  500. id: int
  501. created_at: datetime
  502. updated_at: datetime
  503. ModelInstancesPublic = PaginatedList[ModelInstancePublic]
  504. class ModelInstanceLogWorker(BaseModel):
  505. id: int
  506. name: str
  507. class ModelInstanceLogRestartEntry(BaseModel):
  508. """One main serve log session on disk, with optional UX label time."""
  509. previous: bool = False
  510. started_at: Optional[datetime] = Field(
  511. default=None,
  512. description=(
  513. "Approximate start time from the main log file metadata "
  514. "(birthtime if available, else mtime), UTC."
  515. ),
  516. )
  517. containers: List[str] = Field(
  518. default_factory=list,
  519. description=(
  520. "Available container names for this restart. "
  521. "'default' is the main workload container; others are sidecars "
  522. "(e.g., ['default', 'ray-head'])."
  523. ),
  524. )
  525. class ModelInstanceLogWorkerOption(BaseModel):
  526. """Per-worker result for GET /model-instances/{id}/log-options (one node on disk)."""
  527. worker_id: int
  528. name: str = ""
  529. restarts: List[ModelInstanceLogRestartEntry] = Field(default_factory=list)
  530. error: Optional[str] = Field(
  531. default=None,
  532. description="If set, log options could not be fetched from this worker.",
  533. )
  534. class ServeLogOptionsResponse(BaseModel):
  535. """Worker GET /serveLogOptions JSON; also validates that payload when the server proxies."""
  536. restarts: List[ModelInstanceLogRestartEntry] = Field(default_factory=list)
  537. @model_validator(mode="before")
  538. @classmethod
  539. def _legacy_restart_counts(cls, data: Any) -> Any:
  540. """Old workers only sent restart_counts; expand to restarts when `restarts` is absent."""
  541. if not isinstance(data, dict):
  542. return data
  543. if "restarts" in data:
  544. return data
  545. raw = data.get("restart_counts")
  546. if not isinstance(raw, list):
  547. return {**data, "restarts": []}
  548. counts: List[int] = []
  549. for x in raw:
  550. try:
  551. counts.append(int(x))
  552. except (TypeError, ValueError):
  553. continue
  554. counts.sort(reverse=True)
  555. # Map the highest restart_count to previous=False (current),
  556. # the second highest to previous=True.
  557. entries = []
  558. for i, c in enumerate(counts):
  559. entries.append({"previous": i > 0, "started_at": None})
  560. return {**data, "restarts": entries}
  561. class ModelInstanceLogOptions(BaseModel):
  562. """Server GET /model-instances/{id}/log-options: per-worker serve log distribution."""
  563. main_worker_id: Optional[int] = Field(
  564. default=None,
  565. description="same as model instance worker_id.",
  566. )
  567. workers: List[ModelInstanceLogWorkerOption] = Field(
  568. default_factory=list,
  569. description=(
  570. "Ordered list: main worker first, then subordinate workers. "
  571. "Each entry reflects that worker's local serve logs."
  572. ),
  573. )
  574. def is_gguf_model(model: Union[Model, ModelSource]):
  575. """
  576. Check if the model is a GGUF model.
  577. Args:
  578. model: Model to check.
  579. """
  580. return (
  581. (
  582. model.source == SourceEnum.HUGGING_FACE
  583. and model.huggingface_filename
  584. and model.huggingface_filename.endswith(".gguf")
  585. )
  586. or (
  587. model.source == SourceEnum.MODEL_SCOPE
  588. and model.model_scope_file_path
  589. and model.model_scope_file_path.endswith(".gguf")
  590. )
  591. or (
  592. model.source == SourceEnum.LOCAL_PATH
  593. and model.local_path
  594. and model.local_path.endswith(".gguf")
  595. )
  596. )
  597. def is_audio_model(model: Model):
  598. """
  599. Check if the model is a STT or TTS model.
  600. Args:
  601. model: Model to check.
  602. """
  603. if model.backend == BackendEnum.VOX_BOX:
  604. return True
  605. if model.categories:
  606. return (
  607. 'speech_to_text' in model.categories or 'text_to_speech' in model.categories
  608. )
  609. return False
  610. def is_llm_model(model: Model):
  611. """
  612. Check if the model is an LLM model.
  613. Args:
  614. model: Model to check.
  615. """
  616. return not model.categories or CategoryEnum.LLM in model.categories
  617. def is_omni_model(model: Model) -> bool:
  618. """
  619. Check if the model is an omni model (Image or Audio category).
  620. Args:
  621. model: Model to check.
  622. """
  623. if model.backend == BackendEnum.VLLM and find_bool_parameter(
  624. model.backend_parameters, ["omni"]
  625. ):
  626. return True
  627. OMNI_CATEGORIES = (
  628. CategoryEnum.IMAGE,
  629. CategoryEnum.TEXT_TO_SPEECH,
  630. )
  631. return any(cat in model.categories for cat in OMNI_CATEGORIES)
  632. def is_image_model(model: Model):
  633. """
  634. Check if the model is an image model.
  635. Args:
  636. model: Model to check.
  637. """
  638. return "image" in model.categories
  639. def is_embedding_model(model: Model):
  640. """
  641. Check if the model is an embedding model.
  642. Args:
  643. model: Model to check.
  644. """
  645. return "embedding" in model.categories
  646. def is_reranker_model(model: Model):
  647. """
  648. Check if the model is a reranker model.
  649. Args:
  650. model: Model to check.
  651. """
  652. return "reranker" in model.categories
  653. def get_backend(model: Model) -> str:
  654. if model.backend:
  655. return model.backend
  656. if is_gguf_model(model):
  657. return BackendEnum.CUSTOM
  658. return BackendEnum.VLLM
  659. def get_mmproj_filename(model: Union[Model, ModelSource]) -> Optional[str]:
  660. """
  661. Get the mmproj filename for the model. If the mmproj is not provided in the model's
  662. backend parameters, it will try to find the default mmproj file.
  663. """
  664. if not is_gguf_model(model):
  665. return None
  666. if hasattr(model, "backend_parameters"):
  667. mmproj = find_parameter(model.backend_parameters, ["mmproj"])
  668. if mmproj and Path(mmproj).name == mmproj:
  669. return mmproj
  670. return "*mmproj*.gguf"