2

I try to run a code written with JAX. At one part of the code, key for training set is defined as

key_train = random.PRNGKey(0).

Here the type of the key jaxlib.xla_extension.DeviceArray. Then in the following part, keys are defined as keys = random.split(key_train, N). Here N is an integer which is equal to 10000. At that part of the code it gives an error like:

DenseElementsAttr could not be constructed from the given buffer. This may mean that the Python buffer layout does not match that MLIR expected layout and is a bug.

Could you please help me about the error?

Edit: I try to run the code on Win10. Here (https://github.com/PredictiveIntelligenceLab/Physics-informed-DeepONets/blob/main/Antiderivative/DeepONet_antideriv.ipynb) you can find the code that I try to run. For simplicity you can try to run the code below as well. You will get the exact same error.

from jax import random
N=10000
key_train=random.PRNGKey(0)
keys=random.split(key_train, N)

Jax and Jaxlib versions are 0.3.5 with cuda 11

  • Could you please edit your question to include a [minimal reproducible example](https://stackoverflow.com/help/minimal-reproducible-example) of the problem, and also report what version of `jax` and `jaxlib` you have installed? – jakevdp Apr 11 '22 at 19:21
  • I built from source today from Windows, and got the exact same problem as OP. Potentially a Windows bug in the latest release? Has anyone found a fix? – Alerra Apr 12 '22 at 03:04
  • Edited the question with a minimal reproducible example and added the version of jax. – burak ateş Apr 12 '22 at 08:00
  • Thanks - it sounds like it's a windows-specific build issue. You're not going to find an answer for this on StackOverflow: you might try opening an issue in the JAX repository, or perhaps in the repository associated with the windows binary you are using (the JAX team does not maintain any windows builds). – jakevdp Apr 12 '22 at 16:49

2 Answers2

3

I had the same error. Deleting c:\python37\lib\site-packages\jaxlib\cuda_prng.py fixed the issue (replace the prefix by your python path). It could be cuda_prng.py was an old file.

Erwin Coumans
  • 1,766
  • 14
  • 22
0

By following Erwin Coumans's advice and indeed the mistake disappeared. The version installed on Windows is ( jaxlib 0.3.7 with cuda 11.3 and cudnn 8.2 , on anaconda python 3.9.0) the latest found on : https://whls.blob.core.windows.net/unstable/index.html