|
|
@@ -0,0 +1,391 @@
|
|
|
+#!/usr/bin/env python
|
|
|
+# -*- coding: utf-8 -*-
|
|
|
+"""
|
|
|
+单元测试:思考内容过滤 & 并发安全验证
|
|
|
+
|
|
|
+无网络依赖,使用 stub + mock 模拟模型调用,覆盖以下能力:
|
|
|
+1. _strip_thinking_content:完整内容过滤的各种边界
|
|
|
+2. _ThinkingBlockStreamFilter:流式过滤(含 chunk 边界穿过 <think>/</think>)
|
|
|
+3. GenerateModelClient 端到端:function_name → 思考开关 → 返回过滤
|
|
|
+4. 并发场景:50 个混合模式请求,验证缓存实例不被污染、内容不串话
|
|
|
+
|
|
|
+运行:
|
|
|
+ python utils_test/Model_Test/test_thinking_filter_and_concurrency.py
|
|
|
+"""
|
|
|
+
|
|
|
+import sys
|
|
|
+import asyncio
|
|
|
+import types
|
|
|
+import importlib.util
|
|
|
+import unittest
|
|
|
+from pathlib import Path
|
|
|
+from typing import List, Tuple
|
|
|
+from unittest.mock import MagicMock
|
|
|
+
|
|
|
+
|
|
|
+PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
|
|
|
+
|
|
|
+
|
|
|
+# ============================================================
|
|
|
+# Step 1: 注入 stub 依赖(本地环境无 langchain/yaml 等包)
|
|
|
+# ============================================================
|
|
|
+
|
|
|
+def _make_module(name: str, **attrs) -> types.ModuleType:
|
|
|
+ mod = types.ModuleType(name)
|
|
|
+ for k, v in attrs.items():
|
|
|
+ setattr(mod, k, v)
|
|
|
+ sys.modules[name] = mod
|
|
|
+ return mod
|
|
|
+
|
|
|
+
|
|
|
+_make_module("langchain_core")
|
|
|
+_make_module("langchain_core.prompts",
|
|
|
+ ChatPromptTemplate=type("ChatPromptTemplate", (), {}))
|
|
|
+
|
|
|
+
|
|
|
+class BaseMessage:
|
|
|
+ def __init__(self, content=""):
|
|
|
+ self.content = content
|
|
|
+
|
|
|
+
|
|
|
+class SystemMessage(BaseMessage):
|
|
|
+ pass
|
|
|
+
|
|
|
+
|
|
|
+class HumanMessage(BaseMessage):
|
|
|
+ pass
|
|
|
+
|
|
|
+
|
|
|
+_make_module("langchain_core.messages",
|
|
|
+ BaseMessage=BaseMessage,
|
|
|
+ SystemMessage=SystemMessage,
|
|
|
+ HumanMessage=HumanMessage)
|
|
|
+
|
|
|
+_make_module("foundation")
|
|
|
+_make_module("foundation.ai")
|
|
|
+_make_module("foundation.ai.models")
|
|
|
+_make_module("foundation.ai.agent")
|
|
|
+_make_module("foundation.ai.agent.generate")
|
|
|
+_make_module("foundation.observability")
|
|
|
+_make_module("foundation.observability.logger")
|
|
|
+
|
|
|
+
|
|
|
+class _SilentLogger:
|
|
|
+ def info(self, *a, **k): pass
|
|
|
+ def debug(self, *a, **k): pass
|
|
|
+ def warning(self, *a, **k): pass
|
|
|
+ def error(self, *a, **k): pass
|
|
|
+
|
|
|
+
|
|
|
+_make_module("foundation.observability.logger.loggering",
|
|
|
+ review_logger=_SilentLogger())
|
|
|
+
|
|
|
+
|
|
|
+# Mock LangChain ChatOpenAI:bind() 返回新实例(与真实 RunnableBinding 一致),
|
|
|
+# ainvoke/stream 根据当前 extra_body 决定是否吐 <think> 块(模拟蜀天后端默认开启思考)
|
|
|
+class MockChatOpenAI:
|
|
|
+ def __init__(self, model_name="mock", extra_body=None):
|
|
|
+ self.model_name = model_name
|
|
|
+ self.extra_body = dict(extra_body or {})
|
|
|
+
|
|
|
+ def bind(self, **kwargs):
|
|
|
+ new_eb = dict(self.extra_body)
|
|
|
+ if "extra_body" in kwargs:
|
|
|
+ new_eb.update(kwargs["extra_body"])
|
|
|
+ return MockChatOpenAI(self.model_name, extra_body=new_eb)
|
|
|
+
|
|
|
+ def _is_thinking_on(self) -> bool:
|
|
|
+ ck = self.extra_body.get("chat_template_kwargs", {})
|
|
|
+ return ck.get("enable_thinking", True) # 模拟服务端默认 True
|
|
|
+
|
|
|
+ async def ainvoke(self, messages, **kwargs):
|
|
|
+ if self._is_thinking_on():
|
|
|
+ content = (
|
|
|
+ "<think>\n推理过程:分析问题 X 的步骤...\n经过推导...\n</think>\n\n"
|
|
|
+ f"答案:模型 {self.model_name} 处理 {len(messages)} 条消息"
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ content = f"答案:模型 {self.model_name} 处理 {len(messages)} 条消息"
|
|
|
+ await asyncio.sleep(0.01) # 制造并发竞态机会
|
|
|
+ return MagicMock(content=content)
|
|
|
+
|
|
|
+ def invoke(self, messages, **kwargs):
|
|
|
+ if self._is_thinking_on():
|
|
|
+ content = (
|
|
|
+ "<think>推理过程</think>\n\n"
|
|
|
+ f"同步答案:{self.model_name}"
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ content = f"同步答案:{self.model_name}"
|
|
|
+ return MagicMock(content=content)
|
|
|
+
|
|
|
+ def stream(self, messages, **kwargs):
|
|
|
+ if self._is_thinking_on():
|
|
|
+ full = "<think>\n推理过程\n</think>\n\n这是流式答案"
|
|
|
+ else:
|
|
|
+ full = "这是流式答案"
|
|
|
+ # 故意切到 4 字符一组,逼出 chunk 边界穿越 <think>/</think> 标签
|
|
|
+ for i in range(0, len(full), 4):
|
|
|
+ chunk = MagicMock()
|
|
|
+ chunk.content = full[i:i + 4]
|
|
|
+ yield chunk
|
|
|
+
|
|
|
+
|
|
|
+class MockModelHandler:
|
|
|
+ def __init__(self):
|
|
|
+ self._cache = {}
|
|
|
+
|
|
|
+ def get_model_by_name(self, model_type=None):
|
|
|
+ key = model_type or "default"
|
|
|
+ if key not in self._cache:
|
|
|
+ self._cache[key] = MockChatOpenAI(model_name=key)
|
|
|
+ return self._cache[key]
|
|
|
+
|
|
|
+ def get_models(self):
|
|
|
+ return self.get_model_by_name("default")
|
|
|
+
|
|
|
+
|
|
|
+mock_handler = MockModelHandler()
|
|
|
+_make_module("foundation.ai.models.model_handler",
|
|
|
+ model_handler=mock_handler,
|
|
|
+ ModelHandler=MockModelHandler,
|
|
|
+ get_models=lambda: mock_handler.get_models())
|
|
|
+
|
|
|
+# 配置 stub:模拟 model_setting.yaml 的关键映射
|
|
|
+_FAKE_THINKING = {
|
|
|
+ "outline_chapter_revise": False, # 编写:非思考
|
|
|
+ "catalog_integrity_review": True, # 目录审查:思考
|
|
|
+ "doc_classification_tertiary": False,
|
|
|
+ "default": True, # 危险默认值(云端真实情况)
|
|
|
+}
|
|
|
+_FAKE_MODEL = {
|
|
|
+ "outline_chapter_revise": "shutian_qwen3_5_122b",
|
|
|
+ "catalog_integrity_review": "shutian_qwen3_5_122b",
|
|
|
+ "doc_classification_tertiary": "shutian_qwen3_5_122b",
|
|
|
+ "default": "shutian_qwen3_5_122b",
|
|
|
+}
|
|
|
+_make_module("foundation.ai.models.model_config_loader",
|
|
|
+ get_model_for_function=lambda n: _FAKE_MODEL.get(n, _FAKE_MODEL["default"]),
|
|
|
+ get_thinking_mode_for_function=lambda n: _FAKE_THINKING.get(n, _FAKE_THINKING["default"]))
|
|
|
+
|
|
|
+
|
|
|
+# ============================================================
|
|
|
+# Step 2: 加载 model_generate.py(被测对象)
|
|
|
+# ============================================================
|
|
|
+_target = PROJECT_ROOT / "foundation" / "ai" / "agent" / "generate" / "model_generate.py"
|
|
|
+_spec = importlib.util.spec_from_file_location("mg_under_test", _target)
|
|
|
+mg = importlib.util.module_from_spec(_spec)
|
|
|
+_spec.loader.exec_module(mg)
|
|
|
+
|
|
|
+
|
|
|
+# ============================================================
|
|
|
+# Step 3: 测试用例
|
|
|
+# ============================================================
|
|
|
+
|
|
|
+class TestStripThinkingContent(unittest.TestCase):
|
|
|
+ """_strip_thinking_content 单元测试"""
|
|
|
+
|
|
|
+ def test_complete_block(self):
|
|
|
+ s = "<think>推理过程</think>\n\n最终答案"
|
|
|
+ self.assertEqual(mg._strip_thinking_content(s), "最终答案")
|
|
|
+
|
|
|
+ def test_multiple_blocks(self):
|
|
|
+ s = "<think>think1</think>段A<think>think2</think>段B"
|
|
|
+ self.assertEqual(mg._strip_thinking_content(s), "段A段B")
|
|
|
+
|
|
|
+ def test_dangling_block(self):
|
|
|
+ s = "正文段\n<think>推理被截断"
|
|
|
+ self.assertEqual(mg._strip_thinking_content(s), "正文段")
|
|
|
+
|
|
|
+ def test_no_thinking(self):
|
|
|
+ self.assertEqual(mg._strip_thinking_content("纯回答内容"), "纯回答内容")
|
|
|
+
|
|
|
+ def test_empty_and_none(self):
|
|
|
+ self.assertEqual(mg._strip_thinking_content(""), "")
|
|
|
+ self.assertIsNone(mg._strip_thinking_content(None))
|
|
|
+
|
|
|
+ def test_multiline_block(self):
|
|
|
+ s = "<think>\n第一行\n第二行\n第三行\n</think>\n\n答案"
|
|
|
+ self.assertEqual(mg._strip_thinking_content(s), "答案")
|
|
|
+
|
|
|
+ def test_block_at_end(self):
|
|
|
+ s = "答案先行<think>反思</think>"
|
|
|
+ self.assertEqual(mg._strip_thinking_content(s), "答案先行")
|
|
|
+
|
|
|
+
|
|
|
+class TestStreamFilter(unittest.TestCase):
|
|
|
+ """_ThinkingBlockStreamFilter 流式过滤测试"""
|
|
|
+
|
|
|
+ def _drive(self, chunks: List[str]) -> str:
|
|
|
+ flt = mg._ThinkingBlockStreamFilter()
|
|
|
+ out = []
|
|
|
+ for c in chunks:
|
|
|
+ r = flt.feed(c)
|
|
|
+ if r:
|
|
|
+ out.append(r)
|
|
|
+ tail = flt.flush()
|
|
|
+ if tail:
|
|
|
+ out.append(tail)
|
|
|
+ return "".join(out)
|
|
|
+
|
|
|
+ def test_single_chunk_with_block(self):
|
|
|
+ self.assertEqual(self._drive(["<think>x</think>正文"]), "正文")
|
|
|
+
|
|
|
+ def test_split_open_tag(self):
|
|
|
+ # chunk 边界切到 <thi|nk>
|
|
|
+ self.assertEqual(self._drive(["<thi", "nk>推理</think>正文"]), "正文")
|
|
|
+
|
|
|
+ def test_split_close_tag(self):
|
|
|
+ # chunk 边界切到 </thi|nk>
|
|
|
+ self.assertEqual(self._drive(["<think>推理</thi", "nk>正文"]), "正文")
|
|
|
+
|
|
|
+ def test_split_in_middle_of_block(self):
|
|
|
+ chunks = ["<think>", "推理1", "推理2", "</think>", "答案"]
|
|
|
+ self.assertEqual(self._drive(chunks), "答案")
|
|
|
+
|
|
|
+ def test_no_thinking_passes_through(self):
|
|
|
+ self.assertEqual(self._drive(["普通", "答案", "内容"]), "普通答案内容")
|
|
|
+
|
|
|
+ def test_dangling_flush_drops(self):
|
|
|
+ flt = mg._ThinkingBlockStreamFilter()
|
|
|
+ first = flt.feed("正文") # 应输出 "正文"
|
|
|
+ second = flt.feed("<think>未完成") # 进入 think 内
|
|
|
+ tail = flt.flush() # 在 think 内 → 丢弃
|
|
|
+ self.assertEqual(first, "正文")
|
|
|
+ self.assertEqual(second, "")
|
|
|
+ self.assertEqual(tail, "")
|
|
|
+
|
|
|
+ def test_multiple_blocks_streamed(self):
|
|
|
+ s = "<think>a</think>段1<think>b</think>段2"
|
|
|
+ self.assertEqual(self._drive([s]), "段1段2")
|
|
|
+
|
|
|
+ def test_realistic_3char_chunks(self):
|
|
|
+ full = "<think>\n推理过程\n</think>\n\n这是答案"
|
|
|
+ chunks = [full[i:i + 3] for i in range(0, len(full), 3)]
|
|
|
+ result = self._drive(chunks)
|
|
|
+ self.assertIn("这是答案", result)
|
|
|
+ self.assertNotIn("<think>", result)
|
|
|
+ self.assertNotIn("</think>", result)
|
|
|
+ self.assertNotIn("推理过程", result)
|
|
|
+
|
|
|
+
|
|
|
+class TestEndToEndInvoke(unittest.IsolatedAsyncioTestCase):
|
|
|
+ """通过 mock 验证 GenerateModelClient 端到端"""
|
|
|
+
|
|
|
+ async def test_thinking_off_via_function_name(self):
|
|
|
+ client = mg.GenerateModelClient()
|
|
|
+ result = await client.get_model_generate_invoke(
|
|
|
+ trace_id="t1",
|
|
|
+ prompt="2+2",
|
|
|
+ function_name="outline_chapter_revise", # config: false
|
|
|
+ )
|
|
|
+ self.assertNotIn("<think>", result)
|
|
|
+ self.assertNotIn("</think>", result)
|
|
|
+ self.assertNotIn("推理过程", result)
|
|
|
+ self.assertIn("答案", result)
|
|
|
+
|
|
|
+ async def test_thinking_on_block_stripped(self):
|
|
|
+ client = mg.GenerateModelClient()
|
|
|
+ result = await client.get_model_generate_invoke(
|
|
|
+ trace_id="t2",
|
|
|
+ prompt="解释勾股定理",
|
|
|
+ function_name="catalog_integrity_review", # config: true
|
|
|
+ )
|
|
|
+ # 即使开启思考,底层也会过滤
|
|
|
+ self.assertNotIn("<think>", result)
|
|
|
+ self.assertNotIn("</think>", result)
|
|
|
+ self.assertNotIn("推理过程", result)
|
|
|
+ self.assertIn("答案", result)
|
|
|
+
|
|
|
+ async def test_default_config_thinking_still_filtered(self):
|
|
|
+ """函数名未匹配 → 走 default(true),但仍然被过滤"""
|
|
|
+ client = mg.GenerateModelClient()
|
|
|
+ result = await client.get_model_generate_invoke(
|
|
|
+ trace_id="t3",
|
|
|
+ prompt="X",
|
|
|
+ function_name="not_in_yaml_xxx", # 未匹配 → 走 default
|
|
|
+ )
|
|
|
+ self.assertNotIn("<think>", result)
|
|
|
+ self.assertIn("答案", result)
|
|
|
+
|
|
|
+
|
|
|
+class TestStreamEndToEnd(unittest.TestCase):
|
|
|
+ """流式调用端到端"""
|
|
|
+
|
|
|
+ def test_stream_thinking_on_filtered(self):
|
|
|
+ client = mg.GenerateModelClient()
|
|
|
+ chunks = list(client.get_model_generate_stream(
|
|
|
+ trace_id="s1",
|
|
|
+ prompt="X",
|
|
|
+ function_name="catalog_integrity_review",
|
|
|
+ ))
|
|
|
+ joined = "".join(c if isinstance(c, str) else "" for c in chunks)
|
|
|
+ self.assertNotIn("<think>", joined)
|
|
|
+ self.assertNotIn("</think>", joined)
|
|
|
+ self.assertNotIn("推理过程", joined)
|
|
|
+ self.assertIn("流式答案", joined)
|
|
|
+
|
|
|
+
|
|
|
+class TestConcurrentNoCrosstalk(unittest.IsolatedAsyncioTestCase):
|
|
|
+ """并发场景:思考/非思考混合不串话,缓存实例不被污染"""
|
|
|
+
|
|
|
+ async def test_50_concurrent_mixed_modes(self):
|
|
|
+ client = mg.GenerateModelClient()
|
|
|
+ N = 50
|
|
|
+
|
|
|
+ async def run(i: int) -> Tuple[int, str, str]:
|
|
|
+ fn = "outline_chapter_revise" if i % 2 == 0 else "catalog_integrity_review"
|
|
|
+ r = await client.get_model_generate_invoke(
|
|
|
+ trace_id=f"concur-{i}",
|
|
|
+ prompt=f"请求 {i}",
|
|
|
+ function_name=fn,
|
|
|
+ )
|
|
|
+ return (i, fn, r)
|
|
|
+
|
|
|
+ results = await asyncio.gather(*(run(i) for i in range(N)))
|
|
|
+
|
|
|
+ leaks = [(i, fn) for i, fn, r in results
|
|
|
+ if "<think>" in r or "</think>" in r or "推理过程" in r]
|
|
|
+ self.assertEqual(leaks, [], f"思考内容泄漏到 {len(leaks)} 个返回中: {leaks[:3]}")
|
|
|
+ self.assertEqual(len(results), N)
|
|
|
+ no_answer = [(i, r) for i, fn, r in results if "答案" not in r]
|
|
|
+ self.assertEqual(no_answer, [], f"个别结果丢失答案: {no_answer[:3]}")
|
|
|
+
|
|
|
+ async def test_cached_instance_not_polluted(self):
|
|
|
+ """关键:bind() 不能污染 model_handler 缓存的 ChatOpenAI 实例"""
|
|
|
+ client = mg.GenerateModelClient()
|
|
|
+ cached = mock_handler.get_model_by_name("shutian_qwen3_5_122b")
|
|
|
+ # 跑 30 个混合并发
|
|
|
+ await asyncio.gather(*(
|
|
|
+ client.get_model_generate_invoke(
|
|
|
+ trace_id=f"iso-{i}",
|
|
|
+ prompt="x",
|
|
|
+ function_name="catalog_integrity_review" if i % 2 else "outline_chapter_revise"
|
|
|
+ ) for i in range(30)
|
|
|
+ ))
|
|
|
+ # cached 实例的 extra_body 应当始终为空 —— bind 走的是 RunnableBinding 副本
|
|
|
+ self.assertEqual(
|
|
|
+ cached.extra_body, {},
|
|
|
+ f"缓存实例被污染,extra_body={cached.extra_body}(这意味着并发请求会串话)"
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+# ============================================================
|
|
|
+# 入口
|
|
|
+# ============================================================
|
|
|
+def run_all():
|
|
|
+ loader = unittest.TestLoader()
|
|
|
+ suite = unittest.TestSuite()
|
|
|
+ for cls in (TestStripThinkingContent,
|
|
|
+ TestStreamFilter,
|
|
|
+ TestEndToEndInvoke,
|
|
|
+ TestStreamEndToEnd,
|
|
|
+ TestConcurrentNoCrosstalk):
|
|
|
+ suite.addTests(loader.loadTestsFromTestCase(cls))
|
|
|
+ runner = unittest.TextTestRunner(verbosity=2)
|
|
|
+ result = runner.run(suite)
|
|
|
+ return 0 if result.wasSuccessful() else 1
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ sys.exit(run_all())
|