2

I'd like to override default gradient calculations in Tensorflow 2.0.

In TF1, each time we run something, we create a Session and a Graph explicitly. Then we can use the with graph.gradient_override_map({'GradName': 'CustGradName'}) context apply the custom-registered gradient. In TF2, graphs and sessions can be created using the tf.compat.v1 API, so it's done in the same way as follows

# tf 2.0.0 with compatible API

# register a custom gradient
@tf.RegisterGradient('CustZero')
    def _custom_grad(op, grad):
        return tf.zeros_like(op.inputs[0])

with tf.compat.v1.Session() as sess:
    with sess.graph.as_default() as g:
        x = tf.convert_to_tensor([-1.0, 0.0, 1.0, 2.0])
        with g.gradient_override_map({'Relu': 'CustZero'}):
            y = tf.nn.relu(x)
        dy = tf.gradients(y , x)
    print(sess.run(dy)[0])

# output: [0. 0. 0. 0.]

Now in TF2, tf.function and tf.GradientTape are recommended and is the default way to get gradients. So I'd like to use TF2 native API to do similar things instead of using tf.compat.v1. Any functions decorated with tf.function will create an AutoGraph once it's called. Different graphs will be created and called for different kind of input signatures. If an input_signature is past to tf.function, only one graph will be created, and therefore it takes tensors matching the specified input signature only. This is called a concrete function in TF2. From the concrete function, we can get access to its associated graph and do gradient override. So I did as follows

# tf 2.0.0 with tf.function and tf.GradientTape

# register custom gradient
@tf.RegisterGradient('CustZero')
def _custom_grad(op, grad):
    return tf.zeros_like(op.inputs[0])

# create a concrete function with the specified input signatures
@tf.function(input_signature=(tf.TensorSpec(shape=(None,), dtype=tf.float32),))
def my_relu(x):
    return tf.nn.relu(x)
my_relu_conc = my_relu.get_concrete_function()
    
x = tf.constant([-1.0, 0.0, 1.0, 2.0], dtype=tf.float32)
with tf.GradientTape() as tape:
    with my_relu_conc.graph.gradient_override_map({'Relu': 'CustZero'}):
        tape.watch(x)
        y = my_relu_conc(x)
dy = tape.gradient(y, x).numpy()
print(dy)

# output: array([0., 0., 1., 1.], dtype=float32)

As we see, if seems that the gradient override map doesn't work somehow in the second method. I checked my_relu_conc.graph._gradient_override_map and got {'Relu': 'CustZero'}, so the context works as expected. However, the custom gradient is not used.

Does anyone know how to use gradient_override_map with tf.function and tf.GradientTape in a correct manner? Thanks!

marc_s
  • 732,580
  • 175
  • 1,330
  • 1,459
galaxy_m31
  • 21
  • 4
  • You can see the answer from [mrry](https://stackoverflow.com/users/3574081/mrry) through this link: https://stackoverflow.com/a/55799378/11524628. As being said, "there is no built-in mechanism in TensorFlow 2.0 to override all gradients for a built-in operator within a scope.". You can use `@tf.custom_gradient` instead. – Hoa Nguyen Jul 30 '20 at 11:57

0 Answers0