Advertisement
Guest User

Untitled

a guest
May 26th, 2018
89
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 5.44 KB | None | 0 0
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. import torch.optim as optim
  5. import torchvision
  6. from torch.autograd import Variable
  7. from torchvision import datasets, transforms
  8. import os
  9. import shutil
  10. import random
  11. import numpy as np
  12. import matplotlib.pyplot as plt
  13.  
  14. '''
  15. Implementation of the logistic regression classifier which solves the problem of recognition of cats and dogs.
  16.  
  17. The dataset can be found here:
  18. https://www.kaggle.com/c/dogs-vs-cats
  19. '''
  20.  
  21. '''
  22. Splits train and test data into to the following folder structure:
  23. ---/train
  24. ------/cat
  25. ------/dog
  26. ---/test
  27. ------/cat
  28. ------/dog
  29.  
  30. Randomly selects 5000 images as test examples (2500 for each class).
  31. '''
  32.  
  33.  
  34. def split_dogs_and_cats(source):
  35. cats_test = './test/cat'
  36. dogs_test = './test/dog'
  37. cats_train = './train/cat'
  38. dogs_train = './train/dog'
  39.  
  40. files = os.listdir(source)
  41. os.makedirs(cats_test)
  42. os.makedirs(dogs_test)
  43. os.makedirs(cats_train)
  44. os.makedirs(dogs_train)
  45.  
  46. cats_test_index = [i for i in range(12500)]
  47. dogs_test_index = [i for i in range(12500)]
  48. random.shuffle(cats_test_index)
  49. random.shuffle(dogs_test_index)
  50. cats_test_index = cats_test_index[:2500]
  51. dogs_test_index = dogs_test_index[:2500]
  52.  
  53. for file in files:
  54. srcname = os.path.join(source, file)
  55. tag, number, _ = file.split('.')
  56. number = int(number)
  57. if tag == 'cat':
  58. if number in cats_test_index:
  59. dst = cats_test
  60. else:
  61. dst = cats_train
  62. else:
  63. if number in dogs_test_index:
  64. dst = dogs_test
  65. else:
  66. dst = dogs_train
  67.  
  68. dstname = os.path.join(dst, file)
  69. shutil.move(srcname, dstname)
  70.  
  71.  
  72. split_dogs_and_cats('./dataset')
  73.  
  74. batch_size = 32
  75. image_size = 128
  76.  
  77.  
  78. #Normalize the data.
  79.  
  80.  
  81. transformation = transforms.Compose([
  82. transforms.Resize((image_size, image_size)),
  83. transforms.ToTensor(),
  84. transforms.Normalize(mean=[0.485, 0.456, 0.406],
  85. std=[0.229, 0.224, 0.225])
  86. ])
  87.  
  88.  
  89. #Loads the data.
  90.  
  91.  
  92. train_data = datasets.ImageFolder(root='./train', transform=transformation)
  93. test_data = datasets.ImageFolder(root='./test', transform=transformation)
  94.  
  95. torchvision.utils.make_grid(train_data[0][0])
  96.  
  97. train_data_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=4)
  98. test_data_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=True, num_workers=4)
  99.  
  100.  
  101. #Let's check that we loaded the data correctly.
  102.  
  103.  
  104.  
  105. def imshow(inp, title=None):
  106. inp = inp.numpy().transpose((1, 2, 0))
  107. mean = np.array([0.485, 0.456, 0.406])
  108. std = np.array([0.229, 0.224, 0.225])
  109. inp = std * inp + mean
  110. inp = np.clip(inp, 0, 1)
  111. plt.imshow(inp)
  112. if title is not None:
  113. plt.title(title)
  114. plt.pause(0.001)
  115.  
  116.  
  117. #1 - dog, 0 - cat.
  118.  
  119.  
  120. inputs, classes = next(iter(train_data_loader))
  121.  
  122. out = torchvision.utils.make_grid(inputs)
  123.  
  124. imshow(out, title=classes)
  125.  
  126. '''
  127. Class that represents logistic regression.
  128. '''
  129.  
  130.  
  131. class NeuralNetwork(nn.Module):
  132.  
  133. def __init__(self):
  134. super(NeuralNetwork, self).__init__()
  135. self.layer = nn.Linear(3 * image_size * image_size, 1)
  136.  
  137. def forward(self, x):
  138. return F.sigmoid(self.layer(x)).squeeze()
  139.  
  140.  
  141. network = NeuralNetwork()
  142. criterion = nn.BCELoss()
  143.  
  144. '''
  145. Main procedure.
  146. Firstly, we train our classifier then test.
  147. '''
  148.  
  149.  
  150. def run(learning_rate):
  151. optimizer = optim.Adam(network.parameters(), lr=learning_rate)
  152. for epoch in range(1, 3):
  153. for batch_idx, (data, target) in enumerate(train_data_loader):
  154. data, target = Variable(data), Variable(target)
  155. data = data.view(-1, image_size * image_size * 3)
  156. output = network(data)
  157. cost = criterion(output, target.float())
  158. optimizer.zero_grad()
  159. cost.backward()
  160. optimizer.step()
  161.  
  162. if batch_idx % 10 == 0:
  163. print('Train Epoch: {} [{}/{} ({:.0f}%)]tLoss: {:.6f}'.format(
  164. epoch, batch_idx * len(data), len(train_data_loader.dataset),
  165. 100. * batch_idx / len(train_data_loader), cost.item()))
  166.  
  167. '''
  168. Test set evaluation.
  169. '''
  170.  
  171. network.eval()
  172. test_loss = 0
  173. correct = 0
  174. for data, target in test_data_loader:
  175. data, target = Variable(data), Variable(target)
  176. data = data.view(-1, image_size * image_size * 3)
  177. output = network(data)
  178. test_loss += criterion(output, target.float()).item()
  179. pred = output.ge(0.5)
  180. correct += torch.eq(pred, target.byte()).sum()
  181.  
  182. test_loss /= len(test_data_loader.dataset)
  183.  
  184. print('nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)n'.format(
  185. test_loss, correct, len(test_data_loader.dataset),
  186. 100. * correct / len(test_data_loader.dataset)))
  187.  
  188. '''
  189. Train set evaluation.
  190. '''
  191.  
  192. train_loss = 0
  193. correct = 0
  194. for data, target in train_data_loader:
  195. data, target = Variable(data), Variable(target)
  196. data = data.view(-1, image_size * image_size * 3)
  197. output = network(data)
  198. train_loss += criterion(output, target.float()).item()
  199. pred = output.ge(0.5)
  200. correct += torch.eq(pred, target.byte()).sum()
  201.  
  202. train_loss /= len(train_data_loader.dataset)
  203.  
  204. print('nTrain set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)n'.format(
  205. train_loss, correct, len(train_data_loader.dataset),
  206. 100. * correct / len(train_data_loader.dataset)))
  207.  
  208.  
  209. run(0.0001)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement