model_config_loader.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. """
  4. 模型功能配置加载器
  5. 从 model_setting.yaml 加载模型功能配置
  6. 提供按功能获取模型配置的接口
  7. 使用方式:
  8. from config.model_config_loader import get_model_for_function, get_thinking_mode_for_function
  9. model = get_model_for_function("doc_classification_secondary")
  10. thinking = get_thinking_mode_for_function("doc_classification_secondary")
  11. """
  12. from pathlib import Path
  13. from typing import Dict, Any, Optional
  14. from dataclasses import dataclass
  15. import yaml
  16. # 延迟导入 logger,避免循环依赖
  17. _logger = None
  18. def _get_logger():
  19. global _logger
  20. if _logger is None:
  21. try:
  22. from foundation.observability.logger.loggering import review_logger as logger
  23. _logger = logger
  24. except ImportError:
  25. import logging
  26. _logger = logging.getLogger(__name__)
  27. return _logger
  28. @dataclass
  29. class ModelFunctionConfig:
  30. """模型功能配置"""
  31. model: str
  32. enable_thinking: Optional[bool] = None
  33. description: str = ""
  34. class ModelConfigLoader:
  35. """模型配置加载器(单例)"""
  36. _instance: Optional["ModelConfigLoader"] = None
  37. _config: Optional[Dict[str, Any]] = None
  38. def __new__(cls) -> "ModelConfigLoader":
  39. if cls._instance is None:
  40. cls._instance = super().__new__(cls)
  41. cls._instance._initialized = False
  42. return cls._instance
  43. def __init__(self):
  44. if self._initialized:
  45. return
  46. self._initialized = True
  47. self._load_config()
  48. def _get_config_path(self) -> Path:
  49. """获取配置文件路径"""
  50. # 配置文件位于项目根目录 config/ 下
  51. return Path(__file__).parent / "model_setting.yaml"
  52. def _load_config(self):
  53. """加载 YAML 配置文件"""
  54. config_path = self._get_config_path()
  55. if not config_path.exists():
  56. _get_logger().warning(f"[ModelConfig] 配置文件不存在: {config_path},使用默认配置")
  57. self._config = self._get_default_config()
  58. return
  59. try:
  60. with open(config_path, 'r', encoding='utf-8') as f:
  61. self._config = yaml.safe_load(f)
  62. _get_logger().info(f"[ModelConfig] 已加载模型配置: {config_path}")
  63. except Exception as e:
  64. _get_logger().error(f"[ModelConfig] 加载配置文件失败: {e}")
  65. self._config = self._get_default_config()
  66. def _get_default_config(self) -> Dict[str, Any]:
  67. """获取默认配置"""
  68. return {
  69. "default": {
  70. "model": "qwen3_5_35b_a3b",
  71. "enable_thinking": False
  72. },
  73. "model_settings": {}
  74. }
  75. def get_model_config(self, function_name: str) -> ModelFunctionConfig:
  76. """
  77. 获取指定功能的模型配置
  78. Args:
  79. function_name: 功能名称(如 doc_classification_secondary)
  80. Returns:
  81. ModelFunctionConfig: 模型配置
  82. """
  83. settings = self._config.get("model_settings", {})
  84. default = self._config.get("default", {})
  85. # 获取功能配置,如果不存在则使用默认
  86. func_config = settings.get(function_name, default)
  87. # 合并默认值
  88. model = func_config.get("model", default.get("model", "qwen3_5_35b_a3b"))
  89. enable_thinking = func_config.get("enable_thinking", default.get("enable_thinking", False))
  90. description = func_config.get("description", "")
  91. return ModelFunctionConfig(
  92. model=model,
  93. enable_thinking=enable_thinking,
  94. description=description
  95. )
  96. def get_model_name(self, function_name: str) -> str:
  97. """获取指定功能的模型名称"""
  98. return self.get_model_config(function_name).model
  99. def get_enable_thinking(self, function_name: str) -> Optional[bool]:
  100. """获取指定功能是否启用思考模式"""
  101. return self.get_model_config(function_name).enable_thinking
  102. def get_available_models(self) -> list:
  103. """获取可用模型列表"""
  104. return self._config.get("available_models", [])
  105. def list_functions(self) -> Dict[str, str]:
  106. """列出所有已配置的功能及其描述"""
  107. settings = self._config.get("model_settings", {})
  108. return {
  109. name: config.get("description", "无描述")
  110. for name, config in settings.items()
  111. }
  112. # 全局单例
  113. model_config_loader = ModelConfigLoader()
  114. # 便捷函数
  115. def get_model_for_function(function_name: str) -> str:
  116. """获取指定功能使用的模型名称"""
  117. return model_config_loader.get_model_name(function_name)
  118. def get_thinking_mode_for_function(function_name: str) -> Optional[bool]:
  119. """获取指定功能的思考模式配置"""
  120. return model_config_loader.get_enable_thinking(function_name)
  121. def get_full_config_for_function(function_name: str) -> ModelFunctionConfig:
  122. """获取指定功能的完整配置"""
  123. return model_config_loader.get_model_config(function_name)