浏览代码

fix: 收敛思考内容过滤到模型调用层,修复云端思考块泄漏

云端编写功能因 default.enable_thinking 误设为 true 且未同步本地修复,
导致输出夹带 <think> 块。将过滤逻辑下沉到 model_generate 基础层,
async/sync/stream 三种调用路径自动剥离思考内容;流式过滤实现 chunk
边界缝合,处理 <think>/</think> 标签被切分的情况。

- model_generate: 新增 _strip_thinking_content 与 _ThinkingBlockStreamFilter
- catalog_reviewer: 移除冗余的本地思考块处理
- model_setting: default.enable_thinking 调回 false 规避未匹配场景
- 新增 21 项单测覆盖过滤逻辑、流式边界、50 并发混合模式无串话

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
WangXuMing 1 月之前
父节点
当前提交
71f87de9b4

+ 1 - 1
config/model_setting.yaml

@@ -150,4 +150,4 @@ model_settings:
 # 默认配置(当功能未指定时使用)
 default:
   model: shutian_qwen3_5_122b
-  enable_thinking: true
+  enable_thinking: false

+ 1 - 9
core/construction_review/component/minimal_pipeline/catalog_reviewer.py

@@ -374,7 +374,7 @@ check_result 中必须包含以下字段:
 """
 
     def _extract_json(self, content: str) -> Optional[Dict[str, Any]]:
-        """从LLM响应中提取JSON,增强健壮性(支持思考模式输出)"""
+        """从LLM响应中提取JSON,增强健壮性"""
         try:
             # 清理内容:移除 markdown 代码块标记
             content = content.strip()
@@ -385,14 +385,6 @@ check_result 中必须包含以下字段:
             content = re.sub(r'\s*```\s*$', '', content, flags=re.MULTILINE)
             content = re.sub(r'^```\s*', '', content, flags=re.MULTILINE)
 
-            # 处理思考模式输出:跳过思考部分,提取最终答案
-            # 检查 <think>...</think> 标签 (Qwen3.5 思考模式标准格式)
-            think_end = content.find("</think>")
-            if think_end != -1:
-                # 提取 </think> 之后的内容
-                content = content[think_end + len("</think>"):].strip()
-                logger.debug(f"[CatalogReviewer] 检测到 <think> 标签,从 </think> 后提取内容,长度: {len(content)}")
-
             # 找到第一个 { 开始的位置
             json_start = content.find('{')
             if json_start == -1:

+ 113 - 3
foundation/ai/agent/generate/model_generate.py

@@ -13,10 +13,114 @@ from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage
 from foundation.ai.models.model_handler import model_handler
 from foundation.observability.logger.loggering import review_logger as logger
 import asyncio
+import re
 import time
 from typing import Optional, Callable, Any, List, Union
 
 
+# ============================================================
+# 思考内容过滤(统一收敛在调用层)
+#
+# Qwen3.5 等模型在 enable_thinking=True 时会先输出 <think>...</think>
+# 块再给出最终答案。所有业务方都不需要思考过程,统一在此处去除,
+# 避免每个调用点重复实现,也防止漏处理导致思考内容污染输出。
+# ============================================================
+_THINK_BLOCK_PATTERN = re.compile(r"<think>.*?</think>\s*", re.DOTALL)
+_DANGLING_THINK_PATTERN = re.compile(r"<think>[\s\S]*$")
+
+
+def _strip_thinking_content(content: str) -> str:
+    """去除完整响应中的 <think>...</think> 块。
+
+    - 完整闭合块:整段去除(含尾随空白)
+    - 仅 <think> 无 </think>(被截断):从 <think> 起全部丢弃,记录警告
+    - 不含思考标签:原文返回
+    """
+    if not content:
+        return content
+    cleaned = _THINK_BLOCK_PATTERN.sub("", content)
+    if "<think>" in cleaned:
+        cleaned = _DANGLING_THINK_PATTERN.sub("", cleaned)
+        logger.warning("[模型调用] 响应包含未闭合的 <think> 块,已截断丢弃")
+    return cleaned.strip()
+
+
+class _ThinkingBlockStreamFilter:
+    """流式响应中过滤 <think>...</think> 块的状态机。
+
+    处理 chunk 边界穿过标签的情况(如先收到 "<thi"、下次再到 "nk>正文"),
+    保证调用方拿到的流不会泄漏任何思考片段。
+
+    用法:
+        flt = _ThinkingBlockStreamFilter()
+        for chunk in stream:
+            cleaned = flt.feed(chunk)
+            if cleaned:
+                yield cleaned
+        tail = flt.flush()
+        if tail:
+            yield tail
+    """
+
+    _OPEN = "<think>"
+    _CLOSE = "</think>"
+
+    def __init__(self):
+        self._buf = ""
+        self._inside = False
+
+    def feed(self, chunk: str) -> str:
+        """喂入一个 chunk,返回此刻应输出的内容(可能为空字符串)。"""
+        if not chunk:
+            return ""
+        self._buf += chunk
+        out = []
+        while True:
+            if self._inside:
+                idx = self._buf.find(self._CLOSE)
+                if idx == -1:
+                    keep_len = self._partial_match_len(self._buf, self._CLOSE)
+                    self._buf = self._buf[-keep_len:] if keep_len else ""
+                    break
+                self._buf = self._buf[idx + len(self._CLOSE):].lstrip()
+                self._inside = False
+            else:
+                idx = self._buf.find(self._OPEN)
+                if idx == -1:
+                    keep_len = self._partial_match_len(self._buf, self._OPEN)
+                    if keep_len:
+                        out.append(self._buf[:-keep_len])
+                        self._buf = self._buf[-keep_len:]
+                    else:
+                        out.append(self._buf)
+                        self._buf = ""
+                    break
+                if idx > 0:
+                    out.append(self._buf[:idx])
+                self._buf = self._buf[idx + len(self._OPEN):]
+                self._inside = True
+        return "".join(out)
+
+    def flush(self) -> str:
+        """流结束时调用,返回缓冲区剩余可输出内容。"""
+        if self._inside:
+            logger.warning("[模型流式调用] 流结束时仍在 <think> 块内,已丢弃尾部")
+            self._buf = ""
+            return ""
+        result = self._buf
+        self._buf = ""
+        return result
+
+    @staticmethod
+    def _partial_match_len(buf: str, tag: str) -> int:
+        """返回 buf 末尾匹配 tag 前缀的最大长度(避免标签被切断后误输出)。"""
+        max_n = min(len(tag) - 1, len(buf))
+        for n in range(max_n, 0, -1):
+            if buf[-n:] == tag[:n]:
+                return n
+        return 0
+
+
 def _sync_retry_with_backoff(
     func: Callable,
     *args,
@@ -260,7 +364,7 @@ class GenerateModelClient:
 
             elapsed_time = time.time() - start_time
             logger.info(f"[模型调用] 成功 trace_id: {trace_id}, 耗时: {elapsed_time:.2f}s")
-            return response.content
+            return _strip_thinking_content(response.content)
 
         except asyncio.TimeoutError:
             elapsed_time = time.time() - start_time
@@ -452,7 +556,7 @@ class GenerateModelClient:
 
             elapsed_time = time.time() - start_time
             logger.info(f"[模型调用-同步] 成功 trace_id: {trace_id}, 耗时: {elapsed_time:.2f}s")
-            return response.content
+            return _strip_thinking_content(response.content)
 
         except Exception as e:
             elapsed_time = time.time() - start_time
@@ -538,12 +642,18 @@ class GenerateModelClient:
             response = llm_to_use.stream(final_messages)
 
             chunk_count = 0
+            think_filter = _ThinkingBlockStreamFilter()
             for chunk in response:
                 chunk_count += 1
                 if hasattr(chunk, 'content') and chunk.content:
-                    yield chunk.content
+                    cleaned = think_filter.feed(chunk.content)
+                    if cleaned:
+                        yield cleaned
                 elif chunk:
                     yield chunk
+            tail = think_filter.flush()
+            if tail:
+                yield tail
 
             elapsed_time = time.time() - start_time
             logger.info(f"[模型流式调用] 成功 trace_id: {trace_id}, 生成块数: {chunk_count}, 耗时: {elapsed_time:.2f}s")

+ 391 - 0
utils_test/Model_Test/test_thinking_filter_and_concurrency.py

@@ -0,0 +1,391 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+"""
+单元测试:思考内容过滤 & 并发安全验证
+
+无网络依赖,使用 stub + mock 模拟模型调用,覆盖以下能力:
+1. _strip_thinking_content:完整内容过滤的各种边界
+2. _ThinkingBlockStreamFilter:流式过滤(含 chunk 边界穿过 <think>/</think>)
+3. GenerateModelClient 端到端:function_name → 思考开关 → 返回过滤
+4. 并发场景:50 个混合模式请求,验证缓存实例不被污染、内容不串话
+
+运行:
+    python utils_test/Model_Test/test_thinking_filter_and_concurrency.py
+"""
+
+import sys
+import asyncio
+import types
+import importlib.util
+import unittest
+from pathlib import Path
+from typing import List, Tuple
+from unittest.mock import MagicMock
+
+
+PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
+
+
+# ============================================================
+# Step 1: 注入 stub 依赖(本地环境无 langchain/yaml 等包)
+# ============================================================
+
+def _make_module(name: str, **attrs) -> types.ModuleType:
+    mod = types.ModuleType(name)
+    for k, v in attrs.items():
+        setattr(mod, k, v)
+    sys.modules[name] = mod
+    return mod
+
+
+_make_module("langchain_core")
+_make_module("langchain_core.prompts",
+             ChatPromptTemplate=type("ChatPromptTemplate", (), {}))
+
+
+class BaseMessage:
+    def __init__(self, content=""):
+        self.content = content
+
+
+class SystemMessage(BaseMessage):
+    pass
+
+
+class HumanMessage(BaseMessage):
+    pass
+
+
+_make_module("langchain_core.messages",
+             BaseMessage=BaseMessage,
+             SystemMessage=SystemMessage,
+             HumanMessage=HumanMessage)
+
+_make_module("foundation")
+_make_module("foundation.ai")
+_make_module("foundation.ai.models")
+_make_module("foundation.ai.agent")
+_make_module("foundation.ai.agent.generate")
+_make_module("foundation.observability")
+_make_module("foundation.observability.logger")
+
+
+class _SilentLogger:
+    def info(self, *a, **k): pass
+    def debug(self, *a, **k): pass
+    def warning(self, *a, **k): pass
+    def error(self, *a, **k): pass
+
+
+_make_module("foundation.observability.logger.loggering",
+             review_logger=_SilentLogger())
+
+
+# Mock LangChain ChatOpenAI:bind() 返回新实例(与真实 RunnableBinding 一致),
+# ainvoke/stream 根据当前 extra_body 决定是否吐 <think> 块(模拟蜀天后端默认开启思考)
+class MockChatOpenAI:
+    def __init__(self, model_name="mock", extra_body=None):
+        self.model_name = model_name
+        self.extra_body = dict(extra_body or {})
+
+    def bind(self, **kwargs):
+        new_eb = dict(self.extra_body)
+        if "extra_body" in kwargs:
+            new_eb.update(kwargs["extra_body"])
+        return MockChatOpenAI(self.model_name, extra_body=new_eb)
+
+    def _is_thinking_on(self) -> bool:
+        ck = self.extra_body.get("chat_template_kwargs", {})
+        return ck.get("enable_thinking", True)  # 模拟服务端默认 True
+
+    async def ainvoke(self, messages, **kwargs):
+        if self._is_thinking_on():
+            content = (
+                "<think>\n推理过程:分析问题 X 的步骤...\n经过推导...\n</think>\n\n"
+                f"答案:模型 {self.model_name} 处理 {len(messages)} 条消息"
+            )
+        else:
+            content = f"答案:模型 {self.model_name} 处理 {len(messages)} 条消息"
+        await asyncio.sleep(0.01)  # 制造并发竞态机会
+        return MagicMock(content=content)
+
+    def invoke(self, messages, **kwargs):
+        if self._is_thinking_on():
+            content = (
+                "<think>推理过程</think>\n\n"
+                f"同步答案:{self.model_name}"
+            )
+        else:
+            content = f"同步答案:{self.model_name}"
+        return MagicMock(content=content)
+
+    def stream(self, messages, **kwargs):
+        if self._is_thinking_on():
+            full = "<think>\n推理过程\n</think>\n\n这是流式答案"
+        else:
+            full = "这是流式答案"
+        # 故意切到 4 字符一组,逼出 chunk 边界穿越 <think>/</think> 标签
+        for i in range(0, len(full), 4):
+            chunk = MagicMock()
+            chunk.content = full[i:i + 4]
+            yield chunk
+
+
+class MockModelHandler:
+    def __init__(self):
+        self._cache = {}
+
+    def get_model_by_name(self, model_type=None):
+        key = model_type or "default"
+        if key not in self._cache:
+            self._cache[key] = MockChatOpenAI(model_name=key)
+        return self._cache[key]
+
+    def get_models(self):
+        return self.get_model_by_name("default")
+
+
+mock_handler = MockModelHandler()
+_make_module("foundation.ai.models.model_handler",
+             model_handler=mock_handler,
+             ModelHandler=MockModelHandler,
+             get_models=lambda: mock_handler.get_models())
+
+# 配置 stub:模拟 model_setting.yaml 的关键映射
+_FAKE_THINKING = {
+    "outline_chapter_revise": False,        # 编写:非思考
+    "catalog_integrity_review": True,       # 目录审查:思考
+    "doc_classification_tertiary": False,
+    "default": True,                        # 危险默认值(云端真实情况)
+}
+_FAKE_MODEL = {
+    "outline_chapter_revise": "shutian_qwen3_5_122b",
+    "catalog_integrity_review": "shutian_qwen3_5_122b",
+    "doc_classification_tertiary": "shutian_qwen3_5_122b",
+    "default": "shutian_qwen3_5_122b",
+}
+_make_module("foundation.ai.models.model_config_loader",
+             get_model_for_function=lambda n: _FAKE_MODEL.get(n, _FAKE_MODEL["default"]),
+             get_thinking_mode_for_function=lambda n: _FAKE_THINKING.get(n, _FAKE_THINKING["default"]))
+
+
+# ============================================================
+# Step 2: 加载 model_generate.py(被测对象)
+# ============================================================
+_target = PROJECT_ROOT / "foundation" / "ai" / "agent" / "generate" / "model_generate.py"
+_spec = importlib.util.spec_from_file_location("mg_under_test", _target)
+mg = importlib.util.module_from_spec(_spec)
+_spec.loader.exec_module(mg)
+
+
+# ============================================================
+# Step 3: 测试用例
+# ============================================================
+
+class TestStripThinkingContent(unittest.TestCase):
+    """_strip_thinking_content 单元测试"""
+
+    def test_complete_block(self):
+        s = "<think>推理过程</think>\n\n最终答案"
+        self.assertEqual(mg._strip_thinking_content(s), "最终答案")
+
+    def test_multiple_blocks(self):
+        s = "<think>think1</think>段A<think>think2</think>段B"
+        self.assertEqual(mg._strip_thinking_content(s), "段A段B")
+
+    def test_dangling_block(self):
+        s = "正文段\n<think>推理被截断"
+        self.assertEqual(mg._strip_thinking_content(s), "正文段")
+
+    def test_no_thinking(self):
+        self.assertEqual(mg._strip_thinking_content("纯回答内容"), "纯回答内容")
+
+    def test_empty_and_none(self):
+        self.assertEqual(mg._strip_thinking_content(""), "")
+        self.assertIsNone(mg._strip_thinking_content(None))
+
+    def test_multiline_block(self):
+        s = "<think>\n第一行\n第二行\n第三行\n</think>\n\n答案"
+        self.assertEqual(mg._strip_thinking_content(s), "答案")
+
+    def test_block_at_end(self):
+        s = "答案先行<think>反思</think>"
+        self.assertEqual(mg._strip_thinking_content(s), "答案先行")
+
+
+class TestStreamFilter(unittest.TestCase):
+    """_ThinkingBlockStreamFilter 流式过滤测试"""
+
+    def _drive(self, chunks: List[str]) -> str:
+        flt = mg._ThinkingBlockStreamFilter()
+        out = []
+        for c in chunks:
+            r = flt.feed(c)
+            if r:
+                out.append(r)
+        tail = flt.flush()
+        if tail:
+            out.append(tail)
+        return "".join(out)
+
+    def test_single_chunk_with_block(self):
+        self.assertEqual(self._drive(["<think>x</think>正文"]), "正文")
+
+    def test_split_open_tag(self):
+        # chunk 边界切到 <thi|nk>
+        self.assertEqual(self._drive(["<thi", "nk>推理</think>正文"]), "正文")
+
+    def test_split_close_tag(self):
+        # chunk 边界切到 </thi|nk>
+        self.assertEqual(self._drive(["<think>推理</thi", "nk>正文"]), "正文")
+
+    def test_split_in_middle_of_block(self):
+        chunks = ["<think>", "推理1", "推理2", "</think>", "答案"]
+        self.assertEqual(self._drive(chunks), "答案")
+
+    def test_no_thinking_passes_through(self):
+        self.assertEqual(self._drive(["普通", "答案", "内容"]), "普通答案内容")
+
+    def test_dangling_flush_drops(self):
+        flt = mg._ThinkingBlockStreamFilter()
+        first = flt.feed("正文")          # 应输出 "正文"
+        second = flt.feed("<think>未完成")  # 进入 think 内
+        tail = flt.flush()                # 在 think 内 → 丢弃
+        self.assertEqual(first, "正文")
+        self.assertEqual(second, "")
+        self.assertEqual(tail, "")
+
+    def test_multiple_blocks_streamed(self):
+        s = "<think>a</think>段1<think>b</think>段2"
+        self.assertEqual(self._drive([s]), "段1段2")
+
+    def test_realistic_3char_chunks(self):
+        full = "<think>\n推理过程\n</think>\n\n这是答案"
+        chunks = [full[i:i + 3] for i in range(0, len(full), 3)]
+        result = self._drive(chunks)
+        self.assertIn("这是答案", result)
+        self.assertNotIn("<think>", result)
+        self.assertNotIn("</think>", result)
+        self.assertNotIn("推理过程", result)
+
+
+class TestEndToEndInvoke(unittest.IsolatedAsyncioTestCase):
+    """通过 mock 验证 GenerateModelClient 端到端"""
+
+    async def test_thinking_off_via_function_name(self):
+        client = mg.GenerateModelClient()
+        result = await client.get_model_generate_invoke(
+            trace_id="t1",
+            prompt="2+2",
+            function_name="outline_chapter_revise",  # config: false
+        )
+        self.assertNotIn("<think>", result)
+        self.assertNotIn("</think>", result)
+        self.assertNotIn("推理过程", result)
+        self.assertIn("答案", result)
+
+    async def test_thinking_on_block_stripped(self):
+        client = mg.GenerateModelClient()
+        result = await client.get_model_generate_invoke(
+            trace_id="t2",
+            prompt="解释勾股定理",
+            function_name="catalog_integrity_review",  # config: true
+        )
+        # 即使开启思考,底层也会过滤
+        self.assertNotIn("<think>", result)
+        self.assertNotIn("</think>", result)
+        self.assertNotIn("推理过程", result)
+        self.assertIn("答案", result)
+
+    async def test_default_config_thinking_still_filtered(self):
+        """函数名未匹配 → 走 default(true),但仍然被过滤"""
+        client = mg.GenerateModelClient()
+        result = await client.get_model_generate_invoke(
+            trace_id="t3",
+            prompt="X",
+            function_name="not_in_yaml_xxx",  # 未匹配 → 走 default
+        )
+        self.assertNotIn("<think>", result)
+        self.assertIn("答案", result)
+
+
+class TestStreamEndToEnd(unittest.TestCase):
+    """流式调用端到端"""
+
+    def test_stream_thinking_on_filtered(self):
+        client = mg.GenerateModelClient()
+        chunks = list(client.get_model_generate_stream(
+            trace_id="s1",
+            prompt="X",
+            function_name="catalog_integrity_review",
+        ))
+        joined = "".join(c if isinstance(c, str) else "" for c in chunks)
+        self.assertNotIn("<think>", joined)
+        self.assertNotIn("</think>", joined)
+        self.assertNotIn("推理过程", joined)
+        self.assertIn("流式答案", joined)
+
+
+class TestConcurrentNoCrosstalk(unittest.IsolatedAsyncioTestCase):
+    """并发场景:思考/非思考混合不串话,缓存实例不被污染"""
+
+    async def test_50_concurrent_mixed_modes(self):
+        client = mg.GenerateModelClient()
+        N = 50
+
+        async def run(i: int) -> Tuple[int, str, str]:
+            fn = "outline_chapter_revise" if i % 2 == 0 else "catalog_integrity_review"
+            r = await client.get_model_generate_invoke(
+                trace_id=f"concur-{i}",
+                prompt=f"请求 {i}",
+                function_name=fn,
+            )
+            return (i, fn, r)
+
+        results = await asyncio.gather(*(run(i) for i in range(N)))
+
+        leaks = [(i, fn) for i, fn, r in results
+                 if "<think>" in r or "</think>" in r or "推理过程" in r]
+        self.assertEqual(leaks, [], f"思考内容泄漏到 {len(leaks)} 个返回中: {leaks[:3]}")
+        self.assertEqual(len(results), N)
+        no_answer = [(i, r) for i, fn, r in results if "答案" not in r]
+        self.assertEqual(no_answer, [], f"个别结果丢失答案: {no_answer[:3]}")
+
+    async def test_cached_instance_not_polluted(self):
+        """关键:bind() 不能污染 model_handler 缓存的 ChatOpenAI 实例"""
+        client = mg.GenerateModelClient()
+        cached = mock_handler.get_model_by_name("shutian_qwen3_5_122b")
+        # 跑 30 个混合并发
+        await asyncio.gather(*(
+            client.get_model_generate_invoke(
+                trace_id=f"iso-{i}",
+                prompt="x",
+                function_name="catalog_integrity_review" if i % 2 else "outline_chapter_revise"
+            ) for i in range(30)
+        ))
+        # cached 实例的 extra_body 应当始终为空 —— bind 走的是 RunnableBinding 副本
+        self.assertEqual(
+            cached.extra_body, {},
+            f"缓存实例被污染,extra_body={cached.extra_body}(这意味着并发请求会串话)"
+        )
+
+
+# ============================================================
+# 入口
+# ============================================================
+def run_all():
+    loader = unittest.TestLoader()
+    suite = unittest.TestSuite()
+    for cls in (TestStripThinkingContent,
+                TestStreamFilter,
+                TestEndToEndInvoke,
+                TestStreamEndToEnd,
+                TestConcurrentNoCrosstalk):
+        suite.addTests(loader.loadTestsFromTestCase(cls))
+    runner = unittest.TextTestRunner(verbosity=2)
+    result = runner.run(suite)
+    return 0 if result.wasSuccessful() else 1
+
+
+if __name__ == "__main__":
+    sys.exit(run_all())