test_workflow_node.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. # !/usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. '''
  4. @Project :
  5. @File :workflow_node.py
  6. @IDE :Cursor
  7. @Author :LINGMIN
  8. @Date :2025/08/10 18:00
  9. '''
  10. import json
  11. import sys
  12. from foundation.logger.loggering import server_logger
  13. from foundation.utils.common import handler_err
  14. from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
  15. from langchain_core.prompts import ChatPromptTemplate
  16. from foundation.agent.workflow.test_cus_state import TestCusState
  17. from foundation.agent.generate.test_intent import intent_identify_client
  18. from foundation.agent.test_agent import test_agent_client
  19. from foundation.schemas.test_schemas import FormConfig
  20. from foundation.agent.generate.model_generate import generate_model_client
  21. from foundation.utils.yaml_utils import system_prompt_config
  22. class TestWorkflowNode:
  23. """
  24. 工作流节点定义
  25. """
  26. def __init__(self):
  27. """初始化模型和会话管理"""
  28. def supervisor_agent(self , state: TestCusState):
  29. """
  30. 每个代理都与一个 Supervisor 代理通信(主管代理)。由 Supervisor 代理决定接下来应调用哪个代理
  31. :param state:
  32. :return:
  33. """
  34. session_id = state["session_id"]
  35. trace_id = state["trace_id"]
  36. user_input = state["user_input"]
  37. route_next = state.get("route_next")
  38. server_logger.info(trace_id=trace_id, msg=f"\n===================================[Supervisor].begin-route_next:{route_next}=============================")
  39. config = {
  40. "session_id": session_id
  41. }
  42. # 格式化输出,智能格式化输出
  43. route_next = intent_identify_client.recognize_intent(trace_id=trace_id , config=config , input=user_input)
  44. server_logger.info(trace_id=trace_id, msg=f"[Supervisor].intent_identify_client.recognize_intent:{route_next}")
  45. if route_next not in ["chat_box_generate" , "common_agent"]:
  46. route_next = "chat_box_generate"
  47. server_logger.info(trace_id=trace_id, msg=f"\n===================================[Supervisor].end-route_next:{route_next}=============================")
  48. return {
  49. "route_next": route_next
  50. }
  51. async def common_agent_node(self , state: TestCusState):
  52. """
  53. 通用代理节点
  54. :param state:
  55. :return:
  56. """
  57. session_id = state["session_id"]
  58. trace_id = state["trace_id"]
  59. user_input = state["user_input"]
  60. config_param = FormConfig(session_id=session_id)
  61. task_prompt_info = {"task_prompt": ""}
  62. response_content = await test_agent_client.handle_query(trace_id=trace_id , config_param=config_param,
  63. task_prompt_info=task_prompt_info,
  64. input_query=user_input, context=None)
  65. messages = [AIMessage(content=response_content, name="common_agent_node")]
  66. return {
  67. "messages": messages,
  68. "previous_agent": "common_agent",
  69. "route_next": "FINISH" # ✅ 直接结束流程
  70. }
  71. async def chat_box_generate(self , state: TestCusState) -> dict:
  72. """
  73. 模型生成节点(纯生成类问题)
  74. :param state:
  75. :return:
  76. """
  77. session_id = state["session_id"]
  78. trace_id = state["trace_id"]
  79. user_input = state["user_input"]
  80. task_prompt_info = state["task_prompt_info"]
  81. task_prompt_info["task_prompt"] = ""
  82. # 创建ChatPromptTemplate
  83. template = ChatPromptTemplate.from_messages([
  84. ("system", system_prompt_config['system_prompt']),
  85. ("user", user_input)
  86. ])
  87. task_prompt_info = {"task_prompt": template}
  88. response_content = await generate_model_client.get_model_generate_invoke(trace_id=trace_id , task_prompt_info=task_prompt_info)
  89. messages = [AIMessage(content=response_content , name="chat_box_generate")]
  90. server_logger.info(trace_id=trace_id, msg=f"【result】: {response_content}", log_type="chat_box_generate")
  91. return {
  92. "messages": messages,
  93. "route_next": "FINISH" # ✅ 直接结束流程
  94. }