1

I would like to remove all the batch normalization layers from a Keras model that includes short-skip connections. For example, let's consider EfficientNetB0 as follow:

import tensorflow as tf

model = tf.keras.applications.EfficientNetB0(weights=None, include_top=True)

I was using the 3D version of efficientnet, which I don't think is important for the question ,but I'm going to show it anyway:


import keras 
from keras import layers
from keras.models import Model


input_shape = (32,32,32,3)

import efficientnet_3D.keras as efn 
model = efn.EfficientNetB0(input_shape=input_shape, weights='imagenet')

X = model.layers[-1].output
X = layers.Flatten()(X)
X = layers.Dense(16)(X)
X = layers.Dense(16)(X)
X = layers.Dense(1)(X)

model = Model(inputs=model.inputs, outputs=X)
model.compile(loss='mse',
              optimizer='adam',
              metrics=['mean_absolute_error']
              )
model.summary()

I tried to develop my own way of removing, and it seems to be totally wrong. Because the output model is pretty messy in terms of all the shortcut connections.



import keras
from keras import layers
from keras.models import Model

ind = [i for i, l in enumerate(model.layers) if 'bn' in l.name]


X = model.layers[0].output
for i in range(1, len(model.layers)+1):
    
    # Skipping Batch Normalization layers
    if i in ind:
        # model.layers[i]._inbound_nodes = []
        # model.layers[i]._outbound_nodes = []
        continue
        
    # If there is a short skip 
    if isinstance(model.layers[i].input, list):
        input_names = [j.name for j in model.layers[i].input]
        assert len(input_names) == 2
        input_names.remove(X.name)
        input_names = input_names[0].split('/')[0] 
        # X = [model.get_layer(input_names).output, X]
        X = [model.layers[6].output, X]
        
    if isinstance(X, list):
        print(i)
    X = model.layers[i](X)

new_model = Model(inputs=model.inputs, outputs=X)

I think there should be a better way that I'm not aware of. I tried a similar question for removing a layer, but I think because my model includes skip-connection, those methods don't work. Any help is appreciated.

Ali
  • 11
  • 2
  • 1
    Why remove them? This would effectively destroy the model unless you retrain it, or include the batch normalization operations in another way, – Dr. Snoopy Nov 16 '22 at 06:40
  • @Dr.Snoopy This model is doing poorly on my data set, and I believe the underlying reason is the BN layers. I think my images are so sparse, and there is not enough diversity in each batch that causes the model to overfit to each BATCH and do very poorly at the end. – Ali Nov 16 '22 at 17:11

0 Answers0