test_views.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871
  1. # !/usr/bin/ python
  2. # -*- coding: utf-8 -*-
  3. '''
  4. @Project : lq-agent-api
  5. @File :cattle_farm_views.py
  6. @IDE :PyCharm
  7. @Author :
  8. @Date :2025/7/10 17:32
  9. '''
  10. import json
  11. from typing import Optional, List, Dict, Any
  12. from pydantic import BaseModel, Field
  13. from fastapi import Depends, Response, Header
  14. from sse_starlette import EventSourceResponse
  15. from starlette.responses import JSONResponse
  16. from fastapi import Depends, Request, Response, Header
  17. from foundation.ai.agent.test_agent import test_agent_client
  18. from foundation.ai.agent.generate.model_generate import generate_model_client
  19. from foundation.observability.logger.loggering import server_logger
  20. from foundation.schemas.test_schemas import TestForm
  21. from foundation.utils.common import return_json, handler_err
  22. from views import test_router, get_operation_id
  23. from foundation.ai.agent.workflow.test_workflow_graph import test_workflow_graph
  24. from foundation.database.base.sql.async_mysql_base_dao import TestTabDAO
  25. from database.repositories.bus_data_query import BasisOfPreparationDAO
  26. from foundation.utils.tool_utils import DateTimeEncoder
  27. from langchain_core.prompts import ChatPromptTemplate
  28. from foundation.utils.yaml_utils import system_prompt_config
  29. from foundation.ai.models.model_handler import model_handler as mh
  30. from foundation.database.base.vector.pg_vector import PGVectorDB
  31. from foundation.database.base.vector.milvus_vector import MilvusVectorManager
  32. # 响应模型用于Swagger文档
  33. class StandardResponse(BaseModel):
  34. """标准响应模型"""
  35. code: int = Field(description="响应状态码 (0:成功, 其他:错误)")
  36. msg: str = Field(description="响应消息")
  37. data: Optional[Dict[str, Any]] = Field(default=None, description="响应数据")
  38. data_type: Optional[str] = Field(default=None, description="数据类型")
  39. trace_id: Optional[str] = Field(default=None, description="追踪ID")
  40. class UserRecord(BaseModel):
  41. """用户记录模型"""
  42. user_id: Optional[int] = Field(description="用户ID")
  43. name: Optional[str] = Field(description="用户名")
  44. email: Optional[str] = Field(description="邮箱")
  45. age: Optional[int] = Field(description="年龄")
  46. created_at: Optional[str] = Field(description="创建时间")
  47. class EmbeddingResponse(BaseModel):
  48. """向量嵌入响应模型"""
  49. embed_dim: int = Field(description="嵌入向量维度")
  50. embedding: List[float] = Field(description="嵌入向量数据")
  51. class SearchResult(BaseModel):
  52. """搜索结果模型"""
  53. text_content: Optional[str] = Field(description="文本内容")
  54. score: Optional[float] = Field(description="相关性得分")
  55. metadata: Optional[Dict[str, Any]] = Field(default=None, description="元数据")
  56. class SSEData(BaseModel):
  57. """SSE数据模型"""
  58. code: int = Field(description="状态码")
  59. output: Optional[str] = Field(default=None, description="输出内容")
  60. completed: bool = Field(description="是否完成")
  61. trace_id: Optional[str] = Field(default=None, description="追踪ID")
  62. message: Optional[str] = Field(default=None, description="消息")
  63. dataType: Optional[str] = Field(default="text", description="数据类型")
  64. @test_router.post("/generate/chat", tags=["模型生成"])
  65. async def generate_chat_endpoint(
  66. param: TestForm,
  67. trace_id: str = Depends(get_operation_id)):
  68. """
  69. 生成类模型
  70. """
  71. try:
  72. server_logger.info(trace_id=trace_id, msg=f"{param}")
  73. print(trace_id)
  74. # 从字典中获取input
  75. input_query = param.input
  76. session_id = param.config.session_id
  77. context = param.context
  78. header_info = {
  79. }
  80. # 创建ChatPromptTemplate
  81. template = ChatPromptTemplate.from_messages([
  82. ("system", system_prompt_config['system_prompt']),
  83. ("user", input_query)
  84. ])
  85. task_prompt_info = {"task_prompt": template}
  86. output = await generate_model_client.get_model_generate_invoke(trace_id=trace_id , task_prompt_info=task_prompt_info)
  87. # 直接执行
  88. server_logger.info(trace_id=trace_id, msg=f"【result】: {output}", log_type="agent/chat")
  89. # 返回字典格式的响应
  90. return JSONResponse(
  91. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  92. except ValueError as err:
  93. handler_err(server_logger, trace_id=trace_id, err=err, err_name="agent/chat")
  94. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  95. except Exception as err:
  96. handler_err(server_logger, trace_id=trace_id, err=err, err_name="agent/chat")
  97. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  98. @test_router.post("/generate/stream", tags=["模型生成"])
  99. async def generate_stream_endpoint(
  100. param: TestForm,
  101. trace_id: str = Depends(get_operation_id)):
  102. """
  103. 生成类模型
  104. """
  105. try:
  106. server_logger.info(trace_id=trace_id, msg=f"{param}")
  107. # 从字典中获取input
  108. input_query = param.input
  109. session_id = param.config.session_id
  110. context = param.context
  111. header_info = {
  112. }
  113. # 创建ChatPromptTemplate
  114. template = ChatPromptTemplate.from_messages([
  115. ("system", system_prompt_config['system_prompt']),
  116. ("user", input_query)
  117. ])
  118. task_prompt_info = {"task_prompt": template}
  119. # 创建 SSE 流式响应
  120. async def event_generator():
  121. try:
  122. # 流式处理查询 trace_id, task_prompt_info: dict, input_query, context=None
  123. for chunk in generate_model_client.get_model_generate_stream(trace_id=trace_id , task_prompt_info=task_prompt_info):
  124. # 发送数据块
  125. yield {
  126. "event": "message",
  127. "data": json.dumps({
  128. "output": chunk,
  129. "completed": False,
  130. }, ensure_ascii=False)
  131. }
  132. # 获取缓存数据
  133. result_data = {}
  134. # 发送结束事件
  135. yield {
  136. "event": "message_end",
  137. "data": json.dumps({
  138. "completed": True,
  139. "message": json.dumps(result_data, ensure_ascii=False),
  140. "code": 0,
  141. "trace_id": trace_id,
  142. }, ensure_ascii=False),
  143. }
  144. except Exception as e:
  145. # 错误处理
  146. yield {
  147. "event": "error",
  148. "data": json.dumps({
  149. "trace_id": trace_id,
  150. "message": str(e),
  151. "code": 1
  152. }, ensure_ascii=False)
  153. }
  154. finally:
  155. # 不需要关闭客户端,因为它是单例
  156. pass
  157. # 返回 SSE 响应
  158. return EventSourceResponse(
  159. event_generator(),
  160. headers={
  161. "Cache-Control": "no-cache",
  162. "Connection": "keep-alive"
  163. }
  164. )
  165. except ValueError as err:
  166. handler_err(server_logger, trace_id=trace_id, err=err, err_name="generate/stream")
  167. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  168. except Exception as err:
  169. handler_err(server_logger, trace_id=trace_id, err=err, err_name="generate/stream")
  170. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  171. # 路由
  172. @test_router.post("/agent/chat", tags=["智能体"])
  173. async def chat_endpoint(
  174. param: TestForm,
  175. trace_id: str = Depends(get_operation_id)):
  176. """
  177. 根据场景获取智能体反馈
  178. """
  179. try:
  180. server_logger.info(trace_id=trace_id, msg=f"{param}")
  181. # 验证参数
  182. # 从字典中获取input
  183. input_data = param.input
  184. session_id = param.config.session_id
  185. context = param.context
  186. header_info = {
  187. }
  188. task_prompt_info = {"task_prompt": ""}
  189. # stream 流式执行
  190. output = await test_agent_client.handle_query(trace_id , task_prompt_info, input_data, context, param.config)
  191. # 直接执行
  192. server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="agent/chat")
  193. # 返回字典格式的响应
  194. return JSONResponse(
  195. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  196. except ValueError as err:
  197. handler_err(server_logger, trace_id=trace_id, err=err, err_name="agent/chat")
  198. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  199. except Exception as err:
  200. handler_err(server_logger, trace_id=trace_id, err=err, err_name="agent/chat")
  201. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  202. @test_router.post("/agent/stream", response_class=Response, tags=["智能体"])
  203. async def chat_agent_stream(param: TestForm,
  204. trace_id: str = Depends(get_operation_id)):
  205. """
  206. 根据场景获取智能体反馈 (SSE流式响应)
  207. """
  208. try:
  209. server_logger.info(trace_id=trace_id, msg=f"{param}")
  210. # 提取参数
  211. input_data = param.input
  212. context = param.context
  213. header_info = {
  214. }
  215. task_prompt_info = {"task_prompt": ""}
  216. # 如果business_scene为None,则使用大模型进行意图识别
  217. server_logger.info(trace_id=trace_id, msg=f"{param}")
  218. # 创建 SSE 流式响应
  219. async def event_generator():
  220. try:
  221. # 流式处理查询
  222. async for chunk in test_agent_client.handle_query_stream(
  223. trace_id=trace_id,
  224. config_param=param.config,
  225. task_prompt_info=task_prompt_info,
  226. input_query=input_data,
  227. context=context,
  228. header_info=header_info
  229. ):
  230. server_logger.debug(trace_id=trace_id, msg=f"{chunk}")
  231. # 发送数据块
  232. yield {
  233. "event": "message",
  234. "data": json.dumps({
  235. "code": 0,
  236. "output": chunk,
  237. "completed": False,
  238. "trace_id": trace_id,
  239. }, ensure_ascii=False)
  240. }
  241. # 获取缓存数据
  242. result_data = {}
  243. # 发送结束事件
  244. yield {
  245. "event": "message_end",
  246. "data": json.dumps({
  247. "completed": True,
  248. "message": json.dumps(result_data, ensure_ascii=False),
  249. "code": 0,
  250. "trace_id": trace_id,
  251. }, ensure_ascii=False),
  252. }
  253. except Exception as e:
  254. # 错误处理
  255. yield {
  256. "event": "error",
  257. "data": json.dumps({
  258. "trace_id": trace_id,
  259. "message": str(e),
  260. "code": 1
  261. }, ensure_ascii=False)
  262. }
  263. finally:
  264. # 不需要关闭客户端,因为它是单例
  265. pass
  266. # 返回 SSE 响应
  267. return EventSourceResponse(
  268. event_generator(),
  269. headers={
  270. "Cache-Control": "no-cache",
  271. "Connection": "keep-alive"
  272. }
  273. )
  274. except Exception as err:
  275. # 初始错误处理
  276. handler_err(server_logger, trace_id=trace_id, err=err, err_name="agent/stream")
  277. return JSONResponse(
  278. return_json(code=1, msg=f"{err}", trace_id=trace_id),
  279. status_code=500
  280. )
  281. @test_router.post("/graph/stream", response_class=Response, tags=["智能体"])
  282. async def chat_graph_stream(param: TestForm,
  283. trace_id: str = Depends(get_operation_id)):
  284. """
  285. 根据场景获取智能体反馈 (SSE流式响应)
  286. """
  287. try:
  288. server_logger.info(trace_id=trace_id, msg=f"{param}")
  289. # request_param = {
  290. # "input": param.input,
  291. # "config": param.config,
  292. # "context": param.context
  293. # }
  294. # 创建 SSE 流式响应
  295. async def event_generator():
  296. try:
  297. # 流式处理查询
  298. async for chunk in test_workflow_graph.handle_query_stream(
  299. param=param,
  300. trace_id=trace_id,
  301. ):
  302. server_logger.debug(trace_id=trace_id, msg=f"{chunk}")
  303. # 发送数据块
  304. yield {
  305. "event": "message",
  306. "data": json.dumps({
  307. "code": 0,
  308. "output": chunk,
  309. "completed": False,
  310. "trace_id": trace_id,
  311. "dataType": "text"
  312. }, ensure_ascii=False)
  313. }
  314. # 发送结束事件
  315. yield {
  316. "event": "message_end",
  317. "data": json.dumps({
  318. "completed": True,
  319. "message": "Stream completed",
  320. "code": 0,
  321. "trace_id": trace_id,
  322. "dataType": "text"
  323. }, ensure_ascii=False),
  324. }
  325. except Exception as e:
  326. # 错误处理
  327. yield {
  328. "event": "error",
  329. "data": json.dumps({
  330. "trace_id": trace_id,
  331. "msg": str(e),
  332. "code": 1,
  333. "dataType": "text"
  334. }, ensure_ascii=False)
  335. }
  336. finally:
  337. # 不需要关闭客户端,因为它是单例
  338. pass
  339. # 返回 SSE 响应
  340. return EventSourceResponse(
  341. event_generator(),
  342. headers={
  343. "Cache-Control": "no-cache",
  344. "Connection": "keep-alive"
  345. }
  346. )
  347. except Exception as err:
  348. # 初始错误处理
  349. handler_err(server_logger, trace_id=trace_id, err=err, err_name="graph/stream")
  350. return JSONResponse(
  351. return_json(code=1, msg=f"{err}", trace_id=trace_id),
  352. status_code=500
  353. )
  354. @test_router.post("/redis", tags=["redis"])
  355. async def test_redis(
  356. request: Request,
  357. param: TestForm,
  358. trace_id: str = Depends(get_operation_id)):
  359. """
  360. 根据MySQL应用
  361. """
  362. try:
  363. server_logger.info(trace_id=trace_id, msg=f"{param}")
  364. # 验证参数
  365. # 从字典中获取input
  366. input_data = param.input
  367. session_id = param.config.session_id
  368. context = param.context
  369. header_info = {
  370. }
  371. from foundation.utils.redis_utils import set_redis_result_cache_data , get_redis_result_cache_data
  372. output = "success"
  373. data_type = "output"
  374. await set_redis_result_cache_data(data_type=data_type , trace_id=trace_id , value=input_data)
  375. server_logger.info(trace_id=trace_id, msg=f"key:{trace_id}:{data_type},value:{input_data} redis 设置成功")
  376. output = await get_redis_result_cache_data(data_type=data_type , trace_id=trace_id)
  377. # 直接执行
  378. server_logger.info(trace_id=trace_id, msg=f"【result】: {output}", log_type="/redis")
  379. # 返回字典格式的响应
  380. return JSONResponse(
  381. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  382. except ValueError as err:
  383. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/redis")
  384. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  385. except Exception as err:
  386. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/redis")
  387. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  388. @test_router.post("/mysql/add", tags=["mysql"])
  389. async def test_mysql_add(
  390. request: Request,
  391. param: TestForm,
  392. trace_id: str = Depends(get_operation_id)):
  393. """
  394. 根据MySQL应用
  395. """
  396. try:
  397. server_logger.info(trace_id=trace_id, msg=f"{param}")
  398. # 验证参数
  399. # 从字典中获取input
  400. input_data = param.input
  401. session_id = param.config.session_id
  402. context = param.context
  403. header_info = {
  404. }
  405. # 从app.state中获取数据库连接池
  406. async_db_pool = request.app.state.async_db_pool
  407. from foundation.database.base.sql.async_mysql_base_dao import TestTabDAO
  408. test_tab_dao = TestTabDAO(async_db_pool)
  409. # name: str, email: str, age: int
  410. name = input_data
  411. email = session_id
  412. age = 18
  413. test_id = await test_tab_dao.insert_user(name=name, email=email, age=age)
  414. output = f"【test_id】: {test_id}"
  415. # 直接执行
  416. server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="/mysql/add")
  417. # 返回字典格式的响应
  418. return JSONResponse(
  419. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  420. except ValueError as err:
  421. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/mysql/add")
  422. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  423. except Exception as err:
  424. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/mysql/add")
  425. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  426. @test_router.post("/mysql/get", tags=["mysql"])
  427. async def test_mysql_add(
  428. request: Request,
  429. param: TestForm,
  430. trace_id: str = Depends(get_operation_id)):
  431. """
  432. 根据MySQL应用
  433. """
  434. try:
  435. server_logger.info(trace_id=trace_id, msg=f"{param}")
  436. # 验证参数
  437. # 从字典中获取input
  438. input_data = param.input
  439. session_id = param.config.session_id
  440. context = param.context
  441. header_info = {
  442. }
  443. # 从app.state中获取数据库连接池
  444. async_db_pool = request.app.state.async_db_pool
  445. test_tab_dao = TestTabDAO(async_db_pool)
  446. test_id = input_data;
  447. data = await test_tab_dao.get_user_by_id(user_id=test_id)
  448. server_logger.info(trace_id=trace_id, msg=f"【result】: {data}", log_type="/mysql/get")
  449. json_str = json.dumps(data , cls=DateTimeEncoder, ensure_ascii=False, indent=2)
  450. output = json_str
  451. # 直接执行
  452. server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="/mysql/get")
  453. # 返回字典格式的响应
  454. return JSONResponse(
  455. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  456. except ValueError as err:
  457. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/mysql/get")
  458. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  459. except Exception as err:
  460. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/mysql/get")
  461. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  462. @test_router.post("/mysql/list", tags=["mysql"])
  463. async def test_mysql_add(
  464. request: Request,
  465. param: TestForm,
  466. trace_id: str = Depends(get_operation_id)):
  467. """
  468. 根据MySQL应用
  469. """
  470. try:
  471. server_logger.info(trace_id=trace_id, msg=f"{param}")
  472. # 验证参数
  473. # 从字典中获取input
  474. input_data = param.input
  475. session_id = param.config.session_id
  476. context = param.context
  477. header_info = {
  478. }
  479. # 从app.state中获取数据库连接池
  480. async_db_pool = request.app.state.async_db_pool
  481. from foundation.database.base.sql.async_mysql_base_dao import TestTabDAO
  482. test_tab_dao = TestTabDAO(async_db_pool)
  483. test_id = input_data;
  484. data = await test_tab_dao.get_all_users()
  485. server_logger.info(trace_id=trace_id, msg=f"【result】: {data}", log_type="/mysql/list")
  486. json_str = json.dumps(data , cls=DateTimeEncoder, ensure_ascii=False, indent=2)
  487. output = json_str
  488. # 直接执行
  489. server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="/mysql/list")
  490. # 返回字典格式的响应
  491. return JSONResponse(
  492. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  493. except ValueError as err:
  494. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/mysql/list")
  495. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  496. except Exception as err:
  497. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/mysql/list")
  498. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  499. @test_router.post("/mysql/update", tags=["mysql"])
  500. async def test_mysql_add(
  501. request: Request,
  502. param: TestForm,
  503. trace_id: str = Depends(get_operation_id)):
  504. """
  505. 根据MySQL应用
  506. """
  507. try:
  508. server_logger.info(trace_id=trace_id, msg=f"{param}")
  509. # 验证参数
  510. # 从字典中获取input
  511. input_data = param.input
  512. session_id = param.config.session_id
  513. context = param.context
  514. header_info = {
  515. }
  516. # 从app.state中获取数据库连接池
  517. async_db_pool = request.app.state.async_db_pool
  518. test_tab_dao = TestTabDAO(async_db_pool)
  519. test_id = session_id;
  520. updates = {
  521. "name": input_data,
  522. "email": "test_email——upt",
  523. "age": 22
  524. }
  525. success = await test_tab_dao.update_user(user_id=test_id , **updates)
  526. server_logger.info(trace_id=trace_id, msg=f"【result】: {success}", log_type="/mysql/update")
  527. output = success
  528. # 直接执行
  529. server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="/mysql/update")
  530. # 返回字典格式的响应
  531. return JSONResponse(
  532. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  533. except ValueError as err:
  534. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/mysql/update")
  535. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  536. except Exception as err:
  537. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/mysql/update")
  538. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  539. @test_router.post("/bop/get", tags=["mysql"])
  540. async def test_bop_get(
  541. request: Request,
  542. param: TestForm,
  543. trace_id: str = Depends(get_operation_id)):
  544. """
  545. 根据MySQL应用
  546. """
  547. try:
  548. server_logger.info(trace_id=trace_id, msg=f"{param}")
  549. # 验证参数
  550. # 从字典中获取input
  551. input_data = param.input
  552. session_id = param.config.session_id
  553. context = param.context
  554. header_info = {
  555. }
  556. # 从app.state中获取数据库连接池
  557. async_db_pool = request.app.state.async_db_pool
  558. bop_dao = BasisOfPreparationDAO(async_db_pool)
  559. test_id = input_data;
  560. data = await bop_dao.get_info_by_id(id=test_id)
  561. server_logger.info(trace_id=trace_id, msg=f"【result】: {data}", log_type="/bop/get")
  562. json_str = json.dumps(data , cls=DateTimeEncoder, ensure_ascii=False, indent=2)
  563. output = json_str
  564. # 直接执行
  565. server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="/bop/get")
  566. # 返回字典格式的响应
  567. return JSONResponse(
  568. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  569. except ValueError as err:
  570. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/bop/get")
  571. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  572. except Exception as err:
  573. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/bop/get")
  574. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  575. @test_router.post("/bop/list", tags=["mysql"])
  576. async def test_mysql_add(
  577. request: Request,
  578. param: TestForm,
  579. trace_id: str = Depends(get_operation_id)):
  580. """
  581. 根据MySQL应用
  582. """
  583. try:
  584. server_logger.info(trace_id=trace_id, msg=f"{param}")
  585. # 验证参数
  586. # 从字典中获取input
  587. input_data = param.input
  588. session_id = param.config.session_id
  589. context = param.context
  590. header_info = {
  591. }
  592. # 从app.state中获取数据库连接池
  593. async_db_pool = request.app.state.async_db_pool
  594. from foundation.database.base.sql.async_mysql_base_dao import TestTabDAO
  595. bop_dao = BasisOfPreparationDAO(async_db_pool)
  596. test_id = input_data;
  597. data = await bop_dao.get_list()
  598. server_logger.info(trace_id=trace_id, msg=f"【result】: {data}", log_type="/bop/list")
  599. json_str = json.dumps(data , cls=DateTimeEncoder, ensure_ascii=False, indent=2)
  600. output = json_str
  601. # 直接执行
  602. server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="/bop/list")
  603. # 返回字典格式的响应
  604. return JSONResponse(
  605. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  606. except ValueError as err:
  607. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/bop/list")
  608. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  609. except Exception as err:
  610. handler_err(server_logger, trace_id=trace_id, err=err, err_name="/bop/list")
  611. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  612. # RAG-嵌入接口
  613. @test_router.post("/embedding", tags=["RAG服务"])
  614. async def embedding_test_endpoint(
  615. param: TestForm,
  616. trace_id: str = Depends(get_operation_id)):
  617. """
  618. embedding模型测试
  619. """
  620. try:
  621. server_logger.info(trace_id=trace_id, msg=f"{param}")
  622. print(trace_id)
  623. # 从字典中获取input
  624. input_query = param.input
  625. session_id = param.config.session_id
  626. context = param.context
  627. header_info = {
  628. }
  629. task_prompt_info = {"task_prompt": ""}
  630. text = input_query
  631. embedding = mh._get_lq_qwen3_8b_emd()
  632. embed_dim = len(embedding)
  633. server_logger.info(trace_id=trace_id, msg=f"【result】: {embed_dim}")
  634. output = f"embed_dim={embed_dim},embedding:{embedding}"
  635. #output = test_generate_model_client.get_model_data_governance_invoke(trace_id , task_prompt_info, input_query, context)
  636. # 直接执行
  637. #server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="embedding")
  638. # 返回字典格式的响应
  639. return JSONResponse(
  640. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  641. except ValueError as err:
  642. handler_err(server_logger, trace_id=trace_id, err=err, err_name="embedding")
  643. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  644. except Exception as err:
  645. handler_err(server_logger, trace_id=trace_id, err=err, err_name="embedding")
  646. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  647. #PG向量检索
  648. @test_router.post("/bfp/search", tags=["RAG服务"])
  649. async def bfp_search_endpoint(
  650. param: TestForm,
  651. trace_id: str = Depends(get_operation_id)):
  652. """
  653. 编制依据向量检索
  654. """
  655. try:
  656. server_logger.info(trace_id=trace_id, msg=f"{param}")
  657. print(trace_id)
  658. # 从字典中获取input
  659. input_query = param.input
  660. session_id = param.config.session_id
  661. context = param.context
  662. header_info = {
  663. }
  664. task_prompt_info = {"task_prompt": ""}
  665. top_k = int(session_id)
  666. output = None
  667. # 初始化客户端(需提前设置环境变量 SILICONFLOW_API_KEY)
  668. # 抽象测试
  669. pg_vector_db = PGVectorDB()
  670. output = pg_vector_db.retriever(param={"table_name": "tv_basis_of_preparation"}, query_text=input_query , top_k=top_k)
  671. # 返回字典格式的响应
  672. return JSONResponse(
  673. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  674. except ValueError as err:
  675. handler_err(server_logger, trace_id=trace_id, err=err, err_name="bfp/search")
  676. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  677. except Exception as err:
  678. handler_err(server_logger, trace_id=trace_id, err=err, err_name="bfp/search")
  679. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  680. #PG向量检索/重排序
  681. @test_router.post("/bfp/search/rerank", tags=["RAG服务"])
  682. async def bfp_search_endpoint(
  683. param: TestForm,
  684. trace_id: str = Depends(get_operation_id)):
  685. """
  686. 编制依据文档检索和重排序
  687. """
  688. try:
  689. server_logger.info(trace_id=trace_id, msg=f"{param}")
  690. print(trace_id)
  691. # 从字典中获取input
  692. input_query = param.input
  693. session_id = param.config.session_id
  694. context = param.context
  695. header_info = {
  696. }
  697. task_prompt_info = {"task_prompt": ""}
  698. top_k = int(session_id)
  699. output = None
  700. # 初始化客户端(需提前设置环境变量 SILICONFLOW_API_KEY)
  701. client = SiliconFlowAPI()
  702. # 抽象测试
  703. pg_vector_db = PGVectorDB(base_api_platform=client)
  704. output = pg_vector_db.retriever(param={"table_name": "tv_basis_of_preparation"}, query_text=input_query , top_k=top_k)
  705. # 重排序处理
  706. content_list = [doc["text_content"] for doc in output]
  707. output = client.rerank(input_query=input_query, documents=content_list , top_n=top_k)
  708. # 返回字典格式的响应
  709. return JSONResponse(
  710. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  711. except ValueError as err:
  712. handler_err(server_logger, trace_id=trace_id, err=err, err_name="bfp/search")
  713. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  714. except Exception as err:
  715. handler_err(server_logger, trace_id=trace_id, err=err, err_name="bfp/search")
  716. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  717. # Milvus向量检索
  718. @test_router.post("/data/bfp/milvus/search", tags=["milvus"])
  719. async def bfp_search_endpoint(
  720. param: TestForm,
  721. trace_id: str = Depends(get_operation_id)):
  722. """
  723. 编制依据文档切分处理 和 入库处理
  724. """
  725. try:
  726. server_logger.info(trace_id=trace_id, msg=f"{param}")
  727. print(trace_id)
  728. # 从字典中获取input
  729. input_query = param.input
  730. session_id = param.config.session_id
  731. context = param.context
  732. header_info = {
  733. }
  734. task_prompt_info = {"task_prompt": ""}
  735. top_k = int(session_id)
  736. output = None
  737. # 初始化客户端(需提前设置环境变量 SILICONFLOW_API_KEY)
  738. client = SiliconFlowAPI()
  739. # 抽象测试
  740. vector_db = MilvusVectorManager(base_api_platform=client)
  741. output = vector_db.retriever(param={"collection_name": "tv_basis_of_preparation"}, query_text=input_query , top_k=top_k)
  742. # 返回字典格式的响应
  743. return JSONResponse(
  744. return_json(data={"output": output}, data_type="text", trace_id=trace_id))
  745. except ValueError as err:
  746. handler_err(server_logger, trace_id=trace_id, err=err, err_name="bfp/milvus/search")
  747. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
  748. except Exception as err:
  749. handler_err(server_logger, trace_id=trace_id, err=err, err_name="bfp/milvus/search")
  750. return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))