Advertisement
Guest User

Untitled

a guest
Oct 22nd, 2019
92
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.30 KB | None | 0 0
  1. import argparse
  2. import time
  3. import torch
  4. import torchvision
  5. from apex.parallel import DistributedDataParallel as DDP
  6. from apex import amp
  7.  
  8.  
  9. torch.backends.cudnn.benchmark = True
  10. assert torch.backends.cudnn.enabled
  11.  
  12. parser = argparse.ArgumentParser()
  13. parser.add_argument("--local_rank", default=0, type=int)
  14. parser.add_argument("--batch-size", type=int)
  15. args = parser.parse_args()
  16. LOCAL_RANK = args.local_rank
  17. BATCH_SIZE = args.batch_size
  18. del args
  19.  
  20. OLD = 'scale_loss' not in dir(amp)
  21. CROP_SIZE = 224
  22. PRINT_FREQ = 10
  23. PROF = 50
  24. torch.cuda.set_device(LOCAL_RANK)
  25. torch.distributed.init_process_group(backend='nccl', init_method='env://')
  26. WORLD_SIZE = torch.distributed.get_world_size()
  27. LR = 0.1 * float(BATCH_SIZE * WORLD_SIZE) / 256.0
  28.  
  29. model = torchvision.models.vgg16().cuda()
  30. optimizer = torch.optim.SGD(model.parameters(), LR)
  31. criterion = torch.nn.CrossEntropyLoss().cuda()
  32.  
  33. input = torch.randn(BATCH_SIZE, 3, CROP_SIZE, CROP_SIZE, device='cuda')
  34. target = torch.zeros(BATCH_SIZE, device='cuda', dtype=torch.long)
  35.  
  36. # Initialize Amp
  37. if OLD:
  38. amp_handle = amp.init(enabled=False)
  39. else:
  40. model, optimizer = amp.initialize(model, optimizer, opt_level='O0')
  41. amp_handle = amp
  42.  
  43. model = DDP(model, delay_allreduce=True)
  44. model.train()
  45.  
  46.  
  47. def train_batch(i):
  48. torch.cuda.nvtx.range_push("Body of iteration {}".format(i))
  49.  
  50. torch.cuda.nvtx.range_push("forward")
  51. output = model(input)
  52. torch.cuda.nvtx.range_pop()
  53.  
  54. loss = criterion(output, target)
  55. optimizer.zero_grad()
  56.  
  57. torch.cuda.nvtx.range_push("backward")
  58. with amp_handle.scale_loss(loss, optimizer) as scaled_loss:
  59. scaled_loss.backward()
  60. torch.cuda.nvtx.range_pop()
  61.  
  62. torch.cuda.nvtx.range_push("optimizer.step()")
  63. optimizer.step()
  64. torch.cuda.nvtx.range_pop()
  65.  
  66. torch.cuda.nvtx.range_pop()
  67.  
  68.  
  69. def train_many_batches(base):
  70. start = time.time()
  71. for i in range(PRINT_FREQ):
  72. train_batch(base * PRINT_FREQ + i)
  73. torch.cuda.synchronize()
  74. end = time.time()
  75.  
  76. batch_time = (end - start) / PRINT_FREQ
  77. speed = WORLD_SIZE * BATCH_SIZE / batch_time
  78. if LOCAL_RANK == 0:
  79. print('Time {:.3f}\tSpeed {:.3f}\t'.format(batch_time, speed))
  80.  
  81.  
  82. # warm up
  83. for i in range(PROF // PRINT_FREQ):
  84. train_many_batches(i)
  85.  
  86. # profile
  87. torch.cuda.cudart().cudaProfilerStart()
  88. train_many_batches(i + 1)
  89. torch.cuda.cudart().cudaProfilerStop()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement