Guest User

Untitled

a guest
Oct 15th, 2018
111
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.84 KB | None | 0 0
  1. for epoch in tqdm(range(1, num_epochs+1)):
  2. start_time = time.time()
  3.  
  4. scheduler.step()
  5.  
  6. lr = scheduler.get_lr()[0]
  7.  
  8. model.train()
  9. train_loss_total = 0.0
  10. num_steps = 0
  11.  
  12. ### Training
  13. for i, batch in enumerate(train_loader):
  14. input_samples, gt_samples = batch["input"], batch["gt"]
  15.  
  16. var_input = input_samples.cuda()
  17. var_gt = gt_samples.cuda(async=True)
  18.  
  19. preds = model(var_input)
  20.  
  21. loss = mt_losses.dice_loss(preds, var_gt)
  22. train_loss_total += loss.item()
  23.  
  24. optimizer.zero_grad()
  25. loss.backward()
  26. optimizer.step()
  27. num_steps += 1
  28.  
  29. if epoch % 5 == 0:
  30. grid_img = vutils.make_grid(input_samples,
  31. normalize=True,
  32. scale_each=True)
  33.  
  34.  
  35. grid_img = vutils.make_grid(preds.data.cpu(),
  36. normalize=True,
  37. scale_each=True)
  38.  
  39.  
  40. grid_img = vutils.make_grid(gt_samples,
  41. normalize=True,
  42. scale_each=True)
  43.  
  44. train_loss_total_avg = train_loss_total / num_steps
  45. model.eval()
  46. val_loss_total = 0.0
  47. num_steps = 0
  48. train_acc = accuracy(preds.cpu().detach().numpy(),
  49. var_gt.cpu().detach().numpy())
  50.  
  51. metric_fns = [mt_metrics.dice_score,
  52. mt_metrics.hausdorff_score,
  53. mt_metrics.precision_score,
  54. mt_metrics.recall_score,
  55. mt_metrics.specificity_score,
  56. mt_metrics.intersection_over_union,
  57. mt_metrics.accuracy_score]
  58.  
  59. metric_mgr = mt_metrics.MetricManager(metric_fns)
  60.  
  61. ### Validating
  62. for i, batch in enumerate(val_loader):
  63. input_samples, gt_samples = batch["input"], batch["gt"]
  64.  
  65. with torch.no_grad():
  66. var_input = input_samples.cuda()
  67. var_gt = gt_samples.cuda(async=True)
  68.  
  69. preds = model(var_input)
  70. loss = mt_losses.dice_loss(preds, var_gt)
  71. val_loss_total += loss.item()
  72.  
  73. # Metrics computation
  74. gt_npy = gt_samples.numpy().astype(np.uint8)
  75. gt_npy = gt_npy.squeeze(axis=1)
  76.  
  77. preds = preds.data.cpu().numpy()
  78. preds = threshold_predictions(preds)
  79. preds = preds.astype(np.uint8)
  80. preds = preds.squeeze(axis=1)
  81.  
  82. metric_mgr(preds, gt_npy)
  83.  
  84. num_steps += 1
  85.  
  86. metrics_dict = metric_mgr.get_results()
  87. metric_mgr.reset()
  88. val_loss_total_avg = val_loss_total / num_steps
  89.  
  90. print('\nTrain loss: {:.4f}, Training Accuracy: {:.4f} '.format(train_loss_total_avg, train_acc))
  91. print('Val Loss: {:.4f}, Validation Accuracy: {:.4f} '.format(val_loss_total_avg, metrics_dict["accuracy_score"]))
Add Comment
Please, Sign In to add comment