crawler_sync_service.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383
  1. """
  2. 爬虫数据同步服务
  3. 爬虫 API 响应格式(参考 {.txt):
  4. {
  5. "version": 14,
  6. "discount": 0.8,
  7. "models": [...],
  8. "parsed_prices": [
  9. { "model_name": "qwen3-max", "label": "input<=32k",
  10. "tier_min": 0, "tier_max": 32000, "tier_unit": "tokens",
  11. "input_price": 2.5, "output_price": 10.0,
  12. "unit": "元/每百万tokens", "currency": "CNY", "url": "..." }
  13. ],
  14. "discounted_prices": [
  15. { "model_name": "qwen3-max", "label": "input<=32k",
  16. "input_price": 2.0, "output_price": 8.0,
  17. "unit": "元/每百万tokens", "currency": "CNY",
  18. "discount": 0.8, "url": "..." }
  19. ],
  20. "types": [...]
  21. }
  22. """
  23. import base64
  24. import hashlib
  25. import logging
  26. from datetime import datetime
  27. from decimal import Decimal
  28. from typing import Optional
  29. import httpx
  30. from sqlalchemy.orm import Session
  31. from app.core.config import get_performance_settings
  32. from app.models.model import ModelNew, ModelPriceNew, CrawlerSyncLog, infer_categories
  33. logger = logging.getLogger(__name__)
  34. settings = get_performance_settings()
  35. CRAWLER_BASE_URL = getattr(settings, "crawler_base_url", "http://localhost:8000")
  36. CRAWLER_REFERER = getattr(settings, "crawler_referer", "aigcspace.com")
  37. APIKEY_ENCRYPT_KEY = getattr(settings, "apikey_encrypt_key", "")
  38. def _crawler_decrypt_api_key(ciphertext: str, secret_key: str) -> Optional[str]:
  39. """
  40. 解密爬虫返回的 api_key 密文。
  41. 算法:Base64 URL-safe 解码 → 循环右移 → XOR 密钥流 → UTF-8
  42. 密钥流:SHA-256(key_bytes + block_index_big_endian) 循环拼接
  43. """
  44. if not ciphertext or not secret_key:
  45. return None
  46. try:
  47. # 1. Base64 URL-safe 解码(补齐 padding)
  48. rem = len(ciphertext) % 4
  49. if rem:
  50. ciphertext += '=' * (4 - rem)
  51. data = base64.urlsafe_b64decode(ciphertext)
  52. # 2. 派生密钥流
  53. key_bytes = secret_key.encode('utf-8')
  54. stream = bytearray()
  55. block = 0
  56. while len(stream) < len(data):
  57. stream.extend(
  58. hashlib.sha256(key_bytes + block.to_bytes(4, 'big')).digest()
  59. )
  60. block += 1
  61. keystream = bytes(stream[:len(data)])
  62. # 3. 循环右移 + XOR
  63. result = bytearray(len(data))
  64. for i, byte in enumerate(data):
  65. n = i % 5 + 1
  66. unshifted = ((byte >> n) | (byte << (8 - n))) & 0xFF
  67. result[i] = unshifted ^ keystream[i]
  68. return result.decode('utf-8')
  69. except Exception as e:
  70. logger.warning(f"解密爬虫 api_key 失败: {e}")
  71. return None
  72. def _get_local_version(db: Session) -> int:
  73. log = db.query(CrawlerSyncLog).filter(
  74. CrawlerSyncLog.status == "success"
  75. ).order_by(CrawlerSyncLog.crawler_version.desc()).first()
  76. return log.crawler_version if log else 0
  77. def _infer_tags(display_tags: list) -> tuple[Optional[str], Optional[str]]:
  78. """
  79. 从 display_tags 推断 tag1(功能标签)和 tag2(品牌/系列标签)。
  80. 规则:
  81. - tag2(右上角紫色):品牌/系列标签,如 Qwen3、wan2.6、Realtime-Omni
  82. - tag1(左下角灰色):功能标签,如 文本生成、视觉理解、图像生成、语音合成
  83. 品牌标签特征:全英文或含数字的系列名
  84. """
  85. if not display_tags:
  86. return None, None
  87. # 已知品牌/系列标签前缀
  88. BRAND_PREFIXES = ('qwen', 'wan', 'realtime', 'omni')
  89. BRAND_EXACT = {'Qwen3', 'Qwen2.5', 'wan2.6', 'wan2.2', 'Realtime-Omni'}
  90. tag2 = None # 品牌标签
  91. tag1 = None # 功能标签
  92. for tag in display_tags:
  93. tag_lower = tag.lower()
  94. is_brand = (
  95. tag in BRAND_EXACT
  96. or any(tag_lower.startswith(p) for p in BRAND_PREFIXES)
  97. )
  98. if is_brand:
  99. if tag2 is None:
  100. tag2 = tag
  101. else:
  102. if tag1 is None:
  103. tag1 = tag
  104. return tag1, tag2
  105. def _infer_multiplier(unit: str) -> int:
  106. u = unit.lower()
  107. if "百万" in unit or "million" in u:
  108. return 1_000_000
  109. if "千" in unit or ("k" in u and "token" in u):
  110. return 1_000
  111. if "万" in unit:
  112. return 10_000
  113. return 1
  114. def _discount_label(rate: Decimal) -> Optional[str]:
  115. """0.8 -> '8折',1.0 -> None"""
  116. if rate >= Decimal("1.0") or rate <= Decimal("0"):
  117. return None
  118. tenths = round(float(rate) * 10)
  119. return f"{tenths}折" if tenths < 10 else None
  120. def _d(val, default=0) -> Decimal:
  121. try:
  122. return Decimal(str(val)) if val is not None else Decimal(str(default))
  123. except Exception:
  124. return Decimal(str(default))
  125. async def sync_from_crawler(db: Session) -> dict:
  126. local_version = _get_local_version(db)
  127. try:
  128. async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
  129. resp = await client.get(
  130. f"{CRAWLER_BASE_URL}/api/public/prices",
  131. headers={"Referer": CRAWLER_REFERER, "version": str(local_version)},
  132. )
  133. resp.raise_for_status()
  134. data = resp.json()
  135. except Exception as e:
  136. logger.error(f"请求爬虫数据失败: {e}")
  137. db.add(CrawlerSyncLog(crawler_version=local_version, status="failed", error_message=str(e)))
  138. db.commit()
  139. return {"synced": False, "error": str(e)}
  140. if data.get("up_to_date"):
  141. logger.info(f"爬虫数据已是最新 v{data.get('version')}")
  142. return {"synced": False, "version": data.get("version"), "up_to_date": True}
  143. remote_version = data.get("version", 0)
  144. models_data = data.get("models", [])
  145. parsed_prices = data.get("parsed_prices", [])
  146. disc_prices = data.get("discounted_prices", [])
  147. # 全局折扣兜底(旧格式兼容),新格式中折扣已下沉到每个模型/价格条目
  148. global_discount = _d(data.get("discount", 1))
  149. # 折扣价索引:(model_name, label) -> discounted item
  150. disc_map: dict[tuple, dict] = {
  151. (d["model_name"], d["label"]): d
  152. for d in disc_prices
  153. if d.get("model_name") and d.get("label")
  154. }
  155. # 模型级折扣索引:model_name -> discount rate
  156. # 新格式中每个模型对象携带自己的 discount 字段
  157. model_discount_map: dict[str, Decimal] = {
  158. m["model_name"]: _d(m["discount"], 1)
  159. for m in models_data
  160. if m.get("model_name") and m.get("discount") is not None
  161. }
  162. model_count = 0
  163. price_count = 0
  164. crawled_at = datetime.utcnow()
  165. try:
  166. # ── 同步模型主表 ──────────────────────────────────────────
  167. for m in models_data:
  168. model_code = m.get("model_name")
  169. if not model_code:
  170. continue
  171. info = m.get("model_info") or {}
  172. inp = info.get("input_modalities") or []
  173. out = info.get("output_modalities") or []
  174. cats = infer_categories(inp, out, model_code, info.get("display_tags"))
  175. tag1, tag2 = _infer_tags(info.get("display_tags") or [])
  176. # keywords 用 display_tags 里的功能标签(非品牌标签),换行分隔,供前端左下角展示
  177. keywords = '\n'.join(t for t in (info.get("display_tags") or []) if t != tag2) or None
  178. # 图像翻译模型(如 qwen-mt-image)走专用路由,不在通用模型列表展示
  179. is_image_translation = 'mt-image' in model_code.lower()
  180. group_name = m.get("group_name") or None
  181. raw_api_key = m.get("api_key") or None
  182. # 爬虫传入的 api_key 是用专有算法加密的密文,需先解密得到明文,再用 AES 加密存储
  183. encrypted_api_key = None
  184. if raw_api_key and APIKEY_ENCRYPT_KEY:
  185. plain_key = _crawler_decrypt_api_key(raw_api_key, APIKEY_ENCRYPT_KEY)
  186. if plain_key:
  187. from app.services.crypto_utils import encrypt_api_key
  188. encrypted_api_key = encrypt_api_key(plain_key)
  189. else:
  190. logger.warning(f"模型 {model_code} 的 api_key 解密失败,跳过存储")
  191. elif raw_api_key:
  192. logger.warning(f"未配置 APIKEY_ENCRYPT_KEY,模型 {model_code} 的 api_key 无法解密")
  193. existing = db.query(ModelNew).filter(ModelNew.model_code == model_code).first()
  194. if existing:
  195. existing.description = info.get("description") or existing.description
  196. existing.display_tags = info.get("display_tags") or existing.display_tags
  197. existing.input_modalities = inp or existing.input_modalities
  198. existing.output_modalities = out or existing.output_modalities
  199. existing.features = info.get("features") or existing.features
  200. existing.rate_limits = m.get("rate_limits") or existing.rate_limits
  201. existing.tool_call_prices = m.get("tool_prices") or existing.tool_call_prices
  202. existing.source_url = m.get("url") or existing.source_url
  203. existing.crawled_at = crawled_at
  204. existing.raw_prices = m.get("prices") or existing.raw_prices
  205. if not existing.img and m.get("icon"):
  206. existing.img = m.get("icon")
  207. # categories 始终根据最新 modalities 重新推断,确保分类准确
  208. existing.categories = cats
  209. if is_image_translation:
  210. existing.is_show_enabled = False
  211. # 只在未手动设置时才用推断值更新 tag1/tag2
  212. if not existing.tag1 and tag1:
  213. existing.tag1 = tag1
  214. if not existing.tag2 and tag2:
  215. existing.tag2 = tag2
  216. if not existing.keywords and keywords:
  217. existing.keywords = keywords
  218. # 始终更新 group_name 和 encrypted_api_key(爬虫数据优先)
  219. if group_name:
  220. existing.group_name = group_name
  221. if encrypted_api_key:
  222. existing.encrypted_api_key = encrypted_api_key
  223. else:
  224. db.add(ModelNew(
  225. model_code=model_code,
  226. description=info.get("description"),
  227. display_tags=info.get("display_tags"),
  228. input_modalities=inp,
  229. output_modalities=out,
  230. features=info.get("features"),
  231. rate_limits=m.get("rate_limits"),
  232. tool_call_prices=m.get("tool_prices"),
  233. source_url=m.get("url"),
  234. crawled_at=crawled_at,
  235. display_name=model_code,
  236. supplier="Qwen",
  237. categories=cats,
  238. tag1=tag1,
  239. tag2=tag2,
  240. keywords=keywords,
  241. is_show_enabled=not is_image_translation,
  242. is_api_enabled=True,
  243. raw_prices=m.get("prices"),
  244. img=m.get("icon"),
  245. group_name=group_name,
  246. encrypted_api_key=encrypted_api_key,
  247. ))
  248. model_count += 1
  249. db.flush()
  250. # ── 同步价格表 ────────────────────────────────────────────
  251. # 先按模型维度批量失效,避免旧 label 残留导致误判为阶梯计费
  252. synced_model_codes = {p["model_name"] for p in parsed_prices if p.get("model_name")}
  253. for mc in synced_model_codes:
  254. db.query(ModelPriceNew).filter(
  255. ModelPriceNew.model_code == mc,
  256. ModelPriceNew.is_active == True,
  257. ).update({"is_active": False})
  258. db.flush()
  259. # parsed_prices = 原价,discounted_prices = 折扣价
  260. for p in parsed_prices:
  261. model_code = p.get("model_name")
  262. label = p.get("label")
  263. if not model_code or not label:
  264. continue
  265. in_orig = _d(p.get("input_price"), 0)
  266. out_orig = _d(p.get("output_price"), 0)
  267. unit = p.get("unit", "")
  268. currency = p.get("currency", "CNY")
  269. tier_min = _d(p["tier_min"]) if p.get("tier_min") is not None else None
  270. tier_max = _d(p["tier_max"]) if p.get("tier_max") is not None else None
  271. tier_unit = p.get("tier_unit")
  272. # 从折扣价表取对应记录
  273. disc = disc_map.get((model_code, label), {})
  274. if disc:
  275. in_disc = _d(disc.get("input_price"), in_orig)
  276. out_disc = _d(disc.get("output_price"), out_orig)
  277. # 优先取 discounted_prices 条目自带的 discount,
  278. # 其次取模型级 discount,最后兜底全局 discount
  279. rate = _d(
  280. disc.get("discount")
  281. or model_discount_map.get(model_code)
  282. or global_discount,
  283. 1
  284. )
  285. else:
  286. # 没有对应折扣记录,按优先级取折扣率:模型级 > 全局
  287. rate = model_discount_map.get(model_code, global_discount)
  288. in_disc = (in_orig * rate).quantize(Decimal("0.00000001"))
  289. out_disc = (out_orig * rate).quantize(Decimal("0.00000001"))
  290. # 旧记录失效
  291. db.query(ModelPriceNew).filter(
  292. ModelPriceNew.model_code == model_code,
  293. ModelPriceNew.label == label,
  294. ModelPriceNew.is_active == True,
  295. ).update({"is_active": False})
  296. db.add(ModelPriceNew(
  297. model_code=model_code,
  298. label=label,
  299. tier_min=tier_min,
  300. tier_max=tier_max,
  301. tier_unit=tier_unit,
  302. input_price_original=in_orig,
  303. output_price_original=out_orig,
  304. input_price_discounted=in_disc,
  305. output_price_discounted=out_disc,
  306. discount_rate=rate,
  307. discount_label=_discount_label(rate),
  308. currency=currency,
  309. unit=unit,
  310. display_multiplier=_infer_multiplier(unit),
  311. source_url=p.get("url"),
  312. crawled_at=crawled_at,
  313. is_active=True,
  314. ))
  315. price_count += 1
  316. db.add(CrawlerSyncLog(
  317. crawler_version=remote_version,
  318. model_count=model_count,
  319. price_count=price_count,
  320. status="success",
  321. ))
  322. db.commit()
  323. logger.info(
  324. f"同步完成: v{local_version}→v{remote_version}, "
  325. f"模型+{model_count}, 价格{price_count}条, "
  326. f"模型级折扣数={len(model_discount_map)}"
  327. )
  328. return {"synced": True, "version": remote_version,
  329. "model_count": model_count, "price_count": price_count}
  330. except Exception as e:
  331. db.rollback()
  332. logger.error(f"同步写入失败: {e}")
  333. db.add(CrawlerSyncLog(crawler_version=remote_version, status="failed", error_message=str(e)))
  334. db.commit()
  335. return {"synced": False, "error": str(e)}