#!/usr/bin/env python3 """ 导入爬虫数据到新的数据库表 支持新格式({.txt): 每个 model 有 prices 字段,结构为: { "label": { "输入": { price, unit, price_original? }, "输出": {...} } } 或简单格式: { "label": { price, unit, price_original? } } - 导入模型数据到 models_new 表 - 导入价格数据到 model_price_new 表 """ import json import sys from pathlib import Path from datetime import datetime from decimal import Decimal sys.path.insert(0, str(Path(__file__).parent.parent)) from app.database import SessionLocal from app.models.model import ModelNew, ModelPriceNew, infer_categories def _compute_discount_label(discount_rate: float) -> str | None: """将折扣率转换为折扣标签,如 0.1 -> '1折',1.0 -> None""" if discount_rate >= 1.0 or discount_rate <= 0: return None tenths = round(discount_rate * 10) if tenths >= 10: return None return f"{tenths}折" def _parse_unit_multiplier(unit: str) -> int: """根据单位字符串推断 display_multiplier""" u = unit.lower() if '百万' in u or 'million' in u or 'tokens' in u: return 1_000_000 if '千' in u or '1k' in u: return 1_000 return 1 def _upsert_price_row(db, model_code: str, label: str, scraped_at: datetime, input_orig: float, output_orig: float, input_disc: float, output_disc: float, unit: str, currency: str = 'CNY', tier_min=None, tier_max=None, tier_unit: str | None = None, source_url: str | None = None): """插入或更新一条价格记录(同一 model_code+label 只保留一条 is_active=true)""" # 计算折扣率 if input_orig and input_orig > 0 and abs(input_disc - input_orig) > 1e-9: discount_rate = round(input_disc / input_orig, 4) elif output_orig and output_orig > 0 and abs(output_disc - output_orig) > 1e-9: discount_rate = round(output_disc / output_orig, 4) else: discount_rate = 1.0 discount_label = _compute_discount_label(discount_rate) multiplier = _parse_unit_multiplier(unit) # 将旧的 is_active 记录标记为 false db.query(ModelPriceNew).filter( ModelPriceNew.model_code == model_code, ModelPriceNew.label == label, ModelPriceNew.is_active == True, ).update({"is_active": False}, synchronize_session=False) price = ModelPriceNew( model_code=model_code, label=label, tier_min=tier_min, tier_max=tier_max, tier_unit=tier_unit, input_price_original=input_orig, output_price_original=output_orig, discount_rate=discount_rate, discount_label=discount_label, input_price_discounted=input_disc, output_price_discounted=output_disc, currency=currency, unit=unit, display_multiplier=multiplier, source_url=source_url, crawled_at=scraped_at, is_active=True, ) db.add(price) return price def _parse_prices_new_format(model_data: dict, db, scraped_at: datetime): """ 解析新格式 prices 字段: { "label": { "输入": { "price": 2.5, "unit": "元/每百万tokens", "price_original": 10.0 }, "输出": { ... }, ... } } 或简单格式(无子键): { "label": { "price": 2.0, "unit": "元/每万字符" } } """ model_code = model_data.get('model_name') prices_raw = model_data.get('prices', {}) source_url = model_data.get('url') rows = [] for label, label_data in prices_raw.items(): if not isinstance(label_data, dict): continue # 判断是否为简单格式(直接有 price 字段) if 'price' in label_data: # 简单格式:整个 label 只有一个价格 price_val = float(label_data.get('price', 0) or 0) price_orig = label_data.get('price_original') orig_val = float(price_orig) if price_orig is not None else price_val unit = label_data.get('unit', '') currency = label_data.get('currency', 'CNY') rows.append((label, orig_val, 0.0, price_val, 0.0, unit, currency)) else: # 复合格式:label_data 的 key 是子项名(输入/输出/显式缓存创建/...) # 找输入和输出价格 input_entry = label_data.get('输入') or label_data.get('input') or {} output_entry = label_data.get('输出') or label_data.get('output') or {} if not input_entry and not output_entry: # 没有输入/输出,把所有子项各自作为独立 label for sub_key, sub_val in label_data.items(): if not isinstance(sub_val, dict) or 'price' not in sub_val: continue sub_label = f"{label}_{sub_key}" price_val = float(sub_val.get('price', 0) or 0) price_orig = sub_val.get('price_original') orig_val = float(price_orig) if price_orig is not None else price_val unit = sub_val.get('unit', '') currency = sub_val.get('currency', 'CNY') rows.append((sub_label, orig_val, 0.0, price_val, 0.0, unit, currency)) continue input_price = float(input_entry.get('price', 0) or 0) input_orig_raw = input_entry.get('price_original') input_orig = float(input_orig_raw) if input_orig_raw is not None else input_price output_price = float(output_entry.get('price', 0) or 0) output_orig_raw = output_entry.get('price_original') output_orig = float(output_orig_raw) if output_orig_raw is not None else output_price unit = (input_entry.get('unit') or output_entry.get('unit') or '') currency = (input_entry.get('currency') or output_entry.get('currency') or 'CNY') rows.append((label, input_orig, output_orig, input_price, output_price, unit, currency)) price_objs = [] for label, in_orig, out_orig, in_disc, out_disc, unit, currency in rows: p = _upsert_price_row( db, model_code, label, scraped_at, in_orig, out_orig, in_disc, out_disc, unit, currency, source_url=source_url ) price_objs.append(p) return price_objs def import_crawl_data(crawl_file: str): with open(crawl_file, 'r', encoding='utf-8') as f: data = json.load(f) db = SessionLocal() try: models_list = data if isinstance(data, list) else data.get('models', []) version = data.get('version') if isinstance(data, dict) else None print(f"开始导入,共 {len(models_list)} 个模型,版本号: {version}") model_count = 0 price_count = 0 for model_data in models_list: model_code = model_data.get('model_name') if not model_code: continue model_info = model_data.get('model_info', {}) input_mods = model_info.get('input_modalities') or [] output_mods = model_info.get('output_modalities') or [] display_tags = model_info.get('display_tags') or [] cats = infer_categories(input_mods, output_mods, model_code, display_tags) scraped_at_str = model_data.get('scraped_at') scraped_at = datetime.fromisoformat(scraped_at_str.replace('Z', '+00:00')) if scraped_at_str else datetime.now() existing = db.query(ModelNew).filter(ModelNew.model_code == model_code).first() if existing: # 更新爬虫字段(不覆盖管理员配置字段) existing.description = model_info.get('description') or existing.description existing.display_tags = display_tags or existing.display_tags existing.input_modalities = input_mods or existing.input_modalities existing.output_modalities = output_mods or existing.output_modalities existing.features = model_info.get('features') or existing.features existing.rate_limits = model_data.get('rate_limits') or existing.rate_limits existing.tool_call_prices = model_data.get('tool_prices') or existing.tool_call_prices existing.raw_prices = model_data.get('prices') or existing.raw_prices existing.source_url = model_data.get('url') or existing.source_url existing.crawled_at = scraped_at existing.categories = cats existing.updated_at = datetime.now() else: model = ModelNew( model_code=model_code, description=model_info.get('description'), display_tags=display_tags, input_modalities=input_mods, output_modalities=output_mods, features=model_info.get('features'), rate_limits=model_data.get('rate_limits'), tool_call_prices=model_data.get('tool_prices'), raw_prices=model_data.get('prices'), source_url=model_data.get('url'), crawled_at=scraped_at, display_name=model_code, supplier='Qwen', categories=cats, keywords=','.join(display_tags), is_show_enabled=True, is_api_enabled=True, ) db.add(model) model_count += 1 # 导入价格 price_objs = _parse_prices_new_format(model_data, db, scraped_at) price_count += len(price_objs) db.commit() print(f"完成:新增模型 {model_count} 个,价格记录 {price_count} 条") # 记录同步日志 if version is not None: from app.models.model import CrawlerSyncLog log = CrawlerSyncLog( crawler_version=version, model_count=model_count, price_count=price_count, status='success', ) db.add(log) db.commit() except Exception as e: db.rollback() print(f"导入失败: {e}") raise finally: db.close() if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="导入爬虫数据到数据库") parser.add_argument( "--file", type=str, default=str(Path(__file__).parent.parent.parent / "{.txt"), help="爬虫数据文件路径(默认:项目根目录的 {.txt)", ) args = parser.parse_args() crawl_file = args.file if not Path(crawl_file).exists(): print(f"文件不存在: {crawl_file}") sys.exit(1) print(f"导入数据从: {crawl_file}") import_crawl_data(crawl_file) print("导入完成!")