|
|
@@ -8,14 +8,18 @@ from __future__ import annotations
|
|
|
|
|
|
import asyncio
|
|
|
import csv
|
|
|
+import json
|
|
|
+import re
|
|
|
from collections import OrderedDict
|
|
|
from pathlib import Path
|
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
|
|
+from foundation.infrastructure.config.config import config_handler
|
|
|
+from foundation.observability.logger.loggering import review_logger as logger
|
|
|
+from foundation.ai.agent.generate.model_generate import generate_model_client
|
|
|
+
|
|
|
from ..config.provider import default_config_provider
|
|
|
-from ..utils.llm_client import LLMClient
|
|
|
from ..utils.prompt_loader import PromptLoader
|
|
|
-from foundation.observability.logger.loggering import review_logger as logger
|
|
|
|
|
|
# 延迟导入新的三级分类器(避免循环导入)
|
|
|
_LLM_CONTENT_CLASSIFIER = None
|
|
|
@@ -33,30 +37,50 @@ def _get_llm_content_classifier():
|
|
|
return _LLM_CONTENT_CLASSIFIER
|
|
|
|
|
|
|
|
|
+def _extract_json(text: str) -> Optional[Dict[str, Any]]:
|
|
|
+ """从字符串中提取第一个有效 JSON 对象"""
|
|
|
+ for pattern in [r"```json\s*(\{.*?})\s*```", r"```\s*(\{.*?})\s*```"]:
|
|
|
+ m = re.search(pattern, text, re.DOTALL)
|
|
|
+ if m:
|
|
|
+ try:
|
|
|
+ return json.loads(m.group(1))
|
|
|
+ except json.JSONDecodeError:
|
|
|
+ pass
|
|
|
+ try:
|
|
|
+ for candidate in re.findall(r"(\{.*?\})", text, re.DOTALL):
|
|
|
+ try:
|
|
|
+ return json.loads(candidate)
|
|
|
+ except json.JSONDecodeError:
|
|
|
+ pass
|
|
|
+ except Exception:
|
|
|
+ pass
|
|
|
+ return None
|
|
|
+
|
|
|
+
|
|
|
class ChunkClassifier:
|
|
|
"""内容块分类器(二级和三级分类)"""
|
|
|
|
|
|
def __init__(self):
|
|
|
"""初始化分类器"""
|
|
|
self._cfg = default_config_provider
|
|
|
-
|
|
|
- # 初始化LLM客户端和提示词加载器
|
|
|
- self.llm_client = LLMClient(config_provider=self._cfg)
|
|
|
+ self._concurrency = int(config_handler.get("llm_keywords", "CONCURRENT_WORKERS", "20"))
|
|
|
+
|
|
|
+ # 初始化提示词加载器
|
|
|
self.prompt_loader = PromptLoader()
|
|
|
-
|
|
|
+
|
|
|
# 加载CSV分类标准
|
|
|
self._load_classification_standards()
|
|
|
|
|
|
def _load_classification_standards(self):
|
|
|
"""从CSV文件加载二级和三级分类标准"""
|
|
|
csv_file = Path(__file__).parent.parent / "config" / "StandardCategoryTable.csv"
|
|
|
-
|
|
|
+
|
|
|
if not csv_file.exists():
|
|
|
raise FileNotFoundError(f"分类标准CSV文件不存在: {csv_file}")
|
|
|
-
|
|
|
+
|
|
|
# 结构: {first_code: {second_code: {second_cn, second_focus, third_items: [{third_code, third_cn, third_focus}]}}}
|
|
|
self.classification_tree: Dict[str, Dict[str, Any]] = {}
|
|
|
-
|
|
|
+
|
|
|
with csv_file.open("r", encoding="utf-8-sig") as f:
|
|
|
reader = csv.DictReader(f)
|
|
|
for row in reader:
|
|
|
@@ -68,14 +92,14 @@ class ChunkClassifier:
|
|
|
third_code = (row.get("third_code") or "").strip()
|
|
|
third_cn = (row.get("third_name") or "").strip()
|
|
|
third_focus = (row.get("third_focus") or "").strip()
|
|
|
-
|
|
|
+
|
|
|
if not first_code or not second_code:
|
|
|
continue
|
|
|
-
|
|
|
+
|
|
|
# 初始化一级类别
|
|
|
if first_code not in self.classification_tree:
|
|
|
self.classification_tree[first_code] = {}
|
|
|
-
|
|
|
+
|
|
|
# 初始化二级类别
|
|
|
if second_code not in self.classification_tree[first_code]:
|
|
|
self.classification_tree[first_code][second_code] = {
|
|
|
@@ -83,7 +107,7 @@ class ChunkClassifier:
|
|
|
"second_focus": second_focus,
|
|
|
"third_items": []
|
|
|
}
|
|
|
-
|
|
|
+
|
|
|
# 添加三级类别(如果存在)
|
|
|
if third_code and third_cn:
|
|
|
self.classification_tree[first_code][second_code]["third_items"].append({
|
|
|
@@ -95,102 +119,143 @@ class ChunkClassifier:
|
|
|
def _build_secondary_standards(self, first_category_code: str) -> tuple[str, dict]:
|
|
|
"""
|
|
|
构建二级分类标准文本
|
|
|
-
|
|
|
+
|
|
|
返回:
|
|
|
(标准文本, 索引映射字典)
|
|
|
"""
|
|
|
if first_category_code not in self.classification_tree:
|
|
|
return "(无二级分类标准)", {}
|
|
|
-
|
|
|
+
|
|
|
standards_lines = [" 0. 非标准项 - 不符合以下任何类别"]
|
|
|
index_mapping = {0: ("非标准项", "non_standard")}
|
|
|
-
|
|
|
+
|
|
|
for idx, (second_code, second_data) in enumerate(self.classification_tree[first_category_code].items(), 1):
|
|
|
second_cn = second_data["second_cn"]
|
|
|
second_focus = second_data["second_focus"]
|
|
|
-
|
|
|
+
|
|
|
# 保存索引映射
|
|
|
index_mapping[idx] = (second_cn, second_code)
|
|
|
-
|
|
|
+
|
|
|
if second_focus and second_focus != "NULL":
|
|
|
standards_lines.append(f" {idx}. {second_cn} - 关注点:{second_focus}")
|
|
|
else:
|
|
|
standards_lines.append(f" {idx}. {second_cn}")
|
|
|
-
|
|
|
+
|
|
|
return "\n".join(standards_lines) if standards_lines else "(无二级分类标准)", index_mapping
|
|
|
|
|
|
def _build_tertiary_standards(self, first_category_code: str, second_category_code: str) -> tuple[str, dict]:
|
|
|
"""
|
|
|
构建三级分类标准文本
|
|
|
-
|
|
|
+
|
|
|
返回:
|
|
|
(标准文本, 索引映射字典)
|
|
|
"""
|
|
|
if first_category_code not in self.classification_tree:
|
|
|
return "(无三级分类标准)", {}
|
|
|
-
|
|
|
+
|
|
|
if second_category_code not in self.classification_tree[first_category_code]:
|
|
|
return "(无三级分类标准)", {}
|
|
|
-
|
|
|
+
|
|
|
third_items = self.classification_tree[first_category_code][second_category_code]["third_items"]
|
|
|
-
|
|
|
+
|
|
|
if not third_items:
|
|
|
return "(无三级分类标准)", {}
|
|
|
-
|
|
|
+
|
|
|
standards_lines = [" 0. 非标准项 - 不符合以下任何类别"]
|
|
|
index_mapping = {0: ("非标准项", "non_standard")}
|
|
|
-
|
|
|
+
|
|
|
for idx, third_item in enumerate(third_items, 1):
|
|
|
third_cn = third_item["third_cn"]
|
|
|
third_code = third_item["third_code"]
|
|
|
third_focus = third_item["third_focus"]
|
|
|
-
|
|
|
+
|
|
|
# 保存索引映射
|
|
|
index_mapping[idx] = (third_cn, third_code)
|
|
|
-
|
|
|
+
|
|
|
if third_focus and third_focus != "NULL":
|
|
|
standards_lines.append(f" {idx}. {third_cn} - 关注点:{third_focus}")
|
|
|
else:
|
|
|
standards_lines.append(f" {idx}. {third_cn}")
|
|
|
-
|
|
|
+
|
|
|
return "\n".join(standards_lines), index_mapping
|
|
|
|
|
|
+ async def _call_llm_once(self, system_prompt: str, user_prompt: str) -> Optional[Dict[str, Any]]:
|
|
|
+ """
|
|
|
+ 单次异步 LLM 调用(使用统一的 GenerateModelClient)
|
|
|
+
|
|
|
+ 失败返回 None,由调用方决定处理逻辑
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ content = await generate_model_client.get_model_generate_invoke(
|
|
|
+ trace_id="chunk_classifier",
|
|
|
+ system_prompt=system_prompt,
|
|
|
+ user_prompt=user_prompt,
|
|
|
+ model_name="qwen3_5_122b_a10b", # 使用 122B 大模型提升分类准确性
|
|
|
+ )
|
|
|
+ result = _extract_json(content)
|
|
|
+ return result if result is not None else {"raw_content": content}
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"[ChunkClassifier] LLM 调用失败: {e}")
|
|
|
+ return None
|
|
|
+
|
|
|
+ async def _batch_call_llm(
|
|
|
+ self,
|
|
|
+ requests: List[tuple], # [(system_prompt, user_prompt), ...]
|
|
|
+ ) -> List[Optional[Dict[str, Any]]]:
|
|
|
+ """
|
|
|
+ 并发批量调用 LLM(带信号量控制)
|
|
|
+
|
|
|
+ 参数:
|
|
|
+ requests: 请求列表,每个元素是 (system_prompt, user_prompt) 元组
|
|
|
+
|
|
|
+ 返回:
|
|
|
+ 结果列表,与输入请求一一对应
|
|
|
+ """
|
|
|
+ semaphore = asyncio.Semaphore(self._concurrency)
|
|
|
+
|
|
|
+ async def bounded_call(system_prompt: str, user_prompt: str):
|
|
|
+ async with semaphore:
|
|
|
+ return await self._call_llm_once(system_prompt, user_prompt)
|
|
|
+
|
|
|
+ tasks = [bounded_call(sp, up) for sp, up in requests]
|
|
|
+ return list(await asyncio.gather(*tasks))
|
|
|
+
|
|
|
async def classify_chunks_secondary_async(self, chunks: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
|
"""
|
|
|
异步对chunks进行二级分类
|
|
|
-
|
|
|
+
|
|
|
参数:
|
|
|
chunks: 已完成一级分类的chunk列表
|
|
|
-
|
|
|
+
|
|
|
返回:
|
|
|
添加了二级分类字段的chunk列表
|
|
|
"""
|
|
|
logger.info(f"正在对 {len(chunks)} 个内容块进行二级分类...")
|
|
|
-
|
|
|
+
|
|
|
# 准备LLM请求
|
|
|
llm_requests = []
|
|
|
valid_chunks = []
|
|
|
index_mappings = [] # 保存每个请求对应的索引映射
|
|
|
-
|
|
|
+
|
|
|
for chunk in chunks:
|
|
|
first_category_code = chunk.get("chapter_classification", "")
|
|
|
chunk_title = chunk.get("section_label", "")
|
|
|
hierarchy_path = " -> ".join(chunk.get("hierarchy_path", []))
|
|
|
content = chunk.get("review_chunk_content", "")
|
|
|
content_preview = content[:300] if content else ""
|
|
|
-
|
|
|
+
|
|
|
# 获取一级分类的中文名称
|
|
|
first_category_cn = self._get_first_category_cn(first_category_code)
|
|
|
-
|
|
|
+
|
|
|
# 构建二级分类标准(返回标准文本和索引映射)
|
|
|
secondary_standards, index_mapping = self._build_secondary_standards(first_category_code)
|
|
|
-
|
|
|
+
|
|
|
if secondary_standards == "(无二级分类标准)":
|
|
|
# 如果没有二级分类标准,跳过
|
|
|
chunk["secondary_category_cn"] = "无"
|
|
|
chunk["secondary_category_code"] = "none"
|
|
|
continue
|
|
|
-
|
|
|
+
|
|
|
# 渲染提示词
|
|
|
prompt = self.prompt_loader.render(
|
|
|
"chunk_secondary_classification",
|
|
|
@@ -200,28 +265,23 @@ class ChunkClassifier:
|
|
|
content_preview=content_preview,
|
|
|
secondary_standards=secondary_standards
|
|
|
)
|
|
|
-
|
|
|
- messages = [
|
|
|
- {"role": "system", "content": prompt["system"]},
|
|
|
- {"role": "user", "content": prompt["user"]}
|
|
|
- ]
|
|
|
-
|
|
|
- llm_requests.append(messages)
|
|
|
+
|
|
|
+ llm_requests.append((prompt["system"], prompt["user"]))
|
|
|
valid_chunks.append(chunk)
|
|
|
index_mappings.append(index_mapping)
|
|
|
-
|
|
|
+
|
|
|
if not llm_requests:
|
|
|
logger.info("所有内容块都没有二级分类标准,跳过二级分类")
|
|
|
return chunks
|
|
|
-
|
|
|
+
|
|
|
# 批量异步调用LLM API
|
|
|
- llm_results = await self.llm_client.batch_call_async(llm_requests)
|
|
|
-
|
|
|
+ llm_results = await self._batch_call_llm(llm_requests)
|
|
|
+
|
|
|
# 处理分类结果
|
|
|
for chunk, llm_result, index_mapping in zip(valid_chunks, llm_results, index_mappings):
|
|
|
if llm_result and isinstance(llm_result, dict):
|
|
|
category_index = llm_result.get("category_index")
|
|
|
-
|
|
|
+
|
|
|
# 验证索引并映射到类别
|
|
|
if isinstance(category_index, int) and category_index in index_mapping:
|
|
|
secondary_cn, secondary_code = index_mapping[category_index]
|
|
|
@@ -235,7 +295,7 @@ class ChunkClassifier:
|
|
|
else:
|
|
|
chunk["secondary_category_cn"] = "非标准项"
|
|
|
chunk["secondary_category_code"] = "non_standard"
|
|
|
-
|
|
|
+
|
|
|
logger.info("二级分类完成!")
|
|
|
return chunks
|
|
|
|
|
|
@@ -317,12 +377,12 @@ class ChunkClassifier:
|
|
|
每个chunk只能属于一个三级分类
|
|
|
"""
|
|
|
logger.info(f"正在对 {len(chunks)} 个内容块进行三级分类...")
|
|
|
-
|
|
|
+
|
|
|
# 准备LLM请求
|
|
|
llm_requests = []
|
|
|
valid_chunks = []
|
|
|
index_mappings = [] # 保存每个请求对应的索引映射
|
|
|
-
|
|
|
+
|
|
|
for chunk in chunks:
|
|
|
first_category_code = chunk.get("chapter_classification", "")
|
|
|
second_category_code = chunk.get("secondary_category_code", "")
|
|
|
@@ -330,19 +390,19 @@ class ChunkClassifier:
|
|
|
chunk_title = chunk.get("section_label", "")
|
|
|
content = chunk.get("review_chunk_content", "")
|
|
|
content_preview = content[:300] if content else ""
|
|
|
-
|
|
|
+
|
|
|
# 获取一级分类的中文名称
|
|
|
first_category_cn = self._get_first_category_cn(first_category_code)
|
|
|
-
|
|
|
+
|
|
|
# 构建三级分类标准(返回标准文本和索引映射)
|
|
|
tertiary_standards, index_mapping = self._build_tertiary_standards(first_category_code, second_category_code)
|
|
|
-
|
|
|
+
|
|
|
if tertiary_standards == "(无三级分类标准)":
|
|
|
# 如果没有三级分类标准,跳过
|
|
|
chunk["tertiary_category_cn"] = "无"
|
|
|
chunk["tertiary_category_code"] = "none"
|
|
|
continue
|
|
|
-
|
|
|
+
|
|
|
# 渲染提示词
|
|
|
prompt = self.prompt_loader.render(
|
|
|
"chunk_tertiary_classification",
|
|
|
@@ -352,28 +412,23 @@ class ChunkClassifier:
|
|
|
content_preview=content_preview,
|
|
|
tertiary_standards=tertiary_standards
|
|
|
)
|
|
|
-
|
|
|
- messages = [
|
|
|
- {"role": "system", "content": prompt["system"]},
|
|
|
- {"role": "user", "content": prompt["user"]}
|
|
|
- ]
|
|
|
-
|
|
|
- llm_requests.append(messages)
|
|
|
+
|
|
|
+ llm_requests.append((prompt["system"], prompt["user"]))
|
|
|
valid_chunks.append(chunk)
|
|
|
index_mappings.append(index_mapping)
|
|
|
-
|
|
|
+
|
|
|
if not llm_requests:
|
|
|
logger.info("所有内容块都没有三级分类标准,跳过三级分类")
|
|
|
return chunks
|
|
|
-
|
|
|
+
|
|
|
# 批量异步调用LLM API
|
|
|
- llm_results = await self.llm_client.batch_call_async(llm_requests)
|
|
|
-
|
|
|
+ llm_results = await self._batch_call_llm(llm_requests)
|
|
|
+
|
|
|
# 处理分类结果
|
|
|
for chunk, llm_result, index_mapping in zip(valid_chunks, llm_results, index_mappings):
|
|
|
if llm_result and isinstance(llm_result, dict):
|
|
|
category_index = llm_result.get("category_index")
|
|
|
-
|
|
|
+
|
|
|
# 验证索引并映射到类别
|
|
|
if isinstance(category_index, int) and category_index in index_mapping:
|
|
|
tertiary_cn, tertiary_code = index_mapping[category_index]
|
|
|
@@ -387,7 +442,7 @@ class ChunkClassifier:
|
|
|
else:
|
|
|
chunk["tertiary_category_cn"] = "非标准项"
|
|
|
chunk["tertiary_category_code"] = "non_standard"
|
|
|
-
|
|
|
+
|
|
|
logger.info("三级分类完成!")
|
|
|
return chunks
|
|
|
|
|
|
@@ -435,4 +490,4 @@ class ChunkClassifier:
|
|
|
classifier_config=classifier_config
|
|
|
))
|
|
|
except RuntimeError:
|
|
|
- raise RuntimeError("请使用 await classify_chunks_tertiary_async")
|
|
|
+ raise RuntimeError("请使用 await classify_chunks_tertiary_async")
|