I am currently trying to implement my work within the jax-framework. However I am now encountering an error using the linear solve function from jax.
Here is an example taken directly from the numpy linear algebra documentation page:
import numpy as np
a = np.array([[1, 2], [3, 5]])
b = np.array([1, 2])
x = np.linalg.solve(a, b)
print(x)
Now the same example is carried out using the linear solve in jax
import jax.numpy as jnp
a = jnp.array([[1, 2], [3, 5]])
b = jnp.array([1, 2])
x = jnp.linalg.solve(a, b)
print(x)
The following error message is now generated:
JaxStackTraceBeforeTransformation: ValueError: DenseElementsAttr could not be constructed from the given buffer. This may mean that the Python buffer layout does not match that MLIR expected layout and is a bug.
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "C:\Users\mabso\Desktop\PhD\bayesian NMR modelling\prototype 1 modelling\sequential modelling\python implementation\idea bucket\jax attempt.py", line 96, in x = jnp.linalg.solve(a, b)
File "C:\ProgramData\Anaconda3\lib\site-packages\jax_src\traceback_util.py", line 162, in reraise_with_filtered_traceback return fun(*args, **kwargs)
File "C:\ProgramData\Anaconda3\lib\site-packages\jax_src\api.py", line 466, in cache_miss out_flat = xla.xla_call(
File "C:\ProgramData\Anaconda3\lib\site-packages\jax\core.py", line 1771, in bind return call_bind(self, fun, *args, **params)
File "C:\ProgramData\Anaconda3\lib\site-packages\jax\core.py", line 1787, in call_bind outs = top_trace.process_call(primitive, fun_, tracers, params)
File "C:\ProgramData\Anaconda3\lib\site-packages\jax\core.py", line 660, in process_call return primitive.impl(f, *tracers, **params)
File "C:\ProgramData\Anaconda3\lib\site-packages\jax_src\dispatch.py", line 149, in _xla_call_impl compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
File "C:\ProgramData\Anaconda3\lib\site-packages\jax\linear_util.py", line 285, in memoized_fun ans = call(fun, *args)
File "C:\ProgramData\Anaconda3\lib\site-packages\jax_src\dispatch.py", line 197, in _xla_callable_uncached return lower_xla_callable(fun, device, backend, name, donated_invars,
File "C:\ProgramData\Anaconda3\lib\site-packages\jax_src\profiler.py", line 206, in wrapper return func(*args, **kwargs)
File "C:\ProgramData\Anaconda3\lib\site-packages\jax_src\dispatch.py", line 296, in lower_xla_callable module = mlir.lower_jaxpr_to_module(
File "C:\ProgramData\Anaconda3\lib\site-packages\jax\interpreters\mlir.py", line 524, in lower_jaxpr_to_module lower_jaxpr_to_fun(
File "C:\ProgramData\Anaconda3\lib\site-packages\jax\interpreters\mlir.py", line 688, in lower_jaxpr_to_fun out_vals = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack),
File "C:\ProgramData\Anaconda3\lib\site-packages\jax\interpreters\mlir.py", line 789, in jaxpr_subcomp ans = rule(rule_ctx, *map(_unwrap_singleton_ir_values, in_nodes),
File "C:\ProgramData\Anaconda3\lib\site-packages\jax_src\lax\linalg.py", line 1228, in _lu_cpu_gpu_lowering lu, pivot, info = getrf_impl(operand_aval.dtype, operand)
File "C:\ProgramData\Anaconda3\lib\site-packages\jaxlib\lapack.py", line 269, in getrf_mhlo ir.DenseIntElementsAttr.get(np.arange(num_bd, -1, -1),
ValueError: DenseElementsAttr could not be constructed from the given buffer. This may mean that the Python buffer layout does not match that MLIR expected layout and is a bug.
I have no idea what this error actually means though. I saw the following post DenseElementsAttr could not be constructed from the given buffer which seems to have similar issues. However, when I run the code from the post it works without causing any problems.
I run Jax and Jaxlib on windows 10 with versions 0.3.7 using only CPU
Hope someone knows something:)