public.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. from __future__ import annotations
  2. from datetime import datetime
  3. from typing import List, Optional
  4. from urllib.parse import urlparse
  5. import json
  6. from app.utils.price_parser import parse_prices
  7. from fastapi import APIRouter, HTTPException, Request
  8. from pydantic import BaseModel
  9. from app.db import get_pool
  10. from app.services.geo import geo_resolver
  11. router = APIRouter()
  12. class PublicPriceOut(BaseModel):
  13. url: str
  14. model_name: str
  15. prices: dict
  16. model_info: Optional[dict] = None
  17. rate_limits: Optional[dict] = None
  18. tool_prices: Optional[list] = None
  19. scraped_at: datetime
  20. class ParsedPriceItem(BaseModel):
  21. url: str
  22. model_name: str
  23. tier_min: Optional[float] = None
  24. tier_max: Optional[float] = None
  25. tier_unit: Optional[str] = None
  26. input_price: Optional[float] = None
  27. output_price: Optional[float] = None
  28. currency: str = "CNY"
  29. unit: Optional[str] = None
  30. label: Optional[str] = None
  31. class DiscountedPriceItem(BaseModel):
  32. url: str
  33. model_name: str
  34. tier_min: Optional[float] = None
  35. tier_max: Optional[float] = None
  36. tier_unit: Optional[str] = None
  37. input_price: Optional[float] = None
  38. output_price: Optional[float] = None
  39. currency: str = "CNY"
  40. unit: Optional[str] = None
  41. label: Optional[str] = None
  42. discount: Optional[float] = None # None 表示无折扣(原价)
  43. class PricesResponse(BaseModel):
  44. models: List[PublicPriceOut]
  45. parsed_prices: List[ParsedPriceItem]
  46. discounted_prices: List[DiscountedPriceItem]
  47. def _extract_domain(referer: Optional[str]) -> Optional[str]:
  48. if not referer:
  49. return None
  50. try:
  51. return urlparse(referer).netloc or None
  52. except Exception:
  53. return None
  54. @router.get("/prices", response_model=PricesResponse)
  55. async def get_public_prices(request: Request, url: Optional[str] = None) -> PricesResponse:
  56. pool = get_pool()
  57. # 记录调用来源
  58. ip = request.client.host if request.client else "unknown"
  59. referer = request.headers.get("referer") or request.headers.get("origin")
  60. geo = geo_resolver.resolve(ip)
  61. try:
  62. await pool.execute(
  63. """
  64. INSERT INTO price_api_logs (ip, referer, org, country, city)
  65. VALUES ($1, $2, $3, $4, $5)
  66. """,
  67. ip, referer, geo.org, geo.country, geo.city,
  68. )
  69. except Exception:
  70. pass
  71. # 查调用方域名对应的折扣
  72. caller_domain = _extract_domain(referer)
  73. discount_rate: Optional[float] = None
  74. if caller_domain:
  75. row = await pool.fetchrow(
  76. "SELECT discount FROM discounts WHERE domain = $1", caller_domain
  77. )
  78. if row:
  79. discount_rate = float(row["discount"])
  80. def _j(v):
  81. if v is None:
  82. return None
  83. return v if isinstance(v, (dict, list)) else json.loads(v)
  84. if url is None:
  85. rows = await pool.fetch(
  86. """
  87. WITH latest_job AS (
  88. SELECT id FROM scrape_jobs
  89. WHERE status = 'done'
  90. ORDER BY created_at DESC LIMIT 1
  91. )
  92. SELECT DISTINCT ON (r.url) r.url, r.model_name, r.prices,
  93. r.model_info, r.rate_limits, r.tool_prices, r.scraped_at
  94. FROM scrape_results r
  95. JOIN latest_job j ON r.job_id = j.id
  96. ORDER BY r.url, r.scraped_at DESC
  97. """
  98. )
  99. else:
  100. rows = await pool.fetch(
  101. """
  102. SELECT url, model_name, prices, model_info, rate_limits, tool_prices, scraped_at
  103. FROM scrape_results
  104. WHERE url = $1
  105. ORDER BY scraped_at DESC LIMIT 1
  106. """,
  107. url,
  108. )
  109. if not rows:
  110. raise HTTPException(status_code=404, detail="No scrape results found for the given URL")
  111. models = [PublicPriceOut(
  112. url=r["url"],
  113. model_name=r["model_name"],
  114. prices=_j(r["prices"]) or {},
  115. model_info=_j(r["model_info"]),
  116. rate_limits=_j(r["rate_limits"]),
  117. tool_prices=_j(r["tool_prices"]),
  118. scraped_at=r["scraped_at"],
  119. ) for r in rows]
  120. parsed_prices: List[ParsedPriceItem] = []
  121. discounted_prices: List[DiscountedPriceItem] = []
  122. for r in rows:
  123. for item in parse_prices(_j(r["prices"]) or {}):
  124. parsed_prices.append(ParsedPriceItem(
  125. url=r["url"],
  126. model_name=r["model_name"],
  127. **item,
  128. ))
  129. # 折扣价:有折扣就乘,没有就原价(discount=None)
  130. d_item = dict(item)
  131. if discount_rate is not None:
  132. if d_item.get("input_price") is not None:
  133. d_item["input_price"] = round(d_item["input_price"] * discount_rate, 6)
  134. if d_item.get("output_price") is not None:
  135. d_item["output_price"] = round(d_item["output_price"] * discount_rate, 6)
  136. discounted_prices.append(DiscountedPriceItem(
  137. url=r["url"],
  138. model_name=r["model_name"],
  139. discount=discount_rate,
  140. **d_item,
  141. ))
  142. return PricesResponse(models=models, parsed_prices=parsed_prices, discounted_prices=discounted_prices)