model_generate.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. # !/usr/bin/ python
  2. # -*- coding: utf-8 -*-
  3. '''
  4. @Project : lq-agent-api
  5. @File :model_generate.py
  6. @IDE :PyCharm
  7. @Author :
  8. @Date :2025/7/14 14:22
  9. '''
  10. from typing import List, Dict, Any, Optional, AsyncGenerator
  11. from langchain_core.prompts import HumanMessagePromptTemplate
  12. from langchain_core.prompts import ChatPromptTemplate
  13. from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
  14. from langgraph.prebuilt import ToolNode
  15. from utils.utils import get_models
  16. from views import mcp_server
  17. from utils.yaml_utils import system_prompt_config
  18. from logger.loggering import server_logger
  19. class XiwuzcModelGenerateClient:
  20. """
  21. 主要是生成式模型
  22. """
  23. def __init__(self):
  24. # 获取部署的模型列表
  25. llm, chat, embed = get_models()
  26. self.llm = llm
  27. self.chat = chat
  28. # 构造工具列表
  29. self.tool_node_list = [] # ToolNode(mcp_server.tools)
  30. # 模型绑定工具列表
  31. self.llm_with_tools = None #llm.bind_tools(mcp_server.tools)
  32. # 工具调用系统提示词
  33. self.system_prompt = "" #system_prompt_config["tools_system_prompt"]
  34. def get_prompt_template(self):
  35. """
  36. 构造普通Prompt提示词模板
  37. """
  38. human_template = """
  39. {system_message}
  40. 用户的问题为:
  41. {question}
  42. 答案为:
  43. """
  44. human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)
  45. chat_prompt_template = ChatPromptTemplate.from_messages([human_message_prompt])
  46. return chat_prompt_template
  47. def get_model_generate_stream(self, task_prompt_info: dict, op_id, input_query, context=None, supplement_info=None):
  48. """
  49. 模型生成链
  50. """
  51. # Step 1: 定义系统提示词模板 system_prompt
  52. # Step 2: 构建完整的 prompt 模板
  53. prompt_template = ChatPromptTemplate.from_messages([
  54. ("system", task_prompt_info["task_prompt"]),
  55. ("human", "{input}")
  56. ])
  57. # Step 3: 初始化模型
  58. # Step 4: 使用模板格式化输入
  59. messages = prompt_template.invoke({"input": input_query})
  60. # Step 5: 流式调用模型
  61. response = self.llm.stream(messages)
  62. # Step 6: 逐 token 输出(打字机效果)
  63. for chunk in response:
  64. yield chunk.content
  65. async def get_model_tools_call(self, operate_id: str, session_id, task_prompt_info: dict, input_query, context=None, supplement_info=None,
  66. header_info=None):
  67. """
  68. 工具调用
  69. """
  70. # 构建输入消息
  71. input_message = self.get_input_context(
  72. trace_id= operate_id,
  73. task_prompt_info=task_prompt_info,
  74. input_query=input_query,
  75. context=context,
  76. supplement_info=supplement_info,
  77. header_info=header_info
  78. )
  79. # Step 1: 构建完整的 prompt 模板
  80. prompt_template = ChatPromptTemplate.from_messages([
  81. ("system", self.system_prompt),
  82. ("human", "{input}")
  83. ])
  84. # Step 2: 调用带有工具的 LLM
  85. # response = self.llm_with_tools.invoke(
  86. # [HumanMessage(content="北京的天气怎么样?")]
  87. # )
  88. messages = prompt_template.format_messages(input=input_message)
  89. response = await self.llm_with_tools.ainvoke(messages)
  90. #server_logger.info(f"response={response},{dir(response)}")
  91. # 2. 检查是否有工具调用
  92. if "tool_calls" in response.additional_kwargs:
  93. # 构造符合要求的 AIMessage
  94. tool_call_message = AIMessage(
  95. content="",
  96. additional_kwargs=response.additional_kwargs
  97. )
  98. server_logger.info(operate_id=operate_id, msg=f"self.tool_node_list={self.tool_node_list}")
  99. # 传入格式化的消息
  100. tool_response = await self.tool_node_list.ainvoke({"messages": [tool_call_message]})
  101. #server_logger.info(operate_id=operate_id, msg=f"tool_response={tool_response}")
  102. tools_message_result_list = []
  103. for tools_message in tool_response["messages"]:
  104. tools_message_result_list.append(tools_message.content)
  105. result = "\n".join(tools_message_result_list)
  106. server_logger.info(operate_id=operate_id, msg=f"tool_calls.tool_response.result={result}")
  107. result = self.clean_json_output(result)
  108. return result
  109. else:
  110. result = response.content
  111. server_logger.info(operate_id=operate_id, msg=f"response.content={result}")
  112. result = self.clean_json_output(result)
  113. return result
  114. def get_input_context(
  115. self,
  116. trace_id: str,
  117. task_prompt_info: dict,
  118. input_query: str,
  119. context: Optional[str] = None,
  120. supplement_info: Optional[str] = None,
  121. header_info: Optional[Dict] = None
  122. ) -> str:
  123. #server_logger.info(f"task_prompt_info: {task_prompt_info}")
  124. """构建问题和上下文"""
  125. context = context or "无"
  126. supplement_info = supplement_info or "无"
  127. token = header_info.get('token', '') if header_info else ''
  128. tenantId = header_info.get('tenantId', '') if header_info else ''
  129. task_prompt_info_str = task_prompt_info["task_prompt"]
  130. # 针对场景优化的上下文提示
  131. base_context_prompt = """
  132. 日志链路跟踪ID:{trace_id}
  133. 任务信息:{task_prompt_info_str}
  134. 相关上下文数据:{context}
  135. 补充信息:{supplement_info}
  136. 户问题:{input}
  137. 安全验证:{token}
  138. 场ID:{tenantId}
  139. """
  140. return base_context_prompt.format(
  141. trace_id=trace_id,
  142. task_prompt_info_str=task_prompt_info_str,
  143. context=context,
  144. input=input_query,
  145. supplement_info=supplement_info,
  146. token=token,
  147. tenantId=tenantId
  148. )
  149. def clean_json_output(self , raw_output: str) -> str:
  150. """
  151. 去除开头和结尾的 ```json 和 ```
  152. """
  153. cleaned = raw_output.strip()
  154. if cleaned.startswith("```json"):
  155. cleaned = cleaned[7:] # 去掉开头的 ```json
  156. if cleaned.endswith("```"):
  157. cleaned = cleaned[:-3] # 去掉结尾的 ```
  158. return cleaned.strip()
  159. #
  160. xwzc_generate_client = XiwuzcModelGenerateClient()