Advertisement
Guest User

Untitled

a guest
Mar 22nd, 2019
95
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.90 KB | None | 0 0
  1. class CNN(torch.nn.Module):
  2. def __init__(self):
  3. super(CNN, self).__init__()
  4. # batch_size x 3 x 16 x 16
  5. self.conv1 = torch.nn.Conv2d(3, 8, kernel_size=3, stride=1, padding=1)
  6. self.conv2 = torch.nn.Conv2d(8, 16, kernel_size=3, stride=1, padding=1)
  7. # batch_size x 16 x 16 x 16
  8. self.pool = torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
  9. self.norm1 = torch.nn.BatchNorm2d(16)
  10. #batch_size x 16 x 8 x 8
  11. self.conv3 = torch.nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
  12. # batch_size x 32 x 8 x 8
  13. #pool
  14. # batch_size x 32 x 4 x 4
  15. self.conv4 = torch.nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
  16. self.norm2 = torch.nn.BatchNorm2d(64)
  17. self.conv5 = torch.nn.Conv2d(64, 96, kernel_size=3, stride=1, padding=1)
  18. self.norm3 = torch.nn.BatchNorm2d(96)
  19. self.conv6 = torch.nn.Conv2d(96, 128, kernel_size=3, stride=1, padding=1)
  20. # batch_size x 128 x 4 x 4
  21. # Deleted for now ##pool
  22. # Deleted for now ##batch_size x 128 x 2 x 2
  23. self.norm4 = torch.nn.BatchNorm1d(128 * 4 * 4)
  24. self.fc1 = torch.nn.Linear(128 * 4 * 4, 30 * 30)
  25. self.norm5 = torch.nn.BatchNorm1d(30 * 30)
  26. self.fc2 = torch.nn.Linear(30 * 30, 15 * 15)
  27.  
  28. def forward(self, x):
  29. x = x.view(-1, 3, 16, 16)
  30. x = F.relu(self.conv1(x))
  31. x = F.relu(self.conv2(x))
  32. x = self.pool(x)
  33. x = self.norm1(x)
  34. x = F.relu(self.conv3(x))
  35. x = self.pool(x)
  36. x = F.relu(self.conv4(x))
  37. x = self.norm2(x)
  38. x = F.relu(self.conv5(x))
  39. x = self.norm3(x)
  40. x = F.relu(self.conv6(x))
  41. #x = self.pool(x)
  42. x = x.view(-1, 128 * 4 * 4)
  43. x = self.norm4(x)
  44. x = F.relu(self.fc1(x))
  45. x = self.norm5(x)
  46. x = self.fc2(x)
  47. return(x)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement