This is the JAX computation that I wrote based on your description:
import numpy as np
import jax.numpy as jnp
import jax
N = 10
M = 20
rng = np.random.default_rng(0)
A = jnp.array(rng.random((N,)))
B = jnp.array(rng.random((N, M)))
theta = jnp.array(rng.random(M))
def f(A, B, theta, k=3):
C = B @ theta
_, i_upper = lax.top_k(C, k)
_, i_lower = lax.top_k(-C, k)
return A[i_lower], A[i_upper]
x, y = f(A, B, theta)
dx_dtheta, dy_dtheta = jax.jacobian(f, argnums=2)(A, B, theta)
The derivatives are all zero, and I believe this is correct, because the change in value of the outputs does not depend on the change in value of theta
.
But, you might ask, how can this be? After all, theta
enters into the computation, and if you put in a different value for theta
, you get different outputs. How could the gradient be zero?
What you must keep in mind, though, is that differentiation doesn't measure whether an input affects an output. It measures the change in output given an infinitesimal change in input.
Let's use a slightly simpler function as an example:
import jax
import jax.numpy as jnp
A = jnp.array([1.0, 2.0, 3.0])
theta = jnp.array([5.0, 1.0, 3.0])
def f(A, theta):
return A[jnp.argmax(theta)]
x = f(A, theta)
dx_dtheta = jax.grad(f, argnums=1)(A, theta)
Here the result of differentiating f
with respect to theta
is all zero, for the same reasons as above. Why? If you make an infinitesimal change to theta
, it will in general not affect the sort order of theta
. Thus, the entries you choose from A
do not change given an infinitesimal change in theta, and thus the derivative with respect to theta is zero.
Now, you might argue that there are circumstances where this is not the case: for example, if two values in theta are very close together, then certainly perturbing one even infinitesimally could change their respective rank. This is true, but the gradient resulting from this procedure is undefined (the change in output is not smooth with respect to the change in input). The good news is this discontinuity is one-sided: if you perturb in the other direction, there is no change in rank and the gradient is well-defined. In order to avoid undefined gradients, most autodiff systems will implicitly use this safer definition of a derivative for rank-based computations.
The result is that the value of the output does not change when you infinitesimally perturb the input, which is another way of saying the gradient is zero. And this is not a failure of autodiff – it is the correct gradient given the definition of differentiation that autodiff is built on. Moreover, were you to try changing to a different definition of the derivative at these discontinuities, the best you could hope for would be undefined outputs, so the definition that results in zeros is arguably more useful and correct.