Advertisement
lamiastella

validation phase

Mar 9th, 2022
1,216
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 0.69 KB | None | 0 0
  1.             if phase == 'val':
  2.                 model.eval()
  3.                 with torch.no_grad():
  4.                     for inputs, labels in dataloaders[phase]:
  5.                         val_inputs = inputs.to(device)
  6.                         val_labels = labels.to(device)
  7.  
  8.                         val_outputs = model(val_inputs)
  9.                         val_loss = criterion(val_outputs, val_labels)
  10.                         _, val_preds = torch.max(val_outputs, 1)
  11.                        
  12.                         # statistics
  13.                         val_running_loss += val_loss.item() * val_inputs.size(0)
  14.                         val_running_corrects += torch.sum(val_preds == val_labels.data)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement