Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import argparse
- import time
- import torch
- import torchvision
- from apex.parallel import DistributedDataParallel as DDP
- from apex import amp
- torch.backends.cudnn.benchmark = True
- assert torch.backends.cudnn.enabled
- parser = argparse.ArgumentParser()
- parser.add_argument("--local_rank", default=0, type=int)
- parser.add_argument("--batch-size", type=int)
- args = parser.parse_args()
- LOCAL_RANK = args.local_rank
- BATCH_SIZE = args.batch_size
- del args
- OLD = 'scale_loss' not in dir(amp)
- CROP_SIZE = 224
- PRINT_FREQ = 10
- PROF = 50
- torch.cuda.set_device(LOCAL_RANK)
- torch.distributed.init_process_group(backend='nccl', init_method='env://')
- WORLD_SIZE = torch.distributed.get_world_size()
- LR = 0.1 * float(BATCH_SIZE * WORLD_SIZE) / 256.0
- model = torchvision.models.vgg16().cuda()
- optimizer = torch.optim.SGD(model.parameters(), LR)
- criterion = torch.nn.CrossEntropyLoss().cuda()
- input = torch.randn(BATCH_SIZE, 3, CROP_SIZE, CROP_SIZE, device='cuda')
- target = torch.zeros(BATCH_SIZE, device='cuda', dtype=torch.long)
- # Initialize Amp
- if OLD:
- amp_handle = amp.init(enabled=False)
- else:
- model, optimizer = amp.initialize(model, optimizer, opt_level='O0')
- amp_handle = amp
- model = DDP(model, delay_allreduce=True)
- model.train()
- def train_batch(i):
- torch.cuda.nvtx.range_push("Body of iteration {}".format(i))
- torch.cuda.nvtx.range_push("forward")
- output = model(input)
- torch.cuda.nvtx.range_pop()
- loss = criterion(output, target)
- optimizer.zero_grad()
- torch.cuda.nvtx.range_push("backward")
- with amp_handle.scale_loss(loss, optimizer) as scaled_loss:
- scaled_loss.backward()
- torch.cuda.nvtx.range_pop()
- torch.cuda.nvtx.range_push("optimizer.step()")
- optimizer.step()
- torch.cuda.nvtx.range_pop()
- torch.cuda.nvtx.range_pop()
- def train_many_batches(base):
- start = time.time()
- for i in range(PRINT_FREQ):
- train_batch(base * PRINT_FREQ + i)
- torch.cuda.synchronize()
- end = time.time()
- batch_time = (end - start) / PRINT_FREQ
- speed = WORLD_SIZE * BATCH_SIZE / batch_time
- if LOCAL_RANK == 0:
- print('Time {:.3f}\tSpeed {:.3f}\t'.format(batch_time, speed))
- # warm up
- for i in range(PROF // PRINT_FREQ):
- train_many_batches(i)
- # profile
- torch.cuda.cudart().cudaProfilerStart()
- train_many_batches(i + 1)
- torch.cuda.cudart().cudaProfilerStop()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement