Guest User

Untitled

a guest
Apr 8th, 2020
93
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.95 KB | None | 0 0
  1. import torch
  2.  
  3. class Net(torch.nn.Module):
  4.  
  5.     def __init__(self, n_features):
  6.         super(Net, self).__init__()
  7.         self.conv1 = torch.nn.Conv2d(in_channels = 1, out_channels = 64, kernel_size = 5, stride = 2, padding = 2)
  8.         self.bn1 = torch.nn.BatchNorm2d(num_features = 64)
  9.         self.relu1 = torch.nn.ReLU()
  10.  
  11.         self.conv2 = torch.nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3, stride = 1, padding = 1)
  12.         self.bn2 = torch.nn.BatchNorm2d(num_features = 64)
  13.         self.relu2 = torch.nn.ReLU()
  14.  
  15.         self.mp1 = torch.nn.MaxPool2d(kernel_size = 2)
  16.         self.do1 = torch.nn.Dropout2d(p = 0.3)
  17.        
  18.         self.conv3 = torch.nn.Conv2d(in_channels = 64, out_channels = 128, kernel_size = 3, stride = 1, padding = 1)
  19.         self.bn3 = torch.nn.BatchNorm2d(num_features = 128)
  20.         self.relu3 = torch.nn.ReLU()
  21.        
  22.         self.conv4 = torch.nn.Conv2d(in_channels = 128, out_channels = 128, kernel_size = 3, stride = 1, padding = 1)
  23.         self.bn4 = torch.nn.BatchNorm2d(num_features = 128)
  24.         self.relu4 = torch.nn.ReLU()
  25.  
  26.         self.mp2 = torch.nn.MaxPool2d(kernel_size = 2)
  27.         self.do2 = torch.nn.Dropout2d(p = 0.3)
  28.        
  29.         self.conv5 = torch.nn.Conv2d(in_channels = 128, out_channels = 256, kernel_size = 3, stride = 1, padding = 1)
  30.         self.bn5 = torch.nn.BatchNorm2d(num_features = 256)
  31.         self.relu5 = torch.nn.ReLU()
  32.        
  33.         self.conv6 = torch.nn.Conv2d(in_channels = 256, out_channels = 256, kernel_size = 3, stride = 1, padding = 1)
  34.         self.bn6 = torch.nn.BatchNorm2d(num_features = 256)
  35.         self.relu6 = torch.nn.ReLU()
  36.        
  37.         self.conv7 = torch.nn.Conv2d(in_channels = 256, out_channels = 384, kernel_size = 3, stride = 1, padding = 1)
  38.         self.bn7 = torch.nn.BatchNorm2d(num_features = 384)
  39.         self.relu7 = torch.nn.ReLU()
  40.        
  41.         self.conv8 = torch.nn.Conv2d(in_channels = 384, out_channels = 512, kernel_size = 3, stride = 1, padding = 1)
  42.         self.bn8 = torch.nn.BatchNorm2d(num_features = 512)
  43.         self.relu8 = torch.nn.ReLU()
  44.        
  45.         self.conv9 = torch.nn.Conv2d(in_channels = 512, out_channels = 256, kernel_size = 3, stride = 1, padding = 0)
  46.         self.bn9 = torch.nn.BatchNorm2d(num_features = 256)
  47.         self.relu9 = torch.nn.ReLU()
  48.        
  49.         self.aap1 = torch.nn.AdaptiveAvgPool2d(output_size = (1,1))
  50.  
  51.         self.fc1 = torch.nn.Linear(in_features = 256, out_features = n_features)
  52.  
  53.     def forward(self, x):
  54.         # Two convolutional layers (all conv layers followed by batch normalization and ReLU activation)
  55.         x = self.conv1(x)
  56.         x = self.bn1(x)
  57.         x = self.relu1(x)
  58.        
  59.         x = self.conv2(x)
  60.         x = self.bn2(x)
  61.         x = self.relu2(x)
  62.  
  63.         # Max pooling and dropout
  64.         x = self.mp1(x)
  65.         x = self.do1(x)
  66.  
  67.         # Two conv layers
  68.         x = self.conv3(x)
  69.         x = self.bn3(x)
  70.         x = self.relu3(x)
  71.  
  72.         x = self.conv4(x)
  73.         x = self.bn4(x)
  74.         x = self.relu3(x)
  75.  
  76.         # Max pooling
  77.         x = self.mp1(x)
  78.         x = self.do1(x)
  79.  
  80.         # Five conv layers
  81.         x = self.conv5(x)
  82.         x = self.bn5(x)
  83.         x = self.relu5(x)
  84.  
  85.         x = self.conv6(x)
  86.         x = self.bn6(x)
  87.         x = self.relu6(x)
  88.  
  89.         x = self.conv7(x)
  90.         x = self.bn7(x)
  91.         x = self.relu7(x)
  92.  
  93.         x = self.conv8(x)
  94.         x = self.bn8(x)
  95.         x = self.relu8(x)
  96.  
  97.         x = self.conv9(x)
  98.         x = self.bn9(x)
  99.         x = self.relu9(x)
  100.  
  101.         x = self.aap1(x)
  102.  
  103.         # Fully connected layer
  104.         x = x.view(-1, self.num_flat_features(x)) # view all channels to a single vector for the fc layer
  105.         x = self.fc1(x)
  106.         return x
  107.  
  108.     def num_flat_features(self, x):
  109.         size = x.size()[1:]  # all dimensions except the batch dimension
  110.         num_features = 1
  111.         for s in size:
  112.             num_features *= s
  113.         return num_features
Add Comment
Please, Sign In to add comment