Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- for ii in range(1, steps+1):
- # get the features from your target image
- target_features = get_features(target, vgg)
- # the content loss
- content_loss = torch.mean((target_features['conv4_2'] - content_features['conv4_2'])**2)
- # the style loss
- # initialize the style loss to 0
- style_loss = 0
- # then add to it for each layer's gram matrix loss
- for layer in style_weights:
- # get the "target" style representation for the layer
- target_feature = target_features[layer]
- target_gram = gram_matrix(target_feature)
- _, d, h, w = target_feature.shape
- # get the "style" style representation
- style_gram = style_grams[layer]
- # the style loss for one layer, weighted appropriately
- layer_style_loss = style_weights[layer] * torch.mean((target_gram - style_gram)**2)
- # add to the style loss
- style_loss += layer_style_loss / (d * h * w)
- # calculate the *total* loss
- total_loss = content_weight * content_loss + style_weight * style_loss
- # update your target image
- optimizer.zero_grad()
- total_loss.backward()
- optimizer.step()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement