init_price_data.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. """
  2. 价格数据初始化脚本
  3. 从model_prices.json文件导入价格数据到数据库
  4. 使用model_price.id(主键)作为关联字段
  5. """
  6. import json
  7. import sys
  8. from pathlib import Path
  9. sys.path.insert(0, str(Path(__file__).parent.parent))
  10. from sqlalchemy import text
  11. from app.database import SessionLocal
  12. def init_price_data():
  13. """初始化模型价格数据"""
  14. json_path = Path(__file__).parent / "model_prices.json"
  15. with open(json_path, 'r', encoding='utf-8') as f:
  16. prices_data = json.load(f)
  17. db = SessionLocal()
  18. price_inserted = 0
  19. tier_inserted = 0
  20. model_updated = 0
  21. for item in prices_data:
  22. model_name = item['model_name']
  23. pricing_mode = item['pricing_mode']
  24. unit = item.get('unit', 'tokens')
  25. currency = item.get('currency', 'CNY')
  26. # 检查模型是否存在
  27. model_row = db.execute(
  28. text("SELECT id, price_id FROM aigcspace.models WHERE title = :title"),
  29. {"title": model_name}
  30. ).fetchone()
  31. if not model_row:
  32. print(f"警告: 模型 {model_name} 不存在,跳过")
  33. continue
  34. model_id = model_row[0]
  35. existing_price_id = model_row[1]
  36. if pricing_mode == 'simple':
  37. input_price = item['input_price']
  38. output_price = item['output_price']
  39. else:
  40. first_tier = item['tiers'][0]
  41. input_price = first_tier['input_price']
  42. output_price = first_tier['output_price']
  43. if existing_price_id:
  44. # 更新现有价格
  45. db.execute(
  46. text("""UPDATE aigcspace.model_price
  47. SET input_price = :input_price, output_price = :output_price,
  48. pricing_mode = :pricing_mode, unit = :unit, currency = :currency,
  49. updated_at = CURRENT_TIMESTAMP
  50. WHERE id = :id"""),
  51. {"input_price": input_price, "output_price": output_price,
  52. "pricing_mode": pricing_mode, "unit": unit, "currency": currency,
  53. "id": existing_price_id}
  54. )
  55. if pricing_mode == 'tier':
  56. db.execute(
  57. text("DELETE FROM aigcspace.model_price_tier WHERE price_id = :price_id"),
  58. {"price_id": existing_price_id}
  59. )
  60. for tier in item['tiers']:
  61. db.execute(
  62. text("""INSERT INTO aigcspace.model_price_tier
  63. (price_id, tier_min, tier_max, input_price, output_price)
  64. VALUES (:price_id, :tier_min, :tier_max, :input_price, :output_price)"""),
  65. {"price_id": existing_price_id, "tier_min": tier['tier_min'],
  66. "tier_max": tier.get('tier_max'),
  67. "input_price": tier['input_price'], "output_price": tier['output_price']}
  68. )
  69. tier_inserted += 1
  70. model_updated += 1
  71. else:
  72. # 创建新价格记录,使用RETURNING获取自增ID
  73. result = db.execute(
  74. text("""INSERT INTO aigcspace.model_price
  75. (input_price, output_price, pricing_mode, unit, currency)
  76. VALUES (:input_price, :output_price, :pricing_mode, :unit, :currency)
  77. RETURNING id"""),
  78. {"input_price": input_price, "output_price": output_price,
  79. "pricing_mode": pricing_mode, "unit": unit, "currency": currency}
  80. )
  81. new_price_id = result.fetchone()[0]
  82. if pricing_mode == 'tier':
  83. for tier in item['tiers']:
  84. db.execute(
  85. text("""INSERT INTO aigcspace.model_price_tier
  86. (price_id, tier_min, tier_max, input_price, output_price)
  87. VALUES (:price_id, :tier_min, :tier_max, :input_price, :output_price)"""),
  88. {"price_id": new_price_id, "tier_min": tier['tier_min'],
  89. "tier_max": tier.get('tier_max'),
  90. "input_price": tier['input_price'], "output_price": tier['output_price']}
  91. )
  92. tier_inserted += 1
  93. # 更新模型的price_id为新创建的价格记录ID
  94. db.execute(
  95. text("UPDATE aigcspace.models SET price_id = :price_id WHERE id = :model_id"),
  96. {"price_id": new_price_id, "model_id": model_id}
  97. )
  98. price_inserted += 1
  99. db.commit()
  100. db.close()
  101. print(f"价格数据导入完成:")
  102. print(f" - 新增价格记录: {price_inserted} 条")
  103. print(f" - 更新价格记录: {model_updated} 条")
  104. print(f" - 阶梯价格记录: {tier_inserted} 条")
  105. if __name__ == "__main__":
  106. init_price_data()