Advertisement
lamiastella

modified modified loocv tl

Nov 25th, 2018
391
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 6.60 KB | None | 0 0
  1. from __future__ import print_function, division
  2.  
  3. import torch
  4. from torch.autograd import Variable
  5. from sklearn.metrics import accuracy_score
  6. from sklearn.metrics import confusion_matrix
  7. import torch
  8. import torch.nn as nn
  9. import torch.optim as optim
  10. import torch.nn.functional as F
  11. from torch.optim import lr_scheduler
  12. import numpy as np
  13. import torchvision
  14. from torchvision import datasets, models, transforms
  15. import matplotlib.pyplot as plt
  16. import time
  17. import os
  18. import copy
  19.  
  20. import torch.utils.data as data_utils
  21. from torch.utils import data
  22.  
  23. torch.manual_seed(2809)
  24.  
  25.  
  26. def train_model(model, criterion, optimizer, scheduler,
  27.                 dataloader, num_epochs=25):
  28.     '''since = time.time()
  29.    model.train()  # Set model to training mode
  30.    for epoch in range(num_epochs):
  31.        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
  32.        print('-' * 10)
  33.  
  34.        scheduler.step()
  35.  
  36.        running_loss = 0.0
  37.        running_corrects = 0
  38.  
  39.        # Iterate over data.
  40.        train_input = train_input.to(device)
  41.        train_label = train_label.to(device)
  42.  
  43.        # zero the parameter gradients
  44.        optimizer.zero_grad()
  45.  
  46.        output = model(train_input)
  47.        _, pred = torch.max(output, 1)
  48.        loss = criterion(output, train_label)
  49.        loss.backward()
  50.        optimizer.step()
  51.  
  52.        # statistics
  53.        running_loss += loss.item() * train_input.size(0)
  54.        running_corrects += torch.sum(pred == train_label.data)
  55.  
  56.        epoch_loss = running_loss / dataset_size['train']
  57.        epoch_acc = running_corrects.double() / dataset_size['train']
  58.  
  59.        print('train Loss: {:.4f} Acc: {:.4f}'.format(epoch_loss, epoch_acc))
  60.  
  61.        print()
  62.  
  63.    time_elapsed = time.time() - since
  64.    print('Training complete in {:.0f}m {:.0f}s'.format(
  65.        time_elapsed // 60, time_elapsed % 60))
  66.  
  67.    return model'''
  68.  
  69.     since = time.time()
  70.  
  71.     for epoch in range(num_epochs):
  72.         print('Epoch {}/{}'.format(epoch, num_epochs - 1))
  73.         print('-' * 10)
  74.  
  75.         # Each epoch has a training and validation phase
  76.  
  77.         scheduler.step()
  78.         model.train()  # Set model to training mode
  79.  
  80.  
  81.         running_loss = 0.0
  82.         running_corrects = 0
  83.  
  84.         # Iterate over data.
  85.         for inputs, labels in dataloader:
  86.             inputs = inputs.to(device)
  87.             labels = labels.to(device)
  88.  
  89.             # zero the parameter gradients
  90.             optimizer.zero_grad()
  91.  
  92.             # forward
  93.             # track history if only in train
  94.             with torch.set_grad_enabled(True):
  95.                 outputs = model(inputs)
  96.                 _, preds = torch.max(outputs, 1)
  97.                 loss = criterion(outputs, labels)
  98.                 # backward + optimize only if in training phase
  99.                 loss.backward()
  100.                 optimizer.step()
  101.             # statistics
  102.             running_loss += loss.item() * inputs.size(0)
  103.             running_corrects += torch.sum(preds == labels.data)
  104.  
  105.         epoch_loss = running_loss / dataset_size['train']
  106.         epoch_acc = running_corrects.double() / dataset_size['train']
  107.  
  108.         print('Loss: {:.4f} Acc: {:.4f}'.format(epoch_loss, epoch_acc))
  109.  
  110.  
  111.  
  112.     time_elapsed = time.time() - since
  113.     print('Training complete in {:.0f}m {:.0f}s'.format(
  114.         time_elapsed // 60, time_elapsed % 60))
  115.  
  116.  
  117.  
  118.     return model
  119.  
  120.  
  121. data_transforms = {
  122.     'train': transforms.Compose([
  123.         transforms.RandomResizedCrop(224),
  124.         transforms.RandomHorizontalFlip(),
  125.         transforms.RandomRotation(20),
  126.         transforms.ToTensor(),
  127.         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  128.     ]),
  129.     'test': transforms.Compose([
  130.         transforms.Resize(224),
  131.         transforms.ToTensor(),
  132.         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  133.     ])
  134. }
  135.  
  136.  
  137. data_dir = "test_images"
  138. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  139.  
  140. model_ft = models.resnet50(pretrained=True)
  141. num_ftrs = model_ft.fc.in_features
  142. model_ft.fc = nn.Linear(num_ftrs, 2)
  143. model_ft = model_ft.to(device)
  144.  
  145. # Save a clone of initial model to restore later
  146. initial_model = copy.deepcopy(model_ft)
  147.  
  148. criterion = nn.CrossEntropyLoss()
  149.  
  150. # Observe that all parameters are being optimized
  151. optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
  152.  
  153. # Decay LR by a factor of 0.1 every 7 epochs
  154. exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
  155.  
  156. #model_ft = model_ft.cuda()
  157. nb_samples = 50
  158. nb_classes = 2
  159.  
  160. image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
  161.                                           data_transforms[x])
  162.                   for x in ['train']}
  163.  
  164. dataset_size = {x: len(image_datasets[x]) for x in ['train']}
  165. class_names = image_datasets['train'].classes
  166.  
  167. # LOOCV
  168. loocv_preds = []
  169. loocv_targets = []
  170. for idx in range(nb_samples):
  171.  
  172.     print('Using sample {} as test data'.format(idx))
  173.  
  174.     print('Resetting model')
  175.     model_ft = copy.deepcopy(initial_model)
  176.  
  177.     # Get all indices and remove test sample
  178.     train_indices = list(range(len(image_datasets['train'])))
  179.     del train_indices[idx]
  180.  
  181.     # Create new sampler
  182.     sampler = data.SubsetRandomSampler(train_indices)
  183.     dataloader = data.DataLoader(
  184.         image_datasets['train'],
  185.         num_workers=2,
  186.         batch_size=1,
  187.         sampler=sampler
  188.     )
  189.  
  190.     # Train model
  191.     '''for batch_idx, (sample, target) in enumerate(dataloader):
  192.        print('Batch {}'.format(batch_idx))
  193.        model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, sample, target, num_epochs=10)'''
  194.  
  195.     model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, dataloader, num_epochs=10)
  196.  
  197.     # Test on LOO sample
  198.     model_ft.eval()
  199.     test_data, test_target = image_datasets['train'][idx]
  200.     # Apply test preprocessing on data
  201.     print(type(test_data))
  202.     test_data = data_transforms['test'](transforms.ToPILImage()(test_data))
  203.     test_data = test_data.cuda()
  204.     test_target = torch.tensor(test_target)
  205.     test_target = test_target.cuda()
  206.     test_data.unsqueeze_(0)
  207.     test_target.unsqueeze_(0)
  208.     print(test_data.shape)
  209.     output = model_ft(test_data)
  210.     pred = torch.argmax(output, 1)
  211.     loocv_preds.append(pred)
  212.     loocv_targets.append(test_target.item())
  213.  
  214.  
  215. print("loocv preds: ", loocv_preds)
  216. print("loocv targets: ", loocv_targets)
  217. print("acc score: ", accuracy_score(loocv_targets, loocv_preds))
  218. print("confusion matrix: \n", confusion_matrix(loocv_targets, loocv_preds))
  219. print("confidence score for each image: ", F.softmax(output, 1))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement