models.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814
  1. from dataclasses import dataclass
  2. from datetime import datetime
  3. from enum import Enum
  4. import hashlib
  5. from pathlib import Path
  6. from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Union
  7. from pydantic import BaseModel, ConfigDict, model_validator
  8. from sqlalchemy import JSON, Column, ForeignKey, Integer
  9. from sqlmodel import Field, Relationship, SQLModel, Text
  10. from gpustack.schemas.common import (
  11. ListParams,
  12. PaginatedList,
  13. UTCDateTime,
  14. pydantic_column_type,
  15. )
  16. from gpustack.mixins import BaseModelMixin
  17. from gpustack.schemas.links import (
  18. ModelInstanceDraftModelFileLink,
  19. ModelInstanceModelFileLink,
  20. )
  21. from gpustack.utils.command import find_parameter, find_bool_parameter
  22. from gpustack.schemas.model_routes import (
  23. ModelRoute,
  24. ModelRouteTarget,
  25. AccessPolicyEnum,
  26. )
  27. from gpustack.schemas.principals import PLATFORM_PRINCIPAL_ID
  28. if TYPE_CHECKING:
  29. from gpustack.schemas.model_files import ModelFile
  30. from gpustack.schemas.clusters import Cluster
  31. # Models
  32. class SourceEnum(str, Enum):
  33. HUGGING_FACE = "huggingface"
  34. MODEL_SCOPE = "model_scope"
  35. LOCAL_PATH = "local_path"
  36. class CategoryEnum(str, Enum):
  37. LLM = "llm"
  38. EMBEDDING = "embedding"
  39. IMAGE = "image"
  40. RERANKER = "reranker"
  41. SPEECH_TO_TEXT = "speech_to_text"
  42. TEXT_TO_SPEECH = "text_to_speech"
  43. UNKNOWN = "unknown"
  44. class PlacementStrategyEnum(str, Enum):
  45. SPREAD = "spread"
  46. BINPACK = "binpack"
  47. class BackendEnum(str, Enum):
  48. VLLM = "vLLM"
  49. VOX_BOX = "VoxBox"
  50. ASCEND_MINDIE = "MindIE"
  51. SGLANG = "SGLang"
  52. CUSTOM = "Custom"
  53. class BackendSourceEnum(str, Enum):
  54. CUSTOM = "custom"
  55. BUILT_IN = "built_in"
  56. COMMUNITY = "community"
  57. class SpeculativeAlgorithmEnum(str, Enum):
  58. EAGLE3 = "eagle3"
  59. MTP = "mtp"
  60. NGRAM = "ngram"
  61. class GPUSelector(BaseModel):
  62. # format of each element: "worker_name:device:gpu_index", example: "worker1:cuda:0"
  63. gpu_ids: Optional[List[str]] = None
  64. gpus_per_replica: Optional[int] = None
  65. class ExtendedKVCacheConfig(BaseModel):
  66. enabled: bool = False
  67. """ Enable extended KV cache for the model."""
  68. ram_ratio: Optional[float] = 1.2
  69. """ RAM-to-VRAM ratio for KV cache. For example, 2.0 means the RAM is twice the size of the VRAM. """
  70. ram_size: Optional[int] = None
  71. """ Maximum size of the KV cache to be stored in local CPU memory (unit: GiB). Overrides ram_ratio if both are set. """
  72. chunk_size: Optional[int] = None
  73. """ Size for each KV cache chunk (unit: number of tokens). """
  74. class ModelSource(BaseModel):
  75. source: SourceEnum
  76. huggingface_repo_id: Optional[str] = None
  77. huggingface_filename: Optional[str] = None
  78. model_scope_model_id: Optional[str] = None
  79. model_scope_file_path: Optional[str] = None
  80. local_path: Optional[str] = None
  81. @property
  82. def model_source_key(self) -> str:
  83. """Returns a unique identifier for the model, independent of quantization."""
  84. if self.source == SourceEnum.HUGGING_FACE:
  85. return self.huggingface_repo_id or ""
  86. elif self.source == SourceEnum.MODEL_SCOPE:
  87. return self.model_scope_model_id or ""
  88. elif self.source == SourceEnum.LOCAL_PATH:
  89. return self.local_path or ""
  90. return ""
  91. @property
  92. def readable_source(self) -> str:
  93. values = []
  94. if self.source == SourceEnum.HUGGING_FACE:
  95. values.extend([self.huggingface_repo_id, self.huggingface_filename])
  96. elif self.source == SourceEnum.MODEL_SCOPE:
  97. values.extend([self.model_scope_model_id, self.model_scope_file_path])
  98. elif self.source == SourceEnum.LOCAL_PATH:
  99. values.extend([self.local_path])
  100. return "/".join([value for value in values if value is not None])
  101. @property
  102. def model_source_index(self) -> str:
  103. values = []
  104. if self.source == SourceEnum.HUGGING_FACE:
  105. values.extend([self.huggingface_repo_id, self.huggingface_filename])
  106. elif self.source == SourceEnum.MODEL_SCOPE:
  107. values.extend(
  108. [self.source, self.model_scope_model_id, self.model_scope_file_path]
  109. )
  110. elif self.source == SourceEnum.LOCAL_PATH:
  111. values.extend([self.local_path])
  112. # Filter out None values and join
  113. filtered_values = [v for v in values if v is not None]
  114. source_string = "/".join(filtered_values)
  115. return hashlib.sha256(source_string.encode()).hexdigest()
  116. @model_validator(mode="after")
  117. def check_huggingface_fields(self):
  118. if self.source == SourceEnum.HUGGING_FACE:
  119. if not self.huggingface_repo_id:
  120. raise ValueError(
  121. "huggingface_repo_id must be provided "
  122. "when source is 'huggingface'"
  123. )
  124. if self.source == SourceEnum.MODEL_SCOPE:
  125. if not self.model_scope_model_id:
  126. raise ValueError(
  127. "model_scope_model_id must be provided when source is 'model_scope'"
  128. )
  129. if self.source == SourceEnum.LOCAL_PATH:
  130. if not self.local_path:
  131. raise ValueError(
  132. "local_path must be provided when source is 'local_path'"
  133. )
  134. return self
  135. model_config = ConfigDict(protected_namespaces=())
  136. class SpeculativeConfig(BaseModel):
  137. """Configuration for speculative decoding."""
  138. enabled: bool = False
  139. """Whether speculative decoding is enabled."""
  140. algorithm: Optional[SpeculativeAlgorithmEnum] = None
  141. """The algorithm to use for speculative decoding."""
  142. draft_model: Optional[str] = None
  143. """The draft model to use for speculative decoding.
  144. It can be a draft model name from the model catalog, a local path or a model ID from the main model source."""
  145. num_draft_tokens: Optional[int] = None
  146. """The number of draft tokens."""
  147. # For ngram only
  148. ngram_min_match_length: Optional[int] = None
  149. """Minimum length of the n-gram to match."""
  150. ngram_max_match_length: Optional[int] = None
  151. """Maximum length of the n-gram to match."""
  152. class ModelSpecBase(SQLModel, ModelSource):
  153. name: str = Field(index=True, unique=True)
  154. description: Optional[str] = Field(
  155. sa_type=Text,
  156. nullable=True,
  157. default=None,
  158. )
  159. meta: Optional[Dict[str, Any]] = Field(sa_type=JSON, default={})
  160. replicas: int = Field(default=1, ge=0)
  161. ready_replicas: int = Field(default=0, ge=0)
  162. categories: List[str] = Field(sa_type=JSON, default=[])
  163. placement_strategy: PlacementStrategyEnum = PlacementStrategyEnum.SPREAD
  164. cpu_offloading: Optional[bool] = None
  165. distributed_inference_across_workers: Optional[bool] = None
  166. worker_selector: Optional[Dict[str, str]] = Field(sa_type=JSON, default={})
  167. gpu_selector: Optional[GPUSelector] = Field(
  168. sa_type=pydantic_column_type(GPUSelector), default=None
  169. )
  170. backend: Optional[str] = None
  171. backend_version: Optional[str] = None
  172. backend_parameters: Optional[List[str]] = Field(sa_type=JSON, default=None)
  173. image_name: Optional[str] = None
  174. run_command: Optional[str] = Field(sa_type=Text, default=None)
  175. env: Optional[Dict[str, str]] = Field(sa_type=JSON, default=None)
  176. restart_on_error: Optional[bool] = True
  177. distributable: Optional[bool] = False
  178. # Extended KV Cache configuration. Currently maps to LMCache config in vLLM and SGLang.
  179. extended_kv_cache: Optional[ExtendedKVCacheConfig] = Field(
  180. sa_type=pydantic_column_type(ExtendedKVCacheConfig), default=None
  181. )
  182. speculative_config: Optional[SpeculativeConfig] = Field(
  183. sa_type=pydantic_column_type(SpeculativeConfig), default=None
  184. )
  185. # Enable generic proxy for model, the control of generic proxy
  186. # is migrated to ModelAccess. Keeping this field for backward compatibility
  187. generic_proxy: Optional[bool] = Field(default=False)
  188. @model_validator(mode="after")
  189. def set_defaults(self):
  190. backend = get_backend(self)
  191. if self.distributed_inference_across_workers is None:
  192. self.distributed_inference_across_workers = (
  193. True
  194. if backend
  195. in [BackendEnum.VLLM, BackendEnum.ASCEND_MINDIE, BackendEnum.SGLANG]
  196. else False
  197. )
  198. return self
  199. class ModelBase(ModelSpecBase):
  200. cluster_id: Optional[int] = Field(default=None, foreign_key="clusters.id")
  201. owner_principal_id: int = Field(
  202. default=PLATFORM_PRINCIPAL_ID,
  203. sa_column=Column(
  204. Integer,
  205. ForeignKey("principals.id", ondelete="CASCADE"),
  206. nullable=False,
  207. ),
  208. )
  209. # Deprecated field, kept for backward compatibility
  210. access_policy: AccessPolicyEnum = Field(default=AccessPolicyEnum.AUTHED)
  211. class Model(ModelBase, BaseModelMixin, table=True):
  212. __tablename__ = 'models'
  213. id: Optional[int] = Field(default=None, primary_key=True)
  214. instances: list["ModelInstance"] = Relationship(
  215. sa_relationship_kwargs={"cascade": "delete", "lazy": "noload"},
  216. back_populates="model",
  217. )
  218. cluster: "Cluster" = Relationship(
  219. back_populates="cluster_models",
  220. sa_relationship_kwargs={"lazy": "noload"},
  221. )
  222. model_route_targets: List["ModelRouteTarget"] = Relationship(
  223. back_populates="model",
  224. sa_relationship_kwargs={
  225. "lazy": "noload",
  226. "overlaps": "models",
  227. "cascade": "delete",
  228. },
  229. )
  230. model_routes: List["ModelRoute"] = Relationship(
  231. back_populates="models",
  232. link_model=ModelRouteTarget,
  233. sa_relationship_kwargs={
  234. "lazy": "noload",
  235. "overlaps": "model,model_route_targets,route_targets,model_route",
  236. },
  237. )
  238. class ModelListParams(ListParams):
  239. sortable_fields: ClassVar[List[str]] = [
  240. "name",
  241. "source",
  242. "cluster_id",
  243. "replicas",
  244. "ready_replicas",
  245. "created_at",
  246. "updated_at",
  247. ]
  248. class ModelCreate(ModelBase):
  249. enable_model_route: Optional[bool] = Field(default=None)
  250. class ModelUpdate(ModelBase):
  251. pass
  252. class ModelPublic(
  253. ModelBase,
  254. ):
  255. id: int
  256. created_at: datetime
  257. updated_at: datetime
  258. ModelsPublic = PaginatedList[ModelPublic]
  259. # Model Instances
  260. class ModelInstanceStateEnum(str, Enum):
  261. r"""
  262. Enum for Model Instance State
  263. Transitions:
  264. |- - - - - Scheduler - - - - |- - ServeManager - -|- - - - Controller - - - -|- ServeManager -|
  265. | | | | |
  266. PENDING ---> ANALYZING ---> SCHEDULED ---> INITIALIZING ---> DOWNLOADING ---> STARTING ---> RUNNING
  267. | ^ | | | | ^
  268. | | | | | | |(Worker ready)
  269. |------------|--|---------------|----------------|---------------|----------|
  270. \____________|_____________________________________________________________/|
  271. | ERROR |(Worker unreachable)
  272. └--------------------┘ v
  273. (Restart on Error) UNREACHABLE
  274. """
  275. INITIALIZING = "initializing"
  276. PENDING = "pending"
  277. STARTING = "starting"
  278. RUNNING = "running"
  279. SCHEDULED = "scheduled"
  280. ERROR = "error"
  281. DOWNLOADING = "downloading"
  282. ANALYZING = "analyzing"
  283. UNREACHABLE = "unreachable"
  284. def __str__(self):
  285. return self.value
  286. class ComputedResourceClaim(BaseModel):
  287. is_unified_memory: Optional[bool] = False
  288. offload_layers: Optional[int] = None
  289. total_layers: Optional[int] = None
  290. ram: Optional[int] = Field(default=None) # in bytes
  291. vram: Optional[Dict[int, int]] = Field(default=None) # in bytes
  292. tensor_split: Optional[List[int]] = Field(default=None)
  293. vram_utilization: Optional[float] = Field(default=None)
  294. class ModelInstanceSubordinateWorker(BaseModel):
  295. worker_id: Optional[int] = None
  296. worker_name: Optional[str] = None
  297. worker_ip: Optional[str] = None
  298. worker_ifname: Optional[str] = None
  299. total_gpus: Optional[int] = None
  300. gpu_type: Optional[str] = None
  301. gpu_indexes: Optional[List[int]] = Field(sa_column=Column(JSON), default=[])
  302. gpu_addresses: Optional[List[str]] = Field(sa_column=Column(JSON), default=[])
  303. computed_resource_claim: Optional[ComputedResourceClaim] = Field(
  304. sa_column=Column(pydantic_column_type(ComputedResourceClaim)), default=None
  305. )
  306. # - For model file preparation
  307. download_progress: Optional[float] = None
  308. # - For model instance serving preparation
  309. pid: Optional[int] = None
  310. ports: Optional[List[int]] = Field(sa_column=Column(JSON), default=[])
  311. arguments: Optional[List[str]] = Field(sa_column=Column(JSON), default=[])
  312. state: ModelInstanceStateEnum = ModelInstanceStateEnum.PENDING
  313. state_message: Optional[str] = Field(
  314. default=None, sa_column=Column(Text, nullable=True)
  315. )
  316. class DistributedServerCoordinateModeEnum(Enum):
  317. # DELEGATED means that the subordinate workers' coordinate is by-pass to other framework.
  318. DELEGATED = "delegated"
  319. # INITIALIZE_LATER means that the subordinate workers' coordinate is handled by GPUStack,
  320. # all subordinate workers belong to one model instance SHOULD start after the main worker initializes.
  321. # For example, Ascend MindIE/vLLM/SGLang instances need to start their subordinate workers after the main worker initializes.
  322. INITIALIZE_LATER = "initialize_later"
  323. # RUN_FIRST means that the subordinate workers' coordinate is handled by GPUStack,
  324. # all subordinate workers belong to one model instance MUST get ready before the main worker starts.
  325. RUN_FIRST = "run_first"
  326. class DistributedServers(BaseModel):
  327. # Indicates how the distributed servers coordinate with the main worker.
  328. mode: DistributedServerCoordinateModeEnum = (
  329. DistributedServerCoordinateModeEnum.DELEGATED
  330. )
  331. # Indicates if subordinate workers should download model files.
  332. download_model_files: Optional[bool] = True
  333. subordinate_workers: Optional[List[ModelInstanceSubordinateWorker]] = Field(
  334. sa_column=Column(JSON), default=[]
  335. )
  336. model_config = ConfigDict(from_attributes=True)
  337. @dataclass
  338. class ModelInstanceDeploymentMetadata:
  339. """
  340. Metadata for model instance deployment.
  341. """
  342. name: str
  343. """
  344. Name for model instance deployment.
  345. """
  346. distributed: bool = False
  347. """
  348. Whether the model instance is deployed in distributed mode.
  349. """
  350. distributed_leader: bool = False
  351. """
  352. Whether the model instance is the leader in distributed mode.
  353. """
  354. distributed_follower: bool = False
  355. """
  356. Whether the model instance is a follower in distributed mode.
  357. """
  358. distributed_follower_index: Optional[int] = None
  359. """
  360. Index of the follower in distributed mode.
  361. It is None for leader or non-distributed mode.
  362. """
  363. class ModelInstanceBase(SQLModel, ModelSource):
  364. name: str = Field(index=True, unique=True)
  365. worker_id: Optional[int] = None
  366. worker_name: Optional[str] = None
  367. worker_advertise_address: Optional[str] = None
  368. worker_ip: Optional[str] = None
  369. worker_ifname: Optional[str] = None
  370. pid: Optional[int] = None
  371. # FIXME: Migrate to ports.
  372. port: Optional[int] = None
  373. ports: Optional[List[int]] = Field(sa_column=Column(JSON), default=[])
  374. download_progress: Optional[float] = None
  375. resolved_path: Optional[str] = None
  376. draft_model_source: Optional[ModelSource] = Field(
  377. sa_column=Column(pydantic_column_type(ModelSource)), default=None
  378. )
  379. draft_model_download_progress: Optional[float] = None
  380. draft_model_resolved_path: Optional[str] = None
  381. restart_count: Optional[int] = 0
  382. last_restart_time: Optional[datetime] = Field(
  383. sa_column=Column(UTCDateTime), default=None
  384. )
  385. state: ModelInstanceStateEnum = ModelInstanceStateEnum.PENDING
  386. state_message: Optional[str] = Field(
  387. default=None, sa_column=Column(Text, nullable=True)
  388. )
  389. computed_resource_claim: Optional[ComputedResourceClaim] = Field(
  390. sa_column=Column(pydantic_column_type(ComputedResourceClaim)), default=None
  391. )
  392. gpu_type: Optional[str] = None
  393. gpu_indexes: Optional[List[int]] = Field(sa_column=Column(JSON), default=[])
  394. gpu_addresses: Optional[List[str]] = Field(sa_column=Column(JSON), default=[])
  395. model_id: int = Field(default=None, foreign_key="models.id")
  396. model_name: str
  397. backend: Optional[str] = None
  398. backend_version: Optional[str] = None
  399. api_detected_backend_version: Optional[str] = None
  400. injected_backend_parameters: Optional[List[str]] = Field(
  401. sa_column=Column(JSON), default=None
  402. )
  403. distributed_servers: Optional[DistributedServers] = Field(
  404. sa_column=Column(pydantic_column_type(DistributedServers)), default=None
  405. )
  406. # The "model_id" field conflicts with the protected namespace "model_" in Pydantic.
  407. # Disable it given that it's not a real issue for this particular field.
  408. model_config = ConfigDict(protected_namespaces=())
  409. cluster_id: Optional[int] = Field(default=None, foreign_key="clusters.id")
  410. owner_principal_id: int = Field(
  411. default=PLATFORM_PRINCIPAL_ID,
  412. sa_column=Column(
  413. Integer,
  414. ForeignKey("principals.id", ondelete="CASCADE"),
  415. nullable=False,
  416. ),
  417. )
  418. def get_deployment_metadata(
  419. self,
  420. worker_id: int,
  421. ) -> Optional[ModelInstanceDeploymentMetadata]:
  422. """
  423. Get the deployment metadata for the model instance.
  424. Args:
  425. worker_id:
  426. The ID of the worker to get the deployment metadata for.
  427. Returns:
  428. The deployment metadata,
  429. or None if the model instance is not handling by the given `worker_id` worker.
  430. """
  431. dservers = self.distributed_servers
  432. subworkers = (
  433. dservers.subordinate_workers
  434. if dservers and dservers.subordinate_workers
  435. else []
  436. )
  437. name = self.name
  438. distributed = bool(subworkers)
  439. distributed_leader = distributed and self.worker_id == worker_id
  440. distributed_follower = distributed and not distributed_leader
  441. distributed_follower_index = None
  442. if distributed_follower:
  443. for idx, subworker in enumerate(subworkers):
  444. if subworker.worker_id == worker_id:
  445. distributed_follower_index = idx
  446. break
  447. if distributed_follower_index is not None:
  448. # Mutate the name to include the follower index,
  449. # so that each follower has a unique name.
  450. name += f"-f{distributed_follower_index}"
  451. if self.worker_id != worker_id and distributed_follower_index is None:
  452. # This model instance is not handling by the given worker.
  453. return None
  454. return ModelInstanceDeploymentMetadata(
  455. name=name,
  456. distributed=distributed,
  457. distributed_leader=distributed_leader,
  458. distributed_follower=distributed_follower,
  459. distributed_follower_index=distributed_follower_index,
  460. )
  461. class ModelInstance(ModelInstanceBase, BaseModelMixin, table=True):
  462. __tablename__ = 'model_instances'
  463. id: Optional[int] = Field(default=None, primary_key=True)
  464. model: Optional[Model] = Relationship(
  465. back_populates="instances",
  466. sa_relationship_kwargs={"lazy": "noload"},
  467. )
  468. model_files: List["ModelFile"] = Relationship(
  469. back_populates="instances",
  470. link_model=ModelInstanceModelFileLink,
  471. sa_relationship_kwargs={"lazy": "noload"},
  472. )
  473. draft_model_files: List["ModelFile"] = Relationship(
  474. back_populates="draft_instances",
  475. link_model=ModelInstanceDraftModelFileLink,
  476. sa_relationship_kwargs={"lazy": "noload"},
  477. )
  478. cluster: "Cluster" = Relationship(
  479. back_populates="cluster_model_instances",
  480. sa_relationship_kwargs={"lazy": "noload"},
  481. )
  482. # overwrite the hash to use in uniquequeue
  483. def __hash__(self):
  484. return self.id
  485. class ModelInstanceCreate(ModelInstanceBase):
  486. pass
  487. class ModelInstanceUpdate(ModelInstanceBase):
  488. pass
  489. class ModelInstancePublic(
  490. ModelInstanceBase,
  491. ):
  492. id: int
  493. created_at: datetime
  494. updated_at: datetime
  495. ModelInstancesPublic = PaginatedList[ModelInstancePublic]
  496. class ModelInstanceLogWorker(BaseModel):
  497. id: int
  498. name: str
  499. class ModelInstanceLogRestartEntry(BaseModel):
  500. """One main serve log session on disk, with optional UX label time."""
  501. previous: bool = False
  502. started_at: Optional[datetime] = Field(
  503. default=None,
  504. description=(
  505. "Approximate start time from the main log file metadata "
  506. "(birthtime if available, else mtime), UTC."
  507. ),
  508. )
  509. containers: List[str] = Field(
  510. default_factory=list,
  511. description=(
  512. "Available container names for this restart. "
  513. "'default' is the main workload container; others are sidecars "
  514. "(e.g., ['default', 'ray-head'])."
  515. ),
  516. )
  517. class ModelInstanceLogWorkerOption(BaseModel):
  518. """Per-worker result for GET /model-instances/{id}/log-options (one node on disk)."""
  519. worker_id: int
  520. name: str = ""
  521. restarts: List[ModelInstanceLogRestartEntry] = Field(default_factory=list)
  522. error: Optional[str] = Field(
  523. default=None,
  524. description="If set, log options could not be fetched from this worker.",
  525. )
  526. class ServeLogOptionsResponse(BaseModel):
  527. """Worker GET /serveLogOptions JSON; also validates that payload when the server proxies."""
  528. restarts: List[ModelInstanceLogRestartEntry] = Field(default_factory=list)
  529. @model_validator(mode="before")
  530. @classmethod
  531. def _legacy_restart_counts(cls, data: Any) -> Any:
  532. """Old workers only sent restart_counts; expand to restarts when `restarts` is absent."""
  533. if not isinstance(data, dict):
  534. return data
  535. if "restarts" in data:
  536. return data
  537. raw = data.get("restart_counts")
  538. if not isinstance(raw, list):
  539. return {**data, "restarts": []}
  540. counts: List[int] = []
  541. for x in raw:
  542. try:
  543. counts.append(int(x))
  544. except (TypeError, ValueError):
  545. continue
  546. counts.sort(reverse=True)
  547. # Map the highest restart_count to previous=False (current),
  548. # the second highest to previous=True.
  549. entries = []
  550. for i, c in enumerate(counts):
  551. entries.append({"previous": i > 0, "started_at": None})
  552. return {**data, "restarts": entries}
  553. class ModelInstanceLogOptions(BaseModel):
  554. """Server GET /model-instances/{id}/log-options: per-worker serve log distribution."""
  555. main_worker_id: Optional[int] = Field(
  556. default=None,
  557. description="same as model instance worker_id.",
  558. )
  559. workers: List[ModelInstanceLogWorkerOption] = Field(
  560. default_factory=list,
  561. description=(
  562. "Ordered list: main worker first, then subordinate workers. "
  563. "Each entry reflects that worker's local serve logs."
  564. ),
  565. )
  566. def is_gguf_model(model: Union[Model, ModelSource]):
  567. """
  568. Check if the model is a GGUF model.
  569. Args:
  570. model: Model to check.
  571. """
  572. return (
  573. (
  574. model.source == SourceEnum.HUGGING_FACE
  575. and model.huggingface_filename
  576. and model.huggingface_filename.endswith(".gguf")
  577. )
  578. or (
  579. model.source == SourceEnum.MODEL_SCOPE
  580. and model.model_scope_file_path
  581. and model.model_scope_file_path.endswith(".gguf")
  582. )
  583. or (
  584. model.source == SourceEnum.LOCAL_PATH
  585. and model.local_path
  586. and model.local_path.endswith(".gguf")
  587. )
  588. )
  589. def is_audio_model(model: Model):
  590. """
  591. Check if the model is a STT or TTS model.
  592. Args:
  593. model: Model to check.
  594. """
  595. if model.backend == BackendEnum.VOX_BOX:
  596. return True
  597. if model.categories:
  598. return (
  599. 'speech_to_text' in model.categories or 'text_to_speech' in model.categories
  600. )
  601. return False
  602. def is_llm_model(model: Model):
  603. """
  604. Check if the model is an LLM model.
  605. Args:
  606. model: Model to check.
  607. """
  608. return not model.categories or CategoryEnum.LLM in model.categories
  609. def is_omni_model(model: Model) -> bool:
  610. """
  611. Check if the model is an omni model (Image or Audio category).
  612. Args:
  613. model: Model to check.
  614. """
  615. if model.backend == BackendEnum.VLLM and find_bool_parameter(
  616. model.backend_parameters, ["omni"]
  617. ):
  618. return True
  619. OMNI_CATEGORIES = (
  620. CategoryEnum.IMAGE,
  621. CategoryEnum.TEXT_TO_SPEECH,
  622. )
  623. return any(cat in model.categories for cat in OMNI_CATEGORIES)
  624. def is_image_model(model: Model):
  625. """
  626. Check if the model is an image model.
  627. Args:
  628. model: Model to check.
  629. """
  630. return "image" in model.categories
  631. def is_embedding_model(model: Model):
  632. """
  633. Check if the model is an embedding model.
  634. Args:
  635. model: Model to check.
  636. """
  637. return "embedding" in model.categories
  638. def is_reranker_model(model: Model):
  639. """
  640. Check if the model is a reranker model.
  641. Args:
  642. model: Model to check.
  643. """
  644. return "reranker" in model.categories
  645. def get_backend(model: Model) -> str:
  646. if model.backend:
  647. return model.backend
  648. if is_gguf_model(model):
  649. return BackendEnum.CUSTOM
  650. return BackendEnum.VLLM
  651. def get_mmproj_filename(model: Union[Model, ModelSource]) -> Optional[str]:
  652. """
  653. Get the mmproj filename for the model. If the mmproj is not provided in the model's
  654. backend parameters, it will try to find the default mmproj file.
  655. """
  656. if not is_gguf_model(model):
  657. return None
  658. if hasattr(model, "backend_parameters"):
  659. mmproj = find_parameter(model.backend_parameters, ["mmproj"])
  660. if mmproj and Path(mmproj).name == mmproj:
  661. return mmproj
  662. return "*mmproj*.gguf"