test_views.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. # !/usr/bin/ python
  2. # -*- coding: utf-8 -*-
  3. '''
  4. @Project : lq-agent-api
  5. @File :cattle_farm_views.py
  6. @IDE :PyCharm
  7. @Author :
  8. @Date :2025/7/10 17:32
  9. '''
  10. import json
  11. from typing import Optional
  12. from fastapi import Depends, Response, Header
  13. from sse_starlette import EventSourceResponse
  14. from starlette.responses import JSONResponse
  15. from foundation.agent.test_agent import test_agent_client
  16. from foundation.agent.generate.model_generate import generate_model_client
  17. from foundation.logger.loggering import server_logger
  18. from foundation.schemas.test_schemas import TestForm
  19. from foundation.utils.common import return_json, handler_err
  20. from views import test_router, get_operation_id
  21. from foundation.agent.workflow.test_workflow_graph import test_workflow_graph
  22. @test_router.post("/generate/chat", response_model=TestForm)
  23. async def generate_chat_endpoint(
  24. param: TestForm,
  25. trace_id: str = Depends(get_operation_id)):
  26. """
  27. 生成类模型
  28. """
  29. try:
  30. server_logger.info(trace_id=trace_id, msg=f"{param}")
  31. print(trace_id)
  32. # 从字典中获取input
  33. input_query = param.input
  34. session_id = param.config.session_id
  35. context = param.context
  36. header_info = {
  37. }
  38. task_prompt_info = {"task_prompt": ""}
  39. output = generate_model_client.get_model_generate_invoke(trace_id , task_prompt_info,
  40. input_query, context)
  41. # 直接执行
  42. server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="agent/chat")
  43. # 返回字典格式的响应
  44. return JSONResponse(
  45. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  46. except ValueError as err:
  47. handler_err(server_logger, trace_id=trace_id, err=err, err_name="generate/stream")
  48. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  49. except Exception as err:
  50. handler_err(server_logger, trace_id=trace_id, err=err, err_name="generate/stream")
  51. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  52. @test_router.post("/generate/stream", response_model=TestForm)
  53. async def generate_stream_endpoint(
  54. param: TestForm,
  55. trace_id: str = Depends(get_operation_id)):
  56. """
  57. 生成类模型
  58. """
  59. try:
  60. server_logger.info(trace_id=trace_id, msg=f"{param}")
  61. # 从字典中获取input
  62. input_query = param.input
  63. session_id = param.config.session_id
  64. context = param.context
  65. header_info = {
  66. }
  67. task_prompt_info = {"task_prompt": ""}
  68. # 创建 SSE 流式响应
  69. async def event_generator():
  70. try:
  71. # 流式处理查询 trace_id, task_prompt_info: dict, input_query, context=None
  72. for chunk in generate_model_client.get_model_generate_stream(trace_id , task_prompt_info,
  73. input_query, context):
  74. # 发送数据块
  75. yield {
  76. "event": "message",
  77. "data": json.dumps({
  78. "output": chunk,
  79. "completed": False,
  80. }, ensure_ascii=False)
  81. }
  82. # 获取缓存数据
  83. result_data = {}
  84. # 发送结束事件
  85. yield {
  86. "event": "message_end",
  87. "data": json.dumps({
  88. "completed": True,
  89. "message": json.dumps(result_data, ensure_ascii=False),
  90. "code": 0,
  91. "trace_id": trace_id,
  92. }, ensure_ascii=False),
  93. }
  94. except Exception as e:
  95. # 错误处理
  96. yield {
  97. "event": "error",
  98. "data": json.dumps({
  99. "trace_id": trace_id,
  100. "message": str(e),
  101. "code": 1
  102. }, ensure_ascii=False)
  103. }
  104. finally:
  105. # 不需要关闭客户端,因为它是单例
  106. pass
  107. # 返回 SSE 响应
  108. return EventSourceResponse(
  109. event_generator(),
  110. headers={
  111. "Cache-Control": "no-cache",
  112. "Connection": "keep-alive"
  113. }
  114. )
  115. except ValueError as err:
  116. handler_err(server_logger, trace_id=trace_id, err=err, err_name="generate/stream")
  117. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  118. except Exception as err:
  119. handler_err(server_logger, trace_id=trace_id, err=err, err_name="generate/stream")
  120. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  121. # 路由
  122. @test_router.post("/agent/chat", response_model=TestForm)
  123. async def chat_endpoint(
  124. param: TestForm,
  125. trace_id: str = Depends(get_operation_id)):
  126. """
  127. 根据场景获取智能体反馈
  128. """
  129. try:
  130. server_logger.info(trace_id=trace_id, msg=f"{param}")
  131. # 验证参数
  132. # 从字典中获取input
  133. input_data = param.input
  134. session_id = param.config.session_id
  135. context = param.context
  136. header_info = {
  137. }
  138. task_prompt_info = {"task_prompt": ""}
  139. # stream 流式执行
  140. output = await test_agent_client.handle_query(trace_id , task_prompt_info, input_data, context, param.config)
  141. # 直接执行
  142. server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="agent/chat")
  143. # 返回字典格式的响应
  144. return JSONResponse(
  145. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  146. except ValueError as err:
  147. handler_err(server_logger, trace_id=trace_id, err=err, err_name="agent/chat")
  148. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  149. except Exception as err:
  150. handler_err(server_logger, trace_id=trace_id, err=err, err_name="agent/chat")
  151. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  152. @test_router.post("/agent/stream", response_class=Response)
  153. async def chat_agent_stream(param: TestForm,
  154. trace_id: str = Depends(get_operation_id)):
  155. """
  156. 根据场景获取智能体反馈 (SSE流式响应)
  157. """
  158. try:
  159. server_logger.info(trace_id=trace_id, msg=f"{param}")
  160. # 提取参数
  161. input_data = param.input
  162. context = param.context
  163. header_info = {
  164. }
  165. task_prompt_info = {"task_prompt": ""}
  166. # 如果business_scene为None,则使用大模型进行意图识别
  167. server_logger.info(trace_id=trace_id, msg=f"{param}")
  168. # 创建 SSE 流式响应
  169. async def event_generator():
  170. try:
  171. # 流式处理查询
  172. async for chunk in test_agent_client.handle_query_stream(
  173. trace_id=trace_id,
  174. config_param=param.config,
  175. task_prompt_info=task_prompt_info,
  176. input_query=input_data,
  177. context=context,
  178. header_info=header_info
  179. ):
  180. server_logger.debug(trace_id=trace_id, msg=f"{chunk}")
  181. # 发送数据块
  182. yield {
  183. "event": "message",
  184. "data": json.dumps({
  185. "code": 0,
  186. "output": chunk,
  187. "completed": False,
  188. "trace_id": trace_id,
  189. }, ensure_ascii=False)
  190. }
  191. # 获取缓存数据
  192. result_data = {}
  193. # 发送结束事件
  194. yield {
  195. "event": "message_end",
  196. "data": json.dumps({
  197. "completed": True,
  198. "message": json.dumps(result_data, ensure_ascii=False),
  199. "code": 0,
  200. "trace_id": trace_id,
  201. }, ensure_ascii=False),
  202. }
  203. except Exception as e:
  204. # 错误处理
  205. yield {
  206. "event": "error",
  207. "data": json.dumps({
  208. "trace_id": trace_id,
  209. "message": str(e),
  210. "code": 1
  211. }, ensure_ascii=False)
  212. }
  213. finally:
  214. # 不需要关闭客户端,因为它是单例
  215. pass
  216. # 返回 SSE 响应
  217. return EventSourceResponse(
  218. event_generator(),
  219. headers={
  220. "Cache-Control": "no-cache",
  221. "Connection": "keep-alive"
  222. }
  223. )
  224. except Exception as err:
  225. # 初始错误处理
  226. handler_err(server_logger, trace_id=trace_id, err=err, err_name="agent/stream")
  227. return JSONResponse(
  228. return_json(code=1, msg=f"{err}", trace_id=trace_id),
  229. status_code=500
  230. )
  231. @test_router.post("/graph/stream", response_class=Response)
  232. async def chat_graph_stream(param: TestForm,
  233. trace_id: str = Depends(get_operation_id)):
  234. """
  235. 根据场景获取智能体反馈 (SSE流式响应)
  236. """
  237. try:
  238. server_logger.info(trace_id=trace_id, msg=f"{param}")
  239. # request_param = {
  240. # "input": param.input,
  241. # "config": param.config,
  242. # "context": param.context
  243. # }
  244. # 创建 SSE 流式响应
  245. async def event_generator():
  246. try:
  247. # 流式处理查询
  248. async for chunk in test_workflow_graph.handle_query_stream(
  249. param=param,
  250. trace_id=trace_id,
  251. ):
  252. server_logger.debug(trace_id=trace_id, msg=f"{chunk}")
  253. # 发送数据块
  254. yield {
  255. "event": "message",
  256. "data": json.dumps({
  257. "code": 0,
  258. "output": chunk,
  259. "completed": False,
  260. "trace_id": trace_id,
  261. "dataType": "text"
  262. }, ensure_ascii=False)
  263. }
  264. # 发送结束事件
  265. yield {
  266. "event": "message_end",
  267. "data": json.dumps({
  268. "completed": True,
  269. "message": "Stream completed",
  270. "code": 0,
  271. "trace_id": trace_id,
  272. "dataType": "text"
  273. }, ensure_ascii=False),
  274. }
  275. except Exception as e:
  276. # 错误处理
  277. yield {
  278. "event": "error",
  279. "data": json.dumps({
  280. "trace_id": trace_id,
  281. "msg": str(e),
  282. "code": 1,
  283. "dataType": "text"
  284. }, ensure_ascii=False)
  285. }
  286. finally:
  287. # 不需要关闭客户端,因为它是单例
  288. pass
  289. # 返回 SSE 响应
  290. return EventSourceResponse(
  291. event_generator(),
  292. headers={
  293. "Cache-Control": "no-cache",
  294. "Connection": "keep-alive"
  295. }
  296. )
  297. except Exception as err:
  298. # 初始错误处理
  299. handler_err(server_logger, trace_id=trace_id, err=err, err_name="graph/stream")
  300. return JSONResponse(
  301. return_json(code=1, msg=f"{err}", trace_id=trace_id),
  302. status_code=500
  303. )