Advertisement
lamiastella

recreate optimizer loocv

Nov 25th, 2018
305
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 6.86 KB | None | 0 0
  1. from __future__ import print_function, division
  2.  
  3. import torch
  4. from torch.autograd import Variable
  5. from sklearn.metrics import accuracy_score
  6. from sklearn.metrics import confusion_matrix
  7. import torch
  8. import torch.nn as nn
  9. import torch.optim as optim
  10. import torch.nn.functional as F
  11. from torch.optim import lr_scheduler
  12. import numpy as np
  13. import torchvision
  14. from torchvision import datasets, models, transforms
  15. import matplotlib.pyplot as plt
  16. import time
  17. import os
  18. import copy
  19.  
  20. import torch.utils.data as data_utils
  21. from torch.utils import data
  22.  
  23. torch.manual_seed(2809)
  24.  
  25.  
  26. def train_model(model, criterion, optimizer, scheduler,
  27.                 dataloader, num_epochs=25):
  28.     '''since = time.time()
  29.    model.train()  # Set model to training mode
  30.    for epoch in range(num_epochs):
  31.        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
  32.        print('-' * 10)
  33.  
  34.        scheduler.step()
  35.  
  36.        running_loss = 0.0
  37.        running_corrects = 0
  38.  
  39.        # Iterate over data.
  40.        train_input = train_input.to(device)
  41.        train_label = train_label.to(device)
  42.  
  43.        # zero the parameter gradients
  44.        optimizer.zero_grad()
  45.  
  46.        output = model(train_input)
  47.        _, pred = torch.max(output, 1)
  48.        loss = criterion(output, train_label)
  49.        loss.backward()
  50.        optimizer.step()
  51.  
  52.        # statistics
  53.        running_loss += loss.item() * train_input.size(0)
  54.        running_corrects += torch.sum(pred == train_label.data)
  55.  
  56.        epoch_loss = running_loss / dataset_size['train']
  57.        epoch_acc = running_corrects.double() / dataset_size['train']
  58.  
  59.        print('train Loss: {:.4f} Acc: {:.4f}'.format(epoch_loss, epoch_acc))
  60.  
  61.        print()
  62.  
  63.    time_elapsed = time.time() - since
  64.    print('Training complete in {:.0f}m {:.0f}s'.format(
  65.        time_elapsed // 60, time_elapsed % 60))
  66.  
  67.    return model'''
  68.  
  69.     since = time.time()
  70.        # Observe that all parameters are being optimized
  71.     optimizer = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
  72.  
  73.     # Decay LR by a factor of 0.1 every 7 epochs
  74.     exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
  75.  
  76.  
  77.     for epoch in range(num_epochs):
  78.         print('Epoch {}/{}'.format(epoch, num_epochs - 1))
  79.         print('-' * 10)
  80.  
  81.         # Each epoch has a training and validation phase
  82.  
  83.         scheduler.step()
  84.         model.train()  # Set model to training mode
  85.  
  86.  
  87.         running_loss = 0.0
  88.         running_corrects = 0
  89.  
  90.         # Iterate over data.
  91.         for inputs, labels in dataloader:
  92.             inputs = inputs.to(device)
  93.             labels = labels.to(device)
  94.  
  95.             # zero the parameter gradients
  96.             optimizer.zero_grad()
  97.  
  98.             # forward
  99.             # track history if only in train
  100.             with torch.set_grad_enabled(True):
  101.                 outputs = model(inputs)
  102.                 _, preds = torch.max(outputs, 1)
  103.                 loss = criterion(outputs, labels)
  104.                 # backward + optimize only if in training phase
  105.                 loss.backward()
  106.                 optimizer.step()
  107.             # statistics
  108.             running_loss += loss.item() * inputs.size(0)
  109.             running_corrects += torch.sum(preds == labels.data)
  110.  
  111.         epoch_loss = running_loss / dataset_size['train']
  112.         epoch_acc = running_corrects.double() / dataset_size['train']
  113.  
  114.         print('Loss: {:.4f} Acc: {:.4f}'.format(epoch_loss, epoch_acc))
  115.  
  116.  
  117.  
  118.     time_elapsed = time.time() - since
  119.     print('Training complete in {:.0f}m {:.0f}s'.format(
  120.         time_elapsed // 60, time_elapsed % 60))
  121.  
  122.  
  123.  
  124.     return model
  125.  
  126.  
  127. data_transforms = {
  128.     'train': transforms.Compose([
  129.         transforms.RandomResizedCrop(224),
  130.         transforms.RandomHorizontalFlip(),
  131.         transforms.RandomRotation(20),
  132.         transforms.ToTensor(),
  133.         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  134.     ]),
  135.     'test': transforms.Compose([
  136.         transforms.Resize(224),
  137.         transforms.ToTensor(),
  138.         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  139.     ])
  140. }
  141.  
  142.  
  143. data_dir = "test_images"
  144. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  145.  
  146.  
  147. model_ft = models.resnet50(pretrained=True)
  148. num_ftrs = model_ft.fc.in_features
  149. model_ft.fc = nn.Linear(num_ftrs, 2)
  150. model_ft = model_ft.to(device)
  151.  
  152. # Save a clone of initial model to restore later
  153. initial_model = copy.deepcopy(model_ft)
  154.  
  155. criterion = nn.CrossEntropyLoss()
  156.  
  157. # Observe that all parameters are being optimized
  158. optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
  159.  
  160. # Decay LR by a factor of 0.1 every 7 epochs
  161. exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
  162.  
  163. #model_ft = model_ft.cuda()
  164. nb_samples = 50
  165. nb_classes = 2
  166.  
  167. image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
  168.                                           data_transforms[x])
  169.                   for x in ['train']}
  170.  
  171. dataset_size = {x: len(image_datasets[x]) for x in ['train']}
  172. class_names = image_datasets['train'].classes
  173.  
  174. # LOOCV
  175. loocv_preds = []
  176. loocv_targets = []
  177. for idx in range(nb_samples):
  178.  
  179.     print('Using sample {} as test data'.format(idx))
  180.  
  181.     print('Resetting model')
  182.     model_ft = copy.deepcopy(initial_model)
  183.  
  184.     # Get all indices and remove test sample
  185.     train_indices = list(range(len(image_datasets['train'])))
  186.     del train_indices[idx]
  187.  
  188.     # Create new sampler
  189.     sampler = data.SubsetRandomSampler(train_indices)
  190.     dataloader = data.DataLoader(
  191.         image_datasets['train'],
  192.         num_workers=2,
  193.         batch_size=1,
  194.         sampler=sampler
  195.     )
  196.  
  197.     # Train model
  198.     '''for batch_idx, (sample, target) in enumerate(dataloader):
  199.        print('Batch {}'.format(batch_idx))
  200.        model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, sample, target, num_epochs=10)'''
  201.  
  202.     model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, dataloader, num_epochs=10)
  203.  
  204.     # Test on LOO sample
  205.     model_ft.eval()
  206.     test_data, test_target = image_datasets['train'][idx]
  207.     # Apply test preprocessing on data
  208.     print(type(test_data))
  209.     test_data = data_transforms['test'](transforms.ToPILImage()(test_data))
  210.     test_data = test_data.cuda()
  211.     test_target = torch.tensor(test_target)
  212.     test_target = test_target.cuda()
  213.     test_data.unsqueeze_(0)
  214.     test_target.unsqueeze_(0)
  215.     print(test_data.shape)
  216.     output = model_ft(test_data)
  217.     pred = torch.argmax(output, 1)
  218.     loocv_preds.append(pred)
  219.     loocv_targets.append(test_target.item())
  220.  
  221.  
  222. print("loocv preds: ", loocv_preds)
  223. print("loocv targets: ", loocv_targets)
  224. print("acc score: ", accuracy_score(loocv_targets, loocv_preds))
  225. print("confusion matrix: \n", confusion_matrix(loocv_targets, loocv_preds))
  226. print('Class probabilities: {}'.format(F.softmax(output, 1)))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement