1

I would need to compute the gradient of a batched function using JAX. The following is a minimal example of what I would like to do:

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

x = jnp.expand_dims(jnp.linspace(-1, 1, 20), axis=1)

u = lambda x: jnp.sin(jnp.pi * x)
ux = jax.vmap(jax.grad(u))

plt.plot(x, u(x))
plt.plot(x, ux(x))  # Use vx instead of ux
plt.show()

I have tried a variety of ways of making this work using vmap, but I don't seem to be able to get the code to run without removing the batch dimension in the input x. I have seen some workarounds using the Jacobian but this doesn't seem natural as the given is a scalar function of a single variable.

In the end u will be a neural network (implemented in Flax) that I need to differentiate with respect to the input (not the parameters of the network), so I cannot remove the batch dimension.

desertnaut
  • 57,590
  • 26
  • 140
  • 166
al_cc
  • 93
  • 1
  • 6

1 Answers1

1

To ensure the kernel (u) returns a scalar value, so that jax.grad makes sense, the batched dimension also needs to be mapped over.

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

x = jnp.expand_dims(jnp.linspace(-1, 1, 20), axis=1)

u = lambda x: jnp.sin(jnp.pi * x)
ux = jax.vmap(jax.vmap(jax.grad(u)))
# ux = lambda x : jax.lax.map(jax.vmap(jax.grad(u)), x) # sequential version
# ux = lambda x : jax.vmap(jax.grad(u))(x.reshape(-1)).reshape(x.shape) # flattened map version

plt.plot(x, u(x))
plt.plot(x, ux(x))  # Use vx instead of ux
plt.show()

Which composition of maps to use depends on what's happening in the batched dimension.

DavidJ
  • 326
  • 2
  • 10
  • Great, this solves the issue! If you don't mind, how would this answer have to be modified, if the function u had some extra parameters (the weights of the neural network), but I still want the derivatives w.r.t. the input x? – al_cc Jul 29 '23 at 10:12
  • If you're not differentiating with respect to the extra arguments, you can either close over them or specify `in_axes` to `jax.vmap` and `argnums` to `jax.grad`. if you post another question specifying the extra parameters, the values for the previous arguments can be defined. – DavidJ Jul 29 '23 at 12:29
  • Many thanks! I think it worked just using your code snippet from before, just adding the parameters and not mapping them using vmap. I appreciate your help here! – al_cc Jul 29 '23 at 14:03