| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274 |
- """
- 模型服务层(新表结构)
- """
- from typing import Optional, List
- from sqlalchemy import or_, case
- from sqlalchemy.orm import Session, selectinload
- from app.models.model import ModelNew, ModelPriceNew, ModelCategory
- from app.schemas.model_schema import (
- ModelListResponse,
- ModelPriceNewResponse,
- PaginatedResponse,
- KeywordsResponse,
- ModelPricingResponse,
- )
- CATEGORY_NAME_TO_VALUE = {
- # 前端 ModelSquare 分类 Tab 名称 → 数据库 categories 枚举值
- # 语言模型包含 LLM(0) 和 MULTIMODAL(1),多模态模型也可用于对话
- "语言模型": [0, 1],
- "多模态模型": [1],
- "TTS模型": [2],
- "STT模型": [3],
- "生图模型": [4],
- "生视频模型": [5],
- "图像编辑": [6],
- "Embedding": [7],
- "Rerank": [8],
- # 兼容旧的短名称
- "文本": [0],
- "图像": [1],
- "视频": [5],
- "语音": [2],
- }
- class ModelService:
- def __init__(self, db: Session):
- self.db = db
- def _build_list_response(self, item: ModelNew) -> ModelListResponse:
- prices = [
- ModelPriceNewResponse(
- id=p.id,
- label=p.label,
- tier_min=p.tier_min,
- tier_max=p.tier_max,
- tier_unit=p.tier_unit,
- input_price_original=p.input_price_original,
- output_price_original=p.output_price_original,
- discount_rate=p.discount_rate,
- discount_label=p.discount_label,
- input_price_discounted=p.input_price_discounted,
- output_price_discounted=p.output_price_discounted,
- currency=p.currency,
- unit=p.unit,
- display_multiplier=p.display_multiplier,
- )
- for p in (item.prices or [])
- ]
- return ModelListResponse(
- id=item.id,
- model_code=item.model_code,
- display_name=item.display_name,
- img=item.img,
- tag1=item.tag1,
- tag2=item.tag2,
- description=item.description,
- custom_description=item.custom_description,
- keywords=item.keywords,
- display_tags=item.display_tags,
- categories=item.categories or [],
- category=item.categories[0] if item.categories else 0,
- supplier=item.supplier,
- is_featured=item.is_featured,
- is_search=item.is_search,
- is_thinking=item.is_thinking,
- is_api_enabled=item.is_api_enabled,
- is_show_enabled=item.is_show_enabled,
- created_at=item.created_at,
- updated_at=item.updated_at,
- # 前端兼容字段
- title=item.model_code,
- name=item.display_name or item.model_code,
- keyword=item.keywords,
- prices=prices,
- )
- def get_models(
- self,
- page: int = 1,
- page_size: int = 20,
- keyword: Optional[str] = None,
- category: Optional[str] = None,
- supplier: Optional[str] = None,
- group_name: Optional[str] = None,
- filter_keyword: Optional[str] = None,
- filter_tag: Optional[str] = None,
- is_api_enabled: Optional[bool] = None,
- ) -> PaginatedResponse[ModelListResponse]:
- query = self.db.query(ModelNew).options(
- selectinload(ModelNew.prices)
- ).filter(ModelNew.is_show_enabled == True)
- if keyword and keyword.strip():
- s = f"%{keyword}%"
- query = query.filter(
- or_(
- ModelNew.model_code.ilike(s),
- ModelNew.display_name.ilike(s),
- ModelNew.description.ilike(s),
- ModelNew.custom_description.ilike(s),
- ModelNew.keywords.ilike(s),
- )
- )
- if category and category.strip():
- vals = CATEGORY_NAME_TO_VALUE.get(category.strip())
- if vals is not None:
- from sqlalchemy import cast
- from sqlalchemy.dialects.postgresql import ARRAY, INTEGER
- if len(vals) == 1:
- query = query.filter(
- ModelNew.categories.contains(cast([vals[0]], ARRAY(INTEGER)))
- )
- else:
- # 多个分类值用 OR 连接(如语言模型包含 LLM 和 MULTIMODAL)
- query = query.filter(
- or_(*[
- ModelNew.categories.contains(cast([v], ARRAY(INTEGER)))
- for v in vals
- ])
- )
- if supplier and supplier.strip():
- query = query.filter(ModelNew.supplier == supplier.strip())
- if group_name and group_name.strip():
- query = query.filter(ModelNew.group_name == group_name.strip())
- if filter_keyword and filter_keyword.strip():
- query = query.filter(ModelNew.keywords == filter_keyword.strip())
- if filter_tag and filter_tag.strip():
- tag = filter_tag.strip()
- query = query.filter(or_(ModelNew.tag1 == tag, ModelNew.tag2 == tag))
- if is_api_enabled is not None:
- query = query.filter(ModelNew.is_api_enabled == is_api_enabled)
- query = query.order_by(
- case((ModelNew.model_code == "qwen3-max", 0), else_=1),
- ModelNew.is_api_enabled.desc(),
- ModelNew.created_at.desc(),
- )
- total = query.count()
- items = query.offset((page - 1) * page_size).limit(page_size).all()
- return PaginatedResponse(
- total=total,
- page=page,
- page_size=page_size,
- items=[self._build_list_response(i) for i in items],
- )
- def get_model_by_id(self, model_id: int) -> ModelListResponse:
- from fastapi import HTTPException
- model = self.db.query(ModelNew).options(
- selectinload(ModelNew.prices)
- ).filter(ModelNew.id == model_id).first()
- if not model:
- raise HTTPException(status_code=404, detail="Model not found")
- return self._build_list_response(model)
- def get_keywords(self) -> KeywordsResponse:
- rows = self.db.query(ModelNew.keywords).filter(
- ModelNew.keywords.isnot(None),
- ModelNew.keywords != "",
- ModelNew.is_show_enabled == True,
- ).distinct().order_by(ModelNew.keywords).all()
- return KeywordsResponse(keywords=[r[0] for r in rows if r[0]])
- def get_group_names(self) -> List[str]:
- """获取所有模型分组名称(去重排序)"""
- from sqlalchemy import distinct
- rows = self.db.query(distinct(ModelNew.group_name)).filter(
- ModelNew.group_name.isnot(None),
- ModelNew.group_name != "",
- ModelNew.is_show_enabled == True,
- ModelNew.is_local == False,
- ).order_by(ModelNew.group_name).all()
- return [r[0] for r in rows if r[0]]
- def get_featured_models(self, limit: int = 3) -> List[ModelListResponse]:
- items = self.db.query(ModelNew).options(
- selectinload(ModelNew.prices)
- ).filter(
- ModelNew.is_featured == True,
- ModelNew.is_show_enabled == True,
- ).limit(limit).all()
- return [self._build_list_response(i) for i in items]
- def get_pricing(self, model_code: str) -> ModelPricingResponse:
- """获取模型定价信息,支持 source_keys / normalized_keys 查找"""
- import re
- def normalize(v: str) -> str:
- v = v.strip()
- v = re.sub(r"[\u200B-\u200D\uFEFF]", "", v)
- v = re.sub(r"\s*-\s*", "-", v)
- return re.sub(r"\s+", " ", v)
- normalized = normalize(model_code)
- # 先精确匹配 model_code
- model = self.db.query(ModelNew).filter(
- ModelNew.model_code == model_code
- ).first()
- # 再查 source_keys / normalized_keys 数组
- if not model:
- from sqlalchemy import func, cast
- from sqlalchemy.dialects.postgresql import ARRAY, TEXT
- model = self.db.query(ModelNew).filter(
- or_(
- ModelNew.source_keys.any(model_code),
- ModelNew.normalized_keys.any(normalized),
- )
- ).first()
- prices = []
- if model:
- price_rows = self.db.query(ModelPriceNew).filter(
- ModelPriceNew.model_code == model.model_code,
- ModelPriceNew.is_active == True,
- ).order_by(ModelPriceNew.label).all()
- prices = [
- ModelPriceNewResponse(
- id=p.id,
- label=p.label,
- tier_min=p.tier_min,
- tier_max=p.tier_max,
- tier_unit=p.tier_unit,
- input_price_original=p.input_price_original,
- output_price_original=p.output_price_original,
- discount_rate=p.discount_rate,
- discount_label=p.discount_label,
- input_price_discounted=p.input_price_discounted,
- output_price_discounted=p.output_price_discounted,
- currency=p.currency,
- unit=p.unit,
- display_multiplier=p.display_multiplier,
- )
- for p in price_rows
- ]
- return ModelPricingResponse(
- model_code=model.model_code if model else model_code,
- display_name=model.display_name if model else None,
- description=model.description if model else None,
- custom_description=model.custom_description if model else None,
- features=model.features if model else None,
- rate_limits=model.rate_limits if model else None,
- tool_call_prices=model.tool_call_prices if model else None,
- display_tags=model.display_tags if model else None,
- input_modalities=model.input_modalities if model else None,
- output_modalities=model.output_modalities if model else None,
- categories=model.categories if model else [],
- source_keys=model.source_keys if model else None,
- normalized_keys=model.normalized_keys if model else None,
- is_api_enabled=model.is_api_enabled if model else False,
- prices=prices,
- )
|