2

Hey so I am trying my hand at image classification/transfer learning using the monkey species dataset and the resnet50 with a modified final fc layer to predict just the 10 classes. Eveything is working until I use model.train() and model.eval() then after the first epoch it starts to return nans and the accuracy drops off as you'll see below. I'm curious why is this only when switching to train/eval....?

First I import the model and attach the classifier and freeze the parameters

%%capture
resnet = models.resnet50(pretrained=True)

for param in resnet.parameters():
  param.required_grad = False

in_features = resnet.fc.in_features


# Build custom classifier
classifier = nn.Sequential(OrderedDict([('fc1', nn.Linear(in_features, 512)),
                                        ('relu', nn.ReLU()),
                                        ('drop', nn.Dropout(0.05)),
                                        ('fc2', nn.Linear(512, 10)),
                                        ]))

# ('output', nn.LogSoftmax(dim=1))
resnet.classifier = classifier

resnet.to(device)

Then setting my loss func, optimizer, and shceduler

# Step : Define criterion and optimizer
criterion = nn.CrossEntropyLoss()
# pass the optimizer to the appended classifier layer
optimizer = torch.optim.SGD(resnet.parameters(), lr=0.01)
# Scheduler
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10], gamma=0.05)  

Then setting the training and validation loops

epochs = 20


tr_losses = []
avg_epoch_tr_loss = []
tr_accuracy = []


val_losses = []
avg_epoch_val_loss = []
val_accuracy = []
val_loss_min = np.Inf


resnet.train()
for epoch in range(epochs):
  for i, batch in enumerate(train_loader):
    # Pull the data and labels from the batch
    data, label = batch
    # If available push data and label to GPU
    if train_on_gpu:
      data, label = data.to(device), label.to(device)
    # Compute the logit
    logit = resnet(data)
    # Compte loss
    loss = criterion(logit, label)
    # Clearing the gradient
    resnet.zero_grad()
    # Backpropagate the gradients (accumulte the partial derivatives of loss)
    loss.backward()
    # Apply the updates to the optimizer step in the opposite direction to the gradient
    optimizer.step()
    # Store the losses of each batch
    # loss.item() seperates the loss from comp graph
    tr_losses.append(loss.item())
    # Detach and store the average accuracy of each batch
    tr_accuracy.append(label.eq(logit.argmax(dim=1)).float().mean())
    # Print the rolling batch training loss every 20 batches
    if i % 40 == 0 and not i == 1:
      print(f'Batch No: {i} \tAverage Training Batch Loss: {torch.tensor(tr_losses).mean():.2f}')
  # Print the average loss for each epoch
  print(f'\nEpoch No: {epoch + 1},Training Loss: {torch.tensor(tr_losses).mean():.2f}')
  # Print the average accuracy for each epoch
  print(f'Epoch No: {epoch + 1}, Training Accuracy: {torch.tensor(tr_accuracy).mean():.2f}\n')
  # Store the avg epoch loss for plotting
  avg_epoch_tr_loss.append(torch.tensor(tr_losses).mean())


  resnet.eval()
  for i, batch in enumerate(val_loader):
    # Pull the data and labels from the batch
    data, label = batch
    # If available push data and label to GPU
    if train_on_gpu:
      data, label = data.to(device), label.to(device)
    # Compute the logits without computing the gradients
    with torch.no_grad():
      logit = resnet(data)
    # Compte loss
    loss = criterion(logit, label)
    # Store test loss
    val_losses.append(loss.item())
    # Store the accuracy for each batch
    val_accuracy.append(label.eq(logit.argmax(dim=1)).float().mean())
    if i % 20 == 0 and not i == 1:
      print(f'Batch No: {i+1} \tAverage Val Batch Loss: {torch.tensor(val_losses).mean():.2f}')
  # Print the average loss for each epoch
  print(f'\nEpoch No: {epoch + 1}, Epoch Val Loss: {torch.tensor(val_losses).mean():.2f}')
  # Print the average accuracy for each epoch    
  print(f'Epoch No: {epoch + 1}, Epoch Val Accuracy: {torch.tensor(val_accuracy).mean():.2f}\n')
  # Store the avg epoch loss for plotting
  avg_epoch_val_loss.append(torch.tensor(val_losses).mean())

  # Checpoininting the model using val loss threshold
  if torch.tensor(val_losses).float().mean() <= val_loss_min:
    print("Epoch Val Loss Decreased... Saving model")
    # save current model
    torch.save(resnet.state_dict(), '/content/drive/MyDrive/1. Full Projects/Intel Image Classification/model_state.pt')
    val_loss_min = torch.tensor(val_losses).mean()
  # Step the scheduler for the next epoch
  scheduler.step()
  # Print the updated learning rate
  print('Learning Rate Set To: {:.5f}'.format(optimizer.state_dict()['param_groups'][0]['lr']),'\n')

The model starts to train but then slowly becomes nan values

Batch No: 0     Average Training Batch Loss: 9.51
Batch No: 40    Average Training Batch Loss: 1.71
Batch No: 80    Average Training Batch Loss: 1.15
Batch No: 120   Average Training Batch Loss: 0.94

Epoch No: 1,Training Loss: 0.83
Epoch No: 1, Training Accuracy: 0.78

Batch No: 1     Average Val Batch Loss: 0.39
Batch No: 21    Average Val Batch Loss: 0.56
Batch No: 41    Average Val Batch Loss: 0.54
Batch No: 61    Average Val Batch Loss: 0.54

Epoch No: 1, Epoch Val Loss: 0.55
Epoch No: 1, Epoch Val Accuracy: 0.81

Epoch Val Loss Decreased... Saving model
Learning Rate Set To: 0.01000 

Batch No: 0     Average Training Batch Loss: 0.83
Batch No: 40    Average Training Batch Loss: nan
Batch No: 80    Average Training Batch Loss: nan
Digital Moniker
  • 281
  • 1
  • 12
  • is it possible one of your training samples has `nan` in it? – Shai Apr 20 '21 at 07:00
  • 2
    you might find useful information in [this answer](https://stackoverflow.com/a/33980220/1714410) – Shai Apr 20 '21 at 07:05
  • @Shai no nans in the data its just folders of images, Ill check the link too, thanks. – Digital Moniker Apr 20 '21 at 07:11
  • it might be that transformations applied to data creates `nan`s. does the `nan` always appear at the same time/iteration/epoch? what happens if you significantly reduce the learning rate? – Shai Apr 20 '21 at 07:18
  • @Shai so it seems my learning rate was set too high. I changed it based on that comment and it works but Im going to read a little more about it from the link. Yes the nan appeared in the same place. I switched lr from 0.01 to 0.001. – Digital Moniker Apr 20 '21 at 07:19
  • 1
    @Shai I see that your in Rehovot, I'm Irish but I live in Tel Aviv ;) – Digital Moniker Apr 20 '21 at 12:50

1 Answers1

1

I see that resnet.zero_grad() is after logit = resnet(data), which causes the gradient to explode in your case.

Please do it as below:

# Clearing the gradient
optimizer.zero_grad()
logit = resnet(data)

# Compute loss
loss = criterion(logit, label)
momo
  • 90
  • 9
Prajot Kuvalekar
  • 5,128
  • 3
  • 21
  • 32
  • I'm not sure this is the issue here (based on comments, it was large learning rate). AFAIK you should zero_grad()` before `backward()` and it does not matter if it is before or after the forward pass. – Shai Apr 20 '21 at 13:10
  • Just think , he was computing gradients in forward pass....and he was clearing them using `.zero_grad()`.....will it make sense? – Prajot Kuvalekar Apr 20 '21 at 14:26
  • he did forward, zero_grad, and then backward, step. it seems as good as zero_grad, forward, backward, step – Shai Apr 20 '21 at 14:51
  • Just lowering the lr helped but also I raised the lr again calling .zero_grad() before the forward pass and both worked...no more nans. I figure from reading its safer to call it before the forward pass anyway... – Digital Moniker Apr 20 '21 at 15:56