""" OCR数据库迁移脚本 运行: python -m scripts.migrate_ocr """ import os import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.parent)) from dotenv import load_dotenv load_dotenv() import psycopg2 from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT def get_db_connection(): """获取数据库连接""" return psycopg2.connect( host=os.getenv('DB_HOST', 'localhost'), port=os.getenv('DB_PORT', '5432'), user=os.getenv('DB_USER', 'postgres'), password=os.getenv('DB_PASSWORD', ''), database=os.getenv('DB_NAME', 'model_square') ) def table_exists(cursor): """检查表是否存在""" cursor.execute(""" SELECT EXISTS ( SELECT FROM information_schema.tables WHERE table_schema = 'aigcspace' AND table_name = 'ocr_tasks' ); """) return cursor.fetchone()[0] def migrate(): """执行迁移""" conn = get_db_connection() conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) cursor = conn.cursor() try: if table_exists(cursor): print("✓ 表 aigcspace.ocr_tasks 已存在") return print("开始创建 OCR 任务表...") # 创建表 cursor.execute(""" CREATE TABLE aigcspace.ocr_tasks ( id SERIAL PRIMARY KEY, task_id VARCHAR(64) UNIQUE NOT NULL, user_id VARCHAR(50) NOT NULL REFERENCES aigcspace.users(id) ON DELETE CASCADE, task_type VARCHAR(50) NOT NULL, custom_prompt TEXT, status VARCHAR(20) NOT NULL DEFAULT 'pending', progress INTEGER DEFAULT 0 CHECK (progress >= 0 AND progress <= 100), total_images INTEGER NOT NULL CHECK (total_images > 0 AND total_images <= 10), processed_images INTEGER DEFAULT 0 CHECK (processed_images >= 0), results JSONB DEFAULT '{}'::jsonb, error_message TEXT, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, completed_at TIMESTAMP, CONSTRAINT check_processed_images CHECK (processed_images <= total_images), CONSTRAINT check_status CHECK (status IN ('pending', 'processing', 'completed', 'failed', 'cancelled')) ); """) print("✓ 表创建成功") # 创建索引 indexes = [ "CREATE INDEX idx_ocr_tasks_task_id ON aigcspace.ocr_tasks(task_id);", "CREATE INDEX idx_ocr_tasks_user_id ON aigcspace.ocr_tasks(user_id);", "CREATE INDEX idx_ocr_tasks_status ON aigcspace.ocr_tasks(status);", "CREATE INDEX idx_ocr_tasks_user_status ON aigcspace.ocr_tasks(user_id, status);", "CREATE INDEX idx_ocr_tasks_created_at ON aigcspace.ocr_tasks(created_at DESC);" ] for idx_sql in indexes: cursor.execute(idx_sql) print("✓ 索引创建成功") # 创建触发器 cursor.execute(""" CREATE OR REPLACE FUNCTION update_ocr_tasks_updated_at() RETURNS TRIGGER AS $$ BEGIN NEW.updated_at = CURRENT_TIMESTAMP; RETURN NEW; END; $$ LANGUAGE plpgsql; CREATE TRIGGER trigger_update_ocr_tasks_updated_at BEFORE UPDATE ON aigcspace.ocr_tasks FOR EACH ROW EXECUTE FUNCTION update_ocr_tasks_updated_at(); """) print("✓ 触发器创建成功") # 添加注释 comments = [ "COMMENT ON TABLE aigcspace.ocr_tasks IS 'OCR识别任务表,同时作为历史记录表';", "COMMENT ON COLUMN aigcspace.ocr_tasks.task_id IS '任务唯一标识符(UUID)';", "COMMENT ON COLUMN aigcspace.ocr_tasks.status IS '任务状态:pending/processing/completed/failed/cancelled';", "COMMENT ON COLUMN aigcspace.ocr_tasks.results IS 'JSONB格式存储:{images: [], thumbnails: [], texts: []}';" ] for comment_sql in comments: cursor.execute(comment_sql) print("✓ 注释添加成功") print("\n✅ OCR数据库迁移完成!") except Exception as e: print(f"❌ 迁移失败: {e}") raise finally: cursor.close() conn.close() def rollback(): """回滚迁移""" conn = get_db_connection() conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) cursor = conn.cursor() try: if not table_exists(cursor): print("✓ 表 aigcspace.ocr_tasks 不存在,无需回滚") return print("开始回滚 OCR 数据库迁移...") cursor.execute("DROP TRIGGER IF EXISTS trigger_update_ocr_tasks_updated_at ON aigcspace.ocr_tasks;") cursor.execute("DROP FUNCTION IF EXISTS update_ocr_tasks_updated_at();") cursor.execute("DROP TABLE IF EXISTS aigcspace.ocr_tasks CASCADE;") print("✅ OCR数据库迁移回滚完成!") except Exception as e: print(f"❌ 回滚失败: {e}") raise finally: cursor.close() conn.close() def status(): """查看迁移状态""" conn = get_db_connection() conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) cursor = conn.cursor() try: exists = table_exists(cursor) print(f"表 aigcspace.ocr_tasks: {'✓ 已创建' if exists else '✗ 未创建'}") if exists: cursor.execute("SELECT COUNT(*) FROM aigcspace.ocr_tasks;") count = cursor.fetchone()[0] print(f"记录数: {count}") cursor.execute(""" SELECT indexname FROM pg_indexes WHERE schemaname = 'aigcspace' AND tablename = 'ocr_tasks'; """) indexes = [row[0] for row in cursor.fetchall()] print(f"索引数: {len(indexes)}") except Exception as e: print(f"❌ 查看状态失败: {e}") raise finally: cursor.close() conn.close() if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description='OCR数据库迁移脚本') parser.add_argument('--action', choices=['migrate', 'rollback', 'status'], default='migrate', help='操作类型') args = parser.parse_args() if args.action == 'migrate': migrate() elif args.action == 'rollback': rollback() elif args.action == 'status': status()