test_views.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742
  1. # !/usr/bin/ python
  2. # -*- coding: utf-8 -*-
  3. '''
  4. @Project : lq-agent-api
  5. @File :cattle_farm_views.py
  6. @IDE :PyCharm
  7. @Author :
  8. @Date :2025/7/10 17:32
  9. '''
  10. import json
  11. from typing import Optional
  12. from fastapi import Depends, Response, Header
  13. from sse_starlette import EventSourceResponse
  14. from starlette.responses import JSONResponse
  15. from fastapi import Depends, Request, Response, Header
  16. from foundation.agent.test_agent import test_agent_client
  17. from foundation.agent.generate.model_generate import generate_model_client
  18. from foundation.logger.loggering import server_logger
  19. from foundation.schemas.test_schemas import TestForm
  20. from foundation.utils.common import return_json, handler_err
  21. from views import test_router, get_operation_id
  22. from foundation.agent.workflow.test_workflow_graph import test_workflow_graph
  23. from foundation.base.mysql.async_mysql_base_dao import TestTabDAO
  24. from database.repositories.bus_data_query import BasisOfPreparationDAO
  25. from foundation.utils.tool_utils import DateTimeEncoder
  26. from foundation.models.silicon_flow import SiliconFlowAPI
  27. from foundation.rag.vector.pg_vector import PGVectorDB
  28. from langchain_core.prompts import ChatPromptTemplate
  29. from foundation.utils.yaml_utils import system_prompt_config
  30. @test_router.post("/generate/chat", response_model=TestForm)
  31. async def generate_chat_endpoint(
  32. param: TestForm,
  33. trace_id: str = Depends(get_operation_id)):
  34. """
  35. 生成类模型
  36. """
  37. try:
  38. server_logger.info(trace_id=trace_id, msg=f"{param}")
  39. print(trace_id)
  40. # 从字典中获取input
  41. input_query = param.input
  42. session_id = param.config.session_id
  43. context = param.context
  44. header_info = {
  45. }
  46. # 创建ChatPromptTemplate
  47. template = ChatPromptTemplate.from_messages([
  48. ("system", system_prompt_config['system_prompt']),
  49. ("user", input_query)
  50. ])
  51. task_prompt_info = {"task_prompt": template}
  52. output = await generate_model_client.get_model_generate_invoke(trace_id=trace_id , task_prompt_info=task_prompt_info)
  53. # 直接执行
  54. server_logger.info(trace_id=trace_id, msg=f"【result】: {output}", log_type="agent/chat")
  55. # 返回字典格式的响应
  56. return JSONResponse(
  57. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  58. except ValueError as err:
  59. handler_err(server_logger, trace_id=trace_id, err=err, err_name="agent/chat")
  60. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  61. except Exception as err:
  62. handler_err(server_logger, trace_id=trace_id, err=err, err_name="agent/chat")
  63. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  64. @test_router.post("/generate/stream", response_model=TestForm)
  65. async def generate_stream_endpoint(
  66. param: TestForm,
  67. trace_id: str = Depends(get_operation_id)):
  68. """
  69. 生成类模型
  70. """
  71. try:
  72. server_logger.info(trace_id=trace_id, msg=f"{param}")
  73. # 从字典中获取input
  74. input_query = param.input
  75. session_id = param.config.session_id
  76. context = param.context
  77. header_info = {
  78. }
  79. # 创建ChatPromptTemplate
  80. template = ChatPromptTemplate.from_messages([
  81. ("system", system_prompt_config['system_prompt']),
  82. ("user", input_query)
  83. ])
  84. task_prompt_info = {"task_prompt": template}
  85. # 创建 SSE 流式响应
  86. async def event_generator():
  87. try:
  88. # 流式处理查询 trace_id, task_prompt_info: dict, input_query, context=None
  89. for chunk in generate_model_client.get_model_generate_stream(trace_id=trace_id , task_prompt_info=task_prompt_info):
  90. # 发送数据块
  91. yield {
  92. "event": "message",
  93. "data": json.dumps({
  94. "output": chunk,
  95. "completed": False,
  96. }, ensure_ascii=False)
  97. }
  98. # 获取缓存数据
  99. result_data = {}
  100. # 发送结束事件
  101. yield {
  102. "event": "message_end",
  103. "data": json.dumps({
  104. "completed": True,
  105. "message": json.dumps(result_data, ensure_ascii=False),
  106. "code": 0,
  107. "trace_id": trace_id,
  108. }, ensure_ascii=False),
  109. }
  110. except Exception as e:
  111. # 错误处理
  112. yield {
  113. "event": "error",
  114. "data": json.dumps({
  115. "trace_id": trace_id,
  116. "message": str(e),
  117. "code": 1
  118. }, ensure_ascii=False)
  119. }
  120. finally:
  121. # 不需要关闭客户端,因为它是单例
  122. pass
  123. # 返回 SSE 响应
  124. return EventSourceResponse(
  125. event_generator(),
  126. headers={
  127. "Cache-Control": "no-cache",
  128. "Connection": "keep-alive"
  129. }
  130. )
  131. except ValueError as err:
  132. handler_err(server_logger, trace_id=trace_id, err=err, err_name="generate/stream")
  133. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  134. except Exception as err:
  135. handler_err(server_logger, trace_id=trace_id, err=err, err_name="generate/stream")
  136. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  137. # 路由
  138. @test_router.post("/agent/chat", response_model=TestForm)
  139. async def chat_endpoint(
  140. param: TestForm,
  141. trace_id: str = Depends(get_operation_id)):
  142. """
  143. 根据场景获取智能体反馈
  144. """
  145. try:
  146. server_logger.info(trace_id=trace_id, msg=f"{param}")
  147. # 验证参数
  148. # 从字典中获取input
  149. input_data = param.input
  150. session_id = param.config.session_id
  151. context = param.context
  152. header_info = {
  153. }
  154. task_prompt_info = {"task_prompt": ""}
  155. # stream 流式执行
  156. output = await test_agent_client.handle_query(trace_id , task_prompt_info, input_data, context, param.config)
  157. # 直接执行
  158. server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="agent/chat")
  159. # 返回字典格式的响应
  160. return JSONResponse(
  161. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  162. except ValueError as err:
  163. handler_err(server_logger, trace_id=trace_id, err=err, err_name="agent/chat")
  164. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  165. except Exception as err:
  166. handler_err(server_logger, trace_id=trace_id, err=err, err_name="agent/chat")
  167. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  168. @test_router.post("/agent/stream", response_class=Response)
  169. async def chat_agent_stream(param: TestForm,
  170. trace_id: str = Depends(get_operation_id)):
  171. """
  172. 根据场景获取智能体反馈 (SSE流式响应)
  173. """
  174. try:
  175. server_logger.info(trace_id=trace_id, msg=f"{param}")
  176. # 提取参数
  177. input_data = param.input
  178. context = param.context
  179. header_info = {
  180. }
  181. task_prompt_info = {"task_prompt": ""}
  182. # 如果business_scene为None,则使用大模型进行意图识别
  183. server_logger.info(trace_id=trace_id, msg=f"{param}")
  184. # 创建 SSE 流式响应
  185. async def event_generator():
  186. try:
  187. # 流式处理查询
  188. async for chunk in test_agent_client.handle_query_stream(
  189. trace_id=trace_id,
  190. config_param=param.config,
  191. task_prompt_info=task_prompt_info,
  192. input_query=input_data,
  193. context=context,
  194. header_info=header_info
  195. ):
  196. server_logger.debug(trace_id=trace_id, msg=f"{chunk}")
  197. # 发送数据块
  198. yield {
  199. "event": "message",
  200. "data": json.dumps({
  201. "code": 0,
  202. "output": chunk,
  203. "completed": False,
  204. "trace_id": trace_id,
  205. }, ensure_ascii=False)
  206. }
  207. # 获取缓存数据
  208. result_data = {}
  209. # 发送结束事件
  210. yield {
  211. "event": "message_end",
  212. "data": json.dumps({
  213. "completed": True,
  214. "message": json.dumps(result_data, ensure_ascii=False),
  215. "code": 0,
  216. "trace_id": trace_id,
  217. }, ensure_ascii=False),
  218. }
  219. except Exception as e:
  220. # 错误处理
  221. yield {
  222. "event": "error",
  223. "data": json.dumps({
  224. "trace_id": trace_id,
  225. "message": str(e),
  226. "code": 1
  227. }, ensure_ascii=False)
  228. }
  229. finally:
  230. # 不需要关闭客户端,因为它是单例
  231. pass
  232. # 返回 SSE 响应
  233. return EventSourceResponse(
  234. event_generator(),
  235. headers={
  236. "Cache-Control": "no-cache",
  237. "Connection": "keep-alive"
  238. }
  239. )
  240. except Exception as err:
  241. # 初始错误处理
  242. handler_err(server_logger, trace_id=trace_id, err=err, err_name="agent/stream")
  243. return JSONResponse(
  244. return_json(code=1, msg=f"{err}", trace_id=trace_id),
  245. status_code=500
  246. )
  247. @test_router.post("/graph/stream", response_class=Response)
  248. async def chat_graph_stream(param: TestForm,
  249. trace_id: str = Depends(get_operation_id)):
  250. """
  251. 根据场景获取智能体反馈 (SSE流式响应)
  252. """
  253. try:
  254. server_logger.info(trace_id=trace_id, msg=f"{param}")
  255. # request_param = {
  256. # "input": param.input,
  257. # "config": param.config,
  258. # "context": param.context
  259. # }
  260. # 创建 SSE 流式响应
  261. async def event_generator():
  262. try:
  263. # 流式处理查询
  264. async for chunk in test_workflow_graph.handle_query_stream(
  265. param=param,
  266. trace_id=trace_id,
  267. ):
  268. server_logger.debug(trace_id=trace_id, msg=f"{chunk}")
  269. # 发送数据块
  270. yield {
  271. "event": "message",
  272. "data": json.dumps({
  273. "code": 0,
  274. "output": chunk,
  275. "completed": False,
  276. "trace_id": trace_id,
  277. "dataType": "text"
  278. }, ensure_ascii=False)
  279. }
  280. # 发送结束事件
  281. yield {
  282. "event": "message_end",
  283. "data": json.dumps({
  284. "completed": True,
  285. "message": "Stream completed",
  286. "code": 0,
  287. "trace_id": trace_id,
  288. "dataType": "text"
  289. }, ensure_ascii=False),
  290. }
  291. except Exception as e:
  292. # 错误处理
  293. yield {
  294. "event": "error",
  295. "data": json.dumps({
  296. "trace_id": trace_id,
  297. "msg": str(e),
  298. "code": 1,
  299. "dataType": "text"
  300. }, ensure_ascii=False)
  301. }
  302. finally:
  303. # 不需要关闭客户端,因为它是单例
  304. pass
  305. # 返回 SSE 响应
  306. return EventSourceResponse(
  307. event_generator(),
  308. headers={
  309. "Cache-Control": "no-cache",
  310. "Connection": "keep-alive"
  311. }
  312. )
  313. except Exception as err:
  314. # 初始错误处理
  315. handler_err(server_logger, trace_id=trace_id, err=err, err_name="graph/stream")
  316. return JSONResponse(
  317. return_json(code=1, msg=f"{err}", trace_id=trace_id),
  318. status_code=500
  319. )
  320. @test_router.post("/mysql/add", response_model=TestForm)
  321. async def test_mysql_add(
  322. request: Request,
  323. param: TestForm,
  324. trace_id: str = Depends(get_operation_id)):
  325. """
  326. 根据MySQL应用
  327. """
  328. try:
  329. server_logger.info(trace_id=trace_id, msg=f"{param}")
  330. # 验证参数
  331. # 从字典中获取input
  332. input_data = param.input
  333. session_id = param.config.session_id
  334. context = param.context
  335. header_info = {
  336. }
  337. # 从app.state中获取数据库连接池
  338. async_db_pool = request.app.state.async_db_pool
  339. from foundation.base.mysql.async_mysql_base_dao import TestTabDAO
  340. test_tab_dao = TestTabDAO(async_db_pool)
  341. # name: str, email: str, age: int
  342. name = input_data
  343. email = session_id
  344. age = 18
  345. test_id = await test_tab_dao.insert_user(name=name, email=email, age=age)
  346. output = f"【test_id】: {test_id}"
  347. # 直接执行
  348. server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="/mysql/add")
  349. # 返回字典格式的响应
  350. return JSONResponse(
  351. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  352. except ValueError as err:
  353. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/mysql/add")
  354. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  355. except Exception as err:
  356. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/mysql/add")
  357. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  358. @test_router.post("/mysql/get", response_model=TestForm)
  359. async def test_mysql_add(
  360. request: Request,
  361. param: TestForm,
  362. trace_id: str = Depends(get_operation_id)):
  363. """
  364. 根据MySQL应用
  365. """
  366. try:
  367. server_logger.info(trace_id=trace_id, msg=f"{param}")
  368. # 验证参数
  369. # 从字典中获取input
  370. input_data = param.input
  371. session_id = param.config.session_id
  372. context = param.context
  373. header_info = {
  374. }
  375. # 从app.state中获取数据库连接池
  376. async_db_pool = request.app.state.async_db_pool
  377. test_tab_dao = TestTabDAO(async_db_pool)
  378. test_id = input_data;
  379. data = await test_tab_dao.get_user_by_id(user_id=test_id)
  380. server_logger.info(trace_id=trace_id, msg=f"【result】: {data}", log_type="/mysql/get")
  381. json_str = json.dumps(data , cls=DateTimeEncoder, ensure_ascii=False, indent=2)
  382. output = json_str
  383. # 直接执行
  384. server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="/mysql/get")
  385. # 返回字典格式的响应
  386. return JSONResponse(
  387. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  388. except ValueError as err:
  389. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/mysql/get")
  390. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  391. except Exception as err:
  392. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/mysql/get")
  393. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  394. @test_router.post("/mysql/list", response_model=TestForm)
  395. async def test_mysql_add(
  396. request: Request,
  397. param: TestForm,
  398. trace_id: str = Depends(get_operation_id)):
  399. """
  400. 根据MySQL应用
  401. """
  402. try:
  403. server_logger.info(trace_id=trace_id, msg=f"{param}")
  404. # 验证参数
  405. # 从字典中获取input
  406. input_data = param.input
  407. session_id = param.config.session_id
  408. context = param.context
  409. header_info = {
  410. }
  411. # 从app.state中获取数据库连接池
  412. async_db_pool = request.app.state.async_db_pool
  413. from foundation.base.mysql.async_mysql_base_dao import TestTabDAO
  414. test_tab_dao = TestTabDAO(async_db_pool)
  415. test_id = input_data;
  416. data = await test_tab_dao.get_all_users()
  417. server_logger.info(trace_id=trace_id, msg=f"【result】: {data}", log_type="/mysql/list")
  418. json_str = json.dumps(data , cls=DateTimeEncoder, ensure_ascii=False, indent=2)
  419. output = json_str
  420. # 直接执行
  421. server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="/mysql/list")
  422. # 返回字典格式的响应
  423. return JSONResponse(
  424. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  425. except ValueError as err:
  426. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/mysql/list")
  427. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  428. except Exception as err:
  429. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/mysql/list")
  430. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  431. @test_router.post("/mysql/update", response_model=TestForm)
  432. async def test_mysql_add(
  433. request: Request,
  434. param: TestForm,
  435. trace_id: str = Depends(get_operation_id)):
  436. """
  437. 根据MySQL应用
  438. """
  439. try:
  440. server_logger.info(trace_id=trace_id, msg=f"{param}")
  441. # 验证参数
  442. # 从字典中获取input
  443. input_data = param.input
  444. session_id = param.config.session_id
  445. context = param.context
  446. header_info = {
  447. }
  448. # 从app.state中获取数据库连接池
  449. async_db_pool = request.app.state.async_db_pool
  450. test_tab_dao = TestTabDAO(async_db_pool)
  451. test_id = session_id;
  452. updates = {
  453. "name": input_data,
  454. "email": "test_email——upt",
  455. "age": 22
  456. }
  457. success = await test_tab_dao.update_user(user_id=test_id , **updates)
  458. server_logger.info(trace_id=trace_id, msg=f"【result】: {success}", log_type="/mysql/update")
  459. output = success
  460. # 直接执行
  461. server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="/mysql/update")
  462. # 返回字典格式的响应
  463. return JSONResponse(
  464. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  465. except ValueError as err:
  466. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/mysql/update")
  467. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  468. except Exception as err:
  469. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/mysql/update")
  470. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  471. @test_router.post("/bop/get", response_model=TestForm)
  472. async def test_bop_get(
  473. request: Request,
  474. param: TestForm,
  475. trace_id: str = Depends(get_operation_id)):
  476. """
  477. 根据MySQL应用
  478. """
  479. try:
  480. server_logger.info(trace_id=trace_id, msg=f"{param}")
  481. # 验证参数
  482. # 从字典中获取input
  483. input_data = param.input
  484. session_id = param.config.session_id
  485. context = param.context
  486. header_info = {
  487. }
  488. # 从app.state中获取数据库连接池
  489. async_db_pool = request.app.state.async_db_pool
  490. bop_dao = BasisOfPreparationDAO(async_db_pool)
  491. test_id = input_data;
  492. data = await bop_dao.get_info_by_id(id=test_id)
  493. server_logger.info(trace_id=trace_id, msg=f"【result】: {data}", log_type="/bop/get")
  494. json_str = json.dumps(data , cls=DateTimeEncoder, ensure_ascii=False, indent=2)
  495. output = json_str
  496. # 直接执行
  497. server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="/bop/get")
  498. # 返回字典格式的响应
  499. return JSONResponse(
  500. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  501. except ValueError as err:
  502. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/bop/get")
  503. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  504. except Exception as err:
  505. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/bop/get")
  506. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  507. @test_router.post("/bop/list", response_model=TestForm)
  508. async def test_mysql_add(
  509. request: Request,
  510. param: TestForm,
  511. trace_id: str = Depends(get_operation_id)):
  512. """
  513. 根据MySQL应用
  514. """
  515. try:
  516. server_logger.info(trace_id=trace_id, msg=f"{param}")
  517. # 验证参数
  518. # 从字典中获取input
  519. input_data = param.input
  520. session_id = param.config.session_id
  521. context = param.context
  522. header_info = {
  523. }
  524. # 从app.state中获取数据库连接池
  525. async_db_pool = request.app.state.async_db_pool
  526. from foundation.base.mysql.async_mysql_base_dao import TestTabDAO
  527. bop_dao = BasisOfPreparationDAO(async_db_pool)
  528. test_id = input_data;
  529. data = await bop_dao.get_list()
  530. server_logger.info(trace_id=trace_id, msg=f"【result】: {data}", log_type="/bop/list")
  531. json_str = json.dumps(data , cls=DateTimeEncoder, ensure_ascii=False, indent=2)
  532. output = json_str
  533. # 直接执行
  534. server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="/bop/list")
  535. # 返回字典格式的响应
  536. return JSONResponse(
  537. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  538. except ValueError as err:
  539. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/bop/list")
  540. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  541. except Exception as err:
  542. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/bop/list")
  543. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  544. ##################【RAG 相关测试】##############################################
  545. @test_router.post("/embedding", response_model=TestForm)
  546. async def embedding_test_endpoint(
  547. param: TestForm,
  548. trace_id: str = Depends(get_operation_id)):
  549. """
  550. embedding模型测试
  551. """
  552. try:
  553. server_logger.info(trace_id=trace_id, msg=f"{param}")
  554. print(trace_id)
  555. # 从字典中获取input
  556. input_query = param.input
  557. session_id = param.config.session_id
  558. context = param.context
  559. header_info = {
  560. }
  561. task_prompt_info = {"task_prompt": ""}
  562. text = input_query
  563. # 初始化客户端(需提前设置环境变量 SILICONFLOW_API_KEY)
  564. from foundation.models.silicon_flow import SiliconFlowAPI
  565. base_api_platform = SiliconFlowAPI()
  566. embedding = base_api_platform.get_embeddings([text])[0]
  567. embed_dim = len(embedding)
  568. server_logger.info(trace_id=trace_id, msg=f"【result】: {embed_dim}")
  569. output = f"embed_dim={embed_dim},embedding:{embedding}"
  570. #output = test_generate_model_client.get_model_data_governance_invoke(trace_id , task_prompt_info, input_query, context)
  571. # 直接执行
  572. #server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="embedding")
  573. # 返回字典格式的响应
  574. return JSONResponse(
  575. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  576. except ValueError as err:
  577. handler_err(server_logger, trace_id=trace_id, err=err, err_name="embedding")
  578. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  579. except Exception as err:
  580. handler_err(server_logger, trace_id=trace_id, err=err, err_name="embedding")
  581. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  582. @test_router.post("/bfp/search", response_model=TestForm)
  583. async def bfp_search_endpoint(
  584. param: TestForm,
  585. trace_id: str = Depends(get_operation_id)):
  586. """
  587. 编制依据向量检索
  588. """
  589. try:
  590. server_logger.info(trace_id=trace_id, msg=f"{param}")
  591. print(trace_id)
  592. # 从字典中获取input
  593. input_query = param.input
  594. session_id = param.config.session_id
  595. context = param.context
  596. header_info = {
  597. }
  598. task_prompt_info = {"task_prompt": ""}
  599. top_k = int(session_id)
  600. output = None
  601. # 初始化客户端(需提前设置环境变量 SILICONFLOW_API_KEY)
  602. client = SiliconFlowAPI()
  603. # 抽象测试
  604. pg_vector_db = PGVectorDB(base_api_platform=client)
  605. output = pg_vector_db.retriever(param={"table_name": "tv_basis_of_preparation"}, query_text=input_query , top_k=top_k)
  606. # 返回字典格式的响应
  607. return JSONResponse(
  608. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  609. except ValueError as err:
  610. handler_err(server_logger, trace_id=trace_id, err=err, err_name="bfp/search")
  611. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  612. except Exception as err:
  613. handler_err(server_logger, trace_id=trace_id, err=err, err_name="bfp/search")
  614. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  615. @test_router.post("/bfp/search/rerank", response_model=TestForm)
  616. async def bfp_search_endpoint(
  617. param: TestForm,
  618. trace_id: str = Depends(get_operation_id)):
  619. """
  620. 编制依据文档检索和重排序
  621. """
  622. try:
  623. server_logger.info(trace_id=trace_id, msg=f"{param}")
  624. print(trace_id)
  625. # 从字典中获取input
  626. input_query = param.input
  627. session_id = param.config.session_id
  628. context = param.context
  629. header_info = {
  630. }
  631. task_prompt_info = {"task_prompt": ""}
  632. top_k = int(session_id)
  633. output = None
  634. # 初始化客户端(需提前设置环境变量 SILICONFLOW_API_KEY)
  635. client = SiliconFlowAPI()
  636. # 抽象测试
  637. pg_vector_db = PGVectorDB(base_api_platform=client)
  638. output = pg_vector_db.retriever(param={"table_name": "tv_basis_of_preparation"}, query_text=input_query , top_k=top_k)
  639. # 重排序处理
  640. content_list = [doc["text_content"] for doc in output]
  641. output = client.rerank(input_query=input_query, documents=content_list , top_n=top_k)
  642. # 返回字典格式的响应
  643. return JSONResponse(
  644. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  645. except ValueError as err:
  646. handler_err(server_logger, trace_id=trace_id, err=err, err_name="bfp/search")
  647. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  648. except Exception as err:
  649. handler_err(server_logger, trace_id=trace_id, err=err, err_name="bfp/search")
  650. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))