test_workflow_graph.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. # !/usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. '''
  4. @Project :
  5. @File :workflow_graph.py
  6. @IDE :Cursor
  7. @Author :LINGMIN
  8. @Date :2025/08/10 18:00
  9. '''
  10. from foundation.agent.workflow.test_cus_state import TestCusState
  11. from foundation.agent.workflow.test_workflow_node import TestWorkflowNode
  12. from langgraph.graph import START, StateGraph, END
  13. from langgraph.checkpoint.memory import MemorySaver
  14. from foundation.logger.loggering import server_logger
  15. from typing import AsyncGenerator
  16. import time
  17. from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
  18. from foundation.utils.common import return_json, handler_err
  19. import json
  20. from foundation.schemas.test_schemas import TestForm, FormConfig
  21. class TestWorkflowGraph:
  22. """
  23. 工作流图
  24. """
  25. def __init__(self):
  26. self.workflow_node = TestWorkflowNode()
  27. self.checkpoint_saver = MemorySaver()
  28. self.app = self.init_workflow_graph()
  29. # 将生成的图片保存到文件
  30. self.write_graph()
  31. def init_workflow_graph(self):
  32. """
  33. 初始化工作流图
  34. 使用 graph.get_state 和 get_state_history 检查状态。
  35. 启用 debug=True 查看详细日志。
  36. 使用 graph.get_graph().to_dot() 可视化状态图。
  37. """
  38. # 构建工作流图 创建状态图 , state_update_method="merge"
  39. workflow = StateGraph(TestCusState)
  40. ######分支2、代理Agent supervisor_agent ##################################
  41. # 节点: 代理 agent 节点
  42. workflow.add_node("supervisor_agent", self.workflow_node.supervisor_agent)
  43. # agent节点1: 纯生成类问题
  44. workflow.add_node("chat_box_generate", self.workflow_node.chat_box_generate)
  45. # agent节点2:
  46. workflow.add_node("common_agent", self.workflow_node.common_agent_node)
  47. ###### 节点分支线条 ##################################
  48. # 固定问题识别
  49. workflow.add_edge(START, "supervisor_agent")
  50. # 在图状态中填充 ‘next’字段,路由到具体的某个节点或结束图的运行,从来指定如何执行接下来的任务。
  51. workflow.add_conditional_edges(source="supervisor_agent",
  52. path=lambda state: state["route_next"],
  53. # 显式映射每个返回值到目标节点
  54. path_map={
  55. "chat_box_generate": "chat_box_generate",
  56. "common_agent": "common_agent",
  57. }
  58. )
  59. supervisor_members_list = ["chat_box_generate" , "common_agent"]
  60. # 每个子代理 在完成后总是向主管 “汇报”
  61. for agent_member in supervisor_members_list:
  62. workflow.add_edge(agent_member, END) # 直接结束
  63. #workflow.add_edge(agent_member, "supervisor_agent") # 回到路由 继续 判断执行
  64. #编译图
  65. app = workflow.compile(checkpointer=self.checkpoint_saver)
  66. #print(app.get_graph().draw_ascii())
  67. server_logger.info(f"【图工作流构建完成】app={app}")
  68. return app
  69. async def handle_query_stream(self, param: TestForm, trace_id: str)-> AsyncGenerator[str, None]:
  70. """
  71. 根据场景获取智能体反馈 (SSE流式响应)
  72. """
  73. try:
  74. # 提取参数
  75. user_input = param.input
  76. session_id = param.config.session_id
  77. context = param.context
  78. human_messages = [HumanMessage(content=user_input)]
  79. # 完整的初始状态
  80. initial_state = {
  81. "messages": human_messages,
  82. "session_id": session_id, # 会话id
  83. "trace_id": trace_id, # 日志链路跟踪id
  84. "task_prompt_info": {},
  85. "context": context , # 上下文数据
  86. "user_input": user_input,
  87. }
  88. # 唯一的任务 ID(模拟 session_id / thread_id)
  89. config = {"configurable": {"thread_id": session_id},
  90. "runnable_kwargs":{"recursion_limit": 50}
  91. }
  92. server_logger.info("======================== 启动新任务 ===========================") #, interrupt_before=["user_confirm_task_planning"]
  93. full_response = []
  94. buffer = []
  95. last_flush_time = time.time()
  96. events = self.app.astream_events(initial_state,
  97. config=config ,
  98. version="v1", # 确保使用正确版本
  99. stream_mode="values" # 或者 "updates"
  100. )
  101. # 流式处理事件
  102. async for event in events:
  103. #server_logger.info(trace_id=trace_id, msg=f"→ 事件类型: {event['event']}")
  104. #server_logger.info(trace_id=trace_id, msg=f"→ 事件数据: {event['data']}")
  105. # 处理聊天模型流式输出
  106. if event['event'] == 'on_chat_model_stream':
  107. if 'chunk' in event['data']:
  108. chunk = event['data']['chunk']
  109. if hasattr(chunk, 'content'):
  110. content = chunk.content
  111. full_response.append(content)
  112. # 缓冲管理策略
  113. buffer.append(content)
  114. current_time = time.time()
  115. # 刷新条件
  116. should_flush = (
  117. len(buffer) >= 3 or # 达到最小块数
  118. (current_time - last_flush_time) > 0.5 or # 超时
  119. any(content.endswith(('.', '。', '!', '?', '\n', ';', ';', '?', '!')) for content in buffer) # 自然断点
  120. )
  121. if should_flush:
  122. combined = ''.join(buffer)
  123. yield combined
  124. buffer.clear()
  125. last_flush_time = current_time
  126. # 也可以处理其他类型的事件
  127. # elif event['event'] == 'on_chain_stream':
  128. # server_logger.info(trace_id=trace_id, msg=f"链式处理: {event['data']}")
  129. # elif event['event'] == 'on_tool_stream':
  130. # server_logger.info(trace_id=trace_id, msg=f"工具调用: {event['data']}")
  131. # 处理剩余缓冲内容
  132. if buffer:
  133. yield ''.join(buffer)
  134. # 将完整响应添加到历史并进行压缩
  135. if full_response:
  136. full_text = "".join(full_response)
  137. server_logger.info(trace_id=trace_id, msg=f"full_response: {full_text}", log_type="graph/stream")
  138. except Exception as e:
  139. handler_err(server_logger, trace_id=trace_id, err=e, err_name='graph/stream')
  140. yield json.dumps({"error": f"系统错误: {str(e)}"})
  141. def write_graph(self):
  142. """
  143. 将图写入文件
  144. """
  145. #
  146. graph_png = self.app.get_graph().draw_mermaid_png()
  147. with open("build_graph_app.png", "wb") as f:
  148. f.write(graph_png)
  149. server_logger.info(f"【图工作流写入文件完成】")
  150. # 实例化
  151. test_workflow_graph = TestWorkflowGraph()