import_crawl_data.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. #!/usr/bin/env python3
  2. """
  3. 导入爬虫数据到新的数据库表
  4. 支持新格式({.txt):
  5. 每个 model 有 prices 字段,结构为:
  6. { "label": { "输入": { price, unit, price_original? }, "输出": {...} } }
  7. 或简单格式:
  8. { "label": { price, unit, price_original? } }
  9. - 导入模型数据到 models_new 表
  10. - 导入价格数据到 model_price_new 表
  11. """
  12. import json
  13. import sys
  14. from pathlib import Path
  15. from datetime import datetime
  16. from decimal import Decimal
  17. sys.path.insert(0, str(Path(__file__).parent.parent))
  18. from app.database import SessionLocal
  19. from app.models.model import ModelNew, ModelPriceNew, infer_categories
  20. def _compute_discount_label(discount_rate: float) -> str | None:
  21. """将折扣率转换为折扣标签,如 0.1 -> '1折',1.0 -> None"""
  22. if discount_rate >= 1.0 or discount_rate <= 0:
  23. return None
  24. tenths = round(discount_rate * 10)
  25. if tenths >= 10:
  26. return None
  27. return f"{tenths}折"
  28. def _parse_unit_multiplier(unit: str) -> int:
  29. """根据单位字符串推断 display_multiplier"""
  30. u = unit.lower()
  31. if '百万' in u or 'million' in u or 'tokens' in u:
  32. return 1_000_000
  33. if '千' in u or '1k' in u:
  34. return 1_000
  35. return 1
  36. def _upsert_price_row(db, model_code: str, label: str, scraped_at: datetime,
  37. input_orig: float, output_orig: float,
  38. input_disc: float, output_disc: float,
  39. unit: str, currency: str = 'CNY',
  40. tier_min=None, tier_max=None, tier_unit: str | None = None,
  41. source_url: str | None = None):
  42. """插入或更新一条价格记录(同一 model_code+label 只保留一条 is_active=true)"""
  43. # 计算折扣率
  44. if input_orig and input_orig > 0 and abs(input_disc - input_orig) > 1e-9:
  45. discount_rate = round(input_disc / input_orig, 4)
  46. elif output_orig and output_orig > 0 and abs(output_disc - output_orig) > 1e-9:
  47. discount_rate = round(output_disc / output_orig, 4)
  48. else:
  49. discount_rate = 1.0
  50. discount_label = _compute_discount_label(discount_rate)
  51. multiplier = _parse_unit_multiplier(unit)
  52. # 将旧的 is_active 记录标记为 false
  53. db.query(ModelPriceNew).filter(
  54. ModelPriceNew.model_code == model_code,
  55. ModelPriceNew.label == label,
  56. ModelPriceNew.is_active == True,
  57. ).update({"is_active": False}, synchronize_session=False)
  58. price = ModelPriceNew(
  59. model_code=model_code,
  60. label=label,
  61. tier_min=tier_min,
  62. tier_max=tier_max,
  63. tier_unit=tier_unit,
  64. input_price_original=input_orig,
  65. output_price_original=output_orig,
  66. discount_rate=discount_rate,
  67. discount_label=discount_label,
  68. input_price_discounted=input_disc,
  69. output_price_discounted=output_disc,
  70. currency=currency,
  71. unit=unit,
  72. display_multiplier=multiplier,
  73. source_url=source_url,
  74. crawled_at=scraped_at,
  75. is_active=True,
  76. )
  77. db.add(price)
  78. return price
  79. def _parse_prices_new_format(model_data: dict, db, scraped_at: datetime):
  80. """
  81. 解析新格式 prices 字段:
  82. {
  83. "label": {
  84. "输入": { "price": 2.5, "unit": "元/每百万tokens", "price_original": 10.0 },
  85. "输出": { ... },
  86. ...
  87. }
  88. }
  89. 或简单格式(无子键):
  90. {
  91. "label": { "price": 2.0, "unit": "元/每万字符" }
  92. }
  93. """
  94. model_code = model_data.get('model_name')
  95. prices_raw = model_data.get('prices', {})
  96. source_url = model_data.get('url')
  97. rows = []
  98. for label, label_data in prices_raw.items():
  99. if not isinstance(label_data, dict):
  100. continue
  101. # 判断是否为简单格式(直接有 price 字段)
  102. if 'price' in label_data:
  103. # 简单格式:整个 label 只有一个价格
  104. price_val = float(label_data.get('price', 0) or 0)
  105. price_orig = label_data.get('price_original')
  106. orig_val = float(price_orig) if price_orig is not None else price_val
  107. unit = label_data.get('unit', '')
  108. currency = label_data.get('currency', 'CNY')
  109. rows.append((label, orig_val, 0.0, price_val, 0.0, unit, currency))
  110. else:
  111. # 复合格式:label_data 的 key 是子项名(输入/输出/显式缓存创建/...)
  112. # 找输入和输出价格
  113. input_entry = label_data.get('输入') or label_data.get('input') or {}
  114. output_entry = label_data.get('输出') or label_data.get('output') or {}
  115. if not input_entry and not output_entry:
  116. # 没有输入/输出,把所有子项各自作为独立 label
  117. for sub_key, sub_val in label_data.items():
  118. if not isinstance(sub_val, dict) or 'price' not in sub_val:
  119. continue
  120. sub_label = f"{label}_{sub_key}"
  121. price_val = float(sub_val.get('price', 0) or 0)
  122. price_orig = sub_val.get('price_original')
  123. orig_val = float(price_orig) if price_orig is not None else price_val
  124. unit = sub_val.get('unit', '')
  125. currency = sub_val.get('currency', 'CNY')
  126. rows.append((sub_label, orig_val, 0.0, price_val, 0.0, unit, currency))
  127. continue
  128. input_price = float(input_entry.get('price', 0) or 0)
  129. input_orig_raw = input_entry.get('price_original')
  130. input_orig = float(input_orig_raw) if input_orig_raw is not None else input_price
  131. output_price = float(output_entry.get('price', 0) or 0)
  132. output_orig_raw = output_entry.get('price_original')
  133. output_orig = float(output_orig_raw) if output_orig_raw is not None else output_price
  134. unit = (input_entry.get('unit') or output_entry.get('unit') or '')
  135. currency = (input_entry.get('currency') or output_entry.get('currency') or 'CNY')
  136. rows.append((label, input_orig, output_orig, input_price, output_price, unit, currency))
  137. price_objs = []
  138. for label, in_orig, out_orig, in_disc, out_disc, unit, currency in rows:
  139. p = _upsert_price_row(
  140. db, model_code, label, scraped_at,
  141. in_orig, out_orig, in_disc, out_disc,
  142. unit, currency, source_url=source_url
  143. )
  144. price_objs.append(p)
  145. return price_objs
  146. def import_crawl_data(crawl_file: str):
  147. with open(crawl_file, 'r', encoding='utf-8') as f:
  148. data = json.load(f)
  149. db = SessionLocal()
  150. try:
  151. models_list = data if isinstance(data, list) else data.get('models', [])
  152. version = data.get('version') if isinstance(data, dict) else None
  153. print(f"开始导入,共 {len(models_list)} 个模型,版本号: {version}")
  154. model_count = 0
  155. price_count = 0
  156. for model_data in models_list:
  157. model_code = model_data.get('model_name')
  158. if not model_code:
  159. continue
  160. model_info = model_data.get('model_info', {})
  161. input_mods = model_info.get('input_modalities') or []
  162. output_mods = model_info.get('output_modalities') or []
  163. display_tags = model_info.get('display_tags') or []
  164. cats = infer_categories(input_mods, output_mods, model_code, display_tags)
  165. scraped_at_str = model_data.get('scraped_at')
  166. scraped_at = datetime.fromisoformat(scraped_at_str.replace('Z', '+00:00')) if scraped_at_str else datetime.now()
  167. existing = db.query(ModelNew).filter(ModelNew.model_code == model_code).first()
  168. if existing:
  169. # 更新爬虫字段(不覆盖管理员配置字段)
  170. existing.description = model_info.get('description') or existing.description
  171. existing.display_tags = display_tags or existing.display_tags
  172. existing.input_modalities = input_mods or existing.input_modalities
  173. existing.output_modalities = output_mods or existing.output_modalities
  174. existing.features = model_info.get('features') or existing.features
  175. existing.rate_limits = model_data.get('rate_limits') or existing.rate_limits
  176. existing.tool_call_prices = model_data.get('tool_prices') or existing.tool_call_prices
  177. existing.raw_prices = model_data.get('prices') or existing.raw_prices
  178. existing.source_url = model_data.get('url') or existing.source_url
  179. existing.crawled_at = scraped_at
  180. existing.categories = cats
  181. existing.updated_at = datetime.now()
  182. else:
  183. model = ModelNew(
  184. model_code=model_code,
  185. description=model_info.get('description'),
  186. display_tags=display_tags,
  187. input_modalities=input_mods,
  188. output_modalities=output_mods,
  189. features=model_info.get('features'),
  190. rate_limits=model_data.get('rate_limits'),
  191. tool_call_prices=model_data.get('tool_prices'),
  192. raw_prices=model_data.get('prices'),
  193. source_url=model_data.get('url'),
  194. crawled_at=scraped_at,
  195. display_name=model_code,
  196. supplier='Qwen',
  197. categories=cats,
  198. keywords=','.join(display_tags),
  199. is_show_enabled=True,
  200. is_api_enabled=True,
  201. )
  202. db.add(model)
  203. model_count += 1
  204. # 导入价格
  205. price_objs = _parse_prices_new_format(model_data, db, scraped_at)
  206. price_count += len(price_objs)
  207. db.commit()
  208. print(f"完成:新增模型 {model_count} 个,价格记录 {price_count} 条")
  209. # 记录同步日志
  210. if version is not None:
  211. from app.models.model import CrawlerSyncLog
  212. log = CrawlerSyncLog(
  213. crawler_version=version,
  214. model_count=model_count,
  215. price_count=price_count,
  216. status='success',
  217. )
  218. db.add(log)
  219. db.commit()
  220. except Exception as e:
  221. db.rollback()
  222. print(f"导入失败: {e}")
  223. raise
  224. finally:
  225. db.close()
  226. if __name__ == "__main__":
  227. import argparse
  228. parser = argparse.ArgumentParser(description="导入爬虫数据到数据库")
  229. parser.add_argument(
  230. "--file",
  231. type=str,
  232. default=str(Path(__file__).parent.parent.parent / "{.txt"),
  233. help="爬虫数据文件路径(默认:项目根目录的 {.txt)",
  234. )
  235. args = parser.parse_args()
  236. crawl_file = args.file
  237. if not Path(crawl_file).exists():
  238. print(f"文件不存在: {crawl_file}")
  239. sys.exit(1)
  240. print(f"导入数据从: {crawl_file}")
  241. import_crawl_data(crawl_file)
  242. print("导入完成!")