""" YOLO 版面检测模型测试脚本 测试 RapidLayout 对表格(table)、图片(image)的识别情况, 确认非标准表格是否被误判为 image,以及纯图片的分类标签。 用法: python utils_test/Yolo_Test/test_yolo_layout.py -p python utils_test/Yolo_Test/test_yolo_layout.py -p --save-images python utils_test/Yolo_Test/test_yolo_layout.py -p --pages 0,1,2 依赖: pip install rapid-layout pymupdf numpy Pillow """ import argparse import json import sys from collections import Counter from pathlib import Path from typing import Dict, List, Optional, Tuple import fitz import numpy as np try: from rapid_layout import RapidLayout RAPID_LAYOUT_AVAILABLE = True except ImportError: RAPID_LAYOUT_AVAILABLE = False RapidLayout = None class YoloLayoutTester: """YOLO 版面检测测试器""" def __init__( self, dpi: int = 200, clip_top: float = 60, clip_bottom: float = 60, confidence_threshold: float = 0.3, ): self.dpi = dpi self.clip_top = clip_top self.clip_bottom = clip_bottom self.confidence_threshold = confidence_threshold self._engine: Optional[RapidLayout] = None def _get_engine(self) -> Optional[RapidLayout]: if not RAPID_LAYOUT_AVAILABLE: return None if self._engine is None: self._engine = RapidLayout() return self._engine def analyze_pdf( self, pdf_path: Path, pages: Optional[List[int]] = None, save_images_dir: Optional[Path] = None, ) -> Dict: """分析 PDF 文件的版面检测结果""" if not RAPID_LAYOUT_AVAILABLE: return {"error": "RapidLayout 未安装,请执行: pip install rapid-layout"} engine = self._get_engine() if engine is None: return {"error": "RapidLayout 初始化失败"} doc = fitz.open(str(pdf_path)) try: total_pages = len(doc) target_pages = pages if pages is not None else list(range(total_pages)) all_labels: List[str] = [] page_details: List[Dict] = [] for page_num in target_pages: if page_num >= total_pages: continue page = doc.load_page(page_num) rect = page.rect clip_box = fitz.Rect( 0, self.clip_top, rect.width, rect.height - self.clip_bottom, ) pix = page.get_pixmap(dpi=self.dpi, clip=clip_box) img = np.frombuffer(pix.samples, dtype=np.uint8).reshape( pix.height, pix.width, 3, ) layout_output = engine(img) scale_x = clip_box.width / img.shape[1] scale_y = clip_box.height / img.shape[0] page_regions: List[Dict] = [] page_labels: List[str] = [] if hasattr(layout_output, 'boxes') and hasattr(layout_output, 'class_names'): for box, label, score in zip( layout_output.boxes, layout_output.class_names, layout_output.scores, ): if score < self.confidence_threshold: continue pdf_x1 = clip_box.x0 + box[0] * scale_x pdf_y1 = clip_box.y0 + box[1] * scale_y pdf_x2 = clip_box.x0 + box[2] * scale_x pdf_y2 = clip_box.y0 + box[3] * scale_y width = pdf_x2 - pdf_x1 height = pdf_y2 - pdf_y1 page_regions.append({ "label": label, "score": round(float(score), 4), "bbox": [round(pdf_x1, 1), round(pdf_y1, 1), round(pdf_x2, 1), round(pdf_y2, 1)], "size": [round(width, 1), round(height, 1)], }) page_labels.append(label) all_labels.extend(page_labels) page_details.append({ "page": page_num + 1, "regions": page_regions, "counts": dict(Counter(page_labels)), }) if save_images_dir: self._save_annotated_image( img, layout_output, page_num + 1, scale_x, scale_y, save_images_dir, ) finally: doc.close() label_counter = Counter(all_labels) return { "total_pages": total_pages, "analyzed_pages": len(target_pages), "total_regions": len(all_labels), "label_distribution": dict(label_counter.most_common()), "table_count": label_counter.get("table", 0), "image_count": label_counter.get("image", 0), "figure_count": label_counter.get("figure", 0), "page_details": page_details, } def _save_annotated_image( self, img: np.ndarray, layout_output, page_num: int, scale_x: float, scale_y: float, output_dir: Path, ): """保存带标注框的图片""" try: from PIL import Image, ImageDraw, ImageFont except ImportError: print(" [跳过] Pillow 未安装,无法保存标注图片") return pil_img = Image.fromarray(img) draw = ImageDraw.Draw(pil_img) label_colors = { "table": (0, 255, 0), # 绿色 "figure": (255, 80, 80), # 红色 — 关键:非标表格可能在这 "figure_caption": (255, 165, 0),# 橙色 "table_caption": (200, 200, 0), # 黄绿 "text": (0, 0, 255), # 蓝色 "title": (255, 255, 0), # 黄色 "header": (128, 0, 128), # 紫色 "footer": (128, 128, 0), # 橄榄色 "reference": (0, 128, 128), "equation": (0, 200, 200), } default_color = (200, 200, 200) if hasattr(layout_output, 'boxes') and hasattr(layout_output, 'class_names'): for box, label, score in zip( layout_output.boxes, layout_output.class_names, layout_output.scores, ): if score < self.confidence_threshold: continue x1_img = box[0] / scale_x y1_img = box[1] / scale_y x2_img = box[2] / scale_x y2_img = box[3] / scale_y color = label_colors.get(label, default_color) draw.rectangle([x1_img, y1_img, x2_img, y2_img], outline=color, width=2) draw.text( (x1_img + 2, y1_img + 2), f"{label} ({score:.2f})", fill=color, ) output_path = output_dir / f"page_{page_num:03d}_layout.jpg" pil_img.save(str(output_path), quality=85) print(f" [保存] {output_path}") def print_report(result: Dict): """打印检测报告""" if "error" in result: print(f"[错误] {result['error']}") return print() print("=" * 70) print("YOLO 版面检测报告") print("=" * 70) print(f"总页数: {result['total_pages']}") print(f"分析页数: {result['analyzed_pages']}") print(f"检测区域总数: {result['total_regions']}") print() print("标签分布:") print("-" * 50) for label, count in result["label_distribution"].items(): pct = count / max(result["total_regions"], 1) * 100 bar = "█" * int(pct / 2) print(f" {label:15s}: {count:4d} ({pct:5.1f}%) {bar}") print() # 重点关注 print("关键指标:") print(f" table : {result['table_count']}") print(f" image : {result['image_count']}") print(f" figure : {result['figure_count']}") print() # 逐页详情 print("逐页详情:") print("-" * 50) for page_info in result["page_details"]: page_num = page_info["page"] regions = page_info["regions"] if not regions: continue print(f"\n --- 第 {page_num} 页 ({len(regions)} 个区域) ---") for r in regions: size_str = f"{r['size'][0]}x{r['size'][1]}" print(f" [{r['label']:12s}] score={r['score']:.3f} " f"bbox=({r['bbox'][0]:.0f},{r['bbox'][1]:.0f},{r['bbox'][2]:.0f},{r['bbox'][3]:.0f}) " f"size={size_str}") print() def print_batch_report(batch_results: List[Dict]): """打印批统计报告""" valid = [r for r in batch_results if "error" not in r] errors = [r for r in batch_results if "error" in r] if not valid: print("[错误] 没有成功分析任何 PDF 文件") return print() print("=" * 80) print("YOLO 版面检测 — 批统计报告") print("=" * 80) print(f"分析文件数: {len(batch_results)} (成功 {len(valid)}, 失败 {len(errors)})") # 汇总所有文件的标签计数 all_labels: Counter = Counter() file_summaries: List[Dict] = [] for r in valid: file_labels = r["label_distribution"] all_labels.update(file_labels) total = r["total_regions"] file_summaries.append({ "file": r["file_name"], "pages": r["total_pages"], "regions": total, "table_pct": file_labels.get("table", 0) / max(total, 1) * 100, "figure_pct": file_labels.get("figure", 0) / max(total, 1) * 100, "table_count": file_labels.get("table", 0), "figure_count": file_labels.get("figure", 0), }) total_regions = sum(s["regions"] for s in file_summaries) total_pages = sum(s["pages"] for s in file_summaries) print(f"总页数: {total_pages}") print(f"总区域数: {total_regions}") print() # 全局标签分布 print("全局标签分布:") print("-" * 55) for label, count in all_labels.most_common(): pct = count / max(total_regions, 1) * 100 bar = "█" * int(pct) print(f" {label:15s}: {count:5d} ({pct:5.1f}%) {bar}") print() # 逐文件摘要 print("逐文件摘要:") print("-" * 80) print(f" {'文件':40s} {'页':>4s} {'区域':>5s} {'table%':>7s} {'figure%':>7s} {'table':>6s} {'figure':>6s}") print(" " + "-" * 76) for s in file_summaries: name = s["file"][:38] + ".." if len(s["file"]) > 40 else s["file"] print(f" {name:40s} {s['pages']:4d} {s['regions']:5d} " f"{s['table_pct']:6.1f}% {s['figure_pct']:6.1f}% " f"{s['table_count']:5d} {s['figure_count']:5d}") # 平均统计 avg_table_pct = sum(s["table_pct"] for s in file_summaries) / len(file_summaries) avg_figure_pct = sum(s["figure_pct"] for s in file_summaries) / len(file_summaries) avg_regions_per_page = total_regions / max(total_pages, 1) avg_table_per_page = sum(s["table_count"] for s in file_summaries) / max(total_pages, 1) avg_figure_per_page = sum(s["figure_count"] for s in file_summaries) / max(total_pages, 1) print() print("平均统计 (按页):") print("-" * 40) print(f" 平均区域/页: {avg_regions_per_page:.1f}") print(f" 平均 table/页: {avg_table_per_page:.2f}") print(f" 平均 figure/页: {avg_figure_per_page:.2f}") print(f" 平均 table 占比: {avg_table_pct:.1f}%") print(f" 平均 figure 占比:{avg_figure_pct:.1f}%") print(f" table+figure/页: {avg_table_per_page + avg_figure_per_page:.2f}") if errors: print() print(f"失败文件 ({len(errors)}):") for e in errors: print(f" - {e['file_name']}: {e['error']}") print() def main(): parser = argparse.ArgumentParser(description="YOLO 版面检测模型测试") parser.add_argument("-p", "--pdf", default=None, help="单个 PDF 文件路径") parser.add_argument("-d", "--dir", default=None, help="批量: 扫描目录下所有 PDF 文件") parser.add_argument("--pages", default=None, help="分析指定页码, 逗号分隔, 如 0,1,2 (0-based)") parser.add_argument("--save-images", action="store_true", help="保存标注图片 (批模式不生效)") parser.add_argument("--output-dir", default=None, help="输出目录 (默认与 PDF 同目录)") parser.add_argument("--dpi", type=int, default=200, help="渲染 DPI (默认 200)") parser.add_argument("--confidence", type=float, default=0.3, help="置信度阈值 (默认 0.3)") parser.add_argument("--clip-top", type=float, default=60, help="顶部裁剪 (默认 60)") parser.add_argument("--clip-bottom", type=float, default=60, help="底部裁剪 (默认 60)") parser.add_argument("--json", action="store_true", help="输出 JSON 格式 (批模式输出每个文件的关键统计)") args = parser.parse_args() if not args.pdf and not args.dir: print("[错误] 请指定 -p 或 -d ") return 1 tester = YoloLayoutTester( dpi=args.dpi, clip_top=args.clip_top, clip_bottom=args.clip_bottom, confidence_threshold=args.confidence, ) pages = None if args.pages: pages = [int(p.strip()) for p in args.pages.split(",")] # ---- 单文件模式 ---- if args.pdf: pdf_path = Path(args.pdf) if not pdf_path.exists(): print(f"[错误] PDF 不存在: {pdf_path}") return 1 output_dir = None if args.save_images: output_dir = Path(args.output_dir) if args.output_dir else pdf_path.parent / "yolo_layout_output" output_dir.mkdir(parents=True, exist_ok=True) print(f"[分析] {pdf_path}") result = tester.analyze_pdf(pdf_path, pages=pages, save_images_dir=output_dir) result["file_name"] = pdf_path.name if args.json: print(json.dumps(result, ensure_ascii=False, indent=2)) else: print_report(result) return 0 # ---- 批模式 ---- dir_path = Path(args.dir) if not dir_path.is_dir(): print(f"[错误] 目录不存在: {dir_path}") return 1 pdf_files = sorted(dir_path.glob("*.pdf")) if not pdf_files: print(f"[错误] 目录下无 PDF 文件: {dir_path}") return 1 print(f"[批分析] 找到 {len(pdf_files)} 个 PDF 文件") print(f"[批分析] 目录: {dir_path}") print() batch_results: List[Dict] = [] for idx, pdf_path in enumerate(pdf_files, 1): print(f"[{idx}/{len(pdf_files)}] {pdf_path.name} ...", end=" ", flush=True) try: result = tester.analyze_pdf(pdf_path, pages=pages) result["file_name"] = pdf_path.name batch_results.append(result) regions = result["total_regions"] t = result.get("table_count", 0) f = result.get("figure_count", 0) print(f"OK ({result['total_pages']}页, {regions}区域, table={t}, figure={f})") except Exception as e: print(f"失败: {e}") batch_results.append({"file_name": pdf_path.name, "error": str(e)}) if args.json: summary = [] for r in batch_results: if "error" in r: summary.append(r) else: summary.append({ "file": r["file_name"], "pages": r["total_pages"], "regions": r["total_regions"], "label_distribution": r["label_distribution"], "table_count": r.get("table_count", 0), "figure_count": r.get("figure_count", 0), }) print(json.dumps(summary, ensure_ascii=False, indent=2)) else: print_batch_report(batch_results) return 0 if __name__ == "__main__": sys.exit(main())