I'm trying to use the Jax library with haiku on python3.6 at conda env, I met the below error and am stuck. I have tried to update my Jax version but nothing changed how can I fix it?
Traceback (most recent call last):
File "train.py", line 14, in <module>
import haiku as hk
File "/home/young/.local/lib/python3.6/site-packages/haiku/__init__.py", line 19, in <module>
from haiku import data_structures
File "/home/young/.local/lib/python3.6/site-packages/haiku/data_structures.py", line 18, in <module>
from haiku._src.data_structures import to_haiku_dict
File "/home/young/.local/lib/python3.6/site-packages/haiku/_src/data_structures.py", line 176, in <module>
class FlatComponents(NamedTuple):
File "/home/young/.local/lib/python3.6/site-packages/haiku/_src/data_structures.py", line 178, in FlatComponents
structure: jax.tree_util.PyTreeDef
AttributeError: module 'jax.tree_util' has no attribute 'PyTreeDef'