lamiastella

loocv

Nov 20th, 2018
339
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. from __future__ import print_function, division
  2.  
  3. import torch
  4. from torch.autograd import Variable
  5.  
  6.  
  7.  
  8. import torch
  9. import torch.nn as nn
  10. import torch.optim as optim
  11. from torch.optim import lr_scheduler
  12. import numpy as np
  13. import torchvision
  14. from torchvision import datasets, models, transforms
  15. import matplotlib.pyplot as plt
  16. import time
  17. import os
  18. import copy
  19.  
  20.  
  21.  
  22. import torch.utils.data as data_utils
  23. from torch.utils import data
  24.  
  25.  
  26. data_transforms = {
  27.     'train': transforms.Compose([
  28.         transforms.RandomResizedCrop(224),
  29.         transforms.RandomHorizontalFlip(),
  30.         transforms.RandomRotation(20),
  31.         transforms.ToTensor(),
  32.         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  33.     ])
  34. }
  35.  
  36.  
  37. data_dir = "images"
  38.  
  39. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  40.  
  41.  
  42. def imshow(inp, title=None):
  43.     """Imshow for Tensor."""
  44.     inp = inp.numpy().transpose((1, 2, 0))
  45.     mean = np.array([0.485, 0.456, 0.406])
  46.     std = np.array([0.229, 0.224, 0.225])
  47.     inp = std * inp + mean
  48.     inp = np.clip(inp, 0, 1)
  49.     plt.imshow(inp)
  50.     if title is not None:
  51.         plt.title(title)
  52.     plt.pause(0.001)  # pause a bit so that plots are updated
  53.  
  54.  
  55.  
  56. def train_model(model, criterion, optimizer, scheduler, dataloader, num_epochs=25):
  57.     since = time.time()
  58.  
  59.     best_model_wts = copy.deepcopy(model.state_dict())
  60.     best_acc = 0.0
  61.  
  62.     for epoch in range(num_epochs):
  63.         print('Epoch {}/{}'.format(epoch, num_epochs - 1))
  64.         print('-' * 10)
  65.  
  66.         # Each epoch has a training and validation phase
  67.         for phase in ['train']:
  68.             if phase == 'train':
  69.                 scheduler.step()
  70.                 model.train()  # Set model to training mode
  71.             else:
  72.                 model.eval()   # Set model to evaluate mode
  73.  
  74.             running_loss = 0.0
  75.             running_corrects = 0
  76.  
  77.             # Iterate over data.
  78.             #for inputs, labels in dataloaders[phase]:
  79.             for inputs, labels in dataloader:
  80.                 inputs = inputs.to(device)
  81.                 labels = labels.to(device)
  82.  
  83.                 # zero the parameter gradients
  84.                 optimizer.zero_grad()
  85.  
  86.                 # forward
  87.                 # track history if only in train
  88.                 with torch.set_grad_enabled(phase == 'train'):
  89.                     outputs = model(inputs)
  90.                     _, preds = torch.max(outputs, 1)
  91.                     loss = criterion(outputs, labels)
  92.  
  93.                     # backward + optimize only if in training phase
  94.                     if phase == 'train':
  95.                         loss.backward()
  96.                         optimizer.step()
  97.  
  98.                 # statistics
  99.                 running_loss += loss.item() * inputs.size(0)
  100.                 running_corrects += torch.sum(preds == labels.data)
  101.  
  102.             epoch_loss = running_loss / dataset_sizes[phase]
  103.             epoch_acc = running_corrects.double() / dataset_sizes[phase]
  104.  
  105.             print('{} Loss: {:.4f} Acc: {:.4f}'.format(
  106.                 phase, epoch_loss, epoch_acc))
  107.  
  108.             # deep copy the model
  109.  #           if phase == 'val' and epoch_acc > best_acc:
  110.  #               best_acc = epoch_acc
  111.  #               best_model_wts = copy.deepcopy(model.state_dict())
  112.  
  113.         print()
  114.  
  115.     time_elapsed = time.time() - since
  116.     print('Training complete in {:.0f}m {:.0f}s'.format(
  117.         time_elapsed // 60, time_elapsed % 60))
  118. #    print('Best val Acc: {:4f}'.format(best_acc))
  119.  
  120. #    model.load_state_dict(best_model_wts)
  121.     return model
  122.  
  123.  
  124. def visualize_model(model, num_images=6):
  125.     was_training = model.training
  126.     model.eval()
  127.     images_so_far = 0
  128.     fig = plt.figure()
  129.  
  130.     with torch.no_grad():
  131.         #for i, (inputs, labels) in enumerate(dataloaders['test]):
  132.         for i, (inputs, labels) in enumerate(dataloaders['train']):
  133.  
  134.             inputs = inputs.to(device)
  135.             labels = labels.to(device)
  136.  
  137.             outputs = model(inputs)
  138.             _, preds = torch.max(outputs, 1)
  139.  
  140.             for j in range(inputs.size()[0]):
  141.                 images_so_far += 1
  142.                 ax = plt.subplot(num_images//2, 2, images_so_far)
  143.                 ax.axis('off')
  144.                 ax.set_title('predicted: {}'.format(class_names[preds[j]]))
  145.                 imshow(inputs.cpu().data[j])
  146.  
  147.                 if images_so_far == num_images:
  148.                     model.train(mode=was_training)
  149.                     return
  150.         model.train(mode=was_training)
  151.  
  152.  
  153.  
  154. ######################################################################
  155. # Finetuning the convnet
  156. # ----------------------
  157. #
  158. # Load a pretrained model and reset final fully connected layer.
  159. #
  160.  
  161. #model_ft = models.resnet18(pretrained=True)
  162. model_ft = models.resnet50(pretrained=True)
  163.  
  164. num_ftrs = model_ft.fc.in_features
  165. model_ft.fc = nn.Linear(num_ftrs, 9)
  166.  
  167. model_ft = model_ft.to(device)
  168.  
  169. criterion = nn.CrossEntropyLoss()
  170.  
  171. # Observe that all parameters are being optimized
  172. optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
  173.  
  174. # Decay LR by a factor of 0.1 every 7 epochs
  175. exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
  176.  
  177.  
  178.  
  179. #model_ft = model_ft.cuda()
  180. nb_samples = 864
  181. nb_classes = 9
  182.  
  183.  
  184. data_transforms = {
  185.     'train': transforms.Compose([
  186.         transforms.RandomResizedCrop(224),
  187.         transforms.RandomHorizontalFlip(),
  188.         transforms.RandomRotation(20),
  189.         transforms.ToTensor(),
  190.         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  191.     ])
  192. }
  193.  
  194. '''val_loader = data.DataLoader(
  195.        image_datasets['train'],
  196.        num_workers=2,
  197.        batch_size=1
  198.    )
  199. val_loader = iter(val_loader)'''
  200.  
  201. image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
  202.                                           data_transforms[x])
  203.                   for x in ['train']}
  204.  
  205. dataset_sizes = {x: len(image_datasets[x]) for x in ['train']}
  206. class_names = image_datasets['train'].classes
  207.  
  208. # LOOCV
  209. loocv_preds = []
  210. loocv_targets = []
  211. for idx in range(nb_samples):
  212.    
  213.     print('Using sample {} as test data'.format(idx))
  214.    
  215.     # Get all indices and remove test sample
  216.     train_indices = list(range(len(image_datasets['train'])))
  217.     del train_indices[idx]
  218.    
  219.     # Create new sampler
  220.     sampler = data.SubsetRandomSampler(train_indices)
  221.  
  222.     dataloader = data.DataLoader(
  223.         image_datasets['train'],
  224.         num_workers=2,
  225.         batch_size=1,
  226.         sampler=sampler
  227.     )
  228.    
  229.     # Train model
  230.     for batch_idx, (samples, target) in enumerate(dataloader):
  231.         print('Batch {}'.format(batch_idx))
  232.         model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, dataloader, num_epochs=25) # do I add this line here?
  233.                
  234.     # Test on LOO sample
  235.     model_ft.eval()
  236.     test_data, test_target = image_datasets['train'][idx]
  237.     #test_data, test_target = dataloader.next()
  238.     test_data = test_data.cuda()
  239.     test_target = test_target.cuda()
  240.     test_data.unsqueeze_(1)
  241.     test_target.unsqueeze_(0)
  242.  
  243.     output = model_ft(test_data)
  244.     pred = torch.argmax(output, 1)
  245.     loocv_preds.append(pred)
  246.     loocv_targets.append(test_target.item())
RAW Paste Data Copied