|
|
@@ -0,0 +1,499 @@
|
|
|
+"""
|
|
|
+目录审查主程序
|
|
|
+使用llm_chain_client模块审查施工方案目录,找出缺失的目录项
|
|
|
+"""
|
|
|
+import asyncio
|
|
|
+import json
|
|
|
+import csv
|
|
|
+import logging
|
|
|
+import ast
|
|
|
+from pathlib import Path
|
|
|
+from typing import Dict, List, Any
|
|
|
+import sys
|
|
|
+from typing import List, Dict, Any, Union
|
|
|
+
|
|
|
+import pandas as pd
|
|
|
+
|
|
|
+from ..utils.llm_chain_client.bootstrap import Bootstrap
|
|
|
+from ..utils.llm_chain_client.orchestration import PromptChainProcessor
|
|
|
+
|
|
|
+# 配置日志
|
|
|
+logging.basicConfig(
|
|
|
+ level=logging.INFO,
|
|
|
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
|
+)
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
+
|
|
|
+
|
|
|
+class CatalogCheckProcessor:
|
|
|
+ """目录审查处理器"""
|
|
|
+
|
|
|
+ def __init__(self, processor: PromptChainProcessor):
|
|
|
+ """
|
|
|
+ 初始化处理器
|
|
|
+
|
|
|
+ Args:
|
|
|
+ processor: 提示链处理器实例
|
|
|
+ """
|
|
|
+ self.processor = processor
|
|
|
+
|
|
|
+ def load_specifications(self, spec_file: str) -> Dict[str, Dict[str, Any]]:
|
|
|
+ """
|
|
|
+ 加载规范要求文件
|
|
|
+
|
|
|
+ Args:
|
|
|
+ spec_file: 规范文件路径
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 按标签分组的规范要求字典
|
|
|
+ {
|
|
|
+ "basis": {
|
|
|
+ "一级目录": "编制依据",
|
|
|
+ "二级目录": ["法律法规", "标准规范", ...]
|
|
|
+ },
|
|
|
+ ...
|
|
|
+ }
|
|
|
+ """
|
|
|
+ logger.info(f"加载规范要求文件: {spec_file}")
|
|
|
+
|
|
|
+ specifications = {}
|
|
|
+
|
|
|
+ # 尝试不同的编码读取文件:先尝试 utf-8-sig(支持 BOM),然后尝试 utf-16,最后尝试中文编码(GBK/GB2312/GB18030)
|
|
|
+ encodings = ['utf-8-sig', 'utf-16', 'gbk', 'gb2312', 'gb18030']
|
|
|
+ content = None
|
|
|
+ used_encoding = None
|
|
|
+
|
|
|
+ for encoding in encodings:
|
|
|
+ try:
|
|
|
+ with open(spec_file, 'r', encoding=encoding) as f:
|
|
|
+ content = f.read()
|
|
|
+ used_encoding = encoding
|
|
|
+ logger.info(f"成功使用 {encoding} 编码读取文件")
|
|
|
+ break
|
|
|
+ except UnicodeDecodeError:
|
|
|
+ continue
|
|
|
+
|
|
|
+ if content is None:
|
|
|
+ raise ValueError(f"无法使用常见编码读取文件: {spec_file}")
|
|
|
+
|
|
|
+ # 解析CSV内容(使用制表符作为分隔符)
|
|
|
+ lines = content.strip().split('\n')
|
|
|
+ reader = csv.reader(lines, delimiter='\t')
|
|
|
+
|
|
|
+ # 跳过标题行
|
|
|
+ next(reader, None)
|
|
|
+
|
|
|
+ for row in reader:
|
|
|
+ if len(row) >= 3:
|
|
|
+ label = row[0].strip()
|
|
|
+ primary_dir = row[1].strip()
|
|
|
+ secondary_dir = row[2].strip()
|
|
|
+
|
|
|
+ if label not in specifications:
|
|
|
+ specifications[label] = {
|
|
|
+ "一级目录": primary_dir,
|
|
|
+ "二级目录": []
|
|
|
+ }
|
|
|
+
|
|
|
+ # 避免重复添加
|
|
|
+ if secondary_dir not in specifications[label]["二级目录"]:
|
|
|
+ specifications[label]["二级目录"].append(secondary_dir)
|
|
|
+
|
|
|
+ logger.info(f"加载规范要求完成,共 {len(specifications)} 个标签")
|
|
|
+ return specifications
|
|
|
+
|
|
|
+ def load_catalog_data(self, csv_file: str) -> List[Dict[str, Any]]:
|
|
|
+ """
|
|
|
+ 加载待审查目录数据
|
|
|
+
|
|
|
+ Args:
|
|
|
+ csv_file: CSV文件路径
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 目录数据列表
|
|
|
+ """
|
|
|
+ logger.info(f"加载待审查目录文件: {csv_file}")
|
|
|
+
|
|
|
+ catalog_data = []
|
|
|
+
|
|
|
+ with open(csv_file, 'r', encoding='utf-8-sig') as f:
|
|
|
+ reader = csv.DictReader(f)
|
|
|
+
|
|
|
+ for row in reader:
|
|
|
+ # 解析subsections列(Python字典字符串,使用单引号)
|
|
|
+ subsections_str = row.get('subsections', '[]')
|
|
|
+ try:
|
|
|
+ # 尝试使用 ast.literal_eval 解析(支持单引号)
|
|
|
+ subsections = ast.literal_eval(subsections_str)
|
|
|
+ # 确保结果是列表
|
|
|
+ if not isinstance(subsections, list):
|
|
|
+ subsections = []
|
|
|
+ except (ValueError, SyntaxError):
|
|
|
+ # 如果解析失败,尝试使用 json.loads(需要双引号)
|
|
|
+ try:
|
|
|
+ subsections = json.loads(subsections_str)
|
|
|
+ except json.JSONDecodeError:
|
|
|
+ subsections = []
|
|
|
+
|
|
|
+ catalog_data.append({
|
|
|
+ 'index': row.get('index', ''),
|
|
|
+ 'title': row.get('title', ''),
|
|
|
+ 'page': row.get('page', ''),
|
|
|
+ 'chapter_classification': row.get('chapter_classification', ''),
|
|
|
+ 'subsections': subsections
|
|
|
+ })
|
|
|
+
|
|
|
+ logger.info(f"加载待审查目录完成,共 {len(catalog_data)} 个章节")
|
|
|
+ return catalog_data
|
|
|
+
|
|
|
+ def build_requirements_text(self, spec: Dict[str, Any]) -> str:
|
|
|
+ """
|
|
|
+ 构造规范要求文本
|
|
|
+
|
|
|
+ Args:
|
|
|
+ spec: 规范要求字典
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 规范要求文本
|
|
|
+ 例如: "编制依据章节应包含1.法律法规、2.标准规范、3.文件制度等方面的内容"
|
|
|
+ """
|
|
|
+ primary_dir = spec["一级目录"]
|
|
|
+ secondary_dirs = spec["二级目录"]
|
|
|
+
|
|
|
+ # 构造二级目录列表,带序号
|
|
|
+ secondary_list = [
|
|
|
+ f"{i+1}.{item}"
|
|
|
+ for i, item in enumerate(secondary_dirs)
|
|
|
+ ]
|
|
|
+
|
|
|
+ # 用顿号连接
|
|
|
+ secondary_text = "、".join(secondary_list)
|
|
|
+
|
|
|
+ return f"{primary_dir}章节应包含{secondary_text}等方面的内容"
|
|
|
+
|
|
|
+ def build_catalog_content_text(self, subsections: List[Dict[str, Any]]) -> str:
|
|
|
+ """
|
|
|
+ 构造待审查目录文本
|
|
|
+
|
|
|
+ Args:
|
|
|
+ subsections: 二级目录项列表
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 目录文本
|
|
|
+ """
|
|
|
+ if not subsections:
|
|
|
+ return "待审查目录为空"
|
|
|
+
|
|
|
+ titles = [item.get('title', '') for item in subsections]
|
|
|
+ return f"待审查目录包含:{'、'.join(titles)}"
|
|
|
+
|
|
|
+ async def check_catalog(
|
|
|
+ self,
|
|
|
+ chapter_title: str,
|
|
|
+ catalog_content: str,
|
|
|
+ requirements: str
|
|
|
+ ) -> str:
|
|
|
+ """
|
|
|
+ 检查目录,找出缺失的目录项
|
|
|
+
|
|
|
+ Args:
|
|
|
+ chapter_title: 章节标题
|
|
|
+ catalog_content: 待审查目录内容
|
|
|
+ requirements: 规范要求
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 缺失的目录项序号(如:"3,5" 或 "无缺失")
|
|
|
+ """
|
|
|
+ # 准备输入数据
|
|
|
+ input_data = {
|
|
|
+ "chapter_title": chapter_title,
|
|
|
+ "catalog_content": catalog_content,
|
|
|
+ "requirements": requirements
|
|
|
+ }
|
|
|
+
|
|
|
+ # 执行提示链 - 使用绝对路径
|
|
|
+ current_dir = Path(__file__).parent
|
|
|
+ chain_config_path = str(current_dir / "config" / "prompts" / "catalog_check_chain.yaml")
|
|
|
+
|
|
|
+ try:
|
|
|
+ result = await self.processor.process(
|
|
|
+ chain_config_path=chain_config_path,
|
|
|
+ input_data=input_data
|
|
|
+ )
|
|
|
+
|
|
|
+ # 获取最终结果
|
|
|
+ missing_items = result.get("final_result", "")
|
|
|
+ return missing_items
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"目录检查失败: {e}")
|
|
|
+ return f"检查失败: {str(e)}"
|
|
|
+
|
|
|
+ async def process_all_catalogs(
|
|
|
+ self,
|
|
|
+ spec_file: str | Any,
|
|
|
+ catalog_file: str | Any
|
|
|
+ ) -> List[Dict[str, Any]]:
|
|
|
+ """
|
|
|
+ 处理所有章节的目录审查
|
|
|
+
|
|
|
+ Args:
|
|
|
+ spec_file: 规范要求文件路径
|
|
|
+ catalog_file: 待审查目录文件路径
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 审查结果列表
|
|
|
+ """
|
|
|
+ # 加载数据
|
|
|
+ if type(spec_file) == str:
|
|
|
+ specifications = self.load_specifications(spec_file)
|
|
|
+ else:
|
|
|
+ specifications = spec_file
|
|
|
+
|
|
|
+ if type(catalog_file) == str:
|
|
|
+ catalog_data = self.load_catalog_data(catalog_file)
|
|
|
+ else:
|
|
|
+ # catalog_file 是列表(如 original_outline),转换为 DataFrame
|
|
|
+ catalog_data = pd.DataFrame(catalog_file)
|
|
|
+ # 确保数据格式正确,转换为字典列表
|
|
|
+ catalog_data = catalog_data.to_dict('records')
|
|
|
+
|
|
|
+ results = []
|
|
|
+
|
|
|
+ for i, catalog in enumerate(catalog_data):
|
|
|
+ logger.info(f"处理第 {i+1}/{len(catalog_data)} 个章节: {catalog['title']}")
|
|
|
+
|
|
|
+ # 获取章节分类
|
|
|
+ label = catalog.get('chapter_classification', '')
|
|
|
+
|
|
|
+ # 查找规范要求
|
|
|
+ if label not in specifications:
|
|
|
+ logger.warning(f"未找到标签 '{label}' 的规范要求")
|
|
|
+ results.append({
|
|
|
+ 'index': catalog['index'],
|
|
|
+ 'title': catalog['title'],
|
|
|
+ 'chapter_classification': label,
|
|
|
+ 'missing_items': f"未找到标签 '{label}' 的规范要求"
|
|
|
+ })
|
|
|
+ continue
|
|
|
+
|
|
|
+ spec = specifications[label]
|
|
|
+
|
|
|
+ # 构造规范要求文本
|
|
|
+ requirements = self.build_requirements_text(spec)
|
|
|
+
|
|
|
+ # 构造待审查目录文本
|
|
|
+ catalog_content = self.build_catalog_content_text(catalog['subsections'])
|
|
|
+
|
|
|
+ # 检查目录
|
|
|
+ missing_items = await self.check_catalog(
|
|
|
+ chapter_title=catalog['title'],
|
|
|
+ catalog_content=catalog_content,
|
|
|
+ requirements=requirements
|
|
|
+ )
|
|
|
+
|
|
|
+ # 记录结果
|
|
|
+ results.append({
|
|
|
+ 'index': catalog['index'],
|
|
|
+ 'title': catalog['title'],
|
|
|
+ 'chapter_classification': label,
|
|
|
+ 'missing_items': missing_items
|
|
|
+ })
|
|
|
+
|
|
|
+ logger.info(f"审查结果: {missing_items}")
|
|
|
+
|
|
|
+ # 将缺失项的数字替换为对应的项名称
|
|
|
+ results = self._replace_missing_numbers_with_names(results, specifications)
|
|
|
+
|
|
|
+ return results
|
|
|
+
|
|
|
+ def _replace_missing_numbers_with_names(
|
|
|
+ self,
|
|
|
+ results: List[Dict[str, Any]],
|
|
|
+ specifications: Dict[str, Dict[str, Any]]
|
|
|
+ ) -> List[Dict[str, Any]]:
|
|
|
+ """
|
|
|
+ 将缺失项的数字替换为对应的项名称
|
|
|
+
|
|
|
+ Args:
|
|
|
+ results: 审查结果列表
|
|
|
+ specifications: 规范要求字典
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 替换后的结果列表
|
|
|
+ """
|
|
|
+ for result in results:
|
|
|
+ label = result.get('chapter_classification', '')
|
|
|
+ missing_items_str = result.get('missing_items', '')
|
|
|
+
|
|
|
+ # 如果没有缺失项,跳过
|
|
|
+ if not missing_items_str or missing_items_str == '无缺失':
|
|
|
+ result['missing_items'] = json.dumps([], ensure_ascii=False)
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 获取对应的规范
|
|
|
+ if label not in specifications:
|
|
|
+ logger.warning(f"未找到标签 '{label}' 的规范要求,无法替换缺失项")
|
|
|
+ continue
|
|
|
+
|
|
|
+ spec = specifications[label]
|
|
|
+ secondary_dirs = spec.get('二级目录', [])
|
|
|
+
|
|
|
+ # 解析缺失项数字
|
|
|
+ try:
|
|
|
+ missing_numbers = [int(x.strip()) for x in missing_items_str.split(',')]
|
|
|
+ except (ValueError, AttributeError):
|
|
|
+ logger.warning(f"无法解析缺失项: {missing_items_str}")
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 将数字替换为对应的项名称
|
|
|
+ missing_names = []
|
|
|
+ for num in missing_numbers:
|
|
|
+ if 1 <= num <= len(secondary_dirs):
|
|
|
+ missing_names.append(secondary_dirs[num - 1])
|
|
|
+ else:
|
|
|
+ logger.warning(f"缺失项编号 {num} 超出范围,标签 '{label}' 只有 {len(secondary_dirs)} 项")
|
|
|
+
|
|
|
+ # 更新结果,保存为JSON列表字符串
|
|
|
+ result['missing_items'] = json.dumps(missing_names, ensure_ascii=False)
|
|
|
+
|
|
|
+ return results
|
|
|
+
|
|
|
+ def save_results(self, results: List[Dict[str, Any]], output_file: str):
|
|
|
+ """
|
|
|
+ 保存审查结果
|
|
|
+
|
|
|
+ Args:
|
|
|
+ results: 审查结果列表
|
|
|
+ output_file: 输出文件路径
|
|
|
+ """
|
|
|
+ logger.info(f"保存审查结果到: {output_file}")
|
|
|
+
|
|
|
+ with open(output_file, 'w', encoding='utf-8-sig', newline='') as f:
|
|
|
+ writer = csv.DictWriter(f, fieldnames=['index', 'title', 'chapter_classification', 'missing_items'])
|
|
|
+ writer.writeheader()
|
|
|
+ writer.writerows(results)
|
|
|
+
|
|
|
+ logger.info(f"审查结果已保存,共 {len(results)} 条记录")
|
|
|
+
|
|
|
+
|
|
|
+def remove_common_elements_between_dataframes(
|
|
|
+ miss_outline_df: pd.DataFrame,
|
|
|
+ redis_data: pd.DataFrame
|
|
|
+) -> tuple[pd.DataFrame, Dict[str, set]]:
|
|
|
+ """
|
|
|
+ 去除两个DataFrame中相同chapter_label行的miss_outline列与missing_items列的公共元素
|
|
|
+
|
|
|
+ Args:
|
|
|
+ miss_outline_df: 包含miss_outline列的DataFrame
|
|
|
+ redis_data: 包含missing_items列的DataFrame
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ tuple: (更新后的miss_outline_df, 公共元素字典)
|
|
|
+ """
|
|
|
+ # 合并两个DataFrame,基于chapter_label
|
|
|
+ merged_df = pd.merge(miss_outline_df, redis_data, on='chapter_label', how='inner', suffixes=('_outline', '_redis'))
|
|
|
+
|
|
|
+ # 创建一个字典来存储公共元素
|
|
|
+ common_elements_dict = {}
|
|
|
+
|
|
|
+ # 遍历合并后的DataFrame,计算公共元素
|
|
|
+ for index, row in merged_df.iterrows():
|
|
|
+ chapter_label = row['chapter_label']
|
|
|
+ miss_outline_list = row['miss_outline']
|
|
|
+ missing_items_list = row['missing_items']
|
|
|
+
|
|
|
+ # 确保列表类型正确
|
|
|
+ if not isinstance(miss_outline_list, list):
|
|
|
+ miss_outline_list = []
|
|
|
+ if not isinstance(missing_items_list, list):
|
|
|
+ missing_items_list = []
|
|
|
+
|
|
|
+ # 转换为集合以便计算差集
|
|
|
+ miss_outline_set = set(miss_outline_list)
|
|
|
+ missing_items_set = set(missing_items_list)
|
|
|
+
|
|
|
+ # 计算公共元素
|
|
|
+ common_elements = miss_outline_set & missing_items_set
|
|
|
+
|
|
|
+ # 存储公共元素
|
|
|
+ common_elements_dict[chapter_label] = common_elements
|
|
|
+
|
|
|
+ logger.info(f"[目录审查] 章节: {chapter_label}, 公共元素: {common_elements}")
|
|
|
+
|
|
|
+ # 更新 miss_outline_df,使用apply函数去除公共元素
|
|
|
+ miss_outline_df['miss_outline'] = miss_outline_df.apply(
|
|
|
+ lambda row: list(set(row['miss_outline']) - common_elements_dict.get(row['chapter_label'], set()))
|
|
|
+ if isinstance(row['miss_outline'], list) else [],
|
|
|
+ axis=1
|
|
|
+ )
|
|
|
+
|
|
|
+ logger.info(f"[目录审查] 已去除公共元素,更新后的miss_outline_df: {miss_outline_df.to_dict('records')}")
|
|
|
+
|
|
|
+ return miss_outline_df, common_elements_dict
|
|
|
+
|
|
|
+
|
|
|
+async def catalogues_check(catalog_file = None):
|
|
|
+ """主函数"""
|
|
|
+ # 获取当前文件所在目录
|
|
|
+ current_dir = Path(__file__).parent
|
|
|
+
|
|
|
+ # 创建提示链处理器
|
|
|
+ processor = Bootstrap.create_processor(
|
|
|
+ model_type=None, # 从配置文件读取
|
|
|
+ prompts_dir=str(current_dir / "config" / "prompts"),
|
|
|
+ config_path=str(current_dir / "config" / "llm_api.yaml")
|
|
|
+ )
|
|
|
+
|
|
|
+ # 创建目录审查处理器
|
|
|
+ catalog_processor = CatalogCheckProcessor(processor)
|
|
|
+
|
|
|
+ # 定义文件路径
|
|
|
+ spec_file = str(current_dir / "config" / "Construction_Plan_Content_Specification.csv")
|
|
|
+
|
|
|
+ # 处理所有章节
|
|
|
+ results = await catalog_processor.process_all_catalogs(
|
|
|
+ spec_file=spec_file,
|
|
|
+ catalog_file=catalog_file
|
|
|
+ )
|
|
|
+ return results
|
|
|
+
|
|
|
+async def main():
|
|
|
+ """主函数"""
|
|
|
+ # 获取当前文件所在目录
|
|
|
+ current_dir = Path(__file__).parent
|
|
|
+
|
|
|
+ # 创建提示链处理器
|
|
|
+ processor = Bootstrap.create_processor(
|
|
|
+ model_type=None, # 从配置文件读取
|
|
|
+ prompts_dir=str(current_dir / "config" / "prompts"),
|
|
|
+ config_path=str(current_dir / "config" / "llm_api.yaml")
|
|
|
+ )
|
|
|
+
|
|
|
+ # 创建目录审查处理器
|
|
|
+ catalog_processor = CatalogCheckProcessor(processor)
|
|
|
+
|
|
|
+ # 定义文件路径
|
|
|
+ spec_file = str(current_dir / "config" / "Construction_Plan_Content_Specification.csv")
|
|
|
+ catalog_file = "文档切分预处理结果.csv"
|
|
|
+ output_file = "catalog_check_results.csv"
|
|
|
+
|
|
|
+ # 处理所有章节
|
|
|
+ results = await catalog_processor.process_all_catalogs(
|
|
|
+ spec_file=spec_file,
|
|
|
+ catalog_file=catalog_file
|
|
|
+ )
|
|
|
+
|
|
|
+ # 保存结果
|
|
|
+ catalog_processor.save_results(results, output_file)
|
|
|
+
|
|
|
+ # 打印摘要
|
|
|
+ logger.info("=" * 50)
|
|
|
+ logger.info("目录审查完成")
|
|
|
+ logger.info(f"共处理 {len(results)} 个章节")
|
|
|
+ logger.info(f"结果已保存到: {output_file}")
|
|
|
+ logger.info("=" * 50)
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ asyncio.run(main())
|