|
|
@@ -191,9 +191,15 @@ class TextEngine(BaseEngine):
|
|
|
self._model.print_trainable_parameters()
|
|
|
|
|
|
output_dir = str(settings.adapters_dir / job_id)
|
|
|
+
|
|
|
+ # AdaLoRA 需要 max_steps,同时也让进度计算更准确
|
|
|
+ dataset_len = len(dataset)
|
|
|
+ max_steps = (dataset_len * epochs) // (batch_size * gradient_accumulation)
|
|
|
+
|
|
|
tr_args = TrainingArguments(
|
|
|
output_dir=output_dir,
|
|
|
num_train_epochs=epochs,
|
|
|
+ max_steps=max_steps,
|
|
|
per_device_train_batch_size=batch_size,
|
|
|
gradient_accumulation_steps=gradient_accumulation,
|
|
|
learning_rate=learning_rate,
|