|
|
@@ -41,14 +41,31 @@ class TextEngine(BaseEngine):
|
|
|
)
|
|
|
|
|
|
quantization = kwargs.get("quantization", None)
|
|
|
+
|
|
|
+ # 日志:检查 GPU 状态
|
|
|
+ logger.info(f"CUDA available: {torch.cuda.is_available()}")
|
|
|
+ logger.info(f"CUDA device count: {torch.cuda.device_count()}")
|
|
|
+ if torch.cuda.is_available():
|
|
|
+ for i in range(torch.cuda.device_count()):
|
|
|
+ logger.info(f"GPU {i}: {torch.cuda.get_device_name(i)}")
|
|
|
+ logger.info(f"GPU {i} memory: {torch.cuda.get_device_properties(i).total_memory / (1024**3):.2f} GB")
|
|
|
+ else:
|
|
|
+ logger.warning("No GPU detected! Training will run on CPU.")
|
|
|
+
|
|
|
+ max_memory = {i: "4GB" for i in range(torch.cuda.device_count())} if torch.cuda.is_available() else None
|
|
|
+
|
|
|
load_kwargs: dict[str, Any] = {
|
|
|
"torch_dtype": torch.float16,
|
|
|
"device_map": "auto",
|
|
|
+ "low_cpu_mem_usage": True,
|
|
|
+ "use_safetensors": True,
|
|
|
+ "max_memory": max_memory,
|
|
|
}
|
|
|
if quantization == "4bit" or quantization == "qlora":
|
|
|
load_kwargs["load_in_4bit"] = True
|
|
|
load_kwargs["bnb_4bit_quant_type"] = "nf4"
|
|
|
load_kwargs["bnb_4bit_use_double_quant"] = True
|
|
|
+ load_kwargs["bnb_4bit_compute_dtype"] = torch.float16
|
|
|
elif quantization == "8bit":
|
|
|
load_kwargs["load_in_8bit"] = True
|
|
|
|
|
|
@@ -135,6 +152,10 @@ class TextEngine(BaseEngine):
|
|
|
optim="adamw_torch",
|
|
|
remove_unused_columns=False,
|
|
|
report_to="none",
|
|
|
+ gradient_checkpointing=True,
|
|
|
+ gradient_checkpointing_kwargs={"use_reentrant": False},
|
|
|
+ dataloader_num_workers=1,
|
|
|
+ dataloader_pin_memory=False,
|
|
|
**({"deepspeed": deepspeed_config} if deepspeed_config else {}),
|
|
|
)
|
|
|
|