| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657 |
- from typing import List, Optional, Union
- from sqlalchemy import bindparam, cast
- from sqlmodel import func
- from sqlalchemy.dialects.postgresql import JSONB
- from sqlalchemy.dialects.mysql import JSON
- from gpustack.schemas.models import Model
- from gpustack.schemas.model_routes import ModelRoute, MyModel
- category_classes = Union[
- Model,
- ModelRoute,
- MyModel,
- ]
- def build_pg_category_condition(target_class: category_classes, category: str):
- if category == "":
- return cast(target_class.categories, JSONB).op('@>')(cast('[]', JSONB))
- return cast(target_class.categories, JSONB).op('?')(
- bindparam(f"category_{category}", category)
- )
- # Add MySQL category condition construction function
- def build_mysql_category_condition(target_class: category_classes, category: str):
- if category == "":
- return func.json_length(target_class.categories) == 0
- return func.json_contains(
- target_class.categories, func.cast(func.json_quote(category), JSON), '$'
- )
- def build_category_conditions(session, target_class: category_classes, categories):
- dialect = session.bind.dialect.name
- if dialect == "postgresql":
- return [
- build_pg_category_condition(target_class, category)
- for category in categories
- ]
- elif dialect == "mysql":
- return [
- build_mysql_category_condition(target_class, category)
- for category in categories
- ]
- else:
- raise NotImplementedError(f'Unsupported database {dialect}')
- def categories_filter(data: category_classes, categories: Optional[List[str]]):
- if not categories:
- return True
- data_categories = data.categories or []
- if not data_categories and "" in categories:
- return True
- return any(category in data_categories for category in categories)
|