Advertisement
Guest User

Untitled

a guest
Dec 14th, 2019
109
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 6.89 KB | None | 0 0
  1. import os
  2. from torchvision import datasets
  3. import torchvision
  4.  
  5. ### TODO: Write data loaders for training, validation, and test sets
  6. ## Specify appropriate transforms, and batch_sizes
  7.  
  8. data_dir = '/data/dog_images'
  9. train_dir = data_dir + '/train'
  10. valid_dir = data_dir + '/valid'
  11. test_dir = data_dir + '/test'
  12.  
  13. data_transforms={
  14. 'train_transform':transforms.Compose([transforms.Resize(224),
  15. transforms.RandomRotation(30),
  16. #transforms.RandomResizedCrop(256),
  17. transforms.RandomHorizontalFlip(),
  18. transforms.ToTensor(),
  19. transforms.Normalize([0.485, 0.456, 0.406],
  20. [0.229, 0.224, 0.225])]),
  21. 'valid_transform':transforms.Compose([transforms.Resize(256),
  22. transforms.CenterCrop(256),
  23. transforms.ToTensor(),
  24. transforms.Normalize([0.485, 0.456, 0.406],
  25. [0.229, 0.224, 0.225])]),
  26. 'test_transform':transforms.Compose([transforms.Resize(256),
  27. transforms.CenterCrop(256),
  28. transforms.ToTensor(),
  29. transforms.Normalize([0.485, 0.456, 0.406],
  30. [0.229, 0.224, 0.225])])}
  31.  
  32. train_data = torchvision.datasets.CIFAR10ImageFolder(train_dir, transform=data_transforms['train_transform'])
  33. valid_data = torchvision.datasets.ImageFolder(valid_dir, transform=data_transforms['valid_transform'])
  34. test_data = torchvision.datasets.ImageFolder(test_dir, transform=data_transforms['test_transform'])
  35.  
  36. image_datasets={
  37. 'train_data':torchvision.datasets.ImageFolder(train_dir, train_data),
  38. 'valid_data':torchvision.datasets.ImageFolder(valid_dir, valid_data),
  39. 'test_data': torchvision.datasets.ImageFolder(train_dir, test_data)}
  40.  
  41. train_loader = torch.utils.data.DataLoader(image_datasets['train_data'], batch_size=20,shuffle=True)
  42. valid_loader = torch.utils.data.DataLoader(image_datasets['valid_data'], batch_size=20,shuffle=True)
  43. test_loader = torch.utils.data.DataLoader(image_datasets['test_data'], batch_size=20,shuffle=True)
  44.  
  45. loaders_scratch={
  46. 'train':train_loader,
  47. 'valid': train_loader,
  48. 'test': test_loader}
  49.  
  50.  
  51. import torch.nn as nn
  52. import torch.nn.functional as F
  53.  
  54. # define the CNN architecture
  55. class Net(nn.Module):
  56. ### TODO: choose an architecture, and complete the class
  57. def __init__(self):
  58. super(Net, self).__init__()
  59. ## Define layers of a CNN
  60. self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
  61. self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
  62. self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
  63. # max pooling layer
  64. self.max_pool = nn.MaxPool2d(2, 2)
  65. self.relu = nn.ReLU(inplace=True)
  66.  
  67. self.fc1 = nn.Linear(7 * 7 * 128, 500)
  68. self.fc2 = nn.Linear(500, 133)
  69. self.dropout = nn.Dropout(0.25)
  70.  
  71. def forward(self, x):
  72. ## Define forward behavior
  73. x = F.relu(self.conv1(x))
  74. x = self.max_pool(x)
  75.  
  76. x = F.relu(self.conv2(x))
  77. x = self.max_pool(x)
  78.  
  79. x = F.relu(self.conv3(x))
  80. x = self.max_pool(x)
  81.  
  82. # flatten image input
  83. x = x.view(-1, 7 * 7 * 128)
  84. # add dropout layer
  85. x = self.dropout(x)
  86.  
  87. x = F.relu(self.fc1(x))
  88. # add dropout layer
  89. x = self.dropout(x)
  90. x = self.fc2(x)
  91. return x
  92.  
  93. #-#-# You so NOT have to modify the code below this line. #-#-#
  94.  
  95. # instantiate the CNN
  96. model_scratch = Net()
  97.  
  98. # move tensors to GPU if CUDA is available
  99. if use_cuda:
  100. model_scratch.cuda()
  101.  
  102.  
  103. import torch.optim as optim
  104. ### TODO: select loss function
  105. criterion_scratch = nn.CrossEntropyLoss()
  106.  
  107. ### TODO: select optimizer
  108. optimizer_scratch = optim.SGD(model_scratch.parameters(), lr = 0.05)
  109.  
  110. def train(n_epochs, loaders, model, optimizer, criterion, use_cuda, save_path):
  111. """returns trained model"""
  112. # initialize tracker for minimum validation loss
  113. valid_loss_min = np.Inf
  114.  
  115. for epoch in range(1, n_epochs+1):
  116. # initialize variables to monitor training and validation loss
  117. train_loss = 0.0
  118. valid_loss = 0.0
  119.  
  120. ###################
  121. # train the model #
  122. ###################
  123. model.train()
  124. for batch_idx, (data, target) in enumerate(loaders['train']):
  125. # move to GPU
  126. if use_cuda:
  127. data, target = data.cuda(), target.cuda()
  128. ## find the loss and update the model parameters accordingly
  129. ## record the average training loss, using something like
  130. ## train_loss = train_loss + ((1 / (batch_idx + 1)) * (loss.data - train_loss))
  131.  
  132. optimizer.zero_grad()
  133. # forward pass
  134. output = model(data)
  135. # calculate the loss
  136. loss = criterion(output, target)
  137. # backward pass
  138. loss.backward()
  139. # perform optimization step
  140. optimizer.step()
  141. # update training loss
  142. train_loss += ((1 / (batch_idx + 1)) * (loss.data - train_loss))
  143. #train_loss += loss.item()*data.size(0)
  144.  
  145. ######################
  146. # validate the model #
  147. ######################
  148. model.eval()
  149. for batch_idx, (data, target) in enumerate(loaders['valid']):
  150. # move to GPU
  151. if use_cuda:
  152. data, target = data.cuda(), target.cuda()
  153. ## update the average validation loss
  154. output = model(data)
  155. loss = criterion(output, target)
  156. valid_loss += ((1 / (batch_idx + 1)) * (loss.data - valid_loss))
  157.  
  158. # print training/validation statistics
  159. print('Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(
  160. epoch,
  161. train_loss,
  162. valid_loss
  163. ))
  164.  
  165. ## TODO: save the model if validation loss has decreased
  166. if valid_loss <= valid_loss_min:
  167. valid_loss_min = valid_loss
  168. print('Validation loss decreased ({:.6f} --> {:.6f}). Saving model ...'.format(
  169. valid_loss_min,
  170. valid_loss))
  171. torch.save(model.state_dict(), save_path)
  172.  
  173. # return trained model
  174. return model
  175.  
  176. # train the model
  177. model_scratch = train(20, loaders_scratch, model_scratch, optimizer_scratch,criterion_scratch, use_cuda, 'model_scratch.pt')
  178.  
  179.  
  180. # load the model that got the best validation accuracy
  181. model_scratch.load_state_dict(torch.load('model_scratch.pt'))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement