Advertisement
lamiastella

working

Nov 13th, 2018
411
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.73 KB | None | 0 0
  1. #model_ft = model_ft.cuda()
  2. nb_samples = 931
  3. nb_classes = 9
  4.  
  5.  
  6. from __future__ import print_function, division
  7.  
  8. import torch
  9. import torch.nn as nn
  10. import torch.optim as optim
  11. from torch.optim import lr_scheduler
  12. import numpy as np
  13. import torchvision
  14. from torchvision import datasets, models, transforms
  15. import matplotlib.pyplot as plt
  16. import time
  17. import os
  18. import copy
  19.  
  20. import torch.utils.data as data_utils
  21. from torch.utils import data
  22.  
  23.  
  24.  
  25.  
  26.  
  27. data_transforms = {
  28.     'train': transforms.Compose([
  29.         transforms.RandomResizedCrop(224),
  30.         transforms.RandomHorizontalFlip(),
  31.         transforms.RandomRotation(20),
  32.         transforms.ToTensor(),
  33.         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  34.     ]),
  35.  
  36.         'test': transforms.Compose([
  37.         transforms.Resize(256),
  38.         transforms.CenterCrop(224),
  39.         transforms.ToTensor(),
  40.         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  41.     ]),
  42. }
  43.  
  44. val_loader = data.DataLoader(
  45.         image_datasets['test'],
  46.         num_workers=2,
  47.         batch_size=1
  48.     )
  49. val_loader = iter(val_loader)
  50.  
  51. data_dir = "images"
  52. image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
  53.                                           data_transforms[x])
  54.                   for x in ['train', 'test']}
  55.  
  56. dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'test']}
  57. print(dataset_sizes)
  58. class_names = image_datasets['train'].classes
  59.  
  60. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  61. # LOOCV
  62. loocv_preds = []
  63. loocv_targets = []
  64. for idx in range(nb_samples):
  65.    
  66.     print('Using sample {} as test data'.format(idx))
  67.    
  68.     # Get all indices and remove test sample
  69.     train_indices = list(range(len(image_datasets)))
  70.     del train_indices[idx]
  71.    
  72.     # Create new sampler
  73.     sampler = data.SubsetRandomSampler(train_indices)
  74.  
  75.     dataloader = data.DataLoader(
  76.         image_datasets['train'],
  77.         num_workers=2,
  78.         batch_size=1,
  79.         sampler=sampler
  80.     )
  81.    
  82.     # Train model
  83.     for batch_idx, (samples, target) in enumerate(dataloader):
  84.         print('Batch {}'.format(batch_idx))
  85.         model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=25) # do I add this line here?
  86.                
  87.     # Test on LOO sample
  88.     model_ft.eval()
  89. #    test_data, test_target = image_datasets['train'][idx]
  90.     test_data, test_target = val_loader.next()
  91.     test_data = test_data.cuda()
  92.     test_target = test_target.cuda()
  93.     #test_data.unsqueeze_(1)
  94.     #test_target.unsqueeze_(0)
  95.  
  96.     output = model_ft(test_data)
  97.     pred = torch.argmax(output, 1)
  98.     loocv_preds.append(pred)
  99.     loocv_targets.append(test_target.item())
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement