| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293 |
- #!/usr/bin/env python3
- """
- 搜索功能数据库迁移脚本
- 功能:
- - 应用搜索相关的数据库迁移
- - 创建用户搜索偏好表
- - 创建搜索使用记录表
- - 扩展ai_message表支持搜索信息
- - 添加相关索引
- 使用方法:
- python apply_search_migrations.py
- python apply_search_migrations.py --rollback # 回滚迁移
- """
- import os
- import sys
- import argparse
- from pathlib import Path
- from typing import List
- # 添加项目根目录到Python路径
- sys.path.append(str(Path(__file__).parent.parent))
- from sqlalchemy import create_engine, text
- from sqlalchemy.engine import Engine
- from app.database import DATABASE_URL, engine
- class SearchMigrationManager:
- """搜索功能迁移管理器"""
-
- def __init__(self, database_url: str = None):
- """
- 初始化迁移管理器
-
- Args:
- database_url: 数据库连接URL,默认使用项目配置
- """
- self.database_url = database_url or DATABASE_URL
- self.engine = create_engine(self.database_url)
-
- # 定义迁移文件列表(按执行顺序)
- self.migration_files = [
- "030_create_user_search_preferences.sql",
- "031_create_search_usage_log.sql",
- "032_extend_ai_message_search_support.sql"
- ]
-
- # 迁移文件目录
- self.migrations_dir = Path(__file__).parent.parent / "migrations"
-
- def read_migration_file(self, filename: str) -> str:
- """读取迁移文件内容"""
- file_path = self.migrations_dir / filename
-
- if not file_path.exists():
- raise FileNotFoundError(f"迁移文件不存在: {file_path}")
-
- with open(file_path, 'r', encoding='utf-8') as f:
- content = f.read()
-
- return content
-
- def extract_forward_migration(self, content: str) -> str:
- """提取正向迁移SQL"""
- lines = content.split('\n')
- forward_lines = []
- in_forward_section = False
-
- for line in lines:
- if '正向迁移:' in line or 'Forward Migration' in line:
- in_forward_section = True
- continue
- elif '回滚迁移:' in line or 'Rollback Migration' in line:
- break
- elif in_forward_section and not line.strip().startswith('--'):
- forward_lines.append(line)
-
- return '\n'.join(forward_lines).strip()
-
- def extract_rollback_migration(self, content: str) -> str:
- """提取回滚迁移SQL"""
- lines = content.split('\n')
- rollback_lines = []
- in_rollback_section = False
-
- for line in lines:
- if '回滚迁移:' in line or 'Rollback Migration' in line:
- in_rollback_section = True
- continue
- elif in_rollback_section and line.strip().startswith('-- '):
- # 移除注释符号
- sql_line = line.strip()[3:] # 移除 '-- '
- if sql_line.strip():
- rollback_lines.append(sql_line)
-
- return '\n'.join(rollback_lines).strip()
-
- def execute_sql(self, sql: str, description: str = ""):
- """执行SQL语句"""
- if not sql.strip():
- print(f"跳过空SQL: {description}")
- return
-
- try:
- with self.engine.connect() as conn:
- # 分割SQL语句(按分号分割)
- statements = [stmt.strip() for stmt in sql.split(';') if stmt.strip()]
-
- for stmt in statements:
- if stmt.strip():
- print(f"执行SQL: {stmt[:100]}...")
- conn.execute(text(stmt))
- conn.commit()
-
- print(f"✅ 成功执行: {description}")
-
- except Exception as e:
- print(f"❌ 执行失败: {description}")
- print(f"错误信息: {e}")
- raise
-
- def check_table_exists(self, table_name: str, schema: str = 'aigcspace') -> bool:
- """检查表是否存在"""
- try:
- with self.engine.connect() as conn:
- result = conn.execute(text("""
- SELECT EXISTS (
- SELECT FROM information_schema.tables
- WHERE table_schema = :schema
- AND table_name = :table_name
- )
- """), {"schema": schema, "table_name": table_name})
- return result.scalar()
- except Exception as e:
- print(f"检查表存在性失败: {e}")
- return False
-
- def check_column_exists(self, table_name: str, column_name: str, schema: str = 'aigcspace') -> bool:
- """检查列是否存在"""
- try:
- with self.engine.connect() as conn:
- result = conn.execute(text("""
- SELECT EXISTS (
- SELECT FROM information_schema.columns
- WHERE table_schema = :schema
- AND table_name = :table_name
- AND column_name = :column_name
- )
- """), {"schema": schema, "table_name": table_name, "column_name": column_name})
- return result.scalar()
- except Exception as e:
- print(f"检查列存在性失败: {e}")
- return False
-
- def apply_migrations(self):
- """应用所有搜索相关迁移"""
- print("开始应用搜索功能数据库迁移...")
- print(f"数据库URL: {self.database_url}")
- print(f"迁移文件数量: {len(self.migration_files)}")
- print()
-
- for i, filename in enumerate(self.migration_files, 1):
- print(f"[{i}/{len(self.migration_files)}] 处理迁移文件: {filename}")
-
- try:
- # 读取迁移文件
- content = self.read_migration_file(filename)
-
- # 提取正向迁移SQL
- forward_sql = self.extract_forward_migration(content)
-
- if forward_sql:
- # 执行迁移
- self.execute_sql(forward_sql, f"迁移文件 {filename}")
- else:
- print(f"⚠️ 未找到正向迁移SQL: {filename}")
-
- except Exception as e:
- print(f"❌ 迁移失败: {filename}")
- print(f"错误信息: {e}")
- return False
-
- print()
-
- print("🎉 所有搜索功能迁移已成功应用!")
-
- # 验证迁移结果
- self.verify_migrations()
-
- return True
-
- def rollback_migrations(self):
- """回滚所有搜索相关迁移"""
- print("开始回滚搜索功能数据库迁移...")
- print(f"数据库URL: {self.database_url}")
- print()
-
- # 反向执行回滚
- for i, filename in enumerate(reversed(self.migration_files), 1):
- print(f"[{i}/{len(self.migration_files)}] 回滚迁移文件: {filename}")
-
- try:
- # 读取迁移文件
- content = self.read_migration_file(filename)
-
- # 提取回滚迁移SQL
- rollback_sql = self.extract_rollback_migration(content)
-
- if rollback_sql:
- # 执行回滚
- self.execute_sql(rollback_sql, f"回滚迁移文件 {filename}")
- else:
- print(f"⚠️ 未找到回滚迁移SQL: {filename}")
-
- except Exception as e:
- print(f"❌ 回滚失败: {filename}")
- print(f"错误信息: {e}")
- return False
-
- print()
-
- print("🎉 所有搜索功能迁移已成功回滚!")
- return True
-
- def verify_migrations(self):
- """验证迁移结果"""
- print("验证迁移结果...")
-
- # 检查表是否创建成功
- tables_to_check = [
- "user_search_preferences",
- "search_usage_log"
- ]
-
- for table in tables_to_check:
- exists = self.check_table_exists(table)
- status = "✅" if exists else "❌"
- print(f"{status} 表 {table}: {'存在' if exists else '不存在'}")
-
- # 检查ai_message表的新列
- columns_to_check = [
- ("ai_message", "search_info"),
- ("ai_message", "search_results")
- ]
-
- for table, column in columns_to_check:
- exists = self.check_column_exists(table, column)
- status = "✅" if exists else "❌"
- print(f"{status} 列 {table}.{column}: {'存在' if exists else '不存在'}")
- def main():
- """主函数"""
- parser = argparse.ArgumentParser(description="搜索功能数据库迁移脚本")
- parser.add_argument(
- "--rollback",
- action="store_true",
- help="回滚迁移(默认为应用迁移)"
- )
- parser.add_argument(
- "--database-url",
- type=str,
- help="数据库连接URL(默认使用项目配置)"
- )
-
- args = parser.parse_args()
-
- try:
- # 初始化迁移管理器
- manager = SearchMigrationManager(args.database_url)
-
- # 执行迁移或回滚
- if args.rollback:
- success = manager.rollback_migrations()
- else:
- success = manager.apply_migrations()
-
- if success:
- print("\n✅ 操作完成!")
- else:
- print("\n❌ 操作失败!")
- sys.exit(1)
-
- except Exception as e:
- print(f"\n❌ 执行失败: {e}")
- sys.exit(1)
- if __name__ == "__main__":
- main()
|