While coding for " Neural Style Transfer by using Pytorch (Deep Learning) " I was trying to find Content Loss and Style Loss but during the implementation of the code I came across a KeyError: 'conv4_2' which I was not able to solve.
please have a look at the code below :
# Get content, style features and create gram matrix
def get_features(image,model):
layers = {
'0' : 'conv1_1',
'5' : 'conv2_1',
'10' : 'conv3_1',
'19' : 'conv4_1',
'21' : 'conv4_2', # content_feature
'28' : 'conv5_1'
}
x = image
Features = {}
for name, layer in model._modules.items():
x = layer(x)
if name in layers:
Features[layers[name]] = x
return Features
content_f = get_features(content_p,vgg)
style_f = get_features(style_p,vgg)
def gram_matrix(tensor):
b,c,h,w = tensor.size()
tensor = tensor.view(c,h*w)
gram = torch.mm(tensor,tensor.t())
return gram
style_grams = {layer: gram_matrix(style_f[layer]) for layer in style_f}
# Creating Style and Content loss fucntion
def content_loss(target_conv4_2, content_conv4_2):
loss = torch.mean((target_conv4_2-content_conv4_2)**2)
return loss
style_weights = {
'conv1_1' : 1.0,
'conv2_1' : 0.75,
'conv3_1' : 0.2,
'conv4_1' : 0.2,
'conv5_1' : 0.2
}
def style_loss(style_weights,target_features,style_grams):
loss = 0
for layer in style_weights:
target_f = target_features[layer]
target_gram = gram_matrix(target_f)
style_gram = style_gram[layer]
b,c,h,w = target_f.shape
layer_loss = style_weights[layers]*torch.mean((target_gram-style_gram)**2)
loss += layer_loss/(c*h*w)
return loss
target = content_p.clone().requires_grad_(True).to(device)
target_f = get_features(target,vgg)
print("Content Loss : ",content_loss(target_f['conv4_2'],content_f['conv4_2']))
print("Style Loss : ",style_loss(style_weights, target_f , style_grams))
Output from last 2 line of code :
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
<ipython-input-22-3b32a5406c6a> in <module>()
----> 1 print("Content Loss : ",content_loss(target_f['conv4_2'],content_f['conv4_2']))
2 print("Style Loss : ",style_loss(style_weights, target_f , style_grams))
KeyError: 'conv4_2'
I will be very thankful for a quick response !! Please do let me know if anyone need any code related help for solving this question !