model_common.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. from typing import List, Optional, Union
  2. from sqlalchemy import bindparam, cast
  3. from sqlmodel import func
  4. from sqlalchemy.dialects.postgresql import JSONB
  5. from sqlalchemy.dialects.mysql import JSON
  6. from gpustack.schemas.models import Model
  7. from gpustack.schemas.model_routes import ModelRoute, MyModel
  8. category_classes = Union[
  9. Model,
  10. ModelRoute,
  11. MyModel,
  12. ]
  13. def build_pg_category_condition(target_class: category_classes, category: str):
  14. if category == "":
  15. return cast(target_class.categories, JSONB).op('@>')(cast('[]', JSONB))
  16. return cast(target_class.categories, JSONB).op('?')(
  17. bindparam(f"category_{category}", category)
  18. )
  19. # Add MySQL category condition construction function
  20. def build_mysql_category_condition(target_class: category_classes, category: str):
  21. if category == "":
  22. return func.json_length(target_class.categories) == 0
  23. return func.json_contains(
  24. target_class.categories, func.cast(func.json_quote(category), JSON), '$'
  25. )
  26. def build_category_conditions(session, target_class: category_classes, categories):
  27. dialect = session.bind.dialect.name
  28. if dialect == "postgresql":
  29. return [
  30. build_pg_category_condition(target_class, category)
  31. for category in categories
  32. ]
  33. elif dialect == "mysql":
  34. return [
  35. build_mysql_category_condition(target_class, category)
  36. for category in categories
  37. ]
  38. else:
  39. raise NotImplementedError(f'Unsupported database {dialect}')
  40. def categories_filter(data: category_classes, categories: Optional[List[str]]):
  41. if not categories:
  42. return True
  43. data_categories = data.categories or []
  44. if not data_categories and "" in categories:
  45. return True
  46. return any(category in data_categories for category in categories)