| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620 |
- """
- YOLO 目录页检测与 OCR 提取模块
- 用于在文档处理流程早期检测目录页并提取目录内容,
- 输出结构与 outline 保持一致,便于后续进行目录完整性检查。
- """
- import io
- import os
- import re
- from dataclasses import dataclass
- from typing import Dict, Any, List, Optional, Tuple
- from pathlib import Path
- import fitz
- import numpy as np
- from utils_test.minimal_pipeline._simple_logger import review_logger as logger
- from ultralytics import YOLO
- from PIL import Image
- @dataclass
- class CatalogItem:
- """目录项结构"""
- index: int # 章节序号(1-based)
- title: str # 章节标题
- page: str # 页码(字符串)
- original: str # 原始文本
- level: int = 1 # 层级(1=章,2=节)
- parent_title: str = "" # 父章节标题(用于二级)
- @dataclass
- class CatalogSection:
- """目录节结构(对应二级目录)"""
- title: str
- page: str
- level: int
- original: str
- @dataclass
- class CatalogChapter:
- """目录章结构(对应一级目录)"""
- index: int
- title: str
- page: str
- original: str
- subsections: List[CatalogSection]
- class TOCCatalogExtractor:
- """
- 目录页检测与内容提取器
- 使用 YOLO 模型检测目录页,使用 GLM-OCR 提取目录文本,
- 解析为结构化数据,输出格式与 outline 保持一致。
- """
- # YOLO 配置
- DEFAULT_MODEL_PATH = "best.pt" # 本地副本
- CONF_THRESHOLD = 0.25
- MAX_CHECK_PAGES = 50
- DPI = 150
- # OCR 配置(目录页使用更低DPI避免请求过大)
- OCR_DPI = 150
- MAX_SHORT_EDGE = 800
- JPEG_QUALITY = 85
- MAX_IMAGE_SIZE_MB = 5
- def __init__(
- self,
- model_path: str = None,
- ocr_api_url: str = "http://183.220.37.46:25429/v1/chat/completions",
- ocr_api_key: str = "",
- ocr_timeout: int = 600,
- ):
- self.model_path = model_path or self.DEFAULT_MODEL_PATH
- self.ocr_api_url = ocr_api_url
- self.ocr_api_key = ocr_api_key
- self.ocr_timeout = ocr_timeout
- self._model = None
- def _load_model(self) -> bool:
- """加载 YOLO 模型,缺少依赖或模型文件直接报错"""
- if not os.path.exists(self.model_path):
- raise FileNotFoundError(f"[TOC检测] YOLO模型文件不存在: {self.model_path}")
- if self._model is None:
- logger.info(f"[TOC检测] 正在加载YOLO模型: {self.model_path}")
- self._model = YOLO(self.model_path)
- return True
- def detect_and_extract(
- self,
- file_content: bytes,
- progress_callback=None
- ) -> Optional[Dict[str, Any]]:
- """
- 检测目录页并提取目录内容
- Args:
- file_content: PDF文件字节流
- progress_callback: 进度回调函数
- Returns:
- 目录结构字典,格式与 outline 保持一致:
- {
- "chapters": [...],
- "total_chapters": N
- }
- """
- if not self._load_model():
- return None
- doc = fitz.open(stream=file_content)
- try:
- # 1. 检测目录页范围
- toc_pages = self._detect_toc_pages(doc, progress_callback)
- if not toc_pages:
- logger.info("[TOC检测] 未检测到目录页")
- return None
- logger.info(f"[TOC检测] 检测到目录页: 第{toc_pages[0]+1}页 - 第{toc_pages[-1]+1}页")
- # 2. OCR 提取目录页内容
- if progress_callback:
- progress_callback("目录识别", 10, f"检测到{len(toc_pages)}页目录,开始OCR识别...")
- toc_text = self._ocr_toc_pages(doc, toc_pages, progress_callback)
- if not toc_text:
- return None
- # 3. 解析目录文本为结构化数据
- if progress_callback:
- progress_callback("目录识别", 80, "解析目录结构...")
- catalog = self._parse_toc_text(toc_text)
- # 添加目录页页码范围(1-based)
- if toc_pages:
- catalog["toc_page_range"] = {
- "start": toc_pages[0] + 1, # 转换为1-based页码
- "end": toc_pages[-1] + 1
- }
- if progress_callback:
- progress_callback("目录识别", 100, f"目录提取完成,共{catalog['total_chapters']}章")
- return catalog
- finally:
- doc.close()
- def _detect_toc_pages(
- self,
- doc: fitz.Document,
- progress_callback=None
- ) -> List[int]:
- """
- 使用 YOLO 检测目录页范围
- Returns:
- 目录页索引列表(0-based)
- """
- toc_pages = []
- total_pages = len(doc)
- pages_to_check = min(total_pages, self.MAX_CHECK_PAGES)
- for page_idx in range(pages_to_check):
- page = doc.load_page(page_idx)
- # 渲染页面
- zoom = self.DPI / 72
- mat = fitz.Matrix(zoom, zoom)
- pix = page.get_pixmap(matrix=mat)
- # 转换为 numpy 数组
- img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
- img_array = np.array(img)
- # YOLO 检测
- results = self._model(img_array, conf=self.CONF_THRESHOLD, verbose=False)
- # 检查是否检测到 catalogs 类别
- has_catalogs = False
- for result in results:
- if result.boxes is not None:
- for box in result.boxes:
- cls_id = int(box.cls.item())
- class_name = self._model.names.get(cls_id, f"class_{cls_id}")
- if class_name == 'catalogs':
- has_catalogs = True
- break
- if has_catalogs:
- break
- if has_catalogs:
- toc_pages.append(page_idx)
- logger.debug(f" 第{page_idx + 1:3d}页: 检测到目录")
- else:
- logger.debug(f" 第{page_idx + 1:3d}页: 未检测到目录")
- # 如果已经检测到目录,且现在没有检测到,认为目录结束
- if toc_pages:
- break
- if progress_callback and (page_idx + 1) % 5 == 0:
- progress = int((page_idx + 1) / pages_to_check * 10)
- progress_callback("目录识别", progress, f"扫描页面 {page_idx + 1}/{pages_to_check}")
- return toc_pages
- def _ocr_toc_pages(
- self,
- doc: fitz.Document,
- toc_pages: List[int],
- progress_callback=None
- ) -> str:
- """
- 对目录页进行 OCR 识别
- Returns:
- 合并后的目录文本
- """
- import base64
- import io
- import requests
- import time
- all_texts = []
- total = len(toc_pages)
- for idx, page_idx in enumerate(toc_pages):
- page = doc.load_page(page_idx)
- try:
- # 渲染页面(使用较低DPI避免图片过大)
- pix = page.get_pixmap(dpi=self.OCR_DPI)
- img_bytes = pix.tobytes("jpeg")
- # 压缩图片
- compressed = self._compress_image(img_bytes)
- img_size_mb = len(compressed) / (1024 * 1024)
- logger.debug(f" 第{page_idx + 1}页图片大小: {img_size_mb:.2f}MB")
- # 检查图片大小
- if img_size_mb > self.MAX_IMAGE_SIZE_MB:
- logger.warning(f" 第{page_idx + 1}页图片过大({img_size_mb:.2f}MB),尝试进一步压缩")
- # 再次压缩
- compressed = self._compress_image(compressed, force_smaller=True)
- img_size_mb = len(compressed) / (1024 * 1024)
- logger.debug(f" 压缩后大小: {img_size_mb:.2f}MB")
- img_base64 = base64.b64encode(compressed).decode('utf-8')
- # 请求 OCR
- payload = {
- "model": "GLM-OCR",
- "messages": [
- {
- "role": "user",
- "content": [
- {
- "type": "text",
- "text": "识别目录内容,按原文格式输出。保留章节层级和页码。"
- },
- {
- "type": "image_url",
- "image_url": {"url": f"data:image/jpeg;base64,{img_base64}"}
- }
- ]
- }
- ],
- "max_tokens": 1024, # 2048 -> 1024,目录页486 tokens够用
- "temperature": 0.1,
- "seed": 42 # 固定采样随机性
- }
- headers = {"Content-Type": "application/json"}
- if self.ocr_api_key:
- headers["Authorization"] = f"Bearer {self.ocr_api_key}"
- # 指数退避重试
- max_retries = 3
- for attempt in range(max_retries):
- try:
- response = requests.post(
- self.ocr_api_url,
- headers=headers,
- json=payload,
- timeout=self.ocr_timeout
- )
- # 记录响应状态
- if response.status_code != 200:
- logger.error(f" 第{page_idx + 1}页OCR请求失败: HTTP {response.status_code}, 响应: {response.text[:200]}")
- response.raise_for_status()
- result = response.json()
- content = ""
- if "choices" in result and result["choices"]:
- content = result["choices"][0].get("message", {}).get("content", "")
- if content:
- all_texts.append(content)
- logger.info(f" 第{page_idx + 1}页目录OCR成功")
- break
- except requests.exceptions.HTTPError as e:
- if response.status_code == 400:
- logger.error(f" 第{page_idx + 1}页OCR请求格式错误(400),可能是图片过大")
- break # 400错误不需要重试
- if attempt < max_retries - 1:
- wait_time = 2 ** (attempt + 1)
- logger.warning(f" 第{page_idx + 1}页目录OCR失败,{wait_time}秒后重试...")
- time.sleep(wait_time)
- else:
- logger.error(f" 第{page_idx + 1}页目录OCR最终失败: {e}")
- except Exception as e:
- if attempt < max_retries - 1:
- wait_time = 2 ** (attempt + 1)
- logger.warning(f" 第{page_idx + 1}页目录OCR失败,{wait_time}秒后重试...")
- time.sleep(wait_time)
- else:
- logger.error(f" 第{page_idx + 1}页目录OCR最终失败: {e}")
- if progress_callback:
- progress = 10 + int((idx + 1) / total * 60)
- progress_callback("目录识别", progress, f"OCR识别中 {idx + 1}/{total}")
- except Exception as e:
- logger.error(f" 第{page_idx + 1}页OCR处理出错: {e}")
- return "\n".join(all_texts)
- def _compress_image(self, img_bytes: bytes, force_smaller: bool = False) -> bytes:
- """
- 压缩图片
- Args:
- img_bytes: 图片字节
- force_smaller: 是否强制更小的尺寸(用于处理过大的图片)
- """
- try:
- img = Image.open(io.BytesIO(img_bytes))
- if img.mode in ('RGBA', 'LA', 'P'):
- background = Image.new('RGB', img.size, (255, 255, 255))
- if img.mode == 'P':
- img = img.convert('RGBA')
- if img.mode in ('RGBA', 'LA'):
- background.paste(img, mask=img.split()[-1])
- img = background
- elif img.mode != 'RGB':
- img = img.convert('RGB')
- # 计算目标尺寸
- max_edge = self.MAX_SHORT_EDGE
- if force_smaller:
- max_edge = 640 # 强制小尺寸
- min_edge = min(img.size)
- if min_edge > max_edge:
- ratio = max_edge / min_edge
- new_size = (int(img.width * ratio), int(img.height * ratio))
- img = img.resize(new_size, Image.Resampling.LANCZOS)
- buffer = io.BytesIO()
- quality = self.JPEG_QUALITY if not force_smaller else 75
- img.save(buffer, format='JPEG', quality=quality, optimize=True)
- return buffer.getvalue()
- except Exception as e:
- logger.warning(f"[TOC检测] 图片压缩失败,使用原图: {e}")
- return img_bytes
- def _parse_toc_text(self, text: str) -> Dict[str, Any]:
- """
- 解析目录文本为结构化数据,输出标准格式
- 标准格式:
- 第X章 XXX
- 一、XXX
- 二、XXX
- Returns:
- {
- "chapters": [...],
- "total_chapters": N,
- "raw_ocr_text": "原始OCR文本",
- "formatted_text": "标准格式文本"
- }
- """
- lines = text.strip().split('\n')
- chapters = []
- current_chapter = None
- # 正则表达式模式
- chapter_pattern = re.compile(
- r'第\s*([一二三四五六七八九十百0-9]+)\s*章\s*[\s\.]*(.+?)\s*[\.\s]*(\d+)\s*$',
- re.IGNORECASE
- )
- section_pattern = re.compile(
- r'([一二三四五六七八九十]+)\s*[、\.\s]+\s*(.+?)\s*[\.\s]*(\d+)\s*$'
- )
- generic_pattern = re.compile(
- r'([0-9]+)[\.\s]+(.+?)\s*[\.\s]+(\d+)\s*$'
- )
- # 中文数字映射
- chinese_nums = {
- '一': 1, '二': 2, '三': 3, '四': 4, '五': 5,
- '六': 6, '七': 7, '八': 8, '九': 9, '十': 10,
- '十一': 11, '十二': 12, '十三': 13, '十四': 14, '十五': 15
- }
- for line in lines:
- line = line.strip()
- if not line or len(line) < 3:
- continue
- # 移除 Markdown 表格符号
- line = re.sub(r'^[\|\s]+|[\|\s]+$', '', line)
- line = line.replace('|', ' ')
- # 尝试匹配章
- chapter_match = chapter_pattern.search(line)
- if chapter_match:
- chapter_num = chapter_match.group(1)
- title = chapter_match.group(2).strip()
- page = chapter_match.group(3).strip()
- # 保存上一个章
- if current_chapter:
- chapters.append(current_chapter)
- # 标准化为阿拉伯数字
- if chapter_num.isdigit():
- idx = int(chapter_num)
- else:
- idx = chinese_nums.get(chapter_num, len(chapters) + 1)
- # 从原始行提取完整标题(保留原文格式)
- # 移除行尾页码,保留章节号+标题的原文形式
- original_title = re.sub(r'[\.\s]*(\d+)\s*$', '', line).strip()
- current_chapter = {
- "index": idx,
- "title": original_title,
- "page": page,
- "original": line,
- "subsections": []
- }
- continue
- # 尝试匹配节(二级)- 标准化为一、二、三格式
- section_match = section_pattern.search(line)
- if section_match and current_chapter:
- section_num = section_match.group(1)
- title = section_match.group(2).strip()
- page = section_match.group(3).strip()
- # 标准化节编号
- if section_num.isdigit():
- section_idx = int(section_num)
- section_cn = self._number_to_chinese(section_idx)
- else:
- section_cn = section_num
- current_chapter["subsections"].append({
- "title": title,
- "page": page,
- "level": 2,
- "original": line
- })
- continue
- # 尝试通用匹配(数字开头)
- generic_match = generic_pattern.search(line)
- if generic_match and current_chapter:
- title = generic_match.group(2).strip()
- page = generic_match.group(3).strip()
- # 判断是章还是节(根据内容特征)
- if any(kw in title for kw in ['编制依据', '工程概况', '施工计划', '施工工艺',
- '安全保证', '质量保证', '环境保证', '人员配备',
- '验收要求']):
- chapters.append(current_chapter)
- idx = len(chapters) + 1
- # 保留原标题,只移除页码
- original_title = re.sub(r'[\.\s]*(\d+)\s*$', '', line).strip()
- current_chapter = {
- "index": idx,
- "title": original_title,
- "page": page,
- "original": line,
- "subsections": []
- }
- else:
- # 作为节,保留原标题
- current_chapter["subsections"].append({
- "title": title,
- "page": page,
- "level": 2,
- "original": line
- })
- # 添加最后一个章
- if current_chapter:
- chapters.append(current_chapter)
- # 如果没有匹配到章,尝试按空行或缩进分割
- if not chapters and lines:
- chapters = self._fallback_parse(lines)
- # 构建标准格式文本
- formatted_lines = []
- for ch in chapters:
- formatted_lines.append(ch["title"])
- for sub in ch.get("subsections", []):
- formatted_lines.append(f" {sub['title']}")
- formatted_text = "\n".join(formatted_lines)
- # 日志输出完整的目录解析结果
- logger.info(f"[TOC解析] 共 {len(chapters)} 章,标准格式文本:\n{formatted_text}")
- return {
- "chapters": chapters,
- "total_chapters": len(chapters),
- "raw_ocr_text": text,
- "formatted_text": formatted_text
- }
- def _fallback_parse(self, lines: List[str]) -> List[Dict[str, Any]]:
- """
- 降级解析策略:当正则无法匹配时使用启发式方法
- 输出标准格式:第X章 XXX / 一、XXX
- """
- chapters = []
- idx = 0
- section_idx = 0
- for line in lines:
- line = line.strip()
- if not line:
- continue
- # 检查是否包含页码(行尾数字)
- page_match = re.search(r'(\d+)\s*$', line)
- if not page_match:
- continue
- page = page_match.group(1)
- title = re.sub(r'[\.\s]+\d+\s*$', '', line).strip()
- # 根据内容特征判断层级
- is_chapter = any(kw in title for kw in ['编制依据', '工程概况', '施工计划',
- '施工工艺', '安全保证', '质量保证',
- '环境保证', '人员配备', '验收',
- '其他资料'])
- if is_chapter or len(chapters) == 0:
- idx += 1
- section_idx = 0 # 重置节计数
- chapters.append({
- "index": idx,
- "title": title,
- "page": page,
- "original": line,
- "subsections": []
- })
- else:
- # 作为上一章的节,保留原标题
- if chapters:
- section_idx += 1
- chapters[-1]["subsections"].append({
- "title": title,
- "page": page,
- "level": 2,
- "original": line
- })
- return chapters
- def _number_to_chinese(self, num: int) -> str:
- """阿拉伯数字转中文数字"""
- chinese_nums = {
- 1: '一', 2: '二', 3: '三', 4: '四', 5: '五',
- 6: '六', 7: '七', 8: '八', 9: '九', 10: '十',
- 11: '十一', 12: '十二', 13: '十三', 14: '十四', 15: '十五'
- }
- return chinese_nums.get(num, str(num))
- def extract_catalog_from_pdf(
- file_content: bytes,
- model_path: str = None,
- ocr_api_url: str = "http://183.220.37.46:25429/v1/chat/completions",
- ocr_api_key: str = "",
- progress_callback=None
- ) -> Optional[Dict[str, Any]]:
- """
- 便捷函数:从 PDF 提取目录结构
- Returns:
- {"chapters": [...], "total_chapters": N} 或 None
- """
- extractor = TOCCatalogExtractor(
- model_path=model_path,
- ocr_api_url=ocr_api_url,
- ocr_api_key=ocr_api_key
- )
- return extractor.detect_and_extract(file_content, progress_callback)
|