from __future__ import annotations from datetime import datetime from typing import List, Optional, Union from urllib.parse import urlparse import json from app.utils.price_parser import parse_prices from app.utils.apikey_crypto import try_encrypt from fastapi import APIRouter, HTTPException, Request from pydantic import BaseModel from app.db import get_pool from app.services.geo import geo_resolver router = APIRouter() class PublicPriceOut(BaseModel): url: str model_name: str prices: dict model_info: Optional[dict] = None rate_limits: Optional[dict] = None tool_prices: Optional[list] = None icon: Optional[str] = None scraped_at: datetime discount: Optional[float] = None api_key: Optional[str] = None group_name: Optional[str] = None class ParsedPriceItem(BaseModel): url: str model_name: str tier_min: Optional[float] = None tier_max: Optional[float] = None tier_unit: Optional[str] = None input_price: Optional[float] = None output_price: Optional[float] = None currency: str = "CNY" unit: Optional[str] = None label: Optional[str] = None class DiscountedPriceItem(BaseModel): url: str model_name: str tier_min: Optional[float] = None tier_max: Optional[float] = None tier_unit: Optional[str] = None input_price: Optional[float] = None output_price: Optional[float] = None currency: str = "CNY" unit: Optional[str] = None label: Optional[str] = None discount: Optional[float] = None # None 表示未知,1.0 表示原价 class ModelTypeItem(BaseModel): model_name: str type: List[str] class PricesResponse(BaseModel): version: int models: List[PublicPriceOut] parsed_prices: List[ParsedPriceItem] discounted_prices: List[DiscountedPriceItem] types: List[ModelTypeItem] class UpToDateResponse(BaseModel): up_to_date: bool = True version: int def _extract_domain(referer: Optional[str]) -> Optional[str]: if not referer: return None try: return urlparse(referer).netloc or None except Exception: return None @router.get("/prices", response_model=Union[PricesResponse, UpToDateResponse]) async def get_public_prices( request: Request, url: Optional[str] = None, ) -> Union[PricesResponse, UpToDateResponse]: pool = get_pool() # referer 必须提供 referer = request.headers.get("referer") or request.headers.get("origin") if not referer: raise HTTPException(status_code=400, detail="Missing Referer header") # version 从 Header 读取,默认 0(首次请求) try: version = int(request.headers.get("version", "0") or "0") except ValueError: version = 0 # 记录调用来源 ip = request.client.host if request.client else "unknown" geo = geo_resolver.resolve(ip) try: await pool.execute( "INSERT INTO price_api_logs (ip, referer, org, country, city) VALUES ($1, $2, $3, $4, $5)", ip, referer, geo.org, geo.country, geo.city, ) except Exception: pass # 查调用方域名对应的折扣 caller_domain = _extract_domain(referer) discount_rate: float = 1.0 # 域名级别的模型自定义价格:{model_name: {input_price, output_price}} model_custom_prices: dict = {} if caller_domain: row = await pool.fetchrow("SELECT discount FROM discounts WHERE domain = $1", caller_domain) if row: discount_rate = float(row["discount"]) # 加载该域名下所有模型的自定义折扣 mp_rows = await pool.fetch( "SELECT model_name, discount FROM domain_model_prices WHERE domain = $1", caller_domain, ) for mp in mp_rows: model_custom_prices[mp["model_name"]] = float(mp["discount"]) def _j(v): if v is None: return None return v if isinstance(v, (dict, list)) else json.loads(v) # 读取版本号:优先用域名专属版本,没有则回退到全局版本 if caller_domain: ver_row = await pool.fetchrow( "SELECT version FROM domain_version WHERE domain = $1", caller_domain ) if ver_row: current_version = int(ver_row["version"]) else: ver_row = await pool.fetchrow("SELECT version FROM price_snapshot_version WHERE id = 1") current_version = int(ver_row["version"]) if ver_row else 0 else: ver_row = await pool.fetchrow("SELECT version FROM price_snapshot_version WHERE id = 1") current_version = int(ver_row["version"]) if ver_row else 0 # version != 0 且与当前一致 → 无需更新(0 视为首次请求,强制返回数据) if version != 0 and version == current_version: return UpToDateResponse(up_to_date=True, version=current_version) # 从 price_snapshot 读取数据,LEFT JOIN models 取 api_key 和 group if url is None: rows = await pool.fetch( """ SELECT ps.url, ps.model_name, ps.prices, ps.model_info, ps.rate_limits, ps.tool_prices, ps.icon, ps.updated_at, k.key_value AS api_key, m.group_id, g.name AS group_name FROM price_snapshot ps LEFT JOIN models m ON m.url = ps.url LEFT JOIN api_keys k ON k.id = m.api_key_id LEFT JOIN model_groups g ON g.id = m.group_id ORDER BY ps.url """ ) else: rows = await pool.fetch( """ SELECT ps.url, ps.model_name, ps.prices, ps.model_info, ps.rate_limits, ps.tool_prices, ps.icon, ps.updated_at, k.key_value AS api_key, m.group_id, g.name AS group_name FROM price_snapshot ps LEFT JOIN models m ON m.url = ps.url LEFT JOIN api_keys k ON k.id = m.api_key_id LEFT JOIN model_groups g ON g.id = m.group_id WHERE ps.url = $1 """, url, ) if not rows: raise HTTPException(status_code=404, detail="No price snapshot found for the given URL") if not rows: raise HTTPException(status_code=503, detail="Price snapshot not yet available") # version != 0 且与当前一致 → 无需更新 if version != 0 and version == current_version: return UpToDateResponse(up_to_date=True, version=current_version) def _extract_type(model_info: Optional[dict]) -> Optional[List[str]]: if not model_info: return None tags = model_info.get("display_tags", []) TYPE_TAGS = {"文本生成", "图像生成", "视频生成", "向量表示", "向量模型", "多模态向量", "语音识别", "实时语音识别", "语音合成"} result = [t for t in tags if t in TYPE_TAGS] return result if result else None models = [PublicPriceOut( url=r["url"], model_name=r["model_name"], prices=_j(r["prices"]) or {}, model_info=_j(r["model_info"]), rate_limits=_j(r["rate_limits"]), tool_prices=_j(r["tool_prices"]), icon=r["icon"], scraped_at=r["updated_at"], discount=model_custom_prices.get(r["model_name"]) if model_custom_prices.get(r["model_name"]) is not None else discount_rate, api_key=try_encrypt(r["api_key"]), group_name=r["group_name"], ) for r in rows] parsed_prices: List[ParsedPriceItem] = [] discounted_prices: List[DiscountedPriceItem] = [] for r in rows: for item in parse_prices(_j(r["prices"]) or {}): parsed_prices.append(ParsedPriceItem(url=r["url"], model_name=r["model_name"], **item)) d_item = dict(item) model_name = r["model_name"] custom = model_custom_prices.get(model_name) # 模型级折扣优先,没有则用域名全局折扣 effective_discount = custom if custom is not None else discount_rate if effective_discount is not None: if d_item.get("input_price") is not None: d_item["input_price"] = round(d_item["input_price"] * effective_discount, 6) if d_item.get("output_price") is not None: d_item["output_price"] = round(d_item["output_price"] * effective_discount, 6) discounted_prices.append(DiscountedPriceItem(url=r["url"], model_name=r["model_name"], discount=effective_discount, **d_item)) all_types = [ ModelTypeItem(model_name=r["model_name"], type=_extract_type(_j(r["model_info"])) or []) for r in rows ] return PricesResponse( version=current_version, models=models, parsed_prices=parsed_prices, discounted_prices=discounted_prices, types=all_types, )