SHARE
TWEET

Untitled

a guest Mar 22nd, 2019 61 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. class CNN(torch.nn.Module):    
  2.     def __init__(self):
  3.         super(CNN, self).__init__()
  4.         # batch_size x 3 x 16 x 16
  5.         self.conv1 = torch.nn.Conv2d(3, 8, kernel_size=3, stride=1, padding=1)
  6.         self.conv2 = torch.nn.Conv2d(8, 16, kernel_size=3, stride=1, padding=1)
  7.         # batch_size x 16 x 16 x 16
  8.         self.pool = torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
  9.         self.norm1 = torch.nn.BatchNorm2d(16)
  10.         #batch_size x 16 x 8 x 8
  11.         self.conv3 = torch.nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
  12.         # batch_size x 32 x 8 x 8
  13.         #pool
  14.         # batch_size x 32 x 4 x 4
  15.         self.conv4 = torch.nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
  16.         self.norm2 = torch.nn.BatchNorm2d(64)
  17.         self.conv5 = torch.nn.Conv2d(64, 96, kernel_size=3, stride=1, padding=1)
  18.         self.norm3 = torch.nn.BatchNorm2d(96)
  19.         self.conv6 = torch.nn.Conv2d(96, 128, kernel_size=3, stride=1, padding=1)
  20.         # batch_size x 128 x 4 x 4
  21.         # Deleted for now ##pool
  22.         # Deleted for now ##batch_size x 128 x 2 x 2
  23.         self.norm4 = torch.nn.BatchNorm1d(128 * 4 * 4)
  24.         self.fc1 = torch.nn.Linear(128 * 4 * 4, 30 * 30)
  25.         self.norm5 = torch.nn.BatchNorm1d(30 * 30)
  26.         self.fc2 = torch.nn.Linear(30 * 30, 15 * 15)
  27.        
  28.     def forward(self, x):
  29.         x = x.view(-1, 3, 16, 16)
  30.         x = F.relu(self.conv1(x))
  31.         x = F.relu(self.conv2(x))
  32.         x = self.pool(x)
  33.         x = self.norm1(x)
  34.         x = F.relu(self.conv3(x))
  35.         x = self.pool(x)
  36.         x = F.relu(self.conv4(x))
  37.         x = self.norm2(x)
  38.         x = F.relu(self.conv5(x))
  39.         x = self.norm3(x)
  40.         x = F.relu(self.conv6(x))
  41.         #x = self.pool(x)
  42.         x = x.view(-1, 128 * 4 * 4)
  43.         x = self.norm4(x)
  44.         x = F.relu(self.fc1(x))
  45.         x = self.norm5(x)
  46.         x = self.fc2(x)
  47.         return(x)
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top