6

I'm trying to use gradient_override_map with Tensorflow 2.0. There is an example in the documentation, which I will use as the example here as well.

In 2.0, GradientTape can be used to compute gradients as follows:

import tensorflow as tf
print(tf.version.VERSION)  # 2.0.0-alpha0

x = tf.Variable(5.0)
with tf.GradientTape() as tape:
    s_1 = tf.square(x)
print(tape.gradient(s_1, x))

There is also the tf.custom_gradient decorator, which can be used to define the gradient for a new function (again, using the example from the docs):

import tensorflow as tf
print(tf.version.VERSION)  # 2.0.0-alpha

@tf.custom_gradient
def log1pexp(x):
    e = tf.exp(x)

    def grad(dy):
        return dy * (1 - 1 / (1 + e))

    return tf.math.log(1 + e), grad

x = tf.Variable(100.)

with tf.GradientTape() as tape:
    y = log1pexp(x)

print(tape.gradient(y, x))

However, I would like to replace the gradient for standard functions such as tf.square. I tried to use the following code:

@tf.RegisterGradient("CustomSquare")
def _custom_square_grad(op, grad):
  return tf.constant(0)

with tf.Graph().as_default() as g:
    x = tf.Variable(5.0)
    with g.gradient_override_map({"Square": "CustomSquare"}):
        with tf.GradientTape() as tape:
            s_2 = tf.square(x, name="Square")

    with tf.compat.v1.Session() as sess:
        sess.run(tf.compat.v1.global_variables_initializer())            
        print(sess.run(tape.gradient(s_2, x)))

However, there are two issues: The gradient replacement does not seem to work (it is evaluated to 10.0 instead of 0.0) and I need to resort to session.run() to execute the graph. Is there a way to achieve this in "native" TensorFlow 2.0?

In TensorFlow 1.12.0, the following produces the desired output:

import tensorflow as tf
print(tf.__version__)  # 1.12.0

@tf.RegisterGradient("CustomSquare")
def _custom_square_grad(op, grad):
  return tf.constant(0)

x = tf.Variable(5.0)

g = tf.get_default_graph()
with g.gradient_override_map({"Square": "CustomSquare"}):
    s_2 = tf.square(x, name="Square")
grad = tf.gradients(s_2, x)

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  print(sess.run(grad))
IonicSolutions
  • 2,559
  • 1
  • 18
  • 31

2 Answers2

8

There is no built-in mechanism in TensorFlow 2.0 to override all gradients for a built-in operator within a scope. However, if you are able to modify the call-site for each call to the built-in operator, you can use the tf.custom_gradient decorator as follows:

@tf.custom_gradient
def custom_square(x):
  def grad(dy):
    return tf.constant(0.0)
  return tf.square(x), grad

with tf.Graph().as_default() as g:
  x = tf.Variable(5.0)
  with tf.GradientTape() as tape:
    s_2 = custom_square(x)

  with tf.compat.v1.Session() as sess:
    sess.run(tf.compat.v1.global_variables_initializer())            
    print(sess.run(tape.gradient(s_2, x)))
mrry
  • 125,488
  • 26
  • 399
  • 400
  • 1
    Do you happen to know whether `tf.compat.v1.Session()`/`sess.run()` will remain a part of TensorFlow for the foreseeable future? – IonicSolutions Apr 23 '19 at 09:06
  • 2
    The `tf.compat.v1` compatibility module contains everything (apart from `tf.contrib`) from the `tf` module in the latest release of TF 1.x. There is no plan to remove it from TensorFlow in the foreseeable future, since many libraries still depend on it, although new feature development will focus on the main module, and there may be gaps in compatibility between old- and new-style APIs (though, fortunately, this case works!). – mrry Apr 23 '19 at 14:01
  • dont use compat, it will be gone – Dee Oct 03 '19 at 11:45
2

In addition to mrry's answer, there are two points I would like to add:

(1) In TF 2, we can use tf.GradientTape without building a graph, like this:

@tf.custom_gradient
def custom_square(x):
  def grad(dy):
    return tf.constant(0.0)
  return tf.square(x), grad

with tf.GradientTape() as tape:
  x = tf.Variable(5.0)
  s_2 = custom_square(x)

print(tape.gradient(s_2,x).numpy())

(2) Multiply your custom grad with the previous grad

Be careful, gradient calculation is a chained computation, we should multiply our custom grad by dy (the previously computed gradient). Without doing this, our customized function will be broken in a chain calculation. This is an example:

@tf.custom_gradient
def custom_square(x):
  def grad(dy):
    return tf.constant(4.0)
  return tf.square(x), grad

with tf.GradientTape(persistent=True) as tape:
  x = tf.Variable(5.0)
  s_2 = custom_square(x)
  s_4 = custom_square(s_2)

print("Grad from s_4 to x: ",tape.gradient(s_4,x).numpy())
print("Grad from s_4 to s_2: ",tape.gradient(s_4,s_2).numpy())
print("Grad from s_2 to x: ",tape.gradient(s_2,x).numpy())

The result:

Grad from s_4 to x:  4.0
Grad from s_4 to s_2:  4.0
Grad from s_2 to x:  4.0

Grad from s_4 to x should be 16 (accumulated grad from s_4 to s_2 and grad frm s_2 to x).

but the result was 4. That mean it didn't accumulate gradient from previous step.

Multiply the custom grad with dywill solve the problem:

@tf.custom_gradient
def custom_square(x):
  def grad(dy):
    return tf.constant(4.0)*dy
  return tf.square(x), grad

with tf.GradientTape(persistent=True) as tape:
  x = tf.Variable(5.0)
  s_2 = custom_square(x)
  s_4 = custom_square(s_2)

print("Grad from s_4 to x: ",tape.gradient(s_4,x).numpy())
print("Grad from s_4 to s_2: ",tape.gradient(s_4,s_2).numpy())
print("Grad from s_2 to x: ",tape.gradient(s_2,x).numpy())

Here is the result:

Grad from s_4 to x:  16.0
Grad from s_4 to s_2:  4.0
Grad from s_2 to x:  4.0

You can try the implementation through Colab here: https://colab.research.google.com/drive/1gbLopOLJiyznDA-Cr473bZEeWkWh_KGG?usp=sharing

Hoa Nguyen
  • 470
  • 6
  • 15