i_reranker_node.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. # coding=utf-8
  2. """
  3. @project: MaxKB
  4. @Author:虎
  5. @file: i_reranker_node.py
  6. @date:2024/9/4 10:40
  7. @desc:
  8. """
  9. from typing import Type
  10. from rest_framework import serializers
  11. from application.flow.common import WorkflowMode
  12. from application.flow.i_step_node import INode, NodeResult
  13. from django.utils.translation import gettext_lazy as _
  14. class RerankerSettingSerializer(serializers.Serializer):
  15. # 需要查询的条数
  16. top_n = serializers.IntegerField(required=True,
  17. label=_("Reference segment number"))
  18. # 相似度 0-1之间
  19. similarity = serializers.FloatField(required=True, max_value=2, min_value=0,
  20. label=_("Reference segment number"))
  21. max_paragraph_char_number = serializers.IntegerField(required=True,
  22. label=_("Maximum number of words in a quoted segment"))
  23. class RerankerStepNodeSerializer(serializers.Serializer):
  24. reranker_setting = RerankerSettingSerializer(required=True)
  25. question_reference_address = serializers.ListField(required=True)
  26. reranker_model_id = serializers.CharField(required=False, allow_blank=True, allow_null=True)
  27. reranker_model_id_type = serializers.CharField(required=False, default='custom')
  28. reranker_model_id_reference = serializers.ListField(required=False, child=serializers.CharField(), allow_empty=True)
  29. reranker_reference_list = serializers.ListField(required=True, child=serializers.ListField(required=True))
  30. show_knowledge = serializers.BooleanField(required=True,
  31. label=_("The results are displayed in the knowledge sources"))
  32. def is_valid(self, *, raise_exception=False):
  33. super().is_valid(raise_exception=True)
  34. class IRerankerNode(INode):
  35. type = 'reranker-node'
  36. support = [WorkflowMode.APPLICATION, WorkflowMode.APPLICATION_LOOP, WorkflowMode.TOOL, WorkflowMode.TOOL_LOOP]
  37. def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
  38. return RerankerStepNodeSerializer
  39. def _run(self):
  40. question = self.workflow_manage.get_reference_field(
  41. self.node_params_serializer.data.get('question_reference_address')[0],
  42. self.node_params_serializer.data.get('question_reference_address')[1:])
  43. reranker_list = [self.workflow_manage.get_reference_field(
  44. reference[0],
  45. reference[1:]) for reference in
  46. self.node_params_serializer.data.get('reranker_reference_list')]
  47. node_params_data = dict(self.node_params_serializer.data)
  48. reranker_model_id_type = node_params_data.pop('reranker_model_id_type', None)
  49. reranker_model_id_reference = node_params_data.pop('reranker_model_id_reference', None)
  50. reranker_model_id = node_params_data.pop('reranker_model_id', None)
  51. # 处理引用类型
  52. if reranker_model_id_type == 'reference' and reranker_model_id_reference:
  53. reference_data = self.workflow_manage.get_reference_field(
  54. reranker_model_id_reference[0],
  55. reranker_model_id_reference[1:],
  56. )
  57. if reference_data and isinstance(reference_data, dict):
  58. reranker_model_id = reference_data.get('reranker_model_id',
  59. reference_data.get('model_id', reranker_model_id))
  60. if reranker_model_id is None or reranker_model_id == '':
  61. raise Exception(_('Model is not allowed to be empty'))
  62. return self.execute(**node_params_data, question=str(question),
  63. reranker_list=reranker_list, reranker_model_id=reranker_model_id)
  64. def execute(self, question, reranker_setting, reranker_list, reranker_model_id, show_knowledge,
  65. **kwargs) -> NodeResult:
  66. pass