What you get is the grad starting from o back-propagated through the computational graph to A. In the end you have the grad for every value in A.
It is the same as doing the following
A = torch.Tensor(2, 3).uniform_(-1, 1).requires_grad_()
B = torch.Tensor(3, 1).uniform_(-1, 1).requires_grad_()
o = torch.matmul(A,B).sum()
o.backward()
print("A : ", A)
print("B : ", B)
print(A.grad)
A.grad
in this example and do_dinput
are the same. If you look at the grad tensor it is just B^T
in both rows.
To make it a bit more visual what happens. We have A and B as input and some function f(...) which takes all values from A and B as input and calculates some value. In this case the function is sum(AB).
Note: The summation doesn't change the gradients in any way.
A = x_1 x_2 x_3
x_4 x_5 x_6
B = y_1
y_2
y_3
o = x_1 * y_1 + x_2 * y_2 + x_3 * y_3
x_4 * y_1 + x_5 * y_2 + x_6 * y_3
f(x_1,...,x_6, y_1, y_2, y_3) = x_1 * y_1 + x_2 * y_2 + x_3 * y_3 + x_4 * y_1 + x_5 * y_2 + x_6 * y_3
If you now calculate the gradient you derive f(...) in respect to all variables. So for x_1 it would be
df/dx_1 = y_1
So the grad value in A for x_1 is equal to y_1. This is done for all other values. So in the end you get a grad value for all entries in A and B.
It works the same in your example you just skip the summing of the tensor.