| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144 |
- import importlib.util
- import json
- import unittest
- from pathlib import Path
- from types import SimpleNamespace
- from unittest.mock import AsyncMock, patch
- CHAT_PATH = Path(__file__).resolve().parents[1] / "routers" / "chat.py"
- spec = importlib.util.spec_from_file_location("chat_under_test", CHAT_PATH)
- chat = importlib.util.module_from_spec(spec)
- spec.loader.exec_module(chat)
- def contains_chinese(text):
- return any("\u4e00" <= char <= "\u9fff" for char in text or "")
- class FakeQuery:
- def __init__(self, db):
- self.db = db
- def filter(self, *args, **kwargs):
- return self
- def first(self):
- return self.db.ai_message
- def update(self, values):
- self.db.updated_values = values
- return 1
- class FakeDB:
- def __init__(self, content):
- self.ai_message = SimpleNamespace(id=10, content=content, is_deleted=0)
- self.updated_values = None
- self.commit_called = False
- def query(self, model):
- return FakeQuery(self)
- def commit(self):
- self.commit_called = True
- class GuessYouWantTests(unittest.IsolatedAsyncioTestCase):
- def _request(self):
- return SimpleNamespace(state=SimpleNamespace(user=SimpleNamespace(user_id=70430)))
- async def test_replaces_q_placeholders_with_real_questions(self):
- db = FakeDB("满堂支架施工有哪些安全技术措施和验收要求?")
- with patch.object(chat, "load_prompt", return_value="prompt"), patch.object(
- chat.qwen_service,
- "chat",
- AsyncMock(return_value='{"questions":["q1","q2","q3"]}'),
- ):
- response = await chat.guess_you_want(
- self._request(),
- chat.GuessYouWantRequest(ai_message_id=10),
- db=db,
- )
- questions = response["data"]["questions"]
- self.assertEqual(response["statusCode"], 200)
- self.assertEqual(len(questions), 3)
- self.assertTrue(all(contains_chinese(question) for question in questions))
- self.assertFalse(any(question.lower() in {"q1", "q2", "q3"} for question in questions))
- self.assertTrue(db.commit_called)
- self.assertEqual(json.loads(db.updated_values["guess_you_want"])["questions"], questions)
- async def test_replaces_problem_number_placeholders_with_real_questions(self):
- db = FakeDB("桥梁模板支架验收时需要重点检查哪些项目?")
- with patch.object(chat, "load_prompt", return_value="prompt"), patch.object(
- chat.qwen_service,
- "chat",
- AsyncMock(return_value='{"questions":["问题1","问题2","问题3"]}'),
- ):
- response = await chat.guess_you_want(
- self._request(),
- chat.GuessYouWantRequest(ai_message_id=10),
- db=db,
- )
- questions = response["data"]["questions"]
- self.assertEqual(len(questions), 3)
- self.assertTrue(all(contains_chinese(question) for question in questions))
- self.assertFalse(any(question in {"问题1", "问题2", "问题3"} for question in questions))
- async def test_replaces_prompt_leakage_with_real_questions(self):
- db = FakeDB("满堂支架施工方案编制、审批和验收有哪些安全管控要求?")
- leaked_response = "\n".join([
- "Thinking Process:",
- "**Analyze the Request:**",
- "**Role:** Professional question recommendation assistant focused on infrastructure construction technology (roads, bridges, tunnels, rails).",
- ])
- with patch.object(chat, "load_prompt", return_value="prompt"), patch.object(
- chat.qwen_service,
- "chat",
- AsyncMock(return_value=leaked_response),
- ):
- response = await chat.guess_you_want(
- self._request(),
- chat.GuessYouWantRequest(ai_message_id=10),
- db=db,
- )
- questions = response["data"]["questions"]
- self.assertEqual(len(questions), 3)
- self.assertTrue(all(contains_chinese(question) for question in questions))
- self.assertFalse(any("thinking process" in question.lower() for question in questions))
- self.assertFalse(any("analyze the request" in question.lower() for question in questions))
- self.assertFalse(any("professional question recommendation" in question.lower() for question in questions))
- async def test_preserves_valid_generated_questions(self):
- db = FakeDB("临边防护验收有哪些关键标准?")
- valid_questions = [
- "临边防护栏杆的高度和间距要求是什么?",
- "不同作业高度下的防护措施有何差异?",
- "现场验收时应重点检查哪些隐患?",
- ]
- with patch.object(chat, "load_prompt", return_value="prompt"), patch.object(
- chat.qwen_service,
- "chat",
- AsyncMock(return_value=json.dumps({"questions": valid_questions}, ensure_ascii=False)),
- ):
- response = await chat.guess_you_want(
- self._request(),
- chat.GuessYouWantRequest(ai_message_id=10),
- db=db,
- )
- self.assertEqual(response["data"]["questions"], valid_questions)
- if __name__ == "__main__":
- unittest.main()
|