50

I have a trained Keras model and I would like:

1) to replace Con2D layer with the same but without bias.

2) to add BatchNormalization layer before first Activation

How can I do this?

def keras_simple_model():
    from keras.models import Model
    from keras.layers import Input, Dense,  GlobalAveragePooling2D
    from keras.layers import Conv2D, MaxPooling2D, Activation

    inputs1 = Input((28, 28, 1))
    x = Conv2D(4, (3, 3), activation=None, padding='same', name='conv1')(inputs1)
    x = Activation('relu')(x)
    x = Conv2D(4, (3, 3), activation=None, padding='same', name='conv2')(x)
    x = Activation('relu')(x)
    x = MaxPooling2D((2, 2), strides=(2, 2), name='pool1')(x)

    x = Conv2D(8, (3, 3), activation=None, padding='same', name='conv3')(x)
    x = Activation('relu')(x)
    x = Conv2D(8, (3, 3), activation=None, padding='same', name='conv4')(x)
    x = Activation('relu')(x)
    x = MaxPooling2D((2, 2), strides=(2, 2), name='pool2')(x)

    x = GlobalAveragePooling2D()(x)
    x = Dense(10, activation=None)(x)
    x = Activation('softmax')(x)

    model = Model(inputs=inputs1, outputs=x)
    return model


if __name__ == '__main__':
    model = keras_simple_model()
    print(model.summary())
Innat
  • 16,113
  • 6
  • 53
  • 101
ZFTurbo
  • 3,652
  • 3
  • 22
  • 27

4 Answers4

37

The following function allows you to insert a new layer before, after or to replace each layer in the original model whose name matches a regular expression, including non-sequential models such as DenseNet or ResNet.

import re
from keras.models import Model

def insert_layer_nonseq(model, layer_regex, insert_layer_factory,
                        insert_layer_name=None, position='after'):

    # Auxiliary dictionary to describe the network graph
    network_dict = {'input_layers_of': {}, 'new_output_tensor_of': {}}

    # Set the input layers of each layer
    for layer in model.layers:
        for node in layer._outbound_nodes:
            layer_name = node.outbound_layer.name
            if layer_name not in network_dict['input_layers_of']:
                network_dict['input_layers_of'].update(
                        {layer_name: [layer.name]})
            else:
                network_dict['input_layers_of'][layer_name].append(layer.name)

    # Set the output tensor of the input layer
    network_dict['new_output_tensor_of'].update(
            {model.layers[0].name: model.input})

    # Iterate over all layers after the input
    model_outputs = []
    for layer in model.layers[1:]:

        # Determine input tensors
        layer_input = [network_dict['new_output_tensor_of'][layer_aux] 
                for layer_aux in network_dict['input_layers_of'][layer.name]]
        if len(layer_input) == 1:
            layer_input = layer_input[0]

        # Insert layer if name matches the regular expression
        if re.match(layer_regex, layer.name):
            if position == 'replace':
                x = layer_input
            elif position == 'after':
                x = layer(layer_input)
            elif position == 'before':
                pass
            else:
                raise ValueError('position must be: before, after or replace')

            new_layer = insert_layer_factory()
            if insert_layer_name:
                new_layer.name = insert_layer_name
            else:
                new_layer.name = '{}_{}'.format(layer.name, 
                                                new_layer.name)
            x = new_layer(x)
            print('New layer: {} Old layer: {} Type: {}'.format(new_layer.name,
                                                            layer.name, position))
            if position == 'before':
                x = layer(x)
        else:
            x = layer(layer_input)

        # Set new output tensor (the original one, or the one of the inserted
        # layer)
        network_dict['new_output_tensor_of'].update({layer.name: x})

        # Save tensor in output list if it is output in initial model
        if layer_name in model.output_names:
            model_outputs.append(x)

    return Model(inputs=model.inputs, outputs=model_outputs)

The difference with respect to the simpler case of a purely sequential model is that before iterating over the layers to find the key layer, you first parse the graph and store the input layers of each layer in an auxiliary dictionary. Then, as you iterate over the layers, you also store the new output tensor of each layer, which is used to determine the input layers of each layer, when building the new model.

A use case would be the following, where a Dropout layer is inserted after each activation layer of ResNet50:

from keras.applications.resnet50 import ResNet50
from keras.models import load_model

model = ResNet50()
def dropout_layer_factory():
    return Dropout(rate=0.2, name='dropout')
model = insert_layer_nonseq(model, '.*activation.*', dropout_layer_factory)

# Fix possible problems with new model
model.save('temp.h5')
model = load_model('temp.h5')

model.summary()
ZFTurbo
  • 3,652
  • 3
  • 22
  • 27
alexhg
  • 690
  • 7
  • 11
  • 2
    Thank you! This worked with me. However, I noticed a problem when trying to call the `insert_layer_nonseq(...)` function twice on the same model: the `outbound_layers` and `inbound_layers` are appended to - instead of being replaced - when calling `x = layer(layer_input)` My fix was to place `network_dict['input_layers_of'][layer_name].append(layer.name)` inside an if condition `if (layer.name not in network_dict['input_layers_of'][layer_name]):` – mostafa.elhoushi Apr 18 '19 at 20:23
  • That is a good point, although I guess it could also be solved by inserting the subsequent layers with different layer names (`insert_layer_name`). In any case, modifying the graph multiple times can lead to serious memory leaks, so sometimes it is a good idea to re-load the model. PS: if you found my answer useful, please consider upvoting ;) – alexhg Apr 22 '19 at 09:30
  • 2
    I had to change `outbound_nodes` to `_outbound_nodes` – The Guy with The Hat Jul 31 '19 at 21:27
  • How can I call 'insert_layer_nonseq(model, '.*activation.*, dropout_layer_factory)' multiple times? of for multiple specific layers? because I am getting dimension shape error when I call it on the same model twice or multiple times the 'insert_layer_nonseq'?? – TheJokerAEZ Oct 21 '19 at 13:11
  • 2
    Every time we insert a new layer we are changing the internal Keras graph connectivity, i.e. new inbound and outbound nodes are created, which can be a source of issues if we want to do further network "surgery. Complex solution: write a function to manually remove the old nodes (I did this, but it isn't trivial). Easy solution: save the model and load it again.This way, the unnecessary nodes are not part of the graph any longer. – alexhg Oct 24 '19 at 09:26
  • Just a small comment. This is saying "inserting layer after" even when replacing. – guillefix Feb 07 '20 at 17:07
  • I agree with alexhg. Adding new intermediate layers is very different from adding new layers on top of pre-trained layers as it messes with the inbound and outbound nodes. Specially, I've noticed that the model.get_layer method always points to the "old" model architecture rather than the new architecture. So I've also resorted to saving and loading the new model after I'm satisfied with the new model layering, which will create a "clean" graph with your script/kernel. – Razorocean Apr 11 '20 at 17:49
  • Sometimes it is indeed necessary to save->load the model. But these functions come often very useful to me. – alexhg Apr 12 '20 at 16:55
  • Looks like function will work incorrect for case when model has more than one output. – ZFTurbo May 14 '20 at 10:34
  • This is to illustrate how to insert layers in a Keras model and it works well in the example of the question. It is not meant to work in all possible scenarios. Still, multi-output can be handled with some if statements. – alexhg May 15 '20 at 10:50
  • Nevermind, I added support for multi-output models as well ) – ZFTurbo May 22 '20 at 11:08
  • Unfortunately, it does not work if I want to replace the input layer – volperossa Jul 17 '20 at 00:58
  • The question explicitly asks how to replace "intermediate" layers. To replace only the input layer you don't need all the functionality of this function, as it is much simpler. – alexhg Jul 17 '20 at 13:02
  • This is really great (upvoted), I'm having trouble inserting the output model of this function into a TimeDistributed layer strangely. I try model = tf.keras.layers.TimeDistributed(model_with_intermediate_layer)(input) I meet this error when I try: TypeError: int() argument must be a string, a bytes-like object or a number, not 'TensorShape'. Any inkling as to what it might be? I don't get the error if I put the pretrained mobilenetv2 in there, but once mobilenetv2 is editted by that function it no longer works. – michael_question_answerer Oct 10 '20 at 18:29
  • 1
    @michael_question_answerer unfortunately I can't say much without knowing more details about it (what model_with_intermediate_layer, input, etc.). Have you considered asking a new, more detailed question? Maybe link to this one for reference. – alexhg Oct 12 '20 at 17:55
  • @alexhg thank you, I have gone a different direction with my work but will ask a new question and link this one if I return to it and find I have the same problem. – michael_question_answerer Oct 12 '20 at 19:12
  • Hi! Could you improve your answer to also **remove** given layers? Otherwise I can ask a separate question – ibarrond Feb 23 '21 at 16:23
  • There is a typo: must be "layer.name" instead of "layer_name" in `if layer_name in model.output_names:`. Otherwise you are adding all layers to model_outputs. – yellowdolphin Jun 23 '22 at 13:03
  • layer's name is readonly property, so I replaced new_layer.name by new_layer._name, and it works – Andrew Veriga Jul 09 '22 at 05:33
18

You can use the following functions:

from keras.models import Model

def replace_intermediate_layer_in_keras(model, layer_id, new_layer):

    layers = [l for l in model.layers]

    x = layers[0].output
    for i in range(1, len(layers)):
        if i == layer_id:
            x = new_layer(x)
        else:
            x = layers[i](x)

    new_model = Model(input=layers[0].input, output=x)
    return new_model

def insert_intermediate_layer_in_keras(model, layer_id, new_layer):
 
    layers = [l for l in model.layers]

    x = layers[0].output
    for i in range(1, len(layers)):
        if i == layer_id:
            x = new_layer(x)
        x = layers[i](x)

    new_model = Model(input=layers[0].input, output=x)
    return new_model

Example:

from keras.layers import Conv2D, BatchNormalization

model = keras_simple_model()
print(model.summary())

model = replace_intermediate_layer_in_keras(
    model, 3, 
    Conv2D(
        4, (3, 3), 
        activation=None, 
        padding='same', 
        name='conv2_repl', 
        use_bias=False
    )
)
print(model.summary())

model = insert_intermediate_layer_in_keras(
    model, 4, BatchNormalization()
)
print(model.summary())

There are some limitation on replacements due to layer shapes etc.

Innat
  • 16,113
  • 6
  • 53
  • 101
ZFTurbo
  • 3,652
  • 3
  • 22
  • 27
  • 11
    This does not work for me, it seems to have problems with concatenate and merge layers: ValueError: A merge layer should be called on a list of inputs. – maniac Aug 27 '18 at 10:53
3

This was how i did it:

import keras 
from keras.models import Model 
from tqdm import tqdm 
from keras import backend as K

def make_list(X):
    if isinstance(X, list):
        return X
    return [X]

def list_no_list(X):
    if len(X) == 1:
        return X[0]
    return X

def replace_layer(model, replace_layer_subname, replacement_fn,
**kwargs):
    """
    args:
        model :: keras.models.Model instance
        replace_layer_subname :: str -- if str in layer name, replace it
        replacement_fn :: fn to call to replace all instances
            > fn output must produce shape as the replaced layers input
    returns:
        new model with replaced layers
    quick examples:
        want to just remove all layers with 'batch_norm' in the name:
            > new_model = replace_layer(model, 'batch_norm', lambda **kwargs : (lambda u:u))
        want to replace all Conv1D(N, m, padding='same') with an LSTM (lets say all have 'conv1d' in name)
            > new_model = replace_layer(model, 'conv1d', lambda layer, **kwargs: LSTM(units=layer.filters, return_sequences=True)
    """
    model_inputs = []
    model_outputs = []
    tsr_dict = {}

    model_output_names = [out.name for out in make_list(model.output)]

    for i, layer in enumerate(model.layers):
        ### Loop if layer is used multiple times
        for j in range(len(layer._inbound_nodes)):

            ### check layer inp/outp
            inpt_names = [inp.name for inp in make_list(layer.get_input_at(j))]
            outp_names = [out.name for out in make_list(layer.get_output_at(j))]

            ### setup model inputs
            if 'input' in layer.name:
                for inpt_tsr in make_list(layer.get_output_at(j)):
                    model_inputs.append(inpt_tsr)
                    tsr_dict[inpt_tsr.name] = inpt_tsr
                continue

            ### setup layer inputs
            inpt = list_no_list([tsr_dict[name] for name in inpt_names])

            ### remake layer 
            if replace_layer_subname in layer.name:
                print('replacing '+layer.name)
                x = replacement_fn(old_layer=layer, **kwargs)(inpt)
            else:
                x = layer(inpt)

            ### reinstantialize outputs into dict
            for name, out_tsr in zip(outp_names, make_list(x)):

                ### check if is an output
                if name in model_output_names:
                    model_outputs.append(out_tsr)
                tsr_dict[name] = out_tsr

    return Model(model_inputs, model_outputs)

I have a custom layer (taken from someone online) called BatchNormalizationFreeze, so an example of usage is this:

 new_model = model_replacement(model, 'batch_normal', lambda **kwargs : BatchNormalizationFreeze()(x))

If youre gonna do multiple layers just replace the replacement function with a psuedo model that does them all at once

mshlis
  • 172
  • 6
  • What is a `make_list` and `list_no_list` functions? This code is not reusable and names are not very descriptive. Even if it works I cannot upvote it. – UpmostScarab Feb 21 '19 at 09:57
  • @UpmostScarab oh those were helper functions, that do kinda what the name implies (turns something into a list, and other turns list into non-list)... but ill update it (current version is also less buggy) Note that this only works correctly on a DAG – mshlis Feb 25 '19 at 15:33
2

Unfortunately replacing a layer is no small feat for models that do not follow the sequential pattern. For sequential patterns it is OK to just x = layer(x) and replace with new_layer when you see fit as in the previous answer. However, for models that do not have a classic sequential pattern (say you have a simple "concatenation" of two columns) you have to actually "parse" the graph and use your "new_layer" (or layers) in the right places. Hope this is not too discouraging and happy graph parsing and reconstructing :)

Andrei Damian
  • 101
  • 1
  • 4