Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- if phase == 'val':
- model.eval()
- with torch.no_grad():
- for inputs, labels in dataloaders[phase]:
- val_inputs = inputs.to(device)
- val_labels = labels.to(device)
- val_outputs = model(val_inputs)
- val_loss = criterion(val_outputs, val_labels)
- _, val_preds = torch.max(val_outputs, 1)
- # statistics
- val_running_loss += val_loss.item() * val_inputs.size(0)
- val_running_corrects += torch.sum(val_preds == val_labels.data)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement