Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import VGG19
- import torch.nn as nn
- class VGGLoss(nn.Module):
- def __init__(self, gpu_ids):
- super(VGGLoss, self).__init__()
- self.vgg = VGG19.cuda()
- self.criterion = nn.L1Loss()
- self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]
- def forward(self, x, y):
- x_vgg, y_vgg = self.vgg(x), self.vgg(y)
- loss = 0
- for i in range(len(x_vgg)):
- loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
- return loss
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement