Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import argparse
- import os
- import shutil
- import time
- import torch
- import torch.nn as nn
- import torch.nn.parallel
- import torch.backends.cudnn as cudnn
- import torch.optim
- import torch.utils.data
- import torchvision.transforms as transforms
- import torchvision.datasets as datasets
- import torchvision.models as models
- from visdom import Visdom
- import numpy as np
- model_names = sorted(name for name in models.__dict__
- if name.islower() and not name.startswith("__")
- and callable(models.__dict__[name]))
- parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
- parser.add_argument('data', metavar='DIR',
- help='path to dataset')
- parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18',
- choices=model_names,
- help='model architecture: ' +
- ' | '.join(model_names) +
- ' (default: resnet18)')
- parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
- help='number of data loading workers (default: 4)')
- parser.add_argument('--epochs', default=40, type=int, metavar='N',
- help='number of total epochs to run')
- parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
- help='manual epoch number (useful on restarts)')
- parser.add_argument('-b', '--batch-size', default=32, type=int,
- metavar='N', help='mini-batch size (default: 256)')
- parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
- metavar='LR', help='initial learning rate')
- parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
- help='momentum')
- parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
- metavar='W', help='weight decay (default: 1e-4)')
- parser.add_argument('--print-freq', '-p', default=10, type=int,
- metavar='N', help='print frequency (default: 10)')
- parser.add_argument('--resume', default='', type=str, metavar='PATH',
- help='path to latest checkpoint (default: none)')
- parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
- help='evaluate model on validation set')
- parser.add_argument('--pretrained', dest='pretrained', action='store_true',
- help='use pre-trained model')
- parser.add_argument('--env', default='main', type=str,
- help='name of experiment in Visdom and foldername for weights(default: main)')
- best_prec1 = 0
- viz = Visdom()
- opts = dict(
- fillarea =True,
- legend =['Train', 'Val'],
- width =400,
- height =400,
- xlabel ='Time',
- ylabel ='Value',
- title ='Train/Val Accuracy Crop')
- def main():
- global args, best_prec1
- args = parser.parse_args()
- if not os.path.isdir(args.env):
- os.makedirs(args.env)
- # create model
- if args.pretrained:
- print("=> using pre-trained model '{}'".format(args.arch))
- model = models.__dict__[args.arch](pretrained=True)
- else:
- print("=> creating model '{}'".format(args.arch))
- model = models.__dict__[args.arch]()
- if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
- model.features = torch.nn.DataParallel(model.features)
- model.cuda()
- else:
- model = torch.nn.DataParallel(model).cuda()
- # optionally resume from a checkpoint
- if args.resume:
- if os.path.isfile(args.resume):
- print("=> loading checkpoint '{}'".format(args.resume))
- checkpoint = torch.load(args.resume)
- args.start_epoch = checkpoint['epoch']
- best_prec1 = checkpoint['best_prec1']
- model.load_state_dict(checkpoint['state_dict'])
- print("=> loaded checkpoint '{}' (epoch {})"
- .format(args.resume, checkpoint['epoch']))
- else:
- print("=> no checkpoint found at '{}'".format(args.resume))
- for param in model.parameters():
- param.requires_grad = False
- # Replace the last fully-connected layer
- # Parameters of newly constructed modules have requires_grad=True by default
- model.fc = torch.nn.Linear(512, 3)
- cudnn.benchmark = True
- # Data loading code
- traindir = os.path.join(args.data, 'train')
- valdir = os.path.join(args.data, 'val')
- normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
- std=[0.229, 0.224, 0.225])
- train_loader = torch.utils.data.DataLoader(
- datasets.ImageFolder(traindir, transforms.Compose([
- transforms.RandomCrop(224),
- transforms.RandomHorizontalFlip(),
- transforms.ToTensor(),
- normalize,
- ])),
- batch_size=args.batch_size, shuffle=True,
- num_workers=args.workers, pin_memory=True)
- val_loader = torch.utils.data.DataLoader(
- datasets.ImageFolder(valdir, transforms.Compose([
- transforms.Scale(256),
- transforms.CenterCrop(224),
- transforms.ToTensor(),
- normalize,
- ])),
- batch_size=args.batch_size, shuffle=False,
- num_workers=args.workers, pin_memory=True)
- # define loss function (criterion) and pptimizer
- criterion = nn.CrossEntropyLoss().cuda()
- # Optimize only the classifier
- optimizer = torch.optim.SGD(model.fc.parameters(), args.lr,
- momentum=args.momentum,
- weight_decay=args.weight_decay)
- if args.evaluate:
- validate(val_loader, model, criterion)
- return
- # Add plot of loss for validation
- win_prec1 = viz.line(
- Y = np.column_stack(([0,0.001], [0,0.001])),
- X = np.column_stack(([-1,0],[-1,0])),
- opts = opts,
- env = args.env
- )
- opts['title'] = 'Train/Val Loss Crop'
- win_loss = viz.line(
- Y = np.column_stack(([0,0.001], [0,0.001])),
- X = np.column_stack(([-1,0],[-1,0])),
- opts = opts,
- env = args.env
- )
- for epoch in range(args.start_epoch, args.epochs):
- adjust_learning_rate(optimizer, epoch)
- # train for one epoch
- train_top1, train_loss = train(train_loader, model, criterion, optimizer, epoch)
- # evaluate on validation set
- val_top1, val_loss = validate(val_loader, model, criterion)
- viz.line(
- X = np.column_stack(([epoch],[epoch])),
- Y = np.column_stack(([train_top1], [val_top1])),
- win = win_prec1,
- update = 'append',
- env = args.env
- )
- viz.line(
- X = np.column_stack(([epoch],[epoch])),
- Y = np.column_stack(([train_loss], [val_loss])),
- win = win_loss,
- update = 'append',
- env = args.env
- )
- # remember best prec@1 and save checkpoint
- is_best = val_top1 > best_prec1
- best_prec1 = max(val_top1, best_prec1)
- save_checkpoint({
- 'epoch': epoch + 1,
- 'arch': args.arch,
- 'state_dict': model.state_dict(),
- 'best_prec1': best_prec1,
- }, is_best, folder = args.env)
- def train(train_loader, model, criterion, optimizer, epoch):
- batch_time = AverageMeter()
- data_time = AverageMeter()
- losses = AverageMeter()
- top1 = AverageMeter()
- top5 = AverageMeter()
- # switch to train mode
- model.train()
- end = time.time()
- for i, (input, target) in enumerate(train_loader):
- # measure data loading time
- data_time.update(time.time() - end)
- target = target.cuda(async=True)
- input_var = torch.autograd.Variable(input)
- target_var = torch.autograd.Variable(target)
- # compute output
- output = model(input_var)
- loss = criterion(output, target_var)
- # measure accuracy and record loss
- prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
- losses.update(loss.data[0], input.size(0))
- top1.update(prec1[0], input.size(0))
- top5.update(prec5[0], input.size(0))
- # compute gradient and do SGD step
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
- # measure elapsed time
- batch_time.update(time.time() - end)
- end = time.time()
- if i % args.print_freq == 0:
- print('Epoch: [{0}][{1}/{2}]\t'
- 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
- 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
- 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
- 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
- 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
- epoch, i, len(train_loader), batch_time=batch_time,
- data_time=data_time, loss=losses, top1=top1, top5=top5))
- return top1.avg, losses.avg
- def validate(val_loader, model, criterion):
- batch_time = AverageMeter()
- losses = AverageMeter()
- top1 = AverageMeter()
- top5 = AverageMeter()
- # switch to evaluate mode
- model.eval()
- end = time.time()
- for i, (input, target) in enumerate(val_loader):
- target = target.cuda(async=True)
- input_var = torch.autograd.Variable(input, volatile=True)
- target_var = torch.autograd.Variable(target, volatile=True)
- # compute output
- output = model(input_var)
- loss = criterion(output, target_var)
- # measure accuracy and record loss
- prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
- losses.update(loss.data[0], input.size(0))
- top1.update(prec1[0], input.size(0))
- top5.update(prec5[0], input.size(0))
- # measure elapsed time
- batch_time.update(time.time() - end)
- end = time.time()
- if i % args.print_freq == 0:
- print('Test: [{0}/{1}]\t'
- 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
- 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
- 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
- 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
- i, len(val_loader), batch_time=batch_time, loss=losses,
- top1=top1, top5=top5))
- print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
- .format(top1=top1, top5=top5))
- return top1.avg, losses.avg
- def save_checkpoint(state, is_best, folder = "", filename='checkpoint.pth.tar'):
- torch.save(state, os.path.join(folder, filename))
- if is_best:
- shutil.copyfile(os.path.join(folder, filename), os.path.join(folder, 'model_best.pth.tar'))
- class AverageMeter(object):
- """Computes and stores the average and current value"""
- def __init__(self):
- self.reset()
- def reset(self):
- self.val = 0
- self.avg = 0
- self.sum = 0
- self.count = 0
- def update(self, val, n=1):
- self.val = val
- self.sum += val * n
- self.count += n
- self.avg = self.sum / self.count
- def adjust_learning_rate(optimizer, epoch):
- """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
- lr = args.lr * (0.1 ** (epoch // 10))
- for param_group in optimizer.param_groups:
- param_group['lr'] = lr
- def accuracy(output, target, topk=(1,)):
- """Computes the precision@k for the specified values of k"""
- maxk = max(topk)
- batch_size = target.size(0)
- _, pred = output.topk(maxk, 1, True, True)
- pred = pred.t()
- correct = pred.eq(target.view(1, -1).expand_as(pred))
- res = []
- for k in topk:
- correct_k = correct[:k].view(-1).float().sum(0)
- res.append(correct_k.mul_(100.0 / batch_size))
- return res
- if __name__ == '__main__':
- main()
Add Comment
Please, Sign In to add comment