Advertisement
Abhisek92

conv.py

Nov 1st, 2023
788
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.85 KB | None | 0 0
  1. import numpy as np
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from train_utils import batchify_data, run_epoch, train_model, Flatten
  6. import utils_multiMNIST as U
  7. path_to_data_dir = '../Datasets/'
  8. use_mini_dataset = True
  9.  
  10. batch_size = 64
  11. nb_classes = 10
  12. nb_epoch = 15
  13. num_classes = 10
  14. img_rows, img_cols = 42, 28 # input image dimensions
  15.  
  16.  
  17. class CNN(nn.Module):
  18.     def __init__(self, input_dimension):
  19.         super(CNN, self).__init__()
  20.         self.linear1 = nn.Linear(input_dimension, 64)
  21.         self.linear2 = nn.Linear(64, 64)
  22.         self.linear_first_digit = nn.Linear(64, 10)
  23.         self.linear_second_digit = nn.Linear(64, 10)
  24.  
  25.         self.encoder = nn.Sequential(
  26.               nn.Conv2d(1, 8, (3, 3)),
  27.               nn.ReLU(),
  28.               nn.MaxPool2d((2, 2)),
  29.               nn.Conv2d(8, 16, (3, 3)),
  30.               nn.ReLU(),
  31.               nn.MaxPool2d((2, 2)),
  32.               Flatten(),
  33.               nn.Linear(720, 128),
  34.               nn.Dropout(0.5),
  35.         )
  36.  
  37.         self.first_digit_classifier = nn.Linear(128,10)
  38.         self.second_digit_classifier = nn.Linear(128,10)
  39.  
  40.     def forward(self, x):
  41.         out = self.encoder(x)
  42.         out_first_digit = self.first_digit_classifier(out)
  43.         out_second_digit = self.second_digit_classifier(out)
  44.         return out_first_digit, out_second_digit
  45.  
  46.  
  47. def main():
  48.     X_train, y_train, X_test, y_test = U.get_data(path_to_data_dir, use_mini_dataset)
  49.  
  50.     # Split into train and dev
  51.     dev_split_index = int(9 * len(X_train) / 10)
  52.     X_dev = X_train[dev_split_index:]
  53.     y_dev = [y_train[0][dev_split_index:], y_train[1][dev_split_index:]]
  54.     X_train = X_train[:dev_split_index]
  55.     y_train = [y_train[0][:dev_split_index], y_train[1][:dev_split_index]]
  56.  
  57.     permutation = np.array([i for i in range(len(X_train))])
  58.     np.random.shuffle(permutation)
  59.     X_train = [X_train[i] for i in permutation]
  60.     y_train = [[y_train[0][i] for i in permutation], [y_train[1][i] for i in permutation]]
  61.  
  62.     # Split dataset into batches
  63.     train_batches = batchify_data(X_train, y_train, batch_size)
  64.     dev_batches = batchify_data(X_dev, y_dev, batch_size)
  65.     test_batches = batchify_data(X_test, y_test, batch_size)
  66.  
  67.     # Load model
  68.     input_dimension = img_rows * img_cols
  69.     model = CNN(input_dimension)
  70.  
  71.     # Train
  72.     train_model(train_batches, dev_batches, model)
  73.  
  74.     ## Evaluate the model on test data
  75.     loss, acc = run_epoch(test_batches, model.eval(), None)
  76.     print('Test loss1: {:.6f}  accuracy1: {:.6f}  loss2: {:.6f}   accuracy2: {:.6f}'.format(loss[0], acc[0], loss[1], acc[1]))
  77.  
  78. if __name__ == '__main__':
  79.     # Specify seed for deterministic behavior, then shuffle. Do not change seed for official submissions to edx
  80.     np.random.seed(12321)  # for reproducibility
  81.     torch.manual_seed(12321)  # for reproducibility
  82.     main()
  83.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement