|
@@ -28,6 +28,8 @@ from database.repositories.bus_data_query import BasisOfPreparationDAO
|
|
|
from foundation.utils.tool_utils import DateTimeEncoder
|
|
from foundation.utils.tool_utils import DateTimeEncoder
|
|
|
from foundation.models.silicon_flow import SiliconFlowAPI
|
|
from foundation.models.silicon_flow import SiliconFlowAPI
|
|
|
from foundation.rag.vector.pg_vector import PGVectorDB
|
|
from foundation.rag.vector.pg_vector import PGVectorDB
|
|
|
|
|
+from langchain_core.prompts import ChatPromptTemplate
|
|
|
|
|
+from foundation.utils.yaml_utils import system_prompt_config
|
|
|
|
|
|
|
|
|
|
|
|
|
@test_router.post("/generate/chat", response_model=TestForm)
|
|
@test_router.post("/generate/chat", response_model=TestForm)
|
|
@@ -46,21 +48,27 @@ async def generate_chat_endpoint(
|
|
|
context = param.context
|
|
context = param.context
|
|
|
header_info = {
|
|
header_info = {
|
|
|
}
|
|
}
|
|
|
- task_prompt_info = {"task_prompt": ""}
|
|
|
|
|
- output = generate_model_client.get_model_generate_invoke(trace_id , task_prompt_info,
|
|
|
|
|
- input_query, context)
|
|
|
|
|
|
|
+
|
|
|
|
|
+ # 创建ChatPromptTemplate
|
|
|
|
|
+ template = ChatPromptTemplate.from_messages([
|
|
|
|
|
+ ("system", system_prompt_config['system_prompt']),
|
|
|
|
|
+ ("user", input_query)
|
|
|
|
|
+ ])
|
|
|
|
|
+
|
|
|
|
|
+ task_prompt_info = {"task_prompt": template}
|
|
|
|
|
+ output = await generate_model_client.get_model_generate_invoke(trace_id=trace_id , task_prompt_info=task_prompt_info)
|
|
|
# 直接执行
|
|
# 直接执行
|
|
|
- server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="agent/chat")
|
|
|
|
|
|
|
+ server_logger.info(trace_id=trace_id, msg=f"【result】: {output}", log_type="agent/chat")
|
|
|
# 返回字典格式的响应
|
|
# 返回字典格式的响应
|
|
|
return JSONResponse(
|
|
return JSONResponse(
|
|
|
return_json(data={"output": output}, data_type="text", trace_id=trace_id))
|
|
return_json(data={"output": output}, data_type="text", trace_id=trace_id))
|
|
|
|
|
|
|
|
except ValueError as err:
|
|
except ValueError as err:
|
|
|
- handler_err(server_logger, trace_id=trace_id, err=err, err_name="generate/stream")
|
|
|
|
|
|
|
+ handler_err(server_logger, trace_id=trace_id, err=err, err_name="agent/chat")
|
|
|
return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
|
|
return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
|
|
|
|
|
|
|
|
except Exception as err:
|
|
except Exception as err:
|
|
|
- handler_err(server_logger, trace_id=trace_id, err=err, err_name="generate/stream")
|
|
|
|
|
|
|
+ handler_err(server_logger, trace_id=trace_id, err=err, err_name="agent/chat")
|
|
|
return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
|
|
return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
|
|
|
|
|
|
|
|
|
|
|
|
@@ -80,13 +88,18 @@ async def generate_stream_endpoint(
|
|
|
context = param.context
|
|
context = param.context
|
|
|
header_info = {
|
|
header_info = {
|
|
|
}
|
|
}
|
|
|
- task_prompt_info = {"task_prompt": ""}
|
|
|
|
|
|
|
+ # 创建ChatPromptTemplate
|
|
|
|
|
+ template = ChatPromptTemplate.from_messages([
|
|
|
|
|
+ ("system", system_prompt_config['system_prompt']),
|
|
|
|
|
+ ("user", input_query)
|
|
|
|
|
+ ])
|
|
|
|
|
+
|
|
|
|
|
+ task_prompt_info = {"task_prompt": template}
|
|
|
# 创建 SSE 流式响应
|
|
# 创建 SSE 流式响应
|
|
|
async def event_generator():
|
|
async def event_generator():
|
|
|
try:
|
|
try:
|
|
|
# 流式处理查询 trace_id, task_prompt_info: dict, input_query, context=None
|
|
# 流式处理查询 trace_id, task_prompt_info: dict, input_query, context=None
|
|
|
- for chunk in generate_model_client.get_model_generate_stream(trace_id , task_prompt_info,
|
|
|
|
|
- input_query, context):
|
|
|
|
|
|
|
+ for chunk in generate_model_client.get_model_generate_stream(trace_id=trace_id , task_prompt_info=task_prompt_info):
|
|
|
# 发送数据块
|
|
# 发送数据块
|
|
|
yield {
|
|
yield {
|
|
|
"event": "message",
|
|
"event": "message",
|