def setup(rank, world_size): os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '12355' dist.init_process_group(backend='nccl', rank=rank, world_size=world_size) def cleanup(): dist.destroy_process_group() def prepare_data(rank, CFG): trn_dataset = Dataset(trn_df, transform) train_sampler = DistributedSampler(trn_dataset, num_replicas=CFG.world_size, rank=rank, shuffle=True, drop_last=True) dataloaders['train'] = DataLoader( dataset=trn_dataset, batch_size=CFG.batch_size, pin_memory=True, num_workers=2, drop_last=True, collate_fn=collator, sampler=train_sampler ) return dataloaders def train(rank, CFG): setup(rank, CFG.world_size) dataloaders = prepare_data(rank, CFG) model = ViT() model.to(rank) optimizer = torch.optim.AdamW(param_groups, weight_decay=CFG.weight_decay) scheduler = CosineAnnealingLR(optimizer, T_max=ttl_iters) model = DDP(model, device_ids=[rank], output_device=rank) for epoch in range(CFG.num_epochs): dataloaders['train'].sampler.set_epoch(epoch) dataloaders['val'].sampler.set_epoch(epoch) model.train() for X, y in dataloaders['train']: X, y = X.to(rank), y.to(rank) optimizer.zero_grad() X_out = model(X) target = torch.ones(X.size(0)).to(rank) loss = criterion(X_out, y, target) loss.backward() optimizer.step() scheduler.step() cleanup() if __name__ == '__main__': mp.spawn(train, args=(CFG,), nprocs=CFG.world_size)