| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587 |
- #!/usr/bin/env python3
- """
- 数据库表设计信息获取脚本
- 功能:
- - 获取数据库中所有表的结构信息
- - 导出表字段、类型、约束、索引等详细信息
- - 支持多种输出格式:JSON、Markdown、SQL DDL
- - 生成数据库文档
- 使用方法:
- python get_db_schema.py --format json --output schema.json
- python get_db_schema.py --format markdown --output schema.md
- python get_db_schema.py --format sql --output schema.sql
- """
- import os
- import sys
- import json
- import argparse
- from pathlib import Path
- from datetime import datetime
- from typing import Dict, List, Any, Optional
- from dataclasses import dataclass, asdict
- # 添加项目根目录到Python路径
- sys.path.append(str(Path(__file__).parent.parent))
- from sqlalchemy import (
- create_engine, inspect, MetaData, Table, Column,
- text, Integer, String, Boolean, DateTime, Text,
- Float, Numeric, Date, Time, JSON
- )
- from sqlalchemy.engine import Engine
- from sqlalchemy.orm import sessionmaker
- from app.database import DATABASE_URL, engine
- from app.models import * # 导入所有模型以确保表被注册
- @dataclass
- class ColumnInfo:
- """列信息数据类"""
- name: str
- type: str
- nullable: bool
- default: Optional[str]
- primary_key: bool
- foreign_key: Optional[str]
- comment: Optional[str]
- @dataclass
- class IndexInfo:
- """索引信息数据类"""
- name: str
- columns: List[str]
- unique: bool
- type: str
- @dataclass
- class ForeignKeyInfo:
- """外键信息数据类"""
- name: str
- constrained_columns: List[str]
- referred_table: str
- referred_columns: List[str]
- on_delete: Optional[str]
- on_update: Optional[str]
- @dataclass
- class TableInfo:
- """表信息数据类"""
- name: str
- comment: Optional[str]
- columns: List[ColumnInfo]
- indexes: List[IndexInfo]
- foreign_keys: List[ForeignKeyInfo]
- row_count: Optional[int]
- class DatabaseSchemaExtractor:
- """数据库表结构提取器"""
-
- def __init__(self, database_url: str = None):
- """
- 初始化数据库连接
-
- Args:
- database_url: 数据库连接URL,默认使用项目配置
- """
- self.database_url = database_url or DATABASE_URL
- self.engine = create_engine(self.database_url)
- self.inspector = inspect(self.engine)
- self.metadata = MetaData()
-
- def get_all_tables(self) -> List[str]:
- """获取所有表名"""
- # 首先尝试获取public模式的表
- public_tables = self.inspector.get_table_names()
-
- # 然后尝试获取aigcspace模式的表
- try:
- aigcspace_tables = self.inspector.get_table_names(schema='aigcspace')
- # 为aigcspace模式的表添加模式前缀
- aigcspace_tables = [f"aigcspace.{table}" for table in aigcspace_tables]
- except Exception as e:
- print(f"警告:无法访问aigcspace模式: {e}")
- aigcspace_tables = []
-
- # 合并所有表
- all_tables = public_tables + aigcspace_tables
-
- if not all_tables:
- # 如果仍然没有表,尝试直接查询系统表
- try:
- with self.engine.connect() as conn:
- result = conn.execute(text("""
- SELECT schemaname || '.' || tablename as full_name
- FROM pg_tables
- WHERE schemaname NOT IN ('information_schema', 'pg_catalog')
- ORDER BY schemaname, tablename
- """))
- all_tables = [row[0] for row in result]
- print(f"通过系统表查询发现 {len(all_tables)} 个表")
- except Exception as e:
- print(f"系统表查询也失败: {e}")
-
- return all_tables
-
- def get_column_info(self, table_name: str) -> List[ColumnInfo]:
- """获取表的列信息"""
- columns = []
-
- # 解析表名和模式
- if '.' in table_name:
- schema, table = table_name.split('.', 1)
- else:
- schema = None
- table = table_name
-
- try:
- column_data = self.inspector.get_columns(table, schema=schema)
- pk_constraint = self.inspector.get_pk_constraint(table, schema=schema)
- fk_constraints = self.inspector.get_foreign_keys(table, schema=schema)
- except Exception as e:
- print(f"警告:无法获取表 {table_name} 的列信息: {e}")
- return []
-
- # 构建外键映射
- fk_map = {}
- for fk in fk_constraints:
- for col in fk['constrained_columns']:
- fk_map[col] = f"{fk['referred_table']}.{fk['referred_columns'][0]}"
-
- for col in column_data:
- column_info = ColumnInfo(
- name=col['name'],
- type=str(col['type']),
- nullable=col['nullable'],
- default=str(col['default']) if col['default'] is not None else None,
- primary_key=col['name'] in (pk_constraint.get('constrained_columns', []) or []),
- foreign_key=fk_map.get(col['name']),
- comment=col.get('comment')
- )
- columns.append(column_info)
-
- return columns
-
- def get_index_info(self, table_name: str) -> List[IndexInfo]:
- """获取表的索引信息"""
- indexes = []
-
- # 解析表名和模式
- if '.' in table_name:
- schema, table = table_name.split('.', 1)
- else:
- schema = None
- table = table_name
-
- try:
- index_data = self.inspector.get_indexes(table, schema=schema)
- except Exception as e:
- print(f"警告:无法获取表 {table_name} 的索引信息: {e}")
- return []
-
- for idx in index_data:
- index_info = IndexInfo(
- name=idx['name'],
- columns=idx['column_names'],
- unique=idx['unique'],
- type=idx.get('type', 'btree')
- )
- indexes.append(index_info)
-
- return indexes
-
- def get_foreign_key_info(self, table_name: str) -> List[ForeignKeyInfo]:
- """获取表的外键信息"""
- foreign_keys = []
-
- # 解析表名和模式
- if '.' in table_name:
- schema, table = table_name.split('.', 1)
- else:
- schema = None
- table = table_name
-
- try:
- fk_data = self.inspector.get_foreign_keys(table, schema=schema)
- except Exception as e:
- print(f"警告:无法获取表 {table_name} 的外键信息: {e}")
- return []
-
- for fk in fk_data:
- fk_info = ForeignKeyInfo(
- name=fk['name'],
- constrained_columns=fk['constrained_columns'],
- referred_table=fk['referred_table'],
- referred_columns=fk['referred_columns'],
- on_delete=fk.get('options', {}).get('ondelete'),
- on_update=fk.get('options', {}).get('onupdate')
- )
- foreign_keys.append(fk_info)
-
- return foreign_keys
-
- def get_table_row_count(self, table_name: str) -> Optional[int]:
- """获取表的行数"""
- try:
- # 处理带模式的表名
- if '.' in table_name:
- schema, table = table_name.split('.', 1)
- full_table_name = f'"{schema}"."{table}"'
- else:
- full_table_name = f'"{table_name}"'
-
- with self.engine.connect() as conn:
- result = conn.execute(text(f"SELECT COUNT(*) FROM {full_table_name}"))
- return result.scalar()
- except Exception as e:
- print(f"警告:无法获取表 {table_name} 的行数: {e}")
- return None
-
- def get_table_comment(self, table_name: str) -> Optional[str]:
- """获取表注释"""
- try:
- # 解析表名和模式
- if '.' in table_name:
- schema, table = table_name.split('.', 1)
- else:
- schema = 'public'
- table = table_name
-
- with self.engine.connect() as conn:
- result = conn.execute(text("""
- SELECT obj_description(c.oid)
- FROM pg_class c
- JOIN pg_namespace n ON n.oid = c.relnamespace
- WHERE c.relname = :table_name AND n.nspname = :schema_name
- """), {"table_name": table, "schema_name": schema})
- comment = result.scalar()
- return comment
- except Exception as e:
- print(f"警告:无法获取表 {table_name} 的注释: {e}")
- return None
-
- def extract_table_info(self, table_name: str) -> TableInfo:
- """提取单个表的完整信息"""
- print(f"正在提取表 {table_name} 的信息...")
-
- return TableInfo(
- name=table_name,
- comment=self.get_table_comment(table_name),
- columns=self.get_column_info(table_name),
- indexes=self.get_index_info(table_name),
- foreign_keys=self.get_foreign_key_info(table_name),
- row_count=self.get_table_row_count(table_name)
- )
-
- def extract_all_tables(self) -> List[TableInfo]:
- """提取所有表的信息"""
- tables = []
- table_names = self.get_all_tables()
-
- print(f"发现 {len(table_names)} 个表")
-
- for table_name in table_names:
- try:
- table_info = self.extract_table_info(table_name)
- tables.append(table_info)
- except Exception as e:
- print(f"错误:提取表 {table_name} 信息失败: {e}")
-
- return tables
- class SchemaFormatter:
- """数据库结构格式化器"""
-
- @staticmethod
- def to_json(tables: List[TableInfo], indent: int = 2) -> str:
- """转换为JSON格式"""
- data = {
- "generated_at": datetime.now().isoformat(),
- "database_url": "***隐藏***",
- "total_tables": len(tables),
- "tables": [asdict(table) for table in tables]
- }
- return json.dumps(data, ensure_ascii=False, indent=indent)
-
- @staticmethod
- def to_markdown(tables: List[TableInfo]) -> str:
- """转换为Markdown格式"""
- md_lines = [
- "# 数据库表结构文档",
- "",
- f"**生成时间**: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
- f"**表总数**: {len(tables)}",
- "",
- "## 目录",
- ""
- ]
-
- # 生成目录
- for table in tables:
- md_lines.append(f"- [{table.name}](#{table.name.lower()})")
-
- md_lines.append("")
-
- # 生成每个表的详细信息
- for table in tables:
- md_lines.extend([
- f"## {table.name}",
- ""
- ])
-
- if table.comment:
- md_lines.extend([
- f"**表说明**: {table.comment}",
- ""
- ])
-
- if table.row_count is not None:
- md_lines.extend([
- f"**数据行数**: {table.row_count:,}",
- ""
- ])
-
- # 列信息表格
- md_lines.extend([
- "### 字段信息",
- "",
- "| 字段名 | 类型 | 可空 | 默认值 | 主键 | 外键 | 说明 |",
- "|--------|------|------|--------|------|------|------|"
- ])
-
- for col in table.columns:
- pk_mark = "✓" if col.primary_key else ""
- fk_mark = col.foreign_key or ""
- default_val = col.default or ""
- nullable = "✓" if col.nullable else ""
- comment = col.comment or ""
-
- md_lines.append(
- f"| {col.name} | {col.type} | {nullable} | {default_val} | {pk_mark} | {fk_mark} | {comment} |"
- )
-
- md_lines.append("")
-
- # 索引信息
- if table.indexes:
- md_lines.extend([
- "### 索引信息",
- "",
- "| 索引名 | 字段 | 唯一 | 类型 |",
- "|--------|------|------|------|"
- ])
-
- for idx in table.indexes:
- unique_mark = "✓" if idx.unique else ""
- columns_str = ", ".join(idx.columns)
- md_lines.append(
- f"| {idx.name} | {columns_str} | {unique_mark} | {idx.type} |"
- )
-
- md_lines.append("")
-
- # 外键信息
- if table.foreign_keys:
- md_lines.extend([
- "### 外键约束",
- "",
- "| 约束名 | 本表字段 | 引用表.字段 | 删除规则 | 更新规则 |",
- "|--------|----------|-------------|----------|----------|"
- ])
-
- for fk in table.foreign_keys:
- constrained_cols = ", ".join(fk.constrained_columns)
- referred_cols = ", ".join(fk.referred_columns)
- referred_info = f"{fk.referred_table}.{referred_cols}"
- on_delete = fk.on_delete or ""
- on_update = fk.on_update or ""
-
- md_lines.append(
- f"| {fk.name} | {constrained_cols} | {referred_info} | {on_delete} | {on_update} |"
- )
-
- md_lines.append("")
-
- md_lines.append("---")
- md_lines.append("")
-
- return "\n".join(md_lines)
-
- @staticmethod
- def to_sql_ddl(tables: List[TableInfo]) -> str:
- """转换为SQL DDL格式"""
- sql_lines = [
- "-- 数据库表结构DDL",
- f"-- 生成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
- f"-- 表总数: {len(tables)}",
- "",
- "-- ==================== 表结构定义 ===================="
- ]
-
- for table in tables:
- sql_lines.extend([
- "",
- f"-- 表: {table.name}",
- ])
-
- if table.comment:
- sql_lines.append(f"-- 说明: {table.comment}")
-
- if table.row_count is not None:
- sql_lines.append(f"-- 数据行数: {table.row_count:,}")
-
- sql_lines.extend([
- f"CREATE TABLE {table.name} (",
- ])
-
- # 生成列定义
- column_definitions = []
- for col in table.columns:
- col_def = f" {col.name} {col.type}"
-
- if not col.nullable:
- col_def += " NOT NULL"
-
- if col.default:
- col_def += f" DEFAULT {col.default}"
-
- if col.comment:
- col_def += f" -- {col.comment}"
-
- column_definitions.append(col_def)
-
- sql_lines.append(",\n".join(column_definitions))
-
- # 添加主键约束
- pk_columns = [col.name for col in table.columns if col.primary_key]
- if pk_columns:
- sql_lines.append(f",\n PRIMARY KEY ({', '.join(pk_columns)})")
-
- sql_lines.extend([
- ");",
- ""
- ])
-
- # 添加索引
- for idx in table.indexes:
- if idx.name.endswith("_pkey"): # 跳过主键索引
- continue
-
- unique_keyword = "UNIQUE " if idx.unique else ""
- columns_str = ", ".join(idx.columns)
- sql_lines.append(
- f"CREATE {unique_keyword}INDEX {idx.name} ON {table.name} ({columns_str});"
- )
-
- # 添加外键约束
- for fk in table.foreign_keys:
- constrained_cols = ", ".join(fk.constrained_columns)
- referred_cols = ", ".join(fk.referred_columns)
-
- fk_sql = f"ALTER TABLE {table.name} ADD CONSTRAINT {fk.name} "
- fk_sql += f"FOREIGN KEY ({constrained_cols}) "
- fk_sql += f"REFERENCES {fk.referred_table} ({referred_cols})"
-
- if fk.on_delete:
- fk_sql += f" ON DELETE {fk.on_delete}"
- if fk.on_update:
- fk_sql += f" ON UPDATE {fk.on_update}"
-
- fk_sql += ";"
- sql_lines.append(fk_sql)
-
- sql_lines.append("")
-
- return "\n".join(sql_lines)
- def main():
- """主函数"""
- parser = argparse.ArgumentParser(description="获取数据库表设计信息")
- parser.add_argument(
- "--format",
- choices=["json", "markdown", "sql"],
- default="json",
- help="输出格式 (默认: json)"
- )
- parser.add_argument(
- "--output",
- type=str,
- help="输出文件路径 (默认: 输出到控制台)"
- )
- parser.add_argument(
- "--tables",
- nargs="*",
- help="指定要提取的表名 (默认: 所有表)"
- )
- parser.add_argument(
- "--database-url",
- type=str,
- help="数据库连接URL (默认: 使用项目配置)"
- )
-
- args = parser.parse_args()
-
- try:
- # 初始化提取器
- print("正在连接数据库...")
- extractor = DatabaseSchemaExtractor(args.database_url)
-
- # 提取表信息
- if args.tables:
- print(f"提取指定表: {', '.join(args.tables)}")
- tables = []
- for table_name in args.tables:
- try:
- table_info = extractor.extract_table_info(table_name)
- tables.append(table_info)
- except Exception as e:
- print(f"错误:表 {table_name} 不存在或提取失败: {e}")
- else:
- print("提取所有表...")
- tables = extractor.extract_all_tables()
-
- if not tables:
- print("未找到任何表")
- return
-
- # 格式化输出
- print(f"正在生成 {args.format} 格式...")
-
- if args.format == "json":
- output = SchemaFormatter.to_json(tables)
- elif args.format == "markdown":
- output = SchemaFormatter.to_markdown(tables)
- elif args.format == "sql":
- output = SchemaFormatter.to_sql_ddl(tables)
-
- # 输出结果
- if args.output:
- output_path = Path(args.output)
- output_path.parent.mkdir(parents=True, exist_ok=True)
-
- with open(output_path, 'w', encoding='utf-8') as f:
- f.write(output)
-
- print(f"结果已保存到: {output_path}")
- print(f"文件大小: {output_path.stat().st_size:,} 字节")
- else:
- print("\n" + "="*50)
- print(output)
-
- print(f"\n✅ 成功提取 {len(tables)} 个表的结构信息")
-
- except Exception as e:
- print(f"❌ 执行失败: {e}")
- sys.exit(1)
- if __name__ == "__main__":
- main()
|