test_views.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. # !/usr/bin/ python
  2. # -*- coding: utf-8 -*-
  3. '''
  4. @Project : lq-agent-api
  5. @File :cattle_farm_views.py
  6. @IDE :PyCharm
  7. @Author :
  8. @Date :2025/7/10 17:32
  9. '''
  10. import json
  11. from typing import Optional
  12. from fastapi import Depends, Response, Header
  13. from sse_starlette import EventSourceResponse
  14. from starlette.responses import JSONResponse
  15. from agent.test_agent import test_agent_client
  16. from agent.generate.model_generate import test_generate_model_client
  17. from logger.loggering import server_logger
  18. from schemas.test_schemas import TestForm
  19. from utils.common import return_json, handler_err
  20. from views import test_router, get_operation_id
  21. from agent.workflow.test_workflow_graph import test_workflow_graph
  22. @test_router.post("/generate/chat", response_model=TestForm)
  23. async def generate_chat_endpoint(
  24. param: TestForm,
  25. trace_id: str = Depends(get_operation_id)):
  26. """
  27. 生成类模型
  28. """
  29. try:
  30. server_logger.info(trace_id=trace_id, msg=f"{param}")
  31. # 从字典中获取input
  32. input_query = param.input
  33. session_id = param.config.session_id
  34. context = param.context
  35. header_info = {
  36. }
  37. task_prompt_info = {"task_prompt": ""}
  38. output = test_generate_model_client.get_model_generate_invoke(trace_id , task_prompt_info,
  39. input_query, context)
  40. # 直接执行
  41. server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="agent/chat")
  42. # 返回字典格式的响应
  43. return JSONResponse(
  44. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  45. except ValueError as err:
  46. handler_err(server_logger, trace_id=trace_id, err=err, err_name="generate/stream")
  47. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  48. except Exception as err:
  49. handler_err(server_logger, trace_id=trace_id, err=err, err_name="generate/stream")
  50. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  51. @test_router.post("/generate/stream", response_model=TestForm)
  52. async def generate_stream_endpoint(
  53. param: TestForm,
  54. trace_id: str = Depends(get_operation_id)):
  55. """
  56. 生成类模型
  57. """
  58. try:
  59. server_logger.info(trace_id=trace_id, msg=f"{param}")
  60. # 从字典中获取input
  61. input_query = param.input
  62. session_id = param.config.session_id
  63. context = param.context
  64. header_info = {
  65. }
  66. task_prompt_info = {"task_prompt": ""}
  67. # 创建 SSE 流式响应
  68. async def event_generator():
  69. try:
  70. # 流式处理查询 trace_id, task_prompt_info: dict, input_query, context=None
  71. for chunk in test_generate_model_client.get_model_generate_stream(trace_id , task_prompt_info,
  72. input_query, context):
  73. # 发送数据块
  74. yield {
  75. "event": "message",
  76. "data": json.dumps({
  77. "output": chunk,
  78. "completed": False,
  79. }, ensure_ascii=False)
  80. }
  81. # 获取缓存数据
  82. result_data = {}
  83. # 发送结束事件
  84. yield {
  85. "event": "message_end",
  86. "data": json.dumps({
  87. "completed": True,
  88. "message": json.dumps(result_data, ensure_ascii=False),
  89. "code": 0,
  90. "trace_id": trace_id,
  91. }, ensure_ascii=False),
  92. }
  93. except Exception as e:
  94. # 错误处理
  95. yield {
  96. "event": "error",
  97. "data": json.dumps({
  98. "trace_id": trace_id,
  99. "message": str(e),
  100. "code": 1
  101. }, ensure_ascii=False)
  102. }
  103. finally:
  104. # 不需要关闭客户端,因为它是单例
  105. pass
  106. # 返回 SSE 响应
  107. return EventSourceResponse(
  108. event_generator(),
  109. headers={
  110. "Cache-Control": "no-cache",
  111. "Connection": "keep-alive"
  112. }
  113. )
  114. except ValueError as err:
  115. handler_err(server_logger, trace_id=trace_id, err=err, err_name="generate/stream")
  116. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  117. except Exception as err:
  118. handler_err(server_logger, trace_id=trace_id, err=err, err_name="generate/stream")
  119. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  120. # 路由
  121. @test_router.post("/agent/chat", response_model=TestForm)
  122. async def chat_endpoint(
  123. param: TestForm,
  124. trace_id: str = Depends(get_operation_id)):
  125. """
  126. 根据场景获取智能体反馈
  127. """
  128. try:
  129. server_logger.info(trace_id=trace_id, msg=f"{param}")
  130. # 验证参数
  131. # 从字典中获取input
  132. input_data = param.input
  133. session_id = param.config.session_id
  134. context = param.context
  135. header_info = {
  136. }
  137. task_prompt_info = {"task_prompt": ""}
  138. # stream 流式执行
  139. output = await test_agent_client.handle_query(trace_id , task_prompt_info, input_data, context, param.config)
  140. # 直接执行
  141. server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="agent/chat")
  142. # 返回字典格式的响应
  143. return JSONResponse(
  144. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  145. except ValueError as err:
  146. handler_err(server_logger, trace_id=trace_id, err=err, err_name="agent/chat")
  147. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  148. except Exception as err:
  149. handler_err(server_logger, trace_id=trace_id, err=err, err_name="agent/chat")
  150. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  151. @test_router.post("/agent/stream", response_class=Response)
  152. async def chat_agent_stream(param: TestForm,
  153. trace_id: str = Depends(get_operation_id)):
  154. """
  155. 根据场景获取智能体反馈 (SSE流式响应)
  156. """
  157. try:
  158. server_logger.info(trace_id=trace_id, msg=f"{param}")
  159. # 提取参数
  160. input_data = param.input
  161. context = param.context
  162. header_info = {
  163. }
  164. task_prompt_info = {"task_prompt": ""}
  165. # 如果business_scene为None,则使用大模型进行意图识别
  166. server_logger.info(trace_id=trace_id, msg=f"{param}")
  167. # 创建 SSE 流式响应
  168. async def event_generator():
  169. try:
  170. # 流式处理查询
  171. async for chunk in test_agent_client.handle_query_stream(
  172. trace_id=trace_id,
  173. config_param=param.config,
  174. task_prompt_info=task_prompt_info,
  175. input_query=input_data,
  176. context=context,
  177. header_info=header_info
  178. ):
  179. server_logger.debug(trace_id=trace_id, msg=f"{chunk}")
  180. # 发送数据块
  181. yield {
  182. "event": "message",
  183. "data": json.dumps({
  184. "code": 0,
  185. "output": chunk,
  186. "completed": False,
  187. "trace_id": trace_id,
  188. }, ensure_ascii=False)
  189. }
  190. # 获取缓存数据
  191. result_data = {}
  192. # 发送结束事件
  193. yield {
  194. "event": "message_end",
  195. "data": json.dumps({
  196. "completed": True,
  197. "message": json.dumps(result_data, ensure_ascii=False),
  198. "code": 0,
  199. "trace_id": trace_id,
  200. }, ensure_ascii=False),
  201. }
  202. except Exception as e:
  203. # 错误处理
  204. yield {
  205. "event": "error",
  206. "data": json.dumps({
  207. "trace_id": trace_id,
  208. "message": str(e),
  209. "code": 1
  210. }, ensure_ascii=False)
  211. }
  212. finally:
  213. # 不需要关闭客户端,因为它是单例
  214. pass
  215. # 返回 SSE 响应
  216. return EventSourceResponse(
  217. event_generator(),
  218. headers={
  219. "Cache-Control": "no-cache",
  220. "Connection": "keep-alive"
  221. }
  222. )
  223. except Exception as err:
  224. # 初始错误处理
  225. handler_err(server_logger, trace_id=trace_id, err=err, err_name="agent/stream")
  226. return JSONResponse(
  227. return_json(code=1, msg=f"{err}", trace_id=trace_id),
  228. status_code=500
  229. )
  230. @test_router.post("/graph/stream", response_class=Response)
  231. async def chat_graph_stream(param: TestForm,
  232. trace_id: str = Depends(get_operation_id)):
  233. """
  234. 根据场景获取智能体反馈 (SSE流式响应)
  235. """
  236. try:
  237. server_logger.info(trace_id=trace_id, msg=f"{param}")
  238. # request_param = {
  239. # "input": param.input,
  240. # "config": param.config,
  241. # "context": param.context
  242. # }
  243. # 创建 SSE 流式响应
  244. async def event_generator():
  245. try:
  246. # 流式处理查询
  247. async for chunk in test_workflow_graph.handle_query_stream(
  248. param=param,
  249. trace_id=trace_id,
  250. ):
  251. server_logger.debug(trace_id=trace_id, msg=f"{chunk}")
  252. # 发送数据块
  253. yield {
  254. "event": "message",
  255. "data": json.dumps({
  256. "code": 0,
  257. "output": chunk,
  258. "completed": False,
  259. "trace_id": trace_id,
  260. "dataType": "text"
  261. }, ensure_ascii=False)
  262. }
  263. # 发送结束事件
  264. yield {
  265. "event": "message_end",
  266. "data": json.dumps({
  267. "completed": True,
  268. "message": "Stream completed",
  269. "code": 0,
  270. "trace_id": trace_id,
  271. "dataType": "text"
  272. }, ensure_ascii=False),
  273. }
  274. except Exception as e:
  275. # 错误处理
  276. yield {
  277. "event": "error",
  278. "data": json.dumps({
  279. "trace_id": trace_id,
  280. "msg": str(e),
  281. "code": 1,
  282. "dataType": "text"
  283. }, ensure_ascii=False)
  284. }
  285. finally:
  286. # 不需要关闭客户端,因为它是单例
  287. pass
  288. # 返回 SSE 响应
  289. return EventSourceResponse(
  290. event_generator(),
  291. headers={
  292. "Cache-Control": "no-cache",
  293. "Connection": "keep-alive"
  294. }
  295. )
  296. except Exception as err:
  297. # 初始错误处理
  298. handler_err(server_logger, trace_id=trace_id, err=err, err_name="graph/stream")
  299. return JSONResponse(
  300. return_json(code=1, msg=f"{err}", trace_id=trace_id),
  301. status_code=500
  302. )