I am using a function consisting of compound Tensorflow operations. However, instead of letting Tensorflow automatically compute its derivatives with respect to one of the inputs, I would like to replace the gradients with a different computation on the same input. Moreover, some of the calculation is shared between the forward and backward pass. For example:
def func(in1, in2):
# do something with inputs using only tf operations
shared_rep = tf.op1(tf.op2(tf.op3(in1, in2))) # same computation for both forward and gradient pass
# return output of forward computation
return tf.op4(shared_rep)
def func_grad(in1, in2):
shared_rep = tf.op1(tf.op2(tf.op3(in1, in2)))
# explicitly calculate gradients with respect to in1, with the intention of replacing the gradients computed by Tensorflow
mygrad1 = tf.op5(tf.op6(shared_rep))
return mygrad1
in1 = tf.Variable([1,2,3])
in2 = tf.Variable([2.5,0.01])
func_val = func(in1, in2)
my_grad1 = func_grad(in1, in2)
tf_grad1 = tf.gradients(func_val, in1)
with tf.Session() as sess:
# would like tf_grad1 to equal my_grad1
val, my1, tf1 = sess.run([func_val, my_grad1, tf_grad1])
tf.assert_equal(my1, tf1)
NOTE: This is similar to question How to replace or modify gradient? with one key difference: I am not interested in Tensorflow computing gradients of a different function in the backward pass; rather I would like to supply the gradients myself based on alternate tensorflow operations on the input.
I am trying to use the ideas proposed in the solution to the above question and in the following post, that is using tf.RegisterGradient and gradient_override_map to override the gradient of the identity function wrapping the forward function. This fails because inside the registered alternate grad for identity, I have no access to the input to func_grad:
@tf.RegisterGradient("CustomGrad")
def alternate_identity_grad(op, grad):
# op.inputs[0] is the output of func(in1,in2)
# grad is of no use, because I would like to replace it with func_grad(in1,in2)
g = tf.get_default_graph()
with g.gradient_override_map({"Identity": "CustomGrad"}):
out_grad = tf.identity(input, name="Identity")
EDIT After additional research, I believe this question is similar to the following question. I managed to obtain the desired solution by combining gradient_override_map with the hack suggested here.