|
|
@@ -12,18 +12,16 @@ import logging
|
|
|
from pathlib import Path
|
|
|
from typing import Any
|
|
|
|
|
|
-# 远程训练节点可能没有 pydantic-settings,用环境变量兜底
|
|
|
-try:
|
|
|
- from app.config import get_settings
|
|
|
- settings = get_settings()
|
|
|
-except ImportError:
|
|
|
- from types import SimpleNamespace
|
|
|
- settings = SimpleNamespace(
|
|
|
- data_dir=Path(os.environ.get("COMPUTE_NODE_REMOTE_DATA_DIR", "/root/Fine-tuning/backend/data")),
|
|
|
- processed_dir=Path(os.environ.get("COMPUTE_NODE_REMOTE_DATA_DIR", "/root/Fine-tuning/backend/data")) / "processed",
|
|
|
- adapters_dir=Path(os.environ.get("COMPUTE_NODE_REMOTE_DATA_DIR", "/root/Fine-tuning/backend/data")) / "adapters",
|
|
|
- models_dir=Path(os.environ.get("COMPUTE_NODE_REMOTE_DATA_DIR", "/root/Fine-tuning/backend/data")) / "models",
|
|
|
- )
|
|
|
+# 远程训练节点没有 pydantic-settings/数据库,直接用环境变量
|
|
|
+from types import SimpleNamespace
|
|
|
+
|
|
|
+_data_dir = Path(os.environ.get("COMPUTE_NODE_REMOTE_DATA_DIR", "/root/Fine-tuning/backend/data"))
|
|
|
+settings = SimpleNamespace(
|
|
|
+ data_dir=_data_dir,
|
|
|
+ processed_dir=_data_dir / "processed",
|
|
|
+ adapters_dir=_data_dir / "adapters",
|
|
|
+ models_dir=_data_dir / "models",
|
|
|
+)
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@@ -193,14 +191,7 @@ class TextEngine(BaseEngine):
|
|
|
callbacks=all_callbacks,
|
|
|
)
|
|
|
else:
|
|
|
- from trl import (
|
|
|
- DPOConfig,
|
|
|
- DPOTrainer,
|
|
|
- KTOConfig,
|
|
|
- KTOTrainer,
|
|
|
- ORPOConfig,
|
|
|
- ORPOTrainer,
|
|
|
- )
|
|
|
+ from trl import DPOConfig, DPOTrainer
|
|
|
|
|
|
base_trainer_kwargs = dict(
|
|
|
output_dir=output_dir,
|
|
|
@@ -223,6 +214,8 @@ class TextEngine(BaseEngine):
|
|
|
processing_class=self._tokenizer,
|
|
|
)
|
|
|
elif task_type == "orpo":
|
|
|
+ from trl import ORPOConfig, ORPOTrainer
|
|
|
+
|
|
|
trainer = ORPOTrainer(
|
|
|
model=self._model,
|
|
|
args=ORPOConfig(**base_trainer_kwargs),
|
|
|
@@ -230,6 +223,8 @@ class TextEngine(BaseEngine):
|
|
|
processing_class=self._tokenizer,
|
|
|
)
|
|
|
elif task_type == "kto":
|
|
|
+ from trl import KTOConfig, KTOTrainer
|
|
|
+
|
|
|
trainer = KTOTrainer(
|
|
|
model=self._model,
|
|
|
args=KTOConfig(**base_trainer_kwargs),
|