1

I have been trying to understand this official flax example, based on a Coalb pro+ account with V100. When I execute the command python main.py --workdir=./imagenet --config=configs/v100_x8.py , the returned error is

File "/content/FlaxImageNet/main.py", line 29, in <module>
import train
File "/content/FlaxImageNet/train.py", line 30, in <module>
from flax.training import checkpoints
File "/usr/local/lib/python3.10/dist-packages/flax/training/checkpoints.py", line 34, 
in <module>
from jax.experimental.global_device_array import GlobalDeviceArray
ModuleNotFoundError: No module named 'jax.experimental.global_device_array'

I am not sure whether global_device_array has been moved from jax.experimental package or it is no longer needed or replaced by other equivalent methods.

RanWang
  • 310
  • 2
  • 12

1 Answers1

1

GlobalDeviceArray was deprecated in JAX version 0.4.1 and removed in JAX version 0.4.7.

With that in mind, it seems the code in question requires JAX version 0.4.6 or older. You might consider reporting this incompatibility to the flax project: http://github.com/google/flax/.

jakevdp
  • 77,104
  • 11
  • 125
  • 160
  • Hi! Then is there a way to run the `main.py` by changing it to some similar functionality? Also I will report this~ Thank you very much. – RanWang May 07 '23 at 07:05
  • I'm also facing the same problem, what would be the replacement for GlobalDeviceArray? Thanks @jakevdp – m0j1 Aug 05 '23 at 19:41
  • It depends on the context, but you should generally use `jax.Array`. – jakevdp Aug 05 '23 at 20:14