Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- class CNN(torch.nn.Module):
- def __init__(self):
- super(CNN, self).__init__()
- # batch_size x 3 x 16 x 16
- self.conv1 = torch.nn.Conv2d(3, 8, kernel_size=3, stride=1, padding=1)
- self.conv2 = torch.nn.Conv2d(8, 16, kernel_size=3, stride=1, padding=1)
- # batch_size x 16 x 16 x 16
- self.pool = torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
- self.norm1 = torch.nn.BatchNorm2d(16)
- #batch_size x 16 x 8 x 8
- self.conv3 = torch.nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
- # batch_size x 32 x 8 x 8
- #pool
- # batch_size x 32 x 4 x 4
- self.conv4 = torch.nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
- self.norm2 = torch.nn.BatchNorm2d(64)
- self.conv5 = torch.nn.Conv2d(64, 96, kernel_size=3, stride=1, padding=1)
- self.norm3 = torch.nn.BatchNorm2d(96)
- self.conv6 = torch.nn.Conv2d(96, 128, kernel_size=3, stride=1, padding=1)
- # batch_size x 128 x 4 x 4
- # Deleted for now ##pool
- # Deleted for now ##batch_size x 128 x 2 x 2
- self.norm4 = torch.nn.BatchNorm1d(128 * 4 * 4)
- self.fc1 = torch.nn.Linear(128 * 4 * 4, 30 * 30)
- self.norm5 = torch.nn.BatchNorm1d(30 * 30)
- self.fc2 = torch.nn.Linear(30 * 30, 15 * 15)
- def forward(self, x):
- x = x.view(-1, 3, 16, 16)
- x = F.relu(self.conv1(x))
- x = F.relu(self.conv2(x))
- x = self.pool(x)
- x = self.norm1(x)
- x = F.relu(self.conv3(x))
- x = self.pool(x)
- x = F.relu(self.conv4(x))
- x = self.norm2(x)
- x = F.relu(self.conv5(x))
- x = self.norm3(x)
- x = F.relu(self.conv6(x))
- #x = self.pool(x)
- x = x.view(-1, 128 * 4 * 4)
- x = self.norm4(x)
- x = F.relu(self.fc1(x))
- x = self.norm5(x)
- x = self.fc2(x)
- return(x)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement