-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Give example on how to handle gradient accumulation with cross-entropy #3193
Give example on how to handle gradient accumulation with cross-entropy #3193
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! (We need to do the gather()
+ div by num processes in the trainer still).
Left a few nits, I think it'd be really cool if we can show full training graphs. After doing stuff with FP8 just taking "the end result is the same" at face value I don't fully trust :)
Results on a single device: | ||
``` | ||
initial model weight is tensor([-0.0075, 0.5364]) | ||
initial model clone weight is tensor([-0.0075, 0.5364]) | ||
Step 0 - Device 0 - num items in the local batch 36 | ||
Total num items 36 | ||
Device 0 - w/ accumulation, the final model weight is tensor([0.0953, 0.4337]) | ||
w/o accumulation, the final model weight is tensor([0.0953, 0.4337]) | ||
``` | ||
|
||
Results on a two devices set-up: | ||
``` | ||
initial model weight is tensor([-0.0075, 0.5364]) | ||
initial model clone weight is tensor([-0.0075, 0.5364]) | ||
Step 0 - Device 0 - num items in the local batch 52 | ||
Step 0 - Device 1 - num items in the local batch 84 | ||
Total num items 136 | ||
Device 1 - w/ accumulation, the final model weight is tensor([0.2117, 0.3172]) | ||
Device 0 - w/ accumulation, the final model weight is tensor([0.2117, 0.3172]) | ||
w/o accumulation, the final model weight is tensor([0.2117, 0.3172]) | ||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Honestly if we can let's even toss up some wandb
graphs 🔥
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed, it'd be great, but here we do only one single global batch size, I don't think it's worth adding a graph. Maybe should I modify the current code snippet to do this with multiple global steps ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or add some wandb graphs from the upcoming modif of examples/by_feature/gradient_accumulation
?
model_optimizer.zero_grad() | ||
|
||
|
||
logger.warning(f"Device {accelerator.process_index} - w/ accumulation, the final model weight is {accelerator.unwrap_model(model).weight.detach().cpu().squeeze()}", main_process_only=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rather than logger.warning
, we can do print()
here or change the default logging level :) (Just logging.warning
rather than logging.info
weirds me out)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice job @ylacombe ! Left a few suggestions !
num_samples_in_epoch = len(dataloader) | ||
remainder = num_samples_in_epoch % gradient_accumulation_steps | ||
remainder = remainder if remainder != 0 else gradient_accumulation_steps | ||
total_gradient_updates = math.ceil(num_samples_in_epoch / gradient_accumulation_steps) | ||
|
||
total_batched_samples = 0 | ||
for update_step in range(total_gradient_updates): | ||
# In order to correctly the total number of non-padded tokens on which we'll compute the cross-entropy loss | ||
# we need to pre-load the full local batch - i.e the next per_device_batch_size * accumulation_steps samples | ||
batch_samples = [] | ||
num_batches_in_step = gradient_accumulation_steps if update_step != (total_gradient_updates - 1) else remainder | ||
for _ in range(num_batches_in_step): | ||
batch_samples += [next(training_iterator)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This only works when we know the size of the dataloader. Can we think of a solution that doesn't require this information ? I think we can just iter on the dataloader until we have gradient_accumulation_steps
to create the batch_sample. If we can't iter anymore, then we stop also. I think that code will be easier to understand.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes agreed :) (What we do in the Trainer)
|
||
Results on a single device: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can precise the exact setup ? I think that we are doing the following ?
- dp=1 grad_acc= 2 batch_size = 4 vs dp=1 grad_acc= 1 batch_size = 8 ?
If we are only doing one update, then we won't be able to get a graph. Maybe we do this on a larger dataset where batch_size != len(data_loader) and add the graphs.
Results on a two devices set-up: | ||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On a two devices set-up, the modification you did to take into account the dp won't be reflected here as we are only changing grad acc and batch_size. So the loss will be the same nevertheless. However, it's nice to see that the total_num_items really changed:
- dp=2 grad_acc= 2 batch_size = 4 vs dp=2 grad_acc=1 batch_size=8
Maybe we should probably do a separate section/experiment to show the following will have the same loss graph
- dp=2 batch_size =2 is the same as dp=1 batch_size=4. See this experiment for clarification
tests/test_examples.py
Outdated
def test_gradient_accumulation_for_autoregressive_models(self): | ||
testargs = ["examples/by_feature/gradient_accumulation_for_autoregressive_models.py"] | ||
run_command(self.launch_args + testargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a nit: this doesn't use gradient accumulation here since it uses the default of 1
examples/by_feature/gradient_accumulation_for_autoregressive_models.py
Outdated
Show resolved
Hide resolved
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
any updates? |
…odels.py Co-authored-by: Zach Mueller <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Beautiful! LG2M, @SunMarc are you happy with it as well? :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really nice PR ! Thanks for iterating !
Hi, I have a question. |
Could you detail a bit your point ? |
cc @ylacombe |
What does this PR do?
Following the recent highlights on how gradient accumulation with the cross-entropy loss is usually off, it could be great to have it mentioned in the doc. I've thus added some code and explanation of it in the gradient accumulation page.
cc @SunMarc and @muellerzr, let me know what you think of it or if I can make this any clearer!