Advertisement
Pastehsjsjs

Untitled

Apr 22nd, 2023
597
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.85 KB | Source Code | 0 0
  1. def setup(rank, world_size):
  2.     os.environ['MASTER_ADDR'] = 'localhost'
  3.     os.environ['MASTER_PORT'] = '12355'
  4.  
  5.     dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
  6.  
  7.  
  8. def cleanup():
  9.     dist.destroy_process_group()
  10.  
  11.  
  12. def prepare_data(rank, CFG):
  13.     trn_dataset = Dataset(trn_df, transform)
  14.     train_sampler = DistributedSampler(trn_dataset, num_replicas=CFG.world_size, rank=rank, shuffle=True,
  15.                                        drop_last=True)
  16.  
  17.     dataloaders['train'] = DataLoader(
  18.         dataset=trn_dataset,
  19.         batch_size=CFG.batch_size,
  20.         pin_memory=True,
  21.         num_workers=2,
  22.         drop_last=True,
  23.         collate_fn=collator,
  24.         sampler=train_sampler
  25.     )
  26.     return dataloaders
  27.  
  28.  
  29. def add_layer_wise_lr(rank, model, CFG):
  30.     lr = CFG.base_learning_rate
  31.     lr_mult = CFG.layer_wise_lr_decay
  32.  
  33.     layer_names = [name for name, _ in model.named_parameters()][::-1]
  34.     param_groups = []
  35.     prev_group_name = layer_names[0].split('.')[0]
  36.     for idx, name in enumerate(layer_names):
  37.         if 'blocks' in name:
  38.             cur_group_name = name.split('.')[0] + name.split('.')[1]
  39.         else:
  40.             cur_group_name = name.split('.')[0]
  41.  
  42.         if cur_group_name != prev_group_name:
  43.             lr *= lr_mult
  44.  
  45.         layer_dict = {'params': [p for n, p in model.named_parameters() if n == name and p.requires_grad],
  46.                       'lr': lr}
  47.  
  48.         param_groups.append(layer_dict)
  49.  
  50.         if rank == 0 and (prev_group_name != cur_group_name or idx == 0):
  51.             print(f'{idx}: lr = {lr:.11f}, {name}')
  52.         prev_group_name = cur_group_name
  53.     return param_groups
  54.  
  55.  
  56. '''для ViT-B вывод такой:
  57. 0: lr = 0.00060000000, head.bias
  58. 2: lr = 0.00036000000, norm.bias
  59. 4: lr = 0.00021600000, blocks.11.mlp.fc2.bias
  60. 16: lr = 0.00012960000, blocks.10.mlp.fc2.bias
  61. 28: lr = 0.00007776000, blocks.9.mlp.fc2.bias
  62. 40: lr = 0.00004665600, blocks.8.mlp.fc2.bias
  63. 52: lr = 0.00002799360, blocks.7.mlp.fc2.bias
  64. 64: lr = 0.00001679616, blocks.6.mlp.fc2.bias
  65. 76: lr = 0.00001007770, blocks.5.mlp.fc2.bias
  66. 88: lr = 0.00000604662, blocks.4.mlp.fc2.bias
  67. 100: lr = 0.00000362797, blocks.3.mlp.fc2.bias
  68. 112: lr = 0.00000217678, blocks.2.mlp.fc2.bias
  69. 124: lr = 0.00000130607, blocks.1.mlp.fc2.bias
  70. 136: lr = 0.00000078364, blocks.0.mlp.fc2.bias
  71. 148: lr = 0.00000047018, norm_pre.bias
  72. 150: lr = 0.00000028211, patch_embed.proj.weight
  73. 151: lr = 0.00000016927, pos_embed
  74. 152: lr = 0.00000010156, cls_token
  75. '''
  76.  
  77.  
  78. def train(rank, CFG):
  79.     setup(rank, CFG.world_size)
  80.  
  81.     dataloaders = prepare_data(rank, CFG)
  82.  
  83.     model = ViT()
  84.     model.to(rank)
  85.  
  86.     model_ema = ModelEmaV2(model, decay=CFG.ema)
  87.  
  88.     param_groups = add_layer_wise_lr(rank, model, CFG)
  89.  
  90.     optimizer = torch.optim.AdamW(param_groups, weight_decay=CFG.weight_decay)
  91.  
  92.     ttl_iters = CFG.num_epochs * len(dataloaders['train'])
  93.     scheduler = CosineAnnealingLR(optimizer, T_max=ttl_iters)
  94.     criterion = nn.CosineEmbeddingLoss()
  95.  
  96.     model = DDP(model, device_ids=[rank], output_device=rank)
  97.  
  98.     for epoch in range(CFG.num_epochs):
  99.         dataloaders['train'].sampler.set_epoch(epoch)
  100.         dataloaders['val'].sampler.set_epoch(epoch)
  101.  
  102.         model.train()
  103.         model_ema.eval()
  104.         for X, y in dataloaders['train']:
  105.             X, y = X.to(rank), y.to(rank)
  106.  
  107.             optimizer.zero_grad()
  108.             X_out = model(X)
  109.             target = torch.ones(X.size(0)).to(rank)
  110.             loss = criterion(X_out, y, target)
  111.             loss.backward()
  112.  
  113.             optimizer.step()
  114.             scheduler.step()
  115.  
  116.             model_ema.update(model)
  117.  
  118.         model.eval()
  119.         for X, y in dataloaders['val']:
  120.             #валидация для model и model_ema
  121.  
  122.     cleanup()
  123.  
  124.  
  125. if __name__ == '__main__':
  126.     mp.spawn(train, args=(CFG,), nprocs=CFG.world_size)
  127.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement