| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282 |
- #!/usr/bin/env python3
- """
- 导入爬虫数据到新的数据库表
- 支持新格式({.txt):
- 每个 model 有 prices 字段,结构为:
- { "label": { "输入": { price, unit, price_original? }, "输出": {...} } }
- 或简单格式:
- { "label": { price, unit, price_original? } }
- - 导入模型数据到 models_new 表
- - 导入价格数据到 model_price_new 表
- """
- import json
- import sys
- from pathlib import Path
- from datetime import datetime
- from decimal import Decimal
- sys.path.insert(0, str(Path(__file__).parent.parent))
- from app.database import SessionLocal
- from app.models.model import ModelNew, ModelPriceNew, infer_categories
- def _compute_discount_label(discount_rate: float) -> str | None:
- """将折扣率转换为折扣标签,如 0.1 -> '1折',1.0 -> None"""
- if discount_rate >= 1.0 or discount_rate <= 0:
- return None
- tenths = round(discount_rate * 10)
- if tenths >= 10:
- return None
- return f"{tenths}折"
- def _parse_unit_multiplier(unit: str) -> int:
- """根据单位字符串推断 display_multiplier"""
- u = unit.lower()
- if '百万' in u or 'million' in u or 'tokens' in u:
- return 1_000_000
- if '千' in u or '1k' in u:
- return 1_000
- return 1
- def _upsert_price_row(db, model_code: str, label: str, scraped_at: datetime,
- input_orig: float, output_orig: float,
- input_disc: float, output_disc: float,
- unit: str, currency: str = 'CNY',
- tier_min=None, tier_max=None, tier_unit: str | None = None,
- source_url: str | None = None):
- """插入或更新一条价格记录(同一 model_code+label 只保留一条 is_active=true)"""
- # 计算折扣率
- if input_orig and input_orig > 0 and abs(input_disc - input_orig) > 1e-9:
- discount_rate = round(input_disc / input_orig, 4)
- elif output_orig and output_orig > 0 and abs(output_disc - output_orig) > 1e-9:
- discount_rate = round(output_disc / output_orig, 4)
- else:
- discount_rate = 1.0
- discount_label = _compute_discount_label(discount_rate)
- multiplier = _parse_unit_multiplier(unit)
- # 将旧的 is_active 记录标记为 false
- db.query(ModelPriceNew).filter(
- ModelPriceNew.model_code == model_code,
- ModelPriceNew.label == label,
- ModelPriceNew.is_active == True,
- ).update({"is_active": False}, synchronize_session=False)
- price = ModelPriceNew(
- model_code=model_code,
- label=label,
- tier_min=tier_min,
- tier_max=tier_max,
- tier_unit=tier_unit,
- input_price_original=input_orig,
- output_price_original=output_orig,
- discount_rate=discount_rate,
- discount_label=discount_label,
- input_price_discounted=input_disc,
- output_price_discounted=output_disc,
- currency=currency,
- unit=unit,
- display_multiplier=multiplier,
- source_url=source_url,
- crawled_at=scraped_at,
- is_active=True,
- )
- db.add(price)
- return price
- def _parse_prices_new_format(model_data: dict, db, scraped_at: datetime):
- """
- 解析新格式 prices 字段:
- {
- "label": {
- "输入": { "price": 2.5, "unit": "元/每百万tokens", "price_original": 10.0 },
- "输出": { ... },
- ...
- }
- }
- 或简单格式(无子键):
- {
- "label": { "price": 2.0, "unit": "元/每万字符" }
- }
- """
- model_code = model_data.get('model_name')
- prices_raw = model_data.get('prices', {})
- source_url = model_data.get('url')
- rows = []
- for label, label_data in prices_raw.items():
- if not isinstance(label_data, dict):
- continue
- # 判断是否为简单格式(直接有 price 字段)
- if 'price' in label_data:
- # 简单格式:整个 label 只有一个价格
- price_val = float(label_data.get('price', 0) or 0)
- price_orig = label_data.get('price_original')
- orig_val = float(price_orig) if price_orig is not None else price_val
- unit = label_data.get('unit', '')
- currency = label_data.get('currency', 'CNY')
- rows.append((label, orig_val, 0.0, price_val, 0.0, unit, currency))
- else:
- # 复合格式:label_data 的 key 是子项名(输入/输出/显式缓存创建/...)
- # 找输入和输出价格
- input_entry = label_data.get('输入') or label_data.get('input') or {}
- output_entry = label_data.get('输出') or label_data.get('output') or {}
- if not input_entry and not output_entry:
- # 没有输入/输出,把所有子项各自作为独立 label
- for sub_key, sub_val in label_data.items():
- if not isinstance(sub_val, dict) or 'price' not in sub_val:
- continue
- sub_label = f"{label}_{sub_key}"
- price_val = float(sub_val.get('price', 0) or 0)
- price_orig = sub_val.get('price_original')
- orig_val = float(price_orig) if price_orig is not None else price_val
- unit = sub_val.get('unit', '')
- currency = sub_val.get('currency', 'CNY')
- rows.append((sub_label, orig_val, 0.0, price_val, 0.0, unit, currency))
- continue
- input_price = float(input_entry.get('price', 0) or 0)
- input_orig_raw = input_entry.get('price_original')
- input_orig = float(input_orig_raw) if input_orig_raw is not None else input_price
- output_price = float(output_entry.get('price', 0) or 0)
- output_orig_raw = output_entry.get('price_original')
- output_orig = float(output_orig_raw) if output_orig_raw is not None else output_price
- unit = (input_entry.get('unit') or output_entry.get('unit') or '')
- currency = (input_entry.get('currency') or output_entry.get('currency') or 'CNY')
- rows.append((label, input_orig, output_orig, input_price, output_price, unit, currency))
- price_objs = []
- for label, in_orig, out_orig, in_disc, out_disc, unit, currency in rows:
- p = _upsert_price_row(
- db, model_code, label, scraped_at,
- in_orig, out_orig, in_disc, out_disc,
- unit, currency, source_url=source_url
- )
- price_objs.append(p)
- return price_objs
- def import_crawl_data(crawl_file: str):
- with open(crawl_file, 'r', encoding='utf-8') as f:
- data = json.load(f)
- db = SessionLocal()
- try:
- models_list = data if isinstance(data, list) else data.get('models', [])
- version = data.get('version') if isinstance(data, dict) else None
- print(f"开始导入,共 {len(models_list)} 个模型,版本号: {version}")
- model_count = 0
- price_count = 0
- for model_data in models_list:
- model_code = model_data.get('model_name')
- if not model_code:
- continue
- model_info = model_data.get('model_info', {})
- input_mods = model_info.get('input_modalities') or []
- output_mods = model_info.get('output_modalities') or []
- display_tags = model_info.get('display_tags') or []
- cats = infer_categories(input_mods, output_mods, model_code, display_tags)
- scraped_at_str = model_data.get('scraped_at')
- scraped_at = datetime.fromisoformat(scraped_at_str.replace('Z', '+00:00')) if scraped_at_str else datetime.now()
- existing = db.query(ModelNew).filter(ModelNew.model_code == model_code).first()
- if existing:
- # 更新爬虫字段(不覆盖管理员配置字段)
- existing.description = model_info.get('description') or existing.description
- existing.display_tags = display_tags or existing.display_tags
- existing.input_modalities = input_mods or existing.input_modalities
- existing.output_modalities = output_mods or existing.output_modalities
- existing.features = model_info.get('features') or existing.features
- existing.rate_limits = model_data.get('rate_limits') or existing.rate_limits
- existing.tool_call_prices = model_data.get('tool_prices') or existing.tool_call_prices
- existing.raw_prices = model_data.get('prices') or existing.raw_prices
- existing.source_url = model_data.get('url') or existing.source_url
- existing.crawled_at = scraped_at
- existing.categories = cats
- existing.updated_at = datetime.now()
- else:
- model = ModelNew(
- model_code=model_code,
- description=model_info.get('description'),
- display_tags=display_tags,
- input_modalities=input_mods,
- output_modalities=output_mods,
- features=model_info.get('features'),
- rate_limits=model_data.get('rate_limits'),
- tool_call_prices=model_data.get('tool_prices'),
- raw_prices=model_data.get('prices'),
- source_url=model_data.get('url'),
- crawled_at=scraped_at,
- display_name=model_code,
- supplier='Qwen',
- categories=cats,
- keywords=','.join(display_tags),
- is_show_enabled=True,
- is_api_enabled=True,
- )
- db.add(model)
- model_count += 1
- # 导入价格
- price_objs = _parse_prices_new_format(model_data, db, scraped_at)
- price_count += len(price_objs)
- db.commit()
- print(f"完成:新增模型 {model_count} 个,价格记录 {price_count} 条")
- # 记录同步日志
- if version is not None:
- from app.models.model import CrawlerSyncLog
- log = CrawlerSyncLog(
- crawler_version=version,
- model_count=model_count,
- price_count=price_count,
- status='success',
- )
- db.add(log)
- db.commit()
- except Exception as e:
- db.rollback()
- print(f"导入失败: {e}")
- raise
- finally:
- db.close()
- if __name__ == "__main__":
- import argparse
- parser = argparse.ArgumentParser(description="导入爬虫数据到数据库")
- parser.add_argument(
- "--file",
- type=str,
- default=str(Path(__file__).parent.parent.parent / "{.txt"),
- help="爬虫数据文件路径(默认:项目根目录的 {.txt)",
- )
- args = parser.parse_args()
- crawl_file = args.file
- if not Path(crawl_file).exists():
- print(f"文件不存在: {crawl_file}")
- sys.exit(1)
- print(f"导入数据从: {crawl_file}")
- import_crawl_data(crawl_file)
- print("导入完成!")
|