test_views.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513
  1. # !/usr/bin/ python
  2. # -*- coding: utf-8 -*-
  3. '''
  4. @Project : lq-agent-api
  5. @File :cattle_farm_views.py
  6. @IDE :PyCharm
  7. @Author :
  8. @Date :2025/7/10 17:32
  9. '''
  10. import json
  11. from typing import Optional
  12. from fastapi import Depends, Response, Header
  13. from sse_starlette import EventSourceResponse
  14. from starlette.responses import JSONResponse
  15. from fastapi import Depends, Request, Response, Header
  16. from foundation.agent.test_agent import test_agent_client
  17. from foundation.agent.generate.model_generate import test_generate_model_client
  18. from foundation.logger.loggering import server_logger
  19. from foundation.schemas.test_schemas import TestForm
  20. from foundation.utils.common import return_json, handler_err
  21. from views import test_router, get_operation_id
  22. from foundation.agent.workflow.test_workflow_graph import test_workflow_graph
  23. from foundation.base.mysql.async_mysql_base_dao import TestTabDAO
  24. from foundation.utils.tool_utils import DateTimeEncoder
  25. @test_router.post("/generate/chat", response_model=TestForm)
  26. async def generate_chat_endpoint(
  27. param: TestForm,
  28. trace_id: str = Depends(get_operation_id)):
  29. """
  30. 生成类模型
  31. """
  32. try:
  33. server_logger.info(trace_id=trace_id, msg=f"{param}")
  34. print(trace_id)
  35. # 从字典中获取input
  36. input_query = param.input
  37. session_id = param.config.session_id
  38. context = param.context
  39. header_info = {
  40. }
  41. task_prompt_info = {"task_prompt": ""}
  42. output = test_generate_model_client.get_model_generate_invoke(trace_id , task_prompt_info,
  43. input_query, context)
  44. # 直接执行
  45. server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="agent/chat")
  46. # 返回字典格式的响应
  47. return JSONResponse(
  48. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  49. except ValueError as err:
  50. handler_err(server_logger, trace_id=trace_id, err=err, err_name="generate/stream")
  51. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  52. except Exception as err:
  53. handler_err(server_logger, trace_id=trace_id, err=err, err_name="generate/stream")
  54. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  55. @test_router.post("/generate/stream", response_model=TestForm)
  56. async def generate_stream_endpoint(
  57. param: TestForm,
  58. trace_id: str = Depends(get_operation_id)):
  59. """
  60. 生成类模型
  61. """
  62. try:
  63. server_logger.info(trace_id=trace_id, msg=f"{param}")
  64. # 从字典中获取input
  65. input_query = param.input
  66. session_id = param.config.session_id
  67. context = param.context
  68. header_info = {
  69. }
  70. task_prompt_info = {"task_prompt": ""}
  71. # 创建 SSE 流式响应
  72. async def event_generator():
  73. try:
  74. # 流式处理查询 trace_id, task_prompt_info: dict, input_query, context=None
  75. for chunk in test_generate_model_client.get_model_generate_stream(trace_id , task_prompt_info,
  76. input_query, context):
  77. # 发送数据块
  78. yield {
  79. "event": "message",
  80. "data": json.dumps({
  81. "output": chunk,
  82. "completed": False,
  83. }, ensure_ascii=False)
  84. }
  85. # 获取缓存数据
  86. result_data = {}
  87. # 发送结束事件
  88. yield {
  89. "event": "message_end",
  90. "data": json.dumps({
  91. "completed": True,
  92. "message": json.dumps(result_data, ensure_ascii=False),
  93. "code": 0,
  94. "trace_id": trace_id,
  95. }, ensure_ascii=False),
  96. }
  97. except Exception as e:
  98. # 错误处理
  99. yield {
  100. "event": "error",
  101. "data": json.dumps({
  102. "trace_id": trace_id,
  103. "message": str(e),
  104. "code": 1
  105. }, ensure_ascii=False)
  106. }
  107. finally:
  108. # 不需要关闭客户端,因为它是单例
  109. pass
  110. # 返回 SSE 响应
  111. return EventSourceResponse(
  112. event_generator(),
  113. headers={
  114. "Cache-Control": "no-cache",
  115. "Connection": "keep-alive"
  116. }
  117. )
  118. except ValueError as err:
  119. handler_err(server_logger, trace_id=trace_id, err=err, err_name="generate/stream")
  120. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  121. except Exception as err:
  122. handler_err(server_logger, trace_id=trace_id, err=err, err_name="generate/stream")
  123. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  124. # 路由
  125. @test_router.post("/agent/chat", response_model=TestForm)
  126. async def chat_endpoint(
  127. param: TestForm,
  128. trace_id: str = Depends(get_operation_id)):
  129. """
  130. 根据场景获取智能体反馈
  131. """
  132. try:
  133. server_logger.info(trace_id=trace_id, msg=f"{param}")
  134. # 验证参数
  135. # 从字典中获取input
  136. input_data = param.input
  137. session_id = param.config.session_id
  138. context = param.context
  139. header_info = {
  140. }
  141. task_prompt_info = {"task_prompt": ""}
  142. # stream 流式执行
  143. output = await test_agent_client.handle_query(trace_id , task_prompt_info, input_data, context, param.config)
  144. # 直接执行
  145. server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="agent/chat")
  146. # 返回字典格式的响应
  147. return JSONResponse(
  148. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  149. except ValueError as err:
  150. handler_err(server_logger, trace_id=trace_id, err=err, err_name="agent/chat")
  151. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  152. except Exception as err:
  153. handler_err(server_logger, trace_id=trace_id, err=err, err_name="agent/chat")
  154. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  155. @test_router.post("/agent/stream", response_class=Response)
  156. async def chat_agent_stream(param: TestForm,
  157. trace_id: str = Depends(get_operation_id)):
  158. """
  159. 根据场景获取智能体反馈 (SSE流式响应)
  160. """
  161. try:
  162. server_logger.info(trace_id=trace_id, msg=f"{param}")
  163. # 提取参数
  164. input_data = param.input
  165. context = param.context
  166. header_info = {
  167. }
  168. task_prompt_info = {"task_prompt": ""}
  169. # 如果business_scene为None,则使用大模型进行意图识别
  170. server_logger.info(trace_id=trace_id, msg=f"{param}")
  171. # 创建 SSE 流式响应
  172. async def event_generator():
  173. try:
  174. # 流式处理查询
  175. async for chunk in test_agent_client.handle_query_stream(
  176. trace_id=trace_id,
  177. config_param=param.config,
  178. task_prompt_info=task_prompt_info,
  179. input_query=input_data,
  180. context=context,
  181. header_info=header_info
  182. ):
  183. server_logger.debug(trace_id=trace_id, msg=f"{chunk}")
  184. # 发送数据块
  185. yield {
  186. "event": "message",
  187. "data": json.dumps({
  188. "code": 0,
  189. "output": chunk,
  190. "completed": False,
  191. "trace_id": trace_id,
  192. }, ensure_ascii=False)
  193. }
  194. # 获取缓存数据
  195. result_data = {}
  196. # 发送结束事件
  197. yield {
  198. "event": "message_end",
  199. "data": json.dumps({
  200. "completed": True,
  201. "message": json.dumps(result_data, ensure_ascii=False),
  202. "code": 0,
  203. "trace_id": trace_id,
  204. }, ensure_ascii=False),
  205. }
  206. except Exception as e:
  207. # 错误处理
  208. yield {
  209. "event": "error",
  210. "data": json.dumps({
  211. "trace_id": trace_id,
  212. "message": str(e),
  213. "code": 1
  214. }, ensure_ascii=False)
  215. }
  216. finally:
  217. # 不需要关闭客户端,因为它是单例
  218. pass
  219. # 返回 SSE 响应
  220. return EventSourceResponse(
  221. event_generator(),
  222. headers={
  223. "Cache-Control": "no-cache",
  224. "Connection": "keep-alive"
  225. }
  226. )
  227. except Exception as err:
  228. # 初始错误处理
  229. handler_err(server_logger, trace_id=trace_id, err=err, err_name="agent/stream")
  230. return JSONResponse(
  231. return_json(code=1, msg=f"{err}", trace_id=trace_id),
  232. status_code=500
  233. )
  234. @test_router.post("/graph/stream", response_class=Response)
  235. async def chat_graph_stream(param: TestForm,
  236. trace_id: str = Depends(get_operation_id)):
  237. """
  238. 根据场景获取智能体反馈 (SSE流式响应)
  239. """
  240. try:
  241. server_logger.info(trace_id=trace_id, msg=f"{param}")
  242. # request_param = {
  243. # "input": param.input,
  244. # "config": param.config,
  245. # "context": param.context
  246. # }
  247. # 创建 SSE 流式响应
  248. async def event_generator():
  249. try:
  250. # 流式处理查询
  251. async for chunk in test_workflow_graph.handle_query_stream(
  252. param=param,
  253. trace_id=trace_id,
  254. ):
  255. server_logger.debug(trace_id=trace_id, msg=f"{chunk}")
  256. # 发送数据块
  257. yield {
  258. "event": "message",
  259. "data": json.dumps({
  260. "code": 0,
  261. "output": chunk,
  262. "completed": False,
  263. "trace_id": trace_id,
  264. "dataType": "text"
  265. }, ensure_ascii=False)
  266. }
  267. # 发送结束事件
  268. yield {
  269. "event": "message_end",
  270. "data": json.dumps({
  271. "completed": True,
  272. "message": "Stream completed",
  273. "code": 0,
  274. "trace_id": trace_id,
  275. "dataType": "text"
  276. }, ensure_ascii=False),
  277. }
  278. except Exception as e:
  279. # 错误处理
  280. yield {
  281. "event": "error",
  282. "data": json.dumps({
  283. "trace_id": trace_id,
  284. "msg": str(e),
  285. "code": 1,
  286. "dataType": "text"
  287. }, ensure_ascii=False)
  288. }
  289. finally:
  290. # 不需要关闭客户端,因为它是单例
  291. pass
  292. # 返回 SSE 响应
  293. return EventSourceResponse(
  294. event_generator(),
  295. headers={
  296. "Cache-Control": "no-cache",
  297. "Connection": "keep-alive"
  298. }
  299. )
  300. except Exception as err:
  301. # 初始错误处理
  302. handler_err(server_logger, trace_id=trace_id, err=err, err_name="graph/stream")
  303. return JSONResponse(
  304. return_json(code=1, msg=f"{err}", trace_id=trace_id),
  305. status_code=500
  306. )
  307. @test_router.post("/mysql/add", response_model=TestForm)
  308. async def test_mysql_add(
  309. request: Request,
  310. param: TestForm,
  311. trace_id: str = Depends(get_operation_id)):
  312. """
  313. 根据MySQL应用
  314. """
  315. try:
  316. server_logger.info(trace_id=trace_id, msg=f"{param}")
  317. # 验证参数
  318. # 从字典中获取input
  319. input_data = param.input
  320. session_id = param.config.session_id
  321. context = param.context
  322. header_info = {
  323. }
  324. # 从app.state中获取数据库连接池
  325. async_db_pool = request.app.state.async_db_pool
  326. from foundation.base.mysql.async_mysql_base_dao import TestTabDAO
  327. test_tab_dao = TestTabDAO(async_db_pool)
  328. # name: str, email: str, age: int
  329. name = input_data
  330. email = session_id
  331. age = 18
  332. test_id = await test_tab_dao.insert_user(name=name, email=email, age=age)
  333. output = f"【test_id】: {test_id}"
  334. # 直接执行
  335. server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="/mysql/add")
  336. # 返回字典格式的响应
  337. return JSONResponse(
  338. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  339. except ValueError as err:
  340. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/mysql/add")
  341. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  342. except Exception as err:
  343. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/mysql/add")
  344. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  345. @test_router.post("/mysql/get", response_model=TestForm)
  346. async def test_mysql_add(
  347. request: Request,
  348. param: TestForm,
  349. trace_id: str = Depends(get_operation_id)):
  350. """
  351. 根据MySQL应用
  352. """
  353. try:
  354. server_logger.info(trace_id=trace_id, msg=f"{param}")
  355. # 验证参数
  356. # 从字典中获取input
  357. input_data = param.input
  358. session_id = param.config.session_id
  359. context = param.context
  360. header_info = {
  361. }
  362. # 从app.state中获取数据库连接池
  363. async_db_pool = request.app.state.async_db_pool
  364. test_tab_dao = TestTabDAO(async_db_pool)
  365. test_id = input_data;
  366. data = await test_tab_dao.get_user_by_id(user_id=test_id)
  367. server_logger.info(trace_id=trace_id, msg=f"【result】: {data}", log_type="/mysql/get")
  368. json_str = json.dumps(data , cls=DateTimeEncoder, ensure_ascii=False, indent=2)
  369. output = json_str
  370. # 直接执行
  371. server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="/mysql/get")
  372. # 返回字典格式的响应
  373. return JSONResponse(
  374. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  375. except ValueError as err:
  376. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/mysql/get")
  377. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  378. except Exception as err:
  379. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/mysql/get")
  380. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  381. @test_router.post("/mysql/list", response_model=TestForm)
  382. async def test_mysql_add(
  383. request: Request,
  384. param: TestForm,
  385. trace_id: str = Depends(get_operation_id)):
  386. """
  387. 根据MySQL应用
  388. """
  389. try:
  390. server_logger.info(trace_id=trace_id, msg=f"{param}")
  391. # 验证参数
  392. # 从字典中获取input
  393. input_data = param.input
  394. session_id = param.config.session_id
  395. context = param.context
  396. header_info = {
  397. }
  398. # 从app.state中获取数据库连接池
  399. async_db_pool = request.app.state.async_db_pool
  400. from foundation.base.mysql.async_mysql_base_dao import TestTabDAO
  401. test_tab_dao = TestTabDAO(async_db_pool)
  402. test_id = input_data;
  403. data = await test_tab_dao.get_all_users()
  404. server_logger.info(trace_id=trace_id, msg=f"【result】: {data}", log_type="/mysql/list")
  405. json_str = json.dumps(data , cls=DateTimeEncoder, ensure_ascii=False, indent=2)
  406. output = json_str
  407. # 直接执行
  408. server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="/mysql/list")
  409. # 返回字典格式的响应
  410. return JSONResponse(
  411. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  412. except ValueError as err:
  413. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/mysql/list")
  414. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  415. except Exception as err:
  416. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/mysql/list")
  417. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  418. @test_router.post("/mysql/update", response_model=TestForm)
  419. async def test_mysql_add(
  420. request: Request,
  421. param: TestForm,
  422. trace_id: str = Depends(get_operation_id)):
  423. """
  424. 根据MySQL应用
  425. """
  426. try:
  427. server_logger.info(trace_id=trace_id, msg=f"{param}")
  428. # 验证参数
  429. # 从字典中获取input
  430. input_data = param.input
  431. session_id = param.config.session_id
  432. context = param.context
  433. header_info = {
  434. }
  435. # 从app.state中获取数据库连接池
  436. async_db_pool = request.app.state.async_db_pool
  437. test_tab_dao = TestTabDAO(async_db_pool)
  438. test_id = session_id;
  439. updates = {
  440. "name": input_data,
  441. "email": "test_email——upt",
  442. "age": 22
  443. }
  444. success = await test_tab_dao.update_user(user_id=test_id , **updates)
  445. server_logger.info(trace_id=trace_id, msg=f"【result】: {success}", log_type="/mysql/update")
  446. output = success
  447. # 直接执行
  448. server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="/mysql/update")
  449. # 返回字典格式的响应
  450. return JSONResponse(
  451. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  452. except ValueError as err:
  453. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/mysql/update")
  454. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  455. except Exception as err:
  456. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/mysql/update")
  457. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))