Kaynağa Gözat

修复adalora报错问题

lxylxy123321 4 gün önce
ebeveyn
işleme
0b255898ad
1 değiştirilmiş dosya ile 6 ekleme ve 0 silme
  1. 6 0
      backend/app/engines/text_engine.py

+ 6 - 0
backend/app/engines/text_engine.py

@@ -191,9 +191,15 @@ class TextEngine(BaseEngine):
         self._model.print_trainable_parameters()
         self._model.print_trainable_parameters()
 
 
         output_dir = str(settings.adapters_dir / job_id)
         output_dir = str(settings.adapters_dir / job_id)
+
+        # AdaLoRA 需要 max_steps,同时也让进度计算更准确
+        dataset_len = len(dataset)
+        max_steps = (dataset_len * epochs) // (batch_size * gradient_accumulation)
+
         tr_args = TrainingArguments(
         tr_args = TrainingArguments(
             output_dir=output_dir,
             output_dir=output_dir,
             num_train_epochs=epochs,
             num_train_epochs=epochs,
+            max_steps=max_steps,
             per_device_train_batch_size=batch_size,
             per_device_train_batch_size=batch_size,
             gradient_accumulation_steps=gradient_accumulation,
             gradient_accumulation_steps=gradient_accumulation,
             learning_rate=learning_rate,
             learning_rate=learning_rate,