I am training a CNN in PyTorch with Adam and the initial learning rate is 1e-5. I have 5039 samples in my epoch and the batch size is 1. I have observed that I have a regular spike pattern of training loss at the end of an epoch. Here is a plot of the training loss:
From the plot one can see the clear patter of spikes which are happening exactly at the end of the epoch. My epoch contains 5039 samples. Interestingly enough the spikes do not only shoot down but sometimes up as well.
What I don't think it is:
Those spikes could have been explained if one did not shuffle the data-set. However, I shuffle my data-set every epoch.
This behavior is known to happen when final batch of the epoch is smaller than other batches which results in different magnitude of loss (Why does my training loss have regular spikes?). However, it's not my case since my batch size is 1.
One potential hack could be to apply gradient clipping before the update step. However, it doesn't seem like a good way to treat this issue for me.
- What are your thoughts on reasons for this spike pattern?
- How bad is it to have such a pattern?