-
Notifications
You must be signed in to change notification settings - Fork 262
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Loss metrics dramatically change after resuming from checkpoint #809
Comments
It case it helps to narrow down the issue, it does appear that the shape of the curve of the resumed run is influenced by the config If I force the resumed run to step beyond |
@darkmirage Can you be more specific on what were the configs before and after, if you ever changed them? E.g. in the post I only see one set of config, without checkpointing enabled. Please include a full config change. Also I didn't understand
In current setup, we expect resuming from exact same |
The config used with the default
This is good to know because it seems a bit unintuitive as it would be common to want to train a checkpoint from a previous job further. What would be the suggested workflow for that? Creating a new checkpoint and reseting the LR scheduler config manually? That said, the first two 16 GPU runs above were restarted from the checkpoints with the same Overall it sounds like something to do with the LR scheduler is the issue here. |
I believe the current setup is more from a fault tolerant perspective -- we know how many steps to train from the very beginning; in case job failed, we can resume from some checkpoints we have saved.
Yeah this is a valid use case we don't support right now but we should add. In fact this relates to your other ask #811 We should make things like data loader, LR scheduler optional in checkpointing. |
Meanwhile let me try to reproduce this issue. |
There are several checkpointing use cases:
TorchTitan currently supports 1 and 2 very well, but there may be a bug according to this issue. @mori360 and I are working on this. TorchTitan doesn't support 3 well. @darkmirage your request and #811 belong to 3. To have good support of 3, TorchTitan needs to support partially checkpoint loading, which @mori360 is doing. We also need a good document to specify what components need to be excluded from loading when doing 3. For example, loading dataloader checkpointing with changes of the number of GPUs won't work. |
I am consistently getting bad loss values when resuming from a checkpoint.
Green is the original run, red and blue resumes from different checkpoints.
On 16 GPUs.
On 32 GPUs.
Runs were done on the 8b config with these overriding flags:
The text was updated successfully, but these errors were encountered: