Advertisement
Guest User

Untitled

a guest
Jul 19th, 2019
82
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 0.50 KB | None | 0 0
  1. import VGG19
  2. import torch.nn as nn
  3.  
  4. class VGGLoss(nn.Module):
  5. def __init__(self, gpu_ids):
  6. super(VGGLoss, self).__init__()
  7.  
  8. self.vgg = VGG19.cuda()
  9. self.criterion = nn.L1Loss()
  10. self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]
  11.  
  12. def forward(self, x, y):
  13. x_vgg, y_vgg = self.vgg(x), self.vgg(y)
  14. loss = 0
  15. for i in range(len(x_vgg)):
  16. loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
  17.  
  18. return loss
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement