test_views.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729
  1. # !/usr/bin/ python
  2. # -*- coding: utf-8 -*-
  3. '''
  4. @Project : lq-agent-api
  5. @File :cattle_farm_views.py
  6. @IDE :PyCharm
  7. @Author :
  8. @Date :2025/7/10 17:32
  9. '''
  10. import json
  11. from typing import Optional
  12. from fastapi import Depends, Response, Header
  13. from sse_starlette import EventSourceResponse
  14. from starlette.responses import JSONResponse
  15. from fastapi import Depends, Request, Response, Header
  16. from foundation.agent.test_agent import test_agent_client
  17. from foundation.agent.generate.model_generate import generate_model_client
  18. from foundation.logger.loggering import server_logger
  19. from foundation.schemas.test_schemas import TestForm
  20. from foundation.utils.common import return_json, handler_err
  21. from views import test_router, get_operation_id
  22. from foundation.agent.workflow.test_workflow_graph import test_workflow_graph
  23. from foundation.base.mysql.async_mysql_base_dao import TestTabDAO
  24. from database.repositories.bus_data_query import BasisOfPreparationDAO
  25. from foundation.utils.tool_utils import DateTimeEncoder
  26. from foundation.models.silicon_flow import SiliconFlowAPI
  27. from foundation.rag.vector.pg_vector import PGVectorDB
  28. @test_router.post("/generate/chat", response_model=TestForm)
  29. async def generate_chat_endpoint(
  30. param: TestForm,
  31. trace_id: str = Depends(get_operation_id)):
  32. """
  33. 生成类模型
  34. """
  35. try:
  36. server_logger.info(trace_id=trace_id, msg=f"{param}")
  37. print(trace_id)
  38. # 从字典中获取input
  39. input_query = param.input
  40. session_id = param.config.session_id
  41. context = param.context
  42. header_info = {
  43. }
  44. task_prompt_info = {"task_prompt": ""}
  45. output = generate_model_client.get_model_generate_invoke(trace_id , task_prompt_info,
  46. input_query, context)
  47. # 直接执行
  48. server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="agent/chat")
  49. # 返回字典格式的响应
  50. return JSONResponse(
  51. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  52. except ValueError as err:
  53. handler_err(server_logger, trace_id=trace_id, err=err, err_name="generate/stream")
  54. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  55. except Exception as err:
  56. handler_err(server_logger, trace_id=trace_id, err=err, err_name="generate/stream")
  57. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  58. @test_router.post("/generate/stream", response_model=TestForm)
  59. async def generate_stream_endpoint(
  60. param: TestForm,
  61. trace_id: str = Depends(get_operation_id)):
  62. """
  63. 生成类模型
  64. """
  65. try:
  66. server_logger.info(trace_id=trace_id, msg=f"{param}")
  67. # 从字典中获取input
  68. input_query = param.input
  69. session_id = param.config.session_id
  70. context = param.context
  71. header_info = {
  72. }
  73. task_prompt_info = {"task_prompt": ""}
  74. # 创建 SSE 流式响应
  75. async def event_generator():
  76. try:
  77. # 流式处理查询 trace_id, task_prompt_info: dict, input_query, context=None
  78. for chunk in generate_model_client.get_model_generate_stream(trace_id , task_prompt_info,
  79. input_query, context):
  80. # 发送数据块
  81. yield {
  82. "event": "message",
  83. "data": json.dumps({
  84. "output": chunk,
  85. "completed": False,
  86. }, ensure_ascii=False)
  87. }
  88. # 获取缓存数据
  89. result_data = {}
  90. # 发送结束事件
  91. yield {
  92. "event": "message_end",
  93. "data": json.dumps({
  94. "completed": True,
  95. "message": json.dumps(result_data, ensure_ascii=False),
  96. "code": 0,
  97. "trace_id": trace_id,
  98. }, ensure_ascii=False),
  99. }
  100. except Exception as e:
  101. # 错误处理
  102. yield {
  103. "event": "error",
  104. "data": json.dumps({
  105. "trace_id": trace_id,
  106. "message": str(e),
  107. "code": 1
  108. }, ensure_ascii=False)
  109. }
  110. finally:
  111. # 不需要关闭客户端,因为它是单例
  112. pass
  113. # 返回 SSE 响应
  114. return EventSourceResponse(
  115. event_generator(),
  116. headers={
  117. "Cache-Control": "no-cache",
  118. "Connection": "keep-alive"
  119. }
  120. )
  121. except ValueError as err:
  122. handler_err(server_logger, trace_id=trace_id, err=err, err_name="generate/stream")
  123. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  124. except Exception as err:
  125. handler_err(server_logger, trace_id=trace_id, err=err, err_name="generate/stream")
  126. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  127. # 路由
  128. @test_router.post("/agent/chat", response_model=TestForm)
  129. async def chat_endpoint(
  130. param: TestForm,
  131. trace_id: str = Depends(get_operation_id)):
  132. """
  133. 根据场景获取智能体反馈
  134. """
  135. try:
  136. server_logger.info(trace_id=trace_id, msg=f"{param}")
  137. # 验证参数
  138. # 从字典中获取input
  139. input_data = param.input
  140. session_id = param.config.session_id
  141. context = param.context
  142. header_info = {
  143. }
  144. task_prompt_info = {"task_prompt": ""}
  145. # stream 流式执行
  146. output = await test_agent_client.handle_query(trace_id , task_prompt_info, input_data, context, param.config)
  147. # 直接执行
  148. server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="agent/chat")
  149. # 返回字典格式的响应
  150. return JSONResponse(
  151. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  152. except ValueError as err:
  153. handler_err(server_logger, trace_id=trace_id, err=err, err_name="agent/chat")
  154. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  155. except Exception as err:
  156. handler_err(server_logger, trace_id=trace_id, err=err, err_name="agent/chat")
  157. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  158. @test_router.post("/agent/stream", response_class=Response)
  159. async def chat_agent_stream(param: TestForm,
  160. trace_id: str = Depends(get_operation_id)):
  161. """
  162. 根据场景获取智能体反馈 (SSE流式响应)
  163. """
  164. try:
  165. server_logger.info(trace_id=trace_id, msg=f"{param}")
  166. # 提取参数
  167. input_data = param.input
  168. context = param.context
  169. header_info = {
  170. }
  171. task_prompt_info = {"task_prompt": ""}
  172. # 如果business_scene为None,则使用大模型进行意图识别
  173. server_logger.info(trace_id=trace_id, msg=f"{param}")
  174. # 创建 SSE 流式响应
  175. async def event_generator():
  176. try:
  177. # 流式处理查询
  178. async for chunk in test_agent_client.handle_query_stream(
  179. trace_id=trace_id,
  180. config_param=param.config,
  181. task_prompt_info=task_prompt_info,
  182. input_query=input_data,
  183. context=context,
  184. header_info=header_info
  185. ):
  186. server_logger.debug(trace_id=trace_id, msg=f"{chunk}")
  187. # 发送数据块
  188. yield {
  189. "event": "message",
  190. "data": json.dumps({
  191. "code": 0,
  192. "output": chunk,
  193. "completed": False,
  194. "trace_id": trace_id,
  195. }, ensure_ascii=False)
  196. }
  197. # 获取缓存数据
  198. result_data = {}
  199. # 发送结束事件
  200. yield {
  201. "event": "message_end",
  202. "data": json.dumps({
  203. "completed": True,
  204. "message": json.dumps(result_data, ensure_ascii=False),
  205. "code": 0,
  206. "trace_id": trace_id,
  207. }, ensure_ascii=False),
  208. }
  209. except Exception as e:
  210. # 错误处理
  211. yield {
  212. "event": "error",
  213. "data": json.dumps({
  214. "trace_id": trace_id,
  215. "message": str(e),
  216. "code": 1
  217. }, ensure_ascii=False)
  218. }
  219. finally:
  220. # 不需要关闭客户端,因为它是单例
  221. pass
  222. # 返回 SSE 响应
  223. return EventSourceResponse(
  224. event_generator(),
  225. headers={
  226. "Cache-Control": "no-cache",
  227. "Connection": "keep-alive"
  228. }
  229. )
  230. except Exception as err:
  231. # 初始错误处理
  232. handler_err(server_logger, trace_id=trace_id, err=err, err_name="agent/stream")
  233. return JSONResponse(
  234. return_json(code=1, msg=f"{err}", trace_id=trace_id),
  235. status_code=500
  236. )
  237. @test_router.post("/graph/stream", response_class=Response)
  238. async def chat_graph_stream(param: TestForm,
  239. trace_id: str = Depends(get_operation_id)):
  240. """
  241. 根据场景获取智能体反馈 (SSE流式响应)
  242. """
  243. try:
  244. server_logger.info(trace_id=trace_id, msg=f"{param}")
  245. # request_param = {
  246. # "input": param.input,
  247. # "config": param.config,
  248. # "context": param.context
  249. # }
  250. # 创建 SSE 流式响应
  251. async def event_generator():
  252. try:
  253. # 流式处理查询
  254. async for chunk in test_workflow_graph.handle_query_stream(
  255. param=param,
  256. trace_id=trace_id,
  257. ):
  258. server_logger.debug(trace_id=trace_id, msg=f"{chunk}")
  259. # 发送数据块
  260. yield {
  261. "event": "message",
  262. "data": json.dumps({
  263. "code": 0,
  264. "output": chunk,
  265. "completed": False,
  266. "trace_id": trace_id,
  267. "dataType": "text"
  268. }, ensure_ascii=False)
  269. }
  270. # 发送结束事件
  271. yield {
  272. "event": "message_end",
  273. "data": json.dumps({
  274. "completed": True,
  275. "message": "Stream completed",
  276. "code": 0,
  277. "trace_id": trace_id,
  278. "dataType": "text"
  279. }, ensure_ascii=False),
  280. }
  281. except Exception as e:
  282. # 错误处理
  283. yield {
  284. "event": "error",
  285. "data": json.dumps({
  286. "trace_id": trace_id,
  287. "msg": str(e),
  288. "code": 1,
  289. "dataType": "text"
  290. }, ensure_ascii=False)
  291. }
  292. finally:
  293. # 不需要关闭客户端,因为它是单例
  294. pass
  295. # 返回 SSE 响应
  296. return EventSourceResponse(
  297. event_generator(),
  298. headers={
  299. "Cache-Control": "no-cache",
  300. "Connection": "keep-alive"
  301. }
  302. )
  303. except Exception as err:
  304. # 初始错误处理
  305. handler_err(server_logger, trace_id=trace_id, err=err, err_name="graph/stream")
  306. return JSONResponse(
  307. return_json(code=1, msg=f"{err}", trace_id=trace_id),
  308. status_code=500
  309. )
  310. @test_router.post("/mysql/add", response_model=TestForm)
  311. async def test_mysql_add(
  312. request: Request,
  313. param: TestForm,
  314. trace_id: str = Depends(get_operation_id)):
  315. """
  316. 根据MySQL应用
  317. """
  318. try:
  319. server_logger.info(trace_id=trace_id, msg=f"{param}")
  320. # 验证参数
  321. # 从字典中获取input
  322. input_data = param.input
  323. session_id = param.config.session_id
  324. context = param.context
  325. header_info = {
  326. }
  327. # 从app.state中获取数据库连接池
  328. async_db_pool = request.app.state.async_db_pool
  329. from foundation.base.mysql.async_mysql_base_dao import TestTabDAO
  330. test_tab_dao = TestTabDAO(async_db_pool)
  331. # name: str, email: str, age: int
  332. name = input_data
  333. email = session_id
  334. age = 18
  335. test_id = await test_tab_dao.insert_user(name=name, email=email, age=age)
  336. output = f"【test_id】: {test_id}"
  337. # 直接执行
  338. server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="/mysql/add")
  339. # 返回字典格式的响应
  340. return JSONResponse(
  341. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  342. except ValueError as err:
  343. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/mysql/add")
  344. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  345. except Exception as err:
  346. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/mysql/add")
  347. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  348. @test_router.post("/mysql/get", response_model=TestForm)
  349. async def test_mysql_add(
  350. request: Request,
  351. param: TestForm,
  352. trace_id: str = Depends(get_operation_id)):
  353. """
  354. 根据MySQL应用
  355. """
  356. try:
  357. server_logger.info(trace_id=trace_id, msg=f"{param}")
  358. # 验证参数
  359. # 从字典中获取input
  360. input_data = param.input
  361. session_id = param.config.session_id
  362. context = param.context
  363. header_info = {
  364. }
  365. # 从app.state中获取数据库连接池
  366. async_db_pool = request.app.state.async_db_pool
  367. test_tab_dao = TestTabDAO(async_db_pool)
  368. test_id = input_data;
  369. data = await test_tab_dao.get_user_by_id(user_id=test_id)
  370. server_logger.info(trace_id=trace_id, msg=f"【result】: {data}", log_type="/mysql/get")
  371. json_str = json.dumps(data , cls=DateTimeEncoder, ensure_ascii=False, indent=2)
  372. output = json_str
  373. # 直接执行
  374. server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="/mysql/get")
  375. # 返回字典格式的响应
  376. return JSONResponse(
  377. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  378. except ValueError as err:
  379. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/mysql/get")
  380. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  381. except Exception as err:
  382. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/mysql/get")
  383. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  384. @test_router.post("/mysql/list", response_model=TestForm)
  385. async def test_mysql_add(
  386. request: Request,
  387. param: TestForm,
  388. trace_id: str = Depends(get_operation_id)):
  389. """
  390. 根据MySQL应用
  391. """
  392. try:
  393. server_logger.info(trace_id=trace_id, msg=f"{param}")
  394. # 验证参数
  395. # 从字典中获取input
  396. input_data = param.input
  397. session_id = param.config.session_id
  398. context = param.context
  399. header_info = {
  400. }
  401. # 从app.state中获取数据库连接池
  402. async_db_pool = request.app.state.async_db_pool
  403. from foundation.base.mysql.async_mysql_base_dao import TestTabDAO
  404. test_tab_dao = TestTabDAO(async_db_pool)
  405. test_id = input_data;
  406. data = await test_tab_dao.get_all_users()
  407. server_logger.info(trace_id=trace_id, msg=f"【result】: {data}", log_type="/mysql/list")
  408. json_str = json.dumps(data , cls=DateTimeEncoder, ensure_ascii=False, indent=2)
  409. output = json_str
  410. # 直接执行
  411. server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="/mysql/list")
  412. # 返回字典格式的响应
  413. return JSONResponse(
  414. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  415. except ValueError as err:
  416. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/mysql/list")
  417. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  418. except Exception as err:
  419. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/mysql/list")
  420. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  421. @test_router.post("/mysql/update", response_model=TestForm)
  422. async def test_mysql_add(
  423. request: Request,
  424. param: TestForm,
  425. trace_id: str = Depends(get_operation_id)):
  426. """
  427. 根据MySQL应用
  428. """
  429. try:
  430. server_logger.info(trace_id=trace_id, msg=f"{param}")
  431. # 验证参数
  432. # 从字典中获取input
  433. input_data = param.input
  434. session_id = param.config.session_id
  435. context = param.context
  436. header_info = {
  437. }
  438. # 从app.state中获取数据库连接池
  439. async_db_pool = request.app.state.async_db_pool
  440. test_tab_dao = TestTabDAO(async_db_pool)
  441. test_id = session_id;
  442. updates = {
  443. "name": input_data,
  444. "email": "test_email——upt",
  445. "age": 22
  446. }
  447. success = await test_tab_dao.update_user(user_id=test_id , **updates)
  448. server_logger.info(trace_id=trace_id, msg=f"【result】: {success}", log_type="/mysql/update")
  449. output = success
  450. # 直接执行
  451. server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="/mysql/update")
  452. # 返回字典格式的响应
  453. return JSONResponse(
  454. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  455. except ValueError as err:
  456. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/mysql/update")
  457. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  458. except Exception as err:
  459. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/mysql/update")
  460. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  461. @test_router.post("/bop/get", response_model=TestForm)
  462. async def test_bop_get(
  463. request: Request,
  464. param: TestForm,
  465. trace_id: str = Depends(get_operation_id)):
  466. """
  467. 根据MySQL应用
  468. """
  469. try:
  470. server_logger.info(trace_id=trace_id, msg=f"{param}")
  471. # 验证参数
  472. # 从字典中获取input
  473. input_data = param.input
  474. session_id = param.config.session_id
  475. context = param.context
  476. header_info = {
  477. }
  478. # 从app.state中获取数据库连接池
  479. async_db_pool = request.app.state.async_db_pool
  480. bop_dao = BasisOfPreparationDAO(async_db_pool)
  481. test_id = input_data;
  482. data = await bop_dao.get_info_by_id(id=test_id)
  483. server_logger.info(trace_id=trace_id, msg=f"【result】: {data}", log_type="/bop/get")
  484. json_str = json.dumps(data , cls=DateTimeEncoder, ensure_ascii=False, indent=2)
  485. output = json_str
  486. # 直接执行
  487. server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="/bop/get")
  488. # 返回字典格式的响应
  489. return JSONResponse(
  490. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  491. except ValueError as err:
  492. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/bop/get")
  493. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  494. except Exception as err:
  495. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/bop/get")
  496. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  497. @test_router.post("/bop/list", response_model=TestForm)
  498. async def test_mysql_add(
  499. request: Request,
  500. param: TestForm,
  501. trace_id: str = Depends(get_operation_id)):
  502. """
  503. 根据MySQL应用
  504. """
  505. try:
  506. server_logger.info(trace_id=trace_id, msg=f"{param}")
  507. # 验证参数
  508. # 从字典中获取input
  509. input_data = param.input
  510. session_id = param.config.session_id
  511. context = param.context
  512. header_info = {
  513. }
  514. # 从app.state中获取数据库连接池
  515. async_db_pool = request.app.state.async_db_pool
  516. from foundation.base.mysql.async_mysql_base_dao import TestTabDAO
  517. bop_dao = BasisOfPreparationDAO(async_db_pool)
  518. test_id = input_data;
  519. data = await bop_dao.get_list()
  520. server_logger.info(trace_id=trace_id, msg=f"【result】: {data}", log_type="/bop/list")
  521. json_str = json.dumps(data , cls=DateTimeEncoder, ensure_ascii=False, indent=2)
  522. output = json_str
  523. # 直接执行
  524. server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="/bop/list")
  525. # 返回字典格式的响应
  526. return JSONResponse(
  527. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  528. except ValueError as err:
  529. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/bop/list")
  530. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  531. except Exception as err:
  532. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/bop/list")
  533. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  534. ##################【RAG 相关测试】##############################################
  535. @test_router.post("/embedding", response_model=TestForm)
  536. async def embedding_test_endpoint(
  537. param: TestForm,
  538. trace_id: str = Depends(get_operation_id)):
  539. """
  540. embedding模型测试
  541. """
  542. try:
  543. server_logger.info(trace_id=trace_id, msg=f"{param}")
  544. print(trace_id)
  545. # 从字典中获取input
  546. input_query = param.input
  547. session_id = param.config.session_id
  548. context = param.context
  549. header_info = {
  550. }
  551. task_prompt_info = {"task_prompt": ""}
  552. text = input_query
  553. # 初始化客户端(需提前设置环境变量 SILICONFLOW_API_KEY)
  554. from foundation.models.silicon_flow import SiliconFlowAPI
  555. base_api_platform = SiliconFlowAPI()
  556. embedding = base_api_platform.get_embeddings([text])[0]
  557. embed_dim = len(embedding)
  558. server_logger.info(trace_id=trace_id, msg=f"【result】: {embed_dim}")
  559. output = f"embed_dim={embed_dim},embedding:{embedding}"
  560. #output = test_generate_model_client.get_model_data_governance_invoke(trace_id , task_prompt_info, input_query, context)
  561. # 直接执行
  562. #server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="embedding")
  563. # 返回字典格式的响应
  564. return JSONResponse(
  565. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  566. except ValueError as err:
  567. handler_err(server_logger, trace_id=trace_id, err=err, err_name="embedding")
  568. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  569. except Exception as err:
  570. handler_err(server_logger, trace_id=trace_id, err=err, err_name="embedding")
  571. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  572. @test_router.post("/bfp/search", response_model=TestForm)
  573. async def bfp_search_endpoint(
  574. param: TestForm,
  575. trace_id: str = Depends(get_operation_id)):
  576. """
  577. 编制依据向量检索
  578. """
  579. try:
  580. server_logger.info(trace_id=trace_id, msg=f"{param}")
  581. print(trace_id)
  582. # 从字典中获取input
  583. input_query = param.input
  584. session_id = param.config.session_id
  585. context = param.context
  586. header_info = {
  587. }
  588. task_prompt_info = {"task_prompt": ""}
  589. top_k = int(session_id)
  590. output = None
  591. # 初始化客户端(需提前设置环境变量 SILICONFLOW_API_KEY)
  592. client = SiliconFlowAPI()
  593. # 抽象测试
  594. pg_vector_db = PGVectorDB(base_api_platform=client)
  595. output = pg_vector_db.retriever(param={"table_name": "tv_basis_of_preparation"}, query_text=input_query , top_k=top_k)
  596. # 返回字典格式的响应
  597. return JSONResponse(
  598. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  599. except ValueError as err:
  600. handler_err(server_logger, trace_id=trace_id, err=err, err_name="bfp/search")
  601. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  602. except Exception as err:
  603. handler_err(server_logger, trace_id=trace_id, err=err, err_name="bfp/search")
  604. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  605. @test_router.post("/bfp/search/rerank", response_model=TestForm)
  606. async def bfp_search_endpoint(
  607. param: TestForm,
  608. trace_id: str = Depends(get_operation_id)):
  609. """
  610. 编制依据文档检索和重排序
  611. """
  612. try:
  613. server_logger.info(trace_id=trace_id, msg=f"{param}")
  614. print(trace_id)
  615. # 从字典中获取input
  616. input_query = param.input
  617. session_id = param.config.session_id
  618. context = param.context
  619. header_info = {
  620. }
  621. task_prompt_info = {"task_prompt": ""}
  622. top_k = int(session_id)
  623. output = None
  624. # 初始化客户端(需提前设置环境变量 SILICONFLOW_API_KEY)
  625. client = SiliconFlowAPI()
  626. # 抽象测试
  627. pg_vector_db = PGVectorDB(base_api_platform=client)
  628. output = pg_vector_db.retriever(param={"table_name": "tv_basis_of_preparation"}, query_text=input_query , top_k=top_k)
  629. # 重排序处理
  630. content_list = [doc["text_content"] for doc in output]
  631. output = client.rerank(input_query=input_query, documents=content_list , top_n=top_k)
  632. # 返回字典格式的响应
  633. return JSONResponse(
  634. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  635. except ValueError as err:
  636. handler_err(server_logger, trace_id=trace_id, err=err, err_name="bfp/search")
  637. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  638. except Exception as err:
  639. handler_err(server_logger, trace_id=trace_id, err=err, err_name="bfp/search")
  640. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))