Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch
- import torch.nn as nn
- import torchvision.datasets
- import torchvision.transforms
- import torchvision.models as models
- import numpy as np
- # Device configuration
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- # Hyper-parameters
- num_epochs = 10
- learning_rate = 0.001
- num_classes = 10
- batch_size = 64
- # arch = 'resnet'
- arch = 'mobilenet'
- # Image preprocessing modules
- transform = torchvision.transforms.Compose(
- [torchvision.transforms.ToTensor(),
- torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
- # CIFAR-10 dataset
- train_dataset = torchvision.datasets.CIFAR10(root='../../data/',
- train=True,
- transform=transform,
- download=True)
- test_dataset = torchvision.datasets.CIFAR10(root='../../data/',
- train=False,
- transform=transform)
- # Data loader
- train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
- batch_size=batch_size,
- shuffle=True)
- test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
- batch_size=batch_size,
- shuffle=False)
- resnet50 = models.resnet50(pretrained=True)
- resnet50.fc = nn.Linear(resnet50.fc.in_features, num_classes)
- mobilenetv2 = models.mobilenet_v2(pretrained=True)
- mobilenetv2.classifier[1] = nn.Linear(mobilenetv2.classifier[1].in_features, num_classes)
- resnet50_cpt = torch.load('resnet50.pt') # 77.36 %
- resnet50.load_state_dict(resnet50_cpt['state_dict'])
- mobilenetv2_cpt = torch.load('mobilenetv2.pt') # 81.78 %
- mobilenetv2.load_state_dict(mobilenetv2_cpt['state_dict'])
- resnet50 = resnet50.to(device)
- mobilenetv2 = mobilenetv2.to(device)
- models = [resnet50, mobilenetv2]
- # print(f"resnet50: {resnet50}")
- # print(f"mobilenetv2: {mobilenetv2}")
- # define an objective function
- def fn_objective(W):
- W = torch.Tensor(W) / sum(W)
- [model.eval() for model in models]
- _softmax = nn.Softmax(dim=1)
- with torch.no_grad():
- correct = 0
- total = 0
- for images, labels in test_loader:
- images = images.to(device)
- labels = labels.to(device)
- outputs = torch.zeros((labels.size(0), num_classes)).to(device)
- for i, model in enumerate(models):
- outputs += W[i] * _softmax(model(images))
- # res_output = resnet50(images)
- # mob_output = mobilenetv2(images)
- # outputs = w * _softmax(res_output) + (1 - w) * _softmax(mob_output)
- _, predicted = torch.max(outputs.data, 1)
- total += labels.size(0)
- correct += (predicted == labels).sum().item()
- acc = 100 * correct / total
- print(f"[w: {W}] Accuracy of the model on the test images: {acc} %")
- return -1 * acc
- # define a search space
- # minimize the objective over the space
- from hyperopt import hp, fmin, tpe, space_eval
- space = [hp.uniform(f'w{i}', 0, 1) for i in range(len(models))]
- best = fmin(lambda W: fn_objective(W),
- space, algo=tpe.suggest, max_evals=20, verbose=True)
- print(f"best: {best}")
- print(f"space_eval(space, best): {space_eval(space, best)}")
- '''
- 'w': 0.39289687565992365
- 'w0': 0.6592927150223445, 'w1': 0.9741183172004064
- >> 0.4036294, 0.5963706
- best acc: 83.92
- '''
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement