tf.gradients
computes the gradient of the sum of each output tensor with respect to each value in the input tensors. A gradient operation receives the op for which you are computing the gradient, op
, and the gradient accumulated at this point, grad
. In your example, grad
would be a tensor with the same shape as y
, and each value would be the gradient of the corresponding value in y
- that is, if grad[0, 0] == 2
, it means that increasing y[0, 0]
by 1 will increase the sum of the output tensor by 2 (I know, you probably are already clear on this). Now you have to compute the same thing for A
and B
. Let's say you figure out that increasing A[2, 3]
by 1 will increase y[0, 0]
by 3 and have no effect over any other value in y
. That means that would increase the sum of the output value by 3 × 2 = 6, so the gradient for A[2, 3]
would be 6.
As an example, let's take the gradient of the matrix multiplication (op MatMul
), which you can find in tensorflow/python/ops/math_grad.py
:
@ops.RegisterGradient("MatMul")
def _MatMulGrad(op, grad):
"""Gradient for MatMul."""
t_a = op.get_attr("transpose_a")
t_b = op.get_attr("transpose_b")
a = math_ops.conj(op.inputs[0])
b = math_ops.conj(op.inputs[1])
if not t_a and not t_b:
grad_a = gen_math_ops.mat_mul(grad, b, transpose_b=True)
grad_b = gen_math_ops.mat_mul(a, grad, transpose_a=True)
elif not t_a and t_b:
grad_a = gen_math_ops.mat_mul(grad, b)
grad_b = gen_math_ops.mat_mul(grad, a, transpose_a=True)
elif t_a and not t_b:
grad_a = gen_math_ops.mat_mul(b, grad, transpose_b=True)
grad_b = gen_math_ops.mat_mul(a, grad)
elif t_a and t_b:
grad_a = gen_math_ops.mat_mul(b, grad, transpose_a=True, transpose_b=True)
grad_b = gen_math_ops.mat_mul(grad, a, transpose_a=True, transpose_b=True)
return grad_a, grad_b
We will focus on the case where transpose_a
and transpose_b
are both False
, and so we are in the first branch , if not t_a and not t_b:
(also ignore the conj
, which is meant for complex values). 'a' and 'b' are the operands here and, as said before, grad
has the gradient of the sum of the output with respect to each value in the multiplication result. So how would things change if I increase a[0, 0]
by one? Basically, each element in the first row of the product matrix would be increased by the values in the first row of b
. So the gradient for a[0, 0]
is the dot product of the first row of b
and the first row of grad
- that is, how much I would increase each output value multiplied by the accumulated gradient of each of these. If you think about it, the line grad_a = gen_math_ops.mat_mul(grad, b, transpose_b=True)
is doing exactly that. grad_a[0, 0]
will be the dot product of the first row of grad
and the first row of b
(because we are transposing b
here), and, in general, grad_a[i, j]
will be the dot product of the i
-th row of grad
and the j
-th row of b
. You can follow a similar reasoning for grad_b
too.
EDIT:
As an example, see how tf.gradients
and the registered gradient relate to each other:
import tensorflow as tf
# Import gradient registry to lookup gradient functions
from tensorflow.python.framework.ops import _gradient_registry
# Gradient function for matrix multiplication
matmul_grad = _gradient_registry.lookup('MatMul')
# A matrix multiplication
a = tf.constant([[1, 2], [3, 4]], dtype=tf.float32)
b = tf.constant([[6, 7, 8], [9, 10, 11]], dtype=tf.float32)
c = tf.matmul(a, b)
# Gradient of sum(c) wrt each element of a
grad_c_a_1, = tf.gradients(c, a)
# The same is obtained by backpropagating an all-ones matrix
grad_c_a_2, _ = matmul_grad(c.op, tf.ones_like(c))
# Multiply each element of c by itself, but stopping the gradients
# This should scale the gradients by the values of c
cc = c * tf.stop_gradient(c)
# Regular gradients computation
grad_cc_a_1, = tf.gradients(cc, a)
# Gradients function called with c as backpropagated gradients
grad_cc_a_2, _ = matmul_grad(c.op, c)
with tf.Session() as sess:
print('a:')
print(sess.run(a))
print('b:')
print(sess.run(b))
print('c = a * b:')
print(sess.run(c))
print('tf.gradients(c, a)[0]:')
print(sess.run(grad_c_a_1))
print('matmul_grad(c.op, tf.ones_like(c))[0]:')
print(sess.run(grad_c_a_2))
print('tf.gradients(c * tf.stop_gradient(c), a)[0]:')
print(sess.run(grad_cc_a_1))
print('matmul_grad(c.op, c)[0]:')
print(sess.run(grad_cc_a_2))
Output:
a:
[[1. 2.]
[3. 4.]]
b:
[[ 6. 7. 8.]
[ 9. 10. 11.]]
c = a * b:
[[24. 27. 30.]
[54. 61. 68.]]
tf.gradients(c, a)[0]:
[[21. 30.]
[21. 30.]]
matmul_grad(c.op, tf.ones_like(c))[0]:
[[21. 30.]
[21. 30.]]
tf.gradients(c * tf.stop_gradient(c), a)[0]:
[[ 573. 816.]
[1295. 1844.]]
matmul_grad(c.op, c)[0]:
[[ 573. 816.]
[1295. 1844.]]