Просмотр исходного кода

Merge branch 'dev_sgsc_wxm' of CRBC-MaaS-Platform-Project/LQAgentPlatform into dev

WangXuMing 2 недель назад
Родитель
Сommit
1c21351807

+ 124 - 69
core/construction_review/component/doc_worker/classification/chunk_classifier.py

@@ -8,14 +8,18 @@ from __future__ import annotations
 
 import asyncio
 import csv
+import json
+import re
 from collections import OrderedDict
 from pathlib import Path
 from typing import Any, Dict, List, Optional
 
+from foundation.infrastructure.config.config import config_handler
+from foundation.observability.logger.loggering import review_logger as logger
+from foundation.ai.agent.generate.model_generate import generate_model_client
+
 from ..config.provider import default_config_provider
-from ..utils.llm_client import LLMClient
 from ..utils.prompt_loader import PromptLoader
-from foundation.observability.logger.loggering import review_logger as logger
 
 # 延迟导入新的三级分类器(避免循环导入)
 _LLM_CONTENT_CLASSIFIER = None
@@ -33,30 +37,50 @@ def _get_llm_content_classifier():
     return _LLM_CONTENT_CLASSIFIER
 
 
+def _extract_json(text: str) -> Optional[Dict[str, Any]]:
+    """从字符串中提取第一个有效 JSON 对象"""
+    for pattern in [r"```json\s*(\{.*?})\s*```", r"```\s*(\{.*?})\s*```"]:
+        m = re.search(pattern, text, re.DOTALL)
+        if m:
+            try:
+                return json.loads(m.group(1))
+            except json.JSONDecodeError:
+                pass
+    try:
+        for candidate in re.findall(r"(\{.*?\})", text, re.DOTALL):
+            try:
+                return json.loads(candidate)
+            except json.JSONDecodeError:
+                pass
+    except Exception:
+        pass
+    return None
+
+
 class ChunkClassifier:
     """内容块分类器(二级和三级分类)"""
 
     def __init__(self):
         """初始化分类器"""
         self._cfg = default_config_provider
-        
-        # 初始化LLM客户端和提示词加载器
-        self.llm_client = LLMClient(config_provider=self._cfg)
+        self._concurrency = int(config_handler.get("llm_keywords", "CONCURRENT_WORKERS", "20"))
+
+        # 初始化提示词加载器
         self.prompt_loader = PromptLoader()
-        
+
         # 加载CSV分类标准
         self._load_classification_standards()
 
     def _load_classification_standards(self):
         """从CSV文件加载二级和三级分类标准"""
         csv_file = Path(__file__).parent.parent / "config" / "StandardCategoryTable.csv"
-        
+
         if not csv_file.exists():
             raise FileNotFoundError(f"分类标准CSV文件不存在: {csv_file}")
-        
+
         # 结构: {first_code: {second_code: {second_cn, second_focus, third_items: [{third_code, third_cn, third_focus}]}}}
         self.classification_tree: Dict[str, Dict[str, Any]] = {}
-        
+
         with csv_file.open("r", encoding="utf-8-sig") as f:
             reader = csv.DictReader(f)
             for row in reader:
@@ -68,14 +92,14 @@ class ChunkClassifier:
                 third_code = (row.get("third_code") or "").strip()
                 third_cn = (row.get("third_name") or "").strip()
                 third_focus = (row.get("third_focus") or "").strip()
-                
+
                 if not first_code or not second_code:
                     continue
-                
+
                 # 初始化一级类别
                 if first_code not in self.classification_tree:
                     self.classification_tree[first_code] = {}
-                
+
                 # 初始化二级类别
                 if second_code not in self.classification_tree[first_code]:
                     self.classification_tree[first_code][second_code] = {
@@ -83,7 +107,7 @@ class ChunkClassifier:
                         "second_focus": second_focus,
                         "third_items": []
                     }
-                
+
                 # 添加三级类别(如果存在)
                 if third_code and third_cn:
                     self.classification_tree[first_code][second_code]["third_items"].append({
@@ -95,102 +119,143 @@ class ChunkClassifier:
     def _build_secondary_standards(self, first_category_code: str) -> tuple[str, dict]:
         """
         构建二级分类标准文本
-        
+
         返回:
             (标准文本, 索引映射字典)
         """
         if first_category_code not in self.classification_tree:
             return "(无二级分类标准)", {}
-        
+
         standards_lines = ["    0. 非标准项 - 不符合以下任何类别"]
         index_mapping = {0: ("非标准项", "non_standard")}
-        
+
         for idx, (second_code, second_data) in enumerate(self.classification_tree[first_category_code].items(), 1):
             second_cn = second_data["second_cn"]
             second_focus = second_data["second_focus"]
-            
+
             # 保存索引映射
             index_mapping[idx] = (second_cn, second_code)
-            
+
             if second_focus and second_focus != "NULL":
                 standards_lines.append(f"    {idx}. {second_cn} - 关注点:{second_focus}")
             else:
                 standards_lines.append(f"    {idx}. {second_cn}")
-        
+
         return "\n".join(standards_lines) if standards_lines else "(无二级分类标准)", index_mapping
 
     def _build_tertiary_standards(self, first_category_code: str, second_category_code: str) -> tuple[str, dict]:
         """
         构建三级分类标准文本
-        
+
         返回:
             (标准文本, 索引映射字典)
         """
         if first_category_code not in self.classification_tree:
             return "(无三级分类标准)", {}
-        
+
         if second_category_code not in self.classification_tree[first_category_code]:
             return "(无三级分类标准)", {}
-        
+
         third_items = self.classification_tree[first_category_code][second_category_code]["third_items"]
-        
+
         if not third_items:
             return "(无三级分类标准)", {}
-        
+
         standards_lines = ["    0. 非标准项 - 不符合以下任何类别"]
         index_mapping = {0: ("非标准项", "non_standard")}
-        
+
         for idx, third_item in enumerate(third_items, 1):
             third_cn = third_item["third_cn"]
             third_code = third_item["third_code"]
             third_focus = third_item["third_focus"]
-            
+
             # 保存索引映射
             index_mapping[idx] = (third_cn, third_code)
-            
+
             if third_focus and third_focus != "NULL":
                 standards_lines.append(f"    {idx}. {third_cn} - 关注点:{third_focus}")
             else:
                 standards_lines.append(f"    {idx}. {third_cn}")
-        
+
         return "\n".join(standards_lines), index_mapping
 
+    async def _call_llm_once(self, system_prompt: str, user_prompt: str) -> Optional[Dict[str, Any]]:
+        """
+        单次异步 LLM 调用(使用统一的 GenerateModelClient)
+
+        失败返回 None,由调用方决定处理逻辑
+        """
+        try:
+            content = await generate_model_client.get_model_generate_invoke(
+                trace_id="chunk_classifier",
+                system_prompt=system_prompt,
+                user_prompt=user_prompt,
+                model_name="qwen3_5_122b_a10b",  # 使用 122B 大模型提升分类准确性
+            )
+            result = _extract_json(content)
+            return result if result is not None else {"raw_content": content}
+        except Exception as e:
+            logger.error(f"[ChunkClassifier] LLM 调用失败: {e}")
+            return None
+
+    async def _batch_call_llm(
+        self,
+        requests: List[tuple],  # [(system_prompt, user_prompt), ...]
+    ) -> List[Optional[Dict[str, Any]]]:
+        """
+        并发批量调用 LLM(带信号量控制)
+
+        参数:
+            requests: 请求列表,每个元素是 (system_prompt, user_prompt) 元组
+
+        返回:
+            结果列表,与输入请求一一对应
+        """
+        semaphore = asyncio.Semaphore(self._concurrency)
+
+        async def bounded_call(system_prompt: str, user_prompt: str):
+            async with semaphore:
+                return await self._call_llm_once(system_prompt, user_prompt)
+
+        tasks = [bounded_call(sp, up) for sp, up in requests]
+        return list(await asyncio.gather(*tasks))
+
     async def classify_chunks_secondary_async(self, chunks: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
         """
         异步对chunks进行二级分类
-        
+
         参数:
             chunks: 已完成一级分类的chunk列表
-            
+
         返回:
             添加了二级分类字段的chunk列表
         """
         logger.info(f"正在对 {len(chunks)} 个内容块进行二级分类...")
-        
+
         # 准备LLM请求
         llm_requests = []
         valid_chunks = []
         index_mappings = []  # 保存每个请求对应的索引映射
-        
+
         for chunk in chunks:
             first_category_code = chunk.get("chapter_classification", "")
             chunk_title = chunk.get("section_label", "")
             hierarchy_path = " -> ".join(chunk.get("hierarchy_path", []))
             content = chunk.get("review_chunk_content", "")
             content_preview = content[:300] if content else ""
-            
+
             # 获取一级分类的中文名称
             first_category_cn = self._get_first_category_cn(first_category_code)
-            
+
             # 构建二级分类标准(返回标准文本和索引映射)
             secondary_standards, index_mapping = self._build_secondary_standards(first_category_code)
-            
+
             if secondary_standards == "(无二级分类标准)":
                 # 如果没有二级分类标准,跳过
                 chunk["secondary_category_cn"] = "无"
                 chunk["secondary_category_code"] = "none"
                 continue
-            
+
             # 渲染提示词
             prompt = self.prompt_loader.render(
                 "chunk_secondary_classification",
@@ -200,28 +265,23 @@ class ChunkClassifier:
                 content_preview=content_preview,
                 secondary_standards=secondary_standards
             )
-            
-            messages = [
-                {"role": "system", "content": prompt["system"]},
-                {"role": "user", "content": prompt["user"]}
-            ]
-            
-            llm_requests.append(messages)
+
+            llm_requests.append((prompt["system"], prompt["user"]))
             valid_chunks.append(chunk)
             index_mappings.append(index_mapping)
-        
+
         if not llm_requests:
             logger.info("所有内容块都没有二级分类标准,跳过二级分类")
             return chunks
-        
+
         # 批量异步调用LLM API
-        llm_results = await self.llm_client.batch_call_async(llm_requests)
-        
+        llm_results = await self._batch_call_llm(llm_requests)
+
         # 处理分类结果
         for chunk, llm_result, index_mapping in zip(valid_chunks, llm_results, index_mappings):
             if llm_result and isinstance(llm_result, dict):
                 category_index = llm_result.get("category_index")
-                
+
                 # 验证索引并映射到类别
                 if isinstance(category_index, int) and category_index in index_mapping:
                     secondary_cn, secondary_code = index_mapping[category_index]
@@ -235,7 +295,7 @@ class ChunkClassifier:
             else:
                 chunk["secondary_category_cn"] = "非标准项"
                 chunk["secondary_category_code"] = "non_standard"
-        
+
         logger.info("二级分类完成!")
         return chunks
 
@@ -317,12 +377,12 @@ class ChunkClassifier:
         每个chunk只能属于一个三级分类
         """
         logger.info(f"正在对 {len(chunks)} 个内容块进行三级分类...")
-        
+
         # 准备LLM请求
         llm_requests = []
         valid_chunks = []
         index_mappings = []  # 保存每个请求对应的索引映射
-        
+
         for chunk in chunks:
             first_category_code = chunk.get("chapter_classification", "")
             second_category_code = chunk.get("secondary_category_code", "")
@@ -330,19 +390,19 @@ class ChunkClassifier:
             chunk_title = chunk.get("section_label", "")
             content = chunk.get("review_chunk_content", "")
             content_preview = content[:300] if content else ""
-            
+
             # 获取一级分类的中文名称
             first_category_cn = self._get_first_category_cn(first_category_code)
-            
+
             # 构建三级分类标准(返回标准文本和索引映射)
             tertiary_standards, index_mapping = self._build_tertiary_standards(first_category_code, second_category_code)
-            
+
             if tertiary_standards == "(无三级分类标准)":
                 # 如果没有三级分类标准,跳过
                 chunk["tertiary_category_cn"] = "无"
                 chunk["tertiary_category_code"] = "none"
                 continue
-            
+
             # 渲染提示词
             prompt = self.prompt_loader.render(
                 "chunk_tertiary_classification",
@@ -352,28 +412,23 @@ class ChunkClassifier:
                 content_preview=content_preview,
                 tertiary_standards=tertiary_standards
             )
-            
-            messages = [
-                {"role": "system", "content": prompt["system"]},
-                {"role": "user", "content": prompt["user"]}
-            ]
-            
-            llm_requests.append(messages)
+
+            llm_requests.append((prompt["system"], prompt["user"]))
             valid_chunks.append(chunk)
             index_mappings.append(index_mapping)
-        
+
         if not llm_requests:
             logger.info("所有内容块都没有三级分类标准,跳过三级分类")
             return chunks
-        
+
         # 批量异步调用LLM API
-        llm_results = await self.llm_client.batch_call_async(llm_requests)
-        
+        llm_results = await self._batch_call_llm(llm_requests)
+
         # 处理分类结果
         for chunk, llm_result, index_mapping in zip(valid_chunks, llm_results, index_mappings):
             if llm_result and isinstance(llm_result, dict):
                 category_index = llm_result.get("category_index")
-                
+
                 # 验证索引并映射到类别
                 if isinstance(category_index, int) and category_index in index_mapping:
                     tertiary_cn, tertiary_code = index_mapping[category_index]
@@ -387,7 +442,7 @@ class ChunkClassifier:
             else:
                 chunk["tertiary_category_cn"] = "非标准项"
                 chunk["tertiary_category_code"] = "non_standard"
-        
+
         logger.info("三级分类完成!")
         return chunks
 
@@ -435,4 +490,4 @@ class ChunkClassifier:
                 classifier_config=classifier_config
             ))
         except RuntimeError:
-            raise RuntimeError("请使用 await classify_chunks_tertiary_async")
+            raise RuntimeError("请使用 await classify_chunks_tertiary_async")

+ 1 - 0
core/construction_review/component/doc_worker/classification/hierarchy_classifier.py

@@ -65,6 +65,7 @@ class HierarchyClassifier(IHierarchyClassifier):
                 trace_id="hierarchy_classifier",
                 system_prompt=system_prompt,
                 user_prompt=user_prompt,
+                model_name="qwen3_5_122b_a10b",  # 使用 122B 大模型提升分类准确性
             )
             result = _extract_json(content)
             return result if result is not None else {"raw_content": content}

+ 1 - 1
foundation/ai/agent/generate/model_generate.py

@@ -303,4 +303,4 @@ class GenerateModelClient:
             logger.error(f"[模型流式调用] 异常 trace_id: {trace_id}, 耗时: {elapsed_time:.2f}s, 错误: {type(e).__name__}: {str(e)}")
             raise
 
-generate_model_client = GenerateModelClient(default_timeout=15, max_retries=2, backoff_factor=0.5)
+generate_model_client = GenerateModelClient(default_timeout=60, max_retries=10, backoff_factor=0.5)