I've created the following code, doing some changes from this answer in order to make if work for your case:
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow.keras import layers, Model
model = tf.keras.applications.resnet.ResNet50(include_top=False, weights=None)
print(model.summary())
channels = 64
from keras.models import Model
def insert_layer_nonseq(model, layer_regex, insert_layer_factory,
insert_layer_name=None, position='after'):
# Auxiliary dictionary to describe the network graph
network_dict = {'input_layers_of': {}, 'new_output_tensor_of': {}}
# Set the input layers of each layer
for layer in model.layers:
for node in layer._outbound_nodes:
layer_name = node.outbound_layer.name
if layer_name not in network_dict['input_layers_of']:
network_dict['input_layers_of'].update(
{layer_name: [layer.name]})
else:
network_dict['input_layers_of'][layer_name].append(layer.name)
# Set the output tensor of the input layer
network_dict['new_output_tensor_of'].update(
{model.layers[0].name: model.input})
# Iterate over all layers after the input
model_outputs = []
for layer in model.layers[1:]:
# Determine input tensors
layer_input = [network_dict['new_output_tensor_of'][layer_aux]
for layer_aux in network_dict['input_layers_of'][layer.name]]
if len(layer_input) == 1:
layer_input = layer_input[0]
# Insert layer if name matches
if (layer.name).endswith(layer_regex):
if position == 'replace':
x = layer_input
else:
raise ValueError('position must be: replace')
new_layer = insert_layer_factory()
new_layer._name = '{}_{}'.format(layer.name, new_layer.name)
x = new_layer(x)
# print('New layer: {} Old layer: {} Type: {}'.format(new_layer.name, layer.name, position))
else:
x = layer(layer_input)
# Set new output tensor (the original one, or the one of the inserted
# layer)
network_dict['new_output_tensor_of'].update({layer.name: x})
# Save tensor in output list if it is output in initial model
if layer_name in model.output_names:
model_outputs.append(x)
return Model(inputs=model.inputs, outputs=model_outputs)
def replace_layer():
return tfa.layers.GroupNormalization(channels)
model = insert_layer_nonseq(model, 'bn', replace_layer, position="replace")
Note: I've changed your channels
variable from 3 to 64 for the following reason.
From the documentation of the argument group
:
Integer, the number of groups for Group Normalization. Can be in the
range [1, N] where N is the input dimension. The input dimension must
be divisible by the number of groups. Defaults to 32.
You should choose the most appropriate one.