2

I am new with keras and have been learning it for about 3 weeks now. I apologies if my question sounds a bit stupid.

I am currently doing semantic medical image segmentation of 512x512. I'm using UNet from this link https://github.com/zhixuhao/unet . Basically, I want to segment a brain from an image (so two-class segmentation, background vs foreground)

I have made a few modification of the network and I'm getting some results which i am happy with. But I think I can improve the segmentation results by imposing more weight on the foreground because the number of pixels of the brain is much smaller than the number of background pixels. In some cases the brain does not appear in the image especially those located in the bottom slices.

I don't know which part of the code I need to modify in https://github.com/zhixuhao/unet

I would really appreciate if anyone can help me with this. Thanks a lot in advance!

import numpy as np
import os
import skimage.io as io
import skimage.transform as trans
import numpy as np
from keras.models import *
from keras.layers import *
from keras.optimizers import *
from keras.callbacks import ModelCheckpoint, LearningRateScheduler
from keras import backend as keras


def unet(pretrained_weights=None, input_size=(256, 256, 1)):
  inputs = Input(input_size)
  conv1 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(inputs)
  conv1 = BatchNormalization()(conv1)
  conv1 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv1)
  conv1 = BatchNormalization()(conv1)
  pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

  conv2 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool1)
  conv2 = BatchNormalization()(conv2)
  conv2 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv2)
  conv2 = BatchNormalization()(conv2)
  pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

  conv3 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool2)
  conv3 = BatchNormalization()(conv3)
  conv3 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv3)
  conv3 = BatchNormalization()(conv3)
  pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

  conv4 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool3)
  conv4 = BatchNormalization()(conv4)
  conv4 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv4)
  conv4 = BatchNormalization()(conv4)
  drop4 = Dropout(0.5)(conv4)
  pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)

  conv5 = Conv2D(1024, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool4)
  conv5 = BatchNormalization()(conv5)
  conv5 = Conv2D(1024, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv5)
  conv5 = BatchNormalization()(conv5)
  drop5 = Dropout(0.5)(conv5)

  up6 = Conv2D(512, 2, activation='relu', padding='same', kernel_initializer='he_normal')(
      UpSampling2D(size=(2, 2))(drop5))
  merge6 = concatenate([drop4, up6], axis=3)
  conv6 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge6)
  conv6 = BatchNormalization()(conv6)
  conv6 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv6)
  conv6 = BatchNormalization()(conv6)

  up7 = Conv2D(256, 2, activation='relu', padding='same', kernel_initializer='he_normal')(UpSampling2D(size=(2, 2))(conv6))
  merge7 = concatenate([conv3, up7], axis=3)
  conv7 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge7)
  conv7 = BatchNormalization()(conv7)
  conv7 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv7)
  conv7 = BatchNormalization()(conv7)

  up8 = Conv2D(128, 2, activation='relu', padding='same', kernel_initializer='he_normal')(UpSampling2D(size=(2, 2))(conv7))
  merge8 = concatenate([conv2, up8], axis=3)
  conv8 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge8)
  conv8 = BatchNormalization()(conv8)
  conv8 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv8)
  conv8 = BatchNormalization()(conv8)

  up9 = Conv2D(64, 2, activation='relu', padding='same', kernel_initializer='he_normal')(UpSampling2D(size=(2, 2))(conv8))
  merge9 = concatenate([conv1, up9], axis=3)
  conv9 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge9)
  conv9 = BatchNormalization()(conv9)
  conv9 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv9)
  conv9 = BatchNormalization()(conv9)
  conv9 = Conv2D(2, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv9)
  conv9 = BatchNormalization()(conv9)


  conv10 = Conv2D(1, 1, activation='sigmoid')(conv9)

  model = Model(input=inputs, output=conv10)

  model.compile(optimizer=Adam(lr=1e-4), loss='binary_crossentropy', metrics=['accuracy'])

  # model.summary()

  if (pretrained_weights):
      model.load_weights(pretrained_weights)

  return model

Here's the main.py

from model2 import *
from data2 import *
from keras.models import load_model

class_weight= {0:0.10, 1:0.90}
myGene = trainGenerator(2,'data/brainTIF/trainNew','image','label',save_to_dir = None)
model = unet()
model_checkpoint = ModelCheckpoint('unet_brainTest_e10_s5.hdf5', 
monitor='loss')
model.fit_generator(myGene,steps_per_epoch=5,epochs=10,callbacks = [model_checkpoint])

testGene = testGenerator("data/brainTIF/test3")
results = model.predict_generator(testGene,18,verbose=1)
saveResult("data/brainTIF/test_results3",results)
Galigator
  • 8,957
  • 2
  • 25
  • 39
ErickYA
  • 31
  • 6
  • Can you add some of the code you have already tried and specific snippets you need help with? have a look here on how to ask an MCVE https://stackoverflow.com/help/mcve – nickyfot Apr 15 '19 at 14:03
  • I have edited the original post and included my code – ErickYA Apr 15 '19 at 14:23
  • you can try to use custom loss function as did [here](https://stackoverflow.com/questions/50294566/image-segmentation-custom-loss-function-in-keras) – Krunal V Apr 15 '19 at 14:40
  • Use `focal_loss`. It is exactly for that thing! – enterML Apr 15 '19 at 15:21
  • Thanks @Nain
    I found this code
     def focal_loss(gamma=2., alpha=.25):
      def focal_loss_fixed(y_true, y_pred):
        pt_1 = tf.where(tf.equal(y_true, 1), y_pred, tf.ones_like(y_pred))
        pt_0 = tf.where(tf.equal(y_true, 0), y_pred, tf.zeros_like(y_pred))
    
        pt_1 = K.clip(pt_1, 1e-3, .999)
        pt_0 = K.clip(pt_0, 1e-3, .999)
    
        return -K.sum(alpha * K.pow(1. - pt_1, gamma) * K.log(pt_1))-K.sum((1-alpha) * K.pow( pt_0, gamma) * K.log(1. - pt_0))
    return focal_loss_fixed 
    What are y_true, y_pred? Are they the class weights?
    – ErickYA Apr 15 '19 at 15:35
  • Thanks @Nain from the link [link](https://github.com/mkocabas/focal-loss-keras/blob/master/focal_loss.py) what is y_true, y_pred? are they the weights for different classes? – ErickYA Apr 15 '19 at 15:44
  • Check it out here: https://www.kaggle.com/aakashnain/diving-deep-into-focal-loss – enterML Apr 15 '19 at 16:47
  • @Nain the link you have given is more of image classification problem rather than image semantic segmentation problem – ErickYA Apr 16 '19 at 16:39
  • @ErickYA the idea is the same. PS: Semantic segmentation is just pixel-wise classification, isn't it? – enterML Apr 17 '19 at 17:09

1 Answers1

0

As an option for class_weight for binary classes, you can also handle imbalanced classes using Synthetic Oversampling Technique (SMOTE), increasing the size of the minority group:

from imblearn.over_sampling import SMOTE

sm = SMOTE()
x, y = sm.fit_sample(X_train, Y_train)
razimbres
  • 4,715
  • 5
  • 23
  • 50
  • Thanks @Rubens_Zimbres but is that for image segmentation or image classification? In my case it is semantic image segmentation – ErickYA Apr 15 '19 at 15:41