Jelajahi Sumber

确保max_steps 至少为 1

lxylxy123321 3 hari lalu
induk
melakukan
4a1ae29f62

+ 4 - 0
backend/app/engines/multimodal_engine.py

@@ -122,9 +122,13 @@ class MultimodalEngine(BaseEngine):
         batch_size = training_args.get("batch_size", 4)
         learning_rate = training_args.get("learning_rate", 2e-4)
 
+        dataset_len = len(hf_dataset)
+        max_steps = max(1, (dataset_len * epochs) // batch_size)
+
         tr_args = TrainingArguments(
             output_dir=output_dir,
             num_train_epochs=epochs,
+            max_steps=max_steps,
             per_device_train_batch_size=batch_size,
             learning_rate=learning_rate,
             save_strategy="epoch",

+ 3 - 2
backend/app/engines/text_engine.py

@@ -192,9 +192,9 @@ class TextEngine(BaseEngine):
 
         output_dir = str(settings.adapters_dir / job_id)
 
-        # AdaLoRA 需要 max_steps,同时也让进度计算更准确
+        # AdaLoRA 需要 max_steps > 0,同时也让进度计算更准确
         dataset_len = len(dataset)
-        max_steps = (dataset_len * epochs) // (batch_size * gradient_accumulation)
+        max_steps = max(1, (dataset_len * epochs) // (batch_size * gradient_accumulation))
 
         tr_args = TrainingArguments(
             output_dir=output_dir,
@@ -236,6 +236,7 @@ class TextEngine(BaseEngine):
             base_trainer_kwargs = dict(
                 output_dir=output_dir,
                 num_train_epochs=epochs,
+                max_steps=max_steps,
                 per_device_train_batch_size=batch_size,
                 gradient_accumulation_steps=gradient_accumulation,
                 learning_rate=learning_rate,

+ 4 - 0
backend/app/engines/vision_engine.py

@@ -122,9 +122,13 @@ class VisionEngine(BaseEngine):
         batch_size = training_args.get("batch_size", 4)
         learning_rate = training_args.get("learning_rate", 2e-4)
 
+        dataset_len = len(hf_dataset)
+        max_steps = max(1, (dataset_len * epochs) // batch_size)
+
         tr_args = TrainingArguments(
             output_dir=output_dir,
             num_train_epochs=epochs,
+            max_steps=max_steps,
             per_device_train_batch_size=batch_size,
             learning_rate=learning_rate,
             save_strategy="epoch",