Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loss metrics dramatically change after resuming from checkpoint #809

Open
darkmirage opened this issue Jan 28, 2025 · 6 comments
Open

Loss metrics dramatically change after resuming from checkpoint #809

darkmirage opened this issue Jan 28, 2025 · 6 comments
Assignees
Labels
bug Something isn't working enhancement New feature or request

Comments

@darkmirage
Copy link

darkmirage commented Jan 28, 2025

I am consistently getting bad loss values when resuming from a checkpoint.

Green is the original run, red and blue resumes from different checkpoints.

On 16 GPUs.
Image

Image

On 32 GPUs.
Image

Runs were done on the 8b config with these overriding flags:

--metrics.enable_wandb --training.compile --training.seq_len 8192 \
--training.batch_size 1 --training.tensor_parallel_degree 1 \
--training.data_parallel_shard_degree -1 --training.data_parallel_replicate_degree 1 \
--training.steps 5000 --training.warmup_steps 200 --activation_checkpoint.mode selective \
--checkpoint.interval 500 --training.seed 42
@darkmirage
Copy link
Author

It case it helps to narrow down the issue, it does appear that the shape of the curve of the resumed run is influenced by the config training.steps. It's not exactly the case, but the convergence point seems to be close to the total number of steps.

If I force the resumed run to step beyond training.steps the loss curve goes up very steeply.

@tianyu-l tianyu-l added the bug Something isn't working label Jan 30, 2025
@tianyu-l tianyu-l self-assigned this Jan 30, 2025
@tianyu-l
Copy link
Contributor

@darkmirage
Thanks for reporting the issue.

Can you be more specific on what were the configs before and after, if you ever changed them? E.g. in the post I only see one set of config, without checkpointing enabled. Please include a full config change.

Also I didn't understand

It's not exactly the case, but the convergence point seems to be close to the total number of steps.

In current setup, we expect resuming from exact same training.steps and warmup_steps, because the behavior of LR scheduler would depend on it. The behavior is undefined if you modify the steps config.

@darkmirage
Copy link
Author

The config used with the default llama3_8b.toml file with those override flags added to the CLI.

In current setup, we expect resuming from exact same training.steps and warmup_steps, because the behavior of LR scheduler would depend on it. The behavior is undefined if you modify the steps config.

This is good to know because it seems a bit unintuitive as it would be common to want to train a checkpoint from a previous job further. What would be the suggested workflow for that? Creating a new checkpoint and reseting the LR scheduler config manually?

That said, the first two 16 GPU runs above were restarted from the checkpoints with the same steps and warmup_steps settings. I just meant that for cases where we do change the steps parameter because we want to train the model further, the shape of the mismatch curve seems to be influenced by this change.

Overall it sounds like something to do with the LR scheduler is the issue here.

@tianyu-l
Copy link
Contributor

tianyu-l commented Feb 1, 2025

This is good to know because it seems a bit unintuitive as it would be common to want to train a checkpoint from a previous job further.

I believe the current setup is more from a fault tolerant perspective -- we know how many steps to train from the very beginning; in case job failed, we can resume from some checkpoints we have saved.

What would be the suggested workflow for that? Creating a new checkpoint and reseting the LR scheduler config manually?

Yeah this is a valid use case we don't support right now but we should add. In fact this relates to your other ask #811

We should make things like data loader, LR scheduler optional in checkpointing.
@fegin can you help make a PR on this?

@tianyu-l tianyu-l added the enhancement New feature or request label Feb 1, 2025
@tianyu-l
Copy link
Contributor

tianyu-l commented Feb 1, 2025

That said, the first two 16 GPU runs above were restarted from the checkpoints with the same steps and warmup_steps settings.

Meanwhile let me try to reproduce this issue.
If it's due to lr scheduler, #794 could have fixed it.

@tianyu-l tianyu-l added this to the torchtitan v1.0.0 release milestone Feb 2, 2025
@fegin
Copy link
Contributor

fegin commented Feb 4, 2025

There are several checkpointing use cases:

  1. Recovering from a failure. For such a use case, we need to load everything, e.g., model, optimizer, lr_scheduler, dataloader, with the same settings, e.g., --training.steps.
  2. Load from a pre-trained model weights. For such a use case, we need to load model only.
  3. Partially load from a previous-trained checkpoint (including model, optimizer and lr_scheduler) and change the settings (e.g., step, the number of GPUs).

TorchTitan currently supports 1 and 2 very well, but there may be a bug according to this issue. @mori360 and I are working on this.

TorchTitan doesn't support 3 well. @darkmirage your request and #811 belong to 3.

To have good support of 3, TorchTitan needs to support partially checkpoint loading, which @mori360 is doing. We also need a good document to specify what components need to be excluded from loading when doing 3. For example, loading dataloader checkpointing with changes of the number of GPUs won't work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants