test_workflow_node.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. # !/usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. '''
  4. @Project :
  5. @File :workflow_node.py
  6. @IDE :Cursor
  7. @Author :LINGMIN
  8. @Date :2025/08/10 18:00
  9. '''
  10. import json
  11. import sys
  12. from foundation.logger.loggering import server_logger
  13. from foundation.utils.common import handler_err
  14. from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
  15. from langchain_core.prompts import ChatPromptTemplate
  16. from foundation.agent.workflow.test_cus_state import TestCusState
  17. from foundation.agent.generate.test_intent import intent_identify_client
  18. from foundation.agent.test_agent import test_agent_client
  19. from foundation.schemas.test_schemas import FormConfig
  20. from foundation.agent.generate.model_generate import generate_model_client
  21. class TestWorkflowNode:
  22. """
  23. 工作流节点定义
  24. """
  25. def __init__(self):
  26. """初始化模型和会话管理"""
  27. def supervisor_agent(self , state: TestCusState):
  28. """
  29. 每个代理都与一个 Supervisor 代理通信(主管代理)。由 Supervisor 代理决定接下来应调用哪个代理
  30. :param state:
  31. :return:
  32. """
  33. session_id = state["session_id"]
  34. trace_id = state["trace_id"]
  35. user_input = state["user_input"]
  36. route_next = state.get("route_next")
  37. server_logger.info(trace_id=trace_id, msg=f"\n===================================[Supervisor].begin-route_next:{route_next}=============================")
  38. config = {
  39. "session_id": session_id
  40. }
  41. # 格式化输出,智能格式化输出
  42. route_next = intent_identify_client.recognize_intent(trace_id=trace_id , config=config , input=user_input)
  43. server_logger.info(trace_id=trace_id, msg=f"[Supervisor].intent_identify_client.recognize_intent:{route_next}")
  44. if route_next not in ["chat_box_generate" , "common_agent"]:
  45. route_next = "chat_box_generate"
  46. server_logger.info(trace_id=trace_id, msg=f"\n===================================[Supervisor].end-route_next:{route_next}=============================")
  47. return {
  48. "route_next": route_next
  49. }
  50. async def common_agent_node(self , state: TestCusState):
  51. """
  52. 通用代理节点
  53. :param state:
  54. :return:
  55. """
  56. session_id = state["session_id"]
  57. trace_id = state["trace_id"]
  58. user_input = state["user_input"]
  59. config_param = FormConfig(session_id=session_id)
  60. task_prompt_info = {"task_prompt": ""}
  61. response_content = await test_agent_client.handle_query(trace_id=trace_id , config_param=config_param,
  62. task_prompt_info=task_prompt_info,
  63. input_query=user_input, context=None)
  64. messages = [AIMessage(content=response_content, name="common_agent_node")]
  65. return {
  66. "messages": messages,
  67. "previous_agent": "common_agent",
  68. "route_next": "FINISH" # ✅ 直接结束流程
  69. }
  70. def chat_box_generate(self , state: TestCusState) -> dict:
  71. """
  72. 模型生成节点(纯生成类问题)
  73. :param state:
  74. :return:
  75. """
  76. session_id = state["session_id"]
  77. trace_id = state["trace_id"]
  78. user_input = state["user_input"]
  79. task_prompt_info = state["task_prompt_info"]
  80. task_prompt_info["task_prompt"] = ""
  81. response_content = generate_model_client.get_model_generate_invoke(trace_id=trace_id , task_prompt_info=task_prompt_info, input_query=user_input)
  82. messages = [AIMessage(content=response_content , name="chat_box_generate")]
  83. server_logger.info(trace_id=trace_id, msg=f"【result】: {response_content}", log_type="chat_box_generate")
  84. return {
  85. "messages": messages,
  86. "route_next": "FINISH" # ✅ 直接结束流程
  87. }