ソースを参照

修复训练问题

lxylxy123321 1 週間 前
コミット
d7c14f0eca
1 ファイル変更10 行追加30 行削除
  1. 10 30
      backend/app/engines/text_engine.py

+ 10 - 30
backend/app/engines/text_engine.py

@@ -353,7 +353,11 @@ class TextEngine(BaseEngine):
 
 
 class _ProgressCallback:
-    """自定义训练进度回调,通过 WebSocket 发送进度。"""
+    """自定义训练进度回调,通过 WebSocket 发送进度。
+
+    不继承 TrainerCallback 基类,而是通过 __getattr__ 优雅处理
+    所有未知的回调方法,避免新版本 transformers 新增回调导致的兼容问题。
+    """
 
     def __init__(self, job_id: str):
         self.job_id = job_id
@@ -385,35 +389,11 @@ class _ProgressCallback:
             )
         )
 
-    def on_train_begin(self, args, state, control, **kwargs):
-        pass
-
-    def on_step_begin(self, args, state, control, **kwargs):
-        pass
-
-    def on_step_end(self, args, state, control, **kwargs):
-        pass
-
-    def on_substep_end(self, args, state, control, logs=None, **kwargs):
-        pass
-
-    def on_pre_optimizer_step(self, args, state, control, logs=None, **kwargs):
-        pass
-
-    def on_prediction_step(self, args, state, control, **kwargs):
-        pass
-
-    def on_save(self, args, state, control, **kwargs):
-        pass
-
-    def on_predict(self, args, state, control, metrics=None, **kwargs):
-        pass
-
-    def on_init_end(self, args, state, control, **kwargs):
-        pass
-
-    def on_epoch_begin(self, args, state, control, **kwargs):
-        pass
+    # 其他所有回调方法统一用 __getattr__ 处理:返回空函数,避免 NotImplementedError
+    def __getattr__(self, name: str):
+        if name.startswith("on_"):
+            return lambda *a, **k: None
+        raise AttributeError(name)
 
 
 # 全局单例