test_views.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827
  1. # !/usr/bin/ python
  2. # -*- coding: utf-8 -*-
  3. '''
  4. @Project : lq-agent-api
  5. @File :cattle_farm_views.py
  6. @IDE :PyCharm
  7. @Author :
  8. @Date :2025/7/10 17:32
  9. '''
  10. import json
  11. from typing import Optional
  12. from fastapi import Depends, Response, Header
  13. from sse_starlette import EventSourceResponse
  14. from starlette.responses import JSONResponse
  15. from fastapi import Depends, Request, Response, Header
  16. from foundation.agent.test_agent import test_agent_client
  17. from foundation.agent.generate.model_generate import generate_model_client
  18. from foundation.logger.loggering import server_logger
  19. from foundation.schemas.test_schemas import TestForm
  20. from foundation.utils.common import return_json, handler_err
  21. from views import test_router, get_operation_id
  22. from foundation.agent.workflow.test_workflow_graph import test_workflow_graph
  23. from foundation.base.mysql.async_mysql_base_dao import TestTabDAO
  24. from database.repositories.bus_data_query import BasisOfPreparationDAO
  25. from foundation.utils.tool_utils import DateTimeEncoder
  26. from langchain_core.prompts import ChatPromptTemplate
  27. from foundation.utils.yaml_utils import system_prompt_config
  28. from foundation.models.silicon_flow import SiliconFlowAPI
  29. from foundation.rag.vector.pg_vector import PGVectorDB
  30. from foundation.rag.vector.milvus_vector import MilvusVectorManager
  31. @test_router.post("/generate/chat", response_model=TestForm)
  32. async def generate_chat_endpoint(
  33. param: TestForm,
  34. trace_id: str = Depends(get_operation_id)):
  35. """
  36. 生成类模型
  37. """
  38. try:
  39. server_logger.info(trace_id=trace_id, msg=f"{param}")
  40. print(trace_id)
  41. # 从字典中获取input
  42. input_query = param.input
  43. session_id = param.config.session_id
  44. context = param.context
  45. header_info = {
  46. }
  47. # 创建ChatPromptTemplate
  48. template = ChatPromptTemplate.from_messages([
  49. ("system", system_prompt_config['system_prompt']),
  50. ("user", input_query)
  51. ])
  52. task_prompt_info = {"task_prompt": template}
  53. output = await generate_model_client.get_model_generate_invoke(trace_id=trace_id , task_prompt_info=task_prompt_info)
  54. # 直接执行
  55. server_logger.info(trace_id=trace_id, msg=f"【result】: {output}", log_type="agent/chat")
  56. # 返回字典格式的响应
  57. return JSONResponse(
  58. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  59. except ValueError as err:
  60. handler_err(server_logger, trace_id=trace_id, err=err, err_name="agent/chat")
  61. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  62. except Exception as err:
  63. handler_err(server_logger, trace_id=trace_id, err=err, err_name="agent/chat")
  64. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  65. @test_router.post("/generate/stream", response_model=TestForm)
  66. async def generate_stream_endpoint(
  67. param: TestForm,
  68. trace_id: str = Depends(get_operation_id)):
  69. """
  70. 生成类模型
  71. """
  72. try:
  73. server_logger.info(trace_id=trace_id, msg=f"{param}")
  74. # 从字典中获取input
  75. input_query = param.input
  76. session_id = param.config.session_id
  77. context = param.context
  78. header_info = {
  79. }
  80. # 创建ChatPromptTemplate
  81. template = ChatPromptTemplate.from_messages([
  82. ("system", system_prompt_config['system_prompt']),
  83. ("user", input_query)
  84. ])
  85. task_prompt_info = {"task_prompt": template}
  86. # 创建 SSE 流式响应
  87. async def event_generator():
  88. try:
  89. # 流式处理查询 trace_id, task_prompt_info: dict, input_query, context=None
  90. for chunk in generate_model_client.get_model_generate_stream(trace_id=trace_id , task_prompt_info=task_prompt_info):
  91. # 发送数据块
  92. yield {
  93. "event": "message",
  94. "data": json.dumps({
  95. "output": chunk,
  96. "completed": False,
  97. }, ensure_ascii=False)
  98. }
  99. # 获取缓存数据
  100. result_data = {}
  101. # 发送结束事件
  102. yield {
  103. "event": "message_end",
  104. "data": json.dumps({
  105. "completed": True,
  106. "message": json.dumps(result_data, ensure_ascii=False),
  107. "code": 0,
  108. "trace_id": trace_id,
  109. }, ensure_ascii=False),
  110. }
  111. except Exception as e:
  112. # 错误处理
  113. yield {
  114. "event": "error",
  115. "data": json.dumps({
  116. "trace_id": trace_id,
  117. "message": str(e),
  118. "code": 1
  119. }, ensure_ascii=False)
  120. }
  121. finally:
  122. # 不需要关闭客户端,因为它是单例
  123. pass
  124. # 返回 SSE 响应
  125. return EventSourceResponse(
  126. event_generator(),
  127. headers={
  128. "Cache-Control": "no-cache",
  129. "Connection": "keep-alive"
  130. }
  131. )
  132. except ValueError as err:
  133. handler_err(server_logger, trace_id=trace_id, err=err, err_name="generate/stream")
  134. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  135. except Exception as err:
  136. handler_err(server_logger, trace_id=trace_id, err=err, err_name="generate/stream")
  137. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  138. # 路由
  139. @test_router.post("/agent/chat", response_model=TestForm)
  140. async def chat_endpoint(
  141. param: TestForm,
  142. trace_id: str = Depends(get_operation_id)):
  143. """
  144. 根据场景获取智能体反馈
  145. """
  146. try:
  147. server_logger.info(trace_id=trace_id, msg=f"{param}")
  148. # 验证参数
  149. # 从字典中获取input
  150. input_data = param.input
  151. session_id = param.config.session_id
  152. context = param.context
  153. header_info = {
  154. }
  155. task_prompt_info = {"task_prompt": ""}
  156. # stream 流式执行
  157. output = await test_agent_client.handle_query(trace_id , task_prompt_info, input_data, context, param.config)
  158. # 直接执行
  159. server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="agent/chat")
  160. # 返回字典格式的响应
  161. return JSONResponse(
  162. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  163. except ValueError as err:
  164. handler_err(server_logger, trace_id=trace_id, err=err, err_name="agent/chat")
  165. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  166. except Exception as err:
  167. handler_err(server_logger, trace_id=trace_id, err=err, err_name="agent/chat")
  168. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  169. @test_router.post("/agent/stream", response_class=Response)
  170. async def chat_agent_stream(param: TestForm,
  171. trace_id: str = Depends(get_operation_id)):
  172. """
  173. 根据场景获取智能体反馈 (SSE流式响应)
  174. """
  175. try:
  176. server_logger.info(trace_id=trace_id, msg=f"{param}")
  177. # 提取参数
  178. input_data = param.input
  179. context = param.context
  180. header_info = {
  181. }
  182. task_prompt_info = {"task_prompt": ""}
  183. # 如果business_scene为None,则使用大模型进行意图识别
  184. server_logger.info(trace_id=trace_id, msg=f"{param}")
  185. # 创建 SSE 流式响应
  186. async def event_generator():
  187. try:
  188. # 流式处理查询
  189. async for chunk in test_agent_client.handle_query_stream(
  190. trace_id=trace_id,
  191. config_param=param.config,
  192. task_prompt_info=task_prompt_info,
  193. input_query=input_data,
  194. context=context,
  195. header_info=header_info
  196. ):
  197. server_logger.debug(trace_id=trace_id, msg=f"{chunk}")
  198. # 发送数据块
  199. yield {
  200. "event": "message",
  201. "data": json.dumps({
  202. "code": 0,
  203. "output": chunk,
  204. "completed": False,
  205. "trace_id": trace_id,
  206. }, ensure_ascii=False)
  207. }
  208. # 获取缓存数据
  209. result_data = {}
  210. # 发送结束事件
  211. yield {
  212. "event": "message_end",
  213. "data": json.dumps({
  214. "completed": True,
  215. "message": json.dumps(result_data, ensure_ascii=False),
  216. "code": 0,
  217. "trace_id": trace_id,
  218. }, ensure_ascii=False),
  219. }
  220. except Exception as e:
  221. # 错误处理
  222. yield {
  223. "event": "error",
  224. "data": json.dumps({
  225. "trace_id": trace_id,
  226. "message": str(e),
  227. "code": 1
  228. }, ensure_ascii=False)
  229. }
  230. finally:
  231. # 不需要关闭客户端,因为它是单例
  232. pass
  233. # 返回 SSE 响应
  234. return EventSourceResponse(
  235. event_generator(),
  236. headers={
  237. "Cache-Control": "no-cache",
  238. "Connection": "keep-alive"
  239. }
  240. )
  241. except Exception as err:
  242. # 初始错误处理
  243. handler_err(server_logger, trace_id=trace_id, err=err, err_name="agent/stream")
  244. return JSONResponse(
  245. return_json(code=1, msg=f"{err}", trace_id=trace_id),
  246. status_code=500
  247. )
  248. @test_router.post("/graph/stream", response_class=Response)
  249. async def chat_graph_stream(param: TestForm,
  250. trace_id: str = Depends(get_operation_id)):
  251. """
  252. 根据场景获取智能体反馈 (SSE流式响应)
  253. """
  254. try:
  255. server_logger.info(trace_id=trace_id, msg=f"{param}")
  256. # request_param = {
  257. # "input": param.input,
  258. # "config": param.config,
  259. # "context": param.context
  260. # }
  261. # 创建 SSE 流式响应
  262. async def event_generator():
  263. try:
  264. # 流式处理查询
  265. async for chunk in test_workflow_graph.handle_query_stream(
  266. param=param,
  267. trace_id=trace_id,
  268. ):
  269. server_logger.debug(trace_id=trace_id, msg=f"{chunk}")
  270. # 发送数据块
  271. yield {
  272. "event": "message",
  273. "data": json.dumps({
  274. "code": 0,
  275. "output": chunk,
  276. "completed": False,
  277. "trace_id": trace_id,
  278. "dataType": "text"
  279. }, ensure_ascii=False)
  280. }
  281. # 发送结束事件
  282. yield {
  283. "event": "message_end",
  284. "data": json.dumps({
  285. "completed": True,
  286. "message": "Stream completed",
  287. "code": 0,
  288. "trace_id": trace_id,
  289. "dataType": "text"
  290. }, ensure_ascii=False),
  291. }
  292. except Exception as e:
  293. # 错误处理
  294. yield {
  295. "event": "error",
  296. "data": json.dumps({
  297. "trace_id": trace_id,
  298. "msg": str(e),
  299. "code": 1,
  300. "dataType": "text"
  301. }, ensure_ascii=False)
  302. }
  303. finally:
  304. # 不需要关闭客户端,因为它是单例
  305. pass
  306. # 返回 SSE 响应
  307. return EventSourceResponse(
  308. event_generator(),
  309. headers={
  310. "Cache-Control": "no-cache",
  311. "Connection": "keep-alive"
  312. }
  313. )
  314. except Exception as err:
  315. # 初始错误处理
  316. handler_err(server_logger, trace_id=trace_id, err=err, err_name="graph/stream")
  317. return JSONResponse(
  318. return_json(code=1, msg=f"{err}", trace_id=trace_id),
  319. status_code=500
  320. )
  321. @test_router.post("/redis", response_model=TestForm)
  322. async def test_redis(
  323. request: Request,
  324. param: TestForm,
  325. trace_id: str = Depends(get_operation_id)):
  326. """
  327. 根据MySQL应用
  328. """
  329. try:
  330. server_logger.info(trace_id=trace_id, msg=f"{param}")
  331. # 验证参数
  332. # 从字典中获取input
  333. input_data = param.input
  334. session_id = param.config.session_id
  335. context = param.context
  336. header_info = {
  337. }
  338. from foundation.utils.redis_utils import set_redis_result_cache_data , get_redis_result_cache_data
  339. output = "success"
  340. data_type = "output"
  341. await set_redis_result_cache_data(data_type=data_type , trace_id=trace_id , value=input_data)
  342. server_logger.info(trace_id=trace_id, msg=f"key:{trace_id}:{data_type},value:{input_data} redis 设置成功")
  343. output = await get_redis_result_cache_data(data_type=data_type , trace_id=trace_id)
  344. # 直接执行
  345. server_logger.info(trace_id=trace_id, msg=f"【result】: {output}", log_type="/redis")
  346. # 返回字典格式的响应
  347. return JSONResponse(
  348. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  349. except ValueError as err:
  350. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/redis")
  351. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  352. except Exception as err:
  353. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/redis")
  354. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  355. @test_router.post("/mysql/add", response_model=TestForm)
  356. async def test_mysql_add(
  357. request: Request,
  358. param: TestForm,
  359. trace_id: str = Depends(get_operation_id)):
  360. """
  361. 根据MySQL应用
  362. """
  363. try:
  364. server_logger.info(trace_id=trace_id, msg=f"{param}")
  365. # 验证参数
  366. # 从字典中获取input
  367. input_data = param.input
  368. session_id = param.config.session_id
  369. context = param.context
  370. header_info = {
  371. }
  372. # 从app.state中获取数据库连接池
  373. async_db_pool = request.app.state.async_db_pool
  374. from foundation.base.mysql.async_mysql_base_dao import TestTabDAO
  375. test_tab_dao = TestTabDAO(async_db_pool)
  376. # name: str, email: str, age: int
  377. name = input_data
  378. email = session_id
  379. age = 18
  380. test_id = await test_tab_dao.insert_user(name=name, email=email, age=age)
  381. output = f"【test_id】: {test_id}"
  382. # 直接执行
  383. server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="/mysql/add")
  384. # 返回字典格式的响应
  385. return JSONResponse(
  386. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  387. except ValueError as err:
  388. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/mysql/add")
  389. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  390. except Exception as err:
  391. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/mysql/add")
  392. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  393. @test_router.post("/mysql/get", response_model=TestForm)
  394. async def test_mysql_add(
  395. request: Request,
  396. param: TestForm,
  397. trace_id: str = Depends(get_operation_id)):
  398. """
  399. 根据MySQL应用
  400. """
  401. try:
  402. server_logger.info(trace_id=trace_id, msg=f"{param}")
  403. # 验证参数
  404. # 从字典中获取input
  405. input_data = param.input
  406. session_id = param.config.session_id
  407. context = param.context
  408. header_info = {
  409. }
  410. # 从app.state中获取数据库连接池
  411. async_db_pool = request.app.state.async_db_pool
  412. test_tab_dao = TestTabDAO(async_db_pool)
  413. test_id = input_data;
  414. data = await test_tab_dao.get_user_by_id(user_id=test_id)
  415. server_logger.info(trace_id=trace_id, msg=f"【result】: {data}", log_type="/mysql/get")
  416. json_str = json.dumps(data , cls=DateTimeEncoder, ensure_ascii=False, indent=2)
  417. output = json_str
  418. # 直接执行
  419. server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="/mysql/get")
  420. # 返回字典格式的响应
  421. return JSONResponse(
  422. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  423. except ValueError as err:
  424. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/mysql/get")
  425. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  426. except Exception as err:
  427. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/mysql/get")
  428. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  429. @test_router.post("/mysql/list", response_model=TestForm)
  430. async def test_mysql_add(
  431. request: Request,
  432. param: TestForm,
  433. trace_id: str = Depends(get_operation_id)):
  434. """
  435. 根据MySQL应用
  436. """
  437. try:
  438. server_logger.info(trace_id=trace_id, msg=f"{param}")
  439. # 验证参数
  440. # 从字典中获取input
  441. input_data = param.input
  442. session_id = param.config.session_id
  443. context = param.context
  444. header_info = {
  445. }
  446. # 从app.state中获取数据库连接池
  447. async_db_pool = request.app.state.async_db_pool
  448. from foundation.base.mysql.async_mysql_base_dao import TestTabDAO
  449. test_tab_dao = TestTabDAO(async_db_pool)
  450. test_id = input_data;
  451. data = await test_tab_dao.get_all_users()
  452. server_logger.info(trace_id=trace_id, msg=f"【result】: {data}", log_type="/mysql/list")
  453. json_str = json.dumps(data , cls=DateTimeEncoder, ensure_ascii=False, indent=2)
  454. output = json_str
  455. # 直接执行
  456. server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="/mysql/list")
  457. # 返回字典格式的响应
  458. return JSONResponse(
  459. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  460. except ValueError as err:
  461. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/mysql/list")
  462. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  463. except Exception as err:
  464. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/mysql/list")
  465. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  466. @test_router.post("/mysql/update", response_model=TestForm)
  467. async def test_mysql_add(
  468. request: Request,
  469. param: TestForm,
  470. trace_id: str = Depends(get_operation_id)):
  471. """
  472. 根据MySQL应用
  473. """
  474. try:
  475. server_logger.info(trace_id=trace_id, msg=f"{param}")
  476. # 验证参数
  477. # 从字典中获取input
  478. input_data = param.input
  479. session_id = param.config.session_id
  480. context = param.context
  481. header_info = {
  482. }
  483. # 从app.state中获取数据库连接池
  484. async_db_pool = request.app.state.async_db_pool
  485. test_tab_dao = TestTabDAO(async_db_pool)
  486. test_id = session_id;
  487. updates = {
  488. "name": input_data,
  489. "email": "test_email——upt",
  490. "age": 22
  491. }
  492. success = await test_tab_dao.update_user(user_id=test_id , **updates)
  493. server_logger.info(trace_id=trace_id, msg=f"【result】: {success}", log_type="/mysql/update")
  494. output = success
  495. # 直接执行
  496. server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="/mysql/update")
  497. # 返回字典格式的响应
  498. return JSONResponse(
  499. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  500. except ValueError as err:
  501. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/mysql/update")
  502. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  503. except Exception as err:
  504. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/mysql/update")
  505. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  506. @test_router.post("/bop/get", response_model=TestForm)
  507. async def test_bop_get(
  508. request: Request,
  509. param: TestForm,
  510. trace_id: str = Depends(get_operation_id)):
  511. """
  512. 根据MySQL应用
  513. """
  514. try:
  515. server_logger.info(trace_id=trace_id, msg=f"{param}")
  516. # 验证参数
  517. # 从字典中获取input
  518. input_data = param.input
  519. session_id = param.config.session_id
  520. context = param.context
  521. header_info = {
  522. }
  523. # 从app.state中获取数据库连接池
  524. async_db_pool = request.app.state.async_db_pool
  525. bop_dao = BasisOfPreparationDAO(async_db_pool)
  526. test_id = input_data;
  527. data = await bop_dao.get_info_by_id(id=test_id)
  528. server_logger.info(trace_id=trace_id, msg=f"【result】: {data}", log_type="/bop/get")
  529. json_str = json.dumps(data , cls=DateTimeEncoder, ensure_ascii=False, indent=2)
  530. output = json_str
  531. # 直接执行
  532. server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="/bop/get")
  533. # 返回字典格式的响应
  534. return JSONResponse(
  535. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  536. except ValueError as err:
  537. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/bop/get")
  538. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  539. except Exception as err:
  540. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/bop/get")
  541. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  542. @test_router.post("/bop/list", response_model=TestForm)
  543. async def test_mysql_add(
  544. request: Request,
  545. param: TestForm,
  546. trace_id: str = Depends(get_operation_id)):
  547. """
  548. 根据MySQL应用
  549. """
  550. try:
  551. server_logger.info(trace_id=trace_id, msg=f"{param}")
  552. # 验证参数
  553. # 从字典中获取input
  554. input_data = param.input
  555. session_id = param.config.session_id
  556. context = param.context
  557. header_info = {
  558. }
  559. # 从app.state中获取数据库连接池
  560. async_db_pool = request.app.state.async_db_pool
  561. from foundation.base.mysql.async_mysql_base_dao import TestTabDAO
  562. bop_dao = BasisOfPreparationDAO(async_db_pool)
  563. test_id = input_data;
  564. data = await bop_dao.get_list()
  565. server_logger.info(trace_id=trace_id, msg=f"【result】: {data}", log_type="/bop/list")
  566. json_str = json.dumps(data , cls=DateTimeEncoder, ensure_ascii=False, indent=2)
  567. output = json_str
  568. # 直接执行
  569. server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="/bop/list")
  570. # 返回字典格式的响应
  571. return JSONResponse(
  572. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  573. except ValueError as err:
  574. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/bop/list")
  575. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  576. except Exception as err:
  577. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/bop/list")
  578. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  579. ##################【RAG 相关测试】##############################################
  580. @test_router.post("/embedding", response_model=TestForm)
  581. async def embedding_test_endpoint(
  582. param: TestForm,
  583. trace_id: str = Depends(get_operation_id)):
  584. """
  585. embedding模型测试
  586. """
  587. try:
  588. server_logger.info(trace_id=trace_id, msg=f"{param}")
  589. print(trace_id)
  590. # 从字典中获取input
  591. input_query = param.input
  592. session_id = param.config.session_id
  593. context = param.context
  594. header_info = {
  595. }
  596. task_prompt_info = {"task_prompt": ""}
  597. text = input_query
  598. # 初始化客户端(需提前设置环境变量 SILICONFLOW_API_KEY)
  599. from foundation.models.silicon_flow import SiliconFlowAPI
  600. base_api_platform = SiliconFlowAPI()
  601. embedding = base_api_platform.get_embeddings([text])[0]
  602. embed_dim = len(embedding)
  603. server_logger.info(trace_id=trace_id, msg=f"【result】: {embed_dim}")
  604. output = f"embed_dim={embed_dim},embedding:{embedding}"
  605. #output = test_generate_model_client.get_model_data_governance_invoke(trace_id , task_prompt_info, input_query, context)
  606. # 直接执行
  607. #server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="embedding")
  608. # 返回字典格式的响应
  609. return JSONResponse(
  610. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  611. except ValueError as err:
  612. handler_err(server_logger, trace_id=trace_id, err=err, err_name="embedding")
  613. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  614. except Exception as err:
  615. handler_err(server_logger, trace_id=trace_id, err=err, err_name="embedding")
  616. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  617. @test_router.post("/bfp/search", response_model=TestForm)
  618. async def bfp_search_endpoint(
  619. param: TestForm,
  620. trace_id: str = Depends(get_operation_id)):
  621. """
  622. 编制依据向量检索
  623. """
  624. try:
  625. server_logger.info(trace_id=trace_id, msg=f"{param}")
  626. print(trace_id)
  627. # 从字典中获取input
  628. input_query = param.input
  629. session_id = param.config.session_id
  630. context = param.context
  631. header_info = {
  632. }
  633. task_prompt_info = {"task_prompt": ""}
  634. top_k = int(session_id)
  635. output = None
  636. # 初始化客户端(需提前设置环境变量 SILICONFLOW_API_KEY)
  637. client = SiliconFlowAPI()
  638. # 抽象测试
  639. pg_vector_db = PGVectorDB(base_api_platform=client)
  640. output = pg_vector_db.retriever(param={"table_name": "tv_basis_of_preparation"}, query_text=input_query , top_k=top_k)
  641. # 返回字典格式的响应
  642. return JSONResponse(
  643. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  644. except ValueError as err:
  645. handler_err(server_logger, trace_id=trace_id, err=err, err_name="bfp/search")
  646. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  647. except Exception as err:
  648. handler_err(server_logger, trace_id=trace_id, err=err, err_name="bfp/search")
  649. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  650. @test_router.post("/bfp/search/rerank", response_model=TestForm)
  651. async def bfp_search_endpoint(
  652. param: TestForm,
  653. trace_id: str = Depends(get_operation_id)):
  654. """
  655. 编制依据文档检索和重排序
  656. """
  657. try:
  658. server_logger.info(trace_id=trace_id, msg=f"{param}")
  659. print(trace_id)
  660. # 从字典中获取input
  661. input_query = param.input
  662. session_id = param.config.session_id
  663. context = param.context
  664. header_info = {
  665. }
  666. task_prompt_info = {"task_prompt": ""}
  667. top_k = int(session_id)
  668. output = None
  669. # 初始化客户端(需提前设置环境变量 SILICONFLOW_API_KEY)
  670. client = SiliconFlowAPI()
  671. # 抽象测试
  672. pg_vector_db = PGVectorDB(base_api_platform=client)
  673. output = pg_vector_db.retriever(param={"table_name": "tv_basis_of_preparation"}, query_text=input_query , top_k=top_k)
  674. # 重排序处理
  675. content_list = [doc["text_content"] for doc in output]
  676. output = client.rerank(input_query=input_query, documents=content_list , top_n=top_k)
  677. # 返回字典格式的响应
  678. return JSONResponse(
  679. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  680. except ValueError as err:
  681. handler_err(server_logger, trace_id=trace_id, err=err, err_name="bfp/search")
  682. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  683. except Exception as err:
  684. handler_err(server_logger, trace_id=trace_id, err=err, err_name="bfp/search")
  685. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  686. @test_router.post("/data/bfp/milvus/search", response_model=TestForm)
  687. async def bfp_search_endpoint(
  688. param: TestForm,
  689. trace_id: str = Depends(get_operation_id)):
  690. """
  691. 编制依据文档切分处理 和 入库处理
  692. """
  693. try:
  694. server_logger.info(trace_id=trace_id, msg=f"{param}")
  695. print(trace_id)
  696. # 从字典中获取input
  697. input_query = param.input
  698. session_id = param.config.session_id
  699. context = param.context
  700. header_info = {
  701. }
  702. task_prompt_info = {"task_prompt": ""}
  703. top_k = int(session_id)
  704. output = None
  705. # 初始化客户端(需提前设置环境变量 SILICONFLOW_API_KEY)
  706. client = SiliconFlowAPI()
  707. # 抽象测试
  708. vector_db = MilvusVectorManager(base_api_platform=client)
  709. output = vector_db.retriever(param={"collection_name": "tv_basis_of_preparation"}, query_text=input_query , top_k=top_k)
  710. # 返回字典格式的响应
  711. return JSONResponse(
  712. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  713. except ValueError as err:
  714. handler_err(server_logger, trace_id=trace_id, err=err, err_name="bfp/milvus/search")
  715. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  716. except Exception as err:
  717. handler_err(server_logger, trace_id=trace_id, err=err, err_name="bfp/milvus/search")
  718. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))