| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383 |
- """
- 爬虫数据同步服务
- 爬虫 API 响应格式(参考 {.txt):
- {
- "version": 14,
- "discount": 0.8,
- "models": [...],
- "parsed_prices": [
- { "model_name": "qwen3-max", "label": "input<=32k",
- "tier_min": 0, "tier_max": 32000, "tier_unit": "tokens",
- "input_price": 2.5, "output_price": 10.0,
- "unit": "元/每百万tokens", "currency": "CNY", "url": "..." }
- ],
- "discounted_prices": [
- { "model_name": "qwen3-max", "label": "input<=32k",
- "input_price": 2.0, "output_price": 8.0,
- "unit": "元/每百万tokens", "currency": "CNY",
- "discount": 0.8, "url": "..." }
- ],
- "types": [...]
- }
- """
- import base64
- import hashlib
- import logging
- from datetime import datetime
- from decimal import Decimal
- from typing import Optional
- import httpx
- from sqlalchemy.orm import Session
- from app.core.config import get_performance_settings
- from app.models.model import ModelNew, ModelPriceNew, CrawlerSyncLog, infer_categories
- logger = logging.getLogger(__name__)
- settings = get_performance_settings()
- CRAWLER_BASE_URL = getattr(settings, "crawler_base_url", "http://localhost:8000")
- CRAWLER_REFERER = getattr(settings, "crawler_referer", "aigcspace.com")
- APIKEY_ENCRYPT_KEY = getattr(settings, "apikey_encrypt_key", "")
- def _crawler_decrypt_api_key(ciphertext: str, secret_key: str) -> Optional[str]:
- """
- 解密爬虫返回的 api_key 密文。
- 算法:Base64 URL-safe 解码 → 循环右移 → XOR 密钥流 → UTF-8
- 密钥流:SHA-256(key_bytes + block_index_big_endian) 循环拼接
- """
- if not ciphertext or not secret_key:
- return None
- try:
- # 1. Base64 URL-safe 解码(补齐 padding)
- rem = len(ciphertext) % 4
- if rem:
- ciphertext += '=' * (4 - rem)
- data = base64.urlsafe_b64decode(ciphertext)
- # 2. 派生密钥流
- key_bytes = secret_key.encode('utf-8')
- stream = bytearray()
- block = 0
- while len(stream) < len(data):
- stream.extend(
- hashlib.sha256(key_bytes + block.to_bytes(4, 'big')).digest()
- )
- block += 1
- keystream = bytes(stream[:len(data)])
- # 3. 循环右移 + XOR
- result = bytearray(len(data))
- for i, byte in enumerate(data):
- n = i % 5 + 1
- unshifted = ((byte >> n) | (byte << (8 - n))) & 0xFF
- result[i] = unshifted ^ keystream[i]
- return result.decode('utf-8')
- except Exception as e:
- logger.warning(f"解密爬虫 api_key 失败: {e}")
- return None
- def _get_local_version(db: Session) -> int:
- log = db.query(CrawlerSyncLog).filter(
- CrawlerSyncLog.status == "success"
- ).order_by(CrawlerSyncLog.crawler_version.desc()).first()
- return log.crawler_version if log else 0
- def _infer_tags(display_tags: list) -> tuple[Optional[str], Optional[str]]:
- """
- 从 display_tags 推断 tag1(功能标签)和 tag2(品牌/系列标签)。
- 规则:
- - tag2(右上角紫色):品牌/系列标签,如 Qwen3、wan2.6、Realtime-Omni
- - tag1(左下角灰色):功能标签,如 文本生成、视觉理解、图像生成、语音合成
- 品牌标签特征:全英文或含数字的系列名
- """
- if not display_tags:
- return None, None
- # 已知品牌/系列标签前缀
- BRAND_PREFIXES = ('qwen', 'wan', 'realtime', 'omni')
- BRAND_EXACT = {'Qwen3', 'Qwen2.5', 'wan2.6', 'wan2.2', 'Realtime-Omni'}
- tag2 = None # 品牌标签
- tag1 = None # 功能标签
- for tag in display_tags:
- tag_lower = tag.lower()
- is_brand = (
- tag in BRAND_EXACT
- or any(tag_lower.startswith(p) for p in BRAND_PREFIXES)
- )
- if is_brand:
- if tag2 is None:
- tag2 = tag
- else:
- if tag1 is None:
- tag1 = tag
- return tag1, tag2
- def _infer_multiplier(unit: str) -> int:
- u = unit.lower()
- if "百万" in unit or "million" in u:
- return 1_000_000
- if "千" in unit or ("k" in u and "token" in u):
- return 1_000
- if "万" in unit:
- return 10_000
- return 1
- def _discount_label(rate: Decimal) -> Optional[str]:
- """0.8 -> '8折',1.0 -> None"""
- if rate >= Decimal("1.0") or rate <= Decimal("0"):
- return None
- tenths = round(float(rate) * 10)
- return f"{tenths}折" if tenths < 10 else None
- def _d(val, default=0) -> Decimal:
- try:
- return Decimal(str(val)) if val is not None else Decimal(str(default))
- except Exception:
- return Decimal(str(default))
- async def sync_from_crawler(db: Session) -> dict:
- local_version = _get_local_version(db)
- try:
- async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
- resp = await client.get(
- f"{CRAWLER_BASE_URL}/api/public/prices",
- headers={"Referer": CRAWLER_REFERER, "version": str(local_version)},
- )
- resp.raise_for_status()
- data = resp.json()
- except Exception as e:
- logger.error(f"请求爬虫数据失败: {e}")
- db.add(CrawlerSyncLog(crawler_version=local_version, status="failed", error_message=str(e)))
- db.commit()
- return {"synced": False, "error": str(e)}
- if data.get("up_to_date"):
- logger.info(f"爬虫数据已是最新 v{data.get('version')}")
- return {"synced": False, "version": data.get("version"), "up_to_date": True}
- remote_version = data.get("version", 0)
- models_data = data.get("models", [])
- parsed_prices = data.get("parsed_prices", [])
- disc_prices = data.get("discounted_prices", [])
- # 全局折扣兜底(旧格式兼容),新格式中折扣已下沉到每个模型/价格条目
- global_discount = _d(data.get("discount", 1))
- # 折扣价索引:(model_name, label) -> discounted item
- disc_map: dict[tuple, dict] = {
- (d["model_name"], d["label"]): d
- for d in disc_prices
- if d.get("model_name") and d.get("label")
- }
- # 模型级折扣索引:model_name -> discount rate
- # 新格式中每个模型对象携带自己的 discount 字段
- model_discount_map: dict[str, Decimal] = {
- m["model_name"]: _d(m["discount"], 1)
- for m in models_data
- if m.get("model_name") and m.get("discount") is not None
- }
- model_count = 0
- price_count = 0
- crawled_at = datetime.utcnow()
- try:
- # ── 同步模型主表 ──────────────────────────────────────────
- for m in models_data:
- model_code = m.get("model_name")
- if not model_code:
- continue
- info = m.get("model_info") or {}
- inp = info.get("input_modalities") or []
- out = info.get("output_modalities") or []
- cats = infer_categories(inp, out, model_code, info.get("display_tags"))
- tag1, tag2 = _infer_tags(info.get("display_tags") or [])
- # keywords 用 display_tags 里的功能标签(非品牌标签),换行分隔,供前端左下角展示
- keywords = '\n'.join(t for t in (info.get("display_tags") or []) if t != tag2) or None
- # 图像翻译模型(如 qwen-mt-image)走专用路由,不在通用模型列表展示
- is_image_translation = 'mt-image' in model_code.lower()
- group_name = m.get("group_name") or None
- raw_api_key = m.get("api_key") or None
- # 爬虫传入的 api_key 是用专有算法加密的密文,需先解密得到明文,再用 AES 加密存储
- encrypted_api_key = None
- if raw_api_key and APIKEY_ENCRYPT_KEY:
- plain_key = _crawler_decrypt_api_key(raw_api_key, APIKEY_ENCRYPT_KEY)
- if plain_key:
- from app.services.crypto_utils import encrypt_api_key
- encrypted_api_key = encrypt_api_key(plain_key)
- else:
- logger.warning(f"模型 {model_code} 的 api_key 解密失败,跳过存储")
- elif raw_api_key:
- logger.warning(f"未配置 APIKEY_ENCRYPT_KEY,模型 {model_code} 的 api_key 无法解密")
- existing = db.query(ModelNew).filter(ModelNew.model_code == model_code).first()
- if existing:
- existing.description = info.get("description") or existing.description
- existing.display_tags = info.get("display_tags") or existing.display_tags
- existing.input_modalities = inp or existing.input_modalities
- existing.output_modalities = out or existing.output_modalities
- existing.features = info.get("features") or existing.features
- existing.rate_limits = m.get("rate_limits") or existing.rate_limits
- existing.tool_call_prices = m.get("tool_prices") or existing.tool_call_prices
- existing.source_url = m.get("url") or existing.source_url
- existing.crawled_at = crawled_at
- existing.raw_prices = m.get("prices") or existing.raw_prices
- if not existing.img and m.get("icon"):
- existing.img = m.get("icon")
- # categories 始终根据最新 modalities 重新推断,确保分类准确
- existing.categories = cats
- if is_image_translation:
- existing.is_show_enabled = False
- # 只在未手动设置时才用推断值更新 tag1/tag2
- if not existing.tag1 and tag1:
- existing.tag1 = tag1
- if not existing.tag2 and tag2:
- existing.tag2 = tag2
- if not existing.keywords and keywords:
- existing.keywords = keywords
- # 始终更新 group_name 和 encrypted_api_key(爬虫数据优先)
- if group_name:
- existing.group_name = group_name
- if encrypted_api_key:
- existing.encrypted_api_key = encrypted_api_key
- else:
- db.add(ModelNew(
- model_code=model_code,
- description=info.get("description"),
- display_tags=info.get("display_tags"),
- input_modalities=inp,
- output_modalities=out,
- features=info.get("features"),
- rate_limits=m.get("rate_limits"),
- tool_call_prices=m.get("tool_prices"),
- source_url=m.get("url"),
- crawled_at=crawled_at,
- display_name=model_code,
- supplier="Qwen",
- categories=cats,
- tag1=tag1,
- tag2=tag2,
- keywords=keywords,
- is_show_enabled=not is_image_translation,
- is_api_enabled=True,
- raw_prices=m.get("prices"),
- img=m.get("icon"),
- group_name=group_name,
- encrypted_api_key=encrypted_api_key,
- ))
- model_count += 1
- db.flush()
- # ── 同步价格表 ────────────────────────────────────────────
- # 先按模型维度批量失效,避免旧 label 残留导致误判为阶梯计费
- synced_model_codes = {p["model_name"] for p in parsed_prices if p.get("model_name")}
- for mc in synced_model_codes:
- db.query(ModelPriceNew).filter(
- ModelPriceNew.model_code == mc,
- ModelPriceNew.is_active == True,
- ).update({"is_active": False})
- db.flush()
- # parsed_prices = 原价,discounted_prices = 折扣价
- for p in parsed_prices:
- model_code = p.get("model_name")
- label = p.get("label")
- if not model_code or not label:
- continue
- in_orig = _d(p.get("input_price"), 0)
- out_orig = _d(p.get("output_price"), 0)
- unit = p.get("unit", "")
- currency = p.get("currency", "CNY")
- tier_min = _d(p["tier_min"]) if p.get("tier_min") is not None else None
- tier_max = _d(p["tier_max"]) if p.get("tier_max") is not None else None
- tier_unit = p.get("tier_unit")
- # 从折扣价表取对应记录
- disc = disc_map.get((model_code, label), {})
- if disc:
- in_disc = _d(disc.get("input_price"), in_orig)
- out_disc = _d(disc.get("output_price"), out_orig)
- # 优先取 discounted_prices 条目自带的 discount,
- # 其次取模型级 discount,最后兜底全局 discount
- rate = _d(
- disc.get("discount")
- or model_discount_map.get(model_code)
- or global_discount,
- 1
- )
- else:
- # 没有对应折扣记录,按优先级取折扣率:模型级 > 全局
- rate = model_discount_map.get(model_code, global_discount)
- in_disc = (in_orig * rate).quantize(Decimal("0.00000001"))
- out_disc = (out_orig * rate).quantize(Decimal("0.00000001"))
- # 旧记录失效
- db.query(ModelPriceNew).filter(
- ModelPriceNew.model_code == model_code,
- ModelPriceNew.label == label,
- ModelPriceNew.is_active == True,
- ).update({"is_active": False})
- db.add(ModelPriceNew(
- model_code=model_code,
- label=label,
- tier_min=tier_min,
- tier_max=tier_max,
- tier_unit=tier_unit,
- input_price_original=in_orig,
- output_price_original=out_orig,
- input_price_discounted=in_disc,
- output_price_discounted=out_disc,
- discount_rate=rate,
- discount_label=_discount_label(rate),
- currency=currency,
- unit=unit,
- display_multiplier=_infer_multiplier(unit),
- source_url=p.get("url"),
- crawled_at=crawled_at,
- is_active=True,
- ))
- price_count += 1
- db.add(CrawlerSyncLog(
- crawler_version=remote_version,
- model_count=model_count,
- price_count=price_count,
- status="success",
- ))
- db.commit()
- logger.info(
- f"同步完成: v{local_version}→v{remote_version}, "
- f"模型+{model_count}, 价格{price_count}条, "
- f"模型级折扣数={len(model_discount_map)}"
- )
- return {"synced": True, "version": remote_version,
- "model_count": model_count, "price_count": price_count}
- except Exception as e:
- db.rollback()
- logger.error(f"同步写入失败: {e}")
- db.add(CrawlerSyncLog(crawler_version=remote_version, status="failed", error_message=str(e)))
- db.commit()
- return {"synced": False, "error": str(e)}
|