Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Slow Gencast Inference on GPUs after 1st run #106

Closed
v-weh opened this issue Dec 6, 2024 · 4 comments
Closed

Slow Gencast Inference on GPUs after 1st run #106

v-weh opened this issue Dec 6, 2024 · 4 comments

Comments

@v-weh
Copy link

v-weh commented Dec 6, 2024

Many thanks for making Gencast code and weights public!

I managed to tweak the code in “gencast_demo_cloud_vm.ipynb” and got it running on a 8-GPUs (H100) cluster, to generate forecasts up to 15 days with 12 hours interval, with 8 ensembles.

First run took around ~35 minutes which is expected, however when I ran it the second time, it still took around ~30 - 35 minutes. Not sure if this is expected behaviour because I thought there is a fixed-time cost only when running the first time, and further runs will take only about ~8 minutes?

Or is that only applicable to using TPUs or only when I generate a single forecast e.g 15 days out rather than the entire sequence?

@andrewlkd
Copy link
Collaborator

Hello!

Apologies, the demo notebook implementations have an oversight here. You might notice that upon re-running the rollout cell (# @title Autoregressive rollout (loop in python)), that recompilation is triggered.

This is because the xarray_jax.pmap(run_forward_jitted, dim="sample") should only be called once. To fix this separate the call out of the cell:

%%
run_forward_pmap = xarray_jax.pmap(run_forward_jitted, dim="sample")

%%
# New cell that upon second run will no longer compile again
# ... code as before ...
for chunk in rollout.chunked_prediction_generator_multiple_runs(
    predictor_fn=run_forward_pmap,
    rngs=rngs,
    inputs=eval_inputs,
    targets_template=eval_targets * np.nan,
    forcings=eval_forcings,
    num_steps_per_chunk = 1,
    num_samples = num_ensemble_members,
    pmap_devices=jax.local_devices()
    ):
    chunks.append(chunk)
predictions = xarray.combine_by_coords(chunks`

Will send a fix to the repo ASAP, but thought I'd respond here first.

Thanks!

Andrew

@v-weh
Copy link
Author

v-weh commented Dec 9, 2024

Many thanks for the help! Just to confirm: does the 8 minutes stated in the paper refer to time taken to:

  1. generate single 30 steps out forecast i.e (+360h) or
  2. full 30 steps out forecasts with all the intermediate steps with 12 hours interval i.e (+12h, +24h, +36h ... +360h) ?

@andrewlkd
Copy link
Collaborator

The latter (on a TPUv5 and without compilation/tracing costs).

(Note that since we produce our forecasts autoregressively, the time taken to generate the 30th step - i.e. your former option - is the same as the time to produce all the intermediate steps since they are needed to feed back in as inputs!).

@v-weh
Copy link
Author

v-weh commented Dec 10, 2024

Thank you! I haven't yet seen an increase in run time tbh, but will re-run some experiments and see if I actually made mistakes, will get back to you!

@v-weh v-weh closed this as completed Jan 10, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants