I would like to remove all the batch normalization layers from a Keras model that includes short-skip connections. For example, let's consider EfficientNetB0
as follow:
import tensorflow as tf
model = tf.keras.applications.EfficientNetB0(weights=None, include_top=True)
I was using the 3D version of efficientnet, which I don't think is important for the question ,but I'm going to show it anyway:
import keras
from keras import layers
from keras.models import Model
input_shape = (32,32,32,3)
import efficientnet_3D.keras as efn
model = efn.EfficientNetB0(input_shape=input_shape, weights='imagenet')
X = model.layers[-1].output
X = layers.Flatten()(X)
X = layers.Dense(16)(X)
X = layers.Dense(16)(X)
X = layers.Dense(1)(X)
model = Model(inputs=model.inputs, outputs=X)
model.compile(loss='mse',
optimizer='adam',
metrics=['mean_absolute_error']
)
model.summary()
I tried to develop my own way of removing, and it seems to be totally wrong. Because the output model is pretty messy in terms of all the shortcut connections.
import keras
from keras import layers
from keras.models import Model
ind = [i for i, l in enumerate(model.layers) if 'bn' in l.name]
X = model.layers[0].output
for i in range(1, len(model.layers)+1):
# Skipping Batch Normalization layers
if i in ind:
# model.layers[i]._inbound_nodes = []
# model.layers[i]._outbound_nodes = []
continue
# If there is a short skip
if isinstance(model.layers[i].input, list):
input_names = [j.name for j in model.layers[i].input]
assert len(input_names) == 2
input_names.remove(X.name)
input_names = input_names[0].split('/')[0]
# X = [model.get_layer(input_names).output, X]
X = [model.layers[6].output, X]
if isinstance(X, list):
print(i)
X = model.layers[i](X)
new_model = Model(inputs=model.inputs, outputs=X)
I think there should be a better way that I'm not aware of. I tried a similar question for removing a layer, but I think because my model includes skip-connection, those methods don't work. Any help is appreciated.