| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167 |
- from __future__ import annotations
- from datetime import datetime
- from typing import List, Optional
- from urllib.parse import urlparse
- import json
- from app.utils.price_parser import parse_prices
- from fastapi import APIRouter, HTTPException, Request
- from pydantic import BaseModel
- from app.db import get_pool
- from app.services.geo import geo_resolver
- router = APIRouter()
- class PublicPriceOut(BaseModel):
- url: str
- model_name: str
- prices: dict
- model_info: Optional[dict] = None
- rate_limits: Optional[dict] = None
- tool_prices: Optional[list] = None
- scraped_at: datetime
- class ParsedPriceItem(BaseModel):
- url: str
- model_name: str
- tier_min: Optional[float] = None
- tier_max: Optional[float] = None
- tier_unit: Optional[str] = None
- input_price: Optional[float] = None
- output_price: Optional[float] = None
- currency: str = "CNY"
- unit: Optional[str] = None
- label: Optional[str] = None
- class DiscountedPriceItem(BaseModel):
- url: str
- model_name: str
- tier_min: Optional[float] = None
- tier_max: Optional[float] = None
- tier_unit: Optional[str] = None
- input_price: Optional[float] = None
- output_price: Optional[float] = None
- currency: str = "CNY"
- unit: Optional[str] = None
- label: Optional[str] = None
- discount: Optional[float] = None # None 表示无折扣(原价)
- class PricesResponse(BaseModel):
- models: List[PublicPriceOut]
- parsed_prices: List[ParsedPriceItem]
- discounted_prices: List[DiscountedPriceItem]
- def _extract_domain(referer: Optional[str]) -> Optional[str]:
- if not referer:
- return None
- try:
- return urlparse(referer).netloc or None
- except Exception:
- return None
- @router.get("/prices", response_model=PricesResponse)
- async def get_public_prices(request: Request, url: Optional[str] = None) -> PricesResponse:
- pool = get_pool()
- # 记录调用来源
- ip = request.client.host if request.client else "unknown"
- referer = request.headers.get("referer") or request.headers.get("origin")
- geo = geo_resolver.resolve(ip)
- try:
- await pool.execute(
- """
- INSERT INTO price_api_logs (ip, referer, org, country, city)
- VALUES ($1, $2, $3, $4, $5)
- """,
- ip, referer, geo.org, geo.country, geo.city,
- )
- except Exception:
- pass
- # 查调用方域名对应的折扣
- caller_domain = _extract_domain(referer)
- discount_rate: Optional[float] = None
- if caller_domain:
- row = await pool.fetchrow(
- "SELECT discount FROM discounts WHERE domain = $1", caller_domain
- )
- if row:
- discount_rate = float(row["discount"])
- def _j(v):
- if v is None:
- return None
- return v if isinstance(v, (dict, list)) else json.loads(v)
- if url is None:
- rows = await pool.fetch(
- """
- WITH latest_job AS (
- SELECT id FROM scrape_jobs
- WHERE status = 'done'
- ORDER BY created_at DESC LIMIT 1
- )
- SELECT DISTINCT ON (r.url) r.url, r.model_name, r.prices,
- r.model_info, r.rate_limits, r.tool_prices, r.scraped_at
- FROM scrape_results r
- JOIN latest_job j ON r.job_id = j.id
- ORDER BY r.url, r.scraped_at DESC
- """
- )
- else:
- rows = await pool.fetch(
- """
- SELECT url, model_name, prices, model_info, rate_limits, tool_prices, scraped_at
- FROM scrape_results
- WHERE url = $1
- ORDER BY scraped_at DESC LIMIT 1
- """,
- url,
- )
- if not rows:
- raise HTTPException(status_code=404, detail="No scrape results found for the given URL")
- models = [PublicPriceOut(
- url=r["url"],
- model_name=r["model_name"],
- prices=_j(r["prices"]) or {},
- model_info=_j(r["model_info"]),
- rate_limits=_j(r["rate_limits"]),
- tool_prices=_j(r["tool_prices"]),
- scraped_at=r["scraped_at"],
- ) for r in rows]
- parsed_prices: List[ParsedPriceItem] = []
- discounted_prices: List[DiscountedPriceItem] = []
- for r in rows:
- for item in parse_prices(_j(r["prices"]) or {}):
- parsed_prices.append(ParsedPriceItem(
- url=r["url"],
- model_name=r["model_name"],
- **item,
- ))
- # 折扣价:有折扣就乘,没有就原价(discount=None)
- d_item = dict(item)
- if discount_rate is not None:
- if d_item.get("input_price") is not None:
- d_item["input_price"] = round(d_item["input_price"] * discount_rate, 6)
- if d_item.get("output_price") is not None:
- d_item["output_price"] = round(d_item["output_price"] * discount_rate, 6)
- discounted_prices.append(DiscountedPriceItem(
- url=r["url"],
- model_name=r["model_name"],
- discount=discount_rate,
- **d_item,
- ))
- return PricesResponse(models=models, parsed_prices=parsed_prices, discounted_prices=discounted_prices)
|