3

I am currently trying to implement my work within the jax-framework. However I am now encountering an error using the linear solve function from jax.

Here is an example taken directly from the numpy linear algebra documentation page:

import numpy as np
a = np.array([[1, 2], [3, 5]])
b = np.array([1, 2])
x = np.linalg.solve(a, b)
print(x)

Now the same example is carried out using the linear solve in jax

import jax.numpy as jnp
a = jnp.array([[1, 2], [3, 5]])
b = jnp.array([1, 2])
x = jnp.linalg.solve(a, b)
print(x)

The following error message is now generated:

JaxStackTraceBeforeTransformation: ValueError: DenseElementsAttr could not be constructed from the given buffer. This may mean that the Python buffer layout does not match that MLIR expected layout and is a bug.

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.


The above exception was the direct cause of the following exception:

Traceback (most recent call last):

File "C:\Users\mabso\Desktop\PhD\bayesian NMR modelling\prototype 1 modelling\sequential modelling\python implementation\idea bucket\jax attempt.py", line 96, in x = jnp.linalg.solve(a, b)

File "C:\ProgramData\Anaconda3\lib\site-packages\jax_src\traceback_util.py", line 162, in reraise_with_filtered_traceback return fun(*args, **kwargs)

File "C:\ProgramData\Anaconda3\lib\site-packages\jax_src\api.py", line 466, in cache_miss out_flat = xla.xla_call(

File "C:\ProgramData\Anaconda3\lib\site-packages\jax\core.py", line 1771, in bind return call_bind(self, fun, *args, **params)

File "C:\ProgramData\Anaconda3\lib\site-packages\jax\core.py", line 1787, in call_bind outs = top_trace.process_call(primitive, fun_, tracers, params)

File "C:\ProgramData\Anaconda3\lib\site-packages\jax\core.py", line 660, in process_call return primitive.impl(f, *tracers, **params)

File "C:\ProgramData\Anaconda3\lib\site-packages\jax_src\dispatch.py", line 149, in _xla_call_impl compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,

File "C:\ProgramData\Anaconda3\lib\site-packages\jax\linear_util.py", line 285, in memoized_fun ans = call(fun, *args)

File "C:\ProgramData\Anaconda3\lib\site-packages\jax_src\dispatch.py", line 197, in _xla_callable_uncached return lower_xla_callable(fun, device, backend, name, donated_invars,

File "C:\ProgramData\Anaconda3\lib\site-packages\jax_src\profiler.py", line 206, in wrapper return func(*args, **kwargs)

File "C:\ProgramData\Anaconda3\lib\site-packages\jax_src\dispatch.py", line 296, in lower_xla_callable module = mlir.lower_jaxpr_to_module(

File "C:\ProgramData\Anaconda3\lib\site-packages\jax\interpreters\mlir.py", line 524, in lower_jaxpr_to_module lower_jaxpr_to_fun(

File "C:\ProgramData\Anaconda3\lib\site-packages\jax\interpreters\mlir.py", line 688, in lower_jaxpr_to_fun out_vals = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack),

File "C:\ProgramData\Anaconda3\lib\site-packages\jax\interpreters\mlir.py", line 789, in jaxpr_subcomp ans = rule(rule_ctx, *map(_unwrap_singleton_ir_values, in_nodes),

File "C:\ProgramData\Anaconda3\lib\site-packages\jax_src\lax\linalg.py", line 1228, in _lu_cpu_gpu_lowering lu, pivot, info = getrf_impl(operand_aval.dtype, operand)

File "C:\ProgramData\Anaconda3\lib\site-packages\jaxlib\lapack.py", line 269, in getrf_mhlo ir.DenseIntElementsAttr.get(np.arange(num_bd, -1, -1),

ValueError: DenseElementsAttr could not be constructed from the given buffer. This may mean that the Python buffer layout does not match that MLIR expected layout and is a bug.

I have no idea what this error actually means though. I saw the following post DenseElementsAttr could not be constructed from the given buffer which seems to have similar issues. However, when I run the code from the post it works without causing any problems.

I run Jax and Jaxlib on windows 10 with versions 0.3.7 using only CPU

Hope someone knows something:)

mabso
  • 33
  • 4

2 Answers2

2

I think this is a windows-specific build issue in the jaxlib wheel that you have installed. Unfortunately, JAX does not have any official windows support at the moment, so if you're running on windows you're likely using a community-supported installation. Given the other similar question, it sounds like there are issues with a build for a recent jaxlib release.

I would recommend reporting this issue with the source of the windows jaxlib installation you're using, and try installing older versions of jax and jaxlib to see if you can find a version that does not have this problem.

Alternatively, if you're able to use WSL for your work, you can install and use an official Linux release that would hopefully not have this problem.

jakevdp
  • 77,104
  • 11
  • 125
  • 160
  • I also encountered this error on Windows (with https://github.com/cloudhan/jax-windows-builder) for both Jax 0.3.7 and 0.3.5. Downgrading to 0.3.2 resolved it. Unfortunately performance is still very poor on Windows for me, which I was hoping was related to this but apparently isn't. I would definitely recommend WSL if you just need CPU support (I couldn't get CUDA working on WSL). – tgbrooks Apr 25 '22 at 13:36
  • It doesn't look like the issue reported here and in the linked question has been raised in the windows builder repository – the maintainers probably do not know about it. – jakevdp Apr 25 '22 at 14:17
  • Thanks for the reply. Yeah I figured it was something like a windows error, but just wanted to be sure before I moved on. I ended up installing a virtual machine with Linux and Jax works just fine now (at least for CPU). Now I just need to figure out how to get CUDA working - no idea if it is possible on a virtual machine though. I'll post the issues to the Windows Jax repository and see what happens. Thanks again for the quick reply. – mabso Apr 26 '22 at 07:36
0

See the bug report for the root cause.

It is a long standing platform-wise behaviour and it just bites one more time, see numpy array dtype is coming as int32 by default in a windows 10 64 bit machine

Cloud
  • 1,374
  • 1
  • 13
  • 26