0
model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}

args.dataset == 'cifar100' or args.dataset == 'cifar10':
args.stride = [2, 2]
resnet = resnet18(args, pretrained=False, num_classes=args.num_classes)
initial_weight = model_zoo.load_url(model_urls['resnet18'])
local_model = resnet 
initial_weight_1 = local_model.state_dict() 

for key in initial_weight.keys():
    if key[0:3] == 'fc.' or key[0:5]=='conv1' or key[0:3]=='bn1':
        initial_weight[key] = initial_weight_1[key] 

local_model.load_state_dict(initial_weight)

I dont understand this line " initial_weight[key] = initial_weight_1[key]"

Could you please tell me why we need to do this?

thanks

Joey
  • 13
  • 3

1 Answers1

0

Function torch.utils.model_zoo.load_url will load the serialized torch object from the given URL. In this particular case the URL used hosts the model's weight dictionary for the ResNet18 network.

Therefore initial_weight is the dictionary containing the weights of a pretrained ResNet18, while initial_weight_1 is the dictionary of the weights on the current model resnet in memory initialized by resnet18.

The following lines will go through the layers of the resnet model and copy the weights loaded from that URL if the key[0:3] == 'fc.' or key[0:5]=='conv1' or key[0:3]=='bn1': condition is met.

Ivan
  • 34,531
  • 8
  • 55
  • 100
  • Thanks for your reply. So, initial_weight is the weight of the pre-train ResNet18, while initial_weight_1 is just the structure of the ResNet 18 and not contain any pre-trained parameters, right? If so, why need to copy from initial_weight_1 to initial_weight to cover the pre-trained parameters. Thanks a lot~ – Joey Jul 21 '22 at 07:26