test_chat_ai_writing_route.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. import importlib.util
  2. import sys
  3. import unittest
  4. from pathlib import Path
  5. from types import SimpleNamespace
  6. from unittest.mock import AsyncMock, Mock, patch
  7. CHAT_PATH = Path(__file__).resolve().parents[1] / "routers" / "chat.py"
  8. sys.path.insert(0, str(CHAT_PATH.parents[1]))
  9. spec = importlib.util.spec_from_file_location("chat_under_test", CHAT_PATH)
  10. chat = importlib.util.module_from_spec(spec)
  11. spec.loader.exec_module(chat)
  12. class FakeDB:
  13. def __init__(self):
  14. self.added = []
  15. self.commits = 0
  16. self.refreshed = []
  17. def add(self, obj):
  18. self.added.append(obj)
  19. def commit(self):
  20. self.commits += 1
  21. def refresh(self, obj):
  22. if getattr(obj, "id", None) is None:
  23. obj.id = len(self.refreshed) + 100
  24. self.refreshed.append(obj)
  25. class ChatAIWritingRouteTest(unittest.IsolatedAsyncioTestCase):
  26. async def test_safety_training_generation_enriches_request_then_uses_legacy_ppt_outline_chain(self):
  27. planning_reply = """
  28. {
  29. "topic": "施工安全培训",
  30. "template": "标准安全培训PPT大纲",
  31. "content_focus": ["员工安全意识", "施工现场着装规范"],
  32. "audience": "蜀道集团员工",
  33. "time": "2025/4/10",
  34. "location": "蜀道集团",
  35. "goal": "提升培训员工的安全意识",
  36. "notes": "着装随意",
  37. "normalized_request": "围绕2025年第一季度施工安全培训生成安全培训PPT大纲"
  38. }
  39. """
  40. with patch.object(chat, "_rag_search", AsyncMock(return_value="Safety RAG context")) as rag_search, \
  41. patch.object(chat, "load_prompt", Mock(return_value="ppt outline prompt")) as load_prompt, \
  42. patch.object(chat.qwen_service, "chat", AsyncMock(side_effect=[planning_reply, "PPT outline reply"])) as qwen_chat, \
  43. patch.object(chat.deepseek_service, "chat", AsyncMock(return_value="DeepSeek reply")) as deepseek_chat:
  44. reply = await chat._generate_ppt_outline_response(
  45. "生成一份施工安全培训通知,2025年第一季度施工安全培训,在蜀道集团,2025/4/10进行培训,培训员工的安全意识,着装随意。"
  46. )
  47. self.assertEqual(reply, "PPT outline reply")
  48. self.assertEqual(qwen_chat.await_count, 2)
  49. planning_messages = qwen_chat.await_args_list[0].args[0]
  50. self.assertEqual(planning_messages[0]["role"], "system")
  51. self.assertIn("安全培训需求整理", planning_messages[0]["content"])
  52. enriched_message = rag_search.await_args.args[0]
  53. self.assertIn("输出类型:安全培训PPT大纲", enriched_message)
  54. self.assertIn("主题:施工安全培训", enriched_message)
  55. self.assertIn("模板:标准安全培训PPT大纲", enriched_message)
  56. self.assertIn("培训时间:2025/4/10", enriched_message)
  57. self.assertIn("培训地点:蜀道集团", enriched_message)
  58. self.assertNotIn("公文写作", enriched_message)
  59. rag_search.assert_awaited_once_with(enriched_message, top_k=10)
  60. load_prompt.assert_called_once_with(
  61. "ppt_outline",
  62. userMessage=enriched_message,
  63. contextJSON="Safety RAG context",
  64. )
  65. generation_messages = qwen_chat.await_args_list[1].args[0]
  66. self.assertEqual(generation_messages[0], {"role": "system", "content": "ppt outline prompt"})
  67. self.assertIn("直接输出安全培训PPT大纲", generation_messages[1]["content"])
  68. deepseek_chat.assert_not_called()
  69. def test_safety_training_fallback_plan_keeps_notice_requests_in_training_outline_domain(self):
  70. plan = chat._build_fallback_safety_training_plan("生成施工安全培训通知")
  71. enriched_message = chat._build_safety_training_generation_message("生成施工安全培训通知", plan)
  72. self.assertEqual(plan["topic"], "施工安全培训")
  73. self.assertIn("输出类型:安全培训PPT大纲", enriched_message)
  74. self.assertIn("主题:施工安全培训", enriched_message)
  75. self.assertIn("原始需求:生成施工安全培训通知", enriched_message)
  76. self.assertNotIn("公文写作", enriched_message)
  77. async def test_ai_writing_generation_uses_deepseek_non_streaming(self):
  78. with patch.object(chat, "_rag_search", AsyncMock(return_value="RAG context")) as rag_search, \
  79. patch.object(chat, "load_prompt", Mock(return_value="loaded prompt")) as load_prompt, \
  80. patch.object(chat.deepseek_service, "chat", AsyncMock(return_value="DeepSeek reply")) as deepseek_chat, \
  81. patch.object(chat.qwen_service, "chat", AsyncMock(return_value="Qwen reply")) as qwen_chat:
  82. reply = await chat._generate_ai_writing_response("Draft a notice")
  83. self.assertEqual(reply, "DeepSeek reply")
  84. rag_search.assert_awaited_once_with("Draft a notice", top_k=10)
  85. load_prompt.assert_called_once_with(
  86. "document_writing",
  87. userMessage="Draft a notice",
  88. contextJSON="RAG context",
  89. )
  90. deepseek_messages = deepseek_chat.await_args.args[0]
  91. self.assertEqual(deepseek_messages[0], {"role": "system", "content": "loaded prompt"})
  92. self.assertEqual(deepseek_messages[1]["role"], "user")
  93. self.assertIn("直接生成可放入富文本编辑器的公文正文 HTML 片段", deepseek_messages[1]["content"])
  94. self.assertIn("Draft a notice", deepseek_messages[1]["content"])
  95. qwen_chat.assert_not_called()
  96. async def test_ai_writing_generation_cleans_full_html_document_for_rich_editor(self):
  97. raw_html = """抱歉,我只能帮助您生成蜀道集团的公文内容。
  98. <!DOCTYPE html>
  99. <html><head><style>body{color:red}</style></head>
  100. <body><div class="document"><h1>安全生产责任制</h1><p>正文内容</p></div></body></html>
  101. """
  102. with patch.object(chat, "_rag_search", AsyncMock(return_value="RAG context")), \
  103. patch.object(chat, "load_prompt", Mock(return_value="loaded prompt")), \
  104. patch.object(chat.deepseek_service, "chat", AsyncMock(return_value=raw_html)):
  105. reply = await chat._generate_ai_writing_response("生成安全生产责任制")
  106. self.assertIn('<div class="document">', reply)
  107. self.assertIn("<h1>安全生产责任制</h1>", reply)
  108. self.assertNotIn("抱歉", reply)
  109. self.assertNotIn("<!DOCTYPE", reply)
  110. self.assertNotIn("<style>", reply)
  111. self.assertNotIn("<body>", reply)
  112. def test_ai_writing_exchange_is_persisted_as_user_and_ai_messages(self):
  113. db = FakeDB()
  114. user_message, ai_message = chat._persist_message_pair(
  115. db=db,
  116. conv_id=123,
  117. user=SimpleNamespace(user_id=70430),
  118. user_content="user request",
  119. ai_content="ai reply",
  120. )
  121. self.assertEqual(len(db.added), 2)
  122. self.assertEqual(user_message.type, "user")
  123. self.assertEqual(user_message.content, "user request")
  124. self.assertEqual(ai_message.type, "ai")
  125. self.assertEqual(ai_message.content, "ai reply")
  126. self.assertEqual(ai_message.prev_user_id, user_message.id)
  127. self.assertEqual(db.commits, 2)
  128. if __name__ == "__main__":
  129. unittest.main()