import importlib.util import sys import unittest from pathlib import Path from types import SimpleNamespace from unittest.mock import AsyncMock, Mock, patch CHAT_PATH = Path(__file__).resolve().parents[1] / "routers" / "chat.py" sys.path.insert(0, str(CHAT_PATH.parents[1])) spec = importlib.util.spec_from_file_location("chat_under_test", CHAT_PATH) chat = importlib.util.module_from_spec(spec) spec.loader.exec_module(chat) class FakeDB: def __init__(self): self.added = [] self.commits = 0 self.refreshed = [] def add(self, obj): self.added.append(obj) def commit(self): self.commits += 1 def refresh(self, obj): if getattr(obj, "id", None) is None: obj.id = len(self.refreshed) + 100 self.refreshed.append(obj) class ChatAIWritingRouteTest(unittest.IsolatedAsyncioTestCase): async def test_safety_training_generation_enriches_request_then_uses_legacy_ppt_outline_chain(self): planning_reply = """ { "topic": "施工安全培训", "template": "标准安全培训PPT大纲", "content_focus": ["员工安全意识", "施工现场着装规范"], "audience": "蜀道集团员工", "time": "2025/4/10", "location": "蜀道集团", "goal": "提升培训员工的安全意识", "notes": "着装随意", "normalized_request": "围绕2025年第一季度施工安全培训生成安全培训PPT大纲" } """ with patch.object(chat, "_rag_search", AsyncMock(return_value="Safety RAG context")) as rag_search, \ patch.object(chat, "load_prompt", Mock(return_value="ppt outline prompt")) as load_prompt, \ patch.object(chat.qwen_service, "chat", AsyncMock(side_effect=[planning_reply, "PPT outline reply"])) as qwen_chat, \ patch.object(chat.deepseek_service, "chat", AsyncMock(return_value="DeepSeek reply")) as deepseek_chat: reply = await chat._generate_ppt_outline_response( "生成一份施工安全培训通知,2025年第一季度施工安全培训,在蜀道集团,2025/4/10进行培训,培训员工的安全意识,着装随意。" ) self.assertEqual(reply, "PPT outline reply") self.assertEqual(qwen_chat.await_count, 2) planning_messages = qwen_chat.await_args_list[0].args[0] self.assertEqual(planning_messages[0]["role"], "system") self.assertIn("安全培训需求整理", planning_messages[0]["content"]) enriched_message = rag_search.await_args.args[0] self.assertIn("输出类型:安全培训PPT大纲", enriched_message) self.assertIn("主题:施工安全培训", enriched_message) self.assertIn("模板:标准安全培训PPT大纲", enriched_message) self.assertIn("培训时间:2025/4/10", enriched_message) self.assertIn("培训地点:蜀道集团", enriched_message) self.assertNotIn("公文写作", enriched_message) rag_search.assert_awaited_once_with(enriched_message, top_k=10) load_prompt.assert_called_once_with( "ppt_outline", userMessage=enriched_message, contextJSON="Safety RAG context", ) generation_messages = qwen_chat.await_args_list[1].args[0] self.assertEqual(generation_messages[0], {"role": "system", "content": "ppt outline prompt"}) self.assertIn("直接输出安全培训PPT大纲", generation_messages[1]["content"]) deepseek_chat.assert_not_called() def test_safety_training_fallback_plan_keeps_notice_requests_in_training_outline_domain(self): plan = chat._build_fallback_safety_training_plan("生成施工安全培训通知") enriched_message = chat._build_safety_training_generation_message("生成施工安全培训通知", plan) self.assertEqual(plan["topic"], "施工安全培训") self.assertIn("输出类型:安全培训PPT大纲", enriched_message) self.assertIn("主题:施工安全培训", enriched_message) self.assertIn("原始需求:生成施工安全培训通知", enriched_message) self.assertNotIn("公文写作", enriched_message) async def test_ai_writing_generation_uses_deepseek_non_streaming(self): with patch.object(chat, "_rag_search", AsyncMock(return_value="RAG context")) as rag_search, \ patch.object(chat, "load_prompt", Mock(return_value="loaded prompt")) as load_prompt, \ patch.object(chat.deepseek_service, "chat", AsyncMock(return_value="DeepSeek reply")) as deepseek_chat, \ patch.object(chat.qwen_service, "chat", AsyncMock(return_value="Qwen reply")) as qwen_chat: reply = await chat._generate_ai_writing_response("Draft a notice") self.assertEqual(reply, "DeepSeek reply") rag_search.assert_awaited_once_with("Draft a notice", top_k=10) load_prompt.assert_called_once_with( "document_writing", userMessage="Draft a notice", contextJSON="RAG context", ) deepseek_messages = deepseek_chat.await_args.args[0] self.assertEqual(deepseek_messages[0], {"role": "system", "content": "loaded prompt"}) self.assertEqual(deepseek_messages[1]["role"], "user") self.assertIn("直接生成可放入富文本编辑器的公文正文 HTML 片段", deepseek_messages[1]["content"]) self.assertIn("Draft a notice", deepseek_messages[1]["content"]) qwen_chat.assert_not_called() async def test_ai_writing_generation_cleans_full_html_document_for_rich_editor(self): raw_html = """抱歉,我只能帮助您生成蜀道集团的公文内容。

安全生产责任制

正文内容

""" with patch.object(chat, "_rag_search", AsyncMock(return_value="RAG context")), \ patch.object(chat, "load_prompt", Mock(return_value="loaded prompt")), \ patch.object(chat.deepseek_service, "chat", AsyncMock(return_value=raw_html)): reply = await chat._generate_ai_writing_response("生成安全生产责任制") self.assertIn('
', reply) self.assertIn("

安全生产责任制

", reply) self.assertNotIn("抱歉", reply) self.assertNotIn("", reply) self.assertNotIn("", reply) def test_ai_writing_exchange_is_persisted_as_user_and_ai_messages(self): db = FakeDB() user_message, ai_message = chat._persist_message_pair( db=db, conv_id=123, user=SimpleNamespace(user_id=70430), user_content="user request", ai_content="ai reply", ) self.assertEqual(len(db.added), 2) self.assertEqual(user_message.type, "user") self.assertEqual(user_message.content, "user request") self.assertEqual(ai_message.type, "ai") self.assertEqual(ai_message.content, "ai reply") self.assertEqual(ai_message.prev_user_id, user_message.id) self.assertEqual(db.commits, 2) if __name__ == "__main__": unittest.main()