1

This question is similar to the question here, but I cannot link with what I should alter.

I have a function

def elbo(variational_parameters, eps, a, b):
    ...
    return theta, _

elbo = jit(elbo, static_argnames=["a", "b"])

where variational_parameters is a vector (one-dimensional array) of length P, eps is a two-dimensional array of dimensions K by N, and a, b are fixed values.

The elbo has been successfully vmapped over the rows of eps, and has been jitted by setting by passing a and b to static_argnames, to return theta, which is a two-dimensional array of dimensions K by P.

I want to take the Jacobian of the output theta with respect to variational_parameters through the elbo function. The first value returned by

jacobian(elbo, argnums=0, has_aus=True)(variational_parameters, eps, a, b)

gives me a three-dimensional array of dimensions K by P by N. This is what I want. As soon as I try to jit this function

jit(jacobian(elbo, argnums=0, has_aus=True))(variational_parameters, eps, a, b)

I get the error

ValueError: Non-hashable static arguments are not supported, which can lead to unexpected cache-misses. Static argument (index 2) of type <class 'jax.interpreters.partial_eval.DynamicJaxprTracer'> for function elbo is non-hashable.

Any help would be greatly appreciated; thanks!

hasco641
  • 69
  • 5

1 Answers1

2

Any parameters you pass to a JIT-compiled function will no longer be static, unless you explicitly mark them as such. So this line:

jit(jacobian(elbo, argnums=0, has_aus=True))(variational_parameters, eps, a, b)

Makes variational_parameters, eps, a, and b non-static. Then within the transformed function these non-static parameters are passed to this function:

elbo = jit(elbo, static_argnames=["a", "b"])

which means that you are attempting to pass non-static values as static arguments, which causes an error.

To fix this, you should mark the static parameters as static any time they enter a jit-compiled function. In your case it might look something like this:

jit(jacobian(elbo, argnums=0, has_aus=True),
    static_argnums=(2, 3))(variational_parameters, eps, a, b)
jakevdp
  • 77,104
  • 11
  • 125
  • 160
  • Hi Jake, thank you for the response; it works as expected! I have a follow-up question . I neglected to say in my question that I had tried ```jit(jacobian(elbo, argnums=0, has_aus=True), static_argnames=["a", "b"])(variational_parameters, eps, a, b)``` and got the same error that I mentioned in the question. The solution was to pass the static arguments via `static_argnums` instead of `static_argnames`. Why is this? Do the variables not have the same name when called using a wrapped function like the jacobian? – hasco641 Oct 04 '22 at 01:31
  • Thanks for letting me know about this – I think this should work, but the `jacobian` decorator is currently missing `functools.wraps` on its output. I'll plan to add a fix for that later today. – jakevdp Oct 04 '22 at 11:51
  • This should be fixed by https://github.com/google/jax/pull/12653 – jakevdp Oct 04 '22 at 17:23