""" OCR 处理模块 - 表格检测与识别 提供 PDF 表格区域检测和 OCR 识别功能,支持: - RapidLayout 表格区域检测 - GLM-OCR 并发识别 - 表格文本替换回填 """ import base64 import io import time from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass from typing import Dict, Any, List, Optional, Tuple, Set import fitz import numpy as np import requests from utils_test.minimal_pipeline._simple_logger import review_logger as logger # 尝试导入 RapidLayout try: from rapid_layout import RapidLayout RAPID_LAYOUT_AVAILABLE = True except ImportError: RAPID_LAYOUT_AVAILABLE = False RapidLayout = None @dataclass class TableRegion: """表格区域信息""" page_num: int page: fitz.Page bbox: Tuple[float, float, float, float] score: float label: str = "table" # YOLO 原始标签: table / figure @dataclass class OcrResult: """OCR 结果""" page_num: int bbox: Tuple[float, float, float, float] score: float text: str success: bool class OcrProcessor: """OCR 处理器:表格检测与识别""" # 默认配置 MAX_SHORT_EDGE = 1024 JPEG_QUALITY = 90 OCR_DPI = 200 OCR_CONFIDENCE_THRESHOLD = 0.5 OCR_CONCURRENT_WORKERS = 20 def __init__( self, ocr_api_url: str = "http://183.220.37.46:25429/v1/chat/completions", ocr_timeout: int = 600, ocr_api_key: str = "", max_short_edge: int = 1024, jpeg_quality: int = 90, ocr_dpi: int = 200, confidence_threshold: float = 0.5, concurrent_workers: int = 20, ): """ 初始化 OCR 处理器 Args: ocr_api_url: OCR API 地址 ocr_timeout: OCR 请求超时时间(秒) ocr_api_key: OCR API 密钥 max_short_edge: 图片压缩后短边最大尺寸 jpeg_quality: JPEG 压缩质量 ocr_dpi: OCR 渲染 DPI confidence_threshold: 表格检测置信度阈值 concurrent_workers: OCR 并发工作线程数 """ self.ocr_api_url = ocr_api_url self.ocr_timeout = ocr_timeout self.ocr_api_key = ocr_api_key self.max_short_edge = max_short_edge self.jpeg_quality = jpeg_quality self.ocr_dpi = ocr_dpi self.confidence_threshold = confidence_threshold self.concurrent_workers = concurrent_workers self._layout_engine: Optional[Any] = None if not RAPID_LAYOUT_AVAILABLE: logger.warning("RapidLayout 未安装,表格检测功能不可用") def is_available(self) -> bool: """检查 OCR 功能是否可用""" return RAPID_LAYOUT_AVAILABLE def _get_layout_engine(self) -> Optional[Any]: """延迟初始化 RapidLayout""" if self._layout_engine is None and RAPID_LAYOUT_AVAILABLE: self._layout_engine = RapidLayout() return self._layout_engine def detect_table_regions( self, page: fitz.Page, page_num: int, clip_box: fitz.Rect ) -> List[Tuple[Tuple[float, float, float, float], float]]: """ 检测页面中的表格区域 Args: page: PDF 页面对象 page_num: 页码(用于日志) clip_box: 裁剪区域 Returns: 列表,元素为 ((x1, y1, x2, y2), score) """ table_regions: List[Tuple[Tuple[float, float, float, float], float, str]] = [] if not RAPID_LAYOUT_AVAILABLE: return table_regions layout_engine = self._get_layout_engine() if layout_engine is None: return table_regions # 渲染页面(裁剪区域) pix = page.get_pixmap(dpi=self.ocr_dpi, clip=clip_box) img = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, 3) try: layout_output = layout_engine(img) # 解析版面结果 if hasattr(layout_output, 'boxes') and hasattr(layout_output, 'class_names'): # 获取缩放比例 scale_x = clip_box.width / img.shape[1] scale_y = clip_box.height / img.shape[0] table_count = 0 figure_count = 0 for box, label, score in zip(layout_output.boxes, layout_output.class_names, layout_output.scores): if label in ("table", "figure") and score > self.confidence_threshold: # 转换为 PDF 坐标 pdf_x1 = clip_box.x0 + box[0] * scale_x pdf_y1 = clip_box.y0 + box[1] * scale_y pdf_x2 = clip_box.x0 + box[2] * scale_x pdf_y2 = clip_box.y0 + box[3] * scale_y table_regions.append(((pdf_x1, pdf_y1, pdf_x2, pdf_y2), score, label)) if label == "table": table_count += 1 else: figure_count += 1 if table_count or figure_count: logger.info(f" [YOLO] 第{page_num}页: table={table_count}, figure={figure_count}") except Exception as e: logger.warning(f" 第 {page_num} 页: 版面分析失败 ({e})") return table_regions def process_ocr_concurrent( self, regions: List[TableRegion], progress_callback=None ) -> List[OcrResult]: """ 同步并发处理 OCR Args: regions: 表格区域列表 progress_callback: 进度回调函数,接收 (completed, total) 参数 Returns: OCR 结果列表 """ results: List[OcrResult] = [] total = len(regions) completed = 0 # 统计 table_total = sum(1 for r in regions if r.label == "table") figure_total = sum(1 for r in regions if r.label == "figure") logger.info(f"[OCR] 开始并发识别: table={table_total}, figure={figure_total}, workers={self.concurrent_workers}") with ThreadPoolExecutor(max_workers=self.concurrent_workers) as executor: # 提交所有任务 future_to_region = { executor.submit(self._ocr_table_region, r.page, r.bbox): r for r in regions } # 处理完成的结果 non_table_count = 0 table_ok_count = 0 for future in as_completed(future_to_region): region = future_to_region[future] completed += 1 try: text = future.result() if text.strip(): table_ok_count += 1 else: non_table_count += 1 results.append(OcrResult( page_num=region.page_num, bbox=region.bbox, score=region.score, text=text, success=True, )) except Exception as e: non_table_count += 1 logger.error(f" 第 {region.page_num} 页 {region.label} OCR 失败: {e}") results.append(OcrResult( page_num=region.page_num, bbox=region.bbox, score=region.score, text="", success=False, )) # 每完成5个或最后一个时推送进度 if progress_callback and (completed % 5 == 0 or completed == total): progress_callback(completed, total) logger.info(f"[OCR] 完成: table={table_total}, figure={figure_total}, " f"有效表格={table_ok_count}, Non-table/失败={non_table_count}") return results def _ocr_table_region( self, page: fitz.Page, bbox: Tuple[float, float, float, float], max_retries: int = 3 ) -> str: """ 对指定区域进行 OCR 识别(使用 GLM-OCR),支持指数退避重试 Args: page: PDF 页面对象 bbox: 区域坐标 (x1, y1, x2, y2) max_retries: 最大重试次数 Returns: 识别的文本内容 """ # 渲染指定区域 rect = fitz.Rect(bbox) pix = page.get_pixmap(dpi=self.ocr_dpi, clip=rect) img_bytes = pix.tobytes("jpeg") # 压缩图片 compressed = self._compress_image(img_bytes) img_base64 = base64.b64encode(compressed).decode('utf-8') # 请求 OCR payload = { "model": "GLM-OCR", "messages": [ { "role": "user", "content": [ { "type": "text", "text": "判断图片中是否包含表格。" "- 若包含表格:用 Markdown 表格格式提取内容,保持行列对齐。" "- 若不包含任何表格:只输出 Non-table。" "只输出结果,不要解释。" }, { "type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{img_base64}"} } ] } ], "max_tokens": 2048, "temperature": 0.1 } headers = {"Content-Type": "application/json"} if self.ocr_api_key: headers["Authorization"] = f"Bearer {self.ocr_api_key}" # 指数退避重试 last_error = None for attempt in range(max_retries): try: response = requests.post( self.ocr_api_url, headers=headers, json=payload, timeout=self.ocr_timeout ) response.raise_for_status() result = response.json() return self._extract_ocr_content(result) except Exception as e: last_error = e if attempt < max_retries - 1: # 指数退避: 2, 4, 8 秒 wait_time = 2 ** (attempt + 1) logger.warning(f" 第 {page.number + 1} 页表格 OCR 第 {attempt + 1} 次失败: {e}, {wait_time}秒后重试...") time.sleep(wait_time) else: logger.error(f" 第 {page.number + 1} 页表格 OCR 最终失败(已重试{max_retries}次): {e}") # 所有重试都失败,抛出最后一个错误 raise last_error def _compress_image(self, img_bytes: bytes) -> bytes: """ 压缩图片 Args: img_bytes: 原始图片字节 Returns: 压缩后的图片字节 """ try: from PIL import Image img = Image.open(io.BytesIO(img_bytes)) if img.mode in ('RGBA', 'LA', 'P'): background = Image.new('RGB', img.size, (255, 255, 255)) if img.mode == 'P': img = img.convert('RGBA') if img.mode in ('RGBA', 'LA'): background.paste(img, mask=img.split()[-1]) img = background elif img.mode != 'RGB': img = img.convert('RGB') min_edge = min(img.size) if min_edge > self.max_short_edge: ratio = self.max_short_edge / min_edge new_size = (int(img.width * ratio), int(img.height * ratio)) img = img.resize(new_size, Image.Resampling.LANCZOS) buffer = io.BytesIO() img.save(buffer, format='JPEG', quality=self.jpeg_quality, optimize=True) return buffer.getvalue() except Exception as e: logger.warning(f"图片压缩失败,使用原图: {e}") return img_bytes def _extract_ocr_content(self, result: Dict) -> str: """ 从 OCR 响应提取内容,并将 HTML 表格转换为 Markdown Args: result: OCR API 响应 Returns: 提取的文本内容 """ content = "" if "choices" in result and isinstance(result["choices"], list): if len(result["choices"]) > 0: message = result["choices"][0].get("message", {}) content = message.get("content", "") # GLM 判定为非表格区域,返回空字符串,下游自然跳过 if content and content.strip().startswith("Non-table"): return "" # 如果内容包含 HTML 标签,转换为 Markdown if content and "<" in content and ">" in content: try: from utils_test.minimal_pipeline._html_to_md import convert_html_to_markdown content = convert_html_to_markdown(content) except Exception as e: logger.debug(f"HTML 转 Markdown 失败,保留原始内容: {e}") return content def replace_table_regions( self, page: fitz.Page, original_text: str, ocr_results: List[Dict], clip_box: fitz.Rect ) -> str: """ 用 OCR 结果替换原始文本中的表格区域 Args: page: PDF 页面对象 original_text: 原始文本 ocr_results: OCR 结果列表,每个元素包含 region_index, bbox, score, ocr_text clip_box: 裁剪区域 Returns: 替换后的文本 """ if not ocr_results: return original_text # 获取页面上的文本块及其坐标 text_blocks = [] for block in page.get_text("blocks"): x0, y0, x1, y1, text, _, _ = block # 只考虑裁剪区域内的文本 if y0 >= clip_box.y0 and y1 <= clip_box.y1: text_blocks.append({ "bbox": (x0, y0, x1, y1), "text": text.strip(), }) # 按 Y 坐标排序 text_blocks.sort(key=lambda b: (b["bbox"][1], b["bbox"][0])) # 找出属于表格区域的文本块 replaced_indices: Set[int] = set() for ocr_result in ocr_results: bbox = ocr_result["bbox"] rx0, ry0, rx1, ry1 = bbox for idx, block in enumerate(text_blocks): if idx in replaced_indices: continue bx0, by0, bx1, by1 = block["bbox"] # 检查重叠 overlap_x = max(0, min(bx1, rx1) - max(bx0, rx0)) overlap_y = max(0, min(by1, ry1) - max(by0, ry0)) overlap_area = overlap_x * overlap_y block_area = (bx1 - bx0) * (by1 - by0) if block_area > 0 and overlap_area / block_area > 0.5: replaced_indices.add(idx) # 构建新文本 result_parts: List[str] = [] last_idx = 0 for ocr_result in sorted(ocr_results, key=lambda r: r["bbox"][1]): bbox = ocr_result["bbox"] rx0, ry0, rx1, ry1 = bbox # 找到该表格区域之前的文本 region_start_idx = None for idx, block in enumerate(text_blocks): if idx in replaced_indices: bx0, by0, bx1, by1 = block["bbox"] if (bx0 >= rx0 - 5 and bx1 <= rx1 + 5 and by0 >= ry0 - 5 and by1 <= ry1 + 5): if region_start_idx is None: region_start_idx = idx last_idx = idx + 1 if region_start_idx is not None: # 添加表格前的非表格文本 for idx in range(last_idx - (last_idx - region_start_idx), region_start_idx): if idx not in replaced_indices and idx < len(text_blocks): result_parts.append(text_blocks[idx]["text"]) result_parts.append("\n") # 添加 OCR 结果 result_parts.append(ocr_result["ocr_text"]) result_parts.append("\n") # 添加剩余文本 for idx in range(last_idx, len(text_blocks)): if idx not in replaced_indices: result_parts.append(text_blocks[idx]["text"]) result_parts.append("\n") return "".join(result_parts).strip() or original_text