| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267 |
- from __future__ import annotations
- import re
- from datetime import datetime
- from typing import List, Optional, Union
- from urllib.parse import urlparse
- import json
- from app.utils.price_parser import parse_prices
- from app.utils.apikey_crypto import try_encrypt
- from fastapi import APIRouter, HTTPException, Request
- from pydantic import BaseModel
- from app.db import get_pool
- from app.services.geo import geo_resolver
- router = APIRouter()
- class PublicPriceOut(BaseModel):
- url: str
- model_name: str
- prices: dict
- model_info: Optional[dict] = None
- rate_limits: Optional[dict] = None
- tool_prices: Optional[list] = None
- icon: Optional[str] = None
- scraped_at: datetime
- discount: Optional[float] = None
- api_key: Optional[str] = None
- group_name: Optional[str] = None
- class ParsedPriceItem(BaseModel):
- url: str
- model_name: str
- tier_min: Optional[float] = None
- tier_max: Optional[float] = None
- tier_unit: Optional[str] = None
- input_price: Optional[float] = None
- output_price: Optional[float] = None
- currency: str = "CNY"
- unit: Optional[str] = None
- label: Optional[str] = None
- class DiscountedPriceItem(BaseModel):
- url: str
- model_name: str
- tier_min: Optional[float] = None
- tier_max: Optional[float] = None
- tier_unit: Optional[str] = None
- input_price: Optional[float] = None
- output_price: Optional[float] = None
- currency: str = "CNY"
- unit: Optional[str] = None
- label: Optional[str] = None
- discount: Optional[float] = None # None 表示未知,1.0 表示原价
- class ModelTypeItem(BaseModel):
- model_name: str
- type: List[str]
- class PricesResponse(BaseModel):
- version: int
- models: List[PublicPriceOut]
- parsed_prices: List[ParsedPriceItem]
- discounted_prices: List[DiscountedPriceItem]
- types: List[ModelTypeItem]
- class UpToDateResponse(BaseModel):
- up_to_date: bool = True
- version: int
- def _extract_domain(referer: Optional[str]) -> Optional[str]:
- if not referer:
- return None
- try:
- return urlparse(referer).netloc or None
- except Exception:
- return None
- @router.get("/prices", response_model=Union[PricesResponse, UpToDateResponse])
- async def get_public_prices(
- request: Request,
- url: Optional[str] = None,
- ) -> Union[PricesResponse, UpToDateResponse]:
- pool = get_pool()
- # referer 必须提供
- referer = request.headers.get("referer") or request.headers.get("origin")
- if not referer:
- raise HTTPException(status_code=400, detail="Missing Referer header")
- # version 从 Header 读取,默认 0(首次请求)
- try:
- version = int(request.headers.get("version", "0") or "0")
- except ValueError:
- version = 0
- # 记录调用来源
- ip = request.client.host if request.client else "unknown"
- geo = geo_resolver.resolve(ip)
- try:
- await pool.execute(
- "INSERT INTO price_api_logs (ip, referer, org, country, city) VALUES ($1, $2, $3, $4, $5)",
- ip, referer, geo.org, geo.country, geo.city,
- )
- except Exception:
- pass
- # 查调用方域名对应的折扣
- caller_domain = _extract_domain(referer)
- discount_rate: float = 1.0
- # 域名级别的模型自定义价格:{model_name: {input_price, output_price}}
- model_custom_prices: dict = {}
- if caller_domain:
- row = await pool.fetchrow("SELECT discount FROM discounts WHERE domain = $1", caller_domain)
- if row:
- discount_rate = float(row["discount"])
- # 加载该域名下所有模型的自定义折扣
- mp_rows = await pool.fetch(
- "SELECT model_name, discount FROM domain_model_prices WHERE domain = $1",
- caller_domain,
- )
- for mp in mp_rows:
- model_custom_prices[mp["model_name"]] = float(mp["discount"])
- def _j(v):
- if v is None:
- return None
- return v if isinstance(v, (dict, list)) else json.loads(v)
- # 读取版本号:优先用域名专属版本,没有则回退到全局版本
- if caller_domain:
- ver_row = await pool.fetchrow(
- "SELECT version FROM domain_version WHERE domain = $1", caller_domain
- )
- if ver_row:
- current_version = int(ver_row["version"])
- else:
- ver_row = await pool.fetchrow("SELECT version FROM price_snapshot_version WHERE id = 1")
- current_version = int(ver_row["version"]) if ver_row else 0
- else:
- ver_row = await pool.fetchrow("SELECT version FROM price_snapshot_version WHERE id = 1")
- current_version = int(ver_row["version"]) if ver_row else 0
- # version != 0 且与当前一致 → 无需更新(0 视为首次请求,强制返回数据)
- if version != 0 and version == current_version:
- return UpToDateResponse(up_to_date=True, version=current_version)
- # 从 price_snapshot 读取数据,LEFT JOIN models 取 api_key 和 group
- if url is None:
- rows = await pool.fetch(
- """
- SELECT ps.url, ps.model_name, ps.prices, ps.model_info, ps.rate_limits,
- ps.tool_prices, ps.icon, ps.updated_at,
- k.key_value AS api_key,
- m.group_id, g.name AS group_name
- FROM price_snapshot ps
- LEFT JOIN models m ON m.url = ps.url
- LEFT JOIN api_keys k ON k.id = m.api_key_id
- LEFT JOIN model_groups g ON g.id = m.group_id
- ORDER BY ps.url
- """
- )
- else:
- rows = await pool.fetch(
- """
- SELECT ps.url, ps.model_name, ps.prices, ps.model_info, ps.rate_limits,
- ps.tool_prices, ps.icon, ps.updated_at,
- k.key_value AS api_key,
- m.group_id, g.name AS group_name
- FROM price_snapshot ps
- LEFT JOIN models m ON m.url = ps.url
- LEFT JOIN api_keys k ON k.id = m.api_key_id
- LEFT JOIN model_groups g ON g.id = m.group_id
- WHERE ps.url = $1
- """,
- url,
- )
- if not rows:
- raise HTTPException(status_code=404, detail="No price snapshot found for the given URL")
- if not rows:
- raise HTTPException(status_code=503, detail="Price snapshot not yet available")
- # version != 0 且与当前一致 → 无需更新
- if version != 0 and version == current_version:
- return UpToDateResponse(up_to_date=True, version=current_version)
- def _extract_type(model_info: Optional[dict]) -> Optional[List[str]]:
- if not model_info:
- return None
- tags = model_info.get("display_tags", [])
- TYPE_TAGS = {"文本生成", "图像生成", "视频生成", "向量表示", "向量模型", "多模态向量", "语音识别", "实时语音识别", "语音合成"}
- result = [t for t in tags if t in TYPE_TAGS]
- return result if result else None
- models = [PublicPriceOut(
- url=r["url"],
- model_name=r["model_name"],
- prices=_j(r["prices"]) or {},
- model_info=_j(r["model_info"]),
- rate_limits=_j(r["rate_limits"]),
- tool_prices=_j(r["tool_prices"]),
- icon=r["icon"],
- scraped_at=r["updated_at"],
- discount=model_custom_prices.get(r["model_name"]) if model_custom_prices.get(r["model_name"]) is not None else discount_rate,
- api_key=try_encrypt(r["api_key"]),
- group_name=r["group_name"],
- ) for r in rows]
- parsed_prices: List[ParsedPriceItem] = []
- discounted_prices: List[DiscountedPriceItem] = []
- # 只保留输入/输出主价格,过滤掉缓存命中、Batch File、调优等附加价格
- _EXCLUDED_LABEL_RE = re.compile(
- r"缓存|batch\s*file|批量|调优|思考模式",
- re.I,
- )
- for r in rows:
- for item in parse_prices(_j(r["prices"]) or {}):
- # 过滤掉输入和输出价格都为 None 的条目(保留单边价格,如向量模型、图像生成等)
- if item.get("input_price") is None and item.get("output_price") is None:
- continue
- # 过滤掉缓存命中、Batch File、调优等非主价格条目
- label = item.get("label") or ""
- if _EXCLUDED_LABEL_RE.search(label):
- continue
- # 将 label 改为中文
- item = dict(item)
- if item.get("label") == "input/output":
- item["label"] = "输入/输出"
- parsed_prices.append(ParsedPriceItem(url=r["url"], model_name=r["model_name"], **item))
- d_item = dict(item)
- model_name = r["model_name"]
- custom = model_custom_prices.get(model_name)
- # 模型级折扣优先,没有则用域名全局折扣
- effective_discount = custom if custom is not None else discount_rate
- if effective_discount is not None:
- if d_item.get("input_price") is not None:
- d_item["input_price"] = round(d_item["input_price"] * effective_discount, 6)
- if d_item.get("output_price") is not None:
- d_item["output_price"] = round(d_item["output_price"] * effective_discount, 6)
- discounted_prices.append(DiscountedPriceItem(url=r["url"], model_name=r["model_name"], discount=effective_discount, **d_item))
- all_types = [
- ModelTypeItem(model_name=r["model_name"], type=_extract_type(_j(r["model_info"])) or [])
- for r in rows
- ]
- return PricesResponse(
- version=current_version,
- models=models,
- parsed_prices=parsed_prices,
- discounted_prices=discounted_prices,
- types=all_types,
- )
|