_toc_detector.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624
  1. """
  2. YOLO 目录页检测与 OCR 提取模块
  3. 用于在文档处理流程早期检测目录页并提取目录内容,
  4. 输出结构与 outline 保持一致,便于后续进行目录完整性检查。
  5. """
  6. import io
  7. import os
  8. import re
  9. from dataclasses import dataclass
  10. from typing import Dict, Any, List, Optional, Tuple
  11. from pathlib import Path
  12. import fitz
  13. import numpy as np
  14. from utils_test.minimal_pipeline._simple_logger import review_logger as logger
  15. from ultralytics import YOLO
  16. from PIL import Image
  17. @dataclass
  18. class CatalogItem:
  19. """目录项结构"""
  20. index: int # 章节序号(1-based)
  21. title: str # 章节标题
  22. page: str # 页码(字符串)
  23. original: str # 原始文本
  24. level: int = 1 # 层级(1=章,2=节)
  25. parent_title: str = "" # 父章节标题(用于二级)
  26. @dataclass
  27. class CatalogSection:
  28. """目录节结构(对应二级目录)"""
  29. title: str
  30. page: str
  31. level: int
  32. original: str
  33. @dataclass
  34. class CatalogChapter:
  35. """目录章结构(对应一级目录)"""
  36. index: int
  37. title: str
  38. page: str
  39. original: str
  40. subsections: List[CatalogSection]
  41. class TOCCatalogExtractor:
  42. """
  43. 目录页检测与内容提取器
  44. 使用 YOLO 模型检测目录页,使用 GLM-OCR 提取目录文本,
  45. 解析为结构化数据,输出格式与 outline 保持一致。
  46. """
  47. # YOLO 配置
  48. DEFAULT_MODEL_PATH = "best.pt" # 本地副本
  49. CONF_THRESHOLD = 0.25
  50. MAX_CHECK_PAGES = 50
  51. DPI = 150
  52. # OCR 配置(高 DPI 渲染后缩放到 800px,确保目录文字清晰)
  53. OCR_DPI = 600
  54. MAX_SHORT_EDGE = 800
  55. JPEG_QUALITY = 85
  56. MAX_IMAGE_SIZE_MB = 5
  57. def __init__(
  58. self,
  59. model_path: str = None,
  60. ocr_api_url: str = "http://183.220.37.46:25429/v1/chat/completions",
  61. ocr_api_key: str = "",
  62. ocr_timeout: int = 600,
  63. ):
  64. self.model_path = model_path or self.DEFAULT_MODEL_PATH
  65. self.ocr_api_url = ocr_api_url
  66. self.ocr_api_key = ocr_api_key
  67. self.ocr_timeout = ocr_timeout
  68. self._model = None
  69. def _load_model(self) -> bool:
  70. """加载 YOLO 模型,缺少依赖或模型文件直接报错"""
  71. if not os.path.exists(self.model_path):
  72. raise FileNotFoundError(f"[TOC检测] YOLO模型文件不存在: {self.model_path}")
  73. if self._model is None:
  74. logger.info(f"[TOC检测] 正在加载YOLO模型: {self.model_path}")
  75. self._model = YOLO(self.model_path)
  76. return True
  77. def detect_and_extract(
  78. self,
  79. file_content: bytes,
  80. progress_callback=None
  81. ) -> Optional[Dict[str, Any]]:
  82. """
  83. 检测目录页并提取目录内容
  84. Args:
  85. file_content: PDF文件字节流
  86. progress_callback: 进度回调函数
  87. Returns:
  88. 目录结构字典,格式与 outline 保持一致:
  89. {
  90. "chapters": [...],
  91. "total_chapters": N
  92. }
  93. """
  94. if not self._load_model():
  95. return None
  96. doc = fitz.open(stream=file_content)
  97. try:
  98. # 1. 检测目录页范围
  99. toc_pages = self._detect_toc_pages(doc, progress_callback)
  100. if not toc_pages:
  101. logger.info("[TOC检测] 未检测到目录页")
  102. return None
  103. logger.info(f"[TOC检测] 检测到目录页: 第{toc_pages[0]+1}页 - 第{toc_pages[-1]+1}页")
  104. # 2. OCR 提取目录页内容
  105. if progress_callback:
  106. progress_callback("目录识别", 10, f"检测到{len(toc_pages)}页目录,开始OCR识别...")
  107. toc_text = self._ocr_toc_pages(doc, toc_pages, progress_callback)
  108. if not toc_text:
  109. return None
  110. # 3. 解析目录文本为结构化数据
  111. if progress_callback:
  112. progress_callback("目录识别", 80, "解析目录结构...")
  113. catalog = self._parse_toc_text(toc_text)
  114. # 添加目录页页码范围(1-based)
  115. if toc_pages:
  116. catalog["toc_page_range"] = {
  117. "start": toc_pages[0] + 1, # 转换为1-based页码
  118. "end": toc_pages[-1] + 1
  119. }
  120. if progress_callback:
  121. progress_callback("目录识别", 100, f"目录提取完成,共{catalog['total_chapters']}章")
  122. return catalog
  123. finally:
  124. doc.close()
  125. def _detect_toc_pages(
  126. self,
  127. doc: fitz.Document,
  128. progress_callback=None
  129. ) -> List[int]:
  130. """
  131. 使用 YOLO 检测目录页范围
  132. Returns:
  133. 目录页索引列表(0-based)
  134. """
  135. toc_pages = []
  136. total_pages = len(doc)
  137. pages_to_check = min(total_pages, self.MAX_CHECK_PAGES)
  138. for page_idx in range(pages_to_check):
  139. page = doc.load_page(page_idx)
  140. # 渲染页面
  141. zoom = self.DPI / 72
  142. mat = fitz.Matrix(zoom, zoom)
  143. pix = page.get_pixmap(matrix=mat)
  144. # 转换为 numpy 数组
  145. img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
  146. img_array = np.array(img)
  147. # YOLO 检测
  148. results = self._model(img_array, conf=self.CONF_THRESHOLD, verbose=False)
  149. # 检查是否检测到 catalogs 类别
  150. has_catalogs = False
  151. for result in results:
  152. if result.boxes is not None:
  153. for box in result.boxes:
  154. cls_id = int(box.cls.item())
  155. class_name = self._model.names.get(cls_id, f"class_{cls_id}")
  156. if class_name == 'catalogs':
  157. has_catalogs = True
  158. break
  159. if has_catalogs:
  160. break
  161. if has_catalogs:
  162. toc_pages.append(page_idx)
  163. logger.debug(f" 第{page_idx + 1:3d}页: 检测到目录")
  164. else:
  165. logger.debug(f" 第{page_idx + 1:3d}页: 未检测到目录")
  166. # 如果已经检测到目录,且现在没有检测到,认为目录结束
  167. if toc_pages:
  168. break
  169. if progress_callback and (page_idx + 1) % 5 == 0:
  170. progress = int((page_idx + 1) / pages_to_check * 10)
  171. progress_callback("目录识别", progress, f"扫描页面 {page_idx + 1}/{pages_to_check}")
  172. return toc_pages
  173. def _ocr_toc_pages(
  174. self,
  175. doc: fitz.Document,
  176. toc_pages: List[int],
  177. progress_callback=None
  178. ) -> str:
  179. """
  180. 对目录页进行 OCR 识别
  181. Returns:
  182. 合并后的目录文本
  183. """
  184. import base64
  185. import io
  186. import requests
  187. import time
  188. all_texts = []
  189. total = len(toc_pages)
  190. for idx, page_idx in enumerate(toc_pages):
  191. page = doc.load_page(page_idx)
  192. try:
  193. # 渲染页面(使用较低DPI避免图片过大)
  194. pix = page.get_pixmap(dpi=self.OCR_DPI)
  195. img_bytes = pix.tobytes("jpeg")
  196. # 压缩图片
  197. compressed = self._compress_image(img_bytes)
  198. img_size_mb = len(compressed) / (1024 * 1024)
  199. logger.debug(f" 第{page_idx + 1}页图片大小: {img_size_mb:.2f}MB")
  200. # 检查图片大小
  201. if img_size_mb > self.MAX_IMAGE_SIZE_MB:
  202. logger.warning(f" 第{page_idx + 1}页图片过大({img_size_mb:.2f}MB),尝试进一步压缩")
  203. # 再次压缩
  204. compressed = self._compress_image(compressed, force_smaller=True)
  205. img_size_mb = len(compressed) / (1024 * 1024)
  206. logger.debug(f" 压缩后大小: {img_size_mb:.2f}MB")
  207. img_base64 = base64.b64encode(compressed).decode('utf-8')
  208. # 请求 OCR
  209. payload = {
  210. "model": "GLM-OCR",
  211. "messages": [
  212. {
  213. "role": "user",
  214. "content": [
  215. {
  216. "type": "text",
  217. "text": "识别目录内容,按原文格式输出。保留章节层级和页码。"
  218. },
  219. {
  220. "type": "image_url",
  221. "image_url": {"url": f"data:image/jpeg;base64,{img_base64}"}
  222. }
  223. ]
  224. }
  225. ],
  226. "max_tokens": 1024, # 2048 -> 1024,目录页486 tokens够用
  227. "temperature": 0.1,
  228. "seed": 42 # 固定采样随机性
  229. }
  230. headers = {"Content-Type": "application/json"}
  231. if self.ocr_api_key:
  232. headers["Authorization"] = f"Bearer {self.ocr_api_key}"
  233. # 指数退避重试
  234. max_retries = 3
  235. for attempt in range(max_retries):
  236. try:
  237. response = requests.post(
  238. self.ocr_api_url,
  239. headers=headers,
  240. json=payload,
  241. timeout=self.ocr_timeout
  242. )
  243. # 记录响应状态
  244. if response.status_code != 200:
  245. logger.error(f" 第{page_idx + 1}页OCR请求失败: HTTP {response.status_code}, 响应: {response.text[:200]}")
  246. response.raise_for_status()
  247. result = response.json()
  248. content = ""
  249. if "choices" in result and result["choices"]:
  250. content = result["choices"][0].get("message", {}).get("content", "")
  251. if content:
  252. all_texts.append(content)
  253. logger.info(f" 第{page_idx + 1}页目录OCR成功")
  254. break
  255. except requests.exceptions.HTTPError as e:
  256. if response.status_code == 400:
  257. logger.error(f" 第{page_idx + 1}页OCR请求格式错误(400),可能是图片过大")
  258. break # 400错误不需要重试
  259. if attempt < max_retries - 1:
  260. wait_time = 2 ** (attempt + 1)
  261. logger.warning(f" 第{page_idx + 1}页目录OCR失败,{wait_time}秒后重试...")
  262. time.sleep(wait_time)
  263. else:
  264. logger.error(f" 第{page_idx + 1}页目录OCR最终失败: {e}")
  265. except Exception as e:
  266. if attempt < max_retries - 1:
  267. wait_time = 2 ** (attempt + 1)
  268. logger.warning(f" 第{page_idx + 1}页目录OCR失败,{wait_time}秒后重试...")
  269. time.sleep(wait_time)
  270. else:
  271. logger.error(f" 第{page_idx + 1}页目录OCR最终失败: {e}")
  272. if progress_callback:
  273. progress = 10 + int((idx + 1) / total * 60)
  274. progress_callback("目录识别", progress, f"OCR识别中 {idx + 1}/{total}")
  275. except Exception as e:
  276. logger.error(f" 第{page_idx + 1}页OCR处理出错: {e}")
  277. return "\n".join(all_texts)
  278. def _compress_image(self, img_bytes: bytes, force_smaller: bool = False) -> bytes:
  279. """
  280. 压缩图片
  281. Args:
  282. img_bytes: 图片字节
  283. force_smaller: 是否强制更小的尺寸(用于处理过大的图片)
  284. """
  285. try:
  286. img = Image.open(io.BytesIO(img_bytes))
  287. if img.mode in ('RGBA', 'LA', 'P'):
  288. background = Image.new('RGB', img.size, (255, 255, 255))
  289. if img.mode == 'P':
  290. img = img.convert('RGBA')
  291. if img.mode in ('RGBA', 'LA'):
  292. background.paste(img, mask=img.split()[-1])
  293. img = background
  294. elif img.mode != 'RGB':
  295. img = img.convert('RGB')
  296. # 计算目标尺寸
  297. max_edge = self.MAX_SHORT_EDGE
  298. if force_smaller:
  299. max_edge = 640 # 强制小尺寸
  300. min_edge = min(img.size)
  301. if min_edge > max_edge:
  302. ratio = max_edge / min_edge
  303. new_size = (int(img.width * ratio), int(img.height * ratio))
  304. img = img.resize(new_size, Image.Resampling.LANCZOS)
  305. # 二值化增强:将浅灰文字变黑,提高 OCR 识别率
  306. img = img.convert('L')
  307. img = img.point(lambda x: 0 if x < 220 else 255)
  308. img = img.convert('RGB')
  309. buffer = io.BytesIO()
  310. img.save(buffer, format='PNG', optimize=True)
  311. return buffer.getvalue()
  312. except Exception as e:
  313. logger.warning(f"[TOC检测] 图片压缩失败,使用原图: {e}")
  314. return img_bytes
  315. def _parse_toc_text(self, text: str) -> Dict[str, Any]:
  316. """
  317. 解析目录文本为结构化数据,输出标准格式
  318. 标准格式:
  319. 第X章 XXX
  320. 一、XXX
  321. 二、XXX
  322. Returns:
  323. {
  324. "chapters": [...],
  325. "total_chapters": N,
  326. "raw_ocr_text": "原始OCR文本",
  327. "formatted_text": "标准格式文本"
  328. }
  329. """
  330. lines = text.strip().split('\n')
  331. chapters = []
  332. current_chapter = None
  333. # 正则表达式模式
  334. chapter_pattern = re.compile(
  335. r'第\s*([一二三四五六七八九十百0-9]+)\s*章\s*[\s\.]*(.+?)\s*[\.\s]*(\d+)\s*$',
  336. re.IGNORECASE
  337. )
  338. section_pattern = re.compile(
  339. r'([一二三四五六七八九十]+)\s*[、\.\s]+\s*(.+?)\s*[\.\s]*(\d+)\s*$'
  340. )
  341. generic_pattern = re.compile(
  342. r'([0-9]+)[\.\s]+(.+?)\s*[\.\s]+(\d+)\s*$'
  343. )
  344. # 中文数字映射
  345. chinese_nums = {
  346. '一': 1, '二': 2, '三': 3, '四': 4, '五': 5,
  347. '六': 6, '七': 7, '八': 8, '九': 9, '十': 10,
  348. '十一': 11, '十二': 12, '十三': 13, '十四': 14, '十五': 15
  349. }
  350. for line in lines:
  351. line = line.strip()
  352. if not line or len(line) < 3:
  353. continue
  354. # 移除 Markdown 表格符号
  355. line = re.sub(r'^[\|\s]+|[\|\s]+$', '', line)
  356. line = line.replace('|', ' ')
  357. # 尝试匹配章
  358. chapter_match = chapter_pattern.search(line)
  359. if chapter_match:
  360. chapter_num = chapter_match.group(1)
  361. title = chapter_match.group(2).strip()
  362. page = chapter_match.group(3).strip()
  363. # 保存上一个章
  364. if current_chapter:
  365. chapters.append(current_chapter)
  366. # 标准化为阿拉伯数字
  367. if chapter_num.isdigit():
  368. idx = int(chapter_num)
  369. else:
  370. idx = chinese_nums.get(chapter_num, len(chapters) + 1)
  371. # 从原始行提取完整标题(保留原文格式)
  372. # 移除行尾页码,保留章节号+标题的原文形式
  373. original_title = re.sub(r'[\.\s]*(\d+)\s*$', '', line).strip()
  374. current_chapter = {
  375. "index": idx,
  376. "title": original_title,
  377. "page": page,
  378. "original": line,
  379. "subsections": []
  380. }
  381. continue
  382. # 尝试匹配节(二级)- 标准化为一、二、三格式
  383. section_match = section_pattern.search(line)
  384. if section_match and current_chapter:
  385. section_num = section_match.group(1)
  386. title = section_match.group(2).strip()
  387. page = section_match.group(3).strip()
  388. # 标准化节编号
  389. if section_num.isdigit():
  390. section_idx = int(section_num)
  391. section_cn = self._number_to_chinese(section_idx)
  392. else:
  393. section_cn = section_num
  394. current_chapter["subsections"].append({
  395. "title": title,
  396. "page": page,
  397. "level": 2,
  398. "original": line
  399. })
  400. continue
  401. # 尝试通用匹配(数字开头)
  402. generic_match = generic_pattern.search(line)
  403. if generic_match and current_chapter:
  404. title = generic_match.group(2).strip()
  405. page = generic_match.group(3).strip()
  406. # 判断是章还是节(根据内容特征)
  407. if any(kw in title for kw in ['编制依据', '工程概况', '施工计划', '施工工艺',
  408. '安全保证', '质量保证', '环境保证', '人员配备',
  409. '验收要求']):
  410. chapters.append(current_chapter)
  411. idx = len(chapters) + 1
  412. # 保留原标题,只移除页码
  413. original_title = re.sub(r'[\.\s]*(\d+)\s*$', '', line).strip()
  414. current_chapter = {
  415. "index": idx,
  416. "title": original_title,
  417. "page": page,
  418. "original": line,
  419. "subsections": []
  420. }
  421. else:
  422. # 作为节,保留原标题
  423. current_chapter["subsections"].append({
  424. "title": title,
  425. "page": page,
  426. "level": 2,
  427. "original": line
  428. })
  429. # 添加最后一个章
  430. if current_chapter:
  431. chapters.append(current_chapter)
  432. # 如果没有匹配到章,尝试按空行或缩进分割
  433. if not chapters and lines:
  434. chapters = self._fallback_parse(lines)
  435. # 构建标准格式文本
  436. formatted_lines = []
  437. for ch in chapters:
  438. formatted_lines.append(ch["title"])
  439. for sub in ch.get("subsections", []):
  440. formatted_lines.append(f" {sub['title']}")
  441. formatted_text = "\n".join(formatted_lines)
  442. # 日志输出完整的目录解析结果
  443. logger.info(f"[TOC解析] 共 {len(chapters)} 章,标准格式文本:\n{formatted_text}")
  444. return {
  445. "chapters": chapters,
  446. "total_chapters": len(chapters),
  447. "raw_ocr_text": text,
  448. "formatted_text": formatted_text
  449. }
  450. def _fallback_parse(self, lines: List[str]) -> List[Dict[str, Any]]:
  451. """
  452. 降级解析策略:当正则无法匹配时使用启发式方法
  453. 输出标准格式:第X章 XXX / 一、XXX
  454. """
  455. chapters = []
  456. idx = 0
  457. section_idx = 0
  458. for line in lines:
  459. line = line.strip()
  460. if not line:
  461. continue
  462. # 检查是否包含页码(行尾数字)
  463. page_match = re.search(r'(\d+)\s*$', line)
  464. if not page_match:
  465. continue
  466. page = page_match.group(1)
  467. title = re.sub(r'[\.\s]+\d+\s*$', '', line).strip()
  468. # 根据内容特征判断层级
  469. is_chapter = any(kw in title for kw in ['编制依据', '工程概况', '施工计划',
  470. '施工工艺', '安全保证', '质量保证',
  471. '环境保证', '人员配备', '验收',
  472. '其他资料'])
  473. if is_chapter or len(chapters) == 0:
  474. idx += 1
  475. section_idx = 0 # 重置节计数
  476. chapters.append({
  477. "index": idx,
  478. "title": title,
  479. "page": page,
  480. "original": line,
  481. "subsections": []
  482. })
  483. else:
  484. # 作为上一章的节,保留原标题
  485. if chapters:
  486. section_idx += 1
  487. chapters[-1]["subsections"].append({
  488. "title": title,
  489. "page": page,
  490. "level": 2,
  491. "original": line
  492. })
  493. return chapters
  494. def _number_to_chinese(self, num: int) -> str:
  495. """阿拉伯数字转中文数字"""
  496. chinese_nums = {
  497. 1: '一', 2: '二', 3: '三', 4: '四', 5: '五',
  498. 6: '六', 7: '七', 8: '八', 9: '九', 10: '十',
  499. 11: '十一', 12: '十二', 13: '十三', 14: '十四', 15: '十五'
  500. }
  501. return chinese_nums.get(num, str(num))
  502. def extract_catalog_from_pdf(
  503. file_content: bytes,
  504. model_path: str = None,
  505. ocr_api_url: str = "http://183.220.37.46:25429/v1/chat/completions",
  506. ocr_api_key: str = "",
  507. progress_callback=None
  508. ) -> Optional[Dict[str, Any]]:
  509. """
  510. 便捷函数:从 PDF 提取目录结构
  511. Returns:
  512. {"chapters": [...], "total_chapters": N} 或 None
  513. """
  514. extractor = TOCCatalogExtractor(
  515. model_path=model_path,
  516. ocr_api_url=ocr_api_url,
  517. ocr_api_key=ocr_api_key
  518. )
  519. return extractor.detect_and_extract(file_content, progress_callback)