0

I want to use a function that creates weights for a normal dense layer, it basically behaves like an initialization function, only that it "initializes" before every new forward pass.

The flow for my augmented linear layer looks like this:

input = (x, W)
W_new = g(x,W)
output = tf.matmul(x,W_new)

However, g(x,W) is not differentiable, as it involves some sampling. Luckily it also doesn't have any parameters I want to learn so I just try to do the forward and backward pass, as if I would have never replaced W. Now I need to tell the automatic differentiation to not backpropagate through g(). I do this with:

W_new = tf.stop_gradient(g(x,W))

Unfortunately this does not work, as it complains about non-matching shapes. What does work is the following:

input = (x, W)
W_new = W + tf.stop_gradient(g(x,W) - W)
output = tf.matmul(x,W_new)

as suggested here: https://stackoverflow.com/a/36480182

Now the forward pass seems to be OK, but I don't know how to override the gradient for the backward pass. I know, that I have to use: gradient_override_map for this, but could not transfer applications I have seen to my particular usecase (I am still quite new to TF). However, I am not sure how to do this and if there isn't an easier way. I assume something similar has to be done in the first forward pass in a given model, where all weights are initialized while we don't have to backpropagate through the init functions as well.

Any help would be very much appreciated!

Community
  • 1
  • 1
jhj
  • 1
  • 1

1 Answers1

1

Hey @jhj I too faced the same problem fortunately I found this gist. Hope this helps :)

Sample working -

import tensorflow as tf

from tensorflow.python.framework import ops

import numpy as np

Define custom py_func which takes also a grad op as argument:

def py_func(func, inp, Tout, stateful=True, name=None, grad=None):

# Need to generate a unique name to avoid duplicates:
rnd_name = 'PyFuncGrad' + str(np.random.randint(0, 1E+8))

tf.RegisterGradient(rnd_name)(grad)  # see _MySquareGrad for grad example
g = tf.get_default_graph()
with g.gradient_override_map({"PyFunc": rnd_name, "PyFuncStateless": rnd_name}):
    return tf.py_func(func, inp, Tout, stateful=stateful, name=name)

Def custom square function using np.square instead of tf.square:

def mysquare(x, name=None):

with ops.name_scope(name, "Mysquare", [x]) as name:
    sqr_x = py_func(np.square,
                    [x],
                    [tf.float32],
                    name=name,
                    grad=_MySquareGrad)  # <-- here's the call to the gradient
    return sqr_x[0]

Actual gradient:

def _MySquareGrad(op, grad):

x = op.inputs[0]
return grad * 20 * x  # add a "small" error just to see the difference:

with tf.Session() as sess:

x = tf.constant([1., 2.])
y = mysquare(x)
tf.global_variables_initializer().run()

print(x.eval(), y.eval(), tf.gradients(y, x)[0].eval())