1

I have installed jax and jaxlib using pip according to https://github.com/google/jax#installation:

Successfully installed jax-0.1.68 jaxlib-0.1.67+cuda11

But after i ran my project, it showed Importerror:

Traceback (most recent call last):
  File "...", line 1, in <module>
    from jax import jit, jacfwd, jacrev, hessian, lax
  File "...", line 16, in <module>
    from .api import (
  File "...", line 38, in <module>
    from . import core
  File "...", line 30, in <module>
    from . import dtypes
  File "...", line 31, in <module>
    from .lib import xla_client
  File "...", line 51, in <module>
    from jaxlib import pytree
ImportError: cannot import name 'pytree' from 'jaxlib' (/home/control/.local/lib/python3.7/site-packages/jaxlib/__init__.py)

Could this problem comes from uncompatible jax/jaxlib version for running the project? If not, how to deal with it?

Samuel Leung
  • 21
  • 1
  • 4

1 Answers1

0

It appears that you are importing a much older jax version than you report in the question; jax.lib has not attempted to import pytree from jaxlib since version 0.2.8.

This probably indicates that you are running pip install in a different environment than the one you're using to execute code.

Assuming you are working at the command prompt, try this instead:

$ python -m pip install jax jaxlib
$ python -c "import jax; print(jax.__version__)"

(where you can replace python in both lines with whatever python executable you are using)

If you're working in Jupyter with different kernels, this answer might help you understand how to proceed: Running Jupyter with multiple Python and IPython paths

jakevdp
  • 77,104
  • 11
  • 125
  • 160
  • Thanks for your answer and heads up. yes the right version that is installed is jax-0.1.68 and jaxlib-0.1.67+cuda11, since the project is built upon it. – Samuel Leung Dec 13 '21 at 15:59
  • The author of that github project told me that the project may be not runnable with new jax version: https://github.com/YukunXia/Carla_iLQR_MPC/issues/3. i installed the new version jax as you mentioned, and it brings ValueError: – Samuel Leung Dec 13 '21 at 16:00
  • raceback (most recent call last): File "/home/control/Documents/Carla projects/Carla_iLQR_MPC/MPC/ilqr_jax_MPC.py", line 164, in jac_l, hes_l, jac_l_final, hes_l_final, jac_f = derivative_init() File "/home/control/.local/lib/python3.7/site-packages/jax/core.py", line 981, in concrete_aval raise TypeError(f"Value {repr(x)} with type {type(x)} is not a valid JAX " TypeError: Value .jacfun at 0x7ff1f81c07a0>> with type is not a valid JAX type – Samuel Leung Dec 13 '21 at 16:01
  • Could that be the case which the project really not works with new version? thanks for your time. – Samuel Leung Dec 13 '21 at 16:03
  • *with new version of jax – Samuel Leung Dec 13 '21 at 16:09
  • 1
    jax 0.1.68 is not compatible with jaxlib 0.1.67; the latter was released a full year after the former. You'll have to use something like jaxlib 0.1.50 which you can install with `pip install jax==0.1.68 jaxlib==0.1.50 -f https://storage.googleapis.com/jax-releases/jax_releases.html` – jakevdp Dec 13 '21 at 16:14
  • thank you here again. – Samuel Leung Dec 13 '21 at 22:27