I am trying to code up an example of the Inverting Gradient method from DEEP REINFORCEMENT LEARNING IN PARAMETERIZED ACTION SPACE (equation 11) in Lasagne/Theano. Basically what I am trying to do is ensure the output of the network is within some specified bounds, in this case [1,-1].
I have been looking at the example given here that inverts the gradient which has helped but at this point I am stuck. I think the best place to perform this operation is in the gradient computation method so I copied rmsprop and am trying to edit the gradients before the updates are applied.
This is what I have so far
def rmspropWithInvert(loss_or_grads, params, p, learning_rate=1.0, rho=0.9, epsilon=1e-6):
clip = 2.0
grads = lasagne.updates.get_or_compute_grads(loss_or_grads, params)
# grads = theano.gradient.grad_clip(grads, -clip, clip)
grads_ = []
for grad in grads:
grads_.append(theano.gradient.grad_clip(grad, -clip, clip) )
grads = grads_
a, p_ = T.scalars('a', 'p_')
z_lazy = ifelse(T.gt(a,0.0), (1.0-p_)/(2.0), (p_-(-1.0))/(2.0))
f_lazyifelse = theano.function([a,p_], z_lazy,
mode=theano.Mode(linker='vm'))
# compute the parameter vector to invert the gradients by
ps = theano.shared(
np.zeros((3, 1), dtype=theano.config.floatX),
broadcastable=(False, True))
for i in range(3):
ps[i] = f_lazyifelse(grads[-1][i], p[i])
# Apply vector through computed gradients
grads2=[]
for grad in grads.reverse():
grads2.append(theano.mul(ps, grad))
ps = grad
grads = grads2.reverse()
print "Grad Update: " + str(grads[0])
updates = OrderedDict()
# Using theano constant to prevent upcasting of float32
one = T.constant(1)
for param, grad in zip(params, grads):
value = param.get_value(borrow=True)
accu = theano.shared(np.zeros(value.shape, dtype=value.dtype),
broadcastable=param.broadcastable)
accu_new = rho * accu + (one - rho) * grad ** 2
updates[accu] = accu_new
updates[param] = param - (learning_rate * grad /
T.sqrt(accu_new + epsilon))
return updates
Maybe someone more skilled with Theano/Lasagne will see a solution? Conceptually I think the computation is easy but coding everything in the update step symbolically has proven challenging for me. I am still getting used to Theano.