|
|
@@ -1,16 +1,16 @@
|
|
|
"""
|
|
|
-OCR 增强提取器 - 稳定版
|
|
|
+OCR 增强提取器 - 精准表格区域版
|
|
|
|
|
|
流程:
|
|
|
-1. PyMuPDF 提取全部文本(用于章节切分)
|
|
|
-2. RapidLayout 检测表格页
|
|
|
-3. 对表格页 OCR,替换该页内容
|
|
|
-4. 保持章节切分逻辑不变
|
|
|
+1. PyMuPDF 提取全部文本(用于章节切分,确保格式稳定)
|
|
|
+2. RapidLayout 检测表格区域(返回坐标)
|
|
|
+3. 只对表格区域进行 OCR,替换该区域内容
|
|
|
+4. 其他文本保持 PyMuPDF 提取结果,章节标题不受影响
|
|
|
|
|
|
特点:
|
|
|
- 章节切分基于 PyMuPDF 文本(格式稳定,正则匹配可靠)
|
|
|
-- 表格页内容通过 OCR 补充(识别率高)
|
|
|
-- 输出标记哪些页使用了 OCR
|
|
|
+- 仅表格区域使用 OCR(精准定位,不影响其他内容)
|
|
|
+- 输出标记哪些页使用了 OCR 及表格区域坐标
|
|
|
"""
|
|
|
|
|
|
from __future__ import annotations
|
|
|
@@ -124,61 +124,69 @@ class OcrEnhancedExtractor(FullTextExtractor):
|
|
|
self._layout_engine = RapidLayout()
|
|
|
return self._layout_engine
|
|
|
|
|
|
- def _detect_table_pages(self, doc: fitz.Document) -> Set[int]:
|
|
|
- """检测含表格的页码"""
|
|
|
- table_pages: Set[int] = set()
|
|
|
+ def _detect_table_regions(self, page: fitz.Page, page_num: int) -> List[Tuple[Tuple[float, float, float, float], float]]:
|
|
|
+ """
|
|
|
+ 检测页面中的表格区域
|
|
|
+
|
|
|
+ Args:
|
|
|
+ page: PDF 页面对象
|
|
|
+ page_num: 页码(用于日志)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 表格区域列表,每个元素为 ((x1, y1, x2, y2), score)
|
|
|
+ """
|
|
|
+ table_regions: List[Tuple[Tuple[float, float, float, float], float]] = []
|
|
|
|
|
|
if not RAPID_LAYOUT_AVAILABLE:
|
|
|
- return table_pages
|
|
|
+ return table_regions
|
|
|
|
|
|
layout_engine = self._get_layout_engine()
|
|
|
if layout_engine is None:
|
|
|
- return table_pages
|
|
|
+ return table_regions
|
|
|
|
|
|
- logger.info(f"[版面分析] 检测表格页,共 {len(doc)} 页")
|
|
|
-
|
|
|
- for page_num in range(1, len(doc) + 1):
|
|
|
- page = doc[page_num - 1]
|
|
|
-
|
|
|
- # 裁剪页眉页脚
|
|
|
- rect = page.rect
|
|
|
- clip_box = fitz.Rect(0, self.clip_top, rect.width, rect.height - self.clip_bottom)
|
|
|
-
|
|
|
- # 渲染页面
|
|
|
- pix = page.get_pixmap(dpi=self.dpi, clip=clip_box)
|
|
|
- img = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, 3)
|
|
|
-
|
|
|
- try:
|
|
|
- layout_output = layout_engine(img)
|
|
|
+ # 裁剪页眉页脚
|
|
|
+ rect = page.rect
|
|
|
+ clip_box = fitz.Rect(0, self.clip_top, rect.width, rect.height - self.clip_bottom)
|
|
|
|
|
|
- # 解析版面结果
|
|
|
- labels = []
|
|
|
- if hasattr(layout_output, 'class_names'):
|
|
|
- labels = list(layout_output.class_names)
|
|
|
- elif hasattr(layout_output, 'boxes'):
|
|
|
- labels = [
|
|
|
- label for _, label, _
|
|
|
- in zip(layout_output.boxes, layout_output.class_names, layout_output.scores)
|
|
|
- ]
|
|
|
+ # 渲染页面
|
|
|
+ pix = page.get_pixmap(dpi=self.dpi, clip=clip_box)
|
|
|
+ img = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, 3)
|
|
|
|
|
|
- if "table" in labels:
|
|
|
- table_pages.add(page_num)
|
|
|
- logger.debug(f" 第 {page_num} 页: 检测到表格")
|
|
|
+ try:
|
|
|
+ layout_output = layout_engine(img)
|
|
|
+
|
|
|
+ # 解析版面结果
|
|
|
+ if hasattr(layout_output, 'boxes') and hasattr(layout_output, 'class_names'):
|
|
|
+ # 获取缩放比例(像素坐标转 PDF 坐标)
|
|
|
+ scale_x = clip_box.width / img.shape[1]
|
|
|
+ scale_y = clip_box.height / img.shape[0]
|
|
|
+
|
|
|
+ for box, label, score in zip(layout_output.boxes, layout_output.class_names, layout_output.scores):
|
|
|
+ if label == "table" and score > 0.5: # 置信度阈值
|
|
|
+ # box 格式: [x1, y1, x2, y2] 像素坐标
|
|
|
+ # 转换为 PDF 坐标(加上裁剪区域的偏移)
|
|
|
+ pdf_x1 = clip_box.x0 + box[0] * scale_x
|
|
|
+ pdf_y1 = clip_box.y0 + box[1] * scale_y
|
|
|
+ pdf_x2 = clip_box.x0 + box[2] * scale_x
|
|
|
+ pdf_y2 = clip_box.y0 + box[3] * scale_y
|
|
|
+
|
|
|
+ table_regions.append(((pdf_x1, pdf_y1, pdf_x2, pdf_y2), score))
|
|
|
+ logger.debug(f" 第 {page_num} 页: 检测到表格 ({pdf_x1:.1f}, {pdf_y1:.1f}, {pdf_x2:.1f}, {pdf_y2:.1f}), 置信度 {score:.2f}")
|
|
|
|
|
|
- except Exception as e:
|
|
|
- logger.warning(f" 第 {page_num} 页: 版面分析失败 ({e})")
|
|
|
+ except Exception as e:
|
|
|
+ logger.warning(f" 第 {page_num} 页: 版面分析失败 ({e})")
|
|
|
|
|
|
- logger.info(f"[版面分析] 检测到 {len(table_pages)} 页含表格")
|
|
|
- return table_pages
|
|
|
+ return table_regions
|
|
|
|
|
|
def extract_full_text(self, source: DocumentSource) -> List[Dict[str, Any]]:
|
|
|
"""
|
|
|
- 执行 OCR 增强提取
|
|
|
+ 执行 OCR 增强提取(精准表格区域版)
|
|
|
|
|
|
流程:
|
|
|
- 1. PyMuPDF 提取全部文本
|
|
|
- 2. 检测表格页
|
|
|
- 3. 对表格页 OCR 替换内容
|
|
|
+ 1. PyMuPDF 提取全部文本(确保章节格式稳定)
|
|
|
+ 2. 检测每页的表格区域(返回坐标)
|
|
|
+ 3. 只 OCR 表格区域,替换该区域内容
|
|
|
+ 4. 其他文本保持 PyMuPDF 结果
|
|
|
"""
|
|
|
total_start = time.time()
|
|
|
|
|
|
@@ -211,44 +219,69 @@ class OcrEnhancedExtractor(FullTextExtractor):
|
|
|
"start_pos": 0, # 后续计算
|
|
|
"end_pos": 0,
|
|
|
"source_file": source_file,
|
|
|
- "is_ocr": False, # 标记是否 OCR
|
|
|
+ "is_ocr": False,
|
|
|
+ "ocr_regions": [], # OCR 区域信息
|
|
|
})
|
|
|
|
|
|
- # 阶段 2: 检测表格页
|
|
|
- logger.info("[阶段2] 检测表格页...")
|
|
|
- table_pages = self._detect_table_pages(doc)
|
|
|
+ # 阶段 2&3: 逐页检测表格区域并 OCR 替换
|
|
|
+ logger.info("[阶段2] 检测表格区域并精准 OCR...")
|
|
|
+ total_ocr_count = 0
|
|
|
+ total_ocr_time = 0.0
|
|
|
+
|
|
|
+ for page_num in range(1, total_pages + 1):
|
|
|
+ page = doc[page_num - 1]
|
|
|
+
|
|
|
+ # 检测该页的表格区域
|
|
|
+ table_regions = self._detect_table_regions(page, page_num)
|
|
|
|
|
|
- # 阶段 3: 对表格页 OCR
|
|
|
- if table_pages:
|
|
|
- logger.info(f"[阶段3] 对 {len(table_pages)} 页进行 OCR...")
|
|
|
- ocr_count = 0
|
|
|
- ocr_time = 0.0
|
|
|
+ if not table_regions:
|
|
|
+ continue
|
|
|
|
|
|
- for page_num in table_pages:
|
|
|
- page = doc[page_num - 1]
|
|
|
+ logger.info(f" 第 {page_num} 页: 检测到 {len(table_regions)} 个表格区域")
|
|
|
|
|
|
+ # 对每个表格区域进行 OCR
|
|
|
+ ocr_results = []
|
|
|
+ for idx, (bbox, score) in enumerate(table_regions):
|
|
|
try:
|
|
|
ocr_start = time.time()
|
|
|
|
|
|
- if self.ocr_engine_normalized == "glm_ocr":
|
|
|
- ocr_text = self._ocr_with_glm(page, page_num)
|
|
|
- else:
|
|
|
- ocr_text = self._ocr_with_mineru(doc, page_num)
|
|
|
+ # 只 OCR 表格区域
|
|
|
+ ocr_text = self._ocr_table_region(page, bbox)
|
|
|
|
|
|
- ocr_time += time.time() - ocr_start
|
|
|
- ocr_count += 1
|
|
|
+ ocr_time = time.time() - ocr_start
|
|
|
+ total_ocr_time += ocr_time
|
|
|
|
|
|
- # 替换该页内容
|
|
|
- pages[page_num - 1]["text"] = ocr_text
|
|
|
- pages[page_num - 1]["is_ocr"] = True
|
|
|
- pages[page_num - 1]["original_text"] = pages[page_num - 1]["text"] # 保留原文
|
|
|
+ ocr_results.append({
|
|
|
+ "region_index": idx,
|
|
|
+ "bbox": bbox,
|
|
|
+ "score": score,
|
|
|
+ "ocr_text": ocr_text,
|
|
|
+ "ocr_time": ocr_time,
|
|
|
+ })
|
|
|
|
|
|
- logger.debug(f" 第 {page_num} 页: OCR 完成 ({len(ocr_text)} 字符)")
|
|
|
+ logger.debug(f" 区域 {idx+1}: OCR 完成 ({len(ocr_text)} 字符), 耗时 {ocr_time:.2f}s")
|
|
|
|
|
|
except Exception as e:
|
|
|
- logger.error(f" 第 {page_num} 页: OCR 失败 ({e}),使用原文")
|
|
|
+ logger.error(f" 区域 {idx+1}: OCR 失败 ({e}),保留原文")
|
|
|
+
|
|
|
+ # 替换表格区域内容
|
|
|
+ if ocr_results:
|
|
|
+ original_text = pages[page_num - 1]["text"]
|
|
|
+ updated_text = self._replace_table_regions(
|
|
|
+ page, original_text, ocr_results, table_regions
|
|
|
+ )
|
|
|
+
|
|
|
+ pages[page_num - 1]["text"] = updated_text
|
|
|
+ pages[page_num - 1]["is_ocr"] = True
|
|
|
+ pages[page_num - 1]["ocr_regions"] = [
|
|
|
+ {"bbox": r["bbox"], "score": r["score"], "chars": len(r["ocr_text"])}
|
|
|
+ for r in ocr_results
|
|
|
+ ]
|
|
|
+
|
|
|
+ total_ocr_count += len(ocr_results)
|
|
|
|
|
|
- logger.info(f"[OCR] 完成 {ocr_count} 页,耗时 {ocr_time:.2f}s")
|
|
|
+ if total_ocr_count > 0:
|
|
|
+ logger.info(f"[OCR] 完成 {total_ocr_count} 个表格区域,耗时 {total_ocr_time:.2f}s")
|
|
|
|
|
|
# 阶段 4: 计算位置
|
|
|
current_pos = 0
|
|
|
@@ -264,19 +297,186 @@ class OcrEnhancedExtractor(FullTextExtractor):
|
|
|
# 统计
|
|
|
total_time = time.time() - total_start
|
|
|
ocr_pages = sum(1 for p in pages if p.get("is_ocr"))
|
|
|
+ total_ocr_regions = sum(len(p.get("ocr_regions", [])) for p in pages)
|
|
|
total_chars = sum(len(p["text"]) for p in pages)
|
|
|
|
|
|
logger.info(
|
|
|
f"[提取完成] 总页数: {total_pages} | "
|
|
|
- f"OCR: {ocr_pages} | 本地: {total_pages - ocr_pages} | "
|
|
|
+ f"OCR页: {ocr_pages} | 本地页: {total_pages - ocr_pages} | "
|
|
|
+ f"OCR区域: {total_ocr_regions} | "
|
|
|
f"总耗时: {total_time:.2f}s | "
|
|
|
f"总字符: {total_chars}"
|
|
|
)
|
|
|
|
|
|
return pages
|
|
|
|
|
|
+ def _ocr_table_region(self, page: fitz.Page, bbox: Tuple[float, float, float, float]) -> str:
|
|
|
+ """
|
|
|
+ 对指定区域进行 OCR 识别
|
|
|
+
|
|
|
+ Args:
|
|
|
+ page: PDF 页面对象
|
|
|
+ bbox: 区域坐标 (x1, y1, x2, y2)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ OCR 识别结果文本
|
|
|
+ """
|
|
|
+ # 渲染指定区域
|
|
|
+ rect = fitz.Rect(bbox)
|
|
|
+ pix = page.get_pixmap(dpi=self.dpi, clip=rect)
|
|
|
+ img_bytes = pix.tobytes("jpeg")
|
|
|
+
|
|
|
+ # 压缩
|
|
|
+ compressed = self._compress_image(img_bytes)
|
|
|
+ img_base64 = base64.b64encode(compressed).decode('utf-8')
|
|
|
+
|
|
|
+ # 请求 OCR
|
|
|
+ payload = {
|
|
|
+ "model": "GLM-OCR",
|
|
|
+ "messages": [
|
|
|
+ {
|
|
|
+ "role": "user",
|
|
|
+ "content": [
|
|
|
+ {
|
|
|
+ "type": "text",
|
|
|
+ "text": "识别图片中的表格内容,按原文排版输出。"
|
|
|
+ "注意:"
|
|
|
+ "1. 表格用 Markdown 表格格式"
|
|
|
+ "2. 保持换行和列对齐"
|
|
|
+ "3. 只输出表格内容,不要其他说明"
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "type": "image_url",
|
|
|
+ "image_url": {"url": f"data:image/jpeg;base64,{img_base64}"}
|
|
|
+ }
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "max_tokens": 2048,
|
|
|
+ "temperature": 0.1
|
|
|
+ }
|
|
|
+
|
|
|
+ response = requests.post(
|
|
|
+ self.glm_api_url,
|
|
|
+ headers=self.glm_headers,
|
|
|
+ json=payload,
|
|
|
+ timeout=self.glm_timeout
|
|
|
+ )
|
|
|
+ response.raise_for_status()
|
|
|
+
|
|
|
+ result = response.json()
|
|
|
+ content = self._extract_content(result)
|
|
|
+
|
|
|
+ return content
|
|
|
+
|
|
|
+ def _replace_table_regions(
|
|
|
+ self,
|
|
|
+ page: fitz.Page,
|
|
|
+ original_text: str,
|
|
|
+ ocr_results: List[Dict[str, Any]],
|
|
|
+ table_regions: List[Tuple[Tuple[float, float, float, float], float]]
|
|
|
+ ) -> str:
|
|
|
+ """
|
|
|
+ 用 OCR 结果替换原始文本中的表格区域
|
|
|
+
|
|
|
+ 策略:
|
|
|
+ 1. 找到表格区域在原始文本中的位置
|
|
|
+ 2. 用 OCR 结果替换该部分内容
|
|
|
+ 3. 保留其他所有文本(包括章节标题)
|
|
|
+
|
|
|
+ Args:
|
|
|
+ page: PDF 页面对象
|
|
|
+ original_text: 原始文本(PyMuPDF 提取)
|
|
|
+ ocr_results: OCR 结果列表
|
|
|
+ table_regions: 表格区域坐标列表
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 替换后的文本
|
|
|
+ """
|
|
|
+ if not ocr_results:
|
|
|
+ return original_text
|
|
|
+
|
|
|
+ # 获取页面上的文本块及其坐标
|
|
|
+ text_blocks = []
|
|
|
+ for block in page.get_text("blocks"):
|
|
|
+ # block 格式: (x0, y0, x1, y1, text, block_no, block_type)
|
|
|
+ x0, y0, x1, y1, text, _, _ = block
|
|
|
+ # 只考虑页眉页脚裁剪区域内的文本
|
|
|
+ if y0 >= self.clip_top and y1 <= page.rect.height - self.clip_bottom:
|
|
|
+ text_blocks.append({
|
|
|
+ "bbox": (x0, y0, x1, y1),
|
|
|
+ "text": text.strip(),
|
|
|
+ })
|
|
|
+
|
|
|
+ # 按 Y 坐标排序(从上到下)
|
|
|
+ text_blocks.sort(key=lambda b: (b["bbox"][1], b["bbox"][0]))
|
|
|
+
|
|
|
+ # 标记哪些文本块属于表格区域
|
|
|
+ replaced_indices = set()
|
|
|
+ for region_idx, (bbox, _) in enumerate(table_regions):
|
|
|
+ for idx, block in enumerate(text_blocks):
|
|
|
+ if idx in replaced_indices:
|
|
|
+ continue
|
|
|
+ # 检查文本块是否与表格区域有重叠
|
|
|
+ bx0, by0, bx1, by1 = block["bbox"]
|
|
|
+ rx0, ry0, rx1, ry1 = bbox
|
|
|
+
|
|
|
+ # 计算重叠区域
|
|
|
+ overlap_x = max(0, min(bx1, rx1) - max(bx0, rx0))
|
|
|
+ overlap_y = max(0, min(by1, ry1) - max(by0, ry0))
|
|
|
+ overlap_area = overlap_x * overlap_y
|
|
|
+ block_area = (bx1 - bx0) * (by1 - by0)
|
|
|
+
|
|
|
+ # 如果重叠面积超过 50%,认为是表格内的文本
|
|
|
+ if block_area > 0 and overlap_area / block_area > 0.5:
|
|
|
+ replaced_indices.add(idx)
|
|
|
+
|
|
|
+ # 构建新文本:保留非表格区域的文本,替换表格区域为 OCR 结果
|
|
|
+ result_parts = []
|
|
|
+ last_idx = 0
|
|
|
+
|
|
|
+ # 按顺序处理每个表格区域
|
|
|
+ for region_idx, (bbox, score) in enumerate(table_regions):
|
|
|
+ if region_idx >= len(ocr_results):
|
|
|
+ continue
|
|
|
+
|
|
|
+ ocr_text = ocr_results[region_idx]["ocr_text"]
|
|
|
+
|
|
|
+ # 找到该表格区域之前需要保留的文本
|
|
|
+ region_blocks = []
|
|
|
+ for idx, block in enumerate(text_blocks):
|
|
|
+ if idx in replaced_indices:
|
|
|
+ bx0, by0, bx1, by1 = block["bbox"]
|
|
|
+ rx0, ry0, rx1, ry1 = bbox
|
|
|
+ # 如果该文本块属于当前表格区域
|
|
|
+ if (bx0 >= rx0 - 5 and bx1 <= rx1 + 5 and
|
|
|
+ by0 >= ry0 - 5 and by1 <= ry1 + 5):
|
|
|
+ region_blocks.append((idx, block))
|
|
|
+
|
|
|
+ if region_blocks:
|
|
|
+ # 在第一个表格块之前添加之前的内容
|
|
|
+ first_idx = region_blocks[0][0]
|
|
|
+ for idx in range(last_idx, first_idx):
|
|
|
+ if idx not in replaced_indices:
|
|
|
+ result_parts.append(text_blocks[idx]["text"])
|
|
|
+ result_parts.append("\n")
|
|
|
+
|
|
|
+ # 添加 OCR 结果
|
|
|
+ result_parts.append(ocr_text)
|
|
|
+ result_parts.append("\n")
|
|
|
+
|
|
|
+ last_idx = region_blocks[-1][0] + 1
|
|
|
+
|
|
|
+ # 添加剩余的非表格文本
|
|
|
+ for idx in range(last_idx, len(text_blocks)):
|
|
|
+ if idx not in replaced_indices:
|
|
|
+ result_parts.append(text_blocks[idx]["text"])
|
|
|
+ result_parts.append("\n")
|
|
|
+
|
|
|
+ return "".join(result_parts)
|
|
|
+
|
|
|
def _ocr_with_glm(self, page: fitz.Page, page_num: int) -> str:
|
|
|
- """GLM-OCR 识别"""
|
|
|
+ """GLM-OCR 识别(整页版本,保留用于兼容)"""
|
|
|
# 渲染页面
|
|
|
rect = page.rect
|
|
|
clip_box = fitz.Rect(0, self.clip_top, rect.width, rect.height - self.clip_bottom)
|