I want to use a function that creates weights for a normal dense layer, it basically behaves like an initialization function, only that it "initializes" before every new forward pass.
The flow for my augmented linear layer looks like this:
input = (x, W)
W_new = g(x,W)
output = tf.matmul(x,W_new)
However, g(x,W) is not differentiable, as it involves some sampling. Luckily it also doesn't have any parameters I want to learn so I just try to do the forward and backward pass, as if I would have never replaced W. Now I need to tell the automatic differentiation to not backpropagate through g(). I do this with:
W_new = tf.stop_gradient(g(x,W))
Unfortunately this does not work, as it complains about non-matching shapes. What does work is the following:
input = (x, W)
W_new = W + tf.stop_gradient(g(x,W) - W)
output = tf.matmul(x,W_new)
as suggested here: https://stackoverflow.com/a/36480182
Now the forward pass seems to be OK, but I don't know how to override the gradient for the backward pass. I know, that I have to use: gradient_override_map for this, but could not transfer applications I have seen to my particular usecase (I am still quite new to TF). However, I am not sure how to do this and if there isn't an easier way. I assume something similar has to be done in the first forward pass in a given model, where all weights are initialized while we don't have to backpropagate through the init functions as well.
Any help would be very much appreciated!