1

What is a guideline for setting weight decays (e.g. l2 penalty) - and mainly, how do I track whether it's "working" throughout training? (i.e. whether weights are actually decaying, and by how much, compared to no l2-penalty).

OverLordGoldDragon
  • 1
  • 9
  • 53
  • 101

1 Answers1

5

A common approach is "try a range of values, see what works" - but its pitfall is a lack of orthogonality; l2=2e-4 may work best in a network X, but not network Y. A workaround is to guide weight decays in a subnetwork manner: (1) group layers (e.g. Conv1D stacks & LSTMs separately), (2) set target weight norm, (3) track.

(1): See here; the same arguments and suggested weight values won't apply to convs - thus the need for various groupings

(2): A sound option is the l2-norm of the weight matrix being regularized; then there's the question of which axis to compute it with respect to. A feature extraction-oriented approach is to select the channel axis (last in Keras), yielding a vector of length = number of channels / features, so that each element is the l2-norm of a channel.

(3): The l2-norm vectors can be appended to a list iteratively, or maybe their mean/max as briefer aggregate statistics - then plotted at the end of training.

A complete example shown below; the key function, weights_norm, is given at bottom, and is taken from See RNN. I also recommend Keras AdamW for improved weight decay.

enter image description here

Interpretation:

  • wd=2e-3 decays output layer stronger than 2e-4, but not input, suggesting a counterbalance interaction with the bottleneck layer.
  • wd=2e-3 yields lesser variance of weight norms relative to 2e-4
  • Output conv layer's norms grow even with 2e-3, suggesting stronger gradients toward output
  • It is interesting to explore behavior with BatchNormalization added

Code & explanation; the following is done:

  1. Train & track progress

    • Make dummy model & data, select n_batches and wd (l2 penalty)
    • Set up train loop, select n_epochs
    • Create l2_stats dict to track progress
    • On each train iteration, compute weights_norm() and append to l2_stats
  2. Preprocess progress data for plotting

    • Get names of weight-decayed weights; include non-decayed in omit_names
    • l2_stats is convenient to append to, but must be converted to np.ndarray of proper dims; unpack so that .shape == (n_epochs, n_layers, n_weights, n_batches) -> (n_rows, n_cols, hists_per_subplot). Note that this requires number of weight matrices tracked to be the same for each layer
  3. Plot

    • Explicitly set xlims and ylim for even comparison among different wd values
    • Two stats are computed by default: np.mean (orange), and np.max. Latter is also how Keras handles maxnorm weight regularization.
import numpy as np
import tensorflow as tf
import random

np.random.seed(1)
random.seed(2)
tf.compat.v1.set_random_seed(3)

from keras.layers import Input, Conv1D
from keras.models import Model
from keras.regularizers import l2

from see_rnn import weights_norm, features_hist_v2

########### Model & data funcs ################################################
def make_model(batch_shape, layer_kw={}):
    """Conv1D autoencoder"""
    dim = batch_shape[-1]
    bdim = dim // 2

    ipt = Input(batch_shape=batch_shape)
    x   = Conv1D(dim,  8, activation='relu',   **layer_kw)(ipt)
    x   = Conv1D(bdim, 1, activation='relu',   **layer_kw)(x)  # bottleneck
    out = Conv1D(dim,  8, activation='linear', **layer_kw)(x)

    model = Model(ipt, out)
    model.compile('adam', 'mse')
    return model

def make_data(batch_shape, n_batches):
    X = Y = np.random.randn(n_batches, *batch_shape)
    return X, Y

########### Train setup #######################################################
batch_shape = (32, 100, 64)
n_epochs = 5
n_batches = 200
wd = 2e-3
layer_kw = dict(padding='same', kernel_regularizer=l2(wd))

model = make_model(batch_shape, layer_kw)
X, Y  = make_data(batch_shape, n_batches)

## Train ####################
l2_stats = {}
for epoch in range(n_epochs):
    l2_stats[epoch] = {}
    for i, (x, y) in enumerate(zip(X, Y)):
        model.train_on_batch(x, y)
        print(end='.')

        verbose = bool(i == len(X) - 1)  # if last epoch iter, print last results
        if verbose:
            print()
        l2_stats[epoch] = weights_norm(model, [1, 3], l2_stats[epoch],
                                       omit_names='bias', verbose=verbose)
    print("Epoch", epoch + 1, "finished")
    print()

########### Preprocess funcs ##################################################
def _get_weight_names(model, layer_names, omit_names):
    weight_names= []
    for name in layer_names:
        layer = model.get_layer(name=name)
        for w in layer.weights:
            if not any(to_omit in w.name for to_omit in omit_names):
                weight_names.append(w.name)
    return weight_names

def _merge_layers_and_weights(l2_stats):
    stats_merged = []
    for stats in l2_stats.values():
        x = np.array(list(stats.values()))  # (layers, weights, stats, batches)
        x = x.reshape(-1, *x.shape[2:])     # (layers-weights, stats, batches)
        stats_merged.append(x)
    return stats_merged  # (epochs, layer-weights, stats, batches)

########### Plot setup ########################################################
ylim = 5
xlims = (.4, 1.2)
omit_names = 'bias'
suptitle = "wd={:.0e}".format(wd).replace('0', '')
side_annot = "EP"
configs = {'side_annot': dict(xy=(.9, .9))}

layer_names = list(l2_stats[0].keys())
weight_names = _get_weight_names(model, layer_names, omit_names)
stats_merged = _merge_layers_and_weights(l2_stats)

## Plot ########
features_hist_v2(stats_merged, colnames=weight_names, title=suptitle,
                 xlims=xlims, ylim=ylim, side_annot=side_annot, 
                 pad_xticks=True, configs=configs)
OverLordGoldDragon
  • 1
  • 9
  • 53
  • 101