I have a haiku Module with a call function as follows
class MyModule(hk.Module):
__call__(self, x):
A = hk.get_parameter("A", shape=[self.Ashape], init=A_init)
B = hk.get_parameter("B", shape=[self.Bshape], init=B_init)
C = self.demanding_computation(A, B)
res = easy_computation(C, x)
return res
I use this module via
def _forward(x):
module = MyModule()
return module(x)
forward = hk.without_apply_rng(hk.transform(_forward))
x_test = jnp.ones(1)
params = forward.init(jax.random.PRNGKey(42), x_test)
f = jax.vmap(forward.apply, in_axes=(None, 0))
Then I apply f with the same params
to many different x
. Is the demanding_computation
(that is not depending on x
) then cached within the jax.vmap
call? If not, what is the correct pattern to separate these computations and get demanding_computation
cached?
I have tried to test this by adding a print statement from jax.experimental.host_callback
:
def demanding_computation(self, A, B):
C = compute(A, B)
id_print(C)
return C
and it indeed only printed once. Is that sufficient evidence that this computation is actually cached or is only the printing omitted in subsequent iterations?