1

I am using a function consisting of compound Tensorflow operations. However, instead of letting Tensorflow automatically compute its derivatives with respect to one of the inputs, I would like to replace the gradients with a different computation on the same input. Moreover, some of the calculation is shared between the forward and backward pass. For example:

def func(in1, in2):
    # do something with inputs using only tf operations
    shared_rep = tf.op1(tf.op2(tf.op3(in1, in2))) # same computation for both forward and gradient pass
    # return output of forward computation
    return tf.op4(shared_rep)

def func_grad(in1, in2):
    shared_rep = tf.op1(tf.op2(tf.op3(in1, in2)))
    # explicitly calculate gradients with respect to in1, with the intention of replacing the gradients computed by Tensorflow
    mygrad1 = tf.op5(tf.op6(shared_rep))
    return mygrad1

in1 = tf.Variable([1,2,3])
in2 = tf.Variable([2.5,0.01])
func_val = func(in1, in2)
my_grad1 = func_grad(in1, in2)
tf_grad1 = tf.gradients(func_val, in1)
with tf.Session() as sess:
    # would like tf_grad1 to equal my_grad1
    val, my1, tf1 = sess.run([func_val, my_grad1, tf_grad1])
    tf.assert_equal(my1, tf1)

NOTE: This is similar to question How to replace or modify gradient? with one key difference: I am not interested in Tensorflow computing gradients of a different function in the backward pass; rather I would like to supply the gradients myself based on alternate tensorflow operations on the input.

I am trying to use the ideas proposed in the solution to the above question and in the following post, that is using tf.RegisterGradient and gradient_override_map to override the gradient of the identity function wrapping the forward function. This fails because inside the registered alternate grad for identity, I have no access to the input to func_grad:

@tf.RegisterGradient("CustomGrad")
def alternate_identity_grad(op, grad):
    # op.inputs[0] is the output of func(in1,in2)
    # grad is of no use, because I would like to replace it with func_grad(in1,in2)

g = tf.get_default_graph()
with g.gradient_override_map({"Identity": "CustomGrad"}):
    out_grad = tf.identity(input, name="Identity")

EDIT After additional research, I believe this question is similar to the following question. I managed to obtain the desired solution by combining gradient_override_map with the hack suggested here.

killogre
  • 1,730
  • 15
  • 26

0 Answers0