I have installed jax and jaxlib using pip according to https://github.com/google/jax#installation:
Successfully installed jax-0.1.68 jaxlib-0.1.67+cuda11
But after i ran my project, it showed Importerror:
Traceback (most recent call last):
File "...", line 1, in <module>
from jax import jit, jacfwd, jacrev, hessian, lax
File "...", line 16, in <module>
from .api import (
File "...", line 38, in <module>
from . import core
File "...", line 30, in <module>
from . import dtypes
File "...", line 31, in <module>
from .lib import xla_client
File "...", line 51, in <module>
from jaxlib import pytree
ImportError: cannot import name 'pytree' from 'jaxlib' (/home/control/.local/lib/python3.7/site-packages/jaxlib/__init__.py)
Could this problem comes from uncompatible jax/jaxlib version for running the project? If not, how to deal with it?