| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198 |
- """
- OCR数据库迁移脚本
- 运行: python -m scripts.migrate_ocr
- """
- import os
- import sys
- from pathlib import Path
- sys.path.insert(0, str(Path(__file__).parent.parent))
- from dotenv import load_dotenv
- load_dotenv()
- import psycopg2
- from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT
- def get_db_connection():
- """获取数据库连接"""
- return psycopg2.connect(
- host=os.getenv('DB_HOST', 'localhost'),
- port=os.getenv('DB_PORT', '5432'),
- user=os.getenv('DB_USER', 'postgres'),
- password=os.getenv('DB_PASSWORD', ''),
- database=os.getenv('DB_NAME', 'model_square')
- )
- def table_exists(cursor):
- """检查表是否存在"""
- cursor.execute("""
- SELECT EXISTS (
- SELECT FROM information_schema.tables
- WHERE table_schema = 'aigcspace'
- AND table_name = 'ocr_tasks'
- );
- """)
- return cursor.fetchone()[0]
- def migrate():
- """执行迁移"""
- conn = get_db_connection()
- conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
- cursor = conn.cursor()
-
- try:
- if table_exists(cursor):
- print("✓ 表 aigcspace.ocr_tasks 已存在")
- return
-
- print("开始创建 OCR 任务表...")
-
- # 创建表
- cursor.execute("""
- CREATE TABLE aigcspace.ocr_tasks (
- id SERIAL PRIMARY KEY,
- task_id VARCHAR(64) UNIQUE NOT NULL,
- user_id VARCHAR(50) NOT NULL REFERENCES aigcspace.users(id) ON DELETE CASCADE,
- task_type VARCHAR(50) NOT NULL,
- custom_prompt TEXT,
- status VARCHAR(20) NOT NULL DEFAULT 'pending',
- progress INTEGER DEFAULT 0 CHECK (progress >= 0 AND progress <= 100),
- total_images INTEGER NOT NULL CHECK (total_images > 0 AND total_images <= 10),
- processed_images INTEGER DEFAULT 0 CHECK (processed_images >= 0),
- results JSONB DEFAULT '{}'::jsonb,
- error_message TEXT,
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
- updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
- completed_at TIMESTAMP,
-
- CONSTRAINT check_processed_images CHECK (processed_images <= total_images),
- CONSTRAINT check_status CHECK (status IN ('pending', 'processing', 'completed', 'failed', 'cancelled'))
- );
- """)
- print("✓ 表创建成功")
-
- # 创建索引
- indexes = [
- "CREATE INDEX idx_ocr_tasks_task_id ON aigcspace.ocr_tasks(task_id);",
- "CREATE INDEX idx_ocr_tasks_user_id ON aigcspace.ocr_tasks(user_id);",
- "CREATE INDEX idx_ocr_tasks_status ON aigcspace.ocr_tasks(status);",
- "CREATE INDEX idx_ocr_tasks_user_status ON aigcspace.ocr_tasks(user_id, status);",
- "CREATE INDEX idx_ocr_tasks_created_at ON aigcspace.ocr_tasks(created_at DESC);"
- ]
- for idx_sql in indexes:
- cursor.execute(idx_sql)
- print("✓ 索引创建成功")
-
- # 创建触发器
- cursor.execute("""
- CREATE OR REPLACE FUNCTION update_ocr_tasks_updated_at()
- RETURNS TRIGGER AS $$
- BEGIN
- NEW.updated_at = CURRENT_TIMESTAMP;
- RETURN NEW;
- END;
- $$ LANGUAGE plpgsql;
-
- CREATE TRIGGER trigger_update_ocr_tasks_updated_at
- BEFORE UPDATE ON aigcspace.ocr_tasks
- FOR EACH ROW
- EXECUTE FUNCTION update_ocr_tasks_updated_at();
- """)
- print("✓ 触发器创建成功")
-
- # 添加注释
- comments = [
- "COMMENT ON TABLE aigcspace.ocr_tasks IS 'OCR识别任务表,同时作为历史记录表';",
- "COMMENT ON COLUMN aigcspace.ocr_tasks.task_id IS '任务唯一标识符(UUID)';",
- "COMMENT ON COLUMN aigcspace.ocr_tasks.status IS '任务状态:pending/processing/completed/failed/cancelled';",
- "COMMENT ON COLUMN aigcspace.ocr_tasks.results IS 'JSONB格式存储:{images: [], thumbnails: [], texts: []}';"
- ]
- for comment_sql in comments:
- cursor.execute(comment_sql)
- print("✓ 注释添加成功")
-
- print("\n✅ OCR数据库迁移完成!")
-
- except Exception as e:
- print(f"❌ 迁移失败: {e}")
- raise
- finally:
- cursor.close()
- conn.close()
- def rollback():
- """回滚迁移"""
- conn = get_db_connection()
- conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
- cursor = conn.cursor()
-
- try:
- if not table_exists(cursor):
- print("✓ 表 aigcspace.ocr_tasks 不存在,无需回滚")
- return
-
- print("开始回滚 OCR 数据库迁移...")
-
- cursor.execute("DROP TRIGGER IF EXISTS trigger_update_ocr_tasks_updated_at ON aigcspace.ocr_tasks;")
- cursor.execute("DROP FUNCTION IF EXISTS update_ocr_tasks_updated_at();")
- cursor.execute("DROP TABLE IF EXISTS aigcspace.ocr_tasks CASCADE;")
-
- print("✅ OCR数据库迁移回滚完成!")
-
- except Exception as e:
- print(f"❌ 回滚失败: {e}")
- raise
- finally:
- cursor.close()
- conn.close()
- def status():
- """查看迁移状态"""
- conn = get_db_connection()
- conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
- cursor = conn.cursor()
-
- try:
- exists = table_exists(cursor)
- print(f"表 aigcspace.ocr_tasks: {'✓ 已创建' if exists else '✗ 未创建'}")
-
- if exists:
- cursor.execute("SELECT COUNT(*) FROM aigcspace.ocr_tasks;")
- count = cursor.fetchone()[0]
- print(f"记录数: {count}")
-
- cursor.execute("""
- SELECT indexname FROM pg_indexes
- WHERE schemaname = 'aigcspace' AND tablename = 'ocr_tasks';
- """)
- indexes = [row[0] for row in cursor.fetchall()]
- print(f"索引数: {len(indexes)}")
-
- except Exception as e:
- print(f"❌ 查看状态失败: {e}")
- raise
- finally:
- cursor.close()
- conn.close()
- if __name__ == "__main__":
- import argparse
-
- parser = argparse.ArgumentParser(description='OCR数据库迁移脚本')
- parser.add_argument('--action', choices=['migrate', 'rollback', 'status'],
- default='migrate', help='操作类型')
- args = parser.parse_args()
-
- if args.action == 'migrate':
- migrate()
- elif args.action == 'rollback':
- rollback()
- elif args.action == 'status':
- status()
|