""" 爬虫数据同步服务 爬虫 API 响应格式(参考 {.txt): { "version": 14, "discount": 0.8, "models": [...], "parsed_prices": [ { "model_name": "qwen3-max", "label": "input<=32k", "tier_min": 0, "tier_max": 32000, "tier_unit": "tokens", "input_price": 2.5, "output_price": 10.0, "unit": "元/每百万tokens", "currency": "CNY", "url": "..." } ], "discounted_prices": [ { "model_name": "qwen3-max", "label": "input<=32k", "input_price": 2.0, "output_price": 8.0, "unit": "元/每百万tokens", "currency": "CNY", "discount": 0.8, "url": "..." } ], "types": [...] } """ import base64 import hashlib import logging from datetime import datetime from decimal import Decimal from typing import Optional import httpx from sqlalchemy.orm import Session from app.core.config import get_performance_settings from app.models.model import ModelNew, ModelPriceNew, CrawlerSyncLog, infer_categories logger = logging.getLogger(__name__) settings = get_performance_settings() CRAWLER_BASE_URL = getattr(settings, "crawler_base_url", "http://localhost:8000") CRAWLER_REFERER = getattr(settings, "crawler_referer", "aigcspace.com") APIKEY_ENCRYPT_KEY = getattr(settings, "apikey_encrypt_key", "") def _crawler_decrypt_api_key(ciphertext: str, secret_key: str) -> Optional[str]: """ 解密爬虫返回的 api_key 密文。 算法:Base64 URL-safe 解码 → 循环右移 → XOR 密钥流 → UTF-8 密钥流:SHA-256(key_bytes + block_index_big_endian) 循环拼接 """ if not ciphertext or not secret_key: return None try: # 1. Base64 URL-safe 解码(补齐 padding) rem = len(ciphertext) % 4 if rem: ciphertext += '=' * (4 - rem) data = base64.urlsafe_b64decode(ciphertext) # 2. 派生密钥流 key_bytes = secret_key.encode('utf-8') stream = bytearray() block = 0 while len(stream) < len(data): stream.extend( hashlib.sha256(key_bytes + block.to_bytes(4, 'big')).digest() ) block += 1 keystream = bytes(stream[:len(data)]) # 3. 循环右移 + XOR result = bytearray(len(data)) for i, byte in enumerate(data): n = i % 5 + 1 unshifted = ((byte >> n) | (byte << (8 - n))) & 0xFF result[i] = unshifted ^ keystream[i] return result.decode('utf-8') except Exception as e: logger.warning(f"解密爬虫 api_key 失败: {e}") return None def _get_local_version(db: Session) -> int: log = db.query(CrawlerSyncLog).filter( CrawlerSyncLog.status == "success" ).order_by(CrawlerSyncLog.crawler_version.desc()).first() return log.crawler_version if log else 0 def _infer_tags(display_tags: list) -> tuple[Optional[str], Optional[str]]: """ 从 display_tags 推断 tag1(功能标签)和 tag2(品牌/系列标签)。 规则: - tag2(右上角紫色):品牌/系列标签,如 Qwen3、wan2.6、Realtime-Omni - tag1(左下角灰色):功能标签,如 文本生成、视觉理解、图像生成、语音合成 品牌标签特征:全英文或含数字的系列名 """ if not display_tags: return None, None # 已知品牌/系列标签前缀 BRAND_PREFIXES = ('qwen', 'wan', 'realtime', 'omni') BRAND_EXACT = {'Qwen3', 'Qwen2.5', 'wan2.6', 'wan2.2', 'Realtime-Omni'} tag2 = None # 品牌标签 tag1 = None # 功能标签 for tag in display_tags: tag_lower = tag.lower() is_brand = ( tag in BRAND_EXACT or any(tag_lower.startswith(p) for p in BRAND_PREFIXES) ) if is_brand: if tag2 is None: tag2 = tag else: if tag1 is None: tag1 = tag return tag1, tag2 def _infer_multiplier(unit: str) -> int: u = unit.lower() if "百万" in unit or "million" in u: return 1_000_000 if "千" in unit or ("k" in u and "token" in u): return 1_000 if "万" in unit: return 10_000 return 1 def _discount_label(rate: Decimal) -> Optional[str]: """0.8 -> '8折',1.0 -> None""" if rate >= Decimal("1.0") or rate <= Decimal("0"): return None tenths = round(float(rate) * 10) return f"{tenths}折" if tenths < 10 else None def _d(val, default=0) -> Decimal: try: return Decimal(str(val)) if val is not None else Decimal(str(default)) except Exception: return Decimal(str(default)) async def sync_from_crawler(db: Session) -> dict: local_version = _get_local_version(db) try: async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client: resp = await client.get( f"{CRAWLER_BASE_URL}/api/public/prices", headers={"Referer": CRAWLER_REFERER, "version": str(local_version)}, ) resp.raise_for_status() data = resp.json() except Exception as e: logger.error(f"请求爬虫数据失败: {e}") db.add(CrawlerSyncLog(crawler_version=local_version, status="failed", error_message=str(e))) db.commit() return {"synced": False, "error": str(e)} if data.get("up_to_date"): logger.info(f"爬虫数据已是最新 v{data.get('version')}") return {"synced": False, "version": data.get("version"), "up_to_date": True} remote_version = data.get("version", 0) models_data = data.get("models", []) parsed_prices = data.get("parsed_prices", []) disc_prices = data.get("discounted_prices", []) # 全局折扣兜底(旧格式兼容),新格式中折扣已下沉到每个模型/价格条目 global_discount = _d(data.get("discount", 1)) # 折扣价索引:(model_name, label) -> discounted item disc_map: dict[tuple, dict] = { (d["model_name"], d["label"]): d for d in disc_prices if d.get("model_name") and d.get("label") } # 模型级折扣索引:model_name -> discount rate # 新格式中每个模型对象携带自己的 discount 字段 model_discount_map: dict[str, Decimal] = { m["model_name"]: _d(m["discount"], 1) for m in models_data if m.get("model_name") and m.get("discount") is not None } model_count = 0 price_count = 0 crawled_at = datetime.utcnow() try: # ── 同步模型主表 ────────────────────────────────────────── for m in models_data: model_code = m.get("model_name") if not model_code: continue info = m.get("model_info") or {} inp = info.get("input_modalities") or [] out = info.get("output_modalities") or [] cats = infer_categories(inp, out, model_code, info.get("display_tags")) tag1, tag2 = _infer_tags(info.get("display_tags") or []) # keywords 用 display_tags 里的功能标签(非品牌标签),换行分隔,供前端左下角展示 keywords = '\n'.join(t for t in (info.get("display_tags") or []) if t != tag2) or None # 图像翻译模型(如 qwen-mt-image)走专用路由,不在通用模型列表展示 is_image_translation = 'mt-image' in model_code.lower() group_name = m.get("group_name") or None raw_api_key = m.get("api_key") or None # 爬虫传入的 api_key 是用专有算法加密的密文,需先解密得到明文,再用 AES 加密存储 encrypted_api_key = None if raw_api_key and APIKEY_ENCRYPT_KEY: plain_key = _crawler_decrypt_api_key(raw_api_key, APIKEY_ENCRYPT_KEY) if plain_key: from app.services.crypto_utils import encrypt_api_key encrypted_api_key = encrypt_api_key(plain_key) else: logger.warning(f"模型 {model_code} 的 api_key 解密失败,跳过存储") elif raw_api_key: logger.warning(f"未配置 APIKEY_ENCRYPT_KEY,模型 {model_code} 的 api_key 无法解密") existing = db.query(ModelNew).filter(ModelNew.model_code == model_code).first() if existing: existing.description = info.get("description") or existing.description existing.display_tags = info.get("display_tags") or existing.display_tags existing.input_modalities = inp or existing.input_modalities existing.output_modalities = out or existing.output_modalities existing.features = info.get("features") or existing.features existing.rate_limits = m.get("rate_limits") or existing.rate_limits existing.tool_call_prices = m.get("tool_prices") or existing.tool_call_prices existing.source_url = m.get("url") or existing.source_url existing.crawled_at = crawled_at existing.raw_prices = m.get("prices") or existing.raw_prices if not existing.img and m.get("icon"): existing.img = m.get("icon") # categories 始终根据最新 modalities 重新推断,确保分类准确 existing.categories = cats if is_image_translation: existing.is_show_enabled = False # 只在未手动设置时才用推断值更新 tag1/tag2 if not existing.tag1 and tag1: existing.tag1 = tag1 if not existing.tag2 and tag2: existing.tag2 = tag2 if not existing.keywords and keywords: existing.keywords = keywords # 始终更新 group_name 和 encrypted_api_key(爬虫数据优先) if group_name: existing.group_name = group_name if encrypted_api_key: existing.encrypted_api_key = encrypted_api_key else: db.add(ModelNew( model_code=model_code, description=info.get("description"), display_tags=info.get("display_tags"), input_modalities=inp, output_modalities=out, features=info.get("features"), rate_limits=m.get("rate_limits"), tool_call_prices=m.get("tool_prices"), source_url=m.get("url"), crawled_at=crawled_at, display_name=model_code, supplier="Qwen", categories=cats, tag1=tag1, tag2=tag2, keywords=keywords, is_show_enabled=not is_image_translation, is_api_enabled=True, raw_prices=m.get("prices"), img=m.get("icon"), group_name=group_name, encrypted_api_key=encrypted_api_key, )) model_count += 1 db.flush() # ── 同步价格表 ──────────────────────────────────────────── # 先按模型维度批量失效,避免旧 label 残留导致误判为阶梯计费 synced_model_codes = {p["model_name"] for p in parsed_prices if p.get("model_name")} for mc in synced_model_codes: db.query(ModelPriceNew).filter( ModelPriceNew.model_code == mc, ModelPriceNew.is_active == True, ).update({"is_active": False}) db.flush() # parsed_prices = 原价,discounted_prices = 折扣价 for p in parsed_prices: model_code = p.get("model_name") label = p.get("label") if not model_code or not label: continue in_orig = _d(p.get("input_price"), 0) out_orig = _d(p.get("output_price"), 0) unit = p.get("unit", "") currency = p.get("currency", "CNY") tier_min = _d(p["tier_min"]) if p.get("tier_min") is not None else None tier_max = _d(p["tier_max"]) if p.get("tier_max") is not None else None tier_unit = p.get("tier_unit") # 从折扣价表取对应记录 disc = disc_map.get((model_code, label), {}) if disc: in_disc = _d(disc.get("input_price"), in_orig) out_disc = _d(disc.get("output_price"), out_orig) # 优先取 discounted_prices 条目自带的 discount, # 其次取模型级 discount,最后兜底全局 discount rate = _d( disc.get("discount") or model_discount_map.get(model_code) or global_discount, 1 ) else: # 没有对应折扣记录,按优先级取折扣率:模型级 > 全局 rate = model_discount_map.get(model_code, global_discount) in_disc = (in_orig * rate).quantize(Decimal("0.00000001")) out_disc = (out_orig * rate).quantize(Decimal("0.00000001")) # 旧记录失效 db.query(ModelPriceNew).filter( ModelPriceNew.model_code == model_code, ModelPriceNew.label == label, ModelPriceNew.is_active == True, ).update({"is_active": False}) db.add(ModelPriceNew( model_code=model_code, label=label, tier_min=tier_min, tier_max=tier_max, tier_unit=tier_unit, input_price_original=in_orig, output_price_original=out_orig, input_price_discounted=in_disc, output_price_discounted=out_disc, discount_rate=rate, discount_label=_discount_label(rate), currency=currency, unit=unit, display_multiplier=_infer_multiplier(unit), source_url=p.get("url"), crawled_at=crawled_at, is_active=True, )) price_count += 1 db.add(CrawlerSyncLog( crawler_version=remote_version, model_count=model_count, price_count=price_count, status="success", )) db.commit() logger.info( f"同步完成: v{local_version}→v{remote_version}, " f"模型+{model_count}, 价格{price_count}条, " f"模型级折扣数={len(model_discount_map)}" ) return {"synced": True, "version": remote_version, "model_count": model_count, "price_count": price_count} except Exception as e: db.rollback() logger.error(f"同步写入失败: {e}") db.add(CrawlerSyncLog(crawler_version=remote_version, status="failed", error_message=str(e))) db.commit() return {"synced": False, "error": str(e)}