I am interested in how to train deep neural network with custom loss-function. I have seen posts on stack overflow but they aren't answered. I have downloaded VGG16 and froze weights and added my own head. Now I want to train that network with custom loss, how can I do that?
-
Please provide an example of some code you have tried so that we can help you out. – Bryce Wayne Aug 17 '20 at 06:26
3 Answers
Here is a custom RMSE loss in PyTorch. I hope this gives you a concrete idea of how to implement a custom loss function. You must create a class that inherits nn.Module
, define the initialization and forward pass.
class RMSELoss(nn.Module):
def __init__(self, eps=1e-9):
super().__init__()
self.mse = nn.MSELoss()
self.eps = eps
def forward(self,yhat,y):
loss = torch.sqrt(self.mse(yhat,y) + self.eps)
return loss

- 351
- 5
- 17
-
1-2 immediately right? did you see that i asked about TF and not pytorch? – Dato Nefaridze Aug 17 '20 at 06:43
You can simply define a function with two input parameters(true value, predicted value). Then you can calculate the loss using those values by your very own method.
Here is the coding sample:
def custom_loss( y_true , y_pred ):
tf.losses.mean_squared_error( y_true , y_pred )
I have used mse from tf backend in this example. But you can use manual calculation here. Compile your model with this loss function.
model.compile(
optimizer=your_optimizer,
loss=custom_loss
)
You can also define your own customized metric to judge during the training.
def custom_metric( y_true , y_pred ):
return calculate_your_metric( y_true , y_pred )
Finally, compile with it,
model.compile(
optimizer=your_optimizer,
loss=custom_loss,
metrics=[ custom_metric ]
)

- 860
- 6
- 17
There are several examples and repositories showing how to implement perceptual loss which sounds like what you are referring to. Of course, you can generalize and learn from some of these approaches to different models depending on your problem. If you do so, I recommend writing about it and sharing. I don't see many examples other than using some pretrained vgg model, and breaking that mold might be a nice contribution! Anyway, you might find these other answers useful:

- 331
- 2
- 6