Advertisement
lamiastella

loocv modified

Nov 25th, 2018
883
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 6.86 KB | None | 0 0
  1. from __future__ import print_function, division
  2.  
  3. import torch
  4. from torch.autograd import Variable
  5. from sklearn.metrics import accuracy_score
  6. from sklearn.metrics import confusion_matrix
  7. import torch
  8. import torch.nn as nn
  9. import torch.optim as optim
  10. import torch.nn.functional as F
  11. from torch.optim import lr_scheduler
  12. import numpy as np
  13. import torchvision
  14. from torchvision import datasets, models, transforms
  15. import matplotlib.pyplot as plt
  16. import time
  17. import os
  18. import copy
  19.  
  20. import torch.utils.data as data_utils
  21. from torch.utils import data
  22.  
  23. torch.manual_seed(2809)
  24.  
  25.  
  26. def train_model(model, criterion, optimizer, scheduler,
  27.                 dataloader, num_epochs=25):
  28.     '''since = time.time()
  29.    model.train()  # Set model to training mode
  30.    for epoch in range(num_epochs):
  31.        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
  32.        print('-' * 10)
  33.  
  34.        scheduler.step()
  35.  
  36.        running_loss = 0.0
  37.        running_corrects = 0
  38.  
  39.        # Iterate over data.
  40.        train_input = train_input.to(device)
  41.        train_label = train_label.to(device)
  42.  
  43.        # zero the parameter gradients
  44.        optimizer.zero_grad()
  45.  
  46.        output = model(train_input)
  47.        _, pred = torch.max(output, 1)
  48.        loss = criterion(output, train_label)
  49.        loss.backward()
  50.        optimizer.step()
  51.  
  52.        # statistics
  53.        running_loss += loss.item() * train_input.size(0)
  54.        running_corrects += torch.sum(pred == train_label.data)
  55.  
  56.        epoch_loss = running_loss / dataset_size['train']
  57.        epoch_acc = running_corrects.double() / dataset_size['train']
  58.  
  59.        print('train Loss: {:.4f} Acc: {:.4f}'.format(epoch_loss, epoch_acc))
  60.  
  61.        print()
  62.  
  63.    time_elapsed = time.time() - since
  64.    print('Training complete in {:.0f}m {:.0f}s'.format(
  65.        time_elapsed // 60, time_elapsed % 60))
  66.  
  67.    return model'''
  68.  
  69.     since = time.time()
  70.     for epoch in range(num_epochs):
  71.         print('Epoch {}/{}'.format(epoch, num_epochs - 1))
  72.         print('-' * 10)
  73.  
  74.         # Each epoch has a training and validation phase
  75.  
  76.         scheduler.step()
  77.         model.train()  # Set model to training mode
  78.  
  79.  
  80.         running_loss = 0.0
  81.         running_corrects = 0
  82.  
  83.         # Iterate over data.
  84.         for inputs, labels in dataloader:
  85.             inputs = inputs.to(device)
  86.             labels = labels.to(device)
  87.  
  88.             # zero the parameter gradients
  89.             optimizer.zero_grad()
  90.  
  91.             # forward
  92.             # track history if only in train
  93.             with torch.set_grad_enabled(True):
  94.                 outputs = model(inputs)
  95.                 _, preds = torch.max(outputs, 1)
  96.                 loss = criterion(outputs, labels)
  97.                 # backward + optimize only if in training phase
  98.                 loss.backward()
  99.                 optimizer.step()
  100.             # statistics
  101.             running_loss += loss.item() * inputs.size(0)
  102.             running_corrects += torch.sum(preds == labels.data)
  103.  
  104.         epoch_loss = running_loss / dataset_size['train']
  105.         epoch_acc = running_corrects.double() / dataset_size['train']
  106.  
  107.         print('Loss: {:.4f} Acc: {:.4f}'.format(epoch_loss, epoch_acc))
  108.  
  109.  
  110.  
  111.     time_elapsed = time.time() - since
  112.     print('Training complete in {:.0f}m {:.0f}s'.format(
  113.         time_elapsed // 60, time_elapsed % 60))
  114.  
  115.  
  116.  
  117.     return model
  118.  
  119.  
  120. data_transforms = {
  121.     'train': transforms.Compose([
  122.         transforms.RandomResizedCrop(224),
  123.         transforms.RandomHorizontalFlip(),
  124.         transforms.RandomRotation(20),
  125.         transforms.ToTensor(),
  126.         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  127.     ]),
  128.     'test': transforms.Compose([
  129.         transforms.Resize(224),
  130.         transforms.ToTensor(),
  131.         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  132.     ])
  133. }
  134.  
  135.  
  136. data_dir = "test_images"
  137. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  138.  
  139.  
  140. model_ft = models.resnet50(pretrained=True)
  141. num_ftrs = model_ft.fc.in_features
  142. model_ft.fc = nn.Linear(num_ftrs, 2)
  143. model_ft = model_ft.to(device)
  144.  
  145. # Save a clone of initial model to restore later
  146. initial_model = copy.deepcopy(model_ft)
  147.  
  148. criterion = nn.CrossEntropyLoss()
  149.  
  150. # Observe that all parameters are being optimized
  151. optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
  152.  
  153. # Decay LR by a factor of 0.1 every 7 epochs
  154. exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
  155.  
  156. #model_ft = model_ft.cuda()
  157. nb_samples = 50
  158. nb_classes = 2
  159.  
  160. image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
  161.                                           data_transforms[x])
  162.                   for x in ['train']}
  163.  
  164. dataset_size = {x: len(image_datasets[x]) for x in ['train']}
  165. class_names = image_datasets['train'].classes
  166.  
  167. # LOOCV
  168. loocv_preds = []
  169. loocv_targets = []
  170. for idx in range(nb_samples):
  171.  
  172.     print('Using sample {} as test data'.format(idx))
  173.  
  174.     print('Resetting model')
  175.     model_ft = copy.deepcopy(initial_model)
  176.     # Observe that all parameters are being optimized
  177.     optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
  178.  
  179.     # Decay LR by a factor of 0.1 every 7 epochs
  180.     exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
  181.  
  182.     # Get all indices and remove test sample
  183.     train_indices = list(range(len(image_datasets['train'])))
  184.     del train_indices[idx]
  185.  
  186.     # Create new sampler
  187.     sampler = data.SubsetRandomSampler(train_indices)
  188.     dataloader = data.DataLoader(
  189.         image_datasets['train'],
  190.         num_workers=2,
  191.         batch_size=1,
  192.         sampler=sampler
  193.     )
  194.  
  195.     # Train model
  196.     '''for batch_idx, (sample, target) in enumerate(dataloader):
  197.        print('Batch {}'.format(batch_idx))
  198.        model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, sample, target, num_epochs=10)'''
  199.  
  200.     model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, dataloader, num_epochs=25)
  201.  
  202.     # Test on LOO sample
  203.     model_ft.eval()
  204.     test_data, test_target = image_datasets['train'][idx]
  205.     # Apply test preprocessing on data
  206.     print(type(test_data))
  207.     test_data = data_transforms['test'](transforms.ToPILImage()(test_data))
  208.     test_data = test_data.cuda()
  209.     test_target = torch.tensor(test_target)
  210.     test_target = test_target.cuda()
  211.     test_data.unsqueeze_(0)
  212.     test_target.unsqueeze_(0)
  213.     print(test_data.shape)
  214.     output = model_ft(test_data)
  215.     pred = torch.argmax(output, 1)
  216.     loocv_preds.append(pred)
  217.     loocv_targets.append(test_target.item())
  218.  
  219.  
  220. print("loocv preds: ", loocv_preds)
  221. print("loocv targets: ", loocv_targets)
  222. print("acc score: ", accuracy_score(loocv_targets, loocv_preds))
  223. print("confusion matrix: \n", confusion_matrix(loocv_targets, loocv_preds))
  224. print('Class probabilities: {}'.format(F.softmax(output, 1)))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement