Both "Floor" and "Identity" are type strings of operations, the former is corresponding to tf.floor while the latter tf.identity. So the function of your code, I guess, is to substitute tf.identity's back-propagated gradient(BPG for short) calculation mechanism for BPG calculation mechanism of tf.floor operations within graph G while passing forward output of tf.reduce_mean. It seems a little weird since in all applications of gradient_override_map
I've found so far, the key of op_type_map is always identical to the type string of the operation used to produce an output in the context. By this I mean I'm more familiar with scenarios with tf.floor(SomeVals)
returned, instead of tf.reduce_mean(SomeVals)
.
What gradient_override_map({op_A_type: op_B_type})
does is to replace op_A's BPG calculation mechanism with op_B's while remaining op_A_type's forward propagation calculation mechanism. A common application of gradient_override_map is shown in lahwran's answer.
@tf.RegisterGradient("CustomGrad")
def _const_mul_grad(unused_op, grad):
return 5.0 * grad
g = tf.get_default_graph()
with g.gradient_override_map({"Identity": "CustomGrad"}):
output = tf.identity(input, name="Identity")
by
@tf.RegisterGradient("CustomGrad")
def _const_mul_grad(unused_op, grad):
return 5.0 * grad
the decorator, tf.RegisterGradient("CustomGrad")
registers the gradient function defined by _const_mul_grad(unused_op, grad)
for a customized op type -- "CustomGrad",
while
g = tf.get_default_graph()
with g.gradient_override_map({"Identity": "CustomGrad"}):
output = tf.identity(input, name="Identity")
assures outputs of all operations (in graph g) with string type "Identity" (tf.identity) are as they were whereas BPG calculation mechanism of tf.identitys replaced by BPG calculation mechanism of operation with string type "CustomGrad".
P.S.
The type string of an op corresponds to the OpDef.name
field for the proto that defines the operation. To find an op's OpDef.name
, please refer to MingXing's answer under this question
It is not necessary to declare the name of tf.identity operation since the arg 'name' in tf.identity is optional.