model_config_loader.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. """
  4. 模型功能配置加载器
  5. 从 model_setting.yaml 加载模型功能配置
  6. 提供按功能获取模型配置的接口
  7. 使用方式:
  8. from foundation.ai.models.model_config_loader import get_model_for_function, get_thinking_mode_for_function
  9. model = get_model_for_function("doc_classification_secondary")
  10. thinking = get_thinking_mode_for_function("doc_classification_secondary")
  11. """
  12. from pathlib import Path
  13. from typing import Dict, Any, Optional
  14. from dataclasses import dataclass
  15. import yaml
  16. # 延迟导入 logger,避免循环依赖
  17. _logger = None
  18. def _get_logger():
  19. global _logger
  20. if _logger is None:
  21. try:
  22. from foundation.observability.logger.loggering import review_logger as logger
  23. _logger = logger
  24. except ImportError:
  25. import logging
  26. _logger = logging.getLogger(__name__)
  27. return _logger
  28. @dataclass
  29. class ModelFunctionConfig:
  30. """模型功能配置"""
  31. model: str
  32. enable_thinking: Optional[bool] = None
  33. description: str = ""
  34. class ModelConfigLoader:
  35. """模型配置加载器(单例)"""
  36. _instance: Optional["ModelConfigLoader"] = None
  37. _config: Optional[Dict[str, Any]] = None
  38. def __new__(cls) -> "ModelConfigLoader":
  39. if cls._instance is None:
  40. cls._instance = super().__new__(cls)
  41. cls._instance._initialized = False
  42. return cls._instance
  43. def __init__(self):
  44. if self._initialized:
  45. return
  46. self._initialized = True
  47. self._load_config()
  48. def _get_config_path(self) -> Path:
  49. """获取配置文件路径"""
  50. # 配置文件位于项目根目录 config/ 下
  51. # 本文件位于 foundation/ai/models/,需要向上3层到项目根目录
  52. return Path(__file__).parent.parent.parent.parent / "config" / "model_setting.yaml"
  53. def _load_config(self):
  54. """加载 YAML 配置文件"""
  55. config_path = self._get_config_path()
  56. if not config_path.exists():
  57. _get_logger().warning(f"[ModelConfig] 配置文件不存在: {config_path},使用默认配置")
  58. self._config = self._get_default_config()
  59. return
  60. try:
  61. with open(config_path, 'r', encoding='utf-8') as f:
  62. self._config = yaml.safe_load(f)
  63. _get_logger().info(f"[ModelConfig] 已加载模型配置: {config_path}")
  64. except Exception as e:
  65. _get_logger().error(f"[ModelConfig] 加载配置文件失败: {e}")
  66. self._config = self._get_default_config()
  67. def _get_default_config(self) -> Dict[str, Any]:
  68. """获取默认配置"""
  69. return {
  70. "default": {
  71. "model": "qwen3_5_35b_a3b",
  72. "enable_thinking": False
  73. },
  74. "model_settings": {}
  75. }
  76. def get_model_config(self, function_name: str) -> ModelFunctionConfig:
  77. """
  78. 获取指定功能的模型配置
  79. Args:
  80. function_name: 功能名称(如 doc_classification_secondary)
  81. Returns:
  82. ModelFunctionConfig: 模型配置
  83. """
  84. settings = self._config.get("model_settings", {})
  85. default = self._config.get("default", {})
  86. # 获取功能配置,如果不存在则使用默认
  87. func_config = settings.get(function_name, default)
  88. # 合并默认值
  89. model = func_config.get("model", default.get("model", "qwen3_5_35b_a3b"))
  90. enable_thinking = func_config.get("enable_thinking", default.get("enable_thinking", False))
  91. description = func_config.get("description", "")
  92. return ModelFunctionConfig(
  93. model=model,
  94. enable_thinking=enable_thinking,
  95. description=description
  96. )
  97. def get_model_name(self, function_name: str) -> str:
  98. """获取指定功能的模型名称"""
  99. return self.get_model_config(function_name).model
  100. def get_enable_thinking(self, function_name: str) -> Optional[bool]:
  101. """获取指定功能是否启用思考模式"""
  102. return self.get_model_config(function_name).enable_thinking
  103. def get_available_models(self) -> list:
  104. """获取可用模型列表"""
  105. return self._config.get("available_models", [])
  106. def list_functions(self) -> Dict[str, str]:
  107. """列出所有已配置的功能及其描述"""
  108. settings = self._config.get("model_settings", {})
  109. return {
  110. name: config.get("description", "无描述")
  111. for name, config in settings.items()
  112. }
  113. # 全局单例
  114. model_config_loader = ModelConfigLoader()
  115. # 便捷函数
  116. def get_model_for_function(function_name: str) -> str:
  117. """获取指定功能使用的模型名称"""
  118. return model_config_loader.get_model_name(function_name)
  119. def get_thinking_mode_for_function(function_name: str) -> Optional[bool]:
  120. """获取指定功能的思考模式配置"""
  121. return model_config_loader.get_enable_thinking(function_name)
  122. def get_full_config_for_function(function_name: str) -> ModelFunctionConfig:
  123. """获取指定功能的完整配置"""
  124. return model_config_loader.get_model_config(function_name)