get_db_schema.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587
  1. #!/usr/bin/env python3
  2. """
  3. 数据库表设计信息获取脚本
  4. 功能:
  5. - 获取数据库中所有表的结构信息
  6. - 导出表字段、类型、约束、索引等详细信息
  7. - 支持多种输出格式:JSON、Markdown、SQL DDL
  8. - 生成数据库文档
  9. 使用方法:
  10. python get_db_schema.py --format json --output schema.json
  11. python get_db_schema.py --format markdown --output schema.md
  12. python get_db_schema.py --format sql --output schema.sql
  13. """
  14. import os
  15. import sys
  16. import json
  17. import argparse
  18. from pathlib import Path
  19. from datetime import datetime
  20. from typing import Dict, List, Any, Optional
  21. from dataclasses import dataclass, asdict
  22. # 添加项目根目录到Python路径
  23. sys.path.append(str(Path(__file__).parent.parent))
  24. from sqlalchemy import (
  25. create_engine, inspect, MetaData, Table, Column,
  26. text, Integer, String, Boolean, DateTime, Text,
  27. Float, Numeric, Date, Time, JSON
  28. )
  29. from sqlalchemy.engine import Engine
  30. from sqlalchemy.orm import sessionmaker
  31. from app.database import DATABASE_URL, engine
  32. from app.models import * # 导入所有模型以确保表被注册
  33. @dataclass
  34. class ColumnInfo:
  35. """列信息数据类"""
  36. name: str
  37. type: str
  38. nullable: bool
  39. default: Optional[str]
  40. primary_key: bool
  41. foreign_key: Optional[str]
  42. comment: Optional[str]
  43. @dataclass
  44. class IndexInfo:
  45. """索引信息数据类"""
  46. name: str
  47. columns: List[str]
  48. unique: bool
  49. type: str
  50. @dataclass
  51. class ForeignKeyInfo:
  52. """外键信息数据类"""
  53. name: str
  54. constrained_columns: List[str]
  55. referred_table: str
  56. referred_columns: List[str]
  57. on_delete: Optional[str]
  58. on_update: Optional[str]
  59. @dataclass
  60. class TableInfo:
  61. """表信息数据类"""
  62. name: str
  63. comment: Optional[str]
  64. columns: List[ColumnInfo]
  65. indexes: List[IndexInfo]
  66. foreign_keys: List[ForeignKeyInfo]
  67. row_count: Optional[int]
  68. class DatabaseSchemaExtractor:
  69. """数据库表结构提取器"""
  70. def __init__(self, database_url: str = None):
  71. """
  72. 初始化数据库连接
  73. Args:
  74. database_url: 数据库连接URL,默认使用项目配置
  75. """
  76. self.database_url = database_url or DATABASE_URL
  77. self.engine = create_engine(self.database_url)
  78. self.inspector = inspect(self.engine)
  79. self.metadata = MetaData()
  80. def get_all_tables(self) -> List[str]:
  81. """获取所有表名"""
  82. # 首先尝试获取public模式的表
  83. public_tables = self.inspector.get_table_names()
  84. # 然后尝试获取aigcspace模式的表
  85. try:
  86. aigcspace_tables = self.inspector.get_table_names(schema='aigcspace')
  87. # 为aigcspace模式的表添加模式前缀
  88. aigcspace_tables = [f"aigcspace.{table}" for table in aigcspace_tables]
  89. except Exception as e:
  90. print(f"警告:无法访问aigcspace模式: {e}")
  91. aigcspace_tables = []
  92. # 合并所有表
  93. all_tables = public_tables + aigcspace_tables
  94. if not all_tables:
  95. # 如果仍然没有表,尝试直接查询系统表
  96. try:
  97. with self.engine.connect() as conn:
  98. result = conn.execute(text("""
  99. SELECT schemaname || '.' || tablename as full_name
  100. FROM pg_tables
  101. WHERE schemaname NOT IN ('information_schema', 'pg_catalog')
  102. ORDER BY schemaname, tablename
  103. """))
  104. all_tables = [row[0] for row in result]
  105. print(f"通过系统表查询发现 {len(all_tables)} 个表")
  106. except Exception as e:
  107. print(f"系统表查询也失败: {e}")
  108. return all_tables
  109. def get_column_info(self, table_name: str) -> List[ColumnInfo]:
  110. """获取表的列信息"""
  111. columns = []
  112. # 解析表名和模式
  113. if '.' in table_name:
  114. schema, table = table_name.split('.', 1)
  115. else:
  116. schema = None
  117. table = table_name
  118. try:
  119. column_data = self.inspector.get_columns(table, schema=schema)
  120. pk_constraint = self.inspector.get_pk_constraint(table, schema=schema)
  121. fk_constraints = self.inspector.get_foreign_keys(table, schema=schema)
  122. except Exception as e:
  123. print(f"警告:无法获取表 {table_name} 的列信息: {e}")
  124. return []
  125. # 构建外键映射
  126. fk_map = {}
  127. for fk in fk_constraints:
  128. for col in fk['constrained_columns']:
  129. fk_map[col] = f"{fk['referred_table']}.{fk['referred_columns'][0]}"
  130. for col in column_data:
  131. column_info = ColumnInfo(
  132. name=col['name'],
  133. type=str(col['type']),
  134. nullable=col['nullable'],
  135. default=str(col['default']) if col['default'] is not None else None,
  136. primary_key=col['name'] in (pk_constraint.get('constrained_columns', []) or []),
  137. foreign_key=fk_map.get(col['name']),
  138. comment=col.get('comment')
  139. )
  140. columns.append(column_info)
  141. return columns
  142. def get_index_info(self, table_name: str) -> List[IndexInfo]:
  143. """获取表的索引信息"""
  144. indexes = []
  145. # 解析表名和模式
  146. if '.' in table_name:
  147. schema, table = table_name.split('.', 1)
  148. else:
  149. schema = None
  150. table = table_name
  151. try:
  152. index_data = self.inspector.get_indexes(table, schema=schema)
  153. except Exception as e:
  154. print(f"警告:无法获取表 {table_name} 的索引信息: {e}")
  155. return []
  156. for idx in index_data:
  157. index_info = IndexInfo(
  158. name=idx['name'],
  159. columns=idx['column_names'],
  160. unique=idx['unique'],
  161. type=idx.get('type', 'btree')
  162. )
  163. indexes.append(index_info)
  164. return indexes
  165. def get_foreign_key_info(self, table_name: str) -> List[ForeignKeyInfo]:
  166. """获取表的外键信息"""
  167. foreign_keys = []
  168. # 解析表名和模式
  169. if '.' in table_name:
  170. schema, table = table_name.split('.', 1)
  171. else:
  172. schema = None
  173. table = table_name
  174. try:
  175. fk_data = self.inspector.get_foreign_keys(table, schema=schema)
  176. except Exception as e:
  177. print(f"警告:无法获取表 {table_name} 的外键信息: {e}")
  178. return []
  179. for fk in fk_data:
  180. fk_info = ForeignKeyInfo(
  181. name=fk['name'],
  182. constrained_columns=fk['constrained_columns'],
  183. referred_table=fk['referred_table'],
  184. referred_columns=fk['referred_columns'],
  185. on_delete=fk.get('options', {}).get('ondelete'),
  186. on_update=fk.get('options', {}).get('onupdate')
  187. )
  188. foreign_keys.append(fk_info)
  189. return foreign_keys
  190. def get_table_row_count(self, table_name: str) -> Optional[int]:
  191. """获取表的行数"""
  192. try:
  193. # 处理带模式的表名
  194. if '.' in table_name:
  195. schema, table = table_name.split('.', 1)
  196. full_table_name = f'"{schema}"."{table}"'
  197. else:
  198. full_table_name = f'"{table_name}"'
  199. with self.engine.connect() as conn:
  200. result = conn.execute(text(f"SELECT COUNT(*) FROM {full_table_name}"))
  201. return result.scalar()
  202. except Exception as e:
  203. print(f"警告:无法获取表 {table_name} 的行数: {e}")
  204. return None
  205. def get_table_comment(self, table_name: str) -> Optional[str]:
  206. """获取表注释"""
  207. try:
  208. # 解析表名和模式
  209. if '.' in table_name:
  210. schema, table = table_name.split('.', 1)
  211. else:
  212. schema = 'public'
  213. table = table_name
  214. with self.engine.connect() as conn:
  215. result = conn.execute(text("""
  216. SELECT obj_description(c.oid)
  217. FROM pg_class c
  218. JOIN pg_namespace n ON n.oid = c.relnamespace
  219. WHERE c.relname = :table_name AND n.nspname = :schema_name
  220. """), {"table_name": table, "schema_name": schema})
  221. comment = result.scalar()
  222. return comment
  223. except Exception as e:
  224. print(f"警告:无法获取表 {table_name} 的注释: {e}")
  225. return None
  226. def extract_table_info(self, table_name: str) -> TableInfo:
  227. """提取单个表的完整信息"""
  228. print(f"正在提取表 {table_name} 的信息...")
  229. return TableInfo(
  230. name=table_name,
  231. comment=self.get_table_comment(table_name),
  232. columns=self.get_column_info(table_name),
  233. indexes=self.get_index_info(table_name),
  234. foreign_keys=self.get_foreign_key_info(table_name),
  235. row_count=self.get_table_row_count(table_name)
  236. )
  237. def extract_all_tables(self) -> List[TableInfo]:
  238. """提取所有表的信息"""
  239. tables = []
  240. table_names = self.get_all_tables()
  241. print(f"发现 {len(table_names)} 个表")
  242. for table_name in table_names:
  243. try:
  244. table_info = self.extract_table_info(table_name)
  245. tables.append(table_info)
  246. except Exception as e:
  247. print(f"错误:提取表 {table_name} 信息失败: {e}")
  248. return tables
  249. class SchemaFormatter:
  250. """数据库结构格式化器"""
  251. @staticmethod
  252. def to_json(tables: List[TableInfo], indent: int = 2) -> str:
  253. """转换为JSON格式"""
  254. data = {
  255. "generated_at": datetime.now().isoformat(),
  256. "database_url": "***隐藏***",
  257. "total_tables": len(tables),
  258. "tables": [asdict(table) for table in tables]
  259. }
  260. return json.dumps(data, ensure_ascii=False, indent=indent)
  261. @staticmethod
  262. def to_markdown(tables: List[TableInfo]) -> str:
  263. """转换为Markdown格式"""
  264. md_lines = [
  265. "# 数据库表结构文档",
  266. "",
  267. f"**生成时间**: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
  268. f"**表总数**: {len(tables)}",
  269. "",
  270. "## 目录",
  271. ""
  272. ]
  273. # 生成目录
  274. for table in tables:
  275. md_lines.append(f"- [{table.name}](#{table.name.lower()})")
  276. md_lines.append("")
  277. # 生成每个表的详细信息
  278. for table in tables:
  279. md_lines.extend([
  280. f"## {table.name}",
  281. ""
  282. ])
  283. if table.comment:
  284. md_lines.extend([
  285. f"**表说明**: {table.comment}",
  286. ""
  287. ])
  288. if table.row_count is not None:
  289. md_lines.extend([
  290. f"**数据行数**: {table.row_count:,}",
  291. ""
  292. ])
  293. # 列信息表格
  294. md_lines.extend([
  295. "### 字段信息",
  296. "",
  297. "| 字段名 | 类型 | 可空 | 默认值 | 主键 | 外键 | 说明 |",
  298. "|--------|------|------|--------|------|------|------|"
  299. ])
  300. for col in table.columns:
  301. pk_mark = "✓" if col.primary_key else ""
  302. fk_mark = col.foreign_key or ""
  303. default_val = col.default or ""
  304. nullable = "✓" if col.nullable else ""
  305. comment = col.comment or ""
  306. md_lines.append(
  307. f"| {col.name} | {col.type} | {nullable} | {default_val} | {pk_mark} | {fk_mark} | {comment} |"
  308. )
  309. md_lines.append("")
  310. # 索引信息
  311. if table.indexes:
  312. md_lines.extend([
  313. "### 索引信息",
  314. "",
  315. "| 索引名 | 字段 | 唯一 | 类型 |",
  316. "|--------|------|------|------|"
  317. ])
  318. for idx in table.indexes:
  319. unique_mark = "✓" if idx.unique else ""
  320. columns_str = ", ".join(idx.columns)
  321. md_lines.append(
  322. f"| {idx.name} | {columns_str} | {unique_mark} | {idx.type} |"
  323. )
  324. md_lines.append("")
  325. # 外键信息
  326. if table.foreign_keys:
  327. md_lines.extend([
  328. "### 外键约束",
  329. "",
  330. "| 约束名 | 本表字段 | 引用表.字段 | 删除规则 | 更新规则 |",
  331. "|--------|----------|-------------|----------|----------|"
  332. ])
  333. for fk in table.foreign_keys:
  334. constrained_cols = ", ".join(fk.constrained_columns)
  335. referred_cols = ", ".join(fk.referred_columns)
  336. referred_info = f"{fk.referred_table}.{referred_cols}"
  337. on_delete = fk.on_delete or ""
  338. on_update = fk.on_update or ""
  339. md_lines.append(
  340. f"| {fk.name} | {constrained_cols} | {referred_info} | {on_delete} | {on_update} |"
  341. )
  342. md_lines.append("")
  343. md_lines.append("---")
  344. md_lines.append("")
  345. return "\n".join(md_lines)
  346. @staticmethod
  347. def to_sql_ddl(tables: List[TableInfo]) -> str:
  348. """转换为SQL DDL格式"""
  349. sql_lines = [
  350. "-- 数据库表结构DDL",
  351. f"-- 生成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
  352. f"-- 表总数: {len(tables)}",
  353. "",
  354. "-- ==================== 表结构定义 ===================="
  355. ]
  356. for table in tables:
  357. sql_lines.extend([
  358. "",
  359. f"-- 表: {table.name}",
  360. ])
  361. if table.comment:
  362. sql_lines.append(f"-- 说明: {table.comment}")
  363. if table.row_count is not None:
  364. sql_lines.append(f"-- 数据行数: {table.row_count:,}")
  365. sql_lines.extend([
  366. f"CREATE TABLE {table.name} (",
  367. ])
  368. # 生成列定义
  369. column_definitions = []
  370. for col in table.columns:
  371. col_def = f" {col.name} {col.type}"
  372. if not col.nullable:
  373. col_def += " NOT NULL"
  374. if col.default:
  375. col_def += f" DEFAULT {col.default}"
  376. if col.comment:
  377. col_def += f" -- {col.comment}"
  378. column_definitions.append(col_def)
  379. sql_lines.append(",\n".join(column_definitions))
  380. # 添加主键约束
  381. pk_columns = [col.name for col in table.columns if col.primary_key]
  382. if pk_columns:
  383. sql_lines.append(f",\n PRIMARY KEY ({', '.join(pk_columns)})")
  384. sql_lines.extend([
  385. ");",
  386. ""
  387. ])
  388. # 添加索引
  389. for idx in table.indexes:
  390. if idx.name.endswith("_pkey"): # 跳过主键索引
  391. continue
  392. unique_keyword = "UNIQUE " if idx.unique else ""
  393. columns_str = ", ".join(idx.columns)
  394. sql_lines.append(
  395. f"CREATE {unique_keyword}INDEX {idx.name} ON {table.name} ({columns_str});"
  396. )
  397. # 添加外键约束
  398. for fk in table.foreign_keys:
  399. constrained_cols = ", ".join(fk.constrained_columns)
  400. referred_cols = ", ".join(fk.referred_columns)
  401. fk_sql = f"ALTER TABLE {table.name} ADD CONSTRAINT {fk.name} "
  402. fk_sql += f"FOREIGN KEY ({constrained_cols}) "
  403. fk_sql += f"REFERENCES {fk.referred_table} ({referred_cols})"
  404. if fk.on_delete:
  405. fk_sql += f" ON DELETE {fk.on_delete}"
  406. if fk.on_update:
  407. fk_sql += f" ON UPDATE {fk.on_update}"
  408. fk_sql += ";"
  409. sql_lines.append(fk_sql)
  410. sql_lines.append("")
  411. return "\n".join(sql_lines)
  412. def main():
  413. """主函数"""
  414. parser = argparse.ArgumentParser(description="获取数据库表设计信息")
  415. parser.add_argument(
  416. "--format",
  417. choices=["json", "markdown", "sql"],
  418. default="json",
  419. help="输出格式 (默认: json)"
  420. )
  421. parser.add_argument(
  422. "--output",
  423. type=str,
  424. help="输出文件路径 (默认: 输出到控制台)"
  425. )
  426. parser.add_argument(
  427. "--tables",
  428. nargs="*",
  429. help="指定要提取的表名 (默认: 所有表)"
  430. )
  431. parser.add_argument(
  432. "--database-url",
  433. type=str,
  434. help="数据库连接URL (默认: 使用项目配置)"
  435. )
  436. args = parser.parse_args()
  437. try:
  438. # 初始化提取器
  439. print("正在连接数据库...")
  440. extractor = DatabaseSchemaExtractor(args.database_url)
  441. # 提取表信息
  442. if args.tables:
  443. print(f"提取指定表: {', '.join(args.tables)}")
  444. tables = []
  445. for table_name in args.tables:
  446. try:
  447. table_info = extractor.extract_table_info(table_name)
  448. tables.append(table_info)
  449. except Exception as e:
  450. print(f"错误:表 {table_name} 不存在或提取失败: {e}")
  451. else:
  452. print("提取所有表...")
  453. tables = extractor.extract_all_tables()
  454. if not tables:
  455. print("未找到任何表")
  456. return
  457. # 格式化输出
  458. print(f"正在生成 {args.format} 格式...")
  459. if args.format == "json":
  460. output = SchemaFormatter.to_json(tables)
  461. elif args.format == "markdown":
  462. output = SchemaFormatter.to_markdown(tables)
  463. elif args.format == "sql":
  464. output = SchemaFormatter.to_sql_ddl(tables)
  465. # 输出结果
  466. if args.output:
  467. output_path = Path(args.output)
  468. output_path.parent.mkdir(parents=True, exist_ok=True)
  469. with open(output_path, 'w', encoding='utf-8') as f:
  470. f.write(output)
  471. print(f"结果已保存到: {output_path}")
  472. print(f"文件大小: {output_path.stat().st_size:,} 字节")
  473. else:
  474. print("\n" + "="*50)
  475. print(output)
  476. print(f"\n✅ 成功提取 {len(tables)} 个表的结构信息")
  477. except Exception as e:
  478. print(f"❌ 执行失败: {e}")
  479. sys.exit(1)
  480. if __name__ == "__main__":
  481. main()