""" 模型API路由(新表结构) """ import json import hashlib import logging from typing import Optional from fastapi import APIRouter, Depends, Query, HTTPException from sqlalchemy.orm import Session from app.database import get_db from app.services.model_service import ModelService from app.schemas.model_schema import ( ApiResponse, PaginatedResponse, ModelListResponse, KeywordsResponse, FeaturedModelsResponse, ModelPricingResponse, ) router = APIRouter(prefix="/api/models", tags=["模型广场"]) logger = logging.getLogger(__name__) # 缓存 TTL(秒) _MODEL_LIST_TTL = 1800 # 30 分钟 _KEYWORDS_TTL = 3600 # 1 小时 _FEATURED_TTL = 3600 # 1 小时 _PRICING_TTL = 7200 # 2 小时 def _get_sync_redis(): try: from app.core.redis import redis_manager return redis_manager.get_sync_client() except Exception: return None def _cache_get(key: str): r = _get_sync_redis() if not r: return None try: raw = r.get(key) return json.loads(raw) if raw else None except Exception: return None def _cache_set(key: str, value, ttl: int): r = _get_sync_redis() if not r: return try: r.setex(key, ttl, json.dumps(value, default=str)) except Exception as e: logger.debug(f"模型缓存写入失败: {e}") def _model_list_cache_key(page, page_size, keyword, category, supplier, group_name, filter_keyword, filter_tag, is_api_enabled) -> str: params = f"{page}:{page_size}:{keyword}:{category}:{supplier}:{group_name}:{filter_keyword}:{filter_tag}:{is_api_enabled}" return f"model_list:{hashlib.md5(params.encode()).hexdigest()}" @router.get("/keywords/list", response_model=ApiResponse[KeywordsResponse]) def get_keywords(db: Session = Depends(get_db)): cached = _cache_get("model_keywords") if cached: return ApiResponse(code=200, message="success", data=cached) service = ModelService(db) data = service.get_keywords() _cache_set("model_keywords", data.model_dump(), _KEYWORDS_TTL) return ApiResponse(code=200, message="success", data=data) @router.get("/featured/list", response_model=ApiResponse[FeaturedModelsResponse]) def get_featured_models( limit: int = Query(default=3, ge=1, le=10), db: Session = Depends(get_db), ): cache_key = f"model_featured:{limit}" cached = _cache_get(cache_key) if cached: return ApiResponse(code=200, message="success", data=cached) service = ModelService(db) items = service.get_featured_models(limit=limit) resp = FeaturedModelsResponse(items=items) _cache_set(cache_key, resp.model_dump(), _FEATURED_TTL) return ApiResponse(code=200, message="success", data=resp) @router.get("/pricing/{model_code}") def get_model_pricing(model_code: str, db: Session = Depends(get_db)): lookup = model_code.strip() if not lookup: raise HTTPException(status_code=400, detail="model_code 不能为空") cache_key = f"model_pricing:{lookup}" cached = _cache_get(cache_key) if cached: return ApiResponse(code=200, message="success", data=cached) service = ModelService(db) data = service.get_pricing(lookup) model_capabilities = None if data.features or data.input_modalities or data.output_modalities: model_capabilities = { "features": data.features or {}, "input_modalities": data.input_modalities or [], "output_modalities": data.output_modalities or [], } model_pricing = None if data.prices: if len(data.prices) > 1: tiers = [] for p in data.prices: in_orig = float(p.input_price_original) out_orig = float(p.output_price_original) in_disc = float(p.input_price_discounted) out_disc = float(p.output_price_discounted) rate = float(p.discount_rate) tier = { "input_range": p.label, "input": in_disc, "output": out_disc, "unit": p.unit, } if rate < 0.9999: if in_orig > 0: tier["input_original"] = in_orig if out_orig > 0: tier["output_original"] = out_orig tiers.append(tier) model_pricing = tiers else: p = data.prices[0] in_orig = float(p.input_price_original) out_orig = float(p.output_price_original) in_disc = float(p.input_price_discounted) out_disc = float(p.output_price_discounted) rate = float(p.discount_rate) pricing = { "input": in_disc, "output": out_disc, "unit": p.unit, } if rate < 0.9999: if in_orig > 0: pricing["input_original"] = in_orig if out_orig > 0: pricing["output_original"] = out_orig model_pricing = pricing compat_response = { "model_code": data.model_code, "model_intro": data.custom_description or data.description, "model_tags": data.display_tags, "model_capabilities": model_capabilities, "model_pricing": model_pricing, "tool_call_pricing": data.tool_call_prices, "model_limits": data.rate_limits or None, "api_examples": None, "is_api_enabled": data.is_api_enabled, "categories": data.categories or [], } _cache_set(cache_key, compat_response, _PRICING_TTL) return ApiResponse(code=200, message="success", data=compat_response) @router.get("/group-names/list", response_model=ApiResponse[list[str]]) def get_group_names(db: Session = Depends(get_db)): """获取所有模型分组名称(去重排序)""" service = ModelService(db) return ApiResponse(code=200, message="success", data=service.get_group_names()) @router.get("/{model_id}", response_model=ApiResponse[ModelListResponse]) def get_model_by_id(model_id: int, db: Session = Depends(get_db)): service = ModelService(db) data = service.get_model_by_id(model_id) return ApiResponse(code=200, message="success", data=data) @router.get("", response_model=ApiResponse[PaginatedResponse[ModelListResponse]]) def get_models( page: int = Query(default=1, ge=1), page_size: int = Query(default=20, ge=1, le=100), keyword: Optional[str] = Query(default=None), category: Optional[str] = Query(default=None), supplier: Optional[str] = Query(default=None), group_name: Optional[str] = Query(default=None), filter_keyword: Optional[str] = Query(default=None), filter_tag: Optional[str] = Query(default=None), is_api_enabled: Optional[bool] = Query(default=None), db: Session = Depends(get_db), ): cache_key = _model_list_cache_key(page, page_size, keyword, category, supplier, group_name, filter_keyword, filter_tag, is_api_enabled) cached = _cache_get(cache_key) if cached: return ApiResponse(code=200, message="success", data=cached) service = ModelService(db) data = service.get_models( page=page, page_size=page_size, keyword=keyword, category=category, supplier=supplier, group_name=group_name, filter_keyword=filter_keyword, filter_tag=filter_tag, is_api_enabled=is_api_enabled, ) _cache_set(cache_key, data.model_dump(), _MODEL_LIST_TTL) return ApiResponse(code=200, message="success", data=data)