Advertisement
Sam____

Train/Test

Oct 17th, 2022
50
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.24 KB | None | 0 0
  1. import warnings
  2. import torch
  3. from torch.autograd import Variable
  4. import torchvision.transforms as transforms
  5. from torch.utils.data import DataLoader
  6. from sklearn.model_selection import train_test_split
  7. from torch.utils.data import Subset
  8. from numpy import vstack
  9. from sklearn.metrics import accuracy_score
  10. import numpy as np
  11. from matplotlib import pyplot as plt
  12.  
  13.  
  14.  
  15. from dataloader import RandomFmriDataset
  16. from cnn_model import CNN_model
  17.  
  18. def fxn():
  19.     warnings.warn("deprecated", DeprecationWarning)
  20.  
  21. with warnings.catch_warnings():
  22.     warnings.simplefilter("ignore")
  23.     fxn()
  24.  
  25. if __name__ == '__main__':
  26.     learning_rate = 1e-3
  27.     batch_size = 6
  28.     num_epochs = 20
  29.     compose = transforms.Compose([
  30.         transforms.ToTensor(),
  31.         transforms.RandomHorizontalFlip(),
  32.         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
  33.     ])
  34.     dataset = RandomFmriDataset(transform=compose)
  35.     labels = dataset.labels
  36.     model = CNN_model()
  37.  
  38.     device = torch.device("cuda" if torch.cuda.is_available()
  39.                           else "cpu")
  40.  
  41.     label_array = []
  42.     for i in labels:
  43.         for j in i:
  44.             label_array.append(j)
  45.  
  46.     data_and_labels = []
  47.     for g in range(len(dataset)):
  48.         data_and_labels.append([dataset[g], label_array[g]])
  49.  
  50.  
  51.     def train_val_dataset(dataset, val_split=0.25):
  52.         train_idx, val_idx = train_test_split(list(range(len(dataset))), test_size=val_split)
  53.         datasets = {'train': Subset(dataset, train_idx), 'val': Subset(dataset, val_idx)}
  54.         return datasets
  55.  
  56.     datasets = train_val_dataset(data_and_labels)
  57.     dataloaders = {x: DataLoader(datasets[x], 6, shuffle=True, num_workers=4) for x in ['train', 'val']}
  58.     x_train, y_train = next(iter(dataloaders['train']))
  59.     x_val, y_val = next(iter(dataloaders['val']))
  60.  
  61.     def train_model(train_dl, model):
  62.         criterion = torch.nn.BCELoss(size_average=True)
  63.         optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
  64.         loss_array = []
  65.         loss_values = []
  66.         acc_values = []
  67.         for epoch in range(num_epochs):
  68.             for ind, (data, img_label) in enumerate(train_dl):
  69.                 inputs = data.permute(0, 4, 1, 2, 3),
  70.  
  71.                 inputsV, labelsV = Variable(inputs[0]), Variable(img_label)
  72.                 inputsV = inputsV.to(torch.float32)
  73.                 labelsV = labelsV.to(torch.float32)
  74.                 y_pred = model(inputsV)
  75.  
  76.                 loss = criterion(y_pred.squeeze(), labelsV)
  77.                 running_loss = abs(loss.item())
  78.                 loss_array.append(running_loss)
  79.                 optimizer.zero_grad()
  80.                 loss.backward()
  81.                 optimizer.step()
  82.             avg_loss = np.average(loss_array)
  83.             loss_values.append(avg_loss)
  84.             acc = evaluate_model(dataloaders['val'], model)
  85.             acc = acc * 100
  86.             acc_values.append(acc)
  87.             print(f'Epoch [{epoch + 1}/{num_epochs}], \n Loss: {abs(avg_loss):.4f}, Accuracy: {acc:.4f}')
  88.         plot(acc_values, loss_values)
  89.  
  90.     def evaluate_model(test_dl, model):
  91.         predictions, actuals = list(), list()
  92.         accuracy_array = []
  93.         for ind, (data, img_label) in enumerate(test_dl):
  94.             inputs = data.permute(0, 4, 1, 2, 3),
  95.  
  96.             inputsV, labelsV = Variable(inputs[0]), Variable(img_label)
  97.             inputsV = inputsV.to(torch.float32)
  98.             labelsV = labelsV.to(torch.float32)
  99.  
  100.             yhat = model(inputsV)
  101.  
  102.             actual = labelsV.numpy()
  103.             actual = actual.reshape((len(actual), 1))
  104.  
  105.             yhat = yhat.detach().numpy()
  106.             yhat = yhat.round()
  107.  
  108.             predictions.append(yhat)
  109.             actuals.append(actual)
  110.             predictions, actuals = vstack(predictions), vstack(actuals)
  111.  
  112.             acc = accuracy_score(actuals, predictions)
  113.             accuracy_array.append(acc)
  114.  
  115.             return acc
  116.  
  117.  
  118.     def plot(x, y):
  119.         plt.figure(figsize=(16, 5))
  120.         plt.xlabel('EPOCHS')
  121.         plt.ylabel('LOSS/ACC')
  122.  
  123.         plt.plot(x, 'r', label='ACCURACY')
  124.         plt.plot(y, 'b', label='LOSS')
  125.         plt.legend()
  126.         plt.show()
  127.  
  128.     train_model(dataloaders['train'], model)
  129.     evaluate_model(dataloaders['val'], model)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement