Advertisement
Pastehsjsjs

Untitled

Apr 22nd, 2023
949
0
Never
1
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.96 KB | Source Code | 0 0
  1. def setup(rank, world_size):
  2.     os.environ['MASTER_ADDR'] = 'localhost'
  3.     os.environ['MASTER_PORT'] = '12355'
  4.  
  5.     dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
  6.  
  7.  
  8. def cleanup():
  9.     dist.destroy_process_group()
  10.  
  11.  
  12. def prepare_data(rank, CFG):
  13.     trn_dataset = Dataset(trn_df, transform)
  14.     train_sampler = DistributedSampler(trn_dataset, num_replicas=CFG.world_size, rank=rank, shuffle=True,
  15.                                        drop_last=True)
  16.  
  17.     dataloaders['train'] = DataLoader(
  18.         dataset=trn_dataset,
  19.         batch_size=CFG.batch_size,
  20.         pin_memory=True,
  21.         num_workers=2,
  22.         drop_last=True,
  23.         collate_fn=collator,
  24.         sampler=train_sampler
  25.     )
  26.     return dataloaders
  27.  
  28.  
  29. def train(rank, CFG):
  30.     setup(rank, CFG.world_size)
  31.  
  32.     dataloaders = prepare_data(rank, CFG)
  33.  
  34.     model = ViT()
  35.     model.to(rank)
  36.  
  37.     optimizer = torch.optim.AdamW(param_groups, weight_decay=CFG.weight_decay)
  38.  
  39.     ttl_iters = CFG.num_epochs * len(dataloaders['train'])
  40.     scheduler = CosineAnnealingLR(optimizer, T_max=ttl_iters)
  41.     criterion = nn.CosineEmbeddingLoss()
  42.  
  43.     model = DDP(model, device_ids=[rank], output_device=rank)
  44.  
  45.     for epoch in range(CFG.num_epochs):
  46.         dataloaders['train'].sampler.set_epoch(epoch)
  47.         dataloaders['val'].sampler.set_epoch(epoch)
  48.  
  49.         model.train()
  50.         model_ema.eval()
  51.         for X, y in dataloaders['train']:
  52.             X, y = X.to(rank), y.to(rank)
  53.  
  54.             optimizer.zero_grad()
  55.             X_out = model(X)
  56.             target = torch.ones(X.size(0)).to(rank)
  57.             loss = criterion(X_out, y, target)
  58.             loss.backward()
  59.  
  60.             optimizer.step()
  61.             scheduler.step()
  62.  
  63.             model_ema.update(model)
  64.  
  65.         model.eval()
  66.         for X, y in dataloaders['val']:
  67.             #val stuff
  68.  
  69.     cleanup()
  70.  
  71.  
  72. if __name__ == '__main__':
  73.     mp.spawn(train, args=(CFG,), nprocs=CFG.world_size)
  74.  
Advertisement
Comments
Add Comment
Please, Sign In to add comment
Advertisement