| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451 |
- """
- YOLO 版面检测模型测试脚本
- 测试 RapidLayout 对表格(table)、图片(image)的识别情况,
- 确认非标准表格是否被误判为 image,以及纯图片的分类标签。
- 用法:
- python utils_test/Yolo_Test/test_yolo_layout.py -p <pdf_path>
- python utils_test/Yolo_Test/test_yolo_layout.py -p <pdf_path> --save-images
- python utils_test/Yolo_Test/test_yolo_layout.py -p <pdf_path> --pages 0,1,2
- 依赖:
- pip install rapid-layout pymupdf numpy Pillow
- """
- import argparse
- import json
- import sys
- from collections import Counter
- from pathlib import Path
- from typing import Dict, List, Optional, Tuple
- import fitz
- import numpy as np
- try:
- from rapid_layout import RapidLayout
- RAPID_LAYOUT_AVAILABLE = True
- except ImportError:
- RAPID_LAYOUT_AVAILABLE = False
- RapidLayout = None
- class YoloLayoutTester:
- """YOLO 版面检测测试器"""
- def __init__(
- self,
- dpi: int = 200,
- clip_top: float = 60,
- clip_bottom: float = 60,
- confidence_threshold: float = 0.3,
- ):
- self.dpi = dpi
- self.clip_top = clip_top
- self.clip_bottom = clip_bottom
- self.confidence_threshold = confidence_threshold
- self._engine: Optional[RapidLayout] = None
- def _get_engine(self) -> Optional[RapidLayout]:
- if not RAPID_LAYOUT_AVAILABLE:
- return None
- if self._engine is None:
- self._engine = RapidLayout()
- return self._engine
- def analyze_pdf(
- self,
- pdf_path: Path,
- pages: Optional[List[int]] = None,
- save_images_dir: Optional[Path] = None,
- ) -> Dict:
- """分析 PDF 文件的版面检测结果"""
- if not RAPID_LAYOUT_AVAILABLE:
- return {"error": "RapidLayout 未安装,请执行: pip install rapid-layout"}
- engine = self._get_engine()
- if engine is None:
- return {"error": "RapidLayout 初始化失败"}
- doc = fitz.open(str(pdf_path))
- try:
- total_pages = len(doc)
- target_pages = pages if pages is not None else list(range(total_pages))
- all_labels: List[str] = []
- page_details: List[Dict] = []
- for page_num in target_pages:
- if page_num >= total_pages:
- continue
- page = doc.load_page(page_num)
- rect = page.rect
- clip_box = fitz.Rect(
- 0, self.clip_top,
- rect.width, rect.height - self.clip_bottom,
- )
- pix = page.get_pixmap(dpi=self.dpi, clip=clip_box)
- img = np.frombuffer(pix.samples, dtype=np.uint8).reshape(
- pix.height, pix.width, 3,
- )
- layout_output = engine(img)
- scale_x = clip_box.width / img.shape[1]
- scale_y = clip_box.height / img.shape[0]
- page_regions: List[Dict] = []
- page_labels: List[str] = []
- if hasattr(layout_output, 'boxes') and hasattr(layout_output, 'class_names'):
- for box, label, score in zip(
- layout_output.boxes,
- layout_output.class_names,
- layout_output.scores,
- ):
- if score < self.confidence_threshold:
- continue
- pdf_x1 = clip_box.x0 + box[0] * scale_x
- pdf_y1 = clip_box.y0 + box[1] * scale_y
- pdf_x2 = clip_box.x0 + box[2] * scale_x
- pdf_y2 = clip_box.y0 + box[3] * scale_y
- width = pdf_x2 - pdf_x1
- height = pdf_y2 - pdf_y1
- page_regions.append({
- "label": label,
- "score": round(float(score), 4),
- "bbox": [round(pdf_x1, 1), round(pdf_y1, 1),
- round(pdf_x2, 1), round(pdf_y2, 1)],
- "size": [round(width, 1), round(height, 1)],
- })
- page_labels.append(label)
- all_labels.extend(page_labels)
- page_details.append({
- "page": page_num + 1,
- "regions": page_regions,
- "counts": dict(Counter(page_labels)),
- })
- if save_images_dir:
- self._save_annotated_image(
- img, layout_output, page_num + 1,
- scale_x, scale_y, save_images_dir,
- )
- finally:
- doc.close()
- label_counter = Counter(all_labels)
- return {
- "total_pages": total_pages,
- "analyzed_pages": len(target_pages),
- "total_regions": len(all_labels),
- "label_distribution": dict(label_counter.most_common()),
- "table_count": label_counter.get("table", 0),
- "image_count": label_counter.get("image", 0),
- "figure_count": label_counter.get("figure", 0),
- "page_details": page_details,
- }
- def _save_annotated_image(
- self,
- img: np.ndarray,
- layout_output,
- page_num: int,
- scale_x: float,
- scale_y: float,
- output_dir: Path,
- ):
- """保存带标注框的图片"""
- try:
- from PIL import Image, ImageDraw, ImageFont
- except ImportError:
- print(" [跳过] Pillow 未安装,无法保存标注图片")
- return
- pil_img = Image.fromarray(img)
- draw = ImageDraw.Draw(pil_img)
- label_colors = {
- "table": (0, 255, 0), # 绿色
- "figure": (255, 80, 80), # 红色 — 关键:非标表格可能在这
- "figure_caption": (255, 165, 0),# 橙色
- "table_caption": (200, 200, 0), # 黄绿
- "text": (0, 0, 255), # 蓝色
- "title": (255, 255, 0), # 黄色
- "header": (128, 0, 128), # 紫色
- "footer": (128, 128, 0), # 橄榄色
- "reference": (0, 128, 128),
- "equation": (0, 200, 200),
- }
- default_color = (200, 200, 200)
- if hasattr(layout_output, 'boxes') and hasattr(layout_output, 'class_names'):
- for box, label, score in zip(
- layout_output.boxes,
- layout_output.class_names,
- layout_output.scores,
- ):
- if score < self.confidence_threshold:
- continue
- x1_img = box[0] / scale_x
- y1_img = box[1] / scale_y
- x2_img = box[2] / scale_x
- y2_img = box[3] / scale_y
- color = label_colors.get(label, default_color)
- draw.rectangle([x1_img, y1_img, x2_img, y2_img], outline=color, width=2)
- draw.text(
- (x1_img + 2, y1_img + 2),
- f"{label} ({score:.2f})",
- fill=color,
- )
- output_path = output_dir / f"page_{page_num:03d}_layout.jpg"
- pil_img.save(str(output_path), quality=85)
- print(f" [保存] {output_path}")
- def print_report(result: Dict):
- """打印检测报告"""
- if "error" in result:
- print(f"[错误] {result['error']}")
- return
- print()
- print("=" * 70)
- print("YOLO 版面检测报告")
- print("=" * 70)
- print(f"总页数: {result['total_pages']}")
- print(f"分析页数: {result['analyzed_pages']}")
- print(f"检测区域总数: {result['total_regions']}")
- print()
- print("标签分布:")
- print("-" * 50)
- for label, count in result["label_distribution"].items():
- pct = count / max(result["total_regions"], 1) * 100
- bar = "█" * int(pct / 2)
- print(f" {label:15s}: {count:4d} ({pct:5.1f}%) {bar}")
- print()
- # 重点关注
- print("关键指标:")
- print(f" table : {result['table_count']}")
- print(f" image : {result['image_count']}")
- print(f" figure : {result['figure_count']}")
- print()
- # 逐页详情
- print("逐页详情:")
- print("-" * 50)
- for page_info in result["page_details"]:
- page_num = page_info["page"]
- regions = page_info["regions"]
- if not regions:
- continue
- print(f"\n --- 第 {page_num} 页 ({len(regions)} 个区域) ---")
- for r in regions:
- size_str = f"{r['size'][0]}x{r['size'][1]}"
- print(f" [{r['label']:12s}] score={r['score']:.3f} "
- f"bbox=({r['bbox'][0]:.0f},{r['bbox'][1]:.0f},{r['bbox'][2]:.0f},{r['bbox'][3]:.0f}) "
- f"size={size_str}")
- print()
- def print_batch_report(batch_results: List[Dict]):
- """打印批统计报告"""
- valid = [r for r in batch_results if "error" not in r]
- errors = [r for r in batch_results if "error" in r]
- if not valid:
- print("[错误] 没有成功分析任何 PDF 文件")
- return
- print()
- print("=" * 80)
- print("YOLO 版面检测 — 批统计报告")
- print("=" * 80)
- print(f"分析文件数: {len(batch_results)} (成功 {len(valid)}, 失败 {len(errors)})")
- # 汇总所有文件的标签计数
- all_labels: Counter = Counter()
- file_summaries: List[Dict] = []
- for r in valid:
- file_labels = r["label_distribution"]
- all_labels.update(file_labels)
- total = r["total_regions"]
- file_summaries.append({
- "file": r["file_name"],
- "pages": r["total_pages"],
- "regions": total,
- "table_pct": file_labels.get("table", 0) / max(total, 1) * 100,
- "figure_pct": file_labels.get("figure", 0) / max(total, 1) * 100,
- "table_count": file_labels.get("table", 0),
- "figure_count": file_labels.get("figure", 0),
- })
- total_regions = sum(s["regions"] for s in file_summaries)
- total_pages = sum(s["pages"] for s in file_summaries)
- print(f"总页数: {total_pages}")
- print(f"总区域数: {total_regions}")
- print()
- # 全局标签分布
- print("全局标签分布:")
- print("-" * 55)
- for label, count in all_labels.most_common():
- pct = count / max(total_regions, 1) * 100
- bar = "█" * int(pct)
- print(f" {label:15s}: {count:5d} ({pct:5.1f}%) {bar}")
- print()
- # 逐文件摘要
- print("逐文件摘要:")
- print("-" * 80)
- print(f" {'文件':40s} {'页':>4s} {'区域':>5s} {'table%':>7s} {'figure%':>7s} {'table':>6s} {'figure':>6s}")
- print(" " + "-" * 76)
- for s in file_summaries:
- name = s["file"][:38] + ".." if len(s["file"]) > 40 else s["file"]
- print(f" {name:40s} {s['pages']:4d} {s['regions']:5d} "
- f"{s['table_pct']:6.1f}% {s['figure_pct']:6.1f}% "
- f"{s['table_count']:5d} {s['figure_count']:5d}")
- # 平均统计
- avg_table_pct = sum(s["table_pct"] for s in file_summaries) / len(file_summaries)
- avg_figure_pct = sum(s["figure_pct"] for s in file_summaries) / len(file_summaries)
- avg_regions_per_page = total_regions / max(total_pages, 1)
- avg_table_per_page = sum(s["table_count"] for s in file_summaries) / max(total_pages, 1)
- avg_figure_per_page = sum(s["figure_count"] for s in file_summaries) / max(total_pages, 1)
- print()
- print("平均统计 (按页):")
- print("-" * 40)
- print(f" 平均区域/页: {avg_regions_per_page:.1f}")
- print(f" 平均 table/页: {avg_table_per_page:.2f}")
- print(f" 平均 figure/页: {avg_figure_per_page:.2f}")
- print(f" 平均 table 占比: {avg_table_pct:.1f}%")
- print(f" 平均 figure 占比:{avg_figure_pct:.1f}%")
- print(f" table+figure/页: {avg_table_per_page + avg_figure_per_page:.2f}")
- if errors:
- print()
- print(f"失败文件 ({len(errors)}):")
- for e in errors:
- print(f" - {e['file_name']}: {e['error']}")
- print()
- def main():
- parser = argparse.ArgumentParser(description="YOLO 版面检测模型测试")
- parser.add_argument("-p", "--pdf", default=None, help="单个 PDF 文件路径")
- parser.add_argument("-d", "--dir", default=None, help="批量: 扫描目录下所有 PDF 文件")
- parser.add_argument("--pages", default=None, help="分析指定页码, 逗号分隔, 如 0,1,2 (0-based)")
- parser.add_argument("--save-images", action="store_true", help="保存标注图片 (批模式不生效)")
- parser.add_argument("--output-dir", default=None, help="输出目录 (默认与 PDF 同目录)")
- parser.add_argument("--dpi", type=int, default=200, help="渲染 DPI (默认 200)")
- parser.add_argument("--confidence", type=float, default=0.3, help="置信度阈值 (默认 0.3)")
- parser.add_argument("--clip-top", type=float, default=60, help="顶部裁剪 (默认 60)")
- parser.add_argument("--clip-bottom", type=float, default=60, help="底部裁剪 (默认 60)")
- parser.add_argument("--json", action="store_true", help="输出 JSON 格式 (批模式输出每个文件的关键统计)")
- args = parser.parse_args()
- if not args.pdf and not args.dir:
- print("[错误] 请指定 -p <pdf文件> 或 -d <pdf目录>")
- return 1
- tester = YoloLayoutTester(
- dpi=args.dpi,
- clip_top=args.clip_top,
- clip_bottom=args.clip_bottom,
- confidence_threshold=args.confidence,
- )
- pages = None
- if args.pages:
- pages = [int(p.strip()) for p in args.pages.split(",")]
- # ---- 单文件模式 ----
- if args.pdf:
- pdf_path = Path(args.pdf)
- if not pdf_path.exists():
- print(f"[错误] PDF 不存在: {pdf_path}")
- return 1
- output_dir = None
- if args.save_images:
- output_dir = Path(args.output_dir) if args.output_dir else pdf_path.parent / "yolo_layout_output"
- output_dir.mkdir(parents=True, exist_ok=True)
- print(f"[分析] {pdf_path}")
- result = tester.analyze_pdf(pdf_path, pages=pages, save_images_dir=output_dir)
- result["file_name"] = pdf_path.name
- if args.json:
- print(json.dumps(result, ensure_ascii=False, indent=2))
- else:
- print_report(result)
- return 0
- # ---- 批模式 ----
- dir_path = Path(args.dir)
- if not dir_path.is_dir():
- print(f"[错误] 目录不存在: {dir_path}")
- return 1
- pdf_files = sorted(dir_path.glob("*.pdf"))
- if not pdf_files:
- print(f"[错误] 目录下无 PDF 文件: {dir_path}")
- return 1
- print(f"[批分析] 找到 {len(pdf_files)} 个 PDF 文件")
- print(f"[批分析] 目录: {dir_path}")
- print()
- batch_results: List[Dict] = []
- for idx, pdf_path in enumerate(pdf_files, 1):
- print(f"[{idx}/{len(pdf_files)}] {pdf_path.name} ...", end=" ", flush=True)
- try:
- result = tester.analyze_pdf(pdf_path, pages=pages)
- result["file_name"] = pdf_path.name
- batch_results.append(result)
- regions = result["total_regions"]
- t = result.get("table_count", 0)
- f = result.get("figure_count", 0)
- print(f"OK ({result['total_pages']}页, {regions}区域, table={t}, figure={f})")
- except Exception as e:
- print(f"失败: {e}")
- batch_results.append({"file_name": pdf_path.name, "error": str(e)})
- if args.json:
- summary = []
- for r in batch_results:
- if "error" in r:
- summary.append(r)
- else:
- summary.append({
- "file": r["file_name"],
- "pages": r["total_pages"],
- "regions": r["total_regions"],
- "label_distribution": r["label_distribution"],
- "table_count": r.get("table_count", 0),
- "figure_count": r.get("figure_count", 0),
- })
- print(json.dumps(summary, ensure_ascii=False, indent=2))
- else:
- print_batch_report(batch_results)
- return 0
- if __name__ == "__main__":
- sys.exit(main())
|