|
|
@@ -1,848 +0,0 @@
|
|
|
-#!/usr/bin/env python
|
|
|
-# -*- coding: utf-8 -*-
|
|
|
-"""
|
|
|
-ContentClassifierClient 核心分类逻辑
|
|
|
-"""
|
|
|
-
|
|
|
-import asyncio
|
|
|
-import json
|
|
|
-import re
|
|
|
-import time
|
|
|
-from typing import Dict, List, Optional, Tuple
|
|
|
-
|
|
|
-from .models import CategoryStandard, ClassifiedContent, ClassificationResult, SectionContent
|
|
|
-from .embedding_client import EmbeddingClient
|
|
|
-from foundation.ai.agent.generate.model_generate import generate_model_client
|
|
|
-from .category_loaders import SecondCategoryStandardLoader
|
|
|
-from .json_utils import _fix_json, _aggressive_json_fix
|
|
|
-from .prompt import (
|
|
|
- CLASSIFY_SYSTEM_PROMPT,
|
|
|
- SUPPLEMENT_VERIFY_SYSTEM_PROMPT,
|
|
|
- build_classify_prompt,
|
|
|
- build_fix_prompt,
|
|
|
- build_supplement_verify_prompt,
|
|
|
-)
|
|
|
-from foundation.observability.logger.loggering import review_logger as logger
|
|
|
-
|
|
|
-
|
|
|
-class ContentClassifierClient:
|
|
|
- """LLM 内容分类客户端"""
|
|
|
-
|
|
|
- def __init__(self, model: str, semaphore: asyncio.Semaphore, embedding_client: Optional[EmbeddingClient] = None, second_category_loader: Optional[SecondCategoryStandardLoader] = None, enable_thinking: bool = False):
|
|
|
- self.model = model
|
|
|
- self.semaphore = semaphore
|
|
|
- self.embedding_client = embedding_client
|
|
|
- self.second_category_loader = second_category_loader
|
|
|
- self.enable_thinking = enable_thinking
|
|
|
-
|
|
|
- async def classify_content(self, section: SectionContent) -> ClassificationResult:
|
|
|
- """对内容进行三级分类识别(带并发控制和自动修复,支持长内容分块处理)"""
|
|
|
- start_time = time.time()
|
|
|
-
|
|
|
- # 步骤1: 使用Embedding模型检查二级分类与内容的相似度
|
|
|
- if self.embedding_client and self.second_category_loader and section.category_standards:
|
|
|
- # 从construction_plan_standards.csv中查找对应的标准二级分类
|
|
|
- # 使用section_name进行匹配
|
|
|
- std_second_category = self.second_category_loader.get_standard_by_second_name(section.section_name)
|
|
|
-
|
|
|
- if std_second_category:
|
|
|
- # 找到了对应的标准二级分类,进行相似度检查
|
|
|
- # 检查section内容与标准的second_raw_content的一致性
|
|
|
- section_text = '\n'.join(section.lines)
|
|
|
- is_similar, similarity = await self.embedding_client.check_similarity(
|
|
|
- section_name=section.section_name,
|
|
|
- section_content=section_text,
|
|
|
- second_category_name=std_second_category.second_name,
|
|
|
- second_category_raw_content=std_second_category.second_raw_content
|
|
|
- )
|
|
|
-
|
|
|
- if is_similar:
|
|
|
- from .config import EMBEDDING_SIMILARITY_THRESHOLD
|
|
|
- logger.debug(f"[{section.section_name}] 相似度检查通过 ({similarity:.3f} >= {EMBEDDING_SIMILARITY_THRESHOLD}),跳过LLM分类,默认包含所有三级分类")
|
|
|
- # 生成默认分类结果:包含所有三级分类
|
|
|
- all_contents = self._generate_default_classification(section)
|
|
|
- total_lines, classified_lines, coverage_rate = self._calculate_coverage_rate(section, all_contents)
|
|
|
- latency = time.time() - start_time
|
|
|
- return ClassificationResult(
|
|
|
- model=self.model,
|
|
|
- section_key=section.section_key,
|
|
|
- section_name=section.section_name,
|
|
|
- classified_contents=all_contents,
|
|
|
- latency=latency,
|
|
|
- raw_response=f"[Embedding相似度跳过] similarity={similarity:.3f}",
|
|
|
- error=None,
|
|
|
- total_lines=total_lines,
|
|
|
- classified_lines=classified_lines,
|
|
|
- coverage_rate=coverage_rate
|
|
|
- )
|
|
|
- else:
|
|
|
- logger.debug(f"[{section.section_name}] 相似度检查未通过 ({similarity:.3f} < ?),继续LLM分类")
|
|
|
- else:
|
|
|
- logger.debug(f"[{section.section_name}] 未在construction_plan_standards.csv中找到对应标准,继续LLM分类")
|
|
|
-
|
|
|
- # 如果内容过长,分块处理
|
|
|
- MAX_LINES_PER_CHUNK = 150 # 每个块最多150行
|
|
|
- MAX_CHARS_PER_CHUNK = 3000 # 每个块最多3000字符
|
|
|
- OVERLAP_CHARS = 100 # 相邻块之间重叠约100字符
|
|
|
- total_lines = len(section.lines)
|
|
|
- total_chars = sum(len(line) for line in section.lines)
|
|
|
-
|
|
|
- if total_lines <= MAX_LINES_PER_CHUNK and total_chars <= MAX_CHARS_PER_CHUNK:
|
|
|
- # 内容不长,直接处理
|
|
|
- result = await self._classify_single_chunk(section, start_time)
|
|
|
- # 补充验证:关键字扫描 + LLM二次确认,补充遗漏的分类
|
|
|
- if not result.error and result.classified_contents is not None:
|
|
|
- supplement = await self._detect_and_supplement(section, result.classified_contents)
|
|
|
- if supplement:
|
|
|
- merged = self._merge_classified_contents(result.classified_contents + supplement, section)
|
|
|
- total_l, classified_l, coverage_r = self._calculate_coverage_rate(section, merged)
|
|
|
- return ClassificationResult(
|
|
|
- model=result.model,
|
|
|
- section_key=result.section_key,
|
|
|
- section_name=result.section_name,
|
|
|
- classified_contents=merged,
|
|
|
- latency=result.latency,
|
|
|
- raw_response=result.raw_response,
|
|
|
- error=result.error,
|
|
|
- total_lines=total_l,
|
|
|
- classified_lines=classified_l,
|
|
|
- coverage_rate=coverage_r
|
|
|
- )
|
|
|
- return result
|
|
|
-
|
|
|
- # 内容过长,按字符数+行数双限制分块处理(带重叠)
|
|
|
- logger.debug(
|
|
|
- f"[{section.section_name}] 内容较长({total_lines}行, {total_chars}字符),"
|
|
|
- f"按 max_lines={MAX_LINES_PER_CHUNK}, max_chars={MAX_CHARS_PER_CHUNK}, overlap={OVERLAP_CHARS} 分块处理..."
|
|
|
- )
|
|
|
- chunk_ranges = self._split_section_into_chunks(
|
|
|
- section, MAX_LINES_PER_CHUNK, MAX_CHARS_PER_CHUNK, OVERLAP_CHARS
|
|
|
- )
|
|
|
- all_contents = []
|
|
|
-
|
|
|
- for chunk_start, chunk_end in chunk_ranges:
|
|
|
- chunk_section = self._create_chunk_section(section, chunk_start, chunk_end)
|
|
|
- chunk_result = await self._classify_single_chunk(chunk_section, 0, is_chunk=True)
|
|
|
-
|
|
|
- if chunk_result.error:
|
|
|
- logger.error(f"[{section.section_name}] 块 {chunk_start+1}-{chunk_end} 处理失败: {chunk_result.error[:50]}")
|
|
|
- else:
|
|
|
- logger.debug(f"[{section.section_name}] 块 {chunk_start+1}-{chunk_end} 成功: {len(chunk_result.classified_contents)} 个分类")
|
|
|
- all_contents.extend(chunk_result.classified_contents)
|
|
|
-
|
|
|
- # 所有块处理完成后,再次聚合所有内容(解决分块导致的同一分类分散问题)
|
|
|
- if all_contents:
|
|
|
- all_contents = self._merge_classified_contents(all_contents, section)
|
|
|
-
|
|
|
- # 补充验证:关键字扫描 + LLM二次确认,补充遗漏的分类
|
|
|
- supplement = await self._detect_and_supplement(section, all_contents)
|
|
|
- if supplement:
|
|
|
- all_contents = self._merge_classified_contents(all_contents + supplement, section)
|
|
|
-
|
|
|
- # 计算分类率
|
|
|
- total_lines, classified_lines, coverage_rate = self._calculate_coverage_rate(section, all_contents)
|
|
|
-
|
|
|
- latency = time.time() - start_time
|
|
|
-
|
|
|
- return ClassificationResult(
|
|
|
- model=self.model,
|
|
|
- section_key=section.section_key,
|
|
|
- section_name=section.section_name,
|
|
|
- classified_contents=all_contents,
|
|
|
- latency=latency,
|
|
|
- raw_response="",
|
|
|
- error=None if all_contents else "所有块处理失败",
|
|
|
- total_lines=total_lines,
|
|
|
- classified_lines=classified_lines,
|
|
|
- coverage_rate=coverage_rate
|
|
|
- )
|
|
|
-
|
|
|
- def _split_section_into_chunks(
|
|
|
- self,
|
|
|
- section: SectionContent,
|
|
|
- max_lines: int = 150,
|
|
|
- max_chars: int = 3000,
|
|
|
- overlap_chars: int = 100
|
|
|
- ) -> List[Tuple[int, int]]:
|
|
|
- """将 section 切分成多个子块,满足行数和字符数上限,并带字符重叠。"""
|
|
|
- lines = section.lines
|
|
|
- total = len(lines)
|
|
|
- if total == 0:
|
|
|
- return [(0, 0)]
|
|
|
-
|
|
|
- chunks = []
|
|
|
- start = 0
|
|
|
- while start < total:
|
|
|
- end = start
|
|
|
- chars = 0
|
|
|
- # 同时满足行数和字符数两个限制
|
|
|
- while end < total and (end - start) < max_lines and chars + len(lines[end]) <= max_chars:
|
|
|
- chars += len(lines[end])
|
|
|
- end += 1
|
|
|
-
|
|
|
- # 至少保证一行
|
|
|
- if end == start:
|
|
|
- end = start + 1
|
|
|
-
|
|
|
- chunks.append((start, end))
|
|
|
-
|
|
|
- if end >= total:
|
|
|
- break
|
|
|
-
|
|
|
- # 计算下一次 start,保留约 overlap_chars 的字符重叠
|
|
|
- next_start = end - 1
|
|
|
- overlap_acc = 0
|
|
|
- while next_start > start and overlap_acc < overlap_chars:
|
|
|
- overlap_acc += len(lines[next_start])
|
|
|
- next_start -= 1
|
|
|
- start = next_start + 1
|
|
|
-
|
|
|
- return chunks
|
|
|
-
|
|
|
- def _calculate_coverage_rate(self, section: SectionContent, contents: List[ClassifiedContent]) -> tuple:
|
|
|
- """计算分类率(已分类行数/总行数)"""
|
|
|
- total_lines = len(section.lines)
|
|
|
- if total_lines == 0 or not contents:
|
|
|
- return total_lines, 0, 0.0
|
|
|
-
|
|
|
- # 使用集合记录已分类的行号(避免重复计数)
|
|
|
- classified_line_set = set()
|
|
|
-
|
|
|
- for content in contents:
|
|
|
- if section.line_number_map:
|
|
|
- # 如果有全局行号映射,找出起止行号对应的索引
|
|
|
- start_idx = -1
|
|
|
- end_idx = -1
|
|
|
- for idx, global_line in enumerate(section.line_number_map):
|
|
|
- if global_line == content.start_line:
|
|
|
- start_idx = idx
|
|
|
- if global_line == content.end_line:
|
|
|
- end_idx = idx
|
|
|
- break
|
|
|
-
|
|
|
- if start_idx != -1 and end_idx != -1:
|
|
|
- for i in range(start_idx, end_idx + 1):
|
|
|
- if i < len(section.line_number_map):
|
|
|
- classified_line_set.add(section.line_number_map[i])
|
|
|
- else:
|
|
|
- # 没有全局行号,直接使用起止行号
|
|
|
- for line_num in range(content.start_line, content.end_line + 1):
|
|
|
- classified_line_set.add(line_num)
|
|
|
-
|
|
|
- classified_lines = len(classified_line_set)
|
|
|
- coverage_rate = (classified_lines / total_lines) * 100 if total_lines > 0 else 0.0
|
|
|
-
|
|
|
- return total_lines, classified_lines, coverage_rate
|
|
|
-
|
|
|
- def _generate_default_classification(self, section: SectionContent) -> List[ClassifiedContent]:
|
|
|
- """
|
|
|
- 生成默认的分类结果(当embedding相似度检查通过时使用)
|
|
|
- 默认包含所有三级分类,覆盖整个section内容
|
|
|
- """
|
|
|
- if not section.category_standards:
|
|
|
- return []
|
|
|
-
|
|
|
- # 获取全局行号范围
|
|
|
- if section.line_number_map:
|
|
|
- start_line = section.line_number_map[0]
|
|
|
- end_line = section.line_number_map[-1]
|
|
|
- else:
|
|
|
- start_line = 1
|
|
|
- end_line = len(section.lines)
|
|
|
-
|
|
|
- # 为每个三级分类创建一个条目,覆盖全部内容
|
|
|
- default_contents = []
|
|
|
- for std in section.category_standards:
|
|
|
- # 提取该分类对应的内容
|
|
|
- content = self._extract_content_by_line_numbers(section, start_line, end_line)
|
|
|
- default_contents.append(ClassifiedContent(
|
|
|
- third_category_name=std.third_name,
|
|
|
- third_category_code=std.third_code,
|
|
|
- third_seq=std.third_seq,
|
|
|
- start_line=start_line,
|
|
|
- end_line=end_line,
|
|
|
- content=content
|
|
|
- ))
|
|
|
-
|
|
|
- return default_contents
|
|
|
-
|
|
|
- def _create_chunk_section(self, section: SectionContent, start_idx: int, end_idx: int) -> SectionContent:
|
|
|
- """从section创建子块"""
|
|
|
- chunk_lines = section.lines[start_idx:end_idx]
|
|
|
- chunk_line_map = section.line_number_map[start_idx:end_idx] if section.line_number_map else list(range(start_idx + 1, end_idx + 1))
|
|
|
-
|
|
|
- # 生成带行号的内容
|
|
|
- numbered_content = '\n'.join([f"<{chunk_line_map[i]}> {line}" for i, line in enumerate(chunk_lines)])
|
|
|
-
|
|
|
- return SectionContent(
|
|
|
- section_key=f"{section.section_key}_chunk_{start_idx}_{end_idx}",
|
|
|
- section_name=section.section_name,
|
|
|
- lines=chunk_lines,
|
|
|
- numbered_content=numbered_content,
|
|
|
- category_standards=section.category_standards,
|
|
|
- line_number_map=chunk_line_map
|
|
|
- )
|
|
|
-
|
|
|
- async def _classify_single_chunk(self, section: SectionContent, start_time: float, is_chunk: bool = False) -> ClassificationResult:
|
|
|
- """处理单个块"""
|
|
|
- prompt = self._build_prompt(section, is_chunk=is_chunk)
|
|
|
-
|
|
|
- try:
|
|
|
- async with self.semaphore:
|
|
|
- response = await self._call_api(prompt)
|
|
|
-
|
|
|
- classified_contents, parse_error = await self._parse_with_fix(response, section, prompt)
|
|
|
-
|
|
|
- if not is_chunk:
|
|
|
- latency = time.time() - start_time
|
|
|
- # 计算分类率
|
|
|
- total_lines, classified_lines, coverage_rate = self._calculate_coverage_rate(section, classified_contents)
|
|
|
- return ClassificationResult(
|
|
|
- model=self.model,
|
|
|
- section_key=section.section_key,
|
|
|
- section_name=section.section_name,
|
|
|
- classified_contents=classified_contents,
|
|
|
- latency=latency,
|
|
|
- raw_response=response[:1000],
|
|
|
- error=parse_error,
|
|
|
- total_lines=total_lines,
|
|
|
- classified_lines=classified_lines,
|
|
|
- coverage_rate=coverage_rate
|
|
|
- )
|
|
|
- else:
|
|
|
- return ClassificationResult(
|
|
|
- model=self.model,
|
|
|
- section_key=section.section_key,
|
|
|
- section_name=section.section_name,
|
|
|
- classified_contents=classified_contents,
|
|
|
- latency=0,
|
|
|
- raw_response="",
|
|
|
- error=parse_error
|
|
|
- )
|
|
|
- except Exception as e:
|
|
|
- if not is_chunk:
|
|
|
- latency = time.time() - start_time
|
|
|
- return ClassificationResult(
|
|
|
- model=self.model,
|
|
|
- section_key=section.section_key,
|
|
|
- section_name=section.section_name,
|
|
|
- classified_contents=[],
|
|
|
- latency=latency,
|
|
|
- error=str(e)
|
|
|
- )
|
|
|
- else:
|
|
|
- return ClassificationResult(
|
|
|
- model=self.model,
|
|
|
- section_key=section.section_key,
|
|
|
- section_name=section.section_name,
|
|
|
- classified_contents=[],
|
|
|
- latency=0,
|
|
|
- error=str(e)
|
|
|
- )
|
|
|
-
|
|
|
- async def _parse_with_fix(self, response: str, section: SectionContent, original_prompt: str = "") -> tuple:
|
|
|
- """解析响应,失败时让模型修复(最多3次重试)
|
|
|
-
|
|
|
- 返回: (contents, error_msg)
|
|
|
- - contents: 分类结果列表(可能为空,表示模型判定无匹配内容)
|
|
|
- - error_msg: 错误信息,None表示成功(包括空结果),非None表示解析失败
|
|
|
- """
|
|
|
- # 第一次尝试解析
|
|
|
- contents, parse_success = self._parse_response(response, section)
|
|
|
-
|
|
|
- # 解析成功(包括空结果,表示模型判定内容不符合任何分类标准)
|
|
|
- if parse_success:
|
|
|
- if not contents:
|
|
|
- logger.debug(f"[{section.section_name}] 模型判定无匹配内容,记录为未分类")
|
|
|
- return contents, None
|
|
|
-
|
|
|
- # 解析失败(JSON格式错误),尝试让模型修复(最多3次)
|
|
|
- logger.warning(f"[{section.section_name}] JSON解析失败,请求模型修复...")
|
|
|
- logger.debug(f"[{section.section_name}] 原始响应前200字符: {response[:200]}...")
|
|
|
-
|
|
|
- original_response = response
|
|
|
-
|
|
|
- for attempt in range(3):
|
|
|
- fix_prompt = self._build_fix_prompt(original_response)
|
|
|
-
|
|
|
- try:
|
|
|
- async with self.semaphore:
|
|
|
- fixed_response = await self._call_api(fix_prompt)
|
|
|
-
|
|
|
- # 尝试解析修复后的输出
|
|
|
- contents, parse_success = self._parse_response(fixed_response, section)
|
|
|
- if parse_success:
|
|
|
- logger.debug(f"[{section.section_name}] 模型修复成功(第{attempt+1}次)")
|
|
|
- if not contents:
|
|
|
- logger.debug(f"[{section.section_name}] 修复后模型判定无匹配内容,记录为未分类")
|
|
|
- return contents, None
|
|
|
- else:
|
|
|
- logger.debug(f"[{section.section_name}] 第{attempt+1}次修复失败,继续重试...")
|
|
|
- original_response = fixed_response
|
|
|
- except Exception as e:
|
|
|
- return [], f"请求模型修复失败: {str(e)}"
|
|
|
-
|
|
|
- logger.error(f"[{section.section_name}] 模型修复3次后仍无法解析JSON")
|
|
|
- return [], "模型修复3次后仍无法解析JSON"
|
|
|
-
|
|
|
- def _build_fix_prompt(self, original_response: str) -> str:
|
|
|
- """构建JSON修复提示词(委托给 prompt.py 中的 build_fix_prompt)"""
|
|
|
- return build_fix_prompt(original_response)
|
|
|
-
|
|
|
- def _build_prompt(self, section: SectionContent, is_chunk: bool = False) -> str:
|
|
|
- """构建分类提示词(委托给 prompt.py 中的 build_classify_prompt)"""
|
|
|
- return build_classify_prompt(section, is_chunk)
|
|
|
-
|
|
|
- async def _call_api(self, prompt: str) -> str:
|
|
|
- """调用API(使用统一的 GenerateModelClient,带指数退避重试)"""
|
|
|
- max_retries = 5
|
|
|
- base_delay = 2 # 基础延迟2秒
|
|
|
-
|
|
|
- for attempt in range(max_retries):
|
|
|
- try:
|
|
|
- # 使用统一的模型调用客户端
|
|
|
- # 该客户端已内置重试机制和 thinking 模式控制
|
|
|
- # 从配置获取 enable_thinking,默认禁用
|
|
|
- enable_thinking = getattr(self, 'enable_thinking', False)
|
|
|
- response = await generate_model_client.get_model_generate_invoke(
|
|
|
- trace_id="content_classifier",
|
|
|
- system_prompt=CLASSIFY_SYSTEM_PROMPT,
|
|
|
- user_prompt=prompt,
|
|
|
- model_name=self.model,
|
|
|
- enable_thinking=enable_thinking
|
|
|
- )
|
|
|
- return response
|
|
|
- except Exception as e:
|
|
|
- error_str = str(e)
|
|
|
- # 检查是否是429限流错误
|
|
|
- if "429" in error_str or "rate limit" in error_str.lower():
|
|
|
- if attempt < max_retries - 1:
|
|
|
- # 指数退避: 2^attempt * (1 + random)
|
|
|
- delay = base_delay * (2 ** attempt) + (hash(prompt) % 1000) / 1000
|
|
|
- logger.warning(f"API限流(429),等待 {delay:.1f}s 后重试 ({attempt + 1}/{max_retries})...")
|
|
|
- await asyncio.sleep(delay)
|
|
|
- continue
|
|
|
- # 其他错误或重试次数用完,抛出异常
|
|
|
- raise
|
|
|
-
|
|
|
- return ""
|
|
|
-
|
|
|
- def _parse_response(self, response: str, section: SectionContent) -> tuple:
|
|
|
- """解析响应(增强版,处理各种JSON格式问题)
|
|
|
-
|
|
|
- 返回: (contents, parse_success)
|
|
|
- - contents: 分类结果列表
|
|
|
- - parse_success: True表示JSON解析成功(包括空结果),False表示解析失败
|
|
|
- """
|
|
|
- if not response or not response.strip():
|
|
|
- return [], False # 空响应视为解析失败
|
|
|
-
|
|
|
- response = response.strip()
|
|
|
-
|
|
|
- # 尝试多种方式提取JSON
|
|
|
- json_str = None
|
|
|
-
|
|
|
- # 方法1: 从代码块中提取
|
|
|
- code_block_match = re.search(r'```(?:json)?\s*([\s\S]*?)```', response)
|
|
|
- if code_block_match:
|
|
|
- json_str = code_block_match.group(1).strip()
|
|
|
-
|
|
|
- # 方法2: 优先查找JSON数组(模型经常直接输出数组格式)
|
|
|
- if not json_str:
|
|
|
- # 使用非贪婪匹配找到第一个完整的数组
|
|
|
- array_match = re.search(r'\[[\s\S]*?\]', response)
|
|
|
- if array_match:
|
|
|
- potential_array = array_match.group(0)
|
|
|
- # 验证是否是有效的JSON数组
|
|
|
- try:
|
|
|
- parsed = json.loads(potential_array)
|
|
|
- if isinstance(parsed, list):
|
|
|
- json_str = potential_array
|
|
|
- except Exception:
|
|
|
- pass
|
|
|
-
|
|
|
- # 方法3: 查找JSON对象
|
|
|
- if not json_str:
|
|
|
- json_match = re.search(r'\{[\s\S]*\}', response)
|
|
|
- if json_match:
|
|
|
- json_str = json_match.group(0)
|
|
|
-
|
|
|
- if not json_str:
|
|
|
- return [], False # 未找到JSON结构,解析失败
|
|
|
-
|
|
|
- # 处理模型直接输出数组的情况(包装成对象格式)
|
|
|
- if json_str.strip().startswith('['):
|
|
|
- try:
|
|
|
- # 验证是有效的JSON数组
|
|
|
- array_data = json.loads(json_str)
|
|
|
- if isinstance(array_data, list):
|
|
|
- # 包装成期望的格式
|
|
|
- json_str = json.dumps({"classified_contents": array_data})
|
|
|
- except Exception:
|
|
|
- pass # 不是有效数组,继续后续处理
|
|
|
-
|
|
|
- # 先尝试直接解析,如果成功则不需要修复
|
|
|
- try:
|
|
|
- json.loads(json_str)
|
|
|
- # JSON 有效,直接使用
|
|
|
- except json.JSONDecodeError:
|
|
|
- # JSON 无效,尝试修复
|
|
|
- json_str = self._fix_json(json_str)
|
|
|
-
|
|
|
- try:
|
|
|
- data = json.loads(json_str)
|
|
|
- # 处理数组格式
|
|
|
- if isinstance(data, list):
|
|
|
- data = {"classified_contents": data}
|
|
|
- contents = []
|
|
|
- # 支持两种键名: classified_contents 或 classified_contents_list
|
|
|
- items = data.get("classified_contents", []) or data.get("classified_contents_list", [])
|
|
|
-
|
|
|
- # 构建索引映射表:索引 -> (third_name, third_code, third_seq)
|
|
|
- index_mapping = {0: ("非标准项", "no_standard", 0)}
|
|
|
- if section.category_standards:
|
|
|
- for i, std in enumerate(section.category_standards, 1):
|
|
|
- index_mapping[i] = (std.third_name, std.third_code, std.third_seq)
|
|
|
-
|
|
|
- for item in items:
|
|
|
- start_line = item.get("start_line", 0)
|
|
|
- end_line = item.get("end_line", 0)
|
|
|
-
|
|
|
- # 优先使用 category_index 进行映射
|
|
|
- category_index = item.get("category_index")
|
|
|
- if category_index is not None:
|
|
|
- # 通过索引映射获取标准名称、代码和序号
|
|
|
- idx = int(category_index) if isinstance(category_index, (int, float, str)) else 0
|
|
|
- category_name, category_code, category_seq = index_mapping.get(idx, ("非标准项", "no_standard", 0))
|
|
|
- else:
|
|
|
- # 兼容旧格式:直接读取 third_category_code 和 third_category_name
|
|
|
- category_code = item.get("third_category_code", "")
|
|
|
- category_name = item.get("third_category_name", "")
|
|
|
-
|
|
|
- # 清理分类名称格式:移除末尾的代码部分
|
|
|
- if category_name and " (" in category_name and category_name.endswith(")"):
|
|
|
- category_name = re.sub(r'\s*\([^)]+\)\s*$', '', category_name).strip()
|
|
|
-
|
|
|
- # 验证分类代码是否在有效列表中
|
|
|
- valid_codes = set(v[1] for v in index_mapping.values())
|
|
|
- if category_code not in valid_codes:
|
|
|
- logger.warning(f"发现非标准分类 '{category_name}' ({category_code}),强制归为非标准项")
|
|
|
- category_code = "no_standard"
|
|
|
- category_name = "非标准项"
|
|
|
-
|
|
|
- # 根据行号从section中提取原文
|
|
|
- content = self._extract_content_by_line_numbers(section, start_line, end_line)
|
|
|
- contents.append(ClassifiedContent(
|
|
|
- third_category_name=category_name,
|
|
|
- third_category_code=category_code,
|
|
|
- third_seq=category_seq,
|
|
|
- start_line=start_line,
|
|
|
- end_line=end_line,
|
|
|
- content=content
|
|
|
- ))
|
|
|
- # 聚合同一分类下相邻的内容
|
|
|
- contents = self._merge_classified_contents(contents, section)
|
|
|
- return contents, True # 解析成功(可能为空结果)
|
|
|
- except Exception as e:
|
|
|
- # 尝试更激进的修复
|
|
|
- try:
|
|
|
- fixed = self._aggressive_json_fix(json_str)
|
|
|
- data = json.loads(fixed)
|
|
|
- # 处理数组格式
|
|
|
- if isinstance(data, list):
|
|
|
- data = {"classified_contents": data}
|
|
|
- contents = []
|
|
|
- # 支持两种键名: classified_contents 或 classified_contents_list
|
|
|
- items = data.get("classified_contents", []) or data.get("classified_contents_list", [])
|
|
|
-
|
|
|
- # 构建索引映射表:索引 -> (third_name, third_code)
|
|
|
- index_mapping = {0: ("非标准项", "no_standard")}
|
|
|
- if section.category_standards:
|
|
|
- for i, std in enumerate(section.category_standards, 1):
|
|
|
- index_mapping[i] = (std.third_name, std.third_code)
|
|
|
-
|
|
|
- for item in items:
|
|
|
- start_line = item.get("start_line", 0)
|
|
|
- end_line = item.get("end_line", 0)
|
|
|
-
|
|
|
- # 优先使用 category_index 进行映射
|
|
|
- category_index = item.get("category_index")
|
|
|
- if category_index is not None:
|
|
|
- idx = int(category_index) if isinstance(category_index, (int, float, str)) else 0
|
|
|
- category_name, category_code = index_mapping.get(idx, ("非标准项", "no_standard"))
|
|
|
- else:
|
|
|
- # 兼容旧格式
|
|
|
- category_code = item.get("third_category_code", "")
|
|
|
- category_name = item.get("third_category_name", "")
|
|
|
- valid_codes = set(v[1] for v in index_mapping.values())
|
|
|
- if category_code not in valid_codes:
|
|
|
- logger.warning(f"发现非标准分类 '{category_name}' ({category_code}),强制归为非标准项")
|
|
|
- category_code = "no_standard"
|
|
|
- category_name = "非标准项"
|
|
|
-
|
|
|
- # 根据行号从section中提取原文
|
|
|
- content = self._extract_content_by_line_numbers(section, start_line, end_line)
|
|
|
- contents.append(ClassifiedContent(
|
|
|
- third_category_name=category_name,
|
|
|
- third_category_code=category_code,
|
|
|
- third_seq=0,
|
|
|
- start_line=start_line,
|
|
|
- end_line=end_line,
|
|
|
- content=content
|
|
|
- ))
|
|
|
- # 聚合同一分类下相邻的内容
|
|
|
- contents = self._merge_classified_contents(contents, section)
|
|
|
- return contents, True # 解析成功(可能为空结果)
|
|
|
- except Exception as e2:
|
|
|
- logger.error(f"解析JSON失败: {e}, 二次修复也失败: {e2}")
|
|
|
- logger.debug(f"原始响应前500字符: {response[:500]}...")
|
|
|
- logger.debug(f"提取的JSON前300字符: {json_str[:300]}...")
|
|
|
- return [], False # 解析失败
|
|
|
-
|
|
|
- def _merge_classified_contents(self, contents: List[ClassifiedContent], section: SectionContent) -> List[ClassifiedContent]:
|
|
|
- """将同一分类下的内容按区间合并(只有连续或重叠的区间才合并)"""
|
|
|
- if not contents:
|
|
|
- return contents
|
|
|
-
|
|
|
- # 按分类代码分组
|
|
|
- groups: Dict[str, List[ClassifiedContent]] = {}
|
|
|
- for content in contents:
|
|
|
- key = content.third_category_code
|
|
|
- if key not in groups:
|
|
|
- groups[key] = []
|
|
|
- groups[key].append(content)
|
|
|
-
|
|
|
- merged_contents = []
|
|
|
-
|
|
|
- for category_code, group_contents in groups.items():
|
|
|
- # 按起始行号排序
|
|
|
- group_contents.sort(key=lambda x: x.start_line)
|
|
|
-
|
|
|
- # 合并连续或重叠的区间
|
|
|
- merged_ranges = []
|
|
|
- for content in group_contents:
|
|
|
- if not merged_ranges:
|
|
|
- # 第一个区间
|
|
|
- merged_ranges.append({
|
|
|
- 'start': content.start_line,
|
|
|
- 'end': content.end_line
|
|
|
- })
|
|
|
- else:
|
|
|
- last_range = merged_ranges[-1]
|
|
|
- # 检查是否连续或重叠(允许3行的间隔也算连续)
|
|
|
- if content.start_line <= last_range['end'] + 3:
|
|
|
- # 扩展当前区间
|
|
|
- last_range['end'] = max(last_range['end'], content.end_line)
|
|
|
- else:
|
|
|
- # 不连续,新建区间
|
|
|
- merged_ranges.append({
|
|
|
- 'start': content.start_line,
|
|
|
- 'end': content.end_line
|
|
|
- })
|
|
|
-
|
|
|
- # 为每个合并后的区间创建条目
|
|
|
- for range_info in merged_ranges:
|
|
|
- merged_content = self._extract_content_by_line_numbers(
|
|
|
- section, range_info['start'], range_info['end']
|
|
|
- )
|
|
|
- merged_contents.append(ClassifiedContent(
|
|
|
- third_category_name=group_contents[0].third_category_name,
|
|
|
- third_category_code=category_code,
|
|
|
- third_seq=group_contents[0].third_seq,
|
|
|
- start_line=range_info['start'],
|
|
|
- end_line=range_info['end'],
|
|
|
- content=merged_content
|
|
|
- ))
|
|
|
-
|
|
|
- # 按起始行号排序最终结果
|
|
|
- merged_contents.sort(key=lambda x: x.start_line)
|
|
|
- return merged_contents
|
|
|
-
|
|
|
- def _extract_content_by_line_numbers(self, section: SectionContent, start_line: int, end_line: int) -> str:
|
|
|
- """根据全局行号从section中提取原文内容"""
|
|
|
- if not section.line_number_map:
|
|
|
- # 如果没有行号映射,使用相对索引
|
|
|
- start_idx = max(0, start_line - 1)
|
|
|
- end_idx = min(len(section.lines), end_line)
|
|
|
- return '\n'.join(section.lines[start_idx:end_idx])
|
|
|
-
|
|
|
- # 找到全局行号对应的索引
|
|
|
- start_idx = -1
|
|
|
- end_idx = -1
|
|
|
-
|
|
|
- for idx, global_line_num in enumerate(section.line_number_map):
|
|
|
- if global_line_num == start_line:
|
|
|
- start_idx = idx
|
|
|
- if global_line_num == end_line:
|
|
|
- end_idx = idx
|
|
|
- break
|
|
|
-
|
|
|
- # 如果没找到精确匹配,使用近似值
|
|
|
- if start_idx == -1:
|
|
|
- for idx, global_line_num in enumerate(section.line_number_map):
|
|
|
- if global_line_num >= start_line:
|
|
|
- start_idx = idx
|
|
|
- break
|
|
|
- if end_idx == -1:
|
|
|
- for idx in range(len(section.line_number_map) - 1, -1, -1):
|
|
|
- if section.line_number_map[idx] <= end_line:
|
|
|
- end_idx = idx
|
|
|
- break
|
|
|
-
|
|
|
- if start_idx == -1:
|
|
|
- start_idx = 0
|
|
|
- if end_idx == -1:
|
|
|
- end_idx = len(section.lines) - 1
|
|
|
-
|
|
|
- # 确保索引有效
|
|
|
- start_idx = max(0, min(start_idx, len(section.lines) - 1))
|
|
|
- end_idx = max(0, min(end_idx, len(section.lines) - 1))
|
|
|
-
|
|
|
- if start_idx > end_idx:
|
|
|
- start_idx, end_idx = end_idx, start_idx
|
|
|
-
|
|
|
- # 添加行号标记返回
|
|
|
- lines_with_numbers = []
|
|
|
- for i in range(start_idx, end_idx + 1):
|
|
|
- global_line = section.line_number_map[i] if i < len(section.line_number_map) else (i + 1)
|
|
|
- lines_with_numbers.append(f"<{global_line}> {section.lines[i]}")
|
|
|
-
|
|
|
- return '\n'.join(lines_with_numbers)
|
|
|
-
|
|
|
- async def _call_supplement_verification(
|
|
|
- self,
|
|
|
- section: SectionContent,
|
|
|
- std: CategoryStandard,
|
|
|
- hit_lines: List[int],
|
|
|
- matched_kws: List[str],
|
|
|
- is_table: bool = False
|
|
|
- ) -> bool:
|
|
|
- """针对单个候选遗漏分类发起补充验证LLM调用,返回是否存在。"""
|
|
|
- start = min(hit_lines)
|
|
|
- end = max(hit_lines)
|
|
|
- chunk_text = self._extract_content_by_line_numbers(section, start, end)
|
|
|
-
|
|
|
- prompt = build_supplement_verify_prompt(std, chunk_text, start, end, hit_lines, matched_kws, is_table)
|
|
|
-
|
|
|
- try:
|
|
|
- # 使用统一的模型调用客户端
|
|
|
- resp = await generate_model_client.get_model_generate_invoke(
|
|
|
- trace_id="content_classifier_supplement",
|
|
|
- system_prompt=SUPPLEMENT_VERIFY_SYSTEM_PROMPT,
|
|
|
- user_prompt=prompt,
|
|
|
- model_name=self.model,
|
|
|
- enable_thinking=False,
|
|
|
- timeout=30 # 补充验证较短超时
|
|
|
- )
|
|
|
- if "不存在" in resp:
|
|
|
- return False
|
|
|
- if "存在" in resp:
|
|
|
- return True
|
|
|
- # 格式异常,保守返回 True
|
|
|
- logger.warning(f"supplement_verify 格式异常: {resp[:50]}")
|
|
|
- return True
|
|
|
- except Exception as e:
|
|
|
- logger.warning(f"supplement_verify 调用失败: {e}")
|
|
|
- return True
|
|
|
-
|
|
|
- async def _detect_and_supplement(
|
|
|
- self,
|
|
|
- section: SectionContent,
|
|
|
- llm_results: List[ClassifiedContent]
|
|
|
- ) -> List[ClassifiedContent]:
|
|
|
- """扫描整个 section,补充 LLM 遗漏的三级分类(并发优化版)。
|
|
|
-
|
|
|
- 扫描范围:当前二级分类下的所有行(不跨二级分类,由 section.category_standards 保证)。
|
|
|
- 触发条件:该二级分类下某个三级标准未出现在 LLM 结果中。
|
|
|
- 注意:同一行内容可同时属于多个三级分类,不限制"已覆盖行"。
|
|
|
- """
|
|
|
- if not section.category_standards or not section.lines:
|
|
|
- return []
|
|
|
-
|
|
|
- # 已命中的有效分类(排除 no_standard)
|
|
|
- found_codes = {c.third_category_code for c in llm_results if c.third_category_code != 'no_standard'}
|
|
|
-
|
|
|
- # 判断整个 section 是否含表格特征
|
|
|
- full_text = ' '.join(section.lines)
|
|
|
- is_table = (
|
|
|
- any(kw in full_text for kw in ['序号', '作业活动', '风险源', '防范措施'])
|
|
|
- or full_text.count('|') > 5
|
|
|
- )
|
|
|
-
|
|
|
- # 准备需要验证的任务列表
|
|
|
- verification_tasks = []
|
|
|
- verification_info = [] # 保存对应的 std 和 hit_lines 信息
|
|
|
-
|
|
|
- for std in section.category_standards:
|
|
|
- if std.third_code in found_codes:
|
|
|
- continue
|
|
|
- if not std.keywords and not std.extra_prompt:
|
|
|
- continue
|
|
|
-
|
|
|
- keywords = [k.strip() for k in std.keywords.split(';') if k.strip()]
|
|
|
- # 同时从 extra_prompt 的引号内容中提取补充信号词,用于触发验证
|
|
|
- extra_signals = []
|
|
|
- if std.extra_prompt:
|
|
|
- import re
|
|
|
- quoted = re.findall(r'[""""]([^""""]+)[""""]', std.extra_prompt)
|
|
|
- extra_signals.extend([q.strip() for q in quoted if len(q.strip()) >= 2])
|
|
|
- scan_signals = keywords + extra_signals
|
|
|
-
|
|
|
- if is_table:
|
|
|
- # 表格路径:整个 section 行范围提交 LLM 验证
|
|
|
- if not section.line_number_map:
|
|
|
- continue
|
|
|
- hit_lines = [section.line_number_map[0], section.line_number_map[-1]]
|
|
|
- verification_tasks.append(
|
|
|
- self._call_supplement_verification(section, std, hit_lines, [], is_table=True)
|
|
|
- )
|
|
|
- verification_info.append((std, hit_lines))
|
|
|
- else:
|
|
|
- # 普通路径:扫描整个 section 所有行的关键字
|
|
|
- hit_lines, matched_kws = [], []
|
|
|
- for i, line_text in enumerate(section.lines):
|
|
|
- line_num = section.line_number_map[i] if section.line_number_map else (i + 1)
|
|
|
- for kw in scan_signals:
|
|
|
- if kw in line_text and line_num not in hit_lines:
|
|
|
- hit_lines.append(line_num)
|
|
|
- if kw not in matched_kws:
|
|
|
- matched_kws.append(kw)
|
|
|
- if not hit_lines:
|
|
|
- continue
|
|
|
- verification_tasks.append(
|
|
|
- self._call_supplement_verification(section, std, hit_lines, matched_kws)
|
|
|
- )
|
|
|
- verification_info.append((std, hit_lines))
|
|
|
-
|
|
|
- if not verification_tasks:
|
|
|
- return []
|
|
|
-
|
|
|
- # 并发执行所有验证任务
|
|
|
- results = await asyncio.gather(*verification_tasks, return_exceptions=True)
|
|
|
-
|
|
|
- # 收集验证通过的结果
|
|
|
- supplemented = []
|
|
|
- for (std, hit_lines), confirmed in zip(verification_info, results):
|
|
|
- if isinstance(confirmed, Exception):
|
|
|
- logger.warning(f"[{section.section_name}] 补充验证异常: {confirmed}")
|
|
|
- continue
|
|
|
- if confirmed:
|
|
|
- start, end = min(hit_lines), max(hit_lines)
|
|
|
- content = self._extract_content_by_line_numbers(section, start, end)
|
|
|
- supplemented.append(ClassifiedContent(
|
|
|
- third_category_name=std.third_name,
|
|
|
- third_category_code=std.third_code,
|
|
|
- third_seq=std.third_seq,
|
|
|
- start_line=start,
|
|
|
- end_line=end,
|
|
|
- content=content
|
|
|
- ))
|
|
|
-
|
|
|
- return supplemented
|
|
|
-
|
|
|
-
|
|
|
- def _fix_json(self, json_str: str) -> str:
|
|
|
- return _fix_json(json_str)
|
|
|
-
|
|
|
- def _aggressive_json_fix(self, json_str: str) -> str:
|
|
|
- return _aggressive_json_fix(json_str)
|