Using jax, I try to calculate gradients per sample, process them and then bring them in the normal form to calculate a normal parameter update. My working code looks like
differentiate_per_sample = jit(vmap(grad(loss), in_axes=(None, 0, 0)))
gradients = differentiate_per_sample(params, x, y)
# some code
gradients_summed_over_samples = []
for layer in gradients:
(dw, db) = layer
(dw, db) = (np.sum(dw, axis=0), np.sum(db, axis=0))
gradients_summed_over_samples.append((dw, db))
where gradients
is of the form list(tuple(DeviceArray(...), DeviceArray(...)), ...)
.
Now I tried to rewrite the loop as vmap (not sure if it brings a speedup in the end)
def sum_samples(layer):
(dw, db) = layer
(dw, db) = (np.sum(dw, axis=0), np.sum(db, axis=0))
vmap(sum_samples)(gradients)
but sum_samples
is called only once and not for each element in the list.
Is the list the problem or do I understand something else wrong?