Advertisement
Guest User

Untitled

a guest
Apr 7th, 2020
208
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.56 KB | None | 0 0
  1. import torch
  2. import torch.nn as nn
  3. import torchvision.datasets
  4. import torchvision.transforms
  5. import torchvision.models as models
  6. import numpy as np
  7.  
  8. # Device configuration
  9. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  10.  
  11. # Hyper-parameters
  12. num_epochs = 10
  13. learning_rate = 0.001
  14. num_classes = 10
  15. batch_size = 64
  16.  
  17. # arch = 'resnet'
  18. arch = 'mobilenet'
  19.  
  20. # Image preprocessing modules
  21. transform = torchvision.transforms.Compose(
  22.     [torchvision.transforms.ToTensor(),
  23.      torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
  24.  
  25. # CIFAR-10 dataset
  26. train_dataset = torchvision.datasets.CIFAR10(root='../../data/',
  27.                                              train=True,
  28.                                              transform=transform,
  29.                                              download=True)
  30.  
  31. test_dataset = torchvision.datasets.CIFAR10(root='../../data/',
  32.                                             train=False,
  33.                                             transform=transform)
  34.  
  35. # Data loader
  36. train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
  37.                                            batch_size=batch_size,
  38.                                            shuffle=True)
  39.  
  40. test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
  41.                                           batch_size=batch_size,
  42.                                           shuffle=False)
  43.  
  44. resnet50 = models.resnet50(pretrained=True)
  45. resnet50.fc = nn.Linear(resnet50.fc.in_features, num_classes)
  46.  
  47. mobilenetv2 = models.mobilenet_v2(pretrained=True)
  48. mobilenetv2.classifier[1] = nn.Linear(mobilenetv2.classifier[1].in_features, num_classes)
  49.  
  50. resnet50_cpt = torch.load('resnet50.pt')  # 77.36 %
  51. resnet50.load_state_dict(resnet50_cpt['state_dict'])
  52.  
  53. mobilenetv2_cpt = torch.load('mobilenetv2.pt')  # 81.78 %
  54. mobilenetv2.load_state_dict(mobilenetv2_cpt['state_dict'])
  55.  
  56. resnet50 = resnet50.to(device)
  57. mobilenetv2 = mobilenetv2.to(device)
  58.  
  59. models = [resnet50, mobilenetv2]
  60.  
  61. # print(f"resnet50: {resnet50}")
  62. # print(f"mobilenetv2: {mobilenetv2}")
  63.  
  64.  
  65. # define an objective function
  66. def fn_objective(W):
  67.     W = torch.Tensor(W) / sum(W)
  68.     [model.eval() for model in models]
  69.  
  70.     _softmax = nn.Softmax(dim=1)
  71.  
  72.     with torch.no_grad():
  73.         correct = 0
  74.         total = 0
  75.         for images, labels in test_loader:
  76.             images = images.to(device)
  77.             labels = labels.to(device)
  78.  
  79.             outputs = torch.zeros((labels.size(0), num_classes)).to(device)
  80.             for i, model in enumerate(models):
  81.                 outputs += W[i] * _softmax(model(images))
  82.             # res_output = resnet50(images)
  83.             # mob_output = mobilenetv2(images)
  84.             # outputs = w * _softmax(res_output) + (1 - w) * _softmax(mob_output)
  85.  
  86.             _, predicted = torch.max(outputs.data, 1)
  87.             total += labels.size(0)
  88.             correct += (predicted == labels).sum().item()
  89.  
  90.         acc = 100 * correct / total
  91.         print(f"[w: {W}] Accuracy of the model on the test images: {acc} %")
  92.  
  93.     return -1 * acc
  94.  
  95.  
  96. # define a search space
  97. # minimize the objective over the space
  98. from hyperopt import hp, fmin, tpe, space_eval
  99.  
  100. space = [hp.uniform(f'w{i}', 0, 1) for i in range(len(models))]
  101. best = fmin(lambda W: fn_objective(W),
  102.             space, algo=tpe.suggest, max_evals=20, verbose=True)
  103.  
  104. print(f"best: {best}")
  105. print(f"space_eval(space, best): {space_eval(space, best)}")
  106.  
  107. '''
  108. 'w': 0.39289687565992365
  109.  
  110. 'w0': 0.6592927150223445, 'w1': 0.9741183172004064
  111. >> 0.4036294, 0.5963706
  112.  
  113. best acc: 83.92
  114. '''
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement