SHARE
TWEET

Untitled

a guest Nov 22nd, 2019 92 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. from __future__ import print_function
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. import torch.optim as optim
  6. #from torchvision import datasets, transforms
  7. import matplotlib.pyplot as plt
  8. import numpy as np
  9. import torch.utils.data as utils
  10.  
  11. # The parts that you should complete are designated as TODO
  12. class ConvNet(nn.Module):
  13.     def __init__(self):
  14.         super(ConvNet, self).__init__()
  15.         # TODO: define the layers of the network
  16.  
  17.         self.Cn1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(3, 3))
  18.         self.Cn2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3, 3))
  19.         self.MP = nn.MaxPool2d(kernel_size = (2,2), stride = 2)
  20.         self.dp = nn.Dropout2d(p = .25)
  21.         self.lin1 = nn.Linear(9216,128)
  22.         self.dp2 = nn.Dropout(p = .5)
  23.         self.fin = nn.Linear(128,10)
  24.         self.firstPrint = True;
  25.  
  26.        
  27.  
  28.     def forward(self, x):
  29.         # TODO: define the forward pass of the network using the layers you defined in constructor
  30.         output1 = torch.relu(self.Cn1(x))
  31.         #print(output1.shape)
  32.         output2 = torch.relu(self.Cn2(output1))
  33.         #print(output2.shape)
  34.         output3 = self.MP(output2)
  35.         #print(output3.shape)
  36.         dropPut = self.dp(output3)
  37.         #print(dropPut.shape)
  38.         flatPut = torch.flatten(dropPut, start_dim = 1)
  39.         #print(flatPut.shape)
  40.         #print(flatPut.shape)
  41.         linPut  = torch.relu(self.lin1(flatPut))
  42.         #print(linPut.shape)
  43.         dropPut = self.dp2(linPut)
  44.         #print(dropPut.shape)
  45.         finPut  = torch.softmax(self.fin(dropPut), dim = 1)
  46.         #print(finPut.shape)
  47.         return finPut;
  48.    
  49.    
  50. def train(model, device, train_loader, optimizer, epoch):
  51.     model.train()
  52.     correct = 0
  53.     for batch_idx, (data, target) in enumerate(train_loader):
  54.         data, target = data.to(device), target.to(device)
  55.         optimizer.zero_grad()
  56.         output = model(data)
  57.         loss = F.cross_entropy(output, target)
  58.         loss.backward()
  59.         optimizer.step()
  60.         if batch_idx % 100 == 0: #Print loss every 100 batch
  61.             print('Train Epoch: {}\tLoss: {:.6f}'.format(
  62.                 epoch, loss.item()))
  63.     accuracy = test(model, device, train_loader)
  64.     return accuracy
  65.  
  66. def test(model, device, test_loader):
  67.     model.eval()
  68.     correct = 0
  69.     with torch.no_grad():
  70.         for data, target in test_loader:
  71.             data, target = data.to(device), target.to(device)
  72.             output = model(data)
  73.             pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
  74.             correct += pred.eq(target.view_as(pred)).sum().item()
  75.  
  76.     accuracy = 100. * correct / len(test_loader.dataset)
  77.  
  78.     return accuracy
  79.  
  80.  
  81. def main():
  82.     torch.manual_seed(1)
  83.     np.random.seed(1)
  84.     # Training settings
  85.     use_cuda = False # Switch to False if you only want to use your CPU
  86.     learning_rate = 0.01
  87.     NumEpochs = 10
  88.     batch_size = 32
  89.  
  90.     device = torch.device("cuda" if use_cuda else "cpu")
  91.  
  92.     train_X = np.load('D:/589_MachineLearning/hw4/Data/X_train.npy')
  93.     train_Y = np.load('D:/589_MachineLearning/hw4/Data/y_train.npy')
  94.  
  95.     test_X = np.load('D:/589_MachineLearning/hw4/Data/X_test.npy')
  96.     test_Y = np.load('D:/589_MachineLearning/hw4/Data/y_test.npy')
  97.  
  98.     train_X = train_X.reshape([-1,1,28,28]) # the data is flatten so we reshape it here to get to the original dimensions of images
  99.     test_X = test_X.reshape([-1,1,28,28])
  100.  
  101.     # transform to torch tensors
  102.     tensor_x = torch.tensor(train_X, device=device)
  103.     tensor_y = torch.tensor(train_Y, dtype=torch.long, device=device)
  104.  
  105.     test_tensor_x = torch.tensor(test_X, device=device)
  106.     test_tensor_y = torch.tensor(test_Y, dtype=torch.long)
  107.  
  108.     train_dataset = utils.TensorDataset(tensor_x,tensor_y) # create your datset
  109.     train_loader = utils.DataLoader(train_dataset, batch_size=batch_size) # create your dataloader
  110.  
  111.     test_dataset = utils.TensorDataset(test_tensor_x,test_tensor_y) # create your datset
  112.     test_loader = utils.DataLoader(test_dataset) # create your dataloader if you get a error when loading test data you can set a batch_size here as well like train_dataloader
  113.  
  114.     model = ConvNet().to(device)
  115.     optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0.0001)
  116.    
  117.     testyArr = []
  118.     trainyArr = []
  119.     for epoch in range(NumEpochs):
  120.         train_acc = train(model, device, train_loader, optimizer, epoch)
  121.         trainyArr.append(train_acc)
  122.         print('\nTrain set Accuracy: {:.0f}%\n'.format(train_acc))
  123.         test_acc = test(model, device, test_loader)
  124.         testyArr.append(test_acc)
  125.         print('\nTest set Accuracy: {:.0f}%\n'.format(test_acc))
  126.  
  127.     torch.save(model.state_dict(), "mnist_cnn.pt")
  128.  
  129.     #TODO: Plot train and test accuracy vs epoch
  130.     plt.plot(trainyArr, range(NumEpochs),label = "Train")
  131.     plt.plot(testyArr,  range(NumEpochs),label = "Test")
  132.     plt.xlabel("accuracy")
  133.     plt.ylabel("Epochs")
  134.     plt.legend()
  135.     plt.show();
  136.  
  137. if __name__ == '__main__':
  138.     main()
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top