SHARE
TWEET

Untitled

a guest Nov 18th, 2019 107 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import torch
  2. import torch.nn as nn
  3. import torch.optim as optim
  4. import numpy as np
  5. import torchvision
  6. from torchvision import datasets, models, transforms
  7. import matplotlib.pyplot as plt
  8. import torchvision.models as models
  9. import matplotlib.pyplot as plt
  10. import numpy as np
  11. import copy
  12. import time
  13.  
  14.  
  15. def train_model(model, criterion, optimizer, scheduler, dataloader, dataset_size, num_epochs=25):
  16.     since = time.time()
  17.  
  18.     best_model_wts = copy.deepcopy(model.state_dict())
  19.     best_acc = 0.0
  20.  
  21.     for epoch in range(num_epochs):
  22.         print('Epoch {}/{}'.format(epoch, num_epochs - 1))
  23.         print('-' * 10)
  24.  
  25.         # Each epoch has a training and validation phase
  26.         for phase in ['train', 'val']:
  27.             if phase == 'train':
  28.                 model.train()  # Set model to training mode
  29.             else:
  30.                 model.eval()   # Set model to evaluate mode
  31.  
  32.             running_loss = 0.0
  33.             running_corrects = 0
  34.  
  35.             # Iterate over data.
  36.             for inputs, labels in dataloader:
  37.  
  38.                 # zero the parameter gradients
  39.                 optimizer.zero_grad()
  40.  
  41.                 # forward
  42.                 # track history if only in train
  43.                 with torch.set_grad_enabled(phase == 'train'):
  44.                     outputs = model(inputs)
  45.                     _, preds = torch.max(outputs, 1)
  46.                     loss = criterion(outputs, labels)
  47.  
  48.                     # backward + optimize only if in training phase
  49.                     if phase == 'train':
  50.                         loss.backward()
  51.                         optimizer.step()
  52.  
  53.                 # statistics
  54.                 running_loss += loss.item() * inputs.size(0)
  55.                 running_corrects += torch.sum(preds == labels.data)
  56.             if phase == 'train':
  57.                 scheduler.step()
  58.  
  59.             epoch_loss = running_loss / dataset_size
  60.             epoch_acc = running_corrects.double() / dataset_size
  61.  
  62.             print('{} Loss: {:.4f} Acc: {:.4f}'.format(
  63.                 phase, epoch_loss, epoch_acc))
  64.  
  65.             # deep copy the model
  66.             if phase == 'val' and epoch_acc > best_acc:
  67.                 best_acc = epoch_acc
  68.                 best_model_wts = copy.deepcopy(model.state_dict())
  69.  
  70.         print()
  71.  
  72.     time_elapsed = time.time() - since
  73.     print('Training complete in {:.0f}m {:.0f}s'.format(
  74.         time_elapsed // 60, time_elapsed % 60))
  75.     print('Best val Acc: {:4f}'.format(best_acc))
  76.  
  77.     # load best model weights
  78.     model.load_state_dict(best_model_wts)
  79.     return model
  80.  
  81.  
  82. resnet18 = models.resnet18(pretrained=True)
  83. resnet18.fc = nn.Linear(512, 10)
  84. resnet18.conv1 = nn.Conv2d(1,64, kernel_size=(7,7), stride = (2,2), padding=(3,3), bias=False)
  85. print(resnet18)
  86. transform = transforms.Compose(
  87.     [transforms.Resize((3,64)),
  88.     transforms.ToTensor()])
  89.  
  90. mnist_testset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
  91. # print((mnist_testset[0]))
  92. testloader = torch.utils.data.DataLoader(mnist_testset, batch_size=4,shuffle=False, num_workers=2)
  93.  
  94. dataiter = iter(testloader)
  95. # images, labels = dataiter.next()
  96. # outputs = resnet18(images)
  97. # _, predicted = torch.max(outputs, 1)
  98. corr = 0
  99. for ii, (inputs, labels) in enumerate(testloader):
  100.     outputs = resnet18(inputs)
  101.     _, predicted = torch.max(outputs, 1)
  102.     x = predicted.cpu().detach().numpy()
  103.     y = labels.cpu().detach().numpy()
  104.     for i in range(4):
  105.         if x[i] == y[i]:
  106.             corr += 1
  107. print(corr)
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top