16

Is there a way to set a global weight decay in Keras?

I know about the layer wise one using regularizers(https://keras.io/regularizers/), but I could not find any information about a way to set a global weight decay.

Thomas Pinetz
  • 6,948
  • 2
  • 27
  • 46
  • I guess you have to do it one by one. There is no universal setting on weight decay for the whole model. – James Mar 07 '17 at 11:25

3 Answers3

14

There is no way to directly apply a "global" weight decay to a whole keras model at once.

However, as I describe here, you can employ weight decay on a model by looping through its layers and manually applying the regularizers on appropriate layers. Here's the relevant code snippet:

model = keras.applications.ResNet50(include_top=True, weights='imagenet')
alpha = 0.00002  # weight decay coefficient

for layer in model.layers:
    if isinstance(layer, keras.layers.Conv2D) or isinstance(layer, keras.layers.Dense):
        layer.add_loss(lambda layer=layer: keras.regularizers.l2(alpha)(layer.kernel))
    if hasattr(layer, 'bias_regularizer') and layer.use_bias:
        layer.add_loss(lambda layer=layer: keras.regularizers.l2(alpha)(layer.bias))
craymichael
  • 4,578
  • 1
  • 15
  • 24
jake
  • 284
  • 1
  • 3
  • 10
  • 1
    This answer is currently incorrect and unfortunately will in fact add the L2 regularization to the last layer of the model only. Furthermore, the regularization will be added to the last layer for every single Conv/Dense in the entire network. This is due to [how lambda binds](https://stackoverflow.com/questions/10452770/python-lambdas-binding-to-local-values) in Python. I have submitted an edit using lambda binding. – craymichael Oct 23 '21 at 05:51
  • @craymichael nice catch. That's very subtle. I also didn't know you could bind a lambda parameter like that in python. – erikreed Dec 20 '22 at 21:03
11

According to the github repo (https://github.com/fchollet/keras/issues/2717) there is no way to do global weight decay. I answered it here, so others who have the same problem do not have to look furhter for an answer.

To get global weight decay in keras regularizers have to be added to every layer in the model. In my models these layers are batch normalization (beta/gamma regularizer) and dense/convolutions (W_regularizer/b_regularizer) layers.

Layer wise regularization is described here: (https://keras.io/regularizers/).

Thomas Pinetz
  • 6,948
  • 2
  • 27
  • 46
5

Posting the full code to apply weight decay on a Keras model (inspired by the above post):

# a utility function to add weight decay after the model is defined.
def add_weight_decay(model, weight_decay):
    if (weight_decay is None) or (weight_decay == 0.0):
        return

    # recursion inside the model
    def add_decay_loss(m, factor):
        if isinstance(m, tf.keras.Model):
            for layer in m.layers:
                add_decay_loss(layer, factor)
        else:
            for param in m.trainable_weights:
                with tf.keras.backend.name_scope('weight_regularizer'):
                    regularizer = lambda param=param: tf.keras.regularizers.l2(factor)(param)
                    m.add_loss(regularizer)

    # weight decay and l2 regularization differs by a factor of 2
    add_decay_loss(model, weight_decay/2.0)
    return
craymichael
  • 4,578
  • 1
  • 15
  • 24
mathmanu
  • 71
  • 1
  • 1