Advertisement
Guest User

Untitled

a guest
Sep 19th, 2019
105
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.80 KB | None | 0 0
  1. model class of CNN in PyToch
  2.  
  3. ```python
  4. # with batch normalization, dropout layer and 3 convolutional layers
  5.  
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8.  
  9. # define the CNN architecture
  10. class Net(nn.Module):
  11. def __init__(self):
  12. super(Net, self).__init__()
  13. ## Define layers of a CNN
  14. self.conv1 = nn.Conv2d(3,16,3, padding=1)
  15. self.conv2 = nn.Conv2d(16,32,3, padding=1)
  16. self.conv3 = nn.Conv2d(32,64,3, padding=1)
  17. #self.conv4 = nn.Conv2d(64,128,3, padding=1)
  18. #self.conv5 = nn.Conv2d(128,256,3, padding=1)
  19. self.pool = nn.MaxPool2d(2,2)
  20. # 224x224 size images will go through 5 maxpooling layer of 2,2 => 224/2/2/2/2/2 = 7.
  21. # final image size is 7x7.
  22. # the number of parameters will be 7*7*number of output features 256
  23. #self.fc1 = nn.Linear(7*7*256,500)
  24. self.fc1 = nn.Linear(28*28*64,500)
  25. self.fc2 = nn.Linear(500,133)
  26. self.dropout = nn.Dropout(0.25)
  27.  
  28. self.batch_norm = nn.BatchNorm1d(num_features=500)
  29.  
  30. def forward(self, x):
  31. ## Define forward behavior
  32. x = self.pool(F.relu(self.conv1(x)))
  33. x = self.pool(F.relu(self.conv2(x)))
  34. x = self.pool(F.relu(self.conv3(x)))
  35. #x = self.pool(F.relu(self.conv4(x)))
  36. #x = self.pool(F.relu(self.conv5(x)))
  37. # flatten to a vector
  38. #x = x.view(-1, 7*7*256)
  39. x = x.view(-1, 28*28*64)
  40. x = self.dropout(x)
  41. x = F.relu(self.batch_norm(self.fc1(x)))
  42. #x = F.relu(self.fc1(x))
  43.  
  44. x = self.dropout(x)
  45. x = self.fc2(x)
  46. return x
  47.  
  48.  
  49. # check if CUDA is available
  50. use_cuda = torch.cuda.is_available()
  51.  
  52.  
  53. # instantiate the CNN
  54. model_scratch = Net()
  55.  
  56. # move tensors to GPU if CUDA is available
  57. if use_cuda:
  58. model_scratch.cuda()
  59.  
  60. ```
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement