i_chat_step.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. # coding=utf-8
  2. """
  3. @project: maxkb
  4. @Author:虎
  5. @file: i_chat_step.py
  6. @date:2024/1/9 18:17
  7. @desc: 对话
  8. """
  9. from abc import abstractmethod
  10. from typing import Type, List
  11. from django.utils.translation import gettext_lazy as _
  12. from langchain.chat_models.base import BaseChatModel
  13. from langchain_core.messages import BaseMessage
  14. from rest_framework import serializers
  15. from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel
  16. from application.chat_pipeline.pipeline_manage import PipelineManage
  17. from application.serializers.application import NoReferencesSetting
  18. from common.field.common import InstanceField
  19. class ModelField(serializers.Field):
  20. def to_internal_value(self, data):
  21. if not isinstance(data, BaseChatModel):
  22. self.fail(_('Model type error'), value=data)
  23. return data
  24. def to_representation(self, value):
  25. return value
  26. class MessageField(serializers.Field):
  27. def to_internal_value(self, data):
  28. if not isinstance(data, BaseMessage):
  29. self.fail(_('Message type error'), value=data)
  30. return data
  31. def to_representation(self, value):
  32. return value
  33. class PostResponseHandler:
  34. @abstractmethod
  35. def handler(self, chat_id, chat_record_id, paragraph_list: List[ParagraphPipelineModel], problem_text: str,
  36. answer_text,
  37. manage, step, padding_problem_text: str = None, **kwargs):
  38. pass
  39. class IChatStep(IBaseChatPipelineStep):
  40. class InstanceSerializer(serializers.Serializer):
  41. # 对话列表
  42. message_list = serializers.ListField(required=True, child=MessageField(required=True),
  43. label=_("Conversation list"))
  44. model_id = serializers.UUIDField(required=False, allow_null=True, label=_("Model id"))
  45. # 段落列表
  46. paragraph_list = serializers.ListField(label=_("Paragraph List"))
  47. # 对话id
  48. chat_id = serializers.UUIDField(required=True, label=_("Conversation ID"))
  49. # 用户问题
  50. problem_text = serializers.CharField(required=True, label=_("User Questions"))
  51. # 后置处理器
  52. post_response_handler = InstanceField(model_type=PostResponseHandler,
  53. label=_("Post-processor"))
  54. # 补全问题
  55. padding_problem_text = serializers.CharField(required=False,
  56. label=_("Completion Question"))
  57. # 是否使用流的形式输出
  58. stream = serializers.BooleanField(required=False, label=_("Streaming Output"))
  59. chat_user_id = serializers.CharField(required=True, label=_("Chat user id"))
  60. chat_record_id = serializers.CharField(required=False, label=_("Chat record id"))
  61. chat_user_type = serializers.CharField(required=True, label=_("Chat user Type"))
  62. # 未查询到引用分段
  63. no_references_setting = NoReferencesSetting(required=True,
  64. label=_("No reference segment settings"))
  65. workspace_id = serializers.CharField(required=True, label=_("Workspace ID"))
  66. model_setting = serializers.DictField(required=True, allow_null=True,
  67. label=_("Model settings"))
  68. model_params_setting = serializers.DictField(required=False, allow_null=True,
  69. label=_("Model parameter settings"))
  70. mcp_tool_ids = serializers.JSONField(label="MCP工具ID列表", required=False, default=list)
  71. mcp_servers = serializers.JSONField(label="MCP服务列表", required=False, default=dict)
  72. mcp_source = serializers.CharField(label="MCP Source", required=False, default="referencing")
  73. tool_ids = serializers.JSONField(label="工具ID列表", required=False, default=list)
  74. application_ids = serializers.JSONField(label="应用ID列表", required=False, default=list)
  75. skill_tool_ids = serializers.JSONField(label="技能ID列表", required=False, default=list)
  76. mcp_output_enable = serializers.BooleanField(label="MCP输出是否启用", required=False, default=True)
  77. def is_valid(self, *, raise_exception=False):
  78. super().is_valid(raise_exception=True)
  79. message_list: List = self.initial_data.get('message_list')
  80. for message in message_list:
  81. if not isinstance(message, BaseMessage):
  82. raise Exception(_("message type error"))
  83. def get_step_serializer(self, manage: PipelineManage) -> Type[serializers.Serializer]:
  84. return self.InstanceSerializer
  85. def _run(self, manage: PipelineManage):
  86. chat_result = self.execute(**self.context['step_args'], manage=manage)
  87. manage.context['chat_result'] = chat_result
  88. @abstractmethod
  89. def execute(self, message_list: List[BaseMessage],
  90. chat_id, problem_text,
  91. post_response_handler: PostResponseHandler,
  92. model_id: str = None,
  93. workspace_id: str = None,
  94. paragraph_list=None,
  95. manage: PipelineManage = None,
  96. padding_problem_text: str = None, stream: bool = True, chat_user_id=None, chat_user_type=None,
  97. no_references_setting=None, model_params_setting=None, model_setting=None,
  98. mcp_tool_ids=None, mcp_servers='', mcp_source="referencing",
  99. tool_ids=None, application_ids=None, skill_tool_ids=None, mcp_output_enable=True,
  100. **kwargs):
  101. pass