1- Using torch.autograd.grad
You can get the different terms of your gradient only by back-propagating multiple times on your network. In order to avoid having to perform multiple inferences on your input, you can use the torch.autograd.grad
utility function instead of doing a conventional backward pass backward
. This means you won't pollute the gradients coming from the different terms.
Here is a minimal example that shows the basic idea:
>>> x = torch.rand(1, 10, requires_grad=True)
>>> lossA = x.pow(2).sum()
>>> lossB = x.mean()
Then perform one backward pass on each term out of place. You have to retain the graph on all calls but the last:
>>> gradA = torch.autograd.grad(lossA, x, retain_graph=True)
(tensor([[1.5810, 0.6684, 0.1467, 0.6618, 0.5067, 0.2368, 0.0971, 0.4533, 0.3511,
1.9858]]),)
>>> gradB = torch.autograd.grad(lossB, x)
(tensor([[0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
0.1000]]),)
This method has some limitations since you are receiving your parameters' gradient as a tuple which is not that convenient.
2- Caching the results of backward
An alternative solution consists in caching the gradient after each successive backward
call:
>>> lossA = x.pow(2).sum()
>>> lossB = x.mean()
>>> lossA.backward(retain_graph=True)
Store the gradient and clear the .grad
attributes (don't forget to do so otherwise the gradient of lossA
will pollute gradB
. You will have to adapt this to the general case when handling multiple tensor parameters:
>>> x.gradA = x.grad
>>> x.grad = None
Backward pass on the next loss term:
>>> lossB.backward()
>>> x.gradB = x.grad
Then you can interact with each gradient term locally (i.e. on each parameter separately):
>>> x.gradA, x.gradB
(tensor([[1.5810, 0.6684, 0.1467, 0.6618, 0.5067, 0.2368, 0.0971, 0.4533, 0.3511,
1.9858]]),
tensor([[0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
0.1000]]))
The latter method seems more practical.
This essentially comes down to torch.autograd.grad vs torch.autograd.backward, i.e. out-of-place vs in-place... and will ultimately depends on your needs. You can read more about these two functions here.