| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154 |
- import importlib.util
- import sys
- import unittest
- from pathlib import Path
- from types import SimpleNamespace
- from unittest.mock import AsyncMock, Mock, patch
- CHAT_PATH = Path(__file__).resolve().parents[1] / "routers" / "chat.py"
- sys.path.insert(0, str(CHAT_PATH.parents[1]))
- spec = importlib.util.spec_from_file_location("chat_under_test", CHAT_PATH)
- chat = importlib.util.module_from_spec(spec)
- spec.loader.exec_module(chat)
- class FakeDB:
- def __init__(self):
- self.added = []
- self.commits = 0
- self.refreshed = []
- def add(self, obj):
- self.added.append(obj)
- def commit(self):
- self.commits += 1
- def refresh(self, obj):
- if getattr(obj, "id", None) is None:
- obj.id = len(self.refreshed) + 100
- self.refreshed.append(obj)
- class ChatAIWritingRouteTest(unittest.IsolatedAsyncioTestCase):
- async def test_safety_training_generation_enriches_request_then_uses_legacy_ppt_outline_chain(self):
- planning_reply = """
- {
- "topic": "施工安全培训",
- "template": "标准安全培训PPT大纲",
- "content_focus": ["员工安全意识", "施工现场着装规范"],
- "audience": "蜀道集团员工",
- "time": "2025/4/10",
- "location": "蜀道集团",
- "goal": "提升培训员工的安全意识",
- "notes": "着装随意",
- "normalized_request": "围绕2025年第一季度施工安全培训生成安全培训PPT大纲"
- }
- """
- with patch.object(chat, "_rag_search", AsyncMock(return_value="Safety RAG context")) as rag_search, \
- patch.object(chat, "load_prompt", Mock(return_value="ppt outline prompt")) as load_prompt, \
- patch.object(chat.qwen_service, "chat", AsyncMock(side_effect=[planning_reply, "PPT outline reply"])) as qwen_chat, \
- patch.object(chat.deepseek_service, "chat", AsyncMock(return_value="DeepSeek reply")) as deepseek_chat:
- reply = await chat._generate_ppt_outline_response(
- "生成一份施工安全培训通知,2025年第一季度施工安全培训,在蜀道集团,2025/4/10进行培训,培训员工的安全意识,着装随意。"
- )
- self.assertEqual(reply, "PPT outline reply")
- self.assertEqual(qwen_chat.await_count, 2)
- planning_messages = qwen_chat.await_args_list[0].args[0]
- self.assertEqual(planning_messages[0]["role"], "system")
- self.assertIn("安全培训需求整理", planning_messages[0]["content"])
- enriched_message = rag_search.await_args.args[0]
- self.assertIn("输出类型:安全培训PPT大纲", enriched_message)
- self.assertIn("主题:施工安全培训", enriched_message)
- self.assertIn("模板:标准安全培训PPT大纲", enriched_message)
- self.assertIn("培训时间:2025/4/10", enriched_message)
- self.assertIn("培训地点:蜀道集团", enriched_message)
- self.assertNotIn("公文写作", enriched_message)
- rag_search.assert_awaited_once_with(enriched_message, top_k=10)
- load_prompt.assert_called_once_with(
- "ppt_outline",
- userMessage=enriched_message,
- contextJSON="Safety RAG context",
- )
- generation_messages = qwen_chat.await_args_list[1].args[0]
- self.assertEqual(generation_messages[0], {"role": "system", "content": "ppt outline prompt"})
- self.assertIn("直接输出安全培训PPT大纲", generation_messages[1]["content"])
- deepseek_chat.assert_not_called()
- def test_safety_training_fallback_plan_keeps_notice_requests_in_training_outline_domain(self):
- plan = chat._build_fallback_safety_training_plan("生成施工安全培训通知")
- enriched_message = chat._build_safety_training_generation_message("生成施工安全培训通知", plan)
- self.assertEqual(plan["topic"], "施工安全培训")
- self.assertIn("输出类型:安全培训PPT大纲", enriched_message)
- self.assertIn("主题:施工安全培训", enriched_message)
- self.assertIn("原始需求:生成施工安全培训通知", enriched_message)
- self.assertNotIn("公文写作", enriched_message)
- async def test_ai_writing_generation_uses_deepseek_non_streaming(self):
- with patch.object(chat, "_rag_search", AsyncMock(return_value="RAG context")) as rag_search, \
- patch.object(chat, "load_prompt", Mock(return_value="loaded prompt")) as load_prompt, \
- patch.object(chat.deepseek_service, "chat", AsyncMock(return_value="DeepSeek reply")) as deepseek_chat, \
- patch.object(chat.qwen_service, "chat", AsyncMock(return_value="Qwen reply")) as qwen_chat:
- reply = await chat._generate_ai_writing_response("Draft a notice")
- self.assertEqual(reply, "DeepSeek reply")
- rag_search.assert_awaited_once_with("Draft a notice", top_k=10)
- load_prompt.assert_called_once_with(
- "document_writing",
- userMessage="Draft a notice",
- contextJSON="RAG context",
- )
- deepseek_messages = deepseek_chat.await_args.args[0]
- self.assertEqual(deepseek_messages[0], {"role": "system", "content": "loaded prompt"})
- self.assertEqual(deepseek_messages[1]["role"], "user")
- self.assertIn("直接生成可放入富文本编辑器的公文正文 HTML 片段", deepseek_messages[1]["content"])
- self.assertIn("Draft a notice", deepseek_messages[1]["content"])
- qwen_chat.assert_not_called()
- async def test_ai_writing_generation_cleans_full_html_document_for_rich_editor(self):
- raw_html = """抱歉,我只能帮助您生成蜀道集团的公文内容。
- <!DOCTYPE html>
- <html><head><style>body{color:red}</style></head>
- <body><div class="document"><h1>安全生产责任制</h1><p>正文内容</p></div></body></html>
- """
- with patch.object(chat, "_rag_search", AsyncMock(return_value="RAG context")), \
- patch.object(chat, "load_prompt", Mock(return_value="loaded prompt")), \
- patch.object(chat.deepseek_service, "chat", AsyncMock(return_value=raw_html)):
- reply = await chat._generate_ai_writing_response("生成安全生产责任制")
- self.assertIn('<div class="document">', reply)
- self.assertIn("<h1>安全生产责任制</h1>", reply)
- self.assertNotIn("抱歉", reply)
- self.assertNotIn("<!DOCTYPE", reply)
- self.assertNotIn("<style>", reply)
- self.assertNotIn("<body>", reply)
- def test_ai_writing_exchange_is_persisted_as_user_and_ai_messages(self):
- db = FakeDB()
- user_message, ai_message = chat._persist_message_pair(
- db=db,
- conv_id=123,
- user=SimpleNamespace(user_id=70430),
- user_content="user request",
- ai_content="ai reply",
- )
- self.assertEqual(len(db.added), 2)
- self.assertEqual(user_message.type, "user")
- self.assertEqual(user_message.content, "user request")
- self.assertEqual(ai_message.type, "ai")
- self.assertEqual(ai_message.content, "ai reply")
- self.assertEqual(ai_message.prev_user_id, user_message.id)
- self.assertEqual(db.commits, 2)
- if __name__ == "__main__":
- unittest.main()
|