# !/usr/bin/ python # -*- coding: utf-8 -*- ''' @Project : lq-agent-api @File :cattle_farm_views.py @IDE :PyCharm @Author : @Date :2025/7/10 17:32 ''' import json from typing import Optional from fastapi import Depends, Response, Header from sse_starlette import EventSourceResponse from starlette.responses import JSONResponse from fastapi import Depends, Request, Response, Header from foundation.agent.test_agent import test_agent_client from foundation.agent.generate.model_generate import generate_model_client from foundation.logger.loggering import server_logger from foundation.schemas.test_schemas import TestForm from foundation.utils.common import return_json, handler_err from views import test_router, get_operation_id from foundation.agent.workflow.test_workflow_graph import test_workflow_graph from foundation.base.mysql.async_mysql_base_dao import TestTabDAO from database.repositories.bus_data_query import BasisOfPreparationDAO from foundation.utils.tool_utils import DateTimeEncoder from foundation.models.silicon_flow import SiliconFlowAPI from foundation.rag.vector.pg_vector import PGVectorDB from langchain_core.prompts import ChatPromptTemplate from foundation.utils.yaml_utils import system_prompt_config @test_router.post("/generate/chat", response_model=TestForm) async def generate_chat_endpoint( param: TestForm, trace_id: str = Depends(get_operation_id)): """ 生成类模型 """ try: server_logger.info(trace_id=trace_id, msg=f"{param}") print(trace_id) # 从字典中获取input input_query = param.input session_id = param.config.session_id context = param.context header_info = { } # 创建ChatPromptTemplate template = ChatPromptTemplate.from_messages([ ("system", system_prompt_config['system_prompt']), ("user", input_query) ]) task_prompt_info = {"task_prompt": template} output = await generate_model_client.get_model_generate_invoke(trace_id=trace_id , task_prompt_info=task_prompt_info) # 直接执行 server_logger.info(trace_id=trace_id, msg=f"【result】: {output}", log_type="agent/chat") # 返回字典格式的响应 return JSONResponse( return_json(data={"output": output}, data_type="text", trace_id=trace_id)) except ValueError as err: handler_err(server_logger, trace_id=trace_id, err=err, err_name="agent/chat") return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id)) except Exception as err: handler_err(server_logger, trace_id=trace_id, err=err, err_name="agent/chat") return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id)) @test_router.post("/generate/stream", response_model=TestForm) async def generate_stream_endpoint( param: TestForm, trace_id: str = Depends(get_operation_id)): """ 生成类模型 """ try: server_logger.info(trace_id=trace_id, msg=f"{param}") # 从字典中获取input input_query = param.input session_id = param.config.session_id context = param.context header_info = { } # 创建ChatPromptTemplate template = ChatPromptTemplate.from_messages([ ("system", system_prompt_config['system_prompt']), ("user", input_query) ]) task_prompt_info = {"task_prompt": template} # 创建 SSE 流式响应 async def event_generator(): try: # 流式处理查询 trace_id, task_prompt_info: dict, input_query, context=None for chunk in generate_model_client.get_model_generate_stream(trace_id=trace_id , task_prompt_info=task_prompt_info): # 发送数据块 yield { "event": "message", "data": json.dumps({ "output": chunk, "completed": False, }, ensure_ascii=False) } # 获取缓存数据 result_data = {} # 发送结束事件 yield { "event": "message_end", "data": json.dumps({ "completed": True, "message": json.dumps(result_data, ensure_ascii=False), "code": 0, "trace_id": trace_id, }, ensure_ascii=False), } except Exception as e: # 错误处理 yield { "event": "error", "data": json.dumps({ "trace_id": trace_id, "message": str(e), "code": 1 }, ensure_ascii=False) } finally: # 不需要关闭客户端,因为它是单例 pass # 返回 SSE 响应 return EventSourceResponse( event_generator(), headers={ "Cache-Control": "no-cache", "Connection": "keep-alive" } ) except ValueError as err: handler_err(server_logger, trace_id=trace_id, err=err, err_name="generate/stream") return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id)) except Exception as err: handler_err(server_logger, trace_id=trace_id, err=err, err_name="generate/stream") return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id)) # 路由 @test_router.post("/agent/chat", response_model=TestForm) async def chat_endpoint( param: TestForm, trace_id: str = Depends(get_operation_id)): """ 根据场景获取智能体反馈 """ try: server_logger.info(trace_id=trace_id, msg=f"{param}") # 验证参数 # 从字典中获取input input_data = param.input session_id = param.config.session_id context = param.context header_info = { } task_prompt_info = {"task_prompt": ""} # stream 流式执行 output = await test_agent_client.handle_query(trace_id , task_prompt_info, input_data, context, param.config) # 直接执行 server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="agent/chat") # 返回字典格式的响应 return JSONResponse( return_json(data={"output": output}, data_type="text", trace_id=trace_id)) except ValueError as err: handler_err(server_logger, trace_id=trace_id, err=err, err_name="agent/chat") return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id)) except Exception as err: handler_err(server_logger, trace_id=trace_id, err=err, err_name="agent/chat") return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id)) @test_router.post("/agent/stream", response_class=Response) async def chat_agent_stream(param: TestForm, trace_id: str = Depends(get_operation_id)): """ 根据场景获取智能体反馈 (SSE流式响应) """ try: server_logger.info(trace_id=trace_id, msg=f"{param}") # 提取参数 input_data = param.input context = param.context header_info = { } task_prompt_info = {"task_prompt": ""} # 如果business_scene为None,则使用大模型进行意图识别 server_logger.info(trace_id=trace_id, msg=f"{param}") # 创建 SSE 流式响应 async def event_generator(): try: # 流式处理查询 async for chunk in test_agent_client.handle_query_stream( trace_id=trace_id, config_param=param.config, task_prompt_info=task_prompt_info, input_query=input_data, context=context, header_info=header_info ): server_logger.debug(trace_id=trace_id, msg=f"{chunk}") # 发送数据块 yield { "event": "message", "data": json.dumps({ "code": 0, "output": chunk, "completed": False, "trace_id": trace_id, }, ensure_ascii=False) } # 获取缓存数据 result_data = {} # 发送结束事件 yield { "event": "message_end", "data": json.dumps({ "completed": True, "message": json.dumps(result_data, ensure_ascii=False), "code": 0, "trace_id": trace_id, }, ensure_ascii=False), } except Exception as e: # 错误处理 yield { "event": "error", "data": json.dumps({ "trace_id": trace_id, "message": str(e), "code": 1 }, ensure_ascii=False) } finally: # 不需要关闭客户端,因为它是单例 pass # 返回 SSE 响应 return EventSourceResponse( event_generator(), headers={ "Cache-Control": "no-cache", "Connection": "keep-alive" } ) except Exception as err: # 初始错误处理 handler_err(server_logger, trace_id=trace_id, err=err, err_name="agent/stream") return JSONResponse( return_json(code=1, msg=f"{err}", trace_id=trace_id), status_code=500 ) @test_router.post("/graph/stream", response_class=Response) async def chat_graph_stream(param: TestForm, trace_id: str = Depends(get_operation_id)): """ 根据场景获取智能体反馈 (SSE流式响应) """ try: server_logger.info(trace_id=trace_id, msg=f"{param}") # request_param = { # "input": param.input, # "config": param.config, # "context": param.context # } # 创建 SSE 流式响应 async def event_generator(): try: # 流式处理查询 async for chunk in test_workflow_graph.handle_query_stream( param=param, trace_id=trace_id, ): server_logger.debug(trace_id=trace_id, msg=f"{chunk}") # 发送数据块 yield { "event": "message", "data": json.dumps({ "code": 0, "output": chunk, "completed": False, "trace_id": trace_id, "dataType": "text" }, ensure_ascii=False) } # 发送结束事件 yield { "event": "message_end", "data": json.dumps({ "completed": True, "message": "Stream completed", "code": 0, "trace_id": trace_id, "dataType": "text" }, ensure_ascii=False), } except Exception as e: # 错误处理 yield { "event": "error", "data": json.dumps({ "trace_id": trace_id, "msg": str(e), "code": 1, "dataType": "text" }, ensure_ascii=False) } finally: # 不需要关闭客户端,因为它是单例 pass # 返回 SSE 响应 return EventSourceResponse( event_generator(), headers={ "Cache-Control": "no-cache", "Connection": "keep-alive" } ) except Exception as err: # 初始错误处理 handler_err(server_logger, trace_id=trace_id, err=err, err_name="graph/stream") return JSONResponse( return_json(code=1, msg=f"{err}", trace_id=trace_id), status_code=500 ) @test_router.post("/mysql/add", response_model=TestForm) async def test_mysql_add( request: Request, param: TestForm, trace_id: str = Depends(get_operation_id)): """ 根据MySQL应用 """ try: server_logger.info(trace_id=trace_id, msg=f"{param}") # 验证参数 # 从字典中获取input input_data = param.input session_id = param.config.session_id context = param.context header_info = { } # 从app.state中获取数据库连接池 async_db_pool = request.app.state.async_db_pool from foundation.base.mysql.async_mysql_base_dao import TestTabDAO test_tab_dao = TestTabDAO(async_db_pool) # name: str, email: str, age: int name = input_data email = session_id age = 18 test_id = await test_tab_dao.insert_user(name=name, email=email, age=age) output = f"【test_id】: {test_id}" # 直接执行 server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="/mysql/add") # 返回字典格式的响应 return JSONResponse( return_json(data={"output": output}, data_type="text", trace_id=trace_id)) except ValueError as err: handler_err(server_logger, trace_id=trace_id, err=err, err_name="/mysql/add") return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id)) except Exception as err: handler_err(server_logger, trace_id=trace_id, err=err, err_name="/mysql/add") return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id)) @test_router.post("/mysql/get", response_model=TestForm) async def test_mysql_add( request: Request, param: TestForm, trace_id: str = Depends(get_operation_id)): """ 根据MySQL应用 """ try: server_logger.info(trace_id=trace_id, msg=f"{param}") # 验证参数 # 从字典中获取input input_data = param.input session_id = param.config.session_id context = param.context header_info = { } # 从app.state中获取数据库连接池 async_db_pool = request.app.state.async_db_pool test_tab_dao = TestTabDAO(async_db_pool) test_id = input_data; data = await test_tab_dao.get_user_by_id(user_id=test_id) server_logger.info(trace_id=trace_id, msg=f"【result】: {data}", log_type="/mysql/get") json_str = json.dumps(data , cls=DateTimeEncoder, ensure_ascii=False, indent=2) output = json_str # 直接执行 server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="/mysql/get") # 返回字典格式的响应 return JSONResponse( return_json(data={"output": output}, data_type="text", trace_id=trace_id)) except ValueError as err: handler_err(server_logger, trace_id=trace_id, err=err, err_name="/mysql/get") return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id)) except Exception as err: handler_err(server_logger, trace_id=trace_id, err=err, err_name="/mysql/get") return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id)) @test_router.post("/mysql/list", response_model=TestForm) async def test_mysql_add( request: Request, param: TestForm, trace_id: str = Depends(get_operation_id)): """ 根据MySQL应用 """ try: server_logger.info(trace_id=trace_id, msg=f"{param}") # 验证参数 # 从字典中获取input input_data = param.input session_id = param.config.session_id context = param.context header_info = { } # 从app.state中获取数据库连接池 async_db_pool = request.app.state.async_db_pool from foundation.base.mysql.async_mysql_base_dao import TestTabDAO test_tab_dao = TestTabDAO(async_db_pool) test_id = input_data; data = await test_tab_dao.get_all_users() server_logger.info(trace_id=trace_id, msg=f"【result】: {data}", log_type="/mysql/list") json_str = json.dumps(data , cls=DateTimeEncoder, ensure_ascii=False, indent=2) output = json_str # 直接执行 server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="/mysql/list") # 返回字典格式的响应 return JSONResponse( return_json(data={"output": output}, data_type="text", trace_id=trace_id)) except ValueError as err: handler_err(server_logger, trace_id=trace_id, err=err, err_name="/mysql/list") return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id)) except Exception as err: handler_err(server_logger, trace_id=trace_id, err=err, err_name="/mysql/list") return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id)) @test_router.post("/mysql/update", response_model=TestForm) async def test_mysql_add( request: Request, param: TestForm, trace_id: str = Depends(get_operation_id)): """ 根据MySQL应用 """ try: server_logger.info(trace_id=trace_id, msg=f"{param}") # 验证参数 # 从字典中获取input input_data = param.input session_id = param.config.session_id context = param.context header_info = { } # 从app.state中获取数据库连接池 async_db_pool = request.app.state.async_db_pool test_tab_dao = TestTabDAO(async_db_pool) test_id = session_id; updates = { "name": input_data, "email": "test_email——upt", "age": 22 } success = await test_tab_dao.update_user(user_id=test_id , **updates) server_logger.info(trace_id=trace_id, msg=f"【result】: {success}", log_type="/mysql/update") output = success # 直接执行 server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="/mysql/update") # 返回字典格式的响应 return JSONResponse( return_json(data={"output": output}, data_type="text", trace_id=trace_id)) except ValueError as err: handler_err(server_logger, trace_id=trace_id, err=err, err_name="/mysql/update") return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id)) except Exception as err: handler_err(server_logger, trace_id=trace_id, err=err, err_name="/mysql/update") return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id)) @test_router.post("/bop/get", response_model=TestForm) async def test_bop_get( request: Request, param: TestForm, trace_id: str = Depends(get_operation_id)): """ 根据MySQL应用 """ try: server_logger.info(trace_id=trace_id, msg=f"{param}") # 验证参数 # 从字典中获取input input_data = param.input session_id = param.config.session_id context = param.context header_info = { } # 从app.state中获取数据库连接池 async_db_pool = request.app.state.async_db_pool bop_dao = BasisOfPreparationDAO(async_db_pool) test_id = input_data; data = await bop_dao.get_info_by_id(id=test_id) server_logger.info(trace_id=trace_id, msg=f"【result】: {data}", log_type="/bop/get") json_str = json.dumps(data , cls=DateTimeEncoder, ensure_ascii=False, indent=2) output = json_str # 直接执行 server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="/bop/get") # 返回字典格式的响应 return JSONResponse( return_json(data={"output": output}, data_type="text", trace_id=trace_id)) except ValueError as err: handler_err(server_logger, trace_id=trace_id, err=err, err_name="/bop/get") return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id)) except Exception as err: handler_err(server_logger, trace_id=trace_id, err=err, err_name="/bop/get") return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id)) @test_router.post("/bop/list", response_model=TestForm) async def test_mysql_add( request: Request, param: TestForm, trace_id: str = Depends(get_operation_id)): """ 根据MySQL应用 """ try: server_logger.info(trace_id=trace_id, msg=f"{param}") # 验证参数 # 从字典中获取input input_data = param.input session_id = param.config.session_id context = param.context header_info = { } # 从app.state中获取数据库连接池 async_db_pool = request.app.state.async_db_pool from foundation.base.mysql.async_mysql_base_dao import TestTabDAO bop_dao = BasisOfPreparationDAO(async_db_pool) test_id = input_data; data = await bop_dao.get_list() server_logger.info(trace_id=trace_id, msg=f"【result】: {data}", log_type="/bop/list") json_str = json.dumps(data , cls=DateTimeEncoder, ensure_ascii=False, indent=2) output = json_str # 直接执行 server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="/bop/list") # 返回字典格式的响应 return JSONResponse( return_json(data={"output": output}, data_type="text", trace_id=trace_id)) except ValueError as err: handler_err(server_logger, trace_id=trace_id, err=err, err_name="/bop/list") return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id)) except Exception as err: handler_err(server_logger, trace_id=trace_id, err=err, err_name="/bop/list") return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id)) ##################【RAG 相关测试】############################################## @test_router.post("/embedding", response_model=TestForm) async def embedding_test_endpoint( param: TestForm, trace_id: str = Depends(get_operation_id)): """ embedding模型测试 """ try: server_logger.info(trace_id=trace_id, msg=f"{param}") print(trace_id) # 从字典中获取input input_query = param.input session_id = param.config.session_id context = param.context header_info = { } task_prompt_info = {"task_prompt": ""} text = input_query # 初始化客户端(需提前设置环境变量 SILICONFLOW_API_KEY) from foundation.models.silicon_flow import SiliconFlowAPI base_api_platform = SiliconFlowAPI() embedding = base_api_platform.get_embeddings([text])[0] embed_dim = len(embedding) server_logger.info(trace_id=trace_id, msg=f"【result】: {embed_dim}") output = f"embed_dim={embed_dim},embedding:{embedding}" #output = test_generate_model_client.get_model_data_governance_invoke(trace_id , task_prompt_info, input_query, context) # 直接执行 #server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="embedding") # 返回字典格式的响应 return JSONResponse( return_json(data={"output": output}, data_type="text", trace_id=trace_id)) except ValueError as err: handler_err(server_logger, trace_id=trace_id, err=err, err_name="embedding") return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id)) except Exception as err: handler_err(server_logger, trace_id=trace_id, err=err, err_name="embedding") return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id)) @test_router.post("/bfp/search", response_model=TestForm) async def bfp_search_endpoint( param: TestForm, trace_id: str = Depends(get_operation_id)): """ 编制依据向量检索 """ try: server_logger.info(trace_id=trace_id, msg=f"{param}") print(trace_id) # 从字典中获取input input_query = param.input session_id = param.config.session_id context = param.context header_info = { } task_prompt_info = {"task_prompt": ""} top_k = int(session_id) output = None # 初始化客户端(需提前设置环境变量 SILICONFLOW_API_KEY) client = SiliconFlowAPI() # 抽象测试 pg_vector_db = PGVectorDB(base_api_platform=client) output = pg_vector_db.retriever(param={"table_name": "tv_basis_of_preparation"}, query_text=input_query , top_k=top_k) # 返回字典格式的响应 return JSONResponse( return_json(data={"output": output}, data_type="text", trace_id=trace_id)) except ValueError as err: handler_err(server_logger, trace_id=trace_id, err=err, err_name="bfp/search") return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id)) except Exception as err: handler_err(server_logger, trace_id=trace_id, err=err, err_name="bfp/search") return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id)) @test_router.post("/bfp/search/rerank", response_model=TestForm) async def bfp_search_endpoint( param: TestForm, trace_id: str = Depends(get_operation_id)): """ 编制依据文档检索和重排序 """ try: server_logger.info(trace_id=trace_id, msg=f"{param}") print(trace_id) # 从字典中获取input input_query = param.input session_id = param.config.session_id context = param.context header_info = { } task_prompt_info = {"task_prompt": ""} top_k = int(session_id) output = None # 初始化客户端(需提前设置环境变量 SILICONFLOW_API_KEY) client = SiliconFlowAPI() # 抽象测试 pg_vector_db = PGVectorDB(base_api_platform=client) output = pg_vector_db.retriever(param={"table_name": "tv_basis_of_preparation"}, query_text=input_query , top_k=top_k) # 重排序处理 content_list = [doc["text_content"] for doc in output] output = client.rerank(input_query=input_query, documents=content_list , top_n=top_k) # 返回字典格式的响应 return JSONResponse( return_json(data={"output": output}, data_type="text", trace_id=trace_id)) except ValueError as err: handler_err(server_logger, trace_id=trace_id, err=err, err_name="bfp/search") return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id)) except Exception as err: handler_err(server_logger, trace_id=trace_id, err=err, err_name="bfp/search") return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))