Przeglądaj źródła

增加日志,debug

lxylxy123321 1 tydzień temu
rodzic
commit
242e25caeb
2 zmienionych plików z 22 dodań i 1 usunięć
  1. 1 1
      backend/app/config.py
  2. 21 0
      backend/app/engines/text_engine.py

+ 1 - 1
backend/app/config.py

@@ -33,8 +33,8 @@ class Settings(BaseSettings):
     ):
         return (
             init_settings,
-            EnvSettingsSourceWithCommaLists(settings_cls),
             dotenv_settings,
+            EnvSettingsSourceWithCommaLists(settings_cls),
             file_secret_settings,
         )
 

+ 21 - 0
backend/app/engines/text_engine.py

@@ -41,14 +41,31 @@ class TextEngine(BaseEngine):
             )
 
         quantization = kwargs.get("quantization", None)
+        
+        # 日志:检查 GPU 状态
+        logger.info(f"CUDA available: {torch.cuda.is_available()}")
+        logger.info(f"CUDA device count: {torch.cuda.device_count()}")
+        if torch.cuda.is_available():
+            for i in range(torch.cuda.device_count()):
+                logger.info(f"GPU {i}: {torch.cuda.get_device_name(i)}")
+                logger.info(f"GPU {i} memory: {torch.cuda.get_device_properties(i).total_memory / (1024**3):.2f} GB")
+        else:
+            logger.warning("No GPU detected! Training will run on CPU.")
+        
+        max_memory = {i: "4GB" for i in range(torch.cuda.device_count())} if torch.cuda.is_available() else None
+        
         load_kwargs: dict[str, Any] = {
             "torch_dtype": torch.float16,
             "device_map": "auto",
+            "low_cpu_mem_usage": True,
+            "use_safetensors": True,
+            "max_memory": max_memory,
         }
         if quantization == "4bit" or quantization == "qlora":
             load_kwargs["load_in_4bit"] = True
             load_kwargs["bnb_4bit_quant_type"] = "nf4"
             load_kwargs["bnb_4bit_use_double_quant"] = True
+            load_kwargs["bnb_4bit_compute_dtype"] = torch.float16
         elif quantization == "8bit":
             load_kwargs["load_in_8bit"] = True
 
@@ -135,6 +152,10 @@ class TextEngine(BaseEngine):
             optim="adamw_torch",
             remove_unused_columns=False,
             report_to="none",
+            gradient_checkpointing=True,
+            gradient_checkpointing_kwargs={"use_reentrant": False},
+            dataloader_num_workers=1,
+            dataloader_pin_memory=False,
             **({"deepspeed": deepspeed_config} if deepspeed_config else {}),
         )