Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

peft to opendelta migration (#434) + memory optimization (#320) #486

Merged
merged 8 commits into from
Jun 23, 2023

Conversation

glerzing
Copy link
Contributor

I closed #477 to reopen a new PR.

Feel free to ask questions or to make remarks. If you want the automated tests to be faster, I can reduce the number of models or configs to test. I mostly relied on unit tests for verification. A few things, like CausalILQLOutput, were added to make the methods more generic, and for easier automated tests.

else:
strict = True

self.accelerator.load_state(directory or self.config.train.checkpoint_dir, strict=strict, **kwargs)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this function I'm not confident that everything I wrote makes sense. I have not tested it in a really distributed setting. And I'm not sure how useful it is because it just sets the internal state of accelerator but it does not set other things like self.model. And when not using peft, testing trainer.load(checkpoint_dir) was not working, because of some additional "base_model" prefix in the layer names.

@LouisCastricato
Copy link
Contributor

Can you elaborate on what you mean by memory optimization beyond just the integration of PEFT?

@glerzing
Copy link
Contributor Author

Can you elaborate on what you mean by memory optimization beyond just the integration of PEFT?

If you check AutoModelForCausalLMWithHydraValueHead.init and its Seq2Seq equivalent, when peft is used, we don't create a frozen head, because we can just bypass or temporarily deactivate the adapter. This comes from the issue #320.

Copy link
Collaborator

@jon-tow jon-tow left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @glerzing! Amazing contribution 🙂
I've left some initial requests for changes and pointed out a few issues I encountered while launching some of our examples. We'll need to sort out a few things that break backward compatibility before merging.

trlx/models/modeling_ppo.py Outdated Show resolved Hide resolved
trlx/data/configs.py Outdated Show resolved Hide resolved
trlx/models/modeling_ilql.py Outdated Show resolved Hide resolved
trlx/models/modeling_base.py Outdated Show resolved Hide resolved
trlx/trainer/accelerate_ppo_trainer.py Outdated Show resolved Hide resolved
trlx/models/modeling_base.py Show resolved Hide resolved
trlx/models/modeling_ppo.py Show resolved Hide resolved
@glerzing
Copy link
Contributor Author

glerzing commented Jun 3, 2023

I tried to use the model tiny-gpt2. It would be much faster than gpt2 for the tests, but unfortunately it looks like it causes some numerical instabilities in the test_backpropagation_and_disabling, which can fail depending on the random seed. I still replaced t5-small by google/t5-efficient-tiny, which works well.

@glerzing glerzing requested a review from jon-tow June 8, 2023 21:11
Copy link
Collaborator

@maxreciprocate maxreciprocate left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Copy link
Collaborator

@jon-tow jon-tow left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One final change! Per the Discord discussion on the 8-bit model forward args issue; let's add a guard check for 8-bit loading and notify to users that that it is still an experimental feature.

super().__init__()
self.base_model = base_model
# cache `forward` args for general use (avoids incompatible args across architectures)
self.forward_kwargs = inspect.getfullargspec(self.base_model.forward).args
self.is_loaded_in_8bit = getattr(base_model, "is_loaded_in_8bit", False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hardcode this to False and log something along the lines of "8-bit loading is an experimental feature not yet fully tested; leaving for curious users to explore"

Copy link
Collaborator

@jon-tow jon-tow left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@glerzing Huge contribution and an amazing job! Thank you 🚀

@jon-tow jon-tow merged commit d47996d into CarperAI:main Jun 23, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants