model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}
args.dataset == 'cifar100' or args.dataset == 'cifar10':
args.stride = [2, 2]
resnet = resnet18(args, pretrained=False, num_classes=args.num_classes)
initial_weight = model_zoo.load_url(model_urls['resnet18'])
local_model = resnet
initial_weight_1 = local_model.state_dict()
for key in initial_weight.keys():
if key[0:3] == 'fc.' or key[0:5]=='conv1' or key[0:3]=='bn1':
initial_weight[key] = initial_weight_1[key]
local_model.load_state_dict(initial_weight)
I dont understand this line " initial_weight[key] = initial_weight_1[key]"
Could you please tell me why we need to do this?
thanks