cattle_farm_views.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. # !/usr/bin/ python
  2. # -*- coding: utf-8 -*-
  3. '''
  4. @Project : lq-agent-api
  5. @File :cattle_farm_views.py
  6. @IDE :PyCharm
  7. @Author :
  8. @Date :2025/7/10 17:32
  9. '''
  10. import json
  11. from typing import Optional
  12. from fastapi import Depends, Response, Header
  13. from sse_starlette import EventSourceResponse
  14. from starlette.responses import JSONResponse
  15. from agent.agent_mcp import client
  16. from generate.model_generate import XiwuzcModelGenerateClient
  17. from logger.loggering import server_logger
  18. from schemas.cattle_farm import CattleFarm
  19. from utils import yaml_utils
  20. from utils.common import return_json, handler_err
  21. from views import cattle_router, get_operation_id
  22. from agent.intent import intent_identify_client
  23. def get_token(authorization: Optional[str] = Header(default=None)):
  24. """提取 Bearer Token (非必填)"""
  25. if authorization is None:
  26. return None
  27. scheme, _, token = authorization.partition(" ")
  28. return token
  29. def get_tenant_id(tenant_id: Optional[str] | None = Header(None, alias="X-lq-TENANT-ID")):
  30. """处理租户ID"""
  31. return tenant_id
  32. # 路由
  33. @cattle_router.post("/chat", response_model=CattleFarm)
  34. async def chat_endpoint(
  35. param: CattleFarm,
  36. token: str = Depends(get_token),
  37. tenant_id: str = Depends(get_tenant_id),
  38. trace_id: str = Depends(get_operation_id)):
  39. """
  40. 根据场景获取智能体反馈
  41. """
  42. try:
  43. server_logger.info(trace_id=trace_id, msg=f"{param}")
  44. # 验证参数
  45. # 从字典中获取input
  46. input_data = param.input
  47. session_id = param.config.sessionId
  48. business_scene = param.businessScene
  49. context = param.context
  50. supplementInfo = param.supplementInfo
  51. header_info = {
  52. "token": token,
  53. "tenantId": tenant_id,
  54. }
  55. # 如果business_scene为None,则使用大模型进行意图识别
  56. if business_scene is None:
  57. business_scene = await intent_identify_client.recognize_intent(trace_id=trace_id , config=param.config , input=input_data)
  58. server_logger.info(trace_id=trace_id, msg=f"使用意图识别:business_scene={business_scene}")
  59. business_scene_enum, task_prompt_info = yaml_utils.get_business_scene_prompt(trace_id=trace_id, business_scene=business_scene)
  60. final_result_data_type = task_prompt_info["final_result_data_type"]
  61. server_logger.info(trace_id=trace_id, msg=f"session_id:{session_id}, business_scene:{business_scene},final_result_data_type:{final_result_data_type} ,input_data: {input_data}",
  62. log_type="queryex")
  63. # stream 流式执行
  64. output = await client.handle_query(trace_id , business_scene , task_prompt_info, input_data, context, supplementInfo, header_info , param.config)
  65. # 直接执行
  66. server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="queryex")
  67. # 返回字典格式的响应
  68. return JSONResponse(
  69. return_json(business_scene=business_scene, data={"output": output}, data_type=final_result_data_type, trace_id=trace_id))
  70. except ValueError as err:
  71. handler_err(server_logger, trace_id=trace_id, err=err, err_name="queryex")
  72. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  73. except Exception as err:
  74. handler_err(server_logger, trace_id=trace_id, err=err, err_name="queryex")
  75. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  76. @cattle_router.post("/stream", response_class=Response)
  77. async def chat_agent(param: CattleFarm,
  78. token: str = Depends(get_token),
  79. tenant_id: str = Depends(get_tenant_id),
  80. trace_id: str = Depends(get_operation_id)):
  81. """
  82. 根据场景获取智能体反馈 (SSE流式响应)
  83. """
  84. try:
  85. server_logger.info(trace_id=trace_id, msg=f"{param}")
  86. # 提取参数
  87. input_data = param.input
  88. session_id = param.config.sessionId
  89. user_role = param.config.userRole
  90. business_scene = param.businessScene
  91. context = param.context
  92. supplementInfo = param.supplementInfo
  93. header_info = {
  94. "token": token,
  95. "tenantId": tenant_id,
  96. }
  97. # 如果business_scene为None,则使用大模型进行意图识别
  98. # 获取任务提示信息
  99. from enums.common_enums import BusinessSceneEnum
  100. business_scene_enum, task_prompt_info = BusinessSceneEnum.COMMON_MODEL_QUERY , {"task_prompt": ""}
  101. final_result_data_type = "text"
  102. server_logger.info(trace_id=trace_id, msg=f"session_id:{session_id}, business_scene:{business_scene},final_result_data_type:{final_result_data_type} ,input_data: {input_data}",
  103. log_type="queryex")
  104. server_logger.info(trace_id=trace_id, msg=f"{param}")
  105. # 创建 SSE 流式响应
  106. async def event_generator():
  107. try:
  108. # 流式处理查询
  109. async for chunk in client.handle_query_stream(
  110. trace_id=trace_id,
  111. config_param=param.config,
  112. business_scene=business_scene,
  113. task_prompt_info=task_prompt_info,
  114. input_query=input_data,
  115. context=context,
  116. supplement_info=supplementInfo,
  117. header_info=header_info
  118. ):
  119. server_logger.debug(trace_id=trace_id, msg=f"{chunk}")
  120. # 发送数据块
  121. yield {
  122. "event": "message",
  123. "data": json.dumps({
  124. "code": 0,
  125. "output": chunk,
  126. "completed": False,
  127. "trace_id": trace_id,
  128. "dataType": final_result_data_type,
  129. "business_scene": business_scene,
  130. }, ensure_ascii=False)
  131. }
  132. # 获取缓存数据
  133. result_data = await client.get_redis_result_cache_data(trace_id=trace_id)
  134. # 发送结束事件
  135. yield {
  136. "event": "message_end",
  137. "data": json.dumps({
  138. "completed": True,
  139. "message": json.dumps(result_data, ensure_ascii=False),
  140. "code": 0,
  141. "trace_id": trace_id,
  142. "dataType": "text",
  143. "business_scene": business_scene,
  144. }, ensure_ascii=False),
  145. }
  146. except Exception as e:
  147. # 错误处理
  148. yield {
  149. "event": "error",
  150. "data": json.dumps({
  151. "trace_id": trace_id,
  152. "message": str(e),
  153. "code": 1,
  154. "dataType": "text",
  155. "business_scene": business_scene,
  156. }, ensure_ascii=False)
  157. }
  158. finally:
  159. # 不需要关闭客户端,因为它是单例
  160. pass
  161. # 返回 SSE 响应
  162. return EventSourceResponse(
  163. event_generator(),
  164. headers={
  165. "Cache-Control": "no-cache",
  166. "Connection": "keep-alive"
  167. }
  168. )
  169. except Exception as err:
  170. # 初始错误处理
  171. handler_err(server_logger, trace_id=trace_id, err=err, err_name="queryex")
  172. return JSONResponse(
  173. return_json(code=1, msg=f"{err}", trace_id=trace_id),
  174. status_code=500
  175. )
  176. @cattle_router.post("/generate/ai_stream", response_class=Response)
  177. def chat_stream(param: CattleFarm,
  178. token: str = Depends(get_token),
  179. tenant_id: str = Depends(get_tenant_id),
  180. trace_id: str = Depends(get_operation_id)):
  181. try:
  182. server_logger.info(trace_id=trace_id, msg=f"{param}")
  183. # 提取参数
  184. input_data = param.input
  185. session_id = param.config.sessionId
  186. business_scene = param.businessScene
  187. context = param.context
  188. supplementInfo = param.supplementInfo
  189. header_info = {
  190. "token": token,
  191. "tenantId": tenant_id,
  192. }
  193. # 如果business_scene为None,则使用大模型进行意图识别
  194. if business_scene is None:
  195. business_scene = intent_identify_client.recognize_intent(input_data)
  196. server_logger.info(trace_id=trace_id, msg=f"使用意图识别:business_scene={business_scene}")
  197. # 获取系统提示
  198. business_scene_enum, task_prompt_info = yaml_utils.get_business_scene_prompt(trace_id=trace_id, business_scene=business_scene)
  199. server_logger.info(trace_id=trace_id, msg=f"session_id:{session_id}, business_scene:{business_scene} , business_scene_enum:{business_scene_enum} ,input_data: {input_data}",
  200. log_type="queryex")
  201. xwzc_generate_client = XiwuzcModelGenerateClient()
  202. # 创建 SSE 流式响应
  203. async def event_generator():
  204. try:
  205. # 流式处理查询
  206. for chunk in xwzc_generate_client.get_model_generate_stream(task_prompt_info, session_id, input_data,
  207. context, supplementInfo):
  208. # 发送数据块
  209. yield {
  210. "event": "message",
  211. "data": json.dumps({
  212. "output": chunk,
  213. "completed": False,
  214. "trace_id": trace_id,
  215. "dataType": business_scene_enum.data_type
  216. })
  217. }
  218. # 发送结束事件
  219. yield {
  220. "event": "message_end",
  221. "data": json.dumps({
  222. "completed": True,
  223. "message": "Stream completed",
  224. "code": 0,
  225. "trace_id": trace_id
  226. }),
  227. }
  228. except Exception as e:
  229. # 错误处理
  230. yield {
  231. "event": "error",
  232. "data": json.dumps({
  233. "trace_id": trace_id,
  234. "msg": str(e),
  235. "code": 1,
  236. "dataType": "text"
  237. })
  238. }
  239. # 返回 SSE 响应
  240. return EventSourceResponse(
  241. event_generator(),
  242. headers={
  243. "Cache-Control": "no-cache",
  244. "Connection": "keep-alive"
  245. }
  246. )
  247. except Exception as err:
  248. # 初始错误处理
  249. handler_err(server_logger, trace_id=trace_id, err=err, err_name="queryex")
  250. return JSONResponse(
  251. return_json(code=1, msg=f"{err}", trace_id=trace_id),
  252. status_code=500
  253. )
  254. @cattle_router.post("/generate/tools/execute", response_model=CattleFarm)
  255. async def chat_generate_tools_endpoint(
  256. param: CattleFarm,
  257. token: str = Depends(get_token),
  258. tenant_id: str = Depends(get_tenant_id),
  259. trace_id: str = Depends(get_operation_id)):
  260. """
  261. 工具调用
  262. """
  263. try:
  264. server_logger.info(trace_id=trace_id, msg=f"{param}")
  265. # 验证参数
  266. # 从字典中获取input
  267. input_data = param.input
  268. session_id = param.config.sessionId
  269. business_scene = param.businessScene
  270. context = param.context
  271. supplementInfo = param.supplementInfo
  272. header_info = {
  273. "token": token,
  274. "tenantId": tenant_id,
  275. }
  276. # 如果business_scene为None,则使用大模型进行意图识别
  277. if business_scene is None:
  278. business_scene = intent_identify_client.recognize_intent(input_data)
  279. server_logger.info(trace_id=trace_id, msg=f"使用意图识别:business_scene={business_scene}")
  280. business_scene_enum, task_prompt_info = yaml_utils.get_business_scene_prompt(trace_id=trace_id, business_scene=business_scene)
  281. server_logger.info(trace_id=trace_id, msg=f"session_id:{session_id}, business_scene:{business_scene} , business_scene_enum:{business_scene_enum} ,input_data: {input_data}",
  282. log_type="queryex/tools2")
  283. xwzc_generate_client = XiwuzcModelGenerateClient()
  284. # stream 流式执行
  285. output = await xwzc_generate_client.get_model_tools_call(trace_id, session_id, task_prompt_info, input_data, context, supplementInfo, header_info)
  286. # 直接执行
  287. server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="queryex/tools2")
  288. # 返回字典格式的响应
  289. return JSONResponse(
  290. return_json(data={"output": output}, data_type=business_scene_enum.data_type, trace_id=trace_id))
  291. except ValueError as err:
  292. handler_err(server_logger, trace_id=trace_id, err=err, err_name="queryex/tools2")
  293. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  294. except Exception as err:
  295. handler_err(server_logger, trace_id=trace_id, err=err, err_name="queryex/tools2")
  296. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))