test_yolo_layout.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451
  1. """
  2. YOLO 版面检测模型测试脚本
  3. 测试 RapidLayout 对表格(table)、图片(image)的识别情况,
  4. 确认非标准表格是否被误判为 image,以及纯图片的分类标签。
  5. 用法:
  6. python utils_test/Yolo_Test/test_yolo_layout.py -p <pdf_path>
  7. python utils_test/Yolo_Test/test_yolo_layout.py -p <pdf_path> --save-images
  8. python utils_test/Yolo_Test/test_yolo_layout.py -p <pdf_path> --pages 0,1,2
  9. 依赖:
  10. pip install rapid-layout pymupdf numpy Pillow
  11. """
  12. import argparse
  13. import json
  14. import sys
  15. from collections import Counter
  16. from pathlib import Path
  17. from typing import Dict, List, Optional, Tuple
  18. import fitz
  19. import numpy as np
  20. try:
  21. from rapid_layout import RapidLayout
  22. RAPID_LAYOUT_AVAILABLE = True
  23. except ImportError:
  24. RAPID_LAYOUT_AVAILABLE = False
  25. RapidLayout = None
  26. class YoloLayoutTester:
  27. """YOLO 版面检测测试器"""
  28. def __init__(
  29. self,
  30. dpi: int = 200,
  31. clip_top: float = 60,
  32. clip_bottom: float = 60,
  33. confidence_threshold: float = 0.3,
  34. ):
  35. self.dpi = dpi
  36. self.clip_top = clip_top
  37. self.clip_bottom = clip_bottom
  38. self.confidence_threshold = confidence_threshold
  39. self._engine: Optional[RapidLayout] = None
  40. def _get_engine(self) -> Optional[RapidLayout]:
  41. if not RAPID_LAYOUT_AVAILABLE:
  42. return None
  43. if self._engine is None:
  44. self._engine = RapidLayout()
  45. return self._engine
  46. def analyze_pdf(
  47. self,
  48. pdf_path: Path,
  49. pages: Optional[List[int]] = None,
  50. save_images_dir: Optional[Path] = None,
  51. ) -> Dict:
  52. """分析 PDF 文件的版面检测结果"""
  53. if not RAPID_LAYOUT_AVAILABLE:
  54. return {"error": "RapidLayout 未安装,请执行: pip install rapid-layout"}
  55. engine = self._get_engine()
  56. if engine is None:
  57. return {"error": "RapidLayout 初始化失败"}
  58. doc = fitz.open(str(pdf_path))
  59. try:
  60. total_pages = len(doc)
  61. target_pages = pages if pages is not None else list(range(total_pages))
  62. all_labels: List[str] = []
  63. page_details: List[Dict] = []
  64. for page_num in target_pages:
  65. if page_num >= total_pages:
  66. continue
  67. page = doc.load_page(page_num)
  68. rect = page.rect
  69. clip_box = fitz.Rect(
  70. 0, self.clip_top,
  71. rect.width, rect.height - self.clip_bottom,
  72. )
  73. pix = page.get_pixmap(dpi=self.dpi, clip=clip_box)
  74. img = np.frombuffer(pix.samples, dtype=np.uint8).reshape(
  75. pix.height, pix.width, 3,
  76. )
  77. layout_output = engine(img)
  78. scale_x = clip_box.width / img.shape[1]
  79. scale_y = clip_box.height / img.shape[0]
  80. page_regions: List[Dict] = []
  81. page_labels: List[str] = []
  82. if hasattr(layout_output, 'boxes') and hasattr(layout_output, 'class_names'):
  83. for box, label, score in zip(
  84. layout_output.boxes,
  85. layout_output.class_names,
  86. layout_output.scores,
  87. ):
  88. if score < self.confidence_threshold:
  89. continue
  90. pdf_x1 = clip_box.x0 + box[0] * scale_x
  91. pdf_y1 = clip_box.y0 + box[1] * scale_y
  92. pdf_x2 = clip_box.x0 + box[2] * scale_x
  93. pdf_y2 = clip_box.y0 + box[3] * scale_y
  94. width = pdf_x2 - pdf_x1
  95. height = pdf_y2 - pdf_y1
  96. page_regions.append({
  97. "label": label,
  98. "score": round(float(score), 4),
  99. "bbox": [round(pdf_x1, 1), round(pdf_y1, 1),
  100. round(pdf_x2, 1), round(pdf_y2, 1)],
  101. "size": [round(width, 1), round(height, 1)],
  102. })
  103. page_labels.append(label)
  104. all_labels.extend(page_labels)
  105. page_details.append({
  106. "page": page_num + 1,
  107. "regions": page_regions,
  108. "counts": dict(Counter(page_labels)),
  109. })
  110. if save_images_dir:
  111. self._save_annotated_image(
  112. img, layout_output, page_num + 1,
  113. scale_x, scale_y, save_images_dir,
  114. )
  115. finally:
  116. doc.close()
  117. label_counter = Counter(all_labels)
  118. return {
  119. "total_pages": total_pages,
  120. "analyzed_pages": len(target_pages),
  121. "total_regions": len(all_labels),
  122. "label_distribution": dict(label_counter.most_common()),
  123. "table_count": label_counter.get("table", 0),
  124. "image_count": label_counter.get("image", 0),
  125. "figure_count": label_counter.get("figure", 0),
  126. "page_details": page_details,
  127. }
  128. def _save_annotated_image(
  129. self,
  130. img: np.ndarray,
  131. layout_output,
  132. page_num: int,
  133. scale_x: float,
  134. scale_y: float,
  135. output_dir: Path,
  136. ):
  137. """保存带标注框的图片"""
  138. try:
  139. from PIL import Image, ImageDraw, ImageFont
  140. except ImportError:
  141. print(" [跳过] Pillow 未安装,无法保存标注图片")
  142. return
  143. pil_img = Image.fromarray(img)
  144. draw = ImageDraw.Draw(pil_img)
  145. label_colors = {
  146. "table": (0, 255, 0), # 绿色
  147. "figure": (255, 80, 80), # 红色 — 关键:非标表格可能在这
  148. "figure_caption": (255, 165, 0),# 橙色
  149. "table_caption": (200, 200, 0), # 黄绿
  150. "text": (0, 0, 255), # 蓝色
  151. "title": (255, 255, 0), # 黄色
  152. "header": (128, 0, 128), # 紫色
  153. "footer": (128, 128, 0), # 橄榄色
  154. "reference": (0, 128, 128),
  155. "equation": (0, 200, 200),
  156. }
  157. default_color = (200, 200, 200)
  158. if hasattr(layout_output, 'boxes') and hasattr(layout_output, 'class_names'):
  159. for box, label, score in zip(
  160. layout_output.boxes,
  161. layout_output.class_names,
  162. layout_output.scores,
  163. ):
  164. if score < self.confidence_threshold:
  165. continue
  166. x1_img = box[0] / scale_x
  167. y1_img = box[1] / scale_y
  168. x2_img = box[2] / scale_x
  169. y2_img = box[3] / scale_y
  170. color = label_colors.get(label, default_color)
  171. draw.rectangle([x1_img, y1_img, x2_img, y2_img], outline=color, width=2)
  172. draw.text(
  173. (x1_img + 2, y1_img + 2),
  174. f"{label} ({score:.2f})",
  175. fill=color,
  176. )
  177. output_path = output_dir / f"page_{page_num:03d}_layout.jpg"
  178. pil_img.save(str(output_path), quality=85)
  179. print(f" [保存] {output_path}")
  180. def print_report(result: Dict):
  181. """打印检测报告"""
  182. if "error" in result:
  183. print(f"[错误] {result['error']}")
  184. return
  185. print()
  186. print("=" * 70)
  187. print("YOLO 版面检测报告")
  188. print("=" * 70)
  189. print(f"总页数: {result['total_pages']}")
  190. print(f"分析页数: {result['analyzed_pages']}")
  191. print(f"检测区域总数: {result['total_regions']}")
  192. print()
  193. print("标签分布:")
  194. print("-" * 50)
  195. for label, count in result["label_distribution"].items():
  196. pct = count / max(result["total_regions"], 1) * 100
  197. bar = "█" * int(pct / 2)
  198. print(f" {label:15s}: {count:4d} ({pct:5.1f}%) {bar}")
  199. print()
  200. # 重点关注
  201. print("关键指标:")
  202. print(f" table : {result['table_count']}")
  203. print(f" image : {result['image_count']}")
  204. print(f" figure : {result['figure_count']}")
  205. print()
  206. # 逐页详情
  207. print("逐页详情:")
  208. print("-" * 50)
  209. for page_info in result["page_details"]:
  210. page_num = page_info["page"]
  211. regions = page_info["regions"]
  212. if not regions:
  213. continue
  214. print(f"\n --- 第 {page_num} 页 ({len(regions)} 个区域) ---")
  215. for r in regions:
  216. size_str = f"{r['size'][0]}x{r['size'][1]}"
  217. print(f" [{r['label']:12s}] score={r['score']:.3f} "
  218. f"bbox=({r['bbox'][0]:.0f},{r['bbox'][1]:.0f},{r['bbox'][2]:.0f},{r['bbox'][3]:.0f}) "
  219. f"size={size_str}")
  220. print()
  221. def print_batch_report(batch_results: List[Dict]):
  222. """打印批统计报告"""
  223. valid = [r for r in batch_results if "error" not in r]
  224. errors = [r for r in batch_results if "error" in r]
  225. if not valid:
  226. print("[错误] 没有成功分析任何 PDF 文件")
  227. return
  228. print()
  229. print("=" * 80)
  230. print("YOLO 版面检测 — 批统计报告")
  231. print("=" * 80)
  232. print(f"分析文件数: {len(batch_results)} (成功 {len(valid)}, 失败 {len(errors)})")
  233. # 汇总所有文件的标签计数
  234. all_labels: Counter = Counter()
  235. file_summaries: List[Dict] = []
  236. for r in valid:
  237. file_labels = r["label_distribution"]
  238. all_labels.update(file_labels)
  239. total = r["total_regions"]
  240. file_summaries.append({
  241. "file": r["file_name"],
  242. "pages": r["total_pages"],
  243. "regions": total,
  244. "table_pct": file_labels.get("table", 0) / max(total, 1) * 100,
  245. "figure_pct": file_labels.get("figure", 0) / max(total, 1) * 100,
  246. "table_count": file_labels.get("table", 0),
  247. "figure_count": file_labels.get("figure", 0),
  248. })
  249. total_regions = sum(s["regions"] for s in file_summaries)
  250. total_pages = sum(s["pages"] for s in file_summaries)
  251. print(f"总页数: {total_pages}")
  252. print(f"总区域数: {total_regions}")
  253. print()
  254. # 全局标签分布
  255. print("全局标签分布:")
  256. print("-" * 55)
  257. for label, count in all_labels.most_common():
  258. pct = count / max(total_regions, 1) * 100
  259. bar = "█" * int(pct)
  260. print(f" {label:15s}: {count:5d} ({pct:5.1f}%) {bar}")
  261. print()
  262. # 逐文件摘要
  263. print("逐文件摘要:")
  264. print("-" * 80)
  265. print(f" {'文件':40s} {'页':>4s} {'区域':>5s} {'table%':>7s} {'figure%':>7s} {'table':>6s} {'figure':>6s}")
  266. print(" " + "-" * 76)
  267. for s in file_summaries:
  268. name = s["file"][:38] + ".." if len(s["file"]) > 40 else s["file"]
  269. print(f" {name:40s} {s['pages']:4d} {s['regions']:5d} "
  270. f"{s['table_pct']:6.1f}% {s['figure_pct']:6.1f}% "
  271. f"{s['table_count']:5d} {s['figure_count']:5d}")
  272. # 平均统计
  273. avg_table_pct = sum(s["table_pct"] for s in file_summaries) / len(file_summaries)
  274. avg_figure_pct = sum(s["figure_pct"] for s in file_summaries) / len(file_summaries)
  275. avg_regions_per_page = total_regions / max(total_pages, 1)
  276. avg_table_per_page = sum(s["table_count"] for s in file_summaries) / max(total_pages, 1)
  277. avg_figure_per_page = sum(s["figure_count"] for s in file_summaries) / max(total_pages, 1)
  278. print()
  279. print("平均统计 (按页):")
  280. print("-" * 40)
  281. print(f" 平均区域/页: {avg_regions_per_page:.1f}")
  282. print(f" 平均 table/页: {avg_table_per_page:.2f}")
  283. print(f" 平均 figure/页: {avg_figure_per_page:.2f}")
  284. print(f" 平均 table 占比: {avg_table_pct:.1f}%")
  285. print(f" 平均 figure 占比:{avg_figure_pct:.1f}%")
  286. print(f" table+figure/页: {avg_table_per_page + avg_figure_per_page:.2f}")
  287. if errors:
  288. print()
  289. print(f"失败文件 ({len(errors)}):")
  290. for e in errors:
  291. print(f" - {e['file_name']}: {e['error']}")
  292. print()
  293. def main():
  294. parser = argparse.ArgumentParser(description="YOLO 版面检测模型测试")
  295. parser.add_argument("-p", "--pdf", default=None, help="单个 PDF 文件路径")
  296. parser.add_argument("-d", "--dir", default=None, help="批量: 扫描目录下所有 PDF 文件")
  297. parser.add_argument("--pages", default=None, help="分析指定页码, 逗号分隔, 如 0,1,2 (0-based)")
  298. parser.add_argument("--save-images", action="store_true", help="保存标注图片 (批模式不生效)")
  299. parser.add_argument("--output-dir", default=None, help="输出目录 (默认与 PDF 同目录)")
  300. parser.add_argument("--dpi", type=int, default=200, help="渲染 DPI (默认 200)")
  301. parser.add_argument("--confidence", type=float, default=0.3, help="置信度阈值 (默认 0.3)")
  302. parser.add_argument("--clip-top", type=float, default=60, help="顶部裁剪 (默认 60)")
  303. parser.add_argument("--clip-bottom", type=float, default=60, help="底部裁剪 (默认 60)")
  304. parser.add_argument("--json", action="store_true", help="输出 JSON 格式 (批模式输出每个文件的关键统计)")
  305. args = parser.parse_args()
  306. if not args.pdf and not args.dir:
  307. print("[错误] 请指定 -p <pdf文件> 或 -d <pdf目录>")
  308. return 1
  309. tester = YoloLayoutTester(
  310. dpi=args.dpi,
  311. clip_top=args.clip_top,
  312. clip_bottom=args.clip_bottom,
  313. confidence_threshold=args.confidence,
  314. )
  315. pages = None
  316. if args.pages:
  317. pages = [int(p.strip()) for p in args.pages.split(",")]
  318. # ---- 单文件模式 ----
  319. if args.pdf:
  320. pdf_path = Path(args.pdf)
  321. if not pdf_path.exists():
  322. print(f"[错误] PDF 不存在: {pdf_path}")
  323. return 1
  324. output_dir = None
  325. if args.save_images:
  326. output_dir = Path(args.output_dir) if args.output_dir else pdf_path.parent / "yolo_layout_output"
  327. output_dir.mkdir(parents=True, exist_ok=True)
  328. print(f"[分析] {pdf_path}")
  329. result = tester.analyze_pdf(pdf_path, pages=pages, save_images_dir=output_dir)
  330. result["file_name"] = pdf_path.name
  331. if args.json:
  332. print(json.dumps(result, ensure_ascii=False, indent=2))
  333. else:
  334. print_report(result)
  335. return 0
  336. # ---- 批模式 ----
  337. dir_path = Path(args.dir)
  338. if not dir_path.is_dir():
  339. print(f"[错误] 目录不存在: {dir_path}")
  340. return 1
  341. pdf_files = sorted(dir_path.glob("*.pdf"))
  342. if not pdf_files:
  343. print(f"[错误] 目录下无 PDF 文件: {dir_path}")
  344. return 1
  345. print(f"[批分析] 找到 {len(pdf_files)} 个 PDF 文件")
  346. print(f"[批分析] 目录: {dir_path}")
  347. print()
  348. batch_results: List[Dict] = []
  349. for idx, pdf_path in enumerate(pdf_files, 1):
  350. print(f"[{idx}/{len(pdf_files)}] {pdf_path.name} ...", end=" ", flush=True)
  351. try:
  352. result = tester.analyze_pdf(pdf_path, pages=pages)
  353. result["file_name"] = pdf_path.name
  354. batch_results.append(result)
  355. regions = result["total_regions"]
  356. t = result.get("table_count", 0)
  357. f = result.get("figure_count", 0)
  358. print(f"OK ({result['total_pages']}页, {regions}区域, table={t}, figure={f})")
  359. except Exception as e:
  360. print(f"失败: {e}")
  361. batch_results.append({"file_name": pdf_path.name, "error": str(e)})
  362. if args.json:
  363. summary = []
  364. for r in batch_results:
  365. if "error" in r:
  366. summary.append(r)
  367. else:
  368. summary.append({
  369. "file": r["file_name"],
  370. "pages": r["total_pages"],
  371. "regions": r["total_regions"],
  372. "label_distribution": r["label_distribution"],
  373. "table_count": r.get("table_count", 0),
  374. "figure_count": r.get("figure_count", 0),
  375. })
  376. print(json.dumps(summary, ensure_ascii=False, indent=2))
  377. else:
  378. print_batch_report(batch_results)
  379. return 0
  380. if __name__ == "__main__":
  381. sys.exit(main())