|
|
@@ -0,0 +1,2149 @@
|
|
|
+#!/usr/bin/env python
|
|
|
+# -*- coding: utf-8 -*-
|
|
|
+"""
|
|
|
+LLM 内容三级分类识别模块
|
|
|
+
|
|
|
+根据 StandardCategoryTable.csv 的标准,让模型识别文档中的三级分类内容,
|
|
|
+输出 JSON 格式包含:三级分类名称、起止行号、原文内容
|
|
|
+
|
|
|
+特点:
|
|
|
+- 行级细粒度分类:返回每个三级分类的起止行号和原文内容
|
|
|
+- 多分类支持:一个段落可包含多个三级分类
|
|
|
+- 全局行号:维护全局连续行号,便于跨段落定位
|
|
|
+- Embedding 优化:相似度 >= 阈值时跳过 LLM,降低 API 成本
|
|
|
+- 分块处理:长段落自动分块,结果合并
|
|
|
+- 统一配置管理:从 config.ini 读取模型配置
|
|
|
+
|
|
|
+使用方式:
|
|
|
+1. 作为模块导入使用:
|
|
|
+ from llm_content_classifier_v2 import LLMContentClassifier, classify_chunks
|
|
|
+ result = await classify_chunks(chunks)
|
|
|
+
|
|
|
+2. 独立运行测试:
|
|
|
+ python llm_content_classifier_v2.py
|
|
|
+"""
|
|
|
+
|
|
|
+import asyncio
|
|
|
+import json
|
|
|
+import re
|
|
|
+import csv
|
|
|
+import time
|
|
|
+import math
|
|
|
+from pathlib import Path
|
|
|
+from typing import Dict, List, Optional, Tuple, Any
|
|
|
+from dataclasses import dataclass, field
|
|
|
+from openai import AsyncOpenAI
|
|
|
+
|
|
|
+# 导入统一配置处理器
|
|
|
+from foundation.infrastructure.config.config import config_handler
|
|
|
+
|
|
|
+
|
|
|
+# ==================== 配置类 ====================
|
|
|
+
|
|
|
+def _get_llm_config_from_ini(model_type: str) -> Tuple[str, str, str]:
|
|
|
+ """
|
|
|
+ 从 config.ini 获取 LLM 配置
|
|
|
+
|
|
|
+ Args:
|
|
|
+ model_type: 模型类型(如 qwen3_5_122b_a10b)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ Tuple[str, str, str]: (api_key, base_url, model_id)
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ # 尝试读取 DashScope 格式配置
|
|
|
+ base_url = config_handler.get(model_type, "DASHSCOPE_SERVER_URL", "")
|
|
|
+ model_id = config_handler.get(model_type, "DASHSCOPE_MODEL_ID", "")
|
|
|
+ api_key = config_handler.get(model_type, "DASHSCOPE_API_KEY", "")
|
|
|
+
|
|
|
+ # 如果没有 DashScope 配置,尝试读取其他格式
|
|
|
+ if not base_url:
|
|
|
+ # 尝试 QWEN_SERVER_URL 格式
|
|
|
+ base_url = config_handler.get(model_type, f"{model_type.upper()}_SERVER_URL", "")
|
|
|
+ model_id = config_handler.get(model_type, f"{model_type.upper()}_MODEL_ID", "")
|
|
|
+ api_key = config_handler.get(model_type, f"{model_type.upper()}_API_KEY", "")
|
|
|
+
|
|
|
+ return api_key, base_url, model_id
|
|
|
+ except Exception:
|
|
|
+ return "", "", ""
|
|
|
+
|
|
|
+
|
|
|
+def _get_embedding_config_from_ini(embedding_model_type: str) -> Tuple[str, str, str]:
|
|
|
+ """
|
|
|
+ 从 config.ini 获取 Embedding 模型配置
|
|
|
+
|
|
|
+ Args:
|
|
|
+ embedding_model_type: Embedding 模型类型
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ Tuple[str, str, str]: (api_key, base_url, model_id)
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ # 本地 Embedding 模型
|
|
|
+ if embedding_model_type == "lq_qwen3_8b_emd":
|
|
|
+ base_url = config_handler.get("lq_qwen3_8b_emd", "LQ_EMBEDDING_SERVER_URL", "")
|
|
|
+ model_id = config_handler.get("lq_qwen3_8b_emd", "LQ_EMBEDDING_MODEL_ID", "Qwen3-Embedding-8B")
|
|
|
+ api_key = config_handler.get("lq_qwen3_8b_emd", "LQ_EMBEDDING_API_KEY", "dummy")
|
|
|
+ return api_key, base_url, model_id
|
|
|
+
|
|
|
+ # 硅基流动 Embedding 模型
|
|
|
+ elif embedding_model_type == "siliconflow_embed":
|
|
|
+ base_url = config_handler.get("siliconflow_embed", "SLCF_EMBED_SERVER_URL", "")
|
|
|
+ model_id = config_handler.get("siliconflow_embed", "SLCF_EMBED_MODEL_ID", "Qwen/Qwen3-Embedding-8B")
|
|
|
+ api_key = config_handler.get("siliconflow_embed", "SLCF_EMBED_API_KEY", "")
|
|
|
+ return api_key, base_url, model_id
|
|
|
+
|
|
|
+ return "", "", ""
|
|
|
+ except Exception:
|
|
|
+ return "", "", ""
|
|
|
+
|
|
|
+
|
|
|
+@dataclass
|
|
|
+class ClassifierConfig:
|
|
|
+ """分类器配置(从 config.ini 加载)"""
|
|
|
+
|
|
|
+ # LLM API 配置(从 config.ini 加载)
|
|
|
+ api_key: str = ""
|
|
|
+ base_url: str = ""
|
|
|
+ model: str = ""
|
|
|
+
|
|
|
+ # 并发控制
|
|
|
+ max_concurrent_requests: int = 10
|
|
|
+ max_retries: int = 3
|
|
|
+ retry_delay: int = 1
|
|
|
+
|
|
|
+ # Embedding 配置(从 config.ini 加载)
|
|
|
+ embedding_api_key: str = ""
|
|
|
+ embedding_base_url: str = ""
|
|
|
+ embedding_model: str = ""
|
|
|
+ embedding_similarity_threshold: float = 0.9
|
|
|
+
|
|
|
+ # 路径配置
|
|
|
+ category_table_path: str = ""
|
|
|
+ second_category_path: str = ""
|
|
|
+ output_path: str = ""
|
|
|
+
|
|
|
+ def __post_init__(self):
|
|
|
+ """从 config.ini 加载配置"""
|
|
|
+ # 加载 LLM 配置
|
|
|
+ llm_model_type = config_handler.get("model", "COMPLETENESS_REVIEW_MODEL_TYPE", "qwen3_5_122b_a10b")
|
|
|
+ api_key, base_url, model_id = _get_llm_config_from_ini(llm_model_type)
|
|
|
+
|
|
|
+ # 设置 LLM 配置(如果从 config.ini 读取成功)
|
|
|
+ if api_key:
|
|
|
+ self.api_key = api_key
|
|
|
+ if base_url:
|
|
|
+ self.base_url = base_url
|
|
|
+ if model_id:
|
|
|
+ self.model = model_id
|
|
|
+
|
|
|
+ # 加载 Embedding 配置
|
|
|
+ embedding_model_type = config_handler.get("model", "EMBEDDING_MODEL_TYPE", "lq_qwen3_8b_emd")
|
|
|
+ emb_api_key, emb_base_url, emb_model_id = _get_embedding_config_from_ini(embedding_model_type)
|
|
|
+
|
|
|
+ if emb_api_key:
|
|
|
+ self.embedding_api_key = emb_api_key
|
|
|
+ if emb_base_url:
|
|
|
+ self.embedding_base_url = emb_base_url
|
|
|
+ if emb_model_id:
|
|
|
+ self.embedding_model = emb_model_id
|
|
|
+
|
|
|
+ # 初始化默认路径
|
|
|
+ if not self.category_table_path:
|
|
|
+ self.category_table_path = str(
|
|
|
+ Path(__file__).parent.parent.parent / "doc_worker" / "config" / "StandardCategoryTable.csv"
|
|
|
+ )
|
|
|
+ if not self.second_category_path:
|
|
|
+ self.second_category_path = str(
|
|
|
+ Path(__file__).parent.parent.parent / "doc_worker" / "config" / "construction_plan_standards.csv"
|
|
|
+ )
|
|
|
+ if not self.output_path:
|
|
|
+ # 项目根目录下的 temp/construction_review/llm_content_classifier_v2
|
|
|
+ project_root = Path(__file__).parent.parent.parent.parent.parent.parent
|
|
|
+ self.output_path = str(project_root / "temp" / "construction_review" / "llm_content_classifier_v2")
|
|
|
+
|
|
|
+
|
|
|
+# 默认配置实例(从 config.ini 加载,用于独立运行测试)
|
|
|
+DEFAULT_CONFIG = ClassifierConfig()
|
|
|
+
|
|
|
+# 向后兼容的全局变量(供独立运行测试使用,从 config.ini 加载)
|
|
|
+API_KEY = DEFAULT_CONFIG.api_key
|
|
|
+MAX_CONCURRENT_REQUESTS = DEFAULT_CONFIG.max_concurrent_requests
|
|
|
+MAX_RETRIES = DEFAULT_CONFIG.max_retries
|
|
|
+RETRY_DELAY = DEFAULT_CONFIG.retry_delay
|
|
|
+BASE_URL = DEFAULT_CONFIG.base_url
|
|
|
+MODEL = DEFAULT_CONFIG.model
|
|
|
+EMBEDDING_API_KEY = DEFAULT_CONFIG.embedding_api_key
|
|
|
+EMBEDDING_BASE_URL = DEFAULT_CONFIG.embedding_base_url
|
|
|
+EMBEDDING_MODEL = DEFAULT_CONFIG.embedding_model
|
|
|
+EMBEDDING_SIMILARITY_THRESHOLD = DEFAULT_CONFIG.embedding_similarity_threshold
|
|
|
+CATEGORY_TABLE_PATH = Path(DEFAULT_CONFIG.category_table_path)
|
|
|
+SECOND_CATEGORY_PATH = Path(DEFAULT_CONFIG.second_category_path)
|
|
|
+
|
|
|
+
|
|
|
+# ==================== 数据模型 ====================
|
|
|
+
|
|
|
+@dataclass
|
|
|
+class CategoryStandard:
|
|
|
+ """标准分类定义"""
|
|
|
+ first_code: str
|
|
|
+ first_name: str
|
|
|
+ second_code: str
|
|
|
+ second_name: str
|
|
|
+ second_focus: str # 二级分类关注点
|
|
|
+ third_code: str
|
|
|
+ third_name: str
|
|
|
+ third_focus: str
|
|
|
+
|
|
|
+
|
|
|
+@dataclass
|
|
|
+class SecondCategoryStandard:
|
|
|
+ """二级分类标准定义(来自construction_plan_standards.csv)"""
|
|
|
+ first_name: str # 一级分类中文名
|
|
|
+ second_name: str # 二级分类中文名
|
|
|
+ second_raw_content: str # 二级分类详细描述
|
|
|
+
|
|
|
+
|
|
|
+@dataclass
|
|
|
+class ClassifiedContent:
|
|
|
+ """分类结果"""
|
|
|
+ third_category_name: str # 三级分类名称
|
|
|
+ third_category_code: str # 三级分类代码
|
|
|
+ start_line: int
|
|
|
+ end_line: int
|
|
|
+ content: str # 原文内容
|
|
|
+
|
|
|
+
|
|
|
+@dataclass
|
|
|
+class SectionContent:
|
|
|
+ """二级标题内容"""
|
|
|
+ section_key: str # 如 "第一章->一"
|
|
|
+ section_name: str # 如 "一)编制依据"
|
|
|
+ lines: List[str] # 原始行列表
|
|
|
+ numbered_content: str # 带行号的内容
|
|
|
+ category_standards: List[CategoryStandard] = field(default_factory=list) # 该二级分类下的三级标准
|
|
|
+ line_number_map: List[int] = field(default_factory=list) # 每行对应的全局行号(如果有)
|
|
|
+ chunk_ranges: List[Tuple[str, int, int]] = field(default_factory=list) # [(chunk_id, global_start, global_end), ...]
|
|
|
+
|
|
|
+
|
|
|
+@dataclass
|
|
|
+class ClassificationResult:
|
|
|
+ """分类结果"""
|
|
|
+ model: str
|
|
|
+ section_key: str
|
|
|
+ section_name: str
|
|
|
+ classified_contents: List[ClassifiedContent]
|
|
|
+ latency: float
|
|
|
+ raw_response: str = ""
|
|
|
+ error: Optional[str] = None
|
|
|
+ total_lines: int = 0 # 该section的总行数
|
|
|
+ classified_lines: int = 0 # 已分类的行数
|
|
|
+ coverage_rate: float = 0.0 # 分类率(已分类行数/总行数)
|
|
|
+
|
|
|
+
|
|
|
+# ==================== 二级分类关键词映射 ====================
|
|
|
+# 用于将文档中的二级标题名称映射到 StandardCategoryTable.csv 中的标准名称
|
|
|
+# 格式: { CSV标准名称: [可能的文档名称列表] }
|
|
|
+SECONDARY_CATEGORY_KEYWORDS = {
|
|
|
+ # 编制依据 (basis)
|
|
|
+ "法律法规": ["法律法规", "法律", "法规"],
|
|
|
+ "标准规范": ["标准规范", "标准", "规范", "技术标准"],
|
|
|
+ "文件制度": ["文件制度", "制度文件", "管理文件"],
|
|
|
+ "编制原则": ["编制原则", "原则"],
|
|
|
+ "编制范围": ["编制范围", "范围", "工程范围"],
|
|
|
+
|
|
|
+ # 工程概况 (overview)
|
|
|
+ "设计概况": ["设计概况", "工程简介", "工程概况", "概况"],
|
|
|
+ "工程地质与水文气象": ["工程地质与水文气象", "地质", "水文", "气象", "工程地质", "水文气象", "地质与水文"],
|
|
|
+ "周边环境": ["周边环境", "环境", "周围环境"],
|
|
|
+ "施工平面及立面布置": ["施工平面及立面布置", "平面布置", "立面布置", "施工平面", "平面及立面"],
|
|
|
+ "施工要求和技术保证条件": ["施工要求和技术保证条件", "施工要求", "技术保证", "保证条件"],
|
|
|
+ "风险辨识与分级": ["风险辨识与分级", "风险辨识", "风险分级", "风险", "风险等级"],
|
|
|
+ "参建各方责任主体单位": ["参建各方责任主体单位", "参建单位", "责任主体", "参建各方"],
|
|
|
+
|
|
|
+ # 施工计划 (plan)
|
|
|
+ "施工进度计划": ["施工进度计划", "进度计划", "进度", "工期计划"],
|
|
|
+ "施工材料计划": ["施工材料计划", "材料计划", "材料"],
|
|
|
+ "施工设备计划": ["施工设备计划", "设备计划", "机械设备", "设备"],
|
|
|
+ "劳动力计划": ["劳动力计划", "劳动力", "人员计划", "用工计划"],
|
|
|
+ "安全生产费用使用计划": ["安全生产费用使用计划", "安全费用", "安全费", "安全生产费用"],
|
|
|
+
|
|
|
+ # 施工工艺技术 (technology)
|
|
|
+ "主要施工方法概述": ["主要施工方法概述", "施工方法概述", "方法概述", "施工方法"],
|
|
|
+ "技术参数": ["技术参数", "参数", "技术指标"],
|
|
|
+ "工艺流程": ["工艺流程", "流程", "施工流程"],
|
|
|
+ "施工准备": ["施工准备", "准备", "准备工作"],
|
|
|
+ "施工方法及操作要求": ["施工方法及操作要求", "施工方案及操作要求", "操作要求", "施工方案", "施工方法", "方法及操作"],
|
|
|
+ "检查要求": ["检查要求", "检查", "验收要求", "检查验收"],
|
|
|
+
|
|
|
+ # 安全保证措施 (safety)
|
|
|
+ "安全保证体系": ["安全保证体系", "安全体系", "安全管理体系"],
|
|
|
+ "组织保证措施": ["组织保证措施", "组织措施", "组织保证"],
|
|
|
+ "技术保证措施": ["技术保证措施", "技术保障措施", "技术措施", "保障措施", "技术保障", "安全防护措施", "安全防护"],
|
|
|
+ "监测监控措施": ["监测监控措施", "监测措施", "监控措施", "监测监控"],
|
|
|
+ "应急处置措施": ["应急处置措施", "应急预案", "应急措施", "应急处置"],
|
|
|
+
|
|
|
+ # 质量保证措施 (quality)
|
|
|
+ "质量保证体系": ["质量保证体系", "质量体系", "质量管理体系"],
|
|
|
+ "质量目标": ["质量目标", "质量指标"],
|
|
|
+ "工程创优规划": ["工程创优规划", "创优规划", "创优计划", "创优"],
|
|
|
+ "质量控制程序与具体措施": ["质量控制程序与具体措施", "质量控制", "质量措施", "质量控制措施"],
|
|
|
+
|
|
|
+ # 环境保证措施 (environment)
|
|
|
+ "环境保证体系": ["环境保证体系", "环境体系", "环境管理体系"],
|
|
|
+ "环境保护组织机构": ["环境保护组织机构", "环保组织", "环境组织"],
|
|
|
+ "环境保护及文明施工措施": ["环境保护及文明施工措施", "环保措施", "文明施工", "环境保护", "环境措施"],
|
|
|
+
|
|
|
+ # 施工管理及作业人员配备与分工 (Management)
|
|
|
+ "施工管理人员": ["施工管理人员", "管理人员", "管理人员配备"],
|
|
|
+ "专职安全生产管理人员": ["专职安全生产管理人员", "专职安全员", "安全管理人员", "安全员", "特种作业人员", "特种工"],
|
|
|
+ "其他作业人员": ["其他作业人员", "其他人员", "作业人员"],
|
|
|
+
|
|
|
+ # 验收要求 (acceptance)
|
|
|
+ "验收标准": ["验收标准", "验收规范", "标准"],
|
|
|
+ "验收程序": ["验收程序", "验收流程", "程序"],
|
|
|
+ "验收内容": ["验收内容", "验收项目"],
|
|
|
+ "验收时间": ["验收时间", "验收日期"],
|
|
|
+ "验收人员": ["验收人员", "验收参与人员"],
|
|
|
+
|
|
|
+ # 其他资料 (other)
|
|
|
+ "计算书": ["计算书", "计算", "验算"],
|
|
|
+ "相关施工图纸": ["相关施工图纸", "施工图纸", "图纸"],
|
|
|
+ "附图附表": ["附图附表", "附图", "附表"],
|
|
|
+ "编制及审核人员情况": ["编制及审核人员情况", "编制人员", "审核人员"],
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+# ==================== 标准分类加载器 ====================
|
|
|
+
|
|
|
+class CategoryStandardLoader:
|
|
|
+ """加载 StandardCategoryTable.csv"""
|
|
|
+
|
|
|
+ def __init__(self, csv_path: Path):
|
|
|
+ self.csv_path = csv_path
|
|
|
+ self.standards: List[CategoryStandard] = []
|
|
|
+ self._load()
|
|
|
+
|
|
|
+ def _load(self):
|
|
|
+ """加载CSV文件"""
|
|
|
+ with open(self.csv_path, 'r', encoding='utf-8-sig') as f: # utf-8-sig处理BOM
|
|
|
+ reader = csv.DictReader(f)
|
|
|
+ for row in reader:
|
|
|
+ self.standards.append(CategoryStandard(
|
|
|
+ first_code=row.get('first_code', ''),
|
|
|
+ first_name=row.get('first_name', ''),
|
|
|
+ second_code=row.get('second_code', ''),
|
|
|
+ second_name=row.get('second_name', ''),
|
|
|
+ second_focus=row.get('second_focus', ''),
|
|
|
+ third_code=row.get('third_code', ''),
|
|
|
+ third_name=row.get('third_name', ''),
|
|
|
+ third_focus=row.get('third_focus', '')
|
|
|
+ ))
|
|
|
+
|
|
|
+ def get_standards_by_second_code(self, second_code: str) -> List[CategoryStandard]:
|
|
|
+ """根据二级分类代码获取对应的三级分类标准"""
|
|
|
+ return [s for s in self.standards if s.second_code == second_code]
|
|
|
+
|
|
|
+ def _find_standard_name_by_keyword(self, second_name: str) -> Optional[str]:
|
|
|
+ """
|
|
|
+ 通过关键词映射查找标准二级分类名称
|
|
|
+
|
|
|
+ Args:
|
|
|
+ second_name: 文档中的二级标题名称
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 匹配到的标准名称,未匹配返回None
|
|
|
+ """
|
|
|
+ cleaned_name = second_name.strip().lower()
|
|
|
+
|
|
|
+ # 遍历映射表进行匹配
|
|
|
+ for standard_name, keywords in SECONDARY_CATEGORY_KEYWORDS.items():
|
|
|
+ for keyword in keywords:
|
|
|
+ # 宽容匹配:关键词在标题中,或标题在关键词中
|
|
|
+ if keyword.lower() in cleaned_name or cleaned_name in keyword.lower():
|
|
|
+ return standard_name
|
|
|
+
|
|
|
+ return None
|
|
|
+
|
|
|
+ def get_standards_by_second_name(self, second_name: str) -> List[CategoryStandard]:
|
|
|
+ """
|
|
|
+ 根据二级分类名称获取对应的三级分类标准(支持模糊匹配)
|
|
|
+
|
|
|
+ 匹配优先级:
|
|
|
+ 1. 完全匹配 CSV 中的标准名称
|
|
|
+ 2. 包含关系匹配(标准名包含标题名,或标题名包含标准名)
|
|
|
+ 3. 关键词映射匹配(通过 SECONDARY_CATEGORY_KEYWORDS)
|
|
|
+
|
|
|
+ Args:
|
|
|
+ second_name: 二级标题名称
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 匹配到的三级分类标准列表
|
|
|
+ """
|
|
|
+ cleaned_name = second_name.strip()
|
|
|
+
|
|
|
+ # 1. 先尝试完全匹配
|
|
|
+ exact = [s for s in self.standards if s.second_name == cleaned_name]
|
|
|
+ if exact:
|
|
|
+ return exact
|
|
|
+
|
|
|
+ # 2. 包含关系匹配(取第一个命中的 second_name,再返回同名的全部行)
|
|
|
+ for s in self.standards:
|
|
|
+ if s.second_name in cleaned_name or cleaned_name in s.second_name:
|
|
|
+ matched_name = s.second_name
|
|
|
+ return [st for st in self.standards if st.second_name == matched_name]
|
|
|
+
|
|
|
+ # 3. 使用关键词映射进行模糊匹配
|
|
|
+ matched_standard_name = self._find_standard_name_by_keyword(cleaned_name)
|
|
|
+ if matched_standard_name:
|
|
|
+ return [s for s in self.standards if s.second_name == matched_standard_name]
|
|
|
+
|
|
|
+ return []
|
|
|
+
|
|
|
+
|
|
|
+class SecondCategoryStandardLoader:
|
|
|
+ """加载 construction_plan_standards.csv(二级分类标准)"""
|
|
|
+
|
|
|
+ def __init__(self, csv_path: Path):
|
|
|
+ self.csv_path = csv_path
|
|
|
+ self.standards: List[SecondCategoryStandard] = []
|
|
|
+ self._load()
|
|
|
+
|
|
|
+ def _load(self):
|
|
|
+ """加载CSV文件"""
|
|
|
+ with open(self.csv_path, 'r', encoding='utf-8-sig') as f: # utf-8-sig处理BOM
|
|
|
+ reader = csv.DictReader(f)
|
|
|
+ for row in reader:
|
|
|
+ self.standards.append(SecondCategoryStandard(
|
|
|
+ first_name=row.get('first_name', '').strip(),
|
|
|
+ second_name=row.get('second_name', '').strip(),
|
|
|
+ second_raw_content=row.get('second_raw_content', '').strip()
|
|
|
+ ))
|
|
|
+
|
|
|
+ def get_standard_by_second_name(self, second_name: str) -> Optional[SecondCategoryStandard]:
|
|
|
+ """根据二级分类名称获取标准定义(支持模糊匹配)"""
|
|
|
+ # 清理待匹配的名称
|
|
|
+ cleaned_name = second_name.strip().lower()
|
|
|
+
|
|
|
+ # 1. 先尝试完全匹配或包含关系匹配
|
|
|
+ for std in self.standards:
|
|
|
+ # 完全匹配
|
|
|
+ if std.second_name.lower() == cleaned_name:
|
|
|
+ return std
|
|
|
+ # 包含关系匹配
|
|
|
+ if std.second_name.lower() in cleaned_name or cleaned_name in std.second_name.lower():
|
|
|
+ return std
|
|
|
+
|
|
|
+ # 2. 使用关键词映射进行模糊匹配
|
|
|
+ matched_standard_name = None
|
|
|
+ for standard_name, keywords in SECONDARY_CATEGORY_KEYWORDS.items():
|
|
|
+ for keyword in keywords:
|
|
|
+ if keyword.lower() in cleaned_name or cleaned_name in keyword.lower():
|
|
|
+ matched_standard_name = standard_name
|
|
|
+ break
|
|
|
+ if matched_standard_name:
|
|
|
+ break
|
|
|
+
|
|
|
+ if matched_standard_name:
|
|
|
+ # 在standards中查找匹配的标准
|
|
|
+ for std in self.standards:
|
|
|
+ if std.second_name == matched_standard_name:
|
|
|
+ return std
|
|
|
+
|
|
|
+ return None
|
|
|
+
|
|
|
+
|
|
|
+# ==================== Embedding 客户端 ====================
|
|
|
+
|
|
|
+class EmbeddingClient:
|
|
|
+ """Embedding模型客户端,用于计算文本相似度"""
|
|
|
+
|
|
|
+ def __init__(self):
|
|
|
+ self.client = AsyncOpenAI(
|
|
|
+ api_key=EMBEDDING_API_KEY,
|
|
|
+ base_url=EMBEDDING_BASE_URL
|
|
|
+ )
|
|
|
+ self.model = EMBEDDING_MODEL
|
|
|
+
|
|
|
+ async def get_embedding(self, text: str) -> Optional[List[float]]:
|
|
|
+ """获取文本的embedding向量"""
|
|
|
+ try:
|
|
|
+ response = await self.client.embeddings.create(
|
|
|
+ model=self.model,
|
|
|
+ input=text
|
|
|
+ )
|
|
|
+ if response.data and len(response.data) > 0:
|
|
|
+ return response.data[0].embedding
|
|
|
+ return None
|
|
|
+ except Exception as e:
|
|
|
+ print(f" Embedding API调用失败: {e}")
|
|
|
+ return None
|
|
|
+
|
|
|
+ async def get_embeddings_batch(self, texts: List[str]) -> List[Optional[List[float]]]:
|
|
|
+ """批量获取文本的embedding向量"""
|
|
|
+ try:
|
|
|
+ response = await self.client.embeddings.create(
|
|
|
+ model=self.model,
|
|
|
+ input=texts
|
|
|
+ )
|
|
|
+ results = []
|
|
|
+ for item in response.data:
|
|
|
+ results.append(item.embedding)
|
|
|
+ return results
|
|
|
+ except Exception as e:
|
|
|
+ print(f" Embedding API批量调用失败: {e}")
|
|
|
+ return [None] * len(texts)
|
|
|
+
|
|
|
+ def cosine_similarity(self, vec1: List[float], vec2: List[float]) -> float:
|
|
|
+ """计算两个向量的余弦相似度"""
|
|
|
+ if not vec1 or not vec2 or len(vec1) != len(vec2):
|
|
|
+ return 0.0
|
|
|
+
|
|
|
+ dot_product = sum(a * b for a, b in zip(vec1, vec2))
|
|
|
+ norm1 = math.sqrt(sum(a * a for a in vec1))
|
|
|
+ norm2 = math.sqrt(sum(b * b for b in vec2))
|
|
|
+
|
|
|
+ if norm1 == 0 or norm2 == 0:
|
|
|
+ return 0.0
|
|
|
+
|
|
|
+ return dot_product / (norm1 * norm2)
|
|
|
+
|
|
|
+ def _clean_section_name(self, section_name: str) -> str:
|
|
|
+ """清理section名称,去除序号等前缀
|
|
|
+
|
|
|
+ 例如:
|
|
|
+ - "一)编制依据" -> "编制依据"
|
|
|
+ - "二) 技术保证措施" -> "技术保证措施"
|
|
|
+ - "1. 施工计划" -> "施工计划"
|
|
|
+ - "(1) 工艺流程" -> "工艺流程"
|
|
|
+ """
|
|
|
+ cleaned = section_name.strip()
|
|
|
+
|
|
|
+ # 去除开头的序号模式:
|
|
|
+ # 1. 中文数字+)或中文数字+、 如 "一)"、"二、"
|
|
|
+ # 2. 阿拉伯数字+. 或阿拉伯数字+)如 "1.", "2)"
|
|
|
+ # 3. 括号数字如 "(1)", "(一)"
|
|
|
+ patterns = [
|
|
|
+ r'^[一二三四五六七八九十百千]+[)\\)、\\.\\s]+', # 中文数字+标点
|
|
|
+ r'^\\d+[\\.\\)\\)、\\s]+', # 阿拉伯数字+标点
|
|
|
+ r'^[((]\\d+[))][\\s\\.]*', # 括号数字
|
|
|
+ r'^[((][一二三四五六七八九十][))][\\s\\.]*', # 括号中文数字
|
|
|
+ ]
|
|
|
+
|
|
|
+ for pattern in patterns:
|
|
|
+ cleaned = re.sub(pattern, '', cleaned)
|
|
|
+
|
|
|
+ return cleaned.strip()
|
|
|
+
|
|
|
+ async def check_similarity(
|
|
|
+ self,
|
|
|
+ section_name: str,
|
|
|
+ section_content: str,
|
|
|
+ second_category_name: str,
|
|
|
+ second_category_raw_content: str = ""
|
|
|
+ ) -> Tuple[bool, float]:
|
|
|
+ """
|
|
|
+ 检查待审查内容与二级分类标准的相似度
|
|
|
+
|
|
|
+ 比较:
|
|
|
+ - 左侧: section的实际内容(待审查的施工方案内容)
|
|
|
+ - 右侧: second_raw_content(来自construction_plan_standards.csv的标准定义)
|
|
|
+
|
|
|
+ 返回: (is_similar, similarity_score)
|
|
|
+ - is_similar: 是否相似(相似度 > 阈值 或标题完全匹配)
|
|
|
+ - similarity_score: 相似度分数 (0-1)
|
|
|
+ """
|
|
|
+ # 步骤1: 先判断标题是否匹配
|
|
|
+ # 清理文本进行比较(去除序号等前缀)
|
|
|
+ cleaned_section_name = self._clean_section_name(section_name).lower()
|
|
|
+ cleaned_second_name = second_category_name.strip().lower()
|
|
|
+
|
|
|
+ # 标题直接相等检查(清理后的)
|
|
|
+ if cleaned_section_name == cleaned_second_name:
|
|
|
+ # 标题匹配,继续用embedding比较内容相似度
|
|
|
+ pass
|
|
|
+ else:
|
|
|
+ # 标题不匹配,检查是否包含关系
|
|
|
+ if cleaned_second_name in cleaned_section_name or cleaned_section_name in cleaned_second_name:
|
|
|
+ # 要求包含的部分至少4个字符,避免短词误判
|
|
|
+ if len(cleaned_second_name) >= 4 or len(cleaned_section_name) >= 4:
|
|
|
+ # 标题部分匹配,继续用embedding比较内容
|
|
|
+ pass
|
|
|
+ else:
|
|
|
+ # 标题不匹配且太短,直接返回不相似
|
|
|
+ return False, 0.0
|
|
|
+ else:
|
|
|
+ # 标题完全不匹配,直接返回不相似
|
|
|
+ return False, 0.0
|
|
|
+
|
|
|
+ # 步骤2: 使用embedding计算内容相似度
|
|
|
+ # 左侧: section的实际内容(待审查的施工方案实际内容)
|
|
|
+ # 右侧: second_raw_content(该second_name的标准定义)
|
|
|
+ section_text = section_content[:800] # 取前800字符的实际内容
|
|
|
+ category_text = second_category_raw_content[:800] if second_category_raw_content else second_category_name
|
|
|
+
|
|
|
+ # 获取embedding
|
|
|
+ embeddings = await self.get_embeddings_batch([section_text, category_text])
|
|
|
+
|
|
|
+ if embeddings[0] is None or embeddings[1] is None:
|
|
|
+ # embedding获取失败,保守起见返回不相似
|
|
|
+ return False, 0.0
|
|
|
+
|
|
|
+ # 计算相似度
|
|
|
+ similarity = self.cosine_similarity(embeddings[0], embeddings[1])
|
|
|
+
|
|
|
+ # 判断结果
|
|
|
+ is_similar = similarity >= EMBEDDING_SIMILARITY_THRESHOLD
|
|
|
+
|
|
|
+ return is_similar, similarity
|
|
|
+
|
|
|
+
|
|
|
+# ==================== LLM 客户端 ====================
|
|
|
+
|
|
|
+class ContentClassifierClient:
|
|
|
+ """LLM 内容分类客户端"""
|
|
|
+
|
|
|
+ def __init__(self, model: str, semaphore: asyncio.Semaphore, embedding_client: Optional[EmbeddingClient] = None, second_category_loader: Optional[SecondCategoryStandardLoader] = None):
|
|
|
+ self.model = model
|
|
|
+ self.semaphore = semaphore
|
|
|
+ self.client = AsyncOpenAI(
|
|
|
+ api_key=API_KEY,
|
|
|
+ base_url=BASE_URL
|
|
|
+ )
|
|
|
+ self.embedding_client = embedding_client
|
|
|
+ self.second_category_loader = second_category_loader
|
|
|
+
|
|
|
+ async def classify_content(self, section: SectionContent) -> ClassificationResult:
|
|
|
+ """对内容进行三级分类识别(带并发控制和自动修复,支持长内容分块处理)"""
|
|
|
+ start_time = time.time()
|
|
|
+
|
|
|
+ # 步骤1: 使用Embedding模型检查二级分类与内容的相似度
|
|
|
+ if self.embedding_client and self.second_category_loader and section.category_standards:
|
|
|
+ # 从construction_plan_standards.csv中查找对应的标准二级分类
|
|
|
+ # 使用section_name进行匹配
|
|
|
+ std_second_category = self.second_category_loader.get_standard_by_second_name(section.section_name)
|
|
|
+
|
|
|
+ if std_second_category:
|
|
|
+ # 找到了对应的标准二级分类,进行相似度检查
|
|
|
+ # 检查section内容与标准的second_raw_content的一致性
|
|
|
+ section_text = '\n'.join(section.lines)
|
|
|
+ is_similar, similarity = await self.embedding_client.check_similarity(
|
|
|
+ section_name=section.section_name,
|
|
|
+ section_content=section_text,
|
|
|
+ second_category_name=std_second_category.second_name,
|
|
|
+ second_category_raw_content=std_second_category.second_raw_content
|
|
|
+ )
|
|
|
+
|
|
|
+ if is_similar:
|
|
|
+ print(f" [{section.section_name}] 相似度检查通过 ({similarity:.3f} >= {EMBEDDING_SIMILARITY_THRESHOLD}),跳过LLM分类,默认包含所有三级分类")
|
|
|
+ # 生成默认分类结果:包含所有三级分类
|
|
|
+ all_contents = self._generate_default_classification(section)
|
|
|
+ total_lines, classified_lines, coverage_rate = self._calculate_coverage_rate(section, all_contents)
|
|
|
+ latency = time.time() - start_time
|
|
|
+ return ClassificationResult(
|
|
|
+ model=self.model,
|
|
|
+ section_key=section.section_key,
|
|
|
+ section_name=section.section_name,
|
|
|
+ classified_contents=all_contents,
|
|
|
+ latency=latency,
|
|
|
+ raw_response=f"[Embedding相似度跳过] similarity={similarity:.3f}",
|
|
|
+ error=None,
|
|
|
+ total_lines=total_lines,
|
|
|
+ classified_lines=classified_lines,
|
|
|
+ coverage_rate=coverage_rate
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ print(f" [{section.section_name}] 相似度检查未通过 ({similarity:.3f} < {EMBEDDING_SIMILARITY_THRESHOLD}),继续LLM分类")
|
|
|
+ else:
|
|
|
+ print(f" [{section.section_name}] 未在construction_plan_standards.csv中找到对应标准,继续LLM分类")
|
|
|
+
|
|
|
+ # 如果内容过长,分块处理
|
|
|
+ MAX_LINES_PER_CHUNK = 150 # 每个块最多150行
|
|
|
+ total_lines = len(section.lines)
|
|
|
+
|
|
|
+ if total_lines <= MAX_LINES_PER_CHUNK:
|
|
|
+ # 内容不长,直接处理
|
|
|
+ return await self._classify_single_chunk(section, start_time)
|
|
|
+
|
|
|
+ # 内容过长,无重叠分块处理
|
|
|
+ # 不使用 overlap:有重叠时边界行被两块各看一次反而容易两头都不认领,
|
|
|
+ # 无重叠时每行只属于唯一一块,prompt 里的"必须分类每一行"约束更有效。
|
|
|
+ print(f" [{section.section_name}] 内容较长({total_lines}行),分块处理...")
|
|
|
+ all_contents = []
|
|
|
+ chunk_size = MAX_LINES_PER_CHUNK
|
|
|
+
|
|
|
+ chunk_start = 0
|
|
|
+ while chunk_start < total_lines:
|
|
|
+ chunk_end = min(chunk_start + chunk_size, total_lines)
|
|
|
+ chunk_section = self._create_chunk_section(section, chunk_start, chunk_end)
|
|
|
+
|
|
|
+ chunk_result = await self._classify_single_chunk(chunk_section, 0, is_chunk=True)
|
|
|
+
|
|
|
+ if chunk_result.error:
|
|
|
+ print(f" 块 {chunk_start+1}-{chunk_end} 处理失败: {chunk_result.error[:50]}")
|
|
|
+ else:
|
|
|
+ print(f" 块 {chunk_start+1}-{chunk_end} 成功: {len(chunk_result.classified_contents)} 个分类")
|
|
|
+ all_contents.extend(chunk_result.classified_contents)
|
|
|
+
|
|
|
+ # 无重叠:下一块从当前块末尾紧接开始
|
|
|
+ chunk_start = chunk_end
|
|
|
+
|
|
|
+ # 所有块处理完成后,再次聚合所有内容(解决分块导致的同一分类分散问题)
|
|
|
+ if all_contents:
|
|
|
+ all_contents = self._merge_classified_contents(all_contents, section)
|
|
|
+
|
|
|
+ # 计算分类率
|
|
|
+ total_lines, classified_lines, coverage_rate = self._calculate_coverage_rate(section, all_contents)
|
|
|
+
|
|
|
+ latency = time.time() - start_time
|
|
|
+
|
|
|
+ return ClassificationResult(
|
|
|
+ model=self.model,
|
|
|
+ section_key=section.section_key,
|
|
|
+ section_name=section.section_name,
|
|
|
+ classified_contents=all_contents,
|
|
|
+ latency=latency,
|
|
|
+ raw_response="",
|
|
|
+ error=None if all_contents else "所有块处理失败",
|
|
|
+ total_lines=total_lines,
|
|
|
+ classified_lines=classified_lines,
|
|
|
+ coverage_rate=coverage_rate
|
|
|
+ )
|
|
|
+
|
|
|
+ def _calculate_coverage_rate(self, section: SectionContent, contents: List[ClassifiedContent]) -> tuple:
|
|
|
+ """计算分类率(已分类行数/总行数)"""
|
|
|
+ total_lines = len(section.lines)
|
|
|
+ if total_lines == 0 or not contents:
|
|
|
+ return total_lines, 0, 0.0
|
|
|
+
|
|
|
+ # 使用集合记录已分类的行号(避免重复计数)
|
|
|
+ classified_line_set = set()
|
|
|
+
|
|
|
+ for content in contents:
|
|
|
+ if section.line_number_map:
|
|
|
+ # 如果有全局行号映射,找出起止行号对应的索引
|
|
|
+ start_idx = -1
|
|
|
+ end_idx = -1
|
|
|
+ for idx, global_line in enumerate(section.line_number_map):
|
|
|
+ if global_line == content.start_line:
|
|
|
+ start_idx = idx
|
|
|
+ if global_line == content.end_line:
|
|
|
+ end_idx = idx
|
|
|
+ break
|
|
|
+
|
|
|
+ if start_idx != -1 and end_idx != -1:
|
|
|
+ for i in range(start_idx, end_idx + 1):
|
|
|
+ if i < len(section.line_number_map):
|
|
|
+ classified_line_set.add(section.line_number_map[i])
|
|
|
+ else:
|
|
|
+ # 没有全局行号,直接使用起止行号
|
|
|
+ for line_num in range(content.start_line, content.end_line + 1):
|
|
|
+ classified_line_set.add(line_num)
|
|
|
+
|
|
|
+ classified_lines = len(classified_line_set)
|
|
|
+ coverage_rate = (classified_lines / total_lines) * 100 if total_lines > 0 else 0.0
|
|
|
+
|
|
|
+ return total_lines, classified_lines, coverage_rate
|
|
|
+
|
|
|
+ def _generate_default_classification(self, section: SectionContent) -> List[ClassifiedContent]:
|
|
|
+ """
|
|
|
+ 生成默认的分类结果(当embedding相似度检查通过时使用)
|
|
|
+ 默认包含所有三级分类,覆盖整个section内容
|
|
|
+ """
|
|
|
+ if not section.category_standards:
|
|
|
+ return []
|
|
|
+
|
|
|
+ # 获取全局行号范围
|
|
|
+ if section.line_number_map:
|
|
|
+ start_line = section.line_number_map[0]
|
|
|
+ end_line = section.line_number_map[-1]
|
|
|
+ else:
|
|
|
+ start_line = 1
|
|
|
+ end_line = len(section.lines)
|
|
|
+
|
|
|
+ # 为每个三级分类创建一个条目,覆盖全部内容
|
|
|
+ default_contents = []
|
|
|
+ for std in section.category_standards:
|
|
|
+ # 提取该分类对应的内容
|
|
|
+ content = self._extract_content_by_line_numbers(section, start_line, end_line)
|
|
|
+ default_contents.append(ClassifiedContent(
|
|
|
+ third_category_name=std.third_name,
|
|
|
+ third_category_code=std.third_code,
|
|
|
+ start_line=start_line,
|
|
|
+ end_line=end_line,
|
|
|
+ content=content
|
|
|
+ ))
|
|
|
+
|
|
|
+ return default_contents
|
|
|
+
|
|
|
+ def _create_chunk_section(self, section: SectionContent, start_idx: int, end_idx: int) -> SectionContent:
|
|
|
+ """从section创建子块"""
|
|
|
+ chunk_lines = section.lines[start_idx:end_idx]
|
|
|
+ chunk_line_map = section.line_number_map[start_idx:end_idx] if section.line_number_map else list(range(start_idx + 1, end_idx + 1))
|
|
|
+
|
|
|
+ # 生成带行号的内容
|
|
|
+ numbered_content = '\n'.join([f"<{chunk_line_map[i]}> {line}" for i, line in enumerate(chunk_lines)])
|
|
|
+
|
|
|
+ return SectionContent(
|
|
|
+ section_key=f"{section.section_key}_chunk_{start_idx}_{end_idx}",
|
|
|
+ section_name=section.section_name,
|
|
|
+ lines=chunk_lines,
|
|
|
+ numbered_content=numbered_content,
|
|
|
+ category_standards=section.category_standards,
|
|
|
+ line_number_map=chunk_line_map
|
|
|
+ )
|
|
|
+
|
|
|
+ async def _classify_single_chunk(self, section: SectionContent, start_time: float, is_chunk: bool = False) -> ClassificationResult:
|
|
|
+ """处理单个块"""
|
|
|
+ prompt = self._build_prompt(section, is_chunk=is_chunk)
|
|
|
+
|
|
|
+ try:
|
|
|
+ async with self.semaphore:
|
|
|
+ response = await self._call_api(prompt)
|
|
|
+
|
|
|
+ classified_contents, parse_error = await self._parse_with_fix(response, section, prompt)
|
|
|
+
|
|
|
+ if not is_chunk:
|
|
|
+ latency = time.time() - start_time
|
|
|
+ # 计算分类率
|
|
|
+ total_lines, classified_lines, coverage_rate = self._calculate_coverage_rate(section, classified_contents)
|
|
|
+ return ClassificationResult(
|
|
|
+ model=self.model,
|
|
|
+ section_key=section.section_key,
|
|
|
+ section_name=section.section_name,
|
|
|
+ classified_contents=classified_contents,
|
|
|
+ latency=latency,
|
|
|
+ raw_response=response[:1000],
|
|
|
+ error=parse_error,
|
|
|
+ total_lines=total_lines,
|
|
|
+ classified_lines=classified_lines,
|
|
|
+ coverage_rate=coverage_rate
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ return ClassificationResult(
|
|
|
+ model=self.model,
|
|
|
+ section_key=section.section_key,
|
|
|
+ section_name=section.section_name,
|
|
|
+ classified_contents=classified_contents,
|
|
|
+ latency=0,
|
|
|
+ raw_response="",
|
|
|
+ error=parse_error
|
|
|
+ )
|
|
|
+ except Exception as e:
|
|
|
+ if not is_chunk:
|
|
|
+ latency = time.time() - start_time
|
|
|
+ return ClassificationResult(
|
|
|
+ model=self.model,
|
|
|
+ section_key=section.section_key,
|
|
|
+ section_name=section.section_name,
|
|
|
+ classified_contents=[],
|
|
|
+ latency=latency,
|
|
|
+ error=str(e)
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ return ClassificationResult(
|
|
|
+ model=self.model,
|
|
|
+ section_key=section.section_key,
|
|
|
+ section_name=section.section_name,
|
|
|
+ classified_contents=[],
|
|
|
+ latency=0,
|
|
|
+ error=str(e)
|
|
|
+ )
|
|
|
+
|
|
|
+ async def _parse_with_fix(self, response: str, section: SectionContent, original_prompt: str = "") -> tuple:
|
|
|
+ """解析响应,失败时让模型修复(最多3次重试)
|
|
|
+
|
|
|
+ 返回: (contents, error_msg)
|
|
|
+ - contents: 分类结果列表(可能为空,表示模型判定无匹配内容)
|
|
|
+ - error_msg: 错误信息,None表示成功(包括空结果),非None表示解析失败
|
|
|
+ """
|
|
|
+ # 第一次尝试解析
|
|
|
+ contents, parse_success = self._parse_response(response, section)
|
|
|
+
|
|
|
+ # 解析成功(包括空结果,表示模型判定内容不符合任何分类标准)
|
|
|
+ if parse_success:
|
|
|
+ if not contents:
|
|
|
+ print(f" [{section.section_name}] 模型判定无匹配内容,记录为未分类")
|
|
|
+ return contents, None
|
|
|
+
|
|
|
+ # 解析失败(JSON格式错误),尝试让模型修复(最多3次)
|
|
|
+ print(f" [{section.section_name}] JSON解析失败,请求模型修复...")
|
|
|
+ print(f" 原始响应前200字符: {response[:200]}...")
|
|
|
+
|
|
|
+ original_response = response
|
|
|
+
|
|
|
+ for attempt in range(3):
|
|
|
+ fix_prompt = self._build_fix_prompt(original_response)
|
|
|
+
|
|
|
+ try:
|
|
|
+ async with self.semaphore:
|
|
|
+ fixed_response = await self._call_api(fix_prompt)
|
|
|
+
|
|
|
+ # 尝试解析修复后的输出
|
|
|
+ contents, parse_success = self._parse_response(fixed_response, section)
|
|
|
+ if parse_success:
|
|
|
+ print(f" [{section.section_name}] 模型修复成功(第{attempt+1}次)")
|
|
|
+ if not contents:
|
|
|
+ print(f" [{section.section_name}] 修复后模型判定无匹配内容,记录为未分类")
|
|
|
+ return contents, None
|
|
|
+ else:
|
|
|
+ print(f" 第{attempt+1}次修复失败,继续重试...")
|
|
|
+ original_response = fixed_response
|
|
|
+ except Exception as e:
|
|
|
+ return [], f"请求模型修复失败: {str(e)}"
|
|
|
+
|
|
|
+ print(f" [{section.section_name}] 模型修复3次后仍无法解析JSON")
|
|
|
+ return [], "模型修复3次后仍无法解析JSON"
|
|
|
+
|
|
|
+ def _build_fix_prompt(self, original_response: str) -> str:
|
|
|
+ """构建JSON修复提示词"""
|
|
|
+ return f"""你之前的输出存在JSON格式错误,请修复以下内容为正确的JSON格式。
|
|
|
+
|
|
|
+## 修复要求
|
|
|
+1. 严格保持原始数据的完整性和内容,不要修改任何业务数据
|
|
|
+2. 只修复JSON语法错误(如缺少逗号、括号不匹配、引号问题等)
|
|
|
+3. 确保输出的是合法的JSON格式
|
|
|
+4. 【重要】分类名称和代码必须在原有分类范围内,禁止创造新的分类
|
|
|
+5. 输出必须严格符合以下结构:
|
|
|
+{{
|
|
|
+ "classified_contents_list": [
|
|
|
+ {{
|
|
|
+ "third_category_name": "分类名称",
|
|
|
+ "third_category_code": "分类代码",
|
|
|
+ "start_line": 数字,
|
|
|
+ "end_line": 数字
|
|
|
+ }}
|
|
|
+ ]
|
|
|
+}}
|
|
|
+
|
|
|
+## 原始输出(需要修复的内容)
|
|
|
+```
|
|
|
+{original_response[:6000]}
|
|
|
+```
|
|
|
+
|
|
|
+注意:
|
|
|
+- 只输出JSON,不要任何解释文字
|
|
|
+- 如果原始内容被截断,修复已提供的部分即可
|
|
|
+- 禁止创造新的分类名称和代码"""
|
|
|
+
|
|
|
+ def _build_prompt(self, section: SectionContent, is_chunk: bool = False) -> str:
|
|
|
+ """构建分类提示词(优化版)"""
|
|
|
+
|
|
|
+ # 获取二级分类信息
|
|
|
+ second_code = ""
|
|
|
+ second_name = section.section_name
|
|
|
+ first_code = ""
|
|
|
+ first_name = ""
|
|
|
+
|
|
|
+ if section.category_standards:
|
|
|
+ first_code = section.category_standards[0].first_code
|
|
|
+ first_name = section.category_standards[0].first_name
|
|
|
+ second_code = section.category_standards[0].second_code
|
|
|
+
|
|
|
+ # 构建三级分类标准描述(完整显示关注要点 - third_focus是最重要的分类依据)
|
|
|
+ standards_desc = []
|
|
|
+ for i, std in enumerate(section.category_standards, 1):
|
|
|
+ # 完整显示 third_focus,这是最重要的分类依据!
|
|
|
+ focus_content = std.third_focus if std.third_focus else "(无具体关注要点)"
|
|
|
+ standards_desc.append(
|
|
|
+ f"{i}. {std.third_name} (代码: {std.third_code})\n"
|
|
|
+ f" 【识别要点】{focus_content}"
|
|
|
+ )
|
|
|
+
|
|
|
+ # 添加非标准项作为兜底分类(放在最后,降低优先级)
|
|
|
+ standards_desc.append(
|
|
|
+ f"{len(section.category_standards) + 1}. 非标准项 (代码: no_standard)\n"
|
|
|
+ f" 识别要点: 仅当内容完全不符合以上任何分类时使用,如页眉页脚、纯表格分隔线、无关的广告语等"
|
|
|
+ )
|
|
|
+
|
|
|
+ standards_text = '\n\n'.join(standards_desc) if standards_desc else "无具体标准,请根据内容自行判断"
|
|
|
+
|
|
|
+ # 计算内容长度和分段提示
|
|
|
+ content_length = len(section.numbered_content)
|
|
|
+ max_content_length = 12000 # 增加内容长度限制
|
|
|
+ content_to_use = section.numbered_content[:max_content_length]
|
|
|
+ is_truncated = len(section.numbered_content) > max_content_length
|
|
|
+
|
|
|
+ if is_chunk and section.line_number_map:
|
|
|
+ chunk_hint = (
|
|
|
+ f"\n【注意】这是文档的一个分块(行号 {section.line_number_map[0]}~{section.line_number_map[-1]}),"
|
|
|
+ f"请对此范围内的**每一行**进行分类,首行和末行同样必须分类,不得遗漏。\n"
|
|
|
+ )
|
|
|
+ elif is_chunk:
|
|
|
+ chunk_hint = "\n【注意】这是文档的一个分块,请对此分块内的**每一行**进行分类,不得遗漏。\n"
|
|
|
+ else:
|
|
|
+ chunk_hint = ""
|
|
|
+ truncation_hint = f"\n【提示】内容较长已截断,当前显示前{max_content_length}字符,请对显示的内容进行完整分类。\n" if is_truncated else ""
|
|
|
+
|
|
|
+ return f"""你是一个专业的施工方案文档分析专家。请根据给定的三级分类标准,识别文档内容中属于各个三级分类的部分。{chunk_hint}{truncation_hint}
|
|
|
+
|
|
|
+## 当前文档位置
|
|
|
+- 一级分类: {first_name} ({first_code})
|
|
|
+- 二级分类: {second_name} ({second_code})
|
|
|
+
|
|
|
+## 三级分类标准(共{len(section.category_standards)}个,必须在此范围内分类)
|
|
|
+
|
|
|
+{standards_text}
|
|
|
+
|
|
|
+---
|
|
|
+
|
|
|
+## 文档内容(每行以<行号>开头,共{len(section.lines)}行)
|
|
|
+```
|
|
|
+{content_to_use}
|
|
|
+```
|
|
|
+
|
|
|
+---
|
|
|
+
|
|
|
+## 分类任务指南
|
|
|
+
|
|
|
+### 核心原则(按优先级排序)
|
|
|
+1. **优先匹配标准分类**:首先判断内容是否符合上述任何一个三级分类标准
|
|
|
+2. **关键词匹配**:内容中出现与分类名称相关的关键词时,应归类到该分类
|
|
|
+3. **语义相关**:即使没有精确关键词,只要语义相关,也应归类
|
|
|
+4. **非标准项谨慎使用**:只有当内容完全不符合任何标准分类时,才使用"非标准项"
|
|
|
+
|
|
|
+### 分类示例
|
|
|
+- 看到"验收内容"、"验收标准"、"验收程序"等内容 → 归类到对应的三级分类
|
|
|
+- 看到"检验方法"、"检查内容"等 → 可能属于"检查要求"或"验收内容"
|
|
|
+- 看到"材料"、"钢筋"、"混凝土"等 → 关注上下文判断所属三级分类
|
|
|
+
|
|
|
+### 行号处理规则
|
|
|
+- **必须合并连续行**:连续多行属于同一分类时,合并为一个条目(start_line为起始,end_line为结束)
|
|
|
+- **禁止逐行输出**:不要为每一行单独创建条目
|
|
|
+- **允许重复分类**:同一行内容可以同时属于多个三级分类
|
|
|
+
|
|
|
+### 多主体句拆分规则(重要)
|
|
|
+- 当一行内容同时提及多个不同主体或类别时,**必须为每个主体单独输出一条分类条目,行号相同**
|
|
|
+- 示例:`"3、有关勘察、设计和监测单位项目技术负责人"` 同时涉及设计单位和监测单位,应输出:
|
|
|
+ - `{{"third_category_code": "DesignUnitXxx", "start_line": N, "end_line": N}}`
|
|
|
+ - `{{"third_category_code": "MonitoringUnitXxx", "start_line": N, "end_line": N}}`
|
|
|
+- 示例:`"总承包单位和分包单位技术负责人"` 同时涉及施工单位,应归入施工单位对应分类
|
|
|
+- 凡是"A、B和C单位"句式,需逐一判断每个主体能否对应某个三级分类
|
|
|
+
|
|
|
+### 自查清单
|
|
|
+- [ ] 是否每一行都已分类(非标准项也是分类)?
|
|
|
+- [ ] 是否优先使用了标准分类而非"非标准项"?
|
|
|
+- [ ] 连续相同分类的行是否已合并?
|
|
|
+- [ ] 分类名称是否与标准列表完全一致?
|
|
|
+- [ ] 包含多个主体的行是否已拆分为多条输出?
|
|
|
+
|
|
|
+## 输出格式(严格JSON,不要任何其他文字)
|
|
|
+```{{
|
|
|
+ "classified_contents_list": [
|
|
|
+ {{
|
|
|
+ "third_category_name": "三级分类名称(只写名称,不含代码)",
|
|
|
+ "third_category_code": "三级分类代码",
|
|
|
+ "start_line": 起始行号,
|
|
|
+ "end_line": 结束行号
|
|
|
+ }}
|
|
|
+ ]
|
|
|
+}}
|
|
|
+```
|
|
|
+
|
|
|
+## 强制约束
|
|
|
+1. 分类名称必须与上述标准列表中的名称完全一致
|
|
|
+2. 分类代码必须使用标准列表中括号内的代码
|
|
|
+3. 行号范围: {section.line_number_map[0] if section.line_number_map else 1} - {section.line_number_map[-1] if section.line_number_map else len(section.lines)}
|
|
|
+4. 只输出JSON,禁止任何解释文字"""
|
|
|
+
|
|
|
+ async def _call_api(self, prompt: str) -> str:
|
|
|
+ """调用API(带指数退避重试)"""
|
|
|
+ system_prompt = """你是专业的施工方案文档分析专家。你的任务是:
|
|
|
+1. 仔细阅读文档内容,理解每行的语义
|
|
|
+2. 将内容归类到给定的三级分类标准中
|
|
|
+3. 【重要】优先使用标准分类,只有完全不符合时才使用"非标准项"
|
|
|
+4. 【重要】连续相同分类的多行必须合并为一个条目
|
|
|
+5. 【重要】当一行同时提及多个主体或类别(如"勘察、设计和监测单位"),必须为每个主体单独输出一条条目,行号相同
|
|
|
+5. 【重要】分类名称只写名称,不含代码。例如:写"验收内容"而不是"验收内容 (Content)"
|
|
|
+6. 必须在给定的三级分类标准范围内分类,禁止创造新的分类名称
|
|
|
+7. 只输出JSON格式结果,不要任何解释文字"""
|
|
|
+
|
|
|
+ kwargs = {
|
|
|
+ "model": self.model,
|
|
|
+ "messages": [
|
|
|
+ {"role": "system", "content": system_prompt},
|
|
|
+ {"role": "user", "content": prompt}
|
|
|
+ ],
|
|
|
+ "temperature": 0.1, # 降低温度提高分类准确性
|
|
|
+ "max_tokens": 8000 # 增加输出空间
|
|
|
+ }
|
|
|
+
|
|
|
+ # qwen3.5 系列模型默认开启思考模式,需要显式关闭
|
|
|
+ # qwen3 系列模型不需要 enable_thinking 参数
|
|
|
+ if "qwen3.5" in self.model:
|
|
|
+ kwargs["extra_body"] = {"enable_thinking": False}
|
|
|
+
|
|
|
+ # 指数退避重试
|
|
|
+ max_retries = 5
|
|
|
+ base_delay = 2 # 基础延迟2秒
|
|
|
+
|
|
|
+ for attempt in range(max_retries):
|
|
|
+ try:
|
|
|
+ response = await self.client.chat.completions.create(**kwargs)
|
|
|
+ return response.choices[0].message.content or ""
|
|
|
+ except Exception as e:
|
|
|
+ error_str = str(e)
|
|
|
+ # 检查是否是429限流错误
|
|
|
+ if "429" in error_str or "rate limit" in error_str.lower():
|
|
|
+ if attempt < max_retries - 1:
|
|
|
+ # 指数退避: 2^attempt * (1 + random)
|
|
|
+ delay = base_delay * (2 ** attempt) + (hash(prompt) % 1000) / 1000
|
|
|
+ print(f" API限流(429),等待 {delay:.1f}s 后重试 ({attempt + 1}/{max_retries})...")
|
|
|
+ await asyncio.sleep(delay)
|
|
|
+ continue
|
|
|
+ # 其他错误或重试次数用完,抛出异常
|
|
|
+ raise
|
|
|
+
|
|
|
+ return ""
|
|
|
+
|
|
|
+ def _parse_response(self, response: str, section: SectionContent) -> tuple:
|
|
|
+ """解析响应(增强版,处理各种JSON格式问题)
|
|
|
+
|
|
|
+ 返回: (contents, parse_success)
|
|
|
+ - contents: 分类结果列表
|
|
|
+ - parse_success: True表示JSON解析成功(包括空结果),False表示解析失败
|
|
|
+ """
|
|
|
+ if not response or not response.strip():
|
|
|
+ return [], False # 空响应视为解析失败
|
|
|
+
|
|
|
+ response = response.strip()
|
|
|
+
|
|
|
+ # 尝试多种方式提取JSON
|
|
|
+ json_str = None
|
|
|
+
|
|
|
+ # 方法1: 从代码块中提取
|
|
|
+ code_block_match = re.search(r'```(?:json)?\s*([\s\S]*?)```', response)
|
|
|
+ if code_block_match:
|
|
|
+ json_str = code_block_match.group(1).strip()
|
|
|
+
|
|
|
+ # 方法2: 优先查找JSON数组(模型经常直接输出数组格式)
|
|
|
+ if not json_str:
|
|
|
+ # 使用非贪婪匹配找到第一个完整的数组
|
|
|
+ array_match = re.search(r'\[[\s\S]*?\]', response)
|
|
|
+ if array_match:
|
|
|
+ potential_array = array_match.group(0)
|
|
|
+ # 验证是否是有效的JSON数组
|
|
|
+ try:
|
|
|
+ parsed = json.loads(potential_array)
|
|
|
+ if isinstance(parsed, list):
|
|
|
+ json_str = potential_array
|
|
|
+ except:
|
|
|
+ pass
|
|
|
+
|
|
|
+ # 方法3: 查找JSON对象
|
|
|
+ if not json_str:
|
|
|
+ json_match = re.search(r'\{[\s\S]*\}', response)
|
|
|
+ if json_match:
|
|
|
+ json_str = json_match.group(0)
|
|
|
+
|
|
|
+ if not json_str:
|
|
|
+ return [], False # 未找到JSON结构,解析失败
|
|
|
+
|
|
|
+ # 处理模型直接输出数组的情况(包装成对象格式)
|
|
|
+ if json_str.strip().startswith('['):
|
|
|
+ try:
|
|
|
+ # 验证是有效的JSON数组
|
|
|
+ array_data = json.loads(json_str)
|
|
|
+ if isinstance(array_data, list):
|
|
|
+ # 包装成期望的格式
|
|
|
+ json_str = json.dumps({"classified_contents": array_data})
|
|
|
+ except:
|
|
|
+ pass # 不是有效数组,继续后续处理
|
|
|
+
|
|
|
+ # 先尝试直接解析,如果成功则不需要修复
|
|
|
+ try:
|
|
|
+ json.loads(json_str)
|
|
|
+ # JSON 有效,直接使用
|
|
|
+ except json.JSONDecodeError:
|
|
|
+ # JSON 无效,尝试修复
|
|
|
+ json_str = self._fix_json(json_str)
|
|
|
+
|
|
|
+ try:
|
|
|
+ data = json.loads(json_str)
|
|
|
+ # 处理数组格式
|
|
|
+ if isinstance(data, list):
|
|
|
+ data = {"classified_contents": data}
|
|
|
+ contents = []
|
|
|
+ # 支持两种键名: classified_contents 或 classified_contents_list
|
|
|
+ items = data.get("classified_contents", []) or data.get("classified_contents_list", [])
|
|
|
+
|
|
|
+ # 获取有效的分类代码列表(从section的标准分类中)
|
|
|
+ valid_codes = set()
|
|
|
+ if section.category_standards:
|
|
|
+ for std in section.category_standards:
|
|
|
+ valid_codes.add(std.third_code)
|
|
|
+ # 添加非标准项作为有效代码
|
|
|
+ valid_codes.add("no_standard")
|
|
|
+
|
|
|
+ for item in items:
|
|
|
+ start_line = item.get("start_line", 0)
|
|
|
+ end_line = item.get("end_line", 0)
|
|
|
+ category_code = item.get("third_category_code", "")
|
|
|
+ category_name = item.get("third_category_name", "")
|
|
|
+
|
|
|
+ # 清理分类名称格式:移除末尾的代码部分(如 "非标准项 (no_standard)" -> "非标准项")
|
|
|
+ if category_name and " (" in category_name and category_name.endswith(")"):
|
|
|
+ category_name = re.sub(r'\s*\([^)]+\)\s*$', '', category_name).strip()
|
|
|
+
|
|
|
+ # 检查分类代码是否在有效列表中,如果不在则强制归为非标准项
|
|
|
+ if category_code not in valid_codes:
|
|
|
+ print(f" 警告: 发现非标准分类 '{category_name}' ({category_code}),强制归为非标准项")
|
|
|
+ category_code = "no_standard"
|
|
|
+ category_name = "非标准项"
|
|
|
+
|
|
|
+ # 根据行号从section中提取原文
|
|
|
+ content = self._extract_content_by_line_numbers(section, start_line, end_line)
|
|
|
+ contents.append(ClassifiedContent(
|
|
|
+ third_category_name=category_name,
|
|
|
+ third_category_code=category_code,
|
|
|
+ start_line=start_line,
|
|
|
+ end_line=end_line,
|
|
|
+ content=content
|
|
|
+ ))
|
|
|
+ # 聚合同一分类下相邻的内容
|
|
|
+ contents = self._merge_classified_contents(contents, section)
|
|
|
+ return contents, True # 解析成功(可能为空结果)
|
|
|
+ except Exception as e:
|
|
|
+ # 尝试更激进的修复
|
|
|
+ try:
|
|
|
+ fixed = self._aggressive_json_fix(json_str)
|
|
|
+ data = json.loads(fixed)
|
|
|
+ # 处理数组格式
|
|
|
+ if isinstance(data, list):
|
|
|
+ data = {"classified_contents": data}
|
|
|
+ contents = []
|
|
|
+ # 支持两种键名: classified_contents 或 classified_contents_list
|
|
|
+ items = data.get("classified_contents", []) or data.get("classified_contents_list", [])
|
|
|
+
|
|
|
+ # 获取有效的分类代码列表(从section的标准分类中)
|
|
|
+ valid_codes = set()
|
|
|
+ if section.category_standards:
|
|
|
+ for std in section.category_standards:
|
|
|
+ valid_codes.add(std.third_code)
|
|
|
+ # 添加非标准项作为有效代码
|
|
|
+ valid_codes.add("no_standard")
|
|
|
+
|
|
|
+ for item in items:
|
|
|
+ start_line = item.get("start_line", 0)
|
|
|
+ end_line = item.get("end_line", 0)
|
|
|
+ category_code = item.get("third_category_code", "")
|
|
|
+ category_name = item.get("third_category_name", "")
|
|
|
+
|
|
|
+ # 清理分类名称格式:移除末尾的代码部分(如 "非标准项 (no_standard)" -> "非标准项")
|
|
|
+ if category_name and " (" in category_name and category_name.endswith(")"):
|
|
|
+ category_name = re.sub(r'\s*\([^)]+\)\s*$', '', category_name).strip()
|
|
|
+
|
|
|
+ # 检查分类代码是否在有效列表中,如果不在则强制归为非标准项
|
|
|
+ if category_code not in valid_codes:
|
|
|
+ print(f" 警告: 发现非标准分类 '{category_name}' ({category_code}),强制归为非标准项")
|
|
|
+ category_code = "no_standard"
|
|
|
+ category_name = "非标准项"
|
|
|
+
|
|
|
+ # 根据行号从section中提取原文
|
|
|
+ content = self._extract_content_by_line_numbers(section, start_line, end_line)
|
|
|
+ contents.append(ClassifiedContent(
|
|
|
+ third_category_name=category_name,
|
|
|
+ third_category_code=category_code,
|
|
|
+ start_line=start_line,
|
|
|
+ end_line=end_line,
|
|
|
+ content=content
|
|
|
+ ))
|
|
|
+ # 聚合同一分类下相邻的内容
|
|
|
+ contents = self._merge_classified_contents(contents, section)
|
|
|
+ return contents, True # 解析成功(可能为空结果)
|
|
|
+ except Exception as e2:
|
|
|
+ error_msg = f"解析JSON失败: {e}, 二次修复也失败: {e2}"
|
|
|
+ print(error_msg)
|
|
|
+ print(f"原始响应前500字符: {response[:500]}...")
|
|
|
+ print(f"提取的JSON前300字符: {json_str[:300]}...")
|
|
|
+ return [], False # 解析失败
|
|
|
+
|
|
|
+ def _merge_classified_contents(self, contents: List[ClassifiedContent], section: SectionContent) -> List[ClassifiedContent]:
|
|
|
+ """将同一分类下的内容按区间合并(只有连续或重叠的区间才合并)"""
|
|
|
+ if not contents:
|
|
|
+ return contents
|
|
|
+
|
|
|
+ # 按分类代码分组
|
|
|
+ groups: Dict[str, List[ClassifiedContent]] = {}
|
|
|
+ for content in contents:
|
|
|
+ key = content.third_category_code
|
|
|
+ if key not in groups:
|
|
|
+ groups[key] = []
|
|
|
+ groups[key].append(content)
|
|
|
+
|
|
|
+ merged_contents = []
|
|
|
+
|
|
|
+ for category_code, group_contents in groups.items():
|
|
|
+ # 按起始行号排序
|
|
|
+ group_contents.sort(key=lambda x: x.start_line)
|
|
|
+
|
|
|
+ # 合并连续或重叠的区间
|
|
|
+ merged_ranges = []
|
|
|
+ for content in group_contents:
|
|
|
+ if not merged_ranges:
|
|
|
+ # 第一个区间
|
|
|
+ merged_ranges.append({
|
|
|
+ 'start': content.start_line,
|
|
|
+ 'end': content.end_line
|
|
|
+ })
|
|
|
+ else:
|
|
|
+ last_range = merged_ranges[-1]
|
|
|
+ # 检查是否连续或重叠(允许1行的间隔也算连续)
|
|
|
+ if content.start_line <= last_range['end'] + 1:
|
|
|
+ # 扩展当前区间
|
|
|
+ last_range['end'] = max(last_range['end'], content.end_line)
|
|
|
+ else:
|
|
|
+ # 不连续,新建区间
|
|
|
+ merged_ranges.append({
|
|
|
+ 'start': content.start_line,
|
|
|
+ 'end': content.end_line
|
|
|
+ })
|
|
|
+
|
|
|
+ # 为每个合并后的区间创建条目
|
|
|
+ for range_info in merged_ranges:
|
|
|
+ merged_content = self._extract_content_by_line_numbers(
|
|
|
+ section, range_info['start'], range_info['end']
|
|
|
+ )
|
|
|
+ merged_contents.append(ClassifiedContent(
|
|
|
+ third_category_name=group_contents[0].third_category_name,
|
|
|
+ third_category_code=category_code,
|
|
|
+ start_line=range_info['start'],
|
|
|
+ end_line=range_info['end'],
|
|
|
+ content=merged_content
|
|
|
+ ))
|
|
|
+
|
|
|
+ # 按起始行号排序最终结果
|
|
|
+ merged_contents.sort(key=lambda x: x.start_line)
|
|
|
+ return merged_contents
|
|
|
+
|
|
|
+ def _extract_content_by_line_numbers(self, section: SectionContent, start_line: int, end_line: int) -> str:
|
|
|
+ """根据全局行号从section中提取原文内容"""
|
|
|
+ if not section.line_number_map:
|
|
|
+ # 如果没有行号映射,使用相对索引
|
|
|
+ start_idx = max(0, start_line - 1)
|
|
|
+ end_idx = min(len(section.lines), end_line)
|
|
|
+ return '\n'.join(section.lines[start_idx:end_idx])
|
|
|
+
|
|
|
+ # 找到全局行号对应的索引
|
|
|
+ start_idx = -1
|
|
|
+ end_idx = -1
|
|
|
+
|
|
|
+ for idx, global_line_num in enumerate(section.line_number_map):
|
|
|
+ if global_line_num == start_line:
|
|
|
+ start_idx = idx
|
|
|
+ if global_line_num == end_line:
|
|
|
+ end_idx = idx
|
|
|
+ break
|
|
|
+
|
|
|
+ # 如果没找到精确匹配,使用近似值
|
|
|
+ if start_idx == -1:
|
|
|
+ for idx, global_line_num in enumerate(section.line_number_map):
|
|
|
+ if global_line_num >= start_line:
|
|
|
+ start_idx = idx
|
|
|
+ break
|
|
|
+ if end_idx == -1:
|
|
|
+ for idx in range(len(section.line_number_map) - 1, -1, -1):
|
|
|
+ if section.line_number_map[idx] <= end_line:
|
|
|
+ end_idx = idx
|
|
|
+ break
|
|
|
+
|
|
|
+ if start_idx == -1:
|
|
|
+ start_idx = 0
|
|
|
+ if end_idx == -1:
|
|
|
+ end_idx = len(section.lines) - 1
|
|
|
+
|
|
|
+ # 确保索引有效
|
|
|
+ start_idx = max(0, min(start_idx, len(section.lines) - 1))
|
|
|
+ end_idx = max(0, min(end_idx, len(section.lines) - 1))
|
|
|
+
|
|
|
+ if start_idx > end_idx:
|
|
|
+ start_idx, end_idx = end_idx, start_idx
|
|
|
+
|
|
|
+ # 添加行号标记返回
|
|
|
+ lines_with_numbers = []
|
|
|
+ for i in range(start_idx, end_idx + 1):
|
|
|
+ global_line = section.line_number_map[i] if i < len(section.line_number_map) else (i + 1)
|
|
|
+ lines_with_numbers.append(f"<{global_line}> {section.lines[i]}")
|
|
|
+
|
|
|
+ return '\n'.join(lines_with_numbers)
|
|
|
+
|
|
|
+ def _fix_json(self, json_str: str) -> str:
|
|
|
+ """修复常见的JSON格式问题"""
|
|
|
+ # 去除尾部多余的逗号
|
|
|
+ json_str = re.sub(r',(\s*[}\]])', r'\1', json_str)
|
|
|
+
|
|
|
+ # 确保 JSON 结构闭合
|
|
|
+ json_str = self._ensure_json_closed(json_str)
|
|
|
+
|
|
|
+ # 替换单引号为双引号(但要小心内容中的单引号)
|
|
|
+ # 使用更精确的方法:先尝试解析,失败再替换
|
|
|
+ try:
|
|
|
+ json.loads(json_str)
|
|
|
+ return json_str
|
|
|
+ except:
|
|
|
+ # 尝试替换单引号
|
|
|
+ json_str = json_str.replace("'", '"')
|
|
|
+
|
|
|
+ return json_str
|
|
|
+
|
|
|
+ def _truncate_to_valid_json(self, json_str: str) -> str:
|
|
|
+ """将截断的JSON截断到最后一个完整对象的位置,并保留数组结构"""
|
|
|
+ # 找到 "classified_contents" 数组的开始
|
|
|
+ array_start = json_str.find('"classified_contents"')
|
|
|
+ if array_start == -1:
|
|
|
+ return json_str
|
|
|
+
|
|
|
+ # 找到数组的 '['
|
|
|
+ bracket_start = json_str.find('[', array_start)
|
|
|
+ if bracket_start == -1:
|
|
|
+ return json_str
|
|
|
+
|
|
|
+ # 遍历数组,找到最后一个完整的对象
|
|
|
+ brace_count = 0
|
|
|
+ bracket_count = 1 # 已经进入数组,所以是1
|
|
|
+ in_string = False
|
|
|
+ escape_next = False
|
|
|
+ last_valid_obj_end = 0
|
|
|
+ i = bracket_start + 1
|
|
|
+
|
|
|
+ while i < len(json_str):
|
|
|
+ char = json_str[i]
|
|
|
+
|
|
|
+ if escape_next:
|
|
|
+ escape_next = False
|
|
|
+ i += 1
|
|
|
+ continue
|
|
|
+
|
|
|
+ if char == '\\':
|
|
|
+ escape_next = True
|
|
|
+ i += 1
|
|
|
+ continue
|
|
|
+
|
|
|
+ if char == '"' and not escape_next:
|
|
|
+ in_string = not in_string
|
|
|
+ i += 1
|
|
|
+ continue
|
|
|
+
|
|
|
+ if not in_string:
|
|
|
+ if char == '{':
|
|
|
+ brace_count += 1
|
|
|
+ elif char == '}':
|
|
|
+ brace_count -= 1
|
|
|
+ if brace_count == 0:
|
|
|
+ # 找到一个完整的对象
|
|
|
+ last_valid_obj_end = i
|
|
|
+ elif char == '[':
|
|
|
+ bracket_count += 1
|
|
|
+ elif char == ']':
|
|
|
+ bracket_count -= 1
|
|
|
+ if bracket_count == 0:
|
|
|
+ # 数组正常闭合,不需要截断
|
|
|
+ return json_str
|
|
|
+
|
|
|
+ i += 1
|
|
|
+
|
|
|
+ if last_valid_obj_end > 0:
|
|
|
+ # 截断到最后一个完整对象的位置,并关闭数组
|
|
|
+ return json_str[:last_valid_obj_end + 1] + ']'
|
|
|
+
|
|
|
+ return json_str
|
|
|
+
|
|
|
+ def _ensure_json_closed(self, json_str: str) -> str:
|
|
|
+ """确保JSON结构闭合"""
|
|
|
+ # 计算未闭合的括号
|
|
|
+ brace_count = 0
|
|
|
+ bracket_count = 0
|
|
|
+ in_string = False
|
|
|
+ escape_next = False
|
|
|
+
|
|
|
+ for char in json_str:
|
|
|
+ if escape_next:
|
|
|
+ escape_next = False
|
|
|
+ continue
|
|
|
+ if char == '\\':
|
|
|
+ escape_next = True
|
|
|
+ continue
|
|
|
+ if char == '"' and not escape_next:
|
|
|
+ in_string = not in_string
|
|
|
+ continue
|
|
|
+ if not in_string:
|
|
|
+ if char == '{':
|
|
|
+ brace_count += 1
|
|
|
+ elif char == '}':
|
|
|
+ brace_count -= 1
|
|
|
+ elif char == '[':
|
|
|
+ bracket_count += 1
|
|
|
+ elif char == ']':
|
|
|
+ bracket_count -= 1
|
|
|
+
|
|
|
+ # 添加闭合括号
|
|
|
+ result = json_str
|
|
|
+ # 先去掉尾部可能的逗号
|
|
|
+ result = result.rstrip().rstrip(',').rstrip()
|
|
|
+
|
|
|
+ # 关闭对象
|
|
|
+ while brace_count > 0:
|
|
|
+ result += '}'
|
|
|
+ brace_count -= 1
|
|
|
+
|
|
|
+ # 关闭数组
|
|
|
+ while bracket_count > 0:
|
|
|
+ result += ']'
|
|
|
+ bracket_count -= 1
|
|
|
+
|
|
|
+ return result
|
|
|
+
|
|
|
+ def _aggressive_json_fix(self, json_str: str) -> str:
|
|
|
+ """激进的JSON修复,用于处理复杂情况"""
|
|
|
+ # 首先尝试截断到最后一个完整对象
|
|
|
+ json_str = self._truncate_to_valid_json(json_str)
|
|
|
+ # 然后确保结构闭合
|
|
|
+ json_str = self._ensure_json_closed(json_str)
|
|
|
+ return json_str
|
|
|
+
|
|
|
+
|
|
|
+# ==================== Chunks 转换器(用于集成) ====================
|
|
|
+
|
|
|
+class ChunksConverter:
|
|
|
+ """chunks 格式与 SectionContent 格式的转换器"""
|
|
|
+
|
|
|
+ def __init__(self, category_loader: 'CategoryStandardLoader'):
|
|
|
+ self.category_loader = category_loader
|
|
|
+
|
|
|
+ def chunks_to_sections(self, chunks: List[Dict[str, Any]]) -> List[SectionContent]:
|
|
|
+ """
|
|
|
+ 将 chunks 列表转换为 SectionContent 列表
|
|
|
+
|
|
|
+ 分组策略:
|
|
|
+ 1. 优先按 section_label 分组(更精确的文档结构)
|
|
|
+ 2. 如果 section_label 相同,再按一级分类分组
|
|
|
+ 3. 从 section_label 提取二级分类名称用于匹配三级标准
|
|
|
+
|
|
|
+ Args:
|
|
|
+ chunks: 文档分块列表,每个 chunk 需包含:
|
|
|
+ - chapter_classification: 一级分类代码
|
|
|
+ - secondary_category_code: 二级分类代码(可能为 none)
|
|
|
+ - secondary_category_cn: 二级分类中文名
|
|
|
+ - review_chunk_content 或 content: 内容文本
|
|
|
+ - section_label: 章节标签(如 "第一章编制依据->一、法律法规")
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ List[SectionContent]: 二级标题段落列表
|
|
|
+ """
|
|
|
+ # 按 section_label 分组(更精确)
|
|
|
+ # section_label 格式: "第一章编制依据->一、法律法规"
|
|
|
+ section_groups: Dict[str, List[Dict]] = {}
|
|
|
+
|
|
|
+ for chunk in chunks:
|
|
|
+ # 获取分类信息
|
|
|
+ section_label = chunk.get("section_label", "") or chunk.get("chapter", "")
|
|
|
+ first_code = chunk.get("chapter_classification", "") or chunk.get("first_code", "")
|
|
|
+ second_code = chunk.get("secondary_category_code", "") or chunk.get("second_code", "")
|
|
|
+ second_cn = chunk.get("secondary_category_cn", "") or chunk.get("second_name", "")
|
|
|
+
|
|
|
+ # 分组策略:每个二级分类独立分组,禁止合并不同二级分类
|
|
|
+ # 优先使用 section_label,其次使用 secondary_category_code
|
|
|
+ if section_label and "->" in section_label:
|
|
|
+ # 有明确的章节标签,使用它作为分组键
|
|
|
+ group_key = section_label
|
|
|
+ elif second_code and second_code not in ("none", "None", ""):
|
|
|
+ # 有二级分类代码,按二级分类独立分组(关键:不再合并到一级分类下)
|
|
|
+ group_key = f"{first_code}->{second_code}"
|
|
|
+ elif section_label:
|
|
|
+ group_key = section_label
|
|
|
+ else:
|
|
|
+ # 完全没有分类信息,使用唯一键避免合并
|
|
|
+ group_key = f"unknown_{first_code}_{id(chunk)}"
|
|
|
+
|
|
|
+ if group_key not in section_groups:
|
|
|
+ section_groups[group_key] = []
|
|
|
+ section_groups[group_key].append(chunk)
|
|
|
+
|
|
|
+ # 为每个分组创建 SectionContent
|
|
|
+ section_contents = []
|
|
|
+ all_lines = [] # 全局行号追踪
|
|
|
+
|
|
|
+ for group_key, group_chunks in section_groups.items():
|
|
|
+ if not group_chunks:
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 合并该分组的所有内容,同时记录每个原始 chunk 的行范围
|
|
|
+ section_lines = []
|
|
|
+ chunk_line_counts: List[Tuple[str, int]] = [] # (chunk_id, line_count)
|
|
|
+ for chunk in group_chunks:
|
|
|
+ content = chunk.get("review_chunk_content", "") or chunk.get("content", "") or chunk.get("original_content", "")
|
|
|
+ if content:
|
|
|
+ lines = content.split('\n')
|
|
|
+ n = len(lines)
|
|
|
+ chunk_id = chunk.get("chunk_id") or chunk.get("id") or str(id(chunk))
|
|
|
+ chunk_line_counts.append((chunk_id, n))
|
|
|
+ section_lines.extend(lines)
|
|
|
+ all_lines.extend(lines)
|
|
|
+ else:
|
|
|
+ chunk_id = chunk.get("chunk_id") or chunk.get("id") or str(id(chunk))
|
|
|
+ chunk_line_counts.append((chunk_id, 0))
|
|
|
+
|
|
|
+ if not section_lines:
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 获取一级分类代码
|
|
|
+ first_code = group_chunks[0].get("chapter_classification", "") or group_chunks[0].get("first_code", "")
|
|
|
+
|
|
|
+ # 获取二级分类名称和代码
|
|
|
+ second_code = group_chunks[0].get("secondary_category_code", "") or group_chunks[0].get("second_code", "")
|
|
|
+ second_cn = group_chunks[0].get("secondary_category_cn", "") or group_chunks[0].get("second_name", "")
|
|
|
+
|
|
|
+ # 从 section_label 提取二级分类名称(优先)
|
|
|
+ section_label = group_chunks[0].get("section_label", "") or group_chunks[0].get("chapter", "")
|
|
|
+ if "->" in section_label:
|
|
|
+ parts = section_label.split("->")
|
|
|
+ if len(parts) >= 2:
|
|
|
+ extracted = parts[1].strip()
|
|
|
+ # 去除序号前缀(如 "一、" "二、")
|
|
|
+ cleaned = re.sub(r'^[一二三四五六七八九十]+[、)\s]+', '', extracted).strip()
|
|
|
+ if cleaned:
|
|
|
+ second_cn = cleaned
|
|
|
+ # 尝试根据提取的名称匹配二级分类代码
|
|
|
+ matched_standards = self.category_loader.get_standards_by_second_name(cleaned)
|
|
|
+ if matched_standards:
|
|
|
+ second_code = matched_standards[0].second_code
|
|
|
+
|
|
|
+ # 构建带行号的内容
|
|
|
+ start_line = len(all_lines) - len(section_lines) + 1
|
|
|
+ line_number_map = list(range(start_line, len(all_lines) + 1))
|
|
|
+ numbered_lines = []
|
|
|
+ for i, line in enumerate(section_lines):
|
|
|
+ numbered_lines.append(f"<{line_number_map[i]}> {line}")
|
|
|
+ numbered_content = '\n'.join(numbered_lines)
|
|
|
+
|
|
|
+ # 计算每个原始 chunk 在全局行号中的范围
|
|
|
+ chunk_ranges: List[Tuple[str, int, int]] = []
|
|
|
+ current_global = start_line
|
|
|
+ for chunk_id, n_lines in chunk_line_counts:
|
|
|
+ if n_lines > 0:
|
|
|
+ chunk_ranges.append((chunk_id, current_global, current_global + n_lines - 1))
|
|
|
+ current_global += n_lines
|
|
|
+
|
|
|
+ # 获取三级分类标准
|
|
|
+ category_standards = self.category_loader.get_standards_by_second_code(second_code)
|
|
|
+ if not category_standards:
|
|
|
+ category_standards = self.category_loader.get_standards_by_second_name(second_cn)
|
|
|
+
|
|
|
+ # 构建 section_key
|
|
|
+ section_key = f"{first_code}->{second_code}"
|
|
|
+
|
|
|
+ section_contents.append(SectionContent(
|
|
|
+ section_key=section_key,
|
|
|
+ section_name=second_cn or second_code,
|
|
|
+ lines=section_lines,
|
|
|
+ numbered_content=numbered_content,
|
|
|
+ category_standards=category_standards,
|
|
|
+ line_number_map=line_number_map,
|
|
|
+ chunk_ranges=chunk_ranges
|
|
|
+ ))
|
|
|
+
|
|
|
+ return section_contents
|
|
|
+
|
|
|
+ def classification_result_to_chunks(
|
|
|
+ self,
|
|
|
+ result: ClassificationResult,
|
|
|
+ original_chunks: List[Dict[str, Any]],
|
|
|
+ first_code: str,
|
|
|
+ second_code: str
|
|
|
+ ) -> List[Dict[str, Any]]:
|
|
|
+ """
|
|
|
+ 将 ClassificationResult 转换回 chunks 格式
|
|
|
+
|
|
|
+ 将行级分类结果展开,为每个三级分类创建对应的 chunk 条目
|
|
|
+
|
|
|
+ Args:
|
|
|
+ result: 分类结果
|
|
|
+ original_chunks: 原始 chunks(用于保留其他字段)
|
|
|
+ first_code: 一级分类代码
|
|
|
+ second_code: 二级分类代码
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ List[Dict]: 更新后的 chunks 列表
|
|
|
+ """
|
|
|
+ updated_chunks = []
|
|
|
+
|
|
|
+ # 收集所有三级分类信息,过滤掉非标准项(no_standard)
|
|
|
+ tertiary_classifications = []
|
|
|
+ for content in result.classified_contents:
|
|
|
+ # 跳过非标准项,不纳入三级分类统计
|
|
|
+ if content.third_category_code == "no_standard":
|
|
|
+ continue
|
|
|
+ tertiary_classifications.append({
|
|
|
+ "third_category_name": content.third_category_name,
|
|
|
+ "third_category_code": content.third_category_code,
|
|
|
+ "start_line": content.start_line,
|
|
|
+ "end_line": content.end_line,
|
|
|
+ "content": content.content
|
|
|
+ })
|
|
|
+
|
|
|
+ # 更新原始 chunks
|
|
|
+ for chunk in original_chunks:
|
|
|
+ updated_chunk = dict(chunk)
|
|
|
+ updated_chunk["first_code"] = first_code
|
|
|
+ updated_chunk["second_code"] = second_code
|
|
|
+
|
|
|
+ # 添加三级分类详情列表
|
|
|
+ updated_chunk["tertiary_classification_details"] = tertiary_classifications
|
|
|
+
|
|
|
+ # 如果有三级分类结果,设置第一个作为主要分类(向后兼容)
|
|
|
+ if tertiary_classifications:
|
|
|
+ updated_chunk["tertiary_category_code"] = tertiary_classifications[0]["third_category_code"]
|
|
|
+ updated_chunk["tertiary_category_cn"] = tertiary_classifications[0]["third_category_name"]
|
|
|
+
|
|
|
+ updated_chunks.append(updated_chunk)
|
|
|
+
|
|
|
+ return updated_chunks
|
|
|
+
|
|
|
+
|
|
|
+# ==================== 主入口类 ====================
|
|
|
+
|
|
|
+class LLMContentClassifier:
|
|
|
+ """
|
|
|
+ LLM 内容三级分类器(主入口类)
|
|
|
+
|
|
|
+ 封装完整的分类流程,提供简洁的接口供外部调用
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(self, config: Optional[ClassifierConfig] = None):
|
|
|
+ """
|
|
|
+ 初始化分类器
|
|
|
+
|
|
|
+ Args:
|
|
|
+ config: 配置对象,如果为 None 则使用默认配置
|
|
|
+ """
|
|
|
+ self.config = config or ClassifierConfig()
|
|
|
+
|
|
|
+ # 加载标准分类
|
|
|
+ self.category_loader = CategoryStandardLoader(Path(self.config.category_table_path))
|
|
|
+
|
|
|
+ # 加载二级分类标准(如果存在)
|
|
|
+ self.second_category_loader = None
|
|
|
+ if Path(self.config.second_category_path).exists():
|
|
|
+ self.second_category_loader = SecondCategoryStandardLoader(Path(self.config.second_category_path))
|
|
|
+
|
|
|
+ # 创建转换器
|
|
|
+ self.converter = ChunksConverter(self.category_loader)
|
|
|
+
|
|
|
+ # 并发控制信号量
|
|
|
+ self.semaphore = asyncio.Semaphore(self.config.max_concurrent_requests)
|
|
|
+
|
|
|
+ # Embedding 客户端(可选)
|
|
|
+ self.embedding_client = None
|
|
|
+ if self.config.embedding_base_url:
|
|
|
+ self.embedding_client = self._create_embedding_client()
|
|
|
+
|
|
|
+ def _create_embedding_client(self) -> 'EmbeddingClient':
|
|
|
+ """创建 Embedding 客户端"""
|
|
|
+ client = EmbeddingClient()
|
|
|
+ # 使用配置覆盖默认值
|
|
|
+ client.client = AsyncOpenAI(
|
|
|
+ api_key=self.config.embedding_api_key,
|
|
|
+ base_url=self.config.embedding_base_url
|
|
|
+ )
|
|
|
+ client.model = self.config.embedding_model
|
|
|
+ return client
|
|
|
+
|
|
|
+ async def classify_chunks(
|
|
|
+ self,
|
|
|
+ chunks: List[Dict[str, Any]],
|
|
|
+ progress_callback: Optional[callable] = None
|
|
|
+ ) -> List[Dict[str, Any]]:
|
|
|
+ """
|
|
|
+ 对 chunks 进行三级分类
|
|
|
+
|
|
|
+ Args:
|
|
|
+ chunks: 文档分块列表,每个 chunk 需包含:
|
|
|
+ - chapter_classification: 一级分类代码
|
|
|
+ - secondary_category_code: 二级分类代码
|
|
|
+ - secondary_category_cn: 二级分类中文名
|
|
|
+ - review_chunk_content 或 content: 内容文本
|
|
|
+ progress_callback: 进度回调函数 (completed, total) -> None
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ List[Dict]: 更新后的 chunks 列表,每个 chunk 新增字段:
|
|
|
+ - tertiary_category_code: 三级分类代码
|
|
|
+ - tertiary_category_cn: 三级分类名称
|
|
|
+ - tertiary_classification_details: 行级分类详情列表
|
|
|
+ """
|
|
|
+ print(f"\n正在对 {len(chunks)} 个内容块进行三级分类...")
|
|
|
+
|
|
|
+ # 步骤1: 将 chunks 转换为 SectionContent 列表
|
|
|
+ sections = self.converter.chunks_to_sections(chunks)
|
|
|
+ print(f" 按二级标题分组后得到 {len(sections)} 个段落")
|
|
|
+
|
|
|
+ if not sections:
|
|
|
+ print(" 没有有效的段落需要分类")
|
|
|
+ return chunks
|
|
|
+
|
|
|
+ # 步骤2: 创建分类客户端
|
|
|
+ classifier = ContentClassifierClient(
|
|
|
+ model=self.config.model,
|
|
|
+ semaphore=self.semaphore,
|
|
|
+ embedding_client=self.embedding_client,
|
|
|
+ second_category_loader=self.second_category_loader
|
|
|
+ )
|
|
|
+
|
|
|
+ # 步骤3: 并发分类所有段落
|
|
|
+ results_map: Dict[str, ClassificationResult] = {}
|
|
|
+
|
|
|
+ async def classify_with_progress(section: SectionContent, idx: int, total: int):
|
|
|
+ result = await classifier.classify_content(section)
|
|
|
+ results_map[section.section_key] = result
|
|
|
+
|
|
|
+ if progress_callback:
|
|
|
+ progress_callback(idx + 1, total)
|
|
|
+ else:
|
|
|
+ status = "成功" if not result.error else f"失败: {result.error[:30]}"
|
|
|
+ print(f" [{idx + 1}/{total}] {section.section_name}: {status}")
|
|
|
+
|
|
|
+ return result
|
|
|
+
|
|
|
+ tasks = [
|
|
|
+ classify_with_progress(section, idx, len(sections))
|
|
|
+ for idx, section in enumerate(sections)
|
|
|
+ ]
|
|
|
+ await asyncio.gather(*tasks)
|
|
|
+
|
|
|
+ # 步骤4: 将分类结果转换回 chunks 格式,按 chunk_ranges 过滤确保每个 chunk 只拿自己行范围内的详情
|
|
|
+ updated_chunks = []
|
|
|
+
|
|
|
+ # 建立 chunk_id -> (section_key, g_start, g_end) 映射,来自 sections 的 chunk_ranges
|
|
|
+ chunk_range_map: Dict[str, Tuple[str, int, int]] = {}
|
|
|
+ for section in sections:
|
|
|
+ for (cid, g_start, g_end) in section.chunk_ranges:
|
|
|
+ chunk_range_map[cid] = (section.section_key, g_start, g_end)
|
|
|
+
|
|
|
+ # 为每个原始 chunk 单独分配其行范围内的分类详情
|
|
|
+ for chunk in chunks:
|
|
|
+ updated_chunk = dict(chunk)
|
|
|
+ first_code = chunk.get("chapter_classification", "") or chunk.get("first_code", "")
|
|
|
+ second_code = chunk.get("secondary_category_code", "") or chunk.get("second_code", "")
|
|
|
+
|
|
|
+ # 从 chunk_range_map 获取该 chunk 的行范围(同时拿到正确的 section_key)
|
|
|
+ chunk_id = chunk.get("chunk_id") or chunk.get("id") or str(id(chunk))
|
|
|
+ range_info = chunk_range_map.get(chunk_id)
|
|
|
+
|
|
|
+ if range_info:
|
|
|
+ # 优先使用 chunk_range_map 中记录的 section_key(经过名称匹配的正确 key)
|
|
|
+ section_key = range_info[0]
|
|
|
+ else:
|
|
|
+ # 降级:从 chunk 字段重建(可能在 second_code="none" 时查不到)
|
|
|
+ section_key = f"{first_code}->{second_code}"
|
|
|
+
|
|
|
+ result = results_map.get(section_key)
|
|
|
+
|
|
|
+ if result:
|
|
|
+ updated_chunk["first_code"] = first_code
|
|
|
+ updated_chunk["second_code"] = second_code
|
|
|
+
|
|
|
+ # 收集全部有效三级分类(非 no_standard)
|
|
|
+ all_tertiary = [
|
|
|
+ {
|
|
|
+ "third_category_name": c.third_category_name,
|
|
|
+ "third_category_code": c.third_category_code,
|
|
|
+ "start_line": c.start_line,
|
|
|
+ "end_line": c.end_line,
|
|
|
+ "content": c.content
|
|
|
+ }
|
|
|
+ for c in result.classified_contents
|
|
|
+ if c.third_category_code != "no_standard"
|
|
|
+ ]
|
|
|
+
|
|
|
+ if range_info:
|
|
|
+ # 过滤:只保留与该 chunk 行范围有交集的详情
|
|
|
+ _, g_start, g_end = range_info
|
|
|
+ filtered = [
|
|
|
+ t for t in all_tertiary
|
|
|
+ if t["start_line"] <= g_end and t["end_line"] >= g_start
|
|
|
+ ]
|
|
|
+ updated_chunk["tertiary_classification_details"] = filtered
|
|
|
+ else:
|
|
|
+ # 无法定位行范围(可能是单 chunk 分组),保留全部
|
|
|
+ updated_chunk["tertiary_classification_details"] = all_tertiary
|
|
|
+
|
|
|
+ # 向后兼容:设置第一个三级分类为主分类
|
|
|
+ tertiary_details = updated_chunk["tertiary_classification_details"]
|
|
|
+ if tertiary_details:
|
|
|
+ updated_chunk["tertiary_category_code"] = tertiary_details[0]["third_category_code"]
|
|
|
+ updated_chunk["tertiary_category_cn"] = tertiary_details[0]["third_category_name"]
|
|
|
+
|
|
|
+ updated_chunks.append(updated_chunk)
|
|
|
+
|
|
|
+ print(f" 三级分类完成!共处理 {len(updated_chunks)} 个 chunks")
|
|
|
+ return updated_chunks
|
|
|
+
|
|
|
+
|
|
|
+# ==================== 便捷函数 ====================
|
|
|
+
|
|
|
+async def classify_chunks(
|
|
|
+ chunks: List[Dict[str, Any]],
|
|
|
+ config: Optional[ClassifierConfig] = None,
|
|
|
+ progress_callback: Optional[callable] = None
|
|
|
+) -> List[Dict[str, Any]]:
|
|
|
+ """
|
|
|
+ 对 chunks 进行三级分类的便捷函数
|
|
|
+
|
|
|
+ Args:
|
|
|
+ chunks: 文档分块列表
|
|
|
+ config: 配置对象(可选)
|
|
|
+ progress_callback: 进度回调函数
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ List[Dict]: 更新后的 chunks 列表
|
|
|
+
|
|
|
+ 使用示例:
|
|
|
+ from llm_content_classifier_v2 import classify_chunks
|
|
|
+
|
|
|
+ # 使用默认配置
|
|
|
+ updated_chunks = await classify_chunks(chunks)
|
|
|
+
|
|
|
+ # 使用自定义配置
|
|
|
+ config = ClassifierConfig(
|
|
|
+ model="qwen3.5-122b-a10b",
|
|
|
+ embedding_similarity_threshold=0.85
|
|
|
+ )
|
|
|
+ updated_chunks = await classify_chunks(chunks, config=config)
|
|
|
+ """
|
|
|
+ classifier = LLMContentClassifier(config)
|
|
|
+ return await classifier.classify_chunks(chunks, progress_callback)
|
|
|
+
|
|
|
+
|
|
|
+def classify_chunks_sync(
|
|
|
+ chunks: List[Dict[str, Any]],
|
|
|
+ config: Optional[ClassifierConfig] = None
|
|
|
+) -> List[Dict[str, Any]]:
|
|
|
+ """
|
|
|
+ 同步版本的分类函数(阻塞调用)
|
|
|
+
|
|
|
+ Args:
|
|
|
+ chunks: 文档分块列表
|
|
|
+ config: 配置对象(可选)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ List[Dict]: 更新后的 chunks 列表
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ loop = asyncio.get_running_loop()
|
|
|
+ except RuntimeError:
|
|
|
+ # 没有运行中的事件循环
|
|
|
+ return asyncio.run(classify_chunks(chunks, config))
|
|
|
+
|
|
|
+ # 已有事件循环,创建任务
|
|
|
+ import concurrent.futures
|
|
|
+ with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
|
+ future = executor.submit(
|
|
|
+ asyncio.run,
|
|
|
+ classify_chunks(chunks, config)
|
|
|
+ )
|
|
|
+ return future.result()
|
|
|
+
|
|
|
+
|
|
|
+# ==================== 文本切块工具 ====================
|
|
|
+
|
|
|
+def _is_markdown_table_line(line: str) -> bool:
|
|
|
+ """判断一行是否为 Markdown 表格行(以 | 开头且以 | 结尾)"""
|
|
|
+ stripped = line.strip()
|
|
|
+ return stripped.startswith('|') and stripped.endswith('|') and len(stripped) >= 3
|
|
|
+
|
|
|
+
|
|
|
+def _split_text_lines_with_overlap(
|
|
|
+ lines: List[str],
|
|
|
+ max_chars: int,
|
|
|
+ overlap_chars: int
|
|
|
+) -> List[List[str]]:
|
|
|
+ """
|
|
|
+ 将文本行列表按字符数切分,相邻 chunk 之间保留重叠。
|
|
|
+
|
|
|
+ - 普通行(<= max_chars):积累到超限时 flush,下一个 chunk 以末尾若干行作重叠头。
|
|
|
+ - 超长行(> max_chars):先 flush 当前积累,再对该行做字符级滑窗切分,
|
|
|
+ 每片段 max_chars 字符,步长 max_chars - overlap_chars(即相邻片段重叠 overlap_chars)。
|
|
|
+ """
|
|
|
+ if not lines:
|
|
|
+ return []
|
|
|
+
|
|
|
+ chunks: List[List[str]] = []
|
|
|
+ current_lines: List[str] = []
|
|
|
+ current_chars: int = 0
|
|
|
+
|
|
|
+ def _flush():
|
|
|
+ """保存当前 chunk,并以末尾若干行作为下一个 chunk 的重叠起始。"""
|
|
|
+ nonlocal current_lines, current_chars
|
|
|
+ if not current_lines:
|
|
|
+ return
|
|
|
+ chunks.append(list(current_lines))
|
|
|
+ overlap_lines: List[str] = []
|
|
|
+ overlap_len: int = 0
|
|
|
+ for prev in reversed(current_lines):
|
|
|
+ overlap_lines.insert(0, prev)
|
|
|
+ overlap_len += len(prev)
|
|
|
+ if overlap_len >= overlap_chars:
|
|
|
+ break
|
|
|
+ current_lines = overlap_lines
|
|
|
+ current_chars = overlap_len
|
|
|
+
|
|
|
+ for line in lines:
|
|
|
+ line_chars = len(line)
|
|
|
+
|
|
|
+ if line_chars > max_chars:
|
|
|
+ # 超长行:先 flush,再对该行做字符级滑窗切分
|
|
|
+ _flush()
|
|
|
+ step = max_chars - overlap_chars # 滑动步长
|
|
|
+ start = 0
|
|
|
+ while start < line_chars:
|
|
|
+ piece = line[start: start + max_chars]
|
|
|
+ chunks.append([piece])
|
|
|
+ start += step
|
|
|
+ # 以最后一片段末尾的 overlap_chars 个字符作重叠起始
|
|
|
+ last_piece = line[max(0, line_chars - overlap_chars):]
|
|
|
+ current_lines = [last_piece]
|
|
|
+ current_chars = len(last_piece)
|
|
|
+ else:
|
|
|
+ # 普通行:加入后超限则先 flush
|
|
|
+ if current_chars + line_chars > max_chars and current_lines:
|
|
|
+ _flush()
|
|
|
+ current_lines.append(line)
|
|
|
+ current_chars += line_chars
|
|
|
+
|
|
|
+ if current_lines:
|
|
|
+ chunks.append(current_lines)
|
|
|
+
|
|
|
+ return chunks
|
|
|
+
|
|
|
+
|
|
|
+def split_section_into_chunks(
|
|
|
+ lines: List[str],
|
|
|
+ max_chars: int = 600,
|
|
|
+ overlap_chars: int = 30
|
|
|
+) -> List[Dict[str, Any]]:
|
|
|
+ """
|
|
|
+ 将二级分类下的行列表切分为 chunks。
|
|
|
+
|
|
|
+ 规则:
|
|
|
+ - Markdown 表格(以 | 开头且以 | 结尾的连续行)作为独立 chunk,不切断、不与其他内容合并、无重叠。
|
|
|
+ - 普通文本按 max_chars 字符数切分,相邻 chunk 之间有 overlap_chars 字符的重叠。
|
|
|
+ - 单行超过 max_chars 时做字符级滑窗切分,相邻片段之间同样保留 overlap_chars 重叠。
|
|
|
+
|
|
|
+ Args:
|
|
|
+ lines: 行列表(不含行号标记)
|
|
|
+ max_chars: 每个文本 chunk 的最大字符数,默认 600
|
|
|
+ overlap_chars: 相邻文本 chunk 的重叠字符数,默认 30
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ List[Dict]: 每个元素包含:
|
|
|
+ - 'type': 'text' 或 'table'
|
|
|
+ - 'lines': 该 chunk 对应的行列表
|
|
|
+ """
|
|
|
+ if not lines:
|
|
|
+ return []
|
|
|
+
|
|
|
+ # Step 1:将行序列分割为交替的 table_segment / text_segment
|
|
|
+ segments: List[Tuple[str, List[str]]] = []
|
|
|
+ i = 0
|
|
|
+ while i < len(lines):
|
|
|
+ if _is_markdown_table_line(lines[i]):
|
|
|
+ table_lines: List[str] = []
|
|
|
+ while i < len(lines) and _is_markdown_table_line(lines[i]):
|
|
|
+ table_lines.append(lines[i])
|
|
|
+ i += 1
|
|
|
+ segments.append(('table', table_lines))
|
|
|
+ else:
|
|
|
+ text_lines: List[str] = []
|
|
|
+ while i < len(lines) and not _is_markdown_table_line(lines[i]):
|
|
|
+ text_lines.append(lines[i])
|
|
|
+ i += 1
|
|
|
+ segments.append(('text', text_lines))
|
|
|
+
|
|
|
+ # Step 2:表格段整体输出;文本段按字符数切分并加重叠
|
|
|
+ result: List[Dict[str, Any]] = []
|
|
|
+ for seg_type, seg_lines in segments:
|
|
|
+ if seg_type == 'table':
|
|
|
+ result.append({'type': 'table', 'lines': seg_lines})
|
|
|
+ else:
|
|
|
+ for chunk_lines in _split_text_lines_with_overlap(seg_lines, max_chars, overlap_chars):
|
|
|
+ result.append({'type': 'text', 'lines': chunk_lines})
|
|
|
+
|
|
|
+ return result
|
|
|
+
|
|
|
+
|
|
|
+# ==================== 快速测试入口 ====================
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ import io
|
|
|
+ import sys
|
|
|
+ from datetime import datetime
|
|
|
+
|
|
|
+ # 修复 Windows 终端 UTF-8 输出
|
|
|
+ sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", errors="replace")
|
|
|
+
|
|
|
+ TEST_JSON_PATH = Path(r"temp\construction_review\final_result\4148f6019f89e061b15679666f646893-1773993108.json")
|
|
|
+ OUTPUT_DIR = Path(r"temp\construction_review\llm_content_classifier_v2")
|
|
|
+ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
+
|
|
|
+ def _sep(title: str = "", width: int = 70):
|
|
|
+ print(f"\n{'=' * width}\n {title}\n{'=' * width}" if title else "─" * width)
|
|
|
+
|
|
|
+ def _load_chunks_from_json(json_path: Path) -> List[Dict[str, Any]]:
|
|
|
+ with open(json_path, encoding="utf-8") as f:
|
|
|
+ data = json.load(f)
|
|
|
+ if "document_result" in data:
|
|
|
+ return data["document_result"]["structured_content"]["chunks"]
|
|
|
+ return data["data"]["document_result"]["structured_content"]["chunks"]
|
|
|
+
|
|
|
+ # ── 加载数据 ──────────────────────────────────────────────
|
|
|
+ _sep("加载测试数据")
|
|
|
+ if not TEST_JSON_PATH.exists():
|
|
|
+ print(f"[ERROR] 文件不存在: {TEST_JSON_PATH}")
|
|
|
+ sys.exit(1)
|
|
|
+
|
|
|
+ raw_chunks = _load_chunks_from_json(TEST_JSON_PATH)
|
|
|
+ print(f"原始 chunks 数: {len(raw_chunks)}")
|
|
|
+
|
|
|
+ # ── 运行完整分类流程 ───────────────────────────────────────
|
|
|
+ _sep("运行三级分类(LLMContentClassifier)")
|
|
|
+ config = ClassifierConfig()
|
|
|
+ print(f"模型: {config.model}")
|
|
|
+ print(f"Embedding 模型: {config.embedding_model}")
|
|
|
+ print(f"相似度阈值: {config.embedding_similarity_threshold}")
|
|
|
+
|
|
|
+ classifier = LLMContentClassifier(config)
|
|
|
+ updated_chunks = asyncio.run(classifier.classify_chunks(raw_chunks))
|
|
|
+
|
|
|
+ # ── 保存结果 ──────────────────────────────────────────────
|
|
|
+ _sep("保存结果")
|
|
|
+ ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
|
+ result_file = OUTPUT_DIR / f"result_{ts}.json"
|
|
|
+ with open(result_file, "w", encoding="utf-8") as f:
|
|
|
+ json.dump(updated_chunks, f, ensure_ascii=False, indent=2)
|
|
|
+ print(f"完整结果已保存: {result_file}")
|
|
|
+
|
|
|
+ # ── 控制台汇总展示 ────────────────────────────────────────
|
|
|
+ _sep("分类结果汇总")
|
|
|
+
|
|
|
+ # 按 section_label 聚合三级分类详情
|
|
|
+ section_map: Dict[str, List[Dict]] = {}
|
|
|
+ for chunk in updated_chunks:
|
|
|
+ label = chunk.get("section_label") or chunk.get("chunk_id", "unknown")
|
|
|
+ details = chunk.get("tertiary_classification_details", [])
|
|
|
+ if label not in section_map:
|
|
|
+ section_map[label] = []
|
|
|
+ for d in details:
|
|
|
+ key = d["third_category_code"]
|
|
|
+ if not any(x["third_category_code"] == key for x in section_map[label]):
|
|
|
+ section_map[label].append(d)
|
|
|
+
|
|
|
+ total_third = 0
|
|
|
+ for label, details in section_map.items():
|
|
|
+ print(f"\n[{label}] 三级分类数={len(details)}")
|
|
|
+ for d in details:
|
|
|
+ line_range = f"L{d.get('start_line', '?')}-{d.get('end_line', '?')}"
|
|
|
+ preview = (d.get("content") or "")[:50].replace("\n", " ")
|
|
|
+ print(f" ├ {d['third_category_name']}({d['third_category_code']}) {line_range} {preview}...")
|
|
|
+ total_third += len(details)
|
|
|
+
|
|
|
+ _sep()
|
|
|
+ print(f"处理 chunks: {len(updated_chunks)} | 识别三级分类: {total_third} | 结果目录: {OUTPUT_DIR}")
|