model_router.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. """
  2. 模型API路由(新表结构)
  3. """
  4. import json
  5. import hashlib
  6. import logging
  7. from typing import Optional
  8. from fastapi import APIRouter, Depends, Query, HTTPException
  9. from sqlalchemy.orm import Session
  10. from app.database import get_db
  11. from app.services.model_service import ModelService
  12. from app.schemas.model_schema import (
  13. ApiResponse, PaginatedResponse, ModelListResponse,
  14. KeywordsResponse, FeaturedModelsResponse, ModelPricingResponse,
  15. )
  16. router = APIRouter(prefix="/api/models", tags=["模型广场"])
  17. logger = logging.getLogger(__name__)
  18. # 缓存 TTL(秒)
  19. _MODEL_LIST_TTL = 1800 # 30 分钟
  20. _KEYWORDS_TTL = 3600 # 1 小时
  21. _FEATURED_TTL = 3600 # 1 小时
  22. _PRICING_TTL = 7200 # 2 小时
  23. def _get_sync_redis():
  24. try:
  25. from app.core.redis import redis_manager
  26. return redis_manager.get_sync_client()
  27. except Exception:
  28. return None
  29. def _cache_get(key: str):
  30. r = _get_sync_redis()
  31. if not r:
  32. return None
  33. try:
  34. raw = r.get(key)
  35. return json.loads(raw) if raw else None
  36. except Exception:
  37. return None
  38. def _cache_set(key: str, value, ttl: int):
  39. r = _get_sync_redis()
  40. if not r:
  41. return
  42. try:
  43. r.setex(key, ttl, json.dumps(value, default=str))
  44. except Exception as e:
  45. logger.debug(f"模型缓存写入失败: {e}")
  46. def _model_list_cache_key(page, page_size, keyword, category, supplier,
  47. group_name, filter_keyword, filter_tag, is_api_enabled) -> str:
  48. params = f"{page}:{page_size}:{keyword}:{category}:{supplier}:{group_name}:{filter_keyword}:{filter_tag}:{is_api_enabled}"
  49. return f"model_list:{hashlib.md5(params.encode()).hexdigest()}"
  50. @router.get("/keywords/list", response_model=ApiResponse[KeywordsResponse])
  51. def get_keywords(db: Session = Depends(get_db)):
  52. cached = _cache_get("model_keywords")
  53. if cached:
  54. return ApiResponse(code=200, message="success", data=cached)
  55. service = ModelService(db)
  56. data = service.get_keywords()
  57. _cache_set("model_keywords", data.model_dump(), _KEYWORDS_TTL)
  58. return ApiResponse(code=200, message="success", data=data)
  59. @router.get("/featured/list", response_model=ApiResponse[FeaturedModelsResponse])
  60. def get_featured_models(
  61. limit: int = Query(default=3, ge=1, le=10),
  62. db: Session = Depends(get_db),
  63. ):
  64. cache_key = f"model_featured:{limit}"
  65. cached = _cache_get(cache_key)
  66. if cached:
  67. return ApiResponse(code=200, message="success", data=cached)
  68. service = ModelService(db)
  69. items = service.get_featured_models(limit=limit)
  70. resp = FeaturedModelsResponse(items=items)
  71. _cache_set(cache_key, resp.model_dump(), _FEATURED_TTL)
  72. return ApiResponse(code=200, message="success", data=resp)
  73. @router.get("/pricing/{model_code}")
  74. def get_model_pricing(model_code: str, db: Session = Depends(get_db)):
  75. lookup = model_code.strip()
  76. if not lookup:
  77. raise HTTPException(status_code=400, detail="model_code 不能为空")
  78. cache_key = f"model_pricing:{lookup}"
  79. cached = _cache_get(cache_key)
  80. if cached:
  81. return ApiResponse(code=200, message="success", data=cached)
  82. service = ModelService(db)
  83. data = service.get_pricing(lookup)
  84. model_capabilities = None
  85. if data.features or data.input_modalities or data.output_modalities:
  86. model_capabilities = {
  87. "features": data.features or {},
  88. "input_modalities": data.input_modalities or [],
  89. "output_modalities": data.output_modalities or [],
  90. }
  91. model_pricing = None
  92. if data.prices:
  93. if len(data.prices) > 1:
  94. tiers = []
  95. for p in data.prices:
  96. in_orig = float(p.input_price_original)
  97. out_orig = float(p.output_price_original)
  98. in_disc = float(p.input_price_discounted)
  99. out_disc = float(p.output_price_discounted)
  100. rate = float(p.discount_rate)
  101. tier = {
  102. "input_range": p.label,
  103. "input": in_disc,
  104. "output": out_disc,
  105. "unit": p.unit,
  106. }
  107. if rate < 0.9999:
  108. if in_orig > 0:
  109. tier["input_original"] = in_orig
  110. if out_orig > 0:
  111. tier["output_original"] = out_orig
  112. tiers.append(tier)
  113. model_pricing = tiers
  114. else:
  115. p = data.prices[0]
  116. in_orig = float(p.input_price_original)
  117. out_orig = float(p.output_price_original)
  118. in_disc = float(p.input_price_discounted)
  119. out_disc = float(p.output_price_discounted)
  120. rate = float(p.discount_rate)
  121. pricing = {
  122. "input": in_disc,
  123. "output": out_disc,
  124. "unit": p.unit,
  125. }
  126. if rate < 0.9999:
  127. if in_orig > 0:
  128. pricing["input_original"] = in_orig
  129. if out_orig > 0:
  130. pricing["output_original"] = out_orig
  131. model_pricing = pricing
  132. compat_response = {
  133. "model_code": data.model_code,
  134. "model_intro": data.custom_description or data.description,
  135. "model_tags": data.display_tags,
  136. "model_capabilities": model_capabilities,
  137. "model_pricing": model_pricing,
  138. "tool_call_pricing": data.tool_call_prices,
  139. "model_limits": data.rate_limits or None,
  140. "api_examples": None,
  141. "is_api_enabled": data.is_api_enabled,
  142. "categories": data.categories or [],
  143. }
  144. _cache_set(cache_key, compat_response, _PRICING_TTL)
  145. return ApiResponse(code=200, message="success", data=compat_response)
  146. @router.get("/group-names/list", response_model=ApiResponse[list[str]])
  147. def get_group_names(db: Session = Depends(get_db)):
  148. """获取所有模型分组名称(去重排序)"""
  149. service = ModelService(db)
  150. return ApiResponse(code=200, message="success", data=service.get_group_names())
  151. @router.get("/{model_id}", response_model=ApiResponse[ModelListResponse])
  152. def get_model_by_id(model_id: int, db: Session = Depends(get_db)):
  153. service = ModelService(db)
  154. data = service.get_model_by_id(model_id)
  155. return ApiResponse(code=200, message="success", data=data)
  156. @router.get("", response_model=ApiResponse[PaginatedResponse[ModelListResponse]])
  157. def get_models(
  158. page: int = Query(default=1, ge=1),
  159. page_size: int = Query(default=20, ge=1, le=100),
  160. keyword: Optional[str] = Query(default=None),
  161. category: Optional[str] = Query(default=None),
  162. supplier: Optional[str] = Query(default=None),
  163. group_name: Optional[str] = Query(default=None),
  164. filter_keyword: Optional[str] = Query(default=None),
  165. filter_tag: Optional[str] = Query(default=None),
  166. is_api_enabled: Optional[bool] = Query(default=None),
  167. db: Session = Depends(get_db),
  168. ):
  169. cache_key = _model_list_cache_key(page, page_size, keyword, category, supplier,
  170. group_name, filter_keyword, filter_tag, is_api_enabled)
  171. cached = _cache_get(cache_key)
  172. if cached:
  173. return ApiResponse(code=200, message="success", data=cached)
  174. service = ModelService(db)
  175. data = service.get_models(
  176. page=page, page_size=page_size, keyword=keyword,
  177. category=category, supplier=supplier, group_name=group_name,
  178. filter_keyword=filter_keyword, filter_tag=filter_tag,
  179. is_api_enabled=is_api_enabled,
  180. )
  181. _cache_set(cache_key, data.model_dump(), _MODEL_LIST_TTL)
  182. return ApiResponse(code=200, message="success", data=data)