Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training without Gradient Checkpointing #124

Closed
yonghanyu opened this issue Jan 13, 2025 · 2 comments
Closed

Training without Gradient Checkpointing #124

yonghanyu opened this issue Jan 13, 2025 · 2 comments

Comments

@yonghanyu
Copy link

Hi there,

After went through the graphcast paper and the code, it seems to me that the training was done without using gradient checkpointing when not using back-propagation through time. I am just curious how this could be done with 32GB of memory since the activations are huge. The grid nodes MLP encoder itself takes around 10 GB of memory when using pytorch.

@alvarosg
Copy link
Collaborator

Hi! We did have some additional gradient check-pointing when we trained the models, even when using a single auto-regressive step, which we removed from this implementation from simplicity, as this implementation main purpose was to serve as reference.

For the mesh model, we gradient checkpointed every 3 message passing steps. We also gradient checkpointed the whole encoder gnn, and the whole decoder gnn. Also inside the edge model of the encoder/decoder gnns we had a second level of gradient checkpointing to do the edge update in blocks of edges, rather than computing the edge update for all edges at once.

Furthermore, the compilation of XLA/JAX is very efficient, and fuses many operations together, which also works as a implicit form of gradient checkpointing, since for fused operations, the forward pass does not store intermediate values. Also the XLA compiler tries on its own to add additional rematerialization, when there are opportunities to make it fit on the device.

Hope this helps!

@yonghanyu
Copy link
Author

Hi! We did have some additional gradient check-pointing when we trained the models, even when using a single auto-regressive step, which we removed from this implementation from simplicity, as this implementation main purpose was to serve as reference.

For the mesh model, we gradient checkpointed every 3 message passing steps. We also gradient checkpointed the whole encoder gnn, and the whole decoder gnn. Also inside the edge model of the encoder/decoder gnns we had a second level of gradient checkpointing to do the edge update in blocks of edges, rather than computing the edge update for all edges at once.

Furthermore, the compilation of XLA/JAX is very efficient, and fuses many operations together, which also works as a implicit form of gradient checkpointing, since for fused operations, the forward pass does not store intermediate values. Also the XLA compiler tries on its own to add additional rematerialization, when there are opportunities to make it fit on the device.

Hope this helps!

Hi Alvarosg,

Thanks for the quick reply and it helps a lot. Now I see how the model could be fitted in a TPU v4

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants