test_views.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412
  1. # !/usr/bin/ python
  2. # -*- coding: utf-8 -*-
  3. '''
  4. @Project : lq-agent-api
  5. @File :cattle_farm_views.py
  6. @IDE :PyCharm
  7. @Author :
  8. @Date :2025/7/10 17:32
  9. '''
  10. import json
  11. from typing import Optional
  12. from fastapi import Depends, Response, Header
  13. from sse_starlette import EventSourceResponse
  14. from starlette.responses import JSONResponse
  15. from foundation.agent.test_agent import test_agent_client
  16. from foundation.agent.generate.model_generate import test_generate_model_client
  17. from foundation.logger.loggering import server_logger
  18. from foundation.schemas.test_schemas import TestForm
  19. from foundation.utils.common import return_json, handler_err
  20. from views import test_router, get_operation_id
  21. from foundation.agent.workflow.test_workflow_graph import test_workflow_graph
  22. from file_processors.pdf_processor import PDFProcessor
  23. @test_router.post("/generate/chat", response_model=TestForm)
  24. async def generate_chat_endpoint(
  25. param: TestForm,
  26. trace_id: str = Depends(get_operation_id)):
  27. """
  28. 生成类模型
  29. """
  30. try:
  31. server_logger.info(trace_id=trace_id, msg=f"{param}")
  32. print(trace_id)
  33. # 从字典中获取input
  34. input_query = param.input
  35. session_id = param.config.session_id
  36. context = param.context
  37. header_info = {
  38. }
  39. task_prompt_info = {"task_prompt": ""}
  40. output = test_generate_model_client.get_model_generate_invoke(trace_id , task_prompt_info,
  41. input_query, context)
  42. # 直接执行
  43. server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="agent/chat")
  44. # 返回字典格式的响应
  45. return JSONResponse(
  46. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  47. except ValueError as err:
  48. handler_err(server_logger, trace_id=trace_id, err=err, err_name="generate/stream")
  49. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  50. except Exception as err:
  51. handler_err(server_logger, trace_id=trace_id, err=err, err_name="generate/stream")
  52. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  53. @test_router.post("/generate/stream", response_model=TestForm)
  54. async def generate_stream_endpoint(
  55. param: TestForm,
  56. trace_id: str = Depends(get_operation_id)):
  57. """
  58. 生成类模型
  59. """
  60. try:
  61. server_logger.info(trace_id=trace_id, msg=f"{param}")
  62. # 从字典中获取input
  63. input_query = param.input
  64. session_id = param.config.session_id
  65. context = param.context
  66. header_info = {
  67. }
  68. task_prompt_info = {"task_prompt": ""}
  69. # 创建 SSE 流式响应
  70. async def event_generator():
  71. try:
  72. # 流式处理查询 trace_id, task_prompt_info: dict, input_query, context=None
  73. for chunk in test_generate_model_client.get_model_generate_stream(trace_id , task_prompt_info,
  74. input_query, context):
  75. # 发送数据块
  76. yield {
  77. "event": "message",
  78. "data": json.dumps({
  79. "output": chunk,
  80. "completed": False,
  81. }, ensure_ascii=False)
  82. }
  83. # 获取缓存数据
  84. result_data = {}
  85. # 发送结束事件
  86. yield {
  87. "event": "message_end",
  88. "data": json.dumps({
  89. "completed": True,
  90. "message": json.dumps(result_data, ensure_ascii=False),
  91. "code": 0,
  92. "trace_id": trace_id,
  93. }, ensure_ascii=False),
  94. }
  95. except Exception as e:
  96. # 错误处理
  97. yield {
  98. "event": "error",
  99. "data": json.dumps({
  100. "trace_id": trace_id,
  101. "message": str(e),
  102. "code": 1
  103. }, ensure_ascii=False)
  104. }
  105. finally:
  106. # 不需要关闭客户端,因为它是单例
  107. pass
  108. # 返回 SSE 响应
  109. return EventSourceResponse(
  110. event_generator(),
  111. headers={
  112. "Cache-Control": "no-cache",
  113. "Connection": "keep-alive"
  114. }
  115. )
  116. except ValueError as err:
  117. handler_err(server_logger, trace_id=trace_id, err=err, err_name="generate/stream")
  118. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  119. except Exception as err:
  120. handler_err(server_logger, trace_id=trace_id, err=err, err_name="generate/stream")
  121. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  122. # 路由
  123. @test_router.post("/agent/chat", response_model=TestForm)
  124. async def chat_endpoint(
  125. param: TestForm,
  126. trace_id: str = Depends(get_operation_id)):
  127. """
  128. 根据场景获取智能体反馈
  129. """
  130. try:
  131. server_logger.info(trace_id=trace_id, msg=f"{param}")
  132. # 验证参数
  133. # 从字典中获取input
  134. input_data = param.input
  135. session_id = param.config.session_id
  136. context = param.context
  137. header_info = {
  138. }
  139. task_prompt_info = {"task_prompt": ""}
  140. # stream 流式执行
  141. output = await test_agent_client.handle_query(trace_id , task_prompt_info, input_data, context, param.config)
  142. # 直接执行
  143. server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="agent/chat")
  144. # 返回字典格式的响应
  145. return JSONResponse(
  146. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  147. except ValueError as err:
  148. handler_err(server_logger, trace_id=trace_id, err=err, err_name="agent/chat")
  149. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  150. except Exception as err:
  151. handler_err(server_logger, trace_id=trace_id, err=err, err_name="agent/chat")
  152. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  153. @test_router.post("/agent/stream", response_class=Response)
  154. async def chat_agent_stream(param: TestForm,
  155. trace_id: str = Depends(get_operation_id)):
  156. """
  157. 根据场景获取智能体反馈 (SSE流式响应)
  158. """
  159. try:
  160. server_logger.info(trace_id=trace_id, msg=f"{param}")
  161. # 提取参数
  162. input_data = param.input
  163. context = param.context
  164. header_info = {
  165. }
  166. task_prompt_info = {"task_prompt": ""}
  167. # 如果business_scene为None,则使用大模型进行意图识别
  168. server_logger.info(trace_id=trace_id, msg=f"{param}")
  169. # 创建 SSE 流式响应
  170. async def event_generator():
  171. try:
  172. # 流式处理查询
  173. async for chunk in test_agent_client.handle_query_stream(
  174. trace_id=trace_id,
  175. config_param=param.config,
  176. task_prompt_info=task_prompt_info,
  177. input_query=input_data,
  178. context=context,
  179. header_info=header_info
  180. ):
  181. server_logger.debug(trace_id=trace_id, msg=f"{chunk}")
  182. # 发送数据块
  183. yield {
  184. "event": "message",
  185. "data": json.dumps({
  186. "code": 0,
  187. "output": chunk,
  188. "completed": False,
  189. "trace_id": trace_id,
  190. }, ensure_ascii=False)
  191. }
  192. # 获取缓存数据
  193. result_data = {}
  194. # 发送结束事件
  195. yield {
  196. "event": "message_end",
  197. "data": json.dumps({
  198. "completed": True,
  199. "message": json.dumps(result_data, ensure_ascii=False),
  200. "code": 0,
  201. "trace_id": trace_id,
  202. }, ensure_ascii=False),
  203. }
  204. except Exception as e:
  205. # 错误处理
  206. yield {
  207. "event": "error",
  208. "data": json.dumps({
  209. "trace_id": trace_id,
  210. "message": str(e),
  211. "code": 1
  212. }, ensure_ascii=False)
  213. }
  214. finally:
  215. # 不需要关闭客户端,因为它是单例
  216. pass
  217. # 返回 SSE 响应
  218. return EventSourceResponse(
  219. event_generator(),
  220. headers={
  221. "Cache-Control": "no-cache",
  222. "Connection": "keep-alive"
  223. }
  224. )
  225. except Exception as err:
  226. # 初始错误处理
  227. handler_err(server_logger, trace_id=trace_id, err=err, err_name="agent/stream")
  228. return JSONResponse(
  229. return_json(code=1, msg=f"{err}", trace_id=trace_id),
  230. status_code=500
  231. )
  232. @test_router.post("/graph/stream", response_class=Response)
  233. async def chat_graph_stream(param: TestForm,
  234. trace_id: str = Depends(get_operation_id)):
  235. """
  236. 根据场景获取智能体反馈 (SSE流式响应)
  237. """
  238. try:
  239. server_logger.info(trace_id=trace_id, msg=f"{param}")
  240. # request_param = {
  241. # "input": param.input,
  242. # "config": param.config,
  243. # "context": param.context
  244. # }
  245. # 创建 SSE 流式响应
  246. async def event_generator():
  247. try:
  248. # 流式处理查询
  249. async for chunk in test_workflow_graph.handle_query_stream(
  250. param=param,
  251. trace_id=trace_id,
  252. ):
  253. server_logger.debug(trace_id=trace_id, msg=f"{chunk}")
  254. # 发送数据块
  255. yield {
  256. "event": "message",
  257. "data": json.dumps({
  258. "code": 0,
  259. "output": chunk,
  260. "completed": False,
  261. "trace_id": trace_id,
  262. "dataType": "text"
  263. }, ensure_ascii=False)
  264. }
  265. # 发送结束事件
  266. yield {
  267. "event": "message_end",
  268. "data": json.dumps({
  269. "completed": True,
  270. "message": "Stream completed",
  271. "code": 0,
  272. "trace_id": trace_id,
  273. "dataType": "text"
  274. }, ensure_ascii=False),
  275. }
  276. except Exception as e:
  277. # 错误处理
  278. yield {
  279. "event": "error",
  280. "data": json.dumps({
  281. "trace_id": trace_id,
  282. "msg": str(e),
  283. "code": 1,
  284. "dataType": "text"
  285. }, ensure_ascii=False)
  286. }
  287. finally:
  288. # 不需要关闭客户端,因为它是单例
  289. pass
  290. # 返回 SSE 响应
  291. return EventSourceResponse(
  292. event_generator(),
  293. headers={
  294. "Cache-Control": "no-cache",
  295. "Connection": "keep-alive"
  296. }
  297. )
  298. except Exception as err:
  299. # 初始错误处理
  300. handler_err(server_logger, trace_id=trace_id, err=err, err_name="graph/stream")
  301. return JSONResponse(
  302. return_json(code=1, msg=f"{err}", trace_id=trace_id),
  303. status_code=500
  304. )
  305. @test_router.post("/data/governance", response_model=TestForm)
  306. async def generate_chat_endpoint(
  307. param: TestForm,
  308. trace_id: str = Depends(get_operation_id)):
  309. """
  310. 生成类模型
  311. """
  312. try:
  313. server_logger.info(trace_id=trace_id, msg=f"{param}")
  314. print(trace_id)
  315. # 从字典中获取input
  316. input_query = param.input
  317. session_id = param.config.session_id
  318. context = param.context
  319. header_info = {
  320. }
  321. task_prompt_info = {"task_prompt": ""}
  322. output = test_generate_model_client.get_model_data_governance_invoke(trace_id , task_prompt_info,
  323. input_query, context)
  324. # 直接执行
  325. server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="agent/chat")
  326. # 返回字典格式的响应
  327. return JSONResponse(
  328. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  329. except ValueError as err:
  330. handler_err(server_logger, trace_id=trace_id, err=err, err_name="generate/stream")
  331. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  332. except Exception as err:
  333. handler_err(server_logger, trace_id=trace_id, err=err, err_name="generate/stream")
  334. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  335. @test_router.post("/data/pdf/governance", response_model=TestForm)
  336. async def generate_chat_endpoint(
  337. param: TestForm,
  338. trace_id: str = Depends(get_operation_id)):
  339. """
  340. 生成类模型
  341. """
  342. try:
  343. server_logger.info(trace_id=trace_id, msg=f"{param}")
  344. print(trace_id)
  345. # 从字典中获取input
  346. input_query = param.input
  347. session_id = param.config.session_id
  348. context = param.context
  349. header_info = {
  350. }
  351. task_prompt_info = {"task_prompt": ""}
  352. #file_directory= "I:/wangxun_dev_workspace/lq_workspace/LQDataGovernance/test/pdf_files"
  353. file_directory= "test/pdf_files"
  354. # 初始化知识问答处理
  355. pdf_processor = PDFProcessor(directory=file_directory)
  356. file_data = pdf_processor.process_pdfs_group()
  357. server_logger.info(trace_id=trace_id, msg=f"【result】: {file_data}", log_type="agent/chat")
  358. output = None
  359. #output = test_generate_model_client.get_model_data_governance_invoke(trace_id , task_prompt_info, input_query, context)
  360. # 直接执行
  361. #server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="agent/chat")
  362. # 返回字典格式的响应
  363. return JSONResponse(
  364. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  365. except ValueError as err:
  366. handler_err(server_logger, trace_id=trace_id, err=err, err_name="generate/stream")
  367. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  368. except Exception as err:
  369. handler_err(server_logger, trace_id=trace_id, err=err, err_name="generate/stream")
  370. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))