|
|
@@ -353,7 +353,11 @@ class TextEngine(BaseEngine):
|
|
|
|
|
|
|
|
|
class _ProgressCallback:
|
|
|
- """自定义训练进度回调,通过 WebSocket 发送进度。"""
|
|
|
+ """自定义训练进度回调,通过 WebSocket 发送进度。
|
|
|
+
|
|
|
+ 不继承 TrainerCallback 基类,而是通过 __getattr__ 优雅处理
|
|
|
+ 所有未知的回调方法,避免新版本 transformers 新增回调导致的兼容问题。
|
|
|
+ """
|
|
|
|
|
|
def __init__(self, job_id: str):
|
|
|
self.job_id = job_id
|
|
|
@@ -385,35 +389,11 @@ class _ProgressCallback:
|
|
|
)
|
|
|
)
|
|
|
|
|
|
- def on_train_begin(self, args, state, control, **kwargs):
|
|
|
- pass
|
|
|
-
|
|
|
- def on_step_begin(self, args, state, control, **kwargs):
|
|
|
- pass
|
|
|
-
|
|
|
- def on_step_end(self, args, state, control, **kwargs):
|
|
|
- pass
|
|
|
-
|
|
|
- def on_substep_end(self, args, state, control, logs=None, **kwargs):
|
|
|
- pass
|
|
|
-
|
|
|
- def on_pre_optimizer_step(self, args, state, control, logs=None, **kwargs):
|
|
|
- pass
|
|
|
-
|
|
|
- def on_prediction_step(self, args, state, control, **kwargs):
|
|
|
- pass
|
|
|
-
|
|
|
- def on_save(self, args, state, control, **kwargs):
|
|
|
- pass
|
|
|
-
|
|
|
- def on_predict(self, args, state, control, metrics=None, **kwargs):
|
|
|
- pass
|
|
|
-
|
|
|
- def on_init_end(self, args, state, control, **kwargs):
|
|
|
- pass
|
|
|
-
|
|
|
- def on_epoch_begin(self, args, state, control, **kwargs):
|
|
|
- pass
|
|
|
+ # 其他所有回调方法统一用 __getattr__ 处理:返回空函数,避免 NotImplementedError
|
|
|
+ def __getattr__(self, name: str):
|
|
|
+ if name.startswith("on_"):
|
|
|
+ return lambda *a, **k: None
|
|
|
+ raise AttributeError(name)
|
|
|
|
|
|
|
|
|
# 全局单例
|