test_report_compat.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. import json
  2. import importlib.util
  3. from pathlib import Path
  4. import unittest
  5. from types import SimpleNamespace
  6. from models.report import ReportCompleteFlowRequest
  7. REPORT_COMPAT_PATH = Path(__file__).resolve().parents[1] / "routers" / "report_compat.py"
  8. spec = importlib.util.spec_from_file_location("report_compat_under_test", REPORT_COMPAT_PATH)
  9. report_compat = importlib.util.module_from_spec(spec)
  10. spec.loader.exec_module(report_compat)
  11. class ReportCompatProxyBodyTests(unittest.TestCase):
  12. def _request(self, user_id=70430):
  13. return SimpleNamespace(state=SimpleNamespace(user=SimpleNamespace(user_id=user_id)))
  14. def _payload_for(self, ai_conversation_id, request_user_id=70430, payload_user_id=None):
  15. request_data = ReportCompleteFlowRequest(
  16. user_question="history save regression",
  17. ai_conversation_id=ai_conversation_id,
  18. user_id=payload_user_id,
  19. )
  20. body = report_compat._build_aichat_complete_flow_body(
  21. request_data,
  22. self._request(user_id=request_user_id),
  23. )
  24. return json.loads(body.decode("utf-8"))
  25. def test_new_conversation_is_forwarded_as_zero(self):
  26. payload = self._payload_for(0)
  27. self.assertEqual(payload["ai_conversation_id"], 0)
  28. def test_missing_conversation_id_is_forwarded_as_zero(self):
  29. payload = self._payload_for(None)
  30. self.assertEqual(payload["ai_conversation_id"], 0)
  31. def test_existing_conversation_id_is_preserved(self):
  32. payload = self._payload_for(12345)
  33. self.assertEqual(payload["ai_conversation_id"], 12345)
  34. def test_request_user_id_is_forwarded_to_aichat(self):
  35. payload = self._payload_for(0, request_user_id=70430)
  36. self.assertEqual(payload["user_id"], 70430)
  37. def test_payload_user_id_takes_precedence(self):
  38. payload = self._payload_for(0, request_user_id=70430, payload_user_id=88)
  39. self.assertEqual(payload["user_id"], 88)
  40. if __name__ == "__main__":
  41. unittest.main()