Przeglądaj źródła

增加相关测试案例

lingmin_package@163.com 5 miesięcy temu
rodzic
commit
ac9bc884ee

+ 20 - 2
README.md

@@ -20,7 +20,7 @@
 
 ### 测试接口
 
-  - 生成模型接口 
+  #### 生成模型接口 
     - chat
         http://localhost:8001/test/generate/chat
         {
@@ -40,7 +40,7 @@
 
 
 
-  - agent 智能体
+  #### agent 智能体
     - chat
       http://localhost:8001/test/agent/stream
         {
@@ -60,3 +60,21 @@
         }
 
 
+  #### Workflow-Graph stream
+    - chat
+      http://localhost:8001/test/graph/stream
+        {
+          "config": {
+              "session_id":"111"
+          },
+          "input": "你好"
+        }
+
+    - agent
+      http://localhost:8001/test/graph/stream
+        {
+          "config": {
+              "session_id":"111"
+          },
+          "input": "查询信息"
+        }

+ 41 - 0
agent/function/test_funciton.py

@@ -0,0 +1,41 @@
+
+
+
+class TestFunciton:
+
+
+    def __init__(self):
+        pass
+
+
+
+    def query_info(self , session_id):
+        """
+            查询信息
+            session_id: 会话ID
+        """
+        return "查询结果:小红,session_id:"+session_id
+    
+
+
+    def execute(self , session_id):
+        """
+            执行任务
+            session_id: 会话ID
+        """
+        return "任务执行完成,session_id:"+session_id
+
+
+
+
+    def handle(self , session_id):
+        """
+            处理任务
+            session_id: 会话ID
+        """
+        return "处理结果:小东,session_id:"+session_id
+    
+
+
+
+test_funtion = TestFunciton()

+ 2 - 2
agent/generate/model_generate.py

@@ -52,7 +52,7 @@ class TestGenerateModelClient:
 
         # Step 2: 构建完整的 prompt 模板
         prompt_template = ChatPromptTemplate.from_messages([
-            ("system", task_prompt_info["task_prompt"]),
+            ("system", self.system_prompt), #task_prompt_info["task_prompt"]
             ("human", "{input}")
         ])
         # Step 3: 初始化模型
@@ -70,7 +70,7 @@ class TestGenerateModelClient:
 
         # Step 2: 构建完整的 prompt 模板
         prompt_template = ChatPromptTemplate.from_messages([
-            ("system", task_prompt_info["task_prompt"]),
+            ("system",  self.system_prompt), #task_prompt_info["task_prompt"]
             ("human", "{input}")
         ])
         # Step 3: 初始化模型

+ 105 - 0
agent/generate/test_intent.py

@@ -0,0 +1,105 @@
+# !/usr/bin/ python
+# -*- coding: utf-8 -*-
+'''
+@Project    : xiwu-agent-api
+@File       :intent.py
+@IDE        :PyCharm
+@Author     :LINGMIN
+@Date       :2025/7/14 12:04
+'''
+
+
+import os
+import sys
+sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
+
+from logger.loggering import server_logger
+from utils.utils import get_models
+from langchain_core.prompts import SystemMessagePromptTemplate
+from langchain_core.prompts import HumanMessagePromptTemplate
+from langchain_core.prompts import ChatPromptTemplate
+from langchain_core.prompts import FewShotChatMessagePromptTemplate
+from utils import yaml_utils
+from base.config import config_handler
+
+
+class TestIntentIdentifyClient:
+
+    def __init__(self):
+        """
+            创建意图识别类
+        """
+          # 获取部署的模型列表
+        llm, chat, embed = get_models()
+        self.llm_recognition = chat
+        # 加载 意图识别系统配置信息
+        self.intent_prompt = yaml_utils.get_intent_prompt()
+
+    def recognize_intent(self , trace_id: str , config: dict , input: str):
+        """
+        意图识别
+        输入:用户输入的问题
+        输出:识别出的意图,可选项:
+        """
+        session_id = config["session_id"]
+        history = "无"
+        # 根据历史记录和用户问题进行识别意图
+        return self.recognize_intent_history(input=input , history=history)
+
+
+    def recognize_intent_history(self , input: str , history="无"):
+        """
+        意图识别
+        输入:用户输入的问题
+        输出:识别出的意图,可选项:
+        """
+        # 准备few-shot样例
+        examples = self.intent_prompt["intent_examples"]
+        #server_logger.info(f"加载prompt配置.examples: {examples}")
+        system_prompt = self.intent_prompt["system_prompt"]
+        system_prompt = system_prompt.format(history=history)
+        server_logger.info(f"增加用户历史记录,用于意图识别,prompt配置.system_prompt: {system_prompt}")
+
+        # 定义样本模板
+        examples_prompt = ChatPromptTemplate.from_messages(
+            [
+                ("human", "{inn}"),
+                ("ai", "{out}"),
+            ]
+        )
+        few_shot_prompt = FewShotChatMessagePromptTemplate(example_prompt=examples_prompt,
+                                                           examples=examples)
+        final_prompt = ChatPromptTemplate.from_messages(
+            [
+                ('system', system_prompt),
+                few_shot_prompt,
+                ('human', '{input}'),
+            ]
+        )
+
+        chain = final_prompt | self.llm_recognition
+        server_logger.info(f"意图识别输入input: {input}")
+        result = chain.invoke(input={"input": input})
+        # 容错处理
+        if hasattr(result, 'content'):
+            # 如果 result 有 content 属性,使用它
+            return result.content
+        else:
+            # 否则,直接返回 result
+            return result
+
+
+
+
+
+intent_identify_client = TestIntentIdentifyClient()
+
+
+if __name__ == '__main__':
+   
+    input = "你好"
+    input = "查询课程"
+    input = "操作"
+    result = intent_identify_client.recognize_intent_history(history="" , input=input)
+    server_logger.info(f"result={result}")
+    

+ 2 - 1
agent/test_agent.py

@@ -24,6 +24,7 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
 from langchain_core.runnables import RunnableConfig
 from agent.base_agent import BaseAgent
 from schemas.test_schemas import FormConfig
+from agent.function.test_funciton import test_funtion
 
 
 class TestAgentClient(BaseAgent):
@@ -79,7 +80,7 @@ class TestAgentClient(BaseAgent):
         # 创建Agent - 不再使用MemorySaver
         self.agent_executor = create_react_agent(
             self.llm,
-            tools=[] ,  # 专用工具集 + 私有知识库检索工具
+            tools=[test_funtion.query_info , test_funtion.execute , test_funtion.handle] ,  # 专用工具集 + 私有知识库检索工具
             prompt=prompt
         )
         self.initialized = True

+ 21 - 0
agent/workflow/test_cus_state.py

@@ -0,0 +1,21 @@
+
+from itertools import count
+from langgraph.graph import MessagesState
+
+
+
+
+class TestCusState(MessagesState):
+    """
+     第二步:定义状态结构
+    """
+    route_next: str                                  # 下一个节点  
+    
+    session_id: str                                  # 会话id  
+    trace_id: str                                    # 日志链路跟踪id
+    user_input: str                                  # 用户输入问题    
+    context: str                                     # 上下文数据
+    task_prompt_info: str                            # 任务提示
+
+
+

+ 192 - 0
agent/workflow/test_workflow_graph.py

@@ -0,0 +1,192 @@
+
+# !/usr/bin/python
+# -*- coding: utf-8 -*-
+'''
+@Project    : 
+@File       :workflow_graph.py
+@IDE        :Cursor
+@Author     :LINGMIN
+@Date       :2025/08/10 18:00
+'''
+
+from agent.workflow.test_cus_state import TestCusState
+from agent.workflow.test_workflow_node import TestWorkflowNode
+from langgraph.graph import START, StateGraph, END
+from langgraph.checkpoint.memory import MemorySaver
+from logger.loggering import server_logger
+from typing import AsyncGenerator
+import time
+from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
+from utils.common import return_json, handler_err
+import json
+from schemas.test_schemas import TestForm, FormConfig
+
+
+class TestWorkflowGraph:
+    """
+        工作流图
+    """
+    def __init__(self):
+        self.workflow_node = TestWorkflowNode()
+        self.checkpoint_saver = MemorySaver()
+        self.app = self.init_workflow_graph()
+        # 将生成的图片保存到文件
+        self.write_graph()
+
+
+
+
+    def init_workflow_graph(self):
+        """
+            初始化工作流图
+            使用 graph.get_state 和 get_state_history 检查状态。
+            启用 debug=True 查看详细日志。
+            使用 graph.get_graph().to_dot() 可视化状态图。
+        """
+        # 构建工作流图  创建状态图 , state_update_method="merge"
+        workflow = StateGraph(TestCusState)
+
+
+        ######分支2、代理Agent  supervisor_agent ##################################    
+        # 节点:  代理 agent 节点
+        workflow.add_node("supervisor_agent", self.workflow_node.supervisor_agent)
+        # agent节点1: 纯生成类问题
+        workflow.add_node("chat_box_generate", self.workflow_node.chat_box_generate)
+        # agent节点2:
+        workflow.add_node("common_agent", self.workflow_node.common_agent_node)
+
+
+        ###### 节点分支线条 ##################################    
+        # 固定问题识别
+        workflow.add_edge(START, "supervisor_agent")  
+        # 在图状态中填充 ‘next’字段,路由到具体的某个节点或结束图的运行,从来指定如何执行接下来的任务。
+        workflow.add_conditional_edges(source="supervisor_agent", 
+                path=lambda state: state["route_next"],
+                # 显式映射每个返回值到目标节点
+                path_map={
+                    "chat_box_generate": "chat_box_generate",
+                    "common_agent": "common_agent",
+                
+                }
+        )
+
+        supervisor_members_list = ["chat_box_generate" , "common_agent"] 
+
+         # 每个子代理 在完成后总是向主管 “汇报”
+        for agent_member in supervisor_members_list:
+            workflow.add_edge(agent_member, END) # 直接结束
+            #workflow.add_edge(agent_member, "supervisor_agent") # 回到路由 继续 判断执行
+
+       
+        #编译图
+        app = workflow.compile(checkpointer=self.checkpoint_saver)
+        #print(app.get_graph().draw_ascii())
+        server_logger.info(f"【图工作流构建完成】app={app}")
+        return app
+
+
+
+
+    async def handle_query_stream(self, param: TestForm, trace_id: str)-> AsyncGenerator[str, None]:
+        """
+        根据场景获取智能体反馈 (SSE流式响应)
+        """
+        try:
+
+            # 提取参数
+            user_input = param.input
+            session_id = param.config.session_id
+            context = param.context
+
+            
+            human_messages = [HumanMessage(content=user_input)]
+            # 完整的初始状态
+            initial_state = {
+                "messages": human_messages,
+                "session_id": session_id,                                # 会话id  
+                "trace_id": trace_id,                                  # 日志链路跟踪id
+                "task_prompt_info": {},                                    
+                "context": context ,                                    # 上下文数据
+                "user_input": user_input,
+            }
+            # 唯一的任务 ID(模拟 session_id / thread_id)
+            config = {"configurable": {"thread_id": session_id},
+                    "runnable_kwargs":{"recursion_limit": 50}
+            }
+            server_logger.info("======================== 启动新任务 ===========================")  #, interrupt_before=["user_confirm_task_planning"]
+
+            full_response = []
+            buffer = []
+            last_flush_time = time.time()
+            events = self.app.astream_events(initial_state, 
+                        config=config , 
+                        version="v1",  # 确保使用正确版本
+                        stream_mode="values"  # 或者 "updates"
+            )
+            # 流式处理事件
+            async for event in events:
+                #server_logger.info(trace_id=trace_id, msg=f"→ 事件类型: {event['event']}")
+                #server_logger.info(trace_id=trace_id, msg=f"→ 事件数据: {event['data']}")
+                
+                # 处理聊天模型流式输出
+                if event['event'] == 'on_chat_model_stream':
+                    if 'chunk' in event['data']:
+                        chunk = event['data']['chunk']
+                        if hasattr(chunk, 'content'):
+                            content = chunk.content
+                            full_response.append(content)
+                            
+                            # 缓冲管理策略
+                            buffer.append(content)
+                            current_time = time.time()
+                            
+                            # 刷新条件
+                            should_flush = (
+                                len(buffer) >= 3 or  # 达到最小块数
+                                (current_time - last_flush_time) > 0.5 or  # 超时
+                                any(content.endswith(('.', '。', '!', '?', '\n', ';', ';', '?', '!')) for content in buffer)  # 自然断点
+                            )
+                            
+                            if should_flush:
+                                combined = ''.join(buffer)
+                                yield combined
+                                
+                                buffer.clear()
+                                last_flush_time = current_time
+                
+                # 也可以处理其他类型的事件
+                # elif event['event'] == 'on_chain_stream':
+                #     server_logger.info(trace_id=trace_id, msg=f"链式处理: {event['data']}")
+                
+                # elif event['event'] == 'on_tool_stream':
+                #     server_logger.info(trace_id=trace_id, msg=f"工具调用: {event['data']}")
+            
+            # 处理剩余缓冲内容
+            if buffer:
+                yield ''.join(buffer)
+            
+            # 将完整响应添加到历史并进行压缩
+            if full_response:
+                full_text = "".join(full_response)
+                server_logger.info(trace_id=trace_id, msg=f"full_response: {full_text}", log_type="graph/stream")
+            
+        except Exception as e:
+            handler_err(server_logger, trace_id=trace_id, err=e, err_name='graph/stream')
+            yield json.dumps({"error": f"系统错误: {str(e)}"})
+
+
+
+
+    def write_graph(self):
+        """
+            将图写入文件
+        """
+        # 
+        graph_png = self.app.get_graph().draw_mermaid_png()
+        with open("build_graph_app.png", "wb") as f:
+            f.write(graph_png)
+        server_logger.info(f"【图工作流写入文件完成】")
+
+
+# 实例化
+test_workflow_graph = TestWorkflowGraph()

+ 110 - 0
agent/workflow/test_workflow_node.py

@@ -0,0 +1,110 @@
+
+
+# !/usr/bin/python
+# -*- coding: utf-8 -*-
+'''
+@Project    : 
+@File       :workflow_node.py
+@IDE        :Cursor
+@Author     :LINGMIN
+@Date       :2025/08/10 18:00
+'''
+
+
+import json
+import sys
+from logger.loggering import server_logger
+from utils.common import handler_err
+from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
+from langchain_core.prompts import ChatPromptTemplate
+from agent.workflow.test_cus_state import TestCusState
+from agent.generate.test_intent import intent_identify_client
+from agent.test_agent import test_agent_client
+from schemas.test_schemas import FormConfig
+from agent.generate.model_generate import test_generate_model_client
+
+
+
+
+class TestWorkflowNode:
+    """
+        工作流节点定义
+    """
+    def __init__(self):
+        """初始化模型和会话管理"""
+
+    
+
+    def supervisor_agent(self , state: TestCusState):
+        """
+            每个代理都与一个 Supervisor 代理通信(主管代理)。由  Supervisor 代理决定接下来应调用哪个代理
+            :param state:
+            :return:
+        """
+        session_id = state["session_id"]
+        trace_id = state["trace_id"]
+        user_input = state["user_input"]
+        route_next = state.get("route_next")
+        
+        server_logger.info(trace_id=trace_id, msg=f"\n===================================[Supervisor].begin-route_next:{route_next}=============================")
+        
+        config = {
+            "session_id": session_id
+        }
+        # 格式化输出,智能格式化输出
+        route_next = intent_identify_client.recognize_intent(trace_id=trace_id , config=config , input=user_input)
+        server_logger.info(trace_id=trace_id, msg=f"[Supervisor].intent_identify_client.recognize_intent:{route_next}")
+        if route_next not in ["chat_box_generate" , "common_agent"]:
+            route_next = "chat_box_generate"
+
+        
+        server_logger.info(trace_id=trace_id, msg=f"\n===================================[Supervisor].end-route_next:{route_next}=============================")
+        return {
+            "route_next": route_next
+        }
+
+
+
+    async def common_agent_node(self , state: TestCusState):
+        """
+            通用代理节点
+            :param state:
+            :return:
+        """
+        session_id = state["session_id"]
+        trace_id = state["trace_id"]
+        user_input = state["user_input"]
+        config_param = FormConfig(session_id=session_id)
+        task_prompt_info = {"task_prompt": ""}
+        response_content = await test_agent_client.handle_query(trace_id=trace_id , config_param=config_param, 
+                                                                task_prompt_info=task_prompt_info, 
+                                                                input_query=user_input, context=None)
+        messages = [AIMessage(content=response_content, name="common_agent_node")]
+        return {
+            "messages": messages,
+            "previous_agent": "common_agent",
+            "route_next": "FINISH"   # ✅ 直接结束流程
+        }
+    
+
+    def chat_box_generate(self , state: TestCusState) -> dict:
+        """
+            模型生成节点(纯生成类问题)
+            :param state:
+            :return:
+        """
+        session_id = state["session_id"]
+        trace_id = state["trace_id"]
+        user_input = state["user_input"]
+        task_prompt_info = state["task_prompt_info"]
+        task_prompt_info["task_prompt"] = ""
+        response_content = test_generate_model_client.get_model_generate_invoke(trace_id=trace_id , task_prompt_info=task_prompt_info, input_query=user_input)
+        messages = [AIMessage(content=response_content , name="chat_box_generate")]
+        server_logger.info(trace_id=trace_id, msg=f"【result】: {response_content}", log_type="chat_box_generate")
+        return {
+            "messages": messages,
+            "route_next": "FINISH"   # ✅ 直接结束流程
+        }
+
+
+

BIN
build_graph_app.png


+ 22 - 0
config/prompt/intent_prompt.yaml

@@ -0,0 +1,22 @@
+
+# 系统提示词
+system_prompt: |
+  基于提供的样例,结合用户最近的对话历史上下文进行意图识别,精准匹配对应的业务场景指令。
+  必须优先参考最近的上下文语义及用户意图演变,若问题与样例中的任一业务场景相符,则返回对应指令;若无法匹配任何已定义场景,则返回 chat_box_generate。
+  严格遵守:仅输出指令字符串,不附加任何解释、说明或格式。
+  用户目前历史上下文信息:
+  {history}
+
+
+
+
+# 意图案例 准备few-shot样例;
+intent_examples: 
+  - inn: 你好;咨询.
+    out: chat_box_generate
+
+  - inn: 执行;操作;查询;处理;
+    out: common_agent
+
+
+           

+ 22 - 0
utils/yaml_utils.py

@@ -65,5 +65,27 @@ def get_yaml_file_path(file_name: str) -> str:
 
 
 
+
+def get_intent_prompt() -> dict:
+    """
+        获取意图识别 系统提示语
+    """
+     # 构建文件路径 判断文件是否存在
+    yaml_file = get_yaml_file_path("intent_prompt.yaml")
+    
+    try:
+        with open(yaml_file, 'r', encoding='utf-8') as f:
+            prompt_config = yaml.safe_load(f)
+        # 验证必需字段
+        #validate_prompt_config(prompt_config, prompt_name)
+        server_logger.info(f"成功加载[意图识别]系统.system_prompt配置: {prompt_config["system_prompt"]}")
+        server_logger.info(f"成功加载[意图识别]系统配置.examples: {prompt_config["intent_examples"]}")
+        return prompt_config
+        
+    except Exception as e:
+        server_logger.error(f"加载意图识别intent_prompt文件失败: {yaml_file}, 错误: {str(e)}")
+        raise
+
+
 #获取系统提示语
 system_prompt_config = get_system_prompt()

+ 80 - 0
views/test_views.py

@@ -20,6 +20,7 @@ from logger.loggering import server_logger
 from schemas.test_schemas import TestForm
 from utils.common import return_json, handler_err
 from views import test_router, get_operation_id
+from agent.workflow.test_workflow_graph import test_workflow_graph
 
 
 
@@ -253,3 +254,82 @@ async def chat_agent_stream(param: TestForm,
             return_json(code=1, msg=f"{err}", trace_id=trace_id),
             status_code=500
         )
+
+
+
+@test_router.post("/graph/stream", response_class=Response)
+async def chat_graph_stream(param: TestForm,
+                     trace_id: str = Depends(get_operation_id)):
+    """
+        根据场景获取智能体反馈 (SSE流式响应)
+    """
+    try:
+        server_logger.info(trace_id=trace_id, msg=f"{param}")
+        # request_param = {
+        #     "input": param.input,
+        #     "config": param.config,
+        #     "context": param.context
+        # }
+        # 创建 SSE 流式响应 
+        async def event_generator():
+                try:
+                    # 流式处理查询
+                    async for chunk in test_workflow_graph.handle_query_stream(
+                            param=param,
+                            trace_id=trace_id,
+                    ):
+                        server_logger.debug(trace_id=trace_id, msg=f"{chunk}")
+                        # 发送数据块
+                        yield {
+                            "event": "message",
+                            "data": json.dumps({
+                                "code": 0,
+                                "output": chunk,
+                                "completed": False,
+                                "trace_id": trace_id,
+                                "dataType": "text"
+                            }, ensure_ascii=False)
+                        }
+
+                    # 发送结束事件
+                    yield {
+                        "event": "message_end",
+                        "data": json.dumps({
+                            "completed": True,
+                            "message": "Stream completed",
+                            "code": 0,
+                            "trace_id": trace_id,
+                            "dataType": "text"
+                        }, ensure_ascii=False),
+                    }
+                except Exception as e:
+                    # 错误处理
+                    yield {
+                        "event": "error",
+                        "data": json.dumps({
+                            "trace_id": trace_id,
+                            "msg": str(e),
+                            "code": 1,
+                            "dataType": "text"
+                        }, ensure_ascii=False)
+                    }
+                finally:
+                    # 不需要关闭客户端,因为它是单例
+                    pass
+
+        # 返回 SSE 响应
+        return EventSourceResponse(
+            event_generator(),
+            headers={
+                "Cache-Control": "no-cache",
+                "Connection": "keep-alive"
+            }
+        )
+        
+    except Exception as err:
+        # 初始错误处理
+        handler_err(server_logger, trace_id=trace_id, err=err, err_name="graph/stream")
+        return JSONResponse(
+            return_json(code=1, msg=f"{err}", trace_id=trace_id),
+            status_code=500
+        )