1

I was trying to edit tf.stack op's backward gradient calculation mechanism with tf.RegisterGradientandtf.gradient_override_map, here are my codes:

import tensorflow as tf

class SynthGradBuilder(object):
    def __init__(self):
        self.num_calls = 0

    def __call__(self, x, l=1.0):
        op_name = "SynthGrad%d" % self.num_calls
        @tf.RegisterGradient(op_name)
        def _grad_synth(op, grad):
            return grad[0]

        g = tf.get_default_graph()
        with g.gradient_override_map({"stack": op_name}):
            y = tf.stack([x,x])

        self.num_calls += 1
        return y

GradSys = SynthGradBuilder()

in another script, I wrote

import tensorflow as tf
from gradient_synthesizer import GradSys

x = tf.Variable([1,2])
y = GradSys(x, l=1)
z = tf.stack([x,x])


grad = tf.gradients(y, x, grad_ys=[[tf.convert_to_tensor([3, 4]), 
                              tf.convert_to_tensor([6, 8])]])
grad_stack = tf.gradients(z, x, grad_ys=[[tf.convert_to_tensor([3, 4]), 
                              tf.convert_to_tensor([6, 8])]])

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    print "grad bp: ", sess.run(grad)
    print "grad_stack: ", sess.run(grad_stack)
    print "y: ", sess.run(y)

The expected output should be:

grad bp: [3,4];
grad_stack: [3+6, 4+8] = [9, 12];
y: [[1,2], [1,2]];

What I actually got from the code was:

my result

indicating that tf.stack's backward gradients were not replaced at all, which was opposite to my expectation.

I'm not sure if such discrepancy was brought by falsely using "stack" as the type string of operation tf.stack, I carried out an experiment in the following way:

my validation

The first item describing tensor y, the "stack:0" suggests op tf.stack 's registered name is "stack", which is also its type string. So it seems it is not "stack"'s fault.

I am at a loss to figure out the causes of my codes' problem. I wonder if anyone can help me with that.

Yilin He
  • 107
  • 1
  • 7
  • This is really strange. It seems that `def _grad_synth(op, grad)` is actually never called. Would be curious to know if you could find out the cause and fix it. I will try from my side also. And will let you know if there is something. – SRC Oct 30 '17 at 13:17

1 Answers1

4

Tl;dr: The correct code should be:

@tf.RegisterGradient(op_name)
def _grad_synth(op, grad):
  x, y = tf.unstack(grad)
  return [x, tf.zeros_like(y)]

g = tf.get_default_graph()
with g.gradient_override_map({"Pack": op_name}):
  y = tf.stack([x, x])

Because this is a quite common question, I want to explain a little bit more details:

There are two main issues in your original code:

  1. Wrong usage of gradient_override_map:

The actual OP name for tf.stack is Pack (not Stack), so you need to ovrride Pack instead of Stack:

`g.gradient_override_map({"Pack": op_name})`.

You may wonder how do I know the actual OP name? Well, a simple way is to prober the GraphDef by running the following code:

with tf.Graph().as_default():
  x = tf.constant(0)
  y = tf.stack([x, x])
  print(tf.get_default_graph().as_graph_def())
  1. Wrong gradient function:

The original gradients for Pack is a simple Unpack (official code). In your case, you still need to first unpack the gradients, but only propogate the FIRST part:

@tf.RegisterGradient(op_name)
def _grad_synth(op, grad):
  x, y = tf.unstack(grad)
  return [x, tf.zeros_like(y)]

Note, this code works perfectly for your case. However, if you want to support any length of stack, you can use a slightly more complicated version:

@tf.RegisterGradient(op_name)
def _grad_synth(op, grad):
  x_list = tf.unstack(grad)
  for i in range(1, len(x_list)):
    x_list[i] = tf.zeros_like(x_list[i])
  return x_list
Max
  • 1,014
  • 6
  • 7
  • Thanks for your help! Your method is a nice fit for my issue! Still curious about the "stack:0" I got by printing the op, what is it? When I use the name "stack" to retrieve the op with `tf.Graph.get_operation_by_name`, it returns the op correctly. – Yilin He Nov 02 '17 at 06:13