-
Notifications
You must be signed in to change notification settings - Fork 472
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
peft to opendelta migration (#434) + memory optimization (#320) #486
Conversation
…CarperAI#434) + Collapse reference+learner hydra heads when using LoRa (CarperAI#320)
else: | ||
strict = True | ||
|
||
self.accelerator.load_state(directory or self.config.train.checkpoint_dir, strict=strict, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For this function I'm not confident that everything I wrote makes sense. I have not tested it in a really distributed setting. And I'm not sure how useful it is because it just sets the internal state of accelerator but it does not set other things like self.model. And when not using peft, testing trainer.load(checkpoint_dir)
was not working, because of some additional "base_model" prefix in the layer names.
Can you elaborate on what you mean by memory optimization beyond just the integration of PEFT? |
If you check AutoModelForCausalLMWithHydraValueHead.init and its Seq2Seq equivalent, when peft is used, we don't create a frozen head, because we can just bypass or temporarily deactivate the adapter. This comes from the issue #320. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, @glerzing! Amazing contribution 🙂
I've left some initial requests for changes and pointed out a few issues I encountered while launching some of our examples. We'll need to sort out a few things that break backward compatibility before merging.
I tried to use the model tiny-gpt2. It would be much faster than gpt2 for the tests, but unfortunately it looks like it causes some numerical instabilities in the test_backpropagation_and_disabling, which can fail depending on the random seed. I still replaced t5-small by google/t5-efficient-tiny, which works well. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One final change! Per the Discord discussion on the 8-bit model forward args issue; let's add a guard check for 8-bit loading and notify to users that that it is still an experimental feature.
super().__init__() | ||
self.base_model = base_model | ||
# cache `forward` args for general use (avoids incompatible args across architectures) | ||
self.forward_kwargs = inspect.getfullargspec(self.base_model.forward).args | ||
self.is_loaded_in_8bit = getattr(base_model, "is_loaded_in_8bit", False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hardcode this to False
and log something along the lines of "8-bit loading is an experimental feature not yet fully tested; leaving for curious users to explore"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@glerzing Huge contribution and an amazing job! Thank you 🚀
I closed #477 to reopen a new PR.
Feel free to ask questions or to make remarks. If you want the automated tests to be faster, I can reduce the number of models or configs to test. I mostly relied on unit tests for verification. A few things, like CausalILQLOutput, were added to make the methods more generic, and for easier automated tests.