scheduler.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873
  1. import asyncio
  2. from datetime import datetime, timedelta, timezone
  3. import json
  4. import logging
  5. import os
  6. import queue
  7. from typing import List, Tuple, Optional
  8. from sqlmodel.ext.asyncio.session import AsyncSession
  9. from sqlalchemy.orm import selectinload
  10. from apscheduler.schedulers.asyncio import AsyncIOScheduler
  11. from apscheduler.triggers.interval import IntervalTrigger
  12. from gpustack.policies.scorers.placement_scorer import PlacementScorer
  13. from gpustack.policies.scorers.model_file_locality_scorer import (
  14. ModelFileLocalityScorer,
  15. )
  16. from gpustack.policies.scorers.score_chain import CandidateScoreChain
  17. from gpustack.config.config import Config, get_global_config
  18. from gpustack.policies.base import (
  19. ModelInstanceScheduleCandidate,
  20. WorkerFilterChain,
  21. )
  22. from gpustack.policies.candidate_selectors import (
  23. AscendMindIEResourceFitSelector,
  24. GGUFResourceFitSelector,
  25. SGLangResourceFitSelector,
  26. VLLMResourceFitSelector,
  27. )
  28. from gpustack.policies.candidate_selectors.custom_backend_resource_fit_selector import (
  29. CustomBackendResourceFitSelector,
  30. )
  31. from gpustack.policies.utils import ListMessageBuilder
  32. from gpustack.policies.worker_filters.backend_framework_filter import (
  33. BackendFrameworkFilter,
  34. )
  35. from gpustack.policies.worker_filters.label_matching_filter import LabelMatchingFilter
  36. from gpustack.policies.worker_filters.gpu_matching_filter import GPUMatchingFilter
  37. from gpustack.policies.worker_filters.local_path_filter import LocalPathFilter
  38. from gpustack.policies.worker_filters.cluster_filter import ClusterFilter
  39. from gpustack.scheduler.model_registry import detect_model_type
  40. from gpustack.scheduler.meta_registry import get_model_meta
  41. from gpustack.scheduler.queue import AsyncUniqueQueue
  42. from gpustack.policies.worker_filters.status_filter import StatusFilter
  43. from gpustack import envs
  44. from gpustack.schemas.inference_backend import is_built_in_backend
  45. from gpustack.schemas.workers import Worker
  46. from gpustack.schemas.models import (
  47. BackendEnum,
  48. CategoryEnum,
  49. DistributedServers,
  50. Model,
  51. ModelInstance,
  52. ModelInstanceStateEnum,
  53. get_backend,
  54. is_gguf_model,
  55. DistributedServerCoordinateModeEnum,
  56. SourceEnum,
  57. is_omni_model,
  58. )
  59. from gpustack.schemas.model_files import ModelFileStateEnum
  60. from gpustack.server.bus import EventType
  61. from gpustack.server.db import async_session
  62. from gpustack.scheduler.calculator import (
  63. GPUOffloadEnum,
  64. calculate_gguf_model_resource_claim,
  65. check_diffusers_model_index_from_workers,
  66. )
  67. from gpustack.server.services import (
  68. ModelInstanceService,
  69. ModelService,
  70. ModelFileService,
  71. )
  72. from gpustack.utils.command import find_parameter
  73. from gpustack.utils.gpu import group_gpu_ids_by_worker
  74. from gpustack.utils.hub import has_diffusers_model_index
  75. from gpustack.utils.math import largest_power_of_2_leq
  76. from gpustack.utils.model_source import get_draft_model_source
  77. from gpustack.scheduler.calculator import get_pretrained_config_with_workers
  78. from sqlalchemy.orm.attributes import flag_modified
  79. logger = logging.getLogger(__name__)
  80. class Scheduler:
  81. def __init__(self, cfg: Config, check_interval: int = 180):
  82. """
  83. Init the scheduler with queue and interval.
  84. """
  85. self._id = "model-instance-scheduler"
  86. self._config = cfg
  87. self._check_interval = check_interval
  88. self._queue = AsyncUniqueQueue()
  89. self._cache_dir = None
  90. if self._config.cache_dir is not None:
  91. self._cache_dir = os.path.join(self._config.cache_dir, "gguf-parser")
  92. os.makedirs(self._cache_dir, exist_ok=True)
  93. async def start(self):
  94. """
  95. Start the scheduler.
  96. """
  97. try:
  98. # scheduler queue.
  99. asyncio.create_task(self._schedule_cycle())
  100. # scheduler job trigger by time interval.
  101. trigger = IntervalTrigger(
  102. seconds=self._check_interval, timezone=timezone.utc
  103. )
  104. scheduler = AsyncIOScheduler(timezone=timezone.utc)
  105. scheduler.add_job(
  106. self._enqueue_pending_instances,
  107. trigger=trigger,
  108. id=self._id,
  109. max_instances=1,
  110. )
  111. scheduler.start()
  112. except Exception as e:
  113. logger.error(f"Failed to start scheduler: {e}")
  114. logger.info("Scheduler started.")
  115. # Bootstrap pending state once at startup; replaces the bus replay
  116. # of every existing instance which would flood the queue (#4794).
  117. await self._enqueue_pending_instances()
  118. # Live trigger. event_types/replay_existing keep this subscription
  119. # cheap so UPDATED/HEARTBEAT bursts don't fill the queue.
  120. async for event in ModelInstance.subscribe(
  121. source="scheduler",
  122. event_types={EventType.CREATED},
  123. replay_existing=False,
  124. ):
  125. # The bus filter only blocks events from publishers; the
  126. # subscribe() generator still yields HEARTBEAT events on its
  127. # own to keep the stream alive (active_record.py). Skip those
  128. # and any other non-CREATED events that may surface in future.
  129. if event.type != EventType.CREATED:
  130. continue
  131. # Single-instance path; the IntervalTrigger above is still the
  132. # full-scan fallback for anything missed here.
  133. await self._enqueue_event_instance(event.data)
  134. async def _enqueue_pending_instances(self):
  135. """
  136. Periodic / bootstrap full scan of pending model instances.
  137. """
  138. try:
  139. async with async_session() as session:
  140. instances = await ModelInstance.all(session)
  141. tasks = []
  142. for instance in instances:
  143. if self._should_schedule(instance):
  144. task = asyncio.create_task(self._evaluate(instance))
  145. tasks.append(task)
  146. await asyncio.gather(*tasks)
  147. except Exception as e:
  148. logger.error(f"Failed to enqueue pending model instances: {e}")
  149. async def _enqueue_event_instance(self, instance: Optional[ModelInstance]):
  150. """Event-driven single-instance path. ``_evaluate`` re-fetches from
  151. DB, so the event payload is used only for ``_should_schedule``."""
  152. if instance is None or instance.id is None:
  153. return
  154. try:
  155. if self._should_schedule(instance):
  156. await self._evaluate(instance)
  157. except Exception as e:
  158. logger.error(f"Failed to evaluate instance {instance.id} from event: {e}")
  159. async def _evaluate(self, instance: ModelInstance): # noqa: C901
  160. """
  161. Evaluate the model instance's metadata.
  162. """
  163. async with async_session() as session:
  164. try:
  165. instance = await ModelInstance.one_by_id(session, instance.id)
  166. # Re-check against the freshly-fetched row: the caller's
  167. # snapshot may be stale (event payload, last full scan, etc.)
  168. # and the user may have deleted or transitioned the instance
  169. # between dispatch and now.
  170. if instance is None or not self._should_schedule(instance):
  171. return
  172. model = await Model.one_by_id(session, instance.model_id)
  173. if model is None:
  174. raise Exception("Model not found.")
  175. if instance.state != ModelInstanceStateEnum.ANALYZING:
  176. instance.state = ModelInstanceStateEnum.ANALYZING
  177. instance.state_message = "Evaluating resource requirements"
  178. await ModelInstanceService(session).update(instance)
  179. # Get available workers for potential remote parsing
  180. workers = await Worker.all(session)
  181. sorted_workers = await prioritize_workers_with_model_files(
  182. session, model, workers
  183. )
  184. should_update_model = False
  185. try:
  186. if is_gguf_model(model):
  187. should_update_model = await evaluate_gguf_model(
  188. model, sorted_workers
  189. )
  190. if await self.check_model_distributability(
  191. session, model, instance
  192. ):
  193. return
  194. else:
  195. should_update_model = await evaluate_pretrained_config(
  196. model,
  197. workers=sorted_workers,
  198. raise_raw=True,
  199. )
  200. except Exception as e:
  201. # Even if the evaluation failed, we still want to proceed to deployment.
  202. # Cases can be:
  203. # 1. Model config is not valid, but is overridable by backend parameters.
  204. # 2. It may not be required to be transformer-compatible for certain backends.
  205. logger.error(
  206. f"Failed to evaluate model {model.name or model.readable_source}: {e}"
  207. )
  208. if should_update_model:
  209. await ModelService(session).update(model)
  210. await self._queue.put(instance)
  211. except Exception as e:
  212. try:
  213. instance.state = ModelInstanceStateEnum.ERROR
  214. instance.state_message = str(e)
  215. await ModelInstanceService(session).update(instance)
  216. except Exception as ue:
  217. logger.error(
  218. f"Failed to update model instance: {ue}. Original error: {e}"
  219. )
  220. async def check_model_distributability(
  221. self, session: AsyncSession, model: Model, instance: ModelInstance
  222. ):
  223. if (
  224. not model.distributable
  225. and model.gpu_selector
  226. and model.gpu_selector.gpu_ids
  227. ):
  228. worker_gpu_ids = group_gpu_ids_by_worker(model.gpu_selector.gpu_ids)
  229. if len(worker_gpu_ids) > 1:
  230. instance.state = ModelInstanceStateEnum.ERROR
  231. instance.state_message = (
  232. "The model is not distributable to multiple workers."
  233. )
  234. await ModelInstanceService(session).update(instance)
  235. return True
  236. return False
  237. def _should_schedule(self, instance: ModelInstance) -> bool:
  238. """
  239. Check if the model instance should be scheduled.
  240. Args:
  241. instance: ModelInstance to check.
  242. """
  243. newly_created = (instance.updated_at - instance.created_at) < timedelta(
  244. seconds=1
  245. )
  246. update_delta = datetime.now(timezone.utc) - instance.updated_at.replace(
  247. tzinfo=timezone.utc
  248. )
  249. return (
  250. (
  251. # When enqueueing pending state model instances, handle two cases:
  252. # 1. Newly created model instances (updated_at - created_at < 1 second),
  253. # which will be updated to ANALYZING in _evaluate.
  254. # 2. Existing PENDING model instances periodically enqueued by the scheduler job.
  255. # In this case, update_delta is longer than 90s, as the scheduler runs every 180s.
  256. instance.worker_id is None
  257. and instance.state == ModelInstanceStateEnum.PENDING
  258. and (newly_created or update_delta > timedelta(seconds=90))
  259. )
  260. or (
  261. # Reschedule while it stays in anayzing state for too long,
  262. # maybe the server is restarted.
  263. instance.worker_id is None
  264. and instance.state == ModelInstanceStateEnum.ANALYZING
  265. and update_delta > timedelta(minutes=3)
  266. )
  267. or (
  268. # Reschedule while it stays in scheduled state for too long,
  269. # maybe the worker is down.
  270. instance.worker_id is not None
  271. and instance.state == ModelInstanceStateEnum.SCHEDULED
  272. and update_delta > timedelta(minutes=3)
  273. )
  274. )
  275. async def _schedule_cycle(self):
  276. while True:
  277. try:
  278. item = await self._queue.get()
  279. try:
  280. await self._schedule_one(item)
  281. self._queue.task_done()
  282. except Exception as e:
  283. logger.error(f"Failed to schedule model instance: {e}")
  284. except queue.Empty:
  285. continue
  286. except Exception as e:
  287. logger.error(f"Failed to get item from schedule queue: {e}")
  288. async def _schedule_one(self, instance: ModelInstance): # noqa: C901
  289. """
  290. Schedule a model instance by picking one candidate.
  291. Args:
  292. item: Model instance to schedule.
  293. """
  294. logger.debug(f"Scheduling model instance {instance.name}")
  295. state_message = ""
  296. async with async_session() as session:
  297. workers = await Worker.all(session)
  298. if not workers:
  299. state_message = "No available workers"
  300. model = await Model.one_by_id(session, instance.model_id)
  301. if model is None:
  302. state_message = "Model not found"
  303. model_instance = await ModelInstance.one_by_id(session, instance.id)
  304. if model_instance is None:
  305. logger.debug(
  306. f"Model instance(ID: {instance.id}) was deleted before scheduling due"
  307. )
  308. return
  309. model_instances = await ModelInstance.all(
  310. session, options=[selectinload(ModelInstance.model)]
  311. )
  312. candidate = None
  313. messages = []
  314. if workers and model:
  315. try:
  316. candidate, messages = await find_candidate(
  317. self._config, model, workers, model_instances
  318. )
  319. except Exception as e:
  320. state_message = f"Failed to find candidate: {e}"
  321. if candidate is None:
  322. # update model instance.
  323. if model_instance.state in (
  324. ModelInstanceStateEnum.SCHEDULED,
  325. ModelInstanceStateEnum.ANALYZING,
  326. ):
  327. model_instance.state = ModelInstanceStateEnum.PENDING
  328. model_instance.state_message = (
  329. "No suitable workers.\nDetails:\n" + "".join(messages)
  330. )
  331. if state_message != "":
  332. model_instance.state_message = state_message
  333. await ModelInstanceService(session).update(model_instance)
  334. logger.debug(
  335. f"No suitable workers for model instance {model_instance.name}, state: {model_instance.state}"
  336. )
  337. else:
  338. # update model instance.
  339. model_instance.state = ModelInstanceStateEnum.SCHEDULED
  340. model_instance.state_message = ""
  341. model_instance.worker_id = candidate.worker.id
  342. model_instance.worker_name = candidate.worker.name
  343. model_instance.worker_ip = candidate.worker.ip
  344. model_instance.worker_advertise_address = (
  345. candidate.worker.advertise_address
  346. )
  347. model_instance.worker_ifname = candidate.worker.ifname
  348. model_instance.computed_resource_claim = (
  349. candidate.computed_resource_claim
  350. )
  351. model_instance.gpu_type = candidate.gpu_type
  352. model_instance.gpu_indexes = candidate.gpu_indexes
  353. model_instance.gpu_addresses = candidate.gpu_addresses
  354. model_instance.distributed_servers = DistributedServers(
  355. subordinate_workers=candidate.subordinate_workers,
  356. )
  357. if get_backend(model) in (
  358. BackendEnum.VLLM,
  359. BackendEnum.ASCEND_MINDIE,
  360. BackendEnum.SGLANG,
  361. ):
  362. model_instance.distributed_servers.mode = (
  363. DistributedServerCoordinateModeEnum.INITIALIZE_LATER
  364. )
  365. await ModelInstanceService(session).update(model_instance)
  366. logger.debug(
  367. f"Scheduled model instance {model_instance.name} to worker "
  368. f"{model_instance.worker_name} gpu {candidate.gpu_indexes}"
  369. )
  370. async def find_candidate(
  371. config: Config,
  372. model: Model,
  373. workers: List[Worker],
  374. model_instances: List[ModelInstance],
  375. ) -> Tuple[Optional[ModelInstanceScheduleCandidate], List[str]]:
  376. """
  377. Find a schedule candidate for the model instance.
  378. :param config: GPUStack configuration.
  379. :param model: Model to schedule.
  380. :param workers: List of workers to consider.
  381. :return: A tuple containing:
  382. - The schedule candidate.
  383. - A list of messages for the scheduling process.
  384. """
  385. # Filter workers.
  386. filters = [
  387. ClusterFilter(model),
  388. GPUMatchingFilter(model),
  389. LabelMatchingFilter(model),
  390. StatusFilter(model),
  391. BackendFrameworkFilter(model),
  392. LocalPathFilter(model),
  393. ]
  394. worker_filter_chain = WorkerFilterChain(filters)
  395. workers, filter_messages = await worker_filter_chain.filter(workers)
  396. messages = []
  397. if filter_messages:
  398. messages.append(str(ListMessageBuilder(filter_messages)) + "\n")
  399. if len(workers) == 0:
  400. return None, messages
  401. # Initialize candidate selector.
  402. try:
  403. if is_gguf_model(model):
  404. candidates_selector = GGUFResourceFitSelector(
  405. model, model_instances, config.cache_dir
  406. )
  407. elif model.backend == BackendEnum.ASCEND_MINDIE:
  408. candidates_selector = AscendMindIEResourceFitSelector(
  409. config, model, model_instances
  410. )
  411. elif model.backend == BackendEnum.VLLM and not is_omni_model(model):
  412. # Note: Route omni categories to CustomSelector for vLLM-Omni.
  413. candidates_selector = VLLMResourceFitSelector(
  414. config, model, model_instances
  415. )
  416. elif model.backend == BackendEnum.SGLANG:
  417. candidates_selector = SGLangResourceFitSelector(
  418. config, model, model_instances
  419. )
  420. else:
  421. candidates_selector = CustomBackendResourceFitSelector(
  422. config, model, model_instances
  423. )
  424. except Exception as e:
  425. return None, [f"Failed to initialize {model.backend} candidates selector: {e}"]
  426. # Select candidates.
  427. candidates = await candidates_selector.select_candidates(workers)
  428. # Score candidates.
  429. candidate_scorers = [
  430. PlacementScorer(model, model_instances),
  431. ]
  432. locality_max_score = envs.SCHEDULER_SCALE_UP_LOCALITY_MAX_SCORE
  433. if locality_max_score > 0:
  434. candidate_scorers.append(
  435. ModelFileLocalityScorer(
  436. model,
  437. draft_model_source=get_draft_model_source(model),
  438. max_score=locality_max_score,
  439. )
  440. )
  441. candidates = await CandidateScoreChain(candidate_scorers).score(candidates)
  442. # Pick the highest score candidate.
  443. candidate = pick_highest_score_candidate(candidates)
  444. # Collect messages.
  445. if candidate is None and len(workers) > 0:
  446. resource_fit_messages = candidates_selector.get_messages() or [
  447. "No workers meet the resource requirements."
  448. ]
  449. messages.extend(resource_fit_messages)
  450. elif candidate and candidate.overcommit:
  451. messages.extend(candidates_selector.get_messages())
  452. # Return the candidate and messages.
  453. return candidate, messages
  454. def pick_highest_score_candidate(candidates: List[ModelInstanceScheduleCandidate]):
  455. """
  456. Pick the most offload layers from candidates.
  457. Args:
  458. candidates: List of ModelInstanceScheduleCandidate.
  459. """
  460. logger.debug(f"Pick highest score candidate from {len(candidates)} candidates")
  461. if len(candidates) == 0:
  462. return None
  463. candidate = candidates[0]
  464. for i in range(1, len(candidates)):
  465. if candidates[i].score > candidate.score:
  466. candidate = candidates[i]
  467. return candidate
  468. async def evaluate_gguf_model(
  469. model: Model,
  470. workers: Optional[List[Worker]] = None,
  471. ) -> bool:
  472. task_output = await calculate_gguf_model_resource_claim(
  473. model, offload=GPUOffloadEnum.Full, workers=workers
  474. )
  475. if (
  476. task_output.resource_architecture
  477. and not task_output.resource_architecture.is_deployable()
  478. ):
  479. raise ValueError(
  480. "Unsupported model. To proceed with deployment, ensure the model is supported by backend, or deploy it using a custom backend version or custom backend."
  481. )
  482. should_update = False
  483. if task_output.resource_claim_estimate.reranking and not model.categories:
  484. should_update = True
  485. model.categories = [CategoryEnum.RERANKER]
  486. if task_output.resource_claim_estimate.embeddingOnly and not model.categories:
  487. should_update = True
  488. model.categories = [CategoryEnum.EMBEDDING]
  489. if task_output.resource_claim_estimate.imageOnly and not model.categories:
  490. should_update = True
  491. model.categories = [CategoryEnum.IMAGE]
  492. if not model.categories:
  493. should_update = True
  494. model.categories = [CategoryEnum.LLM]
  495. if task_output.resource_claim_estimate.distributable and not model.distributable:
  496. should_update = True
  497. model.distributable = True
  498. if model.gpu_selector and model.gpu_selector.gpu_ids:
  499. worker_gpu_ids = group_gpu_ids_by_worker(model.gpu_selector.gpu_ids)
  500. if (
  501. len(worker_gpu_ids) > 1
  502. and model.distributable
  503. and not model.distributed_inference_across_workers
  504. ):
  505. should_update = True
  506. model.distributed_inference_across_workers = True
  507. gpus_per_replica_modified = set_model_gpus_per_replica(model)
  508. should_update = should_update or gpus_per_replica_modified
  509. return should_update
  510. async def evaluate_diffusion_model(
  511. model: Model,
  512. workers: Optional[List[Worker]] = None,
  513. ):
  514. """
  515. Evaluate diffusion model and update model categories.
  516. Args:
  517. model: Model to evaluate
  518. workers: Optional list of workers (for LOCAL_PATH remote read)
  519. Returns:
  520. True if the model is a diffusion model, False otherwise
  521. """
  522. # vLLM/SGLang support Diffusers (image) models.
  523. # If the source (HF/ModelScope/Local Path) contains model_index.json with "_diffusers_version",
  524. # classify as IMAGE directly.
  525. if model.categories and CategoryEnum.IMAGE not in model.categories:
  526. return False
  527. hf_token = get_global_config().huggingface_token
  528. # For Hub sources and local files, use hub.py function
  529. if model.source in (SourceEnum.HUGGING_FACE, SourceEnum.MODEL_SCOPE):
  530. is_diffusers = await asyncio.wait_for(
  531. asyncio.to_thread(has_diffusers_model_index, model, token=hf_token),
  532. timeout=10,
  533. )
  534. # For LOCAL_PATH, try local first, then workers
  535. elif model.source == SourceEnum.LOCAL_PATH:
  536. # Try local read first
  537. is_diffusers = await asyncio.wait_for(
  538. asyncio.to_thread(has_diffusers_model_index, model, token=hf_token),
  539. timeout=10,
  540. )
  541. # If not found locally and workers are provided, query workers
  542. if not is_diffusers and workers:
  543. is_diffusers = await asyncio.wait_for(
  544. check_diffusers_model_index_from_workers(model, workers),
  545. timeout=10,
  546. )
  547. else:
  548. return False
  549. if is_diffusers:
  550. model.categories = [CategoryEnum.IMAGE]
  551. return True
  552. return False
  553. async def prioritize_workers_with_model_files(
  554. session: AsyncSession, model: Model, workers: List[Worker]
  555. ) -> List[Worker]:
  556. """
  557. Prioritize workers that have the model files. This helps optimization for getting model config from remote worker local paths.
  558. Args:
  559. session: Database session for querying worker files.
  560. model: Model to check for.
  561. workers: List of workers to prioritize.
  562. Returns:
  563. List of prioritized workers.
  564. """
  565. if not workers:
  566. return []
  567. source_index = model.model_source_index
  568. if not source_index:
  569. return workers
  570. model_files = await ModelFileService(session).get_by_source_index(source_index)
  571. if not model_files:
  572. return workers
  573. worker_ids_with_ready_files = {
  574. mf.worker_id for mf in model_files if mf.state == ModelFileStateEnum.READY
  575. }
  576. # Put workers with ready model files at the front
  577. sorted_workers = sorted(
  578. workers,
  579. key=lambda w: 0 if w.id in worker_ids_with_ready_files else 1,
  580. )
  581. return sorted_workers
  582. async def evaluate_pretrained_config(
  583. model: Model,
  584. workers: Optional[List[Worker]] = None,
  585. raise_raw: bool = False,
  586. ) -> bool:
  587. """
  588. evaluate the model's pretrained config to determine its categories, meta and gpus_per_replica.
  589. Args:
  590. model: Model to evaluate.
  591. workers: Optional list of workers (for LOCAL_PATH).
  592. raise_raw: If True, raise the raw exception.
  593. Returns:
  594. True if the model's categories are updated, False otherwise.
  595. """
  596. # 1) try to evaluate as diffusion model
  597. try:
  598. is_image_category = await evaluate_diffusion_model(model, workers=workers)
  599. if is_image_category:
  600. return True
  601. except Exception:
  602. pass
  603. # 2) Check overrided architectures if specified in backend parameters.
  604. architectures = get_vllm_override_architectures(model)
  605. if not architectures:
  606. try:
  607. trust_remote_code = _extract_trust_remote_code(model)
  608. pretrained_config = await get_pretrained_config_with_workers(
  609. model,
  610. workers=workers,
  611. trust_remote_code=trust_remote_code,
  612. )
  613. except ValueError as e:
  614. # Skip value error exceptions and defaults to LLM catagory for certain cases.
  615. if should_skip_architecture_check(model):
  616. model.categories = model.categories or [CategoryEnum.LLM]
  617. return True
  618. if raise_raw:
  619. raise
  620. logger.debug(
  621. f"Failed to get config for model {model.name or model.readable_source}, ValueError: {e}"
  622. )
  623. raise simplify_auto_config_value_error(e)
  624. except (TimeoutError, asyncio.TimeoutError) as e:
  625. raise Exception(
  626. f"Timeout while getting config for model {model.name or model.readable_source}: {e}."
  627. )
  628. except Exception as e:
  629. raise Exception(
  630. f"Failed to get config for model {model.name or model.readable_source}: {e}"
  631. )
  632. architectures = getattr(pretrained_config, "architectures", []) or []
  633. if not architectures and not model.backend_version:
  634. raise ValueError(
  635. "Unrecognized architecture. To proceed with deployment, ensure the model is supported by backend, or deploy it using a custom backend version or custom backend."
  636. )
  637. model_type = detect_model_type(architectures)
  638. # TODO : Additional checks for unsupported architectures for other backends.
  639. if (
  640. model.backend == BackendEnum.VLLM
  641. and model_type == CategoryEnum.UNKNOWN
  642. and not model.backend_version
  643. ):
  644. raise ValueError(
  645. f"Unsupported architecture: {architectures}. To proceed with deployment, ensure the model is supported by backend, or deploy it using a custom backend version or custom backend."
  646. )
  647. meta_modified = False
  648. if not model.meta and (known_meta := get_model_meta(pretrained_config)):
  649. model.meta = known_meta
  650. meta_modified = True
  651. categories_modified = set_model_categories(model, model_type)
  652. gpus_per_replica_modified = set_model_gpus_per_replica(model)
  653. return categories_modified or gpus_per_replica_modified or meta_modified
  654. def _extract_trust_remote_code(model: Model) -> bool:
  655. """Extract trust_remote_code from model backend parameters."""
  656. if model.backend_parameters and "--trust-remote-code" in model.backend_parameters:
  657. return True
  658. return False
  659. def get_vllm_override_architectures(model: Model) -> List[str]:
  660. """
  661. Get the vLLM override architectures from the model's backend parameters.
  662. Args:
  663. model: Model to check.
  664. Returns:
  665. List of override architectures.
  666. """
  667. backend = get_backend(model)
  668. if backend != BackendEnum.VLLM:
  669. return []
  670. hf_overrides = find_parameter(model.backend_parameters, ["hf-overrides"])
  671. if hf_overrides:
  672. overrides_dict = json.loads(hf_overrides)
  673. return overrides_dict.get("architectures", [])
  674. return []
  675. def should_skip_architecture_check(model: Model) -> bool:
  676. """
  677. Check if the model should skip architecture check.
  678. Args:
  679. model: Model to check.
  680. Returns:
  681. True if the model should skip architecture check, False otherwise.
  682. """
  683. if (
  684. model.backend == BackendEnum.CUSTOM
  685. or not is_built_in_backend(model.backend)
  686. or model.backend_version
  687. ):
  688. # New model architectures may be added with custom backend/version.
  689. return True
  690. if model.backend_parameters and find_parameter(
  691. model.backend_parameters, ["tokenizer-mode"]
  692. ):
  693. # Models like Pixtral may not provide compatible config but still work with custom parameters.
  694. return True
  695. return False
  696. def simplify_auto_config_value_error(e: ValueError) -> ValueError:
  697. """
  698. Simplify the error message for ValueError exceptions.
  699. """
  700. message = str(e)
  701. if "trust_remote_code=True" in message:
  702. return ValueError(
  703. "The model contains custom code that must be executed to load correctly. If you trust the source, please pass the backend parameter `--trust-remote-code` to allow custom code to be run."
  704. )
  705. if "pip install --upgrade transformers" in message:
  706. return ValueError(
  707. "Unsupported model. To proceed with deployment, ensure the model is supported by backend, or deploy it using a custom backend version or custom backend."
  708. )
  709. return ValueError(f"Not a supported model.\n\n{message}")
  710. def set_model_categories(model: Model, model_type: CategoryEnum) -> bool:
  711. if model.categories:
  712. return False
  713. if model_type == CategoryEnum.UNKNOWN:
  714. # Default to LLM for unknown architectures
  715. model.categories = [CategoryEnum.LLM]
  716. else:
  717. model.categories = [model_type]
  718. return True
  719. def set_model_gpus_per_replica(model: Model) -> bool:
  720. """
  721. Set the model's gpu_selector.gpus_per_replica based on its gpu_selector.gpu_ids and backend parameters.
  722. Args:
  723. model: Model to set.
  724. Returns:
  725. True if the model's gpu_selector.gpus_per_replica is updated, False otherwise.
  726. """
  727. def calculate_gpus_per_replica(model: Model) -> int:
  728. if model.backend == BackendEnum.VOX_BOX.value:
  729. return 1
  730. # User-specified world size from backend parameters takes precedence.
  731. if model.backend_parameters is not None:
  732. selector_map = {
  733. BackendEnum.VLLM.value: VLLMResourceFitSelector,
  734. BackendEnum.ASCEND_MINDIE.value: AscendMindIEResourceFitSelector,
  735. BackendEnum.SGLANG.value: SGLangResourceFitSelector,
  736. }
  737. selector = selector_map.get(model.backend)
  738. world_size = None
  739. if selector:
  740. result = selector.get_world_size_from_backend_parameters(model)
  741. world_size, _ = result if result is not None else (None, None)
  742. if world_size and world_size > 0:
  743. return world_size
  744. # The largest power of 2 less than or equal to (total GPUs / replicas), used as the initial per-replica GPU count.
  745. gpus_per_replica = largest_power_of_2_leq(
  746. len(model.gpu_selector.gpu_ids) // model.replicas
  747. )
  748. return gpus_per_replica
  749. if not model.gpu_selector or not model.gpu_selector.gpu_ids:
  750. return False
  751. if model.gpu_selector.gpus_per_replica and model.gpu_selector.gpus_per_replica > 0:
  752. return False
  753. gpus_per_replica = calculate_gpus_per_replica(model)
  754. model.gpu_selector.gpus_per_replica = gpus_per_replica
  755. try:
  756. flag_modified(model, "gpu_selector")
  757. except AttributeError:
  758. # Ignore if the given model is not a SQLModel instance.
  759. pass
  760. return True