Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GenCast Ensemble GPU Parallelization #108

Closed
jacob-t-radford opened this issue Dec 8, 2024 · 3 comments
Closed

GenCast Ensemble GPU Parallelization #108

jacob-t-radford opened this issue Dec 8, 2024 · 3 comments

Comments

@jacob-t-radford
Copy link

I have a few different cloud GPU partitions I'd like to use to run a GenCast ensemble. I was wondering if running with num_ensemble_members set to 4 on one GPU and running with num_ensemble members set to 1 on four different GPUs (and then combining at the end) is equivalent? If not, do you have any suggestions on how I could run a small GenCast ensemble with a controller and multiple GPU partitions? Sorry if this question is a bit unclear, I'm just trying to figure out if we can feasibly run a GenCast ensemble in real-time using our existing resources.

@blackdooo
Copy link

Yeah sounds interesting. Running a GenCast ensemble across multiple GPUs by splitting num_ensemble_members should work theoretically, but combining outputs accurately might depend on model consistency. To enable efficient parallelization, consider using a controller to manage GPU assignments dynamically. Clarifying synchronization methods for ensemble outputs could also help. Would be great to hear thoughts on potential bottlenecks!

@andrewlkd
Copy link
Collaborator

This should definitely be possible with some care around rngs!

As per the demo notebook

rng = jax.random.PRNGKey(0)
# We fold-in the ensemble member, this way the first N members should always
# match across different runs which use take the same inputs, regardless of
# total ensemble size.
rngs = np.stack(
    [jax.random.fold_in(rng, i) for i in range(num_ensemble_members)], axis=0)

I.e. using this code as it stands: 4 runs x 1 ensemble member != 1 run x 4 ensemble members because in the former they will all have the same rng.

Otherwise, if you ensure each GPU partition gets the correct rng(s) then this is indeed equivalent.

Hope this helps!

Andrew

@jacob-t-radford
Copy link
Author

Awesome, thank you for the quick response and clarification!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants