migrate_ocr_add_fields.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. """
  2. OCR任务表添加新字段迁移脚本
  3. 添加字段:thumbnail_url(缩略图URL)、bill(消费金额)、model(使用模型)
  4. 运行: python -m scripts.migrate_ocr_add_fields
  5. """
  6. import os
  7. import sys
  8. from pathlib import Path
  9. sys.path.insert(0, str(Path(__file__).parent.parent))
  10. from dotenv import load_dotenv
  11. load_dotenv()
  12. import psycopg2
  13. from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT
  14. def get_db_connection():
  15. """获取数据库连接"""
  16. return psycopg2.connect(
  17. host=os.getenv('DB_HOST', 'localhost'),
  18. port=os.getenv('DB_PORT', '5432'),
  19. user=os.getenv('DB_USER', 'postgres'),
  20. password=os.getenv('DB_PASSWORD', ''),
  21. database=os.getenv('DB_NAME', 'model_square')
  22. )
  23. def table_exists(cursor):
  24. """检查表是否存在"""
  25. cursor.execute("""
  26. SELECT EXISTS (
  27. SELECT FROM information_schema.tables
  28. WHERE table_schema = 'aigcspace'
  29. AND table_name = 'ocr_tasks'
  30. );
  31. """)
  32. return cursor.fetchone()[0]
  33. def column_exists(cursor, column_name):
  34. """检查列是否存在"""
  35. cursor.execute("""
  36. SELECT EXISTS (
  37. SELECT FROM information_schema.columns
  38. WHERE table_schema = 'aigcspace'
  39. AND table_name = 'ocr_tasks'
  40. AND column_name = %s
  41. );
  42. """, (column_name,))
  43. return cursor.fetchone()[0]
  44. def migrate():
  45. """执行迁移"""
  46. conn = get_db_connection()
  47. conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
  48. cursor = conn.cursor()
  49. try:
  50. if not table_exists(cursor):
  51. print("❌ 表 aigcspace.ocr_tasks 不存在,请先运行 migrate_ocr.py")
  52. return
  53. print("开始添加新字段到 OCR 任务表...")
  54. # 添加 thumbnail_url 字段
  55. if not column_exists(cursor, 'thumbnail_url'):
  56. cursor.execute("""
  57. ALTER TABLE aigcspace.ocr_tasks
  58. ADD COLUMN thumbnail_url VARCHAR(500);
  59. """)
  60. cursor.execute("""
  61. COMMENT ON COLUMN aigcspace.ocr_tasks.thumbnail_url IS '缩略图URL';
  62. """)
  63. print("✓ 添加字段 thumbnail_url")
  64. else:
  65. print("✓ 字段 thumbnail_url 已存在")
  66. # 添加 model 字段
  67. if not column_exists(cursor, 'model'):
  68. cursor.execute("""
  69. ALTER TABLE aigcspace.ocr_tasks
  70. ADD COLUMN model VARCHAR(100);
  71. """)
  72. cursor.execute("""
  73. COMMENT ON COLUMN aigcspace.ocr_tasks.model IS '使用的OCR模型';
  74. """)
  75. print("✓ 添加字段 model")
  76. else:
  77. print("✓ 字段 model 已存在")
  78. # 添加 input_tokens 字段
  79. if not column_exists(cursor, 'input_tokens'):
  80. cursor.execute("""
  81. ALTER TABLE aigcspace.ocr_tasks
  82. ADD COLUMN input_tokens INTEGER DEFAULT 0;
  83. """)
  84. cursor.execute("""
  85. COMMENT ON COLUMN aigcspace.ocr_tasks.input_tokens IS '输入Token数(图片token)';
  86. """)
  87. print("✓ 添加字段 input_tokens")
  88. else:
  89. print("✓ 字段 input_tokens 已存在")
  90. # 添加 output_tokens 字段
  91. if not column_exists(cursor, 'output_tokens'):
  92. cursor.execute("""
  93. ALTER TABLE aigcspace.ocr_tasks
  94. ADD COLUMN output_tokens INTEGER DEFAULT 0;
  95. """)
  96. cursor.execute("""
  97. COMMENT ON COLUMN aigcspace.ocr_tasks.output_tokens IS '输出Token数(识别文本token)';
  98. """)
  99. print("✓ 添加字段 output_tokens")
  100. else:
  101. print("✓ 字段 output_tokens 已存在")
  102. # 添加 bill 字段
  103. if not column_exists(cursor, 'bill'):
  104. cursor.execute("""
  105. ALTER TABLE aigcspace.ocr_tasks
  106. ADD COLUMN bill NUMERIC(10, 4) DEFAULT 0;
  107. """)
  108. cursor.execute("""
  109. COMMENT ON COLUMN aigcspace.ocr_tasks.bill IS '消费金额(元)';
  110. """)
  111. print("✓ 添加字段 bill")
  112. else:
  113. print("✓ 字段 bill 已存在")
  114. print("\n✅ OCR任务表字段添加完成!")
  115. except Exception as e:
  116. print(f"❌ 迁移失败: {e}")
  117. raise
  118. finally:
  119. cursor.close()
  120. conn.close()
  121. def rollback():
  122. """回滚迁移"""
  123. conn = get_db_connection()
  124. conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
  125. cursor = conn.cursor()
  126. try:
  127. if not table_exists(cursor):
  128. print("✓ 表 aigcspace.ocr_tasks 不存在,无需回滚")
  129. return
  130. print("开始回滚 OCR 任务表字段...")
  131. # 删除字段
  132. if column_exists(cursor, 'thumbnail_url'):
  133. cursor.execute("ALTER TABLE aigcspace.ocr_tasks DROP COLUMN thumbnail_url;")
  134. print("✓ 删除字段 thumbnail_url")
  135. if column_exists(cursor, 'model'):
  136. cursor.execute("ALTER TABLE aigcspace.ocr_tasks DROP COLUMN model;")
  137. print("✓ 删除字段 model")
  138. if column_exists(cursor, 'input_tokens'):
  139. cursor.execute("ALTER TABLE aigcspace.ocr_tasks DROP COLUMN input_tokens;")
  140. print("✓ 删除字段 input_tokens")
  141. if column_exists(cursor, 'output_tokens'):
  142. cursor.execute("ALTER TABLE aigcspace.ocr_tasks DROP COLUMN output_tokens;")
  143. print("✓ 删除字段 output_tokens")
  144. if column_exists(cursor, 'bill'):
  145. cursor.execute("ALTER TABLE aigcspace.ocr_tasks DROP COLUMN bill;")
  146. print("✓ 删除字段 bill")
  147. print("\n✅ OCR任务表字段回滚完成!")
  148. except Exception as e:
  149. print(f"❌ 回滚失败: {e}")
  150. raise
  151. finally:
  152. cursor.close()
  153. conn.close()
  154. def status():
  155. """查看迁移状态"""
  156. conn = get_db_connection()
  157. conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
  158. cursor = conn.cursor()
  159. try:
  160. if not table_exists(cursor):
  161. print("❌ 表 aigcspace.ocr_tasks 不存在")
  162. return
  163. print("OCR任务表字段状态:")
  164. fields = ['thumbnail_url', 'model', 'input_tokens', 'output_tokens', 'bill']
  165. for field in fields:
  166. exists = column_exists(cursor, field)
  167. status_icon = '✓' if exists else '✗'
  168. print(f"{status_icon} {field}: {'已添加' if exists else '未添加'}")
  169. except Exception as e:
  170. print(f"❌ 查看状态失败: {e}")
  171. raise
  172. finally:
  173. cursor.close()
  174. conn.close()
  175. if __name__ == "__main__":
  176. import argparse
  177. parser = argparse.ArgumentParser(description='OCR任务表字段迁移脚本')
  178. parser.add_argument('--action', choices=['migrate', 'rollback', 'status'],
  179. default='migrate', help='操作类型')
  180. args = parser.parse_args()
  181. if args.action == 'migrate':
  182. migrate()
  183. elif args.action == 'rollback':
  184. rollback()
  185. elif args.action == 'status':
  186. status()