Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def train_model(model, criterion, optimizer, history, scheduler=None, num_epochs=25, save_path='checkpoint', continue_training=False, start_epoch=0):
- # load trained model
- if continue_training:
- with open(BASE_PATH + 'weights/{}_{}.model'.format(save_path, start_epoch - 1), 'rb') as f:
- state = torch.load(f, map_location=DEVICE)
- model.load_state_dict(state)
- with open(BASE_PATH + 'weights/{}_{}.optimizer'.format(save_path, start_epoch - 1), 'rb') as f:
- state = torch.load(f, map_location=DEVICE)
- optimizer.load_state_dict(state)
- with open(BASE_PATH + 'weights/{}_{}.history'.format(save_path, start_epoch - 1), 'rb') as f:
- history = torch.load(f)
- if scheduler:
- with open(BASE_PATH + 'weights/{}_{}.scheduler'.format(save_path, start_epoch - 1), 'rb') as f:
- state = torch.load(f, map_location=DEVICE)
- scheduler.load_state_dict(state)
- for epoch in range(start_epoch, num_epochs):
- since = time.time()
- # Each epoch has a training and validation phase
- for phase in ['train', 'val']:
- if phase == 'train':
- model.train() # Set model to training mode
- else:
- model.eval() # Set model to evaluate mode
- running_metrics = {}
- """Iterate over data.
- `dataloaders` is a dict{'train': train_dataloader
- 'val': validation_dataloader}
- """
- iterator = tqdm(dataloaders[phase])
- for batch in iterator:
- """
- Batch comes as a dict.
- """
- for k in batch:
- batch[k] = batch[k].to(DEVICE)
- # zero the parameter gradients
- optimizer.zero_grad()
- # forward
- # track history if only in train
- with torch.set_grad_enabled(phase == 'train'):
- outputs = model(batch['src'],
- batch['dst'],
- batch['src_lengths'],
- batch['dst_lengths'])
- _, preds = outputs.max(dim=2)
- loss = criterion(outputs.view(-1, len(train_dataset.src_token2id)), batch['dst'].view(-1))
- # backward + optimize only if in training phase
- if phase == 'train':
- loss.backward()
- nn.utils.clip_grad_norm_(model.parameters(), GRADIENT_CLIPPING)
- optimizer.step()
- # statistics
- running_metrics.setdefault('loss', 0.0)
- running_metrics['loss'] += loss.item() * batch['src'].size(0)
- for pred, ground_truth in zip(preds, batch['dst']):
- metrics = get_metrics(pred, ground_truth) # supposed to return a dictionary of metrics
- for metric_name in metrics:
- running_metrics.setdefault(metric_name, 0.0)
- running_metrics[metric_name] += metrics[metric_name]
- for metric_name in running_metrics:
- multiplier = 1
- average_metric = running_metrics[metric_name] / dataset_sizes[phase]
- history.setdefault(phase, {}).setdefault(metric_name, []).append(average_metric * multiplier)
- print('{} Loss: {:.4f} Rouge: {:.4f}'.format(
- phase, history[phase]['loss'][-1], history[phase]['rouge-l'][-1]))
- # LR scheduler
- if scheduler and phase == 'val':
- scheduler.step(history['val']['loss'][-1])
- # save model and history
- with open(BASE_PATH + 'weights/{}_{}.model'.format(save_path, epoch), 'wb') as f:
- torch.save(model.state_dict(), f)
- with open(BASE_PATH + 'weights/{}_{}.optimizer'.format(save_path, epoch), 'wb') as f:
- torch.save(optimizer.state_dict(), f)
- with open(BASE_PATH + 'weights/{}_{}.history'.format(save_path, epoch), 'wb') as f:
- torch.save(history, f)
- if scheduler:
- with open(BASE_PATH + 'weights/{}_{}.scheduler'.format(save_path, epoch), 'wb') as f:
- torch.save(scheduler.state_dict(), f)
- time_elapsed = time.time() - since
- history.setdefault('times', []).append(time_elapsed) # save times per-epoch
- print('Epoch {} complete in {:.0f}m {:.0f}s'.format(epoch,
- time_elapsed // 60, time_elapsed % 60))
- print()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement