I am working with JAX through numpyro. Specially, I want to use a B-spline function (e.g. implemented in scipy.interpolate.BSpline
) to transform different points into a spline where the input depends on some of the parameters in the model. Thus, I need to be able to differentiate the B-spline in JAX (only in the input argument and not in the knots or the integer order (of course!)).
I can easily use jax.custom_vjp
but not when JIT is used as it is in numpyro. I looked at the following:
- https://github.com/google/jax/issues/1142
- https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html
and it seems like the best hope is to use a callback. Though, I cannot figure out entirely how that would work. At https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.html#using-call-to-call-a-jax-function-on-another-device-with-reverse-mode-autodiff-support
the TensorFlow example with reverse mode autodiff seem not to use JIT.
The example
Here is Python code that works without JIT (see the b_spline_basis()
function):
from scipy.interpolate import BSpline
import numpy as np
from numpy import typing as npt
from functools import partial
import jax
doubleArray = npt.NDArray[np.double]
# see
# https://stackoverflow.com/q/74699053/5861244
# https://en.wikipedia.org/wiki/B-spline#Derivative_expressions
def _b_spline_deriv_inner(spline: BSpline, deriv_basis: doubleArray) -> doubleArray: # type: ignore[no-any-unimported]
out = np.zeros((deriv_basis.shape[0], deriv_basis.shape[1] - 1))
for col_index in range(out.shape[1] - 1):
scale = spline.t[col_index + spline.k + 1] - spline.t[col_index + 1]
if scale != 0:
out[:, col_index] = -deriv_basis[:, col_index + 1] / scale
for col_index in range(1, out.shape[1]):
scale = spline.t[col_index + spline.k] - spline.t[col_index]
if scale != 0:
out[:, col_index] += deriv_basis[:, col_index] / scale
return float(spline.k) * out
def _b_spline_eval(spline: BSpline, x: doubleArray, deriv: int) -> doubleArray: # type: ignore[no-any-unimported]
if deriv == 0:
return spline.design_matrix(x=x, t=spline.t, k=spline.k).todense()
elif spline.k <= 0:
return np.zeros((x.shape[0], spline.t.shape[0] - spline.k - 1))
return _b_spline_deriv_inner(
spline=spline,
deriv_basis=_b_spline_eval(
BSpline(t=spline.t, k=spline.k - 1, c=np.zeros(spline.c.shape[0] + 1)), x=x, deriv=deriv - 1
),
)
@partial(jax.custom_vjp, nondiff_argnums=(0, 1, 2))
def b_spline_basis(knots: doubleArray, order: int, deriv: int, x: doubleArray) -> doubleArray:
return _b_spline_eval(spline=BSpline(t=knots, k=order, c=np.zeros((order + knots.shape[0] - 1))), x=x, deriv=deriv)[
:, 1:
]
def b_spline_basis_fwd(knots: doubleArray, order: int, deriv: int, x: doubleArray) -> tuple[doubleArray, doubleArray]:
spline = BSpline(t=knots, k=order, c=np.zeros(order + knots.shape[0] - 1))
return (
_b_spline_eval(spline=spline, x=x, deriv=deriv)[:, 1:],
_b_spline_eval(spline=spline, x=x, deriv=deriv + 1)[:, 1:],
)
def b_spline_basis_bwd(
knots: doubleArray, order: int, deriv: int, partials: doubleArray, grad: doubleArray
) -> tuple[doubleArray]:
return (jax.numpy.sum(partials * grad, axis=1),)
b_spline_basis.defvjp(b_spline_basis_fwd, b_spline_basis_bwd)
if __name__ == "__main__":
# tests
knots = np.array([0, 0, 0, 0, 0.25, 1, 1, 1, 1])
x = np.array([0.1, 0.5, 0.9])
order = 3
def test_jax(basis: doubleArray, partials: doubleArray, deriv: int) -> None:
weights = jax.numpy.arange(1, basis.shape[1] + 1)
def test_func(x: doubleArray) -> doubleArray:
return jax.numpy.sum(jax.numpy.dot(b_spline_basis(knots=knots, order=order, deriv=deriv, x=x), weights)) # type: ignore[no-any-return]
assert np.allclose(test_func(x), np.sum(np.dot(basis, weights)))
assert np.allclose(jax.grad(test_func)(x), np.dot(partials, weights))
deriv0 = np.transpose(
np.array(
[
0.684,
0.166666666666667,
0.00133333333333333,
0.096,
0.444444444444444,
0.0355555555555555,
0.004,
0.351851851851852,
0.312148148148148,
0,
0.037037037037037,
0.650962962962963,
]
).reshape(-1, 3)
)
deriv1 = np.transpose(
np.array(
[
2.52,
-1,
-0.04,
1.68,
-0.666666666666667,
-0.666666666666667,
0.12,
1.22222222222222,
-2.29777777777778,
0,
0.444444444444444,
3.00444444444444,
]
).reshape(-1, 3)
)
test_jax(deriv0, deriv1, deriv=0)
deriv2 = np.transpose(
np.array(
[
-69.6,
4,
0.8,
9.6,
-5.33333333333333,
5.33333333333333,
2.4,
-2.22222222222222,
-15.3777777777778,
0,
3.55555555555556,
9.24444444444445,
]
).reshape(-1, 3)
)
test_jax(deriv1, deriv2, deriv=1)
deriv3 = np.transpose(
np.array(
[
504,
-8,
-8,
-144,
26.6666666666667,
26.6666666666667,
24,
-32.8888888888889,
-32.8888888888889,
0,
14.2222222222222,
14.2222222222222,
]
).reshape(-1, 3)
)
test_jax(deriv2, deriv3, deriv=2)