1

I'm currently working on a multi-class classification problem which is highly imbalanced. I want to save my model weights for best epoch but I'm confused on which metric I should choose?

Here's my training progress bar :

Progress Bar

I am using ModelCheckpoint callback in tf.keras and monitoring val_loss as a metric to save best model weights.

As you can see in the image,

  • At 8th epoch I got an val_acc = 0.9845 but val_loss = 0.629 and precision and recall is also high here.
  • But at 3rd epoch I got val_acc = 0.9840 but val_loss = 0.590

I understand the difference is not huge but in such cases what's the ideal metric to believe on imbalanced dataset?

desertnaut
  • 57,590
  • 26
  • 140
  • 166
user_12
  • 1,778
  • 7
  • 31
  • 72
  • Accuracy is not useful in highly imbalanced datasets; you should focus on precision and recall. – desertnaut May 24 '20 at 11:14
  • @desertnaut Should I also ignore validation loss in imbalanced datasets? I mean if you look my training, compared to epoch 3 I have better precision and recall at epoch 8 but if we consider validation loss it's higher at epoch 8 compared to epoch 3. – user_12 May 24 '20 at 11:33
  • "Ignore" is probably not the correct term, but you should arguably focus on your business metrics (precision & recall); own answers [here](https://stackoverflow.com/questions/47817424/loss-accuracy-are-these-reasonable-learning-curves/47819022#47819022) and [here](https://stackoverflow.com/questions/47508874/how-does-keras-calculate-the-accuracy/47515095#47515095) might be helpful for clarifying the relation between these metrics and loss (they're about accuracy, but the rationale is the same for precision & recall, too). – desertnaut May 24 '20 at 11:43
  • @desertnaut Thanks. Also, here in tf.keras does recall and precision is calculated for every batch and averaged at the end (or) it is calculated over full validation dataset. Because I've read that the first case is not the ideal way to calculate the precision or recall. I couldn't find any info about it. – user_12 May 24 '20 at 12:03

2 Answers2

1

The most important factors are the the validation and training error. If the validation loss (error) is going to increase so means overfitting. You must set the number of epochs as high as possible and avoid the overfitting and terminate training based on the error rates. . As long as it keeps dropping training should continue. Till model start to converge at nth epochs. Indeed it should converge quite well to a low val_loss.

Just bear in mind an epoch is one learning cycle where the learner can see the whole training data set. If you have two batches, the learner needs to go through two iterations for one epoch.

This link can be helpful.

You can divide data in 3 data sets, training, validation and evaluation. Train each network along enough number of epochs to track the training Mean Squared Error to be stuck in a minimum.

The training process uses training data-set and should be executed epoch by epoch, then calculate the Mean Squared Error of the network in each epoch for the validation set. The network for the epoch with the minimum validation MSE is selected for the evaluation process.

Mahsa Hassankashi
  • 2,086
  • 1
  • 15
  • 25
  • Since OP is in a classification (and an imbalanced one indeed) setting, references to MSE are not correct here. You should probably replace them with the more general term "loss". – desertnaut May 24 '20 at 11:46
1

This can happen for several reasons. Assuming you have used proper separation of train, test and validation set and preprocessing of datasets like min-max scaler, adjusting missing values, you can do the following.

First run the model for several epoch and plot the validation loss graph.

If the loss is first reducing and after reaching a certain point it is now increasing, if the graph is in U shape, then you can do early stopping.

In other scenario, when loss is steadily increasing, early stopping won't work. In this case, add dropout layer of 0.2-0.3 in between the major layers. This will introduce randomness in the layers and will stop the model from memorising.

Now once you add dropouts, your model may suddenly start to behave strange. Tweak with activation functions and number of output nodes or Dense layer and it will eventually get right.

desertnaut
  • 57,590
  • 26
  • 140
  • 166