I'd like to override default gradient calculations in Tensorflow 2.0.
In TF1, each time we run something, we create a Session and a Graph explicitly. Then we can use the with graph.gradient_override_map({'GradName': 'CustGradName'})
context apply the custom-registered gradient. In TF2, graphs and sessions can be created using the tf.compat.v1
API, so it's done in the same way as follows
# tf 2.0.0 with compatible API
# register a custom gradient
@tf.RegisterGradient('CustZero')
def _custom_grad(op, grad):
return tf.zeros_like(op.inputs[0])
with tf.compat.v1.Session() as sess:
with sess.graph.as_default() as g:
x = tf.convert_to_tensor([-1.0, 0.0, 1.0, 2.0])
with g.gradient_override_map({'Relu': 'CustZero'}):
y = tf.nn.relu(x)
dy = tf.gradients(y , x)
print(sess.run(dy)[0])
# output: [0. 0. 0. 0.]
Now in TF2, tf.function
and tf.GradientTape
are recommended and is the default way to get gradients. So I'd like to use TF2 native API to do similar things instead of using tf.compat.v1
. Any functions decorated with tf.function
will create an AutoGraph
once it's called. Different graphs will be created and called for different kind of input signatures. If an input_signature
is past to tf.function
, only one graph will be created, and therefore it takes tensors matching the specified input signature only. This is called a concrete function in TF2. From the concrete function, we can get access to its associated graph and do gradient override. So I did as follows
# tf 2.0.0 with tf.function and tf.GradientTape
# register custom gradient
@tf.RegisterGradient('CustZero')
def _custom_grad(op, grad):
return tf.zeros_like(op.inputs[0])
# create a concrete function with the specified input signatures
@tf.function(input_signature=(tf.TensorSpec(shape=(None,), dtype=tf.float32),))
def my_relu(x):
return tf.nn.relu(x)
my_relu_conc = my_relu.get_concrete_function()
x = tf.constant([-1.0, 0.0, 1.0, 2.0], dtype=tf.float32)
with tf.GradientTape() as tape:
with my_relu_conc.graph.gradient_override_map({'Relu': 'CustZero'}):
tape.watch(x)
y = my_relu_conc(x)
dy = tape.gradient(y, x).numpy()
print(dy)
# output: array([0., 0., 1., 1.], dtype=float32)
As we see, if seems that the gradient override map doesn't work somehow in the second method. I checked my_relu_conc.graph._gradient_override_map
and got {'Relu': 'CustZero'}
, so the context works as expected. However, the custom gradient is not used.
Does anyone know how to use gradient_override_map
with tf.function
and tf.GradientTape
in a correct manner? Thanks!