_ocr_processor.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485
  1. """
  2. OCR 处理模块 - 表格检测与识别
  3. 提供 PDF 表格区域检测和 OCR 识别功能,支持:
  4. - RapidLayout 表格区域检测
  5. - GLM-OCR 并发识别
  6. - 表格文本替换回填
  7. """
  8. import base64
  9. import io
  10. import time
  11. from concurrent.futures import ThreadPoolExecutor, as_completed
  12. from dataclasses import dataclass
  13. from typing import Dict, Any, List, Optional, Tuple, Set
  14. import fitz
  15. import numpy as np
  16. import requests
  17. from utils_test.minimal_pipeline._simple_logger import review_logger as logger
  18. # 尝试导入 RapidLayout
  19. try:
  20. from rapid_layout import RapidLayout
  21. RAPID_LAYOUT_AVAILABLE = True
  22. except ImportError:
  23. RAPID_LAYOUT_AVAILABLE = False
  24. RapidLayout = None
  25. @dataclass
  26. class TableRegion:
  27. """表格区域信息"""
  28. page_num: int
  29. page: fitz.Page
  30. bbox: Tuple[float, float, float, float]
  31. score: float
  32. label: str = "table" # YOLO 原始标签: table / figure
  33. @dataclass
  34. class OcrResult:
  35. """OCR 结果"""
  36. page_num: int
  37. bbox: Tuple[float, float, float, float]
  38. score: float
  39. text: str
  40. success: bool
  41. class OcrProcessor:
  42. """OCR 处理器:表格检测与识别"""
  43. # 默认配置
  44. MAX_SHORT_EDGE = 1024
  45. JPEG_QUALITY = 90
  46. OCR_DPI = 200
  47. OCR_CONFIDENCE_THRESHOLD = 0.5
  48. OCR_CONCURRENT_WORKERS = 20
  49. def __init__(
  50. self,
  51. ocr_api_url: str = "http://183.220.37.46:25429/v1/chat/completions",
  52. ocr_timeout: int = 600,
  53. ocr_api_key: str = "",
  54. max_short_edge: int = 1024,
  55. jpeg_quality: int = 90,
  56. ocr_dpi: int = 200,
  57. confidence_threshold: float = 0.5,
  58. concurrent_workers: int = 20,
  59. ):
  60. """
  61. 初始化 OCR 处理器
  62. Args:
  63. ocr_api_url: OCR API 地址
  64. ocr_timeout: OCR 请求超时时间(秒)
  65. ocr_api_key: OCR API 密钥
  66. max_short_edge: 图片压缩后短边最大尺寸
  67. jpeg_quality: JPEG 压缩质量
  68. ocr_dpi: OCR 渲染 DPI
  69. confidence_threshold: 表格检测置信度阈值
  70. concurrent_workers: OCR 并发工作线程数
  71. """
  72. self.ocr_api_url = ocr_api_url
  73. self.ocr_timeout = ocr_timeout
  74. self.ocr_api_key = ocr_api_key
  75. self.max_short_edge = max_short_edge
  76. self.jpeg_quality = jpeg_quality
  77. self.ocr_dpi = ocr_dpi
  78. self.confidence_threshold = confidence_threshold
  79. self.concurrent_workers = concurrent_workers
  80. self._layout_engine: Optional[Any] = None
  81. if not RAPID_LAYOUT_AVAILABLE:
  82. logger.warning("RapidLayout 未安装,表格检测功能不可用")
  83. def is_available(self) -> bool:
  84. """检查 OCR 功能是否可用"""
  85. return RAPID_LAYOUT_AVAILABLE
  86. def _get_layout_engine(self) -> Optional[Any]:
  87. """延迟初始化 RapidLayout"""
  88. if self._layout_engine is None and RAPID_LAYOUT_AVAILABLE:
  89. self._layout_engine = RapidLayout()
  90. return self._layout_engine
  91. def detect_table_regions(
  92. self,
  93. page: fitz.Page,
  94. page_num: int,
  95. clip_box: fitz.Rect
  96. ) -> List[Tuple[Tuple[float, float, float, float], float]]:
  97. """
  98. 检测页面中的表格区域
  99. Args:
  100. page: PDF 页面对象
  101. page_num: 页码(用于日志)
  102. clip_box: 裁剪区域
  103. Returns:
  104. 列表,元素为 ((x1, y1, x2, y2), score)
  105. """
  106. table_regions: List[Tuple[Tuple[float, float, float, float], float, str]] = []
  107. if not RAPID_LAYOUT_AVAILABLE:
  108. return table_regions
  109. layout_engine = self._get_layout_engine()
  110. if layout_engine is None:
  111. return table_regions
  112. # 渲染页面(裁剪区域)
  113. pix = page.get_pixmap(dpi=self.ocr_dpi, clip=clip_box)
  114. img = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, 3)
  115. try:
  116. layout_output = layout_engine(img)
  117. # 解析版面结果
  118. if hasattr(layout_output, 'boxes') and hasattr(layout_output, 'class_names'):
  119. # 获取缩放比例
  120. scale_x = clip_box.width / img.shape[1]
  121. scale_y = clip_box.height / img.shape[0]
  122. table_count = 0
  123. figure_count = 0
  124. for box, label, score in zip(layout_output.boxes, layout_output.class_names, layout_output.scores):
  125. if label in ("table", "figure") and score > self.confidence_threshold:
  126. # 转换为 PDF 坐标
  127. pdf_x1 = clip_box.x0 + box[0] * scale_x
  128. pdf_y1 = clip_box.y0 + box[1] * scale_y
  129. pdf_x2 = clip_box.x0 + box[2] * scale_x
  130. pdf_y2 = clip_box.y0 + box[3] * scale_y
  131. table_regions.append(((pdf_x1, pdf_y1, pdf_x2, pdf_y2), score, label))
  132. if label == "table":
  133. table_count += 1
  134. else:
  135. figure_count += 1
  136. if table_count or figure_count:
  137. logger.info(f" [YOLO] 第{page_num}页: table={table_count}, figure={figure_count}")
  138. except Exception as e:
  139. logger.warning(f" 第 {page_num} 页: 版面分析失败 ({e})")
  140. return table_regions
  141. def process_ocr_concurrent(
  142. self,
  143. regions: List[TableRegion],
  144. progress_callback=None
  145. ) -> List[OcrResult]:
  146. """
  147. 同步并发处理 OCR
  148. Args:
  149. regions: 表格区域列表
  150. progress_callback: 进度回调函数,接收 (completed, total) 参数
  151. Returns:
  152. OCR 结果列表
  153. """
  154. results: List[OcrResult] = []
  155. total = len(regions)
  156. completed = 0
  157. # 统计
  158. table_total = sum(1 for r in regions if r.label == "table")
  159. figure_total = sum(1 for r in regions if r.label == "figure")
  160. logger.info(f"[OCR] 开始并发识别: table={table_total}, figure={figure_total}, workers={self.concurrent_workers}")
  161. with ThreadPoolExecutor(max_workers=self.concurrent_workers) as executor:
  162. # 提交所有任务
  163. future_to_region = {
  164. executor.submit(self._ocr_table_region, r.page, r.bbox): r
  165. for r in regions
  166. }
  167. # 处理完成的结果
  168. non_table_count = 0
  169. table_ok_count = 0
  170. for future in as_completed(future_to_region):
  171. region = future_to_region[future]
  172. completed += 1
  173. try:
  174. text = future.result()
  175. if text.strip():
  176. table_ok_count += 1
  177. else:
  178. non_table_count += 1
  179. results.append(OcrResult(
  180. page_num=region.page_num,
  181. bbox=region.bbox,
  182. score=region.score,
  183. text=text,
  184. success=True,
  185. ))
  186. except Exception as e:
  187. non_table_count += 1
  188. logger.error(f" 第 {region.page_num} 页 {region.label} OCR 失败: {e}")
  189. results.append(OcrResult(
  190. page_num=region.page_num,
  191. bbox=region.bbox,
  192. score=region.score,
  193. text="",
  194. success=False,
  195. ))
  196. # 每完成5个或最后一个时推送进度
  197. if progress_callback and (completed % 5 == 0 or completed == total):
  198. progress_callback(completed, total)
  199. logger.info(f"[OCR] 完成: table={table_total}, figure={figure_total}, "
  200. f"有效表格={table_ok_count}, Non-table/失败={non_table_count}")
  201. return results
  202. def _ocr_table_region(
  203. self,
  204. page: fitz.Page,
  205. bbox: Tuple[float, float, float, float],
  206. max_retries: int = 3
  207. ) -> str:
  208. """
  209. 对指定区域进行 OCR 识别(使用 GLM-OCR),支持指数退避重试
  210. Args:
  211. page: PDF 页面对象
  212. bbox: 区域坐标 (x1, y1, x2, y2)
  213. max_retries: 最大重试次数
  214. Returns:
  215. 识别的文本内容
  216. """
  217. # 渲染指定区域
  218. rect = fitz.Rect(bbox)
  219. pix = page.get_pixmap(dpi=self.ocr_dpi, clip=rect)
  220. img_bytes = pix.tobytes("jpeg")
  221. # 压缩图片
  222. compressed = self._compress_image(img_bytes)
  223. img_base64 = base64.b64encode(compressed).decode('utf-8')
  224. # 请求 OCR
  225. payload = {
  226. "model": "GLM-OCR",
  227. "messages": [
  228. {
  229. "role": "user",
  230. "content": [
  231. {
  232. "type": "text",
  233. "text": "判断图片中是否包含表格。"
  234. "- 若包含表格:用 Markdown 表格格式提取内容,保持行列对齐。"
  235. "- 若不包含任何表格:只输出 Non-table。"
  236. "只输出结果,不要解释。"
  237. },
  238. {
  239. "type": "image_url",
  240. "image_url": {"url": f"data:image/jpeg;base64,{img_base64}"}
  241. }
  242. ]
  243. }
  244. ],
  245. "max_tokens": 2048,
  246. "temperature": 0.1
  247. }
  248. headers = {"Content-Type": "application/json"}
  249. if self.ocr_api_key:
  250. headers["Authorization"] = f"Bearer {self.ocr_api_key}"
  251. # 指数退避重试
  252. last_error = None
  253. for attempt in range(max_retries):
  254. try:
  255. response = requests.post(
  256. self.ocr_api_url,
  257. headers=headers,
  258. json=payload,
  259. timeout=self.ocr_timeout
  260. )
  261. response.raise_for_status()
  262. result = response.json()
  263. return self._extract_ocr_content(result)
  264. except Exception as e:
  265. last_error = e
  266. if attempt < max_retries - 1:
  267. # 指数退避: 2, 4, 8 秒
  268. wait_time = 2 ** (attempt + 1)
  269. logger.warning(f" 第 {page.number + 1} 页表格 OCR 第 {attempt + 1} 次失败: {e}, {wait_time}秒后重试...")
  270. time.sleep(wait_time)
  271. else:
  272. logger.error(f" 第 {page.number + 1} 页表格 OCR 最终失败(已重试{max_retries}次): {e}")
  273. # 所有重试都失败,抛出最后一个错误
  274. raise last_error
  275. def _compress_image(self, img_bytes: bytes) -> bytes:
  276. """
  277. 压缩图片
  278. Args:
  279. img_bytes: 原始图片字节
  280. Returns:
  281. 压缩后的图片字节
  282. """
  283. try:
  284. from PIL import Image
  285. img = Image.open(io.BytesIO(img_bytes))
  286. if img.mode in ('RGBA', 'LA', 'P'):
  287. background = Image.new('RGB', img.size, (255, 255, 255))
  288. if img.mode == 'P':
  289. img = img.convert('RGBA')
  290. if img.mode in ('RGBA', 'LA'):
  291. background.paste(img, mask=img.split()[-1])
  292. img = background
  293. elif img.mode != 'RGB':
  294. img = img.convert('RGB')
  295. min_edge = min(img.size)
  296. if min_edge > self.max_short_edge:
  297. ratio = self.max_short_edge / min_edge
  298. new_size = (int(img.width * ratio), int(img.height * ratio))
  299. img = img.resize(new_size, Image.Resampling.LANCZOS)
  300. buffer = io.BytesIO()
  301. img.save(buffer, format='JPEG', quality=self.jpeg_quality, optimize=True)
  302. return buffer.getvalue()
  303. except Exception as e:
  304. logger.warning(f"图片压缩失败,使用原图: {e}")
  305. return img_bytes
  306. def _extract_ocr_content(self, result: Dict) -> str:
  307. """
  308. 从 OCR 响应提取内容,并将 HTML 表格转换为 Markdown
  309. Args:
  310. result: OCR API 响应
  311. Returns:
  312. 提取的文本内容
  313. """
  314. content = ""
  315. if "choices" in result and isinstance(result["choices"], list):
  316. if len(result["choices"]) > 0:
  317. message = result["choices"][0].get("message", {})
  318. content = message.get("content", "")
  319. # GLM 判定为非表格区域,返回空字符串,下游自然跳过
  320. if content and content.strip().startswith("Non-table"):
  321. return ""
  322. # 如果内容包含 HTML 标签,转换为 Markdown
  323. if content and "<" in content and ">" in content:
  324. try:
  325. from utils_test.minimal_pipeline._html_to_md import convert_html_to_markdown
  326. content = convert_html_to_markdown(content)
  327. except Exception as e:
  328. logger.debug(f"HTML 转 Markdown 失败,保留原始内容: {e}")
  329. return content
  330. def replace_table_regions(
  331. self,
  332. page: fitz.Page,
  333. original_text: str,
  334. ocr_results: List[Dict],
  335. clip_box: fitz.Rect
  336. ) -> str:
  337. """
  338. 用 OCR 结果替换原始文本中的表格区域
  339. Args:
  340. page: PDF 页面对象
  341. original_text: 原始文本
  342. ocr_results: OCR 结果列表,每个元素包含 region_index, bbox, score, ocr_text
  343. clip_box: 裁剪区域
  344. Returns:
  345. 替换后的文本
  346. """
  347. if not ocr_results:
  348. return original_text
  349. # 获取页面上的文本块及其坐标
  350. text_blocks = []
  351. for block in page.get_text("blocks"):
  352. x0, y0, x1, y1, text, _, _ = block
  353. # 只考虑裁剪区域内的文本
  354. if y0 >= clip_box.y0 and y1 <= clip_box.y1:
  355. text_blocks.append({
  356. "bbox": (x0, y0, x1, y1),
  357. "text": text.strip(),
  358. })
  359. # 按 Y 坐标排序
  360. text_blocks.sort(key=lambda b: (b["bbox"][1], b["bbox"][0]))
  361. # 找出属于表格区域的文本块
  362. replaced_indices: Set[int] = set()
  363. for ocr_result in ocr_results:
  364. bbox = ocr_result["bbox"]
  365. rx0, ry0, rx1, ry1 = bbox
  366. for idx, block in enumerate(text_blocks):
  367. if idx in replaced_indices:
  368. continue
  369. bx0, by0, bx1, by1 = block["bbox"]
  370. # 检查重叠
  371. overlap_x = max(0, min(bx1, rx1) - max(bx0, rx0))
  372. overlap_y = max(0, min(by1, ry1) - max(by0, ry0))
  373. overlap_area = overlap_x * overlap_y
  374. block_area = (bx1 - bx0) * (by1 - by0)
  375. if block_area > 0 and overlap_area / block_area > 0.5:
  376. replaced_indices.add(idx)
  377. # 构建新文本
  378. result_parts: List[str] = []
  379. last_idx = 0
  380. for ocr_result in sorted(ocr_results, key=lambda r: r["bbox"][1]):
  381. bbox = ocr_result["bbox"]
  382. rx0, ry0, rx1, ry1 = bbox
  383. # 找到该表格区域之前的文本
  384. region_start_idx = None
  385. for idx, block in enumerate(text_blocks):
  386. if idx in replaced_indices:
  387. bx0, by0, bx1, by1 = block["bbox"]
  388. if (bx0 >= rx0 - 5 and bx1 <= rx1 + 5 and
  389. by0 >= ry0 - 5 and by1 <= ry1 + 5):
  390. if region_start_idx is None:
  391. region_start_idx = idx
  392. last_idx = idx + 1
  393. if region_start_idx is not None:
  394. # 添加表格前的非表格文本
  395. for idx in range(last_idx - (last_idx - region_start_idx), region_start_idx):
  396. if idx not in replaced_indices and idx < len(text_blocks):
  397. result_parts.append(text_blocks[idx]["text"])
  398. result_parts.append("\n")
  399. # 添加 OCR 结果
  400. result_parts.append(ocr_result["ocr_text"])
  401. result_parts.append("\n")
  402. # 添加剩余文本
  403. for idx in range(last_idx, len(text_blocks)):
  404. if idx not in replaced_indices:
  405. result_parts.append(text_blocks[idx]["text"])
  406. result_parts.append("\n")
  407. return "".join(result_parts).strip() or original_text