#!/usr/bin/env python3 """ 搜索功能数据库迁移脚本 功能: - 应用搜索相关的数据库迁移 - 创建用户搜索偏好表 - 创建搜索使用记录表 - 扩展ai_message表支持搜索信息 - 添加相关索引 使用方法: python apply_search_migrations.py python apply_search_migrations.py --rollback # 回滚迁移 """ import os import sys import argparse from pathlib import Path from typing import List # 添加项目根目录到Python路径 sys.path.append(str(Path(__file__).parent.parent)) from sqlalchemy import create_engine, text from sqlalchemy.engine import Engine from app.database import DATABASE_URL, engine class SearchMigrationManager: """搜索功能迁移管理器""" def __init__(self, database_url: str = None): """ 初始化迁移管理器 Args: database_url: 数据库连接URL,默认使用项目配置 """ self.database_url = database_url or DATABASE_URL self.engine = create_engine(self.database_url) # 定义迁移文件列表(按执行顺序) self.migration_files = [ "030_create_user_search_preferences.sql", "031_create_search_usage_log.sql", "032_extend_ai_message_search_support.sql" ] # 迁移文件目录 self.migrations_dir = Path(__file__).parent.parent / "migrations" def read_migration_file(self, filename: str) -> str: """读取迁移文件内容""" file_path = self.migrations_dir / filename if not file_path.exists(): raise FileNotFoundError(f"迁移文件不存在: {file_path}") with open(file_path, 'r', encoding='utf-8') as f: content = f.read() return content def extract_forward_migration(self, content: str) -> str: """提取正向迁移SQL""" lines = content.split('\n') forward_lines = [] in_forward_section = False for line in lines: if '正向迁移:' in line or 'Forward Migration' in line: in_forward_section = True continue elif '回滚迁移:' in line or 'Rollback Migration' in line: break elif in_forward_section and not line.strip().startswith('--'): forward_lines.append(line) return '\n'.join(forward_lines).strip() def extract_rollback_migration(self, content: str) -> str: """提取回滚迁移SQL""" lines = content.split('\n') rollback_lines = [] in_rollback_section = False for line in lines: if '回滚迁移:' in line or 'Rollback Migration' in line: in_rollback_section = True continue elif in_rollback_section and line.strip().startswith('-- '): # 移除注释符号 sql_line = line.strip()[3:] # 移除 '-- ' if sql_line.strip(): rollback_lines.append(sql_line) return '\n'.join(rollback_lines).strip() def execute_sql(self, sql: str, description: str = ""): """执行SQL语句""" if not sql.strip(): print(f"跳过空SQL: {description}") return try: with self.engine.connect() as conn: # 分割SQL语句(按分号分割) statements = [stmt.strip() for stmt in sql.split(';') if stmt.strip()] for stmt in statements: if stmt.strip(): print(f"执行SQL: {stmt[:100]}...") conn.execute(text(stmt)) conn.commit() print(f"✅ 成功执行: {description}") except Exception as e: print(f"❌ 执行失败: {description}") print(f"错误信息: {e}") raise def check_table_exists(self, table_name: str, schema: str = 'aigcspace') -> bool: """检查表是否存在""" try: with self.engine.connect() as conn: result = conn.execute(text(""" SELECT EXISTS ( SELECT FROM information_schema.tables WHERE table_schema = :schema AND table_name = :table_name ) """), {"schema": schema, "table_name": table_name}) return result.scalar() except Exception as e: print(f"检查表存在性失败: {e}") return False def check_column_exists(self, table_name: str, column_name: str, schema: str = 'aigcspace') -> bool: """检查列是否存在""" try: with self.engine.connect() as conn: result = conn.execute(text(""" SELECT EXISTS ( SELECT FROM information_schema.columns WHERE table_schema = :schema AND table_name = :table_name AND column_name = :column_name ) """), {"schema": schema, "table_name": table_name, "column_name": column_name}) return result.scalar() except Exception as e: print(f"检查列存在性失败: {e}") return False def apply_migrations(self): """应用所有搜索相关迁移""" print("开始应用搜索功能数据库迁移...") print(f"数据库URL: {self.database_url}") print(f"迁移文件数量: {len(self.migration_files)}") print() for i, filename in enumerate(self.migration_files, 1): print(f"[{i}/{len(self.migration_files)}] 处理迁移文件: {filename}") try: # 读取迁移文件 content = self.read_migration_file(filename) # 提取正向迁移SQL forward_sql = self.extract_forward_migration(content) if forward_sql: # 执行迁移 self.execute_sql(forward_sql, f"迁移文件 {filename}") else: print(f"⚠️ 未找到正向迁移SQL: {filename}") except Exception as e: print(f"❌ 迁移失败: {filename}") print(f"错误信息: {e}") return False print() print("🎉 所有搜索功能迁移已成功应用!") # 验证迁移结果 self.verify_migrations() return True def rollback_migrations(self): """回滚所有搜索相关迁移""" print("开始回滚搜索功能数据库迁移...") print(f"数据库URL: {self.database_url}") print() # 反向执行回滚 for i, filename in enumerate(reversed(self.migration_files), 1): print(f"[{i}/{len(self.migration_files)}] 回滚迁移文件: {filename}") try: # 读取迁移文件 content = self.read_migration_file(filename) # 提取回滚迁移SQL rollback_sql = self.extract_rollback_migration(content) if rollback_sql: # 执行回滚 self.execute_sql(rollback_sql, f"回滚迁移文件 {filename}") else: print(f"⚠️ 未找到回滚迁移SQL: {filename}") except Exception as e: print(f"❌ 回滚失败: {filename}") print(f"错误信息: {e}") return False print() print("🎉 所有搜索功能迁移已成功回滚!") return True def verify_migrations(self): """验证迁移结果""" print("验证迁移结果...") # 检查表是否创建成功 tables_to_check = [ "user_search_preferences", "search_usage_log" ] for table in tables_to_check: exists = self.check_table_exists(table) status = "✅" if exists else "❌" print(f"{status} 表 {table}: {'存在' if exists else '不存在'}") # 检查ai_message表的新列 columns_to_check = [ ("ai_message", "search_info"), ("ai_message", "search_results") ] for table, column in columns_to_check: exists = self.check_column_exists(table, column) status = "✅" if exists else "❌" print(f"{status} 列 {table}.{column}: {'存在' if exists else '不存在'}") def main(): """主函数""" parser = argparse.ArgumentParser(description="搜索功能数据库迁移脚本") parser.add_argument( "--rollback", action="store_true", help="回滚迁移(默认为应用迁移)" ) parser.add_argument( "--database-url", type=str, help="数据库连接URL(默认使用项目配置)" ) args = parser.parse_args() try: # 初始化迁移管理器 manager = SearchMigrationManager(args.database_url) # 执行迁移或回滚 if args.rollback: success = manager.rollback_migrations() else: success = manager.apply_migrations() if success: print("\n✅ 操作完成!") else: print("\n❌ 操作失败!") sys.exit(1) except Exception as e: print(f"\n❌ 执行失败: {e}") sys.exit(1) if __name__ == "__main__": main()