""" OCR任务表添加新字段迁移脚本 添加字段:thumbnail_url(缩略图URL)、bill(消费金额)、model(使用模型) 运行: python -m scripts.migrate_ocr_add_fields """ import os import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.parent)) from dotenv import load_dotenv load_dotenv() import psycopg2 from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT def get_db_connection(): """获取数据库连接""" return psycopg2.connect( host=os.getenv('DB_HOST', 'localhost'), port=os.getenv('DB_PORT', '5432'), user=os.getenv('DB_USER', 'postgres'), password=os.getenv('DB_PASSWORD', ''), database=os.getenv('DB_NAME', 'model_square') ) def table_exists(cursor): """检查表是否存在""" cursor.execute(""" SELECT EXISTS ( SELECT FROM information_schema.tables WHERE table_schema = 'aigcspace' AND table_name = 'ocr_tasks' ); """) return cursor.fetchone()[0] def column_exists(cursor, column_name): """检查列是否存在""" cursor.execute(""" SELECT EXISTS ( SELECT FROM information_schema.columns WHERE table_schema = 'aigcspace' AND table_name = 'ocr_tasks' AND column_name = %s ); """, (column_name,)) return cursor.fetchone()[0] def migrate(): """执行迁移""" conn = get_db_connection() conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) cursor = conn.cursor() try: if not table_exists(cursor): print("❌ 表 aigcspace.ocr_tasks 不存在,请先运行 migrate_ocr.py") return print("开始添加新字段到 OCR 任务表...") # 添加 thumbnail_url 字段 if not column_exists(cursor, 'thumbnail_url'): cursor.execute(""" ALTER TABLE aigcspace.ocr_tasks ADD COLUMN thumbnail_url VARCHAR(500); """) cursor.execute(""" COMMENT ON COLUMN aigcspace.ocr_tasks.thumbnail_url IS '缩略图URL'; """) print("✓ 添加字段 thumbnail_url") else: print("✓ 字段 thumbnail_url 已存在") # 添加 model 字段 if not column_exists(cursor, 'model'): cursor.execute(""" ALTER TABLE aigcspace.ocr_tasks ADD COLUMN model VARCHAR(100); """) cursor.execute(""" COMMENT ON COLUMN aigcspace.ocr_tasks.model IS '使用的OCR模型'; """) print("✓ 添加字段 model") else: print("✓ 字段 model 已存在") # 添加 input_tokens 字段 if not column_exists(cursor, 'input_tokens'): cursor.execute(""" ALTER TABLE aigcspace.ocr_tasks ADD COLUMN input_tokens INTEGER DEFAULT 0; """) cursor.execute(""" COMMENT ON COLUMN aigcspace.ocr_tasks.input_tokens IS '输入Token数(图片token)'; """) print("✓ 添加字段 input_tokens") else: print("✓ 字段 input_tokens 已存在") # 添加 output_tokens 字段 if not column_exists(cursor, 'output_tokens'): cursor.execute(""" ALTER TABLE aigcspace.ocr_tasks ADD COLUMN output_tokens INTEGER DEFAULT 0; """) cursor.execute(""" COMMENT ON COLUMN aigcspace.ocr_tasks.output_tokens IS '输出Token数(识别文本token)'; """) print("✓ 添加字段 output_tokens") else: print("✓ 字段 output_tokens 已存在") # 添加 bill 字段 if not column_exists(cursor, 'bill'): cursor.execute(""" ALTER TABLE aigcspace.ocr_tasks ADD COLUMN bill NUMERIC(10, 4) DEFAULT 0; """) cursor.execute(""" COMMENT ON COLUMN aigcspace.ocr_tasks.bill IS '消费金额(元)'; """) print("✓ 添加字段 bill") else: print("✓ 字段 bill 已存在") print("\n✅ OCR任务表字段添加完成!") except Exception as e: print(f"❌ 迁移失败: {e}") raise finally: cursor.close() conn.close() def rollback(): """回滚迁移""" conn = get_db_connection() conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) cursor = conn.cursor() try: if not table_exists(cursor): print("✓ 表 aigcspace.ocr_tasks 不存在,无需回滚") return print("开始回滚 OCR 任务表字段...") # 删除字段 if column_exists(cursor, 'thumbnail_url'): cursor.execute("ALTER TABLE aigcspace.ocr_tasks DROP COLUMN thumbnail_url;") print("✓ 删除字段 thumbnail_url") if column_exists(cursor, 'model'): cursor.execute("ALTER TABLE aigcspace.ocr_tasks DROP COLUMN model;") print("✓ 删除字段 model") if column_exists(cursor, 'input_tokens'): cursor.execute("ALTER TABLE aigcspace.ocr_tasks DROP COLUMN input_tokens;") print("✓ 删除字段 input_tokens") if column_exists(cursor, 'output_tokens'): cursor.execute("ALTER TABLE aigcspace.ocr_tasks DROP COLUMN output_tokens;") print("✓ 删除字段 output_tokens") if column_exists(cursor, 'bill'): cursor.execute("ALTER TABLE aigcspace.ocr_tasks DROP COLUMN bill;") print("✓ 删除字段 bill") print("\n✅ OCR任务表字段回滚完成!") except Exception as e: print(f"❌ 回滚失败: {e}") raise finally: cursor.close() conn.close() def status(): """查看迁移状态""" conn = get_db_connection() conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) cursor = conn.cursor() try: if not table_exists(cursor): print("❌ 表 aigcspace.ocr_tasks 不存在") return print("OCR任务表字段状态:") fields = ['thumbnail_url', 'model', 'input_tokens', 'output_tokens', 'bill'] for field in fields: exists = column_exists(cursor, field) status_icon = '✓' if exists else '✗' print(f"{status_icon} {field}: {'已添加' if exists else '未添加'}") except Exception as e: print(f"❌ 查看状态失败: {e}") raise finally: cursor.close() conn.close() if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description='OCR任务表字段迁移脚本') parser.add_argument('--action', choices=['migrate', 'rollback', 'status'], default='migrate', help='操作类型') args = parser.parse_args() if args.action == 'migrate': migrate() elif args.action == 'rollback': rollback() elif args.action == 'status': status()