test_views.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596
  1. # !/usr/bin/ python
  2. # -*- coding: utf-8 -*-
  3. '''
  4. @Project : lq-agent-api
  5. @File :cattle_farm_views.py
  6. @IDE :PyCharm
  7. @Author :
  8. @Date :2025/7/10 17:32
  9. '''
  10. import json
  11. from typing import Optional
  12. from fastapi import Depends, Response, Header
  13. from sse_starlette import EventSourceResponse
  14. from starlette.responses import JSONResponse
  15. from fastapi import Depends, Request, Response, Header
  16. from foundation.agent.test_agent import test_agent_client
  17. from foundation.agent.generate.model_generate import generate_model_client
  18. from foundation.logger.loggering import server_logger
  19. from foundation.schemas.test_schemas import TestForm
  20. from foundation.utils.common import return_json, handler_err
  21. from views import test_router, get_operation_id
  22. from foundation.agent.workflow.test_workflow_graph import test_workflow_graph
  23. from foundation.base.mysql.async_mysql_base_dao import TestTabDAO
  24. from database.repositories.bus_data_query import BasisOfPreparationDAO
  25. from foundation.utils.tool_utils import DateTimeEncoder
  26. @test_router.post("/generate/chat", response_model=TestForm)
  27. async def generate_chat_endpoint(
  28. param: TestForm,
  29. trace_id: str = Depends(get_operation_id)):
  30. """
  31. 生成类模型
  32. """
  33. try:
  34. server_logger.info(trace_id=trace_id, msg=f"{param}")
  35. print(trace_id)
  36. # 从字典中获取input
  37. input_query = param.input
  38. session_id = param.config.session_id
  39. context = param.context
  40. header_info = {
  41. }
  42. task_prompt_info = {"task_prompt": ""}
  43. output = generate_model_client.get_model_generate_invoke(trace_id , task_prompt_info,
  44. input_query, context)
  45. # 直接执行
  46. server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="agent/chat")
  47. # 返回字典格式的响应
  48. return JSONResponse(
  49. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  50. except ValueError as err:
  51. handler_err(server_logger, trace_id=trace_id, err=err, err_name="generate/stream")
  52. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  53. except Exception as err:
  54. handler_err(server_logger, trace_id=trace_id, err=err, err_name="generate/stream")
  55. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  56. @test_router.post("/generate/stream", response_model=TestForm)
  57. async def generate_stream_endpoint(
  58. param: TestForm,
  59. trace_id: str = Depends(get_operation_id)):
  60. """
  61. 生成类模型
  62. """
  63. try:
  64. server_logger.info(trace_id=trace_id, msg=f"{param}")
  65. # 从字典中获取input
  66. input_query = param.input
  67. session_id = param.config.session_id
  68. context = param.context
  69. header_info = {
  70. }
  71. task_prompt_info = {"task_prompt": ""}
  72. # 创建 SSE 流式响应
  73. async def event_generator():
  74. try:
  75. # 流式处理查询 trace_id, task_prompt_info: dict, input_query, context=None
  76. for chunk in generate_model_client.get_model_generate_stream(trace_id , task_prompt_info,
  77. input_query, context):
  78. # 发送数据块
  79. yield {
  80. "event": "message",
  81. "data": json.dumps({
  82. "output": chunk,
  83. "completed": False,
  84. }, ensure_ascii=False)
  85. }
  86. # 获取缓存数据
  87. result_data = {}
  88. # 发送结束事件
  89. yield {
  90. "event": "message_end",
  91. "data": json.dumps({
  92. "completed": True,
  93. "message": json.dumps(result_data, ensure_ascii=False),
  94. "code": 0,
  95. "trace_id": trace_id,
  96. }, ensure_ascii=False),
  97. }
  98. except Exception as e:
  99. # 错误处理
  100. yield {
  101. "event": "error",
  102. "data": json.dumps({
  103. "trace_id": trace_id,
  104. "message": str(e),
  105. "code": 1
  106. }, ensure_ascii=False)
  107. }
  108. finally:
  109. # 不需要关闭客户端,因为它是单例
  110. pass
  111. # 返回 SSE 响应
  112. return EventSourceResponse(
  113. event_generator(),
  114. headers={
  115. "Cache-Control": "no-cache",
  116. "Connection": "keep-alive"
  117. }
  118. )
  119. except ValueError as err:
  120. handler_err(server_logger, trace_id=trace_id, err=err, err_name="generate/stream")
  121. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  122. except Exception as err:
  123. handler_err(server_logger, trace_id=trace_id, err=err, err_name="generate/stream")
  124. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  125. # 路由
  126. @test_router.post("/agent/chat", response_model=TestForm)
  127. async def chat_endpoint(
  128. param: TestForm,
  129. trace_id: str = Depends(get_operation_id)):
  130. """
  131. 根据场景获取智能体反馈
  132. """
  133. try:
  134. server_logger.info(trace_id=trace_id, msg=f"{param}")
  135. # 验证参数
  136. # 从字典中获取input
  137. input_data = param.input
  138. session_id = param.config.session_id
  139. context = param.context
  140. header_info = {
  141. }
  142. task_prompt_info = {"task_prompt": ""}
  143. # stream 流式执行
  144. output = await test_agent_client.handle_query(trace_id , task_prompt_info, input_data, context, param.config)
  145. # 直接执行
  146. server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="agent/chat")
  147. # 返回字典格式的响应
  148. return JSONResponse(
  149. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  150. except ValueError as err:
  151. handler_err(server_logger, trace_id=trace_id, err=err, err_name="agent/chat")
  152. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  153. except Exception as err:
  154. handler_err(server_logger, trace_id=trace_id, err=err, err_name="agent/chat")
  155. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  156. @test_router.post("/agent/stream", response_class=Response)
  157. async def chat_agent_stream(param: TestForm,
  158. trace_id: str = Depends(get_operation_id)):
  159. """
  160. 根据场景获取智能体反馈 (SSE流式响应)
  161. """
  162. try:
  163. server_logger.info(trace_id=trace_id, msg=f"{param}")
  164. # 提取参数
  165. input_data = param.input
  166. context = param.context
  167. header_info = {
  168. }
  169. task_prompt_info = {"task_prompt": ""}
  170. # 如果business_scene为None,则使用大模型进行意图识别
  171. server_logger.info(trace_id=trace_id, msg=f"{param}")
  172. # 创建 SSE 流式响应
  173. async def event_generator():
  174. try:
  175. # 流式处理查询
  176. async for chunk in test_agent_client.handle_query_stream(
  177. trace_id=trace_id,
  178. config_param=param.config,
  179. task_prompt_info=task_prompt_info,
  180. input_query=input_data,
  181. context=context,
  182. header_info=header_info
  183. ):
  184. server_logger.debug(trace_id=trace_id, msg=f"{chunk}")
  185. # 发送数据块
  186. yield {
  187. "event": "message",
  188. "data": json.dumps({
  189. "code": 0,
  190. "output": chunk,
  191. "completed": False,
  192. "trace_id": trace_id,
  193. }, ensure_ascii=False)
  194. }
  195. # 获取缓存数据
  196. result_data = {}
  197. # 发送结束事件
  198. yield {
  199. "event": "message_end",
  200. "data": json.dumps({
  201. "completed": True,
  202. "message": json.dumps(result_data, ensure_ascii=False),
  203. "code": 0,
  204. "trace_id": trace_id,
  205. }, ensure_ascii=False),
  206. }
  207. except Exception as e:
  208. # 错误处理
  209. yield {
  210. "event": "error",
  211. "data": json.dumps({
  212. "trace_id": trace_id,
  213. "message": str(e),
  214. "code": 1
  215. }, ensure_ascii=False)
  216. }
  217. finally:
  218. # 不需要关闭客户端,因为它是单例
  219. pass
  220. # 返回 SSE 响应
  221. return EventSourceResponse(
  222. event_generator(),
  223. headers={
  224. "Cache-Control": "no-cache",
  225. "Connection": "keep-alive"
  226. }
  227. )
  228. except Exception as err:
  229. # 初始错误处理
  230. handler_err(server_logger, trace_id=trace_id, err=err, err_name="agent/stream")
  231. return JSONResponse(
  232. return_json(code=1, msg=f"{err}", trace_id=trace_id),
  233. status_code=500
  234. )
  235. @test_router.post("/graph/stream", response_class=Response)
  236. async def chat_graph_stream(param: TestForm,
  237. trace_id: str = Depends(get_operation_id)):
  238. """
  239. 根据场景获取智能体反馈 (SSE流式响应)
  240. """
  241. try:
  242. server_logger.info(trace_id=trace_id, msg=f"{param}")
  243. # request_param = {
  244. # "input": param.input,
  245. # "config": param.config,
  246. # "context": param.context
  247. # }
  248. # 创建 SSE 流式响应
  249. async def event_generator():
  250. try:
  251. # 流式处理查询
  252. async for chunk in test_workflow_graph.handle_query_stream(
  253. param=param,
  254. trace_id=trace_id,
  255. ):
  256. server_logger.debug(trace_id=trace_id, msg=f"{chunk}")
  257. # 发送数据块
  258. yield {
  259. "event": "message",
  260. "data": json.dumps({
  261. "code": 0,
  262. "output": chunk,
  263. "completed": False,
  264. "trace_id": trace_id,
  265. "dataType": "text"
  266. }, ensure_ascii=False)
  267. }
  268. # 发送结束事件
  269. yield {
  270. "event": "message_end",
  271. "data": json.dumps({
  272. "completed": True,
  273. "message": "Stream completed",
  274. "code": 0,
  275. "trace_id": trace_id,
  276. "dataType": "text"
  277. }, ensure_ascii=False),
  278. }
  279. except Exception as e:
  280. # 错误处理
  281. yield {
  282. "event": "error",
  283. "data": json.dumps({
  284. "trace_id": trace_id,
  285. "msg": str(e),
  286. "code": 1,
  287. "dataType": "text"
  288. }, ensure_ascii=False)
  289. }
  290. finally:
  291. # 不需要关闭客户端,因为它是单例
  292. pass
  293. # 返回 SSE 响应
  294. return EventSourceResponse(
  295. event_generator(),
  296. headers={
  297. "Cache-Control": "no-cache",
  298. "Connection": "keep-alive"
  299. }
  300. )
  301. except Exception as err:
  302. # 初始错误处理
  303. handler_err(server_logger, trace_id=trace_id, err=err, err_name="graph/stream")
  304. return JSONResponse(
  305. return_json(code=1, msg=f"{err}", trace_id=trace_id),
  306. status_code=500
  307. )
  308. @test_router.post("/mysql/add", response_model=TestForm)
  309. async def test_mysql_add(
  310. request: Request,
  311. param: TestForm,
  312. trace_id: str = Depends(get_operation_id)):
  313. """
  314. 根据MySQL应用
  315. """
  316. try:
  317. server_logger.info(trace_id=trace_id, msg=f"{param}")
  318. # 验证参数
  319. # 从字典中获取input
  320. input_data = param.input
  321. session_id = param.config.session_id
  322. context = param.context
  323. header_info = {
  324. }
  325. # 从app.state中获取数据库连接池
  326. async_db_pool = request.app.state.async_db_pool
  327. from foundation.base.mysql.async_mysql_base_dao import TestTabDAO
  328. test_tab_dao = TestTabDAO(async_db_pool)
  329. # name: str, email: str, age: int
  330. name = input_data
  331. email = session_id
  332. age = 18
  333. test_id = await test_tab_dao.insert_user(name=name, email=email, age=age)
  334. output = f"【test_id】: {test_id}"
  335. # 直接执行
  336. server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="/mysql/add")
  337. # 返回字典格式的响应
  338. return JSONResponse(
  339. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  340. except ValueError as err:
  341. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/mysql/add")
  342. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  343. except Exception as err:
  344. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/mysql/add")
  345. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  346. @test_router.post("/mysql/get", response_model=TestForm)
  347. async def test_mysql_add(
  348. request: Request,
  349. param: TestForm,
  350. trace_id: str = Depends(get_operation_id)):
  351. """
  352. 根据MySQL应用
  353. """
  354. try:
  355. server_logger.info(trace_id=trace_id, msg=f"{param}")
  356. # 验证参数
  357. # 从字典中获取input
  358. input_data = param.input
  359. session_id = param.config.session_id
  360. context = param.context
  361. header_info = {
  362. }
  363. # 从app.state中获取数据库连接池
  364. async_db_pool = request.app.state.async_db_pool
  365. test_tab_dao = TestTabDAO(async_db_pool)
  366. test_id = input_data;
  367. data = await test_tab_dao.get_user_by_id(user_id=test_id)
  368. server_logger.info(trace_id=trace_id, msg=f"【result】: {data}", log_type="/mysql/get")
  369. json_str = json.dumps(data , cls=DateTimeEncoder, ensure_ascii=False, indent=2)
  370. output = json_str
  371. # 直接执行
  372. server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="/mysql/get")
  373. # 返回字典格式的响应
  374. return JSONResponse(
  375. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  376. except ValueError as err:
  377. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/mysql/get")
  378. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  379. except Exception as err:
  380. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/mysql/get")
  381. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  382. @test_router.post("/mysql/list", response_model=TestForm)
  383. async def test_mysql_add(
  384. request: Request,
  385. param: TestForm,
  386. trace_id: str = Depends(get_operation_id)):
  387. """
  388. 根据MySQL应用
  389. """
  390. try:
  391. server_logger.info(trace_id=trace_id, msg=f"{param}")
  392. # 验证参数
  393. # 从字典中获取input
  394. input_data = param.input
  395. session_id = param.config.session_id
  396. context = param.context
  397. header_info = {
  398. }
  399. # 从app.state中获取数据库连接池
  400. async_db_pool = request.app.state.async_db_pool
  401. from foundation.base.mysql.async_mysql_base_dao import TestTabDAO
  402. test_tab_dao = TestTabDAO(async_db_pool)
  403. test_id = input_data;
  404. data = await test_tab_dao.get_all_users()
  405. server_logger.info(trace_id=trace_id, msg=f"【result】: {data}", log_type="/mysql/list")
  406. json_str = json.dumps(data , cls=DateTimeEncoder, ensure_ascii=False, indent=2)
  407. output = json_str
  408. # 直接执行
  409. server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="/mysql/list")
  410. # 返回字典格式的响应
  411. return JSONResponse(
  412. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  413. except ValueError as err:
  414. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/mysql/list")
  415. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  416. except Exception as err:
  417. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/mysql/list")
  418. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  419. @test_router.post("/mysql/update", response_model=TestForm)
  420. async def test_mysql_add(
  421. request: Request,
  422. param: TestForm,
  423. trace_id: str = Depends(get_operation_id)):
  424. """
  425. 根据MySQL应用
  426. """
  427. try:
  428. server_logger.info(trace_id=trace_id, msg=f"{param}")
  429. # 验证参数
  430. # 从字典中获取input
  431. input_data = param.input
  432. session_id = param.config.session_id
  433. context = param.context
  434. header_info = {
  435. }
  436. # 从app.state中获取数据库连接池
  437. async_db_pool = request.app.state.async_db_pool
  438. test_tab_dao = TestTabDAO(async_db_pool)
  439. test_id = session_id;
  440. updates = {
  441. "name": input_data,
  442. "email": "test_email——upt",
  443. "age": 22
  444. }
  445. success = await test_tab_dao.update_user(user_id=test_id , **updates)
  446. server_logger.info(trace_id=trace_id, msg=f"【result】: {success}", log_type="/mysql/update")
  447. output = success
  448. # 直接执行
  449. server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="/mysql/update")
  450. # 返回字典格式的响应
  451. return JSONResponse(
  452. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  453. except ValueError as err:
  454. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/mysql/update")
  455. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  456. except Exception as err:
  457. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/mysql/update")
  458. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  459. @test_router.post("/bop/get", response_model=TestForm)
  460. async def test_bop_get(
  461. request: Request,
  462. param: TestForm,
  463. trace_id: str = Depends(get_operation_id)):
  464. """
  465. 根据MySQL应用
  466. """
  467. try:
  468. server_logger.info(trace_id=trace_id, msg=f"{param}")
  469. # 验证参数
  470. # 从字典中获取input
  471. input_data = param.input
  472. session_id = param.config.session_id
  473. context = param.context
  474. header_info = {
  475. }
  476. # 从app.state中获取数据库连接池
  477. async_db_pool = request.app.state.async_db_pool
  478. bop_dao = BasisOfPreparationDAO(async_db_pool)
  479. test_id = input_data;
  480. data = await bop_dao.get_info_by_id(id=test_id)
  481. server_logger.info(trace_id=trace_id, msg=f"【result】: {data}", log_type="/bop/get")
  482. json_str = json.dumps(data , cls=DateTimeEncoder, ensure_ascii=False, indent=2)
  483. output = json_str
  484. # 直接执行
  485. server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="/bop/get")
  486. # 返回字典格式的响应
  487. return JSONResponse(
  488. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  489. except ValueError as err:
  490. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/bop/get")
  491. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  492. except Exception as err:
  493. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/bop/get")
  494. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  495. @test_router.post("/bop/list", response_model=TestForm)
  496. async def test_mysql_add(
  497. request: Request,
  498. param: TestForm,
  499. trace_id: str = Depends(get_operation_id)):
  500. """
  501. 根据MySQL应用
  502. """
  503. try:
  504. server_logger.info(trace_id=trace_id, msg=f"{param}")
  505. # 验证参数
  506. # 从字典中获取input
  507. input_data = param.input
  508. session_id = param.config.session_id
  509. context = param.context
  510. header_info = {
  511. }
  512. # 从app.state中获取数据库连接池
  513. async_db_pool = request.app.state.async_db_pool
  514. from foundation.base.mysql.async_mysql_base_dao import TestTabDAO
  515. bop_dao = BasisOfPreparationDAO(async_db_pool)
  516. test_id = input_data;
  517. data = await bop_dao.get_list()
  518. server_logger.info(trace_id=trace_id, msg=f"【result】: {data}", log_type="/bop/list")
  519. json_str = json.dumps(data , cls=DateTimeEncoder, ensure_ascii=False, indent=2)
  520. output = json_str
  521. # 直接执行
  522. server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="/bop/list")
  523. # 返回字典格式的响应
  524. return JSONResponse(
  525. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  526. except ValueError as err:
  527. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/bop/list")
  528. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  529. except Exception as err:
  530. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/bop/list")
  531. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))