| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485 |
- """
- OCR 处理模块 - 表格检测与识别
- 提供 PDF 表格区域检测和 OCR 识别功能,支持:
- - RapidLayout 表格区域检测
- - GLM-OCR 并发识别
- - 表格文本替换回填
- """
- import base64
- import io
- import time
- from concurrent.futures import ThreadPoolExecutor, as_completed
- from dataclasses import dataclass
- from typing import Dict, Any, List, Optional, Tuple, Set
- import fitz
- import numpy as np
- import requests
- from utils_test.minimal_pipeline._simple_logger import review_logger as logger
- # 尝试导入 RapidLayout
- try:
- from rapid_layout import RapidLayout
- RAPID_LAYOUT_AVAILABLE = True
- except ImportError:
- RAPID_LAYOUT_AVAILABLE = False
- RapidLayout = None
- @dataclass
- class TableRegion:
- """表格区域信息"""
- page_num: int
- page: fitz.Page
- bbox: Tuple[float, float, float, float]
- score: float
- label: str = "table" # YOLO 原始标签: table / figure
- @dataclass
- class OcrResult:
- """OCR 结果"""
- page_num: int
- bbox: Tuple[float, float, float, float]
- score: float
- text: str
- success: bool
- class OcrProcessor:
- """OCR 处理器:表格检测与识别"""
- # 默认配置
- MAX_SHORT_EDGE = 1024
- JPEG_QUALITY = 90
- OCR_DPI = 200
- OCR_CONFIDENCE_THRESHOLD = 0.5
- OCR_CONCURRENT_WORKERS = 20
- def __init__(
- self,
- ocr_api_url: str = "http://183.220.37.46:25429/v1/chat/completions",
- ocr_timeout: int = 600,
- ocr_api_key: str = "",
- max_short_edge: int = 1024,
- jpeg_quality: int = 90,
- ocr_dpi: int = 200,
- confidence_threshold: float = 0.5,
- concurrent_workers: int = 20,
- ):
- """
- 初始化 OCR 处理器
- Args:
- ocr_api_url: OCR API 地址
- ocr_timeout: OCR 请求超时时间(秒)
- ocr_api_key: OCR API 密钥
- max_short_edge: 图片压缩后短边最大尺寸
- jpeg_quality: JPEG 压缩质量
- ocr_dpi: OCR 渲染 DPI
- confidence_threshold: 表格检测置信度阈值
- concurrent_workers: OCR 并发工作线程数
- """
- self.ocr_api_url = ocr_api_url
- self.ocr_timeout = ocr_timeout
- self.ocr_api_key = ocr_api_key
- self.max_short_edge = max_short_edge
- self.jpeg_quality = jpeg_quality
- self.ocr_dpi = ocr_dpi
- self.confidence_threshold = confidence_threshold
- self.concurrent_workers = concurrent_workers
- self._layout_engine: Optional[Any] = None
- if not RAPID_LAYOUT_AVAILABLE:
- logger.warning("RapidLayout 未安装,表格检测功能不可用")
- def is_available(self) -> bool:
- """检查 OCR 功能是否可用"""
- return RAPID_LAYOUT_AVAILABLE
- def _get_layout_engine(self) -> Optional[Any]:
- """延迟初始化 RapidLayout"""
- if self._layout_engine is None and RAPID_LAYOUT_AVAILABLE:
- self._layout_engine = RapidLayout()
- return self._layout_engine
- def detect_table_regions(
- self,
- page: fitz.Page,
- page_num: int,
- clip_box: fitz.Rect
- ) -> List[Tuple[Tuple[float, float, float, float], float]]:
- """
- 检测页面中的表格区域
- Args:
- page: PDF 页面对象
- page_num: 页码(用于日志)
- clip_box: 裁剪区域
- Returns:
- 列表,元素为 ((x1, y1, x2, y2), score)
- """
- table_regions: List[Tuple[Tuple[float, float, float, float], float, str]] = []
- if not RAPID_LAYOUT_AVAILABLE:
- return table_regions
- layout_engine = self._get_layout_engine()
- if layout_engine is None:
- return table_regions
- # 渲染页面(裁剪区域)
- pix = page.get_pixmap(dpi=self.ocr_dpi, clip=clip_box)
- img = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, 3)
- try:
- layout_output = layout_engine(img)
- # 解析版面结果
- if hasattr(layout_output, 'boxes') and hasattr(layout_output, 'class_names'):
- # 获取缩放比例
- scale_x = clip_box.width / img.shape[1]
- scale_y = clip_box.height / img.shape[0]
- table_count = 0
- figure_count = 0
- for box, label, score in zip(layout_output.boxes, layout_output.class_names, layout_output.scores):
- if label in ("table", "figure") and score > self.confidence_threshold:
- # 转换为 PDF 坐标
- pdf_x1 = clip_box.x0 + box[0] * scale_x
- pdf_y1 = clip_box.y0 + box[1] * scale_y
- pdf_x2 = clip_box.x0 + box[2] * scale_x
- pdf_y2 = clip_box.y0 + box[3] * scale_y
- table_regions.append(((pdf_x1, pdf_y1, pdf_x2, pdf_y2), score, label))
- if label == "table":
- table_count += 1
- else:
- figure_count += 1
- if table_count or figure_count:
- logger.info(f" [YOLO] 第{page_num}页: table={table_count}, figure={figure_count}")
- except Exception as e:
- logger.warning(f" 第 {page_num} 页: 版面分析失败 ({e})")
- return table_regions
- def process_ocr_concurrent(
- self,
- regions: List[TableRegion],
- progress_callback=None
- ) -> List[OcrResult]:
- """
- 同步并发处理 OCR
- Args:
- regions: 表格区域列表
- progress_callback: 进度回调函数,接收 (completed, total) 参数
- Returns:
- OCR 结果列表
- """
- results: List[OcrResult] = []
- total = len(regions)
- completed = 0
- # 统计
- table_total = sum(1 for r in regions if r.label == "table")
- figure_total = sum(1 for r in regions if r.label == "figure")
- logger.info(f"[OCR] 开始并发识别: table={table_total}, figure={figure_total}, workers={self.concurrent_workers}")
- with ThreadPoolExecutor(max_workers=self.concurrent_workers) as executor:
- # 提交所有任务
- future_to_region = {
- executor.submit(self._ocr_table_region, r.page, r.bbox): r
- for r in regions
- }
- # 处理完成的结果
- non_table_count = 0
- table_ok_count = 0
- for future in as_completed(future_to_region):
- region = future_to_region[future]
- completed += 1
- try:
- text = future.result()
- if text.strip():
- table_ok_count += 1
- else:
- non_table_count += 1
- results.append(OcrResult(
- page_num=region.page_num,
- bbox=region.bbox,
- score=region.score,
- text=text,
- success=True,
- ))
- except Exception as e:
- non_table_count += 1
- logger.error(f" 第 {region.page_num} 页 {region.label} OCR 失败: {e}")
- results.append(OcrResult(
- page_num=region.page_num,
- bbox=region.bbox,
- score=region.score,
- text="",
- success=False,
- ))
- # 每完成5个或最后一个时推送进度
- if progress_callback and (completed % 5 == 0 or completed == total):
- progress_callback(completed, total)
- logger.info(f"[OCR] 完成: table={table_total}, figure={figure_total}, "
- f"有效表格={table_ok_count}, Non-table/失败={non_table_count}")
- return results
- def _ocr_table_region(
- self,
- page: fitz.Page,
- bbox: Tuple[float, float, float, float],
- max_retries: int = 3
- ) -> str:
- """
- 对指定区域进行 OCR 识别(使用 GLM-OCR),支持指数退避重试
- Args:
- page: PDF 页面对象
- bbox: 区域坐标 (x1, y1, x2, y2)
- max_retries: 最大重试次数
- Returns:
- 识别的文本内容
- """
- # 渲染指定区域
- rect = fitz.Rect(bbox)
- pix = page.get_pixmap(dpi=self.ocr_dpi, clip=rect)
- img_bytes = pix.tobytes("jpeg")
- # 压缩图片
- compressed = self._compress_image(img_bytes)
- img_base64 = base64.b64encode(compressed).decode('utf-8')
- # 请求 OCR
- payload = {
- "model": "GLM-OCR",
- "messages": [
- {
- "role": "user",
- "content": [
- {
- "type": "text",
- "text": "判断图片中是否包含表格。"
- "- 若包含表格:用 Markdown 表格格式提取内容,保持行列对齐。"
- "- 若不包含任何表格:只输出 Non-table。"
- "只输出结果,不要解释。"
- },
- {
- "type": "image_url",
- "image_url": {"url": f"data:image/jpeg;base64,{img_base64}"}
- }
- ]
- }
- ],
- "max_tokens": 2048,
- "temperature": 0.1
- }
- headers = {"Content-Type": "application/json"}
- if self.ocr_api_key:
- headers["Authorization"] = f"Bearer {self.ocr_api_key}"
- # 指数退避重试
- last_error = None
- for attempt in range(max_retries):
- try:
- response = requests.post(
- self.ocr_api_url,
- headers=headers,
- json=payload,
- timeout=self.ocr_timeout
- )
- response.raise_for_status()
- result = response.json()
- return self._extract_ocr_content(result)
- except Exception as e:
- last_error = e
- if attempt < max_retries - 1:
- # 指数退避: 2, 4, 8 秒
- wait_time = 2 ** (attempt + 1)
- logger.warning(f" 第 {page.number + 1} 页表格 OCR 第 {attempt + 1} 次失败: {e}, {wait_time}秒后重试...")
- time.sleep(wait_time)
- else:
- logger.error(f" 第 {page.number + 1} 页表格 OCR 最终失败(已重试{max_retries}次): {e}")
- # 所有重试都失败,抛出最后一个错误
- raise last_error
- def _compress_image(self, img_bytes: bytes) -> bytes:
- """
- 压缩图片
- Args:
- img_bytes: 原始图片字节
- Returns:
- 压缩后的图片字节
- """
- try:
- from PIL import Image
- img = Image.open(io.BytesIO(img_bytes))
- if img.mode in ('RGBA', 'LA', 'P'):
- background = Image.new('RGB', img.size, (255, 255, 255))
- if img.mode == 'P':
- img = img.convert('RGBA')
- if img.mode in ('RGBA', 'LA'):
- background.paste(img, mask=img.split()[-1])
- img = background
- elif img.mode != 'RGB':
- img = img.convert('RGB')
- min_edge = min(img.size)
- if min_edge > self.max_short_edge:
- ratio = self.max_short_edge / min_edge
- new_size = (int(img.width * ratio), int(img.height * ratio))
- img = img.resize(new_size, Image.Resampling.LANCZOS)
- buffer = io.BytesIO()
- img.save(buffer, format='JPEG', quality=self.jpeg_quality, optimize=True)
- return buffer.getvalue()
- except Exception as e:
- logger.warning(f"图片压缩失败,使用原图: {e}")
- return img_bytes
- def _extract_ocr_content(self, result: Dict) -> str:
- """
- 从 OCR 响应提取内容,并将 HTML 表格转换为 Markdown
- Args:
- result: OCR API 响应
- Returns:
- 提取的文本内容
- """
- content = ""
- if "choices" in result and isinstance(result["choices"], list):
- if len(result["choices"]) > 0:
- message = result["choices"][0].get("message", {})
- content = message.get("content", "")
- # GLM 判定为非表格区域,返回空字符串,下游自然跳过
- if content and content.strip().startswith("Non-table"):
- return ""
- # 如果内容包含 HTML 标签,转换为 Markdown
- if content and "<" in content and ">" in content:
- try:
- from utils_test.minimal_pipeline._html_to_md import convert_html_to_markdown
- content = convert_html_to_markdown(content)
- except Exception as e:
- logger.debug(f"HTML 转 Markdown 失败,保留原始内容: {e}")
- return content
- def replace_table_regions(
- self,
- page: fitz.Page,
- original_text: str,
- ocr_results: List[Dict],
- clip_box: fitz.Rect
- ) -> str:
- """
- 用 OCR 结果替换原始文本中的表格区域
- Args:
- page: PDF 页面对象
- original_text: 原始文本
- ocr_results: OCR 结果列表,每个元素包含 region_index, bbox, score, ocr_text
- clip_box: 裁剪区域
- Returns:
- 替换后的文本
- """
- if not ocr_results:
- return original_text
- # 获取页面上的文本块及其坐标
- text_blocks = []
- for block in page.get_text("blocks"):
- x0, y0, x1, y1, text, _, _ = block
- # 只考虑裁剪区域内的文本
- if y0 >= clip_box.y0 and y1 <= clip_box.y1:
- text_blocks.append({
- "bbox": (x0, y0, x1, y1),
- "text": text.strip(),
- })
- # 按 Y 坐标排序
- text_blocks.sort(key=lambda b: (b["bbox"][1], b["bbox"][0]))
- # 找出属于表格区域的文本块
- replaced_indices: Set[int] = set()
- for ocr_result in ocr_results:
- bbox = ocr_result["bbox"]
- rx0, ry0, rx1, ry1 = bbox
- for idx, block in enumerate(text_blocks):
- if idx in replaced_indices:
- continue
- bx0, by0, bx1, by1 = block["bbox"]
- # 检查重叠
- overlap_x = max(0, min(bx1, rx1) - max(bx0, rx0))
- overlap_y = max(0, min(by1, ry1) - max(by0, ry0))
- overlap_area = overlap_x * overlap_y
- block_area = (bx1 - bx0) * (by1 - by0)
- if block_area > 0 and overlap_area / block_area > 0.5:
- replaced_indices.add(idx)
- # 构建新文本
- result_parts: List[str] = []
- last_idx = 0
- for ocr_result in sorted(ocr_results, key=lambda r: r["bbox"][1]):
- bbox = ocr_result["bbox"]
- rx0, ry0, rx1, ry1 = bbox
- # 找到该表格区域之前的文本
- region_start_idx = None
- for idx, block in enumerate(text_blocks):
- if idx in replaced_indices:
- bx0, by0, bx1, by1 = block["bbox"]
- if (bx0 >= rx0 - 5 and bx1 <= rx1 + 5 and
- by0 >= ry0 - 5 and by1 <= ry1 + 5):
- if region_start_idx is None:
- region_start_idx = idx
- last_idx = idx + 1
- if region_start_idx is not None:
- # 添加表格前的非表格文本
- for idx in range(last_idx - (last_idx - region_start_idx), region_start_idx):
- if idx not in replaced_indices and idx < len(text_blocks):
- result_parts.append(text_blocks[idx]["text"])
- result_parts.append("\n")
- # 添加 OCR 结果
- result_parts.append(ocr_result["ocr_text"])
- result_parts.append("\n")
- # 添加剩余文本
- for idx in range(last_idx, len(text_blocks)):
- if idx not in replaced_indices:
- result_parts.append(text_blocks[idx]["text"])
- result_parts.append("\n")
- return "".join(result_parts).strip() or original_text
|