Parcourir la source

修复已知bug

lxylxy123321 il y a 1 semaine
Parent
commit
613455b7de

+ 11 - 12
backend/app/engines/multimodal_engine.py

@@ -1,19 +1,18 @@
+import os
 import json
 import json
 from pathlib import Path
 from pathlib import Path
 from typing import Any
 from typing import Any
 
 
-# 远程训练节点可能没有 pydantic-settings,用环境变量兜底
-try:
-    from app.config import get_settings
-    settings = get_settings()
-except ImportError:
-    from types import SimpleNamespace
-    settings = SimpleNamespace(
-        data_dir=Path("/root/Fine-tuning/backend/data"),
-        processed_dir=Path("/root/Fine-tuning/backend/data") / "processed",
-        adapters_dir=Path("/root/Fine-tuning/backend/data") / "adapters",
-        models_dir=Path("/root/Fine-tuning/backend/data") / "models",
-    )
+# 远程训练节点没有 pydantic-settings/数据库,直接用环境变量
+from types import SimpleNamespace
+
+_data_dir = Path(os.environ.get("COMPUTE_NODE_REMOTE_DATA_DIR", "/root/Fine-tuning/backend/data"))
+settings = SimpleNamespace(
+    data_dir=_data_dir,
+    processed_dir=_data_dir / "processed",
+    adapters_dir=_data_dir / "adapters",
+    models_dir=_data_dir / "models",
+)
 
 
 import logging
 import logging
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)

+ 8 - 1
backend/app/engines/remote_train.py

@@ -48,7 +48,10 @@ def _write_log(**kwargs):
 
 
 
 
 class FileProgressCallback:
 class FileProgressCallback:
-    """HuggingFace Trainer 回调 — 写进度到共享日志文件。"""
+    """HuggingFace Trainer 回调 — 写进度到共享日志文件。
+
+    只实现关心的回调,其余通过 __getattr__ 自动忽略。
+    """
 
 
     def __init__(self, job_id: str):
     def __init__(self, job_id: str):
         self.job_id = job_id
         self.job_id = job_id
@@ -84,6 +87,10 @@ class FileProgressCallback:
                        eval_loss=metrics.get("eval_loss"),
                        eval_loss=metrics.get("eval_loss"),
                        eval_accuracy=metrics.get("eval_accuracy"))
                        eval_accuracy=metrics.get("eval_accuracy"))
 
 
+    def __getattr__(self, name):
+        """Trainer 期望其他回调方法存在,返回一个空函数自动忽略。"""
+        return lambda *args, **kwargs: None
+
 
 
 async def run_training(job_id: str, model_id: str, model_type: str, dataset_path: str, config: dict):
 async def run_training(job_id: str, model_id: str, model_type: str, dataset_path: str, config: dict):
     """执行单个训练任务(远程调用入口)。"""
     """执行单个训练任务(远程调用入口)。"""

+ 15 - 20
backend/app/engines/text_engine.py

@@ -12,18 +12,16 @@ import logging
 from pathlib import Path
 from pathlib import Path
 from typing import Any
 from typing import Any
 
 
-# 远程训练节点可能没有 pydantic-settings,用环境变量兜底
-try:
-    from app.config import get_settings
-    settings = get_settings()
-except ImportError:
-    from types import SimpleNamespace
-    settings = SimpleNamespace(
-        data_dir=Path(os.environ.get("COMPUTE_NODE_REMOTE_DATA_DIR", "/root/Fine-tuning/backend/data")),
-        processed_dir=Path(os.environ.get("COMPUTE_NODE_REMOTE_DATA_DIR", "/root/Fine-tuning/backend/data")) / "processed",
-        adapters_dir=Path(os.environ.get("COMPUTE_NODE_REMOTE_DATA_DIR", "/root/Fine-tuning/backend/data")) / "adapters",
-        models_dir=Path(os.environ.get("COMPUTE_NODE_REMOTE_DATA_DIR", "/root/Fine-tuning/backend/data")) / "models",
-    )
+# 远程训练节点没有 pydantic-settings/数据库,直接用环境变量
+from types import SimpleNamespace
+
+_data_dir = Path(os.environ.get("COMPUTE_NODE_REMOTE_DATA_DIR", "/root/Fine-tuning/backend/data"))
+settings = SimpleNamespace(
+    data_dir=_data_dir,
+    processed_dir=_data_dir / "processed",
+    adapters_dir=_data_dir / "adapters",
+    models_dir=_data_dir / "models",
+)
 
 
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
 
 
@@ -193,14 +191,7 @@ class TextEngine(BaseEngine):
                 callbacks=all_callbacks,
                 callbacks=all_callbacks,
             )
             )
         else:
         else:
-            from trl import (
-                DPOConfig,
-                DPOTrainer,
-                KTOConfig,
-                KTOTrainer,
-                ORPOConfig,
-                ORPOTrainer,
-            )
+            from trl import DPOConfig, DPOTrainer
 
 
             base_trainer_kwargs = dict(
             base_trainer_kwargs = dict(
                 output_dir=output_dir,
                 output_dir=output_dir,
@@ -223,6 +214,8 @@ class TextEngine(BaseEngine):
                     processing_class=self._tokenizer,
                     processing_class=self._tokenizer,
                 )
                 )
             elif task_type == "orpo":
             elif task_type == "orpo":
+                from trl import ORPOConfig, ORPOTrainer
+
                 trainer = ORPOTrainer(
                 trainer = ORPOTrainer(
                     model=self._model,
                     model=self._model,
                     args=ORPOConfig(**base_trainer_kwargs),
                     args=ORPOConfig(**base_trainer_kwargs),
@@ -230,6 +223,8 @@ class TextEngine(BaseEngine):
                     processing_class=self._tokenizer,
                     processing_class=self._tokenizer,
                 )
                 )
             elif task_type == "kto":
             elif task_type == "kto":
+                from trl import KTOConfig, KTOTrainer
+
                 trainer = KTOTrainer(
                 trainer = KTOTrainer(
                     model=self._model,
                     model=self._model,
                     args=KTOConfig(**base_trainer_kwargs),
                     args=KTOConfig(**base_trainer_kwargs),

+ 11 - 12
backend/app/engines/vision_engine.py

@@ -1,19 +1,18 @@
+import os
 import json
 import json
 from pathlib import Path
 from pathlib import Path
 from typing import Any
 from typing import Any
 
 
-# 远程训练节点可能没有 pydantic-settings,用环境变量兜底
-try:
-    from app.config import get_settings
-    settings = get_settings()
-except ImportError:
-    from types import SimpleNamespace
-    settings = SimpleNamespace(
-        data_dir=Path("/root/Fine-tuning/backend/data"),
-        processed_dir=Path("/root/Fine-tuning/backend/data") / "processed",
-        adapters_dir=Path("/root/Fine-tuning/backend/data") / "adapters",
-        models_dir=Path("/root/Fine-tuning/backend/data") / "models",
-    )
+# 远程训练节点没有 pydantic-settings/数据库,直接用环境变量
+from types import SimpleNamespace
+
+_data_dir = Path(os.environ.get("COMPUTE_NODE_REMOTE_DATA_DIR", "/root/Fine-tuning/backend/data"))
+settings = SimpleNamespace(
+    data_dir=_data_dir,
+    processed_dir=_data_dir / "processed",
+    adapters_dir=_data_dir / "adapters",
+    models_dir=_data_dir / "models",
+)
 
 
 import logging
 import logging
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)