Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def train_model(model, dataloaders, criterion, optimizer, num_epochs=10, is_inception=False):
- since = time.time()
- val_acc_history = []
- best_model_wts = copy.deepcopy(model.state_dict())
- val_best_acc = 0.0
- best_val_acc_epoch = 0
- #early_stopping = EarlyStopping(patience=patience, verbose=True)
- for epoch in range(num_epochs):
- print('Epoch {}/{}'.format(epoch, num_epochs - 1))
- print('-' * 10)
- # Each epoch has a training and validation phase
- for phase in ['train', 'val']:
- train_running_loss = 0.0
- train_running_corrects = 0
- val_running_loss = 0.0
- val_running_corrects = 0
- if phase == 'train':
- model.train()
- # Iterate over data.
- for inputs, labels in dataloaders[phase]:
- train_inputs = inputs.to(device)
- train_labels = labels.to(device)
- #print('train_labels: ', train_labels)
- # zero the parameter gradients
- optimizer.zero_grad()
- # forward
- # track history if only in train
- with torch.set_grad_enabled(True):
- # Get model outputs and calculate loss
- # Special case for inception because in training it has an auxiliary output. In train
- # mode we calculate the loss by summing the final output and the auxiliary output
- # but in testing we only consider the final output.
- # From https://discuss.pytorch.org/t/how-to-optimize-inception-model-with-auxiliary-classifiers/7958
- train_outputs, aux_outputs = model(train_inputs)
- loss1 = criterion(train_outputs, train_labels)
- loss2 = criterion(aux_outputs, train_labels)
- train_loss = loss1 + 0.4*loss2
- _, train_preds = torch.max(train_outputs, 1)
- #print('train_preds:', train_preds)
- # backward + optimize only if in training phase
- train_loss.backward()
- optimizer.step()
- # statistics
- train_running_loss += train_loss.item() * train_inputs.size(0)
- train_running_corrects += torch.sum(train_preds == train_labels.data)
- # Iterate over data.
- if phase == 'val':
- model.eval()
- with torch.no_grad():
- for inputs, labels in dataloaders[phase]:
- val_inputs = inputs.to(device)
- val_labels = labels.to(device)
- val_outputs = model(val_inputs)
- val_loss = criterion(val_outputs, val_labels)
- _, val_preds = torch.max(val_outputs, 1)
- # statistics
- val_running_loss += val_loss.item() * val_inputs.size(0)
- val_running_corrects += torch.sum(val_preds == val_labels.data)
- train_epoch_loss = train_running_loss / len(dataloaders['train'].dataset)
- train_epoch_acc = train_running_corrects / len(dataloaders['train'].dataset)
- val_epoch_loss = val_running_loss / len(dataloaders['val'].dataset)
- val_epoch_acc = val_running_corrects / len(dataloaders['val'].dataset)
- wandb.log({"train loss": train_epoch_loss,
- "val loss": val_epoch_loss,
- "epoch": epoch})
- wandb.log({"train acc": train_epoch_acc,
- "val acc": val_epoch_acc,
- "epoch": epoch})
- wandb.log({"best val acc": val_best_acc, "epoch": epoch})
- #wandb.log(model)
- # deep copy the model
- if phase == 'val' and val_epoch_acc > val_best_acc:
- val_best_acc = val_epoch_acc
- best_model_wts = copy.deepcopy(model.state_dict())
- best_val_acc_epoch = epoch
- if phase == 'val':
- # early_stopping(val_loss, model)
- # if early_stopping.early_stop:
- # print("Early stopping")
- # break
- val_acc_history.append(val_epoch_acc)
- print('train loss: {:.4f} train acc: {:.4f}'.format(train_epoch_loss, train_epoch_acc))
- print('val loss: {:.4f} val acc: {:.4f}'.format(val_epoch_loss, val_epoch_acc))
- time_elapsed = time.time() - since
- print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
- print('Best val Acc: {:4f} happened in epoch {}'.format(val_best_acc, best_val_acc_epoch))
- # load best model weights
- model.load_state_dict(best_model_wts)
- return model, val_acc_history
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement