Advertisement
Guest User

Untitled

a guest
Jul 21st, 2019
99
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.54 KB | None | 0 0
  1. def train_model(model, criterion, optimizer, history, scheduler=None, num_epochs=25, save_path='checkpoint', continue_training=False, start_epoch=0):
  2. # load trained model
  3. if continue_training:
  4. with open(BASE_PATH + 'weights/{}_{}.model'.format(save_path, start_epoch - 1), 'rb') as f:
  5. state = torch.load(f, map_location=DEVICE)
  6. model.load_state_dict(state)
  7. with open(BASE_PATH + 'weights/{}_{}.optimizer'.format(save_path, start_epoch - 1), 'rb') as f:
  8. state = torch.load(f, map_location=DEVICE)
  9. optimizer.load_state_dict(state)
  10. with open(BASE_PATH + 'weights/{}_{}.history'.format(save_path, start_epoch - 1), 'rb') as f:
  11. history = torch.load(f)
  12. if scheduler:
  13. with open(BASE_PATH + 'weights/{}_{}.scheduler'.format(save_path, start_epoch - 1), 'rb') as f:
  14. state = torch.load(f, map_location=DEVICE)
  15. scheduler.load_state_dict(state)
  16.  
  17. for epoch in range(start_epoch, num_epochs):
  18. since = time.time()
  19. # Each epoch has a training and validation phase
  20. for phase in ['train', 'val']:
  21. if phase == 'train':
  22. model.train() # Set model to training mode
  23. else:
  24. model.eval() # Set model to evaluate mode
  25.  
  26. running_metrics = {}
  27.  
  28. """Iterate over data.
  29. `dataloaders` is a dict{'train': train_dataloader
  30. 'val': validation_dataloader}
  31. """
  32. iterator = tqdm(dataloaders[phase])
  33. for batch in iterator:
  34. """
  35. Batch comes as a dict.
  36. """
  37. for k in batch:
  38. batch[k] = batch[k].to(DEVICE)
  39.  
  40. # zero the parameter gradients
  41. optimizer.zero_grad()
  42.  
  43. # forward
  44. # track history if only in train
  45. with torch.set_grad_enabled(phase == 'train'):
  46.  
  47. outputs = model(batch['src'],
  48. batch['dst'],
  49. batch['src_lengths'],
  50. batch['dst_lengths'])
  51. _, preds = outputs.max(dim=2)
  52.  
  53. loss = criterion(outputs.view(-1, len(train_dataset.src_token2id)), batch['dst'].view(-1))
  54.  
  55. # backward + optimize only if in training phase
  56. if phase == 'train':
  57. loss.backward()
  58. nn.utils.clip_grad_norm_(model.parameters(), GRADIENT_CLIPPING)
  59. optimizer.step()
  60.  
  61. # statistics
  62. running_metrics.setdefault('loss', 0.0)
  63. running_metrics['loss'] += loss.item() * batch['src'].size(0)
  64. for pred, ground_truth in zip(preds, batch['dst']):
  65. metrics = get_metrics(pred, ground_truth) # supposed to return a dictionary of metrics
  66. for metric_name in metrics:
  67. running_metrics.setdefault(metric_name, 0.0)
  68. running_metrics[metric_name] += metrics[metric_name]
  69.  
  70. for metric_name in running_metrics:
  71. multiplier = 1
  72. average_metric = running_metrics[metric_name] / dataset_sizes[phase]
  73. history.setdefault(phase, {}).setdefault(metric_name, []).append(average_metric * multiplier)
  74.  
  75. print('{} Loss: {:.4f} Rouge: {:.4f}'.format(
  76. phase, history[phase]['loss'][-1], history[phase]['rouge-l'][-1]))
  77.  
  78. # LR scheduler
  79. if scheduler and phase == 'val':
  80. scheduler.step(history['val']['loss'][-1])
  81.  
  82. # save model and history
  83. with open(BASE_PATH + 'weights/{}_{}.model'.format(save_path, epoch), 'wb') as f:
  84. torch.save(model.state_dict(), f)
  85. with open(BASE_PATH + 'weights/{}_{}.optimizer'.format(save_path, epoch), 'wb') as f:
  86. torch.save(optimizer.state_dict(), f)
  87. with open(BASE_PATH + 'weights/{}_{}.history'.format(save_path, epoch), 'wb') as f:
  88. torch.save(history, f)
  89. if scheduler:
  90. with open(BASE_PATH + 'weights/{}_{}.scheduler'.format(save_path, epoch), 'wb') as f:
  91. torch.save(scheduler.state_dict(), f)
  92.  
  93.  
  94. time_elapsed = time.time() - since
  95. history.setdefault('times', []).append(time_elapsed) # save times per-epoch
  96. print('Epoch {} complete in {:.0f}m {:.0f}s'.format(epoch,
  97. time_elapsed // 60, time_elapsed % 60))
  98. print()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement