apply_search_migrations.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. #!/usr/bin/env python3
  2. """
  3. 搜索功能数据库迁移脚本
  4. 功能:
  5. - 应用搜索相关的数据库迁移
  6. - 创建用户搜索偏好表
  7. - 创建搜索使用记录表
  8. - 扩展ai_message表支持搜索信息
  9. - 添加相关索引
  10. 使用方法:
  11. python apply_search_migrations.py
  12. python apply_search_migrations.py --rollback # 回滚迁移
  13. """
  14. import os
  15. import sys
  16. import argparse
  17. from pathlib import Path
  18. from typing import List
  19. # 添加项目根目录到Python路径
  20. sys.path.append(str(Path(__file__).parent.parent))
  21. from sqlalchemy import create_engine, text
  22. from sqlalchemy.engine import Engine
  23. from app.database import DATABASE_URL, engine
  24. class SearchMigrationManager:
  25. """搜索功能迁移管理器"""
  26. def __init__(self, database_url: str = None):
  27. """
  28. 初始化迁移管理器
  29. Args:
  30. database_url: 数据库连接URL,默认使用项目配置
  31. """
  32. self.database_url = database_url or DATABASE_URL
  33. self.engine = create_engine(self.database_url)
  34. # 定义迁移文件列表(按执行顺序)
  35. self.migration_files = [
  36. "030_create_user_search_preferences.sql",
  37. "031_create_search_usage_log.sql",
  38. "032_extend_ai_message_search_support.sql"
  39. ]
  40. # 迁移文件目录
  41. self.migrations_dir = Path(__file__).parent.parent / "migrations"
  42. def read_migration_file(self, filename: str) -> str:
  43. """读取迁移文件内容"""
  44. file_path = self.migrations_dir / filename
  45. if not file_path.exists():
  46. raise FileNotFoundError(f"迁移文件不存在: {file_path}")
  47. with open(file_path, 'r', encoding='utf-8') as f:
  48. content = f.read()
  49. return content
  50. def extract_forward_migration(self, content: str) -> str:
  51. """提取正向迁移SQL"""
  52. lines = content.split('\n')
  53. forward_lines = []
  54. in_forward_section = False
  55. for line in lines:
  56. if '正向迁移:' in line or 'Forward Migration' in line:
  57. in_forward_section = True
  58. continue
  59. elif '回滚迁移:' in line or 'Rollback Migration' in line:
  60. break
  61. elif in_forward_section and not line.strip().startswith('--'):
  62. forward_lines.append(line)
  63. return '\n'.join(forward_lines).strip()
  64. def extract_rollback_migration(self, content: str) -> str:
  65. """提取回滚迁移SQL"""
  66. lines = content.split('\n')
  67. rollback_lines = []
  68. in_rollback_section = False
  69. for line in lines:
  70. if '回滚迁移:' in line or 'Rollback Migration' in line:
  71. in_rollback_section = True
  72. continue
  73. elif in_rollback_section and line.strip().startswith('-- '):
  74. # 移除注释符号
  75. sql_line = line.strip()[3:] # 移除 '-- '
  76. if sql_line.strip():
  77. rollback_lines.append(sql_line)
  78. return '\n'.join(rollback_lines).strip()
  79. def execute_sql(self, sql: str, description: str = ""):
  80. """执行SQL语句"""
  81. if not sql.strip():
  82. print(f"跳过空SQL: {description}")
  83. return
  84. try:
  85. with self.engine.connect() as conn:
  86. # 分割SQL语句(按分号分割)
  87. statements = [stmt.strip() for stmt in sql.split(';') if stmt.strip()]
  88. for stmt in statements:
  89. if stmt.strip():
  90. print(f"执行SQL: {stmt[:100]}...")
  91. conn.execute(text(stmt))
  92. conn.commit()
  93. print(f"✅ 成功执行: {description}")
  94. except Exception as e:
  95. print(f"❌ 执行失败: {description}")
  96. print(f"错误信息: {e}")
  97. raise
  98. def check_table_exists(self, table_name: str, schema: str = 'aigcspace') -> bool:
  99. """检查表是否存在"""
  100. try:
  101. with self.engine.connect() as conn:
  102. result = conn.execute(text("""
  103. SELECT EXISTS (
  104. SELECT FROM information_schema.tables
  105. WHERE table_schema = :schema
  106. AND table_name = :table_name
  107. )
  108. """), {"schema": schema, "table_name": table_name})
  109. return result.scalar()
  110. except Exception as e:
  111. print(f"检查表存在性失败: {e}")
  112. return False
  113. def check_column_exists(self, table_name: str, column_name: str, schema: str = 'aigcspace') -> bool:
  114. """检查列是否存在"""
  115. try:
  116. with self.engine.connect() as conn:
  117. result = conn.execute(text("""
  118. SELECT EXISTS (
  119. SELECT FROM information_schema.columns
  120. WHERE table_schema = :schema
  121. AND table_name = :table_name
  122. AND column_name = :column_name
  123. )
  124. """), {"schema": schema, "table_name": table_name, "column_name": column_name})
  125. return result.scalar()
  126. except Exception as e:
  127. print(f"检查列存在性失败: {e}")
  128. return False
  129. def apply_migrations(self):
  130. """应用所有搜索相关迁移"""
  131. print("开始应用搜索功能数据库迁移...")
  132. print(f"数据库URL: {self.database_url}")
  133. print(f"迁移文件数量: {len(self.migration_files)}")
  134. print()
  135. for i, filename in enumerate(self.migration_files, 1):
  136. print(f"[{i}/{len(self.migration_files)}] 处理迁移文件: {filename}")
  137. try:
  138. # 读取迁移文件
  139. content = self.read_migration_file(filename)
  140. # 提取正向迁移SQL
  141. forward_sql = self.extract_forward_migration(content)
  142. if forward_sql:
  143. # 执行迁移
  144. self.execute_sql(forward_sql, f"迁移文件 {filename}")
  145. else:
  146. print(f"⚠️ 未找到正向迁移SQL: {filename}")
  147. except Exception as e:
  148. print(f"❌ 迁移失败: {filename}")
  149. print(f"错误信息: {e}")
  150. return False
  151. print()
  152. print("🎉 所有搜索功能迁移已成功应用!")
  153. # 验证迁移结果
  154. self.verify_migrations()
  155. return True
  156. def rollback_migrations(self):
  157. """回滚所有搜索相关迁移"""
  158. print("开始回滚搜索功能数据库迁移...")
  159. print(f"数据库URL: {self.database_url}")
  160. print()
  161. # 反向执行回滚
  162. for i, filename in enumerate(reversed(self.migration_files), 1):
  163. print(f"[{i}/{len(self.migration_files)}] 回滚迁移文件: {filename}")
  164. try:
  165. # 读取迁移文件
  166. content = self.read_migration_file(filename)
  167. # 提取回滚迁移SQL
  168. rollback_sql = self.extract_rollback_migration(content)
  169. if rollback_sql:
  170. # 执行回滚
  171. self.execute_sql(rollback_sql, f"回滚迁移文件 {filename}")
  172. else:
  173. print(f"⚠️ 未找到回滚迁移SQL: {filename}")
  174. except Exception as e:
  175. print(f"❌ 回滚失败: {filename}")
  176. print(f"错误信息: {e}")
  177. return False
  178. print()
  179. print("🎉 所有搜索功能迁移已成功回滚!")
  180. return True
  181. def verify_migrations(self):
  182. """验证迁移结果"""
  183. print("验证迁移结果...")
  184. # 检查表是否创建成功
  185. tables_to_check = [
  186. "user_search_preferences",
  187. "search_usage_log"
  188. ]
  189. for table in tables_to_check:
  190. exists = self.check_table_exists(table)
  191. status = "✅" if exists else "❌"
  192. print(f"{status} 表 {table}: {'存在' if exists else '不存在'}")
  193. # 检查ai_message表的新列
  194. columns_to_check = [
  195. ("ai_message", "search_info"),
  196. ("ai_message", "search_results")
  197. ]
  198. for table, column in columns_to_check:
  199. exists = self.check_column_exists(table, column)
  200. status = "✅" if exists else "❌"
  201. print(f"{status} 列 {table}.{column}: {'存在' if exists else '不存在'}")
  202. def main():
  203. """主函数"""
  204. parser = argparse.ArgumentParser(description="搜索功能数据库迁移脚本")
  205. parser.add_argument(
  206. "--rollback",
  207. action="store_true",
  208. help="回滚迁移(默认为应用迁移)"
  209. )
  210. parser.add_argument(
  211. "--database-url",
  212. type=str,
  213. help="数据库连接URL(默认使用项目配置)"
  214. )
  215. args = parser.parse_args()
  216. try:
  217. # 初始化迁移管理器
  218. manager = SearchMigrationManager(args.database_url)
  219. # 执行迁移或回滚
  220. if args.rollback:
  221. success = manager.rollback_migrations()
  222. else:
  223. success = manager.apply_migrations()
  224. if success:
  225. print("\n✅ 操作完成!")
  226. else:
  227. print("\n❌ 操作失败!")
  228. sys.exit(1)
  229. except Exception as e:
  230. print(f"\n❌ 执行失败: {e}")
  231. sys.exit(1)
  232. if __name__ == "__main__":
  233. main()