run.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. """
  4. 独立最小化管线运行入口
  5. 用法:
  6. python run.py -p <pdf路径> [-o <输出目录>] [--skip-tertiary] [--ocr]
  7. 示例:
  8. python utils_test/minimal_pipeline/run.py \
  9. -p "D:/wx_work/sichuan_luqiao/lu_sgsc_testfile/测试模版.pdf" \
  10. -o ./output \
  11. --skip-tertiary
  12. """
  13. import argparse
  14. import asyncio
  15. import json
  16. import os
  17. import sys
  18. import time
  19. from pathlib import Path
  20. PROJECT_ROOT = Path(__file__).parent.parent.parent
  21. os.chdir(PROJECT_ROOT)
  22. from utils_test.minimal_pipeline import MinimalPipeline
  23. from utils_test.minimal_pipeline.models import PipelineResult
  24. def parse_args():
  25. parser = argparse.ArgumentParser(description="独立最小化文档处理管线")
  26. parser.add_argument("-p", "--pdf", required=True, help="PDF 文件路径")
  27. parser.add_argument("-o", "--output", default="./output", help="输出目录(默认 ./output)")
  28. parser.add_argument("--skip-tertiary", action="store_true", help="跳过三级分类(节省 LLM 调用)")
  29. parser.add_argument("--api-key", default=os.environ.get("DASHSCOPE_API_KEY", ""), help="API Key(默认从环境变量 DASHSCOPE_API_KEY 读取)")
  30. parser.add_argument("--base-url", default="https://dashscope.aliyuncs.com/compatible-mode/v1", help="API Base URL")
  31. parser.add_argument("--model", default="qwen3.5-122b-a10b", help="模型名称")
  32. parser.add_argument("--csv", default=None, help="StandardCategoryTable.csv 路径(默认自动查找)")
  33. return parser.parse_args()
  34. def print_progress(stage: str, percent: int, message: str):
  35. """进度回调"""
  36. bar_len = 30
  37. filled = int(bar_len * percent / 100)
  38. bar = "█" * filled + "░" * (bar_len - filled)
  39. print(f"\r[{bar}] {percent:3d}% | {stage:10s} | {message}", end="", flush=True)
  40. if percent >= 100:
  41. print()
  42. def print_result(result: PipelineResult, elapsed: float):
  43. """打印结果摘要"""
  44. print("\n" + "=" * 80)
  45. print("处理结果摘要")
  46. print("=" * 80)
  47. print(f"文档名称: {result.document_name}")
  48. print(f"总页数: {result.total_pages}")
  49. print(f"处理耗时: {elapsed:.2f} 秒")
  50. print(f"\n一级章节数: {len(result.primary_items)}")
  51. for item in result.primary_items:
  52. print(f" [{item.category_code:15s}] {item.title}")
  53. print(f"\nChunks 数: {len(result.chunks)}")
  54. for chunk in result.chunks[:5]:
  55. print(f" {chunk.chunk_id} | {chunk.section_label} | "
  56. f"一级={chunk.first_name} 二级={chunk.secondary_category_cn} "
  57. f"三级={chunk.tertiary_category_cn}")
  58. if len(result.chunks) > 5:
  59. print(f" ... 共 {len(result.chunks)} 个 chunks")
  60. print(f"\n质量检查:")
  61. qc = result.quality_check
  62. l1 = qc.get("l1_chapter_quality", {})
  63. l2 = qc.get("l2_subsection_quality", {})
  64. print(f" 一级提取率: {l1.get('extraction_rate', 0):.1f}% ({l1.get('extracted_count', 0)}/{l1.get('expected_count', 0)})")
  65. print(f" 二级提取率: {l2.get('extraction_rate', 0):.1f}% ({l2.get('extracted_count', 0)}/{l2.get('expected_count', 0)})")
  66. print(f"\n分类统计:")
  67. for level, stats in result.stats.items():
  68. if isinstance(stats, dict) and stats:
  69. print(f" {level}:")
  70. for cat, count in stats.items():
  71. print(f" {cat}: {count}")
  72. print("=" * 80)
  73. def main():
  74. args = parse_args()
  75. pdf_path = Path(args.pdf)
  76. if not pdf_path.exists():
  77. print(f"[错误] PDF 文件不存在: {pdf_path}")
  78. return 1
  79. if not args.api_key:
  80. print("[错误] 未提供 API Key。请通过 --api-key 参数或 DASHSCOPE_API_KEY 环境变量设置。")
  81. return 1
  82. output_dir = Path(args.output)
  83. output_dir.mkdir(parents=True, exist_ok=True)
  84. print(f"[信息] 处理文档: {pdf_path}")
  85. print(f"[信息] 输出目录: {output_dir}")
  86. print(f"[信息] 模型: {args.model}")
  87. print(f"[信息] 跳过三级分类: {args.skip_tertiary}")
  88. print()
  89. # 读取 PDF
  90. with open(pdf_path, "rb") as f:
  91. file_content = f.read()
  92. # 初始化管线
  93. pipeline = MinimalPipeline(
  94. api_key=args.api_key,
  95. base_url=args.base_url,
  96. model=args.model,
  97. concurrency=10,
  98. csv_path=args.csv,
  99. )
  100. # 运行管线
  101. start_time = time.time()
  102. try:
  103. result = asyncio.run(pipeline.process(
  104. file_content=file_content,
  105. file_name=pdf_path.name,
  106. skip_tertiary=args.skip_tertiary,
  107. progress_callback=print_progress,
  108. ))
  109. except Exception as e:
  110. print(f"\n[错误] 处理失败: {e}")
  111. import traceback
  112. traceback.print_exc()
  113. return 1
  114. elapsed = time.time() - start_time
  115. # 打印结果
  116. print_result(result, elapsed)
  117. # 保存结果
  118. output_file = output_dir / f"{pdf_path.stem}_result.json"
  119. with open(output_file, "w", encoding="utf-8") as f:
  120. json.dump(result.to_dict(), f, ensure_ascii=False, indent=2)
  121. print(f"[信息] 结果已保存到: {output_file}")
  122. # 保存 chunks 明细
  123. chunks_file = output_dir / f"{pdf_path.stem}_chunks.jsonl"
  124. with open(chunks_file, "w", encoding="utf-8") as f:
  125. for chunk in result.chunks:
  126. f.write(json.dumps({
  127. "chunk_id": chunk.chunk_id,
  128. "section_label": chunk.section_label,
  129. "chapter_classification": chunk.chapter_classification,
  130. "first_name": chunk.first_name,
  131. "secondary_category_code": chunk.secondary_category_code,
  132. "secondary_category_cn": chunk.secondary_category_cn,
  133. "tertiary_category_code": chunk.tertiary_category_code,
  134. "tertiary_category_cn": chunk.tertiary_category_cn,
  135. "page_start": chunk.page_start,
  136. "page_end": chunk.page_end,
  137. "content_preview": chunk.review_chunk_content[:200] + "...",
  138. }, ensure_ascii=False) + "\n")
  139. print(f"[信息] Chunks 明细已保存到: {chunks_file}")
  140. return 0
  141. if __name__ == "__main__":
  142. sys.exit(main())