lamiastella

LOOCV transfer learning

Nov 13th, 2018
206
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 5.71 KB | None | 0 0
  1. model_ft = model_ft.cuda()
  2. nb_samples = 931
  3. nb_classes = 9
  4.  
  5.  
  6. from __future__ import print_function, division
  7.  
  8. import torch
  9. import torch.nn as nn
  10. import torch.optim as optim
  11. from torch.optim import lr_scheduler
  12. import numpy as np
  13. import torchvision
  14. from torchvision import datasets, models, transforms
  15. import matplotlib.pyplot as plt
  16. import time
  17. import os
  18. import copy
  19.  
  20. import torch.utils.data as data_utils
  21. from torch.utils import data
  22.  
  23.  
  24. def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
  25.     since = time.time()
  26.  
  27.     best_model_wts = copy.deepcopy(model.state_dict())
  28.     best_acc = 0.0
  29.  
  30.     for epoch in range(num_epochs):
  31.         print('Epoch {}/{}'.format(epoch, num_epochs - 1))
  32.         print('-' * 10)
  33.  
  34.         # Each epoch has a training and validation phase
  35.         for phase in ['train', 'test']:
  36.             if phase == 'train':
  37.                 scheduler.step()
  38.                 model.train()  # Set model to training mode
  39.             else:
  40.                 model.eval()   # Set model to evaluate mode
  41.  
  42.             running_loss = 0.0
  43.             running_corrects = 0
  44.  
  45.             # Iterate over data.
  46.             for inputs, labels in dataloaders[phase]:
  47.                 inputs = inputs.to(device)
  48.                 labels = labels.to(device)
  49.  
  50.                 # zero the parameter gradients
  51.                 optimizer.zero_grad()
  52.  
  53.                 # forward
  54.                 # track history if only in train
  55.                 with torch.set_grad_enabled(phase == 'train'):
  56.                     outputs = model(inputs)
  57.                     _, preds = torch.max(outputs, 1)
  58.                     loss = criterion(outputs, labels)
  59.  
  60.                     # backward + optimize only if in training phase
  61.                     if phase == 'train':
  62.                         loss.backward()
  63.                         optimizer.step()
  64.  
  65.                 # statistics
  66.                 running_loss += loss.item() * inputs.size(0)
  67.                 running_corrects += torch.sum(preds == labels.data)
  68.  
  69.             epoch_loss = running_loss / dataset_sizes[phase]
  70.             epoch_acc = running_corrects.double() / dataset_sizes[phase]
  71.  
  72.             print('{} Loss: {:.4f} Acc: {:.4f}'.format(
  73.                 phase, epoch_loss, epoch_acc))
  74.  
  75.             # deep copy the model
  76.  #           if phase == 'val' and epoch_acc > best_acc:
  77.  #               best_acc = epoch_acc
  78.  #               best_model_wts = copy.deepcopy(model.state_dict())
  79.  
  80.         print()
  81.  
  82.     time_elapsed = time.time() - since
  83.     print('Training complete in {:.0f}m {:.0f}s'.format(
  84.         time_elapsed // 60, time_elapsed % 60))
  85. #    print('Best val Acc: {:4f}'.format(best_acc))
  86.  
  87.     # load best model weights
  88. #    model.load_state_dict(best_model_wts)
  89.     return model
  90.  
  91.  
  92.  
  93.  
  94. data_transforms = {
  95.     'train': transforms.Compose([
  96.         transforms.RandomResizedCrop(224),
  97.         transforms.RandomHorizontalFlip(),
  98.         transforms.RandomRotation(20),
  99.         transforms.ToTensor(),
  100.         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  101.     ]),
  102.  
  103.         'test': transforms.Compose([
  104.         transforms.Resize(256),
  105.         transforms.CenterCrop(224),
  106.         transforms.ToTensor(),
  107.         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  108.     ]),
  109. }
  110.  
  111. val_loader = data.DataLoader(
  112.         image_datasets['test'],
  113.         num_workers=2,
  114.         batch_size=1
  115.     )
  116. val_loader = iter(val_loader)
  117.  
  118. data_dir = "images"
  119. image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
  120.                                           data_transforms[x])
  121.                   for x in ['train', 'test']}
  122.  
  123. dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'test']}
  124. print(dataset_sizes)
  125. class_names = image_datasets['train'].classes
  126.  
  127. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  128. # LOOCV
  129. loocv_preds = []
  130. loocv_targets = []
  131. for idx in range(nb_samples):
  132.    
  133.     print('Using sample {} as test data'.format(idx))
  134.    
  135.     # Get all indices and remove test sample
  136.     train_indices = list(range(len(image_datasets)))
  137.     del train_indices[idx]
  138.    
  139.     # Create new sampler
  140.     sampler = data.SubsetRandomSampler(train_indices)
  141.  
  142.     dataloader = data.DataLoader(
  143.         image_datasets['train'],
  144.         num_workers=2,
  145.         batch_size=1,
  146.         sampler=sampler
  147.     )
  148.    
  149.     # Train model
  150.     for batch_idx, (samples, target) in enumerate(dataloader):
  151.         print('Batch {}'.format(batch_idx))
  152.         model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=5) # do I add this line here?
  153.                
  154.     # Test on LOO sample
  155.     model_ft.eval()
  156. #    test_data, test_target = image_datasets['train'][idx]
  157.     test_data, test_target = val_loader.next()
  158.     test_data = test_data.cuda()
  159.     test_target = test_target.cuda()
  160.     #test_data.unsqueeze_(1)
  161.     #test_target.unsqueeze_(0)
  162.  
  163.     output = model_ft(test_data)
  164.     pred = torch.argmax(output, 1)
  165.     loocv_preds.append(pred)
  166.     loocv_targets.append(test_target.item())
  167.  
  168.  
  169. -----------------------------------------------------------------------------------------------
  170.  
  171. {'train': 791, 'test': 140}
  172. Using sample 0 as test data
  173. Batch 0
  174. Using sample 1 as test data
  175. Batch 0
  176. Using sample 2 as test data
  177.  
  178. ---------------------------------------------------------------------------
  179. IndexError                                Traceback (most recent call last)
  180. <ipython-input-35-4a2b0fe05e20> in <module>()
  181.      65     # Get all indices and remove test sample
  182.      66     train_indices = list(range(len(image_datasets)))
  183. ---> 67     del train_indices[idx]
  184.      68
  185.      69     # Create new sampler
  186.  
  187. IndexError: list assignment index out of range
Add Comment
Please, Sign In to add comment