I have been trying to understand this official flax example, based on a Coalb pro+ account with V100. When I execute the command python main.py --workdir=./imagenet --config=configs/v100_x8.py
, the returned error is
File "/content/FlaxImageNet/main.py", line 29, in <module>
import train
File "/content/FlaxImageNet/train.py", line 30, in <module>
from flax.training import checkpoints
File "/usr/local/lib/python3.10/dist-packages/flax/training/checkpoints.py", line 34,
in <module>
from jax.experimental.global_device_array import GlobalDeviceArray
ModuleNotFoundError: No module named 'jax.experimental.global_device_array'
I am not sure whether global_device_array has been moved from jax.experimental
package or it is no longer needed or replaced by other equivalent methods.