| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213 |
- """
- 模型API路由(新表结构)
- """
- import json
- import hashlib
- import logging
- from typing import Optional
- from fastapi import APIRouter, Depends, Query, HTTPException
- from sqlalchemy.orm import Session
- from app.database import get_db
- from app.services.model_service import ModelService
- from app.schemas.model_schema import (
- ApiResponse, PaginatedResponse, ModelListResponse,
- KeywordsResponse, FeaturedModelsResponse, ModelPricingResponse,
- )
- router = APIRouter(prefix="/api/models", tags=["模型广场"])
- logger = logging.getLogger(__name__)
- # 缓存 TTL(秒)
- _MODEL_LIST_TTL = 1800 # 30 分钟
- _KEYWORDS_TTL = 3600 # 1 小时
- _FEATURED_TTL = 3600 # 1 小时
- _PRICING_TTL = 7200 # 2 小时
- def _get_sync_redis():
- try:
- from app.core.redis import redis_manager
- return redis_manager.get_sync_client()
- except Exception:
- return None
- def _cache_get(key: str):
- r = _get_sync_redis()
- if not r:
- return None
- try:
- raw = r.get(key)
- return json.loads(raw) if raw else None
- except Exception:
- return None
- def _cache_set(key: str, value, ttl: int):
- r = _get_sync_redis()
- if not r:
- return
- try:
- r.setex(key, ttl, json.dumps(value, default=str))
- except Exception as e:
- logger.debug(f"模型缓存写入失败: {e}")
- def _model_list_cache_key(page, page_size, keyword, category, supplier,
- group_name, filter_keyword, filter_tag, is_api_enabled) -> str:
- params = f"{page}:{page_size}:{keyword}:{category}:{supplier}:{group_name}:{filter_keyword}:{filter_tag}:{is_api_enabled}"
- return f"model_list:{hashlib.md5(params.encode()).hexdigest()}"
- @router.get("/keywords/list", response_model=ApiResponse[KeywordsResponse])
- def get_keywords(db: Session = Depends(get_db)):
- cached = _cache_get("model_keywords")
- if cached:
- return ApiResponse(code=200, message="success", data=cached)
- service = ModelService(db)
- data = service.get_keywords()
- _cache_set("model_keywords", data.model_dump(), _KEYWORDS_TTL)
- return ApiResponse(code=200, message="success", data=data)
- @router.get("/featured/list", response_model=ApiResponse[FeaturedModelsResponse])
- def get_featured_models(
- limit: int = Query(default=3, ge=1, le=10),
- db: Session = Depends(get_db),
- ):
- cache_key = f"model_featured:{limit}"
- cached = _cache_get(cache_key)
- if cached:
- return ApiResponse(code=200, message="success", data=cached)
- service = ModelService(db)
- items = service.get_featured_models(limit=limit)
- resp = FeaturedModelsResponse(items=items)
- _cache_set(cache_key, resp.model_dump(), _FEATURED_TTL)
- return ApiResponse(code=200, message="success", data=resp)
- @router.get("/pricing/{model_code}")
- def get_model_pricing(model_code: str, db: Session = Depends(get_db)):
- lookup = model_code.strip()
- if not lookup:
- raise HTTPException(status_code=400, detail="model_code 不能为空")
- cache_key = f"model_pricing:{lookup}"
- cached = _cache_get(cache_key)
- if cached:
- return ApiResponse(code=200, message="success", data=cached)
- service = ModelService(db)
- data = service.get_pricing(lookup)
- model_capabilities = None
- if data.features or data.input_modalities or data.output_modalities:
- model_capabilities = {
- "features": data.features or {},
- "input_modalities": data.input_modalities or [],
- "output_modalities": data.output_modalities or [],
- }
- model_pricing = None
- if data.prices:
- if len(data.prices) > 1:
- tiers = []
- for p in data.prices:
- in_orig = float(p.input_price_original)
- out_orig = float(p.output_price_original)
- in_disc = float(p.input_price_discounted)
- out_disc = float(p.output_price_discounted)
- rate = float(p.discount_rate)
- tier = {
- "input_range": p.label,
- "input": in_disc,
- "output": out_disc,
- "unit": p.unit,
- }
- if rate < 0.9999:
- if in_orig > 0:
- tier["input_original"] = in_orig
- if out_orig > 0:
- tier["output_original"] = out_orig
- tiers.append(tier)
- model_pricing = tiers
- else:
- p = data.prices[0]
- in_orig = float(p.input_price_original)
- out_orig = float(p.output_price_original)
- in_disc = float(p.input_price_discounted)
- out_disc = float(p.output_price_discounted)
- rate = float(p.discount_rate)
- pricing = {
- "input": in_disc,
- "output": out_disc,
- "unit": p.unit,
- }
- if rate < 0.9999:
- if in_orig > 0:
- pricing["input_original"] = in_orig
- if out_orig > 0:
- pricing["output_original"] = out_orig
- model_pricing = pricing
- compat_response = {
- "model_code": data.model_code,
- "model_intro": data.custom_description or data.description,
- "model_tags": data.display_tags,
- "model_capabilities": model_capabilities,
- "model_pricing": model_pricing,
- "tool_call_pricing": data.tool_call_prices,
- "model_limits": data.rate_limits or None,
- "api_examples": None,
- "is_api_enabled": data.is_api_enabled,
- "categories": data.categories or [],
- }
- _cache_set(cache_key, compat_response, _PRICING_TTL)
- return ApiResponse(code=200, message="success", data=compat_response)
- @router.get("/group-names/list", response_model=ApiResponse[list[str]])
- def get_group_names(db: Session = Depends(get_db)):
- """获取所有模型分组名称(去重排序)"""
- service = ModelService(db)
- return ApiResponse(code=200, message="success", data=service.get_group_names())
- @router.get("/{model_id}", response_model=ApiResponse[ModelListResponse])
- def get_model_by_id(model_id: int, db: Session = Depends(get_db)):
- service = ModelService(db)
- data = service.get_model_by_id(model_id)
- return ApiResponse(code=200, message="success", data=data)
- @router.get("", response_model=ApiResponse[PaginatedResponse[ModelListResponse]])
- def get_models(
- page: int = Query(default=1, ge=1),
- page_size: int = Query(default=20, ge=1, le=100),
- keyword: Optional[str] = Query(default=None),
- category: Optional[str] = Query(default=None),
- supplier: Optional[str] = Query(default=None),
- group_name: Optional[str] = Query(default=None),
- filter_keyword: Optional[str] = Query(default=None),
- filter_tag: Optional[str] = Query(default=None),
- is_api_enabled: Optional[bool] = Query(default=None),
- db: Session = Depends(get_db),
- ):
- cache_key = _model_list_cache_key(page, page_size, keyword, category, supplier,
- group_name, filter_keyword, filter_tag, is_api_enabled)
- cached = _cache_get(cache_key)
- if cached:
- return ApiResponse(code=200, message="success", data=cached)
- service = ModelService(db)
- data = service.get_models(
- page=page, page_size=page_size, keyword=keyword,
- category=category, supplier=supplier, group_name=group_name,
- filter_keyword=filter_keyword, filter_tag=filter_tag,
- is_api_enabled=is_api_enabled,
- )
- _cache_set(cache_key, data.model_dump(), _MODEL_LIST_TTL)
- return ApiResponse(code=200, message="success", data=data)
|