6

I'm trying to load checkpoints and populate model weights using The Faster-RCNN architecture (Faster R-CNN ResNet50 V1 640x640 to be precise, from here. I'm trying to load the weights for this network similar to how it's done in the example notebook for RetinaNet, where they do the following:

fake_box_predictor = tf.compat.v2.train.Checkpoint(
    _base_tower_layers_for_heads=detection_model._box_predictor._base_tower_layers_for_heads,
    _box_prediction_head=detection_model._box_predictor._box_prediction_head,
)

fake_model = tf.compat.v2.train.Checkpoint(
          _feature_extractor=detection_model._feature_extractor,
          _box_predictor=fake_box_predictor
)

ckpt = tf.compat.v2.train.Checkpoint(model=fake_model)
ckpt.restore(checkpoint_path).expect_partial()

I'm trying to get a similar checkpoint loading mechanism going for the Faster-RCNN network I want to use, but the properties like _base_tower_layers_for_heads, _box_prediction_head only exist for the architecture used in the example, and not for anything else.

I also couldn't find documentation on which parts of the model to populate using Checkpoint for my particular use case. Would greatly appreciate any help on how to approach this!

Rohan Gautam
  • 146
  • 1
  • 9

1 Answers1

0

As you said the main problem that you have is that you don't have a layer tensor to the layers that you want to do transfer learning on it. This is part of the original implementation of the Faster R-CNN ResNet50 V1 640x640 copy in the Zoo. They didn't name the layers, or maybe they did name it but they didn't published the names (or the code). To solve this you need to find out which layers you want to keep and which you want to relearn. You can print out all the layers in a network using (ref):

[n.name for n in tf.get_default_graph().as_graph_def().node]

Names to layer can manually added but tf reserve default names for each node. This list can be long and exhausting however you need to find the tensor to start your transfer learning. Therefore you need to follow the list and try to understand from which layers you want to freeze and which you want to continue the learning process. Freezing a layer (ref):

if layer.name == 'layer_name':
    layer.trainable = False
Avizipi
  • 502
  • 6
  • 16