test_guess_you_want.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. import importlib.util
  2. import json
  3. import unittest
  4. from pathlib import Path
  5. from types import SimpleNamespace
  6. from unittest.mock import AsyncMock, patch
  7. CHAT_PATH = Path(__file__).resolve().parents[1] / "routers" / "chat.py"
  8. spec = importlib.util.spec_from_file_location("chat_under_test", CHAT_PATH)
  9. chat = importlib.util.module_from_spec(spec)
  10. spec.loader.exec_module(chat)
  11. def contains_chinese(text):
  12. return any("\u4e00" <= char <= "\u9fff" for char in text or "")
  13. class FakeQuery:
  14. def __init__(self, db):
  15. self.db = db
  16. def filter(self, *args, **kwargs):
  17. return self
  18. def first(self):
  19. return self.db.ai_message
  20. def update(self, values):
  21. self.db.updated_values = values
  22. return 1
  23. class FakeDB:
  24. def __init__(self, content):
  25. self.ai_message = SimpleNamespace(id=10, content=content, is_deleted=0)
  26. self.updated_values = None
  27. self.commit_called = False
  28. def query(self, model):
  29. return FakeQuery(self)
  30. def commit(self):
  31. self.commit_called = True
  32. class GuessYouWantTests(unittest.IsolatedAsyncioTestCase):
  33. def _request(self):
  34. return SimpleNamespace(state=SimpleNamespace(user=SimpleNamespace(user_id=70430)))
  35. async def test_replaces_q_placeholders_with_real_questions(self):
  36. db = FakeDB("满堂支架施工有哪些安全技术措施和验收要求?")
  37. with patch.object(chat, "load_prompt", return_value="prompt"), patch.object(
  38. chat.qwen_service,
  39. "chat",
  40. AsyncMock(return_value='{"questions":["q1","q2","q3"]}'),
  41. ):
  42. response = await chat.guess_you_want(
  43. self._request(),
  44. chat.GuessYouWantRequest(ai_message_id=10),
  45. db=db,
  46. )
  47. questions = response["data"]["questions"]
  48. self.assertEqual(response["statusCode"], 200)
  49. self.assertEqual(len(questions), 3)
  50. self.assertTrue(all(contains_chinese(question) for question in questions))
  51. self.assertFalse(any(question.lower() in {"q1", "q2", "q3"} for question in questions))
  52. self.assertTrue(db.commit_called)
  53. self.assertEqual(json.loads(db.updated_values["guess_you_want"])["questions"], questions)
  54. async def test_replaces_problem_number_placeholders_with_real_questions(self):
  55. db = FakeDB("桥梁模板支架验收时需要重点检查哪些项目?")
  56. with patch.object(chat, "load_prompt", return_value="prompt"), patch.object(
  57. chat.qwen_service,
  58. "chat",
  59. AsyncMock(return_value='{"questions":["问题1","问题2","问题3"]}'),
  60. ):
  61. response = await chat.guess_you_want(
  62. self._request(),
  63. chat.GuessYouWantRequest(ai_message_id=10),
  64. db=db,
  65. )
  66. questions = response["data"]["questions"]
  67. self.assertEqual(len(questions), 3)
  68. self.assertTrue(all(contains_chinese(question) for question in questions))
  69. self.assertFalse(any(question in {"问题1", "问题2", "问题3"} for question in questions))
  70. async def test_replaces_prompt_leakage_with_real_questions(self):
  71. db = FakeDB("满堂支架施工方案编制、审批和验收有哪些安全管控要求?")
  72. leaked_response = "\n".join([
  73. "Thinking Process:",
  74. "**Analyze the Request:**",
  75. "**Role:** Professional question recommendation assistant focused on infrastructure construction technology (roads, bridges, tunnels, rails).",
  76. ])
  77. with patch.object(chat, "load_prompt", return_value="prompt"), patch.object(
  78. chat.qwen_service,
  79. "chat",
  80. AsyncMock(return_value=leaked_response),
  81. ):
  82. response = await chat.guess_you_want(
  83. self._request(),
  84. chat.GuessYouWantRequest(ai_message_id=10),
  85. db=db,
  86. )
  87. questions = response["data"]["questions"]
  88. self.assertEqual(len(questions), 3)
  89. self.assertTrue(all(contains_chinese(question) for question in questions))
  90. self.assertFalse(any("thinking process" in question.lower() for question in questions))
  91. self.assertFalse(any("analyze the request" in question.lower() for question in questions))
  92. self.assertFalse(any("professional question recommendation" in question.lower() for question in questions))
  93. async def test_preserves_valid_generated_questions(self):
  94. db = FakeDB("临边防护验收有哪些关键标准?")
  95. valid_questions = [
  96. "临边防护栏杆的高度和间距要求是什么?",
  97. "不同作业高度下的防护措施有何差异?",
  98. "现场验收时应重点检查哪些隐患?",
  99. ]
  100. with patch.object(chat, "load_prompt", return_value="prompt"), patch.object(
  101. chat.qwen_service,
  102. "chat",
  103. AsyncMock(return_value=json.dumps({"questions": valid_questions}, ensure_ascii=False)),
  104. ):
  105. response = await chat.guess_you_want(
  106. self._request(),
  107. chat.GuessYouWantRequest(ai_message_id=10),
  108. db=db,
  109. )
  110. self.assertEqual(response["data"]["questions"], valid_questions)
  111. if __name__ == "__main__":
  112. unittest.main()