Guest User

Untitled

a guest
Mar 22nd, 2017
836
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 12.22 KB | None | 0 0
  1. import argparse
  2. import os
  3. import shutil
  4. import time
  5.  
  6. import torch
  7. import torch.nn as nn
  8. import torch.nn.parallel
  9. import torch.backends.cudnn as cudnn
  10. import torch.optim
  11. import torch.utils.data
  12. import torchvision.transforms as transforms
  13. import torchvision.datasets as datasets
  14. import torchvision.models as models
  15. from visdom import Visdom
  16. import numpy as np
  17.  
  18. model_names = sorted(name for name in models.__dict__
  19.     if name.islower() and not name.startswith("__")
  20.     and callable(models.__dict__[name]))
  21.  
  22.  
  23. parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
  24. parser.add_argument('data', metavar='DIR',
  25.                     help='path to dataset')
  26. parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18',
  27.                     choices=model_names,
  28.                     help='model architecture: ' +
  29.                         ' | '.join(model_names) +
  30.                         ' (default: resnet18)')
  31. parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
  32.                     help='number of data loading workers (default: 4)')
  33. parser.add_argument('--epochs', default=40, type=int, metavar='N',
  34.                     help='number of total epochs to run')
  35. parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
  36.                     help='manual epoch number (useful on restarts)')
  37. parser.add_argument('-b', '--batch-size', default=32, type=int,
  38.                     metavar='N', help='mini-batch size (default: 256)')
  39. parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
  40.                     metavar='LR', help='initial learning rate')
  41. parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
  42.                     help='momentum')
  43. parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
  44.                     metavar='W', help='weight decay (default: 1e-4)')
  45. parser.add_argument('--print-freq', '-p', default=10, type=int,
  46.                     metavar='N', help='print frequency (default: 10)')
  47. parser.add_argument('--resume', default='', type=str, metavar='PATH',
  48.                     help='path to latest checkpoint (default: none)')
  49. parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
  50.                     help='evaluate model on validation set')
  51. parser.add_argument('--pretrained', dest='pretrained', action='store_true',
  52.                     help='use pre-trained model')
  53. parser.add_argument('--env', default='main', type=str,
  54.                     help='name of experiment in Visdom and foldername for weights(default: main)')
  55.  
  56. best_prec1 = 0
  57. viz = Visdom()
  58.  
  59. opts = dict(
  60.         fillarea     =True,
  61.         legend       =['Train', 'Val'],
  62.         width        =400,
  63.         height       =400,
  64.         xlabel       ='Time',
  65.         ylabel       ='Value',
  66.         title        ='Train/Val Accuracy Crop')
  67.  
  68. def main():
  69.     global args, best_prec1
  70.     args = parser.parse_args()
  71.     if not os.path.isdir(args.env):
  72.         os.makedirs(args.env)
  73.  
  74.     # create model
  75.     if args.pretrained:
  76.         print("=> using pre-trained model '{}'".format(args.arch))
  77.         model = models.__dict__[args.arch](pretrained=True)
  78.     else:
  79.         print("=> creating model '{}'".format(args.arch))
  80.         model = models.__dict__[args.arch]()
  81.  
  82.     if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
  83.         model.features = torch.nn.DataParallel(model.features)
  84.         model.cuda()
  85.     else:
  86.         model = torch.nn.DataParallel(model).cuda()
  87.    
  88.     # optionally resume from a checkpoint
  89.     if args.resume:
  90.         if os.path.isfile(args.resume):
  91.             print("=> loading checkpoint '{}'".format(args.resume))
  92.             checkpoint = torch.load(args.resume)
  93.             args.start_epoch = checkpoint['epoch']
  94.             best_prec1 = checkpoint['best_prec1']
  95.             model.load_state_dict(checkpoint['state_dict'])
  96.             print("=> loaded checkpoint '{}' (epoch {})"
  97.                   .format(args.resume, checkpoint['epoch']))
  98.         else:
  99.             print("=> no checkpoint found at '{}'".format(args.resume))
  100.    
  101.     for param in model.parameters():
  102.       param.requires_grad = False
  103.    
  104.     # Replace the last fully-connected layer
  105.     # Parameters of newly constructed modules have requires_grad=True by default
  106.     model.fc = torch.nn.Linear(512, 3)
  107.  
  108.  
  109.     cudnn.benchmark = True
  110.  
  111.     # Data loading code
  112.     traindir = os.path.join(args.data, 'train')
  113.     valdir = os.path.join(args.data, 'val')
  114.     normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
  115.                                      std=[0.229, 0.224, 0.225])
  116.  
  117.     train_loader = torch.utils.data.DataLoader(
  118.         datasets.ImageFolder(traindir, transforms.Compose([
  119.             transforms.RandomCrop(224),
  120.             transforms.RandomHorizontalFlip(),
  121.             transforms.ToTensor(),
  122.             normalize,
  123.         ])),
  124.         batch_size=args.batch_size, shuffle=True,
  125.         num_workers=args.workers, pin_memory=True)
  126.  
  127.     val_loader = torch.utils.data.DataLoader(
  128.         datasets.ImageFolder(valdir, transforms.Compose([
  129.             transforms.Scale(256),
  130.             transforms.CenterCrop(224),
  131.             transforms.ToTensor(),
  132.             normalize,
  133.         ])),
  134.         batch_size=args.batch_size, shuffle=False,
  135.         num_workers=args.workers, pin_memory=True)
  136.  
  137.     # define loss function (criterion) and pptimizer
  138.     criterion = nn.CrossEntropyLoss().cuda()
  139.    
  140.    
  141.     # Optimize only the classifier
  142.     optimizer = torch.optim.SGD(model.fc.parameters(), args.lr,
  143.                             momentum=args.momentum,
  144.                             weight_decay=args.weight_decay)
  145.  
  146.     if args.evaluate:
  147.         validate(val_loader, model, criterion)
  148.         return
  149.    
  150.  
  151.     # Add plot of loss for validation
  152.     win_prec1 = viz.line(
  153.                             Y    = np.column_stack(([0,0.001], [0,0.001])),
  154.                             X    = np.column_stack(([-1,0],[-1,0])),
  155.                             opts = opts,
  156.                             env  = args.env
  157.                         )
  158.     opts['title'] = 'Train/Val Loss Crop'
  159.     win_loss = viz.line(
  160.                             Y    = np.column_stack(([0,0.001], [0,0.001])),
  161.                             X    = np.column_stack(([-1,0],[-1,0])),
  162.                             opts = opts,
  163.                             env  = args.env
  164.                         )
  165.  
  166.     for epoch in range(args.start_epoch, args.epochs):
  167.         adjust_learning_rate(optimizer, epoch)
  168.  
  169.         # train for one epoch
  170.         train_top1, train_loss = train(train_loader, model, criterion, optimizer, epoch)
  171.  
  172.         # evaluate on validation set
  173.         val_top1, val_loss = validate(val_loader, model, criterion)
  174.    
  175.         viz.line(
  176.             X      = np.column_stack(([epoch],[epoch])),
  177.             Y      = np.column_stack(([train_top1], [val_top1])),
  178.             win    = win_prec1,
  179.             update = 'append',
  180.             env  = args.env
  181.         )
  182.  
  183.         viz.line(
  184.             X      = np.column_stack(([epoch],[epoch])),
  185.             Y      = np.column_stack(([train_loss], [val_loss])),
  186.             win    = win_loss,
  187.             update = 'append',
  188.                             env  = args.env
  189.         )
  190.  
  191.    
  192.         # remember best prec@1 and save checkpoint
  193.         is_best = val_top1 > best_prec1
  194.         best_prec1 = max(val_top1, best_prec1)
  195.         save_checkpoint({
  196.             'epoch': epoch + 1,
  197.             'arch': args.arch,
  198.             'state_dict': model.state_dict(),
  199.             'best_prec1': best_prec1,
  200.         }, is_best, folder = args.env)
  201.  
  202.  
  203. def train(train_loader, model, criterion, optimizer, epoch):
  204.     batch_time = AverageMeter()
  205.     data_time = AverageMeter()
  206.     losses = AverageMeter()
  207.     top1 = AverageMeter()
  208.     top5 = AverageMeter()
  209.  
  210.     # switch to train mode
  211.     model.train()
  212.  
  213.     end = time.time()
  214.     for i, (input, target) in enumerate(train_loader):
  215.         # measure data loading time
  216.         data_time.update(time.time() - end)
  217.  
  218.         target = target.cuda(async=True)
  219.         input_var = torch.autograd.Variable(input)
  220.         target_var = torch.autograd.Variable(target)
  221.  
  222.         # compute output
  223.         output = model(input_var)
  224.         loss = criterion(output, target_var)
  225.  
  226.         # measure accuracy and record loss
  227.         prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
  228.         losses.update(loss.data[0], input.size(0))
  229.         top1.update(prec1[0], input.size(0))
  230.         top5.update(prec5[0], input.size(0))
  231.  
  232.         # compute gradient and do SGD step
  233.         optimizer.zero_grad()
  234.         loss.backward()
  235.         optimizer.step()
  236.  
  237.         # measure elapsed time
  238.         batch_time.update(time.time() - end)
  239.         end = time.time()
  240.  
  241.         if i % args.print_freq == 0:
  242.             print('Epoch: [{0}][{1}/{2}]\t'
  243.                   'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
  244.                   'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
  245.                   'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
  246.                   'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
  247.                   'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
  248.                    epoch, i, len(train_loader), batch_time=batch_time,
  249.                    data_time=data_time, loss=losses, top1=top1, top5=top5))
  250.  
  251.     return top1.avg, losses.avg
  252.  
  253.  
  254. def validate(val_loader, model, criterion):
  255.     batch_time = AverageMeter()
  256.     losses = AverageMeter()
  257.     top1 = AverageMeter()
  258.     top5 = AverageMeter()
  259.  
  260.     # switch to evaluate mode
  261.     model.eval()
  262.  
  263.     end = time.time()
  264.     for i, (input, target) in enumerate(val_loader):
  265.         target = target.cuda(async=True)
  266.         input_var = torch.autograd.Variable(input, volatile=True)
  267.         target_var = torch.autograd.Variable(target, volatile=True)
  268.  
  269.         # compute output
  270.         output = model(input_var)
  271.         loss   = criterion(output, target_var)
  272.  
  273.         # measure accuracy and record loss
  274.         prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
  275.         losses.update(loss.data[0], input.size(0))
  276.         top1.update(prec1[0], input.size(0))
  277.         top5.update(prec5[0], input.size(0))
  278.  
  279.         # measure elapsed time
  280.         batch_time.update(time.time() - end)
  281.         end = time.time()
  282.        
  283.         if i % args.print_freq == 0:
  284.             print('Test: [{0}/{1}]\t'
  285.                   'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
  286.                   'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
  287.                   'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
  288.                   'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
  289.                    i, len(val_loader), batch_time=batch_time, loss=losses,
  290.                    top1=top1, top5=top5))
  291.  
  292.     print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
  293.           .format(top1=top1, top5=top5))
  294.    
  295.    
  296.     return top1.avg, losses.avg
  297.  
  298.  
  299. def save_checkpoint(state, is_best, folder = "", filename='checkpoint.pth.tar'):
  300.     torch.save(state, os.path.join(folder, filename))
  301.     if is_best:
  302.         shutil.copyfile(os.path.join(folder, filename), os.path.join(folder, 'model_best.pth.tar'))
  303.  
  304.  
  305. class AverageMeter(object):
  306.     """Computes and stores the average and current value"""
  307.     def __init__(self):
  308.         self.reset()
  309.  
  310.     def reset(self):
  311.         self.val = 0
  312.         self.avg = 0
  313.         self.sum = 0
  314.         self.count = 0
  315.  
  316.     def update(self, val, n=1):
  317.         self.val = val
  318.         self.sum += val * n
  319.         self.count += n
  320.         self.avg = self.sum / self.count
  321.  
  322.  
  323. def adjust_learning_rate(optimizer, epoch):
  324.     """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
  325.     lr = args.lr * (0.1 ** (epoch // 10))
  326.     for param_group in optimizer.param_groups:
  327.         param_group['lr'] = lr
  328.  
  329.  
  330. def accuracy(output, target, topk=(1,)):
  331.     """Computes the precision@k for the specified values of k"""
  332.     maxk = max(topk)
  333.     batch_size = target.size(0)
  334.  
  335.     _, pred = output.topk(maxk, 1, True, True)
  336.     pred = pred.t()
  337.     correct = pred.eq(target.view(1, -1).expand_as(pred))
  338.  
  339.     res = []
  340.     for k in topk:
  341.         correct_k = correct[:k].view(-1).float().sum(0)
  342.         res.append(correct_k.mul_(100.0 / batch_size))
  343.     return res
  344.  
  345.  
  346. if __name__ == '__main__':
  347.     main()
Add Comment
Please, Sign In to add comment