0

I would like to replace BatchNorm layers with GroupNorm in built-in keras models, e.g. ResNet50. I'm trying to reset nodes' layers to my new layer, however nothing changes when I query a model.summary().

import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow.keras import layers

model = tf.keras.applications.resnet.ResNet50(include_top=False, weights=None)
channels = 3

for i,layer in enumerate(model.layers[:]):
    if 'bn' in layer.name:
        inbound_nodes = layer.inbound_nodes
        outbound_nodes = layer.outbound_nodes
        
        new_name = layer.name.replace('bn','gn')
        new_layer =  tfa.layers.GroupNormalization(channels)
        new_layer._name = new_name 
        
        for j in range(len(inbound_nodes)):
            inbound_nodes[j].layer = new_layer #set end of node to this layer
        
        for k in range(len(outbound_nodes)):
            new_layer.outbound_nodes.append(outbound_nodes[k])
        
        layer = new_layer
kaycaborr
  • 49
  • 1
  • 5

1 Answers1

0

I've created the following code, doing some changes from this answer in order to make if work for your case:

import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow.keras import layers, Model 

model = tf.keras.applications.resnet.ResNet50(include_top=False, weights=None)
print(model.summary())
channels = 64

from keras.models import Model

def insert_layer_nonseq(model, layer_regex, insert_layer_factory,
                        insert_layer_name=None, position='after'):

    # Auxiliary dictionary to describe the network graph
    network_dict = {'input_layers_of': {}, 'new_output_tensor_of': {}}

    # Set the input layers of each layer
    for layer in model.layers:
        for node in layer._outbound_nodes:
            layer_name = node.outbound_layer.name
            if layer_name not in network_dict['input_layers_of']:
                network_dict['input_layers_of'].update(
                        {layer_name: [layer.name]})
            else:
                network_dict['input_layers_of'][layer_name].append(layer.name)

    # Set the output tensor of the input layer
    network_dict['new_output_tensor_of'].update(
            {model.layers[0].name: model.input})

    # Iterate over all layers after the input
    model_outputs = []
    for layer in model.layers[1:]:

        # Determine input tensors
        layer_input = [network_dict['new_output_tensor_of'][layer_aux] 
                for layer_aux in network_dict['input_layers_of'][layer.name]]
        if len(layer_input) == 1:
            layer_input = layer_input[0]

        # Insert layer if name matches
        if (layer.name).endswith(layer_regex):
            if position == 'replace':
                x = layer_input
            else:
                raise ValueError('position must be: replace')

            new_layer = insert_layer_factory()
            new_layer._name = '{}_{}'.format(layer.name, new_layer.name)
            x = new_layer(x)
            # print('New layer: {} Old layer: {} Type: {}'.format(new_layer.name, layer.name, position))
            
        else:
            x = layer(layer_input)

        # Set new output tensor (the original one, or the one of the inserted
        # layer)
        network_dict['new_output_tensor_of'].update({layer.name: x})

        # Save tensor in output list if it is output in initial model
        if layer_name in model.output_names:
            model_outputs.append(x)

    return Model(inputs=model.inputs, outputs=model_outputs)

def replace_layer():
  return tfa.layers.GroupNormalization(channels)

model = insert_layer_nonseq(model, 'bn', replace_layer, position="replace")

Note: I've changed your channels variable from 3 to 64 for the following reason.

From the documentation of the argument group:

Integer, the number of groups for Group Normalization. Can be in the range [1, N] where N is the input dimension. The input dimension must be divisible by the number of groups. Defaults to 32.

You should choose the most appropriate one.

ClaudiaR
  • 3,108
  • 2
  • 13
  • 27
  • The model summary looks good, however, `model.layers[2].outbound_nodes[0].layer` returns ``. I could not figure out why the nodes are still incorrect. – kaycaborr Aug 23 '22 at 09:48
  • If you loop over the nodes in `model.layers[2]._outbound_nodes` and print `node.layer`, you'll see both `batch_normalization` and `group_normalization` nodes now. My guess is that `_outbound_nodes` was not correctly updated to remove the old nodes that are still visible but not used. The real nodes connected are the ones in summary. – ClaudiaR Aug 23 '22 at 10:08
  • This might have some side effects. `layer_names = ['relu0', 'stage2_unit1_relu2', 'stage3_unit1_relu2', 'stage4_unit1_relu2', 'relu1'] layers = [model.get_layer(name).output for name in layer_names] new_model = tf.keras.Model(inputs=model.input, outputs=layers) new_model.summary()` This returns a model with the old BatchNorm layers:/ – kaycaborr Aug 24 '22 at 07:42
  • I'm not sure what you're trying to do here... you have already created a new model – ClaudiaR Aug 24 '22 at 13:06