Advertisement
lamiastella

pytorch forum code for loocv

Nov 25th, 2018
388
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.90 KB | None | 0 0
  1. from __future__ import print_function, division
  2.  
  3. import torch
  4. from torch.autograd import Variable
  5. from sklearn.metrics import accuracy_score
  6. from sklearn.metrics import confusion_matrix
  7. import torch
  8. import torch.nn as nn
  9. import torch.optim as optim
  10. from torch.optim import lr_scheduler
  11. import numpy as np
  12. import torchvision
  13. from torchvision import datasets, models, transforms
  14. import matplotlib.pyplot as plt
  15. import time
  16. import os
  17. import copy
  18.  
  19. import torch.utils.data as data_utils
  20. from torch.utils import data
  21.  
  22. torch.manual_seed(2809)
  23.  
  24.  
  25. def train_model(model, criterion, optimizer, scheduler,
  26.                 train_input, train_label, num_epochs=25):
  27.     since = time.time()
  28.     model.train()  # Set model to training mode
  29.     for epoch in range(num_epochs):
  30.         print('Epoch {}/{}'.format(epoch, num_epochs - 1))
  31.         print('-' * 10)
  32.  
  33.         scheduler.step()
  34.        
  35.         running_loss = 0.0
  36.         running_corrects = 0
  37.  
  38.         # Iterate over data.
  39.         train_input = train_input.to(device)
  40.         train_label = train_label.to(device)
  41.  
  42.         # zero the parameter gradients
  43.         optimizer.zero_grad()
  44.  
  45.         output = model(train_input)
  46.         _, pred = torch.max(output, 1)
  47.         loss = criterion(output, train_label)
  48.         loss.backward()
  49.         optimizer.step()
  50.  
  51.         # statistics
  52.         running_loss += loss.item() * train_input.size(0)
  53.         running_corrects += torch.sum(pred == train_label.data)
  54.  
  55.         epoch_loss = running_loss / dataset_size['train']
  56.         epoch_acc = running_corrects.double() / dataset_size['train']
  57.  
  58.         print('train Loss: {:.4f} Acc: {:.4f}'.format(epoch_loss, epoch_acc))
  59.  
  60.         print()
  61.  
  62.     time_elapsed = time.time() - since
  63.     print('Training complete in {:.0f}m {:.0f}s'.format(
  64.         time_elapsed // 60, time_elapsed % 60))
  65.  
  66.     return model
  67.  
  68.  
  69. data_transforms = {
  70.     'train': transforms.Compose([
  71.         transforms.RandomResizedCrop(224),
  72.         transforms.RandomHorizontalFlip(),
  73.         transforms.RandomRotation(20),
  74.         transforms.ToTensor(),
  75.         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  76.     ]),
  77.     'test': transforms.Compose([
  78.         transforms.Resize(224),
  79.         transforms.ToTensor(),
  80.         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  81.     ])
  82. }
  83.  
  84.  
  85. data_dir = "test_images"
  86. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  87.  
  88. model_ft = models.resnet50(pretrained=True)
  89. num_ftrs = model_ft.fc.in_features
  90. model_ft.fc = nn.Linear(num_ftrs, 2)
  91. model_ft = model_ft.to(device)
  92.  
  93. # Save a clone of initial model to restore later
  94. initial_model = copy.deepcopy(model_ft)
  95.  
  96. criterion = nn.CrossEntropyLoss()
  97.  
  98. # Observe that all parameters are being optimized
  99. optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
  100.  
  101. # Decay LR by a factor of 0.1 every 7 epochs
  102. exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
  103.  
  104. #model_ft = model_ft.cuda()
  105. nb_samples = 10
  106. nb_classes = 2
  107.  
  108. image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
  109.                                           data_transforms[x])
  110.                   for x in ['train']}
  111.  
  112. dataset_size = {x: len(image_datasets[x]) for x in ['train']}
  113. class_names = image_datasets['train'].classes
  114.  
  115. # LOOCV
  116. loocv_preds = []
  117. loocv_targets = []
  118. for idx in range(nb_samples):
  119.  
  120.     print('Using sample {} as test data'.format(idx))
  121.    
  122.     print('Resetting model')
  123.     model_ft = copy.deepcopy(initial_model)
  124.  
  125.     # Get all indices and remove test sample
  126.     train_indices = list(range(len(image_datasets['train'])))
  127.     del train_indices[idx]
  128.  
  129.     # Create new sampler
  130.     sampler = data.SubsetRandomSampler(train_indices)
  131.     dataloader = data.DataLoader(
  132.         image_datasets['train'],
  133.         num_workers=2,
  134.         batch_size=1,
  135.         sampler=sampler
  136.     )
  137.  
  138.     # Train model
  139.     for batch_idx, (sample, target) in enumerate(dataloader):
  140.         print('Batch {}'.format(batch_idx))
  141.         model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, sample, target, num_epochs=1)
  142.  
  143.     # Test on LOO sample
  144.     model_ft.eval()
  145.     test_data, test_target = image_datasets['train'][idx]
  146.     # Apply test preprocessing on data
  147.     print(type(test_data))
  148.     test_data = data_transforms['test'](transforms.ToPILImage()(test_data))
  149.     test_data = test_data.cuda()
  150.     test_target = torch.tensor(test_target)
  151.     test_target = test_target.cuda()
  152.     test_data.unsqueeze_(0)
  153.     test_target.unsqueeze_(0)
  154.     print(test_data.shape)
  155.     output = model_ft(test_data)
  156.     pred = torch.argmax(output, 1)
  157.     loocv_preds.append(pred)
  158.     loocv_targets.append(test_target.item())
  159.  
  160.  
  161. print("loocv preds: ", loocv_preds)
  162. print("loocv targets: ", loocv_targets)
  163. print(accuracy_score(loocv_targets, loocv_preds))
  164. print(confusion_matrix(loocv_targets, loocv_preds))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement