You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am reproducing GraphCast training phases, and it seems to take dramatically more time on Phase 3 (autoregressive phase) than on Phase 1 and 2.
Is it possible to share the approximated time taken for each training Phases 1, 2, and 3 of GraphCast?
Also, the model seems to use a LOT of device memory on phase 3. Is there any tips on reducing device memory usage during training?
Thank you.
The text was updated successfully, but these errors were encountered:
Is it possible to share the approximated time taken for each training Phases 1, 2, and 3 of GraphCast?
Out of the approximately 4 weeks that it took to train the model, phase 3 was about 11 days. Phase 1 and Phase 2 took time proportional to the relative number of training steps in each.
Also, the model seems to use a LOT of device memory on phase 3.
Yes, this is indeed hard to work around. At the time we were limited to 32GB of device memory per training example using TPUv4 and we had to leverage various gradient check-pointing strategies. With today's TPUv5 it should be a lot easier to get it to fit.
Typically if you can 1-step part of training training fit in 32GB by adding some hk.remat options (e.g. TPUv4), then you should be able to fit the 12AR part of training on the TPU v5, so long as you set this to True for the autoregressive training.
Hello,
I am reproducing GraphCast training phases, and it seems to take dramatically more time on Phase 3 (autoregressive phase) than on Phase 1 and 2.
Is it possible to share the approximated time taken for each training Phases 1, 2, and 3 of GraphCast?
Also, the model seems to use a LOT of device memory on phase 3. Is there any tips on reducing device memory usage during training?
Thank you.
The text was updated successfully, but these errors were encountered: