Say I have some custom operation binarizer
used in a neural network. The operation takes a Tensor
and constructs a new Tensor
. I would like to modify that operation such that it is only used in the forward pass. In the backward pass, when gradients are calculated, it should just pass through the gradients reaching it.
More concretly, say binarizer
is:
def binarizer(input):
prob = tf.truediv(tf.add(1.0, input), 2.0)
bernoulli = tf.contrib.distributions.Bernoulli(p=prob, dtype=tf.float32)
return 2 * bernoulli.sample() - 1
and I setup my network:
# ...
h1_before_my_op = tf.nn.tanh(tf.matmul(x, W) + bias_h1)
h1 = binarizer(h1_before_b)
# ...
loss = tf.reduce_mean(tf.square(y - y_true))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(loss)
How do I tell TensorFlow to skip gradient calculation in the backward pass?
I tried defining a custom operation as described in this answer, however: py_func
cannot return Tensor
s, that's not what it is made for – I get:
UnimplementedError (see above for traceback): Unsupported object type Tensor