Advertisement
Guest User

Untitled

a guest
Jul 19th, 2018
75
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.40 KB | None | 0 0
  1. from torchvision import models
  2. class Vgg19(torch.nn.Module):
  3.     def __init__(self, requires_grad=False):
  4.         super(Vgg19, self).__init__()
  5.         vgg_pretrained_features = models.vgg19(pretrained=True).features
  6.         self.slice1 = torch.nn.Sequential()
  7.         self.slice2 = torch.nn.Sequential()
  8.         self.slice3 = torch.nn.Sequential()
  9.         self.slice4 = torch.nn.Sequential()
  10.         self.slice5 = torch.nn.Sequential()
  11.         for x in range(2):
  12.             self.slice1.add_module(str(x), vgg_pretrained_features[x])
  13.         for x in range(2, 7):
  14.             self.slice2.add_module(str(x), vgg_pretrained_features[x])
  15.         for x in range(7, 12):
  16.             self.slice3.add_module(str(x), vgg_pretrained_features[x])
  17.         for x in range(12, 21):
  18.             self.slice4.add_module(str(x), vgg_pretrained_features[x])
  19.         for x in range(21, 30):
  20.             self.slice5.add_module(str(x), vgg_pretrained_features[x])
  21.         if not requires_grad:
  22.             for param in self.parameters():
  23.                 param.requires_grad = False
  24.  
  25.     def forward(self, X):
  26.         h_relu1 = self.slice1(X)
  27.         h_relu2 = self.slice2(h_relu1)        
  28.         h_relu3 = self.slice3(h_relu2)        
  29.         h_relu4 = self.slice4(h_relu3)        
  30.         h_relu5 = self.slice5(h_relu4)                
  31.         out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
  32.         return out
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement