2

I am training a CNN in PyTorch with Adam and the initial learning rate is 1e-5. I have 5039 samples in my epoch and the batch size is 1. I have observed that I have a regular spike pattern of training loss at the end of an epoch. Here is a plot of the training loss: enter image description here

From the plot one can see the clear patter of spikes which are happening exactly at the end of the epoch. My epoch contains 5039 samples. Interestingly enough the spikes do not only shoot down but sometimes up as well.

What I don't think it is:

  • Those spikes could have been explained if one did not shuffle the data-set. However, I shuffle my data-set every epoch.

  • This behavior is known to happen when final batch of the epoch is smaller than other batches which results in different magnitude of loss (Why does my training loss have regular spikes?). However, it's not my case since my batch size is 1.

One potential hack could be to apply gradient clipping before the update step. However, it doesn't seem like a good way to treat this issue for me.

  1. What are your thoughts on reasons for this spike pattern?
  2. How bad is it to have such a pattern?
tivan
  • 21
  • 3

1 Answers1

0

Two possibilities that I can think of:

  1. Loss logging method that resets every epoch.
  2. Small dataset.

One possibility: the way that you are logging the loss. If, for example, you are accumulating loss at each step, logging the average, and resetting loss at the end of an epoch, then the first batches could be influencing your loss at the end of the epoch. If you reset the loss counter at the end of the epoch, then you could see a jump in performance.

all_losses = []
for e in range(epochs):
    epoch_losses = [] # <- jump in performance when you discard earlier losses
    for i, batch in enumerate(data_loader):
        batch_loss = ...
        epoch_losses.append(batch_loss)
        all_losses.append(np.mean(epoch_losses))

plot(losses)

Another possibility is that your dataset is so small that there is a discernible jump in performance at the beginning of each epoch, because you have an entire batch of items that you have now seen one extra time. With larger datasets there is more noise in the training process (later batches undoing progress made on the first batches) so you don't see that jump.

Jacob Stern
  • 3,758
  • 3
  • 32
  • 54