""" 模型服务层(新表结构) """ from typing import Optional, List from sqlalchemy import or_, case from sqlalchemy.orm import Session, selectinload from app.models.model import ModelNew, ModelPriceNew, ModelCategory from app.schemas.model_schema import ( ModelListResponse, ModelPriceNewResponse, PaginatedResponse, KeywordsResponse, ModelPricingResponse, ) CATEGORY_NAME_TO_VALUE = { # 前端 ModelSquare 分类 Tab 名称 → 数据库 categories 枚举值 # 语言模型包含 LLM(0) 和 MULTIMODAL(1),多模态模型也可用于对话 "语言模型": [0, 1], "多模态模型": [1], "TTS模型": [2], "STT模型": [3], "生图模型": [4], "生视频模型": [5], "图像编辑": [6], "Embedding": [7], "Rerank": [8], # 兼容旧的短名称 "文本": [0], "图像": [1], "视频": [5], "语音": [2], } class ModelService: def __init__(self, db: Session): self.db = db def _build_list_response(self, item: ModelNew) -> ModelListResponse: prices = [ ModelPriceNewResponse( id=p.id, label=p.label, tier_min=p.tier_min, tier_max=p.tier_max, tier_unit=p.tier_unit, input_price_original=p.input_price_original, output_price_original=p.output_price_original, discount_rate=p.discount_rate, discount_label=p.discount_label, input_price_discounted=p.input_price_discounted, output_price_discounted=p.output_price_discounted, currency=p.currency, unit=p.unit, display_multiplier=p.display_multiplier, ) for p in (item.prices or []) ] return ModelListResponse( id=item.id, model_code=item.model_code, display_name=item.display_name, img=item.img, tag1=item.tag1, tag2=item.tag2, description=item.description, custom_description=item.custom_description, keywords=item.keywords, display_tags=item.display_tags, categories=item.categories or [], category=item.categories[0] if item.categories else 0, supplier=item.supplier, is_featured=item.is_featured, is_search=item.is_search, is_thinking=item.is_thinking, is_api_enabled=item.is_api_enabled, is_show_enabled=item.is_show_enabled, created_at=item.created_at, updated_at=item.updated_at, # 前端兼容字段 title=item.model_code, name=item.display_name or item.model_code, keyword=item.keywords, prices=prices, ) def get_models( self, page: int = 1, page_size: int = 20, keyword: Optional[str] = None, category: Optional[str] = None, supplier: Optional[str] = None, group_name: Optional[str] = None, filter_keyword: Optional[str] = None, filter_tag: Optional[str] = None, is_api_enabled: Optional[bool] = None, ) -> PaginatedResponse[ModelListResponse]: query = self.db.query(ModelNew).options( selectinload(ModelNew.prices) ).filter(ModelNew.is_show_enabled == True) if keyword and keyword.strip(): s = f"%{keyword}%" query = query.filter( or_( ModelNew.model_code.ilike(s), ModelNew.display_name.ilike(s), ModelNew.description.ilike(s), ModelNew.custom_description.ilike(s), ModelNew.keywords.ilike(s), ) ) if category and category.strip(): vals = CATEGORY_NAME_TO_VALUE.get(category.strip()) if vals is not None: from sqlalchemy import cast from sqlalchemy.dialects.postgresql import ARRAY, INTEGER if len(vals) == 1: query = query.filter( ModelNew.categories.contains(cast([vals[0]], ARRAY(INTEGER))) ) else: # 多个分类值用 OR 连接(如语言模型包含 LLM 和 MULTIMODAL) query = query.filter( or_(*[ ModelNew.categories.contains(cast([v], ARRAY(INTEGER))) for v in vals ]) ) if supplier and supplier.strip(): query = query.filter(ModelNew.supplier == supplier.strip()) if group_name and group_name.strip(): query = query.filter(ModelNew.group_name == group_name.strip()) if filter_keyword and filter_keyword.strip(): query = query.filter(ModelNew.keywords == filter_keyword.strip()) if filter_tag and filter_tag.strip(): tag = filter_tag.strip() query = query.filter(or_(ModelNew.tag1 == tag, ModelNew.tag2 == tag)) if is_api_enabled is not None: query = query.filter(ModelNew.is_api_enabled == is_api_enabled) query = query.order_by( case((ModelNew.model_code == "qwen3-max", 0), else_=1), ModelNew.is_api_enabled.desc(), ModelNew.created_at.desc(), ) total = query.count() items = query.offset((page - 1) * page_size).limit(page_size).all() return PaginatedResponse( total=total, page=page, page_size=page_size, items=[self._build_list_response(i) for i in items], ) def get_model_by_id(self, model_id: int) -> ModelListResponse: from fastapi import HTTPException model = self.db.query(ModelNew).options( selectinload(ModelNew.prices) ).filter(ModelNew.id == model_id).first() if not model: raise HTTPException(status_code=404, detail="Model not found") return self._build_list_response(model) def get_keywords(self) -> KeywordsResponse: rows = self.db.query(ModelNew.keywords).filter( ModelNew.keywords.isnot(None), ModelNew.keywords != "", ModelNew.is_show_enabled == True, ).distinct().order_by(ModelNew.keywords).all() return KeywordsResponse(keywords=[r[0] for r in rows if r[0]]) def get_group_names(self) -> List[str]: """获取所有模型分组名称(去重排序)""" from sqlalchemy import distinct rows = self.db.query(distinct(ModelNew.group_name)).filter( ModelNew.group_name.isnot(None), ModelNew.group_name != "", ModelNew.is_show_enabled == True, ModelNew.is_local == False, ).order_by(ModelNew.group_name).all() return [r[0] for r in rows if r[0]] def get_featured_models(self, limit: int = 3) -> List[ModelListResponse]: items = self.db.query(ModelNew).options( selectinload(ModelNew.prices) ).filter( ModelNew.is_featured == True, ModelNew.is_show_enabled == True, ).limit(limit).all() return [self._build_list_response(i) for i in items] def get_pricing(self, model_code: str) -> ModelPricingResponse: """获取模型定价信息,支持 source_keys / normalized_keys 查找""" import re def normalize(v: str) -> str: v = v.strip() v = re.sub(r"[\u200B-\u200D\uFEFF]", "", v) v = re.sub(r"\s*-\s*", "-", v) return re.sub(r"\s+", " ", v) normalized = normalize(model_code) # 先精确匹配 model_code model = self.db.query(ModelNew).filter( ModelNew.model_code == model_code ).first() # 再查 source_keys / normalized_keys 数组 if not model: from sqlalchemy import func, cast from sqlalchemy.dialects.postgresql import ARRAY, TEXT model = self.db.query(ModelNew).filter( or_( ModelNew.source_keys.any(model_code), ModelNew.normalized_keys.any(normalized), ) ).first() prices = [] if model: price_rows = self.db.query(ModelPriceNew).filter( ModelPriceNew.model_code == model.model_code, ModelPriceNew.is_active == True, ).order_by(ModelPriceNew.label).all() prices = [ ModelPriceNewResponse( id=p.id, label=p.label, tier_min=p.tier_min, tier_max=p.tier_max, tier_unit=p.tier_unit, input_price_original=p.input_price_original, output_price_original=p.output_price_original, discount_rate=p.discount_rate, discount_label=p.discount_label, input_price_discounted=p.input_price_discounted, output_price_discounted=p.output_price_discounted, currency=p.currency, unit=p.unit, display_multiplier=p.display_multiplier, ) for p in price_rows ] return ModelPricingResponse( model_code=model.model_code if model else model_code, display_name=model.display_name if model else None, description=model.description if model else None, custom_description=model.custom_description if model else None, features=model.features if model else None, rate_limits=model.rate_limits if model else None, tool_call_prices=model.tool_call_prices if model else None, display_tags=model.display_tags if model else None, input_modalities=model.input_modalities if model else None, output_modalities=model.output_modalities if model else None, categories=model.categories if model else [], source_keys=model.source_keys if model else None, normalized_keys=model.normalized_keys if model else None, is_api_enabled=model.is_api_enabled if model else False, prices=prices, )