Advertisement
Guest User

Untitled

a guest
Apr 29th, 2019
173
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.86 KB | None | 0 0
  1. import os
  2. import time
  3. import torch
  4. import datetime
  5. import numpy as np
  6. import torch.nn as nn
  7.  
  8. from glob import glob
  9. from collections import deque
  10. from torchvision import models
  11. from ignite.engine import Events
  12. from torchsummary import summary
  13. from Code.Networks.TD import TD_CNN
  14. from torch.utils.data import DataLoader
  15. from Code.Utils.npyDataset import npyDataset
  16. from Code.Utils.utils import get_trainable, get_trainers_and_evaluators, create_summary_writer
  17.  
  18.  
  19. if __name__ == "__main__":
  20.     # Dataset
  21.     TRAIN_DIR = "C:/Users/James/Dropbox/Work/Papers/EBD Echo big data/Data/classifier/train_video_3"
  22.     TEST_DIR = "C:/Users/James/Dropbox/Work/Papers/EBD Echo big data/Data/classifier/test_video_3"
  23.     N_CLASSES = 14
  24.     N_FRAMES = 12
  25.  
  26.     # Model & training
  27.     MODEL_NAME = "TD_DenseNet161"
  28.     INNER_MODEL_TYPE = models.densenet161
  29.     INNER_MODEL_WEIGHTS = "DenseNet161_nonorm_e005_a0.936.model"
  30.     DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  31.     MULTI_GPU = True
  32.     BATCH_SIZE = 1
  33.     EPOCHS = 5
  34.     VERBOSE = True
  35.     TENSORBOARD = False
  36.     LOG_INTERVAL = 5
  37.     LOG_DIR = f"./logs/{MODEL_NAME}"
  38.  
  39.     # Begin
  40.     print(f"Using device {DEVICE}")
  41.  
  42.     # Load data
  43.     training_set = npyDataset(TRAIN_DIR)
  44.     testing_set = npyDataset(TEST_DIR)
  45.     train_loader = DataLoader(training_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=3)
  46.     test_loader = DataLoader(testing_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=3)
  47.  
  48.     # Load model & set up training
  49.     inner_model = INNER_MODEL_TYPE(pretrained=True)
  50.     n_features = inner_model.classifier.in_features
  51.     inner_model.classifier = nn.Linear(n_features, N_CLASSES)
  52.     if INNER_MODEL_WEIGHTS:
  53.         inner_model.load_state_dict({k.split('module.',1)[1]:v for k,v in torch.load(INNER_MODEL_WEIGHTS).items()})
  54.     for param in inner_model.parameters():
  55.         param.requires_grad = False
  56.     model = TD_CNN(inner_model=inner_model, n_classes=N_CLASSES)
  57.    
  58.     model = model.to(DEVICE)
  59.     summary(model, input_size=(N_FRAMES, 3, 299, 299))
  60.     if MULTI_GPU:
  61.         model = nn.DataParallel(model)
  62.     loss = torch.nn.CrossEntropyLoss()
  63.     optimizer = torch.optim.Adam(get_trainable(model.parameters()))
  64.     trainer, evaluator = get_trainers_and_evaluators(model, optimizer, loss, DEVICE)
  65.     writer = create_summary_writer(model, train_loader, LOG_DIR)
  66.  
  67.     @trainer.on(Events.STARTED)
  68.     def initialise_custom_engine_vars(engine):
  69.         engine.iteration_timings = deque(maxlen=100)
  70.         engine.iteration_loss = deque(maxlen=100)
  71.         engine.best_accuracy = 0
  72.  
  73.     @trainer.on(Events.ITERATION_COMPLETED)
  74.     def log_training_loss(engine):
  75.         engine.iteration_timings.append(time.time())
  76.         engine.iteration_loss.append(engine.state.output)
  77.         seconds_per_iteration = np.mean(np.gradient(engine.iteration_timings)) if len(engine.iteration_timings) > 1 else 0
  78.         eta = seconds_per_iteration * (len(train_loader)-(engine.state.iteration % len(train_loader)))
  79.         if TENSORBOARD:
  80.             if ((engine.state.iteration - 1) % len(train_loader) + 1) % LOG_INTERVAL == 0:
  81.                 writer.add_scalar("training/loss", engine.state.output, engine.state.iteration)
  82.         if VERBOSE:
  83.             print(f"\rEPOCH: {engine.state.epoch:03d} | "
  84.                   f"BATCH: {engine.state.iteration % len(train_loader):03d} of {len(train_loader):03d} | "
  85.                   f"LOSS: {engine.state.output:.3f} ({np.mean(engine.iteration_loss):.3f}) | "
  86.                   f"({seconds_per_iteration:.2f} s/it; ETA {str(datetime.timedelta(seconds=int(eta)))})", end='')
  87.  
  88.     @trainer.on(Events.EPOCH_COMPLETED)
  89.     def log_training_results(engine):
  90.         evaluator.run(train_loader)
  91.         metrics = evaluator.state.metrics
  92.         acc, loss, precision = metrics['accuracy'], metrics['loss'], metrics['precision'].cpu()
  93.         if acc > engine.best_accuracy:
  94.             saved = True
  95.             engine.best_accuracy = acc
  96.             torch.save(model.state_dict(), f"./{MODEL_NAME}_e{engine.state.epoch:03d}_a{acc:.3f}.model")
  97.         else:
  98.             saved = False
  99.  
  100.         print(f"\rEnd of epoch {engine.state.epoch:03d}")
  101.         print(f"TRAINING Accuracy: {acc:.3f} | Loss: {loss:.3f} {'-> SAVED' if saved else ''}")
  102.         writer.add_scalar("training/accuracy", acc, engine.state.epoch)
  103.  
  104.  
  105.     @trainer.on(Events.EPOCH_COMPLETED)
  106.     def log_validation_results(engine):
  107.         evaluator.run(test_loader)
  108.         metrics = evaluator.state.metrics
  109.         acc, loss, precision = metrics['accuracy'], metrics['loss'], metrics['precision'].cpu()
  110.         print(f"TESTING  Accuracy: {acc:.3f} | Loss: {loss:.3f}\n")
  111.         writer.add_scalar("testing/loss", loss, engine.state.epoch)
  112.         writer.add_scalar("testing/accuracy", acc, engine.state.epoch)
  113.  
  114.  
  115.     trainer.run(train_loader, max_epochs=EPOCHS)
  116.     writer.close()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement