| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178 |
- # !/usr/bin/ python
- # -*- coding: utf-8 -*-
- '''
- @Project : lq-agent-api
- @File :model_generate.py
- @IDE :PyCharm
- @Author :
- @Date :2025/7/14 14:22
- '''
- from typing import List, Dict, Any, Optional, AsyncGenerator
- from langchain_core.prompts import HumanMessagePromptTemplate
- from langchain_core.prompts import ChatPromptTemplate
- from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
- from langgraph.prebuilt import ToolNode
- from utils.utils import get_models
- from views import mcp_server
- from utils.yaml_utils import system_prompt_config
- from logger.loggering import server_logger
- class XiwuzcModelGenerateClient:
- """
- 主要是生成式模型
- """
- def __init__(self):
- # 获取部署的模型列表
- llm, chat, embed = get_models()
- self.llm = llm
- self.chat = chat
- # 构造工具列表
- self.tool_node_list = [] # ToolNode(mcp_server.tools)
- # 模型绑定工具列表
- self.llm_with_tools = None #llm.bind_tools(mcp_server.tools)
- # 工具调用系统提示词
- self.system_prompt = "" #system_prompt_config["tools_system_prompt"]
- def get_prompt_template(self):
- """
- 构造普通Prompt提示词模板
- """
- human_template = """
- {system_message}
- 用户的问题为:
- {question}
- 答案为:
- """
- human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)
- chat_prompt_template = ChatPromptTemplate.from_messages([human_message_prompt])
- return chat_prompt_template
- def get_model_generate_stream(self, task_prompt_info: dict, op_id, input_query, context=None, supplement_info=None):
- """
- 模型生成链
- """
- # Step 1: 定义系统提示词模板 system_prompt
- # Step 2: 构建完整的 prompt 模板
- prompt_template = ChatPromptTemplate.from_messages([
- ("system", task_prompt_info["task_prompt"]),
- ("human", "{input}")
- ])
- # Step 3: 初始化模型
- # Step 4: 使用模板格式化输入
- messages = prompt_template.invoke({"input": input_query})
- # Step 5: 流式调用模型
- response = self.llm.stream(messages)
- # Step 6: 逐 token 输出(打字机效果)
- for chunk in response:
- yield chunk.content
- async def get_model_tools_call(self, operate_id: str, session_id, task_prompt_info: dict, input_query, context=None, supplement_info=None,
- header_info=None):
- """
- 工具调用
- """
- # 构建输入消息
- input_message = self.get_input_context(
- trace_id= operate_id,
- task_prompt_info=task_prompt_info,
- input_query=input_query,
- context=context,
- supplement_info=supplement_info,
- header_info=header_info
- )
- # Step 1: 构建完整的 prompt 模板
- prompt_template = ChatPromptTemplate.from_messages([
- ("system", self.system_prompt),
- ("human", "{input}")
- ])
- # Step 2: 调用带有工具的 LLM
- # response = self.llm_with_tools.invoke(
- # [HumanMessage(content="北京的天气怎么样?")]
- # )
- messages = prompt_template.format_messages(input=input_message)
- response = await self.llm_with_tools.ainvoke(messages)
- #server_logger.info(f"response={response},{dir(response)}")
- # 2. 检查是否有工具调用
- if "tool_calls" in response.additional_kwargs:
- # 构造符合要求的 AIMessage
- tool_call_message = AIMessage(
- content="",
- additional_kwargs=response.additional_kwargs
- )
- server_logger.info(operate_id=operate_id, msg=f"self.tool_node_list={self.tool_node_list}")
- # 传入格式化的消息
- tool_response = await self.tool_node_list.ainvoke({"messages": [tool_call_message]})
- #server_logger.info(operate_id=operate_id, msg=f"tool_response={tool_response}")
- tools_message_result_list = []
- for tools_message in tool_response["messages"]:
- tools_message_result_list.append(tools_message.content)
- result = "\n".join(tools_message_result_list)
- server_logger.info(operate_id=operate_id, msg=f"tool_calls.tool_response.result={result}")
- result = self.clean_json_output(result)
- return result
- else:
- result = response.content
- server_logger.info(operate_id=operate_id, msg=f"response.content={result}")
- result = self.clean_json_output(result)
- return result
- def get_input_context(
- self,
- trace_id: str,
- task_prompt_info: dict,
- input_query: str,
- context: Optional[str] = None,
- supplement_info: Optional[str] = None,
- header_info: Optional[Dict] = None
- ) -> str:
- #server_logger.info(f"task_prompt_info: {task_prompt_info}")
- """构建问题和上下文"""
- context = context or "无"
- supplement_info = supplement_info or "无"
- token = header_info.get('token', '') if header_info else ''
- tenantId = header_info.get('tenantId', '') if header_info else ''
- task_prompt_info_str = task_prompt_info["task_prompt"]
- # 针对场景优化的上下文提示
- base_context_prompt = """
- 日志链路跟踪ID:{trace_id}
- 任务信息:{task_prompt_info_str}
- 相关上下文数据:{context}
- 补充信息:{supplement_info}
- 户问题:{input}
- 安全验证:{token}
- 场ID:{tenantId}
- """
- return base_context_prompt.format(
- trace_id=trace_id,
- task_prompt_info_str=task_prompt_info_str,
- context=context,
- input=input_query,
- supplement_info=supplement_info,
- token=token,
- tenantId=tenantId
- )
- def clean_json_output(self , raw_output: str) -> str:
- """
- 去除开头和结尾的 ```json 和 ```
- """
- cleaned = raw_output.strip()
- if cleaned.startswith("```json"):
- cleaned = cleaned[7:] # 去掉开头的 ```json
- if cleaned.endswith("```"):
- cleaned = cleaned[:-3] # 去掉结尾的 ```
- return cleaned.strip()
- #
- xwzc_generate_client = XiwuzcModelGenerateClient()
|