Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add BAdam algorithm #3287

Merged
merged 14 commits into from
Apr 16, 2024
Merged

[Feature] Add BAdam algorithm #3287

merged 14 commits into from
Apr 16, 2024

Conversation

Ledzy
Copy link
Contributor

@Ledzy Ledzy commented Apr 15, 2024

What does this PR do?

This PR incorporates the BAdam algorithm to the repository. One can now use BAdam by setting argument "--use_badam". It enables the full parameter finetuning of Llama 2-7b within 24GB RAM under mixed precision training, while using only half time compared with LoRA.

A sample running script is available at "examples/extras/badam/sft.sh". The script attains 0.8591 eval_loss after 900 steps within training time of 3 hours (including evaluation time) on a single RTX-3090. Note that one needs to run pip install badam before running the script.

Notes

  • The gradient_checkpointing_enable in "llmtuner/model/patcher.py" is modified to check if any parameter of the checkpointed layer is trainable. If so, it sets the input's requires_grad to be True. The modification enables checkpointed layer to be trainable when the input's requires_grad is False. Such a modification is essential for the acceleration of BAdam (otherwise, the backward will always go to the first layer, due to model.enable_input_require_grads(), which reduces the training time by half.
  • BAdam uses mixed precision training, where the fp32 master weight are created during the creation and update of the optimizer; the model should be loaded in fp16 format. The file "llmtuner/model/adapter.py" is modified accordingly.
  • Currently, BAdam's implementation doesn't support distributed training. The DDP and ZeRO implementation is scheduled to be support in future versions of package badam.

Before submitting

@hiyouga hiyouga added the pending This problem is yet to be addressed label Apr 15, 2024
@hiyouga hiyouga self-requested a review April 15, 2024 17:38
Copy link
Owner

@hiyouga hiyouga left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for adding this brilliant algorithm to LLaMA Factory, please take a look at my comments, especially the implementation of gradient checkpointing.

examples/extras/badam/sft.sh Outdated Show resolved Hide resolved
requirements.txt Outdated Show resolved Hide resolved
src/llmtuner/model/patcher.py Show resolved Hide resolved
src/llmtuner/model/utils.py Outdated Show resolved Hide resolved
src/llmtuner/model/utils.py Outdated Show resolved Hide resolved
src/llmtuner/model/utils.py Outdated Show resolved Hide resolved
src/llmtuner/train/sft/trainer.py Outdated Show resolved Hide resolved
@Ledzy Ledzy requested a review from hiyouga April 16, 2024 04:27
@hiyouga
Copy link
Owner

hiyouga commented Apr 16, 2024

Some necessary changes are made, it can be merged now

@hiyouga hiyouga merged commit 4d660c5 into hiyouga:main Apr 16, 2024
1 check passed
@hiyouga hiyouga added solved This problem has been already solved and removed pending This problem is yet to be addressed labels Apr 16, 2024
@leekum2018
Copy link

How to select Badam in the web interface, instead of the command line?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
solved This problem has been already solved
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants