Advertisement
lamiastella

model train separate train val

Mar 10th, 2022
862
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 5.22 KB | None | 0 0
  1. def train_model(model, dataloaders, criterion, optimizer, num_epochs=10, is_inception=False):
  2.     since = time.time()
  3.  
  4.     val_acc_history = []
  5.  
  6.     best_model_wts = copy.deepcopy(model.state_dict())
  7.     val_best_acc = 0.0
  8.     best_val_acc_epoch = 0
  9.    
  10.     #early_stopping = EarlyStopping(patience=patience, verbose=True)
  11.  
  12.     for epoch in range(num_epochs):
  13.         print('Epoch {}/{}'.format(epoch, num_epochs - 1))
  14.         print('-' * 10)
  15.  
  16.         # Each epoch has a training and validation phase
  17.         for phase in ['train', 'val']:
  18.      
  19.             train_running_loss = 0.0
  20.             train_running_corrects = 0
  21.            
  22.             val_running_loss = 0.0
  23.             val_running_corrects = 0
  24.            
  25.             if phase == 'train':
  26.                 model.train()
  27.                 # Iterate over data.
  28.                 for inputs, labels in dataloaders[phase]:
  29.                     train_inputs = inputs.to(device)
  30.                     train_labels = labels.to(device)
  31.                     #print('train_labels: ', train_labels)
  32.  
  33.                     # zero the parameter gradients
  34.                     optimizer.zero_grad()
  35.  
  36.                     # forward
  37.                     # track history if only in train
  38.                     with torch.set_grad_enabled(True):
  39.                         # Get model outputs and calculate loss
  40.                         # Special case for inception because in training it has an auxiliary output. In train
  41.                         #   mode we calculate the loss by summing the final output and the auxiliary output
  42.                         #   but in testing we only consider the final output.
  43.                    
  44.                         # From https://discuss.pytorch.org/t/how-to-optimize-inception-model-with-auxiliary-classifiers/7958
  45.                         train_outputs, aux_outputs = model(train_inputs)
  46.                         loss1 = criterion(train_outputs, train_labels)
  47.                         loss2 = criterion(aux_outputs, train_labels)
  48.                         train_loss = loss1 + 0.4*loss2
  49.                         _, train_preds = torch.max(train_outputs, 1)
  50.                        
  51.                         #print('train_preds:', train_preds)
  52.                        
  53.  
  54.                         # backward + optimize only if in training phase
  55.                        
  56.                         train_loss.backward()
  57.                         optimizer.step()
  58.                            
  59.                     # statistics
  60.                     train_running_loss += train_loss.item() * train_inputs.size(0)
  61.                     train_running_corrects += torch.sum(train_preds == train_labels.data)
  62.                 # Iterate over data.
  63.             if phase == 'val':
  64.                 model.eval()
  65.                 with torch.no_grad():
  66.                     for inputs, labels in dataloaders[phase]:
  67.                         val_inputs = inputs.to(device)
  68.                         val_labels = labels.to(device)
  69.  
  70.                         val_outputs = model(val_inputs)
  71.                         val_loss = criterion(val_outputs, val_labels)
  72.                         _, val_preds = torch.max(val_outputs, 1)
  73.                        
  74.                         # statistics
  75.                         val_running_loss += val_loss.item() * val_inputs.size(0)
  76.                         val_running_corrects += torch.sum(val_preds == val_labels.data)
  77.  
  78.  
  79.             train_epoch_loss = train_running_loss / len(dataloaders['train'].dataset)
  80.             train_epoch_acc = train_running_corrects / len(dataloaders['train'].dataset)
  81.            
  82.             val_epoch_loss = val_running_loss / len(dataloaders['val'].dataset)
  83.             val_epoch_acc = val_running_corrects / len(dataloaders['val'].dataset)
  84.            
  85.             wandb.log({"train loss": train_epoch_loss,
  86.                        "val loss": val_epoch_loss,
  87.                        "epoch": epoch})
  88.            
  89.             wandb.log({"train acc": train_epoch_acc,
  90.                        "val acc": val_epoch_acc,
  91.                        "epoch": epoch})
  92.            
  93.             wandb.log({"best val acc": val_best_acc, "epoch": epoch})
  94.             #wandb.log(model)
  95.  
  96.    
  97.  
  98.             # deep copy the model
  99.             if phase == 'val' and val_epoch_acc > val_best_acc:
  100.                 val_best_acc = val_epoch_acc
  101.                 best_model_wts = copy.deepcopy(model.state_dict())
  102.                 best_val_acc_epoch = epoch
  103.             if phase == 'val':
  104.                 # early_stopping(val_loss, model)
  105.                 # if early_stopping.early_stop:
  106.                 #     print("Early stopping")
  107.                 #     break
  108.                 val_acc_history.append(val_epoch_acc)
  109.                
  110.         print('train loss: {:.4f} train acc: {:.4f}'.format(train_epoch_loss, train_epoch_acc))
  111.         print('val loss: {:.4f} val acc: {:.4f}'.format(val_epoch_loss, val_epoch_acc))
  112.  
  113.        
  114.  
  115.     time_elapsed = time.time() - since
  116.     print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
  117.     print('Best val Acc: {:4f} happened in epoch {}'.format(val_best_acc, best_val_acc_epoch))
  118.  
  119.     # load best model weights
  120.     model.load_state_dict(best_model_wts)
  121.     return model, val_acc_history
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement