Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch
- import torch.nn as nn
- import torch.optim as optim
- import numpy as np
- import torchvision
- from torchvision import datasets, models, transforms
- import matplotlib.pyplot as plt
- import torchvision.models as models
- import matplotlib.pyplot as plt
- import numpy as np
- import copy
- import time
- def train_model(model, criterion, optimizer, scheduler, dataloader, dataset_size, num_epochs=25):
- since = time.time()
- best_model_wts = copy.deepcopy(model.state_dict())
- best_acc = 0.0
- for epoch in range(num_epochs):
- print('Epoch {}/{}'.format(epoch, num_epochs - 1))
- print('-' * 10)
- # Each epoch has a training and validation phase
- for phase in ['train', 'val']:
- if phase == 'train':
- model.train() # Set model to training mode
- else:
- model.eval() # Set model to evaluate mode
- running_loss = 0.0
- running_corrects = 0
- # Iterate over data.
- for inputs, labels in dataloader:
- # zero the parameter gradients
- optimizer.zero_grad()
- # forward
- # track history if only in train
- with torch.set_grad_enabled(phase == 'train'):
- outputs = model(inputs)
- _, preds = torch.max(outputs, 1)
- loss = criterion(outputs, labels)
- # backward + optimize only if in training phase
- if phase == 'train':
- loss.backward()
- optimizer.step()
- # statistics
- running_loss += loss.item() * inputs.size(0)
- running_corrects += torch.sum(preds == labels.data)
- if phase == 'train':
- scheduler.step()
- epoch_loss = running_loss / dataset_size
- epoch_acc = running_corrects.double() / dataset_size
- print('{} Loss: {:.4f} Acc: {:.4f}'.format(
- phase, epoch_loss, epoch_acc))
- # deep copy the model
- if phase == 'val' and epoch_acc > best_acc:
- best_acc = epoch_acc
- best_model_wts = copy.deepcopy(model.state_dict())
- print()
- time_elapsed = time.time() - since
- print('Training complete in {:.0f}m {:.0f}s'.format(
- time_elapsed // 60, time_elapsed % 60))
- print('Best val Acc: {:4f}'.format(best_acc))
- # load best model weights
- model.load_state_dict(best_model_wts)
- return model
- resnet18 = models.resnet18(pretrained=True)
- resnet18.fc = nn.Linear(512, 10)
- resnet18.conv1 = nn.Conv2d(1,64, kernel_size=(7,7), stride = (2,2), padding=(3,3), bias=False)
- print(resnet18)
- transform = transforms.Compose(
- [transforms.Resize((3,64)),
- transforms.ToTensor()])
- mnist_testset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
- # print((mnist_testset[0]))
- testloader = torch.utils.data.DataLoader(mnist_testset, batch_size=4,shuffle=False, num_workers=2)
- dataiter = iter(testloader)
- # images, labels = dataiter.next()
- # outputs = resnet18(images)
- # _, predicted = torch.max(outputs, 1)
- corr = 0
- for ii, (inputs, labels) in enumerate(testloader):
- outputs = resnet18(inputs)
- _, predicted = torch.max(outputs, 1)
- x = predicted.cpu().detach().numpy()
- y = labels.cpu().detach().numpy()
- for i in range(4):
- if x[i] == y[i]:
- corr += 1
- print(corr)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement