Advertisement
Pastehsjsjs

Untitled

Apr 22nd, 2023
942
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.71 KB | Source Code | 0 0
  1. def setup(rank, world_size):
  2.     os.environ['MASTER_ADDR'] = 'localhost'
  3.     os.environ['MASTER_PORT'] = '12355'
  4.     dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
  5.  
  6. def cleanup():
  7.     dist.destroy_process_group()
  8.  
  9. def prepare_data(rank, CFG):
  10.     trn_dataset = Dataset(trn_df, transform)
  11.     train_sampler = DistributedSampler(trn_dataset, num_replicas=CFG.world_size, rank=rank, shuffle=True,
  12.                                        drop_last=True)
  13.  
  14.     dataloaders['train'] = DataLoader(
  15.         dataset=trn_dataset,
  16.         batch_size=CFG.batch_size,
  17.         pin_memory=True,
  18.         num_workers=2,
  19.         drop_last=True,
  20.         collate_fn=collator,
  21.         sampler=train_sampler
  22.     )
  23.     return dataloaders
  24.  
  25.  
  26. def train(rank, CFG):
  27.     setup(rank, CFG.world_size)
  28.  
  29.     dataloaders = prepare_data(rank, CFG)
  30.  
  31.     model = ViT()
  32.     model.to(rank)
  33.  
  34.     optimizer = torch.optim.AdamW(param_groups, weight_decay=CFG.weight_decay)
  35.     scheduler = CosineAnnealingLR(optimizer, T_max=ttl_iters)
  36.  
  37.     model = DDP(model, device_ids=[rank], output_device=rank)
  38.  
  39.     for epoch in range(CFG.num_epochs):
  40.         dataloaders['train'].sampler.set_epoch(epoch)
  41.         dataloaders['val'].sampler.set_epoch(epoch)
  42.  
  43.         model.train()
  44.         for X, y in dataloaders['train']:
  45.             X, y = X.to(rank), y.to(rank)
  46.  
  47.             optimizer.zero_grad()
  48.             X_out = model(X)
  49.             target = torch.ones(X.size(0)).to(rank)
  50.             loss = criterion(X_out, y, target)
  51.             loss.backward()
  52.  
  53.             optimizer.step()
  54.             scheduler.step()
  55.  
  56.     cleanup()
  57.  
  58.  
  59. if __name__ == '__main__':
  60.     mp.spawn(train, args=(CFG,), nprocs=CFG.world_size)
  61.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement