Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trying to import jax raises ImportError #7796

Closed
raghuramshankar opened this issue Sep 3, 2021 · 12 comments
Closed

Trying to import jax raises ImportError #7796

raghuramshankar opened this issue Sep 3, 2021 · 12 comments
Assignees
Labels
bug Something isn't working

Comments

@raghuramshankar
Copy link

Trying to import Jax raises ImportError:

import jax

This gives me the error:

File "/home/user/.local/lib/python3.9/site-packages/jax/_src/api.py", line 52, in <module>
    from ..tree_util import (tree_map, tree_flatten, tree_unflatten, tree_structure,
ImportError: cannot import name 'Partial' from 'jax.tree_util' (/home/user/.local/lib/python3.9/site-packages/jax/tree_util.py)

Python and jax version:

python --version
Python 3.9.6

pip show jax 
Name: jax
Version: 0.2.19

My system info:
Linux Manjaro (Arch Linux) with the latest kernel.

uname -r
5.13.12-1-MANJARO

Adding Partial to the list in /home/user/.local/lib/python3.9/site-packages/jax/tree_util.py line 40 solved the issue for me and I am able to successfully import jax.
I also noticed that partial is available in the list but replacing partial with Partial seems to raise more errors for me.

@raghuramshankar raghuramshankar added the bug Something isn't working label Sep 3, 2021
@jakevdp
Copy link
Collaborator

jakevdp commented Sep 3, 2021

Hi - thanks for the report!

I'm a bit confused about what's going on, because you mention that you're using JAX 0.2.19, and you mention adding Partial at line 40 of jax/tree_util.py, but if I look at that line within the 0.2.19 source tree Partial is already there: https://github.com/google/jax/blob/jax-v0.2.19/jax/tree_util.py#L40

I would do two things: first, check the output of this:

$ python -m pip show jax

calling pip this way assures you're using the pip associated with your python executable, which is not guaranteed if you just call pip alone.

If that also shows jax v 0.2.19, then I assume your installation has been corrupted somehow. I'd then run

$ python -m pip uninstall jax jaxlib
$ python -m pip install jax[cpu]

(or the equivalent GPU installation command if relevant)

@jakevdp jakevdp self-assigned this Sep 3, 2021
@ioannis12
Copy link

ioannis12 commented Sep 3, 2021

hello,
I think I have the same problem, when I try to import jax I get an error message "kernel died, restarting kernel".
I use spyder 4.1.0.

I tried uninstalling and installing different versions, but it doesn't seem to work

@raghuramshankar
Copy link
Author

Hi,

Thank you for your reply.
I am using Jax as a solver for Pybamm simulations. I just found that the issue was with the installation of the Pybamm library, and not with the Jax library. Here is the issue that was raised over there regarding the same.
Sorry for not catching this earlier, I did not expect it to be this way.
Thanks!

Regards
Raghuram

@jakevdp
Copy link
Collaborator

jakevdp commented Sep 13, 2021

Thanks! I'll close the issue then. Feel free to open another if you run into any other problems!

@jakevdp jakevdp closed this as completed Sep 13, 2021
@RylanSchaeffer
Copy link

I received the same error from jax 0.2.21:

  File "/om2/user/rylansch/FieteLab-RIBP/ribp/lib/python3.7/site-packages/numpyro/infer/hmc.py", line 8, in <module>
    from jax import device_put, lax, partial, random, vmap
ImportError: cannot import name 'partial' from 'jax' (/om2/user/rylansch/FieteLab-RIBP/ribp/lib/python3.7/site-packages/jax/__init__.py)
python-BaseException

@RylanSchaeffer
Copy link

I think the problem is with numpyro (pyro-ppl/numpyro#1165)

@jakevdp
Copy link
Collaborator

jakevdp commented Oct 7, 2021

jax.partial was removed in version 0.2.21 (see the Changelog). You should import functools.partial instead, or install version 0.2.20 if you depend on other packages that attempt to import partial from the jax namespace.

@yukiweng62
Copy link

I received a similar error from jax 0.4.11, could u help me:

File C:\ProgramData\anaconda3\Lib\site-packages\flax\linen\linear.py:30
from jax import ShapedArray

ImportError: cannot import name 'ShapedArray' from 'jax' (C:\ProgramData\anaconda3\Lib\site-packages\jax_init_.py)

@jakevdp
Copy link
Collaborator

jakevdp commented Sep 20, 2023

jax.ShapedArray was removed in JAX v0.4.14 after being deprecated for several releases (see the JAX v0.4.14 changelog).

I would suggest updating flax to a more recent version that is compatible with recent JAX versions.

@yukiweng62
Copy link

Thanks a lot!!!

@aiyb1314
Copy link

ImportError: cannot import name 'abstract_arrays' from 'jax' (/home/xuj/anaconda3/envs/removal/lib/python3.10/site-packages/jax/init.py)
How to solve??

@jakevdp
Copy link
Collaborator

jakevdp commented Dec 18, 2023

jax.abstract_arrays was deprecated in jax v0.4.12 and removed in JAX v 0.4.17.

To fis the error you're seeing, you can avoid attempting to import jax.abstract_arrays. If you're using another package that imports this deprecated module, you could try updating that other package, or if that's not possible, install JAX v0.4.16 or older.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

6 participants