Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JAX problems #89

Closed
agbruno-git opened this issue Aug 1, 2024 · 2 comments
Closed

JAX problems #89

agbruno-git opened this issue Aug 1, 2024 · 2 comments

Comments

@agbruno-git
Copy link

I am running the model using my own data read from ERA5, the dataset I am using to feed the model looks identical to the one stored on the google bucked.
The worked fine when I run it locally with the data used in the demo, but I am having problems with JAX when I run the model using my data.

I had this error message

TypeError: Argument 'dask.array<getitem, shape=(1, 2, 37, 721, 1440), dtype=float32, chunksize=(1, 2, 37, 721, 1440), chunktype=numpy.ndarray>' of type <class 'dask.array.core.Array'> is not a valid JAX type.

Anyone had the same problem?

@alvarosg
Copy link
Collaborator

alvarosg commented Aug 6, 2024

I think this may be fixed by calling ".compute()" on all of the data (inputs, targets and forcings), before passing it to the model.

@Zappandy
Copy link

@alvarosg I can confirm this was fixed by calling compute before passing it to the predictor object. Feel free to close the issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants