def setup(rank, world_size): os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '12355' dist.init_process_group(backend='nccl', rank=rank, world_size=world_size) def cleanup(): dist.destroy_process_group() def prepare_data(rank, CFG): trn_dataset = Dataset(trn_df, transform) train_sampler = DistributedSampler(trn_dataset, num_replicas=CFG.world_size, rank=rank, shuffle=True, drop_last=True) dataloaders['train'] = DataLoader( dataset=trn_dataset, batch_size=CFG.batch_size, pin_memory=True, num_workers=2, drop_last=True, collate_fn=collator, sampler=train_sampler ) return dataloaders def add_layer_wise_lr(rank, model, CFG): lr = CFG.base_learning_rate lr_mult = CFG.layer_wise_lr_decay layer_names = [name for name, _ in model.named_parameters()][::-1] param_groups = [] prev_group_name = layer_names[0].split('.')[0] for idx, name in enumerate(layer_names): if 'blocks' in name: cur_group_name = name.split('.')[0] + name.split('.')[1] else: cur_group_name = name.split('.')[0] if cur_group_name != prev_group_name: lr *= lr_mult layer_dict = {'params': [p for n, p in model.named_parameters() if n == name and p.requires_grad], 'lr': lr} param_groups.append(layer_dict) if rank == 0 and (prev_group_name != cur_group_name or idx == 0): print(f'{idx}: lr = {lr:.11f}, {name}') prev_group_name = cur_group_name return param_groups '''для ViT-B вывод такой: 0: lr = 0.00060000000, head.bias 2: lr = 0.00036000000, norm.bias 4: lr = 0.00021600000, blocks.11.mlp.fc2.bias 16: lr = 0.00012960000, blocks.10.mlp.fc2.bias 28: lr = 0.00007776000, blocks.9.mlp.fc2.bias 40: lr = 0.00004665600, blocks.8.mlp.fc2.bias 52: lr = 0.00002799360, blocks.7.mlp.fc2.bias 64: lr = 0.00001679616, blocks.6.mlp.fc2.bias 76: lr = 0.00001007770, blocks.5.mlp.fc2.bias 88: lr = 0.00000604662, blocks.4.mlp.fc2.bias 100: lr = 0.00000362797, blocks.3.mlp.fc2.bias 112: lr = 0.00000217678, blocks.2.mlp.fc2.bias 124: lr = 0.00000130607, blocks.1.mlp.fc2.bias 136: lr = 0.00000078364, blocks.0.mlp.fc2.bias 148: lr = 0.00000047018, norm_pre.bias 150: lr = 0.00000028211, patch_embed.proj.weight 151: lr = 0.00000016927, pos_embed 152: lr = 0.00000010156, cls_token ''' def train(rank, CFG): setup(rank, CFG.world_size) dataloaders = prepare_data(rank, CFG) model = ViT() model.to(rank) model_ema = ModelEmaV2(model, decay=CFG.ema) param_groups = add_layer_wise_lr(rank, model, CFG) optimizer = torch.optim.AdamW(param_groups, weight_decay=CFG.weight_decay) ttl_iters = CFG.num_epochs * len(dataloaders['train']) scheduler = CosineAnnealingLR(optimizer, T_max=ttl_iters) criterion = nn.CosineEmbeddingLoss() model = DDP(model, device_ids=[rank], output_device=rank) for epoch in range(CFG.num_epochs): dataloaders['train'].sampler.set_epoch(epoch) dataloaders['val'].sampler.set_epoch(epoch) model.train() model_ema.eval() for X, y in dataloaders['train']: X, y = X.to(rank), y.to(rank) optimizer.zero_grad() X_out = model(X) target = torch.ones(X.size(0)).to(rank) loss = criterion(X_out, y, target) loss.backward() optimizer.step() scheduler.step() model_ema.update(model) model.eval() for X, y in dataloaders['val']: ... cleanup() if __name__ == '__main__': mp.spawn(train, args=(CFG,), nprocs=CFG.world_size)