4

I want to constrain the parameters of an intermediate layer in a neural network to prefer discrete values: -1, 0, or 1. The idea is to add a custom objective function that would increase the loss if the parameters take any other value. Note that, I want to constrain parameters of a particular layer, not all layers.

How can I implement this in pytorch? I want to add this custom loss to the total loss in the training loop, something like this:

custom_loss = constrain_parameters_to_be_discrete 
loss = other_loss + custom_loss 

May be using a Dirichlet prior might help, any pointer to this?

Rakib
  • 7,435
  • 7
  • 29
  • 45

2 Answers2

1

You can use the loss function:

def custom_loss_function(x):
  loss = torch.abs(x**2 - torch.abs(x))
  return loss.mean()

This graph plots the proposed loss for a single element:
enter image description here

As you can see, the proposed loss is zero for x={-1, 0, 1} and positive otherwise.

Note that if you want to apply this loss to the weights of a specific layer, then your x here are the weights, not the activations of the layer.

Shai
  • 111,146
  • 38
  • 238
  • 371
  • I thought about something similar, but how do you actually make sure that the *parameters* (and not the outputs) of *a specific layer* take these values? – dennlinger May 31 '21 at 12:37
  • 1
    @dennlinger you need to apply it to the _parameters_ rather than the _activations_ – Shai May 31 '21 at 12:38
  • 1
    @Shai might be worth mentioning this as right now it looks as if the `x` is some kind of output with the naming containing `loss`, IMO confusing a little. – Szymon Maszke May 31 '21 at 12:40
  • Thanks a lot, Shai for this elegant solution. However, I tried that, and all the parameters ended up being very close to zero (akin to L2 regularization). Note that I am applying this cost for the first layer only, and no regularization for other layers, so those parameters are free to vary, which should make it easier for the first layer parameters to obey the constraints. I also tried l1_norm = `torch.abs(params**2-params).mean()`, which is to constrain them to be 0 or 1, but that also did not work. – Rakib May 31 '21 at 19:18
  • Do you think taking the `sum` might be better than `mean`? – Rakib May 31 '21 at 19:19
  • 1
    @Rakib do you have a batch norm layer after this one? how do you initialize the weights of this layer? Is it possible you init with normal distribution with small variance, such that all the weights are very close to zero to begin with? You need to rethink the way you init the weights in this scenario – Shai May 31 '21 at 19:31
  • No batchnormalization layer was used, and I kept the default random initialisation method. All layers are nn.Linear (with bias units) and nn.Relu. The default initialisation seems to be following He et al. – Rakib May 31 '21 at 19:49
  • 2
    then what do you expect? the weoghts are init very close to zero, and you need to pay a large penalty to "cross the barrier" around -/+0.5. – Shai May 31 '21 at 19:52
  • I agree, and this is what I did. I set a lambda of 1000 to be multiplied with the loss, but no luck. do you think there might be any other reason for that? – Rakib May 31 '21 at 19:55
  • 2
    setting lambda to high value will make it even more difficult to change values from zero to -/+1. try to init the weights to {-1, 0, 1} - and see how that affect the convergence – Shai May 31 '21 at 20:02
1

Extending upon @Shai answer and mixing it with this answer one could do it simpler via custom layer into which you could pass your specific layer.

First, the calculated derivative of torch.abs(x**2 - torch.abs(x)) taken from WolframAlpha (check here) would be placed inside regularize function.

Now the Constrainer layer:

class Constrainer(torch.nn.Module):
    def __init__(self, module, weight_decay=1.0):
        super().__init__()
        self.module = module
        self.weight_decay = weight_decay

        # Backward hook is registered on the specified module
        self.hook = self.module.register_full_backward_hook(self._weight_decay_hook)

    # Not working with grad accumulation, check original answer and pointers there
    # If that's needed
    def _weight_decay_hook(self, *_):
        for parameter in self.module.parameters():
            parameter.grad = self.regularize(parameter)

    def regularize(self, parameter):
        # Derivative of the regularization term created by @Shia
        sgn = torch.sign(parameter)
        return self.weight_decay * (
            (sgn - 2 * parameter) * torch.sign(1 - parameter * sgn)
        )

    def forward(self, *args, **kwargs):
        # Simply forward and args and kwargs to module
        return self.module(*args, **kwargs)

Usage is really simple (with your specified weight_decay hyperparameter if you need more/less force on the params):

constrained_layer = Constrainer(torch.nn.Linear(20, 10), weight_decay=0.1)

Now you don't have to worry about different loss functions and can use your model normally.

Szymon Maszke
  • 22,747
  • 4
  • 43
  • 83