| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286 |
- from typing import Optional
- from fastapi import APIRouter, Depends
- from fastapi.responses import StreamingResponse
- from sqlmodel import String, cast, func, or_
- from pathlib import Path
- from sqlalchemy.orm import selectinload
- from gpustack.api.exceptions import (
- AlreadyExistsException,
- ConflictException,
- InternalServerErrorException,
- )
- from gpustack.api.tenant import (
- bypass_tenant_filter,
- assert_cluster_resource_visible,
- cluster_resource_visibility_conditions,
- )
- from gpustack.schemas.workers import Worker
- from gpustack.server.db import async_session
- from gpustack.server.deps import SessionDep, TenantContextDep
- from gpustack.schemas.model_files import (
- ModelFile,
- ModelFileCreate,
- ModelFileListParams,
- ModelFilePublic,
- ModelFileStateEnum,
- ModelFileUpdate,
- ModelFilesPublic,
- )
- router = APIRouter()
- def _make_model_file_visibility_filter(ctx):
- def _visible(m: ModelFile) -> bool:
- if bypass_tenant_filter(ctx):
- return True
- org_id = getattr(m, "owner_principal_id", None)
- if (
- ctx.current_principal_id is not None
- and org_id is not None
- and org_id == ctx.current_principal_id
- ):
- return True
- if getattr(m, "cluster_id", None) in ctx.accessible_cluster_ids:
- return True
- return False
- return _visible
- def _model_file_search_clause(search: str):
- lower_search = search.lower()
- return or_(
- *[
- func.lower(cast(ModelFile.resolved_paths, String)).like(
- f"%{lower_search}%"
- ),
- func.lower(ModelFile.huggingface_repo_id).like(f"%{lower_search}%"),
- func.lower(ModelFile.huggingface_filename).like(f"%{lower_search}%"),
- func.lower(ModelFile.model_scope_model_id).like(f"%{lower_search}%"),
- func.lower(ModelFile.model_scope_file_path).like(f"%{lower_search}%"),
- func.lower(ModelFile.local_path).like(f"%{lower_search}%"),
- ]
- )
- def _normalize_model_file_order_by(order_by):
- if not order_by:
- return order_by
- new_order_by = []
- for field, direction in order_by:
- if field == "source":
- # add additional sorting fields for deterministic ordering
- new_order_by.append((field, direction))
- new_order_by.append(("huggingface_repo_id", direction))
- new_order_by.append(("huggingface_filename", direction))
- new_order_by.append(("model_scope_model_id", direction))
- new_order_by.append(("model_scope_file_path", direction))
- new_order_by.append(("local_path", direction))
- elif field == "resolved_paths":
- # resolved_paths is a JSON field; replace ordering with expression
- new_order_by.append((cast(ModelFile.resolved_paths, String), direction))
- else:
- new_order_by.append((field, direction))
- return new_order_by
- @router.get("", response_model=ModelFilesPublic)
- async def get_model_files(
- ctx: TenantContextDep,
- params: ModelFileListParams = Depends(),
- search: str = None,
- worker_id: int = None,
- ):
- fields = {"worker_id": worker_id} if worker_id else {}
- visible = _make_model_file_visibility_filter(ctx)
- if params.watch:
- filter_func = (
- (lambda data: visible(data) and search_model_file_filter(data, search))
- if search
- else visible
- )
- return StreamingResponse(
- ModelFile.streaming(fields=fields, filter_func=filter_func),
- media_type="text/event-stream",
- )
- extra_conditions = list(cluster_resource_visibility_conditions(ctx, ModelFile))
- if search:
- extra_conditions.append(_model_file_search_clause(search))
- async with async_session() as session:
- return await ModelFile.paginated_by_query(
- session=session,
- fields=fields,
- extra_conditions=extra_conditions,
- page=params.page,
- per_page=params.perPage,
- order_by=_normalize_model_file_order_by(params.order_by),
- )
- def search_model_file_filter(data: ModelFile, search: str) -> bool:
- if (
- (
- data.huggingface_repo_id
- and search.lower() in data.huggingface_repo_id.lower()
- )
- or (
- data.huggingface_filename
- and search.lower() in data.huggingface_filename.lower()
- )
- or (
- data.model_scope_model_id
- and search.lower() in data.model_scope_model_id.lower()
- )
- or (
- data.model_scope_file_path
- and search.lower() in data.model_scope_file_path.lower()
- )
- or (data.local_path and search.lower() in data.local_path.lower())
- or (data.resolved_paths and search.lower() in data.resolved_paths[0].lower())
- ):
- return True
- return False
- @router.get("/{id}", response_model=ModelFilePublic)
- async def get_model_file(session: SessionDep, ctx: TenantContextDep, id: int):
- model_file = await ModelFile.one_by_id(session, id)
- assert_cluster_resource_visible(
- ctx, model_file, not_found_message=f"Model file {id} not found"
- )
- return model_file
- @router.post("", response_model=ModelFilePublic)
- async def create_model_file(
- session: SessionDep, ctx: TenantContextDep, model_file_in: ModelFileCreate
- ):
- fields = {
- "worker_id": model_file_in.worker_id,
- "source_index": model_file_in.model_source_index,
- "local_dir": model_file_in.local_dir,
- }
- existing = await ModelFile.one_by_fields(session, fields)
- if existing:
- raise AlreadyExistsException(
- message="Model file with the same model source already exists on the worker."
- )
- if model_file_in.local_dir is not None:
- fields = {
- "worker_id": model_file_in.worker_id,
- "local_dir": model_file_in.local_dir,
- }
- worker_existing_files = await ModelFile.all_by_field(
- session, field="worker_id", value=model_file_in.worker_id
- )
- if worker_existing_files:
- for file in worker_existing_files:
- if (
- file.local_dir is not None
- and file.huggingface_filename is None
- and file.model_scope_file_path is None
- and Path(file.local_dir).resolve()
- == Path(model_file_in.local_dir).resolve()
- ):
- raise AlreadyExistsException(
- message=f"The local directory {model_file_in.local_dir} is already occupied by {file.readable_source} on this worker."
- )
- # Derive tenant scope from the targeted worker → cluster.
- cluster_id: Optional[int] = None
- owner_principal_id: Optional[int] = None
- if model_file_in.worker_id is not None:
- worker = await Worker.one_by_id(session, model_file_in.worker_id)
- if worker is not None:
- cluster_id = worker.cluster_id
- owner_principal_id = getattr(worker, "owner_principal_id", None)
- try:
- model_file = ModelFile(
- **model_file_in.model_dump(),
- source_index=model_file_in.model_source_index,
- cluster_id=cluster_id,
- owner_principal_id=owner_principal_id,
- )
- model_file = await ModelFile.create(session, model_file)
- except Exception as e:
- raise InternalServerErrorException(message=f"Failed to create model file: {e}")
- return model_file
- @router.put("/{id}", response_model=ModelFilePublic)
- async def update_model_file(
- session: SessionDep,
- ctx: TenantContextDep,
- id: int,
- model_file_in: ModelFileUpdate,
- ):
- model_file = await ModelFile.one_by_id(session, id)
- assert_cluster_resource_visible(
- ctx, model_file, not_found_message=f"Model file {id} not found"
- )
- try:
- await model_file.update(session, model_file_in)
- except Exception as e:
- raise InternalServerErrorException(message=f"Failed to update model file: {e}")
- return model_file
- @router.delete("/{id}")
- async def delete_model_file(
- session: SessionDep,
- ctx: TenantContextDep,
- id: int,
- cleanup: Optional[bool] = None,
- ):
- model_file = await ModelFile.one_by_id(
- session, id, options=[selectinload(ModelFile.instances)]
- )
- assert_cluster_resource_visible(
- ctx, model_file, not_found_message=f"Model file {id} not found"
- )
- if model_file.instances:
- model_instance_names = ", ".join(
- [model_instance.name for model_instance in model_file.instances]
- )
- raise ConflictException(
- message=f"Cannot delete the model file. It's being used by model instances: {model_instance_names}.",
- )
- try:
- if cleanup is not None and model_file.cleanup_on_delete != cleanup:
- model_file.cleanup_on_delete = cleanup
- await model_file.update(session)
- await model_file.delete(session)
- except Exception as e:
- raise InternalServerErrorException(message=f"Failed to delete model file: {e}")
- @router.post("/{id}/reset", response_model=ModelFilePublic)
- async def reset_model_file(session: SessionDep, ctx: TenantContextDep, id: int):
- model_file = await ModelFile.one_by_id(session, id)
- assert_cluster_resource_visible(
- ctx, model_file, not_found_message=f"Model file {id} not found"
- )
- try:
- model_file.state = ModelFileStateEnum.DOWNLOADING
- model_file.download_progress = 0
- model_file.state_message = ""
- await model_file.update(session)
- except Exception as e:
- raise InternalServerErrorException(message=f"Failed to update model file: {e}")
- return model_file
|