Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def setup(rank, world_size):
- os.environ['MASTER_ADDR'] = 'localhost'
- os.environ['MASTER_PORT'] = '12355'
- dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
- def cleanup():
- dist.destroy_process_group()
- def prepare_data(rank, CFG):
- trn_dataset = Dataset(trn_df, transform)
- train_sampler = DistributedSampler(trn_dataset, num_replicas=CFG.world_size, rank=rank, shuffle=True,
- drop_last=True)
- dataloaders['train'] = DataLoader(
- dataset=trn_dataset,
- batch_size=CFG.batch_size,
- pin_memory=True,
- num_workers=2,
- drop_last=True,
- collate_fn=collator,
- sampler=train_sampler
- )
- return dataloaders
- def train(rank, CFG):
- setup(rank, CFG.world_size)
- dataloaders = prepare_data(rank, CFG)
- model = ViT()
- model.to(rank)
- optimizer = torch.optim.AdamW(param_groups, weight_decay=CFG.weight_decay)
- scheduler = CosineAnnealingLR(optimizer, T_max=ttl_iters)
- model = DDP(model, device_ids=[rank], output_device=rank)
- for epoch in range(CFG.num_epochs):
- dataloaders['train'].sampler.set_epoch(epoch)
- dataloaders['val'].sampler.set_epoch(epoch)
- model.train()
- for X, y in dataloaders['train']:
- X, y = X.to(rank), y.to(rank)
- optimizer.zero_grad()
- X_out = model(X)
- target = torch.ones(X.size(0)).to(rank)
- loss = criterion(X_out, y, target)
- loss.backward()
- optimizer.step()
- scheduler.step()
- cleanup()
- if __name__ == '__main__':
- mp.spawn(train, args=(CFG,), nprocs=CFG.world_size)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement