|
@@ -1,194 +1,208 @@
|
|
|
"""
|
|
"""
|
|
|
目录分类模块(基于LLM API智能识别)
|
|
目录分类模块(基于LLM API智能识别)
|
|
|
|
|
|
|
|
-适配 file_parse 的配置系统,通过异步并发调用LLM API来判断一级目录的分类。
|
|
|
|
|
|
|
+使用 config/config.ini 中的通用 LLM 配置,通过异步并发调用 LLM API 来判断一级目录的分类。
|
|
|
"""
|
|
"""
|
|
|
|
|
|
|
|
from __future__ import annotations
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
-from collections import Counter
|
|
|
|
|
import asyncio
|
|
import asyncio
|
|
|
import json
|
|
import json
|
|
|
|
|
+import re
|
|
|
|
|
+from collections import Counter
|
|
|
from typing import Any, Dict, List, Optional
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
|
|
|
|
|
|
+from foundation.infrastructure.config.config import config_handler
|
|
|
|
|
+from foundation.observability.logger.loggering import review_logger as logger
|
|
|
|
|
+from foundation.ai.agent.generate.model_generate import generate_model_client
|
|
|
|
|
+
|
|
|
from ..interfaces import HierarchyClassifier as IHierarchyClassifier
|
|
from ..interfaces import HierarchyClassifier as IHierarchyClassifier
|
|
|
from ..config.provider import default_config_provider
|
|
from ..config.provider import default_config_provider
|
|
|
-from ..utils.llm_client import LLMClient
|
|
|
|
|
from ..utils.prompt_loader import PromptLoader
|
|
from ..utils.prompt_loader import PromptLoader
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
+def _extract_json(text: str) -> Optional[Dict[str, Any]]:
|
|
|
|
|
+ """从字符串中提取第一个有效 JSON 对象"""
|
|
|
|
|
+ for pattern in [r"```json\s*(\{.*?})\s*```", r"```\s*(\{.*?})\s*```"]:
|
|
|
|
|
+ m = re.search(pattern, text, re.DOTALL)
|
|
|
|
|
+ if m:
|
|
|
|
|
+ try:
|
|
|
|
|
+ return json.loads(m.group(1))
|
|
|
|
|
+ except json.JSONDecodeError:
|
|
|
|
|
+ pass
|
|
|
|
|
+ try:
|
|
|
|
|
+ for candidate in re.findall(r"(\{.*?\})", text, re.DOTALL):
|
|
|
|
|
+ try:
|
|
|
|
|
+ return json.loads(candidate)
|
|
|
|
|
+ except json.JSONDecodeError:
|
|
|
|
|
+ pass
|
|
|
|
|
+ except Exception:
|
|
|
|
|
+ pass
|
|
|
|
|
+ return None
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
class HierarchyClassifier(IHierarchyClassifier):
|
|
class HierarchyClassifier(IHierarchyClassifier):
|
|
|
- """基于层级结构的目录分类器(通过LLM API智能识别来分类一级目录)"""
|
|
|
|
|
|
|
+ """基于层级结构的目录分类器(通过 LLM API 智能识别来分类一级目录)"""
|
|
|
|
|
|
|
|
def __init__(self):
|
|
def __init__(self):
|
|
|
- """初始化分类器"""
|
|
|
|
|
self._cfg = default_config_provider
|
|
self._cfg = default_config_provider
|
|
|
-
|
|
|
|
|
- # 获取分类配置
|
|
|
|
|
|
|
+ self._concurrency = int(config_handler.get("llm_keywords", "CONCURRENT_WORKERS", "20"))
|
|
|
|
|
+
|
|
|
self.category_mapping = self._cfg.get("categories.mapping", {})
|
|
self.category_mapping = self._cfg.get("categories.mapping", {})
|
|
|
-
|
|
|
|
|
- # 初始化LLM客户端和提示词加载器
|
|
|
|
|
- self.llm_client = LLMClient(config_provider=self._cfg)
|
|
|
|
|
self.prompt_loader = PromptLoader()
|
|
self.prompt_loader = PromptLoader()
|
|
|
-
|
|
|
|
|
- # 获取标准类别列表(从CSV动态加载)
|
|
|
|
|
self.standard_categories = self.prompt_loader.get_standard_categories()
|
|
self.standard_categories = self.prompt_loader.get_standard_categories()
|
|
|
|
|
|
|
|
|
|
+ # ------------------------------------------------------------------
|
|
|
|
|
+ # 内部 LLM 调用
|
|
|
|
|
+ # ------------------------------------------------------------------
|
|
|
|
|
+
|
|
|
|
|
+ async def _call_once(self, messages: List[Dict[str, str]]) -> Optional[Dict[str, Any]]:
|
|
|
|
|
+ """单次异步 LLM 调用,失败返回 None"""
|
|
|
|
|
+ system_prompt = next((m["content"] for m in messages if m["role"] == "system"), "")
|
|
|
|
|
+ user_prompt = next((m["content"] for m in messages if m["role"] == "user"), "")
|
|
|
|
|
+ try:
|
|
|
|
|
+ content = await generate_model_client.get_model_generate_invoke(
|
|
|
|
|
+ trace_id="hierarchy_classifier",
|
|
|
|
|
+ system_prompt=system_prompt,
|
|
|
|
|
+ user_prompt=user_prompt,
|
|
|
|
|
+ )
|
|
|
|
|
+ result = _extract_json(content)
|
|
|
|
|
+ return result if result is not None else {"raw_content": content}
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"[HierarchyClassifier] LLM 调用失败: {e}")
|
|
|
|
|
+ return None
|
|
|
|
|
+
|
|
|
|
|
+ async def _batch_call(self, requests: List[List[Dict[str, str]]]) -> List[Optional[Dict[str, Any]]]:
|
|
|
|
|
+ """并发批量调用 LLM"""
|
|
|
|
|
+ semaphore = asyncio.Semaphore(self._concurrency)
|
|
|
|
|
+
|
|
|
|
|
+ async def bounded(msgs):
|
|
|
|
|
+ async with semaphore:
|
|
|
|
|
+ return await self._call_once(msgs)
|
|
|
|
|
+
|
|
|
|
|
+ return list(await asyncio.gather(*[bounded(r) for r in requests]))
|
|
|
|
|
+
|
|
|
|
|
+ # ------------------------------------------------------------------
|
|
|
|
|
+ # 公开接口
|
|
|
|
|
+ # ------------------------------------------------------------------
|
|
|
|
|
+
|
|
|
async def classify_async(
|
|
async def classify_async(
|
|
|
self, toc_items: List[Dict[str, Any]], target_level: int = 1
|
|
self, toc_items: List[Dict[str, Any]], target_level: int = 1
|
|
|
) -> Dict[str, Any]:
|
|
) -> Dict[str, Any]:
|
|
|
- """
|
|
|
|
|
- 异步版本的目录分类(推荐在已有事件循环中使用)。
|
|
|
|
|
- """
|
|
|
|
|
- print(f"\n正在对{target_level}级目录进行智能分类(基于LLM API识别)...")
|
|
|
|
|
-
|
|
|
|
|
- # 筛选出指定层级的目录项
|
|
|
|
|
|
|
+ """异步版目录分类(推荐在已有事件循环中使用)"""
|
|
|
|
|
+ logger.debug(f"[HierarchyClassifier] 开始对 {target_level} 级目录进行智能分类...")
|
|
|
|
|
+
|
|
|
level1_items = [item for item in toc_items if item["level"] == target_level]
|
|
level1_items = [item for item in toc_items if item["level"] == target_level]
|
|
|
-
|
|
|
|
|
if not level1_items:
|
|
if not level1_items:
|
|
|
- print(f" 警告: 未找到{target_level}级目录项")
|
|
|
|
|
- return {
|
|
|
|
|
- "items": [],
|
|
|
|
|
- "total_count": 0,
|
|
|
|
|
- "target_level": target_level,
|
|
|
|
|
- "category_stats": {},
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- print(f" 找到 {len(level1_items)} 个{target_level}级目录项")
|
|
|
|
|
-
|
|
|
|
|
- # 构建层级结构:为每个一级目录找到其对应的二级目录
|
|
|
|
|
|
|
+ logger.warning(f"[HierarchyClassifier] 未找到 {target_level} 级目录项")
|
|
|
|
|
+ return {"items": [], "total_count": 0, "target_level": target_level, "category_stats": {}}
|
|
|
|
|
+
|
|
|
|
|
+ logger.debug(f"[HierarchyClassifier] 找到 {len(level1_items)} 个 {target_level} 级目录项,准备 LLM 分类")
|
|
|
|
|
+
|
|
|
|
|
+ # 构建带二级子目录的层级结构
|
|
|
level1_with_children = []
|
|
level1_with_children = []
|
|
|
-
|
|
|
|
|
for i, level1_item in enumerate(level1_items):
|
|
for i, level1_item in enumerate(level1_items):
|
|
|
- # 找到当前一级目录在原列表中的索引
|
|
|
|
|
level1_idx = toc_items.index(level1_item)
|
|
level1_idx = toc_items.index(level1_item)
|
|
|
-
|
|
|
|
|
- # 找到下一个一级目录的索引(如果存在)
|
|
|
|
|
- if i < len(level1_items) - 1:
|
|
|
|
|
- next_level1_item = level1_items[i + 1]
|
|
|
|
|
- next_level1_idx = toc_items.index(next_level1_item)
|
|
|
|
|
- else:
|
|
|
|
|
- next_level1_idx = len(toc_items)
|
|
|
|
|
-
|
|
|
|
|
- # 提取当前一级目录下的二级目录
|
|
|
|
|
- level2_children = [
|
|
|
|
|
- item
|
|
|
|
|
- for item in toc_items[level1_idx + 1 : next_level1_idx]
|
|
|
|
|
|
|
+ next_idx = toc_items.index(level1_items[i + 1]) if i < len(level1_items) - 1 else len(toc_items)
|
|
|
|
|
+ children = [
|
|
|
|
|
+ item for item in toc_items[level1_idx + 1: next_idx]
|
|
|
if item["level"] == target_level + 1
|
|
if item["level"] == target_level + 1
|
|
|
]
|
|
]
|
|
|
-
|
|
|
|
|
- level1_with_children.append(
|
|
|
|
|
- {"level1_item": level1_item, "level2_children": level2_children}
|
|
|
|
|
- )
|
|
|
|
|
-
|
|
|
|
|
- print(f" 正在使用LLM API进行异步并发识别分类...")
|
|
|
|
|
-
|
|
|
|
|
- # 准备LLM API请求
|
|
|
|
|
|
|
+ level1_with_children.append({"level1_item": level1_item, "level2_children": children})
|
|
|
|
|
+
|
|
|
|
|
+ # 构造 LLM 请求
|
|
|
llm_requests = []
|
|
llm_requests = []
|
|
|
- for item_with_children in level1_with_children:
|
|
|
|
|
- level1_item = item_with_children["level1_item"]
|
|
|
|
|
- level2_children = item_with_children["level2_children"]
|
|
|
|
|
-
|
|
|
|
|
- # 准备二级目录标题列表
|
|
|
|
|
- level2_titles = "\n".join([f"- {child['title']}" for child in level2_children])
|
|
|
|
|
- if not level2_titles:
|
|
|
|
|
- level2_titles = "(无二级目录)"
|
|
|
|
|
-
|
|
|
|
|
- # 渲染提示词模板
|
|
|
|
|
|
|
+ for entry in level1_with_children:
|
|
|
|
|
+ level1_item = entry["level1_item"]
|
|
|
|
|
+ level2_titles = "\n".join(f"- {c['title']}" for c in entry["level2_children"]) or "(无二级目录)"
|
|
|
prompt = self.prompt_loader.render(
|
|
prompt = self.prompt_loader.render(
|
|
|
"toc_classification",
|
|
"toc_classification",
|
|
|
level1_title=level1_item["title"],
|
|
level1_title=level1_item["title"],
|
|
|
- level2_titles=level2_titles
|
|
|
|
|
|
|
+ level2_titles=level2_titles,
|
|
|
)
|
|
)
|
|
|
- # 构建消息列表
|
|
|
|
|
- messages = [
|
|
|
|
|
|
|
+ llm_requests.append([
|
|
|
{"role": "system", "content": prompt["system"]},
|
|
{"role": "system", "content": prompt["system"]},
|
|
|
- {"role": "user", "content": prompt["user"]}
|
|
|
|
|
- ]
|
|
|
|
|
- # 添加打印语句,用于调试
|
|
|
|
|
- print(f"\n--- LLM Request for '{level1_item['title']}' ---")
|
|
|
|
|
- print(f"System Prompt:\n{messages[0]['content']}")
|
|
|
|
|
- print(f"User Prompt:\n{messages[1]['content']}")
|
|
|
|
|
- print("---------------------------------------\n")
|
|
|
|
|
-
|
|
|
|
|
- llm_requests.append(messages)
|
|
|
|
|
-
|
|
|
|
|
- # 批量异步调用LLM API
|
|
|
|
|
- llm_results = await self.llm_client.batch_call_async(llm_requests)
|
|
|
|
|
-
|
|
|
|
|
- # 处理分类结果
|
|
|
|
|
|
|
+ {"role": "user", "content": prompt["user"]},
|
|
|
|
|
+ ])
|
|
|
|
|
+
|
|
|
|
|
+ # 批量调用
|
|
|
|
|
+ llm_results = await self._batch_call(llm_requests)
|
|
|
|
|
+
|
|
|
|
|
+ # 解析结果
|
|
|
classified_items = []
|
|
classified_items = []
|
|
|
- category_stats = Counter()
|
|
|
|
|
-
|
|
|
|
|
- for i, (item_with_children, llm_result) in enumerate(zip(level1_with_children, llm_results)):
|
|
|
|
|
- level1_item = item_with_children["level1_item"]
|
|
|
|
|
- level2_children = item_with_children["level2_children"]
|
|
|
|
|
-
|
|
|
|
|
- print(f" DEBUG: LLM raw result for '{level1_item['title']}': {llm_result}")
|
|
|
|
|
- # 解析LLM返回结果
|
|
|
|
|
|
|
+ category_stats: Counter = Counter()
|
|
|
|
|
+
|
|
|
|
|
+ for entry, llm_result in zip(level1_with_children, llm_results):
|
|
|
|
|
+ level1_item = entry["level1_item"]
|
|
|
|
|
+ level2_children = entry["level2_children"]
|
|
|
|
|
+
|
|
|
|
|
+ logger.debug(f"[HierarchyClassifier] '{level1_item['title']}' LLM 返回: {llm_result}")
|
|
|
|
|
+
|
|
|
if llm_result and isinstance(llm_result, dict):
|
|
if llm_result and isinstance(llm_result, dict):
|
|
|
- category_cn = llm_result.get("category_cn", "")
|
|
|
|
|
|
|
+ category_cn = llm_result.get("category_cn", "")
|
|
|
category_code = llm_result.get("category_code", "")
|
|
category_code = llm_result.get("category_code", "")
|
|
|
- confidence = llm_result.get("confidence", 0.0)
|
|
|
|
|
-
|
|
|
|
|
- # 强制移除无效的类别代码,但保留"非标准项"作为有效的兜底类别
|
|
|
|
|
- if category_code in ["non_standard_invalid", "unknown"]:
|
|
|
|
|
- category_cn = ""
|
|
|
|
|
- category_code = ""
|
|
|
|
|
-
|
|
|
|
|
- # 验证类别是否在标准类别列表中("非标准项"是特殊的兜底类别,也是有效的)
|
|
|
|
|
- if not category_cn or (category_cn not in self.standard_categories and category_cn != "非标准项"):
|
|
|
|
|
- # 如果不在标准类别中,强制使用"非标准项"作为兜底
|
|
|
|
|
|
|
+ confidence = llm_result.get("confidence", 0.0)
|
|
|
|
|
+
|
|
|
|
|
+ if category_code in ("non_standard_invalid", "unknown"):
|
|
|
|
|
+ category_cn = category_code = ""
|
|
|
|
|
+
|
|
|
|
|
+ if not category_cn or (
|
|
|
|
|
+ category_cn not in self.standard_categories and category_cn != "非标准项"
|
|
|
|
|
+ ):
|
|
|
if category_cn and category_cn != "非标准项":
|
|
if category_cn and category_cn != "非标准项":
|
|
|
- print(f" 警告: LLM返回的类别 '{category_cn}' 不在标准类别中,归类为'非标准项'")
|
|
|
|
|
- elif not category_cn:
|
|
|
|
|
- print(f" 警告: LLM返回的类别为空或无效,归类为'非标准项'")
|
|
|
|
|
|
|
+ logger.warning(
|
|
|
|
|
+ f"[HierarchyClassifier] '{level1_item['title']}' "
|
|
|
|
|
+ f"LLM 返回类别 '{category_cn}' 不在标准列表,归为'非标准项'"
|
|
|
|
|
+ )
|
|
|
|
|
+ else:
|
|
|
|
|
+ logger.warning(
|
|
|
|
|
+ f"[HierarchyClassifier] '{level1_item['title']}' "
|
|
|
|
|
+ f"LLM 返回类别为空或无效,归为'非标准项'"
|
|
|
|
|
+ )
|
|
|
category_cn = "非标准项"
|
|
category_cn = "非标准项"
|
|
|
category_code = "non_standard"
|
|
category_code = "non_standard"
|
|
|
-
|
|
|
|
|
- # 确保category_code与mapping一致
|
|
|
|
|
|
|
+
|
|
|
if category_cn in self.category_mapping:
|
|
if category_cn in self.category_mapping:
|
|
|
category_code = self.category_mapping.get(category_cn, category_code)
|
|
category_code = self.category_mapping.get(category_cn, category_code)
|
|
|
elif category_cn == "非标准项":
|
|
elif category_cn == "非标准项":
|
|
|
category_code = "non_standard"
|
|
category_code = "non_standard"
|
|
|
else:
|
|
else:
|
|
|
- # LLM调用失败,使用"非标准项"作为兜底
|
|
|
|
|
- print(f" 警告: 一级目录 '{level1_item['title']}' 的LLM分类失败,归类为'非标准项'")
|
|
|
|
|
|
|
+ logger.error(
|
|
|
|
|
+ f"[HierarchyClassifier] '{level1_item['title']}' LLM 分类失败,归为'非标准项'"
|
|
|
|
|
+ )
|
|
|
category_cn = "非标准项"
|
|
category_cn = "非标准项"
|
|
|
category_code = "non_standard"
|
|
category_code = "non_standard"
|
|
|
confidence = 0.0
|
|
confidence = 0.0
|
|
|
-
|
|
|
|
|
- classified_items.append(
|
|
|
|
|
- {
|
|
|
|
|
- "title": level1_item["title"],
|
|
|
|
|
- "page": level1_item["page"],
|
|
|
|
|
- "level": level1_item["level"],
|
|
|
|
|
- "category": category_cn,
|
|
|
|
|
- "category_code": category_code,
|
|
|
|
|
- "original": level1_item.get("original", ""),
|
|
|
|
|
- "level2_count": len(level2_children),
|
|
|
|
|
- "level2_titles": [child["title"] for child in level2_children],
|
|
|
|
|
- "confidence": confidence if llm_result else 0.0,
|
|
|
|
|
- }
|
|
|
|
|
- )
|
|
|
|
|
-
|
|
|
|
|
|
|
+
|
|
|
|
|
+ classified_items.append({
|
|
|
|
|
+ "title": level1_item["title"],
|
|
|
|
|
+ "page": level1_item["page"],
|
|
|
|
|
+ "level": level1_item["level"],
|
|
|
|
|
+ "category": category_cn,
|
|
|
|
|
+ "category_code": category_code,
|
|
|
|
|
+ "original": level1_item.get("original", ""),
|
|
|
|
|
+ "level2_count": len(level2_children),
|
|
|
|
|
+ "level2_titles": [c["title"] for c in level2_children],
|
|
|
|
|
+ "confidence": confidence if llm_result else 0.0,
|
|
|
|
|
+ })
|
|
|
category_stats[category_cn] += 1
|
|
category_stats[category_cn] += 1
|
|
|
-
|
|
|
|
|
- print(f" 分类完成!共分类 {len(classified_items)} 个目录项")
|
|
|
|
|
-
|
|
|
|
|
|
|
+
|
|
|
|
|
+ logger.debug(
|
|
|
|
|
+ f"[HierarchyClassifier] 分类完成,共 {len(classified_items)} 个目录项,"
|
|
|
|
|
+ f"分布: {dict(category_stats)}"
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
return {
|
|
return {
|
|
|
- "items": classified_items,
|
|
|
|
|
- "total_count": len(classified_items),
|
|
|
|
|
- "target_level": target_level,
|
|
|
|
|
|
|
+ "items": classified_items,
|
|
|
|
|
+ "total_count": len(classified_items),
|
|
|
|
|
+ "target_level": target_level,
|
|
|
"category_stats": dict(category_stats),
|
|
"category_stats": dict(category_stats),
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
def classify(
|
|
def classify(
|
|
|
self, toc_items: List[Dict[str, Any]], target_level: int = 1
|
|
self, toc_items: List[Dict[str, Any]], target_level: int = 1
|
|
|
) -> Dict[str, Any]:
|
|
) -> Dict[str, Any]:
|
|
|
- """
|
|
|
|
|
- 同步包装,内部调用异步实现。适合无事件循环的同步场景。
|
|
|
|
|
- """
|
|
|
|
|
|
|
+ """同步包装,内部调用异步实现。适合无事件循环的同步场景。"""
|
|
|
try:
|
|
try:
|
|
|
return asyncio.run(self.classify_async(toc_items, target_level))
|
|
return asyncio.run(self.classify_async(toc_items, target_level))
|
|
|
except RuntimeError as exc:
|
|
except RuntimeError as exc:
|