Advertisement
Guest User

Untitled

a guest
Sep 17th, 2019
123
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 20.66 KB | None | 0 0
  1. #!/usr/bin/env python3
  2. import argparse
  3. import os
  4. import random
  5. import shutil
  6. import time
  7. import warnings
  8.  
  9. import torch
  10. import torch.nn as nn
  11. import torch.nn.parallel
  12. import torch.backends.cudnn as cudnn
  13. import torch.distributed as dist
  14. import torch.optim
  15. import torch.multiprocessing as mp
  16. import torch.utils.data
  17. import torch.utils.data.distributed
  18. import torchvision.transforms as transforms
  19. import torchvision.datasets as datasets
  20. import torchvision.models as models
  21.  
  22. import nvidia.dali.ops as ops
  23. import nvidia.dali.types as types
  24. from nvidia.dali.pipeline import Pipeline
  25. from nvidia.dali.plugin.pytorch import DALIClassificationIterator
  26.  
  27. model_names = sorted(name for name in models.__dict__
  28. if name.islower() and not name.startswith("__")
  29. and callable(models.__dict__[name]))
  30.  
  31. parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
  32. parser.add_argument('data', metavar='DIR',
  33. help='path to dataset')
  34. parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
  35. choices=model_names,
  36. help='model architecture: ' +
  37. ' | '.join(model_names) +
  38. ' (default: resnet18)')
  39. parser.add_argument('-j', '--workers', default=12, type=int, metavar='N',
  40. help='number of data loading workers (default: 12)')
  41. parser.add_argument('--epochs', default=1, type=int, metavar='N',
  42. help='number of total epochs to run')
  43. parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
  44. help='manual epoch number (useful on restarts)')
  45. parser.add_argument('-b', '--batch-size', default=1024, type=int,
  46. metavar='N',
  47. help='mini-batch size (default: 1024), this is the total '
  48. 'batch size of all GPUs on the current node when '
  49. 'using Data Parallel or Distributed Data Parallel')
  50. parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
  51. metavar='LR', help='initial learning rate', dest='lr')
  52. parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
  53. help='momentum')
  54. parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
  55. metavar='W', help='weight decay (default: 1e-4)',
  56. dest='weight_decay')
  57. parser.add_argument('-p', '--print-freq', default=10, type=int,
  58. metavar='N', help='print frequency (default: 10)')
  59. parser.add_argument('--resume', default='', type=str, metavar='PATH',
  60. help='path to latest checkpoint (default: none)')
  61. parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
  62. help='evaluate model on validation set')
  63. parser.add_argument('--pretrained', dest='pretrained', action='store_true',
  64. help='use pre-trained model')
  65. parser.add_argument('--world-size', default=-1, type=int,
  66. help='number of nodes for distributed training')
  67. parser.add_argument('--rank', default=-1, type=int,
  68. help='node rank for distributed training')
  69. parser.add_argument('--dist-url', default='tcp://127.0.0.1:23456', type=str,
  70. help='url used to set up distributed training')
  71. parser.add_argument('--dist-backend', default='nccl', type=str,
  72. help='distributed backend')
  73. parser.add_argument('--seed', default=None, type=int,
  74. help='seed for initializing training. ')
  75. parser.add_argument('--gpu', default=None, type=int,
  76. help='GPU id to use.')
  77. parser.add_argument('--multiprocessing-distributed', action='store_true',
  78. help='Use multi-processing distributed training to launch '
  79. 'N processes per node, which has N GPUs. This is the '
  80. 'fastest way to use PyTorch for either single node or '
  81. 'multi node data parallel training')
  82. parser.add_argument('--fake', action='store_true', help='Use fake dataset')
  83. parser.add_argument('--skip-val', action='store_true', help='skip validation')
  84. parser.add_argument('--dali-loader', action='store_true',
  85. help='Use dali dataset loader')
  86.  
  87. best_acc1 = 0
  88. fake_train_samples = 1300000
  89. fake_val_samples = 50000
  90. fake_num_classes = 1000
  91. fake_image_size = (3, 224, 224)
  92.  
  93.  
  94. def main():
  95. args = parser.parse_args()
  96.  
  97. if args.seed is not None:
  98. random.seed(args.seed)
  99. torch.manual_seed(args.seed)
  100. cudnn.deterministic = True
  101. warnings.warn('You have chosen to seed training. '
  102. 'This will turn on the CUDNN deterministic setting, '
  103. 'which can slow down your training considerably! '
  104. 'You may see unexpected behavior when restarting '
  105. 'from checkpoints.')
  106.  
  107. if args.gpu is not None:
  108. warnings.warn('You have chosen a specific GPU. This will completely '
  109. 'disable data parallelism.')
  110.  
  111. if args.dist_url == "env://" and args.world_size == -1:
  112. args.world_size = int(os.environ["WORLD_SIZE"])
  113.  
  114. args.distributed = args.world_size > 1 or args.multiprocessing_distributed
  115.  
  116. ngpus_per_node = torch.cuda.device_count()
  117. if args.multiprocessing_distributed:
  118. # Since we have ngpus_per_node processes per node, the total world_size
  119. # needs to be adjusted accordingly
  120. args.world_size = ngpus_per_node * args.world_size
  121. # Use torch.multiprocessing.spawn to launch distributed processes: the
  122. # main_worker process function
  123. mp.spawn(main_worker, nprocs=ngpus_per_node,
  124. args=(ngpus_per_node, args))
  125. else:
  126. # Simply call main_worker function
  127. main_worker(args.gpu, ngpus_per_node, args)
  128.  
  129.  
  130. def main_worker(gpu, ngpus_per_node, args):
  131. global best_acc1
  132. args.gpu = gpu # gpu id of current process
  133.  
  134. if args.gpu is not None:
  135. print("Use GPU: {} for training".format(args.gpu))
  136.  
  137. if args.distributed:
  138. if args.dist_url == "env://" and args.rank == -1:
  139. args.rank = int(os.environ["RANK"])
  140. if args.multiprocessing_distributed:
  141. # For multiprocessing distributed training, rank needs to be the
  142. # global rank among all the processes
  143. args.rank = args.rank * ngpus_per_node + gpu
  144. dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
  145. world_size=args.world_size, rank=args.rank)
  146. # create model
  147. if args.pretrained:
  148. print("=> using pre-trained model '{}'".format(args.arch))
  149. model = models.__dict__[args.arch](pretrained=True)
  150. else:
  151. print("=> creating model '{}'".format(args.arch))
  152. model = models.__dict__[args.arch]()
  153.  
  154. if args.distributed:
  155. # For multiprocessing distributed, DistributedDataParallel constructor
  156. # should always set the single device scope, otherwise,
  157. # DistributedDataParallel will use all available devices.
  158. if args.gpu is not None:
  159. torch.cuda.set_device(args.gpu)
  160. model.cuda(args.gpu)
  161. # When using a single GPU per process and per
  162. # DistributedDataParallel, we need to divide the batch size
  163. # ourselves based on the total number of GPUs we have
  164. args.batch_size = int(args.batch_size / ngpus_per_node)
  165. args.workers = int(
  166. (args.workers + ngpus_per_node - 1) / ngpus_per_node)
  167. model = torch.nn.parallel.DistributedDataParallel(
  168. model, device_ids=[args.gpu])
  169. else:
  170. model.cuda()
  171. # DistributedDataParallel will divide and allocate batch_size to all
  172. # available GPUs if device_ids are not set
  173. model = torch.nn.parallel.DistributedDataParallel(model)
  174. elif args.gpu is not None:
  175. torch.cuda.set_device(args.gpu)
  176. model = model.cuda(args.gpu)
  177. else:
  178. # DataParallel will divide and allocate batch_size to all available GPUs
  179. if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
  180. model.features = torch.nn.DataParallel(model.features)
  181. model.cuda()
  182. else:
  183. model = torch.nn.DataParallel(model).cuda()
  184.  
  185. # define loss function (criterion) and optimizer
  186. criterion = nn.CrossEntropyLoss().cuda(args.gpu)
  187.  
  188. optimizer = torch.optim.SGD(model.parameters(), args.lr,
  189. momentum=args.momentum,
  190. weight_decay=args.weight_decay)
  191.  
  192. # optionally resume from a checkpoint
  193. if args.resume:
  194. if os.path.isfile(args.resume):
  195. print("=> loading checkpoint '{}'".format(args.resume))
  196. checkpoint = torch.load(args.resume)
  197. args.start_epoch = checkpoint['epoch']
  198. best_acc1 = checkpoint['best_acc1']
  199. if args.gpu is not None:
  200. # best_acc1 may be from a checkpoint from a different GPU
  201. best_acc1 = best_acc1.to(args.gpu)
  202. model.load_state_dict(checkpoint['state_dict'])
  203. optimizer.load_state_dict(checkpoint['optimizer'])
  204. print("=> loaded checkpoint '{}' (epoch {})"
  205. .format(args.resume, checkpoint['epoch']))
  206. else:
  207. print("=> no checkpoint found at '{}'".format(args.resume))
  208.  
  209. cudnn.benchmark = True
  210.  
  211. # Data loading code
  212. start = time.time()
  213. traindir = os.path.join(args.data, 'train')
  214. valdir = os.path.join(args.data, 'val')
  215. normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
  216. std=[0.229, 0.224, 0.225])
  217.  
  218. train_dataset = None
  219. if args.fake:
  220. train_dataset = datasets.FakeData(
  221. fake_train_samples, fake_image_size, fake_num_classes,
  222. transforms.Compose([
  223. transforms.ToTensor(),
  224. normalize,
  225. ]))
  226. else:
  227. train_dataset = datasets.ImageFolder(
  228. traindir,
  229. transforms.Compose([
  230. transforms.RandomResizedCrop(224),
  231. transforms.RandomHorizontalFlip(),
  232. transforms.ToTensor(),
  233. normalize,
  234. ]))
  235.  
  236. if args.distributed:
  237. # sampler will load a subset of the dataset
  238. train_sampler = torch.utils.data.distributed.DistributedSampler(
  239. train_dataset)
  240. else:
  241. train_sampler = None
  242.  
  243. train_loader = None
  244. if args.dali_loader:
  245. pp = FileReadPipeline(
  246. args.batch_size, args.workers, args.gpu, traindir)
  247. pp.build()
  248. train_loader = ClassificationIterator(
  249. pp, size=pp.epoch_size("Reader") // args.world_size) # Number of samples in the epoch (on this gpu), world_size was already adjusted
  250. else:
  251. train_loader = torch.utils.data.DataLoader(
  252. train_dataset, batch_size=args.batch_size, shuffle=(
  253. train_sampler is None),
  254. num_workers=args.workers, pin_memory=True, sampler=train_sampler)
  255.  
  256. if not args.skip_val:
  257. val_dataset = None
  258. if args.fake:
  259. val_dataset = datasets.FakeData(
  260. fake_val_samples, fake_image_size, fake_num_classes,
  261. transforms.Compose([
  262. transforms.ToTensor(),
  263. normalize,
  264. ]))
  265. else:
  266. val_dataset = datasets.ImageFolder(
  267. valdir,
  268. transforms.Compose([
  269. transforms.Resize(256),
  270. transforms.CenterCrop(224),
  271. transforms.ToTensor(),
  272. normalize,
  273. ]))
  274.  
  275. val_loader = torch.utils.data.DataLoader(
  276. val_dataset, batch_size=args.batch_size, shuffle=False,
  277. num_workers=args.workers, pin_memory=True)
  278.  
  279. if args.evaluate:
  280. validate(val_loader, model, criterion, args)
  281. return
  282.  
  283. print('Dataloader prepare done, cost {:.3f}'.format(time.time() - start))
  284.  
  285. epoch_time = AverageMeter('Time', ':6.3f')
  286. end = time.time()
  287. for epoch in range(args.start_epoch, args.epochs):
  288. if args.distributed:
  289. train_sampler.set_epoch(epoch)
  290. adjust_learning_rate(optimizer, epoch, args)
  291.  
  292. # train for one epoch
  293. train(train_loader, model, criterion, optimizer, epoch, args)
  294.  
  295. acc1 = 0
  296. if not args.skip_val:
  297. # evaluate on validation set
  298. acc1 = validate(val_loader, model, criterion, args)
  299.  
  300. # remember best acc@1 and save checkpoint
  301. is_best = acc1 > best_acc1
  302. best_acc1 = max(acc1, best_acc1)
  303.  
  304. if not args.multiprocessing_distributed or (args.multiprocessing_distributed
  305. and args.rank % ngpus_per_node == 0):
  306. save_checkpoint({
  307. 'epoch': epoch + 1,
  308. 'arch': args.arch,
  309. 'state_dict': model.state_dict(),
  310. 'best_acc1': best_acc1,
  311. 'optimizer': optimizer.state_dict(),
  312. }, is_best)
  313.  
  314. epoch_time.update(time.time() - end)
  315. end = time.time()
  316. print('Epoch: [{}] done {}'.format(epoch, epoch_time))
  317.  
  318.  
  319. def train(train_loader, model, criterion, optimizer, epoch, args):
  320. batch_time = AverageMeter('Time', ':6.3f')
  321. data_time = AverageMeter('Data', ':6.3f')
  322. losses = AverageMeter('Loss', ':.4e')
  323. top1 = AverageMeter('Acc@1', ':6.2f')
  324. top5 = AverageMeter('Acc@5', ':6.2f')
  325. progress = ProgressMeter(
  326. len(train_loader),
  327. [batch_time, data_time, losses, top1, top5],
  328. prefix="Epoch: [{}]".format(epoch))
  329.  
  330. # switch to train mode
  331. model.train()
  332.  
  333. end = time.time()
  334. for i, (images, target) in enumerate(train_loader):
  335. # measure data loading time
  336. data_time.update(time.time() - end)
  337.  
  338. if args.gpu is not None:
  339. images = images.cuda(args.gpu, non_blocking=True)
  340. target = target.cuda(args.gpu, non_blocking=True)
  341.  
  342. # compute output
  343. output = model(images)
  344. loss = criterion(output, target)
  345.  
  346. # measure accuracy and record loss
  347. acc1, acc5 = accuracy(output, target, topk=(1, 5))
  348. losses.update(loss.item(), images.size(0))
  349. top1.update(acc1[0], images.size(0))
  350. top5.update(acc5[0], images.size(0))
  351.  
  352. # compute gradient and do SGD step
  353. optimizer.zero_grad()
  354. loss.backward()
  355. optimizer.step()
  356.  
  357. # measure elapsed time
  358. batch_time.update(time.time() - end)
  359. end = time.time()
  360.  
  361. if i % args.print_freq == 0:
  362. progress.display(i)
  363.  
  364.  
  365. def validate(val_loader, model, criterion, args):
  366. batch_time = AverageMeter('Time', ':6.3f')
  367. losses = AverageMeter('Loss', ':.4e')
  368. top1 = AverageMeter('Acc@1', ':6.2f')
  369. top5 = AverageMeter('Acc@5', ':6.2f')
  370. progress = ProgressMeter(
  371. len(val_loader),
  372. [batch_time, losses, top1, top5],
  373. prefix='Test: ')
  374.  
  375. # switch to evaluate mode
  376. model.eval()
  377.  
  378. with torch.no_grad():
  379. end = time.time()
  380. for i, (images, target) in enumerate(val_loader):
  381. if args.gpu is not None:
  382. images = images.cuda(args.gpu, non_blocking=True)
  383. target = target.cuda(args.gpu, non_blocking=True)
  384.  
  385. # compute output
  386. output = model(images)
  387. loss = criterion(output, target)
  388.  
  389. # measure accuracy and record loss
  390. acc1, acc5 = accuracy(output, target, topk=(1, 5))
  391. losses.update(loss.item(), images.size(0))
  392. top1.update(acc1[0], images.size(0))
  393. top5.update(acc5[0], images.size(0))
  394.  
  395. # measure elapsed time
  396. batch_time.update(time.time() - end)
  397. end = time.time()
  398.  
  399. if i % args.print_freq == 0:
  400. progress.display(i)
  401.  
  402. # TODO: this should also be done with the ProgressMeter
  403. print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
  404. .format(top1=top1, top5=top5))
  405.  
  406. return top1.avg
  407.  
  408.  
  409. def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
  410. torch.save(state, filename)
  411. if is_best:
  412. shutil.copyfile(filename, 'model_best.pth.tar')
  413.  
  414.  
  415. class AverageMeter(object):
  416. """Computes and stores the average and current value"""
  417.  
  418. def __init__(self, name, fmt=':f'):
  419. self.name = name
  420. self.fmt = fmt
  421. self.reset()
  422.  
  423. def reset(self):
  424. self.val = 0
  425. self.avg = 0
  426. self.sum = 0
  427. self.count = 0
  428.  
  429. def update(self, val, n=1):
  430. self.val = val
  431. self.sum += val * n
  432. self.count += n
  433. self.avg = self.sum / self.count
  434.  
  435. def __str__(self):
  436. fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
  437. return fmtstr.format(**self.__dict__)
  438.  
  439.  
  440. class ProgressMeter(object):
  441. def __init__(self, num_batches, meters, prefix=""):
  442. self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
  443. self.meters = meters
  444. self.prefix = prefix
  445.  
  446. def display(self, batch):
  447. entries = [self.prefix + self.batch_fmtstr.format(batch)]
  448. entries += [str(meter) for meter in self.meters]
  449. print('\t'.join(entries))
  450.  
  451. def _get_batch_fmtstr(self, num_batches):
  452. num_digits = len(str(num_batches // 1))
  453. fmt = '{:' + str(num_digits) + 'd}'
  454. return '[' + fmt + '/' + fmt.format(num_batches) + ']'
  455.  
  456.  
  457. def adjust_learning_rate(optimizer, epoch, args):
  458. """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
  459. lr = args.lr * (0.1 ** (epoch // 30))
  460. for param_group in optimizer.param_groups:
  461. param_group['lr'] = lr
  462.  
  463.  
  464. def accuracy(output, target, topk=(1,)):
  465. """Computes the accuracy over the k top predictions for the specified values of k"""
  466. with torch.no_grad():
  467. maxk = max(topk)
  468. batch_size = target.size(0)
  469.  
  470. _, pred = output.topk(maxk, 1, True, True)
  471. pred = pred.t()
  472. correct = pred.eq(target.view(1, -1).expand_as(pred))
  473.  
  474. res = []
  475. for k in topk:
  476. correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
  477. res.append(correct_k.mul_(100.0 / batch_size))
  478. return res
  479.  
  480.  
  481. class CommonPipeline(Pipeline):
  482. def __init__(self, batch_size, num_threads, device_id):
  483. super(CommonPipeline, self).__init__(
  484. batch_size, num_threads, device_id)
  485.  
  486. self.decode = ops.ImageDecoder(device="mixed", output_type=types.RGB)
  487. self.resize = ops.Resize(device="gpu",
  488. image_type=types.RGB,
  489. interp_type=types.INTERP_LINEAR)
  490. self.cmn = ops.CropMirrorNormalize(device="gpu",
  491. output_dtype=types.FLOAT,
  492. crop=(227, 227),
  493. image_type=types.RGB,
  494. mean=[128., 128., 128.],
  495. std=[1., 1., 1.])
  496. self.uniform = ops.Uniform(range=(0.0, 1.0))
  497. self.resize_rng = ops.Uniform(range=(256, 480))
  498.  
  499. def base_define_graph(self, inputs, labels):
  500. images = self.decode(inputs)
  501. images = self.resize(images, resize_shorter=self.resize_rng())
  502. output = self.cmn(images, crop_pos_x=self.uniform(),
  503. crop_pos_y=self.uniform())
  504. return (output, labels)
  505.  
  506.  
  507. class FileReadPipeline(CommonPipeline):
  508. def __init__(self, batch_size, num_threads, device_id, image_dir):
  509. super(FileReadPipeline, self).__init__(
  510. batch_size, num_threads, device_id)
  511. self.input = ops.FileReader(file_root=image_dir)
  512.  
  513. def define_graph(self):
  514. images, labels = self.input(name="Reader")
  515. return self.base_define_graph(images, labels)
  516.  
  517.  
  518. class ClassificationIterator(DALIClassificationIterator):
  519. def __init__(self,
  520. pipelines,
  521. size,
  522. auto_reset=False,
  523. fill_last_batch=True,
  524. dynamic_shape=False,
  525. last_batch_padded=False):
  526. self._first = True
  527. super().__init__(pipelines,
  528. size,
  529. auto_reset=auto_reset,
  530. fill_last_batch=fill_last_batch,
  531. dynamic_shape=dynamic_shape,
  532. last_batch_padded=last_batch_padded)
  533.  
  534. def __next__(self):
  535. # see DALIGenericIterator _first_batch handling
  536. if self._first:
  537. self._first = False
  538. return super().__next__()
  539. data = super().__next__()
  540. return (data[0]['data'], data[0]['label'].squeeze().long())
  541.  
  542. def __len__(self):
  543. # DALIGenericIterator member
  544. return self._size // self.batch_size
  545.  
  546.  
  547. if __name__ == '__main__':
  548. start = time.time()
  549. main()
  550. print('All done, cost {:.3f}'.format(time.time() - start))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement