tool_utils.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. import time
  2. from math import log
  3. import os
  4. from dotenv import load_dotenv
  5. from enums.common_enums import BusinessSceneEnum, ErrorCodeEnum, UserRoleEnum
  6. from functools import wraps
  7. from logger.loggering import server_logger
  8. from utils.common import handler_err
  9. from base.config import config_handler
  10. # 获取当前文件的目录
  11. current_dir = os.path.dirname(__file__)
  12. # 构建到 .env 的相对路径
  13. conf_file_path = os.path.join(current_dir , '../', '.env')
  14. #server_logger.info(f"当前目录: {conf_file_path}")
  15. # 加载环境变量
  16. load_dotenv(dotenv_path=conf_file_path)
  17. def verify_param(param: dict):
  18. """
  19. 验证请求参数
  20. """
  21. input_data = param.get("input")
  22. session_id = param.get("config").get("session_id")
  23. businessScene = param.get("businessScene")
  24. if input_data is None:
  25. raise ValueError(ErrorCodeEnum.INPUT_INFO_EMPTY.__str__)
  26. if session_id is None:
  27. raise ValueError(ErrorCodeEnum.SESSION_ID_EMPTY.__str__)
  28. # 是否可使用默认的通用模型查询 默认 False
  29. use_default_common_model_query = os.environ.get("USE_DEFAULT_COMMON_MODEL_QUERY" , False)
  30. server_logger.info(f"使用可默认的通用模型查询: {use_default_common_model_query}")
  31. if not use_default_common_model_query:
  32. if businessScene is None:
  33. raise ValueError(ErrorCodeEnum.BUSINSESS_SCENE_EMPTY.__str__)
  34. if not BusinessSceneEnum.get_item_by_code(param.get('businessScene')):
  35. raise ValueError(ErrorCodeEnum.BUSINSESS_SCENE_ERROR.__str__)
  36. def get_system_prompt() -> str:
  37. """
  38. 获取系统提示语
  39. """
  40. system_prompt = config_handler.get("system", "SYSTEM_PROMPT")
  41. server_logger.info(f"获取系统提示语: {system_prompt}")
  42. return str(system_prompt)
  43. def get_business_scene_prompt(business_scene):
  44. """
  45. 获取业务场景的提示语
  46. """
  47. # 默认公共查询提示语
  48. business_scene_enum = BusinessSceneEnum.COMMON_MODEL_QUERY
  49. prompt_file = business_scene_enum.prompt_file
  50. # 是否可使用默认的通用模型查询 默认 False
  51. use_default_common_model_query = os.environ.get("USE_DEFAULT_COMMON_MODEL_QUERY" , False)
  52. if not business_scene is None:
  53. business_scene_enum = BusinessSceneEnum.get_item_by_code(business_scene)
  54. if not business_scene_enum:
  55. raise ValueError("未找到枚举值")
  56. if business_scene_enum.prompt_file is None:
  57. raise ValueError("业务场景不存在")
  58. prompt_file = business_scene_enum.prompt_file
  59. prompt_file = os.path.join(current_dir , '../', 'config', 'prompt' , prompt_file)
  60. server_logger.info(f"获取业务场景提示语: {prompt_file}")
  61. if not os.path.exists(prompt_file):
  62. raise ValueError("业务场景不存在")
  63. try:
  64. with open(prompt_file, 'r', encoding='utf-8') as f:
  65. return business_scene_enum , '\n'.join(f.readlines())
  66. except Exception as e:
  67. handler_err(server_logger, e,err_name="get_business_scene_prompt")
  68. server_logger.error(f"获取业务场景提示语失败: {e}")
  69. raise e
  70. def get_fixed_problem_answer_txt_content(file_name: str):
  71. """
  72. 获取固定问题答案内容
  73. """
  74. file_name = file_name+".txt"
  75. answer_txt_file = os.path.join(current_dir , '../', 'config', 'fixed_answer' , file_name)
  76. server_logger.info(f"固定回答文本内容: {answer_txt_file}")
  77. if not os.path.exists(answer_txt_file):
  78. raise ValueError("固定回答文本不存在")
  79. try:
  80. result_list = []
  81. with open(answer_txt_file, 'r', encoding='utf-8') as f:
  82. result_list=f.readlines()
  83. return "".join(result_list)
  84. except Exception as e:
  85. handler_err(server_logger, e,err_name="get_fixed_problem_answer_txt_content")
  86. server_logger.error(f"获取固定回答文本失败: {e}")
  87. raise e
  88. def verify_user_role(user_role: str):
  89. """
  90. 验证用户角色
  91. 普通用户 common ,不能检索查询知识库
  92. 租户用户 tenant ,只有租户才能检索查询知识库
  93. """
  94. if user_role in [UserRoleEnum.TENANT.code]:
  95. return True
  96. return False