test_utils.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. import os
  2. import sys
  3. sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
  4. from datetime import datetime
  5. from typing import List, Dict, Optional
  6. from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
  7. from langchain_openai import ChatOpenAI
  8. from test_config import config_handler
  9. from logger.loggering import server_logger
  10. def get_models():
  11. """
  12. 获取模型,模型类型 默认为deepseek 、qwen
  13. """
  14. model_type = config_handler.get("model", "MODEL_TYPE")
  15. server_logger.info(f"get_models -> model_type:{model_type}")
  16. if model_type.upper() == "QWEN":
  17. return get_deploy_qwen_models()
  18. return get_deepseek_models()
  19. def get_deepseek_models():
  20. """
  21. 获取DeepSeek模型
  22. """
  23. deepseek_model_server_url = config_handler.get("deepseek", "DEEPSEEK_SERVER_URL")
  24. deepseek_chat_model_id = config_handler.get("deepseek", "DEEPSEEK_MODEL_ID")
  25. deepseek_api_key = config_handler.get("deepseek", "DEEPSEEK_API_KEY")
  26. server_logger.info(f"get_deepseek_models -> chat_model_id:{deepseek_chat_model_id},api_key:{deepseek_api_key}")
  27. if deepseek_model_server_url is None or deepseek_chat_model_id is None or deepseek_api_key is None:
  28. server_logger.error("请设置环境变量: DEEPSEEK_SERVER_URL, DEEPSEEK_MODEL_ID, DEEPSEEK_API_KEY")
  29. raise Exception("设置环境变量: DEEPSEEK_SERVER_URL, DEEPSEEK_MODEL_ID, DEEPSEEK_API_KEY")
  30. # llm 大模型
  31. llm = ChatOpenAI(base_url=deepseek_model_server_url,
  32. api_key=deepseek_api_key,
  33. model=deepseek_chat_model_id,
  34. max_tokens=4096,
  35. temperature=0.3,
  36. top_p=0.7,
  37. extra_body={
  38. "enable_thinking": False # 添加这个参数以避免报错
  39. })
  40. # chat 大模型
  41. chat = ChatOpenAI(base_url=deepseek_model_server_url,
  42. api_key=deepseek_api_key,
  43. model=deepseek_chat_model_id,
  44. max_tokens=4096,
  45. temperature=0.3,
  46. top_p=0.2,
  47. extra_body={
  48. "enable_thinking": False # 添加这个参数以避免报错
  49. })
  50. embed = None
  51. return llm, chat, embed
  52. # 获取千问模型
  53. def get_deploy_qwen_models():
  54. """
  55. 加载千问系列大模型-魔搭在线Qwen3 API服务
  56. """
  57. model_server_url = config_handler.get("qwen", "MODEL_SERVER_URL")
  58. chat_model_id = config_handler.get("qwen", "CHAT_MODEL_ID")
  59. api_key = config_handler.get("qwen", "API_KEY")
  60. embedding_model_id = config_handler.get("qwen", "EMBED_MODEL_ID")
  61. # temperature = os.getenv("CHAT_MODEL_TEMPERATURE")
  62. server_logger.info(
  63. f"get_qwen_chat_model -> chat_model_id:{chat_model_id},api_key:{api_key},embedding_model_id:{embedding_model_id}")
  64. if model_server_url is None or chat_model_id is None or api_key is None:
  65. server_logger.error("请设置环境变量: MODEL_SERVER_URL, CHAT_MODEL_ID, API_KEY")
  66. raise Exception("请设置环境变量: MODEL_SERVER_URL, CHAT_MODEL_ID, API_KEY")
  67. # llm 大模型
  68. llm = ChatOpenAI(base_url=model_server_url,
  69. api_key=api_key,
  70. model=chat_model_id,
  71. max_tokens=1024,
  72. temperature=0.5,
  73. top_p=0.7,
  74. extra_body={
  75. "enable_thinking": False # 添加这个参数以避免报错
  76. })
  77. # chat 大模型
  78. chat = ChatOpenAI(base_url=model_server_url,
  79. api_key=api_key,
  80. model=chat_model_id,
  81. max_tokens=1024,
  82. temperature=0.01,
  83. top_p=0.2,
  84. extra_body={
  85. "enable_thinking": False # 添加这个参数以避免报错
  86. })
  87. # embedding 大模型 text-embedding-v3 text-embedding-v4
  88. # from langchain_community.embeddings import DashScopeEmbeddings
  89. embed = None # DashScopeEmbeddings(model=embedding_model_id)
  90. return llm, chat, embed
  91. def test_qwen_chat_model():
  92. # 获取模型
  93. llm, chat, embed = get_deploy_qwen_models()
  94. example_query = "你好,你是谁?"
  95. result = llm.invoke(input=example_query)
  96. server_logger.info(f"result={result}")
  97. print(f"result={result}")
  98. def test_deepseek_chat_model():
  99. # 获取模型
  100. llm, chat, embed = get_deepseek_models()
  101. example_query = "你好,你是谁?"
  102. result = llm.invoke(input=example_query)
  103. server_logger.info(f"result={result}")
  104. print(f"result={result}")
  105. if __name__ == "__main__":
  106. test_qwen_chat_model() # 运行
  107. #test_deepseek_chat_model()