-
Notifications
You must be signed in to change notification settings - Fork 4.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Enhancement] Support ZeRO-3 when using BAdam #4352
Conversation
@@ -371,6 +371,12 @@ def _create_badam_optimizer( | |||
dict(params=decay_params, weight_decay=training_args.weight_decay), | |||
] | |||
|
|||
ds_zero3_enabled = False | |||
if hasattr(training_args, "deepspeed_plugin") and training_args.deepspeed_plugin is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not use from transformers.integrations import is_deepspeed_zero3_enabled
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the suggestion, I have changed it to use is_deepspeed_zero3_enabled
for cleaner expressions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
What does this PR do?
This PR enables BAdam algorithm to use model parallelism, based on the implementation of deepspeed ZeRO-3.
A sample script is provided in "examples/extras/badam/train_zero3.sh". The running command generated by the current webUI works correctly as well.
When training Llama 3-8B on "alpaca_en_demo" dataset with batch size 1, the maximum per device allocated memory is about 13/10/8 GB when training with 2/3/4 RTX3090 GPUs, respectively. I suppose it would be feasible to finetune a Llama 3-70B model given 8 RTX-3090 / 3 A100-80G using BAdam, though I haven't conduct a comprehensive test due to the limited computation resources.
The main change in code is to add the a BAdam's callback during Trainer's initialization, when
use_badam
and ZeRO-3 mode are detected.Thanks for review!
Before submitting