0

I am trying to train custom UNET model for multiclass segmentation with 25000 training images. I am running code on remote ubuntu machine via Putty SSH connection. The scripts starts running and after some epochs the entire Putty session crashes making it impossible to get the actual error that caused the crash.

I believe that in my main training script is not a problem, but maybe I have some problems in my script that create Datasets and DataLoaders. It looks like this:

import pandas as pd
import numpy as np
import os
import cv2
import matplotlib.pyplot as plt
from datetime import datetime

import torch
import torch.nn as nn
from torch.nn import ConvTranspose2d
from torch.nn import Conv2d
from torch.nn import MaxPool2d
from torch.nn import Module
from torch.nn import ModuleList
from torch.nn import ReLU
from torch.nn import functional as F
from torchvision.transforms import CenterCrop
import torchvision.transforms as T
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import time
from tqdm import tqdm
import random
from PIL import Image

to_tensor = T.ToTensor()

# determine the device to be used for training and evaluation
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" 

# determine if we will be pinning memory during data loading
PIN_MEMORY = True if DEVICE == "cuda" else False

BATCH_SIZE = 64

# # Data importing
df_train = pd.read_csv("../../datasets/unet_cropped/df_train.csv")
df_val = pd.read_csv("../../datasets/unet_cropped/df_val.csv")

# Dictionaries
d_train = {}
for row, data in df_train.iterrows():
    d_train[data['image_path']] = [data['class_1_path'], data['class_3_path'], data['class_4_path']]

d_val = {}
for row, data in df_val.iterrows():
    d_val[data['image_path']] = [data['class_1_path'], data['class_3_path'], data['class_4_path']]


def image_to_gray(image_path):
    # Open image
    img = Image.open(image_path)
    
    # Apply transformation and convert to Pytorch tensor
    img_tensor = to_tensor(img)
    
    # Convert into Gray-scale 
    img_torch_gray = T.functional.rgb_to_grayscale(img_tensor, num_output_channels= 1)
    
    return img_torch_gray


def show(images, folder_path, title, subtitles = True):
    
    """ images : list of images
    """
    
    fig,axes = plt.subplots(nrows = len(images), ncols = 1, figsize=(6,7), constrained_layout=True)
    fig.suptitle(title, fontsize=10)
    
    if len(images) >1:
        for i, image_id in enumerate(images):
            image_path = folder_path + image_id
            im_for_plot = image_to_gray(image_path)
            if subtitles:
                axes[i-1].set_title(image_id, fontsize=8)
                axes[i-1].imshow(im_for_plot, cmap = "gray")
    else:
        image_id = images[0]
        image_path = folder_path + image_id
        im_for_plot = image_to_gray(image_path)
        if subtitles:
            axes.set_title(image_id, fontsize=8)
            axes.imshow(im_for_plot.permute((1,2,0)), cmap = "gray")
        
    plt.show()
    


def mask_to_gray4(mask_path):
    
    """ For given list of mask_paths create 3-dim tensor.
        Example:
        mask_path  = [mask_path_class_1, NaN, NaN]
        final_mask = [ gray_scale_mask_class_1,  torch_zeros(1, W, H),  torch_zeros(1, W, H) ]
    """
    
    final_mask = torch.zeros((3, 256, 256))
    
    for i, sample in enumerate(mask_path):
        # if sample is not NaN, continue
        if sample!=sample:
            continue
            
        # Open image
        img = Image.open(sample)
        
        # Apply transformation and convert to Pytorch tensor
        img_tensor = to_tensor(img)
        
        # Convert into Gray-scale 
        img_torch_gray = T.functional.rgb_to_grayscale(img_tensor, num_output_channels= 1)
        img_torch_gray[img_torch_gray>0] = 1
        final_mask[i, :, :] = img_torch_gray
        
    return final_mask


image_train_paths = []
mask_train_paths = []

for key, value in d_train.items():
    image_train_paths.append(key)
    mask_train_paths.append(value)

image_val_paths = []
mask_val_paths = []

for key, value in d_val.items():
    image_val_paths.append(key)
    mask_val_paths.append(value)

del d_train, d_val


# main function
class SegmentationDataset(Dataset):
   
    def __init__(self, image_paths, mask_paths):
        # store the image and mask filepaths
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        
    def __len__(self):
        # return the number of total samples contained in the dataset
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        """ Loads and returns a sample from the dataset at the given index idx. """
        
        # grab the image path from the current index
        image_path = self.image_paths[idx]
        image = image_to_gray(image_path)
              
        # grab the mask path from the current index
        mask_path = self.mask_paths[idx]
        mask = mask_to_gray4(mask_path)

        return (image, mask)
    
    
    
# # Data Loading

trainDS = SegmentationDataset(image_train_paths, mask_train_paths) 
trainLoader = DataLoader(trainDS, shuffle = True, batch_size = BATCH_SIZE,
                         pin_memory = PIN_MEMORY)

valDS = SegmentationDataset(image_val_paths, mask_val_paths) 
valLoader = DataLoader(valDS, shuffle = True, batch_size = BATCH_SIZE,
                        pin_memory = PIN_MEMORY)


# calculate steps per epoch for training and validation set
trainSteps = len(trainDS) // BATCH_SIZE
valSteps = len(valDS) // BATCH_SIZE

Maybe the problem has to do with some sort of memory leak or RAM constantly increasing during training, but I can't spot where exactly can that happen in this script.

Training script looks like this:

#!/usr/bin/env python
# coding: utf-8

# Import libraries

import pandas as pd
import numpy as np
import os
from datetime import datetime
import time

import random
from PIL import Image

import torch
import torch.nn as nn
from torch.nn import ConvTranspose2d
from torch.nn import Conv2d
from torch.nn import MaxPool2d, BatchNorm2d
from torch.nn import Module
from torch.nn import ModuleList
from torch.nn import ReLU
from torch.nn import functional as F
from torchvision.transforms import CenterCrop
import torchvision.transforms as T
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

from torch.optim.lr_scheduler import StepLR

from tqdm import tqdm

torch.manual_seed(1)


from dataset import (
    trainLoader,
    valLoader,
    trainSteps,
    valSteps
)

from model import UNet

# determine the device to be used for training and evaluation
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" 

# determine if we will be pinning memory during data loading
PIN_MEMORY = True if DEVICE == "cuda" else False

#----------------------------------------------------------------------------

NUM_CHANNELS = 1 # number of channels in the input - grayscale image
NUM_CLASSES = 3  # number of classes
NUM_LEVELS = 3   # number of levels in the U-Net model

# initialize learning rate, number of epochs to train for, and the batch size
INIT_LR = 1e-3
NUM_EPOCHS = 100
BATCH_SIZE = 64
threshold = 0.4      ######## --------------------------

dt_string = time.ctime()

model_name = "models/tuesday_25/unet" + "_th_" + str(threshold) + "_" + dt_string.replace(" ", "_")
path_model =  model_name + '.pth'
path_param = model_name + '_model_param.txt'

eval_txt = model_name + "_eval.txt"

FILEPRINT = True
if FILEPRINT:
    EVAL_FILE = open(eval_txt, "a+")
    # print("[INFO] follow training and validation loss in last 'n' epochs by running:")
    # print(f" >watch tail -n {eval_txt}")


def fileprint(*args):
    if FILEPRINT:
        print(*args, file = EVAL_FILE)
        EVAL_FILE.flush()
    else:
        print(*args)



# # Training UNet model

class EarlyStopping():
    #     https://stackoverflow.com/questions/71998978/early-stopping-in-pytorch
    def __init__(self, tolerance=50, min_delta=0):

        self.tolerance = tolerance
        self.min_delta = min_delta
        self.counter = 0
        self.early_stop = False

    def __call__(self, train_loss, validation_loss):
        if (validation_loss - train_loss) > self.min_delta:
            # if self.counter == 0:
                # print("Validation loss increase detected")
            self.counter +=1
            if self.counter >= self.tolerance:  
                self.early_stop = True
                


# def bce_dice_loss(predicted, truth, threshold):
    
#     batch_size = len(truth)
    
#     # BCE
#     bce_loss = lossFunc(predicted, truth)
    
#     # DICE
#     predicted = torch.sigmoid(predicted.detach())
#     predicted[predicted  > threshold] = 1    
#     predicted[predicted <= threshold] = 0 
    
#     predicted = predicted.view(batch_size, -1)
#     truth     = truth.view(batch_size, -1)      
#     assert(predicted.shape == truth.shape)
    
#     tp = (predicted * truth).sum(-1)
#     fp = ((truth == 0.0).float() * predicted).sum(-1)
#     fn = ((truth >= 1.0).float() * (predicted == 0.0).float()).sum(-1)

#     dice_score = 2*tp / (2*tp + fp + fn)
    
#     # BCE DICE
#     bce_dice = 0.75 * bce_loss + 0.25 * (1 - dice_score)

#     batch_bce_dice_loss = torch.nanmean(bce_dice)
        
#     return batch_bce_dice_loss, (torch.nanmean(dice_score)).item()


# ---------------------------- Initialize UNet model -------------------------
model = UNet(nbClasses = NUM_CLASSES).to(DEVICE)

lossFunc = nn.BCEWithLogitsLoss()
opt = torch.optim.RAdam(model.parameters(), lr=INIT_LR)
scaler = torch.cuda.amp.GradScaler()
torch.autograd.set_detect_anomaly(True)


early_stopping = EarlyStopping(tolerance=500, min_delta=1e-5)



# initialize a dictionary to store training history
H = {"train_loss": [], "val_loss": [], "dice_score": []}


# ----------------------------- Training UNet model ----------------------------
# print("[INFO] training the network...")
startTime = time.time()

model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
# print(f"Num of params: {params}")


# for e in tqdm(range(NUM_EPOCHS)):
for e in range(NUM_EPOCHS):
    
    model.train()                             # set the model in training mode
    
    totalTrainLoss = 0                        # initialize the total training and validation loss
    totalValLoss = 0
    # totalValDiceScore = 0
    # totalTrainDiceScore = 0
    
    # loop over the training set
    for (i, (x, y)) in enumerate(trainLoader):
        
        opt.zero_grad()  # first, zero out any previously accumulated gradients, 
        
        (x, y) = (x.to(DEVICE), y.to(DEVICE)) # send the input to the device
        
        pred = model(x)                       # perform a forward pass
        # print("pred :", pred.size())
        # print("y:" , y.size())
        
        # loss, dice_score = bce_dice_loss(pred, y, threshold)
        loss = lossFunc(pred, y.float())      # calculate the training loss
                 
        scaler.scale(loss).backward()         # perform backpropagation
        scaler.step(opt)
        scaler.update()                       # update model parameters
        
        totalTrainLoss += loss.item()                # add the loss to the total training loss
        # totalTrainDiceScore += float(dice_score)
        

    # switch off autograd
    with torch.no_grad():
        model.eval()                         # set the model in evaluation mode
        # loop over the validation set
        for (x, y) in valLoader:
            (x, y) = (x.to(DEVICE), y.to(DEVICE))
            pred = model(x)
            
            loss = lossFunc(pred, y.float())  
            # loss, dice_score = bce_dice_loss(pred, y, threshold)
            
            totalValLoss += loss.item()            
            # totalValDiceScore += float(dice_score)
    
    
    # calculate the average training and validation loss
    avgTrainLoss = totalTrainLoss / trainSteps
    # avgTrainDiceScore = totalTrainDiceScore/ trainSteps
    
    avgValLoss = totalValLoss / valSteps
    # avgValDiceScore = totalValDiceScore/ valSteps
    
    # early_stopping(avgTrainLoss, avgValLoss)
    # if early_stopping.early_stop:
    if (e>0) & ((e%5) == 0):
        # print("We are at epoch: ", e)
        path_model =  model_name + ".pth"
        path_param = model_name + '_model_param.txt'
        torch.save(model.state_dict(), path_model)

        with open(path_param, 'wt') as f:
            f.write(f"Batch size used: {BATCH_SIZE}")
            f.write(f"\nNumber of epochs: {NUM_EPOCHS}")
            f.write(f"\nINIT_LR: {INIT_LR}")
            f.write("\nModel parameters: \n")
            f.write(str(model.eval()))
        # break
        
    text1 = "[INFO] EPOCH: {}/{}\n".format(e + 1, NUM_EPOCHS)
    # text2 = "Train loss: {:.4f}, Train dice score: {:.4f}, Val loss: {:.4f}, Val dice score: {:.4f}\n".format(avgTrainLoss, avgTrainDiceScore, avgValLoss, avgValDiceScore)
    text2 = "Train loss: {:.4f}, Val loss: {:.4f}\n".format(avgTrainLoss, avgValLoss)
    
    fileprint(text1)
    fileprint(text2)

    
    
# display the total time needed to perform the training
endTime = time.time()
# print("[INFO] total time taken to train the model: {:.2f}s".format(endTime - startTime))

# ------------------------------- Saving UNet model --------------------------------------------------
path_model =  model_name + '.pth'
path_param = model_name + '_model_param.txt'
torch.save(model.state_dict(), path_model)
            
with open(path_model, 'wt') as f:
    f.write(f"Batch size used: {BATCH_SIZE}")
    f.write(f"\nNumber of epochs: {NUM_EPOCHS}")
    f.write(f"\nINIT_LR: {INIT_LR}")
    f.write("\nModel parameters: \n")
    f.write(str(model.eval()))

Does anyone have any ideas how to solve this? Thank you a lot

Martin Prikryl
  • 188,800
  • 56
  • 490
  • 992
zmaj993
  • 1
  • 1

0 Answers0