-
Notifications
You must be signed in to change notification settings - Fork 686
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Training without Gradient Checkpointing #124
Comments
Hi! We did have some additional gradient check-pointing when we trained the models, even when using a single auto-regressive step, which we removed from this implementation from simplicity, as this implementation main purpose was to serve as reference. For the mesh model, we gradient checkpointed every 3 message passing steps. We also gradient checkpointed the whole encoder gnn, and the whole decoder gnn. Also inside the edge model of the encoder/decoder gnns we had a second level of gradient checkpointing to do the edge update in blocks of edges, rather than computing the edge update for all edges at once. Furthermore, the compilation of XLA/JAX is very efficient, and fuses many operations together, which also works as a implicit form of gradient checkpointing, since for fused operations, the forward pass does not store intermediate values. Also the XLA compiler tries on its own to add additional rematerialization, when there are opportunities to make it fit on the device. Hope this helps! |
Hi Alvarosg, Thanks for the quick reply and it helps a lot. Now I see how the model could be fitted in a TPU v4 |
Hi there,
After went through the graphcast paper and the code, it seems to me that the training was done without using gradient checkpointing when not using back-propagation through time. I am just curious how this could be done with 32GB of memory since the activations are huge. The grid nodes MLP encoder itself takes around 10 GB of memory when using pytorch.
The text was updated successfully, but these errors were encountered: