I am trying to train custom UNET model for multiclass segmentation with 25000 training images. I am running code on remote ubuntu machine via Putty SSH connection. The scripts starts running and after some epochs the entire Putty session crashes making it impossible to get the actual error that caused the crash.
I believe that in my main training script is not a problem, but maybe I have some problems in my script that create Datasets and DataLoaders. It looks like this:
import pandas as pd
import numpy as np
import os
import cv2
import matplotlib.pyplot as plt
from datetime import datetime
import torch
import torch.nn as nn
from torch.nn import ConvTranspose2d
from torch.nn import Conv2d
from torch.nn import MaxPool2d
from torch.nn import Module
from torch.nn import ModuleList
from torch.nn import ReLU
from torch.nn import functional as F
from torchvision.transforms import CenterCrop
import torchvision.transforms as T
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import time
from tqdm import tqdm
import random
from PIL import Image
to_tensor = T.ToTensor()
# determine the device to be used for training and evaluation
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# determine if we will be pinning memory during data loading
PIN_MEMORY = True if DEVICE == "cuda" else False
BATCH_SIZE = 64
# # Data importing
df_train = pd.read_csv("../../datasets/unet_cropped/df_train.csv")
df_val = pd.read_csv("../../datasets/unet_cropped/df_val.csv")
# Dictionaries
d_train = {}
for row, data in df_train.iterrows():
d_train[data['image_path']] = [data['class_1_path'], data['class_3_path'], data['class_4_path']]
d_val = {}
for row, data in df_val.iterrows():
d_val[data['image_path']] = [data['class_1_path'], data['class_3_path'], data['class_4_path']]
def image_to_gray(image_path):
# Open image
img = Image.open(image_path)
# Apply transformation and convert to Pytorch tensor
img_tensor = to_tensor(img)
# Convert into Gray-scale
img_torch_gray = T.functional.rgb_to_grayscale(img_tensor, num_output_channels= 1)
return img_torch_gray
def show(images, folder_path, title, subtitles = True):
""" images : list of images
"""
fig,axes = plt.subplots(nrows = len(images), ncols = 1, figsize=(6,7), constrained_layout=True)
fig.suptitle(title, fontsize=10)
if len(images) >1:
for i, image_id in enumerate(images):
image_path = folder_path + image_id
im_for_plot = image_to_gray(image_path)
if subtitles:
axes[i-1].set_title(image_id, fontsize=8)
axes[i-1].imshow(im_for_plot, cmap = "gray")
else:
image_id = images[0]
image_path = folder_path + image_id
im_for_plot = image_to_gray(image_path)
if subtitles:
axes.set_title(image_id, fontsize=8)
axes.imshow(im_for_plot.permute((1,2,0)), cmap = "gray")
plt.show()
def mask_to_gray4(mask_path):
""" For given list of mask_paths create 3-dim tensor.
Example:
mask_path = [mask_path_class_1, NaN, NaN]
final_mask = [ gray_scale_mask_class_1, torch_zeros(1, W, H), torch_zeros(1, W, H) ]
"""
final_mask = torch.zeros((3, 256, 256))
for i, sample in enumerate(mask_path):
# if sample is not NaN, continue
if sample!=sample:
continue
# Open image
img = Image.open(sample)
# Apply transformation and convert to Pytorch tensor
img_tensor = to_tensor(img)
# Convert into Gray-scale
img_torch_gray = T.functional.rgb_to_grayscale(img_tensor, num_output_channels= 1)
img_torch_gray[img_torch_gray>0] = 1
final_mask[i, :, :] = img_torch_gray
return final_mask
image_train_paths = []
mask_train_paths = []
for key, value in d_train.items():
image_train_paths.append(key)
mask_train_paths.append(value)
image_val_paths = []
mask_val_paths = []
for key, value in d_val.items():
image_val_paths.append(key)
mask_val_paths.append(value)
del d_train, d_val
# main function
class SegmentationDataset(Dataset):
def __init__(self, image_paths, mask_paths):
# store the image and mask filepaths
self.image_paths = image_paths
self.mask_paths = mask_paths
def __len__(self):
# return the number of total samples contained in the dataset
return len(self.image_paths)
def __getitem__(self, idx):
""" Loads and returns a sample from the dataset at the given index idx. """
# grab the image path from the current index
image_path = self.image_paths[idx]
image = image_to_gray(image_path)
# grab the mask path from the current index
mask_path = self.mask_paths[idx]
mask = mask_to_gray4(mask_path)
return (image, mask)
# # Data Loading
trainDS = SegmentationDataset(image_train_paths, mask_train_paths)
trainLoader = DataLoader(trainDS, shuffle = True, batch_size = BATCH_SIZE,
pin_memory = PIN_MEMORY)
valDS = SegmentationDataset(image_val_paths, mask_val_paths)
valLoader = DataLoader(valDS, shuffle = True, batch_size = BATCH_SIZE,
pin_memory = PIN_MEMORY)
# calculate steps per epoch for training and validation set
trainSteps = len(trainDS) // BATCH_SIZE
valSteps = len(valDS) // BATCH_SIZE
Maybe the problem has to do with some sort of memory leak or RAM constantly increasing during training, but I can't spot where exactly can that happen in this script.
Training script looks like this:
#!/usr/bin/env python
# coding: utf-8
# Import libraries
import pandas as pd
import numpy as np
import os
from datetime import datetime
import time
import random
from PIL import Image
import torch
import torch.nn as nn
from torch.nn import ConvTranspose2d
from torch.nn import Conv2d
from torch.nn import MaxPool2d, BatchNorm2d
from torch.nn import Module
from torch.nn import ModuleList
from torch.nn import ReLU
from torch.nn import functional as F
from torchvision.transforms import CenterCrop
import torchvision.transforms as T
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import StepLR
from tqdm import tqdm
torch.manual_seed(1)
from dataset import (
trainLoader,
valLoader,
trainSteps,
valSteps
)
from model import UNet
# determine the device to be used for training and evaluation
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# determine if we will be pinning memory during data loading
PIN_MEMORY = True if DEVICE == "cuda" else False
#----------------------------------------------------------------------------
NUM_CHANNELS = 1 # number of channels in the input - grayscale image
NUM_CLASSES = 3 # number of classes
NUM_LEVELS = 3 # number of levels in the U-Net model
# initialize learning rate, number of epochs to train for, and the batch size
INIT_LR = 1e-3
NUM_EPOCHS = 100
BATCH_SIZE = 64
threshold = 0.4 ######## --------------------------
dt_string = time.ctime()
model_name = "models/tuesday_25/unet" + "_th_" + str(threshold) + "_" + dt_string.replace(" ", "_")
path_model = model_name + '.pth'
path_param = model_name + '_model_param.txt'
eval_txt = model_name + "_eval.txt"
FILEPRINT = True
if FILEPRINT:
EVAL_FILE = open(eval_txt, "a+")
# print("[INFO] follow training and validation loss in last 'n' epochs by running:")
# print(f" >watch tail -n {eval_txt}")
def fileprint(*args):
if FILEPRINT:
print(*args, file = EVAL_FILE)
EVAL_FILE.flush()
else:
print(*args)
# # Training UNet model
class EarlyStopping():
# https://stackoverflow.com/questions/71998978/early-stopping-in-pytorch
def __init__(self, tolerance=50, min_delta=0):
self.tolerance = tolerance
self.min_delta = min_delta
self.counter = 0
self.early_stop = False
def __call__(self, train_loss, validation_loss):
if (validation_loss - train_loss) > self.min_delta:
# if self.counter == 0:
# print("Validation loss increase detected")
self.counter +=1
if self.counter >= self.tolerance:
self.early_stop = True
# def bce_dice_loss(predicted, truth, threshold):
# batch_size = len(truth)
# # BCE
# bce_loss = lossFunc(predicted, truth)
# # DICE
# predicted = torch.sigmoid(predicted.detach())
# predicted[predicted > threshold] = 1
# predicted[predicted <= threshold] = 0
# predicted = predicted.view(batch_size, -1)
# truth = truth.view(batch_size, -1)
# assert(predicted.shape == truth.shape)
# tp = (predicted * truth).sum(-1)
# fp = ((truth == 0.0).float() * predicted).sum(-1)
# fn = ((truth >= 1.0).float() * (predicted == 0.0).float()).sum(-1)
# dice_score = 2*tp / (2*tp + fp + fn)
# # BCE DICE
# bce_dice = 0.75 * bce_loss + 0.25 * (1 - dice_score)
# batch_bce_dice_loss = torch.nanmean(bce_dice)
# return batch_bce_dice_loss, (torch.nanmean(dice_score)).item()
# ---------------------------- Initialize UNet model -------------------------
model = UNet(nbClasses = NUM_CLASSES).to(DEVICE)
lossFunc = nn.BCEWithLogitsLoss()
opt = torch.optim.RAdam(model.parameters(), lr=INIT_LR)
scaler = torch.cuda.amp.GradScaler()
torch.autograd.set_detect_anomaly(True)
early_stopping = EarlyStopping(tolerance=500, min_delta=1e-5)
# initialize a dictionary to store training history
H = {"train_loss": [], "val_loss": [], "dice_score": []}
# ----------------------------- Training UNet model ----------------------------
# print("[INFO] training the network...")
startTime = time.time()
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
# print(f"Num of params: {params}")
# for e in tqdm(range(NUM_EPOCHS)):
for e in range(NUM_EPOCHS):
model.train() # set the model in training mode
totalTrainLoss = 0 # initialize the total training and validation loss
totalValLoss = 0
# totalValDiceScore = 0
# totalTrainDiceScore = 0
# loop over the training set
for (i, (x, y)) in enumerate(trainLoader):
opt.zero_grad() # first, zero out any previously accumulated gradients,
(x, y) = (x.to(DEVICE), y.to(DEVICE)) # send the input to the device
pred = model(x) # perform a forward pass
# print("pred :", pred.size())
# print("y:" , y.size())
# loss, dice_score = bce_dice_loss(pred, y, threshold)
loss = lossFunc(pred, y.float()) # calculate the training loss
scaler.scale(loss).backward() # perform backpropagation
scaler.step(opt)
scaler.update() # update model parameters
totalTrainLoss += loss.item() # add the loss to the total training loss
# totalTrainDiceScore += float(dice_score)
# switch off autograd
with torch.no_grad():
model.eval() # set the model in evaluation mode
# loop over the validation set
for (x, y) in valLoader:
(x, y) = (x.to(DEVICE), y.to(DEVICE))
pred = model(x)
loss = lossFunc(pred, y.float())
# loss, dice_score = bce_dice_loss(pred, y, threshold)
totalValLoss += loss.item()
# totalValDiceScore += float(dice_score)
# calculate the average training and validation loss
avgTrainLoss = totalTrainLoss / trainSteps
# avgTrainDiceScore = totalTrainDiceScore/ trainSteps
avgValLoss = totalValLoss / valSteps
# avgValDiceScore = totalValDiceScore/ valSteps
# early_stopping(avgTrainLoss, avgValLoss)
# if early_stopping.early_stop:
if (e>0) & ((e%5) == 0):
# print("We are at epoch: ", e)
path_model = model_name + ".pth"
path_param = model_name + '_model_param.txt'
torch.save(model.state_dict(), path_model)
with open(path_param, 'wt') as f:
f.write(f"Batch size used: {BATCH_SIZE}")
f.write(f"\nNumber of epochs: {NUM_EPOCHS}")
f.write(f"\nINIT_LR: {INIT_LR}")
f.write("\nModel parameters: \n")
f.write(str(model.eval()))
# break
text1 = "[INFO] EPOCH: {}/{}\n".format(e + 1, NUM_EPOCHS)
# text2 = "Train loss: {:.4f}, Train dice score: {:.4f}, Val loss: {:.4f}, Val dice score: {:.4f}\n".format(avgTrainLoss, avgTrainDiceScore, avgValLoss, avgValDiceScore)
text2 = "Train loss: {:.4f}, Val loss: {:.4f}\n".format(avgTrainLoss, avgValLoss)
fileprint(text1)
fileprint(text2)
# display the total time needed to perform the training
endTime = time.time()
# print("[INFO] total time taken to train the model: {:.2f}s".format(endTime - startTime))
# ------------------------------- Saving UNet model --------------------------------------------------
path_model = model_name + '.pth'
path_param = model_name + '_model_param.txt'
torch.save(model.state_dict(), path_model)
with open(path_model, 'wt') as f:
f.write(f"Batch size used: {BATCH_SIZE}")
f.write(f"\nNumber of epochs: {NUM_EPOCHS}")
f.write(f"\nINIT_LR: {INIT_LR}")
f.write("\nModel parameters: \n")
f.write(str(model.eval()))
Does anyone have any ideas how to solve this? Thank you a lot