public.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. from __future__ import annotations
  2. from datetime import datetime
  3. from typing import List, Optional, Union
  4. from urllib.parse import urlparse
  5. import json
  6. from app.utils.price_parser import parse_prices
  7. from app.utils.apikey_crypto import try_encrypt
  8. from fastapi import APIRouter, HTTPException, Request
  9. from pydantic import BaseModel
  10. from app.db import get_pool
  11. from app.services.geo import geo_resolver
  12. router = APIRouter()
  13. class PublicPriceOut(BaseModel):
  14. url: str
  15. model_name: str
  16. prices: dict
  17. model_info: Optional[dict] = None
  18. rate_limits: Optional[dict] = None
  19. tool_prices: Optional[list] = None
  20. icon: Optional[str] = None
  21. scraped_at: datetime
  22. discount: Optional[float] = None
  23. api_key: Optional[str] = None
  24. group_name: Optional[str] = None
  25. class ParsedPriceItem(BaseModel):
  26. url: str
  27. model_name: str
  28. tier_min: Optional[float] = None
  29. tier_max: Optional[float] = None
  30. tier_unit: Optional[str] = None
  31. input_price: Optional[float] = None
  32. output_price: Optional[float] = None
  33. currency: str = "CNY"
  34. unit: Optional[str] = None
  35. label: Optional[str] = None
  36. class DiscountedPriceItem(BaseModel):
  37. url: str
  38. model_name: str
  39. tier_min: Optional[float] = None
  40. tier_max: Optional[float] = None
  41. tier_unit: Optional[str] = None
  42. input_price: Optional[float] = None
  43. output_price: Optional[float] = None
  44. currency: str = "CNY"
  45. unit: Optional[str] = None
  46. label: Optional[str] = None
  47. discount: Optional[float] = None # None 表示未知,1.0 表示原价
  48. class ModelTypeItem(BaseModel):
  49. model_name: str
  50. type: List[str]
  51. class PricesResponse(BaseModel):
  52. version: int
  53. models: List[PublicPriceOut]
  54. parsed_prices: List[ParsedPriceItem]
  55. discounted_prices: List[DiscountedPriceItem]
  56. types: List[ModelTypeItem]
  57. class UpToDateResponse(BaseModel):
  58. up_to_date: bool = True
  59. version: int
  60. def _extract_domain(referer: Optional[str]) -> Optional[str]:
  61. if not referer:
  62. return None
  63. try:
  64. return urlparse(referer).netloc or None
  65. except Exception:
  66. return None
  67. @router.get("/prices", response_model=Union[PricesResponse, UpToDateResponse])
  68. async def get_public_prices(
  69. request: Request,
  70. url: Optional[str] = None,
  71. ) -> Union[PricesResponse, UpToDateResponse]:
  72. pool = get_pool()
  73. # referer 必须提供
  74. referer = request.headers.get("referer") or request.headers.get("origin")
  75. if not referer:
  76. raise HTTPException(status_code=400, detail="Missing Referer header")
  77. # version 从 Header 读取,默认 0(首次请求)
  78. try:
  79. version = int(request.headers.get("version", "0") or "0")
  80. except ValueError:
  81. version = 0
  82. # 记录调用来源
  83. ip = request.client.host if request.client else "unknown"
  84. geo = geo_resolver.resolve(ip)
  85. try:
  86. await pool.execute(
  87. "INSERT INTO price_api_logs (ip, referer, org, country, city) VALUES ($1, $2, $3, $4, $5)",
  88. ip, referer, geo.org, geo.country, geo.city,
  89. )
  90. except Exception:
  91. pass
  92. # 查调用方域名对应的折扣
  93. caller_domain = _extract_domain(referer)
  94. discount_rate: float = 1.0
  95. # 域名级别的模型自定义价格:{model_name: {input_price, output_price}}
  96. model_custom_prices: dict = {}
  97. if caller_domain:
  98. row = await pool.fetchrow("SELECT discount FROM discounts WHERE domain = $1", caller_domain)
  99. if row:
  100. discount_rate = float(row["discount"])
  101. # 加载该域名下所有模型的自定义折扣
  102. mp_rows = await pool.fetch(
  103. "SELECT model_name, discount FROM domain_model_prices WHERE domain = $1",
  104. caller_domain,
  105. )
  106. for mp in mp_rows:
  107. model_custom_prices[mp["model_name"]] = float(mp["discount"])
  108. def _j(v):
  109. if v is None:
  110. return None
  111. return v if isinstance(v, (dict, list)) else json.loads(v)
  112. # 读取版本号:优先用域名专属版本,没有则回退到全局版本
  113. if caller_domain:
  114. ver_row = await pool.fetchrow(
  115. "SELECT version FROM domain_version WHERE domain = $1", caller_domain
  116. )
  117. if ver_row:
  118. current_version = int(ver_row["version"])
  119. else:
  120. ver_row = await pool.fetchrow("SELECT version FROM price_snapshot_version WHERE id = 1")
  121. current_version = int(ver_row["version"]) if ver_row else 0
  122. else:
  123. ver_row = await pool.fetchrow("SELECT version FROM price_snapshot_version WHERE id = 1")
  124. current_version = int(ver_row["version"]) if ver_row else 0
  125. # version != 0 且与当前一致 → 无需更新(0 视为首次请求,强制返回数据)
  126. if version != 0 and version == current_version:
  127. return UpToDateResponse(up_to_date=True, version=current_version)
  128. # 从 price_snapshot 读取数据,LEFT JOIN models 取 api_key 和 group
  129. if url is None:
  130. rows = await pool.fetch(
  131. """
  132. SELECT ps.url, ps.model_name, ps.prices, ps.model_info, ps.rate_limits,
  133. ps.tool_prices, ps.icon, ps.updated_at,
  134. k.key_value AS api_key,
  135. m.group_id, g.name AS group_name
  136. FROM price_snapshot ps
  137. LEFT JOIN models m ON m.url = ps.url
  138. LEFT JOIN api_keys k ON k.id = m.api_key_id
  139. LEFT JOIN model_groups g ON g.id = m.group_id
  140. ORDER BY ps.url
  141. """
  142. )
  143. else:
  144. rows = await pool.fetch(
  145. """
  146. SELECT ps.url, ps.model_name, ps.prices, ps.model_info, ps.rate_limits,
  147. ps.tool_prices, ps.icon, ps.updated_at,
  148. k.key_value AS api_key,
  149. m.group_id, g.name AS group_name
  150. FROM price_snapshot ps
  151. LEFT JOIN models m ON m.url = ps.url
  152. LEFT JOIN api_keys k ON k.id = m.api_key_id
  153. LEFT JOIN model_groups g ON g.id = m.group_id
  154. WHERE ps.url = $1
  155. """,
  156. url,
  157. )
  158. if not rows:
  159. raise HTTPException(status_code=404, detail="No price snapshot found for the given URL")
  160. if not rows:
  161. raise HTTPException(status_code=503, detail="Price snapshot not yet available")
  162. # version != 0 且与当前一致 → 无需更新
  163. if version != 0 and version == current_version:
  164. return UpToDateResponse(up_to_date=True, version=current_version)
  165. def _extract_type(model_info: Optional[dict]) -> Optional[List[str]]:
  166. if not model_info:
  167. return None
  168. tags = model_info.get("display_tags", [])
  169. TYPE_TAGS = {"文本生成", "图像生成", "视频生成", "向量表示", "向量模型", "多模态向量", "语音识别", "实时语音识别", "语音合成"}
  170. result = [t for t in tags if t in TYPE_TAGS]
  171. return result if result else None
  172. models = [PublicPriceOut(
  173. url=r["url"],
  174. model_name=r["model_name"],
  175. prices=_j(r["prices"]) or {},
  176. model_info=_j(r["model_info"]),
  177. rate_limits=_j(r["rate_limits"]),
  178. tool_prices=_j(r["tool_prices"]),
  179. icon=r["icon"],
  180. scraped_at=r["updated_at"],
  181. discount=model_custom_prices.get(r["model_name"]) if model_custom_prices.get(r["model_name"]) is not None else discount_rate,
  182. api_key=try_encrypt(r["api_key"]),
  183. group_name=r["group_name"],
  184. ) for r in rows]
  185. parsed_prices: List[ParsedPriceItem] = []
  186. discounted_prices: List[DiscountedPriceItem] = []
  187. for r in rows:
  188. for item in parse_prices(_j(r["prices"]) or {}):
  189. parsed_prices.append(ParsedPriceItem(url=r["url"], model_name=r["model_name"], **item))
  190. d_item = dict(item)
  191. model_name = r["model_name"]
  192. custom = model_custom_prices.get(model_name)
  193. # 模型级折扣优先,没有则用域名全局折扣
  194. effective_discount = custom if custom is not None else discount_rate
  195. if effective_discount is not None:
  196. if d_item.get("input_price") is not None:
  197. d_item["input_price"] = round(d_item["input_price"] * effective_discount, 6)
  198. if d_item.get("output_price") is not None:
  199. d_item["output_price"] = round(d_item["output_price"] * effective_discount, 6)
  200. discounted_prices.append(DiscountedPriceItem(url=r["url"], model_name=r["model_name"], discount=effective_discount, **d_item))
  201. all_types = [
  202. ModelTypeItem(model_name=r["model_name"], type=_extract_type(_j(r["model_info"])) or [])
  203. for r in rows
  204. ]
  205. return PricesResponse(
  206. version=current_version,
  207. models=models,
  208. parsed_prices=parsed_prices,
  209. discounted_prices=discounted_prices,
  210. types=all_types,
  211. )