model_files.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286
  1. from typing import Optional
  2. from fastapi import APIRouter, Depends
  3. from fastapi.responses import StreamingResponse
  4. from sqlmodel import String, cast, func, or_
  5. from pathlib import Path
  6. from sqlalchemy.orm import selectinload
  7. from gpustack.api.exceptions import (
  8. AlreadyExistsException,
  9. ConflictException,
  10. InternalServerErrorException,
  11. )
  12. from gpustack.api.tenant import (
  13. bypass_tenant_filter,
  14. assert_cluster_resource_visible,
  15. cluster_resource_visibility_conditions,
  16. )
  17. from gpustack.schemas.workers import Worker
  18. from gpustack.server.db import async_session
  19. from gpustack.server.deps import SessionDep, TenantContextDep
  20. from gpustack.schemas.model_files import (
  21. ModelFile,
  22. ModelFileCreate,
  23. ModelFileListParams,
  24. ModelFilePublic,
  25. ModelFileStateEnum,
  26. ModelFileUpdate,
  27. ModelFilesPublic,
  28. )
  29. router = APIRouter()
  30. def _make_model_file_visibility_filter(ctx):
  31. def _visible(m: ModelFile) -> bool:
  32. if bypass_tenant_filter(ctx):
  33. return True
  34. org_id = getattr(m, "owner_principal_id", None)
  35. if (
  36. ctx.current_principal_id is not None
  37. and org_id is not None
  38. and org_id == ctx.current_principal_id
  39. ):
  40. return True
  41. if getattr(m, "cluster_id", None) in ctx.accessible_cluster_ids:
  42. return True
  43. return False
  44. return _visible
  45. def _model_file_search_clause(search: str):
  46. lower_search = search.lower()
  47. return or_(
  48. *[
  49. func.lower(cast(ModelFile.resolved_paths, String)).like(
  50. f"%{lower_search}%"
  51. ),
  52. func.lower(ModelFile.huggingface_repo_id).like(f"%{lower_search}%"),
  53. func.lower(ModelFile.huggingface_filename).like(f"%{lower_search}%"),
  54. func.lower(ModelFile.model_scope_model_id).like(f"%{lower_search}%"),
  55. func.lower(ModelFile.model_scope_file_path).like(f"%{lower_search}%"),
  56. func.lower(ModelFile.local_path).like(f"%{lower_search}%"),
  57. ]
  58. )
  59. def _normalize_model_file_order_by(order_by):
  60. if not order_by:
  61. return order_by
  62. new_order_by = []
  63. for field, direction in order_by:
  64. if field == "source":
  65. # add additional sorting fields for deterministic ordering
  66. new_order_by.append((field, direction))
  67. new_order_by.append(("huggingface_repo_id", direction))
  68. new_order_by.append(("huggingface_filename", direction))
  69. new_order_by.append(("model_scope_model_id", direction))
  70. new_order_by.append(("model_scope_file_path", direction))
  71. new_order_by.append(("local_path", direction))
  72. elif field == "resolved_paths":
  73. # resolved_paths is a JSON field; replace ordering with expression
  74. new_order_by.append((cast(ModelFile.resolved_paths, String), direction))
  75. else:
  76. new_order_by.append((field, direction))
  77. return new_order_by
  78. @router.get("", response_model=ModelFilesPublic)
  79. async def get_model_files(
  80. ctx: TenantContextDep,
  81. params: ModelFileListParams = Depends(),
  82. search: str = None,
  83. worker_id: int = None,
  84. ):
  85. fields = {"worker_id": worker_id} if worker_id else {}
  86. visible = _make_model_file_visibility_filter(ctx)
  87. if params.watch:
  88. filter_func = (
  89. (lambda data: visible(data) and search_model_file_filter(data, search))
  90. if search
  91. else visible
  92. )
  93. return StreamingResponse(
  94. ModelFile.streaming(fields=fields, filter_func=filter_func),
  95. media_type="text/event-stream",
  96. )
  97. extra_conditions = list(cluster_resource_visibility_conditions(ctx, ModelFile))
  98. if search:
  99. extra_conditions.append(_model_file_search_clause(search))
  100. async with async_session() as session:
  101. return await ModelFile.paginated_by_query(
  102. session=session,
  103. fields=fields,
  104. extra_conditions=extra_conditions,
  105. page=params.page,
  106. per_page=params.perPage,
  107. order_by=_normalize_model_file_order_by(params.order_by),
  108. )
  109. def search_model_file_filter(data: ModelFile, search: str) -> bool:
  110. if (
  111. (
  112. data.huggingface_repo_id
  113. and search.lower() in data.huggingface_repo_id.lower()
  114. )
  115. or (
  116. data.huggingface_filename
  117. and search.lower() in data.huggingface_filename.lower()
  118. )
  119. or (
  120. data.model_scope_model_id
  121. and search.lower() in data.model_scope_model_id.lower()
  122. )
  123. or (
  124. data.model_scope_file_path
  125. and search.lower() in data.model_scope_file_path.lower()
  126. )
  127. or (data.local_path and search.lower() in data.local_path.lower())
  128. or (data.resolved_paths and search.lower() in data.resolved_paths[0].lower())
  129. ):
  130. return True
  131. return False
  132. @router.get("/{id}", response_model=ModelFilePublic)
  133. async def get_model_file(session: SessionDep, ctx: TenantContextDep, id: int):
  134. model_file = await ModelFile.one_by_id(session, id)
  135. assert_cluster_resource_visible(
  136. ctx, model_file, not_found_message=f"Model file {id} not found"
  137. )
  138. return model_file
  139. @router.post("", response_model=ModelFilePublic)
  140. async def create_model_file(
  141. session: SessionDep, ctx: TenantContextDep, model_file_in: ModelFileCreate
  142. ):
  143. fields = {
  144. "worker_id": model_file_in.worker_id,
  145. "source_index": model_file_in.model_source_index,
  146. "local_dir": model_file_in.local_dir,
  147. }
  148. existing = await ModelFile.one_by_fields(session, fields)
  149. if existing:
  150. raise AlreadyExistsException(
  151. message="Model file with the same model source already exists on the worker."
  152. )
  153. if model_file_in.local_dir is not None:
  154. fields = {
  155. "worker_id": model_file_in.worker_id,
  156. "local_dir": model_file_in.local_dir,
  157. }
  158. worker_existing_files = await ModelFile.all_by_field(
  159. session, field="worker_id", value=model_file_in.worker_id
  160. )
  161. if worker_existing_files:
  162. for file in worker_existing_files:
  163. if (
  164. file.local_dir is not None
  165. and file.huggingface_filename is None
  166. and file.model_scope_file_path is None
  167. and Path(file.local_dir).resolve()
  168. == Path(model_file_in.local_dir).resolve()
  169. ):
  170. raise AlreadyExistsException(
  171. message=f"The local directory {model_file_in.local_dir} is already occupied by {file.readable_source} on this worker."
  172. )
  173. # Derive tenant scope from the targeted worker → cluster.
  174. cluster_id: Optional[int] = None
  175. owner_principal_id: Optional[int] = None
  176. if model_file_in.worker_id is not None:
  177. worker = await Worker.one_by_id(session, model_file_in.worker_id)
  178. if worker is not None:
  179. cluster_id = worker.cluster_id
  180. owner_principal_id = getattr(worker, "owner_principal_id", None)
  181. try:
  182. model_file = ModelFile(
  183. **model_file_in.model_dump(),
  184. source_index=model_file_in.model_source_index,
  185. cluster_id=cluster_id,
  186. owner_principal_id=owner_principal_id,
  187. )
  188. model_file = await ModelFile.create(session, model_file)
  189. except Exception as e:
  190. raise InternalServerErrorException(message=f"Failed to create model file: {e}")
  191. return model_file
  192. @router.put("/{id}", response_model=ModelFilePublic)
  193. async def update_model_file(
  194. session: SessionDep,
  195. ctx: TenantContextDep,
  196. id: int,
  197. model_file_in: ModelFileUpdate,
  198. ):
  199. model_file = await ModelFile.one_by_id(session, id)
  200. assert_cluster_resource_visible(
  201. ctx, model_file, not_found_message=f"Model file {id} not found"
  202. )
  203. try:
  204. await model_file.update(session, model_file_in)
  205. except Exception as e:
  206. raise InternalServerErrorException(message=f"Failed to update model file: {e}")
  207. return model_file
  208. @router.delete("/{id}")
  209. async def delete_model_file(
  210. session: SessionDep,
  211. ctx: TenantContextDep,
  212. id: int,
  213. cleanup: Optional[bool] = None,
  214. ):
  215. model_file = await ModelFile.one_by_id(
  216. session, id, options=[selectinload(ModelFile.instances)]
  217. )
  218. assert_cluster_resource_visible(
  219. ctx, model_file, not_found_message=f"Model file {id} not found"
  220. )
  221. if model_file.instances:
  222. model_instance_names = ", ".join(
  223. [model_instance.name for model_instance in model_file.instances]
  224. )
  225. raise ConflictException(
  226. message=f"Cannot delete the model file. It's being used by model instances: {model_instance_names}.",
  227. )
  228. try:
  229. if cleanup is not None and model_file.cleanup_on_delete != cleanup:
  230. model_file.cleanup_on_delete = cleanup
  231. await model_file.update(session)
  232. await model_file.delete(session)
  233. except Exception as e:
  234. raise InternalServerErrorException(message=f"Failed to delete model file: {e}")
  235. @router.post("/{id}/reset", response_model=ModelFilePublic)
  236. async def reset_model_file(session: SessionDep, ctx: TenantContextDep, id: int):
  237. model_file = await ModelFile.one_by_id(session, id)
  238. assert_cluster_resource_visible(
  239. ctx, model_file, not_found_message=f"Model file {id} not found"
  240. )
  241. try:
  242. model_file.state = ModelFileStateEnum.DOWNLOADING
  243. model_file.download_progress = 0
  244. model_file.state_message = ""
  245. await model_file.update(session)
  246. except Exception as e:
  247. raise InternalServerErrorException(message=f"Failed to update model file: {e}")
  248. return model_file