Просмотр исходного кода

Merge branch 'dev_sgsc_wxm' of CRBC-MaaS-Platform-Project/LQAgentPlatform into dev

WangXuMing 2 недель назад
Родитель
Сommit
1c21351807

+ 124 - 69
core/construction_review/component/doc_worker/classification/chunk_classifier.py

@@ -8,14 +8,18 @@ from __future__ import annotations
 
 
 import asyncio
 import asyncio
 import csv
 import csv
+import json
+import re
 from collections import OrderedDict
 from collections import OrderedDict
 from pathlib import Path
 from pathlib import Path
 from typing import Any, Dict, List, Optional
 from typing import Any, Dict, List, Optional
 
 
+from foundation.infrastructure.config.config import config_handler
+from foundation.observability.logger.loggering import review_logger as logger
+from foundation.ai.agent.generate.model_generate import generate_model_client
+
 from ..config.provider import default_config_provider
 from ..config.provider import default_config_provider
-from ..utils.llm_client import LLMClient
 from ..utils.prompt_loader import PromptLoader
 from ..utils.prompt_loader import PromptLoader
-from foundation.observability.logger.loggering import review_logger as logger
 
 
 # 延迟导入新的三级分类器(避免循环导入)
 # 延迟导入新的三级分类器(避免循环导入)
 _LLM_CONTENT_CLASSIFIER = None
 _LLM_CONTENT_CLASSIFIER = None
@@ -33,30 +37,50 @@ def _get_llm_content_classifier():
     return _LLM_CONTENT_CLASSIFIER
     return _LLM_CONTENT_CLASSIFIER
 
 
 
 
+def _extract_json(text: str) -> Optional[Dict[str, Any]]:
+    """从字符串中提取第一个有效 JSON 对象"""
+    for pattern in [r"```json\s*(\{.*?})\s*```", r"```\s*(\{.*?})\s*```"]:
+        m = re.search(pattern, text, re.DOTALL)
+        if m:
+            try:
+                return json.loads(m.group(1))
+            except json.JSONDecodeError:
+                pass
+    try:
+        for candidate in re.findall(r"(\{.*?\})", text, re.DOTALL):
+            try:
+                return json.loads(candidate)
+            except json.JSONDecodeError:
+                pass
+    except Exception:
+        pass
+    return None
+
+
 class ChunkClassifier:
 class ChunkClassifier:
     """内容块分类器(二级和三级分类)"""
     """内容块分类器(二级和三级分类)"""
 
 
     def __init__(self):
     def __init__(self):
         """初始化分类器"""
         """初始化分类器"""
         self._cfg = default_config_provider
         self._cfg = default_config_provider
-        
-        # 初始化LLM客户端和提示词加载器
-        self.llm_client = LLMClient(config_provider=self._cfg)
+        self._concurrency = int(config_handler.get("llm_keywords", "CONCURRENT_WORKERS", "20"))
+
+        # 初始化提示词加载器
         self.prompt_loader = PromptLoader()
         self.prompt_loader = PromptLoader()
-        
+
         # 加载CSV分类标准
         # 加载CSV分类标准
         self._load_classification_standards()
         self._load_classification_standards()
 
 
     def _load_classification_standards(self):
     def _load_classification_standards(self):
         """从CSV文件加载二级和三级分类标准"""
         """从CSV文件加载二级和三级分类标准"""
         csv_file = Path(__file__).parent.parent / "config" / "StandardCategoryTable.csv"
         csv_file = Path(__file__).parent.parent / "config" / "StandardCategoryTable.csv"
-        
+
         if not csv_file.exists():
         if not csv_file.exists():
             raise FileNotFoundError(f"分类标准CSV文件不存在: {csv_file}")
             raise FileNotFoundError(f"分类标准CSV文件不存在: {csv_file}")
-        
+
         # 结构: {first_code: {second_code: {second_cn, second_focus, third_items: [{third_code, third_cn, third_focus}]}}}
         # 结构: {first_code: {second_code: {second_cn, second_focus, third_items: [{third_code, third_cn, third_focus}]}}}
         self.classification_tree: Dict[str, Dict[str, Any]] = {}
         self.classification_tree: Dict[str, Dict[str, Any]] = {}
-        
+
         with csv_file.open("r", encoding="utf-8-sig") as f:
         with csv_file.open("r", encoding="utf-8-sig") as f:
             reader = csv.DictReader(f)
             reader = csv.DictReader(f)
             for row in reader:
             for row in reader:
@@ -68,14 +92,14 @@ class ChunkClassifier:
                 third_code = (row.get("third_code") or "").strip()
                 third_code = (row.get("third_code") or "").strip()
                 third_cn = (row.get("third_name") or "").strip()
                 third_cn = (row.get("third_name") or "").strip()
                 third_focus = (row.get("third_focus") or "").strip()
                 third_focus = (row.get("third_focus") or "").strip()
-                
+
                 if not first_code or not second_code:
                 if not first_code or not second_code:
                     continue
                     continue
-                
+
                 # 初始化一级类别
                 # 初始化一级类别
                 if first_code not in self.classification_tree:
                 if first_code not in self.classification_tree:
                     self.classification_tree[first_code] = {}
                     self.classification_tree[first_code] = {}
-                
+
                 # 初始化二级类别
                 # 初始化二级类别
                 if second_code not in self.classification_tree[first_code]:
                 if second_code not in self.classification_tree[first_code]:
                     self.classification_tree[first_code][second_code] = {
                     self.classification_tree[first_code][second_code] = {
@@ -83,7 +107,7 @@ class ChunkClassifier:
                         "second_focus": second_focus,
                         "second_focus": second_focus,
                         "third_items": []
                         "third_items": []
                     }
                     }
-                
+
                 # 添加三级类别(如果存在)
                 # 添加三级类别(如果存在)
                 if third_code and third_cn:
                 if third_code and third_cn:
                     self.classification_tree[first_code][second_code]["third_items"].append({
                     self.classification_tree[first_code][second_code]["third_items"].append({
@@ -95,102 +119,143 @@ class ChunkClassifier:
     def _build_secondary_standards(self, first_category_code: str) -> tuple[str, dict]:
     def _build_secondary_standards(self, first_category_code: str) -> tuple[str, dict]:
         """
         """
         构建二级分类标准文本
         构建二级分类标准文本
-        
+
         返回:
         返回:
             (标准文本, 索引映射字典)
             (标准文本, 索引映射字典)
         """
         """
         if first_category_code not in self.classification_tree:
         if first_category_code not in self.classification_tree:
             return "(无二级分类标准)", {}
             return "(无二级分类标准)", {}
-        
+
         standards_lines = ["    0. 非标准项 - 不符合以下任何类别"]
         standards_lines = ["    0. 非标准项 - 不符合以下任何类别"]
         index_mapping = {0: ("非标准项", "non_standard")}
         index_mapping = {0: ("非标准项", "non_standard")}
-        
+
         for idx, (second_code, second_data) in enumerate(self.classification_tree[first_category_code].items(), 1):
         for idx, (second_code, second_data) in enumerate(self.classification_tree[first_category_code].items(), 1):
             second_cn = second_data["second_cn"]
             second_cn = second_data["second_cn"]
             second_focus = second_data["second_focus"]
             second_focus = second_data["second_focus"]
-            
+
             # 保存索引映射
             # 保存索引映射
             index_mapping[idx] = (second_cn, second_code)
             index_mapping[idx] = (second_cn, second_code)
-            
+
             if second_focus and second_focus != "NULL":
             if second_focus and second_focus != "NULL":
                 standards_lines.append(f"    {idx}. {second_cn} - 关注点:{second_focus}")
                 standards_lines.append(f"    {idx}. {second_cn} - 关注点:{second_focus}")
             else:
             else:
                 standards_lines.append(f"    {idx}. {second_cn}")
                 standards_lines.append(f"    {idx}. {second_cn}")
-        
+
         return "\n".join(standards_lines) if standards_lines else "(无二级分类标准)", index_mapping
         return "\n".join(standards_lines) if standards_lines else "(无二级分类标准)", index_mapping
 
 
     def _build_tertiary_standards(self, first_category_code: str, second_category_code: str) -> tuple[str, dict]:
     def _build_tertiary_standards(self, first_category_code: str, second_category_code: str) -> tuple[str, dict]:
         """
         """
         构建三级分类标准文本
         构建三级分类标准文本
-        
+
         返回:
         返回:
             (标准文本, 索引映射字典)
             (标准文本, 索引映射字典)
         """
         """
         if first_category_code not in self.classification_tree:
         if first_category_code not in self.classification_tree:
             return "(无三级分类标准)", {}
             return "(无三级分类标准)", {}
-        
+
         if second_category_code not in self.classification_tree[first_category_code]:
         if second_category_code not in self.classification_tree[first_category_code]:
             return "(无三级分类标准)", {}
             return "(无三级分类标准)", {}
-        
+
         third_items = self.classification_tree[first_category_code][second_category_code]["third_items"]
         third_items = self.classification_tree[first_category_code][second_category_code]["third_items"]
-        
+
         if not third_items:
         if not third_items:
             return "(无三级分类标准)", {}
             return "(无三级分类标准)", {}
-        
+
         standards_lines = ["    0. 非标准项 - 不符合以下任何类别"]
         standards_lines = ["    0. 非标准项 - 不符合以下任何类别"]
         index_mapping = {0: ("非标准项", "non_standard")}
         index_mapping = {0: ("非标准项", "non_standard")}
-        
+
         for idx, third_item in enumerate(third_items, 1):
         for idx, third_item in enumerate(third_items, 1):
             third_cn = third_item["third_cn"]
             third_cn = third_item["third_cn"]
             third_code = third_item["third_code"]
             third_code = third_item["third_code"]
             third_focus = third_item["third_focus"]
             third_focus = third_item["third_focus"]
-            
+
             # 保存索引映射
             # 保存索引映射
             index_mapping[idx] = (third_cn, third_code)
             index_mapping[idx] = (third_cn, third_code)
-            
+
             if third_focus and third_focus != "NULL":
             if third_focus and third_focus != "NULL":
                 standards_lines.append(f"    {idx}. {third_cn} - 关注点:{third_focus}")
                 standards_lines.append(f"    {idx}. {third_cn} - 关注点:{third_focus}")
             else:
             else:
                 standards_lines.append(f"    {idx}. {third_cn}")
                 standards_lines.append(f"    {idx}. {third_cn}")
-        
+
         return "\n".join(standards_lines), index_mapping
         return "\n".join(standards_lines), index_mapping
 
 
+    async def _call_llm_once(self, system_prompt: str, user_prompt: str) -> Optional[Dict[str, Any]]:
+        """
+        单次异步 LLM 调用(使用统一的 GenerateModelClient)
+
+        失败返回 None,由调用方决定处理逻辑
+        """
+        try:
+            content = await generate_model_client.get_model_generate_invoke(
+                trace_id="chunk_classifier",
+                system_prompt=system_prompt,
+                user_prompt=user_prompt,
+                model_name="qwen3_5_122b_a10b",  # 使用 122B 大模型提升分类准确性
+            )
+            result = _extract_json(content)
+            return result if result is not None else {"raw_content": content}
+        except Exception as e:
+            logger.error(f"[ChunkClassifier] LLM 调用失败: {e}")
+            return None
+
+    async def _batch_call_llm(
+        self,
+        requests: List[tuple],  # [(system_prompt, user_prompt), ...]
+    ) -> List[Optional[Dict[str, Any]]]:
+        """
+        并发批量调用 LLM(带信号量控制)
+
+        参数:
+            requests: 请求列表,每个元素是 (system_prompt, user_prompt) 元组
+
+        返回:
+            结果列表,与输入请求一一对应
+        """
+        semaphore = asyncio.Semaphore(self._concurrency)
+
+        async def bounded_call(system_prompt: str, user_prompt: str):
+            async with semaphore:
+                return await self._call_llm_once(system_prompt, user_prompt)
+
+        tasks = [bounded_call(sp, up) for sp, up in requests]
+        return list(await asyncio.gather(*tasks))
+
     async def classify_chunks_secondary_async(self, chunks: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
     async def classify_chunks_secondary_async(self, chunks: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
         """
         """
         异步对chunks进行二级分类
         异步对chunks进行二级分类
-        
+
         参数:
         参数:
             chunks: 已完成一级分类的chunk列表
             chunks: 已完成一级分类的chunk列表
-            
+
         返回:
         返回:
             添加了二级分类字段的chunk列表
             添加了二级分类字段的chunk列表
         """
         """
         logger.info(f"正在对 {len(chunks)} 个内容块进行二级分类...")
         logger.info(f"正在对 {len(chunks)} 个内容块进行二级分类...")
-        
+
         # 准备LLM请求
         # 准备LLM请求
         llm_requests = []
         llm_requests = []
         valid_chunks = []
         valid_chunks = []
         index_mappings = []  # 保存每个请求对应的索引映射
         index_mappings = []  # 保存每个请求对应的索引映射
-        
+
         for chunk in chunks:
         for chunk in chunks:
             first_category_code = chunk.get("chapter_classification", "")
             first_category_code = chunk.get("chapter_classification", "")
             chunk_title = chunk.get("section_label", "")
             chunk_title = chunk.get("section_label", "")
             hierarchy_path = " -> ".join(chunk.get("hierarchy_path", []))
             hierarchy_path = " -> ".join(chunk.get("hierarchy_path", []))
             content = chunk.get("review_chunk_content", "")
             content = chunk.get("review_chunk_content", "")
             content_preview = content[:300] if content else ""
             content_preview = content[:300] if content else ""
-            
+
             # 获取一级分类的中文名称
             # 获取一级分类的中文名称
             first_category_cn = self._get_first_category_cn(first_category_code)
             first_category_cn = self._get_first_category_cn(first_category_code)
-            
+
             # 构建二级分类标准(返回标准文本和索引映射)
             # 构建二级分类标准(返回标准文本和索引映射)
             secondary_standards, index_mapping = self._build_secondary_standards(first_category_code)
             secondary_standards, index_mapping = self._build_secondary_standards(first_category_code)
-            
+
             if secondary_standards == "(无二级分类标准)":
             if secondary_standards == "(无二级分类标准)":
                 # 如果没有二级分类标准,跳过
                 # 如果没有二级分类标准,跳过
                 chunk["secondary_category_cn"] = "无"
                 chunk["secondary_category_cn"] = "无"
                 chunk["secondary_category_code"] = "none"
                 chunk["secondary_category_code"] = "none"
                 continue
                 continue
-            
+
             # 渲染提示词
             # 渲染提示词
             prompt = self.prompt_loader.render(
             prompt = self.prompt_loader.render(
                 "chunk_secondary_classification",
                 "chunk_secondary_classification",
@@ -200,28 +265,23 @@ class ChunkClassifier:
                 content_preview=content_preview,
                 content_preview=content_preview,
                 secondary_standards=secondary_standards
                 secondary_standards=secondary_standards
             )
             )
-            
-            messages = [
-                {"role": "system", "content": prompt["system"]},
-                {"role": "user", "content": prompt["user"]}
-            ]
-            
-            llm_requests.append(messages)
+
+            llm_requests.append((prompt["system"], prompt["user"]))
             valid_chunks.append(chunk)
             valid_chunks.append(chunk)
             index_mappings.append(index_mapping)
             index_mappings.append(index_mapping)
-        
+
         if not llm_requests:
         if not llm_requests:
             logger.info("所有内容块都没有二级分类标准,跳过二级分类")
             logger.info("所有内容块都没有二级分类标准,跳过二级分类")
             return chunks
             return chunks
-        
+
         # 批量异步调用LLM API
         # 批量异步调用LLM API
-        llm_results = await self.llm_client.batch_call_async(llm_requests)
-        
+        llm_results = await self._batch_call_llm(llm_requests)
+
         # 处理分类结果
         # 处理分类结果
         for chunk, llm_result, index_mapping in zip(valid_chunks, llm_results, index_mappings):
         for chunk, llm_result, index_mapping in zip(valid_chunks, llm_results, index_mappings):
             if llm_result and isinstance(llm_result, dict):
             if llm_result and isinstance(llm_result, dict):
                 category_index = llm_result.get("category_index")
                 category_index = llm_result.get("category_index")
-                
+
                 # 验证索引并映射到类别
                 # 验证索引并映射到类别
                 if isinstance(category_index, int) and category_index in index_mapping:
                 if isinstance(category_index, int) and category_index in index_mapping:
                     secondary_cn, secondary_code = index_mapping[category_index]
                     secondary_cn, secondary_code = index_mapping[category_index]
@@ -235,7 +295,7 @@ class ChunkClassifier:
             else:
             else:
                 chunk["secondary_category_cn"] = "非标准项"
                 chunk["secondary_category_cn"] = "非标准项"
                 chunk["secondary_category_code"] = "non_standard"
                 chunk["secondary_category_code"] = "non_standard"
-        
+
         logger.info("二级分类完成!")
         logger.info("二级分类完成!")
         return chunks
         return chunks
 
 
@@ -317,12 +377,12 @@ class ChunkClassifier:
         每个chunk只能属于一个三级分类
         每个chunk只能属于一个三级分类
         """
         """
         logger.info(f"正在对 {len(chunks)} 个内容块进行三级分类...")
         logger.info(f"正在对 {len(chunks)} 个内容块进行三级分类...")
-        
+
         # 准备LLM请求
         # 准备LLM请求
         llm_requests = []
         llm_requests = []
         valid_chunks = []
         valid_chunks = []
         index_mappings = []  # 保存每个请求对应的索引映射
         index_mappings = []  # 保存每个请求对应的索引映射
-        
+
         for chunk in chunks:
         for chunk in chunks:
             first_category_code = chunk.get("chapter_classification", "")
             first_category_code = chunk.get("chapter_classification", "")
             second_category_code = chunk.get("secondary_category_code", "")
             second_category_code = chunk.get("secondary_category_code", "")
@@ -330,19 +390,19 @@ class ChunkClassifier:
             chunk_title = chunk.get("section_label", "")
             chunk_title = chunk.get("section_label", "")
             content = chunk.get("review_chunk_content", "")
             content = chunk.get("review_chunk_content", "")
             content_preview = content[:300] if content else ""
             content_preview = content[:300] if content else ""
-            
+
             # 获取一级分类的中文名称
             # 获取一级分类的中文名称
             first_category_cn = self._get_first_category_cn(first_category_code)
             first_category_cn = self._get_first_category_cn(first_category_code)
-            
+
             # 构建三级分类标准(返回标准文本和索引映射)
             # 构建三级分类标准(返回标准文本和索引映射)
             tertiary_standards, index_mapping = self._build_tertiary_standards(first_category_code, second_category_code)
             tertiary_standards, index_mapping = self._build_tertiary_standards(first_category_code, second_category_code)
-            
+
             if tertiary_standards == "(无三级分类标准)":
             if tertiary_standards == "(无三级分类标准)":
                 # 如果没有三级分类标准,跳过
                 # 如果没有三级分类标准,跳过
                 chunk["tertiary_category_cn"] = "无"
                 chunk["tertiary_category_cn"] = "无"
                 chunk["tertiary_category_code"] = "none"
                 chunk["tertiary_category_code"] = "none"
                 continue
                 continue
-            
+
             # 渲染提示词
             # 渲染提示词
             prompt = self.prompt_loader.render(
             prompt = self.prompt_loader.render(
                 "chunk_tertiary_classification",
                 "chunk_tertiary_classification",
@@ -352,28 +412,23 @@ class ChunkClassifier:
                 content_preview=content_preview,
                 content_preview=content_preview,
                 tertiary_standards=tertiary_standards
                 tertiary_standards=tertiary_standards
             )
             )
-            
-            messages = [
-                {"role": "system", "content": prompt["system"]},
-                {"role": "user", "content": prompt["user"]}
-            ]
-            
-            llm_requests.append(messages)
+
+            llm_requests.append((prompt["system"], prompt["user"]))
             valid_chunks.append(chunk)
             valid_chunks.append(chunk)
             index_mappings.append(index_mapping)
             index_mappings.append(index_mapping)
-        
+
         if not llm_requests:
         if not llm_requests:
             logger.info("所有内容块都没有三级分类标准,跳过三级分类")
             logger.info("所有内容块都没有三级分类标准,跳过三级分类")
             return chunks
             return chunks
-        
+
         # 批量异步调用LLM API
         # 批量异步调用LLM API
-        llm_results = await self.llm_client.batch_call_async(llm_requests)
-        
+        llm_results = await self._batch_call_llm(llm_requests)
+
         # 处理分类结果
         # 处理分类结果
         for chunk, llm_result, index_mapping in zip(valid_chunks, llm_results, index_mappings):
         for chunk, llm_result, index_mapping in zip(valid_chunks, llm_results, index_mappings):
             if llm_result and isinstance(llm_result, dict):
             if llm_result and isinstance(llm_result, dict):
                 category_index = llm_result.get("category_index")
                 category_index = llm_result.get("category_index")
-                
+
                 # 验证索引并映射到类别
                 # 验证索引并映射到类别
                 if isinstance(category_index, int) and category_index in index_mapping:
                 if isinstance(category_index, int) and category_index in index_mapping:
                     tertiary_cn, tertiary_code = index_mapping[category_index]
                     tertiary_cn, tertiary_code = index_mapping[category_index]
@@ -387,7 +442,7 @@ class ChunkClassifier:
             else:
             else:
                 chunk["tertiary_category_cn"] = "非标准项"
                 chunk["tertiary_category_cn"] = "非标准项"
                 chunk["tertiary_category_code"] = "non_standard"
                 chunk["tertiary_category_code"] = "non_standard"
-        
+
         logger.info("三级分类完成!")
         logger.info("三级分类完成!")
         return chunks
         return chunks
 
 
@@ -435,4 +490,4 @@ class ChunkClassifier:
                 classifier_config=classifier_config
                 classifier_config=classifier_config
             ))
             ))
         except RuntimeError:
         except RuntimeError:
-            raise RuntimeError("请使用 await classify_chunks_tertiary_async")
+            raise RuntimeError("请使用 await classify_chunks_tertiary_async")

+ 1 - 0
core/construction_review/component/doc_worker/classification/hierarchy_classifier.py

@@ -65,6 +65,7 @@ class HierarchyClassifier(IHierarchyClassifier):
                 trace_id="hierarchy_classifier",
                 trace_id="hierarchy_classifier",
                 system_prompt=system_prompt,
                 system_prompt=system_prompt,
                 user_prompt=user_prompt,
                 user_prompt=user_prompt,
+                model_name="qwen3_5_122b_a10b",  # 使用 122B 大模型提升分类准确性
             )
             )
             result = _extract_json(content)
             result = _extract_json(content)
             return result if result is not None else {"raw_content": content}
             return result if result is not None else {"raw_content": content}

+ 1 - 1
foundation/ai/agent/generate/model_generate.py

@@ -303,4 +303,4 @@ class GenerateModelClient:
             logger.error(f"[模型流式调用] 异常 trace_id: {trace_id}, 耗时: {elapsed_time:.2f}s, 错误: {type(e).__name__}: {str(e)}")
             logger.error(f"[模型流式调用] 异常 trace_id: {trace_id}, 耗时: {elapsed_time:.2f}s, 错误: {type(e).__name__}: {str(e)}")
             raise
             raise
 
 
-generate_model_client = GenerateModelClient(default_timeout=15, max_retries=2, backoff_factor=0.5)
+generate_model_client = GenerateModelClient(default_timeout=60, max_retries=10, backoff_factor=0.5)