public.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. from __future__ import annotations
  2. from datetime import datetime
  3. from typing import List, Optional, Union
  4. from urllib.parse import urlparse
  5. import json
  6. from app.utils.price_parser import parse_prices
  7. from fastapi import APIRouter, HTTPException, Request
  8. from pydantic import BaseModel
  9. from app.db import get_pool
  10. from app.services.geo import geo_resolver
  11. router = APIRouter()
  12. class PublicPriceOut(BaseModel):
  13. url: str
  14. model_name: str
  15. prices: dict
  16. model_info: Optional[dict] = None
  17. rate_limits: Optional[dict] = None
  18. tool_prices: Optional[list] = None
  19. icon: Optional[str] = None
  20. scraped_at: datetime
  21. class ParsedPriceItem(BaseModel):
  22. url: str
  23. model_name: str
  24. tier_min: Optional[float] = None
  25. tier_max: Optional[float] = None
  26. tier_unit: Optional[str] = None
  27. input_price: Optional[float] = None
  28. output_price: Optional[float] = None
  29. currency: str = "CNY"
  30. unit: Optional[str] = None
  31. label: Optional[str] = None
  32. class DiscountedPriceItem(BaseModel):
  33. url: str
  34. model_name: str
  35. tier_min: Optional[float] = None
  36. tier_max: Optional[float] = None
  37. tier_unit: Optional[str] = None
  38. input_price: Optional[float] = None
  39. output_price: Optional[float] = None
  40. currency: str = "CNY"
  41. unit: Optional[str] = None
  42. label: Optional[str] = None
  43. discount: Optional[float] = None # None 表示无折扣(原价)
  44. class ModelTypeItem(BaseModel):
  45. model_name: str
  46. type: List[str]
  47. class PricesResponse(BaseModel):
  48. version: int
  49. models: List[PublicPriceOut]
  50. parsed_prices: List[ParsedPriceItem]
  51. discounted_prices: List[DiscountedPriceItem]
  52. types: List[ModelTypeItem]
  53. discount: float = 1.0
  54. class UpToDateResponse(BaseModel):
  55. up_to_date: bool = True
  56. version: int
  57. def _extract_domain(referer: Optional[str]) -> Optional[str]:
  58. if not referer:
  59. return None
  60. try:
  61. return urlparse(referer).netloc or None
  62. except Exception:
  63. return None
  64. @router.get("/prices", response_model=Union[PricesResponse, UpToDateResponse])
  65. async def get_public_prices(
  66. request: Request,
  67. url: Optional[str] = None,
  68. ) -> Union[PricesResponse, UpToDateResponse]:
  69. pool = get_pool()
  70. # referer 必须提供
  71. referer = request.headers.get("referer") or request.headers.get("origin")
  72. if not referer:
  73. raise HTTPException(status_code=400, detail="Missing Referer header")
  74. # version 从 Header 读取,默认 0(首次请求)
  75. try:
  76. version = int(request.headers.get("version", "0") or "0")
  77. except ValueError:
  78. version = 0
  79. # 记录调用来源
  80. ip = request.client.host if request.client else "unknown"
  81. geo = geo_resolver.resolve(ip)
  82. try:
  83. await pool.execute(
  84. "INSERT INTO price_api_logs (ip, referer, org, country, city) VALUES ($1, $2, $3, $4, $5)",
  85. ip, referer, geo.org, geo.country, geo.city,
  86. )
  87. except Exception:
  88. pass
  89. # 查调用方域名对应的折扣
  90. caller_domain = _extract_domain(referer)
  91. discount_rate: Optional[float] = None
  92. if caller_domain:
  93. row = await pool.fetchrow("SELECT discount FROM discounts WHERE domain = $1", caller_domain)
  94. if row:
  95. discount_rate = float(row["discount"])
  96. def _j(v):
  97. if v is None:
  98. return None
  99. return v if isinstance(v, (dict, list)) else json.loads(v)
  100. # 读取版本号:优先用域名专属版本,没有则回退到全局版本
  101. if caller_domain:
  102. ver_row = await pool.fetchrow(
  103. "SELECT version FROM domain_version WHERE domain = $1", caller_domain
  104. )
  105. if ver_row:
  106. current_version = int(ver_row["version"])
  107. else:
  108. ver_row = await pool.fetchrow("SELECT version FROM price_snapshot_version WHERE id = 1")
  109. current_version = int(ver_row["version"]) if ver_row else 0
  110. else:
  111. ver_row = await pool.fetchrow("SELECT version FROM price_snapshot_version WHERE id = 1")
  112. current_version = int(ver_row["version"]) if ver_row else 0
  113. # version != 0 且与当前一致 → 无需更新(0 视为首次请求,强制返回数据)
  114. if version != 0 and version == current_version:
  115. return UpToDateResponse(up_to_date=True, version=current_version)
  116. # 从 price_snapshot 读取数据
  117. if url is None:
  118. rows = await pool.fetch(
  119. "SELECT url, model_name, prices, model_info, rate_limits, tool_prices, icon, updated_at FROM price_snapshot ORDER BY url"
  120. )
  121. else:
  122. rows = await pool.fetch(
  123. "SELECT url, model_name, prices, model_info, rate_limits, tool_prices, icon, updated_at FROM price_snapshot WHERE url = $1",
  124. url,
  125. )
  126. if not rows:
  127. raise HTTPException(status_code=404, detail="No price snapshot found for the given URL")
  128. if not rows:
  129. raise HTTPException(status_code=503, detail="Price snapshot not yet available")
  130. # version != 0 且与当前一致 → 无需更新
  131. if version != 0 and version == current_version:
  132. return UpToDateResponse(up_to_date=True, version=current_version)
  133. def _extract_type(model_info: Optional[dict]) -> Optional[List[str]]:
  134. if not model_info:
  135. return None
  136. tags = model_info.get("display_tags", [])
  137. TYPE_TAGS = {"文本生成", "图像生成", "视频生成", "向量表示", "向量模型", "多模态向量", "语音识别", "实时语音识别", "语音合成"}
  138. result = [t for t in tags if t in TYPE_TAGS]
  139. return result if result else None
  140. models = [PublicPriceOut(
  141. url=r["url"],
  142. model_name=r["model_name"],
  143. prices=_j(r["prices"]) or {},
  144. model_info=_j(r["model_info"]),
  145. rate_limits=_j(r["rate_limits"]),
  146. tool_prices=_j(r["tool_prices"]),
  147. icon=r["icon"],
  148. scraped_at=r["updated_at"],
  149. ) for r in rows]
  150. parsed_prices: List[ParsedPriceItem] = []
  151. discounted_prices: List[DiscountedPriceItem] = []
  152. for r in rows:
  153. for item in parse_prices(_j(r["prices"]) or {}):
  154. parsed_prices.append(ParsedPriceItem(url=r["url"], model_name=r["model_name"], **item))
  155. d_item = dict(item)
  156. if discount_rate is not None:
  157. if d_item.get("input_price") is not None:
  158. d_item["input_price"] = round(d_item["input_price"] * discount_rate, 6)
  159. if d_item.get("output_price") is not None:
  160. d_item["output_price"] = round(d_item["output_price"] * discount_rate, 6)
  161. discounted_prices.append(DiscountedPriceItem(url=r["url"], model_name=r["model_name"], discount=discount_rate, **d_item))
  162. all_types = [
  163. ModelTypeItem(model_name=r["model_name"], type=_extract_type(_j(r["model_info"])) or [])
  164. for r in rows
  165. ]
  166. return PricesResponse(
  167. version=current_version,
  168. models=models,
  169. parsed_prices=parsed_prices,
  170. discounted_prices=discounted_prices,
  171. types=all_types,
  172. discount=discount_rate if discount_rate is not None else 1.0,
  173. )