|
|
@@ -105,7 +105,7 @@ for cls, kw in [(AutoModelForCausalLM, {{'trust_remote_code': True}}), (AutoMode
|
|
|
for dtype_val, dtype_name in [(torch.float16, 'float16'), (torch.float32, 'float32')]:
|
|
|
try:
|
|
|
if has_accelerate:
|
|
|
- m = cls.from_pretrained(model_path, dtype=dtype_val, device_map='auto', **kw)
|
|
|
+ m = cls.from_pretrained(model_path, dtype=dtype_val, device_map={"": 0}, **kw)
|
|
|
else:
|
|
|
m = cls.from_pretrained(model_path, dtype=dtype_val, device_map=None, **kw)
|
|
|
m = m.to(device)
|
|
|
@@ -211,7 +211,7 @@ def _run_local_inference(model_dir: Path, prompt: str, max_new_tokens: int, temp
|
|
|
model = loader_cls.from_pretrained(
|
|
|
model_dir,
|
|
|
torch_dtype=torch.float16,
|
|
|
- device_map="auto",
|
|
|
+ device_map={"": 0},
|
|
|
**kwargs,
|
|
|
)
|
|
|
break
|