migrate_ocr.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. """
  2. OCR数据库迁移脚本
  3. 运行: python -m scripts.migrate_ocr
  4. """
  5. import os
  6. import sys
  7. from pathlib import Path
  8. sys.path.insert(0, str(Path(__file__).parent.parent))
  9. from dotenv import load_dotenv
  10. load_dotenv()
  11. import psycopg2
  12. from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT
  13. def get_db_connection():
  14. """获取数据库连接"""
  15. return psycopg2.connect(
  16. host=os.getenv('DB_HOST', 'localhost'),
  17. port=os.getenv('DB_PORT', '5432'),
  18. user=os.getenv('DB_USER', 'postgres'),
  19. password=os.getenv('DB_PASSWORD', ''),
  20. database=os.getenv('DB_NAME', 'model_square')
  21. )
  22. def table_exists(cursor):
  23. """检查表是否存在"""
  24. cursor.execute("""
  25. SELECT EXISTS (
  26. SELECT FROM information_schema.tables
  27. WHERE table_schema = 'aigcspace'
  28. AND table_name = 'ocr_tasks'
  29. );
  30. """)
  31. return cursor.fetchone()[0]
  32. def migrate():
  33. """执行迁移"""
  34. conn = get_db_connection()
  35. conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
  36. cursor = conn.cursor()
  37. try:
  38. if table_exists(cursor):
  39. print("✓ 表 aigcspace.ocr_tasks 已存在")
  40. return
  41. print("开始创建 OCR 任务表...")
  42. # 创建表
  43. cursor.execute("""
  44. CREATE TABLE aigcspace.ocr_tasks (
  45. id SERIAL PRIMARY KEY,
  46. task_id VARCHAR(64) UNIQUE NOT NULL,
  47. user_id VARCHAR(50) NOT NULL REFERENCES aigcspace.users(id) ON DELETE CASCADE,
  48. task_type VARCHAR(50) NOT NULL,
  49. custom_prompt TEXT,
  50. status VARCHAR(20) NOT NULL DEFAULT 'pending',
  51. progress INTEGER DEFAULT 0 CHECK (progress >= 0 AND progress <= 100),
  52. total_images INTEGER NOT NULL CHECK (total_images > 0 AND total_images <= 10),
  53. processed_images INTEGER DEFAULT 0 CHECK (processed_images >= 0),
  54. results JSONB DEFAULT '{}'::jsonb,
  55. error_message TEXT,
  56. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
  57. updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
  58. completed_at TIMESTAMP,
  59. CONSTRAINT check_processed_images CHECK (processed_images <= total_images),
  60. CONSTRAINT check_status CHECK (status IN ('pending', 'processing', 'completed', 'failed', 'cancelled'))
  61. );
  62. """)
  63. print("✓ 表创建成功")
  64. # 创建索引
  65. indexes = [
  66. "CREATE INDEX idx_ocr_tasks_task_id ON aigcspace.ocr_tasks(task_id);",
  67. "CREATE INDEX idx_ocr_tasks_user_id ON aigcspace.ocr_tasks(user_id);",
  68. "CREATE INDEX idx_ocr_tasks_status ON aigcspace.ocr_tasks(status);",
  69. "CREATE INDEX idx_ocr_tasks_user_status ON aigcspace.ocr_tasks(user_id, status);",
  70. "CREATE INDEX idx_ocr_tasks_created_at ON aigcspace.ocr_tasks(created_at DESC);"
  71. ]
  72. for idx_sql in indexes:
  73. cursor.execute(idx_sql)
  74. print("✓ 索引创建成功")
  75. # 创建触发器
  76. cursor.execute("""
  77. CREATE OR REPLACE FUNCTION update_ocr_tasks_updated_at()
  78. RETURNS TRIGGER AS $$
  79. BEGIN
  80. NEW.updated_at = CURRENT_TIMESTAMP;
  81. RETURN NEW;
  82. END;
  83. $$ LANGUAGE plpgsql;
  84. CREATE TRIGGER trigger_update_ocr_tasks_updated_at
  85. BEFORE UPDATE ON aigcspace.ocr_tasks
  86. FOR EACH ROW
  87. EXECUTE FUNCTION update_ocr_tasks_updated_at();
  88. """)
  89. print("✓ 触发器创建成功")
  90. # 添加注释
  91. comments = [
  92. "COMMENT ON TABLE aigcspace.ocr_tasks IS 'OCR识别任务表,同时作为历史记录表';",
  93. "COMMENT ON COLUMN aigcspace.ocr_tasks.task_id IS '任务唯一标识符(UUID)';",
  94. "COMMENT ON COLUMN aigcspace.ocr_tasks.status IS '任务状态:pending/processing/completed/failed/cancelled';",
  95. "COMMENT ON COLUMN aigcspace.ocr_tasks.results IS 'JSONB格式存储:{images: [], thumbnails: [], texts: []}';"
  96. ]
  97. for comment_sql in comments:
  98. cursor.execute(comment_sql)
  99. print("✓ 注释添加成功")
  100. print("\n✅ OCR数据库迁移完成!")
  101. except Exception as e:
  102. print(f"❌ 迁移失败: {e}")
  103. raise
  104. finally:
  105. cursor.close()
  106. conn.close()
  107. def rollback():
  108. """回滚迁移"""
  109. conn = get_db_connection()
  110. conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
  111. cursor = conn.cursor()
  112. try:
  113. if not table_exists(cursor):
  114. print("✓ 表 aigcspace.ocr_tasks 不存在,无需回滚")
  115. return
  116. print("开始回滚 OCR 数据库迁移...")
  117. cursor.execute("DROP TRIGGER IF EXISTS trigger_update_ocr_tasks_updated_at ON aigcspace.ocr_tasks;")
  118. cursor.execute("DROP FUNCTION IF EXISTS update_ocr_tasks_updated_at();")
  119. cursor.execute("DROP TABLE IF EXISTS aigcspace.ocr_tasks CASCADE;")
  120. print("✅ OCR数据库迁移回滚完成!")
  121. except Exception as e:
  122. print(f"❌ 回滚失败: {e}")
  123. raise
  124. finally:
  125. cursor.close()
  126. conn.close()
  127. def status():
  128. """查看迁移状态"""
  129. conn = get_db_connection()
  130. conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
  131. cursor = conn.cursor()
  132. try:
  133. exists = table_exists(cursor)
  134. print(f"表 aigcspace.ocr_tasks: {'✓ 已创建' if exists else '✗ 未创建'}")
  135. if exists:
  136. cursor.execute("SELECT COUNT(*) FROM aigcspace.ocr_tasks;")
  137. count = cursor.fetchone()[0]
  138. print(f"记录数: {count}")
  139. cursor.execute("""
  140. SELECT indexname FROM pg_indexes
  141. WHERE schemaname = 'aigcspace' AND tablename = 'ocr_tasks';
  142. """)
  143. indexes = [row[0] for row in cursor.fetchall()]
  144. print(f"索引数: {len(indexes)}")
  145. except Exception as e:
  146. print(f"❌ 查看状态失败: {e}")
  147. raise
  148. finally:
  149. cursor.close()
  150. conn.close()
  151. if __name__ == "__main__":
  152. import argparse
  153. parser = argparse.ArgumentParser(description='OCR数据库迁移脚本')
  154. parser.add_argument('--action', choices=['migrate', 'rollback', 'status'],
  155. default='migrate', help='操作类型')
  156. args = parser.parse_args()
  157. if args.action == 'migrate':
  158. migrate()
  159. elif args.action == 'rollback':
  160. rollback()
  161. elif args.action == 'status':
  162. status()