Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- for epoch in tqdm(range(1, num_epochs+1)):
- start_time = time.time()
- scheduler.step()
- lr = scheduler.get_lr()[0]
- model.train()
- train_loss_total = 0.0
- num_steps = 0
- ### Training
- for i, batch in enumerate(train_loader):
- input_samples, gt_samples = batch["input"], batch["gt"]
- var_input = input_samples.cuda()
- var_gt = gt_samples.cuda(async=True)
- preds = model(var_input)
- loss = mt_losses.dice_loss(preds, var_gt)
- train_loss_total += loss.item()
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
- num_steps += 1
- if epoch % 5 == 0:
- grid_img = vutils.make_grid(input_samples,
- normalize=True,
- scale_each=True)
- grid_img = vutils.make_grid(preds.data.cpu(),
- normalize=True,
- scale_each=True)
- grid_img = vutils.make_grid(gt_samples,
- normalize=True,
- scale_each=True)
- train_loss_total_avg = train_loss_total / num_steps
- model.eval()
- val_loss_total = 0.0
- num_steps = 0
- train_acc = accuracy(preds.cpu().detach().numpy(),
- var_gt.cpu().detach().numpy())
- metric_fns = [mt_metrics.dice_score,
- mt_metrics.hausdorff_score,
- mt_metrics.precision_score,
- mt_metrics.recall_score,
- mt_metrics.specificity_score,
- mt_metrics.intersection_over_union,
- mt_metrics.accuracy_score]
- metric_mgr = mt_metrics.MetricManager(metric_fns)
- ### Validating
- for i, batch in enumerate(val_loader):
- input_samples, gt_samples = batch["input"], batch["gt"]
- with torch.no_grad():
- var_input = input_samples.cuda()
- var_gt = gt_samples.cuda(async=True)
- preds = model(var_input)
- loss = mt_losses.dice_loss(preds, var_gt)
- val_loss_total += loss.item()
- # Metrics computation
- gt_npy = gt_samples.numpy().astype(np.uint8)
- gt_npy = gt_npy.squeeze(axis=1)
- preds = preds.data.cpu().numpy()
- preds = threshold_predictions(preds)
- preds = preds.astype(np.uint8)
- preds = preds.squeeze(axis=1)
- metric_mgr(preds, gt_npy)
- num_steps += 1
- metrics_dict = metric_mgr.get_results()
- metric_mgr.reset()
- val_loss_total_avg = val_loss_total / num_steps
- print('\nTrain loss: {:.4f}, Training Accuracy: {:.4f} '.format(train_loss_total_avg, train_acc))
- print('Val Loss: {:.4f}, Validation Accuracy: {:.4f} '.format(val_loss_total_avg, metrics_dict["accuracy_score"]))
Add Comment
Please, Sign In to add comment