Skip to content

Commit

Permalink
Fix LoRA device map (attempt)
Browse files Browse the repository at this point in the history
  • Loading branch information
oobabooga authored Mar 23, 2023
1 parent c5ebcc5 commit 9bf6ecf
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions modules/LoRA.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@ def add_lora_to_model(lora_name):
print(f"Adding the LoRA {lora_name} to the model...")

params = {}
if shared.args.load_in_8bit:
params['device_map'] = {'': 0}
elif not shared.args.cpu:
params['device_map'] = 'auto'
if not shared.args.cpu:
params['dtype'] = shared.model.dtype
if hasattr(shared.model, "hf_device_map"):
params['device_map'] = {"base_model.model."+k: v for k, v in shared.model.hf_device_map.items()}
elif shared.args.load_in_8bit:
params['device_map'] = {'': 0}

shared.model = PeftModel.from_pretrained(shared.model, Path(f"loras/{lora_name}"), **params)
if not shared.args.load_in_8bit and not shared.args.cpu:
shared.model.half()
if not hasattr(shared.model, "hf_device_map"):
shared.model.cuda()

0 comments on commit 9bf6ecf

Please sign in to comment.