Advertisement
lamiastella

train/test acc for Transfer Learning

Nov 13th, 2018
186
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 12.20 KB | None | 0 0
  1. %matplotlib inline
  2. from graphviz import Digraph
  3. import torch
  4. from torch.autograd import Variable
  5.  
  6.  
  7. # make_dot was moved to https://github.com/szagoruyko/pytorchviz
  8. from torchviz import make_dot
  9. # -*- coding: utf-8 -*-
  10. """
  11. Transfer Learning Tutorial
  12. ==========================
  13. **Author**: `Sasank Chilamkurthy <https://chsasank.github.io>`_
  14.  
  15. In this tutorial, you will learn how to train your network using
  16. transfer learning. You can read more about the transfer learning at `cs231n
  17. notes <http://cs231n.github.io/transfer-learning/>`__
  18.  
  19. Quoting these notes,
  20.  
  21.    In practice, very few people train an entire Convolutional Network
  22.    from scratch (with random initialization), because it is relatively
  23.    rare to have a dataset of sufficient size. Instead, it is common to
  24.    pretrain a ConvNet on a very large dataset (e.g. ImageNet, which
  25.    contains 1.2 million images with 1000 categories), and then use the
  26.    ConvNet either as an initialization or a fixed feature extractor for
  27.    the task of interest.
  28.  
  29. These two major transfer learning scenarios look as follows:
  30.  
  31. -  **Finetuning the convnet**: Instead of random initializaion, we
  32.   initialize the network with a pretrained network, like the one that is
  33.   trained on imagenet 1000 dataset. Rest of the training looks as
  34.   usual.
  35. -  **ConvNet as fixed feature extractor**: Here, we will freeze the weights
  36.   for all of the network except that of the final fully connected
  37.   layer. This last fully connected layer is replaced with a new one
  38.   with random weights and only this layer is trained.
  39.  
  40. """
  41. # License: BSD
  42. # Author: Sasank Chilamkurthy
  43.  
  44. from __future__ import print_function, division
  45.  
  46. import torch
  47. import torch.nn as nn
  48. import torch.optim as optim
  49. from torch.optim import lr_scheduler
  50. import numpy as np
  51. import torchvision
  52. from torchvision import datasets, models, transforms
  53. import matplotlib.pyplot as plt
  54. import time
  55. import os
  56. import copy
  57.  
  58. plt.ion()   # interactive mode
  59.  
  60.  
  61. ######################################################################
  62. # Load Data
  63. # ---------
  64. #
  65. # We will use torchvision and torch.utils.data packages for loading the
  66. # data.
  67. #
  68. # The problem we're going to solve today is to train a model to classify
  69. # **ants** and **bees**. We have about 120 training images each for ants and bees.
  70. # There are 75 validation images for each class. Usually, this is a very
  71. # small dataset to generalize upon, if trained from scratch. Since we
  72. # are using transfer learning, we should be able to generalize reasonably
  73. # well.
  74. #
  75. # This dataset is a very small subset of imagenet.
  76. #
  77. # .. Note ::
  78. #    Download the data from
  79. #    `here <https://download.pytorch.org/tutorial/hymenoptera_data.zip>`_
  80. #    and extract it to the current directory.
  81.  
  82. # Data augmentation and normalization for training
  83. # Just normalization for validation
  84. data_transforms = {
  85.     'train': transforms.Compose([
  86.         transforms.RandomResizedCrop(224),
  87.         transforms.RandomHorizontalFlip(),
  88.         transforms.ToTensor(),
  89.         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  90.     ]),
  91. #    'val': transforms.Compose([
  92. #        transforms.Resize(256),
  93. #        transforms.CenterCrop(224),
  94. #        transforms.ToTensor(),
  95. #        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  96. #    ]),
  97.    
  98.         'test': transforms.Compose([
  99.         transforms.Resize(256),
  100.         transforms.CenterCrop(224),
  101.         transforms.ToTensor(),
  102.         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  103.     ]),
  104. }
  105.  
  106. #data_dir = 'hymenoptera_data'
  107. #data_dir = "mona_data"
  108. #data_dir = "shooting_data_2class"
  109. #data_dir = "shooting_data_3cat"
  110. data_dir = "images"
  111. image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
  112.                                           data_transforms[x])
  113.                   for x in ['train', 'test']}
  114. dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
  115.                                              shuffle=True, num_workers=4)
  116.               for x in ['train', 'test']}
  117. dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'test']}
  118. class_names = image_datasets['train'].classes
  119.  
  120. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  121.  
  122.  
  123. ######################################################################
  124. # Visualize a few images
  125. # ^^^^^^^^^^^^^^^^^^^^^^
  126. # Let's visualize a few training images so as to understand the data
  127. # augmentations.
  128.  
  129. def imshow(inp, title=None):
  130.     """Imshow for Tensor."""
  131.     inp = inp.numpy().transpose((1, 2, 0))
  132.     mean = np.array([0.485, 0.456, 0.406])
  133.     std = np.array([0.229, 0.224, 0.225])
  134.     inp = std * inp + mean
  135.     inp = np.clip(inp, 0, 1)
  136.     plt.imshow(inp)
  137.     if title is not None:
  138.         plt.title(title)
  139.     plt.pause(0.001)  # pause a bit so that plots are updated
  140.  
  141.  
  142. # Get a batch of training data
  143. inputs, classes = next(iter(dataloaders['train']))
  144.  
  145. # Make a grid from batch
  146. out = torchvision.utils.make_grid(inputs)
  147.  
  148. imshow(out, title=[class_names[x] for x in classes])
  149.  
  150.  
  151. ######################################################################
  152. # Training the model
  153. # ------------------
  154. #
  155. # Now, let's write a general function to train a model. Here, we will
  156. # illustrate:
  157. #
  158. # -  Scheduling the learning rate
  159. # -  Saving the best model
  160. #
  161. # In the following, parameter ``scheduler`` is an LR scheduler object from
  162. # ``torch.optim.lr_scheduler``.
  163.  
  164.  
  165. def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
  166.     since = time.time()
  167.  
  168.     best_model_wts = copy.deepcopy(model.state_dict())
  169.     best_acc = 0.0
  170.  
  171.     for epoch in range(num_epochs):
  172.         print('Epoch {}/{}'.format(epoch, num_epochs - 1))
  173.         print('-' * 10)
  174.  
  175.         # Each epoch has a training and validation phase
  176.         for phase in ['train', 'test']:
  177.             if phase == 'train':
  178.                 scheduler.step()
  179.                 model.train()  # Set model to training mode
  180. ##            else:
  181. ##                model.eval()   # Set model to evaluate mode
  182.  
  183.             running_loss = 0.0
  184.             running_corrects = 0
  185.  
  186.             # Iterate over data.
  187.             for inputs, labels in dataloaders[phase]:
  188.                 inputs = inputs.to(device)
  189.                 labels = labels.to(device)
  190.  
  191.                 # zero the parameter gradients
  192.                 optimizer.zero_grad()
  193.  
  194.                 # forward
  195.                 # track history if only in train
  196.                 with torch.set_grad_enabled(phase == 'train'):
  197.                     outputs = model(inputs)
  198.                     _, preds = torch.max(outputs, 1)
  199.                     loss = criterion(outputs, labels)
  200.  
  201.                     # backward + optimize only if in training phase
  202.                     if phase == 'train':
  203.                         loss.backward()
  204.                         optimizer.step()
  205.  
  206.                 # statistics
  207.                 running_loss += loss.item() * inputs.size(0)
  208.                 running_corrects += torch.sum(preds == labels.data)
  209.  
  210.             epoch_loss = running_loss / dataset_sizes[phase]
  211.             epoch_acc = running_corrects.double() / dataset_sizes[phase]
  212.  
  213.             print('{} Loss: {:.4f} Acc: {:.4f}'.format(
  214.                 phase, epoch_loss, epoch_acc))
  215.  
  216.             # deep copy the model
  217.  #           if phase == 'val' and epoch_acc > best_acc:
  218.  #               best_acc = epoch_acc
  219.  #               best_model_wts = copy.deepcopy(model.state_dict())
  220.  
  221.         print()
  222.  
  223.     time_elapsed = time.time() - since
  224.     print('Training complete in {:.0f}m {:.0f}s'.format(
  225.         time_elapsed // 60, time_elapsed % 60))
  226. #    print('Best val Acc: {:4f}'.format(best_acc))
  227.  
  228.     # load best model weights
  229. #    model.load_state_dict(best_model_wts)
  230.     return model
  231.  
  232.  
  233. ######################################################################
  234. # Visualizing the model predictions
  235. # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  236. #
  237. # Generic function to display predictions for a few images
  238. #
  239.  
  240. def visualize_model(model, num_images=6):
  241.     was_training = model.training
  242.     model.eval()
  243.     images_so_far = 0
  244.     fig = plt.figure()
  245.  
  246.     with torch.no_grad():
  247.         for i, (inputs, labels) in enumerate(dataloaders['test']):
  248.             inputs = inputs.to(device)
  249.             labels = labels.to(device)
  250.  
  251.             outputs = model(inputs)
  252.             _, preds = torch.max(outputs, 1)
  253.  
  254.             for j in range(inputs.size()[0]):
  255.                 images_so_far += 1
  256.                 ax = plt.subplot(num_images//2, 2, images_so_far)
  257.                 ax.axis('off')
  258.                 ax.set_title('predicted: {}'.format(class_names[preds[j]]))
  259.                 imshow(inputs.cpu().data[j])
  260.  
  261.                 if images_so_far == num_images:
  262.                     model.train(mode=was_training)
  263.                     return
  264.         model.train(mode=was_training)
  265.  
  266. ######################################################################
  267. # Finetuning the convnet
  268. # ----------------------
  269. #
  270. # Load a pretrained model and reset final fully connected layer.
  271. #
  272.  
  273. #model_ft = models.resnet18(pretrained=True)
  274. model_ft = models.resnet50(pretrained=True)
  275.  
  276. num_ftrs = model_ft.fc.in_features
  277. model_ft.fc = nn.Linear(num_ftrs, 9)
  278.  
  279. model_ft = model_ft.to(device)
  280.  
  281. criterion = nn.CrossEntropyLoss()
  282.  
  283. # Observe that all parameters are being optimized
  284. optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
  285.  
  286. # Decay LR by a factor of 0.1 every 7 epochs
  287. exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
  288.  
  289.  
  290. ######################################################################
  291. # Train and evaluate
  292. # ^^^^^^^^^^^^^^^^^^
  293. #
  294. # It should take around 15-25 min on CPU. On GPU though, it takes less than a
  295. # minute.
  296. #
  297.  
  298. model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler,
  299.                        num_epochs=25)
  300.  
  301. ######################################################################
  302. #
  303.  
  304. visualize_model(model_ft)
  305.  
  306.  
  307. ---------------------------------------------------------------------
  308. Epoch 0/24
  309. ----------
  310. train Loss: 2.0849 Acc: 0.3047
  311. test Loss: 1.9907 Acc: 0.3643
  312.  
  313. Epoch 1/24
  314. ----------
  315. train Loss: 1.9912 Acc: 0.3262
  316. test Loss: 1.9723 Acc: 0.3143
  317.  
  318. Epoch 2/24
  319. ----------
  320. train Loss: 1.8772 Acc: 0.3451
  321. test Loss: 1.9634 Acc: 0.3429
  322.  
  323. Epoch 3/24
  324. ----------
  325. train Loss: 1.8997 Acc: 0.3477
  326. test Loss: 2.2405 Acc: 0.2857
  327.  
  328. Epoch 4/24
  329. ----------
  330. train Loss: 1.8376 Acc: 0.3869
  331. test Loss: 2.0975 Acc: 0.2857
  332.  
  333. Epoch 5/24
  334. ----------
  335. train Loss: 1.7459 Acc: 0.4121
  336. test Loss: 2.0324 Acc: 0.3143
  337.  
  338. Epoch 6/24
  339. ----------
  340. train Loss: 1.7635 Acc: 0.4046
  341. test Loss: 2.0811 Acc: 0.4071
  342.  
  343. Epoch 7/24
  344. ----------
  345. train Loss: 1.4989 Acc: 0.4829
  346. test Loss: 1.9167 Acc: 0.4214
  347.  
  348. Epoch 8/24
  349. ----------
  350. train Loss: 1.3821 Acc: 0.5145
  351. test Loss: 1.9868 Acc: 0.3929
  352.  
  353. Epoch 9/24
  354. ----------
  355. train Loss: 1.3185 Acc: 0.5575
  356. test Loss: 1.9225 Acc: 0.4143
  357.  
  358. Epoch 10/24
  359. ----------
  360. train Loss: 1.3083 Acc: 0.5436
  361. test Loss: 1.9001 Acc: 0.4357
  362.  
  363. Epoch 11/24
  364. ----------
  365. train Loss: 1.2618 Acc: 0.5638
  366. test Loss: 1.9409 Acc: 0.4000
  367.  
  368. Epoch 12/24
  369. ----------
  370. train Loss: 1.2696 Acc: 0.5765
  371. test Loss: 1.9952 Acc: 0.3857
  372.  
  373. Epoch 13/24
  374. ----------
  375. train Loss: 1.2782 Acc: 0.5638
  376. test Loss: 1.8705 Acc: 0.4143
  377.  
  378. Epoch 14/24
  379. ----------
  380. train Loss: 1.1628 Acc: 0.6233
  381. test Loss: 1.9135 Acc: 0.4071
  382.  
  383. Epoch 15/24
  384. ----------
  385. train Loss: 1.2363 Acc: 0.5904
  386. test Loss: 1.9826 Acc: 0.4071
  387.  
  388. Epoch 16/24
  389. ----------
  390. train Loss: 1.2247 Acc: 0.5879
  391. test Loss: 1.9062 Acc: 0.4357
  392.  
  393. Epoch 17/24
  394. ----------
  395. train Loss: 1.1758 Acc: 0.6157
  396. test Loss: 1.9463 Acc: 0.4500
  397.  
  398. Epoch 18/24
  399. ----------
  400. train Loss: 1.2133 Acc: 0.5942
  401. test Loss: 1.9168 Acc: 0.4143
  402.  
  403. Epoch 19/24
  404. ----------
  405. train Loss: 1.1976 Acc: 0.5828
  406. test Loss: 1.9197 Acc: 0.4000
  407.  
  408. Epoch 20/24
  409. ----------
  410. train Loss: 1.1934 Acc: 0.6119
  411. test Loss: 1.8853 Acc: 0.4071
  412.  
  413. Epoch 21/24
  414. ----------
  415. train Loss: 1.1578 Acc: 0.6068
  416. test Loss: 1.9011 Acc: 0.3929
  417.  
  418. Epoch 22/24
  419. ----------
  420. train Loss: 1.1713 Acc: 0.5967
  421. test Loss: 1.8918 Acc: 0.4143
  422.  
  423. Epoch 23/24
  424. ----------
  425. train Loss: 1.1726 Acc: 0.6068
  426. test Loss: 1.9312 Acc: 0.4643
  427.  
  428. Epoch 24/24
  429. ----------
  430. train Loss: 1.1744 Acc: 0.5891
  431. test Loss: 1.8914 Acc: 0.4143
  432.  
  433. Training complete in 5m 0s
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement