I am trying to get the second derivative of the output w.r.t the input of a neural network built using Flax. The network is structured as follows:
import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
from flax import optim
class MLP(nn.Module):
features: Sequence[int]
@nn.compact
def __call__(self, x):
for feat in self.features[:-1]:
x = nn.tanh(nn.Dense(feat)(x))
x = nn.Dense(self.features[-1])(x)
return x
model = MLP([20, 20, 20, 20, 20, 1])
batch = jnp.ones((32, 3)) #Dummy input to Initialize the NN
params = model.init(jax.random.PRNGKey(0), batch)
X = jnp.ones((32, 3))
output = model.apply(params, X)
I can get the single derivative by using vmap over grad :
@jit
def u_function(params, X):
u = model.apply(params, X)
return jnp.squeeze(u)
grad_fn = vmap(grad(u_function, argnums=1), in_axes=(None, 0), out_axes=(0))
u_X = vmap(grad(u_function, argnums=1), in_axes=(None, 0), out_axes=(0))(params, X)
However, when I try to do this again to obtain the second derivative :
u_X_func = vmap(grad(u_function, argnums=1), in_axes=(None, 0), out_axes=(0))
u_XX_func = vmap(grad(u_X_func, argnums=1), in_axes=(None, 0), out_axes=(0))(params, X)
I get the folllowing error:
[/usr/local/lib/python3.7/dist-packages/flax/linen/linear.py](https://localhost:8080/#) in __call__(self, inputs)
186 kernel = self.param('kernel',
187 self.kernel_init,
--> 188 (jnp.shape(inputs)[-1], self.features),
189 self.param_dtype)
190 if self.use_bias:
IndexError: tuple index out of range
I tried using the hvp definition from the autodiff cookbook, but with params being an input to the function just wasnt sure how to proceed.
Any help on this would be really appreciable.