Advertisement
Abhisek92

mlp.py

Nov 1st, 2023
674
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.49 KB | None | 0 0
  1. import numpy as np
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from train_utils import batchify_data, run_epoch, train_model, Flatten
  6. import utils_multiMNIST as U
  7. path_to_data_dir = '../Datasets/'
  8. use_mini_dataset = True
  9.  
  10. batch_size = 64
  11. nb_classes = 10
  12. nb_epoch = 30
  13. num_classes = 10
  14. img_rows, img_cols = 42, 28 # input image dimensions
  15.  
  16. class MLP(nn.Module):
  17.     def __init__(self, input_dimension):
  18.         super(MLP, self).__init__()
  19.         self.flatten = Flatten()
  20.  
  21.         H = 64
  22.         self.linear1 = nn.Linear(input_dimension, H)
  23.         self.linear_out1 = nn.Linear(H, num_classes)
  24.         self.linear_out2 = nn.Linear(num_classes, num_classes)
  25.  
  26.     def forward(self, x):
  27.         xf = self.flatten(x)  # need to flatten to use linear layers
  28.  
  29.         xf_out = F.relu(self.linear1(xf))
  30.         out_first_digit = self.linear_out1(xf_out)
  31.         out_second_digit = self.linear_out2(out_first_digit)
  32.  
  33.         return out_first_digit, out_second_digit
  34.  
  35.  
  36.  
  37. def main():
  38.     X_train, y_train, X_test, y_test = U.get_data(path_to_data_dir, use_mini_dataset)
  39.  
  40.     # Split into train and dev
  41.     dev_split_index = int(9 * len(X_train) / 10)
  42.     X_dev = X_train[dev_split_index:]
  43.     y_dev = [y_train[0][dev_split_index:], y_train[1][dev_split_index:]]
  44.     X_train = X_train[:dev_split_index]
  45.     y_train = [y_train[0][:dev_split_index], y_train[1][:dev_split_index]]
  46.  
  47.     permutation = np.array([i for i in range(len(X_train))])
  48.     np.random.shuffle(permutation)
  49.     X_train = [X_train[i] for i in permutation]
  50.     y_train = [[y_train[0][i] for i in permutation], [y_train[1][i] for i in permutation]]
  51.  
  52.     # Split dataset into batches
  53.     train_batches = batchify_data(X_train, y_train, batch_size)
  54.     dev_batches = batchify_data(X_dev, y_dev, batch_size)
  55.     test_batches = batchify_data(X_test, y_test, batch_size)
  56.  
  57.     # Load model
  58.     input_dimension = img_rows * img_cols
  59.     model = MLP(input_dimension)
  60.  
  61.     # Train
  62.     train_model(train_batches, dev_batches, model)
  63.  
  64.     ## Evaluate the model on test data
  65.     loss, acc = run_epoch(test_batches, model.eval(), None)
  66.     print('Test loss1: {:.6f}  accuracy1: {:.6f}  loss2: {:.6f}   accuracy2: {:.6f}'.format(loss[0], acc[0], loss[1], acc[1]))
  67.  
  68. if __name__ == '__main__':
  69.     # Specify seed for deterministic behavior, then shuffle. Do not change seed for official submissions to edx
  70.     np.random.seed(12321)  # for reproducibility
  71.     torch.manual_seed(12321)  # for reproducibility
  72.     main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement