|
@@ -8,13 +8,13 @@ from collections import Counter
|
|
|
|
|
|
|
|
try:
|
|
try:
|
|
|
from .toc.toc_extractor import TOCExtractor
|
|
from .toc.toc_extractor import TOCExtractor
|
|
|
- from .classification.llm_classifier import LLMClassifier
|
|
|
|
|
|
|
+ from .classification.hierarchy_classifier import HierarchyClassifier
|
|
|
from .chunking.text_splitter import TextSplitter
|
|
from .chunking.text_splitter import TextSplitter
|
|
|
from .output.result_saver import ResultSaver
|
|
from .output.result_saver import ResultSaver
|
|
|
from .config.config_loader import get_config
|
|
from .config.config_loader import get_config
|
|
|
except ImportError:
|
|
except ImportError:
|
|
|
from toc.toc_extractor import TOCExtractor
|
|
from toc.toc_extractor import TOCExtractor
|
|
|
- from classification.llm_classifier import LLMClassifier
|
|
|
|
|
|
|
+ from classification.hierarchy_classifier import HierarchyClassifier
|
|
|
from chunking.text_splitter import TextSplitter
|
|
from chunking.text_splitter import TextSplitter
|
|
|
from output.result_saver import ResultSaver
|
|
from output.result_saver import ResultSaver
|
|
|
from config.config_loader import get_config
|
|
from config.config_loader import get_config
|
|
@@ -27,16 +27,13 @@ class DocumentClassifier:
|
|
|
支持PDF和Word文档的目录提取、分类和文本切分
|
|
支持PDF和Word文档的目录提取、分类和文本切分
|
|
|
"""
|
|
"""
|
|
|
|
|
|
|
|
- def __init__(self, model_url=None):
|
|
|
|
|
|
|
+ def __init__(self):
|
|
|
"""
|
|
"""
|
|
|
初始化文档分类器
|
|
初始化文档分类器
|
|
|
-
|
|
|
|
|
- 参数:
|
|
|
|
|
- model_url: 大语言模型API地址(已废弃,保留以兼容旧接口)
|
|
|
|
|
"""
|
|
"""
|
|
|
self.config = get_config()
|
|
self.config = get_config()
|
|
|
self.toc_extractor = TOCExtractor()
|
|
self.toc_extractor = TOCExtractor()
|
|
|
- self.llm_classifier = LLMClassifier(model_url)
|
|
|
|
|
|
|
+ self.hierarchy_classifier = HierarchyClassifier()
|
|
|
self.text_splitter = TextSplitter()
|
|
self.text_splitter = TextSplitter()
|
|
|
self.result_saver = ResultSaver()
|
|
self.result_saver = ResultSaver()
|
|
|
|
|
|
|
@@ -103,18 +100,31 @@ class DocumentClassifier:
|
|
|
print(f"\n成功提取 {toc_info['toc_count']} 个目录项")
|
|
print(f"\n成功提取 {toc_info['toc_count']} 个目录项")
|
|
|
print(f"目录所在页: {', '.join(map(str, toc_info['toc_pages']))}")
|
|
print(f"目录所在页: {', '.join(map(str, toc_info['toc_pages']))}")
|
|
|
|
|
|
|
|
- # 显示目录层级统计
|
|
|
|
|
|
|
+ # ========== 步骤2: 目录层级校对 ==========
|
|
|
|
|
+ print("\n" + "=" * 100)
|
|
|
|
|
+ print("步骤2: 目录层级校对")
|
|
|
|
|
+ print("=" * 100)
|
|
|
|
|
+
|
|
|
|
|
+ # 注意:toc_extractor.extract_toc 已经包含了层级识别
|
|
|
|
|
+ # 这里只是显示层级统计信息
|
|
|
level_counts = Counter([item['level'] for item in toc_info['toc_items']])
|
|
level_counts = Counter([item['level'] for item in toc_info['toc_items']])
|
|
|
print("\n目录层级分布:")
|
|
print("\n目录层级分布:")
|
|
|
for level in sorted(level_counts.keys()):
|
|
for level in sorted(level_counts.keys()):
|
|
|
print(f" {level}级: {level_counts[level]} 项")
|
|
print(f" {level}级: {level_counts[level]} 项")
|
|
|
|
|
|
|
|
- # ========== 步骤2: 使用正则和关键词进行分类 ==========
|
|
|
|
|
|
|
+ # 显示前几个目录项的层级信息
|
|
|
|
|
+ print("\n目录层级示例(前5项):")
|
|
|
|
|
+ for i, item in enumerate(toc_info['toc_items'][:5], 1):
|
|
|
|
|
+ print(f" [{i}] 第{item['level']}级: {item['title']}")
|
|
|
|
|
+ if len(toc_info['toc_items']) > 5:
|
|
|
|
|
+ print(f" ... 还有 {len(toc_info['toc_items']) - 5} 个目录项")
|
|
|
|
|
+
|
|
|
|
|
+ # ========== 步骤3: 目录分类(基于二级目录关键词匹配) ==========
|
|
|
print("\n" + "=" * 100)
|
|
print("\n" + "=" * 100)
|
|
|
- print("步骤2: 使用正则表达式和关键词进行智能分类")
|
|
|
|
|
|
|
+ print("步骤3: 目录分类(基于二级目录关键词匹配)")
|
|
|
print("=" * 100)
|
|
print("=" * 100)
|
|
|
|
|
|
|
|
- classification_result = self.llm_classifier.classify(
|
|
|
|
|
|
|
+ classification_result = self.hierarchy_classifier.classify(
|
|
|
toc_info['toc_items'],
|
|
toc_info['toc_items'],
|
|
|
target_level=target_level
|
|
target_level=target_level
|
|
|
)
|
|
)
|
|
@@ -128,9 +138,22 @@ class DocumentClassifier:
|
|
|
for category, count in sorted(category_counts.items(), key=lambda x: x[1], reverse=True):
|
|
for category, count in sorted(category_counts.items(), key=lambda x: x[1], reverse=True):
|
|
|
print(f" {category}: {count} 项")
|
|
print(f" {category}: {count} 项")
|
|
|
|
|
|
|
|
- # ========== 步骤3: 提取文档全文 ==========
|
|
|
|
|
|
|
+ # 显示分类详情(前几项)
|
|
|
|
|
+ print("\n分类详情示例(前3项):")
|
|
|
|
|
+ for i, item in enumerate(classification_result['items'][:3], 1):
|
|
|
|
|
+ print(f" [{i}] {item['title']}")
|
|
|
|
|
+ print(f" 分类: {item['category']}")
|
|
|
|
|
+ print(f" 二级目录数: {item['level2_count']}")
|
|
|
|
|
+ if item['level2_titles']:
|
|
|
|
|
+ print(f" 二级目录: {', '.join(item['level2_titles'][:3])}")
|
|
|
|
|
+ if len(item['level2_titles']) > 3:
|
|
|
|
|
+ print(f" ... 还有 {len(item['level2_titles']) - 3} 个")
|
|
|
|
|
+ if len(classification_result['items']) > 3:
|
|
|
|
|
+ print(f" ... 还有 {len(classification_result['items']) - 3} 个一级目录")
|
|
|
|
|
+
|
|
|
|
|
+ # ========== 步骤4: 提取文档全文 ==========
|
|
|
print("\n" + "=" * 100)
|
|
print("\n" + "=" * 100)
|
|
|
- print("步骤3: 提取文档全文")
|
|
|
|
|
|
|
+ print("步骤4: 提取文档全文")
|
|
|
print("=" * 100)
|
|
print("=" * 100)
|
|
|
|
|
|
|
|
pages_content = self.text_splitter.extract_full_text(file_path)
|
|
pages_content = self.text_splitter.extract_full_text(file_path)
|
|
@@ -141,9 +164,9 @@ class DocumentClassifier:
|
|
|
total_chars = sum(len(page['text']) for page in pages_content)
|
|
total_chars = sum(len(page['text']) for page in pages_content)
|
|
|
print(f"\n提取完成,共 {len(pages_content)} 页,{total_chars} 个字符")
|
|
print(f"\n提取完成,共 {len(pages_content)} 页,{total_chars} 个字符")
|
|
|
|
|
|
|
|
- # ========== 步骤4: 按分类标题切分文本 ==========
|
|
|
|
|
|
|
+ # ========== 步骤5: 按分类标题切分文本 ==========
|
|
|
print("\n" + "=" * 100)
|
|
print("\n" + "=" * 100)
|
|
|
- print("步骤4: 按分类标题智能切分文本")
|
|
|
|
|
|
|
+ print("步骤5: 按分类标题智能切分文本")
|
|
|
print("=" * 100)
|
|
print("=" * 100)
|
|
|
|
|
|
|
|
chunks = self.text_splitter.split_by_hierarchy(
|
|
chunks = self.text_splitter.split_by_hierarchy(
|
|
@@ -167,11 +190,11 @@ class DocumentClassifier:
|
|
|
if len(chunks) > 5:
|
|
if len(chunks) > 5:
|
|
|
print(f" ... 还有 {len(chunks) - 5} 个文本块")
|
|
print(f" ... 还有 {len(chunks) - 5} 个文本块")
|
|
|
|
|
|
|
|
- # ========== 步骤5: 保存结果(可选) ==========
|
|
|
|
|
|
|
+ # ========== 步骤6: 保存结果(可选) ==========
|
|
|
saved_files = None
|
|
saved_files = None
|
|
|
if save_results:
|
|
if save_results:
|
|
|
print("\n" + "=" * 100)
|
|
print("\n" + "=" * 100)
|
|
|
- print("步骤5: 保存结果")
|
|
|
|
|
|
|
+ print("步骤6: 保存结果")
|
|
|
print("=" * 100)
|
|
print("=" * 100)
|
|
|
|
|
|
|
|
# 保存结果
|
|
# 保存结果
|