#!/usr/bin/env python3 """ 数据库表设计信息获取脚本 功能: - 获取数据库中所有表的结构信息 - 导出表字段、类型、约束、索引等详细信息 - 支持多种输出格式:JSON、Markdown、SQL DDL - 生成数据库文档 使用方法: python get_db_schema.py --format json --output schema.json python get_db_schema.py --format markdown --output schema.md python get_db_schema.py --format sql --output schema.sql """ import os import sys import json import argparse from pathlib import Path from datetime import datetime from typing import Dict, List, Any, Optional from dataclasses import dataclass, asdict # 添加项目根目录到Python路径 sys.path.append(str(Path(__file__).parent.parent)) from sqlalchemy import ( create_engine, inspect, MetaData, Table, Column, text, Integer, String, Boolean, DateTime, Text, Float, Numeric, Date, Time, JSON ) from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from app.database import DATABASE_URL, engine from app.models import * # 导入所有模型以确保表被注册 @dataclass class ColumnInfo: """列信息数据类""" name: str type: str nullable: bool default: Optional[str] primary_key: bool foreign_key: Optional[str] comment: Optional[str] @dataclass class IndexInfo: """索引信息数据类""" name: str columns: List[str] unique: bool type: str @dataclass class ForeignKeyInfo: """外键信息数据类""" name: str constrained_columns: List[str] referred_table: str referred_columns: List[str] on_delete: Optional[str] on_update: Optional[str] @dataclass class TableInfo: """表信息数据类""" name: str comment: Optional[str] columns: List[ColumnInfo] indexes: List[IndexInfo] foreign_keys: List[ForeignKeyInfo] row_count: Optional[int] class DatabaseSchemaExtractor: """数据库表结构提取器""" def __init__(self, database_url: str = None): """ 初始化数据库连接 Args: database_url: 数据库连接URL,默认使用项目配置 """ self.database_url = database_url or DATABASE_URL self.engine = create_engine(self.database_url) self.inspector = inspect(self.engine) self.metadata = MetaData() def get_all_tables(self) -> List[str]: """获取所有表名""" # 首先尝试获取public模式的表 public_tables = self.inspector.get_table_names() # 然后尝试获取aigcspace模式的表 try: aigcspace_tables = self.inspector.get_table_names(schema='aigcspace') # 为aigcspace模式的表添加模式前缀 aigcspace_tables = [f"aigcspace.{table}" for table in aigcspace_tables] except Exception as e: print(f"警告:无法访问aigcspace模式: {e}") aigcspace_tables = [] # 合并所有表 all_tables = public_tables + aigcspace_tables if not all_tables: # 如果仍然没有表,尝试直接查询系统表 try: with self.engine.connect() as conn: result = conn.execute(text(""" SELECT schemaname || '.' || tablename as full_name FROM pg_tables WHERE schemaname NOT IN ('information_schema', 'pg_catalog') ORDER BY schemaname, tablename """)) all_tables = [row[0] for row in result] print(f"通过系统表查询发现 {len(all_tables)} 个表") except Exception as e: print(f"系统表查询也失败: {e}") return all_tables def get_column_info(self, table_name: str) -> List[ColumnInfo]: """获取表的列信息""" columns = [] # 解析表名和模式 if '.' in table_name: schema, table = table_name.split('.', 1) else: schema = None table = table_name try: column_data = self.inspector.get_columns(table, schema=schema) pk_constraint = self.inspector.get_pk_constraint(table, schema=schema) fk_constraints = self.inspector.get_foreign_keys(table, schema=schema) except Exception as e: print(f"警告:无法获取表 {table_name} 的列信息: {e}") return [] # 构建外键映射 fk_map = {} for fk in fk_constraints: for col in fk['constrained_columns']: fk_map[col] = f"{fk['referred_table']}.{fk['referred_columns'][0]}" for col in column_data: column_info = ColumnInfo( name=col['name'], type=str(col['type']), nullable=col['nullable'], default=str(col['default']) if col['default'] is not None else None, primary_key=col['name'] in (pk_constraint.get('constrained_columns', []) or []), foreign_key=fk_map.get(col['name']), comment=col.get('comment') ) columns.append(column_info) return columns def get_index_info(self, table_name: str) -> List[IndexInfo]: """获取表的索引信息""" indexes = [] # 解析表名和模式 if '.' in table_name: schema, table = table_name.split('.', 1) else: schema = None table = table_name try: index_data = self.inspector.get_indexes(table, schema=schema) except Exception as e: print(f"警告:无法获取表 {table_name} 的索引信息: {e}") return [] for idx in index_data: index_info = IndexInfo( name=idx['name'], columns=idx['column_names'], unique=idx['unique'], type=idx.get('type', 'btree') ) indexes.append(index_info) return indexes def get_foreign_key_info(self, table_name: str) -> List[ForeignKeyInfo]: """获取表的外键信息""" foreign_keys = [] # 解析表名和模式 if '.' in table_name: schema, table = table_name.split('.', 1) else: schema = None table = table_name try: fk_data = self.inspector.get_foreign_keys(table, schema=schema) except Exception as e: print(f"警告:无法获取表 {table_name} 的外键信息: {e}") return [] for fk in fk_data: fk_info = ForeignKeyInfo( name=fk['name'], constrained_columns=fk['constrained_columns'], referred_table=fk['referred_table'], referred_columns=fk['referred_columns'], on_delete=fk.get('options', {}).get('ondelete'), on_update=fk.get('options', {}).get('onupdate') ) foreign_keys.append(fk_info) return foreign_keys def get_table_row_count(self, table_name: str) -> Optional[int]: """获取表的行数""" try: # 处理带模式的表名 if '.' in table_name: schema, table = table_name.split('.', 1) full_table_name = f'"{schema}"."{table}"' else: full_table_name = f'"{table_name}"' with self.engine.connect() as conn: result = conn.execute(text(f"SELECT COUNT(*) FROM {full_table_name}")) return result.scalar() except Exception as e: print(f"警告:无法获取表 {table_name} 的行数: {e}") return None def get_table_comment(self, table_name: str) -> Optional[str]: """获取表注释""" try: # 解析表名和模式 if '.' in table_name: schema, table = table_name.split('.', 1) else: schema = 'public' table = table_name with self.engine.connect() as conn: result = conn.execute(text(""" SELECT obj_description(c.oid) FROM pg_class c JOIN pg_namespace n ON n.oid = c.relnamespace WHERE c.relname = :table_name AND n.nspname = :schema_name """), {"table_name": table, "schema_name": schema}) comment = result.scalar() return comment except Exception as e: print(f"警告:无法获取表 {table_name} 的注释: {e}") return None def extract_table_info(self, table_name: str) -> TableInfo: """提取单个表的完整信息""" print(f"正在提取表 {table_name} 的信息...") return TableInfo( name=table_name, comment=self.get_table_comment(table_name), columns=self.get_column_info(table_name), indexes=self.get_index_info(table_name), foreign_keys=self.get_foreign_key_info(table_name), row_count=self.get_table_row_count(table_name) ) def extract_all_tables(self) -> List[TableInfo]: """提取所有表的信息""" tables = [] table_names = self.get_all_tables() print(f"发现 {len(table_names)} 个表") for table_name in table_names: try: table_info = self.extract_table_info(table_name) tables.append(table_info) except Exception as e: print(f"错误:提取表 {table_name} 信息失败: {e}") return tables class SchemaFormatter: """数据库结构格式化器""" @staticmethod def to_json(tables: List[TableInfo], indent: int = 2) -> str: """转换为JSON格式""" data = { "generated_at": datetime.now().isoformat(), "database_url": "***隐藏***", "total_tables": len(tables), "tables": [asdict(table) for table in tables] } return json.dumps(data, ensure_ascii=False, indent=indent) @staticmethod def to_markdown(tables: List[TableInfo]) -> str: """转换为Markdown格式""" md_lines = [ "# 数据库表结构文档", "", f"**生成时间**: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", f"**表总数**: {len(tables)}", "", "## 目录", "" ] # 生成目录 for table in tables: md_lines.append(f"- [{table.name}](#{table.name.lower()})") md_lines.append("") # 生成每个表的详细信息 for table in tables: md_lines.extend([ f"## {table.name}", "" ]) if table.comment: md_lines.extend([ f"**表说明**: {table.comment}", "" ]) if table.row_count is not None: md_lines.extend([ f"**数据行数**: {table.row_count:,}", "" ]) # 列信息表格 md_lines.extend([ "### 字段信息", "", "| 字段名 | 类型 | 可空 | 默认值 | 主键 | 外键 | 说明 |", "|--------|------|------|--------|------|------|------|" ]) for col in table.columns: pk_mark = "✓" if col.primary_key else "" fk_mark = col.foreign_key or "" default_val = col.default or "" nullable = "✓" if col.nullable else "" comment = col.comment or "" md_lines.append( f"| {col.name} | {col.type} | {nullable} | {default_val} | {pk_mark} | {fk_mark} | {comment} |" ) md_lines.append("") # 索引信息 if table.indexes: md_lines.extend([ "### 索引信息", "", "| 索引名 | 字段 | 唯一 | 类型 |", "|--------|------|------|------|" ]) for idx in table.indexes: unique_mark = "✓" if idx.unique else "" columns_str = ", ".join(idx.columns) md_lines.append( f"| {idx.name} | {columns_str} | {unique_mark} | {idx.type} |" ) md_lines.append("") # 外键信息 if table.foreign_keys: md_lines.extend([ "### 外键约束", "", "| 约束名 | 本表字段 | 引用表.字段 | 删除规则 | 更新规则 |", "|--------|----------|-------------|----------|----------|" ]) for fk in table.foreign_keys: constrained_cols = ", ".join(fk.constrained_columns) referred_cols = ", ".join(fk.referred_columns) referred_info = f"{fk.referred_table}.{referred_cols}" on_delete = fk.on_delete or "" on_update = fk.on_update or "" md_lines.append( f"| {fk.name} | {constrained_cols} | {referred_info} | {on_delete} | {on_update} |" ) md_lines.append("") md_lines.append("---") md_lines.append("") return "\n".join(md_lines) @staticmethod def to_sql_ddl(tables: List[TableInfo]) -> str: """转换为SQL DDL格式""" sql_lines = [ "-- 数据库表结构DDL", f"-- 生成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", f"-- 表总数: {len(tables)}", "", "-- ==================== 表结构定义 ====================" ] for table in tables: sql_lines.extend([ "", f"-- 表: {table.name}", ]) if table.comment: sql_lines.append(f"-- 说明: {table.comment}") if table.row_count is not None: sql_lines.append(f"-- 数据行数: {table.row_count:,}") sql_lines.extend([ f"CREATE TABLE {table.name} (", ]) # 生成列定义 column_definitions = [] for col in table.columns: col_def = f" {col.name} {col.type}" if not col.nullable: col_def += " NOT NULL" if col.default: col_def += f" DEFAULT {col.default}" if col.comment: col_def += f" -- {col.comment}" column_definitions.append(col_def) sql_lines.append(",\n".join(column_definitions)) # 添加主键约束 pk_columns = [col.name for col in table.columns if col.primary_key] if pk_columns: sql_lines.append(f",\n PRIMARY KEY ({', '.join(pk_columns)})") sql_lines.extend([ ");", "" ]) # 添加索引 for idx in table.indexes: if idx.name.endswith("_pkey"): # 跳过主键索引 continue unique_keyword = "UNIQUE " if idx.unique else "" columns_str = ", ".join(idx.columns) sql_lines.append( f"CREATE {unique_keyword}INDEX {idx.name} ON {table.name} ({columns_str});" ) # 添加外键约束 for fk in table.foreign_keys: constrained_cols = ", ".join(fk.constrained_columns) referred_cols = ", ".join(fk.referred_columns) fk_sql = f"ALTER TABLE {table.name} ADD CONSTRAINT {fk.name} " fk_sql += f"FOREIGN KEY ({constrained_cols}) " fk_sql += f"REFERENCES {fk.referred_table} ({referred_cols})" if fk.on_delete: fk_sql += f" ON DELETE {fk.on_delete}" if fk.on_update: fk_sql += f" ON UPDATE {fk.on_update}" fk_sql += ";" sql_lines.append(fk_sql) sql_lines.append("") return "\n".join(sql_lines) def main(): """主函数""" parser = argparse.ArgumentParser(description="获取数据库表设计信息") parser.add_argument( "--format", choices=["json", "markdown", "sql"], default="json", help="输出格式 (默认: json)" ) parser.add_argument( "--output", type=str, help="输出文件路径 (默认: 输出到控制台)" ) parser.add_argument( "--tables", nargs="*", help="指定要提取的表名 (默认: 所有表)" ) parser.add_argument( "--database-url", type=str, help="数据库连接URL (默认: 使用项目配置)" ) args = parser.parse_args() try: # 初始化提取器 print("正在连接数据库...") extractor = DatabaseSchemaExtractor(args.database_url) # 提取表信息 if args.tables: print(f"提取指定表: {', '.join(args.tables)}") tables = [] for table_name in args.tables: try: table_info = extractor.extract_table_info(table_name) tables.append(table_info) except Exception as e: print(f"错误:表 {table_name} 不存在或提取失败: {e}") else: print("提取所有表...") tables = extractor.extract_all_tables() if not tables: print("未找到任何表") return # 格式化输出 print(f"正在生成 {args.format} 格式...") if args.format == "json": output = SchemaFormatter.to_json(tables) elif args.format == "markdown": output = SchemaFormatter.to_markdown(tables) elif args.format == "sql": output = SchemaFormatter.to_sql_ddl(tables) # 输出结果 if args.output: output_path = Path(args.output) output_path.parent.mkdir(parents=True, exist_ok=True) with open(output_path, 'w', encoding='utf-8') as f: f.write(output) print(f"结果已保存到: {output_path}") print(f"文件大小: {output_path.stat().st_size:,} 字节") else: print("\n" + "="*50) print(output) print(f"\n✅ 成功提取 {len(tables)} 个表的结构信息") except Exception as e: print(f"❌ 执行失败: {e}") sys.exit(1) if __name__ == "__main__": main()