-
Notifications
You must be signed in to change notification settings - Fork 2.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding LoRA fine tuning #187
base: master
Are you sure you want to change the base?
Conversation
This looks elegant to me. |
I like where this is going, but this looks like multiple PRs in one, and a little bit of sus code. I'll inline comment |
@wlamond If we're going to It's a very simple change to your PR, just need to reference bnb.nn.Linear4bit for the 4-bit quantization. |
@vgoklani Oooo, I do like that idea. I think it would be better as a separate PR though. I'm not sure how Andrej feels about adding other dependencies, so I'd rather get this project finished and then add QLoRA as another option if there's interest. Thanks for the idea and feedback! |
definitely interest. adds possibilities to potentially do more with less 🙇♂️ |
os.makedirs(out_dir, exist_ok=True) | ||
for p in model.parameters(): | ||
p.requires_grad = False | ||
apply_lora(model, layer_types=lora_layer_types, rank=lora_rank, dropout=lora_dropout, alpha=lora_alpha) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like you apply LoRA to every Linear layer? Could you provide the ability to register lora for target modules (e.g. wq and wk only as suggested in the original paper)? The code may look like:
def apply_lora(model, ..., target_modules=['wq', 'wk']:
for name, layer in model.named_modules():
if name.split(".")[-1] not in target_modules:
continue
# register lora parameterization
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I love this idea! I have a local implementation that does that, but I'll update this to follow suit.
@@ -332,5 +350,12 @@ def get_lr(it): | |||
if iter_num > max_iters: | |||
break | |||
|
|||
if init_from == "lora_finetune": | |||
print('merging lora') | |||
merge_lora(raw_model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe save lora parameters in a standalone file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, the checkpoints actually already have the lora parameters in them (the parameterization is computed whenever the weights are referenced, including during exports). Saving them off to the side could enable hot swapping loras for different tasks at some point if folks are interested in that feature.
Another possible improvement: the original parameters doesn't need to be stored in the optimizer during lora finetuning. |
The |
Oh I got it. Look forward to your updates 🚀 |
return weight + torch.matmul(self.lora_b, self.dropout(self.lora_a)) * self.scaling | ||
|
||
|
||
def apply_lora(model: nn.Module, layer_types=[nn.Linear], rank=8, dropout=0.0, alpha=1.0): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I always get a little antsy seeing Lists in defaults
https://docs.python-guide.org/writing/gotchas/#what-you-wrote
def _apply_lora(module): | ||
if type(module) in layer_types and hasattr(module, 'weight'): | ||
fan_out, fan_in = module.weight.shape | ||
parametrize.register_parametrization(module, 'weight', LoraLinear(fan_in, fan_out, rank, dropout, alpha)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would this fail if LayerTypes is nn.Embedding? It shouldn't get replaced with LoraLinear right?
@@ -54,6 +55,12 @@ | |||
n_heads = 6 | |||
multiple_of = 32 | |||
dropout = 0.0 | |||
# LoRA | |||
lora_layer_types = [nn.Linear, nn.Embedding] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this can't be overridden as an arg, if it's a list like this?
best_val_loss = checkpoint["best_val_loss"] | ||
|
||
if init_from == "lora_finetune": | ||
out_dir = out_dir + "_lora_finetune" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
os.path.join?
@wlamond I'd love to do some experimentation with LoRA on various types of smaller models. Any chance this PR could be revived/updated? |
Adding an implementation of LoRA fine tuning, heavily inspired by minLoRA. I thought the use of pytorch parametrization was interesting and simple, and fits in nicely with the approach of this project. Let me know if you were thinking of explicitly implementing the modified forward pass rather than a factored/merged forward pass, or if you think this would be a better fit as a separate repo.
I added the tinyshakespeare dataset and default to fine tuning on that. I wanted to tune the tinystories models a small amount (~50-100 steps) to get Shakespearian tiny stories :) I had some mixed results, e.g.:
Still mostly story-like, but certainly leaning more towards the drama of Shakespeare. I like the commentary on how being exposed to new and original thoughts can leave you in a new state of being. ;)
I also tuned this for ~1k steps with the 15M param model to get something that more closely resembles Shakespeare.
I only have access to a 1080ti and a v100 16GB, so I wasn't able to do more thorough testing/experimentation on the actual Llama2 checkpoints. Let me know if you'd like to see more testing before making a decision on what to do with this.
Thanks for sharing this project! It's been fun to play with.