Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problem when running gencast_mini_demo Colab, XlaRuntimeError: Failed to Deserialize the Mosaic Module #132

Open
AndrzejP-RE opened this issue Jan 28, 2025 · 2 comments

Comments

@AndrzejP-RE
Copy link

Hi,

I am trying to run the provided Colab notebook for GenCast Mini and encountered an issue when running the prediction cell (Autoregressive rollout (loop in python)). Below are the details:

Issue:
When executing the cell for autoregressive rollout, I encounter an error:

XlaRuntimeError: INTERNAL: Failed to deserialize the Mosaic module

Steps Taken:

  1. I ran the provided notebook in Google Colab without modifying it.
  2. Tried running the notebook using v2-8 TPU and also bought Colab subscription and run it using v5e1 TPU, but to no avail.
  3. Tried selecting the model from checkpoint and the default random model, got the same error each time.

Traceback:

---------------------------------------------------------------------------
XlaRuntimeError                           Traceback (most recent call last)
[<ipython-input-20-e54246abb142>](https://localhost:8080/#) in <cell line: 0>()
     14 
     15 chunks = []
---> 16 for chunk in rollout.chunked_prediction_generator_multiple_runs(
     17     # Use pmapped version to parallelise across devices.
     18     predictor_fn=run_forward_pmap,

4 frames
[/usr/local/lib/python3.11/dist-packages/graphcast/rollout.py](https://localhost:8080/#) in chunked_prediction_generator_multiple_runs(predictor_fn, rngs, inputs, targets_template, forcings, num_samples, pmap_devices, **chunked_prediction_kwargs)
    161         sample_forcings = None
    162 
--> 163       for prediction_chunk in chunked_prediction_generator(
    164           predictor_fn=predictor_fn_pmap_named_args,
    165           rng=sample_group_rngs,

[/usr/local/lib/python3.11/dist-packages/graphcast/rollout.py](https://localhost:8080/#) in chunked_prediction_generator(predictor_fn, rng, inputs, targets_template, forcings, num_steps_per_chunk, verbose, pmap_devices)
    343     # Make predictions for the chunk.
    344     rng, this_rng = split_rng_fn(rng)
--> 345     predictions = predictor_fn(
    346         rng=this_rng,
    347         inputs=current_inputs,

[/usr/local/lib/python3.11/dist-packages/graphcast/rollout.py](https://localhost:8080/#) in predictor_fn_pmap_named_args(rng, inputs, targets_template, forcings)
    119           devices=pmap_devices,
    120       )
--> 121       return predictor_fn(rng, inputs, targets_template, forcings)
    122 
    123     for i in range(0, num_samples, len(pmap_devices)):

[/usr/local/lib/python3.11/dist-packages/graphcast/xarray_jax.py](https://localhost:8080/#) in result_fn(*args)
    596     nonlocal input_treedef
    597     flat_args, input_treedef = jax.tree_util.tree_flatten(args)
--> 598     flat_result = pmapped_fn(*flat_args)
    599     assert output_treedef is not None
    600     # After the pmap an extra leading axis will be present, we need to add an

    [... skipping hidden 12 frame]

[/usr/local/lib/python3.11/dist-packages/jax/_src/compiler.py](https://localhost:8080/#) in backend_compile(backend, module, options, host_callbacks)
    264     # TODO(sharadmv): remove this fallback when all backends allow `compile`
    265     # to take in `host_callbacks`
--> 266     return backend.compile(built_c, compile_options=options)
    267   except xc.XlaRuntimeError as e:
    268     for error_handler in _XLA_RUNTIME_ERROR_HANDLERS:

XlaRuntimeError: INTERNAL: Failed to deserialize the Mosaic module
@andrewlkd
Copy link
Collaborator

Hey,

Unfortunately, this looks like an issue related to pallas code in splash_attention. I'll look into this and get back to you.

In the meantime, switching to triblockdiag_mha should unblock you if you'd like to press on. There's instructions on how to do this here.

-- Andrew

@AndrzejP-RE
Copy link
Author

Thanks for your response! triblockdiag_mha resolves this issue.

Unfortunately, the runtimes available on Colab seem too weak to run the 0.25-degree model, which is a shame—it would be great to test the model’s forecasts.

On a side-note, do you have a rough estimate for when the historical and real-time forecasts will be released? I know you mentioned "soon," but would that be closer to two weeks or two months?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants