Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

xarray_jax does not support jax.jit().lower #98

Open
csubich opened this issue Sep 23, 2024 · 3 comments
Open

xarray_jax does not support jax.jit().lower #98

csubich opened this issue Sep 23, 2024 · 3 comments

Comments

@csubich
Copy link

csubich commented Sep 23, 2024

The JAX API now includes more detailed control over the compilation process with jax.stages, but the xarray_jax wrapper here in graphcast does not seem to support jax.jit().lower:

import graphcast.xarray_jax as xarray_jax
import jax.numpy as jnp
import jax

def ident(a): # Trivial test function
    return a

# Sample variables
foo = jnp.ones(3)
foo_xr = xarray_jax.DataArray(foo)

print(jax.jit(ident)(foo)) # Works
# [1. 1. 1.]

print(jax.jit(ident)(foo_xr)) # Works
# <xarray.DataArray (dim_0: 3)>
# xarray_jax.JaxArrayWrapper(Array([1., 1., 1.], dtype=float32))
# Dimensions without coordinates: dim_0

jax.jit(ident).lower(foo) # Works
# <jax._src.stages.Lowered at 0x151bb5e04830>

jax.jit(ident).lower(foo_xr) # Fails
Traceback
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[13], line 1
----> 1 jax.jit(ident).lower(foo_xr) # Fails

    [... skipping hidden 5 frame]

File /fs/site5/eccc/mrd/rpnatm/csu001/ppp5/graphcast_dev/graphcast/xarray_jax.py:668, in _unflatten_variable(aux, children)
    666 dims_change_fn = _DIMS_CHANGE_ON_UNFLATTEN_FN.get(None)
    667 if dims_change_fn: dims = dims_change_fn(dims)
--> 668 return Variable(dims=dims, data=children[0])

File /fs/site5/eccc/mrd/rpnatm/csu001/ppp5/graphcast_dev/graphcast/xarray_jax.py:113, in Variable(dims, data, **kwargs)
    111 def Variable(dims, data, **kwargs) -> xarray.Variable:  # pylint:disable=invalid-name
    112   """Like xarray.Variable, but can wrap JAX arrays."""
--> 113   return xarray.Variable(dims, wrap(data), **kwargs)

File ~/data/ppp5/conda_env/gforecast_test/lib/python3.11/site-packages/xarray/core/variable.py:365, in Variable.__init__(self, dims, data, attrs, encoding, fastpath)
    338 def __init__(
    339     self,
    340     dims,
   (...)
    344     fastpath=False,
    345 ):
    346     """
    347     Parameters
    348     ----------
   (...)
    363         unrecognized encoding items.
    364     """
--> 365     super().__init__(
    366         dims=dims, data=as_compatible_data(data, fastpath=fastpath), attrs=attrs
    367     )
    369     self._encoding = None
    370     if encoding is not None:

File ~/data/ppp5/conda_env/gforecast_test/lib/python3.11/site-packages/xarray/namedarray/core.py:253, in NamedArray.__init__(self, dims, data, attrs)
    246 def __init__(
    247     self,
    248     dims: _DimsLike,
    249     data: duckarray[Any, _DType_co],
    250     attrs: _AttrsLike = None,
    251 ):
    252     self._data = data
--> 253     self._dims = self._parse_dimensions(dims)
    254     self._attrs = dict(attrs) if attrs else None

File ~/data/ppp5/conda_env/gforecast_test/lib/python3.11/site-packages/xarray/namedarray/core.py:481, in NamedArray._parse_dimensions(self, dims)
    479 dims = (dims,) if isinstance(dims, str) else tuple(dims)
    480 if len(dims) != self.ndim:
--> 481     raise ValueError(
    482         f"dimensions {dims} must have the same length as the "
    483         f"number of data dimensions, ndim={self.ndim}"
    484     )
    485 if len(set(dims)) < len(dims):
    486     repeated_dims = set([d for d in dims if dims.count(d) > 1])

ValueError: dimensions ('dim_0',) must have the same length as the number of data dimensions, ndim=0

If the xarray is created inside a JITted function, things seem to work:

def make_xr(a):
    return xarray_jax.DataArray(a)

def compose(a):
    return (ident(make_xr(a)))

print(jax.jit(compose).lower(foo).compile()(foo)) # Works
# <xarray.DataArray (dim_0: 3)>
# xarray_jax.JaxArrayWrapper(Array([1., 1., 1.], dtype=float32))
# Dimensions without coordinates: dim_0

I'm not yet sure if exploding xarray arguments into a more pytree-friendly version only to recreate them inside a wrapper is a generic solution, or if doing so with graphcast would just reveal an error further in.

@allen-adastra
Copy link

Self-plug here, but there is a separate xarray_jax package on PyPi now, and I just tested and it works :)
https://github.com/allen-adastra/xarray_jax

    def ident(a):  # Trivial test function
        return a

    # Sample variables
    foo = jnp.ones(3)
    foo_xr = xr.DataArray(foo)

    print(jax.jit(ident)(foo))  # Works
    # [1. 1. 1.]

    print(jax.jit(ident)(foo_xr))  # Works
    # <xarray.DataArray (dim_0: 3)>
    # xarray_jax.JaxArrayWrapper(Array([1., 1., 1.], dtype=float32))
    # Dimensions without coordinates: dim_0

    jax.jit(ident).lower(foo)  # Works
    # <jax._src.stages.Lowered at 0x151bb5e04830>

    jax.jit(ident).lower(foo_xr)  # Fails

@csubich
Copy link
Author

csubich commented Oct 16, 2024

That's potentially useful, how is your package's compatibility with graphcast?

@allen-adastra
Copy link

I don't know. It does handle things differently.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants