瀏覽代碼

修复151上训练

lxylxy123321 2 天之前
父節點
當前提交
c2382c3c64

+ 5 - 2
backend/app/engines/multimodal_engine.py

@@ -182,10 +182,13 @@ class MultimodalEngine(BaseEngine):
         return {"model_type": "multimodal", "context_length": 4096}
 
 
-from transformers import TrainerCallback
+try:
+    from transformers import TrainerCallback as _TrainerCallbackBase
+except ImportError:
+    _TrainerCallbackBase = object  # 151 主节点无 transformers,仅做占位
 
 
-class _ProgressCallback(TrainerCallback):
+class _ProgressCallback(_TrainerCallbackBase):
     def __init__(self, job_id: str):
         super().__init__()
         self.job_id = job_id

+ 5 - 2
backend/app/engines/text_engine.py

@@ -530,10 +530,13 @@ class TextEngine(BaseEngine):
         return tokenized_dataset
 
 
-from transformers import TrainerCallback
+try:
+    from transformers import TrainerCallback as _TrainerCallbackBase
+except ImportError:
+    _TrainerCallbackBase = object  # 151 主节点无 transformers,仅做占位
 
 
-class _ProgressCallback(TrainerCallback):
+class _ProgressCallback(_TrainerCallbackBase):
     """自定义训练进度回调,通过 WebSocket 发送进度。"""
 
     def __init__(self, job_id: str):

+ 5 - 2
backend/app/engines/vision_engine.py

@@ -182,10 +182,13 @@ class VisionEngine(BaseEngine):
         return {"model_type": "vision", "context_length": 2048}
 
 
-from transformers import TrainerCallback
+try:
+    from transformers import TrainerCallback as _TrainerCallbackBase
+except ImportError:
+    _TrainerCallbackBase = object  # 151 主节点无 transformers,仅做占位
 
 
-class _ProgressCallback(TrainerCallback):
+class _ProgressCallback(_TrainerCallbackBase):
     def __init__(self, job_id: str):
         super().__init__()
         self.job_id = job_id