5

I'm trying to use TensorFlow with my deep learning project.

When I use Momentum Gradient Descent, how is the weight cost strength set?
(The λ in this formula.)

Guy Coder
  • 24,501
  • 8
  • 71
  • 136
Peter Yang
  • 211
  • 2
  • 8

2 Answers2

8

The term for the weight cost/decay is not part of the optimizers in TensorFlow.

It is easy to include, however, by adding the extra penalty to the cost function with the L2 loss on the weights:

C = <your initial cost function>
l2_loss = tf.add_n([tf.nn.l2_loss(v) for v in tf.trainable_variables()])
C = C + lambda * l2_loss

tf.nn.l2_loss(v) link is simply 0.5 * tf.reduce_sum(v * v) and the gradients for individual weights will be equal to lambda * w, which should be equivalent to your linked equation.

Rafał Józefowicz
  • 6,215
  • 2
  • 24
  • 18
  • Thank u so much. I have also implement this part in Theano, it worked.But when i try this in tensorflow, it still can not get the expected result. what is different between them? Please check: http://stackoverflow.com/questions/35488019/whats-different-about-momentum-gradient-update-in-tensorflow-and-theano-like-th – Peter Yang Feb 19 '16 at 09:08
  • be careful not to include biases in this loss since they are also tf.trainable_variables() – gizzmole Feb 14 '17 at 00:45
0

Note that the formula you show does not actually present true "weight decay", but instead L2-regularization. Many people mix these up, including well-known professors. Let me explain.

When using pure SGD (without momentum) as an optimizer, weight decay is the same thing as adding a L2-regularization term to the loss. When using any other optimizer, including Momentum, this is not true.

Weight decay (don't know how to TeX here, so excuse my pseudo-notation):

w[t+1] = w[t] - learning_rate * dw - weight_decay * w

L2-regularization:

loss = actual_loss + lambda * 1/2 sum(||w||_2 for w in network_params)

Computing the gradient of the extra term in L2-regularization gives lambda * w and thus inserting it into the SGD update equation

dloss_dw = dactual_loss_dw + lambda * w
w[t+1] = w[t] - learning_rate * dw

gives the same as weight decay, but mixes lambda with the learning_rate. Any other optimizer, even SGD with momentum, gives a different update rule for weight decay as for L2-regularization! See the paper Fixing weight decay in Adam for more details. (Edit: AFAIK, this 1987 Hinton paper introduced "weight decay", literally as "each time the weights are updated, their magnitude is also decremented by 0.4%" at page 10)

That being said, there doesn't seem to be support for "proper" weight decay in TensorFlow yet. There are a few issues discussing it, specifically because of above paper.

One possible way to implement it is by writing an op that does the decay step manually after every optimizer step. A different way, which is what I'm currently doing, is using an additional SGD optimizer just for the weight decay, and "attaching" it to your train_op. Both of these are just crude work-arounds, though. My current code:

# In the network definition:
with arg_scope([layers.conv2d, layers.dense],
               weights_regularizer=layers.l2_regularizer(weight_decay)):
    # define the network.

loss = # compute the actual loss of your problem.
train_op = optimizer.minimize(loss, global_step=global_step)
if args.weight_decay not in (None, 0):
    with tf.control_dependencies([train_op]):
        sgd = tf.train.GradientDescentOptimizer(learning_rate=1.0)
        train_op = sgd.minimize(tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)))

This somewhat makes use of TensorFlow's provided bookkeeping. Note that the arg_scope takes care of appending an L2-regularization term for every layer to the REGULARIZATION_LOSSES graph-key, which I then all sum up and optimize using SGD which, as shown above, corresponds to actual weight-decay.

Hope that helps, and if anyone gets a nicer code snippet for this, or TensorFlow implements it better (i.e. in the optimizers), please share.

Edit: see also this PR which just got merged into TF.

LucasB
  • 3,253
  • 1
  • 28
  • 31