Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import os
- import time
- import torch
- import datetime
- import numpy as np
- import torch.nn as nn
- from glob import glob
- from collections import deque
- from torchvision import models
- from ignite.engine import Events
- from torchsummary import summary
- from Code.Networks.TD import TD_CNN
- from torch.utils.data import DataLoader
- from Code.Utils.npyDataset import npyDataset
- from Code.Utils.utils import get_trainable, get_trainers_and_evaluators, create_summary_writer
- if __name__ == "__main__":
- # Dataset
- TRAIN_DIR = "C:/Users/James/Dropbox/Work/Papers/EBD Echo big data/Data/classifier/train_video_3"
- TEST_DIR = "C:/Users/James/Dropbox/Work/Papers/EBD Echo big data/Data/classifier/test_video_3"
- N_CLASSES = 14
- N_FRAMES = 12
- # Model & training
- MODEL_NAME = "TD_DenseNet161"
- INNER_MODEL_TYPE = models.densenet161
- INNER_MODEL_WEIGHTS = "DenseNet161_nonorm_e005_a0.936.model"
- DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
- MULTI_GPU = True
- BATCH_SIZE = 1
- EPOCHS = 5
- VERBOSE = True
- TENSORBOARD = False
- LOG_INTERVAL = 5
- LOG_DIR = f"./logs/{MODEL_NAME}"
- # Begin
- print(f"Using device {DEVICE}")
- # Load data
- training_set = npyDataset(TRAIN_DIR)
- testing_set = npyDataset(TEST_DIR)
- train_loader = DataLoader(training_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=3)
- test_loader = DataLoader(testing_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=3)
- # Load model & set up training
- inner_model = INNER_MODEL_TYPE(pretrained=True)
- n_features = inner_model.classifier.in_features
- inner_model.classifier = nn.Linear(n_features, N_CLASSES)
- if INNER_MODEL_WEIGHTS:
- inner_model.load_state_dict({k.split('module.',1)[1]:v for k,v in torch.load(INNER_MODEL_WEIGHTS).items()})
- for param in inner_model.parameters():
- param.requires_grad = False
- model = TD_CNN(inner_model=inner_model, n_classes=N_CLASSES)
- model = model.to(DEVICE)
- summary(model, input_size=(N_FRAMES, 3, 299, 299))
- if MULTI_GPU:
- model = nn.DataParallel(model)
- loss = torch.nn.CrossEntropyLoss()
- optimizer = torch.optim.Adam(get_trainable(model.parameters()))
- trainer, evaluator = get_trainers_and_evaluators(model, optimizer, loss, DEVICE)
- writer = create_summary_writer(model, train_loader, LOG_DIR)
- @trainer.on(Events.STARTED)
- def initialise_custom_engine_vars(engine):
- engine.iteration_timings = deque(maxlen=100)
- engine.iteration_loss = deque(maxlen=100)
- engine.best_accuracy = 0
- @trainer.on(Events.ITERATION_COMPLETED)
- def log_training_loss(engine):
- engine.iteration_timings.append(time.time())
- engine.iteration_loss.append(engine.state.output)
- seconds_per_iteration = np.mean(np.gradient(engine.iteration_timings)) if len(engine.iteration_timings) > 1 else 0
- eta = seconds_per_iteration * (len(train_loader)-(engine.state.iteration % len(train_loader)))
- if TENSORBOARD:
- if ((engine.state.iteration - 1) % len(train_loader) + 1) % LOG_INTERVAL == 0:
- writer.add_scalar("training/loss", engine.state.output, engine.state.iteration)
- if VERBOSE:
- print(f"\rEPOCH: {engine.state.epoch:03d} | "
- f"BATCH: {engine.state.iteration % len(train_loader):03d} of {len(train_loader):03d} | "
- f"LOSS: {engine.state.output:.3f} ({np.mean(engine.iteration_loss):.3f}) | "
- f"({seconds_per_iteration:.2f} s/it; ETA {str(datetime.timedelta(seconds=int(eta)))})", end='')
- @trainer.on(Events.EPOCH_COMPLETED)
- def log_training_results(engine):
- evaluator.run(train_loader)
- metrics = evaluator.state.metrics
- acc, loss, precision = metrics['accuracy'], metrics['loss'], metrics['precision'].cpu()
- if acc > engine.best_accuracy:
- saved = True
- engine.best_accuracy = acc
- torch.save(model.state_dict(), f"./{MODEL_NAME}_e{engine.state.epoch:03d}_a{acc:.3f}.model")
- else:
- saved = False
- print(f"\rEnd of epoch {engine.state.epoch:03d}")
- print(f"TRAINING Accuracy: {acc:.3f} | Loss: {loss:.3f} {'-> SAVED' if saved else ''}")
- writer.add_scalar("training/accuracy", acc, engine.state.epoch)
- @trainer.on(Events.EPOCH_COMPLETED)
- def log_validation_results(engine):
- evaluator.run(test_loader)
- metrics = evaluator.state.metrics
- acc, loss, precision = metrics['accuracy'], metrics['loss'], metrics['precision'].cpu()
- print(f"TESTING Accuracy: {acc:.3f} | Loss: {loss:.3f}\n")
- writer.add_scalar("testing/loss", loss, engine.state.epoch)
- writer.add_scalar("testing/accuracy", acc, engine.state.epoch)
- trainer.run(train_loader, max_epochs=EPOCHS)
- writer.close()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement