-
Notifications
You must be signed in to change notification settings - Fork 686
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
gencast_mini_demo.ipynb on AMD CPU #113
Comments
Hey, This looks like a splash attention related error. Splash attention is only supported on TPU. You can try follow the GPU instructions to change attention mechanism, I believe this should work fine on CPU. Note that without knowing the memory specifications of your device, I can't guarantee it won't run out of memory. We've also never run GenCast on CPU so cannot make any guarantees around its correctness. Hope that helps! Andrew |
I will try your suggestion and report back here. |
I followed the suggestion in the "Running Inference on GPU" section of cloud_vm_setup.md task_config = ckpt.task_config The job (4 time steps and 8 members) ran for about 2h:30m using 17GB of system RAM with an averaged CPU load of ~30 (I have 48 cores). Unfortunately, the results are all NaN. GenCast/graphcast/GenCast/lib/python3.12/site-packages/numpy/lib/_nanfunctions_impl.py:1409: RuntimeWarning: All-NaN slice encountered |
I can't say I've seen this warning before. Could you confirm if the entire forecast was NaN? Note that we expect NaNs in the sea surface temperature variable so I wonder if this is what you might be encountering. |
I was plotting 2m_temp for all 8 ensemble members. All members had this same warning. I'll need to run it again to view other variables. |
specific humidity at 850 and 100, vertical speed at 850, geopotential at 500 and u and v components of wind at 925 are also NaN. I did not look at the rest. |
Any more ideas on how to investigate this issue? |
Unfortunately, we've never attempted to run the model on a CPU as this is too slow for practical uses. In principal there should be no reason why it should differ but unexpected device-specific compilation issues may be manifesting here. In the mean time hopefully the instructions on how to use free cloud compute are useful. Do let us know if you gain any insights on why this is happening. |
If you've never attempted to run it on a decent CPU, then how do you know it won't be practical? |
I also think it would be nice to be able to set up the model config and run it for one timestep on a our own CPU systems and then move it to cloud GPU or TPU. CPU systems have very large RAM nowadays. I set this in the notebook , but if the CPU count is greater than 1 , I get an AssertionError.
In the 'build jitted' section:
@andrewlkd Maybe #108 can be of some use, however, obviously, I don't understand how jax is working here with the CPUs. When the cpu device count is set to 1, it uses all the CPUs anyway. |
results from debugging so far are attached. I put a breakpoint in function chunked_prediction_generator() from rollout.py before predictor_fn(). I then printed out some variables looking for NaNs, then hit continue. The stack trace is in the attached text file. Please review and let me know if this help shed any light on how the NaNs are being generated. |
Hm, I'm not so sure this does shed light. This just suggests something in the actual predictor function (i.e. forward pass of GenCast) is causing NaNs when running on CPU. In case it was something to do with the pmapping, I just tried on my end to run in the non pmapped case and it still produces NaNs. Let me know if you get any more data points from debugging. |
I'm attempting to run the gencast_mini_demo.ipynb case on my home workstation without a GPU. The notebook recognizes that I don't have the correct software to run on the installed GPU and fails over to CPU (which is what want to happen).
Output from cell 22.
WARNING:2024-12-21 14:22:21,184:jax._src.xla_bridge:969: An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
I've attached the stack trace I get from cell 23 (Autoregressive rollout (loop in python)).
gencast.failure.txt
Is this expected? Does GenCast require a GPU or TPU to work?
The text was updated successfully, but these errors were encountered: